"""Enhanced validation state for tool message routing and conditional branching."""
import time
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
[docs]
class ValidationStatus(str, Enum):
"""Status of tool validation."""
PENDING = "pending"
VALID = "valid"
INVALID = "invalid"
ERROR = "error"
SKIPPED = "skipped"
[docs]
class RouteRecommendation(str, Enum):
"""Routing recommendations for validated tools."""
EXECUTE = "execute" # Execute the tool
RETRY = "retry" # Retry with corrections
SKIP = "skip" # Skip this tool
REDIRECT = "redirect" # Redirect to different tool
AGENT = "agent" # Return to agent for clarification
END = "end" # End processing
[docs]
class ValidationRoutingState(BaseModel):
"""State for managing validation results and routing decisions."""
# Validation results
tool_validations: dict[str, ToolValidationResult] = Field(
default_factory=dict, description="Validation results keyed by tool_call_id"
)
# Routing decisions
valid_tool_calls: list[str] = Field(
default_factory=list, description="Tool call IDs that passed validation"
)
invalid_tool_calls: list[str] = Field(
default_factory=list, description="Tool call IDs that failed validation"
)
error_tool_calls: list[str] = Field(
default_factory=list, description="Tool call IDs that had validation errors"
)
# Routing state
next_action: RouteRecommendation = Field(
default=RouteRecommendation.EXECUTE,
description="Overall recommendation for next action",
)
target_nodes: set[str] = Field(
default_factory=set, description="Set of target nodes for routing"
)
# Message updates
tool_message_updates: dict[str, dict[str, Any]] = Field(
default_factory=dict, description="Updates to apply to tool messages"
)
# Branch decision data
branch_data: dict[str, Any] = Field(
default_factory=dict, description="Data for conditional branching decisions"
)
# Statistics
total_tools: int = Field(default=0, description="Total number of tool calls")
validation_duration: float = Field(default=0.0, description="Total validation time")
[docs]
def add_validation_result(self, result: ToolValidationResult) -> None:
"""Add a validation result and update routing state."""
self.tool_validations[result.tool_call_id] = result
# Update routing lists
if result.status == ValidationStatus.VALID:
self.valid_tool_calls.append(result.tool_call_id)
elif result.status == ValidationStatus.INVALID:
self.invalid_tool_calls.append(result.tool_call_id)
elif result.status == ValidationStatus.ERROR:
self.error_tool_calls.append(result.tool_call_id)
# Update target nodes
if result.target_node:
self.target_nodes.add(result.target_node)
# Update total count
self.total_tools = len(self.tool_validations)
# Update overall recommendation
self._update_next_action()
# Prepare message updates
self._prepare_message_updates(result)
def _update_next_action(self) -> None:
"""Update the overall next action based on validation results."""
if not self.tool_validations:
self.next_action = RouteRecommendation.END
return
# Check for errors first
if self.error_tool_calls:
self.next_action = RouteRecommendation.AGENT
return
# If we have valid tools, execute them
if self.valid_tool_calls:
self.next_action = RouteRecommendation.EXECUTE
return
# If all invalid, retry or return to agent
if len(self.invalid_tool_calls) == self.total_tools:
# Check if we have corrections
has_corrections = any(
result.corrected_args is not None
for result in self.tool_validations.values()
if result.status == ValidationStatus.INVALID
)
if has_corrections:
self.next_action = RouteRecommendation.RETRY
else:
self.next_action = RouteRecommendation.AGENT
return
# Default to execute
self.next_action = RouteRecommendation.EXECUTE
def _prepare_message_updates(self, result: ToolValidationResult) -> None:
"""Prepare message updates for a validation result."""
updates = {}
# Add validation status to message metadata
updates["validation_status"] = result.status.value
updates["validation_time"] = result.validation_time
# Add errors/warnings if present
if result.errors:
updates["validation_errors"] = result.errors
if result.warnings:
updates["validation_warnings"] = result.warnings
# Add routing information
updates["route_recommendation"] = result.route_recommendation.value
if result.target_node:
updates["target_node"] = result.target_node
if result.engine_name:
updates["engine_name"] = result.engine_name
# Add corrected args if available
if result.corrected_args:
updates["corrected_args"] = result.corrected_args
self.tool_message_updates[result.tool_call_id] = updates
[docs]
def get_routing_decision(self) -> dict[str, Any]:
"""Get routing decision data for conditional branching."""
return {
"next_action": self.next_action.value,
"target_nodes": list(self.target_nodes),
"valid_count": len(self.valid_tool_calls),
"invalid_count": len(self.invalid_tool_calls),
"error_count": len(self.error_tool_calls),
"total_count": self.total_tools,
"has_corrections": any(
"corrected_args" in updates
for updates in self.tool_message_updates.values()
),
"validation_duration": self.validation_duration,
"branch_data": self.branch_data,
}
[docs]
def should_continue_execution(self) -> bool:
"""Check if execution should continue based on validation results."""
return (
self.next_action in [RouteRecommendation.EXECUTE, RouteRecommendation.RETRY]
and len(self.valid_tool_calls) > 0
)
[docs]
def should_return_to_agent(self) -> bool:
"""Check if processing should return to agent."""
return self.next_action == RouteRecommendation.AGENT
[docs]
def should_end_processing(self) -> bool:
"""Check if processing should end."""
return self.next_action == RouteRecommendation.END
[docs]
def get_routing_summary(self) -> str:
"""Get a human-readable summary of routing decisions."""
summary_parts = [
f"Validated {self.total_tools} tool calls",
f"Valid: {len(self.valid_tool_calls)}",
f"Invalid: {len(self.invalid_tool_calls)}",
f"Errors: {len(self.error_tool_calls)}",
]
if self.target_nodes:
summary_parts.append(f"Target nodes: {', '.join(self.target_nodes)}")
summary_parts.append(f"Next action: {self.next_action.value}")
return " | ".join(summary_parts)
[docs]
class ValidationStateManager:
"""Manager for validation state operations."""
[docs]
@staticmethod
def create_validation_result(
tool_call_id: str,
tool_name: str,
status: ValidationStatus,
route_recommendation: RouteRecommendation = RouteRecommendation.EXECUTE,
errors: list[str] | None = None,
warnings: list[str] | None = None,
corrected_args: dict[str, Any] | None = None,
target_node: str | None = None,
engine_name: str | None = None,
priority: int = 0,
metadata: dict[str, Any] | None = None,
) -> ToolValidationResult:
"""Create a validation result with all parameters."""
return ToolValidationResult(
tool_call_id=tool_call_id,
tool_name=tool_name,
status=status,
route_recommendation=route_recommendation,
errors=errors or [],
warnings=warnings or [],
corrected_args=corrected_args,
target_node=target_node,
engine_name=engine_name,
priority=priority,
metadata=metadata or {},
)
[docs]
@staticmethod
def create_routing_state() -> ValidationRoutingState:
"""Create a new validation routing state."""
return ValidationRoutingState()
[docs]
@staticmethod
def merge_routing_states(
states: list[ValidationRoutingState],
) -> ValidationRoutingState:
"""Merge multiple routing states into one."""
merged = ValidationRoutingState()
for state in states:
# Merge validation results
merged.tool_validations.update(state.tool_validations)
# Merge routing lists
merged.valid_tool_calls.extend(state.valid_tool_calls)
merged.invalid_tool_calls.extend(state.invalid_tool_calls)
merged.error_tool_calls.extend(state.error_tool_calls)
# Merge target nodes
merged.target_nodes.update(state.target_nodes)
# Merge message updates
merged.tool_message_updates.update(state.tool_message_updates)
# Merge branch data
merged.branch_data.update(state.branch_data)
# Update duration
merged.validation_duration += state.validation_duration
# Update totals and next action
merged.total_tools = len(merged.tool_validations)
merged._update_next_action()
return merged