Source code for haive.core.contracts.tool_config

"""Tool configuration system with contracts and capabilities.

This module provides a focused tool management system extracted from AugLLMConfig,
reducing complexity while adding explicit contracts and capability-based routing.
"""

from typing import Any, Dict, List, Literal, Optional, Set, Union, Callable
from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool, StructuredTool


[docs] class ToolCapability(BaseModel): """Defines what a tool can do at runtime. Attributes: can_read_state: Whether tool can read from state. can_write_state: Whether tool can write to state. can_call_external: Whether tool makes external calls. is_stateful: Whether tool maintains internal state. is_async: Whether tool supports async execution. requires_confirmation: Whether tool needs user confirmation. computational_cost: Relative cost of tool execution. """ can_read_state: bool = Field(default=False, description="Can read from state") can_write_state: bool = Field(default=False, description="Can write to state") can_call_external: bool = Field(default=False, description="Makes external calls") is_stateful: bool = Field(default=False, description="Maintains internal state") is_async: bool = Field(default=False, description="Supports async execution") requires_confirmation: bool = Field(default=False, description="Needs user confirmation") computational_cost: Literal["low", "medium", "high"] = Field( default="low", description="Relative computational cost" )
[docs] class ToolContract(BaseModel): """Contract defining tool's behavior and requirements. Attributes: name: Tool identifier. description: Human-readable description. capabilities: What the tool can do. input_schema: Expected input structure. output_schema: Expected output structure. side_effects: List of potential side effects. required_permissions: Permissions needed to execute. """ name: str = Field(..., description="Tool identifier") description: str = Field(..., description="Human-readable description") capabilities: ToolCapability = Field( default_factory=ToolCapability, description="Tool capabilities" ) input_schema: Optional[type[BaseModel]] = Field( default=None, description="Expected input structure" ) output_schema: Optional[type[BaseModel]] = Field( default=None, description="Expected output structure" ) side_effects: List[str] = Field( default_factory=list, description="Potential side effects" ) required_permissions: Set[str] = Field( default_factory=set, description="Required permissions" )
[docs] class ToolConfig(BaseModel): """Focused tool configuration with contracts. This replaces the scattered tool management in AugLLMConfig (~266 lines) with a focused, contract-based approach. Attributes: tools: List of tools to configure. contracts: Tool contracts by name. routing_strategy: How to route tool calls. force_tool_use: Whether to force tool usage. specific_tool: Force specific tool selection. tool_choice_mode: Tool selection mode. allow_parallel: Whether to allow parallel tool execution. max_retries: Maximum retry attempts for failed tools. timeout_seconds: Tool execution timeout. Examples: Basic tool configuration: >>> config = ToolConfig( ... tools=[calculator, web_search], ... routing_strategy="capability" ... ) With contracts: >>> config = ToolConfig( ... tools=[data_processor], ... contracts={ ... "data_processor": ToolContract( ... name="data_processor", ... description="Process data", ... capabilities=ToolCapability( ... can_write_state=True, ... computational_cost="high" ... ) ... ) ... } ... ) """ tools: List[Union[BaseTool, StructuredTool, type[BaseModel], Callable, str]] = Field( default_factory=list, description="List of tools to configure" ) contracts: Dict[str, ToolContract] = Field( default_factory=dict, description="Tool contracts by name" ) routing_strategy: Literal["auto", "capability", "priority", "manual"] = Field( default="auto", description="Tool routing strategy" ) force_tool_use: bool = Field( default=False, description="Force tool usage" ) specific_tool: Optional[str] = Field( default=None, description="Force specific tool" ) tool_choice_mode: Literal["auto", "required", "none"] = Field( default="auto", description="Tool selection mode" ) allow_parallel: bool = Field( default=False, description="Allow parallel tool execution" ) max_retries: int = Field( default=3, ge=0, description="Maximum retry attempts" ) timeout_seconds: Optional[float] = Field( default=30.0, gt=0, description="Tool execution timeout" )
[docs] def add_tool( self, tool: Any, contract: Optional[ToolContract] = None ) -> "ToolConfig": """Add a tool with optional contract. Args: tool: Tool to add. contract: Optional tool contract. Returns: Self for chaining. """ if tool not in self.tools: self.tools.append(tool) # Get the tool name for consistent indexing tool_name = self._get_tool_name(tool) if contract: # Use the tool's actual name for indexing, not the contract name if tool_name: self.contracts[tool_name] = contract else: self.contracts[contract.name] = contract else: # Auto-generate basic contract if tool_name and tool_name not in self.contracts: self.contracts[tool_name] = ToolContract( name=tool_name, description=self._get_tool_description(tool) ) return self
[docs] def remove_tool(self, tool: Any) -> "ToolConfig": """Remove a tool and its contract. Args: tool: Tool to remove. Returns: Self for chaining. """ if tool in self.tools: self.tools.remove(tool) tool_name = self._get_tool_name(tool) if tool_name in self.contracts: del self.contracts[tool_name] return self
[docs] def get_tools_by_capability( self, capability: str, value: bool = True ) -> List[Any]: """Get tools matching a capability. Args: capability: Capability field name. value: Expected capability value. Returns: List of matching tools. """ matching = [] for tool in self.tools: tool_name = self._get_tool_name(tool) if tool_name in self.contracts: contract = self.contracts[tool_name] if getattr(contract.capabilities, capability, None) == value: matching.append(tool) return matching
[docs] def validate_permissions( self, tool: Any, available_permissions: Set[str] ) -> bool: """Check if tool has required permissions. Args: tool: Tool to validate. available_permissions: Available permissions. Returns: True if permissions are satisfied. """ tool_name = self._get_tool_name(tool) if tool_name not in self.contracts: return True # No contract means no restrictions contract = self.contracts[tool_name] return contract.required_permissions.issubset(available_permissions)
[docs] def get_safe_tools(self) -> List[Any]: """Get tools that don't have side effects. Returns: List of safe tools. """ safe = [] for tool in self.tools: tool_name = self._get_tool_name(tool) if tool_name in self.contracts: contract = self.contracts[tool_name] if not contract.side_effects and not contract.capabilities.can_write_state: safe.append(tool) else: # No contract means we can't guarantee safety continue return safe
def _get_tool_name(self, tool: Any) -> Optional[str]: """Extract tool name from various tool types. Args: tool: Tool to get name from. Returns: Tool name or None. """ if isinstance(tool, str): return tool elif hasattr(tool, "name"): return tool.name elif hasattr(tool, "__name__"): return tool.__name__ elif isinstance(tool, type): return tool.__name__ return None def _get_tool_description(self, tool: Any) -> str: """Extract tool description. Args: tool: Tool to get description from. Returns: Tool description or empty string. """ if hasattr(tool, "description"): return tool.description elif hasattr(tool, "__doc__") and tool.__doc__: return tool.__doc__.split("\n")[0].strip() return ""
[docs] def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for serialization. Returns: Configuration as dictionary. """ return { "tools": [self._get_tool_name(t) for t in self.tools], "contracts": { name: contract.model_dump() for name, contract in self.contracts.items() }, "routing_strategy": self.routing_strategy, "force_tool_use": self.force_tool_use, "specific_tool": self.specific_tool, "tool_choice_mode": self.tool_choice_mode, "allow_parallel": self.allow_parallel, "max_retries": self.max_retries, "timeout_seconds": self.timeout_seconds }