Source code for haive.core.utils.tool_list
"""Tool list implementation for Haive Core.
This module provides specialized tool collection management and utilities.
"""
import inspect
from collections.abc import Callable, Sequence
from typing import Any
from langchain_core.tools import BaseTool, BaseToolkit, StructuredTool
from pydantic import BaseModel, Field, model_validator
from haive.core.utils.haive_collections import NamedDict
[docs]
class ToolList(NamedDict):
"""A collection of tools that inherits from NamedDict.
Provides specialized handling for:
- BaseTool classes and instances
- BaseToolkit instances (automatically expands tools)
- StructuredTool instances
- Pydantic BaseModel classes (kept as classes)
- Callable functions
"""
# Override default name attributes to include function names
name_attrs: list[str] = Field(default=["name", "__name__", "func_name"])
# Define tool types field
tool_types: dict[str, str] = Field(
default_factory=dict, description="Type information for each tool"
)
# Define tools field for clear typing and validation
tools: Sequence[
type[BaseTool]
| type[BaseModel]
| Callable
| StructuredTool
| BaseModel
| BaseTool
| BaseToolkit
] = Field(
default_factory=list,
description="The tools to use (BaseTool, BaseToolkit, BaseModel, or Callable)",
)
model_config = {"arbitrary_types_allowed": True}
# Add a custom __new__ method to handle positional arguments
def __new__(cls, arg=None, **kwargs):
if arg is not None and not isinstance(arg, dict) and "tools" not in kwargs:
kwargs["tools"] = arg
return super().__new__(cls)
[docs]
@model_validator(mode="before")
@classmethod
def process_tools(cls, data: Any) -> Any:
"""Process tools input and expand toolkits."""
# If this is a sequence without proper names, convert to dictionary
# form
if isinstance(data, list | tuple):
# Expand toolkits and extract tools
expanded_tools = []
for tool in data:
# Handle toolkits by expanding their tools
if isinstance(tool, BaseToolkit):
try:
toolkit_tools = tool.get_tools()
expanded_tools.extend(toolkit_tools)
except Exception:
pass
else:
expanded_tools.append(tool)
# Now let NamedDict's validator build the dictionary from the
# expanded tools
return expanded_tools
# If we have a dictionary with 'tools' key
if (
isinstance(data, dict)
and "tools" in data
and isinstance(data["tools"], list | tuple)
):
# Extract the tools
tools_list = data["tools"]
# Expand toolkits
expanded_tools = []
for tool in tools_list:
if isinstance(tool, BaseToolkit):
try:
toolkit_tools = tool.get_tools()
expanded_tools.extend(toolkit_tools)
except Exception:
pass
else:
expanded_tools.append(tool)
# Replace original tools with expanded ones
data["tools"] = expanded_tools
# Keep values field if already present for NamedDict
if "values" not in data:
# Create values dictionary from tools
values = {}
name_attrs = data.get("name_attrs", ["name", "__name__", "func_name"])
for tool in expanded_tools:
# Extract name
name = cls._extract_key(tool, name_attrs)
if name:
values[name] = tool
data["values"] = values
return data
[docs]
def model_post_init(self, __context) -> None:
"""Build tool type information after initialization."""
# Initialize tool_types
self.tool_types = {}
# Map all tools to their types
for name, tool in self.values.items():
self.tool_types[name] = self._determine_tool_type(tool)
# Set tools field to match values for proper typing
self.tools = list(self.values.values())
# Process tool-specific operations (expand toolkits but keep model
# classes as classes)
self._process_tool_types()
@classmethod
def _determine_tool_type(cls, tool: Any) -> str:
"""Determine the type of a tool.
Args:
tool: The tool to analyze
Returns:
String representing tool type
"""
# Check tool instance types
if isinstance(tool, BaseTool):
return "base_tool_instance"
if isinstance(tool, StructuredTool):
return "structured_tool_instance"
if isinstance(tool, BaseModel):
return "model_instance"
if isinstance(tool, BaseToolkit):
return "toolkit"
# Check tool class types
if inspect.isclass(tool):
if issubclass(tool, BaseTool):
return "base_tool_class"
if issubclass(tool, BaseModel):
return "model_class"
if issubclass(tool, BaseToolkit):
return "toolkit_class"
# Check callable
if callable(tool):
return "callable"
return "unknown"
def _process_tool_types(self) -> None:
"""Process tools based on their types.
Expands toolkits but keeps model classes as classes.
"""
# Process toolkit classes and instances by expanding their tools
for name, tool_type in list(self.tool_types.items()):
if tool_type in ["toolkit", "toolkit_class"] and name in self.values:
tool = self.values[name]
try:
# Instantiate if it's a class
toolkit = tool() if tool_type == "toolkit_class" else tool
# Get tools from the toolkit
toolkit_tools = toolkit.get_tools()
# Remove the toolkit
del self.values[name]
del self.tool_types[name]
# Add the toolkit's tools
for t in toolkit_tools:
self.add(t)
except Exception:
pass
# Update tools list to match values
self.tools = list(self.values.values())
[docs]
def add(self, tool: Any, key: str | None = None) -> str:
"""Add a tool with automatic or explicit key.
Args:
tool: Tool to add
key: Optional explicit key
Returns:
Key used for the tool
"""
# Handle toolkit by expanding its tools
if isinstance(tool, BaseToolkit):
added_keys = []
try:
toolkit_tools = tool.get_tools()
for t in toolkit_tools:
added_key = self.add(t)
added_keys.append(added_key)
return added_keys[0] if added_keys else ""
except Exception:
pass
# Use parent add method for normal tools
tool_key = super().add(tool, key)
# Store tool type
self.tool_types[tool_key] = self._determine_tool_type(tool)
# Update tools list to match values
self.tools = list(self.values.values())
return tool_key
[docs]
def update(self, items: Any) -> None:
"""Update with new tools.
Args:
items: Dictionary or sequence of tools
"""
# Expand toolkits if this is a sequence
if isinstance(items, list | tuple):
expanded_items = []
for item in items:
if isinstance(item, BaseToolkit):
try:
toolkit_tools = item.get_tools()
expanded_items.extend(toolkit_tools)
except Exception:
pass
else:
expanded_items.append(item)
# Update with expanded items
super().update(expanded_items)
else:
# Use parent update method
super().update(items)
# Update tool types for new items
for key, value in self.values.items():
if key not in self.tool_types:
self.tool_types[key] = self._determine_tool_type(value)
# Update tools list to match values
self.tools = list(self.values.values())
# Process toolkits but keep model classes as classes
self._process_tool_types()
[docs]
def get_tool_type(self, name: str) -> str | None:
"""Get type of a specific tool.
Args:
name: Tool name
Returns:
Tool type string or None if not found
"""
return self.tool_types.get(name)
[docs]
def get_by_tool_type(self, tool_type: str) -> list[Any]:
"""Get all tools of a specified type.
Args:
tool_type: Type to filter by
Returns:
List of tools matching the type
"""
result = []
for name, type_value in self.tool_types.items():
if type_value == tool_type and name in self.values:
result.append(self.values[name])
return result
[docs]
def get_tool_type_mapping(self) -> dict[str, list[str]]:
"""Get mapping of tool types to tool names.
Returns:
Dictionary mapping tool types to lists of tool names
"""
result = {}
for name, tool_type in self.tool_types.items():
if tool_type not in result:
result[tool_type] = []
result[tool_type].append(name)
return result
[docs]
def get_tool(self, name: str) -> Any | None:
"""Get a tool by name.
Args:
name: Tool name
Returns:
Tool if found, None otherwise
"""
return self.get(name)
[docs]
def get_tool_info(self, name: str) -> dict[str, Any]:
"""Get comprehensive information about a tool.
Args:
name: Tool name
Returns:
Dictionary with tool information
"""
if name not in self.values:
return {"found": False}
tool = self.values[name]
tool_type = self.tool_types.get(name, "unknown")
info = {"found": True, "name": name, "tool_type": tool_type, "tool": tool}
# Add tool-specific information
if hasattr(tool, "description"):
info["description"] = tool.description
if hasattr(tool, "args_schema"):
info["has_schema"] = True
info["schema"] = tool.args_schema
if tool_type in ["model_class", "model_instance"]:
info["is_model"] = True
# For model classes, get field info
if (
tool_type == "model_class"
and inspect.isclass(tool)
and issubclass(tool, BaseModel)
):
fields = {}
for field_name, field in tool.model_fields.items():
fields[field_name] = {
"type": str(field.annotation),
"required": field.is_required(),
"description": field.description or "",
}
info["fields"] = fields
return info
[docs]
def get_model_classes(self) -> dict[str, type[BaseModel]]:
"""Get all model classes in the tool list.
Returns:
Dictionary mapping name to model class
"""
result = {}
for name, tool_type in self.tool_types.items():
if tool_type == "model_class" and name in self.values:
result[name] = self.values[name]
return result
[docs]
def get_model_instances(self) -> dict[str, BaseModel]:
"""Get all model instances in the tool list.
Returns:
Dictionary mapping name to model instance
"""
result = {}
for name, tool_type in self.tool_types.items():
if tool_type == "model_instance" and name in self.values:
result[name] = self.values[name]
return result
[docs]
def get_tools_by_category(self) -> dict[str, dict[str, Any]]:
"""Get tools organized by category.
Returns:
Dictionary with tools grouped by type
"""
categories = {
"tools": {}, # BaseTool instances
"models": {}, # Model classes and instances
"callables": {}, # Function callables
}
for name, tool in self.values.items():
tool_type = self.tool_types.get(name)
if tool_type in [
"base_tool_instance",
"structured_tool_instance",
"base_tool_class",
]:
categories["tools"][name] = tool
elif tool_type in ["model_class", "model_instance"]:
categories["models"][name] = tool
elif tool_type == "callable":
categories["callables"][name] = tool
return categories
[docs]
def to_list(self) -> list[Any]:
"""Convert to a simple list of tools.
Returns:
List of all tools
"""
return list(self.values.values())
def __delitem__(self, key: str) -> None:
"""Delete tool by name."""
super().__delitem__(key)
# Also cleanup tool_types
if key in self.tool_types:
del self.tool_types[key]
# Update tools list to match values
self.tools = list(self.values.values())