Source code for haive.core.schema.composer.field.field_manager

"""Field management for schema composition."""

import logging
from collections import defaultdict
from collections.abc import Callable
from typing import Any

from haive.core.schema.field_definition import FieldDefinition

logger = logging.getLogger(__name__)


[docs] class FieldManagerMixin: """Mixin that handles field management in SchemaComposer. This mixin provides: - Field definition storage and management - Field metadata tracking (sources, sharing, reducers) - Field validation and processing - Engine I/O field mapping """ def __init__(self, *args, **kwargs) -> None: """Initialize field tracking structures.""" super().__init__(*args, **kwargs) # Core field storage self.fields: dict[str, FieldDefinition] = {} # Field metadata tracking self.shared_fields: set = set() self.field_sources: dict[str, set] = defaultdict(set) self.nested_schemas: dict[str, type] = {} # Engine I/O tracking self.input_fields: dict[str, set] = defaultdict(set) self.output_fields: dict[str, set] = defaultdict(set) self.engine_io_mappings: dict[str, dict[str, list[str]]] = defaultdict( lambda: {"inputs": [], "outputs": []} ) # Structured output tracking self.structured_models: dict[str, type] = {} self.structured_model_fields: dict[str, list[str]] = defaultdict(list)
[docs] def add_field( self, name: str, field_type: type, default: Any = None, default_factory: Callable[[], Any] | None = None, description: str | None = None, shared: bool = False, reducer: Callable | None = None, source: str | None = None, input_for: list[str] | None = None, output_from: list[str] | None = None, ) -> "FieldManagerMixin": """Add a field definition to the schema. Args: name: Field name field_type: Type of the field default: Default value for the field default_factory: Optional factory function for creating default values description: Optional field description for documentation shared: Whether field is shared with parent graph reducer: Optional reducer function for merging field values source: Optional source identifier input_for: Optional list of engines this field is input for output_from: Optional list of engines this field is output from Returns: Self for method chaining """ # Skip special fields if name in {"__runnable_config__", "runnable_config"}: logger.warning(f"Skipping special field {name}") return self # Validate field type if field_type is None: field_type = Any elif not isinstance(field_type, type) and not hasattr(field_type, "__origin__"): if isinstance(field_type, dict | list | tuple) and not hasattr( field_type, "__origin__" ): logger.warning( f"Invalid field type for '{name}': {field_type}, using Any" ) field_type = Any # Check if field is provided by base class if hasattr(self, "detected_base_class") and hasattr(self, "base_class_fields"): if self.detected_base_class and name in self.base_class_fields: logger.debug( f"Field '{name}' is provided by base class, updating metadata only" ) # Still track metadata for the field if shared: self.shared_fields.add(name) if source: self.field_sources[name].add(source) if input_for: for engine in input_for: self.input_fields[engine].add(name) if engine not in self.engine_io_mappings: self.engine_io_mappings[engine] = { "inputs": [], "outputs": [], } if name not in self.engine_io_mappings[engine]["inputs"]: self.engine_io_mappings[engine]["inputs"].append(name) if output_from: for engine in output_from: self.output_fields[engine].add(name) if engine not in self.engine_io_mappings: self.engine_io_mappings[engine] = { "inputs": [], "outputs": [], } if name not in self.engine_io_mappings[engine]["outputs"]: self.engine_io_mappings[engine]["outputs"].append(name) return self # Create field definition field_def = FieldDefinition( name=name, field_type=field_type, default=default, default_factory=default_factory, description=description, shared=shared, reducer=reducer, input_for=input_for or [], output_from=output_from or [], ) # Check for nested schemas import inspect from typing import Union from haive.core.schema.state_schema import StateSchema if inspect.isclass(field_type) and issubclass(field_type, StateSchema): logger.debug(f"Field '{name}' is a StateSchema type: {field_type.__name__}") self.nested_schemas[name] = field_type elif getattr(field_type, "__origin__", None) is Union and any( inspect.isclass(arg) and issubclass(arg, StateSchema) for arg in field_type.__args__ if inspect.isclass(arg) ): # Handle Optional[StateSchema] and Union types containing # StateSchema for arg in field_type.__args__: if inspect.isclass(arg) and issubclass(arg, StateSchema): logger.debug( f"Field '{name}' contains StateSchema type: {arg.__name__}" ) self.nested_schemas[name] = arg break # Update detection flags if name == "messages": self.has_messages = True logger.debug( "Added 'messages' field - will use MessagesState as base class" ) if name == "tools": self.has_tools = True logger.debug("Added 'tools' field - will use ToolState as base class") # Store the field self.fields[name] = field_def # Track input/output relationships if input_for: for engine in input_for: self.input_fields[engine].add(name) if engine not in self.engine_io_mappings: self.engine_io_mappings[engine] = {"inputs": [], "outputs": []} if name not in self.engine_io_mappings[engine]["inputs"]: self.engine_io_mappings[engine]["inputs"].append(name) if output_from: for engine in output_from: self.output_fields[engine].add(name) if engine not in self.engine_io_mappings: self.engine_io_mappings[engine] = {"inputs": [], "outputs": []} if name not in self.engine_io_mappings[engine]["outputs"]: self.engine_io_mappings[engine]["outputs"].append(name) # Track additional metadata if shared: self.shared_fields.add(name) logger.debug(f"Marked field '{name}' as shared") if source: self.field_sources[name].add(source) logger.debug(f"Field '{name}' source: {source}") # Add tracking entry if hasattr(self, "processing_history"): self.processing_history.append( { "action": "add_field", "field_name": name, "field_type": str(field_type), "shared": shared, "has_reducer": reducer is not None, } ) logger.debug(f"Added field '{name}' of type {field_type}") return self
[docs] def get_field_count(self) -> int: """Get the number of fields defined.""" return len(self.fields)
[docs] def has_field(self, name: str) -> bool: """Check if a field is defined.""" return name in self.fields
[docs] def get_field_names(self) -> list[str]: """Get list of all field names.""" return list(self.fields.keys())
[docs] def get_shared_fields(self) -> list[str]: """Get list of shared field names.""" return list(self.shared_fields)
[docs] def get_engine_io_mapping(self, engine_name: str) -> dict[str, list[str]]: """Get input/output mapping for a specific engine.""" return self.engine_io_mappings.get(engine_name, {"inputs": [], "outputs": []})