mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-15 14:51:19 -08:00
Enhance Copilot provider with cookie handling and user identity support; add Kimi provider; refactor usage tracking in run_tools.
This commit is contained in:
parent
c5670047b6
commit
edfc0e7c79
7 changed files with 188 additions and 76 deletions
|
|
@ -30,6 +30,7 @@ from ..image import to_bytes, is_accepted_format
|
||||||
from .helper import get_last_user_message
|
from .helper import get_last_user_message
|
||||||
from ..files import get_bucket_dir
|
from ..files import get_bucket_dir
|
||||||
from ..tools.files import read_bucket
|
from ..tools.files import read_bucket
|
||||||
|
from ..cookies import get_cookies
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from .. import debug
|
from .. import debug
|
||||||
|
|
||||||
|
|
@ -54,31 +55,38 @@ def extract_bucket_items(messages: Messages) -> list[dict]:
|
||||||
class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
label = "Microsoft Copilot"
|
label = "Microsoft Copilot"
|
||||||
url = "https://copilot.microsoft.com"
|
url = "https://copilot.microsoft.com"
|
||||||
|
cookie_domain = ".microsoft.com"
|
||||||
|
anon_cookie_name = "__Host-copilot-anon"
|
||||||
|
|
||||||
working = True
|
working = has_curl_cffi
|
||||||
supports_stream = True
|
use_nodriver = has_nodriver
|
||||||
active_by_default = True
|
active_by_default = True
|
||||||
|
|
||||||
default_model = "Copilot"
|
default_model = "Copilot"
|
||||||
models = [default_model, "Think Deeper", "Smart (GPT-5)"]
|
models = [default_model, "Think Deeper", "Smart (GPT-5)", "Study"]
|
||||||
model_aliases = {
|
model_aliases = {
|
||||||
"o1": "Think Deeper",
|
"o1": "Think Deeper",
|
||||||
"gpt-4": default_model,
|
"gpt-4": default_model,
|
||||||
"gpt-4o": default_model,
|
"gpt-4o": default_model,
|
||||||
"gpt-5": "GPT-5",
|
"gpt-5": "GPT-5",
|
||||||
|
"study": "Study",
|
||||||
}
|
}
|
||||||
|
|
||||||
websocket_url = "wss://copilot.microsoft.com/c/api/chat?api-version=2"
|
websocket_url = "wss://copilot.microsoft.com/c/api/chat?api-version=2"
|
||||||
conversation_url = f"{url}/c/api/conversations"
|
conversation_url = f"{url}/c/api/conversations"
|
||||||
|
|
||||||
_access_token: str = None
|
_access_token: str = None
|
||||||
|
_useridentitytype: str = None
|
||||||
_cookies: dict = {}
|
_cookies: dict = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def on_auth_async(cls, **kwargs) -> AsyncIterator:
|
async def on_auth_async(cls, api_key: str = None, **kwargs) -> AsyncIterator:
|
||||||
|
cookies = cls.cookies_to_dict()
|
||||||
|
if api_key:
|
||||||
|
cookies[cls.anon_cookie_name] = api_key
|
||||||
yield AuthResult(
|
yield AuthResult(
|
||||||
api_key=cls._access_token,
|
access_token=cls._access_token,
|
||||||
cookies=cls.cookies_to_dict()
|
cookies=cls.cookies_to_dict() or get_cookies(cls.cookie_domain, False)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -89,7 +97,7 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
auth_result: AuthResult,
|
auth_result: AuthResult,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
cls._access_token = getattr(auth_result, "api_key")
|
cls._access_token = getattr(auth_result, "access_token", None)
|
||||||
cls._cookies = getattr(auth_result, "cookies")
|
cls._cookies = getattr(auth_result, "cookies")
|
||||||
async for chunk in cls.create(model, messages, **kwargs):
|
async for chunk in cls.create(model, messages, **kwargs):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
@ -110,8 +118,6 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
media: MediaListType = None,
|
media: MediaListType = None,
|
||||||
conversation: BaseConversation = None,
|
conversation: BaseConversation = None,
|
||||||
return_conversation: bool = True,
|
return_conversation: bool = True,
|
||||||
useridentitytype: str = "google",
|
|
||||||
api_key: str = None,
|
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
if not has_curl_cffi:
|
if not has_curl_cffi:
|
||||||
|
|
@ -120,21 +126,20 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
websocket_url = cls.websocket_url
|
websocket_url = cls.websocket_url
|
||||||
headers = None
|
headers = None
|
||||||
if cls._access_token or cls.needs_auth:
|
if cls._access_token or cls.needs_auth:
|
||||||
if api_key is not None:
|
|
||||||
cls._access_token = api_key
|
|
||||||
if cls._access_token is None:
|
if cls._access_token is None:
|
||||||
try:
|
try:
|
||||||
cls._access_token, cls._cookies = readHAR(cls.url)
|
cls._access_token, cls._useridentitytype, cls._cookies = readHAR(cls.url)
|
||||||
except NoValidHarFileError as h:
|
except NoValidHarFileError as h:
|
||||||
debug.log(f"Copilot: {h}")
|
debug.log(f"Copilot: {h}")
|
||||||
if has_nodriver:
|
if has_nodriver:
|
||||||
yield RequestLogin(cls.label, os.environ.get("G4F_LOGIN_URL", ""))
|
yield RequestLogin(cls.label, os.environ.get("G4F_LOGIN_URL", ""))
|
||||||
cls._access_token, cls._cookies = await get_access_token_and_cookies(cls.url, proxy)
|
cls._access_token, cls._useridentitytype, cls._cookies = await get_access_token_and_cookies(cls.url, proxy)
|
||||||
else:
|
else:
|
||||||
raise h
|
raise h
|
||||||
websocket_url = f"{websocket_url}&accessToken={quote(cls._access_token)}&X-UserIdentityType={quote(useridentitytype)}"
|
websocket_url = f"{websocket_url}&accessToken={quote(cls._access_token)}" + (f"&X-UserIdentityType={quote(cls._useridentitytype)}" if cls._useridentitytype else "")
|
||||||
headers = {"authorization": f"Bearer {cls._access_token}"}
|
headers = {"authorization": f"Bearer {cls._access_token}"}
|
||||||
|
|
||||||
|
|
||||||
async with AsyncSession(
|
async with AsyncSession(
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
proxy=proxy,
|
proxy=proxy,
|
||||||
|
|
@ -142,31 +147,64 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
headers=headers,
|
headers=headers,
|
||||||
cookies=cls._cookies,
|
cookies=cls._cookies,
|
||||||
) as session:
|
) as session:
|
||||||
if cls._access_token is not None:
|
|
||||||
cls._cookies = session.cookies.jar if hasattr(session.cookies, "jar") else session.cookies
|
cls._cookies = session.cookies.jar if hasattr(session.cookies, "jar") else session.cookies
|
||||||
response = await session.get("https://copilot.microsoft.com/c/api/user?api-version=2", headers={"x-useridentitytype": useridentitytype})
|
|
||||||
if response.status_code == 401:
|
|
||||||
raise MissingAuthError("Status 401: Invalid access token")
|
|
||||||
response.raise_for_status()
|
|
||||||
user = response.json().get('firstName')
|
|
||||||
if user is None:
|
|
||||||
if cls.needs_auth:
|
|
||||||
raise MissingAuthError("No user found, please login first")
|
|
||||||
cls._access_token = None
|
|
||||||
else:
|
|
||||||
debug.log(f"Copilot: User: {user}")
|
|
||||||
if conversation is None:
|
if conversation is None:
|
||||||
response = await session.post(cls.conversation_url, headers={"x-useridentitytype": useridentitytype} if cls._access_token else {})
|
# har_file = os.path.join(os.path.dirname(__file__), "copilot", "copilot.microsoft.com.har")
|
||||||
|
# with open(har_file, "r") as f:
|
||||||
|
# har_entries = json.load(f).get("log", {}).get("entries", [])
|
||||||
|
# conversationId = ""
|
||||||
|
# for har_entry in har_entries:
|
||||||
|
# if har_entry.get("request"):
|
||||||
|
# if "/c/api/" in har_entry.get("request").get("url", ""):
|
||||||
|
# try:
|
||||||
|
# response = await getattr(session, har_entry.get("request").get("method").lower())(
|
||||||
|
# har_entry.get("request").get("url", "").replace("cvqBJw7kyPAp1RoMTmzC6", conversationId),
|
||||||
|
# data=har_entry.get("request").get("postData", {}).get("text"),
|
||||||
|
# headers={header["name"]: header["value"] for header in har_entry.get("request").get("headers")}
|
||||||
|
# )
|
||||||
|
# response.raise_for_status()
|
||||||
|
# if response.headers.get("content-type", "").startswith("application/json"):
|
||||||
|
# conversationId = response.json().get("currentConversationId", conversationId)
|
||||||
|
# except Exception as e:
|
||||||
|
# debug.log(f"Copilot: Failed request to {har_entry.get('request').get('url', '')}: {e}")
|
||||||
|
data = {
|
||||||
|
"timeZone": "America/Los_Angeles",
|
||||||
|
"startNewConversation": True,
|
||||||
|
"teenSupportEnabled": True,
|
||||||
|
"correctPersonalizationSetting": True,
|
||||||
|
"performUserMerge": True,
|
||||||
|
"deferredDataUseCapable": True
|
||||||
|
}
|
||||||
|
response = await session.post(
|
||||||
|
"https://copilot.microsoft.com/c/api/start",
|
||||||
|
headers={
|
||||||
|
"content-type": "application/json",
|
||||||
|
**({"x-useridentitytype": cls._useridentitytype} if cls._useridentitytype else {}),
|
||||||
|
**(headers or {})
|
||||||
|
},
|
||||||
|
json=data
|
||||||
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
conversation_id = response.json().get("id")
|
conversation = Conversation(response.json().get("currentConversationId"))
|
||||||
conversation = Conversation(conversation_id)
|
debug.log(f"Copilot: Created conversation: {conversation.conversation_id}")
|
||||||
debug.log(f"Copilot: Created conversation: {conversation_id}")
|
|
||||||
else:
|
else:
|
||||||
conversation_id = conversation.conversation_id
|
debug.log(f"Copilot: Use conversation: {conversation.conversation_id}")
|
||||||
debug.log(f"Copilot: Use conversation: {conversation_id}")
|
|
||||||
if return_conversation:
|
if return_conversation:
|
||||||
yield conversation
|
yield conversation
|
||||||
|
|
||||||
|
# response = await session.get("https://copilot.microsoft.com/c/api/user?api-version=4", headers={"x-useridentitytype": useridentitytype} if cls._access_token else {})
|
||||||
|
# if response.status_code == 401:
|
||||||
|
# raise MissingAuthError("Status 401: Invalid session")
|
||||||
|
# response.raise_for_status()
|
||||||
|
# print(response.json())
|
||||||
|
# user = response.json().get('firstName')
|
||||||
|
# if user is None:
|
||||||
|
# if cls.needs_auth:
|
||||||
|
# raise MissingAuthError("No user found, please login first")
|
||||||
|
# cls._access_token = None
|
||||||
|
# else:
|
||||||
|
# debug.log(f"Copilot: User: {user}")
|
||||||
|
|
||||||
uploaded_attachments = []
|
uploaded_attachments = []
|
||||||
if cls._access_token is not None:
|
if cls._access_token is not None:
|
||||||
# Upload regular media (images)
|
# Upload regular media (images)
|
||||||
|
|
@ -178,7 +216,7 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
headers={
|
headers={
|
||||||
"content-type": is_accepted_format(data),
|
"content-type": is_accepted_format(data),
|
||||||
"content-length": str(len(data)),
|
"content-length": str(len(data)),
|
||||||
**({"x-useridentitytype": useridentitytype} if cls._access_token else {})
|
**({"x-useridentitytype": cls._useridentitytype} if cls._useridentitytype else {})
|
||||||
},
|
},
|
||||||
data=data
|
data=data
|
||||||
)
|
)
|
||||||
|
|
@ -201,7 +239,7 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
response = await session.post(
|
response = await session.post(
|
||||||
"https://copilot.microsoft.com/c/api/attachments",
|
"https://copilot.microsoft.com/c/api/attachments",
|
||||||
multipart=data,
|
multipart=data,
|
||||||
headers={"x-useridentitytype": useridentitytype}
|
headers={"x-useridentitytype": cls._useridentitytype} if cls._useridentitytype else {}
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
@ -225,7 +263,7 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
mode = "chat"
|
mode = "chat"
|
||||||
await wss.send(json.dumps({
|
await wss.send(json.dumps({
|
||||||
"event": "send",
|
"event": "send",
|
||||||
"conversationId": conversation_id,
|
"conversationId": conversation.conversation_id,
|
||||||
"content": [*uploaded_attachments, {
|
"content": [*uploaded_attachments, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": prompt,
|
"text": prompt,
|
||||||
|
|
@ -285,6 +323,7 @@ async def get_access_token_and_cookies(url: str, proxy: str = None):
|
||||||
try:
|
try:
|
||||||
page = await browser.get(url)
|
page = await browser.get(url)
|
||||||
access_token = None
|
access_token = None
|
||||||
|
useridentitytype = None
|
||||||
while access_token is None:
|
while access_token is None:
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
await asyncio.sleep(3)
|
await asyncio.sleep(3)
|
||||||
|
|
@ -292,9 +331,12 @@ async def get_access_token_and_cookies(url: str, proxy: str = None):
|
||||||
(() => {
|
(() => {
|
||||||
for (var i = 0; i < localStorage.length; i++) {
|
for (var i = 0; i < localStorage.length; i++) {
|
||||||
try {
|
try {
|
||||||
item = JSON.parse(localStorage.getItem(localStorage.key(i)));
|
const key = localStorage.key(i);
|
||||||
|
const item = JSON.parse(localStorage.getItem(key));
|
||||||
if (item?.body?.access_token) {
|
if (item?.body?.access_token) {
|
||||||
return item.body.access_token;
|
return ["" + item?.body?.access_token, "google"];
|
||||||
|
} else if (key.includes("chatai")) {
|
||||||
|
return "" + item.secret;
|
||||||
}
|
}
|
||||||
} catch(e) {}
|
} catch(e) {}
|
||||||
}
|
}
|
||||||
|
|
@ -302,16 +344,24 @@ async def get_access_token_and_cookies(url: str, proxy: str = None):
|
||||||
""")
|
""")
|
||||||
if access_token is None:
|
if access_token is None:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
continue
|
||||||
|
if isinstance(access_token, list):
|
||||||
|
access_token, useridentitytype = access_token
|
||||||
|
access_token = access_token.get("value") if isinstance(access_token, dict) else access_token
|
||||||
|
useridentitytype = useridentitytype.get("value") if isinstance(useridentitytype, dict) else None
|
||||||
|
print(f"Got access token: {access_token[:10]}..., useridentitytype: {useridentitytype}")
|
||||||
|
break
|
||||||
cookies = {}
|
cookies = {}
|
||||||
for c in await page.send(nodriver.cdp.network.get_cookies([url])):
|
for c in await page.send(nodriver.cdp.network.get_cookies([url])):
|
||||||
cookies[c.name] = c.value
|
cookies[c.name] = c.value
|
||||||
stop_browser()
|
stop_browser()
|
||||||
return access_token, cookies
|
return access_token, useridentitytype, cookies
|
||||||
finally:
|
finally:
|
||||||
stop_browser()
|
stop_browser()
|
||||||
|
|
||||||
def readHAR(url: str):
|
def readHAR(url: str):
|
||||||
api_key = None
|
api_key = None
|
||||||
|
useridentitytype = None
|
||||||
cookies = None
|
cookies = None
|
||||||
for path in get_har_files():
|
for path in get_har_files():
|
||||||
with open(path, 'rb') as file:
|
with open(path, 'rb') as file:
|
||||||
|
|
@ -325,9 +375,11 @@ def readHAR(url: str):
|
||||||
v_headers = get_headers(v)
|
v_headers = get_headers(v)
|
||||||
if "authorization" in v_headers:
|
if "authorization" in v_headers:
|
||||||
api_key = v_headers["authorization"].split(maxsplit=1).pop()
|
api_key = v_headers["authorization"].split(maxsplit=1).pop()
|
||||||
|
if "x-useridentitytype" in v_headers:
|
||||||
|
useridentitytype = v_headers["x-useridentitytype"]
|
||||||
if v['request']['cookies']:
|
if v['request']['cookies']:
|
||||||
cookies = {c['name']: c['value'] for c in v['request']['cookies']}
|
cookies = {c['name']: c['value'] for c in v['request']['cookies']}
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise NoValidHarFileError("No access token found in .har files")
|
raise NoValidHarFileError("No access token found in .har files")
|
||||||
|
|
||||||
return api_key, cookies
|
return api_key, useridentitytype, cookies
|
||||||
|
|
@ -39,6 +39,7 @@ except ImportError as e:
|
||||||
from .deprecated.ARTA import ARTA
|
from .deprecated.ARTA import ARTA
|
||||||
from .deprecated.Blackbox import Blackbox
|
from .deprecated.Blackbox import Blackbox
|
||||||
from .deprecated.DuckDuckGo import DuckDuckGo
|
from .deprecated.DuckDuckGo import DuckDuckGo
|
||||||
|
from .deprecated.Kimi import Kimi
|
||||||
from .deprecated.PerplexityLabs import PerplexityLabs
|
from .deprecated.PerplexityLabs import PerplexityLabs
|
||||||
|
|
||||||
from .ApiAirforce import ApiAirforce
|
from .ApiAirforce import ApiAirforce
|
||||||
|
|
@ -48,7 +49,6 @@ from .Copilot import Copilot
|
||||||
from .DeepInfra import DeepInfra
|
from .DeepInfra import DeepInfra
|
||||||
from .EasyChat import EasyChat
|
from .EasyChat import EasyChat
|
||||||
from .GLM import GLM
|
from .GLM import GLM
|
||||||
from .Kimi import Kimi
|
|
||||||
from .LambdaChat import LambdaChat
|
from .LambdaChat import LambdaChat
|
||||||
from .Mintlify import Mintlify
|
from .Mintlify import Mintlify
|
||||||
from .OIVSCodeSer import OIVSCodeSer2, OIVSCodeSer0501
|
from .OIVSCodeSer import OIVSCodeSer2, OIVSCodeSer0501
|
||||||
|
|
|
||||||
|
|
@ -3,16 +3,16 @@ from __future__ import annotations
|
||||||
import random
|
import random
|
||||||
from typing import AsyncIterator
|
from typing import AsyncIterator
|
||||||
|
|
||||||
from .base_provider import AsyncAuthedProvider, ProviderModelMixin
|
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
|
||||||
from ..providers.helper import get_last_user_message
|
from ...providers.helper import get_last_user_message
|
||||||
from ..requests import StreamSession, sse_stream, raise_for_status
|
from ...requests import StreamSession, sse_stream, raise_for_status
|
||||||
from ..providers.response import AuthResult, TitleGeneration, JsonConversation, FinishReason
|
from ...providers.response import AuthResult, TitleGeneration, JsonConversation, FinishReason
|
||||||
from ..typing import AsyncResult, Messages
|
from ...typing import AsyncResult, Messages
|
||||||
from ..errors import MissingAuthError
|
from ...errors import MissingAuthError
|
||||||
|
|
||||||
class Kimi(AsyncAuthedProvider, ProviderModelMixin):
|
class Kimi(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
url = "https://www.kimi.com"
|
url = "https://www.kimi.com"
|
||||||
working = True
|
working = False
|
||||||
active_by_default = True
|
active_by_default = True
|
||||||
default_model = "kimi-k2"
|
default_model = "kimi-k2"
|
||||||
models = [default_model]
|
models = [default_model]
|
||||||
|
|
@ -29,7 +29,7 @@ class CopilotAccount(Copilot):
|
||||||
debug.log(f"Copilot: {h}")
|
debug.log(f"Copilot: {h}")
|
||||||
if has_nodriver:
|
if has_nodriver:
|
||||||
yield RequestLogin(cls.label, os.environ.get("G4F_LOGIN_URL", ""))
|
yield RequestLogin(cls.label, os.environ.get("G4F_LOGIN_URL", ""))
|
||||||
cls._access_token, cls._cookies = await get_access_token_and_cookies(cls.url, proxy)
|
cls._access_token, cls._useridentitytype, cls._cookies = await get_access_token_and_cookies(cls.url, proxy)
|
||||||
else:
|
else:
|
||||||
raise h
|
raise h
|
||||||
yield AuthResult(
|
yield AuthResult(
|
||||||
|
|
|
||||||
|
|
@ -125,9 +125,9 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
yield ImageResponse([image["url"] for image in data["data"]], prompt)
|
yield ImageResponse([image["url"] for image in data["data"]], prompt)
|
||||||
return
|
return
|
||||||
|
|
||||||
if stream:
|
if stream or stream is None:
|
||||||
kwargs.setdefault("stream_options", {"include_usage": True})
|
kwargs.setdefault("stream_options", {"include_usage": True})
|
||||||
extra_parameters = filter_none(**{key: kwargs.get(key) for key in extra_parameters})
|
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}
|
||||||
if extra_body is None:
|
if extra_body is None:
|
||||||
extra_body = {}
|
extra_body = {}
|
||||||
data = filter_none(
|
data = filter_none(
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ class ConversationManager:
|
||||||
try:
|
try:
|
||||||
with open(self.file_path, 'r', encoding='utf-8') as f:
|
with open(self.file_path, 'r', encoding='utf-8') as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
if self.model is None:
|
if self.provider is None and self.model is None:
|
||||||
self.model = data.get("model")
|
self.model = data.get("model")
|
||||||
if self.provider is None:
|
if self.provider is None:
|
||||||
self.provider = data.get("provider")
|
self.provider = data.get("provider")
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,19 @@ from __future__ import annotations
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import datetime
|
import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, AsyncIterator, Iterator, Dict, Any, Tuple, List, Union
|
from typing import Optional, AsyncIterator, Iterator, Dict, Any, Tuple, List, Union
|
||||||
|
|
||||||
|
try:
|
||||||
|
from aiofile import async_open
|
||||||
|
has_aiofile = True
|
||||||
|
except ImportError:
|
||||||
|
has_aiofile = False
|
||||||
|
|
||||||
from ..typing import Messages
|
from ..typing import Messages
|
||||||
from ..providers.helper import filter_none
|
from ..providers.helper import filter_none
|
||||||
from ..providers.asyncio import to_async_iterator
|
from ..providers.asyncio import to_async_iterator
|
||||||
|
|
@ -254,25 +261,45 @@ async def async_iter_run_tools(
|
||||||
response = to_async_iterator(provider.async_create_function(model=model, messages=messages, **kwargs))
|
response = to_async_iterator(provider.async_create_function(model=model, messages=messages, **kwargs))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_info = model
|
usage_model = model
|
||||||
|
usage_provider = provider.__name__
|
||||||
|
completion_tokens = 0
|
||||||
|
usage = None
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
if isinstance(chunk, ProviderInfo):
|
if isinstance(chunk, FinishReason):
|
||||||
model_info = getattr(chunk, 'model', model_info)
|
if sources is not None:
|
||||||
|
yield sources
|
||||||
|
sources = None
|
||||||
|
yield chunk
|
||||||
|
continue
|
||||||
|
elif isinstance(chunk, Sources):
|
||||||
|
sources = None
|
||||||
|
elif isinstance(chunk, str):
|
||||||
|
completion_tokens += 1
|
||||||
|
elif isinstance(chunk, ProviderInfo):
|
||||||
|
usage_model = getattr(chunk, "model", usage_model)
|
||||||
|
usage_provider = getattr(chunk, "name", usage_provider)
|
||||||
elif isinstance(chunk, Usage):
|
elif isinstance(chunk, Usage):
|
||||||
usage = {"user": kwargs.get("user"), "model": model_info, "provider": provider.get_parent(), **chunk.get_dict()}
|
usage = chunk
|
||||||
|
yield chunk
|
||||||
|
if has_aiofile:
|
||||||
|
if usage is None:
|
||||||
|
usage = get_usage(messages, completion_tokens)
|
||||||
|
yield usage
|
||||||
|
usage = {"user": kwargs.get("user"), "model": usage_model, "provider": usage_provider, **usage.get_dict()}
|
||||||
usage_dir = Path(get_cookies_dir()) / ".usage"
|
usage_dir = Path(get_cookies_dir()) / ".usage"
|
||||||
usage_file = usage_dir / f"{datetime.date.today()}.jsonl"
|
usage_file = usage_dir / f"{datetime.date.today()}.jsonl"
|
||||||
usage_dir.mkdir(parents=True, exist_ok=True)
|
usage_dir.mkdir(parents=True, exist_ok=True)
|
||||||
with usage_file.open("a" if usage_file.exists() else "w") as f:
|
async with async_open(usage_file, "a") as f:
|
||||||
f.write(f"{json.dumps(usage)}\n")
|
await f.write(f"{json.dumps(usage)}\n")
|
||||||
yield chunk
|
if completion_tokens > 0:
|
||||||
provider.live += 1
|
provider.live += 1
|
||||||
except:
|
except:
|
||||||
provider.live -= 1
|
provider.live -= 1
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Yield sources if available
|
# Yield sources if available
|
||||||
if sources:
|
if sources is not None:
|
||||||
yield sources
|
yield sources
|
||||||
|
|
||||||
def iter_run_tools(
|
def iter_run_tools(
|
||||||
|
|
@ -340,10 +367,13 @@ def iter_run_tools(
|
||||||
messages[-1]["content"] = last_message + BUCKET_INSTRUCTIONS
|
messages[-1]["content"] = last_message + BUCKET_INSTRUCTIONS
|
||||||
|
|
||||||
# Process response chunks
|
# Process response chunks
|
||||||
|
try:
|
||||||
thinking_start_time = 0
|
thinking_start_time = 0
|
||||||
processor = ThinkingProcessor()
|
processor = ThinkingProcessor()
|
||||||
model_info = model
|
usage_model = model
|
||||||
try:
|
usage_provider = provider.__name__
|
||||||
|
completion_tokens = 0
|
||||||
|
usage = None
|
||||||
for chunk in provider.create_function(model=model, messages=messages, provider=provider, **kwargs):
|
for chunk in provider.create_function(model=model, messages=messages, provider=provider, **kwargs):
|
||||||
if isinstance(chunk, FinishReason):
|
if isinstance(chunk, FinishReason):
|
||||||
if sources is not None:
|
if sources is not None:
|
||||||
|
|
@ -353,24 +383,30 @@ def iter_run_tools(
|
||||||
continue
|
continue
|
||||||
elif isinstance(chunk, Sources):
|
elif isinstance(chunk, Sources):
|
||||||
sources = None
|
sources = None
|
||||||
|
elif isinstance(chunk, str):
|
||||||
|
completion_tokens += 1
|
||||||
elif isinstance(chunk, ProviderInfo):
|
elif isinstance(chunk, ProviderInfo):
|
||||||
model_info = getattr(chunk, 'model', model_info)
|
usage_model = getattr(chunk, "model", usage_model)
|
||||||
|
usage_provider = getattr(chunk, "name", usage_provider)
|
||||||
elif isinstance(chunk, Usage):
|
elif isinstance(chunk, Usage):
|
||||||
usage = {"user": kwargs.get("user"), "model": model_info, "provider": provider.get_parent(), **chunk.get_dict()}
|
usage = chunk
|
||||||
usage_dir = Path(get_cookies_dir()) / ".usage"
|
|
||||||
usage_file = usage_dir / f"{datetime.date.today()}.jsonl"
|
|
||||||
usage_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
with usage_file.open("a" if usage_file.exists() else "w") as f:
|
|
||||||
f.write(f"{json.dumps(usage)}\n")
|
|
||||||
if not isinstance(chunk, str):
|
if not isinstance(chunk, str):
|
||||||
yield chunk
|
yield chunk
|
||||||
continue
|
continue
|
||||||
|
|
||||||
thinking_start_time, results = processor.process_thinking_chunk(chunk, thinking_start_time)
|
thinking_start_time, results = processor.process_thinking_chunk(chunk, thinking_start_time)
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
yield result
|
yield result
|
||||||
|
if usage is None:
|
||||||
|
usage = get_usage(messages, completion_tokens)
|
||||||
|
yield usage
|
||||||
|
usage = {"user": kwargs.get("user"), "model": usage_model, "provider": usage_provider, **usage.get_dict()}
|
||||||
|
usage_dir = Path(get_cookies_dir()) / ".usage"
|
||||||
|
usage_file = usage_dir / f"{datetime.date.today()}.jsonl"
|
||||||
|
usage_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
with usage_file.open("a") as f:
|
||||||
|
f.write(f"{json.dumps(usage)}\n")
|
||||||
|
if completion_tokens > 0:
|
||||||
provider.live += 1
|
provider.live += 1
|
||||||
except:
|
except:
|
||||||
provider.live -= 1
|
provider.live -= 1
|
||||||
|
|
@ -378,3 +414,27 @@ def iter_run_tools(
|
||||||
|
|
||||||
if sources is not None:
|
if sources is not None:
|
||||||
yield sources
|
yield sources
|
||||||
|
|
||||||
|
def caculate_prompt_tokens(messages: Messages) -> int:
|
||||||
|
"""Calculate the total number of tokens in messages"""
|
||||||
|
token_count = 1 # Bos Token
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message.get("content"), str):
|
||||||
|
token_count += math.floor(len(message["content"].encode("utf-8")) / 4)
|
||||||
|
token_count += 4 # Role and start/end message token
|
||||||
|
elif isinstance(message.get("content"), list):
|
||||||
|
for item in message["content"]:
|
||||||
|
if isinstance(item, str):
|
||||||
|
token_count += math.floor(len(item.encode("utf-8")) / 4)
|
||||||
|
elif isinstance(item, dict) and "text" in item and isinstance(item["text"], str):
|
||||||
|
token_count += math.floor(len(item["text"].encode("utf-8")) / 4)
|
||||||
|
token_count += 4 # Role and start/end message token
|
||||||
|
return token_count
|
||||||
|
|
||||||
|
def get_usage(messages: Messages, completion_tokens: int) -> Usage:
|
||||||
|
prompt_tokens = caculate_prompt_tokens(messages)
|
||||||
|
return Usage(
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens
|
||||||
|
)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue