mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-05 18:20:35 -08:00
Enhance Copilot provider with cookie handling and user identity support; add Kimi provider; refactor usage tracking in run_tools.
This commit is contained in:
parent
c5670047b6
commit
edfc0e7c79
7 changed files with 188 additions and 76 deletions
|
|
@ -30,6 +30,7 @@ from ..image import to_bytes, is_accepted_format
|
|||
from .helper import get_last_user_message
|
||||
from ..files import get_bucket_dir
|
||||
from ..tools.files import read_bucket
|
||||
from ..cookies import get_cookies
|
||||
from pathlib import Path
|
||||
from .. import debug
|
||||
|
||||
|
|
@ -54,31 +55,38 @@ def extract_bucket_items(messages: Messages) -> list[dict]:
|
|||
class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||
label = "Microsoft Copilot"
|
||||
url = "https://copilot.microsoft.com"
|
||||
cookie_domain = ".microsoft.com"
|
||||
anon_cookie_name = "__Host-copilot-anon"
|
||||
|
||||
working = True
|
||||
supports_stream = True
|
||||
working = has_curl_cffi
|
||||
use_nodriver = has_nodriver
|
||||
active_by_default = True
|
||||
|
||||
default_model = "Copilot"
|
||||
models = [default_model, "Think Deeper", "Smart (GPT-5)"]
|
||||
models = [default_model, "Think Deeper", "Smart (GPT-5)", "Study"]
|
||||
model_aliases = {
|
||||
"o1": "Think Deeper",
|
||||
"gpt-4": default_model,
|
||||
"gpt-4o": default_model,
|
||||
"gpt-5": "GPT-5",
|
||||
"study": "Study",
|
||||
}
|
||||
|
||||
websocket_url = "wss://copilot.microsoft.com/c/api/chat?api-version=2"
|
||||
conversation_url = f"{url}/c/api/conversations"
|
||||
|
||||
_access_token: str = None
|
||||
_useridentitytype: str = None
|
||||
_cookies: dict = {}
|
||||
|
||||
@classmethod
|
||||
async def on_auth_async(cls, **kwargs) -> AsyncIterator:
|
||||
async def on_auth_async(cls, api_key: str = None, **kwargs) -> AsyncIterator:
|
||||
cookies = cls.cookies_to_dict()
|
||||
if api_key:
|
||||
cookies[cls.anon_cookie_name] = api_key
|
||||
yield AuthResult(
|
||||
api_key=cls._access_token,
|
||||
cookies=cls.cookies_to_dict()
|
||||
access_token=cls._access_token,
|
||||
cookies=cls.cookies_to_dict() or get_cookies(cls.cookie_domain, False)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -89,7 +97,7 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
|||
auth_result: AuthResult,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
cls._access_token = getattr(auth_result, "api_key")
|
||||
cls._access_token = getattr(auth_result, "access_token", None)
|
||||
cls._cookies = getattr(auth_result, "cookies")
|
||||
async for chunk in cls.create(model, messages, **kwargs):
|
||||
yield chunk
|
||||
|
|
@ -110,8 +118,6 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
|||
media: MediaListType = None,
|
||||
conversation: BaseConversation = None,
|
||||
return_conversation: bool = True,
|
||||
useridentitytype: str = "google",
|
||||
api_key: str = None,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
if not has_curl_cffi:
|
||||
|
|
@ -120,20 +126,19 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
|||
websocket_url = cls.websocket_url
|
||||
headers = None
|
||||
if cls._access_token or cls.needs_auth:
|
||||
if api_key is not None:
|
||||
cls._access_token = api_key
|
||||
if cls._access_token is None:
|
||||
try:
|
||||
cls._access_token, cls._cookies = readHAR(cls.url)
|
||||
cls._access_token, cls._useridentitytype, cls._cookies = readHAR(cls.url)
|
||||
except NoValidHarFileError as h:
|
||||
debug.log(f"Copilot: {h}")
|
||||
if has_nodriver:
|
||||
yield RequestLogin(cls.label, os.environ.get("G4F_LOGIN_URL", ""))
|
||||
cls._access_token, cls._cookies = await get_access_token_and_cookies(cls.url, proxy)
|
||||
cls._access_token, cls._useridentitytype, cls._cookies = await get_access_token_and_cookies(cls.url, proxy)
|
||||
else:
|
||||
raise h
|
||||
websocket_url = f"{websocket_url}&accessToken={quote(cls._access_token)}&X-UserIdentityType={quote(useridentitytype)}"
|
||||
websocket_url = f"{websocket_url}&accessToken={quote(cls._access_token)}" + (f"&X-UserIdentityType={quote(cls._useridentitytype)}" if cls._useridentitytype else "")
|
||||
headers = {"authorization": f"Bearer {cls._access_token}"}
|
||||
|
||||
|
||||
async with AsyncSession(
|
||||
timeout=timeout,
|
||||
|
|
@ -142,31 +147,64 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
|||
headers=headers,
|
||||
cookies=cls._cookies,
|
||||
) as session:
|
||||
if cls._access_token is not None:
|
||||
cls._cookies = session.cookies.jar if hasattr(session.cookies, "jar") else session.cookies
|
||||
response = await session.get("https://copilot.microsoft.com/c/api/user?api-version=2", headers={"x-useridentitytype": useridentitytype})
|
||||
if response.status_code == 401:
|
||||
raise MissingAuthError("Status 401: Invalid access token")
|
||||
response.raise_for_status()
|
||||
user = response.json().get('firstName')
|
||||
if user is None:
|
||||
if cls.needs_auth:
|
||||
raise MissingAuthError("No user found, please login first")
|
||||
cls._access_token = None
|
||||
else:
|
||||
debug.log(f"Copilot: User: {user}")
|
||||
cls._cookies = session.cookies.jar if hasattr(session.cookies, "jar") else session.cookies
|
||||
if conversation is None:
|
||||
response = await session.post(cls.conversation_url, headers={"x-useridentitytype": useridentitytype} if cls._access_token else {})
|
||||
# har_file = os.path.join(os.path.dirname(__file__), "copilot", "copilot.microsoft.com.har")
|
||||
# with open(har_file, "r") as f:
|
||||
# har_entries = json.load(f).get("log", {}).get("entries", [])
|
||||
# conversationId = ""
|
||||
# for har_entry in har_entries:
|
||||
# if har_entry.get("request"):
|
||||
# if "/c/api/" in har_entry.get("request").get("url", ""):
|
||||
# try:
|
||||
# response = await getattr(session, har_entry.get("request").get("method").lower())(
|
||||
# har_entry.get("request").get("url", "").replace("cvqBJw7kyPAp1RoMTmzC6", conversationId),
|
||||
# data=har_entry.get("request").get("postData", {}).get("text"),
|
||||
# headers={header["name"]: header["value"] for header in har_entry.get("request").get("headers")}
|
||||
# )
|
||||
# response.raise_for_status()
|
||||
# if response.headers.get("content-type", "").startswith("application/json"):
|
||||
# conversationId = response.json().get("currentConversationId", conversationId)
|
||||
# except Exception as e:
|
||||
# debug.log(f"Copilot: Failed request to {har_entry.get('request').get('url', '')}: {e}")
|
||||
data = {
|
||||
"timeZone": "America/Los_Angeles",
|
||||
"startNewConversation": True,
|
||||
"teenSupportEnabled": True,
|
||||
"correctPersonalizationSetting": True,
|
||||
"performUserMerge": True,
|
||||
"deferredDataUseCapable": True
|
||||
}
|
||||
response = await session.post(
|
||||
"https://copilot.microsoft.com/c/api/start",
|
||||
headers={
|
||||
"content-type": "application/json",
|
||||
**({"x-useridentitytype": cls._useridentitytype} if cls._useridentitytype else {}),
|
||||
**(headers or {})
|
||||
},
|
||||
json=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
conversation_id = response.json().get("id")
|
||||
conversation = Conversation(conversation_id)
|
||||
debug.log(f"Copilot: Created conversation: {conversation_id}")
|
||||
conversation = Conversation(response.json().get("currentConversationId"))
|
||||
debug.log(f"Copilot: Created conversation: {conversation.conversation_id}")
|
||||
else:
|
||||
conversation_id = conversation.conversation_id
|
||||
debug.log(f"Copilot: Use conversation: {conversation_id}")
|
||||
debug.log(f"Copilot: Use conversation: {conversation.conversation_id}")
|
||||
if return_conversation:
|
||||
yield conversation
|
||||
|
||||
# response = await session.get("https://copilot.microsoft.com/c/api/user?api-version=4", headers={"x-useridentitytype": useridentitytype} if cls._access_token else {})
|
||||
# if response.status_code == 401:
|
||||
# raise MissingAuthError("Status 401: Invalid session")
|
||||
# response.raise_for_status()
|
||||
# print(response.json())
|
||||
# user = response.json().get('firstName')
|
||||
# if user is None:
|
||||
# if cls.needs_auth:
|
||||
# raise MissingAuthError("No user found, please login first")
|
||||
# cls._access_token = None
|
||||
# else:
|
||||
# debug.log(f"Copilot: User: {user}")
|
||||
|
||||
uploaded_attachments = []
|
||||
if cls._access_token is not None:
|
||||
# Upload regular media (images)
|
||||
|
|
@ -178,7 +216,7 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
|||
headers={
|
||||
"content-type": is_accepted_format(data),
|
||||
"content-length": str(len(data)),
|
||||
**({"x-useridentitytype": useridentitytype} if cls._access_token else {})
|
||||
**({"x-useridentitytype": cls._useridentitytype} if cls._useridentitytype else {})
|
||||
},
|
||||
data=data
|
||||
)
|
||||
|
|
@ -201,7 +239,7 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
|||
response = await session.post(
|
||||
"https://copilot.microsoft.com/c/api/attachments",
|
||||
multipart=data,
|
||||
headers={"x-useridentitytype": useridentitytype}
|
||||
headers={"x-useridentitytype": cls._useridentitytype} if cls._useridentitytype else {}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
|
@ -225,7 +263,7 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
|||
mode = "chat"
|
||||
await wss.send(json.dumps({
|
||||
"event": "send",
|
||||
"conversationId": conversation_id,
|
||||
"conversationId": conversation.conversation_id,
|
||||
"content": [*uploaded_attachments, {
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
|
|
@ -285,6 +323,7 @@ async def get_access_token_and_cookies(url: str, proxy: str = None):
|
|||
try:
|
||||
page = await browser.get(url)
|
||||
access_token = None
|
||||
useridentitytype = None
|
||||
while access_token is None:
|
||||
for _ in range(2):
|
||||
await asyncio.sleep(3)
|
||||
|
|
@ -292,9 +331,12 @@ async def get_access_token_and_cookies(url: str, proxy: str = None):
|
|||
(() => {
|
||||
for (var i = 0; i < localStorage.length; i++) {
|
||||
try {
|
||||
item = JSON.parse(localStorage.getItem(localStorage.key(i)));
|
||||
const key = localStorage.key(i);
|
||||
const item = JSON.parse(localStorage.getItem(key));
|
||||
if (item?.body?.access_token) {
|
||||
return item.body.access_token;
|
||||
return ["" + item?.body?.access_token, "google"];
|
||||
} else if (key.includes("chatai")) {
|
||||
return "" + item.secret;
|
||||
}
|
||||
} catch(e) {}
|
||||
}
|
||||
|
|
@ -302,16 +344,24 @@ async def get_access_token_and_cookies(url: str, proxy: str = None):
|
|||
""")
|
||||
if access_token is None:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
if isinstance(access_token, list):
|
||||
access_token, useridentitytype = access_token
|
||||
access_token = access_token.get("value") if isinstance(access_token, dict) else access_token
|
||||
useridentitytype = useridentitytype.get("value") if isinstance(useridentitytype, dict) else None
|
||||
print(f"Got access token: {access_token[:10]}..., useridentitytype: {useridentitytype}")
|
||||
break
|
||||
cookies = {}
|
||||
for c in await page.send(nodriver.cdp.network.get_cookies([url])):
|
||||
cookies[c.name] = c.value
|
||||
stop_browser()
|
||||
return access_token, cookies
|
||||
return access_token, useridentitytype, cookies
|
||||
finally:
|
||||
stop_browser()
|
||||
|
||||
def readHAR(url: str):
|
||||
api_key = None
|
||||
useridentitytype = None
|
||||
cookies = None
|
||||
for path in get_har_files():
|
||||
with open(path, 'rb') as file:
|
||||
|
|
@ -325,9 +375,11 @@ def readHAR(url: str):
|
|||
v_headers = get_headers(v)
|
||||
if "authorization" in v_headers:
|
||||
api_key = v_headers["authorization"].split(maxsplit=1).pop()
|
||||
if "x-useridentitytype" in v_headers:
|
||||
useridentitytype = v_headers["x-useridentitytype"]
|
||||
if v['request']['cookies']:
|
||||
cookies = {c['name']: c['value'] for c in v['request']['cookies']}
|
||||
if api_key is None:
|
||||
raise NoValidHarFileError("No access token found in .har files")
|
||||
|
||||
return api_key, cookies
|
||||
return api_key, useridentitytype, cookies
|
||||
|
|
@ -39,6 +39,7 @@ except ImportError as e:
|
|||
from .deprecated.ARTA import ARTA
|
||||
from .deprecated.Blackbox import Blackbox
|
||||
from .deprecated.DuckDuckGo import DuckDuckGo
|
||||
from .deprecated.Kimi import Kimi
|
||||
from .deprecated.PerplexityLabs import PerplexityLabs
|
||||
|
||||
from .ApiAirforce import ApiAirforce
|
||||
|
|
@ -48,7 +49,6 @@ from .Copilot import Copilot
|
|||
from .DeepInfra import DeepInfra
|
||||
from .EasyChat import EasyChat
|
||||
from .GLM import GLM
|
||||
from .Kimi import Kimi
|
||||
from .LambdaChat import LambdaChat
|
||||
from .Mintlify import Mintlify
|
||||
from .OIVSCodeSer import OIVSCodeSer2, OIVSCodeSer0501
|
||||
|
|
|
|||
|
|
@ -3,16 +3,16 @@ from __future__ import annotations
|
|||
import random
|
||||
from typing import AsyncIterator
|
||||
|
||||
from .base_provider import AsyncAuthedProvider, ProviderModelMixin
|
||||
from ..providers.helper import get_last_user_message
|
||||
from ..requests import StreamSession, sse_stream, raise_for_status
|
||||
from ..providers.response import AuthResult, TitleGeneration, JsonConversation, FinishReason
|
||||
from ..typing import AsyncResult, Messages
|
||||
from ..errors import MissingAuthError
|
||||
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
|
||||
from ...providers.helper import get_last_user_message
|
||||
from ...requests import StreamSession, sse_stream, raise_for_status
|
||||
from ...providers.response import AuthResult, TitleGeneration, JsonConversation, FinishReason
|
||||
from ...typing import AsyncResult, Messages
|
||||
from ...errors import MissingAuthError
|
||||
|
||||
class Kimi(AsyncAuthedProvider, ProviderModelMixin):
|
||||
url = "https://www.kimi.com"
|
||||
working = True
|
||||
working = False
|
||||
active_by_default = True
|
||||
default_model = "kimi-k2"
|
||||
models = [default_model]
|
||||
|
|
@ -29,7 +29,7 @@ class CopilotAccount(Copilot):
|
|||
debug.log(f"Copilot: {h}")
|
||||
if has_nodriver:
|
||||
yield RequestLogin(cls.label, os.environ.get("G4F_LOGIN_URL", ""))
|
||||
cls._access_token, cls._cookies = await get_access_token_and_cookies(cls.url, proxy)
|
||||
cls._access_token, cls._useridentitytype, cls._cookies = await get_access_token_and_cookies(cls.url, proxy)
|
||||
else:
|
||||
raise h
|
||||
yield AuthResult(
|
||||
|
|
|
|||
|
|
@ -125,9 +125,9 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
|||
yield ImageResponse([image["url"] for image in data["data"]], prompt)
|
||||
return
|
||||
|
||||
if stream:
|
||||
if stream or stream is None:
|
||||
kwargs.setdefault("stream_options", {"include_usage": True})
|
||||
extra_parameters = filter_none(**{key: kwargs.get(key) for key in extra_parameters})
|
||||
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}
|
||||
if extra_body is None:
|
||||
extra_body = {}
|
||||
data = filter_none(
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class ConversationManager:
|
|||
try:
|
||||
with open(self.file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
if self.model is None:
|
||||
if self.provider is None and self.model is None:
|
||||
self.model = data.get("model")
|
||||
if self.provider is None:
|
||||
self.provider = data.get("provider")
|
||||
|
|
|
|||
|
|
@ -3,12 +3,19 @@ from __future__ import annotations
|
|||
import os
|
||||
import re
|
||||
import json
|
||||
import math
|
||||
import asyncio
|
||||
import time
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, AsyncIterator, Iterator, Dict, Any, Tuple, List, Union
|
||||
|
||||
try:
|
||||
from aiofile import async_open
|
||||
has_aiofile = True
|
||||
except ImportError:
|
||||
has_aiofile = False
|
||||
|
||||
from ..typing import Messages
|
||||
from ..providers.helper import filter_none
|
||||
from ..providers.asyncio import to_async_iterator
|
||||
|
|
@ -254,25 +261,45 @@ async def async_iter_run_tools(
|
|||
response = to_async_iterator(provider.async_create_function(model=model, messages=messages, **kwargs))
|
||||
|
||||
try:
|
||||
model_info = model
|
||||
usage_model = model
|
||||
usage_provider = provider.__name__
|
||||
completion_tokens = 0
|
||||
usage = None
|
||||
async for chunk in response:
|
||||
if isinstance(chunk, ProviderInfo):
|
||||
model_info = getattr(chunk, 'model', model_info)
|
||||
if isinstance(chunk, FinishReason):
|
||||
if sources is not None:
|
||||
yield sources
|
||||
sources = None
|
||||
yield chunk
|
||||
continue
|
||||
elif isinstance(chunk, Sources):
|
||||
sources = None
|
||||
elif isinstance(chunk, str):
|
||||
completion_tokens += 1
|
||||
elif isinstance(chunk, ProviderInfo):
|
||||
usage_model = getattr(chunk, "model", usage_model)
|
||||
usage_provider = getattr(chunk, "name", usage_provider)
|
||||
elif isinstance(chunk, Usage):
|
||||
usage = {"user": kwargs.get("user"), "model": model_info, "provider": provider.get_parent(), **chunk.get_dict()}
|
||||
usage_dir = Path(get_cookies_dir()) / ".usage"
|
||||
usage_file = usage_dir / f"{datetime.date.today()}.jsonl"
|
||||
usage_dir.mkdir(parents=True, exist_ok=True)
|
||||
with usage_file.open("a" if usage_file.exists() else "w") as f:
|
||||
f.write(f"{json.dumps(usage)}\n")
|
||||
usage = chunk
|
||||
yield chunk
|
||||
provider.live += 1
|
||||
if has_aiofile:
|
||||
if usage is None:
|
||||
usage = get_usage(messages, completion_tokens)
|
||||
yield usage
|
||||
usage = {"user": kwargs.get("user"), "model": usage_model, "provider": usage_provider, **usage.get_dict()}
|
||||
usage_dir = Path(get_cookies_dir()) / ".usage"
|
||||
usage_file = usage_dir / f"{datetime.date.today()}.jsonl"
|
||||
usage_dir.mkdir(parents=True, exist_ok=True)
|
||||
async with async_open(usage_file, "a") as f:
|
||||
await f.write(f"{json.dumps(usage)}\n")
|
||||
if completion_tokens > 0:
|
||||
provider.live += 1
|
||||
except:
|
||||
provider.live -= 1
|
||||
raise
|
||||
|
||||
# Yield sources if available
|
||||
if sources:
|
||||
if sources is not None:
|
||||
yield sources
|
||||
|
||||
def iter_run_tools(
|
||||
|
|
@ -340,10 +367,13 @@ def iter_run_tools(
|
|||
messages[-1]["content"] = last_message + BUCKET_INSTRUCTIONS
|
||||
|
||||
# Process response chunks
|
||||
thinking_start_time = 0
|
||||
processor = ThinkingProcessor()
|
||||
model_info = model
|
||||
try:
|
||||
thinking_start_time = 0
|
||||
processor = ThinkingProcessor()
|
||||
usage_model = model
|
||||
usage_provider = provider.__name__
|
||||
completion_tokens = 0
|
||||
usage = None
|
||||
for chunk in provider.create_function(model=model, messages=messages, provider=provider, **kwargs):
|
||||
if isinstance(chunk, FinishReason):
|
||||
if sources is not None:
|
||||
|
|
@ -353,28 +383,58 @@ def iter_run_tools(
|
|||
continue
|
||||
elif isinstance(chunk, Sources):
|
||||
sources = None
|
||||
elif isinstance(chunk, str):
|
||||
completion_tokens += 1
|
||||
elif isinstance(chunk, ProviderInfo):
|
||||
model_info = getattr(chunk, 'model', model_info)
|
||||
usage_model = getattr(chunk, "model", usage_model)
|
||||
usage_provider = getattr(chunk, "name", usage_provider)
|
||||
elif isinstance(chunk, Usage):
|
||||
usage = {"user": kwargs.get("user"), "model": model_info, "provider": provider.get_parent(), **chunk.get_dict()}
|
||||
usage_dir = Path(get_cookies_dir()) / ".usage"
|
||||
usage_file = usage_dir / f"{datetime.date.today()}.jsonl"
|
||||
usage_dir.mkdir(parents=True, exist_ok=True)
|
||||
with usage_file.open("a" if usage_file.exists() else "w") as f:
|
||||
f.write(f"{json.dumps(usage)}\n")
|
||||
usage = chunk
|
||||
if not isinstance(chunk, str):
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
thinking_start_time, results = processor.process_thinking_chunk(chunk, thinking_start_time)
|
||||
|
||||
for result in results:
|
||||
yield result
|
||||
|
||||
provider.live += 1
|
||||
if usage is None:
|
||||
usage = get_usage(messages, completion_tokens)
|
||||
yield usage
|
||||
usage = {"user": kwargs.get("user"), "model": usage_model, "provider": usage_provider, **usage.get_dict()}
|
||||
usage_dir = Path(get_cookies_dir()) / ".usage"
|
||||
usage_file = usage_dir / f"{datetime.date.today()}.jsonl"
|
||||
usage_dir.mkdir(parents=True, exist_ok=True)
|
||||
with usage_file.open("a") as f:
|
||||
f.write(f"{json.dumps(usage)}\n")
|
||||
if completion_tokens > 0:
|
||||
provider.live += 1
|
||||
except:
|
||||
provider.live -= 1
|
||||
raise
|
||||
|
||||
if sources is not None:
|
||||
yield sources
|
||||
|
||||
def caculate_prompt_tokens(messages: Messages) -> int:
|
||||
"""Calculate the total number of tokens in messages"""
|
||||
token_count = 1 # Bos Token
|
||||
for message in messages:
|
||||
if isinstance(message.get("content"), str):
|
||||
token_count += math.floor(len(message["content"].encode("utf-8")) / 4)
|
||||
token_count += 4 # Role and start/end message token
|
||||
elif isinstance(message.get("content"), list):
|
||||
for item in message["content"]:
|
||||
if isinstance(item, str):
|
||||
token_count += math.floor(len(item.encode("utf-8")) / 4)
|
||||
elif isinstance(item, dict) and "text" in item and isinstance(item["text"], str):
|
||||
token_count += math.floor(len(item["text"].encode("utf-8")) / 4)
|
||||
token_count += 4 # Role and start/end message token
|
||||
return token_count
|
||||
|
||||
def get_usage(messages: Messages, completion_tokens: int) -> Usage:
|
||||
prompt_tokens = caculate_prompt_tokens(messages)
|
||||
return Usage(
|
||||
completion_tokens=completion_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue