"""Enhanced prompt configuration with full feature parity to AugLLMConfig.
This module provides complete prompt management including few-shot learning,
messages placeholders, format instructions, and dynamic template creation.
"""
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from langchain_core.prompts import (
BasePromptTemplate,
ChatPromptTemplate,
PromptTemplate,
FewShotPromptTemplate,
FewShotChatMessagePromptTemplate,
MessagesPlaceholder
)
from langchain_core.output_parsers import PydanticOutputParser
from haive.core.contracts.prompt_config import PromptContract, PromptVariable
[docs]
class FewShotConfig(BaseModel):
"""Configuration for few-shot prompting.
Attributes:
examples: List of example input-output pairs.
example_prompt: Template for formatting examples.
prefix: Text before examples.
suffix: Text after examples.
example_separator: Separator between examples.
example_selector: Optional selector for dynamic examples.
"""
examples: List[Dict[str, str]] = Field(
default_factory=list,
description="Few-shot examples"
)
example_prompt: Optional[Union[PromptTemplate, ChatPromptTemplate]] = Field(
default=None,
description="Template for examples"
)
prefix: Optional[str] = Field(
default=None,
description="Text before examples"
)
suffix: Optional[str] = Field(
default=None,
description="Text after examples"
)
example_separator: str = Field(
default="\n\n",
description="Separator between examples"
)
example_selector: Optional[Any] = Field(
default=None,
description="Dynamic example selector"
)
[docs]
class Config:
"""Pydantic configuration."""
arbitrary_types_allowed = True
[docs]
class MessagesConfig(BaseModel):
"""Configuration for messages placeholder handling.
Attributes:
add_messages_placeholder: Whether to add placeholder.
messages_placeholder_name: Variable name for messages.
force_messages_optional: Force optional messages.
uses_messages_field: Whether template uses messages.
messages_position: Where to insert placeholder.
"""
add_messages_placeholder: bool = Field(
default=True,
description="Add messages placeholder"
)
messages_placeholder_name: str = Field(
default="messages",
description="Variable name for messages"
)
force_messages_optional: bool = Field(
default=False,
description="Force optional messages"
)
uses_messages_field: bool = Field(
default=False,
description="Template uses messages field"
)
messages_position: Literal["start", "before_human", "end"] = Field(
default="before_human",
description="Where to insert placeholder"
)
[docs]
class TemplateManager(BaseModel):
"""Manages multiple prompt templates.
Attributes:
stored_templates: Named templates.
active_template: Currently active template.
template_history: History of used templates.
fallback_template: Fallback if main fails.
"""
stored_templates: Dict[str, BasePromptTemplate] = Field(
default_factory=dict,
description="Named templates"
)
active_template: Optional[str] = Field(
default=None,
description="Active template name"
)
template_history: List[str] = Field(
default_factory=list,
description="Template usage history"
)
fallback_template: Optional[BasePromptTemplate] = Field(
default=None,
description="Fallback template"
)
[docs]
class Config:
"""Pydantic configuration."""
arbitrary_types_allowed = True
[docs]
class EnhancedPromptConfig(BaseModel):
"""Complete prompt configuration with full AugLLMConfig feature parity.
This provides all prompt management features from AugLLMConfig:
- Few-shot prompting with examples
- Messages placeholder management
- Dynamic template creation
- Format instructions integration
- Template storage and switching
- Input variable management
Attributes:
prompt_template: Main prompt template.
system_message: System message for chat models.
few_shot: Few-shot configuration.
messages: Messages configuration.
format_instructions: Format instructions config.
template_manager: Template management.
input_variables: Required input variables.
optional_variables: Optional input variables.
partial_variables: Partial variables.
contracts: Prompt contracts.
"""
prompt_template: Optional[BasePromptTemplate] = Field(
default=None,
description="Main prompt template"
)
system_message: Optional[str] = Field(
default=None,
description="System message"
)
few_shot: FewShotConfig = Field(
default_factory=FewShotConfig,
description="Few-shot config"
)
messages: MessagesConfig = Field(
default_factory=MessagesConfig,
description="Messages config"
)
format_instructions: FormatInstructionsConfig = Field(
default_factory=FormatInstructionsConfig,
description="Format instructions"
)
template_manager: TemplateManager = Field(
default_factory=TemplateManager,
description="Template manager"
)
input_variables: List[str] = Field(
default_factory=list,
description="Required variables"
)
optional_variables: List[str] = Field(
default_factory=list,
description="Optional variables"
)
partial_variables: Dict[str, Any] = Field(
default_factory=dict,
description="Partial variables"
)
contracts: Dict[str, PromptContract] = Field(
default_factory=dict,
description="Prompt contracts"
)
[docs]
class Config:
"""Pydantic configuration."""
arbitrary_types_allowed = True
[docs]
def create_template(self) -> Optional[BasePromptTemplate]:
"""Create appropriate template based on configuration.
Returns:
Created template or None.
"""
# Return existing template if available
if self.prompt_template:
return self.prompt_template
# Try few-shot template creation
if self.few_shot.examples and self.few_shot.example_prompt:
return self._create_few_shot_template()
# Try chat template from system message
if self.system_message:
return self._create_chat_from_system()
# Create default with messages placeholder
if self.messages.add_messages_placeholder:
return self._create_default_with_placeholder()
# Fallback to simple template
return PromptTemplate.from_template("{input}")
def _create_few_shot_template(self) -> BasePromptTemplate:
"""Create few-shot template.
Returns:
FewShotPromptTemplate or FewShotChatMessagePromptTemplate.
"""
# Check for chat-based few-shot
if isinstance(self.few_shot.example_prompt, ChatPromptTemplate):
return FewShotChatMessagePromptTemplate(
examples=self.few_shot.examples,
example_prompt=self.few_shot.example_prompt,
input_variables=self.input_variables or ["input"]
)
# Standard few-shot with prefix/suffix
if self.few_shot.prefix and self.few_shot.suffix:
return FewShotPromptTemplate(
examples=self.few_shot.examples,
example_prompt=self.few_shot.example_prompt,
prefix=self.few_shot.prefix,
suffix=self.few_shot.suffix,
example_separator=self.few_shot.example_separator,
input_variables=self.input_variables or ["input"],
example_selector=self.few_shot.example_selector
)
# Fallback to basic few-shot
return FewShotPromptTemplate(
examples=self.few_shot.examples,
example_prompt=self.few_shot.example_prompt,
input_variables=self.input_variables or ["input"]
)
def _create_chat_from_system(self) -> ChatPromptTemplate:
"""Create chat template from system message.
Returns:
ChatPromptTemplate with system message.
"""
messages = [("system", self.system_message)]
# Add messages placeholder if configured
if self.messages.add_messages_placeholder:
placeholder = MessagesPlaceholder(
variable_name=self.messages.messages_placeholder_name,
optional=self.messages.force_messages_optional or
self.messages.messages_placeholder_name in self.optional_variables
)
if self.messages.messages_position == "start":
messages.insert(0, placeholder)
elif self.messages.messages_position == "before_human":
messages.append(placeholder)
# Add human message
messages.append(("human", "{input}"))
# Add after if position is end
if (self.messages.add_messages_placeholder and
self.messages.messages_position == "end"):
messages.append(placeholder)
return ChatPromptTemplate.from_messages(messages)
def _create_default_with_placeholder(self) -> ChatPromptTemplate:
"""Create default template with messages placeholder.
Returns:
Default ChatPromptTemplate.
"""
messages = []
# Add placeholder
placeholder = MessagesPlaceholder(
variable_name=self.messages.messages_placeholder_name,
optional=True
)
messages.append(placeholder)
# Add human message
messages.append(("human", "{input}"))
return ChatPromptTemplate.from_messages(messages)
[docs]
def ensure_messages_placeholder(self) -> None:
"""Ensure messages placeholder exists in template."""
if not self.prompt_template:
return
if not isinstance(self.prompt_template, ChatPromptTemplate):
return
messages = list(self.prompt_template.messages)
has_placeholder = any(
isinstance(msg, MessagesPlaceholder) and
getattr(msg, "variable_name", "") == self.messages.messages_placeholder_name
for msg in messages
)
if not has_placeholder and self.messages.add_messages_placeholder:
# Add placeholder
placeholder = MessagesPlaceholder(
variable_name=self.messages.messages_placeholder_name,
optional=self.messages.force_messages_optional
)
# Insert at appropriate position
if self.messages.messages_position == "start":
messages.insert(0, placeholder)
elif self.messages.messages_position == "before_human":
# Find last human message
for i in range(len(messages) - 1, -1, -1):
if messages[i][0] == "human":
messages.insert(i, placeholder)
break
else:
messages.append(placeholder)
else:
messages.append(placeholder)
# Update template
self.prompt_template = ChatPromptTemplate.from_messages(messages)
self.messages.uses_messages_field = True
[docs]
def apply_partial_variables(self) -> None:
"""Apply partial variables to template."""
if self.prompt_template and self.partial_variables:
self.prompt_template = self.prompt_template.partial(**self.partial_variables)
[docs]
def store_template(self, name: str, template: BasePromptTemplate) -> None:
"""Store a template for later use.
Args:
name: Template name.
template: Template to store.
"""
self.template_manager.stored_templates[name] = template
self.template_manager.template_history.append(name)
[docs]
def use_template(self, name: str) -> bool:
"""Switch to a stored template.
Args:
name: Template name to use.
Returns:
True if switched successfully.
"""
if name in self.template_manager.stored_templates:
self.prompt_template = self.template_manager.stored_templates[name]
self.template_manager.active_template = name
self.template_manager.template_history.append(name)
return True
return False
[docs]
def remove_template(self, name: Optional[str] = None) -> bool:
"""Remove a stored template.
Args:
name: Template to remove (current if None).
Returns:
True if removed successfully.
"""
if name is None:
name = self.template_manager.active_template
if name and name in self.template_manager.stored_templates:
del self.template_manager.stored_templates[name]
if self.template_manager.active_template == name:
self.template_manager.active_template = None
return True
return False
[docs]
def list_templates(self) -> List[str]:
"""List all stored template names.
Returns:
List of template names.
"""
return list(self.template_manager.stored_templates.keys())
[docs]
def validate_configuration(self) -> Dict[str, str]:
"""Validate the configuration.
Returns:
Dictionary of validation errors.
"""
errors = {}
# Check few-shot configuration
if self.few_shot.examples and not self.few_shot.example_prompt:
errors["few_shot"] = "Examples provided but no example_prompt"
# Check format instructions
if (self.format_instructions.include_format_instructions and
self.format_instructions.format_instructions_key not in self.partial_variables and
not self.format_instructions.auto_generate):
errors["format_instructions"] = "Format instructions enabled but not provided"
# Check input variables
required_vars = self.compute_input_variables()
if not required_vars and not self.optional_variables:
errors["variables"] = "No input variables defined"
return errors
[docs]
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary.
Returns:
Configuration as dictionary.
"""
return {
"has_template": self.prompt_template is not None,
"system_message": self.system_message,
"few_shot": {
"enabled": bool(self.few_shot.examples),
"example_count": len(self.few_shot.examples)
},
"messages": self.messages.model_dump(),
"format_instructions": {
"enabled": self.format_instructions.include_format_instructions,
"has_text": self.format_instructions.format_instructions_text is not None
},
"templates": {
"stored_count": len(self.template_manager.stored_templates),
"active": self.template_manager.active_template
},
"variables": {
"input": self.input_variables,
"optional": self.optional_variables,
"partial": list(self.partial_variables.keys())
}
}