Merge pull request #2817 from hlohaus/20Mar

Update openai example in docs, add default api_key
This commit is contained in:
H Lohaus 2025-03-22 10:48:22 +01:00 committed by GitHub
commit 3cbcbe1047
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 316 additions and 116 deletions

View file

@ -8,7 +8,7 @@
- [From Repository](#from-repository)
- [Using the Interference API](#using-the-interference-api)
- [Basic Usage](#basic-usage)
- [With OpenAI Library](#with-openai-library)
- [Using the OpenAI Library](#using-the-openai-library)
- [With Requests Library](#with-requests-library)
- [Selecting a Provider](#selecting-a-provider)
- [Key Points](#key-points)
@ -95,35 +95,45 @@ curl -X POST "http://localhost:1337/v1/images/generate" \
}'
```
---
### With OpenAI Library
### Using the OpenAI Library
**To utilize the Inference API with the OpenAI Python library, you can specify the `base_url` to point to your endpoint:**
**You can use the Interference API with the OpenAI Python library by changing the `base_url`:**
```python
from openai import OpenAI
# Initialize the OpenAI client
client = OpenAI(
api_key="",
base_url="http://localhost:1337/v1"
api_key="secret", # Set an API key (use "secret" if your provider doesn't require one)
base_url="http://localhost:1337/v1" # Point to your local or custom API endpoint
)
# Create a chat completion request
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Write a poem about a tree"}],
stream=True,
model="gpt-4o-mini", # Specify the model to use
messages=[{"role": "user", "content": "Write a poem about a tree"}], # Define the input message
stream=True, # Enable streaming for real-time responses
)
# Handle the response
if isinstance(response, dict):
# Not streaming
# Non-streaming response
print(response.choices[0].message.content)
else:
# Streaming
# Streaming response
for token in response:
content = token.choices[0].delta.content
if content is not None:
print(content, end="", flush=True)
```
**Notes:**
- The `api_key` is required by the OpenAI Python library. If your provider does not require an API key, you can set it to `"secret"`. This value will be ignored by providers in G4F.
- Replace `"http://localhost:1337/v1"` with the appropriate URL for your custom or local inference API.
---
### With Requests Library

View file

@ -16,7 +16,7 @@ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_image_prompt
class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://img-gen-prod.ai-arta.com"
url = "https://ai-arta.com"
auth_url = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/signupNewUser?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
token_refresh_url = "https://securetoken.googleapis.com/v1/token?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
image_generation_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image"

View file

@ -19,7 +19,6 @@ class LambdaChat(HuggingChat):
default_model,
reasoning_model,
"hermes-3-llama-3.1-405b-fp8",
"hermes3-405b-fp8-128k",
"llama3.1-nemotron-70b-instruct",
"lfm-40b",
"llama3.3-70b-instruct-fp8"
@ -27,7 +26,6 @@ class LambdaChat(HuggingChat):
model_aliases = {
"deepseek-v3": default_model,
"hermes-3": "hermes-3-llama-3.1-405b-fp8",
"hermes-3": "hermes3-405b-fp8-128k",
"nemotron-70b": "llama3.1-nemotron-70b-instruct",
"llama-3.3-70b": "llama3.3-70b-instruct-fp8"
}

View file

@ -157,13 +157,14 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult:
# Load model list
cls.get_models()
if not model and media is not None:
has_audio = False
for media_data, filename in media:
if is_data_an_audio(media_data, filename):
has_audio = True
break
model = next(iter(cls.audio_models)) if has_audio else cls.default_vision_model
if not model:
has_audio = "audio" in kwargs
if not has_audio and media is not None:
for media_data, filename in media:
if is_data_an_audio(media_data, filename):
has_audio = True
break
model = next(iter(cls.audio_models)) if has_audio else model
try:
model = cls.get_model(model)
except ModelNotFoundError:
@ -278,7 +279,7 @@ class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
}
for media_data, filename in media
]
last_message["content"] = image_content + [{"type": "text", "text": last_message["content"]}]
last_message["content"] = image_content + ([{"type": "text", "text": last_message["content"]}] if isinstance(last_message["content"], str) else image_content)
messages[-1] = last_message
async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:

View file

@ -92,17 +92,17 @@ class HuggingFaceAPI(OpenaiTemplate):
model = provider_mapping[provider_key]["providerId"]
yield ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})"})
break
start = calculate_lenght(messages)
if start > max_inputs_lenght:
if len(messages) > 6:
messages = messages[:3] + messages[-3:]
if calculate_lenght(messages) > max_inputs_lenght:
last_user_message = [{"role": "user", "content": get_last_user_message(messages)}]
if len(messages) > 2:
messages = [m for m in messages if m["role"] == "system"] + last_user_message
if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
messages = last_user_message
debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
# start = calculate_lenght(messages)
# if start > max_inputs_lenght:
# if len(messages) > 6:
# messages = messages[:3] + messages[-3:]
# if calculate_lenght(messages) > max_inputs_lenght:
# last_user_message = [{"role": "user", "content": get_last_user_message(messages)}]
# if len(messages) > 2:
# messages = [m for m in messages if m["role"] == "system"] + last_user_message
# if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
# messages = last_user_message
# debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, media=media, **kwargs):
yield chunk

View file

@ -36,7 +36,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages,
**kwargs
) -> AsyncResult:
if "tools" not in kwargs and "images" not in kwargs and random.random() >= 0.5:
if "tools" not in kwargs and "media" not in kwargs and random.random() >= 0.5:
try:
is_started = False
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):

View file

@ -465,8 +465,6 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
if not line.startswith(b"data: "):
return
elif line.startswith(b"data: [DONE]"):
if fields.finish_reason is None:
fields.finish_reason = "error"
return
try:
line = json.loads(line[6:])

View file

@ -25,7 +25,7 @@ class BackendApi(AsyncGeneratorProvider, ProviderModelMixin):
debug.log(f"{cls.__name__}: {api_key}")
if media is not None:
for i in range(len(media)):
media[i] = (to_data_uri(media[i][0]), media[i][1])
media[i] = (to_data_uri(media[i][0], media[i][1]), media[i][1])
async with StreamSession(
headers={"Accept": "text/event-stream", **cls.headers},
) as session:

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import json
import requests
from ..helper import filter_none, format_image_prompt
@ -102,23 +101,19 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
if not model and hasattr(cls, "default_vision_model"):
model = cls.default_vision_model
last_message = messages[-1].copy()
last_message["content"] = [
*[
{
"type": "input_audio",
"input_audio": to_input_audio(media_data, filename)
}
if is_data_an_audio(media_data, filename) else {
"type": "image_url",
"image_url": {"url": to_data_uri(media_data, filename)}
}
for media_data, filename in media
],
image_content = [
{
"type": "text",
"text": last_message["content"]
} if isinstance(last_message["content"], str) else last_message["content"]
"type": "input_audio",
"input_audio": to_input_audio(media_data, filename)
}
if is_data_an_audio(media_data, filename) else {
"type": "image_url",
"image_url": {"url": to_data_uri(media_data)}
}
for media_data, filename in media
]
last_message["content"] = image_content + ([{"type": "text", "text": last_message["content"]}] if isinstance(last_message["content"], str) else image_content)
messages[-1] = last_message
extra_parameters = {key: kwargs[key] for key in extra_parameters if key in kwargs}
data = filter_none(
@ -145,7 +140,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
choice = data["choices"][0]
if "content" in choice["message"] and choice["message"]["content"]:
yield choice["message"]["content"].strip()
elif "tool_calls" in choice["message"]:
if "tool_calls" in choice["message"]:
yield ToolCalls(choice["message"]["tool_calls"])
if "usage" in data:
yield Usage(**data["usage"])
@ -155,26 +150,21 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
elif content_type.startswith("text/event-stream"):
await raise_for_status(response)
first = True
async for line in response.iter_lines():
if line.startswith(b"data: "):
chunk = line[6:]
if chunk == b"[DONE]":
break
data = json.loads(chunk)
cls.raise_error(data)
choice = data["choices"][0]
if "content" in choice["delta"] and choice["delta"]["content"]:
delta = choice["delta"]["content"]
if first:
delta = delta.lstrip()
if delta:
first = False
yield delta
if "usage" in data and data["usage"]:
yield Usage(**data["usage"])
if "finish_reason" in choice and choice["finish_reason"] is not None:
yield FinishReason(choice["finish_reason"])
break
async for data in response.sse():
cls.raise_error(data)
choice = data["choices"][0]
if "content" in choice["delta"] and choice["delta"]["content"]:
delta = choice["delta"]["content"]
if first:
delta = delta.lstrip()
if delta:
first = False
yield delta
if "usage" in data and data["usage"]:
yield Usage(**data["usage"])
if "finish_reason" in choice and choice["finish_reason"] is not None:
yield FinishReason(choice["finish_reason"])
break
else:
await raise_for_status(response)
raise ResponseError(f"Not supported content-type: {content_type}")

View file

@ -177,7 +177,7 @@ class Api:
if path.startswith("/v1") or path.startswith("/api/") or (AppConfig.demo and path == '/backend-api/v2/upload_cookies'):
if user_g4f_api_key is None:
return ErrorResponse.from_message("G4F API key required", HTTP_401_UNAUTHORIZED)
if not secrets.compare_digest(AppConfig.g4f_api_key, user_g4f_api_key):
if AppConfig.g4f_api_key is None or not secrets.compare_digest(AppConfig.g4f_api_key, user_g4f_api_key):
return ErrorResponse.from_message("Invalid G4F API key", HTTP_403_FORBIDDEN)
elif not AppConfig.demo and not path.startswith("/images/"):
if user_g4f_api_key is not None:
@ -305,10 +305,11 @@ class Api:
try:
if config.provider is None:
config.provider = AppConfig.provider if provider is None else provider
if credentials is not None:
if credentials is not None and credentials.credentials != "secret":
config.api_key = credentials.credentials
conversation = return_conversation = None
conversation = None
return_conversation = config.return_conversation
if conversation is not None:
conversation = JsonConversation(**conversation)
return_conversation = True
@ -328,7 +329,7 @@ class Api:
if config.media is not None:
for image in config.media:
try:
is_data_an_media(image[0])
is_data_an_media(image[0], image[1])
except ValueError as e:
example = json.dumps({"media": [["data:image/jpeg;base64,...", "filename.jpg"]]})
return ErrorResponse.from_message(f'The media you send must be a data URIs. Example: {example}', status_code=HTTP_422_UNPROCESSABLE_ENTITY)
@ -410,7 +411,7 @@ class Api:
config: ImageGenerationConfig,
credentials: Annotated[HTTPAuthorizationCredentials, Depends(Api.security)] = None
):
if credentials is not None:
if credentials is not None and credentials.credentials != "secret":
config.api_key = credentials.credentials
try:
response = await self.client.images.generate(
@ -637,11 +638,8 @@ def run_api(
port: int = None,
bind: str = None,
debug: bool = False,
workers: int = None,
use_colors: bool = None,
reload: bool = False,
ssl_keyfile: str = None,
ssl_certfile: str = None
**kwargs
) -> None:
print(f'Starting server... [g4f v-{g4f.version.utils.current_version}]' + (" (debug)" if debug else ""))
@ -665,10 +663,7 @@ def run_api(
f"g4f.api:{method}",
host=host,
port=int(port),
workers=workers,
use_colors=use_colors,
factory=True,
reload=reload,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile
use_colors=use_colors,
**filter_none(**kwargs)
)

View file

@ -31,6 +31,7 @@ class ChatCompletionsConfig(BaseModel):
proxy: Optional[str] = None
conversation_id: Optional[str] = None
conversation: Optional[dict] = None
return_conversation: Optional[bool] = None
history_disabled: Optional[bool] = None
timeout: Optional[int] = None
tool_calls: list = Field(default=[], examples=[[
@ -43,6 +44,12 @@ class ChatCompletionsConfig(BaseModel):
}
]])
tools: list = None
parallel_tool_calls: bool = None
tool_choice: Optional[str] = None
reasoning_effort: Optional[str] = None
logit_bias: Optional[dict] = None
modalities: Optional[list[str]] = None
audio: Optional[dict] = None
response_format: Optional[dict] = None
class ImageGenerationConfig(BaseModel):

View file

@ -32,6 +32,7 @@ def get_api_parser():
api_parser.add_argument("--ssl-keyfile", type=str, default=None, help="Path to SSL key file for HTTPS.")
api_parser.add_argument("--ssl-certfile", type=str, default=None, help="Path to SSL certificate file for HTTPS.")
api_parser.add_argument("--log-config", type=str, default=None, help="Custom log config.")
return api_parser
@ -74,7 +75,8 @@ def run_api_args(args):
use_colors=not args.disable_colors,
reload=args.reload,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile
ssl_certfile=args.ssl_certfile,
log_config=args.log_config,
)
if __name__ == "__main__":

View file

@ -162,14 +162,15 @@ async def async_iter_response(
tool_calls = None
usage = None
provider: ProviderInfo = None
conversation: JsonConversation = None
try:
async for chunk in response:
if isinstance(chunk, FinishReason):
finish_reason = chunk.reason
break
elif isinstance(chunk, BaseConversation):
yield chunk
elif isinstance(chunk, JsonConversation):
conversation = chunk
continue
elif isinstance(chunk, ToolCalls):
tool_calls = chunk.get_list()
@ -228,7 +229,8 @@ async def async_iter_response(
content, finish_reason, completion_id, int(time.time()), usage=usage,
**filter_none(
tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls]
) if tool_calls is not None else {}
) if tool_calls is not None else {},
conversation=None if conversation is None else conversation.get_dict()
)
if provider is not None:
chat_completion.provider = provider.name
@ -242,7 +244,6 @@ async def async_iter_append_model_and_provider(
last_model: str,
last_provider: ProviderType
) -> AsyncChatCompletionResponseType:
last_provider = None
try:
if isinstance(last_provider, BaseRetryProvider):
async for chunk in response:

View file

@ -132,6 +132,7 @@ class ChatCompletion(BaseModel):
provider: Optional[str]
choices: list[ChatCompletionChoice]
usage: UsageModel
conversation: dict
@classmethod
def model_construct(
@ -141,7 +142,8 @@ class ChatCompletion(BaseModel):
completion_id: str = None,
created: int = None,
tool_calls: list[ToolCallModel] = None,
usage: UsageModel = None
usage: UsageModel = None,
conversation: dict = None
):
return super().model_construct(
id=f"chatcmpl-{completion_id}" if completion_id else None,
@ -153,7 +155,7 @@ class ChatCompletion(BaseModel):
ChatCompletionMessage.model_construct(content, tool_calls),
finish_reason,
)],
**filter_none(usage=usage)
**filter_none(usage=usage, conversation=conversation)
)
class ChatCompletionDelta(BaseModel):

View file

@ -298,10 +298,12 @@
let oauthResult = localStorage.getItem("oauth");
if (oauthResult) {
let user;
try {
oauthResult = JSON.parse(oauthResult);
user = await hub.whoAmI({accessToken: oauthResult.accessToken});
} catch {
} catch (e) {
console.error(e);
oauthResult = null;
localStorage.removeItem("oauth");
localStorage.removeItem("HuggingFace-api_key");
@ -365,7 +367,7 @@
return;
}
const lower = data.prompt.toLowerCase();
const tags = ["nsfw", "timeline", "feet", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", " text ", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "hemale", "oprah", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"];
const tags = ["nsfw", "timeline", "feet", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "hemale", "oprah", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"];
for (i in tags) {
if (lower.indexOf(tags[i]) != -1) {
console.log("Skipping image with tag: " + tags[i]);

127
g4f/gui/client/qrcode.html Normal file
View file

@ -0,0 +1,127 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>QR Scanner and QR Code Generator</title>
<script src="https://cdn.jsdelivr.net/npm/qrcodejs/qrcode.min.js"></script>
<style>
body { font-family: Arial, sans-serif; text-align: center; margin: 20px; }
video { width: 300px; height: 300px; border: 1px solid black; display: block; margin: auto; }
#qrcode { margin-top: 20px; }
#qrcode img, #qrcode canvas { margin: 0 auto; width: 300px; height: 300px; }
button { margin: 5px; padding: 10px; }
</style>
</head>
<body>
<h1>QR Scanner & QR Code Generator</h1>
<h2>QR Code Scanner</h2>
<video id="video"></video>
<button id="startCamera">Start Camera</button>
<button id="stopScan">Stop Scanning</button>
<button id="switchCamera">Switch Camera</button>
<button id="toggleFlash">Toggle Flash</button>
<p id="cam-status"></p>
<h2>Generate QR Code</h2>
<div id="qrcode"></div>
<button id="generateQRCode">Generate QR Code</button>
<script type="module">
import QrScanner from 'https://cdn.jsdelivr.net/npm/qr-scanner/qr-scanner.min.js';
function generate_uuid() {
function random16Hex() { return (0x10000 | Math.random() * 0x10000).toString(16).substr(1); }
return random16Hex() + random16Hex() +
"-" + random16Hex() +
"-" + random16Hex() +
"-" + random16Hex() +
"-" + random16Hex() + random16Hex() + random16Hex();
}
const videoElem = document.getElementById('video');
const camStatus = document.getElementById('cam-status');
let qrScanner;
document.getElementById('startCamera').addEventListener('click', async () => {
startCamera();
});
document.getElementById('stopScan').addEventListener('click', () => {
qrScanner.stop();
});
document.getElementById('toggleFlash').addEventListener('click', async () => {
if (qrScanner) {
const hasFlash = await qrScanner.hasFlash();
if (hasFlash) {
qrScanner.toggleFlash();
} else {
alert('Flash not supported on this camera.');
}
}
});
localStorage.getItem('device_id') || localStorage.setItem('device_id', generate_uuid());
const qrcode = new QRCode(document.getElementById("qrcode"), {
text: JSON.stringify({
date: Math.floor(Date.now() / 1000),
device_id: localStorage.getItem('device_id')
}),
width: 128,
height: 128,
colorDark: "#000000",
colorLight: "#ffffff",
correctLevel: QRCode.CorrectLevel.H
});
const switchButton = document.getElementById('switchCamera');
let currentStream = null;
let facingMode = 'user';
async function startCamera() {
try {
if (currentStream) {
currentStream.getTracks().forEach(track => track.stop());
}
const constraints = {
video: {
width: { ideal: 1280 },
height: { ideal: 720 },
facingMode: facingMode
},
audio: false
};
const stream = await navigator.mediaDevices.getUserMedia(constraints);
currentStream = stream;
video.srcObject = stream;
qrScanner = new QrScanner(videoElem, result => {
camStatus.innerText = 'Camera Success: ' + result;
console.log('decoded QR code:', result);
}, {
highlightScanRegion: true,
highlightCodeOutline: true,
});
await video.play();
await qrScanner.start();
} catch (error) {
console.error('Error accessing the camera:', error);
alert(`Could not access the camera: ${error.message}`);
}
}
switchButton.addEventListener('click', () => {
facingMode = facingMode === 'user' ? 'environment' : 'user';
startCamera();
});
</script>
</body>
</html>

View file

@ -1081,11 +1081,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
}
try {
let api_key;
if (is_demo && ["OpenaiChat", "DeepSeekAPI", "PollinationsAI", "Gemini"].includes(provider)) {
api_key = localStorage.getItem("user");
} else if (["HuggingSpace", "G4F"].includes(provider)) {
api_key = localStorage.getItem("HuggingSpace-api_key");
} else if (is_demo) {
if (is_demo && !provider) {
api_key = localStorage.getItem("HuggingFace-api_key");
if (!api_key) {
location.href = "/";
@ -2111,7 +2107,9 @@ async function on_api() {
}
providerSelect.innerHTML = `
<option value="" selected="selected">Demo Mode</option>
<option value="ARTA">ARTA Provider</option>
<option value="DeepSeekAPI">DeepSeek Provider</option>
<option value="Grok">Grok Provider</option>
<option value="OpenaiChat">OpenAI Provider</option>
<option value="PollinationsAI">Pollinations AI</option>
<option value="G4F">G4F framework</option>
@ -2323,6 +2321,7 @@ async function load_version() {
}
function renderMediaSelect() {
mediaSelect.classList.remove("hidden");
const oldImages = mediaSelect.querySelectorAll("a:has(img)");
oldImages.forEach((el)=>el.remove());
Object.entries(image_storage).forEach(([object_url, file]) => {
@ -2333,7 +2332,9 @@ function renderMediaSelect() {
img.onclick = () => {
img.remove();
delete image_storage[object_url];
URL.revokeObjectURL(object_url)
if (file instanceof File) {
URL.revokeObjectURL(object_url)
}
}
img.onload = () => {
link.title += `\n${img.naturalWidth}x${img.naturalHeight}`;
@ -2349,9 +2350,11 @@ imageInput.onclick = () => {
mediaSelect.querySelector(".close").onclick = () => {
if (Object.values(image_storage).length) {
for (key in image_storage) {
URL.revokeObjectURL(key);
}
Object.entries(image_storage).forEach(([object_url, file]) => {
if (file instanceof File) {
URL.revokeObjectURL(object_url)
}
});
image_storage = {};
renderMediaSelect();
} else {
@ -2459,13 +2462,27 @@ async function upload_files(fileInput) {
Array.from(fileInput.files).forEach(file => {
formData.append('files', file);
});
await fetch("/backend-api/v2/files/" + bucket_id, {
const response = await fetch("/backend-api/v2/files/" + bucket_id, {
method: 'POST',
body: formData
});
let do_refine = document.getElementById("refine")?.checked;
connectToSSE(`/backend-api/v2/files/${bucket_id}`, do_refine, bucket_id);
const result = await response.json()
const count = result.files.length + result.media.length;
inputCount.innerText = `${count} File(s) uploaded successfully`;
if (result.files.length > 0) {
let do_refine = document.getElementById("refine")?.checked;
connectToSSE(`/backend-api/v2/files/${bucket_id}`, do_refine, bucket_id);
} else {
paperclip.classList.remove("blink");
fileInput.value = "";
}
if (result.media) {
result.media.forEach((filename)=> {
const url = `/backend-api/v2/files/${bucket_id}/media/${filename}`;
image_storage[url] = {bucket_id: bucket_id, name: filename};
});
renderMediaSelect();
}
}
fileInput.addEventListener('change', async (event) => {
@ -2580,7 +2597,9 @@ async function api(ressource, args=null, files=null, message_id=null, scroll=tru
if (files.length > 0) {
const formData = new FormData();
for (const file of files) {
formData.append('files', file)
if (file instanceof File) {
formData.append('files', file)
}
}
formData.append('json', body);
body = formData;

View file

@ -9,7 +9,7 @@ import shutil
import random
import datetime
import tempfile
from flask import Flask, Response, request, jsonify, render_template
from flask import Flask, Response, request, jsonify, render_template, send_from_directory
from typing import Generator
from pathlib import Path
from urllib.parse import quote_plus
@ -23,6 +23,7 @@ from ...client.helper import filter_markdown
from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_buckets
from ...tools.run_tools import iter_run_tools
from ...errors import ProviderNotFoundError
from ...image import is_allowed_extension
from ...cookies import get_cookies_dir
from ... import ChatCompletion
from ... import models
@ -66,6 +67,10 @@ class Backend_Api(Api):
def home():
return render_template('home.html')
@app.route('/qrcode', methods=['GET'])
def qrcode():
return render_template('qrcode.html')
@app.route('/backend-api/v2/models', methods=['GET'])
def jsonify_models(**kwargs):
response = get_demo_models() if app.demo else self.get_models(**kwargs)
@ -302,20 +307,38 @@ class Backend_Api(Api):
def upload_files(bucket_id: str):
bucket_id = secure_filename(bucket_id)
bucket_dir = get_bucket_dir(bucket_id)
media_dir = os.path.join(bucket_dir, "media")
os.makedirs(bucket_dir, exist_ok=True)
filenames = []
media = []
for file in request.files.getlist('files'):
try:
filename = secure_filename(file.filename)
if supports_filename(filename):
with open(os.path.join(bucket_dir, filename), 'wb') as f:
shutil.copyfileobj(file.stream, f)
if is_allowed_extension(filename):
os.makedirs(media_dir, exist_ok=True)
newfile = os.path.join(media_dir, filename)
media.append(filename)
elif supports_filename(filename):
newfile = os.path.join(bucket_dir, filename)
filenames.append(filename)
else:
continue
with open(newfile, 'wb') as f:
shutil.copyfileobj(file.stream, f)
finally:
file.stream.close()
with open(os.path.join(bucket_dir, "files.txt"), 'w') as f:
[f.write(f"{filename}\n") for filename in filenames]
return {"bucket_id": bucket_id, "files": filenames}
return {"bucket_id": bucket_id, "files": filenames, "media": media}
@app.route('/backend-api/v2/files/<bucket_id>/media/<filename>', methods=['GET'])
def get_media(bucket_id, filename):
bucket_id = secure_filename(bucket_id)
bucket_dir = get_bucket_dir(bucket_id)
media_dir = os.path.join(bucket_dir, "media")
if os.path.exists(media_dir):
return send_from_directory(os.path.abspath(media_dir), filename)
return "File not found", 404
@app.route('/backend-api/v2/files/<bucket_id>/<filename>', methods=['PUT'])
def upload_file(bucket_id, filename):

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import json
from aiohttp import ClientSession, ClientResponse, ClientTimeout, BaseConnector, FormData
from typing import AsyncIterator, Any, Optional
@ -18,6 +19,18 @@ class StreamResponse(ClientResponse):
async def json(self, content_type: str = None) -> Any:
return await super().json(content_type=content_type)
async def sse(self) -> AsyncIterator[dict]:
"""Asynchronously iterate over the Server-Sent Events of the response."""
async for line in self.content:
if line.startswith(b"data: "):
chunk = line[6:]
if chunk.startswith(b"[DONE]"):
break
try:
yield json.loads(chunk)
except json.JSONDecodeError:
continue
class StreamSession(ClientSession):
def __init__(
self,

View file

@ -47,6 +47,18 @@ class StreamResponse:
"""Asynchronously iterate over the response content."""
return self.inner.aiter_content()
async def sse(self) -> AsyncGenerator[dict, None]:
"""Asynchronously iterate over the Server-Sent Events of the response."""
async for line in self.iter_lines():
if line.startswith(b"data: "):
chunk = line[6:]
if chunk == b"[DONE]":
break
try:
yield json.loads(chunk)
except json.JSONDecodeError:
continue
async def __aenter__(self):
"""Asynchronously enter the runtime context for the response object."""
inner: Response = await self.inner