Skip to content

Commit

Permalink
Refactor inference server logic
Browse files Browse the repository at this point in the history
  • Loading branch information
desi-ivanov committed Nov 4, 2022
1 parent 4e34cf3 commit 5594d6d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 31 deletions.
46 changes: 17 additions & 29 deletions srv.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import random
from srv_predictor import Predictor
from flask import Flask, request, jsonify, make_response
import threading
import time
import os
EXIT_SECONDS_AFTER_INACTIVITY = 60 * 5
import time

predictor = Predictor(
os.environ.get('SG_SEG_DATASET_PATH'),
Expand All @@ -24,26 +23,25 @@ def infer():
return _build_cors_preflight_response()
elif request.method == "POST":
content = request.json
in_img = content['img'].split(',')[-1]
raw_z = content['z']
return _corsify_actual_response(jsonify(predictor.infer(in_img, raw_z)))


@app.route("/example", methods=['POST', 'OPTIONS'])
def example():
latest_request_ptr[0] = time.time()
if request.method == "OPTIONS":
return _build_cors_preflight_response()
elif request.method == "POST":
content = request.json
raw_z = content['z']
return _corsify_actual_response(jsonify(predictor.example(raw_z)))
if content['action'] == 'infer':
in_img = content['img'].split(',')[-1]
raw_z = content['z']
return _corsify_actual_response(jsonify(predictor.infer(in_img, raw_z)))
elif content['action'] == 'example':
raw_z = content['z'] if 'z' in content else [random.gauss(0, 1) for _ in range(512)]
return _corsify_actual_response(jsonify(predictor.example(raw_z)))
else:
return _corsify_actual_response(jsonify({'error': 'unknown action'}))

@app.route("/health", methods=['GET'])
def health():
latest_request_ptr[0] = time.time()
def ping():
return "OK"

@app.route("/latest_request", methods=['GET'])
def latest_request():
return str(latest_request_ptr[0])


def _build_cors_preflight_response():
response = make_response()
response.headers["Access-Control-Allow-Origin"] = "*"
Expand All @@ -54,13 +52,3 @@ def _build_cors_preflight_response():
def _corsify_actual_response(response):
response.headers.add("Access-Control-Allow-Origin", "*")
return response


def keep_checking_if_inactive():
print(f'CHECKING {time.time() - latest_request_ptr[0]}s since last request')
if time.time() - latest_request_ptr[0] > EXIT_SECONDS_AFTER_INACTIVITY:
os._exit(0)
t = threading.Timer(10, keep_checking_if_inactive)
t.start()

keep_checking_if_inactive()
16 changes: 14 additions & 2 deletions srv_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,23 @@
import legacy
from training.dataset import ZipDataset, WithSegSplit

class Wrapper(torch.nn.Module):
def __init__(self, E, G):
super().__init__()
self.E = E
self.G = G
def forward(self, z, x, c, **G_kwargs):
skips = self.E(x)
return self.G(z, c, imgs=skips, **G_kwargs)

def load_ae(model_path):
with dnnlib.util.open_url(model_path) as f:
data = legacy.load_network_pkl(f)
model = (data['AE'] if 'AE' in data else data['VAE']).requires_grad_(False).cuda().eval()
return model
if 'AE' not in data and 'VAE' not in data:
model = Wrapper(data['E'], data['G'])
else:
model = (data['AE'] if 'AE' in data else data['VAE'])
return model.requires_grad_(False).cuda().eval()

def load_clf(clf_path):
clf = resnet18()
Expand Down

0 comments on commit 5594d6d

Please sign in to comment.