Skip to content

Commit

Permalink
Add pagination to treasury app (#622)
Browse files Browse the repository at this point in the history
  • Loading branch information
kongzii authored Jan 7, 2025
1 parent f9b4a85 commit acc1ecd
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
ReceiveMessage,
SendPaidMessageToAnotherAgent,
)
from prediction_market_agent.db.agent_communication import fetch_unseen_transactions
from prediction_market_agent.db.agent_communication import (
fetch_count_unprocessed_transactions,
fetch_unseen_transactions,
)
from prediction_market_agent.db.long_term_memory_table_handler import (
LongTermMemories,
LongTermMemoryTableHandler,
Expand Down Expand Up @@ -178,12 +181,44 @@ def customized_chat_message(
st.markdown(parsed_function_output_body)


@st.fragment(run_every=timedelta(seconds=5))
def show_function_calls_part(nft_agent: type[DeployableAgentNFTGameAbstract]) -> None:
st.markdown(f"""### Agent's actions""")

n_total_messages = long_term_memory_table_handler(nft_agent.identifier).count()
messages_per_page = 50
if "page_number" not in st.session_state:
st.session_state.page_number = 0

col1, col2, col3 = st.columns(3)
with col1:
if st.button("Previous page", disabled=st.session_state.page_number == 0):
st.session_state.page_number -= 1
with col2:
if st.button(
"Next page",
disabled=st.session_state.page_number
== n_total_messages // messages_per_page,
):
st.session_state.page_number += 1
with col3:
st.write(f"Current page {st.session_state.page_number + 1}")

show_function_calls_part_messages(
nft_agent, messages_per_page, st.session_state.page_number
)


@st.fragment(run_every=timedelta(seconds=10))
def show_function_calls_part_messages(
nft_agent: type[DeployableAgentNFTGameAbstract],
messages_per_page: int,
page_number: int,
) -> None:
with st.spinner("Loading agent's actions..."):
calls = long_term_memory_table_handler(nft_agent.identifier).search()
calls = long_term_memory_table_handler(nft_agent.identifier).search(
offset=page_number * messages_per_page,
limit=messages_per_page,
)

if not calls:
st.markdown("No actions yet.")
Expand All @@ -205,7 +240,7 @@ def show_function_calls_part(nft_agent: type[DeployableAgentNFTGameAbstract]) ->
customized_chat_message(function_call, function_output)


@st.fragment(run_every=timedelta(seconds=5))
@st.fragment(run_every=timedelta(seconds=10))
def show_about_agent_part(nft_agent: type[DeployableAgentNFTGameAbstract]) -> None:
system_prompt = (
system_prompt_from_db.prompt
Expand Down Expand Up @@ -240,23 +275,28 @@ def show_about_agent_part(nft_agent: type[DeployableAgentNFTGameAbstract]) -> No
)
st.markdown("---")
with st.popover("Show unprocessed incoming messages"):
messages = fetch_unseen_transactions(nft_agent.wallet_address)
show_n = 10
n_messages = fetch_count_unprocessed_transactions(nft_agent.wallet_address)
messages = fetch_unseen_transactions(nft_agent.wallet_address, n=show_n)

if not messages:
st.info("No unprocessed messages")
else:
for message in messages:
st.markdown(
f"""
**From:** {message.sender}
**From:** {message.sender}
**Message:** {unzip_message_else_do_nothing(message.message.hex())}
**Value:** {wei_to_xdai(message.value)} xDai
"""
)
st.divider()

if n_messages > show_n:
st.write(f"... and another {n_messages - show_n} unprocessed messages.")


@st.fragment(run_every=timedelta(seconds=5))
@st.fragment(run_every=timedelta(seconds=10))
def show_treasury_part() -> None:
treasury_xdai_balance = get_balances(TREASURY_SAFE_ADDRESS).xdai
st.markdown(
Expand Down
9 changes: 8 additions & 1 deletion prediction_market_agent/db/agent_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,20 @@

def fetch_unseen_transactions(
consumer_address: ChecksumAddress,
n: int | None = None,
) -> list[MessageContainer]:
agent_comm_contract = AgentCommunicationContract()

count_unseen_messages = fetch_count_unprocessed_transactions(consumer_address)

message_containers = par_map(
items=list(range(count_unseen_messages)),
items=list(
range(
min(n, count_unseen_messages)
if n is not None
else count_unseen_messages
)
),
func=lambda idx: agent_comm_contract.get_at_index(
agent_address=consumer_address, idx=idx
),
Expand Down
28 changes: 20 additions & 8 deletions prediction_market_agent/db/long_term_memory_table_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import typing as t

from prediction_market_agent_tooling.tools.utils import DatetimeUTC, utcnow
from sqlalchemy.sql.elements import ColumnElement
from sqlmodel import col

from prediction_market_agent.agents.identifiers import AgentIdentifier
Expand Down Expand Up @@ -44,25 +45,36 @@ def save_answer_with_scenario(
) -> None:
return self.save_history([answer_with_scenario.model_dump()])

def search(
self,
from_: DatetimeUTC | None = None,
to_: DatetimeUTC | None = None,
limit: int | None = None,
) -> list[LongTermMemories]:
"""Searches the LongTermMemoryTableHandler for entries within a specified datetime range that match
self.task_description."""
def _get_query_filters(
self, from_: DatetimeUTC | None, to_: DatetimeUTC | None
) -> list[ColumnElement[bool]]:
query_filters = [
col(LongTermMemories.task_description) == self.task_description
]
if from_ is not None:
query_filters.append(col(LongTermMemories.datetime_) >= from_)
if to_ is not None:
query_filters.append(col(LongTermMemories.datetime_) <= to_)
return query_filters

def search(
self,
from_: DatetimeUTC | None = None,
to_: DatetimeUTC | None = None,
offset: int = 0,
limit: int | None = None,
) -> list[LongTermMemories]:
"""Searches the LongTermMemoryTableHandler for entries within a specified datetime range that match
self.task_description."""
query_filters = self._get_query_filters(from_, to_)
return self.sql_handler.get_with_filter_and_order(
query_filters=query_filters,
order_by_column_name=LongTermMemories.datetime_.key, # type: ignore[attr-defined]
order_desc=True,
offset=offset,
limit=limit,
)

def count(self) -> int:
query_filters = self._get_query_filters(None, None)
return self.sql_handler.count(query_filters=query_filters)
13 changes: 13 additions & 0 deletions prediction_market_agent/db/sql_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def get_with_filter_and_order(
query_filters: t.Sequence[ColumnElement[bool] | BinaryExpression[bool]] = (),
order_by_column_name: str | None = None,
order_desc: bool = True,
offset: int = 0,
limit: int | None = None,
) -> list[SQLModelType]:
with self.db_manager.get_session() as session:
Expand All @@ -47,7 +48,19 @@ def get_with_filter_and_order(
if order_desc
else asc(order_by_column_name)
)
if offset:
query = query.offset(offset)
if limit:
query = query.limit(limit)
results = query.all()
return results

def count(
self,
query_filters: t.Sequence[ColumnElement[bool] | BinaryExpression[bool]] = (),
) -> int:
with self.db_manager.get_session() as session:
query = session.query(self.table)
for exp in query_filters:
query = query.where(exp)
return query.count()

0 comments on commit acc1ecd

Please sign in to comment.