Skip to content

Commit

Permalink
Added support to pass params for prompt service API calls (#69)
Browse files Browse the repository at this point in the history
* Added support to pass params for prompt service API calls

* Changed from positional args to keyword args

* SDK version bump
  • Loading branch information
Deepak-Kesavan authored Jul 4, 2024
1 parent 57d10f9 commit 3ad74d2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/unstract/sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.35.0"
__version__ = "0.36.0"


def get_sdk_version():
Expand Down
57 changes: 41 additions & 16 deletions src/unstract/sdk/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,52 @@ def __init__(
tool (AbstractTool): Instance of AbstractTool
prompt_host (str): Host of platform service
prompt_host (str): Port of platform service
"""
self.tool = tool
self.base_url = SdkHelper.get_platform_base_url(prompt_host, prompt_port)
self.bearer_token = tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)

def answer_prompt(self, payload: dict[str, Any]) -> dict[str, Any]:
return self._post_call("answer-prompt", payload)

def single_pass_extraction(self, payload: dict[str, Any]) -> dict[str, Any]:
return self._post_call("single-pass-extraction", payload)

def summarize(self, payload: dict[str, Any]) -> dict[str, Any]:
return self._post_call("summarize", payload)

def _post_call(self, url_path: str, payload: dict[str, Any]) -> dict[str, Any]:
def answer_prompt(
self, payload: dict[str, Any], params: Optional[dict[str, str]] = None
) -> dict[str, Any]:
return self._post_call(
url_path="answer-prompt",
payload=payload,
params=params,
)

def single_pass_extraction(
self, payload: dict[str, Any], params: Optional[dict[str, str]] = None
) -> dict[str, Any]:
return self._post_call(
url_path="single-pass-extraction",
payload=payload,
params=params,
)

def summarize(
self, payload: dict[str, Any], params: Optional[dict[str, str]] = None
) -> dict[str, Any]:
return self._post_call(url_path="summarize", payload=payload, params=params)

def _post_call(
self,
url_path: str,
payload: dict[str, Any],
params: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
"""Invokes and communicates to prompt service to fetch response for the
prompt.
Args:
file_name (str): File in which the prompt is processed
outputs (dict): dict of all input data for the tool
tool_id (str): Unique ID of the tool to be processed
url_path (str): URL path to the service endpoint
payload (dict): Payload to send in the request body
params (dict, optional): Query parameters to include in the request
Returns:
Sample return dict:
dict: Response from the prompt service
Sample Response:
{
"status": "OK",
"error": "",
Expand All @@ -68,7 +88,12 @@ def _post_call(self, url_path: str, payload: dict[str, Any]) -> dict[str, Any]:
headers: dict[str, str] = {"Authorization": f"Bearer {self.bearer_token}"}
response: Response = Response()
try:
response = requests.post(url, json=payload, headers=headers)
response = requests.post(
url=url,
json=payload,
headers=headers,
params=params,
)
response.raise_for_status()
result["status"] = "OK"
result["structure_output"] = response.text
Expand Down

0 comments on commit 3ad74d2

Please sign in to comment.