Update docs: Using the OpenAI Library

Add sse function to requests sessions
Small improvments in OpenaiChat and ARTA provider
This commit is contained in:
hlohaus 2025-03-22 07:32:30 +01:00
parent fa2344b031
commit 8f6efd5366
17 changed files with 291 additions and 86 deletions

View file

@ -8,7 +8,7 @@
- [From Repository](#from-repository)
- [Using the Interference API](#using-the-interference-api)
- [Basic Usage](#basic-usage)
- [With OpenAI Library](#with-openai-library)
- [Using the OpenAI Library](#using-the-openai-library)
- [With Requests Library](#with-requests-library)
- [Selecting a Provider](#selecting-a-provider)
- [Key Points](#key-points)
@ -95,36 +95,46 @@ curl -X POST "http://localhost:1337/v1/images/generate" \
}'
```
---
### With OpenAI Library
### Using the OpenAI Library
**To utilize the Inference API with the OpenAI Python library, you can specify the `base_url` to point to your endpoint:**
**You can use the Interference API with the OpenAI Python library by changing the `base_url`:**
```python
from openai import OpenAI
# Initialize the OpenAI client
client = OpenAI(
api_key="secret",
base_url="http://localhost:1337/v1"
api_key="secret", # Set an API key (use "secret" if your provider doesn't require one)
base_url="http://localhost:1337/v1" # Point to your local or custom API endpoint
)
# Create a chat completion request
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Write a poem about a tree"}],
stream=True,
model="gpt-4o-mini", # Specify the model to use
messages=[{"role": "user", "content": "Write a poem about a tree"}], # Define the input message
stream=True, # Enable streaming for real-time responses
)
# Handle the response
if isinstance(response, dict):
# Not streaming
# Non-streaming response
print(response.choices[0].message.content)
else:
# Streaming
# Streaming response
for token in response:
content = token.choices[0].delta.content
if content is not None:
print(content, end="", flush=True)
```
**Notes:**
- The `api_key` is required by the OpenAI Python library. If your provider does not require an API key, you can set it to `"secret"`. This value will be ignored by providers in G4F.
- Replace `"http://localhost:1337/v1"` with the appropriate URL for your custom or local inference API.
---
### With Requests Library

View file

@ -16,7 +16,7 @@ from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from .helper import format_image_prompt
class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://img-gen-prod.ai-arta.com"
url = "https://ai-arta.com"
auth_url = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/signupNewUser?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
token_refresh_url = "https://securetoken.googleapis.com/v1/token?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
image_generation_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image"

View file

@ -92,17 +92,17 @@ class HuggingFaceAPI(OpenaiTemplate):
model = provider_mapping[provider_key]["providerId"]
yield ProviderInfo(**{**cls.get_dict(), "label": f"HuggingFace ({provider_key})"})
break
start = calculate_lenght(messages)
if start > max_inputs_lenght:
if len(messages) > 6:
messages = messages[:3] + messages[-3:]
if calculate_lenght(messages) > max_inputs_lenght:
last_user_message = [{"role": "user", "content": get_last_user_message(messages)}]
if len(messages) > 2:
messages = [m for m in messages if m["role"] == "system"] + last_user_message
if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
messages = last_user_message
debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
# start = calculate_lenght(messages)
# if start > max_inputs_lenght:
# if len(messages) > 6:
# messages = messages[:3] + messages[-3:]
# if calculate_lenght(messages) > max_inputs_lenght:
# last_user_message = [{"role": "user", "content": get_last_user_message(messages)}]
# if len(messages) > 2:
# messages = [m for m in messages if m["role"] == "system"] + last_user_message
# if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
# messages = last_user_message
# debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, media=media, **kwargs):
yield chunk

View file

@ -36,7 +36,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages,
**kwargs
) -> AsyncResult:
if "tools" not in kwargs and "images" not in kwargs and random.random() >= 0.5:
if "tools" not in kwargs and "media" not in kwargs and random.random() >= 0.5:
try:
is_started = False
async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):

View file

@ -465,8 +465,6 @@ class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
if not line.startswith(b"data: "):
return
elif line.startswith(b"data: [DONE]"):
if fields.finish_reason is None:
fields.finish_reason = "error"
return
try:
line = json.loads(line[6:])

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import json
import requests
from ..helper import filter_none, format_image_prompt
@ -141,7 +140,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
choice = data["choices"][0]
if "content" in choice["message"] and choice["message"]["content"]:
yield choice["message"]["content"].strip()
elif "tool_calls" in choice["message"]:
if "tool_calls" in choice["message"]:
yield ToolCalls(choice["message"]["tool_calls"])
if "usage" in data:
yield Usage(**data["usage"])
@ -151,12 +150,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
elif content_type.startswith("text/event-stream"):
await raise_for_status(response)
first = True
async for line in response.iter_lines():
if line.startswith(b"data: "):
chunk = line[6:]
if chunk == b"[DONE]":
break
data = json.loads(chunk)
async for data in response.sse():
cls.raise_error(data)
choice = data["choices"][0]
if "content" in choice["delta"] and choice["delta"]["content"]:

View file

@ -308,7 +308,8 @@ class Api:
if credentials is not None and credentials.credentials != "secret":
config.api_key = credentials.credentials
conversation = return_conversation = None
conversation = None
return_conversation = config.return_conversation
if conversation is not None:
conversation = JsonConversation(**conversation)
return_conversation = True
@ -637,11 +638,8 @@ def run_api(
port: int = None,
bind: str = None,
debug: bool = False,
workers: int = None,
use_colors: bool = None,
reload: bool = False,
ssl_keyfile: str = None,
ssl_certfile: str = None
**kwargs
) -> None:
print(f'Starting server... [g4f v-{g4f.version.utils.current_version}]' + (" (debug)" if debug else ""))
@ -665,10 +663,7 @@ def run_api(
f"g4f.api:{method}",
host=host,
port=int(port),
workers=workers,
use_colors=use_colors,
factory=True,
reload=reload,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile
use_colors=use_colors,
**filter_none(**kwargs)
)

View file

@ -31,6 +31,7 @@ class ChatCompletionsConfig(BaseModel):
proxy: Optional[str] = None
conversation_id: Optional[str] = None
conversation: Optional[dict] = None
return_conversation: Optional[bool] = None
history_disabled: Optional[bool] = None
timeout: Optional[int] = None
tool_calls: list = Field(default=[], examples=[[
@ -43,6 +44,12 @@ class ChatCompletionsConfig(BaseModel):
}
]])
tools: list = None
parallel_tool_calls: bool = None
tool_choice: Optional[str] = None
reasoning_effort: Optional[str] = None
logit_bias: Optional[dict] = None
modalities: Optional[list[str]] = None
audio: Optional[dict] = None
response_format: Optional[dict] = None
class ImageGenerationConfig(BaseModel):

View file

@ -32,6 +32,7 @@ def get_api_parser():
api_parser.add_argument("--ssl-keyfile", type=str, default=None, help="Path to SSL key file for HTTPS.")
api_parser.add_argument("--ssl-certfile", type=str, default=None, help="Path to SSL certificate file for HTTPS.")
api_parser.add_argument("--log-config", type=str, default=None, help="Custom log config.")
return api_parser
@ -74,7 +75,8 @@ def run_api_args(args):
use_colors=not args.disable_colors,
reload=args.reload,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile
ssl_certfile=args.ssl_certfile,
log_config=args.log_config,
)
if __name__ == "__main__":

View file

@ -162,14 +162,15 @@ async def async_iter_response(
tool_calls = None
usage = None
provider: ProviderInfo = None
conversation: JsonConversation = None
try:
async for chunk in response:
if isinstance(chunk, FinishReason):
finish_reason = chunk.reason
break
elif isinstance(chunk, BaseConversation):
yield chunk
elif isinstance(chunk, JsonConversation):
conversation = chunk
continue
elif isinstance(chunk, ToolCalls):
tool_calls = chunk.get_list()
@ -228,7 +229,8 @@ async def async_iter_response(
content, finish_reason, completion_id, int(time.time()), usage=usage,
**filter_none(
tool_calls=[ToolCallModel.model_construct(**tool_call) for tool_call in tool_calls]
) if tool_calls is not None else {}
) if tool_calls is not None else {},
conversation=None if conversation is None else conversation.get_dict()
)
if provider is not None:
chat_completion.provider = provider.name
@ -242,7 +244,6 @@ async def async_iter_append_model_and_provider(
last_model: str,
last_provider: ProviderType
) -> AsyncChatCompletionResponseType:
last_provider = None
try:
if isinstance(last_provider, BaseRetryProvider):
async for chunk in response:

View file

@ -132,6 +132,7 @@ class ChatCompletion(BaseModel):
provider: Optional[str]
choices: list[ChatCompletionChoice]
usage: UsageModel
conversation: dict
@classmethod
def model_construct(
@ -141,7 +142,8 @@ class ChatCompletion(BaseModel):
completion_id: str = None,
created: int = None,
tool_calls: list[ToolCallModel] = None,
usage: UsageModel = None
usage: UsageModel = None,
conversation: dict = None
):
return super().model_construct(
id=f"chatcmpl-{completion_id}" if completion_id else None,
@ -153,7 +155,7 @@ class ChatCompletion(BaseModel):
ChatCompletionMessage.model_construct(content, tool_calls),
finish_reason,
)],
**filter_none(usage=usage)
**filter_none(usage=usage, conversation=conversation)
)
class ChatCompletionDelta(BaseModel):

View file

@ -298,10 +298,12 @@
let oauthResult = localStorage.getItem("oauth");
if (oauthResult) {
let user;
try {
oauthResult = JSON.parse(oauthResult);
user = await hub.whoAmI({accessToken: oauthResult.accessToken});
} catch {
} catch (e) {
console.error(e);
oauthResult = null;
localStorage.removeItem("oauth");
localStorage.removeItem("HuggingFace-api_key");
@ -365,7 +367,7 @@
return;
}
const lower = data.prompt.toLowerCase();
const tags = ["nsfw", "timeline", "feet", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", " text ", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "hemale", "oprah", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"];
const tags = ["nsfw", "timeline", "feet", "blood", "soap", "orally", "heel", "latex", "bathroom", "boobs", "charts", "gel", "logo", "infographic", "warts", " bra ", "prostitute", "curvy", "breasts", "written", "bodies", "naked", "classroom", "malone", "dirty", "shoes", "shower", "banner", "fat", "nipples", "couple", "sexual", "sandal", "supplier", "overlord", "succubus", "platinum", "cracy", "crazy", "hemale", "oprah", "lamic", "ropes", "cables", "wires", "dirty", "messy", "cluttered", "chaotic", "disorganized", "disorderly", "untidy", "unorganized", "unorderly", "unsystematic", "disarranged", "disarrayed", "disheveled", "disordered", "jumbled", "muddled", "scattered", "shambolic", "sloppy", "unkept", "unruly"];
for (i in tags) {
if (lower.indexOf(tags[i]) != -1) {
console.log("Skipping image with tag: " + tags[i]);

127
g4f/gui/client/qrcode.html Normal file
View file

@ -0,0 +1,127 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>QR Scanner and QR Code Generator</title>
<script src="https://cdn.jsdelivr.net/npm/qrcodejs/qrcode.min.js"></script>
<style>
body { font-family: Arial, sans-serif; text-align: center; margin: 20px; }
video { width: 300px; height: 300px; border: 1px solid black; display: block; margin: auto; }
#qrcode { margin-top: 20px; }
#qrcode img, #qrcode canvas { margin: 0 auto; width: 300px; height: 300px; }
button { margin: 5px; padding: 10px; }
</style>
</head>
<body>
<h1>QR Scanner & QR Code Generator</h1>
<h2>QR Code Scanner</h2>
<video id="video"></video>
<button id="startCamera">Start Camera</button>
<button id="stopScan">Stop Scanning</button>
<button id="switchCamera">Switch Camera</button>
<button id="toggleFlash">Toggle Flash</button>
<p id="cam-status"></p>
<h2>Generate QR Code</h2>
<div id="qrcode"></div>
<button id="generateQRCode">Generate QR Code</button>
<script type="module">
import QrScanner from 'https://cdn.jsdelivr.net/npm/qr-scanner/qr-scanner.min.js';
function generate_uuid() {
function random16Hex() { return (0x10000 | Math.random() * 0x10000).toString(16).substr(1); }
return random16Hex() + random16Hex() +
"-" + random16Hex() +
"-" + random16Hex() +
"-" + random16Hex() +
"-" + random16Hex() + random16Hex() + random16Hex();
}
const videoElem = document.getElementById('video');
const camStatus = document.getElementById('cam-status');
let qrScanner;
document.getElementById('startCamera').addEventListener('click', async () => {
startCamera();
});
document.getElementById('stopScan').addEventListener('click', () => {
qrScanner.stop();
});
document.getElementById('toggleFlash').addEventListener('click', async () => {
if (qrScanner) {
const hasFlash = await qrScanner.hasFlash();
if (hasFlash) {
qrScanner.toggleFlash();
} else {
alert('Flash not supported on this camera.');
}
}
});
localStorage.getItem('device_id') || localStorage.setItem('device_id', generate_uuid());
const qrcode = new QRCode(document.getElementById("qrcode"), {
text: JSON.stringify({
date: Math.floor(Date.now() / 1000),
device_id: localStorage.getItem('device_id')
}),
width: 128,
height: 128,
colorDark: "#000000",
colorLight: "#ffffff",
correctLevel: QRCode.CorrectLevel.H
});
const switchButton = document.getElementById('switchCamera');
let currentStream = null;
let facingMode = 'user';
async function startCamera() {
try {
if (currentStream) {
currentStream.getTracks().forEach(track => track.stop());
}
const constraints = {
video: {
width: { ideal: 1280 },
height: { ideal: 720 },
facingMode: facingMode
},
audio: false
};
const stream = await navigator.mediaDevices.getUserMedia(constraints);
currentStream = stream;
video.srcObject = stream;
qrScanner = new QrScanner(videoElem, result => {
camStatus.innerText = 'Camera Success: ' + result;
console.log('decoded QR code:', result);
}, {
highlightScanRegion: true,
highlightCodeOutline: true,
});
await video.play();
await qrScanner.start();
} catch (error) {
console.error('Error accessing the camera:', error);
alert(`Could not access the camera: ${error.message}`);
}
}
switchButton.addEventListener('click', () => {
facingMode = facingMode === 'user' ? 'environment' : 'user';
startCamera();
});
</script>
</body>
</html>

View file

@ -1081,11 +1081,7 @@ const ask_gpt = async (message_id, message_index = -1, regenerate = false, provi
}
try {
let api_key;
if (is_demo && ["OpenaiChat", "DeepSeekAPI", "PollinationsAI", "Gemini"].includes(provider)) {
api_key = localStorage.getItem("user");
} else if (["HuggingSpace", "G4F"].includes(provider)) {
api_key = localStorage.getItem("HuggingSpace-api_key");
} else if (is_demo) {
if (is_demo && !provider) {
api_key = localStorage.getItem("HuggingFace-api_key");
if (!api_key) {
location.href = "/";
@ -2111,7 +2107,9 @@ async function on_api() {
}
providerSelect.innerHTML = `
<option value="" selected="selected">Demo Mode</option>
<option value="ARTA">ARTA Provider</option>
<option value="DeepSeekAPI">DeepSeek Provider</option>
<option value="Grok">Grok Provider</option>
<option value="OpenaiChat">OpenAI Provider</option>
<option value="PollinationsAI">Pollinations AI</option>
<option value="G4F">G4F framework</option>
@ -2323,6 +2321,7 @@ async function load_version() {
}
function renderMediaSelect() {
mediaSelect.classList.remove("hidden");
const oldImages = mediaSelect.querySelectorAll("a:has(img)");
oldImages.forEach((el)=>el.remove());
Object.entries(image_storage).forEach(([object_url, file]) => {
@ -2333,8 +2332,10 @@ function renderMediaSelect() {
img.onclick = () => {
img.remove();
delete image_storage[object_url];
if (file instanceof File) {
URL.revokeObjectURL(object_url)
}
}
img.onload = () => {
link.title += `\n${img.naturalWidth}x${img.naturalHeight}`;
};
@ -2349,9 +2350,11 @@ imageInput.onclick = () => {
mediaSelect.querySelector(".close").onclick = () => {
if (Object.values(image_storage).length) {
for (key in image_storage) {
URL.revokeObjectURL(key);
Object.entries(image_storage).forEach(([object_url, file]) => {
if (file instanceof File) {
URL.revokeObjectURL(object_url)
}
});
image_storage = {};
renderMediaSelect();
} else {
@ -2459,13 +2462,27 @@ async function upload_files(fileInput) {
Array.from(fileInput.files).forEach(file => {
formData.append('files', file);
});
await fetch("/backend-api/v2/files/" + bucket_id, {
const response = await fetch("/backend-api/v2/files/" + bucket_id, {
method: 'POST',
body: formData
});
const result = await response.json()
const count = result.files.length + result.media.length;
inputCount.innerText = `${count} File(s) uploaded successfully`;
if (result.files.length > 0) {
let do_refine = document.getElementById("refine")?.checked;
connectToSSE(`/backend-api/v2/files/${bucket_id}`, do_refine, bucket_id);
} else {
paperclip.classList.remove("blink");
fileInput.value = "";
}
if (result.media) {
result.media.forEach((filename)=> {
const url = `/backend-api/v2/files/${bucket_id}/media/${filename}`;
image_storage[url] = {bucket_id: bucket_id, name: filename};
});
renderMediaSelect();
}
}
fileInput.addEventListener('change', async (event) => {
@ -2580,8 +2597,10 @@ async function api(ressource, args=null, files=null, message_id=null, scroll=tru
if (files.length > 0) {
const formData = new FormData();
for (const file of files) {
if (file instanceof File) {
formData.append('files', file)
}
}
formData.append('json', body);
body = formData;
} else {

View file

@ -9,7 +9,7 @@ import shutil
import random
import datetime
import tempfile
from flask import Flask, Response, request, jsonify, render_template
from flask import Flask, Response, request, jsonify, render_template, send_from_directory
from typing import Generator
from pathlib import Path
from urllib.parse import quote_plus
@ -23,6 +23,7 @@ from ...client.helper import filter_markdown
from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_buckets
from ...tools.run_tools import iter_run_tools
from ...errors import ProviderNotFoundError
from ...image import is_allowed_extension
from ...cookies import get_cookies_dir
from ... import ChatCompletion
from ... import models
@ -66,6 +67,10 @@ class Backend_Api(Api):
def home():
return render_template('home.html')
@app.route('/qrcode', methods=['GET'])
def qrcode():
return render_template('qrcode.html')
@app.route('/backend-api/v2/models', methods=['GET'])
def jsonify_models(**kwargs):
response = get_demo_models() if app.demo else self.get_models(**kwargs)
@ -302,20 +307,38 @@ class Backend_Api(Api):
def upload_files(bucket_id: str):
bucket_id = secure_filename(bucket_id)
bucket_dir = get_bucket_dir(bucket_id)
media_dir = os.path.join(bucket_dir, "media")
os.makedirs(bucket_dir, exist_ok=True)
filenames = []
media = []
for file in request.files.getlist('files'):
try:
filename = secure_filename(file.filename)
if supports_filename(filename):
with open(os.path.join(bucket_dir, filename), 'wb') as f:
shutil.copyfileobj(file.stream, f)
if is_allowed_extension(filename):
os.makedirs(media_dir, exist_ok=True)
newfile = os.path.join(media_dir, filename)
media.append(filename)
elif supports_filename(filename):
newfile = os.path.join(bucket_dir, filename)
filenames.append(filename)
else:
continue
with open(newfile, 'wb') as f:
shutil.copyfileobj(file.stream, f)
finally:
file.stream.close()
with open(os.path.join(bucket_dir, "files.txt"), 'w') as f:
[f.write(f"{filename}\n") for filename in filenames]
return {"bucket_id": bucket_id, "files": filenames}
return {"bucket_id": bucket_id, "files": filenames, "media": media}
@app.route('/backend-api/v2/files/<bucket_id>/media/<filename>', methods=['GET'])
def get_media(bucket_id, filename):
bucket_id = secure_filename(bucket_id)
bucket_dir = get_bucket_dir(bucket_id)
media_dir = os.path.join(bucket_dir, "media")
if os.path.exists(media_dir):
return send_from_directory(os.path.abspath(media_dir), filename)
return "File not found", 404
@app.route('/backend-api/v2/files/<bucket_id>/<filename>', methods=['PUT'])
def upload_file(bucket_id, filename):

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import json
from aiohttp import ClientSession, ClientResponse, ClientTimeout, BaseConnector, FormData
from typing import AsyncIterator, Any, Optional
@ -18,6 +19,18 @@ class StreamResponse(ClientResponse):
async def json(self, content_type: str = None) -> Any:
return await super().json(content_type=content_type)
async def sse(self) -> AsyncIterator[dict]:
"""Asynchronously iterate over the Server-Sent Events of the response."""
async for line in self.content:
if line.startswith(b"data: "):
chunk = line[6:]
if chunk.startswith(b"[DONE]"):
break
try:
yield json.loads(chunk)
except json.JSONDecodeError:
continue
class StreamSession(ClientSession):
def __init__(
self,

View file

@ -47,6 +47,18 @@ class StreamResponse:
"""Asynchronously iterate over the response content."""
return self.inner.aiter_content()
async def sse(self) -> AsyncGenerator[dict, None]:
"""Asynchronously iterate over the Server-Sent Events of the response."""
async for line in self.iter_lines():
if line.startswith(b"data: "):
chunk = line[6:]
if chunk == b"[DONE]":
break
try:
yield json.loads(chunk)
except json.JSONDecodeError:
continue
async def __aenter__(self):
"""Asynchronously enter the runtime context for the response object."""
inner: Response = await self.inner