Add login_url to authed providers

Create a home page for the GUI
Fix CopyButton in Code Highlight
This commit is contained in:
Heiner Lohaus 2025-01-03 02:40:21 +01:00
parent 5b74a22cf9
commit 486e9a9122
32 changed files with 502 additions and 88 deletions

View file

@ -70,8 +70,7 @@ def read_json(text: str) -> dict:
try:
return json.loads(text.strip())
except json.JSONDecodeError:
print("No valid json:", text)
return {}
raise RuntimeError(f"Invalid JSON: {text}")
def read_text(text: str) -> str:
"""
@ -86,7 +85,8 @@ def read_text(text: str) -> str:
match = re.search(r"```(markdown|)\n(?P<text>[\S\s]+?)\n```", text)
if match:
return match.group("text")
return text
else:
raise RuntimeError(f"Invalid markdown: {text}")
def get_ai_response(prompt: str, as_json: bool = True) -> Union[dict, str]:
"""
@ -197,6 +197,7 @@ def create_review_prompt(pull: PullRequest, diff: str):
return f"""Your task is to review a pull request. Instructions:
- Write in name of g4f copilot. Don't use placeholder.
- Write the review in GitHub Markdown format.
- Enclose your response in backticks ```response```
- Thank the author for contributing to the project.
Pull request author: {pull.user.name}

View file

@ -1,36 +1,11 @@
import unittest
import asyncio
import g4f
from g4f import ChatCompletion, get_last_provider
import g4f.version
from g4f.errors import VersionNotFoundError
from g4f.Provider import RetryProvider
from .mocks import ProviderMock
DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
class TestGetLastProvider(unittest.TestCase):
def test_get_last_provider(self):
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
self.assertEqual(get_last_provider(), ProviderMock)
def test_get_last_provider_retry(self):
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, RetryProvider([ProviderMock]))
self.assertEqual(get_last_provider(), ProviderMock)
def test_get_last_provider_async(self):
coroutine = ChatCompletion.create_async(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
asyncio.run(coroutine)
self.assertEqual(get_last_provider(), ProviderMock)
def test_get_last_provider_as_dict(self):
ChatCompletion.create(g4f.models.default, DEFAULT_MESSAGES, ProviderMock)
last_provider_dict = get_last_provider(True)
self.assertIsInstance(last_provider_dict, dict)
self.assertIn('name', last_provider_dict)
self.assertEqual(ProviderMock.__name__, last_provider_dict['name'])
def test_get_latest_version(self):
try:
self.assertIsInstance(g4f.version.utils.current_version, str)

View file

@ -35,7 +35,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
supports_system_message = True
supports_message_history = True
default_model = "gpt-4o-mini"
default_model = "llama-3.1-70b-chat"
default_image_model = "flux"
models = []
@ -113,7 +113,7 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
@classmethod
def get_model(cls, model: str) -> str:
"""Get the actual model name from alias"""
return cls.model_aliases.get(model, model)
return cls.model_aliases.get(model, model or cls.default_model)
@classmethod
async def check_api_key(cls, api_key: str) -> bool:
@ -162,6 +162,9 @@ class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
"""
Filters the full response to remove system errors and other unwanted text.
"""
if "Model not found or too long input. Or any other error (xD)" in response:
raise ValueError(response)
filtered_response = re.sub(r"\[ERROR\] '\w{8}-\w{4}-\w{4}-\w{4}-\w{12}'", '', response) # any-uncensored
filtered_response = re.sub(r'<\|im_end\|>', '', filtered_response) # remove <|im_end|> token
filtered_response = re.sub(r'</s>', '', filtered_response) # neural-chat-7b-v3-1

View file

@ -246,6 +246,8 @@ class BlackboxCreateAgent(AsyncGeneratorProvider, ProviderModelMixin):
Returns:
AsyncResult: The response from the provider
"""
if not model:
model = cls.default_model
if model in cls.chat_models:
async for text in cls._generate_text(model, messages, proxy=proxy, **kwargs):
return text

View file

