mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-06 02:30:41 -08:00
Address code review feedback: improve type hints, validation, and documentation
Co-authored-by: hlohaus <983577+hlohaus@users.noreply.github.com>
This commit is contained in:
parent
1e895fbb6a
commit
0a72ce961c
2 changed files with 71 additions and 37 deletions
|
|
@ -12,7 +12,7 @@ from __future__ import annotations
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
|
|
||||||
from .tools import WebSearchTool, WebScrapeTool, ImageGenerationTool
|
from .tools import WebSearchTool, WebScrapeTool, ImageGenerationTool
|
||||||
|
|
@ -22,7 +22,7 @@ from .tools import WebSearchTool, WebScrapeTool, ImageGenerationTool
|
||||||
class MCPRequest:
|
class MCPRequest:
|
||||||
"""MCP request following JSON-RPC 2.0 format"""
|
"""MCP request following JSON-RPC 2.0 format"""
|
||||||
jsonrpc: str = "2.0"
|
jsonrpc: str = "2.0"
|
||||||
id: Optional[int | str] = None
|
id: Optional[Union[int, str]] = None
|
||||||
method: Optional[str] = None
|
method: Optional[str] = None
|
||||||
params: Optional[Dict[str, Any]] = None
|
params: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
@ -31,7 +31,7 @@ class MCPRequest:
|
||||||
class MCPResponse:
|
class MCPResponse:
|
||||||
"""MCP response following JSON-RPC 2.0 format"""
|
"""MCP response following JSON-RPC 2.0 format"""
|
||||||
jsonrpc: str = "2.0"
|
jsonrpc: str = "2.0"
|
||||||
id: Optional[int | str] = None
|
id: Optional[Union[int, str]] = None
|
||||||
result: Optional[Any] = None
|
result: Optional[Any] = None
|
||||||
error: Optional[Dict[str, Any]] = None
|
error: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
|
||||||
102
g4f/mcp/tools.py
102
g4f/mcp/tools.py
|
|
@ -29,8 +29,15 @@ class MCPTool(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(self, arguments: Dict[str, Any]) -> Any:
|
async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Execute the tool with given arguments"""
|
"""Execute the tool with given arguments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: Tool input arguments matching the input_schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict containing either results or an error key with error message
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -60,7 +67,11 @@ class WebSearchTool(MCPTool):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Execute web search"""
|
"""Execute web search
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Search results or error message
|
||||||
|
"""
|
||||||
from ..tools.web_search import do_search
|
from ..tools.web_search import do_search
|
||||||
|
|
||||||
query = arguments.get("query", "")
|
query = arguments.get("query", "")
|
||||||
|
|
@ -72,7 +83,8 @@ class WebSearchTool(MCPTool):
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Perform search
|
# Perform search - query parameter is used for search execution
|
||||||
|
# and prompt parameter holds the content to be searched
|
||||||
result, sources = await do_search(
|
result, sources = await do_search(
|
||||||
prompt=query,
|
prompt=query,
|
||||||
query=query,
|
query=query,
|
||||||
|
|
@ -127,7 +139,11 @@ class WebScrapeTool(MCPTool):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Execute web scraping"""
|
"""Execute web scraping
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Scraped content or error message
|
||||||
|
"""
|
||||||
from ..tools.fetch_and_scrape import fetch_and_scrape
|
from ..tools.fetch_and_scrape import fetch_and_scrape
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
|
|
@ -202,7 +218,11 @@ class ImageGenerationTool(MCPTool):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Execute image generation"""
|
"""Execute image generation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Generated image data or error message
|
||||||
|
"""
|
||||||
from ..client import AsyncClient
|
from ..client import AsyncClient
|
||||||
from ..image import to_data_uri
|
from ..image import to_data_uri
|
||||||
import base64
|
import base64
|
||||||
|
|
@ -228,35 +248,49 @@ class ImageGenerationTool(MCPTool):
|
||||||
height=height
|
height=height
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the image data
|
# Get the image data with proper validation
|
||||||
if response and hasattr(response, 'data') and response.data:
|
if not response:
|
||||||
image_data = response.data[0]
|
return {
|
||||||
|
"error": "Image generation failed: No response from provider"
|
||||||
# Convert to base64 if needed
|
}
|
||||||
if hasattr(image_data, 'url'):
|
|
||||||
image_url = image_data.url
|
|
||||||
|
|
||||||
# Check if it's already a data URI
|
|
||||||
if image_url.startswith('data:'):
|
|
||||||
return {
|
|
||||||
"prompt": prompt,
|
|
||||||
"model": model,
|
|
||||||
"width": width,
|
|
||||||
"height": height,
|
|
||||||
"image": image_url
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"prompt": prompt,
|
|
||||||
"model": model,
|
|
||||||
"width": width,
|
|
||||||
"height": height,
|
|
||||||
"image_url": image_url
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
if not hasattr(response, 'data') or not response.data:
|
||||||
"error": "Image generation failed: No image data in response"
|
return {
|
||||||
}
|
"error": "Image generation failed: No image data in response"
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(response.data) == 0:
|
||||||
|
return {
|
||||||
|
"error": "Image generation failed: Empty image data array"
|
||||||
|
}
|
||||||
|
|
||||||
|
image_data = response.data[0]
|
||||||
|
|
||||||
|
# Check if image_data has url attribute
|
||||||
|
if not hasattr(image_data, 'url'):
|
||||||
|
return {
|
||||||
|
"error": "Image generation failed: No URL in image data"
|
||||||
|
}
|
||||||
|
|
||||||
|
image_url = image_data.url
|
||||||
|
|
||||||
|
# Return result based on URL type
|
||||||
|
if image_url.startswith('data:'):
|
||||||
|
return {
|
||||||
|
"prompt": prompt,
|
||||||
|
"model": model,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"image": image_url
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"prompt": prompt,
|
||||||
|
"model": model,
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"image_url": image_url
|
||||||
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue