-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
711 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
results/ | ||
multimodal-datasets/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -142,3 +142,7 @@ checkpoints* | |
|
||
# Pytorch models | ||
*.pt | ||
|
||
# pretrained models | ||
emoberta-base/ | ||
emoberta-large/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
"""Emoberta app""" | ||
import argparse | ||
import logging | ||
import os | ||
|
||
import jsonpickle | ||
import torch | ||
from flask import Flask, request | ||
from transformers import AutoModelForSequenceClassification, AutoTokenizer | ||
|
||
logging.basicConfig( | ||
level=logging.INFO, | ||
format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", | ||
datefmt="%Y-%m-%d %H:%M:%S", | ||
) | ||
|
||
|
||
# ---------------------- GLOBAL VARIABLES ---------------------- # | ||
emotions = [ | ||
"neutral", | ||
"joy", | ||
"surprise", | ||
"anger", | ||
"sadness", | ||
"disgust", | ||
"fear", | ||
] | ||
id2emotion = {idx: emotion for idx, emotion in enumerate(emotions)} | ||
|
||
tokenizer = None | ||
model = None | ||
device = None | ||
|
||
app = Flask(__name__) | ||
# --------------------------------------------------------------- # | ||
|
||
|
||
def load_tokenizer_model(model_type: str, device_: str) -> None: | ||
"""Load tokenizer and model. | ||
Args | ||
---- | ||
model_type: Should be either "emoberta-base" or "emoberta-large" | ||
device_: "cpu" or "cuda" | ||
""" | ||
if "large" in model_type.lower(): | ||
model_type = "emoberta-large" | ||
elif "base" in model_type.lower(): | ||
model_type = "emoberta-base" | ||
else: | ||
raise ValueError( | ||
f"{model_type} is not a valid model type! Should be 'base' or 'large'." | ||
) | ||
|
||
if not os.path.isdir(model_type): | ||
model_type = f"tae898/{model_type}" | ||
|
||
global device | ||
device = device_ | ||
global tokenizer | ||
tokenizer = AutoTokenizer.from_pretrained(model_type) | ||
global model | ||
model = AutoModelForSequenceClassification.from_pretrained(model_type) | ||
model.eval() | ||
model.to(device) | ||
|
||
|
||
@app.route("/", methods=["POST"]) | ||
def run_emoberta(): | ||
"""Receive everything in json!!!""" | ||
app.logger.debug("Receiving data ...") | ||
data = request.json | ||
data = jsonpickle.decode(data) | ||
|
||
text = data["text"] | ||
|
||
app.logger.info(f"raw text received: {text}") | ||
|
||
tokens = tokenizer(text, truncation=True) | ||
|
||
tokens["input_ids"] = torch.tensor(tokens["input_ids"]).view(1, -1).to(device) | ||
tokens["attention_mask"] = ( | ||
torch.tensor(tokens["attention_mask"]).view(1, -1).to(device) | ||
) | ||
|
||
outputs = model(**tokens) | ||
outputs = torch.softmax(outputs["logits"].detach().cpu(), dim=1).squeeze().numpy() | ||
outputs = {id2emotion[idx]: prob.item() for idx, prob in enumerate(outputs)} | ||
app.logger.info(f"prediction: {outputs}") | ||
|
||
response = jsonpickle.encode(outputs) | ||
app.logger.info("json-pickle is done.") | ||
|
||
return response | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="emoberta app.") | ||
parser.add_argument( | ||
"--host", | ||
type=str, | ||
default="0.0.0.0", | ||
help="host ip address", | ||
) | ||
parser.add_argument( | ||
"--port", | ||
type=int, | ||
default=10006, | ||
help="port number", | ||
) | ||
parser.add_argument( | ||
"--device", | ||
type=str, | ||
default="cpu", | ||
help="cpu or cuda", | ||
) | ||
parser.add_argument( | ||
"--model-type", | ||
type=str, | ||
default="emoberta-base", | ||
help="should be either emoberta-base or emoberta-large", | ||
) | ||
|
||
args = parser.parse_args() | ||
load_tokenizer_model(args.model_type, args.device) | ||
|
||
app.run(host=args.host, port=args.port) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
""" | ||
This is just a simple client example. Hack it as much as you want. | ||
""" | ||
import argparse | ||
import logging | ||
|
||
import jsonpickle | ||
import requests | ||
|
||
logging.basicConfig( | ||
level=logging.INFO, | ||
format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", | ||
datefmt="%Y-%m-%d %H:%M:%S", | ||
) | ||
|
||
|
||
def run_text(text: str, url_emoberta: str) -> None: | ||
"""Send data to the flask server. | ||
Args | ||
---- | ||
text: raw text | ||
url_emoberta: e.g., http://127.0.0.1:10006/ | ||
""" | ||
data = {"text": text} | ||
|
||
logging.debug("sending text to server...") | ||
data = jsonpickle.encode(data) | ||
response = requests.post(url_emoberta, json=data) | ||
logging.info(f"got {response} from server!...") | ||
print(response.text) | ||
response = jsonpickle.decode(response.text) | ||
|
||
logging.info(f"emoberta results: {response}") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Classify room type") | ||
parser.add_argument("--url-emoberta", type=str, default="http://127.0.0.1:10006/") | ||
parser.add_argument("--text", type=str, required=True) | ||
|
||
args = vars(parser.parse_args()) | ||
|
||
run_text(**args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
FROM python:3.8.12 | ||
ENV DEBIAN_FRONTEND=noninteractive | ||
|
||
WORKDIR /app | ||
|
||
COPY emoberta-base ./emoberta-base | ||
COPY utils ./utils | ||
COPY app.py ./ | ||
COPY requirements-deploy.txt ./ | ||
|
||
RUN apt update | ||
RUN python3.8 -m pip install --upgrade pip | ||
RUN python3.8 -m pip install -r requirements-deploy.txt | ||
|
||
CMD ["python3.8", "app.py", "--model-type", "emoberta-base", "--device", "cpu"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
FROM nvidia/cuda:11.2.2-cudnn8-runtime-ubuntu20.04 | ||
|
||
ENV DEBIAN_FRONTEND=noninteractive | ||
|
||
WORKDIR /app | ||
|
||
COPY emoberta-base ./emoberta-base | ||
COPY utils ./utils | ||
COPY app.py ./ | ||
COPY requirements-deploy.txt ./ | ||
|
||
RUN apt update | ||
RUN apt install software-properties-common -y | ||
RUN add-apt-repository ppa:deadsnakes/ppa -y | ||
RUN apt update | ||
RUN apt install python3.8 python3.8-dev python3-pip -y | ||
RUN python3.8 -m pip install --upgrade pip | ||
RUN python3.8 -m pip install -r requirements-deploy.txt | ||
|
||
CMD ["python3.8", "app.py", "--model-type", "emoberta-base", "--device", "cuda"] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
FROM python:3.8.12 | ||
ENV DEBIAN_FRONTEND=noninteractive | ||
|
||
WORKDIR /app | ||
|
||
COPY emoberta-large ./emoberta-large | ||
COPY utils ./utils | ||
COPY app.py ./ | ||
COPY requirements-deploy.txt ./ | ||
|
||
RUN apt update | ||
RUN python3.8 -m pip install --upgrade pip | ||
RUN python3.8 -m pip install -r requirements-deploy.txt | ||
|
||
CMD ["python3.8", "app.py", "--model-type", "emoberta-large", "--device", "cpu"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
FROM nvidia/cuda:11.2.2-cudnn8-runtime-ubuntu20.04 | ||
|
||
ENV DEBIAN_FRONTEND=noninteractive | ||
|
||
WORKDIR /app | ||
|
||
COPY emoberta-large ./emoberta-large | ||
COPY utils ./utils | ||
COPY app.py ./ | ||
COPY requirements-deploy.txt ./ | ||
|
||
RUN apt update | ||
RUN apt install software-properties-common -y | ||
RUN add-apt-repository ppa:deadsnakes/ppa -y | ||
RUN apt update | ||
RUN apt install python3.8 python3.8-dev python3-pip -y | ||
RUN python3.8 -m pip install --upgrade pip | ||
RUN python3.8 -m pip install -r requirements-deploy.txt | ||
|
||
CMD ["python3.8", "app.py", "--model-type", "emoberta-large", "--device", "cuda"] | ||
|
Oops, something went wrong.