Skip to content

Commit

Permalink
Add training instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
wolny committed Jun 23, 2024
1 parent 9970b0e commit 763b9fe
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 13 deletions.
107 changes: 94 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,29 @@ This repository provides a PyTorch implementation of our method from the [paper]
```

## Installation

Checkout the repo and set up conda environment:

```bash
conda env create -f environment.yaml
```

Activate the new environment:

```bash
conda activate spoco
```

## Training

This implementation uses `DistributedDataParallel` training. In order to restrict the number of GPUs used for training
use `CUDA_VISIBLE_DEVICES`, e.g. `CUDA_VISIBLE_DEVICES=0 python spoco_train.py ...` will execute training on `GPU:0`.

### CVPPP dataset
We used A1 subset of the [CVPPP2017_LSC challenge](https://competitions.codalab.org/competitions/18405) for training. In order to train with 10% of randomly selected objects, run:
### CVPPP dataset

We used A1 subset of the [CVPPP2017_LSC challenge](https://competitions.codalab.org/competitions/18405) for training. In
order to train with 10% of randomly selected objects, run:

```bash
python spoco_train.py \
--spoco \
Expand All @@ -54,6 +61,7 @@ python spoco_train.py \
```

`CVPPP_ROOT_DIR` is assumed to have the following subdirectories:

```
- train:
- A1:
Expand All @@ -73,11 +81,16 @@ python spoco_train.py \
- ...
```
Since the CVPPP dataset consist of only `training` and `testing` subdirectories, one has to create the train/val split manually using the `training` subdir.

Since the CVPPP dataset consist of only `training` and `testing` subdirectories, one has to create the train/val split
manually using the `training` subdir.

### Cityscapes Dataset
Download the images `leftImg8bit_trainvaltest.zip` and the labels `gtFine_trainvaltest.zip` from the [Cityscapes website](https://www.cityscapes-dataset.com/downloads)

Download the images `leftImg8bit_trainvaltest.zip` and the labels `gtFine_trainvaltest.zip` from
the [Cityscapes website](https://www.cityscapes-dataset.com/downloads)
and extract them into the `CITYSCAPES_ROOT_DIR` of your choice, so it has the following structure:

```
- gtFine:
- train
Expand All @@ -91,20 +104,27 @@ and extract them into the `CITYSCAPES_ROOT_DIR` of your choice, so it has the fo
```

Create random samplings of each class using the [cityscapesampler.py](spoco/datasets/cityscapesampler.py) script:

```bash
python spoco/datasets/cityscapesampler.py --base_dir CITYSCAPES_ROOT_DIR --class_names person rider car truck bus train motorcycle bicycle
```
this will randomly sample 10%, 20%, ..., 90% of objects from the specified class(es) and save the results in dedicated directories,

this will randomly sample 10%, 20%, ..., 90% of objects from the specified class(es) and save the results in dedicated
directories,
e.g. `CITYSCAPES_ROOT_DIR/gtFine/train/darmstadt/car/0.4` will contain random 40% of objects of class `car`.

One can also sample from all of the objects (people, riders, cars, trucks, buses, trains, motorcycles, bicycles) collectively by simply:
One can also sample from all of the objects (people, riders, cars, trucks, buses, trains, motorcycles, bicycles)
collectively by simply:

```bash
python spoco/datasets/cityscapesampler.py --base_dir CITYSCAPES_ROOT_DIR
```

this will randomly sample 10%, 20%, ..., 90% of **all** objects and save the results in dedicated directories,
e.g. `CITYSCAPES_ROOT_DIR/gtFine/train/darmstadt/all/0.4` will contain random 40% of all objects.

In order to train with 40% of randomly selected objects of class `car`, run:

```bash
python spoco_train.py \
--spoco \
Expand All @@ -129,10 +149,13 @@ python spoco_train.py \
--log-after-iters 500 --max-num-iterations 90000
```

In order to train with a random 40% of all ground truth objects, just remove the `--things-class` argument from the command above.
In order to train with a random 40% of all ground truth objects, just remove the `--things-class` argument from the
command above.

## Prediction

Give a model trained on the CVPPP dataset, run the prediction using the following command:

```bash
python spoco_predict.py \
--spoco \
Expand All @@ -143,13 +166,17 @@ python spoco_predict.py \
--model-feature-maps 16 32 64 128 256 512 \
--output-dir OUTPUT_DIR
```

Results will be saved in the given `OUTPUT_DIR` directory. For each test input image `plantXXX_rgb.png` the following
3 output files will be saved in the `OUTPUT_DIR`:
* `plantXXX_rgb_predictions.h5` - HDF5 file with datasets `/raw` (input image), `/embeddings1` (output from the `f` embedding network), `/embeddings2` (output from the `g` momentum contrast network)

* `plantXXX_rgb_predictions.h5` - HDF5 file with datasets `/raw` (input image), `/embeddings1` (output from the `f`
embedding network), `/embeddings2` (output from the `g` momentum contrast network)
* `plantXXX_rgb_predictions_1.png` - output from the `f` embedding network PCA-projected into the RGB-space
* `plantXXX_rgb_predictions_2.png` - output from the `g` momentum contrast network PCA-projected into the RGB-space

And similarly for the Cityscapes dataset
And similarly for the Cityscapes dataset

```bash
python spoco_predict.py \
--spoco \
Expand All @@ -163,8 +190,11 @@ python spoco_predict.py \
```

## Clustering

To produce the final segmentation one needs to cluster the embeddings with and algorithm of choice. Supported
algoritms: mean-shift, HDBSCAN and Consistency Clustering (as described in the paper). E.g. to cluster CVPPP with HDBSCAN, run:
algoritms: mean-shift, HDBSCAN and Consistency Clustering (as described in the paper). E.g. to cluster CVPPP with
HDBSCAN, run:

```bash
python cluster_predictions.py \
--ds-name cvppp \
Expand All @@ -175,7 +205,9 @@ python cluster_predictions.py \
Where `PREDICTION_DIR` is the directory where h5 files containing network predictions are stored. Resulting segmentation
will be saved as a separate dataset (named `segmentation`) inside each of the H5 prediction files.

