Skip to content

Commit

Permalink
Merge pull request #2439 from hlohaus/retry
Browse files Browse the repository at this point in the history
Fix api with default providers, add unittests for RetryProvider
  • Loading branch information
hlohaus authored Nov 28, 2024
2 parents 971a01e + c31f543 commit 2cf2f86
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 91 deletions.
3 changes: 2 additions & 1 deletion etc/unittest/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
from .model import *
from .client import *
from .include import *
from .retry_provider import *

unittest.main()
unittest.main()
32 changes: 30 additions & 2 deletions etc/unittest/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,37 @@ def create_completion(

class YieldProviderMock(AsyncGeneratorProvider):
working = True

async def create_async_generator(
model, messages, stream, **kwargs
):
for message in messages:
yield message["content"]
yield message["content"]

class RaiseExceptionProviderMock(AbstractProvider):
working = True

@classmethod
def create_completion(
cls, model, messages, stream, **kwargs
):
raise RuntimeError(cls.__name__)
yield cls.__name__

class AsyncRaiseExceptionProviderMock(AsyncGeneratorProvider):
working = True

@classmethod
async def create_async_generator(
cls, model, messages, stream, **kwargs
):
raise RuntimeError(cls.__name__)
yield cls.__name__

class YieldNoneProviderMock(AsyncGeneratorProvider):
working = True

async def create_async_generator(
model, messages, stream, **kwargs
):
yield None
60 changes: 60 additions & 0 deletions etc/unittest/retry_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

import unittest

from g4f.client import AsyncClient, ChatCompletion, ChatCompletionChunk
from g4f.providers.retry_provider import IterListProvider
from .mocks import YieldProviderMock, RaiseExceptionProviderMock, AsyncRaiseExceptionProviderMock, YieldNoneProviderMock

DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]

class TestIterListProvider(unittest.IsolatedAsyncioTestCase):

async def test_skip_provider(self):
client = AsyncClient(provider=IterListProvider([RaiseExceptionProviderMock, YieldProviderMock], False))
response = await client.chat.completions.create(DEFAULT_MESSAGES, "")
self.assertIsInstance(response, ChatCompletion)
self.assertEqual("Hello", response.choices[0].message.content)

async def test_only_one_result(self):
client = AsyncClient(provider=IterListProvider([YieldProviderMock, YieldProviderMock]))
response = await client.chat.completions.create(DEFAULT_MESSAGES, "")
self.assertIsInstance(response, ChatCompletion)
self.assertEqual("Hello", response.choices[0].message.content)

async def test_stream_skip_provider(self):
client = AsyncClient(provider=IterListProvider([AsyncRaiseExceptionProviderMock, YieldProviderMock], False))
messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
response = client.chat.completions.create(messages, "Hello", stream=True)
async for chunk in response:
chunk: ChatCompletionChunk = chunk
self.assertIsInstance(chunk, ChatCompletionChunk)
if chunk.choices[0].delta.content is not None:
self.assertIsInstance(chunk.choices[0].delta.content, str)

async def test_stream_only_one_result(self):
client = AsyncClient(provider=IterListProvider([YieldProviderMock, YieldProviderMock], False))
messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You "]]
response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2)
response_list = []
async for chunk in response:
response_list.append(chunk)
self.assertEqual(len(response_list), 3)
for chunk in response_list:
if chunk.choices[0].delta.content is not None:
self.assertEqual(chunk.choices[0].delta.content, "You ")

async def test_skip_none(self):
client = AsyncClient(provider=IterListProvider([YieldNoneProviderMock, YieldProviderMock], False))
response = await client.chat.completions.create(DEFAULT_MESSAGES, "")
self.assertIsInstance(response, ChatCompletion)
self.assertEqual("Hello", response.choices[0].message.content)

async def test_stream_skip_none(self):
client = AsyncClient(provider=IterListProvider([YieldNoneProviderMock, YieldProviderMock], False))
response = client.chat.completions.create(DEFAULT_MESSAGES, "", stream=True)
response_list = [chunk async for chunk in response]
self.assertEqual(len(response_list), 2)
for chunk in response_list:
if chunk.choices[0].delta.content is not None:
self.assertEqual(chunk.choices[0].delta.content, "Hello")
2 changes: 1 addition & 1 deletion g4f/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def create(
**kwargs
)

if not isinstance(response, AsyncIterator):
if not hasattr(response, "__aiter__"):
response = to_async_iterator(response)
response = async_iter_response(response, stream, response_format, max_tokens, stop)
response = async_iter_append_model_and_provider(response)
Expand Down
4 changes: 2 additions & 2 deletions g4f/client/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from ..models import Model, ModelUtils, default
from ..Provider import ProviderUtils
from ..providers.types import BaseRetryProvider, ProviderType
from ..providers.retry_provider import IterProvider
from ..providers.retry_provider import IterListProvider

