Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Platform] Add drain callback #57

Merged
merged 2 commits into from
Dec 7, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions nuclio_sdk/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(

self._control_callback = on_control_callback
self._termination_callback = None
self._drain_callback = None

async def explicit_ack(self, qualified_offset):
"""
Expand All @@ -51,6 +52,16 @@ async def explicit_ack(self, qualified_offset):
"Cannot send explicit ack since control callback was not initialized"
)

def set_drain_callback(self, callback):
"""
Register a callback to be called when the platform is draining (rebalance happening in stream).
If already registered, the callback will be replaced.
When called, the callback will be called with zero arguments.

:param callback: the callback to call when terminating
"""
self._drain_callback = callback

def set_termination_callback(self, callback):
"""
Register a callback to be called when the platform is terminating.
Expand All @@ -64,7 +75,6 @@ def set_termination_callback(self, callback):
def call_function(
self, function_name, event, node=None, timeout=None, service_name_override=None
):

# get connection from provider
connection = self._connection_provider(
self._get_function_url(function_name, service_name_override),
Expand Down Expand Up @@ -106,7 +116,7 @@ def call_function(
response_headers = {}

# get response headers as lowercase
for (name, value) in response.getheaders():
for name, value in response.getheaders():
response_headers[name.lower()] = value

# if content type exists, use it
Expand All @@ -124,7 +134,6 @@ def call_function(
)

def _get_function_url(self, function_name, service_name_override=None):

# local envs prefix namespace
if self.kind == "local":
service_name = service_name_override or "nuclio-{0}-{1}".format(
Expand All @@ -134,10 +143,14 @@ def _get_function_url(self, function_name, service_name_override=None):
service_name = service_name_override or "nuclio-{0}".format(function_name)
return "{0}:8080".format(service_name)

def _on_signal(self):
def _on_signal(self, callback_type="termination"):
"""
When a signal is received, call the termination callback as a hook before exiting
When a signal is received, call the termination/drain callback as a hook before exiting
If not set, the callback will be a no-op

:arg callback_type:str - callback type, can be "termination" or "drain"
"""
if self._termination_callback:
if callback_type == "termination" and self._termination_callback:
self._termination_callback()
elif callback_type == "drain" and self._drain_callback:
self._drain_callback()
Loading