Source code for haive.core.contracts.node_contracts

"""Node contracts for graph execution.

This module defines contracts for graph nodes, ensuring explicit
declaration of behavior and dependencies.
"""

from __future__ import annotations

import logging
from typing import Any, Callable, Dict, List, Optional

from pydantic import BaseModel, Field

from haive.core.contracts.boundaries import StateView

logger = logging.getLogger(__name__)


[docs] class NodeContract(BaseModel): """Contract for a graph node. Defines what a node requires, produces, and guarantees during execution in a graph workflow. Attributes: inputs: Required input fields from state outputs: Produced output fields to state transforms: How fields are transformed (source -> target) dependencies: Other nodes this depends on guarantees: What this node guarantees after execution error_handling: How errors are handled can_retry: Whether node can be retried on failure Examples: >>> # Validation node contract >>> validation_contract = NodeContract( ... inputs=["response", "criteria"], ... outputs=["validated_response", "validation_score"], ... transforms={"response": "validated_response"}, ... guarantees=["validation_score between 0 and 1"] ... ) >>> >>> # Tool node contract >>> tool_contract = NodeContract( ... inputs=["tool_calls", "tools"], ... outputs=["tool_results"], ... dependencies=["llm_node"], ... can_retry=True ... ) """ # Required fields inputs: List[str] = Field( default_factory=list, description="Required input fields" ) outputs: List[str] = Field( default_factory=list, description="Produced output fields" ) # Transformations transforms: Dict[str, str] = Field( default_factory=dict, description="Field transformations (source -> target)" ) # Dependencies and guarantees dependencies: List[str] = Field( default_factory=list, description="Other nodes this depends on" ) guarantees: List[str] = Field( default_factory=list, description="Post-execution guarantees" ) # Error handling error_handling: str = Field( default="raise", description="How to handle errors (raise, log, default)" ) can_retry: bool = Field( default=False, description="Whether node can be retried" ) max_retries: int = Field( default=3, description="Maximum retry attempts if can_retry" )
[docs] def validate_dependencies(self, executed_nodes: List[str]) -> List[str]: """Check if dependencies have been executed. Args: executed_nodes: List of already executed node names Returns: List of missing dependencies """ return [dep for dep in self.dependencies if dep not in executed_nodes]
[docs] def get_all_fields(self) -> Dict[str, str]: """Get all fields this node interacts with. Returns: Dictionary mapping field to interaction type """ fields = {} for field in self.inputs: fields[field] = "input" for field in self.outputs: fields[field] = "output" for source, target in self.transforms.items(): fields[source] = "transform_source" fields[target] = "transform_target" return fields
[docs] class ContractualNode: """Node that declares and enforces its contract. This ensures nodes explicitly declare their behavior, making graph composition predictable and debuggable. Attributes: name: Node identifier contract: Node's contract specification execute_fn: Function to execute _execution_count: Number of times executed _contract_violations: List of contract violations _execution_history: History of executions Examples: >>> # Create validation node >>> def validate(state_view): ... response = state_view.get("response") ... # Validation logic ... return {"validated_response": response, "score": 0.95} >>> >>> node = ContractualNode( ... name="validator", ... contract=validation_contract, ... execute_fn=validate ... ) >>> >>> # Execute with state view >>> result = node(state_view) """ def __init__( self, name: str, contract: NodeContract, execute_fn: Callable[[StateView], Dict[str, Any]] ): """Initialize contractual node. Args: name: Node identifier contract: Node's contract execute_fn: Function to execute with StateView """ self.name = name self.contract = contract self.execute_fn = execute_fn self._execution_count = 0 self._contract_violations: List[Dict[str, Any]] = [] self._execution_history: List[Dict[str, Any]] = [] self._retry_count = 0 def __call__(self, state_view: StateView) -> Dict[str, Any]: """Execute node with contract enforcement. Args: state_view: Bounded view of state Returns: Execution results Raises: ContractViolation: If contract is violated """ from datetime import datetime start_time = datetime.now() # Validate inputs if not self._validate_inputs(state_view): violation = { "node": self.name, "type": "input", "details": f"Missing required inputs: {self.contract.inputs}", "timestamp": datetime.now().isoformat() } self._contract_violations.append(violation) raise ContractViolation(violation) # Execute with retry logic result = None last_error = None for attempt in range(self.contract.max_retries if self.contract.can_retry else 1): try: result = self.execute_fn(state_view) self._execution_count += 1 break except Exception as e: last_error = e self._retry_count += 1 if not self.contract.can_retry or attempt == self.contract.max_retries - 1: # Handle error based on policy if self.contract.error_handling == "raise": violation = { "node": self.name, "type": "execution", "details": str(e), "timestamp": datetime.now().isoformat() } self._contract_violations.append(violation) raise ContractViolation(violation) elif self.contract.error_handling == "log": logger.error(f"Node '{self.name}' execution failed: {e}") result = {} elif self.contract.error_handling == "default": result = {field: None for field in self.contract.outputs} else: logger.warning( f"Node '{self.name}' attempt {attempt + 1} failed, retrying: {e}" ) if result is None: result = {} # Validate outputs if not self._validate_outputs(result): violation = { "node": self.name, "type": "output", "details": f"Missing required outputs: {self.contract.outputs}", "timestamp": datetime.now().isoformat() } self._contract_violations.append(violation) raise ContractViolation(violation) # Apply transformations to state self._apply_transforms(state_view, result) # Record execution execution_time = (datetime.now() - start_time).total_seconds() self._execution_history.append({ "timestamp": datetime.now().isoformat(), "duration": execution_time, "status": "success", "retry_count": self._retry_count }) return result def _validate_inputs(self, state_view: StateView) -> bool: """Validate all required inputs are available. Args: state_view: State view to validate Returns: True if all inputs available """ for field in self.contract.inputs: if not state_view.has_permission("read", field): logger.error(f"Node '{self.name}' cannot read required field '{field}'") return False try: # Try to get the field to ensure it exists state_view.get(field) except (PermissionError, KeyError): return False return True def _validate_outputs(self, result: Dict[str, Any]) -> bool: """Validate all required outputs are produced. Args: result: Execution result Returns: True if all outputs produced """ if not isinstance(result, dict): return False for field in self.contract.outputs: if field not in result: logger.error(f"Node '{self.name}' missing required output '{field}'") return False return True def _apply_transforms(self, state_view: StateView, result: Dict[str, Any]) -> None: """Apply field transformations to state. Args: state_view: State view to update result: Execution result """ for source, target in self.contract.transforms.items(): if source in result: # Transform and write to target field if state_view.has_permission("write", target): state_view.set(target, result[source]) else: logger.warning( f"Node '{self.name}' cannot write transform target '{target}'" )
[docs] def get_contract_summary(self) -> Dict[str, Any]: """Get human-readable contract summary. Returns: Contract details and execution statistics """ return { "name": self.name, "contract": { "inputs": self.contract.inputs, "outputs": self.contract.outputs, "transforms": self.contract.transforms, "dependencies": self.contract.dependencies, "guarantees": self.contract.guarantees, "can_retry": self.contract.can_retry }, "statistics": { "executions": self._execution_count, "violations": len(self._contract_violations), "retries": self._retry_count, "success_rate": ( (self._execution_count - len(self._contract_violations)) / self._execution_count if self._execution_count > 0 else 0 ) }, "recent_violations": self._contract_violations[-3:] }
[docs] def reset_statistics(self) -> None: """Reset execution statistics.""" self._execution_count = 0 self._contract_violations = [] self._execution_history = [] self._retry_count = 0
[docs] class ContractViolation(Exception): """Exception raised when contract is violated. Attributes: violation: Dictionary containing violation details """ def __init__(self, violation: Dict[str, Any]): """Initialize with violation details. Args: violation: Violation information """ self.violation = violation super().__init__(f"Contract violation: {violation}")
[docs] class NodeChain: """Chain of contractual nodes with dependency validation. Ensures nodes are executed in valid order with all dependencies met. Attributes: nodes: Dictionary of nodes by name execution_order: Order to execute nodes _executed: Set of executed node names Examples: >>> # Create node chain >>> chain = NodeChain() >>> chain.add_node(llm_node) >>> chain.add_node(validation_node) >>> chain.add_node(output_node) >>> >>> # Validate and execute >>> issues = chain.validate_chain() >>> if not issues: ... result = chain.execute(state_view) """ def __init__(self): """Initialize empty node chain.""" self.nodes: Dict[str, ContractualNode] = {} self.execution_order: List[str] = [] self._executed: set[str] = set()
[docs] def add_node(self, node: ContractualNode) -> None: """Add node to chain. Args: node: Node to add """ self.nodes[node.name] = node self.execution_order.append(node.name)
[docs] def validate_chain(self) -> List[str]: """Validate the node chain for consistency. Returns: List of validation issues """ issues = [] # Check dependencies are met for name in self.execution_order: node = self.nodes[name] missing = node.contract.validate_dependencies( self.execution_order[:self.execution_order.index(name)] ) if missing: issues.append( f"Node '{name}' has unmet dependencies: {missing}" ) # Check output-input compatibility available_outputs = set() for name in self.execution_order: node = self.nodes[name] # Check inputs are available missing_inputs = set(node.contract.inputs) - available_outputs if missing_inputs and name != self.execution_order[0]: issues.append( f"Node '{name}' requires {missing_inputs} but they're not produced" ) # Add outputs to available available_outputs.update(node.contract.outputs) # Add transform targets available_outputs.update(node.contract.transforms.values()) return issues
[docs] def execute(self, state_view: StateView) -> Dict[str, Any]: """Execute the node chain. Args: state_view: State view for execution Returns: Combined results from all nodes Raises: ContractViolation: If any node violates contract """ combined_result = {} for name in self.execution_order: node = self.nodes[name] # Check dependencies missing = node.contract.validate_dependencies(list(self._executed)) if missing: raise ContractViolation({ "node": name, "type": "dependency", "details": f"Missing dependencies: {missing}" }) # Execute node result = node(state_view) combined_result.update(result) # Mark as executed self._executed.add(name) return combined_result
[docs] def get_execution_plan(self) -> List[Dict[str, Any]]: """Get execution plan showing order and dependencies. Returns: List of execution steps """ plan = [] for i, name in enumerate(self.execution_order): node = self.nodes[name] plan.append({ "step": i + 1, "node": name, "inputs": node.contract.inputs, "outputs": node.contract.outputs, "dependencies": node.contract.dependencies }) return plan