class ContrastiveGenerator(DirectGenerator):
"""Generate checklists by comparing candidate responses (RLCF candidate modes).
Extends DirectGenerator with candidate auto-generation. Candidates are
generated by smaller models and included in the prompt for contrastive
analysis.
Two modes:
- rlcf_candidate: input + reference + candidates
- rlcf_candidates_only: input + candidates (no reference)
"""
def __init__(
self,
candidate_models: Optional[List[str]] = None,
num_candidates: int = 4,
generate_candidates: Optional[bool] = None,
candidate_provider: Optional[str] = None,
candidate_base_url: Optional[str] = None,
candidate_api_key: Optional[str] = None,
candidate_api_format: Optional[str] = None,
**kwargs: Any,
):
super().__init__(**kwargs)
# Read generate_candidates from preset if not explicitly provided
from .pipeline_presets import PIPELINE_PRESETS
preset = PIPELINE_PRESETS.get(self._method_name, {})
if generate_candidates is None:
self.generate_candidates = preset.get("generate_candidates", True)
else:
self.generate_candidates = generate_candidates
self.candidate_models = candidate_models
self.num_candidates = num_candidates
self._candidate_provider = candidate_provider
self._candidate_base_url = candidate_base_url
self._candidate_api_key = candidate_api_key
self._candidate_api_format = candidate_api_format
def generate(
self,
input: str,
target: Optional[str] = None,
reference: Optional[str] = None,
candidates: Optional[Union[List[str], Dict[str, str]]] = None,
**kwargs: Any,
) -> Checklist:
"""Generate checklist from input + candidates.
Args:
input: The instruction/query
target: Alias for reference
reference: Expert/reference target (optional for candidates_only)
candidates: Candidate responses. Can be:
- List[str]: multiple candidates (RLCF or listwise)
- Dict with "chosen"/"rejected" keys (pairwise CRG)
- None: auto-generated if candidate_models is set
**kwargs: Additional arguments
"""
# Get or generate candidates
if candidates is None:
if self.generate_candidates and self.candidate_models:
candidates = self._generate_candidates(input)
else:
raise ValueError(
f"{self.method_name} requires 'candidates' argument."
)
# Delegate to _generate_with_candidates with raw candidates
checklist = self._generate_with_candidates(
input=input,
candidates=candidates,
reference=reference,
**kwargs,
)
# Store raw candidates and count in metadata
if isinstance(candidates, dict):
checklist.metadata["candidates"] = list(candidates.values())
checklist.metadata["num_candidates"] = 2
else:
checklist.metadata["candidates"] = candidates
checklist.metadata["num_candidates"] = len(candidates)
return checklist
def _generate_with_candidates(
self,
input: str,
candidates: Union[List[str], Dict[str, str]],
reference: Optional[str] = None,
**kwargs: Any,
) -> Checklist:
"""Build prompt with candidates and call model.
Routes candidates to template placeholders based on type and template:
- Dict → {chosen} + {rejected} placeholders (pairwise CRG)
- List + {responses} placeholder → numbered Response blocks (listwise)
- List + {candidates} placeholder → numbered Candidate blocks (RLCF)
"""
placeholders = self._template._placeholders
format_kwargs: dict[str, str] = {"input": input}
# --- Route candidates to placeholders ---
if isinstance(candidates, dict):
# Pairwise: dict must have chosen+rejected, template must have those placeholders
if "candidates" in placeholders:
raise ValueError(
"Template has {candidates} placeholder but received dict candidates. "
"Use {chosen}/{rejected} placeholders for pairwise, or pass a list."
)
if not {"chosen", "rejected"} <= placeholders:
raise ValueError(
"Template must have {chosen} and {rejected} placeholders for dict candidates."
)
if set(candidates.keys()) != {"chosen", "rejected"}:
raise ValueError(
"Dict candidates must have exactly 'chosen' and 'rejected' keys, "
f"got: {set(candidates.keys())}"
)
format_kwargs["chosen"] = candidates["chosen"]
format_kwargs["rejected"] = candidates["rejected"]
else:
# List candidates
if "chosen" in placeholders or "rejected" in placeholders:
raise ValueError(
"Template has {chosen}/{rejected} placeholders but received list candidates. "
"Pass a dict with 'chosen' and 'rejected' keys instead."
)
if "responses" in placeholders:
format_kwargs["responses"] = self._format_ordered_responses(candidates)
elif "candidates" in placeholders:
format_kwargs["candidates"] = self._format_candidates(candidates)
else:
raise ValueError(
"Template must have {candidates} or {responses} placeholder for list candidates."
)
# --- Handle optional placeholders ---
if "context" in placeholders:
format_kwargs["context"] = kwargs.pop("context", "")
if "reference" in placeholders:
if reference is None:
raise ValueError(
f"{self.method_name} requires a reference target."
)
format_kwargs["reference"] = reference
# Load format instructions (skip for custom schemas)
format_text = load_format(self._format_name) if self._format_name else ""
# Inject format inline if template has {format_instructions} placeholder,
# otherwise append after the prompt (default).
if "format_instructions" in placeholders:
format_kwargs["format_instructions"] = format_text
full_prompt = self._template.format(**format_kwargs)
else:
prompt = self._template.format(**format_kwargs)
full_prompt = prompt + "\n\n" + format_text
response_format = to_response_format(
self._response_schema, self._method_name
)
raw = self._call_model(full_prompt, response_format=response_format)
items = self._parse_structured(raw)
return Checklist(
items=items,
source_method=self.method_name,
generation_level=self.generation_level,
input=input,
metadata={"raw_response": raw},
)
def _get_candidate_client(self) -> Any:
"""Get client for candidate generation.
If any candidate_* provider param is set, creates a separate client.
Otherwise falls back to the main client via _get_or_create_client().
"""
if any([
self._candidate_provider,
self._candidate_base_url,
self._candidate_api_key,
self._candidate_api_format,
]):
return get_client(
provider=self._candidate_provider or self._provider,
base_url=self._candidate_base_url,
api_key=self._candidate_api_key,
model=self.model,
api_format=self._candidate_api_format,
)
return self._get_or_create_client()
def _generate_candidates(self, input: str) -> List[str]:
"""Generate candidate responses using smaller models."""
candidates = []
client = self._get_candidate_client()
if len(self.candidate_models) > 1:
for model in self.candidate_models:
resp = client.chat_completion(
model=model,
messages=[{"role": "user", "content": input}],
temperature=0.7,
max_tokens=1024,
)
candidates.append(resp["choices"][0]["message"]["content"])
else:
model = self.candidate_models[0]
for _ in range(self.num_candidates):
resp = client.chat_completion(
model=model,
messages=[{"role": "user", "content": input}],
temperature=0.9,
max_tokens=1024,
)
candidates.append(resp["choices"][0]["message"]["content"])
return candidates
def _format_ordered_responses(self, responses: List[str]) -> str:
"""Format responses as numbered Response blocks for listwise CRG."""
formatted = []
for i, response in enumerate(responses, 1):
formatted.append(f"### Response {i}\n{response}")
return "\n\n".join(formatted)
def _format_candidates(self, candidates: List[str]) -> str:
"""Format candidate responses for prompt injection."""
formatted = []
for i, candidate in enumerate(candidates, 1):
formatted.append(f"### Candidate {i}\n{candidate}")
return "\n\n".join(formatted)