@ -80,4 +80,6 @@ class ChatGptEs(AsyncGeneratorProvider, ProviderModelMixin):
async with session.post(cls.api_endpoint, headers=headers, data=payload) as response:
response.raise_for_status()
result = await response.json()
if "Du musst das Kästchen anklicken!" in result['data']:
raise ValueError(result['data'])
yield result['data']

View file

@ -127,8 +127,8 @@ class Copilot(AbstractProvider, ProviderModelMixin):
response = session.post(cls.conversation_url)
raise_for_status(response)
conversation_id = response.json().get("id")
conversation = Conversation(conversation_id)
if return_conversation:
conversation = Conversation(conversation_id)
yield conversation
if prompt is None:
prompt = format_prompt_max_length(messages, 10000)

View file

@ -15,7 +15,8 @@ from .needs_auth.OpenaiAPI import OpenaiAPI
"""
class Mhystical(OpenaiAPI):
url = "https://api.mhystical.cc"
label = "Mhystical"
url = "https://mhystical.cc"
api_endpoint = "https://api.mhystical.cc/v1/completions"
working = True
needs_auth = False

View file

@ -16,6 +16,7 @@ from .OpenaiAPI import OpenaiAPI
class Anthropic(OpenaiAPI):
label = "Anthropic API"
url = "https://console.anthropic.com"
login_url = "https://console.anthropic.com/settings/keys"
working = True
api_base = "https://api.anthropic.com/v1"
needs_auth = True

View file

@ -10,6 +10,7 @@ from ...cookies import get_cookies
class Cerebras(OpenaiAPI):
label = "Cerebras Inference"
url = "https://inference.cerebras.ai/"
login_url = "https://cloud.cerebras.ai"
api_base = "https://api.cerebras.ai/v1"
working = True
default_model = "llama3.1-70b"

View file

@ -16,6 +16,7 @@ from ... import debug
class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
label = "Google Gemini API"
url = "https://ai.google.dev"
login_url = "https://aistudio.google.com/u/0/apikey"
api_base = "https://generativelanguage.googleapis.com/v1beta"
working = True

View file

@ -5,6 +5,7 @@ from .OpenaiAPI import OpenaiAPI
class Groq(OpenaiAPI):
label = "Groq"
url = "https://console.groq.com/playground"
login_url = "https://console.groq.com/keys"
api_base = "https://api.groq.com/openai/v1"
working = True
default_model = "mixtral-8x7b-32768"

View file

@ -6,6 +6,7 @@ from .HuggingChat import HuggingChat
class HuggingFaceAPI(OpenaiAPI):
label = "HuggingFace (Inference API)"
url = "https://api-inference.huggingface.co"
login_url = "https://huggingface.co/settings/tokens"
api_base = "https://api-inference.huggingface.co/v1"
working = True
default_model = "meta-llama/Llama-3.2-11B-Vision-Instruct"

View file

@ -15,6 +15,7 @@ from ... import debug
class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
label = "OpenAI API"
url = "https://platform.openai.com"
login_url = "https://platform.openai.com/settings/organization/api-keys"
api_base = "https://api.openai.com/v1"
working = True
needs_auth = True

View file

@ -5,6 +5,7 @@ from .OpenaiAPI import OpenaiAPI
class PerplexityApi(OpenaiAPI):
label = "Perplexity API"
url = "https://www.perplexity.ai"
login_url = "https://www.perplexity.ai/settings/api"
working = True
api_base = "https://api.perplexity.ai"
default_model = "llama-3-sonar-large-32k-online"

View file

@ -9,6 +9,7 @@ from ...errors import ResponseError, MissingAuthError
class Replicate(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://replicate.com"
login_url = "https://replicate.com/account/api-tokens"
working = True
needs_auth = True
default_model = "meta/meta-llama-3-70b-instruct"

View file

@ -5,6 +5,7 @@ from .OpenaiAPI import OpenaiAPI
class glhfChat(OpenaiAPI):
label = "glhf.chat"
url = "https://glhf.chat"
login_url = "https://glhf.chat/users/settings/api"
api_base = "https://glhf.chat/api/openai/v1"
working = True
model_aliases = {

View file

@ -5,5 +5,6 @@ from .OpenaiAPI import OpenaiAPI
class xAI(OpenaiAPI):
label = "xAI"
url = "https://console.x.ai"
login_url = "https://console.x.ai"
api_base = "https://api.x.ai/v1"
working = True

View file

@ -12,7 +12,7 @@ from .errors import StreamNotSupportedError
from .cookies import get_cookies, set_cookies
from .providers.types import ProviderType
from .providers.helper import concat_chunks
from .client.service import get_model_and_provider, get_last_provider
from .client.service import get_model_and_provider
#Configure "g4f" logger
logger = logging.getLogger(__name__)
@ -47,7 +47,8 @@ class ChatCompletion:
if ignore_stream:
kwargs["ignore_stream"] = True
result = provider.create_completion(model, messages, stream=stream, **kwargs)
create_method = provider.create_authed if hasattr(provider, "create_authed") else provider.create_completion
result = create_method(model, messages, stream=stream, **kwargs)
return result if stream else concat_chunks(result)
@ -72,7 +73,9 @@ class ChatCompletion:
kwargs["ignore_stream"] = True
if stream:
if hasattr(provider, "create_async_generator"):
if hasattr(provider, "create_async_authed_generator"):
return provider.create_async_authed_generator(model, messages, **kwargs)
elif hasattr(provider, "create_async_generator"):
return provider.create_async_generator(model, messages, **kwargs)
raise StreamNotSupportedError(f'{provider.__name__} does not support "stream" argument in "create_async"')

View file

@ -197,11 +197,10 @@ class Api:
)
def register_routes(self):
@self.app.get("/")
async def read_root():
if AppConfig.gui:
return RedirectResponse("/chat/", 302)
return RedirectResponse("/v1", 302)
if not AppConfig.gui:
@self.app.get("/")
async def read_root():
return RedirectResponse("/v1", 302)
@self.app.get("/v1")
async def read_root_v1():

View file

@ -7,9 +7,9 @@ import string
import asyncio
import aiohttp
import base64
from typing import Union, AsyncIterator, Iterator, Coroutine, Optional
from typing import Union, AsyncIterator, Iterator, Awaitable, Optional
from ..image import ImageResponse, copy_images, images_dir
from ..image import ImageResponse, copy_images
from ..typing import Messages, ImageType
from ..providers.types import ProviderType, BaseRetryProvider
from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage
@ -22,7 +22,7 @@ from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
from .image_models import ImageModels
from .types import IterResponse, ImageProvider, Client as BaseClient
from .service import get_model_and_provider, convert_to_provider
from .helper import find_stop, filter_json, filter_none, safe_aclose
from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator
from .. import debug
ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
@ -220,7 +220,7 @@ class Completions:
ignore_working: Optional[bool] = False,
ignore_stream: Optional[bool] = False,
**kwargs
) -> IterResponse:
) -> ChatCompletion:
model, provider = get_model_and_provider(
model,
self.provider if provider is None else provider,
@ -236,7 +236,7 @@ class Completions:
kwargs["ignore_stream"] = True
response = iter_run_tools(
provider.create_completion,
provider.create_authed if hasattr(provider, "create_authed") else provider.create_completion,
model,
messages,
stream=stream,
@ -248,9 +248,6 @@ class Completions:
),
**kwargs
)
if asyncio.iscoroutinefunction(provider.create_completion):
# Run the asynchronous function in an event loop
response = asyncio.run(response)
if stream and hasattr(response, '__aiter__'):
# It's an async generator, wrap it into a sync iterator
response = to_sync_generator(response)
@ -264,6 +261,14 @@ class Completions:
else:
return next(response)
def stream(
self,
messages: Messages,
model: str,
**kwargs
) -> IterResponse:
return self.create(messages, model, stream=True, **kwargs)
class Chat:
completions: Completions
@ -507,7 +512,7 @@ class AsyncCompletions:
ignore_working: Optional[bool] = False,
ignore_stream: Optional[bool] = False,
**kwargs
) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk, BaseConversation]]:
) -> Awaitable[ChatCompletion]:
model, provider = get_model_and_provider(
model,
self.provider if provider is None else provider,
@ -521,6 +526,8 @@ class AsyncCompletions:
kwargs["images"] = [(image, image_name)]
if ignore_stream:
kwargs["ignore_stream"] = True
if hasattr(provider, "create_async_authed_generator"):
create_handler = provider.create_async_authed_generator
if hasattr(provider, "create_async_generator"):
create_handler = provider.create_async_generator
else:
@ -538,10 +545,20 @@ class AsyncCompletions:
),
**kwargs
)
if not hasattr(response, '__aiter__'):
response = to_async_iterator(response)
response = async_iter_response(response, stream, response_format, max_tokens, stop)
response = async_iter_append_model_and_provider(response, model, provider)
return response if stream else anext(response)
def stream(
self,
messages: Messages,
model: str,
**kwargs
) -> AsyncIterator[ChatCompletionChunk, BaseConversation]:
return self.create(messages, model, stream=True, **kwargs)
class AsyncImages(Images):
def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
self.client: AsyncClient = client

View file

@ -5,6 +5,22 @@ import logging
from typing import AsyncIterator, Iterator, AsyncGenerator, Optional
def filter_markdown(text: str, allowd_types=None, default=None) -> str:
"""
Parses code block from a string.
Args:
text (str): A string containing a code block.
Returns:
dict: A dictionary parsed from the code block.
"""
match = re.search(r"```(.+)\n(?P<code>[\S\s]+?)(\n```|$)", text)
if match:
if allowd_types is None or match.group(1) in allowd_types:
return match.group("code")
return default
def filter_json(text: str) -> str:
"""
Parses JSON code block from a string.
@ -15,10 +31,7 @@ def filter_json(text: str) -> str:
Returns:
dict: A dictionary parsed from the JSON code block.
"""
match = re.search(r"```(json|)\n(?P<code>[\S\s]+?)\n```", text)
if match:
return match.group("code")
return text
return filter_markdown(text, ["", "json"], text)
def find_stop(stop: Optional[list[str]], content: str, chunk: str = None):
first = -1

230
g4f/gui/client/home.html Normal file
View file

@ -0,0 +1,230 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>G4F GUI</title>
<style>
:root {
--colour-1: #000000;
--colour-2: #ccc;
--colour-3: #e4d4ff;
--colour-4: #f0f0f0;
--colour-5: #181818;
--colour-6: #242424;
--accent: #8b3dff;
--gradient: #1a1a1a;
--background: #16101b;
--size: 70vw;
--top: 50%;
--blur: 40px;
--opacity: 0.6;
}
@import url("https://fonts.googleapis.com/css2?family=Inter:wght@100;200;300;400;500;600;700;800;900&display=swap");
.gradient {
position: absolute;
z-index: -1;
left: 50vw;
border-radius: 50%;
background: radial-gradient(circle at center, var(--accent), var(--gradient));
width: var(--size);
height: var(--size);
top: var(--top);
transform: translate(-50%, -50%);
filter: blur(var(--blur)) opacity(var(--opacity));
animation: zoom_gradient 6s infinite alternate;
display: none;
max-height: 100%;
transition: max-height 0.25s ease-in;
}
.gradient.hidden {
max-height: 0;
transition: max-height 0.15s ease-out;
}
@media only screen and (min-width: 40em) {
body .gradient{
display: block;
}
}
@keyframes zoom_gradient {
0% {
transform: translate(-50%, -50%) scale(1);
}
100% {
transform: translate(-50%, -50%) scale(1.2);
}
}
/* Body and text color */
body {
background: var(--background);
color: var(--colour-3);
font-family: "Inter", sans-serif;
height: 100vh;
margin: 0;
padding: 0;
overflow: hidden;
font-weight: bold;
}
/* Container for the main content */
.container {
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
height: 100%;
text-align: center;
z-index: 1;
}
header {
font-size: 3rem;
text-transform: uppercase;
margin: 20px;
color: var(--colour-4);
}
iframe {
background: transparent;
width: 100%;
border: none;
}
#background {
height: 100%;
position: absolute;
z-index: -1;
}
iframe.stream {
max-height: 0;
transition: max-height 0.15s ease-out;
}
iframe.stream.show {
max-height: 1000px;
height: 1000px;
transition: max-height 0.25s ease-in;
background: rgba(255,255,255,0.7);
border-top: 2px solid rgba(255,255,255,0.5);
}
.description {
font-size: 1.2rem;
margin-bottom: 30px;
color: var(--colour-2);
} return app
.input-field {
width: 80%;
max-width: 400px;
padding: 12px;
margin: 10px 0;
border: 2px solid var(--colour-6);
background-color: var(--colour-5);
color: var(--colour-3);
border-radius: 8px;
font-size: 1.1rem;
}
.input-field:focus {
outline: none;
border-color: var(--accent);
}
.button {
background-color: var(--accent);
color: var(--colour-3);
border: none;
padding: 15px 30px;
font-size: 1.1rem;
border-radius: 8px;
cursor: pointer;
transition: background-color 0.3s ease;
margin-top: 15px;
width: 100%;
max-width: 400px;
font-weight: bold;
}
.button:hover {
background-color: #7a2ccd;
}
.footer {
margin-top: 30px;
font-size: 0.9rem;
color: var(--colour-2);
}
/* Animation for the gradient circle */
@keyframes zoom_gradient {
0% {
transform: translate(-50%, -50%) scale(1);
}
100% {
transform: translate(-50%, -50%) scale(1.5);
}
}
</style>
</head>
<body>
<iframe id="background"></iframe>
<!-- Gradient Background Circle -->
<div class="gradient"></div>
<!-- Main Content -->
<div class="container">
<header>
G4F GUI
</header>
<div class="description">
Welcome to the G4F GUI! <br>
Your AI assistant is ready to assist you.
</div>
<!-- Input and Button -->
<form action="/chat/">
<!--
<input type="text" name="prompt" class="input-field" placeholder="Enter your query...">
-->
<button class="button">Open Chat</button>
</form>
<!-- Footer -->
<div class="footer">
<p>&copy; 2025 G4F. All Rights Reserved.</p>
<p>Powered by the G4F framework</p>
</div>
<iframe id="stream-widget" class="stream" data-src="/backend-api/v2/create?prompt=Create of overview of the news in plain text&stream=1&web_search=news in " class="" frameborder="0"></iframe>
</div>
<script>
const iframe = document.getElementById('stream-widget');
iframe.src = iframe.dataset.src + navigator.language;
setTimeout(()=>iframe.classList.add('show'), 5000);
(async () => {
const prompt = `
Today is ${new Date().toJSON().slice(0, 10)}.
Create a single-page HTML screensaver reflecting the current season (based on the date).
For example, if it's Spring, it might use floral patterns or pastel colors.
Avoid using any text. Consider a subtle animation or transition effect.`;
const response = await fetch(`/backend-api/v2/create?prompt=${prompt}&filter_markdown=html`)
const text = await response.text()
background.src = `data:text/html;charset=utf-8,${encodeURIComponent(text)}`;
const gradient = document.querySelector('.gradient');
gradient.classList.add('hidden');
})();
</script>
</body>
</html>

View file

@ -137,14 +137,13 @@ class HtmlRenderPlugin {
}
}
}
if (window.hljs) {
hljs.addPlugin(new HtmlRenderPlugin())
hljs.addPlugin(new CopyButtonPlugin());
}
let typesetPromise = Promise.resolve();
const highlight = (container) => {
if (window.hljs) {
hljs.addPlugin(new HtmlRenderPlugin())
if (window.CopyButtonPlugin) {
hljs.addPlugin(new CopyButtonPlugin());
}
container.querySelectorAll('code:not(.hljs').forEach((el) => {
if (el.className != "hljs") {
hljs.highlightElement(el);
@ -542,7 +541,6 @@ async function load_provider_parameters(provider) {
if (old_form) {
provider_forms.removeChild(old_form);
}
console.log(provider, parameters_storage[provider]);
Object.entries(parameters_storage[provider]).forEach(([key, value]) => {
let el_id = `${provider}-${key}`;
let saved_value = appStorage.getItem(el_id);
@ -1012,8 +1010,9 @@ const load_conversation = async (conversation_id, scroll=true) => {
let lines = buffer.trim().split("\n");
let lastLine = lines[lines.length - 1];
let newContent = item.content;
if (newContent.startsWith("```\n")) {
newContent = item.content.substring(4);
if (newContent.startsWith("```")) {
const index = str.indexOf("\n");
newContent = newContent.substring(index);
}
if (newContent.startsWith(lastLine)) {
newContent = newContent.substring(lastLine.length);
@ -1763,7 +1762,7 @@ async function load_version() {
new_version = document.createElement("div");
new_version.classList.add("new_version");
const link = `<a href="${release_url}" target="_blank" title="${title}">v${versions["latest_version"]}</a>`;
new_version.innerHTML = `g4f ${link}&nbsp;&nbsp;🆕`;
new_version.innerHTML = `G4F ${link}&nbsp;&nbsp;🆕`;
new_version.addEventListener("click", ()=>new_version.parentElement.removeChild(new_version));
document.body.appendChild(new_version);
} else {

View file

@ -7,7 +7,7 @@ class CopyButtonPlugin {
el,
text
}) {
if (el.classList.contains("language-plaintext")) {
if (el.parentElement.tagName != "PRE") {
return;
}
let button = Object.assign(document.createElement("button"), {

View file

@ -6,4 +6,6 @@ def create_app() -> Flask:
template_folder = os.path.join(sys._MEIPASS, "client")
else:
template_folder = "../client"
return Flask(__name__, template_folder=template_folder, static_folder=f"{template_folder}/static")
app = Flask(__name__, template_folder=template_folder, static_folder=f"{template_folder}/static")
app.config["TEMPLATES_AUTO_RELOAD"] = True # Enable auto reload in debug mode
return app

View file

@ -14,9 +14,12 @@ from werkzeug.utils import secure_filename
from ...image import is_allowed_extension, to_image
from ...client.service import convert_to_provider
from ...providers.asyncio import to_sync_generator
from ...client.helper import filter_markdown
from ...tools.files import supports_filename, get_streaming, get_bucket_dir, get_buckets
from ...tools.run_tools import iter_run_tools
from ...errors import ProviderNotFoundError
from ...cookies import get_cookies_dir
from ... import ChatCompletion
from .api import Api
logger = logging.getLogger(__name__)
@ -101,6 +104,44 @@ class Backend_Api(Api):
}
}
@app.route('/backend-api/v2/create', methods=['GET', 'POST'])
def create():
try:
tool_calls = [{
"function": {
"name": "bucket_tool"
},
"type": "function"
}]
web_search = request.args.get("web_search")
if web_search:
tool_calls.append({
"function": {
"name": "search_tool",
"arguments": {"query": web_search, "instructions": ""} if web_search != "true" else {}
},
"type": "function"
})
do_filter_markdown = request.args.get("filter_markdown")
response = iter_run_tools(
ChatCompletion.create,
model=request.args.get("model"),
messages=[{"role": "user", "content": request.args.get("prompt")}],
provider=request.args.get("provider", None),
stream=not do_filter_markdown,
ignore_stream=not request.args.get("stream"),
tool_calls=tool_calls,
)
if do_filter_markdown:
return Response(filter_markdown(response, do_filter_markdown), mimetype='text/plain')
def cast_str():
for chunk in response:
yield str(chunk)
return Response(cast_str(), mimetype='text/plain')
except Exception as e:
logger.exception(e)
return jsonify({"error": {"message": f"{type(e).__name__}: {e}"}}), 500
@app.route('/backend-api/v2/buckets', methods=['GET'])
def list_buckets():
try:

View file

@ -9,7 +9,7 @@ class Website:
self.app = app
self.routes = {
'/': {
'function': redirect_home,
'function': self._home,
'methods': ['GET', 'POST']
},
'/chat/': {
@ -41,3 +41,6 @@ class Website:
def _index(self):
return render_template('index.html', chat_id=str(uuid.uuid4()))
def _home(self):
return render_template('home.html')

View file

@ -5,8 +5,10 @@ import asyncio
from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod
import json
from inspect import signature, Parameter
from typing import Optional, _GenericAlias
from typing import Optional, Awaitable, _GenericAlias
from pathlib import Path
try:
from types import NoneType
except ImportError:
@ -15,9 +17,10 @@ except ImportError:
from ..typing import CreateResult, AsyncResult, Messages
from .types import BaseProvider
from .asyncio import get_running_loop, to_sync_generator
from .response import BaseConversation
from .response import BaseConversation, AuthResult
from .helper import concat_chunks, async_concat_chunks
from ..errors import ModelNotSupportedError, ResponseError
from ..cookies import get_cookies_dir
from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError
from .. import debug
SAFE_PARAMETERS = [
@ -308,6 +311,12 @@ class AsyncGeneratorProvider(AsyncProvider):
"""
raise NotImplementedError()
create_authed = create_completion
create_authed_async = create_async
create_async_authed = create_async_generator
class ProviderModelMixin:
default_model: str = None
models: list[str] = []
@ -347,3 +356,98 @@ class RaiseErrorMixin():
raise ResponseError(data["error"]["message"])
else:
raise ResponseError(data["error"])
class AuthedMixin():
@classmethod
def on_auth(cls, **kwargs) -> Optional[AuthResult]:
if "api_key" not in kwargs:
raise MissingAuthError(f"API key is required for {cls.__name__}")
return None
@classmethod
def create_authed(
cls,
model: str,
messages: Messages,
**kwargs
) -> CreateResult:
auth_result = {}
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return cls.create_completion(model, messages, **kwargs)
finally:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
class AsyncAuthedMixin(AuthedMixin):
@classmethod
async def create_async_authed(
cls,
model: str,
messages: Messages,
**kwargs
) -> str:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return await cls.create_async(model, messages, **kwargs)
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
class AsyncAuthedGeneratorMixin(AsyncAuthedMixin):
@classmethod
async def create_async_authed(
cls,
model: str,
messages: Messages,
**kwargs
) -> str:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return await async_concat_chunks(cls.create_async_generator(model, messages, stream=False, **kwargs))
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))
@classmethod
def create_async_authed_generator(
cls,
model: str,
messages: Messages,
stream: bool = True,
**kwargs
) -> Awaitable[AsyncResult]:
cache_file = Path(get_cookies_dir()) / f"auth_{cls.__name__}.json"
if cache_file.exists():
auth_result = {}
with cache_file.open("r") as f:
auth_result = json.load(f)
return cls.create_completion(model, messages, **kwargs, **auth_result)
auth_result = cls.on_auth(**kwargs)
try:
return cls.create_async_generator(model, messages, stream=stream, **kwargs)
finally:
if auth_result is not None:
cache_file.parent.mkdir(parents=True, exist_ok=True)
cache_file.write_text(json.dumps(auth_result.get_dict()))

