From d236aec3007454d62b5e9aa05c87bd80f6a54216 Mon Sep 17 00:00:00 2001 From: extreme4all <40169115+extreme4all@users.noreply.github.com> Date: Tue, 5 Nov 2024 15:39:12 +0100 Subject: [PATCH] no session support --- osrs/asyncio/osrs/hiscores.py | 17 +++++++++++--- osrs/asyncio/osrs/itemdb.py | 39 ++++++++++++++++++++++++-------- osrs/asyncio/wiki/prices.py | 28 +++++++++++++++++------ tests/test_async_osrs_hiscore.py | 12 ++++++++++ tests/test_async_osrs_itemdb.py | 18 +++++++++++++++ 5 files changed, 94 insertions(+), 20 deletions(-) diff --git a/osrs/asyncio/osrs/hiscores.py b/osrs/asyncio/osrs/hiscores.py index ccdcd6b..51c30d3 100644 --- a/osrs/asyncio/osrs/hiscores.py +++ b/osrs/asyncio/osrs/hiscores.py @@ -49,7 +49,12 @@ def __init__( self.proxy = proxy self.rate_limiter = rate_limiter - async def get(self, mode: Mode, player: str, session: ClientSession) -> PlayerStats: + async def get( + self, + player: str, + mode: Mode = Mode.OLDSCHOOL, + session: ClientSession | None = None, + ) -> PlayerStats: """ Fetches player stats from the OSRS hiscores API. @@ -73,7 +78,9 @@ async def get(self, mode: Mode, player: str, session: ClientSession) -> PlayerSt url = f"{self.BASE_URL}/m={mode.value}/index_lite.json" params = {"player": player} - async with session.get(url, proxy=self.proxy, params=params) as response: + _session = ClientSession() if session is None else session + + async with _session.get(url, proxy=self.proxy, params=params) as response: # when the HS are down it will redirect to the main page. # after redirction it will return a 200, so we must check for redirection first if response.history and any(r.status == 302 for r in response.history): @@ -88,4 +95,8 @@ async def get(self, mode: Mode, player: str, session: ClientSession) -> PlayerSt response.raise_for_status() raise Undefined() data = await response.json() - return PlayerStats(**data) + + if session is None: + await _session.close() + + return PlayerStats(**data) diff --git a/osrs/asyncio/osrs/itemdb.py b/osrs/asyncio/osrs/itemdb.py index d7f81f8..d738cc7 100644 --- a/osrs/asyncio/osrs/itemdb.py +++ b/osrs/asyncio/osrs/itemdb.py @@ -98,10 +98,10 @@ def __init__( async def get_items( self, - session: ClientSession, alpha: str, page: int | None = 1, mode: Mode = Mode.OLDSCHOOL, + session: ClientSession | None = None, category: int = 1, ) -> Items: """Fetch items from the RuneScape item catalog based on alphabetical filter. @@ -124,16 +124,22 @@ async def get_items( logger.debug(f"[GET]: {url=}, {params=}") - async with session.get(url, proxy=self.proxy, params=params) as response: + _session = ClientSession() if session is None else session + + async with _session.get(url, proxy=self.proxy, params=params) as response: response.raise_for_status() data = await response.text() - return Items(**json.loads(data)) + + if session is None: + await _session.close() + + return Items(**json.loads(data)) async def get_detail( self, - session: ClientSession, item_id: int, mode: Mode = Mode.OLDSCHOOL, + session: ClientSession | None = None, ) -> Detail: """Fetch detailed information about a specific item. @@ -152,17 +158,24 @@ async def get_detail( logger.debug(f"[GET]: {url=}, {params=}") - async with session.get(url, proxy=self.proxy, params=params) as response: + _session = ClientSession() if session is None else session + + async with _session.get(url, proxy=self.proxy, params=params) as response: response.raise_for_status() data = await response.text() - return Detail(**json.loads(data)) + + if session is None: + await _session.close() + return Detail(**json.loads(data)) class Graph: BASE_URL = "https://secure.runescape.com" def __init__( - self, proxy: str = "", rate_limiter: RateLimiter = RateLimiter() + self, + proxy: str = "", + rate_limiter: RateLimiter = RateLimiter(), ) -> None: """Initialize the Catalogue with an optional proxy and rate limiter. @@ -176,9 +189,9 @@ def __init__( async def get_graph( self, - session: ClientSession, item_id: int, mode: Mode = Mode.OLDSCHOOL, + session: ClientSession | None = None, ) -> TradeHistory: """Fetch trade history graph data for a specific item. @@ -196,7 +209,13 @@ async def get_graph( logger.debug(f"[GET]: {url=}") - async with session.get(url, proxy=self.proxy) as response: + _session = ClientSession() if session is None else session + + async with _session.get(url, proxy=self.proxy) as response: response.raise_for_status() data = await response.text() - return TradeHistory(**json.loads(data)) + + if session is None: + await _session.close() + + return TradeHistory(**json.loads(data)) diff --git a/osrs/asyncio/wiki/prices.py b/osrs/asyncio/wiki/prices.py index 8ff895f..bfe640d 100644 --- a/osrs/asyncio/wiki/prices.py +++ b/osrs/asyncio/wiki/prices.py @@ -94,7 +94,12 @@ def __init__( raise Exception("invalid input") self.user_agent = inp - async def fetch_data(self, session: ClientSession, url: str, params: dict = {}): + async def fetch_data( + self, + url: str, + session: ClientSession | None = None, + params: dict = {}, + ): """ Utility method to fetch data from a specific endpoint, with ratelimiter, and basic error handling @@ -109,16 +114,23 @@ async def fetch_data(self, session: ClientSession, url: str, params: dict = {}): """ await self.rate_limiter.check() - async with session.get(url, proxy=self.proxy, params=params) as response: + _session = ClientSession() if session is None else session + + async with _session.get(url, proxy=self.proxy, params=params) as response: if response.status == 400: error = await response.json() raise Exception(error) elif response.status != 200: response.raise_for_status() raise Undefined("Unexpected error.") - return await response.json() + data = await response.json() + + if session is None: + await _session.close() + + return data - async def get_mapping(self, session: ClientSession): + async def get_mapping(self, session: ClientSession | None = None): """ Fetches item mappings containing metadata. @@ -138,7 +150,9 @@ async def get_mapping(self, session: ClientSession): data = await self.fetch_data(session=session, url=url) return [ItemMapping(**item) for item in data] - async def get_latest_prices(self, session: ClientSession) -> LatestPrices: + async def get_latest_prices( + self, session: ClientSession | None = None + ) -> LatestPrices: """ Fetches the latest prices for all items. @@ -160,8 +174,8 @@ async def get_latest_prices(self, session: ClientSession) -> LatestPrices: async def get_average_prices( self, - session: ClientSession, interval: Interval, + session: ClientSession | None = None, timestamp: int | None = None, ) -> AveragePrices: """ @@ -187,7 +201,7 @@ async def get_average_prices( return AveragePrices(**data) async def get_time_series( - self, session: ClientSession, item_id: int, timestep: Interval + self, item_id: int, timestep: Interval, session: ClientSession | None = None ) -> TimeSeries: """ Fetches time-series data for a specific item and timestep. diff --git a/tests/test_async_osrs_hiscore.py b/tests/test_async_osrs_hiscore.py index a88e6fa..98e590a 100644 --- a/tests/test_async_osrs_hiscore.py +++ b/tests/test_async_osrs_hiscore.py @@ -34,3 +34,15 @@ async def test_get_invalid(): player="This_is_not_a_valid_name", session=session, ) + + +@pytest.mark.asyncio +async def test_get_default_no_session(): + hiscore_instance = Hiscore() + player_stats = await hiscore_instance.get(player="extreme4all") + # Assertions to confirm the response is correct + assert isinstance( + player_stats, PlayerStats + ), "The returned object is not of type PlayerStats" + assert player_stats.skills, "Skills data should not be empty" + assert player_stats.activities, "Activities data should not be empty" diff --git a/tests/test_async_osrs_itemdb.py b/tests/test_async_osrs_itemdb.py index e400df7..a6a1d6d 100644 --- a/tests/test_async_osrs_itemdb.py +++ b/tests/test_async_osrs_itemdb.py @@ -95,6 +95,24 @@ async def test_get_graph_valid(): assert trade_history.average, "Average trade history should not be empty" +@pytest.mark.asyncio +async def test_get_graph_valid_no_session(): + """Test fetching trade history for a valid item ID""" + catalogue_instance = Graph() + + item_id = 4151 # Assume this is a valid item ID + trade_history = await catalogue_instance.get_graph( + item_id=item_id, mode=ItemDBMode.OLDSCHOOL + ) + + # Assertions to confirm the response is correct + assert isinstance( + trade_history, TradeHistory + ), "The returned object is not of type TradeHistory" + assert trade_history.daily, "Daily trade history should not be empty" + assert trade_history.average, "Average trade history should not be empty" + + @pytest.mark.asyncio async def test_get_graph_invalid(): """Test fetching trade history for an invalid item ID"""