Skip to content

Commit

Permalink
more experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommie Kerssies committed Mar 17, 2023
1 parent 41f3a18 commit df3225d
Show file tree
Hide file tree
Showing 10 changed files with 516 additions and 851 deletions.
5 changes: 5 additions & 0 deletions feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
The code in this file is borrowed from PatchCore (https://github.com/amazon-science/patchcore-inspection) (Apache-2.0 License)
"""


import contextlib
from copy import deepcopy
from torch.nn import Module, Sequential
Expand Down
12 changes: 12 additions & 0 deletions jobs_ablation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/bin/bash
name="ablation_2023_03_17_13_09_00"
n_trials="2000"
for k in "1"; do
for seed in "0" "1" "2"; do
for category in "carpet" "grid" "leather" "tile" "wood" "bottle" "cable" "capsule" "hazelnut" "metal_nut" "pill" "screw" "toothbrush" "transistor" "zipper"; do
for test_set_search in "False"; do
sbatch search.sh "${name}_n${n_trials}_k${k}_s${seed}_${category}_${test_set_search}" "${n_trials}" "${k}" "${seed}" "${category}" "${test_set_search}" "ofa_mbv3_d234_e346_k357_w1.2" "7" "6"
done
done
done
done
2 changes: 1 addition & 1 deletion jobs.sh → jobs_main_results.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ for k in "1" "2" "4"; do
for seed in "0" "1" "2"; do
for category in "carpet" "grid" "leather" "tile" "wood" "bottle" "cable" "capsule" "hazelnut" "metal_nut" "pill" "screw" "toothbrush" "transistor" "zipper"; do
for test_set_search in "False" "True"; do
sbatch search.sh "${name}_n${n_trials}_k${k}_s${seed}_${category}_${test_set_search}" "${n_trials}" "${k}" "${seed}" "${category}" "${test_set_search}"
sbatch search.sh "${name}_n${n_trials}_k${k}_s${seed}_${category}_${test_set_search}" "${n_trials}" "${k}" "${seed}" "${category}" "${test_set_search}" "" "" ""
done
done
done
Expand Down
20 changes: 15 additions & 5 deletions patchcore.py → model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from typing import Tuple
import torch
from torch.nn import Linear, Module
from torchmetrics_v1_9_3 import precision_recall_curve
from torchmetrics_v1_9_3 import _auroc_compute, precision_recall_curve
from skimage.measure import label, regionprops


class PatchCore(LightningModule):
class Model(LightningModule):
def __init__(
self,
backbone: Module,
Expand Down Expand Up @@ -132,6 +132,15 @@ def _compute_metrics(
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
y_hat = torch.cat([y_hat for y_hat, _, _ in outputs], dim=0).cpu()
y = torch.cat([y for _, y, _ in outputs], dim=0).cpu()

self.AUROC = _auroc_compute(y_hat.flatten(), y.flatten(), "binary").item()
self.partial_AUROC = _auroc_compute(
y_hat.flatten(), y.flatten(), "binary", max_fpr=0.3
).item()

precision, recall, _ = precision_recall_curve(y_hat.flatten(), y.flatten())
self.AP = -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]).item()

x_type = []
for _, _, x_type_i in outputs:
x_type.extend(x_type_i)
Expand All @@ -158,11 +167,12 @@ def _compute_metrics(
mean_region_area / region.area
) * (mean_region_count / region_count_per_type[x_type[i]])

precision, recall, _ = precision_recall_curve(
weighted_precision, weighted_recall, _ = precision_recall_curve(
y_hat.flatten(), y.flatten(), sample_weights=sample_weights.flatten()
)

self.wAP = -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]).item()
self.wAP = -torch.sum(
(weighted_recall[1:] - weighted_recall[:-1]) * weighted_precision[:-1]
).item()

def validation_epoch_end(self, outputs) -> None:
self._compute_metrics(outputs)
Expand Down
710 changes: 102 additions & 608 deletions notebook.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
## Acknowledgement
We borrow some code from PatchCore (https://github.com/amazon-science/patchcore-inspection) (Apache-2.0 License)
We borrow some code from PatchCore (https://github.com/amazon-science/patchcore-inspection) (Apache-2.0 License)
We borrow some code from TorchMetrics (https://github.com/Lightning-AI/metrics) (Apache-2.0 License)
547 changes: 331 additions & 216 deletions results.ipynb

Large diffs are not rendered by default.

49 changes: 35 additions & 14 deletions search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from optuna import create_study
from torch import set_float32_matmul_precision
from mvtec import MVTecDataModule
from patchcore import PatchCore
from model import Model
from pytorch_lightning import Trainer, seed_everything
from ofa.model_zoo import ofa_net
from feature_extractor import FeatureExtractor
Expand All @@ -17,14 +17,19 @@ def objective(
datamodule,
trainer_kwargs,
img_size,
fixed_supernet_name=None,
fixed_kernel_size=None,
fixed_expand_ratio=None,
test_set_search=False,
return_patchcore=False,
return_model=False,
):
supernet = ofa_net(
trial.suggest_categorical(
"supernet_name",
["ofa_mbv3_d234_e346_k357_w1.0", "ofa_mbv3_d234_e346_k357_w1.2"],
),
)
if fixed_supernet_name is None
else fixed_supernet_name,
pretrained=True,
)

Expand All @@ -35,11 +40,15 @@ def objective(
stage_patch_size = {}

for stage_idx, stage_blocks in enumerate(supernet.block_group_info):
stage_kernel_size[stage_idx] = trial.suggest_int(
f"stage_{stage_idx}_kernel_size", 3, 7, step=2
stage_kernel_size[stage_idx] = (
trial.suggest_int(f"stage_{stage_idx}_kernel_size", 3, 7, step=2)
if fixed_kernel_size is None
else fixed_kernel_size
)
stage_expand_ratio[stage_idx] = trial.suggest_categorical(
f"stage_{stage_idx}_expand_ratio", [3, 4, 6]
stage_expand_ratio[stage_idx] = (
trial.suggest_categorical(f"stage_{stage_idx}_expand_ratio", [3, 4, 6])
if fixed_expand_ratio is None
else fixed_expand_ratio
)
stage_block[stage_idx] = trial.suggest_categorical(
f"stage_{stage_idx}_block", [None, *stage_blocks]
Expand Down Expand Up @@ -87,7 +96,7 @@ def objective(
)
trainer = Trainer(**trainer_kwargs)

patchcore = PatchCore(
model = Model(
feature_extractor,
img_size,
patch_sizes=[
Expand All @@ -101,16 +110,22 @@ def objective(
)

info("Fitting...")
trainer.fit(patchcore, datamodule=datamodule)
trainer.fit(model, datamodule=datamodule)
if not test_set_search:
trial.set_user_attr("val_wAP", patchcore.wAP)
trial.set_user_attr("val_AUROC", model.AUROC)
trial.set_user_attr("val_partial_AUROC", model.partial_AUROC)
trial.set_user_attr("val_AP", model.AP)
trial.set_user_attr("val_wAP", model.wAP)

info("Testing...")
trainer.test(patchcore, datamodule=datamodule)
trial.set_user_attr("test_wAP", patchcore.wAP)
trainer.test(model, datamodule=datamodule)
trial.set_user_attr("test_AUROC", model.AUROC)
trial.set_user_attr("test_partial_AUROC", model.partial_AUROC)
trial.set_user_attr("test_AP", model.AP)
trial.set_user_attr("test_wAP", model.wAP)

if return_patchcore:
return patchcore
if return_model:
return model

if test_set_search:
return [flops, trial.user_attrs["test_wAP"]]
Expand Down Expand Up @@ -150,6 +165,9 @@ def main(args, trainer_kwargs):
datamodule,
trainer_kwargs,
args.img_size,
args.fixed_supernet_name,
args.fixed_kernel_size,
args.fixed_expand_ratio,
args.test_set_search,
),
n_trials=args.n_trials,
Expand All @@ -171,6 +189,9 @@ def main(args, trainer_kwargs):
parser.add_argument("--batch_size", type=int, default=391)
parser.add_argument("--img_size", type=int, default=224)
parser.add_argument("--category", type=str)
parser.add_argument("--fixed_supernet_name", type=str)
parser.add_argument("--fixed_kernel_size", type=int)
parser.add_argument("--fixed_expand_ratio", type=int)
parser.add_argument(
"--dataset_dir", type=str, default="/dataB1/tommie_kerssies/MVTec"
)
Expand Down
2 changes: 1 addition & 1 deletion search.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
source /home/tommie_kerssies/miniconda3/etc/profile.d/conda.sh
conda activate AutoPatch

srun python search.py --accelerator auto --study_name $1 --n_trials $2 --k $3 --seed $4 --category $5 --test_set_search $6
srun python search.py --accelerator auto --study_name $1 --n_trials $2 --k $3 --seed $4 --category $5 --test_set_search $6 --fixed_supernet_name $7 --fixed_kernel_size $8 --fixed_expand_ratio $9
17 changes: 12 additions & 5 deletions torchmetrics_v1_9_3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
"""
The code in this file is borrowed from TorchMetrics v1.9.3 (https://github.com/Lightning-AI/metrics/tree/v0.9.3) (Apache-2.0 License)
"""

from typing import List, Optional, Sequence, Tuple, Union, Any, no_type_check, Dict
import torch
from torch import Tensor, tensor
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_9
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
import warnings
Expand Down Expand Up @@ -1316,11 +1321,13 @@ def _auroc_compute(
f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}"
)

if _TORCH_LOWER_1_6:
raise RuntimeError(
"`max_fpr` argument requires `torch.bucketize` which"
" is not available below PyTorch version 1.6"
)
if not _TORCH_GREATER_EQUAL_1_9:
raise RuntimeError()
# if _TORCH_LOWER_1_6:
# raise RuntimeError(
# "`max_fpr` argument requires `torch.bucketize` which"
# " is not available below PyTorch version 1.6"
# )

# max_fpr parameter is only support for binary
if mode != DataType.BINARY:
Expand Down

0 comments on commit df3225d

Please sign in to comment.