-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
85 lines (70 loc) · 2.35 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import torch
import numpy as np
from PIL import Image
import albumentations
import pandas as pd
from lightning_model import LitClassification
# Load class labels
df = pd.read_csv("imagenet_class_labels.csv")
class_labels = df['Labels'].tolist()
# Initialize model and load checkpoint
model = LitClassification()
checkpoint = torch.load("bestmodel-epoch=46-monitor-val_acc1=63.7760009765625.ckpt",
map_location=torch.device('cpu')) # Load to CPU by default
model.load_state_dict(checkpoint['state_dict'])
model.eval()
# Image preprocessing
valid_aug = albumentations.Compose(
[
albumentations.Resize(224, 224, p=1),
albumentations.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
max_pixel_value=255.0,
p=1.0,
),
],
p=1.0,
)
def preprocess_image(image):
# Convert to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
# Convert to numpy array
image = np.array(image)
# Center crop 95% area
H, W, C = image.shape
image = image[int(0.04 * H) : int(0.96 * H), int(0.04 * W) : int(0.96 * W), :]
# Apply augmentations
augmented = valid_aug(image=image)
image = augmented["image"]
# Convert to tensor and add batch dimension
image = torch.tensor(image.transpose(2, 0, 1), dtype=torch.float).unsqueeze(0)
return image
def predict(image):
# Preprocess the image
processed_image = preprocess_image(image)
# Get model prediction
with torch.no_grad():
outputs = model(processed_image)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
# Get top 5 predictions
top5_prob, top5_indices = torch.topk(probabilities, 5)
# Convert predictions to labels and probabilities
results = {
class_labels[idx]: float(prob)
for prob, idx in zip(top5_prob[0], top5_indices[0])
}
return results
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=5),
examples=["sample_imgs/stock-photo-large-hot-dog.jpg"],
title="ImageNet Classification with ResNet50",
description="Upload an image to classify it into one of 1000 ImageNet categories."
)
if __name__ == "__main__":
iface.launch()