diff --git a/g4f/mcp/server.py b/g4f/mcp/server.py index 1ea11346..9eea3dbe 100644 --- a/g4f/mcp/server.py +++ b/g4f/mcp/server.py @@ -12,7 +12,7 @@ from __future__ import annotations import sys import json import asyncio -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass, asdict from .tools import WebSearchTool, WebScrapeTool, ImageGenerationTool @@ -22,7 +22,7 @@ from .tools import WebSearchTool, WebScrapeTool, ImageGenerationTool class MCPRequest: """MCP request following JSON-RPC 2.0 format""" jsonrpc: str = "2.0" - id: Optional[int | str] = None + id: Optional[Union[int, str]] = None method: Optional[str] = None params: Optional[Dict[str, Any]] = None @@ -31,7 +31,7 @@ class MCPRequest: class MCPResponse: """MCP response following JSON-RPC 2.0 format""" jsonrpc: str = "2.0" - id: Optional[int | str] = None + id: Optional[Union[int, str]] = None result: Optional[Any] = None error: Optional[Dict[str, Any]] = None diff --git a/g4f/mcp/tools.py b/g4f/mcp/tools.py index 0e434833..2ba498f2 100644 --- a/g4f/mcp/tools.py +++ b/g4f/mcp/tools.py @@ -29,8 +29,15 @@ class MCPTool(ABC): pass @abstractmethod - async def execute(self, arguments: Dict[str, Any]) -> Any: - """Execute the tool with given arguments""" + async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Execute the tool with given arguments + + Args: + arguments: Tool input arguments matching the input_schema + + Returns: + Dict containing either results or an error key with error message + """ pass @@ -60,7 +67,11 @@ class WebSearchTool(MCPTool): } async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]: - """Execute web search""" + """Execute web search + + Returns: + Dict[str, Any]: Search results or error message + """ from ..tools.web_search import do_search query = arguments.get("query", "") @@ -72,7 +83,8 @@ class WebSearchTool(MCPTool): } try: - # Perform search + # Perform search - query parameter is used for search execution + # and prompt parameter holds the content to be searched result, sources = await do_search( prompt=query, query=query, @@ -127,7 +139,11 @@ class WebScrapeTool(MCPTool): } async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]: - """Execute web scraping""" + """Execute web scraping + + Returns: + Dict[str, Any]: Scraped content or error message + """ from ..tools.fetch_and_scrape import fetch_and_scrape from aiohttp import ClientSession @@ -202,7 +218,11 @@ class ImageGenerationTool(MCPTool): } async def execute(self, arguments: Dict[str, Any]) -> Dict[str, Any]: - """Execute image generation""" + """Execute image generation + + Returns: + Dict[str, Any]: Generated image data or error message + """ from ..client import AsyncClient from ..image import to_data_uri import base64 @@ -228,35 +248,49 @@ class ImageGenerationTool(MCPTool): height=height ) - # Get the image data - if response and hasattr(response, 'data') and response.data: - image_data = response.data[0] - - # Convert to base64 if needed - if hasattr(image_data, 'url'): - image_url = image_data.url - - # Check if it's already a data URI - if image_url.startswith('data:'): - return { - "prompt": prompt, - "model": model, - "width": width, - "height": height, - "image": image_url - } - else: - return { - "prompt": prompt, - "model": model, - "width": width, - "height": height, - "image_url": image_url - } + # Get the image data with proper validation + if not response: + return { + "error": "Image generation failed: No response from provider" + } - return { - "error": "Image generation failed: No image data in response" - } + if not hasattr(response, 'data') or not response.data: + return { + "error": "Image generation failed: No image data in response" + } + + if len(response.data) == 0: + return { + "error": "Image generation failed: Empty image data array" + } + + image_data = response.data[0] + + # Check if image_data has url attribute + if not hasattr(image_data, 'url'): + return { + "error": "Image generation failed: No URL in image data" + } + + image_url = image_data.url + + # Return result based on URL type + if image_url.startswith('data:'): + return { + "prompt": prompt, + "model": model, + "width": width, + "height": height, + "image": image_url + } + else: + return { + "prompt": prompt, + "model": model, + "width": width, + "height": height, + "image_url": image_url + } except Exception as e: return {