Update docs: Using the OpenAI Library

Add sse function to requests sessions
Small improvments in OpenaiChat and ARTA provider
This commit is contained in:
hlohaus 2025-03-22 07:32:30 +01:00
parent fa2344b031
commit 8f6efd5366
17 changed files with 291 additions and 86 deletions

View file

@ -8,7 +8,7 @@
- [From Repository](#from-repository) - [From Repository](#from-repository)
- [Using the Interference API](#using-the-interference-api) - [Using the Interference API](#using-the-interference-api)
- [Basic Usage](#basic-usage) - [Basic Usage](#basic-usage)
- [With OpenAI Library](#with-openai-library) - [Using the OpenAI Library](#using-the-openai-library)
- [With Requests Library](#with-requests-library) - [With Requests Library](#with-requests-library)
- [Selecting a Provider](#selecting-a-provider) - [Selecting a Provider](#selecting-a-provider)
- [Key Points](#key-points) - [Key Points](#key-points)
@ -95,35 +95,45 @@ curl -X POST "http://localhost:1337/v1/images/generate" \
}' }'
``` ```
---
### With OpenAI Library ### Using the OpenAI Library
**To utilize the Inference API with the OpenAI Python library, you can specify the `base_url` to point to your endpoint:**
**You can use the Interference API with the OpenAI Python library by changing the `base_url`:**
```python ```python
from openai import OpenAI from openai import OpenAI
# Initialize the OpenAI client
client = OpenAI( client = OpenAI(
api_key="secret", api_key="secret", # Set an API key (use "secret" if your provider doesn't require one)
base_url="http://localhost:1337/v1" base_url="http://localhost:1337/v1" # Point to your local or custom API endpoint
) )
# Create a chat completion request
response = client.chat.completions.create( response = client.chat.completions.create(
model="gpt-4o-mini", model="gpt-4o-mini", # Specify the model to use
messages=[{"role": "user", "content": "Write a poem about a tree"}], messages=[{"role": "user", "content": "Write a poem about a tree"}], # Define the input message
stream=True, stream=True, # Enable streaming for real-time responses
) )
# Handle the response
if isinstance(response, dict): if isinstance(response, dict):
# Not streaming # Non-streaming response
print(response.choices[0].message.content) print(response.choices[0].message.content)
else: else:
# Streaming # Streaming response
for token in response: for token in response:
content = token.choices[0].delta.content content = token.choices[0].delta.content
if content is not None: if content is not None:
print(content, end="", flush=True) print(content, end="", flush=True)
``` ```
**Notes:**
- The `api_key` is required by the OpenAI Python library. If your provider does not require an API key, you can set it to `"secret"`. This value will be ignored by providers in G4F.
- Replace `"http://localhost:1337/v1"` with the appropriate URL for your custom or local inference API.
---
### With Requests Library ### With Requests Library

View file

@ -16,7 +16,7 @@ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_image_prompt from .helper import format_image_prompt
class ARTA(AsyncGeneratorProvider, ProviderModelMixin): class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://img-gen-prod.ai-arta.com" url = "https://ai-arta.com"
auth_url = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/signupNewUser?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ" auth_url = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/signupNewUser?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
token_refresh_url = "https://securetoken.googleapis.com/v1/token?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ" token_refresh_url = "https://securetoken.googleapis.com/v1/token?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
image_generation_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image" image_generation_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image"

View file

@ -92,17 +92,17 @@ class HuggingFaceAPI(OpenaiTemplate):
model = provider_mapping[provider_key]["providerId"] model = provider_mapping[provider_key]["providerId"]
yield ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})"}) yield ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})"})
break break
start = calculate_lenght(messages) # start = calculate_lenght(messages)
if start > max_inputs_lenght: # if start > max_inputs_lenght:
if len(messages) > 6: # if len(messages) > 6:
messages = messages[:3] + messages[-3:] # messages = messages[:3] + messages[-3:]
if calculate_lenght(messages) > max_inputs_lenght: # if calculate_lenght(messages) > max_inputs_lenght:
last_user_message = [{"role": "user", "content": get_last_user_message(messages)}] # last_user_message = [{"role": "user", "content": get_last_user_message(messages)}]
if len(messages) > 2: # if len(messages) > 2:
messages = [m for m in messages if m["role"] == "system"] + last_user_message # messages = [m for m in messages if m["role"] == "system"] + last_user_message
if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght: # if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
messages = last_user_message # messages = last_user_message
debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}") # debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, media=media, **kwargs): async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, media=media, **kwargs):
yield chunk yield chunk

View file

@ -36,7 +36,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages, messages: Messages,
**kwargs **kwargs
) -> AsyncResult: ) -> AsyncResult:
if "tools" not in kwargs and "images" not in kwargs and random.random() >= 0.5: if "tools" not in kwargs and "media" not in kwargs and random.random() >= 0.5:
try: try:
is_started = False is_started = False
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs): async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):

View file

@ -465,8 +465,6 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
if not line.startswith(b"data: "): if not line.startswith(b"data: "):
return return
elif line.startswith(b"data: [DONE]"): elif line.startswith(b"data: [DONE]"):
if fields.finish_reason is None:
fields.finish_reason = "error"
return return
try: try:
line = json.loads(line[6:]) line = json.loads(line[6:])

View file

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import json
import requests import requests
from ..helper import filter_none, format_image_prompt from ..helper import filter_none, format_image_prompt
@ -141,7 +140,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
choice = data["choices"][0] choice = data["choices"][0]
if "content" in choice["message"] and choice["message"]["content"]: if "content" in choice["message"] and choice["message"]["content"]:
yield choice["message"]["content"].strip() yield choice["message"]["content"].strip()
elif "tool_calls" in choice["message"]: if "tool_calls" in choice["message"]:
yield ToolCalls(choice["message"]["tool_calls"]) yield ToolCalls(choice["message"]["tool_calls"])
if "usage" in data: if "usage" in data:
yield Usage(**data["usage"]) yield Usage(**data["usage"])
@ -151,26 +150,21 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
elif content_type.startswith("text/event-stream"): elif content_type.startswith("text/event-stream"):
await raise_for_status(response) await raise_for_status(response)
first = True first = True
async for line in response.iter_lines(): async for data in response.sse():
if line.startswith(b"data: "): cls.raise_error(data)
chunk = line[6:] choice = data["choices"][0]
if chunk == b"[DONE]": if "content" in choice["delta"] and choice["delta"]["content"]:
break delta = choice["delta"]["content"]
data = json.loads(chunk) if first:
cls.raise_error(data) delta = delta.lstrip()
choice = data["choices"][0] if delta:
if "content" in choice["delta"] and choice["delta"]["content"]: first = False
delta = choice["delta"]["content"] yield delta
if first: if "usage" in data and data["usage"]:
delta = delta.lstrip() yield Usage(**data["usage"])
if delta: if "finish_reason" in choice and choice["finish_reason"] is not None:
first = False yield FinishReason(choice["finish_reason"])
yield delta break
if "usage" in data and data["usage"]:
yield Usage(**data["usage"])
if "finish_reason" in choice and choice["finish_reason"] is not None:
yield FinishReason(choice["finish_reason"])
break
else: else:
await raise_for_status(response) await raise_for_status(response)
raise ResponseError(f"Not supported content-type: {content_type}") raise ResponseError(f"Not supported content-type: {content_type}")

View file

@ -308,7 +308,8 @@ class Api:
if credentials is not None and credentials.credentials != "secret": if credentials is not None and credentials.credentials != "secret":
config.api_key = credentials.credentials config.api_key = credentials.credentials
conversation = return_conversation = None conversation = None
return_conversation = config.return_conversation
if conversation is not None: if conversation is not None:
conversation = JsonConversation(**conversation) conversation = JsonConversation(**conversation)
return_conversation = True return_conversation = True
@ -637,11 +638,8 @@ def run_api(
port: int = None, port: int = None,
bind: str = None, bind: str = None,
debug: bool = False, debug: bool = False,
workers: int = None,
use_colors: bool = None, use_colors: bool = None,
reload: bool = False, **kwargs
ssl_keyfile: str = None,
ssl_certfile: str = None
) -> None: ) -> None:
print(f'Starting server... [g4f v-{g4f.version.utils.current_version}]' + (" (debug)" if debug else "")) print(f'Starting server... [g4f v-{g4f.version.utils.current_version}]' + (" (debug)" if debug else ""))
@ -665,10 +663,7 @@ def run_api(
f"g4f.api:{method}", f"g4f.api:{method}",
host=host, host=host,
port=int(port), port=int(port),
workers=workers,
use_colors=use_colors,
factory=True, factory=True,
reload=reload, use_colors=use_colors,
ssl_keyfile=ssl_keyfile, **filter_none(**kwargs)
ssl_certfile=ssl_certfile
) )

View file

@ -31,6 +31,7 @@ class ChatCompletionsConfig(BaseModel):
proxy: Optional[str] = None proxy: Optional[str] = None
conversation_id: Optional[str] = None conversation_id: Optional[str] = None
conversation: Optional[dict] = None conversation: Optional[dict] = None
return_conversation: Optional[bool] = None
history_disabled: Optional[bool] = None history_disabled: Optional[bool] = None
timeout: Optional[int] = None timeout: Optional[int] = None
tool_calls: list = Field(default=[], examples=[[ tool_calls: list = Field(default=[], examples=[[
@ -43,6 +44,12 @@ class ChatCompletionsConfig(BaseModel):
} }
]]) ]])
tools: list = None tools: list = None
parallel_tool_calls: bool = None
tool_choice: Optional[str] = None
reasoning_effort: Optional[str] = None
logit_bias: Optional[dict] = None
modalities: Optional[list[str]] = None
audio: Optional[dict] = None
response_format: Optional[dict] = None response_format: Optional[dict] = None
class ImageGenerationConfig(BaseModel): class ImageGenerationConfig(BaseModel):

View file

@ -32,6 +32,7 @@ def get_api_parser():
api_parser.add_argument("--ssl-keyfile", type=str, default=None, help="Path to SSL key file for HTTPS.") api_parser.add_argument("--ssl-keyfile", type=str, default=None, help="Path to SSL key file for HTTPS.")
api_parser.add_argument("--ssl-certfile", type=str, default=None, help="Path to SSL certificate file for HTTPS.") api_parser.add_argument("--ssl-certfile", type=str, default=None, help="Path to SSL certificate file for HTTPS.")
api_parser.add_argument("--log-config", type=str, default=None, help="Custom log config.")
return api_parser return api_parser
@ -74,7 +75,8 @@ def run_api_args(args):
use_colors=not args.disable_colors, use_colors=not args.disable_colors,
reload=args.reload, reload=args.reload,
ssl_keyfile=args.ssl_keyfile, ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile ssl_certfile=args.ssl_certfile,
log_config=args.log_config,
) )
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -162,14 +162,15 @@ async def async_iter_response(
tool_calls = None tool_calls = None
usage = None usage = None
provider: ProviderInfo = None provider: ProviderInfo = None
conversation: JsonConversation = None
try: try:
async for chunk in response: async for chunk in response:
if isinstance(chunk, FinishReason): if isinstance(chunk, FinishReason):
finish_reason = chunk.reason finish_reason = chunk.reason
break break
elif isinstance(chunk, BaseConversation): elif isinstance(chunk, JsonConversation):
yield chunk conversation = chunk
continue continue
elif isinstance(chunk, ToolCalls): elif isinstance(chunk, ToolCalls):
tool_calls = chunk.get_list() tool_calls = chunk.get_list()
@ -228,7 +229,8 @@ async def async_iter_response(
content, finish_reason, completion_id, int(time.time()), usage=usage, content, finish_reason, completion_id, int(time.time()), usage=usage,
**filter_none( **filter_none(
tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls] tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls]
) if tool_calls is not None else {} ) if tool_calls is not None else {},
conversation=None if conversation is None else conversation.get_dict()
) )
if provider is not None: if provider is not None:
chat_completion.provider = provider.name chat_completion.provider = provider.name
@ -242,7 +244,6 @@ async def async_iter_append_model_and_provider(
last_model: str, last_model: str,
last_provider: ProviderType last_provider: ProviderType
) -> AsyncChatCompletionResponseType: ) -> AsyncChatCompletionResponseType:
last_provider = None
try: try:
if isinstance(last_provider, BaseRetryProvider): if isinstance(last_provider, BaseRetryProvider):
async for chunk in response: async for chunk in response:

