Skip to content

Commit

Permalink
[ADD] training code
Browse files Browse the repository at this point in the history
  • Loading branch information
markkua committed May 25, 2024
1 parent ed281c7 commit 21c8843
Show file tree
Hide file tree
Showing 53 changed files with 2,677 additions and 61 deletions.
43 changes: 40 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ The default settings are optimized for the best result. However, the behavior of
- `--processing_res`: the processing resolution; set as 0 to process the input resolution directly. When unassigned (`None`), will read default setting from model config. Default: ~~768~~ `None`.
- `--output_processing_res`: produce output at the processing resolution instead of upsampling it to the input resolution. Default: False.
- `--resample_method`: resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`. Default: `bilinear`.
- `--resample_method`: the resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic`, or `nearest`. Default: `bilinear`.
- `--half_precision` or `--fp16`: Run with half-precision (16-bit float) to reduce VRAM usage, might lead to suboptimal result.
- `--half_precision` or `--fp16`: Run with half-precision (16-bit float) to reduce VRAM usage, which might lead to suboptimal results.
- `--seed`: Random seed can be set to ensure additional reproducibility. Default: None (unseeded). Note: forcing `--batch_size 1` helps to increase reproducibility. To ensure full reproducibility, [deterministic mode](https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms) needs to be used.
- `--batch_size`: Batch size of repeated inference. Default: 0 (best value determined automatically).
- `--color_map`: [Colormap](https://matplotlib.org/stable/users/explain/colors/colormaps.html) used to colorize the depth prediction. Default: Spectral. Set to `None` to skip colored depth map generation.
Expand Down Expand Up @@ -196,7 +196,7 @@ python run.py \
--output_dir output/in-the-wild_example
```
## 🦿 Evaluation on test datasets
## 🦿 Evaluation on test datasets <a name="evaluation"></a>
Install additional dependencies:
Expand Down Expand Up @@ -224,6 +224,43 @@ bash script/eval/12_eval_nyu.sh
Note: although the seed has been set, the results might still be slightly different on different hardware.
## 🏋️ Training
Based on the previously created environment, install extended requirements:
```bash
pip install -r requirements++.txt -r requirements+.txt -r requirements.txt
```
Set environment parameters for the data directory:
```bash
export BASE_DATA_DIR=YOUR_DATA_DIR # directory of training data
export BASE_CKPT_DIR=YOUR_CHECKPOINT_DIR # directory of pretrained checkpoint
```
Download Stable Diffusion v2 [checkpoint](https://huggingface.co/stabilityai/stable-diffusion-2) into `${BASE_CKPT_DIR}`
Prepare for [Hypersim](https://github.com/apple/ml-hypersim) and [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/) datasets and save into `${BASE_DATA_DIR}`. Please refer to [this README](script/dataset_preprocess/hypersim/README.md) for Hypersim preprocessing.
Run training script
```bash
python train.py --config config/train_marigold.yaml
```
Resume from a checkpoint, e.g.
```bash
python train.py --resume_from output/marigold_base/checkpoint/latest
```
Evaluating results
Only the U-Net is updated and saved during training. To use the inference pipeline with your training result, replace `unet` folder in Marigold checkpoints with that in the `checkpoint` output folder. Then refer to [this section](#evaluation) for evaluation.
**Note**: Although random seeds have been set, the training result might be slightly different on different hardwares. It's recommended to train without interruption.

## ✏️ Contributing

Please refer to [this](CONTRIBUTING.md) instruction.
Expand Down
4 changes: 4 additions & 0 deletions config/dataset/data_hypersim_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name: hypersim
disp_name: hypersim_train
dir: hypersim/hypersim_processed_train.tar
filenames: data_split/hypersim/filename_list_train_filtered.txt
4 changes: 4 additions & 0 deletions config/dataset/data_hypersim_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
name: hypersim
disp_name: hypersim_val
dir: hypersim/hypersim_processed_val.tar
filenames: data_split/hypersim/filename_list_val_filtered.txt
6 changes: 6 additions & 0 deletions config/dataset/data_kitti_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
name: kitti
disp_name: kitti_val800_from_eigen_train
dir: kitti/kitti_sampled_val_800.tar
filenames: data_split/kitti/eigen_val_from_train_800.txt
kitti_bm_crop: true
valid_mask_crop: eigen
5 changes: 5 additions & 0 deletions config/dataset/data_nyu_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
name: nyu_v2
disp_name: nyu_train_full
dir: nyuv2/nyu_labeled_extracted.tar
filenames: data_split/nyu/labeled/filename_list_train.txt
eigen_valid_mask: true
6 changes: 6 additions & 0 deletions config/dataset/data_vkitti_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
name: vkitti
disp_name: vkitti_train
dir: vkitti/vkitti.tar
filenames: data_split/vkitti/vkitti_train.txt
kitti_bm_crop: true
valid_mask_crop: null # no valid_mask_crop for training
6 changes: 6 additions & 0 deletions config/dataset/data_vkitti_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
name: vkitti
disp_name: vkitti_val
dir: vkitti/vkitti.tar
filenames: data_split/vkitti/vkitti_val.txt
kitti_bm_crop: true
valid_mask_crop: eigen
18 changes: 18 additions & 0 deletions config/dataset/dataset_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
dataset:
train:
name: mixed
prob_ls: [0.9, 0.1]
dataset_list:
- name: hypersim
disp_name: hypersim_train
dir: hypersim/hypersim_processed_train.tar
filenames: data_split/hypersim/filename_list_train_filtered.txt
resize_to_hw:
- 480
- 640
- name: vkitti
disp_name: vkitti_train
dir: vkitti/vkitti.tar
filenames: data_split/vkitti/vkitti_train.txt
kitti_bm_crop: true
valid_mask_crop: null
45 changes: 45 additions & 0 deletions config/dataset/dataset_val.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
dataset:
val:
# - name: hypersim
# disp_name: hypersim_val
# dir: hypersim/hypersim_processed_val.tar
# filenames: data_split/hypersim/filename_list_val_filtered.txt
# resize_to_hw:
# - 480
# - 640

# - name: nyu_v2
# disp_name: nyu_train_full
# dir: nyuv2/nyu_labeled_extracted.tar
# filenames: data_split/nyu/labeled/filename_list_train.txt
# eigen_valid_mask: true

# - name: kitti
# disp_name: kitti_val800_from_eigen_train
# dir: kitti/kitti_sampled_val_800.tar
# filenames: data_split/kitti/eigen_val_from_train_800.txt
# kitti_bm_crop: true
# valid_mask_crop: eigen

# Smaller subsets for faster validation during training
# The first dataset is used to calculate main eval metric.
- name: hypersim
disp_name: hypersim_val_small_80
dir: hypersim/hypersim_processed_val.tar
filenames: data_split/hypersim/filename_list_val_filtered_small_80.txt
resize_to_hw:
- 480
- 640

- name: nyu_v2
disp_name: nyu_train_small_100
dir: nyuv2/nyu_labeled_extracted.tar
filenames: data_split/nyu/labeled/filename_list_train_small_100.txt
eigen_valid_mask: true

- name: kitti
disp_name: kitti_val_from_train_sub_100
dir: kitti/kitti_sampled_val_800.tar
filenames: data_split/kitti/eigen_val_from_train_sub_100.txt
kitti_bm_crop: true
valid_mask_crop: eigen
9 changes: 9 additions & 0 deletions config/dataset/dataset_vis.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
dataset:
vis:
- name: hypersim
disp_name: hypersim_vis
dir: hypersim/hypersim_processed_val.tar
filenames: data_split/hypersim/selected_vis_sample.txt
resize_to_hw:
- 480
- 640
5 changes: 5 additions & 0 deletions config/logging.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
logging:
filename: logging.log
format: ' %(asctime)s - %(levelname)s -%(filename)s - %(funcName)s >> %(message)s'
console_level: 20
file_level: 10
4 changes: 4 additions & 0 deletions config/model_sdv2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
model:
name: marigold_pipeline
pretrained_path: stable-diffusion-2
latent_scale_factor: 0.18215
12 changes: 12 additions & 0 deletions config/train_debug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
base_config:
- config/train_marigold.yaml


# Training settings
trainer:
save_period: 5
backup_period: 10
validation_period: 5
visualization_period: 5

max_iter: 50
94 changes: 94 additions & 0 deletions config/train_marigold.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
base_config:
- config/logging.yaml
- config/wandb.yaml
- config/dataset/dataset_train.yaml
- config/dataset/dataset_val.yaml
- config/dataset/dataset_vis.yaml
- config/model_sdv2.yaml


pipeline:
name: MarigoldPipeline
kwargs:
scale_invariant: true
shift_invariant: true

depth_normalization:
type: scale_shift_depth
clip: true
norm_min: -1.0
norm_max: 1.0
min_max_quantile: 0.02

augmentation:
lr_flip_p: 0.5

dataloader:
num_workers: 2
effective_batch_size: 32
max_train_batch_size: 2
seed: 2024 # to ensure continuity when resuming from checkpoint

# Training settings
trainer:
name: MarigoldTrainer
training_noise_scheduler:
pretrained_path: stable-diffusion-2
init_seed: 2024 # use null to train w/o seeding
save_period: 50
backup_period: 2000
validation_period: 2000
visualization_period: 2000

multi_res_noise:
strength: 0.9
annealed: true
downscale_strategy: original

gt_depth_type: depth_raw_norm
gt_mask_type: valid_mask_raw

max_epoch: 10000 # a large enough number
max_iter: 30000 # usually converges at around 20k

optimizer:
name: Adam

loss:
name: mse_loss
kwargs:
reduction: mean

lr: 3.0e-05
lr_scheduler:
name: IterExponential
kwargs:
total_iter: 25000
final_ratio: 0.01
warmup_steps: 100

# Validation (and visualization) settings
validation:
denoising_steps: 50
ensemble_size: 1 # simplified setting for on-training validation
processing_res: 0
match_input_res: false
resample_method: bilinear
main_val_metric: abs_relative_difference
main_val_metric_goal: minimize
init_seed: 2024

eval:
alignment: least_square
align_max_res: null
eval_metrics:
- abs_relative_difference
- squared_relative_difference
- rmse_linear
- rmse_log
- log10
- delta1_acc
- delta2_acc
- delta3_acc
- i_rmse
- silog_rmse
3 changes: 3 additions & 0 deletions config/wandb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
wandb:
# entity: your_entity
project: marigold
80 changes: 80 additions & 0 deletions data_split/hypersim/filename_list_val_filtered_small_80.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
ai_003_010/rgb_cam_00_fr0047.png ai_003_010/depth_plane_cam_00_fr0047.png
ai_003_010/rgb_cam_00_fr0048.png ai_003_010/depth_plane_cam_00_fr0048.png
ai_003_010/rgb_cam_01_fr0098.png ai_003_010/depth_plane_cam_01_fr0098.png
ai_004_003/rgb_cam_01_fr0008.png ai_004_003/depth_plane_cam_01_fr0008.png
ai_004_004/rgb_cam_00_fr0025.png ai_004_004/depth_plane_cam_00_fr0025.png
ai_004_004/rgb_cam_00_fr0046.png ai_004_004/depth_plane_cam_00_fr0046.png
ai_004_004/rgb_cam_00_fr0049.png ai_004_004/depth_plane_cam_00_fr0049.png
ai_004_004/rgb_cam_01_fr0023.png ai_004_004/depth_plane_cam_01_fr0023.png
ai_005_005/rgb_cam_00_fr0032.png ai_005_005/depth_plane_cam_00_fr0032.png
ai_006_007/rgb_cam_00_fr0022.png ai_006_007/depth_plane_cam_00_fr0022.png
ai_006_007/rgb_cam_00_fr0095.png ai_006_007/depth_plane_cam_00_fr0095.png
ai_007_001/rgb_cam_00_fr0044.png ai_007_001/depth_plane_cam_00_fr0044.png
ai_007_001/rgb_cam_00_fr0048.png ai_007_001/depth_plane_cam_00_fr0048.png
ai_009_007/rgb_cam_00_fr0017.png ai_009_007/depth_plane_cam_00_fr0017.png
ai_009_007/rgb_cam_00_fr0097.png ai_009_007/depth_plane_cam_00_fr0097.png
ai_009_009/rgb_cam_00_fr0094.png ai_009_009/depth_plane_cam_00_fr0094.png
ai_015_001/rgb_cam_00_fr0058.png ai_015_001/depth_plane_cam_00_fr0058.png
ai_015_001/rgb_cam_00_fr0089.png ai_015_001/depth_plane_cam_00_fr0089.png
ai_017_007/rgb_cam_01_fr0064.png ai_017_007/depth_plane_cam_01_fr0064.png
ai_018_005/rgb_cam_00_fr0014.png ai_018_005/depth_plane_cam_00_fr0014.png
ai_018_005/rgb_cam_00_fr0059.png ai_018_005/depth_plane_cam_00_fr0059.png
ai_022_010/rgb_cam_00_fr0097.png ai_022_010/depth_plane_cam_00_fr0097.png
ai_022_010/rgb_cam_00_fr0099.png ai_022_010/depth_plane_cam_00_fr0099.png
ai_023_003/rgb_cam_00_fr0013.png ai_023_003/depth_plane_cam_00_fr0013.png
ai_023_003/rgb_cam_00_fr0015.png ai_023_003/depth_plane_cam_00_fr0015.png
ai_023_003/rgb_cam_00_fr0036.png ai_023_003/depth_plane_cam_00_fr0036.png
ai_023_003/rgb_cam_00_fr0095.png ai_023_003/depth_plane_cam_00_fr0095.png
ai_023_003/rgb_cam_01_fr0029.png ai_023_003/depth_plane_cam_01_fr0029.png
ai_023_003/rgb_cam_01_fr0036.png ai_023_003/depth_plane_cam_01_fr0036.png
ai_023_003/rgb_cam_01_fr0071.png ai_023_003/depth_plane_cam_01_fr0071.png
ai_032_007/rgb_cam_00_fr0031.png ai_032_007/depth_plane_cam_00_fr0031.png
ai_032_007/rgb_cam_00_fr0040.png ai_032_007/depth_plane_cam_00_fr0040.png
ai_032_007/rgb_cam_00_fr0075.png ai_032_007/depth_plane_cam_00_fr0075.png
ai_035_003/rgb_cam_00_fr0054.png ai_035_003/depth_plane_cam_00_fr0054.png
ai_035_004/rgb_cam_00_fr0077.png ai_035_004/depth_plane_cam_00_fr0077.png
ai_038_009/rgb_cam_00_fr0031.png ai_038_009/depth_plane_cam_00_fr0031.png
ai_038_009/rgb_cam_01_fr0010.png ai_038_009/depth_plane_cam_01_fr0010.png
ai_038_009/rgb_cam_01_fr0088.png ai_038_009/depth_plane_cam_01_fr0088.png
ai_039_003/rgb_cam_01_fr0042.png ai_039_003/depth_plane_cam_01_fr0042.png
ai_039_003/rgb_cam_01_fr0097.png ai_039_003/depth_plane_cam_01_fr0097.png
ai_044_001/rgb_cam_00_fr0043.png ai_044_001/depth_plane_cam_00_fr0043.png
ai_044_001/rgb_cam_01_fr0018.png ai_044_001/depth_plane_cam_01_fr0018.png
ai_044_003/rgb_cam_01_fr0082.png ai_044_003/depth_plane_cam_01_fr0082.png
ai_044_003/rgb_cam_01_fr0087.png ai_044_003/depth_plane_cam_01_fr0087.png
ai_044_003/rgb_cam_02_fr0086.png ai_044_003/depth_plane_cam_02_fr0086.png
ai_044_003/rgb_cam_03_fr0022.png ai_044_003/depth_plane_cam_03_fr0022.png
ai_044_003/rgb_cam_03_fr0063.png ai_044_003/depth_plane_cam_03_fr0063.png
ai_045_008/rgb_cam_00_fr0015.png ai_045_008/depth_plane_cam_00_fr0015.png
ai_045_008/rgb_cam_00_fr0030.png ai_045_008/depth_plane_cam_00_fr0030.png
ai_045_008/rgb_cam_01_fr0029.png ai_045_008/depth_plane_cam_01_fr0029.png
ai_045_008/rgb_cam_01_fr0052.png ai_045_008/depth_plane_cam_01_fr0052.png
ai_045_008/rgb_cam_01_fr0088.png ai_045_008/depth_plane_cam_01_fr0088.png
ai_047_009/rgb_cam_00_fr0097.png ai_047_009/depth_plane_cam_00_fr0097.png
ai_048_001/rgb_cam_00_fr0014.png ai_048_001/depth_plane_cam_00_fr0014.png
ai_048_001/rgb_cam_00_fr0088.png ai_048_001/depth_plane_cam_00_fr0088.png
ai_048_001/rgb_cam_01_fr0045.png ai_048_001/depth_plane_cam_01_fr0045.png
ai_048_001/rgb_cam_02_fr0031.png ai_048_001/depth_plane_cam_02_fr0031.png
ai_048_001/rgb_cam_03_fr0005.png ai_048_001/depth_plane_cam_03_fr0005.png
ai_048_001/rgb_cam_03_fr0045.png ai_048_001/depth_plane_cam_03_fr0045.png
ai_048_001/rgb_cam_03_fr0054.png ai_048_001/depth_plane_cam_03_fr0054.png
ai_048_001/rgb_cam_03_fr0061.png ai_048_001/depth_plane_cam_03_fr0061.png
ai_050_002/rgb_cam_01_fr0016.png ai_050_002/depth_plane_cam_01_fr0016.png
ai_050_002/rgb_cam_02_fr0053.png ai_050_002/depth_plane_cam_02_fr0053.png
ai_050_002/rgb_cam_03_fr0082.png ai_050_002/depth_plane_cam_03_fr0082.png
ai_050_002/rgb_cam_04_fr0033.png ai_050_002/depth_plane_cam_04_fr0033.png
ai_051_004/rgb_cam_00_fr0028.png ai_051_004/depth_plane_cam_00_fr0028.png
ai_051_004/rgb_cam_01_fr0065.png ai_051_004/depth_plane_cam_01_fr0065.png
ai_051_004/rgb_cam_02_fr0054.png ai_051_004/depth_plane_cam_02_fr0054.png
ai_051_004/rgb_cam_02_fr0056.png ai_051_004/depth_plane_cam_02_fr0056.png
ai_051_004/rgb_cam_03_fr0037.png ai_051_004/depth_plane_cam_03_fr0037.png
ai_051_004/rgb_cam_04_fr0083.png ai_051_004/depth_plane_cam_04_fr0083.png
ai_051_004/rgb_cam_05_fr0003.png ai_051_004/depth_plane_cam_05_fr0003.png
ai_052_001/rgb_cam_00_fr0008.png ai_052_001/depth_plane_cam_00_fr0008.png
ai_052_003/rgb_cam_00_fr0097.png ai_052_003/depth_plane_cam_00_fr0097.png
ai_052_003/rgb_cam_01_fr0081.png ai_052_003/depth_plane_cam_01_fr0081.png
ai_052_007/rgb_cam_01_fr0001.png ai_052_007/depth_plane_cam_01_fr0001.png
ai_053_003/rgb_cam_00_fr0005.png ai_053_003/depth_plane_cam_00_fr0005.png
ai_053_005/rgb_cam_00_fr0080.png ai_053_005/depth_plane_cam_00_fr0080.png
ai_055_009/rgb_cam_01_fr0070.png ai_055_009/depth_plane_cam_01_fr0070.png
ai_055_009/rgb_cam_01_fr0086.png ai_055_009/depth_plane_cam_01_fr0086.png
3 changes: 3 additions & 0 deletions data_split/hypersim/selected_vis_sample.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ai_015_004/rgb_cam_00_fr0002.png ai_015_004/depth_plane_cam_00_fr0002.png (val)
ai_044_003/rgb_cam_01_fr0063.png ai_044_003/depth_plane_cam_01_fr0063.png (val)
ai_052_003/rgb_cam_01_fr0076.png ai_052_003/depth_plane_cam_01_fr0076.png (val)
Loading

0 comments on commit 21c8843

Please sign in to comment.