-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathweb_interface.py
52 lines (39 loc) · 1.56 KB
/
web_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr
import torch
from torchvision import transforms
from p3_analysis.interpretations import interpret_with_GradCAM
# Import model
from p2_models.models import *
LABELS = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Effusion',
'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration', 'Mass', 'No Finding',
'Nodule', 'Pleural_Thickening', 'Pneumonia', 'Pneumothorax']
def predict(model, inp, device):
transform = model.get_infer_transforms()
inp_tensor = transform(inp).to(device).unsqueeze(0)
with torch.no_grad():
output = model(inp_tensor) # [1, num_diseases]
predictions = output.cpu().detach().numpy()[0]
return {label: float(prob) for label, prob in zip(LABELS, predictions)}
def show_web_interface(model, device):
"""
Runs a gradio web interface for the model.
:param model: a LeafCounter class instance.
:param device: device: 'cuda'/'cpu'.
"""
with gr.Blocks() as demo:
gr.Markdown("# Lung Disease Classification")
with gr.Row():
im = gr.Image()
lbl = gr.Label()
predict_fn = lambda img: (predict(model, img, device),
interpret_with_GradCAM(model, img, device))
btn = gr.Button(value="Get predictions")
btn.click(predict_fn, inputs=[im], outputs=[lbl, im])
gr.Markdown("## Image Examples")
gr.Examples(
examples=["data/images/00000001_000.png"],
inputs=im,
outputs=lbl,
fn=predict_fn
)
demo.launch()