mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Enhance MCP server tests to reflect updated tool count; improve model fetching with timeout handling in providers
This commit is contained in:
parent
006b8c8d50
commit
af56ac0c03
6 changed files with 19 additions and 16 deletions
|
|
@ -22,7 +22,7 @@ class TestMCPServer(unittest.IsolatedAsyncioTestCase):
|
||||||
server = MCPServer()
|
server = MCPServer()
|
||||||
self.assertIsNotNone(server)
|
self.assertIsNotNone(server)
|
||||||
self.assertEqual(server.server_info["name"], "gpt4free-mcp-server")
|
self.assertEqual(server.server_info["name"], "gpt4free-mcp-server")
|
||||||
self.assertEqual(len(server.tools), 3)
|
self.assertEqual(len(server.tools), 5)
|
||||||
self.assertIn('web_search', server.tools)
|
self.assertIn('web_search', server.tools)
|
||||||
self.assertIn('web_scrape', server.tools)
|
self.assertIn('web_scrape', server.tools)
|
||||||
self.assertIn('image_generation', server.tools)
|
self.assertIn('image_generation', server.tools)
|
||||||
|
|
@ -57,7 +57,7 @@ class TestMCPServer(unittest.IsolatedAsyncioTestCase):
|
||||||
self.assertEqual(response.id, 2)
|
self.assertEqual(response.id, 2)
|
||||||
self.assertIsNotNone(response.result)
|
self.assertIsNotNone(response.result)
|
||||||
self.assertIn("tools", response.result)
|
self.assertIn("tools", response.result)
|
||||||
self.assertEqual(len(response.result["tools"]), 3)
|
self.assertEqual(len(response.result["tools"]), 5)
|
||||||
|
|
||||||
# Check tool structure
|
# Check tool structure
|
||||||
tool_names = [tool["name"] for tool in response.result["tools"]]
|
tool_names = [tool["name"] for tool in response.result["tools"]]
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Type
|
from typing import Type
|
||||||
import asyncio
|
from requests.exceptions import RequestException
|
||||||
|
|
||||||
from g4f.models import __models__
|
from g4f.models import __models__
|
||||||
from g4f.providers.base_provider import BaseProvider, ProviderModelMixin
|
from g4f.providers.base_provider import BaseProvider, ProviderModelMixin
|
||||||
|
|
@ -15,12 +15,15 @@ class TestProviderHasModel(unittest.TestCase):
|
||||||
if provider.needs_auth:
|
if provider.needs_auth:
|
||||||
continue
|
continue
|
||||||
if issubclass(provider, ProviderModelMixin):
|
if issubclass(provider, ProviderModelMixin):
|
||||||
provider.get_models() # Update models
|
try:
|
||||||
if model.name in provider.model_aliases:
|
provider.get_models(timeout=5) # Update models
|
||||||
model_name = provider.model_aliases[model.name]
|
if model.name in provider.model_aliases:
|
||||||
else:
|
model_name = provider.model_aliases[model.name]
|
||||||
model_name = model.get_long_name()
|
else:
|
||||||
self.provider_has_model(provider, model_name)
|
model_name = model.get_long_name()
|
||||||
|
self.provider_has_model(provider, model_name)
|
||||||
|
except RequestException:
|
||||||
|
continue
|
||||||
|
|
||||||
def provider_has_model(self, provider: Type[BaseProvider], model: str):
|
def provider_has_model(self, provider: Type[BaseProvider], model: str):
|
||||||
if provider.__name__ not in self.cache:
|
if provider.__name__ not in self.cache:
|
||||||
|
|
|
||||||
|
|
@ -502,7 +502,7 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
||||||
_models_loaded = False
|
_models_loaded = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls) -> list[str]:
|
def get_models(cls, timeout: int = None) -> list[str]:
|
||||||
if not cls._models_loaded and has_curl_cffi:
|
if not cls._models_loaded and has_curl_cffi:
|
||||||
cache_file = cls.get_cache_file()
|
cache_file = cls.get_cache_file()
|
||||||
args = {}
|
args = {}
|
||||||
|
|
@ -516,7 +516,7 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
||||||
args = {}
|
args = {}
|
||||||
if not args:
|
if not args:
|
||||||
return cls.models
|
return cls.models
|
||||||
response = curl_cffi.get(f"{cls.url}/?mode=direct", **args)
|
response = curl_cffi.get(f"{cls.url}/?mode=direct", **args, timeout=timeout)
|
||||||
if response.ok:
|
if response.ok:
|
||||||
for line in response.text.splitlines():
|
for line in response.text.splitlines():
|
||||||
if "initialModels" in line:
|
if "initialModels" in line:
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
max_tokens: int = None
|
max_tokens: int = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_models(cls, api_key: str = None, api_base: str = None) -> list[str]:
|
def get_models(cls, api_key: str = None, api_base: str = None, timeout: int = None) -> list[str]:
|
||||||
if not cls.models:
|
if not cls.models:
|
||||||
try:
|
try:
|
||||||
if api_base is None:
|
if api_base is None:
|
||||||
|
|
@ -42,7 +42,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||||
api_key = AuthManager.load_api_key(cls)
|
api_key = AuthManager.load_api_key(cls)
|
||||||
if cls.models_needs_auth and not api_key:
|
if cls.models_needs_auth and not api_key:
|
||||||
raise MissingAuthError('Add a "api_key"')
|
raise MissingAuthError('Add a "api_key"')
|
||||||
response = requests.get(f"{api_base}/models", headers=cls.get_headers(False, api_key), verify=cls.ssl)
|
response = requests.get(f"{api_base}/models", headers=cls.get_headers(False, api_key), verify=cls.ssl, timeout=timeout)
|
||||||
raise_for_status(response)
|
raise_for_status(response)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
data = data.get("data", data.get("models")) if isinstance(data, dict) else data
|
data = data.get("data", data.get("models")) if isinstance(data, dict) else data
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,6 @@ from dataclasses import dataclass
|
||||||
|
|
||||||
from ..debug import enable_logging
|
from ..debug import enable_logging
|
||||||
|
|
||||||
enable_logging()
|
|
||||||
|
|
||||||
from .tools import MarkItDownTool, TextToAudioTool, WebSearchTool, WebScrapeTool, ImageGenerationTool
|
from .tools import MarkItDownTool, TextToAudioTool, WebSearchTool, WebScrapeTool, ImageGenerationTool
|
||||||
from .tools import WebSearchTool, WebScrapeTool, ImageGenerationTool
|
from .tools import WebSearchTool, WebScrapeTool, ImageGenerationTool
|
||||||
|
|
||||||
|
|
@ -214,6 +212,8 @@ class MCPServer:
|
||||||
sys.stderr.write("Error: aiohttp is required for HTTP transport\n")
|
sys.stderr.write("Error: aiohttp is required for HTTP transport\n")
|
||||||
sys.stderr.write("Install it with: pip install aiohttp\n")
|
sys.stderr.write("Install it with: pip install aiohttp\n")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
enable_logging()
|
||||||
|
|
||||||
async def handle_mcp_request(request: web.Request) -> web.Response:
|
async def handle_mcp_request(request: web.Request) -> web.Response:
|
||||||
nonlocal origin
|
nonlocal origin
|
||||||
|
|
|
||||||
|
|
@ -213,7 +213,7 @@ gpt_4o = VisionModel(
|
||||||
gpt_4o_mini = Model(
|
gpt_4o_mini = Model(
|
||||||
name = 'gpt-4o-mini',
|
name = 'gpt-4o-mini',
|
||||||
base_provider = 'OpenAI',
|
base_provider = 'OpenAI',
|
||||||
best_provider = IterListProvider([Chatai, OIVSCodeSer2, Startnest, OpenaiChat, OIVSCodeSer0501])
|
best_provider = IterListProvider([Chatai, OIVSCodeSer2, Startnest, OpenaiChat])
|
||||||
)
|
)
|
||||||
|
|
||||||
gpt_4o_mini_audio = AudioModel(
|
gpt_4o_mini_audio = AudioModel(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue