Skip to content

Commit

Permalink
Merge pull request #2536 from hlohaus/cont
Browse files Browse the repository at this point in the history
Support continue messages in Airforce
  • Loading branch information
hlohaus authored Jan 3, 2025
2 parents 48c4183 + 6e0bc14 commit 63a81fd
Show file tree
Hide file tree
Showing 17 changed files with 284 additions and 341 deletions.
14 changes: 12 additions & 2 deletions g4f/Provider/Airforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..typing import AsyncResult, Messages
from ..image import ImageResponse
from ..providers.response import FinishReason, Usage
from ..requests.raise_for_status import raise_for_status
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin

Expand Down Expand Up @@ -232,17 +233,19 @@ async def generate_text(
data = {
"messages": final_messages,
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"stream": stream,
}
if max_tokens != 512:
data["max_tokens"] = max_tokens

async with ClientSession(headers=headers) as session:
async with session.post(cls.api_endpoint_completions, json=data, proxy=proxy) as response:
await raise_for_status(response)

if stream:
idx = 0
async for line in response.content:
line = line.decode('utf-8').strip()
if line.startswith('data: '):
Expand All @@ -255,11 +258,18 @@ async def generate_text(
chunk = cls._filter_response(delta['content'])
if chunk:
yield chunk
idx += 1
except json.JSONDecodeError:
continue
if idx == 512:
yield FinishReason("length")
else:
# Non-streaming response
result = await response.json()
if "usage" in result:
yield Usage(**result["usage"])
if result["usage"]["completion_tokens"] == 512:
yield FinishReason("length")
if 'choices' in result and result['choices']:
message = result['choices'][0].get('message', {})
content = message.get('content', '')
Expand All @@ -273,7 +283,7 @@ async def create_async_generator(
messages: Messages,
prompt: str = None,
proxy: str = None,
max_tokens: int = 4096,
max_tokens: int = 512,
temperature: float = 1,
top_p: float = 1,
stream: bool = True,
Expand Down
4 changes: 3 additions & 1 deletion g4f/Provider/Copilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def create_completion(
cls._access_token, cls._cookies = asyncio.run(get_access_token_and_cookies(cls.url, proxy))
else:
raise h
yield Parameters(**{"api_key": cls._access_token, "cookies": cls._cookies})
yield Parameters(**{"api_key": cls._access_token, "cookies": cls._cookies if isinstance(cls._cookies, dict) else {c.name: c.value for c in cls._cookies}})
websocket_url = f"{websocket_url}&accessToken={quote(cls._access_token)}"
headers = {"authorization": f"Bearer {cls._access_token}"}

Expand Down Expand Up @@ -191,6 +191,8 @@ def create_completion(
yield ImageResponse(msg.get("url"), image_prompt, {"preview": msg.get("thumbnailUrl")})
elif msg.get("event") == "done":
break
elif msg.get("event") == "replaceText":
yield msg.get("text")
elif msg.get("event") == "error":
raise RuntimeError(f"Error: {msg}")
elif msg.get("event") not in ["received", "startMessage", "citation", "partCompleted"]:
Expand Down
4 changes: 2 additions & 2 deletions g4f/Provider/needs_auth/Gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,13 @@ async def iter_filter_base64(chunks: AsyncIterator[bytes]) -> AsyncIterator[byte
async for chunk in chunks:
if is_started:
if end_with in chunk:
yield chunk.split(end_with, 1, maxsplit=1).pop(0)
yield chunk.split(end_with, maxsplit=1).pop(0)
break
else:
yield chunk
elif search_for in chunk:
is_started = True
yield chunk.split(search_for, 1, maxsplit=1).pop()
yield chunk.split(search_for, maxsplit=1).pop()
else:
raise ValueError(f"Response: {chunk}")

Expand Down
2 changes: 1 addition & 1 deletion g4f/Provider/needs_auth/HuggingFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def create_async_generator(
messages = [m for m in messages if m["role"] == "system"] + [messages[-1]]
inputs = get_inputs(messages, model_data, model_type, do_continue)
debug.log(f"New len: {len(inputs)}")
if model_type == "gpt2" and max_new_tokens >= 1024:
if model_type == "gpt2" and max_tokens >= 1024:
params["max_new_tokens"] = 512
payload = {"inputs": inputs, "parameters": params, "stream": stream}

Expand Down
100 changes: 57 additions & 43 deletions g4f/Provider/needs_auth/OpenaiChat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
except ImportError:
has_nodriver = False

from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..base_provider import AsyncAuthedProvider, ProviderModelMixin
from ...typing import AsyncResult, Messages, Cookies, ImagesType
from ...requests.raise_for_status import raise_for_status
from ...requests import StreamSession
from ...requests import get_nodriver
from ...image import ImageResponse, ImageRequest, to_image, to_bytes, is_accepted_format
from ...errors import MissingAuthError, NoValidHarFileError
from ...providers.response import JsonConversation, FinishReason, SynthesizeData
from ...providers.response import JsonConversation, FinishReason, SynthesizeData, AuthResult
from ...providers.response import Sources, TitleGeneration, RequestLogin, Parameters
from ..helper import format_cookies
from ..openai.har_file import get_request_config
Expand Down Expand Up @@ -85,7 +85,7 @@
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
}

class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
class OpenaiChat(AsyncAuthedProvider, ProviderModelMixin):
"""A class for creating and managing conversations with OpenAI chat service"""

label = "OpenAI ChatGPT"
Expand All @@ -104,6 +104,20 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
_cookies: Cookies = None
_expires: int = None

@classmethod
async def on_auth_async(cls, **kwargs) -> AuthResult:
if cls.needs_auth:
async for _ in cls.login():
pass
return AuthResult(
api_key=cls._api_key,
cookies=cls._cookies or RequestConfig.cookies or {},
headers=cls._headers or RequestConfig.headers or cls.get_default_headers(),
expires=cls._expires,
proof_token=RequestConfig.proof_token,
turnstile_token=RequestConfig.turnstile_token
)

@classmethod
def get_models(cls, proxy: str = None, timeout: int = 180) -> List[str]:
if not cls.models:
Expand Down Expand Up @@ -135,7 +149,7 @@ def get_models(cls, proxy: str = None, timeout: int = 180) -> List[str]:
async def upload_images(
cls,
session: StreamSession,
headers: dict,
auth_result: AuthResult,
images: ImagesType,
) -> ImageRequest:
"""
Expand All @@ -160,8 +174,8 @@ async def upload_image(image, image_name):
"use_case": "multimodal"
}
# Post the image data to the service and get the image data
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=headers) as response:
cls._update_request_args(session)
async with session.post(f"{cls.url}/backend-api/files", json=data, headers=auth_result.headers) as response:
cls._update_request_args(auth_result, session)
await raise_for_status(response, "Create file failed")
image_data = {
**data,
Expand Down Expand Up @@ -189,9 +203,9 @@ async def upload_image(image, image_name):
async with session.post(
f"{cls.url}/backend-api/files/{image_data['file_id']}/uploaded",
json={},
headers=headers
headers=auth_result.headers
) as response:
cls._update_request_args(session)
cls._update_request_args(auth_result, session)
await raise_for_status(response, "Get download url failed")
image_data["download_url"] = (await response.json())["download_url"]
return ImageRequest(image_data)
Expand Down Expand Up @@ -248,7 +262,7 @@ def create_messages(cls, messages: Messages, image_requests: ImageRequest = None
return messages

@classmethod
async def get_generated_image(cls, session: StreamSession, headers: dict, element: dict, prompt: str = None) -> ImageResponse:
async def get_generated_image(cls, auth_result: AuthResult, session: StreamSession, element: dict, prompt: str = None) -> ImageResponse:
try:
prompt = element["metadata"]["dalle"]["prompt"]
file_id = element["asset_pointer"].split("file-service://", 1)[1]
Expand All @@ -257,19 +271,20 @@ async def get_generated_image(cls, session: StreamSession, headers: dict, elemen
except Exception as e:
raise RuntimeError(f"No Image: {e.__class__.__name__}: {e}")
try:
async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=headers) as response:
cls._update_request_args(session)
async with session.get(f"{cls.url}/backend-api/files/{file_id}/download", headers=auth_result.headers) as response:
cls._update_request_args(auth_result, session)
await raise_for_status(response)
download_url = (await response.json())["download_url"]
return ImageResponse(download_url, prompt)
except Exception as e:
raise RuntimeError(f"Error in downloading image: {e}")

@classmethod
async def create_async_generator(
async def create_authed(
cls,
model: str,
messages: Messages,
auth_result: AuthResult,
proxy: str = None,
timeout: int = 180,
auto_continue: bool = False,
Expand All @@ -279,7 +294,7 @@ async def create_async_generator(
conversation: Conversation = None,
images: ImagesType = None,
return_conversation: bool = False,
max_retries: int = 3,
max_retries: int = 0,
web_search: bool = False,
**kwargs
) -> AsyncResult:
Expand All @@ -306,9 +321,6 @@ async def create_async_generator(
Raises:
RuntimeError: If an error occurs during processing.
"""
if cls.needs_auth:
async for message in cls.login(proxy, **kwargs):
yield message
async with StreamSession(
proxy=proxy,
impersonate="chrome",
Expand All @@ -319,15 +331,18 @@ async def create_async_generator(
if cls._headers is None:
cls._create_request_args(cls._cookies)
async with session.get(cls.url, headers=INIT_HEADERS) as response:
cls._update_request_args(session)
cls._update_request_args(auth_result, session)
await raise_for_status(response)
else:
print(cls._headers)
async with session.get(cls.url, headers=cls._headers) as response:
cls._update_request_args(session)
if cls._headers is None:
cls._create_request_args(auth_result.cookies, auth_result.headers)
if not cls._set_api_key(auth_result.api_key):
raise MissingAuthError("Access token is not valid")
async with session.get(cls.url, headers=auth_result.headers) as response:
cls._update_request_args(auth_result, session)
await raise_for_status(response)
try:
image_requests = await cls.upload_images(session, cls._headers, images) if images else None
image_requests = await cls.upload_images(session, auth_result, images) if images else None
except Exception as e:
debug.log("OpenaiChat: Upload image failed")
debug.log(f"{e.__class__.__name__}: {e}")
Expand All @@ -345,36 +360,36 @@ async def create_async_generator(
f"{cls.url}/backend-anon/sentinel/chat-requirements"
if cls._api_key is None else
f"{cls.url}/backend-api/sentinel/chat-requirements",
json={"p": get_requirements_token(RequestConfig.proof_token) if RequestConfig.proof_token else None},
json={"p": None if auth_result.proof_token is None else get_requirements_token(auth_result.proof_token)},
headers=cls._headers
) as response:
if response.status == 401:
cls._headers = cls._api_key = None
else:
cls._update_request_args(session)
cls._update_request_args(auth_result, session)
await raise_for_status(response)
chat_requirements = await response.json()
need_turnstile = chat_requirements.get("turnstile", {}).get("required", False)
need_arkose = chat_requirements.get("arkose", {}).get("required", False)
chat_token = chat_requirements.get("token")

if need_arkose and RequestConfig.arkose_token is None:
await get_request_config(proxy)
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
cls._set_api_key(RequestConfig.access_token)
if RequestConfig.arkose_token is None:
raise MissingAuthError("No arkose token found in .har file")
# if need_arkose and RequestConfig.arkose_token is None:
# await get_request_config(proxy)
# cls._create_request_args(auth_result.cookies, auth_result.headers)
# cls._set_api_key(auth_result.access_token)
# if auth_result.arkose_token is None:
# raise MissingAuthError("No arkose token found in .har file")

if "proofofwork" in chat_requirements:
if RequestConfig.proof_token is None:
RequestConfig.proof_token = get_config(cls._headers.get("user-agent"))
if auth_result.proof_token is None:
auth_result.proof_token = get_config(auth_result.headers.get("user-agent"))
proofofwork = generate_proof_token(
**chat_requirements["proofofwork"],
user_agent=cls._headers.get("user-agent"),
proof_token=RequestConfig.proof_token
user_agent=auth_result.headers.get("user-agent"),
proof_token=auth_result.proof_token
)
[debug.log(text) for text in (
f"Arkose: {'False' if not need_arkose else RequestConfig.arkose_token[:12]+'...'}",
#f"Arkose: {'False' if not need_arkose else auth_result.arkose_token[:12]+'...'}",
f"Proofofwork: {'False' if proofofwork is None else proofofwork[:12]+'...'}",
f"AccessToken: {'False' if cls._api_key is None else cls._api_key[:12]+'...'}",
)]
Expand Down Expand Up @@ -414,20 +429,20 @@ async def create_async_generator(
"content-type": "application/json",
"openai-sentinel-chat-requirements-token": chat_token,
}
if RequestConfig.arkose_token:
headers["openai-sentinel-arkose-token"] = RequestConfig.arkose_token
#if RequestConfig.arkose_token:
# headers["openai-sentinel-arkose-token"] = RequestConfig.arkose_token
if proofofwork is not None:
headers["openai-sentinel-proof-token"] = proofofwork
if need_turnstile and RequestConfig.turnstile_token is not None:
headers['openai-sentinel-turnstile-token'] = RequestConfig.turnstile_token
if need_turnstile and auth_result.turnstile_token is not None:
headers['openai-sentinel-turnstile-token'] = auth_result.turnstile_token
async with session.post(
f"{cls.url}/backend-anon/conversation"
if cls._api_key is None else
f"{cls.url}/backend-api/conversation",
json=data,
headers=headers
) as response:
cls._update_request_args(session)
cls._update_request_args(auth_result, session)
if response.status in (403, 404) and max_retries > 0:
max_retries -= 1
debug.log(f"Retry: Error {response.status}: {await response.text()}")
Expand Down Expand Up @@ -462,7 +477,7 @@ def replacer(match):
yield sources
if return_conversation:
yield conversation
if not history_disabled and cls._api_key is not None:
if not history_disabled and auth_result.api_key is not None:
yield SynthesizeData(cls.__name__, {
"conversation_id": conversation.conversation_id,
"message_id": conversation.message_id,
Expand Down Expand Up @@ -587,7 +602,6 @@ async def login(
try:
await get_request_config(proxy)
cls._create_request_args(RequestConfig.cookies, RequestConfig.headers)
print(RequestConfig.access_token)
if RequestConfig.access_token is not None or cls.needs_auth:
if not cls._set_api_key(RequestConfig.access_token):
raise NoValidHarFileError(f"Access token is not valid: {RequestConfig.access_token}")
Expand Down Expand Up @@ -673,9 +687,9 @@ def _create_request_args(cls, cookies: Cookies = None, headers: dict = None, use
cls._update_cookie_header()

@classmethod
def _update_request_args(cls, session: StreamSession):
def _update_request_args(cls, auth_result: AuthResult, session: StreamSession):
for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar:
cls._cookies[getattr(c, "key", getattr(c, "name", ""))] = c.value
auth_result.cookies[getattr(c, "key", getattr(c, "name", ""))] = c.value
cls._update_cookie_header()

@classmethod
Expand Down
Loading

0 comments on commit 63a81fd

Please sign in to comment.