Source code for haive.core.schema.prebuilt.messages.token_usage
"""Token usage tracking schema for LLM interactions.This module provides schemas for tracking token usage, costs, and capacityacross different LLM providers and models. It supports comprehensive metricsincluding cached tokens, audio tokens, and reasoning tokens."""fromtypingimportSelffromlangchain_core.messagesimportAIMessage,BaseMessagefrompydanticimportBaseModel,Field,model_validator
[docs]classTokenUsage(BaseModel):"""Comprehensive token usage tracking with cost calculation. This class tracks all aspects of token usage including: - Input/output/total tokens - Cached tokens (for providers that support caching) - Audio tokens (for multimodal models) - Reasoning tokens (for models with explicit reasoning steps) - Cost calculation based on provider pricing - Capacity percentage for context window management """input_tokens:int=Field(default=0,description="Number of input tokens")output_tokens:int=Field(default=0,description="Number of output tokens")total_tokens:int=Field(default=0,description="Total tokens (input + output)")input_tokens_cached:int|None=Field(default=None,description="Number of cached input tokens (if supported)")audio_tokens:int|None=Field(default=None,description="Number of audio tokens (for multimodal models)")reasoning_tokens:int|None=Field(default=None,description="Number of reasoning tokens (for reasoning models)")input_token_cost:float=Field(default=0.0,description="Cost of input tokens")output_token_cost:float=Field(default=0.0,description="Cost of output tokens")total_cost:float=Field(default=0.0,description="Total cost")capacity_percentage:float=Field(default=0.0,description="Percentage of model's context window used")
[docs]@model_validator(mode="after")defvalidate_totals(self)->Self:"""Ensure total_tokens and total_cost are calculated if not set."""ifself.total_tokens==0:self.total_tokens=self.input_tokens+self.output_tokensifself.total_cost==0.0:self.total_cost=self.input_token_cost+self.output_token_costreturnself
[docs]defadd(self,other:"TokenUsage")->"TokenUsage":"""Add two TokenUsage instances together."""returnTokenUsage(input_tokens=self.input_tokens+other.input_tokens,output_tokens=self.output_tokens+other.output_tokens,total_tokens=self.total_tokens+other.total_tokens,input_tokens_cached=((self.input_tokens_cachedor0)+(other.input_tokens_cachedor0)ifself.input_tokens_cachedisnotNoneorother.input_tokens_cachedisnotNoneelseNone),audio_tokens=((self.audio_tokensor0)+(other.audio_tokensor0)ifself.audio_tokensisnotNoneorother.audio_tokensisnotNoneelseNone),reasoning_tokens=((self.reasoning_tokensor0)+(other.reasoning_tokensor0)ifself.reasoning_tokensisnotNoneorother.reasoning_tokensisnotNoneelseNone),input_token_cost=self.input_token_cost+other.input_token_cost,output_token_cost=self.output_token_cost+other.output_token_cost,total_cost=self.total_cost+other.total_cost,capacity_percentage=max(self.capacity_percentage,other.capacity_percentage),)
def__add__(self,other:"TokenUsage")->"TokenUsage":"""Support + operator for TokenUsage instances."""returnself.add(other)
[docs]defextract_token_usage_from_message(message:BaseMessage,provider:str|None=None)->TokenUsage|None:"""Extract token usage information from a message. Args: message: The message to extract usage from provider: Optional provider name for provider-specific handling Returns: TokenUsage instance if usage info found, None otherwise """ifnotisinstance(message,AIMessage):returnNoneifhasattr(message,"usage_metadata")andmessage.usage_metadata:metadata=message.usage_metadatareturnTokenUsage(input_tokens=metadata.get("input_tokens",0),output_tokens=metadata.get("output_tokens",0),total_tokens=metadata.get("total_tokens",0),input_tokens_cached=metadata.get("input_tokens_cached"),audio_tokens=metadata.get("audio_tokens"),reasoning_tokens=metadata.get("reasoning_tokens"),input_token_cost=metadata.get("input_token_cost",0.0),output_token_cost=metadata.get("output_token_cost",0.0),total_cost=metadata.get("total_cost",0.0),)ifhasattr(message,"response_metadata")andmessage.response_metadata:metadata=message.response_metadataif"usage"inmetadata:usage=metadata["usage"]returnTokenUsage(input_tokens=usage.get("prompt_tokens",0),output_tokens=usage.get("completion_tokens",0),total_tokens=usage.get("total_tokens",0),input_tokens_cached=usage.get("prompt_tokens_cached"),)if"usage"inmetadata:usage=metadata["usage"]returnTokenUsage(input_tokens=usage.get("input_tokens",0),output_tokens=usage.get("output_tokens",0),total_tokens=usage.get("input_tokens",0)+usage.get("output_tokens",0),)if"input_tokens"inmetadataor"prompt_tokens"inmetadata:returnTokenUsage(input_tokens=metadata.get("input_tokens",metadata.get("prompt_tokens",0)),output_tokens=metadata.get("output_tokens",metadata.get("completion_tokens",0)),total_tokens=metadata.get("total_tokens",0),)ifhasattr(message,"additional_kwargs")andmessage.additional_kwargs:kwargs=message.additional_kwargsif"usage"inkwargs:usage=kwargs["usage"]ifisinstance(usage,dict):returnTokenUsage(input_tokens=usage.get("input_tokens",usage.get("prompt_tokens",0)),output_tokens=usage.get("output_tokens",usage.get("completion_tokens",0)),total_tokens=usage.get("total_tokens",0),)returnNone
[docs]defaggregate_token_usage(messages:list[BaseMessage])->TokenUsage:"""Aggregate token usage across multiple messages. Args: messages: List of messages to aggregate usage from Returns: Combined TokenUsage instance """total_usage=TokenUsage()formessageinmessages:usage=extract_token_usage_from_message(message)ifusage:total_usage=total_usage+usagereturntotal_usage
[docs]defcalculate_token_cost(usage:TokenUsage,input_cost_per_1k:float,output_cost_per_1k:float,cached_input_cost_per_1k:float|None=None,)->TokenUsage:"""Calculate costs based on token usage and pricing. Args: usage: TokenUsage instance to calculate costs for input_cost_per_1k: Cost per 1000 input tokens output_cost_per_1k: Cost per 1000 output tokens cached_input_cost_per_1k: Optional cost per 1000 cached input tokens Returns: New TokenUsage instance with calculated costs """input_cost=usage.input_tokens/1000*input_cost_per_1koutput_cost=usage.output_tokens/1000*output_cost_per_1kifusage.input_tokens_cachedandcached_input_cost_per_1kisnotNone:cached_cost=usage.input_tokens_cached/1000*cached_input_cost_per_1kuncached_tokens=usage.input_tokens-usage.input_tokens_cacheduncached_cost=uncached_tokens/1000*input_cost_per_1kinput_cost=cached_cost+uncached_costreturnusage.model_copy(update={"input_token_cost":input_cost,"output_token_cost":output_cost,"total_cost":input_cost+output_cost,})