Skip to content

Commit

Permalink
Add get_models to GeminiPro provider
Browse files Browse the repository at this point in the history
  • Loading branch information
hlohaus committed Dec 16, 2024
1 parent 68c7a92 commit ec9df59
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 24 deletions.
11 changes: 4 additions & 7 deletions g4f/Provider/Cloudflare.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio
import json
import uuid

from ..typing import AsyncResult, Messages, Cookies
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
Expand Down Expand Up @@ -37,18 +36,16 @@ def get_models(cls) -> str:
if not cls.models:
if cls._args is None:
get_running_loop(check_nested=True)
args = get_args_from_nodriver(cls.url, cookies={
'__cf_bm': uuid.uuid4().hex,
})
args = get_args_from_nodriver(cls.url)
cls._args = asyncio.run(args)
with Session(**cls._args) as session:
response = session.get(cls.models_url)
cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
try:
raise_for_status(response)
except ResponseStatusError as e:
except ResponseStatusError:
cls._args = None
raise e
raise
json_data = response.json()
cls.models = [model.get("name") for model in json_data.get("models")]
return cls.models
Expand All @@ -64,9 +61,9 @@ async def create_async_generator(
timeout: int = 300,
**kwargs
) -> AsyncResult:
model = cls.get_model(model)
if cls._args is None:
cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
model = cls.get_model(model)
data = {
"messages": messages,
"lora": None,
Expand Down
2 changes: 1 addition & 1 deletion g4f/Provider/PollinationsAI.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class PollinationsAI(OpenaiAPI):
}

@classmethod
def get_models(cls):
def get_models(cls, **kwargs):
if not hasattr(cls, 'image_models'):
cls.image_models = []
if not cls.image_models:
Expand Down
2 changes: 1 addition & 1 deletion g4f/Provider/needs_auth/DeepInfra.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class DeepInfra(OpenaiAPI):
default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct"

@classmethod
def get_models(cls):
def get_models(cls, **kwargs):
if not cls.models:
url = 'https://api.deepinfra.com/models/featured'
models = requests.get(url).json()
Expand Down
36 changes: 29 additions & 7 deletions g4f/Provider/needs_auth/GeminiPro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,52 @@

import base64
import json
import requests
from aiohttp import ClientSession, BaseConnector

from ...typing import AsyncResult, Messages, ImagesType
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...image import to_bytes, is_accepted_format
from ...errors import MissingAuthError
from ...requests.raise_for_status import raise_for_status
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import get_connector
from ... import debug

class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
label = "Google Gemini API"
url = "https://ai.google.dev"

api_base = "https://generativelanguage.googleapis.com/v1beta"

working = True
supports_message_history = True
needs_auth = True

default_model = "gemini-1.5-pro"
default_vision_model = default_model
models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
fallback_models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
model_aliases = {
"gemini-flash": "gemini-1.5-flash",
"gemini-flash": "gemini-1.5-flash-8b",
}

@classmethod
def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
if not cls.models:
try:
response = requests.get(f"{api_base}/models?key={api_key}")
raise_for_status(response)
data = response.json()
cls.models = [
model.get("name").split("/").pop()
for model in data.get("models")
if "generateContent" in model.get("supportedGenerationMethods")
]
cls.models.sort()
except Exception as e:
debug.log(e)
cls.models = cls.fallback_models
return cls.models

@classmethod
async def create_async_generator(
cls,
Expand All @@ -34,17 +56,17 @@ async def create_async_generator(
stream: bool = False,
proxy: str = None,
api_key: str = None,
api_base: str = "https://generativelanguage.googleapis.com/v1beta",
api_base: str = api_base,
use_auth_header: bool = False,
images: ImagesType = None,
connector: BaseConnector = None,
**kwargs
) -> AsyncResult:
model = cls.get_model(model)

if not api_key:
raise MissingAuthError('Add a "api_key"')

model = cls.get_model(model, api_key=api_key, api_base=api_base)

headers = params = None
if use_auth_header:
headers = {"Authorization": f"Bearer {api_key}"}
Expand Down
8 changes: 4 additions & 4 deletions g4f/Provider/needs_auth/OpenaiAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
fallback_models = []

@classmethod
def get_models(cls, api_key: str = None):
def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
if not cls.models:
try:
headers = {}
if api_key is not None:
headers["authorization"] = f"Bearer {api_key}"
response = requests.get(f"{cls.api_base}/models", headers=headers)
response = requests.get(f"{api_base}/models", headers=headers)
raise_for_status(response)
data = response.json()
cls.models = [model.get("id") for model in data.get("data")]
Expand Down Expand Up @@ -82,7 +82,7 @@ async def create_async_generator(
) as session:
data = filter_none(
messages=messages,
model=cls.get_model(model),
model=cls.get_model(model, api_key=api_key, api_base=api_base),
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
Expand Down Expand Up @@ -147,4 +147,4 @@ def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) ->
if api_key is not None else {}
),
**({} if headers is None else headers)
}
}
9 changes: 5 additions & 4 deletions g4f/providers/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,19 +243,20 @@ class ProviderModelMixin:
last_model: str = None

@classmethod
def get_models(cls) -> list[str]:
def get_models(cls, **kwargs) -> list[str]:
if not cls.models and cls.default_model is not None:
return [cls.default_model]
return cls.models

@classmethod
def get_model(cls, model: str) -> str:
def get_model(cls, model: str, **kwargs) -> str:
if not model and cls.default_model is not None:
model = cls.default_model
elif model in cls.model_aliases:
model = cls.model_aliases[model]
elif model not in cls.get_models() and cls.models:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
else:
if model not in cls.get_models(**kwargs) and cls.models:
raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
cls.last_model = model
debug.last_model = model
return model

0 comments on commit ec9df59

Please sign in to comment.