mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-05 18:20:35 -08:00
Enhance Copilot provider to support user identity type in API requests and improve bucket item handling
This commit is contained in:
parent
37b79b6df8
commit
03eef2a226
2 changed files with 59 additions and 88 deletions
|
|
@ -9,7 +9,7 @@ from urllib.parse import quote
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from curl_cffi.requests import AsyncSession
|
from curl_cffi.requests import AsyncSession
|
||||||
from curl_cffi import CurlWsFlag
|
from curl_cffi import CurlWsFlag, CurlMime
|
||||||
has_curl_cffi = True
|
has_curl_cffi = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_curl_cffi = False
|
has_curl_cffi = False
|
||||||
|
|
@ -20,7 +20,6 @@ except ImportError:
|
||||||
has_nodriver = False
|
has_nodriver = False
|
||||||
|
|
||||||
from .base_provider import AsyncAuthedProvider, ProviderModelMixin
|
from .base_provider import AsyncAuthedProvider, ProviderModelMixin
|
||||||
from .helper import format_prompt_max_length
|
|
||||||
from .openai.har_file import get_headers, get_har_files
|
from .openai.har_file import get_headers, get_har_files
|
||||||
from ..typing import AsyncResult, Messages, MediaListType
|
from ..typing import AsyncResult, Messages, MediaListType
|
||||||
from ..errors import MissingRequirementsError, NoValidHarFileError, MissingAuthError
|
from ..errors import MissingRequirementsError, NoValidHarFileError, MissingAuthError
|
||||||
|
|
@ -30,7 +29,7 @@ from ..requests import get_nodriver
|
||||||
from ..image import to_bytes, is_accepted_format
|
from ..image import to_bytes, is_accepted_format
|
||||||
from .helper import get_last_user_message
|
from .helper import get_last_user_message
|
||||||
from ..files import get_bucket_dir
|
from ..files import get_bucket_dir
|
||||||
from ..tools.files import get_filenames, read_bucket
|
from ..tools.files import read_bucket
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from .. import debug
|
from .. import debug
|
||||||
|
|
||||||
|
|
@ -46,8 +45,10 @@ def extract_bucket_items(messages: Messages) -> list[dict]:
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if isinstance(message, dict) and isinstance(message.get("content"), list):
|
if isinstance(message, dict) and isinstance(message.get("content"), list):
|
||||||
for content_item in message["content"]:
|
for content_item in message["content"]:
|
||||||
if isinstance(content_item, dict) and ("bucket_id" in content_item or "bucket" in content_item):
|
if isinstance(content_item, dict) and "bucket_id" in content_item and "name" not in content_item:
|
||||||
bucket_items.append(content_item)
|
bucket_items.append(content_item)
|
||||||
|
if message.get("role") == "assistant":
|
||||||
|
bucket_items = []
|
||||||
return bucket_items
|
return bucket_items
|
||||||
|
|
||||||
class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
|
|
@ -109,6 +110,7 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
media: MediaListType = None,
|
media: MediaListType = None,
|
||||||
conversation: BaseConversation = None,
|
conversation: BaseConversation = None,
|
||||||
return_conversation: bool = True,
|
return_conversation: bool = True,
|
||||||
|
useridentitytype: str = "google",
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> AsyncResult:
|
) -> AsyncResult:
|
||||||
|
|
@ -130,7 +132,7 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
cls._access_token, cls._cookies = await get_access_token_and_cookies(cls.url, proxy)
|
cls._access_token, cls._cookies = await get_access_token_and_cookies(cls.url, proxy)
|
||||||
else:
|
else:
|
||||||
raise h
|
raise h
|
||||||
websocket_url = f"{websocket_url}&accessToken={quote(cls._access_token)}"
|
websocket_url = f"{websocket_url}&accessToken={quote(cls._access_token)}&X-UserIdentityType={quote(useridentitytype)}"
|
||||||
headers = {"authorization": f"Bearer {cls._access_token}"}
|
headers = {"authorization": f"Bearer {cls._access_token}"}
|
||||||
|
|
||||||
async with AsyncSession(
|
async with AsyncSession(
|
||||||
|
|
@ -142,7 +144,7 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
) as session:
|
) as session:
|
||||||
if cls._access_token is not None:
|
if cls._access_token is not None:
|
||||||
cls._cookies = session.cookies.jar if hasattr(session.cookies, "jar") else session.cookies
|
cls._cookies = session.cookies.jar if hasattr(session.cookies, "jar") else session.cookies
|
||||||
response = await session.get("https://copilot.microsoft.com/c/api/user")
|
response = await session.get("https://copilot.microsoft.com/c/api/user?api-version=2", headers={"x-useridentitytype": useridentitytype})
|
||||||
if response.status_code == 401:
|
if response.status_code == 401:
|
||||||
raise MissingAuthError("Status 401: Invalid access token")
|
raise MissingAuthError("Status 401: Invalid access token")
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
@ -154,17 +156,13 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
else:
|
else:
|
||||||
debug.log(f"Copilot: User: {user}")
|
debug.log(f"Copilot: User: {user}")
|
||||||
if conversation is None:
|
if conversation is None:
|
||||||
response = await session.post(cls.conversation_url)
|
response = await session.post(cls.conversation_url, headers={"x-useridentitytype": useridentitytype})
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
conversation_id = response.json().get("id")
|
conversation_id = response.json().get("id")
|
||||||
conversation = Conversation(conversation_id)
|
conversation = Conversation(conversation_id)
|
||||||
if prompt is None:
|
|
||||||
prompt = format_prompt_max_length(messages, 10000)
|
|
||||||
debug.log(f"Copilot: Created conversation: {conversation_id}")
|
debug.log(f"Copilot: Created conversation: {conversation_id}")
|
||||||
else:
|
else:
|
||||||
conversation_id = conversation.conversation_id
|
conversation_id = conversation.conversation_id
|
||||||
if prompt is None:
|
|
||||||
prompt = get_last_user_message(messages)
|
|
||||||
debug.log(f"Copilot: Use conversation: {conversation_id}")
|
debug.log(f"Copilot: Use conversation: {conversation_id}")
|
||||||
if return_conversation:
|
if return_conversation:
|
||||||
yield conversation
|
yield conversation
|
||||||
|
|
@ -180,6 +178,7 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
headers={
|
headers={
|
||||||
"content-type": is_accepted_format(data),
|
"content-type": is_accepted_format(data),
|
||||||
"content-length": str(len(data)),
|
"content-length": str(len(data)),
|
||||||
|
"x-useridentitytype": useridentitytype
|
||||||
},
|
},
|
||||||
data=data
|
data=data
|
||||||
)
|
)
|
||||||
|
|
@ -191,69 +190,33 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
bucket_items = extract_bucket_items(messages)
|
bucket_items = extract_bucket_items(messages)
|
||||||
for item in bucket_items:
|
for item in bucket_items:
|
||||||
try:
|
try:
|
||||||
if "name" in item:
|
# Handle plain text content from bucket
|
||||||
# Handle specific file from bucket with name
|
bucket_path = Path(get_bucket_dir(item["bucket_id"]))
|
||||||
file_path = Path(get_bucket_dir(item["bucket_id"], "media", item["name"]))
|
for text_chunk in read_bucket(bucket_path):
|
||||||
if file_path.exists() and file_path.is_file():
|
if text_chunk.strip():
|
||||||
with open(file_path, "rb") as f:
|
|
||||||
file_data = f.read()
|
|
||||||
|
|
||||||
filename = item["name"]
|
|
||||||
# Determine content type based on file extension
|
|
||||||
content_type = "application/octet-stream"
|
|
||||||
if filename.endswith(".pdf"):
|
|
||||||
content_type = "application/pdf"
|
|
||||||
elif filename.endswith(".docx"):
|
|
||||||
content_type = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
|
||||||
elif filename.endswith(".txt"):
|
|
||||||
content_type = "text/plain"
|
|
||||||
elif filename.endswith(".md"):
|
|
||||||
content_type = "text/markdown"
|
|
||||||
elif filename.endswith(".json"):
|
|
||||||
content_type = "application/json"
|
|
||||||
|
|
||||||
response = await session.post(
|
|
||||||
"https://copilot.microsoft.com/c/api/attachments",
|
|
||||||
headers={
|
|
||||||
"content-type": content_type,
|
|
||||||
"content-length": str(len(file_data)),
|
|
||||||
},
|
|
||||||
data=file_data
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
file_url = response.json().get("url")
|
|
||||||
uploaded_attachments.append({"type": "file", "url": file_url, "name": filename})
|
|
||||||
debug.log(f"Copilot: Uploaded bucket file: {filename}")
|
|
||||||
else:
|
|
||||||
debug.log(f"Copilot: Bucket file not found: {item.get('name')}")
|
|
||||||
else:
|
|
||||||
# Handle plain text content from bucket
|
|
||||||
bucket_path = Path(get_bucket_dir(item["bucket"]))
|
|
||||||
plain_text_content = ""
|
|
||||||
for text_chunk in read_bucket(bucket_path):
|
|
||||||
plain_text_content += text_chunk
|
|
||||||
|
|
||||||
if plain_text_content.strip():
|
|
||||||
# Upload plain text as a text file
|
# Upload plain text as a text file
|
||||||
text_data = plain_text_content.encode('utf-8')
|
text_data = text_chunk.encode('utf-8')
|
||||||
|
data = CurlMime()
|
||||||
|
data.addpart("file", filename=f"bucket_{item['bucket_id']}.txt", content_type="text/plain", data=text_data)
|
||||||
response = await session.post(
|
response = await session.post(
|
||||||
"https://copilot.microsoft.com/c/api/attachments",
|
"https://copilot.microsoft.com/c/api/attachments",
|
||||||
headers={
|
multipart=data,
|
||||||
"content-type": "text/plain",
|
headers={"x-useridentitytype": useridentitytype}
|
||||||
"content-length": str(len(text_data)),
|
|
||||||
},
|
|
||||||
data=text_data
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
file_url = response.json().get("url")
|
data = response.json()
|
||||||
uploaded_attachments.append({"type": "file", "url": file_url, "name": f"bucket_{item['bucket']}.txt"})
|
uploaded_attachments.append({"type": "document", "attachmentId": data.get("id")})
|
||||||
debug.log(f"Copilot: Uploaded bucket text content: {item['bucket']}")
|
debug.log(f"Copilot: Uploaded bucket text content: {item['bucket_id']}")
|
||||||
else:
|
else:
|
||||||
debug.log(f"Copilot: No text content found in bucket: {item['bucket']}")
|
debug.log(f"Copilot: No text content found in bucket: {item['bucket_id']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debug.log(f"Copilot: Failed to upload bucket item {item}: {e}")
|
debug.log(f"Copilot: Failed to upload bucket item: {item}")
|
||||||
|
debug.error(e)
|
||||||
|
|
||||||
wss = await session.ws_connect(cls.websocket_url, timeout=3)
|
if prompt is None:
|
||||||
|
prompt = get_last_user_message(messages, False)
|
||||||
|
|
||||||
|
wss = await session.ws_connect(websocket_url, timeout=3)
|
||||||
if "Think" in model:
|
if "Think" in model:
|
||||||
mode = "reasoning"
|
mode = "reasoning"
|
||||||
elif model.startswith("gpt-5") or "GPT-5" in model:
|
elif model.startswith("gpt-5") or "GPT-5" in model:
|
||||||
|
|
@ -317,32 +280,32 @@ class Copilot(AsyncAuthedProvider, ProviderModelMixin):
|
||||||
if not wss.closed:
|
if not wss.closed:
|
||||||
await wss.close()
|
await wss.close()
|
||||||
|
|
||||||
async def get_access_token_and_cookies(url: str, proxy: str = None, target: str = "ChatAI",):
|
async def get_access_token_and_cookies(url: str, proxy: str = None):
|
||||||
browser, stop_browser = await get_nodriver(proxy=proxy, user_data_dir="copilot")
|
browser, stop_browser = await get_nodriver(proxy=proxy)
|
||||||
try:
|
try:
|
||||||
page = await browser.get(url)
|
page = await browser.get(url)
|
||||||
access_token = None
|
access_token = None
|
||||||
while access_token is None:
|
while access_token is None:
|
||||||
access_token = await page.evaluate("""
|
for _ in range(2):
|
||||||
(() => {
|
await asyncio.sleep(3)
|
||||||
for (var i = 0; i < localStorage.length; i++) {
|
access_token = await page.evaluate("""
|
||||||
try {
|
(() => {
|
||||||
item = JSON.parse(localStorage.getItem(localStorage.key(i)));
|
for (var i = 0; i < localStorage.length; i++) {
|
||||||
if (item.credentialType == "AccessToken"
|
try {
|
||||||
&& item.expiresOn > Math.floor(Date.now() / 1000)
|
item = JSON.parse(localStorage.getItem(localStorage.key(i)));
|
||||||
&& item.target.includes("target")) {
|
if (item?.body?.access_token) {
|
||||||
return item.secret;
|
return item.body.access_token;
|
||||||
}
|
}
|
||||||
} catch(e) {}
|
} catch(e) {}
|
||||||
}
|
}
|
||||||
})()
|
})()
|
||||||
""".replace('"target"', json.dumps(target)))
|
""")
|
||||||
if access_token is None:
|
if access_token is None:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
cookies = {}
|
cookies = {}
|
||||||
for c in await page.send(nodriver.cdp.network.get_cookies([url])):
|
for c in await page.send(nodriver.cdp.network.get_cookies([url])):
|
||||||
cookies[c.name] = c.value
|
cookies[c.name] = c.value
|
||||||
await page.close()
|
stop_browser()
|
||||||
return access_token, cookies
|
return access_token, cookies
|
||||||
finally:
|
finally:
|
||||||
stop_browser()
|
stop_browser()
|
||||||
|
|
|
||||||
|
|
@ -66,15 +66,23 @@ def format_prompt(messages: Messages, add_special_tokens: bool = False, do_conti
|
||||||
def get_system_prompt(messages: Messages) -> str:
|
def get_system_prompt(messages: Messages) -> str:
|
||||||
return "\n".join([m["content"] for m in messages if m["role"] in ("developer", "system")])
|
return "\n".join([m["content"] for m in messages if m["role"] in ("developer", "system")])
|
||||||
|
|
||||||
def get_last_user_message(messages: Messages) -> str:
|
def get_last_user_message(messages: Messages, include_buckets: bool = True) -> str:
|
||||||
user_messages = []
|
user_messages = []
|
||||||
for message in messages[::-1]:
|
for message in messages[::-1]:
|
||||||
if message.get("role") == "user" or not user_messages:
|
if message.get("role") == "user" or not user_messages:
|
||||||
if message.get("role") != "user":
|
if message.get("role") != "user":
|
||||||
continue
|
continue
|
||||||
content = to_string(message.get("content")).strip()
|
content = message.get("content")
|
||||||
if content:
|
if include_buckets:
|
||||||
|
content = to_string(content).strip()
|
||||||
|
if isinstance(content, str):
|
||||||
user_messages.append(content)
|
user_messages.append(content)
|
||||||
|
else:
|
||||||
|
for content_item in content:
|
||||||
|
if content_item.get("type") == "text":
|
||||||
|
content = content_item.get("text").strip()
|
||||||
|
if content:
|
||||||
|
user_messages.append(content)
|
||||||
else:
|
else:
|
||||||
return "\n".join(user_messages[::-1])
|
return "\n".join(user_messages[::-1])
|
||||||
return "\n".join(user_messages[::-1])
|
return "\n".join(user_messages[::-1])
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue