Skip to content

Commit

Permalink
no session support
Browse files Browse the repository at this point in the history
  • Loading branch information
extreme4all committed Nov 5, 2024
1 parent 586ff0e commit d236aec
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 20 deletions.
17 changes: 14 additions & 3 deletions osrs/asyncio/osrs/hiscores.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ def __init__(
self.proxy = proxy
self.rate_limiter = rate_limiter

async def get(self, mode: Mode, player: str, session: ClientSession) -> PlayerStats:
async def get(
self,
player: str,
mode: Mode = Mode.OLDSCHOOL,
session: ClientSession | None = None,
) -> PlayerStats:
"""
Fetches player stats from the OSRS hiscores API.
Expand All @@ -73,7 +78,9 @@ async def get(self, mode: Mode, player: str, session: ClientSession) -> PlayerSt
url = f"{self.BASE_URL}/m={mode.value}/index_lite.json"
params = {"player": player}

async with session.get(url, proxy=self.proxy, params=params) as response:
_session = ClientSession() if session is None else session

async with _session.get(url, proxy=self.proxy, params=params) as response:
# when the HS are down it will redirect to the main page.
# after redirction it will return a 200, so we must check for redirection first
if response.history and any(r.status == 302 for r in response.history):
Expand All @@ -88,4 +95,8 @@ async def get(self, mode: Mode, player: str, session: ClientSession) -> PlayerSt
response.raise_for_status()
raise Undefined()
data = await response.json()
return PlayerStats(**data)

if session is None:
await _session.close()

return PlayerStats(**data)
39 changes: 29 additions & 10 deletions osrs/asyncio/osrs/itemdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ def __init__(

async def get_items(
self,
session: ClientSession,
alpha: str,
page: int | None = 1,
mode: Mode = Mode.OLDSCHOOL,
session: ClientSession | None = None,
category: int = 1,
) -> Items:
"""Fetch items from the RuneScape item catalog based on alphabetical filter.
Expand All @@ -124,16 +124,22 @@ async def get_items(

logger.debug(f"[GET]: {url=}, {params=}")

async with session.get(url, proxy=self.proxy, params=params) as response:
_session = ClientSession() if session is None else session

async with _session.get(url, proxy=self.proxy, params=params) as response:
response.raise_for_status()
data = await response.text()
return Items(**json.loads(data))

if session is None:
await _session.close()

return Items(**json.loads(data))

async def get_detail(
self,
session: ClientSession,
item_id: int,
mode: Mode = Mode.OLDSCHOOL,
session: ClientSession | None = None,
) -> Detail:
"""Fetch detailed information about a specific item.
Expand All @@ -152,17 +158,24 @@ async def get_detail(

logger.debug(f"[GET]: {url=}, {params=}")

async with session.get(url, proxy=self.proxy, params=params) as response:
_session = ClientSession() if session is None else session

async with _session.get(url, proxy=self.proxy, params=params) as response:
response.raise_for_status()
data = await response.text()
return Detail(**json.loads(data))

if session is None:
await _session.close()
return Detail(**json.loads(data))


class Graph:
BASE_URL = "https://secure.runescape.com"

def __init__(
self, proxy: str = "", rate_limiter: RateLimiter = RateLimiter()
self,
proxy: str = "",
rate_limiter: RateLimiter = RateLimiter(),
) -> None:
"""Initialize the Catalogue with an optional proxy and rate limiter.
Expand All @@ -176,9 +189,9 @@ def __init__(

async def get_graph(
self,
session: ClientSession,
item_id: int,
mode: Mode = Mode.OLDSCHOOL,
session: ClientSession | None = None,
) -> TradeHistory:
"""Fetch trade history graph data for a specific item.
Expand All @@ -196,7 +209,13 @@ async def get_graph(

logger.debug(f"[GET]: {url=}")

async with session.get(url, proxy=self.proxy) as response:
_session = ClientSession() if session is None else session

async with _session.get(url, proxy=self.proxy) as response:
response.raise_for_status()
data = await response.text()
return TradeHistory(**json.loads(data))

if session is None:
await _session.close()

return TradeHistory(**json.loads(data))
28 changes: 21 additions & 7 deletions osrs/asyncio/wiki/prices.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ def __init__(
raise Exception("invalid input")
self.user_agent = inp

async def fetch_data(self, session: ClientSession, url: str, params: dict = {}):
async def fetch_data(
self,
url: str,
session: ClientSession | None = None,
params: dict = {},
):
"""
Utility method to fetch data from a specific endpoint, with ratelimiter,
and basic error handling
Expand All @@ -109,16 +114,23 @@ async def fetch_data(self, session: ClientSession, url: str, params: dict = {}):
"""
await self.rate_limiter.check()

async with session.get(url, proxy=self.proxy, params=params) as response:
_session = ClientSession() if session is None else session

async with _session.get(url, proxy=self.proxy, params=params) as response:
if response.status == 400:
error = await response.json()
raise Exception(error)
elif response.status != 200:
response.raise_for_status()
raise Undefined("Unexpected error.")
return await response.json()
data = await response.json()

if session is None:
await _session.close()

return data

async def get_mapping(self, session: ClientSession):
async def get_mapping(self, session: ClientSession | None = None):
"""
Fetches item mappings containing metadata.
Expand All @@ -138,7 +150,9 @@ async def get_mapping(self, session: ClientSession):
data = await self.fetch_data(session=session, url=url)
return [ItemMapping(**item) for item in data]

async def get_latest_prices(self, session: ClientSession) -> LatestPrices:
async def get_latest_prices(
self, session: ClientSession | None = None
) -> LatestPrices:
"""
Fetches the latest prices for all items.
Expand All @@ -160,8 +174,8 @@ async def get_latest_prices(self, session: ClientSession) -> LatestPrices:

async def get_average_prices(
self,
session: ClientSession,
interval: Interval,
session: ClientSession | None = None,
timestamp: int | None = None,
) -> AveragePrices:
"""
Expand All @@ -187,7 +201,7 @@ async def get_average_prices(
return AveragePrices(**data)

async def get_time_series(
self, session: ClientSession, item_id: int, timestep: Interval
self, item_id: int, timestep: Interval, session: ClientSession | None = None
) -> TimeSeries:
"""
Fetches time-series data for a specific item and timestep.
Expand Down
12 changes: 12 additions & 0 deletions tests/test_async_osrs_hiscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,15 @@ async def test_get_invalid():
player="This_is_not_a_valid_name",
session=session,
)


@pytest.mark.asyncio
async def test_get_default_no_session():
hiscore_instance = Hiscore()
player_stats = await hiscore_instance.get(player="extreme4all")
# Assertions to confirm the response is correct
assert isinstance(
player_stats, PlayerStats
), "The returned object is not of type PlayerStats"
assert player_stats.skills, "Skills data should not be empty"
assert player_stats.activities, "Activities data should not be empty"
18 changes: 18 additions & 0 deletions tests/test_async_osrs_itemdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,24 @@ async def test_get_graph_valid():
assert trade_history.average, "Average trade history should not be empty"


@pytest.mark.asyncio
async def test_get_graph_valid_no_session():
"""Test fetching trade history for a valid item ID"""
catalogue_instance = Graph()

item_id = 4151 # Assume this is a valid item ID
trade_history = await catalogue_instance.get_graph(
item_id=item_id, mode=ItemDBMode.OLDSCHOOL
)

# Assertions to confirm the response is correct
assert isinstance(
trade_history, TradeHistory
), "The returned object is not of type TradeHistory"
assert trade_history.daily, "Daily trade history should not be empty"
assert trade_history.average, "Average trade history should not be empty"


@pytest.mark.asyncio
async def test_get_graph_invalid():
"""Test fetching trade history for an invalid item ID"""
Expand Down

0 comments on commit d236aec

Please sign in to comment.