"""PromptTemplateMixin: Advanced prompt template integration for engine classes.
This module provides the PromptTemplateMixin class, which adds sophisticated
prompt template management capabilities to any engine class. The mixin enables
automatic input schema derivation, prompt template validation, and seamless
composition with existing engine functionality.
The PromptTemplateMixin is designed to integrate with Haive's engine architecture,
particularly AugLLMConfig, to provide dynamic schema generation based on prompt
template requirements while preserving existing engine behaviors.
Key Features:
- Automatic conversion of prompt templates to InvokableEngines
- Dynamic input schema derivation with intelligent composition
- Prompt template validation and preprocessing
- Schema composition with existing engine schemas
- Field-level validation integration via Pydantic validators
- Support for both override and composition approaches
Architecture:
The mixin uses method override patterns to integrate with engine classes:
- Overrides derive_input_schema() to incorporate prompt template variables
- Provides field validators for prompt template preprocessing
- Offers helper methods for prompt formatting and variable management
Integration Patterns:
1. Method Override: derive_input_schema() method is overridden to check for
prompt templates and compose schemas when present
2. Field Validation: @field_validator decorators preprocess prompt templates
3. Composition: Existing schemas are preserved and extended, not replaced
Examples:
Basic integration with an engine class:
.. code-block:: python
from haive.core.common.mixins.prompt_template_mixin import PromptTemplateMixin
from haive.core.engine.base import InvokableEngine
class MyEngine(PromptTemplateMixin, InvokableEngine):
prompt_template: Optional[BasePromptTemplate] = None
# The mixin automatically enhances input schema derivation
pass
# Usage
engine = MyEngine(prompt_template=my_template)
schema = engine.derive_input_schema() # Includes prompt variables
Advanced usage with schema composition:
.. code-block:: python
# Engine with existing input schema
class AdvancedEngine(PromptTemplateMixin, InvokableEngine):
def get_base_input_schema(self):
return MyExistingSchema
# The mixin will compose prompt variables with existing schema
engine = AdvancedEngine(prompt_template=chat_template)
combined_schema = engine.derive_input_schema()
Classes:
PromptTemplateMixin: Main mixin class for prompt template integration
Dependencies:
- langchain_core: For prompt template functionality and message types
- pydantic: For schema generation, validation, and field validation
- typing: For type hints and optional typing support
Author:
Haive Core Team
Version:
1.0.0
See Also:
- haive.core.engine.prompt_template.PromptTemplateEngine: Standalone engine
- haive.core.engine.aug_llm.config.AugLLMConfig: Primary integration target
- haive.core.schema.schema_composer.SchemaComposer: Schema composition utilities
"""
import contextlib
from typing import TYPE_CHECKING, Any, Optional
from langchain_core.messages import AnyMessage
from langchain_core.prompts import BasePromptTemplate
from pydantic import BaseModel, field_validator
if TYPE_CHECKING:
from haive.core.engine.prompt_template import PromptTemplateEngine
[docs]
class PromptTemplateMixin:
"""Advanced mixin for integrating prompt template functionality into engine classes.
This mixin provides comprehensive prompt template management capabilities,
enabling any engine class to automatically derive input schemas from prompt
templates while preserving existing functionality through intelligent composition.
The mixin is designed to be non-invasive and compatible with existing engine
architectures. It overrides key methods like derive_input_schema() to enhance
functionality rather than replace it, ensuring backward compatibility.
Key Capabilities:
- Automatic prompt template to engine conversion
- Dynamic input schema derivation with type inference
- Intelligent schema composition (prompt + existing schemas)
- Prompt template validation and preprocessing
- Field-level integration via Pydantic validators
- Configurable behavior (enable/disable prompt schema usage)
Required Fields:
Classes using this mixin must define:
- prompt_template: Optional[BasePromptTemplate] = None
Optional Configuration:
- _use_prompt_for_input_schema: bool = True (control schema derivation)
- _prompt_engine: Internal cache for prompt template engine
Method Override Pattern:
The mixin overrides derive_input_schema() using a safe pattern:
1. Attempts to call parent class implementation
2. Checks if prompt template schema derivation is enabled
3. Derives schema from prompt template if present
4. Composes schemas intelligently (prompt + parent)
5. Falls back gracefully on any errors
Examples:
Basic usage:
.. code-block:: python
class MyEngine(PromptTemplateMixin, InvokableEngine):
prompt_template: Optional[BasePromptTemplate] = None
engine = MyEngine(prompt_template=template)
schema = engine.derive_input_schema() # Enhanced with prompt fields
With existing schema:
.. code-block:: python
class ComplexEngine(PromptTemplateMixin, SomeOtherMixin, InvokableEngine):
# Existing schema logic preserved and enhanced
pass
Integration Notes:
- Safe to use with multiple inheritance
- Preserves existing derive_input_schema() behavior
- Graceful error handling prevents disruption
- Can be enabled/disabled at runtime
See Also:
- PromptTemplateEngine: Standalone engine implementation
- AugLLMConfig: Primary usage example with full integration
"""
# Fields that subclasses need to have
prompt_template: BasePromptTemplate | None
_prompt_engine: Optional["PromptTemplateEngine"] = None
_use_prompt_for_input_schema: bool = True
[docs]
@field_validator("prompt_template", mode="before")
@classmethod
def validate_prompt_template(cls, v) -> Any:
"""Validate and potentially transform prompt template before assignment."""
if v is None:
return v
# Add any validation logic here
# Could check for required variables, validate template format, etc.
if hasattr(v, "input_variables"):
# Ensure input_variables is not None
if v.input_variables is None:
# Try to extract variables from template if possible
if hasattr(v, "template") and hasattr(v, "_get_template_variables"):
try:
v.input_variables = v._get_template_variables()
except BaseException:
v.input_variables = []
return v
[docs]
def get_prompt_engine(self) -> Optional["PromptTemplateEngine"]:
"""Get or create a cached PromptTemplateEngine for the current prompt template.
This method provides lazy initialization of a PromptTemplateEngine wrapper
around the current prompt template. The engine is cached to avoid recreation
on multiple calls, improving performance.
Returns:
Optional[PromptTemplateEngine]: A PromptTemplateEngine instance wrapping
the current prompt template, or None if no prompt template is set.
Note:
- The engine is cached in _prompt_engine for reuse
- Engine name is automatically generated from the parent object's name
- Returns None if no prompt template is configured
Examples:
.. code-block:: python
# Get the prompt engine (creates if first time)
engine = self.get_prompt_engine()
if engine:
schema = engine.derive_input_schema()
result = engine.invoke(input_data)
"""
if not self.prompt_template:
return None
if self._prompt_engine is None:
from haive.core.engine.prompt_template import PromptTemplateEngine
# Create a prompt engine with a name based on this config
engine_name = f"{getattr(self, 'name', 'config')}_prompt"
self._prompt_engine = PromptTemplateEngine(
name=engine_name, prompt_template=self.prompt_template
)
return self._prompt_engine
[docs]
def derive_prompt_output_schema(self) -> type[BaseModel] | None:
"""Derive output schema from the prompt template."""
prompt_engine = self.get_prompt_engine()
if prompt_engine:
return prompt_engine.derive_output_schema()
return None
[docs]
def get_prompt_variables(self) -> dict[str, Any]:
"""Get information about prompt template variables."""
if not self.prompt_template:
return {}
return {
"input_variables": list(self.prompt_template.input_variables or []),
"optional_variables": list(
getattr(self.prompt_template, "optional_variables", []) or []
),
"partial_variables": list(
self.prompt_template.partial_variables.keys()
if self.prompt_template.partial_variables
else []
),
"template_format": getattr(
self.prompt_template, "template_format", "f-string"
),
}
[docs]
def update_prompt_partials(self, **partials) -> bool:
"""Update partial variables in the prompt template."""
if not self.prompt_template:
return False
# Create a new template with updated partials
try:
self.prompt_template = self.prompt_template.partial(**partials)
# Reset the engine so it gets recreated with new template
self._prompt_engine = None
return True
except Exception:
return False
[docs]
def compose_with_prompt_schema(
self, base_schema: type[BaseModel]
) -> type[BaseModel]:
"""Compose a base schema with the prompt template's input schema."""
prompt_schema = self.derive_prompt_input_schema()
if not prompt_schema:
return base_schema
# Simple field combination approach
from pydantic import create_model
base_fields = base_schema.model_fields
prompt_fields = prompt_schema.model_fields
# Combine fields, with base schema taking precedence
combined_fields = {}
# Add prompt fields first
for name, field_info in prompt_fields.items():
if name not in base_fields:
# Preserve the field info structure
combined_fields[name] = (field_info.annotation, field_info)
# Add base fields (they override and take precedence)
for name, field_info in base_fields.items():
# Preserve the field info structure from base schema
combined_fields[name] = (field_info.annotation, field_info)
# Create new model with combined fields
schema_name = f"{base_schema.__name__}WithPrompt"
return create_model(schema_name, **combined_fields)
[docs]
def enable_prompt_schema_derivation(self, enabled: bool = True):
"""Enable or disable prompt template schema derivation."""
self._use_prompt_for_input_schema = enabled
[docs]
def get_missing_prompt_vars(self, input_data: dict[str, Any]) -> list[str]:
"""Get list of missing required prompt variables."""
if not self.prompt_template:
return []
required_vars = set(self.prompt_template.input_variables or [])
partial_vars = set(
self.prompt_template.partial_variables.keys()
if self.prompt_template.partial_variables
else []
)
# Remove variables that are already provided via partials
required_vars = required_vars - partial_vars
# Find missing variables
provided_vars = set(input_data.keys())
missing_vars = required_vars - provided_vars
return list(missing_vars)
[docs]
def validate_with_prompt_schema(self, input_data: dict[str, Any]) -> dict[str, Any]:
"""Validate input data against the effective input schema."""
schema = self.get_effective_input_schema()
if schema:
validated = schema(**input_data)
return validated.model_dump()
return input_data