View file

@ -132,6 +132,7 @@ class ChatCompletion(BaseModel):
provider: Optional[str] provider: Optional[str]
choices: list[ChatCompletionChoice] choices: list[ChatCompletionChoice]
usage: UsageModel usage: UsageModel
conversation: dict
@classmethod @classmethod
def model_construct( def model_construct(
@ -141,7 +142,8 @@ class ChatCompletion(BaseModel):
completion_id: str = None, completion_id: str = None,
created: int = None, created: int = None,
tool_calls: list[ToolCallModel] = None, tool_calls: list[ToolCallModel] = None,
usage: UsageModel = None usage: UsageModel = None,
conversation: dict = None
): ):
return super().model_construct( return super().model_construct(
id=f"chatcmpl-{completion_id}" if completion_id else None, id=f"chatcmpl-{completion_id}" if completion_id else None,
@ -153,7 +155,7 @@ class ChatCompletion(BaseModel):
ChatCompletionMessage.model_construct(content, tool_calls), ChatCompletionMessage.model_construct(content, tool_calls),
finish_reason, finish_reason,
)], )],
**filter_none(usage=usage) **filter_none(usage=usage, conversation=conversation)
) )
class ChatCompletionDelta(BaseModel): class ChatCompletionDelta(BaseModel):

View file