def convert_to_provider(provider: str) -> ProviderType:
if " " in provider:
provider_list = [ProviderUtils.convert[p] for p in provider.split() if p in ProviderUtils.convert]
if not provider_list:
raise ProviderNotFoundError(f'Providers not found: {provider}')
provider = IterProvider(provider_list)
provider = IterListProvider(provider_list, False)
elif provider in ProviderUtils.convert:
provider = ProviderUtils.convert[provider]
elif provider:
Expand Down
6 changes: 4 additions & 2 deletions g4f/providers/base_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ async def create_async(
loop = loop or asyncio.get_running_loop()

def create_func() -> str:
return "".join(cls.create_completion(model, messages, False, **kwargs))
chunks = [str(chunk) for chunk in cls.create_completion(model, messages, False, **kwargs) if chunk]
if chunks:
return "".join(chunks)

return await asyncio.wait_for(
loop.run_in_executor(executor, create_func),
Expand Down Expand Up @@ -205,7 +207,7 @@ async def create_async(
"""
return "".join([
str(chunk) async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
if not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData))
if chunk and not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData))
])

@staticmethod
Expand Down
145 changes: 62 additions & 83 deletions g4f/providers/retry_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError

DEFAULT_TIMEOUT = 60

class IterListProvider(BaseRetryProvider):
def __init__(
self,
Expand Down Expand Up @@ -50,12 +52,12 @@ def create_completion(

for provider in self.get_providers(stream):
self.last_provider = provider
debug.log(f"Using {provider.__name__} provider")
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
started = True
for chunk in provider.create_completion(model, messages, stream, **kwargs):
if chunk:
yield chunk
started = True
if started:
return
except Exception as e:
Expand Down Expand Up @@ -87,13 +89,14 @@ async def create_async(

for provider in self.get_providers(False):
self.last_provider = provider
debug.log(f"Using {provider.__name__} provider")
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
return await asyncio.wait_for(
chunk = await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", 60),
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
)
if chunk:
return chunk
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
Expand All @@ -119,16 +122,21 @@ async def create_async_generator(

for provider in self.get_providers(stream):
self.last_provider = provider
debug.log(f"Using {provider.__name__} provider")
try:
if debug.logging:
print(f"Using {provider.__name__} provider")
if not stream:
yield await provider.create_async(model, messages, **kwargs)
started = True
elif hasattr(provider, "create_async_generator"):
async for token in provider.create_async_generator(model, messages, stream=stream, **kwargs):
yield token
chunk = await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
)
if chunk:
yield chunk
started = True
elif hasattr(provider, "create_async_generator"):
async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs):
if chunk:
yield chunk
started = True
else:
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
Expand All @@ -137,8 +145,7 @@ async def create_async_generator(
return
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started:
raise e

Expand Down Expand Up @@ -243,76 +250,48 @@ async def create_async(
else:
return await super().create_async(model, messages, **kwargs)

class IterProvider(BaseRetryProvider):
__name__ = "IterProvider"

def __init__(
self,
providers: List[BaseProvider],
) -> None:
providers.reverse()
self.providers: List[BaseProvider] = providers
self.working: bool = True
self.last_provider: BaseProvider = None

def create_completion(
self,
model: str,
messages: Messages,
stream: bool = False,
**kwargs
) -> CreateResult:
exceptions: dict = {}
started: bool = False
for provider in self.iter_providers():
if stream and not provider.supports_stream:
continue
try:
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
started = True
if started:
return
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started:
raise e
raise_exceptions(exceptions)

async def create_async(
async def create_async_generator(
self,
model: str,
messages: Messages,
stream: bool = True,
**kwargs
) -> str:
exceptions: dict = {}
for provider in self.iter_providers():
try:
return await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", 60)
)
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
raise_exceptions(exceptions)
) -> AsyncResult:
exceptions = {}
started = False

def iter_providers(self) -> Iterator[BaseProvider]:
used_provider = []
try:
while self.providers:
provider = self.providers.pop()
used_provider.append(provider)
self.last_provider = provider
if debug.logging:
print(f"Using {provider.__name__} provider")
yield provider
finally:
used_provider.reverse()
self.providers = [*used_provider, *self.providers]
if self.single_provider_retry:
provider = self.providers[0]
self.last_provider = provider
for attempt in range(self.max_retries):
try:
debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})")
if not stream:
chunk = await asyncio.wait_for(
provider.create_async(model, messages, **kwargs),
timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
)
if chunk:
started = True
elif hasattr(provider, "create_async_generator"):
async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs):
if chunk:
yield chunk
started = True
else:
for token in provider.create_completion(model, messages, stream, **kwargs):
yield token
started = True
if started:
return
except Exception as e:
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
raise_exceptions(exceptions)
else:
async for chunk in super().create_async_generator(model, messages, stream, **kwargs):
yield chunk

def raise_exceptions(exceptions: dict) -> None:
"""
Expand Down

0 comments on commit 2cf2f86

Please sign in to comment.