Skip to content

Commit

Permalink
[Platform] Add drain callback (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
rokatyy authored Dec 7, 2023
1 parent 24f2927 commit 55c20f6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
25 changes: 19 additions & 6 deletions nuclio_sdk/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(

self._control_callback = on_control_callback
self._termination_callback = None
self._drain_callback = None

async def explicit_ack(self, qualified_offset):
"""
Expand All @@ -51,6 +52,16 @@ async def explicit_ack(self, qualified_offset):
"Cannot send explicit ack since control callback was not initialized"
)

def set_drain_callback(self, callback):
"""
Register a callback to be called when the platform is draining (rebalance happening in stream).
If already registered, the callback will be replaced.
When called, the callback will be called with zero arguments.
:param callback: the callback to call when terminating
"""
self._drain_callback = callback

def set_termination_callback(self, callback):
"""
Register a callback to be called when the platform is terminating.
Expand All @@ -64,7 +75,6 @@ def set_termination_callback(self, callback):
def call_function(
self, function_name, event, node=None, timeout=None, service_name_override=None
):

# get connection from provider
connection = self._connection_provider(
self._get_function_url(function_name, service_name_override),
Expand Down Expand Up @@ -106,7 +116,7 @@ def call_function(
response_headers = {}

# get response headers as lowercase
for (name, value) in response.getheaders():
for name, value in response.getheaders():
response_headers[name.lower()] = value

# if content type exists, use it
Expand All @@ -124,7 +134,6 @@ def call_function(
)

def _get_function_url(self, function_name, service_name_override=None):

# local envs prefix namespace
if self.kind == "local":
service_name = service_name_override or "nuclio-{0}-{1}".format(
Expand All @@ -134,10 +143,14 @@ def _get_function_url(self, function_name, service_name_override=None):
service_name = service_name_override or "nuclio-{0}".format(function_name)
return "{0}:8080".format(service_name)

def _on_signal(self):
def _on_signal(self, callback_type="termination"):
"""
When a signal is received, call the termination callback as a hook before exiting
When a signal is received, call the termination/drain callback as a hook before exiting
If not set, the callback will be a no-op
:arg callback_type:str - callback type, can be "termination" or "drain"
"""
if self._termination_callback:
if callback_type == "termination" and self._termination_callback:
self._termination_callback()
elif callback_type == "drain" and self._drain_callback:
self._drain_callback()
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def get_version():
version = f.read().strip()
if version.startswith("v"):
version = version[1:]
if version == "":
return "0.0.0-dev0"
return version


Expand Down

0 comments on commit 55c20f6

Please sign in to comment.