Skip to content

Commit

Permalink
Merge pull request #2488 from hlohaus/ccccc
Browse files Browse the repository at this point in the history
Add get_models to GeminiPro provider
  • Loading branch information
hlohaus authored Dec 16, 2024
2 parents f317fea + ec9df59 commit 5da55b3
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 58 deletions.
21 changes: 10 additions & 11 deletions g4f/Provider/Blackbox2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
from pathlib import Path
from aiohttp import ClientSession
from typing import AsyncGenerator
from typing import AsyncIterator

from ..typing import AsyncResult, Messages
from ..image import ImageResponse
Expand All @@ -21,12 +21,12 @@ class Blackbox2(AsyncGeneratorProvider, ProviderModelMixin):
"llama-3.1-70b": "https://www.blackbox.ai/api/improve-prompt",
"flux": "https://www.blackbox.ai/api/image-generator"
}

working = True
supports_system_message = True
supports_message_history = True
supports_stream = False

default_model = 'llama-3.1-70b'
chat_models = ['llama-3.1-70b']
image_models = ['flux']
Expand Down Expand Up @@ -97,15 +97,14 @@ async def create_async_generator(
messages: Messages,
prompt: str = None,
proxy: str = None,
prompt: str = None,
max_retries: int = 3,
delay: int = 1,
max_tokens: int = None,
**kwargs
) -> AsyncGenerator[str, None]:
) -> AsyncResult:
if not model:
model = cls.default_model

if model in cls.chat_models:
async for result in cls._generate_text(model, messages, proxy, max_retries, delay, max_tokens):
yield result
Expand All @@ -125,13 +124,13 @@ async def _generate_text(
max_retries: int = 3,
delay: int = 1,
max_tokens: int = None,
) -> AsyncGenerator[str, None]:
) -> AsyncIterator[str]:
headers = cls._get_headers()

async with ClientSession(headers=headers) as session:
license_key = await cls._get_license_key(session)
api_endpoint = cls.api_endpoints[model]

data = {
"messages": messages,
"max_tokens": max_tokens,
Expand Down Expand Up @@ -162,19 +161,19 @@ async def _generate_image(
model: str,
prompt: str,
proxy: str = None
) -> AsyncGenerator[ImageResponse, None]:
) -> AsyncIterator[ImageResponse]:
headers = cls._get_headers()
api_endpoint = cls.api_endpoints[model]

async with ClientSession(headers=headers) as session:
data = {
"query": prompt
}

async with session.post(api_endpoint, headers=headers, json=data, proxy=proxy) as response:
response.raise_for_status()
response_data = await response.json()

if 'markdown' in response_data:
image_url = response_data['markdown'].split('(')[1].split(')')[0]
yield ImageResponse(images=image_url, alt=prompt)
Expand Down
11 changes: 4 additions & 7 deletions g4f/Provider/Cloudflare.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import asyncio
import json
import uuid

from ..typing import AsyncResult, Messages, Cookies
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
Expand Down Expand Up @@ -37,18 +36,16 @@ def get_models(cls) -> str:
if not cls.models:
if cls._args is None:
get_running_loop(check_nested=True)
args = get_args_from_nodriver(cls.url, cookies={
'__cf_bm': uuid.uuid4().hex,
})
args = get_args_from_nodriver(cls.url)
cls._args = asyncio.run(args)
with Session(**cls._args) as session:
response = session.get(cls.models_url)
cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
try:
raise_for_status(response)
except ResponseStatusError as e:
except ResponseStatusError:
cls._args = None
raise e
raise
json_data = response.json()
cls.models = [model.get("name") for model in json_data.get("models")]
return cls.models
Expand All @@ -64,9 +61,9 @@ async def create_async_generator(
timeout: int = 300,
**kwargs
) -> AsyncResult:
model = cls.get_model(model)
if cls._args is None:
cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
model = cls.get_model(model)
data = {
"messages": messages,
"lora": None,
Expand Down
2 changes: 1 addition & 1 deletion g4f/Provider/PollinationsAI.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class PollinationsAI(OpenaiAPI):
}

@classmethod
def get_models(cls):
def get_models(cls, **kwargs):
if not hasattr(cls, 'image_models'):
cls.image_models = []
if not cls.image_models:
Expand Down
18 changes: 10 additions & 8 deletions g4f/Provider/needs_auth/Cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Cerebras(OpenaiAPI):
models = [
"llama3.1-70b",
"llama3.1-8b",
"llama-3.3-70b"
]
model_aliases = {"llama-3.1-70b": "llama3.1-70b", "llama-3.1-8b": "llama3.1-8b"}

