From ec9df59828ce6c244e155650a88c99538ddafb91 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Mon, 16 Dec 2024 01:59:30 +0100 Subject: [PATCH] Add get_models to GeminiPro provider --- g4f/Provider/Cloudflare.py | 11 ++++----- g4f/Provider/PollinationsAI.py | 2 +- g4f/Provider/needs_auth/DeepInfra.py | 2 +- g4f/Provider/needs_auth/GeminiPro.py | 36 ++++++++++++++++++++++------ g4f/Provider/needs_auth/OpenaiAPI.py | 8 +++---- g4f/providers/base_provider.py | 9 +++---- 6 files changed, 44 insertions(+), 24 deletions(-) diff --git a/g4f/Provider/Cloudflare.py b/g4f/Provider/Cloudflare.py index 7d477d57327..4416f7a315e 100644 --- a/g4f/Provider/Cloudflare.py +++ b/g4f/Provider/Cloudflare.py @@ -2,7 +2,6 @@ import asyncio import json -import uuid from ..typing import AsyncResult, Messages, Cookies from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop @@ -37,18 +36,16 @@ def get_models(cls) -> str: if not cls.models: if cls._args is None: get_running_loop(check_nested=True) - args = get_args_from_nodriver(cls.url, cookies={ - '__cf_bm': uuid.uuid4().hex, - }) + args = get_args_from_nodriver(cls.url) cls._args = asyncio.run(args) with Session(**cls._args) as session: response = session.get(cls.models_url) cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response) try: raise_for_status(response) - except ResponseStatusError as e: + except ResponseStatusError: cls._args = None - raise e + raise json_data = response.json() cls.models = [model.get("name") for model in json_data.get("models")] return cls.models @@ -64,9 +61,9 @@ async def create_async_generator( timeout: int = 300, **kwargs ) -> AsyncResult: - model = cls.get_model(model) if cls._args is None: cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies) + model = cls.get_model(model) data = { "messages": messages, "lora": None, diff --git a/g4f/Provider/PollinationsAI.py b/g4f/Provider/PollinationsAI.py index 9520674a188..31a7e7e436b 100644 --- a/g4f/Provider/PollinationsAI.py +++ b/g4f/Provider/PollinationsAI.py @@ -40,7 +40,7 @@ class PollinationsAI(OpenaiAPI): } @classmethod - def get_models(cls): + def get_models(cls, **kwargs): if not hasattr(cls, 'image_models'): cls.image_models = [] if not cls.image_models: diff --git a/g4f/Provider/needs_auth/DeepInfra.py b/g4f/Provider/needs_auth/DeepInfra.py index 35e7ca7f85e..035effb072c 100644 --- a/g4f/Provider/needs_auth/DeepInfra.py +++ b/g4f/Provider/needs_auth/DeepInfra.py @@ -14,7 +14,7 @@ class DeepInfra(OpenaiAPI): default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct" @classmethod - def get_models(cls): + def get_models(cls, **kwargs): if not cls.models: url = 'https://api.deepinfra.com/models/featured' models = requests.get(url).json() diff --git a/g4f/Provider/needs_auth/GeminiPro.py b/g4f/Provider/needs_auth/GeminiPro.py index 36c906563c6..22c9c015885 100644 --- a/g4f/Provider/needs_auth/GeminiPro.py +++ b/g4f/Provider/needs_auth/GeminiPro.py @@ -2,30 +2,52 @@ import base64 import json +import requests from aiohttp import ClientSession, BaseConnector from ...typing import AsyncResult, Messages, ImagesType -from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ...image import to_bytes, is_accepted_format from ...errors import MissingAuthError +from ...requests.raise_for_status import raise_for_status +from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..helper import get_connector +from ... import debug class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): label = "Google Gemini API" url = "https://ai.google.dev" - + api_base = "https://generativelanguage.googleapis.com/v1beta" + working = True supports_message_history = True needs_auth = True - + default_model = "gemini-1.5-pro" default_vision_model = default_model - models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"] + fallback_models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"] model_aliases = { "gemini-flash": "gemini-1.5-flash", "gemini-flash": "gemini-1.5-flash-8b", } + @classmethod + def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]: + if not cls.models: + try: + response = requests.get(f"{api_base}/models?key={api_key}") + raise_for_status(response) + data = response.json() + cls.models = [ + model.get("name").split("/").pop() + for model in data.get("models") + if "generateContent" in model.get("supportedGenerationMethods") + ] + cls.models.sort() + except Exception as e: + debug.log(e) + cls.models = cls.fallback_models + return cls.models + @classmethod async def create_async_generator( cls, @@ -34,17 +56,17 @@ async def create_async_generator( stream: bool = False, proxy: str = None, api_key: str = None, - api_base: str = "https://generativelanguage.googleapis.com/v1beta", + api_base: str = api_base, use_auth_header: bool = False, images: ImagesType = None, connector: BaseConnector = None, **kwargs ) -> AsyncResult: - model = cls.get_model(model) - if not api_key: raise MissingAuthError('Add a "api_key"') + model = cls.get_model(model, api_key=api_key, api_base=api_base) + headers = params = None if use_auth_header: headers = {"Authorization": f"Bearer {api_key}"} diff --git a/g4f/Provider/needs_auth/OpenaiAPI.py b/g4f/Provider/needs_auth/OpenaiAPI.py index ebc4d5192d9..a61115eaab9 100644 --- a/g4f/Provider/needs_auth/OpenaiAPI.py +++ b/g4f/Provider/needs_auth/OpenaiAPI.py @@ -23,13 +23,13 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin): fallback_models = [] @classmethod - def get_models(cls, api_key: str = None): + def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]: if not cls.models: try: headers = {} if api_key is not None: headers["authorization"] = f"Bearer {api_key}" - response = requests.get(f"{cls.api_base}/models", headers=headers) + response = requests.get(f"{api_base}/models", headers=headers) raise_for_status(response) data = response.json() cls.models = [model.get("id") for model in data.get("data")] @@ -82,7 +82,7 @@ async def create_async_generator( ) as session: data = filter_none( messages=messages, - model=cls.get_model(model), + model=cls.get_model(model, api_key=api_key, api_base=api_base), temperature=temperature, max_tokens=max_tokens, top_p=top_p, @@ -147,4 +147,4 @@ def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> if api_key is not None else {} ), **({} if headers is None else headers) - } + } \ No newline at end of file diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index 0cdcde90e67..e2c356e338b 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -243,19 +243,20 @@ class ProviderModelMixin: last_model: str = None @classmethod - def get_models(cls) -> list[str]: + def get_models(cls, **kwargs) -> list[str]: if not cls.models and cls.default_model is not None: return [cls.default_model] return cls.models @classmethod - def get_model(cls, model: str) -> str: + def get_model(cls, model: str, **kwargs) -> str: if not model and cls.default_model is not None: model = cls.default_model elif model in cls.model_aliases: model = cls.model_aliases[model] - elif model not in cls.get_models() and cls.models: - raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") + else: + if model not in cls.get_models(**kwargs) and cls.models: + raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") cls.last_model = model debug.last_model = model return model \ No newline at end of file