Source code for haive.core.contracts.tool_registry

"""Central tool registry with contract enforcement.

This module provides a centralized registry for tools with capability-based
lookup and contract enforcement, extracted from scattered tool management.
"""

from typing import Any, Dict, List, Optional, Set, Callable
from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool, StructuredTool

from haive.core.contracts.tool_config import ToolContract, ToolCapability


[docs] class ToolMetadata(BaseModel): """Metadata about a registered tool. Attributes: name: Tool identifier. contract: Tool contract. tags: Categorization tags. version: Tool version. registered_at: Registration timestamp. usage_count: Number of times used. last_used: Last usage timestamp. performance_metrics: Performance statistics. """ name: str = Field(..., description="Tool identifier") contract: ToolContract = Field(..., description="Tool contract") tags: Set[str] = Field(default_factory=set, description="Categorization tags") version: str = Field(default="1.0.0", description="Tool version") registered_at: Optional[str] = Field(default=None, description="Registration timestamp") usage_count: int = Field(default=0, description="Usage count") last_used: Optional[str] = Field(default=None, description="Last usage timestamp") performance_metrics: Dict[str, float] = Field( default_factory=dict, description="Performance metrics" )
[docs] class ToolRegistry(BaseModel): """Central registry for tools with contract enforcement. Provides: - Tool registration with contracts - Capability-based tool discovery - Permission validation - Usage tracking - Performance monitoring Attributes: tools: Registered tools by name. metadata: Tool metadata by name. capability_index: Tools indexed by capability. tag_index: Tools indexed by tag. permission_requirements: Global permission requirements. Examples: Basic registration: >>> registry = ToolRegistry() >>> registry.register( ... name="calculator", ... tool=calculator_tool, ... contract=calculator_contract ... ) Capability-based lookup: >>> safe_tools = registry.find_by_capability( ... "can_write_state", False ... ) """ tools: Dict[str, Any] = Field( default_factory=dict, description="Registered tools by name" ) metadata: Dict[str, ToolMetadata] = Field( default_factory=dict, description="Tool metadata by name" ) capability_index: Dict[str, Set[str]] = Field( default_factory=dict, description="Tools indexed by capability" ) tag_index: Dict[str, Set[str]] = Field( default_factory=dict, description="Tools indexed by tag" ) permission_requirements: Set[str] = Field( default_factory=set, description="Global permission requirements" )
[docs] def register( self, name: str, tool: Any, contract: Optional[ToolContract] = None, tags: Optional[Set[str]] = None, version: str = "1.0.0" ) -> "ToolRegistry": """Register a tool with contract. Args: name: Tool name. tool: Tool instance. contract: Tool contract. tags: Tool tags. version: Tool version. Returns: Self for chaining. """ # Create default contract if not provided if not contract: contract = ToolContract( name=name, description=self._get_tool_description(tool) ) # Store tool self.tools[name] = tool # Create and store metadata from datetime import datetime metadata = ToolMetadata( name=name, contract=contract, tags=tags or set(), version=version, registered_at=datetime.now().isoformat() ) self.metadata[name] = metadata # Update indices self._update_capability_index(name, contract.capabilities) self._update_tag_index(name, tags or set()) # Update global permissions self.permission_requirements.update(contract.required_permissions) return self
[docs] def unregister(self, name: str) -> "ToolRegistry": """Unregister a tool. Args: name: Tool name to unregister. Returns: Self for chaining. """ if name in self.tools: # Remove from main registry del self.tools[name] # Get metadata for cleanup metadata = self.metadata.get(name) if metadata: # Remove from capability index for capability_key in self._get_capability_keys(metadata.contract.capabilities): if capability_key in self.capability_index: self.capability_index[capability_key].discard(name) # Remove from tag index for tag in metadata.tags: if tag in self.tag_index: self.tag_index[tag].discard(name) # Remove metadata del self.metadata[name] return self
[docs] def get_tool(self, name: str) -> Optional[Any]: """Get a tool by name. Args: name: Tool name. Returns: Tool instance or None. """ return self.tools.get(name)
[docs] def find_by_capability( self, capability: str, value: bool = True ) -> List[Any]: """Find tools by capability. Args: capability: Capability field name. value: Expected capability value. Returns: List of matching tools. """ capability_key = f"{capability}={value}" tool_names = self.capability_index.get(capability_key, set()) return [self.tools[name] for name in tool_names if name in self.tools]
[docs] def find_by_tag(self, tag: str) -> List[Any]: """Find tools by tag. Args: tag: Tag to search for. Returns: List of matching tools. """ tool_names = self.tag_index.get(tag, set()) return [self.tools[name] for name in tool_names if name in self.tools]
[docs] def find_safe_tools(self) -> List[Any]: """Find tools without side effects. Returns: List of safe tools. """ safe_tools = [] for name, metadata in self.metadata.items(): contract = metadata.contract if (not contract.side_effects and not contract.capabilities.can_write_state and not contract.capabilities.can_call_external): safe_tools.append(self.tools[name]) return safe_tools
[docs] def find_stateful_tools(self) -> List[Any]: """Find tools that maintain state. Returns: List of stateful tools. """ return self.find_by_capability("is_stateful", True)
[docs] def find_async_tools(self) -> List[Any]: """Find async-capable tools. Returns: List of async tools. """ return self.find_by_capability("is_async", True)
[docs] def validate_permissions( self, tool_name: str, available_permissions: Set[str] ) -> tuple[bool, List[str]]: """Validate tool permissions. Args: tool_name: Tool to validate. available_permissions: Available permissions. Returns: Tuple of (is_valid, missing_permissions). """ metadata = self.metadata.get(tool_name) if not metadata: return False, [f"Tool '{tool_name}' not found"] required = metadata.contract.required_permissions missing = required - available_permissions return len(missing) == 0, list(missing)
[docs] def track_usage(self, tool_name: str, execution_time: float = 0.0) -> None: """Track tool usage. Args: tool_name: Tool that was used. execution_time: Execution time in seconds. """ if tool_name in self.metadata: from datetime import datetime metadata = self.metadata[tool_name] metadata.usage_count += 1 metadata.last_used = datetime.now().isoformat() # Update performance metrics if execution_time > 0: if "avg_execution_time" not in metadata.performance_metrics: metadata.performance_metrics["avg_execution_time"] = execution_time else: # Running average current_avg = metadata.performance_metrics["avg_execution_time"] count = metadata.usage_count new_avg = ((current_avg * (count - 1)) + execution_time) / count metadata.performance_metrics["avg_execution_time"] = new_avg
[docs] def get_usage_stats(self) -> Dict[str, Dict[str, Any]]: """Get usage statistics for all tools. Returns: Usage statistics by tool name. """ stats = {} for name, metadata in self.metadata.items(): stats[name] = { "usage_count": metadata.usage_count, "last_used": metadata.last_used, "performance": metadata.performance_metrics } return stats
[docs] def get_capability_summary(self) -> Dict[str, int]: """Get summary of tool capabilities. Returns: Count of tools by capability. """ summary = {} for capability_key, tool_names in self.capability_index.items(): summary[capability_key] = len(tool_names) return summary
def _update_capability_index( self, tool_name: str, capabilities: ToolCapability ) -> None: """Update capability index. Args: tool_name: Tool name. capabilities: Tool capabilities. """ for capability_key in self._get_capability_keys(capabilities): if capability_key not in self.capability_index: self.capability_index[capability_key] = set() self.capability_index[capability_key].add(tool_name) def _update_tag_index(self, tool_name: str, tags: Set[str]) -> None: """Update tag index. Args: tool_name: Tool name. tags: Tool tags. """ for tag in tags: if tag not in self.tag_index: self.tag_index[tag] = set() self.tag_index[tag].add(tool_name) def _get_capability_keys(self, capabilities: ToolCapability) -> List[str]: """Get capability index keys. Args: capabilities: Tool capabilities. Returns: List of capability keys. """ keys = [] for field_name, field_value in capabilities.model_dump().items(): if isinstance(field_value, bool): keys.append(f"{field_name}={field_value}") elif field_name == "computational_cost": keys.append(f"cost={field_value}") return keys def _get_tool_description(self, tool: Any) -> str: """Extract tool description. Args: tool: Tool to get description from. Returns: Tool description or empty string. """ if hasattr(tool, "description"): return tool.description elif hasattr(tool, "__doc__") and tool.__doc__: return tool.__doc__.split("\n")[0].strip() return ""
[docs] def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for serialization. Returns: Registry as dictionary. """ return { "tools": list(self.tools.keys()), "metadata": { name: metadata.model_dump() for name, metadata in self.metadata.items() }, "capability_summary": self.get_capability_summary(), "usage_stats": self.get_usage_stats(), "permission_requirements": list(self.permission_requirements) }