Expand All @@ -29,14 +30,15 @@ async def create_async_generator(
cookies: Cookies = None,
**kwargs
) -> AsyncResult:
if api_key is None and cookies is None:
cookies = get_cookies(".cerebras.ai")
async with ClientSession(cookies=cookies) as session:
async with session.get("https://inference.cerebras.ai/api/auth/session") as response:
raise_for_status(response)
data = await response.json()
if data:
api_key = data.get("user", {}).get("demoApiKey")
if api_key is None:
if cookies is None:
cookies = get_cookies(".cerebras.ai")
async with ClientSession(cookies=cookies) as session:
async with session.get("https://inference.cerebras.ai/api/auth/session") as response:
await raise_for_status(response)
data = await response.json()
if data:
api_key = data.get("user", {}).get("demoApiKey")
async for chunk in super().create_async_generator(
model, messages,
api_base=api_base,
Expand Down
2 changes: 1 addition & 1 deletion g4f/Provider/needs_auth/DeepInfra.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class DeepInfra(OpenaiAPI):
default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct"

@classmethod
def get_models(cls):
def get_models(cls, **kwargs):
if not cls.models:
url = 'https://api.deepinfra.com/models/featured'
models = requests.get(url).json()
Expand Down
36 changes: 29 additions & 7 deletions g4f/Provider/needs_auth/GeminiPro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,52 @@

import base64
import json
import requests
from aiohttp import ClientSession, BaseConnector

from ...typing import AsyncResult, Messages, ImagesType
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ...image import to_bytes, is_accepted_format
from ...errors import MissingAuthError
from ...requests.raise_for_status import raise_for_status
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..helper import get_connector
from ... import debug

class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
label = "Google Gemini API"
url = "https://ai.google.dev"

api_base = "https://generativelanguage.googleapis.com/v1beta"

working = True
supports_message_history = True
needs_auth = True

default_model = "gemini-1.5-pro"
default_vision_model = default_model
models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
fallback_models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
model_aliases = {
"gemini-flash": "gemini-1.5-flash",
"gemini-flash": "gemini-1.5-flash-8b",
}

@classmethod
def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
if not cls.models:
try:
response = requests.get(f"{api_base}/models?key={api_key}")
raise_for_status(response)
data = response.json()
cls.models = [
model.get("name").split("/").pop()
for model in data.get("models")
if "generateContent" in model.get("supportedGenerationMethods")
]
cls.models.sort()
except Exception as e:
debug.log(e)
cls.models = cls.fallback_models
return cls.models

@classmethod
async def create_async_generator(
cls,
Expand All @@ -34,17 +56,17 @@ async def create_async_generator(
stream: bool = False,
proxy: str = None,
api_key: str = None,
api_base: str = "https://generativelanguage.googleapis.com/v1beta",
api_base: str = api_base,
use_auth_header: bool = False,
images: ImagesType = None,
connector: BaseConnector = None,
**kwargs
) -> AsyncResult:
model = cls.get_model(model)

if not api_key:
raise MissingAuthError('Add a "api_key"')

model = cls.get_model(model, api_key=api_key, api_base=api_base)

headers = params = None
if use_auth_header:
headers = {"Authorization": f"Bearer {api_key}"}
Expand Down
8 changes: 4 additions & 4 deletions g4f/Provider/needs_auth/OpenaiAPI.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
fallback_models = []

@classmethod
def get_models(cls, api_key: str = None):
def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
if not cls.models:
try:
headers = {}
if api_key is not None:
headers["authorization"] = f"Bearer {api_key}"
response = requests.get(f"{cls.api_base}/models", headers=headers)
response = requests.get(f"{api_base}/models", headers=headers)
raise_for_status(response)
data = response.json()
cls.models = [model.get("id") for model in data.get("data")]
Expand Down Expand Up @@ -82,7 +82,7 @@ async def create_async_generator(
) as session:
data = filter_none(
messages=messages,
model=cls.get_model(model),
model=cls.get_model(model, api_key=api_key, api_base=api_base),
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
Expand Down Expand Up @@ -147,4 +147,4 @@ def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) ->
if api_key is not None else {}
),
**({} if headers is None else headers)
}
}
2 changes: 1 addition & 1 deletion g4f/Provider/needs_auth/OpenaiChat.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ async def synthesize(cls, params: dict) -> AsyncIterator[bytes]:
await cls.login()
async with StreamSession(
impersonate="chrome",
timeout=900
timeout=0
) as session:
async with session.get(
f"{cls.url}/backend-api/synthesize",
Expand Down
26 changes: 13 additions & 13 deletions g4f/gui/client/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -150,47 +150,47 @@ <h3>Settings</h3>
<label for="recognition-language" class="label" title="">Speech recognition language</label>
<input type="text" id="recognition-language" value="" placeholder="navigator.language"/>
</div>
<div class="field box">
<div class="field box hidden">
<label for="BingCreateImages-api_key" class="label" title="">Microsoft Designer in Bing:</label>
<textarea id="BingCreateImages-api_key" name="BingCreateImages[api_key]" placeholder="&quot;_U&quot; cookie"></textarea>
</div>
<div class="field box">
<div class="field box hidden">
<label for="Cerebras-api_key" class="label" title="">Cerebras Inference:</label>
<textarea id="Cerebras-api_key" name="Cerebras[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box">
<div class="field box hidden">
<label for="DeepInfra-api_key" class="label" title="">DeepInfra:</label>
<textarea id="DeepInfra-api_key" name="DeepInfra[api_key]" class="DeepInfraImage-api_key" placeholder="api_key"></textarea>
</div>
<div class="field box">
<div class="field box hidden">
<label for="GeminiPro-api_key" class="label" title="">Gemini API:</label>
<textarea id="GeminiPro-api_key" name="GeminiPro[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box">
<div class="field box hidden">
<label for="Groq-api_key" class="label" title="">Groq:</label>
<textarea id="Groq-api_key" name="Groq[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box">
<div class="field box hidden">
<label for="HuggingFace-api_key" class="label" title="">HuggingFace:</label>
<textarea id="HuggingFace-api_key" name="HuggingFace[api_key]" class="HuggingFaceAPI-api_key" placeholder="api_key"></textarea>
</div>
<div class="field box">
<label for="Openai-api_key" class="label" title="">OpenAI API:</label>
<textarea id="Openai-api_key" name="Openai[api_key]" placeholder="api_key"></textarea>
<div class="field box hidden">
<label for="OpenaiAPI-api_key" class="label" title="">OpenAI API:</label>
<textarea id="OpenaiAPI-api_key" name="OpenaiAPI[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box">
<div class="field box hidden">
<label for="OpenRouter-api_key" class="label" title="">OpenRouter:</label>
<textarea id="OpenRouter-api_key" name="OpenRouter[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box">
<div class="field box hidden">
<label for="PerplexityApi-api_key" class="label" title="">Perplexity API:</label>
<textarea id="PerplexityApi-api_key" name="PerplexityApi[api_key]" placeholder="api_key"></textarea>
</div>
<div class="field box">
<div class="field box hidden">
<label for="Replicate-api_key" class="label" title="">Replicate:</label>
<textarea id="Replicate-api_key" name="Replicate[api_key]" class="ReplicateImage-api_key" placeholder="api_key"></textarea>
</div>
<div class="field box">
<div class="field box hidden">
<label for="xAI-api_key" class="label" title="">xAI:</label>
<textarea id="xAI-api_key" name="xAI[api_key]" placeholder="api_key"></textarea>
</div>
Expand Down
3 changes: 2 additions & 1 deletion g4f/gui/client/static/css/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,8 @@ ul {
}

.settings h3 {
padding-left: 50px;
padding-left: 54px;
padding-top: 18px;
}

.buttons {
Expand Down
2 changes: 2 additions & 0 deletions g4f/gui/client/static/js/chat.v1.js
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,7 @@ const load_provider_option = (input, provider_name) => {
providerSelect.querySelectorAll(`option[data-parent="${provider_name}"]`).forEach(
(el) => el.removeAttribute("disabled")
);
settings.querySelector(`.field:has(#${provider_name}-api_key)`)?.classList.remove("hidden");
} else {
modelSelect.querySelectorAll(`option[data-providers*="${provider_name}"]`).forEach(
(el) => {
Expand All @@ -1307,6 +1308,7 @@ const load_provider_option = (input, provider_name) => {
providerSelect.querySelectorAll(`option[data-parent="${provider_name}"]`).forEach(
(el) => el.setAttribute("disabled", "disabled")
);
settings.querySelector(`.field:has(#${provider_name}-api_key)`)?.classList.add("hidden");
}
};

Expand Down
Loading

0 comments on commit 5da55b3

Please sign in to comment.