@ -298,10 +298,12 @@
let oauthResult = localStorage.getItem("oauth"); let oauthResult = localStorage.getItem("oauth");
if (oauthResult) { if (oauthResult) {
let user;
try { try {
oauthResult = JSON.parse(oauthResult); oauthResult = JSON.parse(oauthResult);
user = await hub.whoAmI({accessToken: oauthResult.accessToken}); user = await hub.whoAmI({accessToken: oauthResult.accessToken});
} catch { } catch (e) {
console.error(e);
oauthResult = null; oauthResult = null;
localStorage.removeItem("oauth"); localStorage.removeItem("oauth");
localStorage.removeItem("HuggingFace-api_key"); localStorage.removeItem("HuggingFace-api_key");
@ -365,7 +367,7 @@
return; return;
} }
const lower = data.prompt.toLowerCase(); const lower = data.prompt.toLowerCase();
const tags = ["nsfw", "timeline", "feet", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", " text ", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "hemale", "oprah", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"]; const tags = ["nsfw", "timeline", "feet", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "hemale", "oprah", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"];
for (i in tags) { for (i in tags) {
if (lower.indexOf(tags[i]) != -1) { if (lower.indexOf(tags[i]) != -1) {
console.log("Skipping image with tag: " + tags[i]); console.log("Skipping image with tag: " + tags[i]);

127
g4f/gui/client/qrcode.html Normal file
View file

@ -0,0 +1,127 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>QR Scanner and QR Code Generator</title>
<script src="https://cdn.jsdelivr.net/npm/qrcodejs/qrcode.min.js"></script>
<style>
body { font-family: Arial, sans-serif; text-align: center; margin: 20px; }
video { width: 300px; height: 300px; border: 1px solid black; display: block; margin: auto; }
#qrcode { margin-top: 20px; }
#qrcode img, #qrcode canvas { margin: 0 auto; width: 300px; height: 300px; }
button { margin: 5px; padding: 10px; }
</style>
</head>
<body>
<h1>QR Scanner & QR Code Generator</h1>
<h2>QR Code Scanner</h2>
<video id="video"></video>
<button id="startCamera">Start Camera</button>
<button id="stopScan">Stop Scanning</button>
<button id="switchCamera">Switch Camera</button>
<button id="toggleFlash">Toggle Flash</button>
<p id="cam-status"></p>
<h2>Generate QR Code</h2>
<div id="qrcode"></div>
<button id="generateQRCode">Generate QR Code</button>
<script type="module">
import QrScanner from 'https://cdn.jsdelivr.net/npm/qr-scanner/qr-scanner.min.js';
function generate_uuid() {
function random16Hex() { return (0x10000 | Math.random() * 0x10000).toString(16).substr(1); }
return random16Hex() + random16Hex() +
"-" + random16Hex() +
"-" + random16Hex() +
"-" + random16Hex() +
"-" + random16Hex() + random16Hex() + random16Hex();
}
const videoElem = document.getElementById('video');
const camStatus = document.getElementById('cam-status');
let qrScanner;
document.getElementById('startCamera').addEventListener('click', async () => {
startCamera();
});
document.getElementById('stopScan').addEventListener('click', () => {
qrScanner.stop();
});
document.getElementById('toggleFlash').addEventListener('click', async () => {
if (qrScanner) {
const hasFlash = await qrScanner.hasFlash();
if (hasFlash) {
qrScanner.toggleFlash();
} else {
alert('Flash not supported on this camera.');
}
}
});
localStorage.getItem('device_id') || localStorage.setItem('device_id', generate_uuid());
const qrcode = new QRCode(document.getElementById("qrcode"), {
text: JSON.stringify({
date: Math.floor(Date.now() / 1000),
device_id: localStorage.getItem('device_id')
}),
width: 128,
height: 128,
colorDark: "#000000",
colorLight: "#ffffff",
correctLevel: QRCode.CorrectLevel.H
});
const switchButton = document.getElementById('switchCamera');
let currentStream = null;
let facingMode = 'user';
async function startCamera() {
try {
if (currentStream) {
currentStream.getTracks().forEach(track => track.stop());
}
const constraints = {
video: {
width: { ideal: 1280 },
height: { ideal: 720 },
facingMode: facingMode
},
audio: false
};
const stream = await navigator.mediaDevices.getUserMedia(constraints);
currentStream = stream;
video.srcObject = stream;
qrScanner = new QrScanner(videoElem, result => {
camStatus.innerText = 'Camera Success: ' + result;
console.log('decoded QR code:', result);
}, {
highlightScanRegion: true,
highlightCodeOutline: true,
});
await video.play();
await qrScanner.start();
} catch (error) {
console.error('Error accessing the camera:', error);
alert(`Could not access the camera: ${error.message}`);
}
}
switchButton.addEventListener('click', () => {
facingMode = facingMode === 'user' ? 'environment' : 'user';
startCamera();
});
</script>
</body>
</html>

View file

@ -1081,11 +1081,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
} }
try { try {
let api_key; let api_key;
if (is_demo && ["OpenaiChat", "DeepSeekAPI", "PollinationsAI", "Gemini"].includes(provider)) { if (is_demo && !provider) {
api_key = localStorage.getItem("user");
} else if (["HuggingSpace", "G4F"].includes(provider)) {
api_key = localStorage.getItem("HuggingSpace-api_key");
} else if (is_demo) {
api_key = localStorage.getItem("HuggingFace-api_key"); api_key = localStorage.getItem("HuggingFace-api_key");
if (!api_key) { if (!api_key) {
location.href = "/"; location.href = "/";
@ -2111,7 +2107,9 @@ async function on_api() {
} }
providerSelect.innerHTML = ` providerSelect.innerHTML = `
<option value="" selected="selected">Demo Mode</option> <option value="" selected="selected">Demo Mode</option>
<option value="ARTA">ARTA Provider</option>
<option value="DeepSeekAPI">DeepSeek Provider</option> <option value="DeepSeekAPI">DeepSeek Provider</option>
<option value="Grok">Grok Provider</option>
<option value="OpenaiChat">OpenAI Provider</option> <option value="OpenaiChat">OpenAI Provider</option>
<option value="PollinationsAI">Pollinations AI</option> <option value="PollinationsAI">Pollinations AI</option>
<option value="G4F">G4F framework</option> <option value="G4F">G4F framework</option>
@ -2323,6 +2321,7 @@ async function load_version() {
} }
function renderMediaSelect() { function renderMediaSelect() {
mediaSelect.classList.remove("hidden");
const oldImages = mediaSelect.querySelectorAll("a:has(img)"); const oldImages = mediaSelect.querySelectorAll("a:has(img)");
oldImages.forEach((el)=>el.remove()); oldImages.forEach((el)=>el.remove());
Object.entries(image_storage).forEach(([object_url, file]) => { Object.entries(image_storage).forEach(([object_url, file]) => {
@ -2333,7 +2332,9 @@ function renderMediaSelect() {
img.onclick = () => { img.onclick = () => {
img.remove(); img.remove();
delete image_storage[object_url]; delete image_storage[object_url];
URL.revokeObjectURL(object_url) if (file instanceof File) {
URL.revokeObjectURL(object_url)
}
} }
img.onload = () => { img.onload = () => {
link.title += `\n${img.naturalWidth}x${img.naturalHeight}`; link.title += `\n${img.naturalWidth}x${img.naturalHeight}`;
@ -2349,9 +2350,11 @@ imageInput.onclick = () => {
mediaSelect.querySelector(".close").onclick = () => { mediaSelect.querySelector(".close").onclick = () => {
if (Object.values(image_storage).length) { if (Object.values(image_storage).length) {
for (key in image_storage) { Object.entries(image_storage).forEach(([object_url, file]) => {
URL.revokeObjectURL(key); if (file instanceof File) {
} URL.revokeObjectURL(object_url)
}
});
image_storage = {}; image_storage = {};
renderMediaSelect(); renderMediaSelect();
} else { } else {
@ -2459,13 +2462,27 @@ async function upload_files(fileInput) {
Array.from(fileInput.files).forEach(file => { Array.from(fileInput.files).forEach(file => {
formData.append('files', file); formData.append('files', file);
}); });
await fetch("/backend-api/v2/files/" + bucket_id, { const response = await fetch("/backend-api/v2/files/" + bucket_id, {
method: 'POST', method: 'POST',
body: formData body: formData
}); });
const result = await response.json()
let do_refine = document.getElementById("refine")?.checked; const count = result.files.length + result.media.length;
connectToSSE(`/backend-api/v2/files/${bucket_id}`, do_refine, bucket_id); inputCount.innerText = `${count} File(s) uploaded successfully`;
if (result.files.length > 0) {
let do_refine = document.getElementById("refine")?.checked;
connectToSSE(`/backend-api/v2/files/${bucket_id}`, do_refine, bucket_id);
} else {
paperclip.classList.remove("blink");
fileInput.value = "";
}
if (result.media) {
result.media.forEach((filename)=> {
const url = `/backend-api/v2/files/${bucket_id}/media/${filename}`;
image_storage[url] = {bucket_id: bucket_id, name: filename};
});
renderMediaSelect();
}
} }
fileInput.addEventListener('change', async (event) => { fileInput.addEventListener('change', async (event) => {
@ -2580,7 +2597,9 @@ async function api(ressource, args=null, files=null, message_id=null, scroll=tru
if (files.length > 0) { if (files.length > 0) {
const formData = new FormData(); const formData = new FormData();
for (const file of files) { for (const file of files) {
formData.append('files', file) if (file instanceof File) {
formData.append('files', file)
}
} }
formData.append('json', body); formData.append('json', body);
body = formData; body = formData;

View file

@ -9,7 +9,7 @@ import shutil
import random import random
import datetime import datetime
import tempfile import tempfile
from flask import Flask, Response, request, jsonify, render_template from flask import Flask, Response, request, jsonify, render_template, send_from_directory
from typing import Generator from typing import Generator
from pathlib import Path from pathlib import Path
from urllib.parse import quote_plus from urllib.parse import quote_plus
@ -23,6 +23,7 @@ from ...client.helper import filter_markdown
from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_buckets from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_buckets
from ...tools.run_tools import iter_run_tools from ...tools.run_tools import iter_run_tools
from ...errors import ProviderNotFoundError from ...errors import ProviderNotFoundError
from ...image import is_allowed_extension
from ...cookies import get_cookies_dir from ...cookies import get_cookies_dir
from ... import ChatCompletion from ... import ChatCompletion
from ... import models from ... import models
@ -66,6 +67,10 @@ class Backend_Api(Api):
def home(): def home():
return render_template('home.html') return render_template('home.html')
@app.route('/qrcode', methods=['GET'])
def qrcode():
return render_template('qrcode.html')
@app.route('/backend-api/v2/models', methods=['GET']) @app.route('/backend-api/v2/models', methods=['GET'])
def jsonify_models(**kwargs): def jsonify_models(**kwargs):
response = get_demo_models() if app.demo else self.get_models(**kwargs) response = get_demo_models() if app.demo else self.get_models(**kwargs)
@ -302,20 +307,38 @@ class Backend_Api(Api):
def upload_files(bucket_id: str): def upload_files(bucket_id: str):
bucket_id = secure_filename(bucket_id) bucket_id = secure_filename(bucket_id)
bucket_dir = get_bucket_dir(bucket_id) bucket_dir = get_bucket_dir(bucket_id)
media_dir = os.path.join(bucket_dir, "media")
os.makedirs(bucket_dir, exist_ok=True) os.makedirs(bucket_dir, exist_ok=True)
filenames = [] filenames = []
media = []
for file in request.files.getlist('files'): for file in request.files.getlist('files'):
try: try:
filename = secure_filename(file.filename) filename = secure_filename(file.filename)
if supports_filename(filename): if is_allowed_extension(filename):
with open(os.path.join(bucket_dir, filename), 'wb') as f: os.makedirs(media_dir, exist_ok=True)
shutil.copyfileobj(file.stream, f) newfile = os.path.join(media_dir, filename)
media.append(filename)
elif supports_filename(filename):
newfile = os.path.join(bucket_dir, filename)
filenames.append(filename) filenames.append(filename)
else:
continue
with open(newfile, 'wb') as f:
shutil.copyfileobj(file.stream, f)
finally: finally:
file.stream.close() file.stream.close()
with open(os.path.join(bucket_dir, "files.txt"), 'w') as f: with open(os.path.join(bucket_dir, "files.txt"), 'w') as f:
[f.write(f"{filename}\n") for filename in filenames] [f.write(f"{filename}\n") for filename in filenames]
return {"bucket_id": bucket_id, "files": filenames} return {"bucket_id": bucket_id, "files": filenames, "media": media}
@app.route('/backend-api/v2/files/<bucket_id>/media/<filename>', methods=['GET'])
def get_media(bucket_id, filename):
bucket_id = secure_filename(bucket_id)
bucket_dir = get_bucket_dir(bucket_id)
media_dir = os.path.join(bucket_dir, "media")
if os.path.exists(media_dir):
return send_from_directory(os.path.abspath(media_dir), filename)
return "File not found", 404
@app.route('/backend-api/v2/files/<bucket_id>/<filename>', methods=['PUT']) @app.route('/backend-api/v2/files/<bucket_id>/<filename>', methods=['PUT'])
def upload_file(bucket_id, filename): def upload_file(bucket_id, filename):

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import json
from aiohttp import ClientSession, ClientResponse, ClientTimeout, BaseConnector, FormData from aiohttp import ClientSession, ClientResponse, ClientTimeout, BaseConnector, FormData
from typing import AsyncIterator, Any, Optional from typing import AsyncIterator, Any, Optional
@ -18,6 +19,18 @@ class StreamResponse(ClientResponse):
async def json(self, content_type: str = None) -> Any: async def json(self, content_type: str = None) -> Any:
return await super().json(content_type=content_type) return await super().json(content_type=content_type)
async def sse(self) -> AsyncIterator[dict]:
"""Asynchronously iterate over the Server-Sent Events of the response."""
async for line in self.content:
if line.startswith(b"data: "):
chunk = line[6:]
if chunk.startswith(b"[DONE]"):
break
try:
yield json.loads(chunk)
except json.JSONDecodeError:
continue
class StreamSession(ClientSession): class StreamSession(ClientSession):
def __init__( def __init__(
self, self,

View file

@ -47,6 +47,18 @@ class StreamResponse:
"""Asynchronously iterate over the response content.""" """Asynchronously iterate over the response content."""
return self.inner.aiter_content() return self.inner.aiter_content()
async def sse(self) -> AsyncGenerator[dict, None]:
"""Asynchronously iterate over the Server-Sent Events of the response."""
async for line in self.iter_lines():
if line.startswith(b"data: "):
chunk = line[6:]
if chunk == b"[DONE]":
break
try:
yield json.loads(chunk)
except json.JSONDecodeError:
continue
async def __aenter__(self): async def __aenter__(self):
"""Asynchronously enter the runtime context for the response object.""" """Asynchronously enter the runtime context for the response object."""
inner: Response = await self.inner inner: Response = await self.inner