Skip to content

Commit

Permalink
Fix unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
hlohaus committed Jan 14, 2024
1 parent 32252de commit 55e5cf1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ jobs:
- name: Install requirements
- run: pip install -r requirements.txt
- name: Run tests
run: python -m etc.unittest.main
- run: python -m etc.unittest.main
52 changes: 40 additions & 12 deletions g4f/Provider/needs_auth/OpenaiChat.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ async def create(
image: ImageType = None,
**kwargs
) -> Response:
"""Create a new conversation or continue an existing one
"""
Create a new conversation or continue an existing one
Args:
prompt: The user input to start or continue the conversation
Expand Down Expand Up @@ -96,7 +97,8 @@ async def _upload_image(
headers: dict,
image: ImageType
) -> ImageResponse:
"""Upload an image to the service and get the download URL
"""
Upload an image to the service and get the download URL
Args:
session: The StreamSession object to use for requests
Expand Down Expand Up @@ -149,7 +151,8 @@ async def _upload_image(

@classmethod
async def _get_default_model(cls, session: StreamSession, headers: dict):
"""Get the default model name from the service
"""
Get the default model name from the service
Args:
session: The StreamSession object to use for requests
Expand All @@ -172,7 +175,8 @@ async def _get_default_model(cls, session: StreamSession, headers: dict):

@classmethod
def _create_messages(cls, prompt: str, image_response: ImageResponse = None):
"""Create a list of messages for the user input
"""
Create a list of messages for the user input
Args:
prompt: The user input as a string
Expand Down Expand Up @@ -222,10 +226,20 @@ async def _get_generated_image(cls, session: StreamSession, headers: dict, line:
"""
Retrieves the image response based on the message content.
:param session: The StreamSession object.
:param headers: HTTP headers for the request.
:param line: The line of response containing image information.
:return: An ImageResponse object with the image details.
This method processes the message content to extract image information and retrieves the
corresponding image from the backend API. It then returns an ImageResponse object containing
the image URL and the prompt used to generate the image.
Args:
session (StreamSession): The StreamSession object used for making HTTP requests.
headers (dict): HTTP headers to be used for the request.
line (dict): A dictionary representing the line of response that contains image information.
Returns:
ImageResponse: An object containing the image URL and the prompt, or None if no image is found.
Raises:
RuntimeError: If there'san error in downloading the image, including issues with the HTTP request or response.
"""
if "parts" not in line["message"]["content"]:
return
Expand All @@ -244,6 +258,20 @@ async def _get_generated_image(cls, session: StreamSession, headers: dict, line:

@classmethod
async def _delete_conversation(cls, session: StreamSession, headers: dict, conversation_id: str):
"""
Deletes a conversation by setting its visibility to False.
This method sends an HTTP PATCH request to update the visibility of a conversation.
It's used to effectively delete a conversation from being accessed or displayed in the future.
Args:
session (StreamSession): The StreamSession object used for making HTTP requests.
headers (dict): HTTP headers to be used for the request.
conversation_id (str): The unique identifier of the conversation to be deleted.
Raises:
HTTPError: If the HTTP request fails or returns an unsuccessful status code.
"""
async with session.patch(
f"{cls.url}/backend-api/conversation/{conversation_id}",
json={"is_visible": False},
Expand Down Expand Up @@ -283,7 +311,7 @@ async def create_async_generator(
history_disabled (bool): Flag to disable history and training.
action (str): Type of action ('next', 'continue', 'variant').
conversation_id (str): ID of the conversation.
parent_id (str): ID of the parent message.
parent_id (str): ID of the parent message.
image (ImageType): Image to include in the conversation.
response_fields (bool): Flag to include response fields in the output.
**kwargs: Additional keyword arguments.
Expand Down Expand Up @@ -397,7 +425,7 @@ async def create_async_generator(
await cls._delete_conversation(session, headers, conversation_id)

@classmethod
def _browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
def _browse_access_token(cls, proxy: str = None, timeout: int = 1200) -> tuple[str, dict]:
"""
Browse to obtain an access token.
Expand All @@ -410,7 +438,7 @@ def _browse_access_token(cls, proxy: str = None) -> tuple[str, dict]:
driver = get_browser(proxy=proxy)
try:
driver.get(f"{cls.url}/")
WebDriverWait(driver, 1200).until(EC.presence_of_element_located((By.ID, "prompt-textarea")))
WebDriverWait(driver, timeout).until(EC.presence_of_element_located((By.ID, "prompt-textarea")))
access_token = driver.execute_script(
"let session = await fetch('/api/auth/session');"
"let data = await session.json();"
Expand Down Expand Up @@ -471,7 +499,7 @@ def __init__(self, conversation_id: str, message_id: str, end_turn: EndTurn):
self.conversation_id = conversation_id
self.message_id = message_id
self._end_turn = end_turn

class Response():
"""
Class to encapsulate a response from the chat service.
Expand Down

0 comments on commit 55e5cf1

Please sign in to comment.