In order to cluster the Cityscapes predictions and extract the instances of class `car` and compute the segmentation scores on the validation set:
In order to cluster the Cityscapes predictions and extract the instances of class `car` and compute the segmentation
scores on the validation set:

```bash
python cluster_predictions.py \
--ds-name cityscapes \
Expand All @@ -185,9 +217,58 @@ python cluster_predictions.py \
--things-class car \
--clustering msplus --delta-var 0.5 --delta-dist 2.0
```

Where `SEM_PREDICTION_DIR` is the directory containing the semantic segmentation predictions for your validation images.
We used pre-trained DeepLabv3 model from [here](https://github.com/VainF/DeepLabV3Plus-Pytorch).


## Training and inference on MitoEM dataset
TODO: add instructions

Download the MitoEM-R dataset from https://mitoem.grand-challenge.org and split the h5 file containing 500 slices
into training and validation sets: training file should be named `train.h5` and have 400 slices and validation file
should be named `val.h5` and contain 100 slices.

Then create the random 1%, 5%, 10% samplings of instances using the [mitoemsampler.py](spoco/datasets/mitoemsampler.py)
script:

```bash
python spoco/datasets/mitoemsampler.py --dataset_dir MITOEM_ROOT_DIR --instance_ratios 0.01 0.05, 0.1
```

this will create the following additional datasets inside the `MITOEM_ROOT_DIR/train.h5`:

```
- label_0.01
- label_0.05
- label_0.1
```

### Training on MitoEM

In order to train with 1% of randomly selected instances, run:

```bash
python spoco_train.py \
--spoco \
--ds-name mitoem --ds-path MITOEM_ROOT_DIR \
--instance-ratio 0.01 \
--batch-size 16 \
--model-name UNet2D \
--model-feature-maps 16 32 64 128 256 512 \
--learning-rate 0.0002 \
--weight-decay 0.00001 \
--cos \
--loss-delta-var 0.5 \
--loss-delta-dist 2.0 \
--loss-unlabeled-push 1.0 \
--loss-instance-weight 1.0 \
--loss-consistency-weight 1.0 \
--kernel-threshold 0.5 \
--checkpoint-dir CHECKPOINT_DIR \
--log-after-iters 256 --max-num-iterations 100000
```

### Prediction on MitoEM

The prediction scripts converts the embeddings to affinities using the formula defined in the paper (see eq. 12 in
Appendix 4).
TODO
50 changes: 50 additions & 0 deletions spoco/datasets/mitoemsampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import argparse
from pathlib import Path

import h5py
import numpy as np


def mitoem_sample_instances(label, instance_ratio, random_state):
label_img = np.copy(label)
unique_ids = np.unique(label)[1:]
rs.shuffle(unique_ids)
# pick instance_ratio objects
num_objects = round(instance_ratio * len(unique_ids))
assert num_objects > 0, 'No objects to sample'
print(f'Sampled {num_objects} out of {len(unique_ids)} objects. Instance ratio: {instance_ratio}')
# create a set of object ids left for training
sampled_ids = set(unique_ids[:num_objects])
for id in unique_ids:
if id not in sampled_ids:
label_img[label_img == id] = 0
return label_img


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_dir', type=str, help='MitoEM dir containing train.h5 and val.h5 files',
required=True)
parser.add_argument('--instance_ratios', nargs="+", type=float,
help='fraction of ground truth objects to sample.', required=True)

args = parser.parse_args()

# load label dataset from the train.h5 file
train_file = Path(args.dataset_dir) / 'train.h5'
assert train_file.exists(), f'{train_file} does not exist'

with h5py.File(train_file, 'r+') as f:
label = f['label'][:]

for instance_ratio in args.instance_ratios:
assert 0.0 <= instance_ratio <= 1.0, 'Instance ratio must be in [0, 1]'

ir = float(instance_ratio)
rs = np.random.RandomState(47)
print(f'Sampling {ir * 100}% of mitoEM instances')

label_sampled = mitoem_sample_instances(label, ir, rs)

# save the sampled label dataset
f.create_dataset(f'label_{instance_ratio}', data=label_sampled, compression='gzip')

0 comments on commit 763b9fe

Please sign in to comment.