Add cf_ipcountry header to Api class and enhance output_tokens handling in Usage class

This commit is contained in:
hlohaus 2025-10-04 22:17:07 +02:00
parent 8366a2956a
commit d309e0df1d
3 changed files with 9 additions and 4 deletions

View file

@ -423,7 +423,8 @@ class Api:
credentials: Annotated[HTTPAuthorizationCredentials, Depends(Api.security)] = None,
provider: str = None,
conversation_id: str = None,
x_user: Annotated[str | None, Header()] = None
x_user: Annotated[str | None, Header()] = None,
cf_ipcountry: Annotated[str | None, Header()] = None
):
if provider is not None and provider not in Provider.__map__:
if provider in model_map:
@ -477,7 +478,7 @@ class Api:
**{
"conversation_id": None,
"conversation": conversation,
"user": x_user,
"user": f"{cf_ipcountry}:{x_user}" if cf_ipcountry else x_user,
}
},
ignored=AppConfig.ignored_providers

View file

@ -203,6 +203,7 @@ class Usage(JsonMixin, HiddenResponse):
completionTokens: int = None,
input_tokens: int = None,
output_tokens: int = None,
output_tokens_details: Dict = None,
**kwargs
):
if promptTokens is not None:
@ -213,6 +214,9 @@ class Usage(JsonMixin, HiddenResponse):
kwargs["prompt_tokens"] = input_tokens
if output_tokens is not None:
kwargs["completion_tokens"] = output_tokens
if output_tokens_details is not None:
for key, value in output_tokens_details.items():
kwargs[key] = value
if "total_tokens" not in kwargs and "prompt_tokens" in kwargs and "completion_tokens" in kwargs:
kwargs["total_tokens"] = kwargs["prompt_tokens"] + kwargs["completion_tokens"]
return super().__init__(**kwargs)

View file

@ -251,7 +251,7 @@ async def async_iter_run_tools(
elif isinstance(chunk, Sources):
sources = None
elif isinstance(chunk, str):
completion_tokens += 1
completion_tokens += round(len(chunk.encode("utf-8"))/4)
elif isinstance(chunk, ProviderInfo):
usage_model = getattr(chunk, "model", usage_model)
usage_provider = getattr(chunk, "name", usage_provider)
@ -360,7 +360,7 @@ def iter_run_tools(
elif isinstance(chunk, Sources):
sources = None
elif isinstance(chunk, str):
completion_tokens += 1
completion_tokens += round(len(chunk.encode("utf-8"))/4)
elif isinstance(chunk, ProviderInfo):
usage_model = getattr(chunk, "model", usage_model)
usage_provider = getattr(chunk, "name", usage_provider)