diff --git a/workers/fund_public_goods/lib/strategy/utils/fetch_matching_projects.py b/workers/fund_public_goods/lib/strategy/utils/fetch_matching_projects.py index e8d597b..709327c 100644 --- a/workers/fund_public_goods/lib/strategy/utils/fetch_matching_projects.py +++ b/workers/fund_public_goods/lib/strategy/utils/fetch_matching_projects.py @@ -1,13 +1,8 @@ from fund_public_goods.db.entities import Projects -from fund_public_goods.db.tables.projects import get_projects_by_ids from fund_public_goods.lib.strategy.models.answer import Answer from fund_public_goods.lib.strategy.utils.get_top_matching_projects import get_top_matching_projects def fetch_matching_projects(prompt: str) -> list[tuple[Projects, list[Answer]]]: matching_projects = get_top_matching_projects(prompt)[:10] - matched_ids = [p.id for p in matching_projects] - - matching_projects_with_answers = get_projects_by_ids(matched_ids) - - return matching_projects_with_answers \ No newline at end of file + return matching_projects \ No newline at end of file diff --git a/workers/fund_public_goods/lib/strategy/utils/get_top_matching_projects.py b/workers/fund_public_goods/lib/strategy/utils/get_top_matching_projects.py index c94c9b8..18a926c 100644 --- a/workers/fund_public_goods/lib/strategy/utils/get_top_matching_projects.py +++ b/workers/fund_public_goods/lib/strategy/utils/get_top_matching_projects.py @@ -37,10 +37,10 @@ Projects: {projects} """ -def rerank_top_projects(prompt: str, projects: list[Projects]) -> list[Projects]: +def rerank_top_projects(prompt: str, projects: list[tuple[Projects, list[Answer]]]) -> list[tuple[Projects, list[Answer]]]: separator = "\n-----\n" formatted_projects = separator.join([ - f"ID: {i} - Description: {projects[i].description}\n" + f"ID: {i} - Description: {projects[i][0].description}\n" for i in range(len(projects)) ]) formatted_prompt = reranking_prompt_template.format(prompt=prompt, separator=separator, projects=formatted_projects) @@ -56,7 +56,7 @@ def rerank_top_projects(prompt: str, projects: list[Projects]) -> list[Projects] raw_response = str(response.choices[0].message.content) json_response = json.loads(raw_response) top_ids = json_response['project_ids'] - reranked_projects: list[Projects] = [] + reranked_projects: list[tuple[Projects, list[Answer]]] = [] for i in range(len(top_ids)): id = top_ids[i] @@ -82,7 +82,7 @@ def remove_duplicates_and_preserve_order(lst: list[str]) -> list[str]: return result -def get_top_matching_projects(prompt: str) -> list[Projects]: +def get_top_matching_projects(prompt: str) -> list[tuple[Projects, list[Answer]]]: env = load_env() vectorstore = Pinecone( index_name=env.pinecone_index, @@ -99,6 +99,6 @@ def get_top_matching_projects(prompt: str) -> list[Projects]: matched_projects: list[tuple[Projects, list[Answer]]] = get_projects_by_ids(total_unique_ids[:target_unique_ids]) - reranked_projects = rerank_top_projects(prompt=prompt, projects=[p for (p, _) in matched_projects]) + reranked_projects = rerank_top_projects(prompt=prompt, projects=matched_projects) return reranked_projects \ No newline at end of file