View file

@ -105,6 +105,10 @@ class Usage(ResponseType, JsonMixin):
def __str__(self) -> str:
return ""
class AuthResult(JsonMixin):
def __str__(self) -> str:
return ""
class TitleGeneration(ResponseType):
def __init__(self, title: str) -> None:
self.title = title
@ -182,4 +186,5 @@ class ImagePreview(ImageResponse):
return super().__str__()
class Parameters(ResponseType, JsonMixin):
pass
def __str__(self):
return ""

View file

@ -40,7 +40,7 @@ async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls:
)
elif tool.get("function", {}).get("name") == "continue":
last_line = messages[-1]["content"].strip().splitlines()[-1]
content = f"Continue after this line.\n{last_line}"
content = f"Carry on from this point:\n{last_line}"
messages.append({"role": "user", "content": content})
elif tool.get("function", {}).get("name") == "bucket_tool":
def on_bucket(match):
@ -90,7 +90,7 @@ def iter_run_tools(
elif tool.get("function", {}).get("name") == "continue_tool":
if provider not in ("OpenaiAccount", "HuggingFace"):
last_line = messages[-1]["content"].strip().splitlines()[-1]
content = f"Continue after this line:\n{last_line}"
content = f"Carry on from this point:\n{last_line}"
messages.append({"role": "user", "content": content})
else:
# Enable provider native continue

View file

@ -5,6 +5,7 @@ import json
import hashlib
from pathlib import Path
from urllib.parse import urlparse
from datetime import datetime
import datetime
import asyncio
@ -65,7 +66,7 @@ class SearchResultEntry():
def set_text(self, text: str):
self.text = text
def scrape_text(html: str, max_words: int = None) -> Iterator[str]:
def scrape_text(html: str, max_words: int = None, add_source=True) -> Iterator[str]:
source = BeautifulSoup(html, "html.parser")
soup = source
for selector in [
@ -87,7 +88,7 @@ def scrape_text(html: str, max_words: int = None) -> Iterator[str]:
if select:
select.extract()
for paragraph in soup.select("p, table, ul, h1, h2, h3, h4, h5, h6"):
for paragraph in soup.select("p, table:not(:has(p)), ul:not(:has(p)), h1, h2, h3, h4, h5, h6"):
for line in paragraph.text.splitlines():
words = [word for word in line.replace("\t", " ").split(" ") if word]
count = len(words)
@ -99,24 +100,25 @@ def scrape_text(html: str, max_words: int = None) -> Iterator[str]:
break
yield " ".join(words) + "\n"
canonical_link = source.find("link", rel="canonical")
if canonical_link and "href" in canonical_link.attrs:
link = canonical_link["href"]
domain = urlparse(link).netloc
yield f"\nSource: [{domain}]({link})"
if add_source:
canonical_link = source.find("link", rel="canonical")
if canonical_link and "href" in canonical_link.attrs:
link = canonical_link["href"]
domain = urlparse(link).netloc
yield f"\nSource: [{domain}]({link})"
async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None) -> str:
async def fetch_and_scrape(session: ClientSession, url: str, max_words: int = None, add_source: bool = False) -> str:
try:
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "fetch_and_scrape"
bucket_dir.mkdir(parents=True, exist_ok=True)
md5_hash = hashlib.md5(url.encode()).hexdigest()
cache_file = bucket_dir / f"{url.split('/')[3]}.{datetime.date.today()}.{md5_hash}.txt"
cache_file = bucket_dir / f"{url.split('?')[0].split('//')[1].replace('/', '+')}.{datetime.date.today()}.{md5_hash}.txt"
if cache_file.exists():
return cache_file.read_text()
async with session.get(url) as response:
if response.status == 200:
html = await response.text()
text = "".join(scrape_text(html, max_words))
text = "".join(scrape_text(html, max_words, add_source))
with open(cache_file, "w") as f:
f.write(text)
return text
@ -136,6 +138,8 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
max_results=max_results,
backend=backend,
):
if ".google.com" in result["href"]:
continue
results.append(SearchResultEntry(
result["title"],
result["href"],
@ -146,7 +150,7 @@ async def search(query: str, max_results: int = 5, max_words: int = 2500, backen
requests = []
async with ClientSession(timeout=ClientTimeout(timeout)) as session:
for entry in results:
requests.append(fetch_and_scrape(session, entry.url, int(max_words / (max_results - 1))))
requests.append(fetch_and_scrape(session, entry.url, int(max_words / (max_results - 1)), False))
texts = await asyncio.gather(*requests)
formatted_results = []
@ -173,7 +177,7 @@ async def do_search(prompt: str, query: str = None, instructions: str = DEFAULT_
query = spacy_get_keywords(prompt)
json_bytes = json.dumps({"query": query, **kwargs}, sort_keys=True).encode()
md5_hash = hashlib.md5(json_bytes).hexdigest()
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / "web_search"
bucket_dir: Path = Path(get_cookies_dir()) / ".scrape_cache" / f"web_search:{datetime.date.today()}"
bucket_dir.mkdir(parents=True, exist_ok=True)
cache_file = bucket_dir / f"{query[:20]}.{md5_hash}.txt"
if cache_file.exists():