From 21c884390c56b650b8a408ba994ebd5b3d338039 Mon Sep 17 00:00:00 2001 From: Bingxin Date: Sat, 25 May 2024 23:38:56 +0200 Subject: [PATCH] [ADD] training code --- README.md | 43 +- config/dataset/data_hypersim_train.yaml | 4 + config/dataset/data_hypersim_val.yaml | 4 + config/dataset/data_kitti_val.yaml | 6 + config/dataset/data_nyu_train.yaml | 5 + config/dataset/data_vkitti_train.yaml | 6 + config/dataset/data_vkitti_val.yaml | 6 + config/dataset/dataset_train.yaml | 18 + config/dataset/dataset_val.yaml | 45 ++ config/dataset/dataset_vis.yaml | 9 + config/logging.yaml | 5 + config/model_sdv2.yaml | 4 + config/train_debug.yaml | 12 + config/train_marigold.yaml | 94 +++ config/wandb.yaml | 3 + .../filename_list_val_filtered_small_80.txt | 80 +++ data_split/hypersim/selected_vis_sample.txt | 3 + .../kitti/eigen_val_from_train_sub_100.txt | 100 +++ .../labeled/filename_list_train_small_100.txt | 100 +++ infer.py | 2 +- requirements++.txt | 6 + script/eval/11_infer_nyu.sh | 6 +- script/eval/12_eval_nyu.sh | 5 +- script/eval/21_infer_kitti.sh | 8 +- script/eval/22_eval_kitti.sh | 5 +- script/eval/31_infer_eth3d.sh | 8 +- script/eval/32_eval_eth3d.sh | 5 +- script/eval/41_infer_scannet.sh | 8 +- script/eval/42_eval_scannet.sh | 5 +- script/eval/51_infer_diode.sh | 8 +- script/eval/52_eval_diode.sh | 5 +- src/dataset/__init__.py | 36 +- src/dataset/base_depth_dataset.py | 72 +- src/dataset/diode_dataset.py | 21 +- src/dataset/eth3d_dataset.py | 21 +- src/dataset/hypersim_dataset.py | 45 ++ src/dataset/kitti_dataset.py | 21 +- src/dataset/mixed_sampler.py | 149 ++++ src/dataset/nyu_dataset.py | 22 +- src/dataset/scannet_dataset.py | 21 +- src/dataset/vkitti_dataset.py | 98 +++ src/trainer/__init__.py | 13 + src/trainer/marigold_trainer.py | 674 ++++++++++++++++++ src/util/config_util.py | 49 ++ src/util/data_loader.py | 111 +++ src/util/depth_transform.py | 19 +- src/util/logging_util.py | 102 +++ src/util/loss.py | 124 ++++ src/util/lr_scheduler.py | 48 ++ src/util/multi_res_noise.py | 75 ++ src/util/{seed_all.py => seeding.py} | 21 + src/util/slurm_util.py | 15 + train.py | 363 ++++++++++ 53 files changed, 2677 insertions(+), 61 deletions(-) create mode 100644 config/dataset/data_hypersim_train.yaml create mode 100644 config/dataset/data_hypersim_val.yaml create mode 100644 config/dataset/data_kitti_val.yaml create mode 100644 config/dataset/data_nyu_train.yaml create mode 100644 config/dataset/data_vkitti_train.yaml create mode 100644 config/dataset/data_vkitti_val.yaml create mode 100644 config/dataset/dataset_train.yaml create mode 100644 config/dataset/dataset_val.yaml create mode 100644 config/dataset/dataset_vis.yaml create mode 100644 config/logging.yaml create mode 100644 config/model_sdv2.yaml create mode 100644 config/train_debug.yaml create mode 100644 config/train_marigold.yaml create mode 100644 config/wandb.yaml create mode 100644 data_split/hypersim/filename_list_val_filtered_small_80.txt create mode 100644 data_split/hypersim/selected_vis_sample.txt create mode 100644 data_split/kitti/eigen_val_from_train_sub_100.txt create mode 100644 data_split/nyu/labeled/filename_list_train_small_100.txt create mode 100644 requirements++.txt create mode 100644 src/dataset/hypersim_dataset.py create mode 100644 src/dataset/mixed_sampler.py create mode 100644 src/dataset/vkitti_dataset.py create mode 100644 src/trainer/__init__.py create mode 100644 src/trainer/marigold_trainer.py create mode 100644 src/util/config_util.py create mode 100644 src/util/data_loader.py create mode 100644 src/util/logging_util.py create mode 100644 src/util/loss.py create mode 100644 src/util/lr_scheduler.py create mode 100644 src/util/multi_res_noise.py rename src/util/{seed_all.py => seeding.py} (72%) create mode 100644 src/util/slurm_util.py create mode 100644 train.py diff --git a/README.md b/README.md index 8e45365..13efaa6 100644 --- a/README.md +++ b/README.md @@ -160,9 +160,9 @@ The default settings are optimized for the best result. However, the behavior of - `--processing_res`: the processing resolution; set as 0 to process the input resolution directly. When unassigned (`None`), will read default setting from model config. Default: ~~768~~ `None`. - `--output_processing_res`: produce output at the processing resolution instead of upsampling it to the input resolution. Default: False. - - `--resample_method`: resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`. Default: `bilinear`. + - `--resample_method`: the resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic`, or `nearest`. Default: `bilinear`. -- `--half_precision` or `--fp16`: Run with half-precision (16-bit float) to reduce VRAM usage, might lead to suboptimal result. +- `--half_precision` or `--fp16`: Run with half-precision (16-bit float) to reduce VRAM usage, which might lead to suboptimal results. - `--seed`: Random seed can be set to ensure additional reproducibility. Default: None (unseeded). Note: forcing `--batch_size 1` helps to increase reproducibility. To ensure full reproducibility, [deterministic mode](https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms) needs to be used. - `--batch_size`: Batch size of repeated inference. Default: 0 (best value determined automatically). - `--color_map`: [Colormap](https://matplotlib.org/stable/users/explain/colors/colormaps.html) used to colorize the depth prediction. Default: Spectral. Set to `None` to skip colored depth map generation. @@ -196,7 +196,7 @@ python run.py \ --output_dir output/in-the-wild_example ``` -## 🦿 Evaluation on test datasets +## 🦿 Evaluation on test datasets Install additional dependencies: @@ -224,6 +224,43 @@ bash script/eval/12_eval_nyu.sh Note: although the seed has been set, the results might still be slightly different on different hardware. +## 🏋️ Training + +Based on the previously created environment, install extended requirements: + +```bash +pip install -r requirements++.txt -r requirements+.txt -r requirements.txt +``` + +Set environment parameters for the data directory: + +```bash +export BASE_DATA_DIR=YOUR_DATA_DIR # directory of training data +export BASE_CKPT_DIR=YOUR_CHECKPOINT_DIR # directory of pretrained checkpoint +``` + +Download Stable Diffusion v2 [checkpoint](https://huggingface.co/stabilityai/stable-diffusion-2) into `${BASE_CKPT_DIR}` + +Prepare for [Hypersim](https://github.com/apple/ml-hypersim) and [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/) datasets and save into `${BASE_DATA_DIR}`. Please refer to [this README](script/dataset_preprocess/hypersim/README.md) for Hypersim preprocessing. + +Run training script + +```bash +python train.py --config config/train_marigold.yaml +``` + +Resume from a checkpoint, e.g. + +```bash +python train.py --resume_from output/marigold_base/checkpoint/latest +``` + +Evaluating results + +Only the U-Net is updated and saved during training. To use the inference pipeline with your training result, replace `unet` folder in Marigold checkpoints with that in the `checkpoint` output folder. Then refer to [this section](#evaluation) for evaluation. + +**Note**: Although random seeds have been set, the training result might be slightly different on different hardwares. It's recommended to train without interruption. + ## ✏️ Contributing Please refer to [this](CONTRIBUTING.md) instruction. diff --git a/config/dataset/data_hypersim_train.yaml b/config/dataset/data_hypersim_train.yaml new file mode 100644 index 0000000..589a189 --- /dev/null +++ b/config/dataset/data_hypersim_train.yaml @@ -0,0 +1,4 @@ +name: hypersim +disp_name: hypersim_train +dir: hypersim/hypersim_processed_train.tar +filenames: data_split/hypersim/filename_list_train_filtered.txt \ No newline at end of file diff --git a/config/dataset/data_hypersim_val.yaml b/config/dataset/data_hypersim_val.yaml new file mode 100644 index 0000000..fe26bdc --- /dev/null +++ b/config/dataset/data_hypersim_val.yaml @@ -0,0 +1,4 @@ +name: hypersim +disp_name: hypersim_val +dir: hypersim/hypersim_processed_val.tar +filenames: data_split/hypersim/filename_list_val_filtered.txt \ No newline at end of file diff --git a/config/dataset/data_kitti_val.yaml b/config/dataset/data_kitti_val.yaml new file mode 100644 index 0000000..fc6c3a2 --- /dev/null +++ b/config/dataset/data_kitti_val.yaml @@ -0,0 +1,6 @@ +name: kitti +disp_name: kitti_val800_from_eigen_train +dir: kitti/kitti_sampled_val_800.tar +filenames: data_split/kitti/eigen_val_from_train_800.txt +kitti_bm_crop: true +valid_mask_crop: eigen \ No newline at end of file diff --git a/config/dataset/data_nyu_train.yaml b/config/dataset/data_nyu_train.yaml new file mode 100644 index 0000000..185ff58 --- /dev/null +++ b/config/dataset/data_nyu_train.yaml @@ -0,0 +1,5 @@ +name: nyu_v2 +disp_name: nyu_train_full +dir: nyuv2/nyu_labeled_extracted.tar +filenames: data_split/nyu/labeled/filename_list_train.txt +eigen_valid_mask: true \ No newline at end of file diff --git a/config/dataset/data_vkitti_train.yaml b/config/dataset/data_vkitti_train.yaml new file mode 100644 index 0000000..8c089c4 --- /dev/null +++ b/config/dataset/data_vkitti_train.yaml @@ -0,0 +1,6 @@ +name: vkitti +disp_name: vkitti_train +dir: vkitti/vkitti.tar +filenames: data_split/vkitti/vkitti_train.txt +kitti_bm_crop: true +valid_mask_crop: null # no valid_mask_crop for training \ No newline at end of file diff --git a/config/dataset/data_vkitti_val.yaml b/config/dataset/data_vkitti_val.yaml new file mode 100644 index 0000000..257e9f1 --- /dev/null +++ b/config/dataset/data_vkitti_val.yaml @@ -0,0 +1,6 @@ +name: vkitti +disp_name: vkitti_val +dir: vkitti/vkitti.tar +filenames: data_split/vkitti/vkitti_val.txt +kitti_bm_crop: true +valid_mask_crop: eigen \ No newline at end of file diff --git a/config/dataset/dataset_train.yaml b/config/dataset/dataset_train.yaml new file mode 100644 index 0000000..0381948 --- /dev/null +++ b/config/dataset/dataset_train.yaml @@ -0,0 +1,18 @@ +dataset: + train: + name: mixed + prob_ls: [0.9, 0.1] + dataset_list: + - name: hypersim + disp_name: hypersim_train + dir: hypersim/hypersim_processed_train.tar + filenames: data_split/hypersim/filename_list_train_filtered.txt + resize_to_hw: + - 480 + - 640 + - name: vkitti + disp_name: vkitti_train + dir: vkitti/vkitti.tar + filenames: data_split/vkitti/vkitti_train.txt + kitti_bm_crop: true + valid_mask_crop: null \ No newline at end of file diff --git a/config/dataset/dataset_val.yaml b/config/dataset/dataset_val.yaml new file mode 100644 index 0000000..f70d90d --- /dev/null +++ b/config/dataset/dataset_val.yaml @@ -0,0 +1,45 @@ +dataset: + val: + # - name: hypersim + # disp_name: hypersim_val + # dir: hypersim/hypersim_processed_val.tar + # filenames: data_split/hypersim/filename_list_val_filtered.txt + # resize_to_hw: + # - 480 + # - 640 + + # - name: nyu_v2 + # disp_name: nyu_train_full + # dir: nyuv2/nyu_labeled_extracted.tar + # filenames: data_split/nyu/labeled/filename_list_train.txt + # eigen_valid_mask: true + + # - name: kitti + # disp_name: kitti_val800_from_eigen_train + # dir: kitti/kitti_sampled_val_800.tar + # filenames: data_split/kitti/eigen_val_from_train_800.txt + # kitti_bm_crop: true + # valid_mask_crop: eigen + + # Smaller subsets for faster validation during training + # The first dataset is used to calculate main eval metric. + - name: hypersim + disp_name: hypersim_val_small_80 + dir: hypersim/hypersim_processed_val.tar + filenames: data_split/hypersim/filename_list_val_filtered_small_80.txt + resize_to_hw: + - 480 + - 640 + + - name: nyu_v2 + disp_name: nyu_train_small_100 + dir: nyuv2/nyu_labeled_extracted.tar + filenames: data_split/nyu/labeled/filename_list_train_small_100.txt + eigen_valid_mask: true + + - name: kitti + disp_name: kitti_val_from_train_sub_100 + dir: kitti/kitti_sampled_val_800.tar + filenames: data_split/kitti/eigen_val_from_train_sub_100.txt + kitti_bm_crop: true + valid_mask_crop: eigen \ No newline at end of file diff --git a/config/dataset/dataset_vis.yaml b/config/dataset/dataset_vis.yaml new file mode 100644 index 0000000..c11dc8e --- /dev/null +++ b/config/dataset/dataset_vis.yaml @@ -0,0 +1,9 @@ +dataset: + vis: + - name: hypersim + disp_name: hypersim_vis + dir: hypersim/hypersim_processed_val.tar + filenames: data_split/hypersim/selected_vis_sample.txt + resize_to_hw: + - 480 + - 640 diff --git a/config/logging.yaml b/config/logging.yaml new file mode 100644 index 0000000..8cecbae --- /dev/null +++ b/config/logging.yaml @@ -0,0 +1,5 @@ +logging: + filename: logging.log + format: ' %(asctime)s - %(levelname)s -%(filename)s - %(funcName)s >> %(message)s' + console_level: 20 + file_level: 10 diff --git a/config/model_sdv2.yaml b/config/model_sdv2.yaml new file mode 100644 index 0000000..ce58c5e --- /dev/null +++ b/config/model_sdv2.yaml @@ -0,0 +1,4 @@ +model: + name: marigold_pipeline + pretrained_path: stable-diffusion-2 + latent_scale_factor: 0.18215 \ No newline at end of file diff --git a/config/train_debug.yaml b/config/train_debug.yaml new file mode 100644 index 0000000..7c21ed1 --- /dev/null +++ b/config/train_debug.yaml @@ -0,0 +1,12 @@ +base_config: +- config/train_marigold.yaml + + +# Training settings +trainer: + save_period: 5 + backup_period: 10 + validation_period: 5 + visualization_period: 5 + +max_iter: 50 \ No newline at end of file diff --git a/config/train_marigold.yaml b/config/train_marigold.yaml new file mode 100644 index 0000000..defb4c7 --- /dev/null +++ b/config/train_marigold.yaml @@ -0,0 +1,94 @@ +base_config: +- config/logging.yaml +- config/wandb.yaml +- config/dataset/dataset_train.yaml +- config/dataset/dataset_val.yaml +- config/dataset/dataset_vis.yaml +- config/model_sdv2.yaml + + +pipeline: + name: MarigoldPipeline + kwargs: + scale_invariant: true + shift_invariant: true + +depth_normalization: + type: scale_shift_depth + clip: true + norm_min: -1.0 + norm_max: 1.0 + min_max_quantile: 0.02 + +augmentation: + lr_flip_p: 0.5 + +dataloader: + num_workers: 2 + effective_batch_size: 32 + max_train_batch_size: 2 + seed: 2024 # to ensure continuity when resuming from checkpoint + +# Training settings +trainer: + name: MarigoldTrainer + training_noise_scheduler: + pretrained_path: stable-diffusion-2 + init_seed: 2024 # use null to train w/o seeding + save_period: 50 + backup_period: 2000 + validation_period: 2000 + visualization_period: 2000 + +multi_res_noise: + strength: 0.9 + annealed: true + downscale_strategy: original + +gt_depth_type: depth_raw_norm +gt_mask_type: valid_mask_raw + +max_epoch: 10000 # a large enough number +max_iter: 30000 # usually converges at around 20k + +optimizer: + name: Adam + +loss: + name: mse_loss + kwargs: + reduction: mean + +lr: 3.0e-05 +lr_scheduler: + name: IterExponential + kwargs: + total_iter: 25000 + final_ratio: 0.01 + warmup_steps: 100 + +# Validation (and visualization) settings +validation: + denoising_steps: 50 + ensemble_size: 1 # simplified setting for on-training validation + processing_res: 0 + match_input_res: false + resample_method: bilinear + main_val_metric: abs_relative_difference + main_val_metric_goal: minimize + init_seed: 2024 + +eval: + alignment: least_square + align_max_res: null + eval_metrics: + - abs_relative_difference + - squared_relative_difference + - rmse_linear + - rmse_log + - log10 + - delta1_acc + - delta2_acc + - delta3_acc + - i_rmse + - silog_rmse diff --git a/config/wandb.yaml b/config/wandb.yaml new file mode 100644 index 0000000..5631cfb --- /dev/null +++ b/config/wandb.yaml @@ -0,0 +1,3 @@ +wandb: + # entity: your_entity + project: marigold \ No newline at end of file diff --git a/data_split/hypersim/filename_list_val_filtered_small_80.txt b/data_split/hypersim/filename_list_val_filtered_small_80.txt new file mode 100644 index 0000000..b240c2d --- /dev/null +++ b/data_split/hypersim/filename_list_val_filtered_small_80.txt @@ -0,0 +1,80 @@ +ai_003_010/rgb_cam_00_fr0047.png ai_003_010/depth_plane_cam_00_fr0047.png +ai_003_010/rgb_cam_00_fr0048.png ai_003_010/depth_plane_cam_00_fr0048.png +ai_003_010/rgb_cam_01_fr0098.png ai_003_010/depth_plane_cam_01_fr0098.png +ai_004_003/rgb_cam_01_fr0008.png ai_004_003/depth_plane_cam_01_fr0008.png +ai_004_004/rgb_cam_00_fr0025.png ai_004_004/depth_plane_cam_00_fr0025.png +ai_004_004/rgb_cam_00_fr0046.png ai_004_004/depth_plane_cam_00_fr0046.png +ai_004_004/rgb_cam_00_fr0049.png ai_004_004/depth_plane_cam_00_fr0049.png +ai_004_004/rgb_cam_01_fr0023.png ai_004_004/depth_plane_cam_01_fr0023.png +ai_005_005/rgb_cam_00_fr0032.png ai_005_005/depth_plane_cam_00_fr0032.png +ai_006_007/rgb_cam_00_fr0022.png ai_006_007/depth_plane_cam_00_fr0022.png +ai_006_007/rgb_cam_00_fr0095.png ai_006_007/depth_plane_cam_00_fr0095.png +ai_007_001/rgb_cam_00_fr0044.png ai_007_001/depth_plane_cam_00_fr0044.png +ai_007_001/rgb_cam_00_fr0048.png ai_007_001/depth_plane_cam_00_fr0048.png +ai_009_007/rgb_cam_00_fr0017.png ai_009_007/depth_plane_cam_00_fr0017.png +ai_009_007/rgb_cam_00_fr0097.png ai_009_007/depth_plane_cam_00_fr0097.png +ai_009_009/rgb_cam_00_fr0094.png ai_009_009/depth_plane_cam_00_fr0094.png +ai_015_001/rgb_cam_00_fr0058.png ai_015_001/depth_plane_cam_00_fr0058.png +ai_015_001/rgb_cam_00_fr0089.png ai_015_001/depth_plane_cam_00_fr0089.png +ai_017_007/rgb_cam_01_fr0064.png ai_017_007/depth_plane_cam_01_fr0064.png +ai_018_005/rgb_cam_00_fr0014.png ai_018_005/depth_plane_cam_00_fr0014.png +ai_018_005/rgb_cam_00_fr0059.png ai_018_005/depth_plane_cam_00_fr0059.png +ai_022_010/rgb_cam_00_fr0097.png ai_022_010/depth_plane_cam_00_fr0097.png +ai_022_010/rgb_cam_00_fr0099.png ai_022_010/depth_plane_cam_00_fr0099.png +ai_023_003/rgb_cam_00_fr0013.png ai_023_003/depth_plane_cam_00_fr0013.png +ai_023_003/rgb_cam_00_fr0015.png ai_023_003/depth_plane_cam_00_fr0015.png +ai_023_003/rgb_cam_00_fr0036.png ai_023_003/depth_plane_cam_00_fr0036.png +ai_023_003/rgb_cam_00_fr0095.png ai_023_003/depth_plane_cam_00_fr0095.png +ai_023_003/rgb_cam_01_fr0029.png ai_023_003/depth_plane_cam_01_fr0029.png +ai_023_003/rgb_cam_01_fr0036.png ai_023_003/depth_plane_cam_01_fr0036.png +ai_023_003/rgb_cam_01_fr0071.png ai_023_003/depth_plane_cam_01_fr0071.png +ai_032_007/rgb_cam_00_fr0031.png ai_032_007/depth_plane_cam_00_fr0031.png +ai_032_007/rgb_cam_00_fr0040.png ai_032_007/depth_plane_cam_00_fr0040.png +ai_032_007/rgb_cam_00_fr0075.png ai_032_007/depth_plane_cam_00_fr0075.png +ai_035_003/rgb_cam_00_fr0054.png ai_035_003/depth_plane_cam_00_fr0054.png +ai_035_004/rgb_cam_00_fr0077.png ai_035_004/depth_plane_cam_00_fr0077.png +ai_038_009/rgb_cam_00_fr0031.png ai_038_009/depth_plane_cam_00_fr0031.png +ai_038_009/rgb_cam_01_fr0010.png ai_038_009/depth_plane_cam_01_fr0010.png +ai_038_009/rgb_cam_01_fr0088.png ai_038_009/depth_plane_cam_01_fr0088.png +ai_039_003/rgb_cam_01_fr0042.png ai_039_003/depth_plane_cam_01_fr0042.png +ai_039_003/rgb_cam_01_fr0097.png ai_039_003/depth_plane_cam_01_fr0097.png +ai_044_001/rgb_cam_00_fr0043.png ai_044_001/depth_plane_cam_00_fr0043.png +ai_044_001/rgb_cam_01_fr0018.png ai_044_001/depth_plane_cam_01_fr0018.png +ai_044_003/rgb_cam_01_fr0082.png ai_044_003/depth_plane_cam_01_fr0082.png +ai_044_003/rgb_cam_01_fr0087.png ai_044_003/depth_plane_cam_01_fr0087.png +ai_044_003/rgb_cam_02_fr0086.png ai_044_003/depth_plane_cam_02_fr0086.png +ai_044_003/rgb_cam_03_fr0022.png ai_044_003/depth_plane_cam_03_fr0022.png +ai_044_003/rgb_cam_03_fr0063.png ai_044_003/depth_plane_cam_03_fr0063.png +ai_045_008/rgb_cam_00_fr0015.png ai_045_008/depth_plane_cam_00_fr0015.png +ai_045_008/rgb_cam_00_fr0030.png ai_045_008/depth_plane_cam_00_fr0030.png +ai_045_008/rgb_cam_01_fr0029.png ai_045_008/depth_plane_cam_01_fr0029.png +ai_045_008/rgb_cam_01_fr0052.png ai_045_008/depth_plane_cam_01_fr0052.png +ai_045_008/rgb_cam_01_fr0088.png ai_045_008/depth_plane_cam_01_fr0088.png +ai_047_009/rgb_cam_00_fr0097.png ai_047_009/depth_plane_cam_00_fr0097.png +ai_048_001/rgb_cam_00_fr0014.png ai_048_001/depth_plane_cam_00_fr0014.png +ai_048_001/rgb_cam_00_fr0088.png ai_048_001/depth_plane_cam_00_fr0088.png +ai_048_001/rgb_cam_01_fr0045.png ai_048_001/depth_plane_cam_01_fr0045.png +ai_048_001/rgb_cam_02_fr0031.png ai_048_001/depth_plane_cam_02_fr0031.png +ai_048_001/rgb_cam_03_fr0005.png ai_048_001/depth_plane_cam_03_fr0005.png +ai_048_001/rgb_cam_03_fr0045.png ai_048_001/depth_plane_cam_03_fr0045.png +ai_048_001/rgb_cam_03_fr0054.png ai_048_001/depth_plane_cam_03_fr0054.png +ai_048_001/rgb_cam_03_fr0061.png ai_048_001/depth_plane_cam_03_fr0061.png +ai_050_002/rgb_cam_01_fr0016.png ai_050_002/depth_plane_cam_01_fr0016.png +ai_050_002/rgb_cam_02_fr0053.png ai_050_002/depth_plane_cam_02_fr0053.png +ai_050_002/rgb_cam_03_fr0082.png ai_050_002/depth_plane_cam_03_fr0082.png +ai_050_002/rgb_cam_04_fr0033.png ai_050_002/depth_plane_cam_04_fr0033.png +ai_051_004/rgb_cam_00_fr0028.png ai_051_004/depth_plane_cam_00_fr0028.png +ai_051_004/rgb_cam_01_fr0065.png ai_051_004/depth_plane_cam_01_fr0065.png +ai_051_004/rgb_cam_02_fr0054.png ai_051_004/depth_plane_cam_02_fr0054.png +ai_051_004/rgb_cam_02_fr0056.png ai_051_004/depth_plane_cam_02_fr0056.png +ai_051_004/rgb_cam_03_fr0037.png ai_051_004/depth_plane_cam_03_fr0037.png +ai_051_004/rgb_cam_04_fr0083.png ai_051_004/depth_plane_cam_04_fr0083.png +ai_051_004/rgb_cam_05_fr0003.png ai_051_004/depth_plane_cam_05_fr0003.png +ai_052_001/rgb_cam_00_fr0008.png ai_052_001/depth_plane_cam_00_fr0008.png +ai_052_003/rgb_cam_00_fr0097.png ai_052_003/depth_plane_cam_00_fr0097.png +ai_052_003/rgb_cam_01_fr0081.png ai_052_003/depth_plane_cam_01_fr0081.png +ai_052_007/rgb_cam_01_fr0001.png ai_052_007/depth_plane_cam_01_fr0001.png +ai_053_003/rgb_cam_00_fr0005.png ai_053_003/depth_plane_cam_00_fr0005.png +ai_053_005/rgb_cam_00_fr0080.png ai_053_005/depth_plane_cam_00_fr0080.png +ai_055_009/rgb_cam_01_fr0070.png ai_055_009/depth_plane_cam_01_fr0070.png +ai_055_009/rgb_cam_01_fr0086.png ai_055_009/depth_plane_cam_01_fr0086.png \ No newline at end of file diff --git a/data_split/hypersim/selected_vis_sample.txt b/data_split/hypersim/selected_vis_sample.txt new file mode 100644 index 0000000..0ebd78d --- /dev/null +++ b/data_split/hypersim/selected_vis_sample.txt @@ -0,0 +1,3 @@ +ai_015_004/rgb_cam_00_fr0002.png ai_015_004/depth_plane_cam_00_fr0002.png (val) +ai_044_003/rgb_cam_01_fr0063.png ai_044_003/depth_plane_cam_01_fr0063.png (val) +ai_052_003/rgb_cam_01_fr0076.png ai_052_003/depth_plane_cam_01_fr0076.png (val) \ No newline at end of file diff --git a/data_split/kitti/eigen_val_from_train_sub_100.txt b/data_split/kitti/eigen_val_from_train_sub_100.txt new file mode 100644 index 0000000..fa0ebc4 --- /dev/null +++ b/data_split/kitti/eigen_val_from_train_sub_100.txt @@ -0,0 +1,100 @@ +2011_09_26/2011_09_26_drive_0001_sync/image_02/data/0000000046.png 2011_09_26_drive_0001_sync/proj_depth/groundtruth/image_02/0000000046.png 721.5377 +2011_09_26/2011_09_26_drive_0005_sync/image_02/data/0000000148.png 2011_09_26_drive_0005_sync/proj_depth/groundtruth/image_02/0000000148.png 721.5377 +2011_09_26/2011_09_26_drive_0014_sync/image_02/data/0000000076.png 2011_09_26_drive_0014_sync/proj_depth/groundtruth/image_02/0000000076.png 721.5377 +2011_09_26/2011_09_26_drive_0015_sync/image_02/data/0000000019.png 2011_09_26_drive_0015_sync/proj_depth/groundtruth/image_02/0000000019.png 721.5377 +2011_09_26/2011_09_26_drive_0015_sync/image_02/data/0000000194.png 2011_09_26_drive_0015_sync/proj_depth/groundtruth/image_02/0000000194.png 721.5377 +2011_09_26/2011_09_26_drive_0018_sync/image_02/data/0000000106.png 2011_09_26_drive_0018_sync/proj_depth/groundtruth/image_02/0000000106.png 721.5377 +2011_09_26/2011_09_26_drive_0019_sync/image_02/data/0000000263.png 2011_09_26_drive_0019_sync/proj_depth/groundtruth/image_02/0000000263.png 721.5377 +2011_09_26/2011_09_26_drive_0019_sync/image_02/data/0000000274.png 2011_09_26_drive_0019_sync/proj_depth/groundtruth/image_02/0000000274.png 721.5377 +2011_09_26/2011_09_26_drive_0022_sync/image_02/data/0000000015.png 2011_09_26_drive_0022_sync/proj_depth/groundtruth/image_02/0000000015.png 721.5377 +2011_09_26/2011_09_26_drive_0022_sync/image_02/data/0000000123.png 2011_09_26_drive_0022_sync/proj_depth/groundtruth/image_02/0000000123.png 721.5377 +2011_09_26/2011_09_26_drive_0022_sync/image_02/data/0000000149.png 2011_09_26_drive_0022_sync/proj_depth/groundtruth/image_02/0000000149.png 721.5377 +2011_09_26/2011_09_26_drive_0022_sync/image_02/data/0000000308.png 2011_09_26_drive_0022_sync/proj_depth/groundtruth/image_02/0000000308.png 721.5377 +2011_09_26/2011_09_26_drive_0022_sync/image_02/data/0000000553.png 2011_09_26_drive_0022_sync/proj_depth/groundtruth/image_02/0000000553.png 721.5377 +2011_09_26/2011_09_26_drive_0022_sync/image_02/data/0000000691.png 2011_09_26_drive_0022_sync/proj_depth/groundtruth/image_02/0000000691.png 721.5377 +2011_09_26/2011_09_26_drive_0028_sync/image_02/data/0000000270.png 2011_09_26_drive_0028_sync/proj_depth/groundtruth/image_02/0000000270.png 721.5377 +2011_09_26/2011_09_26_drive_0035_sync/image_02/data/0000000085.png 2011_09_26_drive_0035_sync/proj_depth/groundtruth/image_02/0000000085.png 721.5377 +2011_09_26/2011_09_26_drive_0039_sync/image_02/data/0000000326.png 2011_09_26_drive_0039_sync/proj_depth/groundtruth/image_02/0000000326.png 721.5377 +2011_09_26/2011_09_26_drive_0051_sync/image_02/data/0000000429.png 2011_09_26_drive_0051_sync/proj_depth/groundtruth/image_02/0000000429.png 721.5377 +2011_09_26/2011_09_26_drive_0057_sync/image_02/data/0000000010.png 2011_09_26_drive_0057_sync/proj_depth/groundtruth/image_02/0000000010.png 721.5377 +2011_09_26/2011_09_26_drive_0060_sync/image_02/data/0000000020.png 2011_09_26_drive_0060_sync/proj_depth/groundtruth/image_02/0000000020.png 721.5377 +2011_09_26/2011_09_26_drive_0061_sync/image_02/data/0000000223.png 2011_09_26_drive_0061_sync/proj_depth/groundtruth/image_02/0000000223.png 721.5377 +2011_09_26/2011_09_26_drive_0061_sync/image_02/data/0000000262.png 2011_09_26_drive_0061_sync/proj_depth/groundtruth/image_02/0000000262.png 721.5377 +2011_09_26/2011_09_26_drive_0061_sync/image_02/data/0000000291.png 2011_09_26_drive_0061_sync/proj_depth/groundtruth/image_02/0000000291.png 721.5377 +2011_09_26/2011_09_26_drive_0061_sync/image_02/data/0000000523.png 2011_09_26_drive_0061_sync/proj_depth/groundtruth/image_02/0000000523.png 721.5377 +2011_09_26/2011_09_26_drive_0061_sync/image_02/data/0000000524.png 2011_09_26_drive_0061_sync/proj_depth/groundtruth/image_02/0000000524.png 721.5377 +2011_09_26/2011_09_26_drive_0070_sync/image_02/data/0000000063.png 2011_09_26_drive_0070_sync/proj_depth/groundtruth/image_02/0000000063.png 721.5377 +2011_09_26/2011_09_26_drive_0070_sync/image_02/data/0000000320.png 2011_09_26_drive_0070_sync/proj_depth/groundtruth/image_02/0000000320.png 721.5377 +2011_09_26/2011_09_26_drive_0087_sync/image_02/data/0000000313.png 2011_09_26_drive_0087_sync/proj_depth/groundtruth/image_02/0000000313.png 721.5377 +2011_09_26/2011_09_26_drive_0087_sync/image_02/data/0000000316.png 2011_09_26_drive_0087_sync/proj_depth/groundtruth/image_02/0000000316.png 721.5377 +2011_09_26/2011_09_26_drive_0087_sync/image_02/data/0000000363.png 2011_09_26_drive_0087_sync/proj_depth/groundtruth/image_02/0000000363.png 721.5377 +2011_09_26/2011_09_26_drive_0087_sync/image_02/data/0000000438.png 2011_09_26_drive_0087_sync/proj_depth/groundtruth/image_02/0000000438.png 721.5377 +2011_09_26/2011_09_26_drive_0091_sync/image_02/data/0000000137.png 2011_09_26_drive_0091_sync/proj_depth/groundtruth/image_02/0000000137.png 721.5377 +2011_09_26/2011_09_26_drive_0091_sync/image_02/data/0000000143.png 2011_09_26_drive_0091_sync/proj_depth/groundtruth/image_02/0000000143.png 721.5377 +2011_09_26/2011_09_26_drive_0091_sync/image_02/data/0000000278.png 2011_09_26_drive_0091_sync/proj_depth/groundtruth/image_02/0000000278.png 721.5377 +2011_09_26/2011_09_26_drive_0091_sync/image_02/data/0000000312.png 2011_09_26_drive_0091_sync/proj_depth/groundtruth/image_02/0000000312.png 721.5377 +2011_09_26/2011_09_26_drive_0095_sync/image_02/data/0000000160.png 2011_09_26_drive_0095_sync/proj_depth/groundtruth/image_02/0000000160.png 721.5377 +2011_09_26/2011_09_26_drive_0104_sync/image_02/data/0000000011.png 2011_09_26_drive_0104_sync/proj_depth/groundtruth/image_02/0000000011.png 721.5377 +2011_09_26/2011_09_26_drive_0113_sync/image_02/data/0000000052.png 2011_09_26_drive_0113_sync/proj_depth/groundtruth/image_02/0000000052.png 721.5377 +2011_09_26/2011_09_26_drive_0113_sync/image_02/data/0000000055.png 2011_09_26_drive_0113_sync/proj_depth/groundtruth/image_02/0000000055.png 721.5377 +2011_09_29/2011_09_29_drive_0004_sync/image_02/data/0000000065.png 2011_09_29_drive_0004_sync/proj_depth/groundtruth/image_02/0000000065.png 718.3351 +2011_09_30/2011_09_30_drive_0020_sync/image_02/data/0000000325.png 2011_09_30_drive_0020_sync/proj_depth/groundtruth/image_02/0000000325.png 707.0912 +2011_09_30/2011_09_30_drive_0020_sync/image_02/data/0000000959.png 2011_09_30_drive_0020_sync/proj_depth/groundtruth/image_02/0000000959.png 707.0912 +2011_09_30/2011_09_30_drive_0020_sync/image_02/data/0000001004.png 2011_09_30_drive_0020_sync/proj_depth/groundtruth/image_02/0000001004.png 707.0912 +2011_09_30/2011_09_30_drive_0020_sync/image_02/data/0000001054.png 2011_09_30_drive_0020_sync/proj_depth/groundtruth/image_02/0000001054.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000000545.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000000545.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000000920.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000000920.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000001593.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000001593.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000001692.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000001692.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000001806.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000001806.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000001905.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000001905.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000002714.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000002714.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000002812.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000002812.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000002838.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000002838.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000003402.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000003402.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000003700.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000003700.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000004016.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000004016.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000004276.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000004276.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000004664.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000004664.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000004772.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000004772.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000004782.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000004782.png 707.0912 +2011_09_30/2011_09_30_drive_0028_sync/image_02/data/0000005095.png 2011_09_30_drive_0028_sync/proj_depth/groundtruth/image_02/0000005095.png 707.0912 +2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000000319.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000000319.png 707.0912 +2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000000355.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000000355.png 707.0912 +2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000000500.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000000500.png 707.0912 +2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000000682.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000000682.png 707.0912 +2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000000710.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000000710.png 707.0912 +2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000000896.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000000896.png 707.0912 +2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000001197.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000001197.png 707.0912 +2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000001508.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000001508.png 707.0912 +2011_09_30/2011_09_30_drive_0033_sync/image_02/data/0000001512.png 2011_09_30_drive_0033_sync/proj_depth/groundtruth/image_02/0000001512.png 707.0912 +2011_09_30/2011_09_30_drive_0034_sync/image_02/data/0000000029.png 2011_09_30_drive_0034_sync/proj_depth/groundtruth/image_02/0000000029.png 707.0912 +2011_09_30/2011_09_30_drive_0034_sync/image_02/data/0000000171.png 2011_09_30_drive_0034_sync/proj_depth/groundtruth/image_02/0000000171.png 707.0912 +2011_09_30/2011_09_30_drive_0034_sync/image_02/data/0000000193.png 2011_09_30_drive_0034_sync/proj_depth/groundtruth/image_02/0000000193.png 707.0912 +2011_09_30/2011_09_30_drive_0034_sync/image_02/data/0000000389.png 2011_09_30_drive_0034_sync/proj_depth/groundtruth/image_02/0000000389.png 707.0912 +2011_09_30/2011_09_30_drive_0034_sync/image_02/data/0000001141.png 2011_09_30_drive_0034_sync/proj_depth/groundtruth/image_02/0000001141.png 707.0912 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000000138.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000000138.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000000593.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000000593.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000001046.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000001046.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000001151.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000001151.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000001255.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000001255.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000001283.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000001283.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000001737.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000001737.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000001999.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000001999.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000002012.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000002012.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000002089.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000002089.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000002324.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000002324.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000002902.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000002902.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000002971.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000002971.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000003299.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000003299.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000003366.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000003366.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000003427.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000003427.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000003440.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000003440.png 718.856 +2011_10_03/2011_10_03_drive_0034_sync/image_02/data/0000004060.png 2011_10_03_drive_0034_sync/proj_depth/groundtruth/image_02/0000004060.png 718.856 +2011_10_03/2011_10_03_drive_0042_sync/image_02/data/0000000525.png 2011_10_03_drive_0042_sync/proj_depth/groundtruth/image_02/0000000525.png 718.856 +2011_10_03/2011_10_03_drive_0042_sync/image_02/data/0000000538.png 2011_10_03_drive_0042_sync/proj_depth/groundtruth/image_02/0000000538.png 718.856 +2011_10_03/2011_10_03_drive_0042_sync/image_02/data/0000000648.png 2011_10_03_drive_0042_sync/proj_depth/groundtruth/image_02/0000000648.png 718.856 +2011_10_03/2011_10_03_drive_0042_sync/image_02/data/0000000776.png 2011_10_03_drive_0042_sync/proj_depth/groundtruth/image_02/0000000776.png 718.856 +2011_10_03/2011_10_03_drive_0042_sync/image_02/data/0000000779.png 2011_10_03_drive_0042_sync/proj_depth/groundtruth/image_02/0000000779.png 718.856 +2011_10_03/2011_10_03_drive_0042_sync/image_02/data/0000001087.png 2011_10_03_drive_0042_sync/proj_depth/groundtruth/image_02/0000001087.png 718.856 +2011_10_03/2011_10_03_drive_0042_sync/image_02/data/0000001107.png 2011_10_03_drive_0042_sync/proj_depth/groundtruth/image_02/0000001107.png 718.856 \ No newline at end of file diff --git a/data_split/nyu/labeled/filename_list_train_small_100.txt b/data_split/nyu/labeled/filename_list_train_small_100.txt new file mode 100644 index 0000000..61103d4 --- /dev/null +++ b/data_split/nyu/labeled/filename_list_train_small_100.txt @@ -0,0 +1,100 @@ +train/bathroom_0007/rgb_0649.png train/bathroom_0007/depth_0649.png train/bathroom_0007/filled_0649.png +train/bathroom_0010/rgb_0653.png train/bathroom_0010/depth_0653.png train/bathroom_0010/filled_0653.png +train/bathroom_0041/rgb_0719.png train/bathroom_0041/depth_0719.png train/bathroom_0041/filled_0719.png +train/bathroom_0045/rgb_0729.png train/bathroom_0045/depth_0729.png train/bathroom_0045/filled_0729.png +train/bathroom_0048/rgb_0736.png train/bathroom_0048/depth_0736.png train/bathroom_0048/filled_0736.png +train/bathroom_0056/rgb_0505.png train/bathroom_0056/depth_0505.png train/bathroom_0056/filled_0505.png +train/bedroom_0004/rgb_0178.png train/bedroom_0004/depth_0178.png train/bedroom_0004/filled_0178.png +train/bedroom_0016/rgb_0071.png train/bedroom_0016/depth_0071.png train/bedroom_0016/filled_0071.png +train/bedroom_0025/rgb_0910.png train/bedroom_0025/depth_0910.png train/bedroom_0025/filled_0910.png +train/bedroom_0026/rgb_0914.png train/bedroom_0026/depth_0914.png train/bedroom_0026/filled_0914.png +train/bedroom_0031/rgb_0929.png train/bedroom_0031/depth_0929.png train/bedroom_0031/filled_0929.png +train/bedroom_0034/rgb_0939.png train/bedroom_0034/depth_0939.png train/bedroom_0034/filled_0939.png +train/bedroom_0040/rgb_0954.png train/bedroom_0040/depth_0954.png train/bedroom_0040/filled_0954.png +train/bedroom_0042/rgb_0958.png train/bedroom_0042/depth_0958.png train/bedroom_0042/filled_0958.png +train/bedroom_0050/rgb_0978.png train/bedroom_0050/depth_0978.png train/bedroom_0050/filled_0978.png +train/bedroom_0051/rgb_0984.png train/bedroom_0051/depth_0984.png train/bedroom_0051/filled_0984.png +train/bedroom_0056/rgb_0997.png train/bedroom_0056/depth_0997.png train/bedroom_0056/filled_0997.png +train/bedroom_0060/rgb_1008.png train/bedroom_0060/depth_1008.png train/bedroom_0060/filled_1008.png +train/bedroom_0067/rgb_1029.png train/bedroom_0067/depth_1029.png train/bedroom_0067/filled_1029.png +train/bedroom_0072/rgb_1045.png train/bedroom_0072/depth_1045.png train/bedroom_0072/filled_1045.png +train/bedroom_0072/rgb_1046.png train/bedroom_0072/depth_1046.png train/bedroom_0072/filled_1046.png +train/bedroom_0079/rgb_1062.png train/bedroom_0079/depth_1062.png train/bedroom_0079/filled_1062.png +train/bedroom_0081/rgb_1072.png train/bedroom_0081/depth_1072.png train/bedroom_0081/filled_1072.png +train/bedroom_0096/rgb_1112.png train/bedroom_0096/depth_1112.png train/bedroom_0096/filled_1112.png +train/bedroom_0118/rgb_1173.png train/bedroom_0118/depth_1173.png train/bedroom_0118/filled_1173.png +train/bedroom_0129/rgb_1197.png train/bedroom_0129/depth_1197.png train/bedroom_0129/filled_1197.png +train/bedroom_0136/rgb_0527.png train/bedroom_0136/depth_0527.png train/bedroom_0136/filled_0527.png +train/bookstore_0000/rgb_0105.png train/bookstore_0000/depth_0105.png train/bookstore_0000/filled_0105.png +train/bookstore_0000/rgb_0107.png train/bookstore_0000/depth_0107.png train/bookstore_0000/filled_0107.png +train/bookstore_0002/rgb_0101.png train/bookstore_0002/depth_0101.png train/bookstore_0002/filled_0101.png +train/bookstore_0002/rgb_0103.png train/bookstore_0002/depth_0103.png train/bookstore_0002/filled_0103.png +train/classroom_0010/rgb_0305.png train/classroom_0010/depth_0305.png train/classroom_0010/filled_0305.png +train/classroom_0012/rgb_0309.png train/classroom_0012/depth_0309.png train/classroom_0012/filled_0309.png +train/conference_room_0001/rgb_0339.png train/conference_room_0001/depth_0339.png train/conference_room_0001/filled_0339.png +train/conference_room_0001/rgb_0341.png train/conference_room_0001/depth_0341.png train/conference_room_0001/filled_0341.png +train/dining_room_0002/rgb_1346.png train/dining_room_0002/depth_1346.png train/dining_room_0002/filled_1346.png +train/dining_room_0008/rgb_1363.png train/dining_room_0008/depth_1363.png train/dining_room_0008/filled_1363.png +train/dining_room_0012/rgb_1371.png train/dining_room_0012/depth_1371.png train/dining_room_0012/filled_1371.png +train/dining_room_0014/rgb_1377.png train/dining_room_0014/depth_1377.png train/dining_room_0014/filled_1377.png +train/dining_room_0015/rgb_1379.png train/dining_room_0015/depth_1379.png train/dining_room_0015/filled_1379.png +train/dining_room_0016/rgb_1382.png train/dining_room_0016/depth_1382.png train/dining_room_0016/filled_1382.png +train/dining_room_0031/rgb_1425.png train/dining_room_0031/depth_1425.png train/dining_room_0031/filled_1425.png +train/dining_room_0031/rgb_1426.png train/dining_room_0031/depth_1426.png train/dining_room_0031/filled_1426.png +train/dining_room_0033/rgb_1436.png train/dining_room_0033/depth_1436.png train/dining_room_0033/filled_1436.png +train/dining_room_0037/rgb_0548.png train/dining_room_0037/depth_0548.png train/dining_room_0037/filled_0548.png +train/furniture_store_0001/rgb_0224.png train/furniture_store_0001/depth_0224.png train/furniture_store_0001/filled_0224.png +train/furniture_store_0001/rgb_0237.png train/furniture_store_0001/depth_0237.png train/furniture_store_0001/filled_0237.png +train/furniture_store_0002/rgb_0249.png train/furniture_store_0002/depth_0249.png train/furniture_store_0002/filled_0249.png +train/home_office_0005/rgb_0368.png train/home_office_0005/depth_0368.png train/home_office_0005/filled_0368.png +train/home_office_0006/rgb_0374.png train/home_office_0006/depth_0374.png train/home_office_0006/filled_0374.png +train/home_office_0008/rgb_0380.png train/home_office_0008/depth_0380.png train/home_office_0008/filled_0380.png +train/home_office_0013/rgb_0554.png train/home_office_0013/depth_0554.png train/home_office_0013/filled_0554.png +train/kitchen_0010/rgb_0138.png train/kitchen_0010/depth_0138.png train/kitchen_0010/filled_0138.png +train/kitchen_0019/rgb_0750.png train/kitchen_0019/depth_0750.png train/kitchen_0019/filled_0750.png +train/kitchen_0019/rgb_0757.png train/kitchen_0019/depth_0757.png train/kitchen_0019/filled_0757.png +train/kitchen_0028/rgb_0788.png train/kitchen_0028/depth_0788.png train/kitchen_0028/filled_0788.png +train/kitchen_0028/rgb_0793.png train/kitchen_0028/depth_0793.png train/kitchen_0028/filled_0793.png +train/kitchen_0029/rgb_0799.png train/kitchen_0029/depth_0799.png train/kitchen_0029/filled_0799.png +train/kitchen_0033/rgb_0815.png train/kitchen_0033/depth_0815.png train/kitchen_0033/filled_0815.png +train/kitchen_0033/rgb_0816.png train/kitchen_0033/depth_0816.png train/kitchen_0033/filled_0816.png +train/kitchen_0037/rgb_0832.png train/kitchen_0037/depth_0832.png train/kitchen_0037/filled_0832.png +train/kitchen_0041/rgb_0849.png train/kitchen_0041/depth_0849.png train/kitchen_0041/filled_0849.png +train/kitchen_0047/rgb_0875.png train/kitchen_0047/depth_0875.png train/kitchen_0047/filled_0875.png +train/kitchen_0050/rgb_0887.png train/kitchen_0050/depth_0887.png train/kitchen_0050/filled_0887.png +train/kitchen_0051/rgb_0892.png train/kitchen_0051/depth_0892.png train/kitchen_0051/filled_0892.png +train/kitchen_0051/rgb_0893.png train/kitchen_0051/depth_0893.png train/kitchen_0051/filled_0893.png +train/kitchen_0052/rgb_0899.png train/kitchen_0052/depth_0899.png train/kitchen_0052/filled_0899.png +train/kitchen_0059/rgb_0573.png train/kitchen_0059/depth_0573.png train/kitchen_0059/filled_0573.png +train/living_room_0000/rgb_0050.png train/living_room_0000/depth_0050.png train/living_room_0000/filled_0050.png +train/living_room_0010/rgb_0156.png train/living_room_0010/depth_0156.png train/living_room_0010/filled_0156.png +train/living_room_0010/rgb_0158.png train/living_room_0010/depth_0158.png train/living_room_0010/filled_0158.png +train/living_room_0010/rgb_0159.png train/living_room_0010/depth_0159.png train/living_room_0010/filled_0159.png +train/living_room_0011/rgb_0162.png train/living_room_0011/depth_0162.png train/living_room_0011/filled_0162.png +train/living_room_0019/rgb_0258.png train/living_room_0019/depth_0258.png train/living_room_0019/filled_0258.png +train/living_room_0042/rgb_1251.png train/living_room_0042/depth_1251.png train/living_room_0042/filled_1251.png +train/living_room_0046/rgb_1268.png train/living_room_0046/depth_1268.png train/living_room_0046/filled_1268.png +train/living_room_0047/rgb_1272.png train/living_room_0047/depth_1272.png train/living_room_0047/filled_1272.png +train/living_room_0058/rgb_1301.png train/living_room_0058/depth_1301.png train/living_room_0058/filled_1301.png +train/living_room_0062/rgb_1310.png train/living_room_0062/depth_1310.png train/living_room_0062/filled_1310.png +train/living_room_0063/rgb_1313.png train/living_room_0063/depth_1313.png train/living_room_0063/filled_1313.png +train/living_room_0083/rgb_0588.png train/living_room_0083/depth_0588.png train/living_room_0083/filled_0588.png +train/living_room_0086/rgb_0601.png train/living_room_0086/depth_0601.png train/living_room_0086/filled_0601.png +train/office_0003/rgb_0004.png train/office_0003/depth_0004.png train/office_0003/filled_0004.png +train/office_0023/rgb_0623.png train/office_0023/depth_0623.png train/office_0023/filled_0623.png +train/office_0024/rgb_0627.png train/office_0024/depth_0627.png train/office_0024/filled_0627.png +train/office_kitchen_0003/rgb_0415.png train/office_kitchen_0003/depth_0415.png train/office_kitchen_0003/filled_0415.png +train/playroom_0002/rgb_0418.png train/playroom_0002/depth_0418.png train/playroom_0002/filled_0418.png +train/playroom_0003/rgb_0423.png train/playroom_0003/depth_0423.png train/playroom_0003/filled_0423.png +train/playroom_0003/rgb_0424.png train/playroom_0003/depth_0424.png train/playroom_0003/filled_0424.png +train/playroom_0003/rgb_0425.png train/playroom_0003/depth_0425.png train/playroom_0003/filled_0425.png +train/playroom_0004/rgb_0426.png train/playroom_0004/depth_0426.png train/playroom_0004/filled_0426.png +train/printer_room_0001/rgb_0451.png train/printer_room_0001/depth_0451.png train/printer_room_0001/filled_0451.png +train/reception_room_0001/rgb_0456.png train/reception_room_0001/depth_0456.png train/reception_room_0001/filled_0456.png +train/reception_room_0002/rgb_0459.png train/reception_room_0002/depth_0459.png train/reception_room_0002/filled_0459.png +train/reception_room_0004/rgb_0468.png train/reception_room_0004/depth_0468.png train/reception_room_0004/filled_0468.png +train/student_lounge_0001/rgb_0641.png train/student_lounge_0001/depth_0641.png train/student_lounge_0001/filled_0641.png +train/study_0003/rgb_0478.png train/study_0003/depth_0478.png train/study_0003/filled_0478.png +train/study_0005/rgb_0485.png train/study_0005/depth_0485.png train/study_0005/filled_0485.png +train/study_0008/rgb_0646.png train/study_0008/depth_0646.png train/study_0008/filled_0646.png +train/study_room_0004/rgb_0274.png train/study_room_0004/depth_0274.png train/study_room_0004/filled_0274.png \ No newline at end of file diff --git a/infer.py b/infer.py index 5db976f..4f88721 100644 --- a/infer.py +++ b/infer.py @@ -31,7 +31,7 @@ from tqdm.auto import tqdm from marigold import MarigoldPipeline -from src.util.seed_all import seed_all +from src.util.seeding import seed_all from src.dataset import ( BaseDepthDataset, DatasetMode, diff --git a/requirements++.txt b/requirements++.txt new file mode 100644 index 0000000..492dd06 --- /dev/null +++ b/requirements++.txt @@ -0,0 +1,6 @@ +h5py +opencv-python +tensorboard +wandb +xformers==0.0.21 + diff --git a/script/eval/11_infer_nyu.sh b/script/eval/11_infer_nyu.sh index 193d302..ad36be5 100644 --- a/script/eval/11_infer_nyu.sh +++ b/script/eval/11_infer_nyu.sh @@ -2,12 +2,16 @@ set -e set -x +# Use specified checkpoint path, otherwise, default value +ckpt=${1:-"prs-eth/marigold-v1-0"} +subfolder=${2:-"eval"} python infer.py \ + --checkpoint $ckpt \ --seed 1234 \ --base_data_dir $BASE_DATA_DIR \ --denoise_steps 50 \ --ensemble_size 10 \ --processing_res 0 \ --dataset_config config/dataset/data_nyu_test.yaml \ - --output_dir output/nyu_test/prediction \ + --output_dir output/${subfolder}/nyu_test/prediction \ diff --git a/script/eval/12_eval_nyu.sh b/script/eval/12_eval_nyu.sh index 61f7d30..69e52b9 100644 --- a/script/eval/12_eval_nyu.sh +++ b/script/eval/12_eval_nyu.sh @@ -2,10 +2,11 @@ set -e set -x +subfolder=${1:-"eval"} python eval.py \ --base_data_dir $BASE_DATA_DIR \ --dataset_config config/dataset/data_nyu_test.yaml \ --alignment least_square \ - --prediction_dir output/nyu_test/prediction \ - --output_dir output/nyu_test/eval_metric \ + --prediction_dir output/${subfolder}/nyu_test/prediction \ + --output_dir output/${subfolder}/nyu_test/eval_metric \ diff --git a/script/eval/21_infer_kitti.sh b/script/eval/21_infer_kitti.sh index 924ccf6..0eb6634 100644 --- a/script/eval/21_infer_kitti.sh +++ b/script/eval/21_infer_kitti.sh @@ -2,12 +2,16 @@ set -e set -x +# Use specified checkpoint path, otherwise, default value +ckpt=${1:-"prs-eth/marigold-v1-0"} +subfolder=${2:-"eval"} -python infer.py \ +python infer.py \ + --checkpoint $ckpt \ --seed 1234 \ --base_data_dir $BASE_DATA_DIR \ --denoise_steps 50 \ --ensemble_size 10 \ --processing_res 0 \ --dataset_config config/dataset/data_kitti_eigen_test.yaml \ - --output_dir output/kitti_eigen_test/prediction \ + --output_dir output/${subfolder}/kitti_eigen_test/prediction \ diff --git a/script/eval/22_eval_kitti.sh b/script/eval/22_eval_kitti.sh index 42793d2..69828e7 100644 --- a/script/eval/22_eval_kitti.sh +++ b/script/eval/22_eval_kitti.sh @@ -2,10 +2,11 @@ set -e set -x +subfolder=${1:-"eval"} python eval.py \ --base_data_dir $BASE_DATA_DIR \ --dataset_config config/dataset/data_kitti_eigen_test.yaml \ --alignment least_square \ - --prediction_dir output/kitti_eigen_test/prediction \ - --output_dir output/kitti_eigen_test/eval_metric \ + --prediction_dir output/${subfolder}/kitti_eigen_test/prediction \ + --output_dir output/${subfolder}/kitti_eigen_test/eval_metric \ diff --git a/script/eval/31_infer_eth3d.sh b/script/eval/31_infer_eth3d.sh index ff7ea2f..1dc7efa 100644 --- a/script/eval/31_infer_eth3d.sh +++ b/script/eval/31_infer_eth3d.sh @@ -2,13 +2,17 @@ set -e set -x +# Use specified checkpoint path, otherwise, default value +ckpt=${1:-"prs-eth/marigold-v1-0"} +subfolder=${2:-"eval"} -python infer.py \ +python infer.py \ + --checkpoint $ckpt \ --seed 1234 \ --base_data_dir $BASE_DATA_DIR \ --denoise_steps 50 \ --ensemble_size 10 \ --dataset_config config/dataset/data_eth3d.yaml \ - --output_dir output/eth3d/prediction \ + --output_dir output/${subfolder}/eth3d/prediction \ --processing_res 756 \ --resample_method bilinear \ \ No newline at end of file diff --git a/script/eval/32_eval_eth3d.sh b/script/eval/32_eval_eth3d.sh index a1aa554..f25c346 100644 --- a/script/eval/32_eval_eth3d.sh +++ b/script/eval/32_eval_eth3d.sh @@ -2,11 +2,12 @@ set -e set -x +subfolder=${1:-"eval"} python eval.py \ --base_data_dir $BASE_DATA_DIR \ --dataset_config config/dataset/data_eth3d.yaml \ --alignment least_square \ - --prediction_dir output/eth3d/prediction \ - --output_dir output/eth3d/eval_metric \ + --prediction_dir output/${subfolder}/eth3d/prediction \ + --output_dir output/${subfolder}/eth3d/eval_metric \ --alignment_max_res 1024 \ \ No newline at end of file diff --git a/script/eval/41_infer_scannet.sh b/script/eval/41_infer_scannet.sh index 0734bae..15006ea 100644 --- a/script/eval/41_infer_scannet.sh +++ b/script/eval/41_infer_scannet.sh @@ -2,12 +2,16 @@ set -e set -x +# Use specified checkpoint path, otherwise, default value +ckpt=${1:-"prs-eth/marigold-v1-0"} +subfolder=${2:-"eval"} -python infer.py \ +python infer.py \ + --checkpoint $ckpt \ --seed 1234 \ --base_data_dir $BASE_DATA_DIR \ --denoise_steps 50 \ --ensemble_size 10 \ --processing_res 0 \ --dataset_config config/dataset/data_scannet_val.yaml \ - --output_dir output/scannet/prediction \ + --output_dir output/${subfolder}/scannet/prediction \ diff --git a/script/eval/42_eval_scannet.sh b/script/eval/42_eval_scannet.sh index ea78c9a..da4c784 100644 --- a/script/eval/42_eval_scannet.sh +++ b/script/eval/42_eval_scannet.sh @@ -2,10 +2,11 @@ set -e set -x +subfolder=${1:-"eval"} python eval.py \ --base_data_dir $BASE_DATA_DIR \ --dataset_config config/dataset/data_scannet_val.yaml \ --alignment least_square \ - --prediction_dir output/scannet/prediction \ - --output_dir output/scannet/eval_metric \ + --prediction_dir output/${subfolder}/scannet/prediction \ + --output_dir output/${subfolder}/scannet/eval_metric \ diff --git a/script/eval/51_infer_diode.sh b/script/eval/51_infer_diode.sh index d0c9fca..2ec6cc2 100644 --- a/script/eval/51_infer_diode.sh +++ b/script/eval/51_infer_diode.sh @@ -2,13 +2,17 @@ set -e set -x +# Use specified checkpoint path, otherwise, default value +ckpt=${1:-"prs-eth/marigold-v1-0"} +subfolder=${2:-"eval"} -python infer.py \ +python infer.py \ + --checkpoint $ckpt \ --seed 1234 \ --base_data_dir $BASE_DATA_DIR \ --denoise_steps 50 \ --ensemble_size 10 \ --dataset_config config/dataset/data_diode_all.yaml \ - --output_dir output/diode/prediction \ + --output_dir output/${subfolder}/diode/prediction \ --processing_res 640 \ --resample_method bilinear \ diff --git a/script/eval/52_eval_diode.sh b/script/eval/52_eval_diode.sh index 9674353..c10c672 100644 --- a/script/eval/52_eval_diode.sh +++ b/script/eval/52_eval_diode.sh @@ -2,10 +2,11 @@ set -e set -x +subfolder=${1:-"eval"} python eval.py \ --base_data_dir $BASE_DATA_DIR \ --dataset_config config/dataset/data_diode_all.yaml \ --alignment least_square \ - --prediction_dir output/diode/prediction \ - --output_dir output/diode/eval_metric \ + --prediction_dir output/${subfolder}/diode/prediction \ + --output_dir output/${subfolder}/diode/eval_metric \ diff --git a/src/dataset/__init__.py b/src/dataset/__init__.py index dd7d502..6c57add 100644 --- a/src/dataset/__init__.py +++ b/src/dataset/__init__.py @@ -1,17 +1,40 @@ -# Author: Bingxin Ke -# Last modified: 2024-03-30 +# Last modified: 2024-04-16 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- import os from .base_depth_dataset import BaseDepthDataset, get_pred_name, DatasetMode # noqa: F401 from .diode_dataset import DIODEDataset from .eth3d_dataset import ETH3DDataset +from .hypersim_dataset import HypersimDataset from .kitti_dataset import KITTIDataset from .nyu_dataset import NYUDataset from .scannet_dataset import ScanNetDataset +from .vkitti_dataset import VirtualKITTIDataset dataset_name_class_dict = { + "hypersim": HypersimDataset, + "vkitti": VirtualKITTIDataset, "nyu_v2": NYUDataset, "kitti": KITTIDataset, "eth3d": ETH3DDataset, @@ -23,7 +46,14 @@ def get_dataset( cfg_data_split, base_data_dir: str, mode: DatasetMode, **kwargs ) -> BaseDepthDataset: - if cfg_data_split.name in dataset_name_class_dict.keys(): + if "mixed" == cfg_data_split.name: + assert DatasetMode.TRAIN == mode, "Only training mode supports mixed datasets." + dataset_ls = [ + get_dataset(_cfg, base_data_dir, mode, **kwargs) + for _cfg in cfg_data_split.dataset_list + ] + return dataset_ls + elif cfg_data_split.name in dataset_name_class_dict.keys(): dataset_class = dataset_name_class_dict[cfg_data_split.name] dataset = dataset_class( mode=mode, diff --git a/src/dataset/base_depth_dataset.py b/src/dataset/base_depth_dataset.py index 878c03a..11efb7f 100644 --- a/src/dataset/base_depth_dataset.py +++ b/src/dataset/base_depth_dataset.py @@ -1,11 +1,31 @@ -# Author: Bingxin Ke -# Last modified: 2024-04-15 +# Last modified: 2024-04-30 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- import io import os import random import tarfile from enum import Enum +from typing import Union import numpy as np import torch @@ -13,6 +33,8 @@ from torch.utils.data import Dataset from torchvision.transforms import InterpolationMode, Resize +from src.util.depth_transform import DepthNormalizerBase + class DatasetMode(Enum): RGB_ONLY = "rgb_only" @@ -20,6 +42,15 @@ class DatasetMode(Enum): TRAIN = "train" +class DepthFileNameMode(Enum): + """Prediction file naming modes""" + + id = 1 # id.png + rgb_id = 2 # rgb_id.png + i_d_rgb = 3 # i_d_1_rgb.png + rgb_i_d = 4 + + def read_image_from_tar(tar_obj, img_rel_path): image = tar_obj.extractfile("./" + img_rel_path) image = image.read() @@ -33,11 +64,11 @@ def __init__( filename_ls_path: str, dataset_dir: str, disp_name: str, - min_depth, - max_depth, - has_filled_depth, - name_mode, - depth_transform=None, + min_depth: float, + max_depth: float, + has_filled_depth: bool, + name_mode: DepthFileNameMode, + depth_transform: Union[DepthNormalizerBase, None] = None, augmentation_args: dict = None, resize_to_hw=None, move_invalid_to_far_plane: bool = True, @@ -49,6 +80,9 @@ def __init__( # dataset info self.filename_ls_path = filename_ls_path self.dataset_dir = dataset_dir + assert os.path.exists( + self.dataset_dir + ), f"Dataset does not exist at: {self.dataset_dir}" self.disp_name = disp_name self.has_filled_depth = has_filled_depth self.name_mode: DepthFileNameMode = name_mode @@ -56,7 +90,7 @@ def __init__( self.max_depth = max_depth # training arguments - self.depth_transform = depth_transform + self.depth_transform: DepthNormalizerBase = depth_transform self.augm_args = augmentation_args self.resize_to_hw = resize_to_hw self.rgb_transform = rgb_transform @@ -118,9 +152,11 @@ def _get_data_item(self, index): def _load_rgb_data(self, rgb_rel_path): # Read RGB data rgb = self._read_rgb_file(rgb_rel_path) + rgb_norm = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] outputs = { "rgb_int": torch.from_numpy(rgb).int(), + "rgb_norm": torch.from_numpy(rgb_norm).float(), } return outputs @@ -157,12 +193,12 @@ def _read_image(self, img_rel_path) -> np.ndarray: if self.is_tar: if self.tar_obj is None: self.tar_obj = tarfile.open(self.dataset_dir) - image = self.tar_obj.extractfile("./" + img_rel_path) - image = image.read() - image = Image.open(io.BytesIO(image)) # [H, W, rgb] + image_to_read = self.tar_obj.extractfile("./" + img_rel_path) + image_to_read = image_to_read.read() + image_to_read = io.BytesIO(image_to_read) else: - img_path = os.path.join(self.dataset_dir, img_rel_path) - image = Image.open(img_path) + image_to_read = os.path.join(self.dataset_dir, img_rel_path) + image = Image.open(image_to_read) # [H, W, rgb] image = np.asarray(image) return image @@ -226,19 +262,11 @@ def _augment_data(self, rasters_dict): return rasters_dict def __del__(self): - if self.tar_obj is not None: + if hasattr(self, "tar_obj") and self.tar_obj is not None: self.tar_obj.close() self.tar_obj = None -# Prediction file naming modes -class DepthFileNameMode(Enum): - id = 1 # id.png - rgb_id = 2 # rgb_id.png - i_d_rgb = 3 # i_d_1_rgb.png - rgb_i_d = 4 - - def get_pred_name(rgb_basename, name_mode, suffix=".png"): if DepthFileNameMode.rgb_id == name_mode: pred_basename = "pred_" + rgb_basename.split("_")[1] diff --git a/src/dataset/diode_dataset.py b/src/dataset/diode_dataset.py index 81dc62f..509fb10 100644 --- a/src/dataset/diode_dataset.py +++ b/src/dataset/diode_dataset.py @@ -1,5 +1,24 @@ -# Author: Bingxin Ke # Last modified: 2024-02-26 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- import os import tarfile diff --git a/src/dataset/eth3d_dataset.py b/src/dataset/eth3d_dataset.py index 34acce0..02810b8 100644 --- a/src/dataset/eth3d_dataset.py +++ b/src/dataset/eth3d_dataset.py @@ -1,5 +1,24 @@ -# Author: Bingxin Ke # Last modified: 2024-02-08 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- import torch import tarfile diff --git a/src/dataset/hypersim_dataset.py b/src/dataset/hypersim_dataset.py new file mode 100644 index 0000000..886a2be --- /dev/null +++ b/src/dataset/hypersim_dataset.py @@ -0,0 +1,45 @@ +# Last modified: 2024-02-08 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + + +from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode + + +class HypersimDataset(BaseDepthDataset): + def __init__( + self, + **kwargs, + ) -> None: + super().__init__( + # Hypersim data parameter + min_depth=1e-5, + max_depth=65.0, + has_filled_depth=False, + name_mode=DepthFileNameMode.rgb_i_d, + **kwargs, + ) + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Decode Hypersim depth + depth_decoded = depth_in / 1000.0 + return depth_decoded diff --git a/src/dataset/kitti_dataset.py b/src/dataset/kitti_dataset.py index 7cde0bb..5daa760 100644 --- a/src/dataset/kitti_dataset.py +++ b/src/dataset/kitti_dataset.py @@ -1,5 +1,24 @@ -# Author: Bingxin Ke # Last modified: 2024-02-08 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- import torch diff --git a/src/dataset/mixed_sampler.py b/src/dataset/mixed_sampler.py new file mode 100644 index 0000000..3abc60f --- /dev/null +++ b/src/dataset/mixed_sampler.py @@ -0,0 +1,149 @@ +# Last modified: 2024-04-18 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + +import torch +from torch.utils.data import ( + BatchSampler, + RandomSampler, + SequentialSampler, +) + + +class MixedBatchSampler(BatchSampler): + """Sample one batch from a selected dataset with given probability. + Compatible with datasets at different resolution + """ + + def __init__( + self, src_dataset_ls, batch_size, drop_last, shuffle, prob=None, generator=None + ): + self.base_sampler = None + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + self.generator = generator + + self.src_dataset_ls = src_dataset_ls + self.n_dataset = len(self.src_dataset_ls) + + # Dataset length + self.dataset_length = [len(ds) for ds in self.src_dataset_ls] + self.cum_dataset_length = [ + sum(self.dataset_length[:i]) for i in range(self.n_dataset) + ] # cumulative dataset length + + # BatchSamplers for each source dataset + if self.shuffle: + self.src_batch_samplers = [ + BatchSampler( + sampler=RandomSampler( + ds, replacement=False, generator=self.generator + ), + batch_size=self.batch_size, + drop_last=self.drop_last, + ) + for ds in self.src_dataset_ls + ] + else: + self.src_batch_samplers = [ + BatchSampler( + sampler=SequentialSampler(ds), + batch_size=self.batch_size, + drop_last=self.drop_last, + ) + for ds in self.src_dataset_ls + ] + self.raw_batches = [ + list(bs) for bs in self.src_batch_samplers + ] # index in original dataset + self.n_batches = [len(b) for b in self.raw_batches] + self.n_total_batch = sum(self.n_batches) + + # sampling probability + if prob is None: + # if not given, decide by dataset length + self.prob = torch.tensor(self.n_batches) / self.n_total_batch + else: + self.prob = torch.as_tensor(prob) + + def __iter__(self): + """_summary_ + + Yields: + list(int): a batch of indics, corresponding to ConcatDataset of src_dataset_ls + """ + for _ in range(self.n_total_batch): + idx_ds = torch.multinomial( + self.prob, 1, replacement=True, generator=self.generator + ).item() + # if batch list is empty, generate new list + if 0 == len(self.raw_batches[idx_ds]): + self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds]) + # get a batch from list + batch_raw = self.raw_batches[idx_ds].pop() + # shift by cumulative dataset length + shift = self.cum_dataset_length[idx_ds] + batch = [n + shift for n in batch_raw] + + yield batch + + def __len__(self): + return self.n_total_batch + + +# Unit test +if "__main__" == __name__: + from torch.utils.data import ConcatDataset, DataLoader, Dataset + + class SimpleDataset(Dataset): + def __init__(self, start, len) -> None: + super().__init__() + self.start = start + self.len = len + + def __len__(self): + return self.len + + def __getitem__(self, index): + return self.start + index + + dataset_1 = SimpleDataset(0, 10) + dataset_2 = SimpleDataset(200, 20) + dataset_3 = SimpleDataset(1000, 50) + + concat_dataset = ConcatDataset( + [dataset_1, dataset_2, dataset_3] + ) # will directly concatenate + + mixed_sampler = MixedBatchSampler( + src_dataset_ls=[dataset_1, dataset_2, dataset_3], + batch_size=4, + drop_last=True, + shuffle=False, + prob=[0.6, 0.3, 0.1], + generator=torch.Generator().manual_seed(0), + ) + + loader = DataLoader(concat_dataset, batch_sampler=mixed_sampler) + + for d in loader: + print(d) diff --git a/src/dataset/nyu_dataset.py b/src/dataset/nyu_dataset.py index e073c4c..e3f1e80 100644 --- a/src/dataset/nyu_dataset.py +++ b/src/dataset/nyu_dataset.py @@ -1,6 +1,24 @@ -# Author: Bingxin Ke # Last modified: 2024-02-08 - +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- import torch diff --git a/src/dataset/scannet_dataset.py b/src/dataset/scannet_dataset.py index 251ba66..c401f92 100644 --- a/src/dataset/scannet_dataset.py +++ b/src/dataset/scannet_dataset.py @@ -1,5 +1,24 @@ -# Author: Bingxin Ke # Last modified: 2024-02-08 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode diff --git a/src/dataset/vkitti_dataset.py b/src/dataset/vkitti_dataset.py new file mode 100644 index 0000000..ceb7903 --- /dev/null +++ b/src/dataset/vkitti_dataset.py @@ -0,0 +1,98 @@ +# Last modified: 2024-02-08 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + +import torch + +from .base_depth_dataset import BaseDepthDataset, DepthFileNameMode +from .kitti_dataset import KITTIDataset + + +class VirtualKITTIDataset(BaseDepthDataset): + def __init__( + self, + kitti_bm_crop, # Crop to KITTI benchmark size + valid_mask_crop, # Evaluation mask. [None, garg or eigen] + **kwargs, + ) -> None: + super().__init__( + # virtual KITTI data parameter + min_depth=1e-5, + max_depth=80, # 655.35 + has_filled_depth=False, + name_mode=DepthFileNameMode.id, + **kwargs, + ) + self.kitti_bm_crop = kitti_bm_crop + self.valid_mask_crop = valid_mask_crop + assert self.valid_mask_crop in [ + None, + "garg", # set evaluation mask according to Garg ECCV16 + "eigen", # set evaluation mask according to Eigen NIPS14 + ], f"Unknown crop type: {self.valid_mask_crop}" + + # Filter out empty depth + self.filenames = [f for f in self.filenames if "None" != f[1]] + + def _read_depth_file(self, rel_path): + depth_in = self._read_image(rel_path) + # Decode vKITTI depth + depth_decoded = depth_in / 100.0 + return depth_decoded + + def _load_rgb_data(self, rgb_rel_path): + rgb_data = super()._load_rgb_data(rgb_rel_path) + if self.kitti_bm_crop: + rgb_data = { + k: KITTIDataset.kitti_benchmark_crop(v) for k, v in rgb_data.items() + } + return rgb_data + + def _load_depth_data(self, depth_rel_path, filled_rel_path): + depth_data = super()._load_depth_data(depth_rel_path, filled_rel_path) + if self.kitti_bm_crop: + depth_data = { + k: KITTIDataset.kitti_benchmark_crop(v) for k, v in depth_data.items() + } + return depth_data + + def _get_valid_mask(self, depth: torch.Tensor): + # reference: https://github.com/cleinc/bts/blob/master/pytorch/bts_eval.py + valid_mask = super()._get_valid_mask(depth) # [1, H, W] + + if self.valid_mask_crop is not None: + eval_mask = torch.zeros_like(valid_mask.squeeze()).bool() + gt_height, gt_width = eval_mask.shape + + if "garg" == self.valid_mask_crop: + eval_mask[ + int(0.40810811 * gt_height) : int(0.99189189 * gt_height), + int(0.03594771 * gt_width) : int(0.96405229 * gt_width), + ] = 1 + elif "eigen" == self.valid_mask_crop: + eval_mask[ + int(0.3324324 * gt_height) : int(0.91351351 * gt_height), + int(0.0359477 * gt_width) : int(0.96405229 * gt_width), + ] = 1 + + eval_mask.reshape(valid_mask.shape) + valid_mask = torch.logical_and(valid_mask, eval_mask) + return valid_mask diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py new file mode 100644 index 0000000..435ea49 --- /dev/null +++ b/src/trainer/__init__.py @@ -0,0 +1,13 @@ +# Author: Bingxin Ke +# Last modified: 2024-05-17 + +from .marigold_trainer import MarigoldTrainer + + +trainer_cls_name_dict = { + "MarigoldTrainer": MarigoldTrainer, +} + + +def get_trainer_cls(trainer_name): + return trainer_cls_name_dict[trainer_name] diff --git a/src/trainer/marigold_trainer.py b/src/trainer/marigold_trainer.py new file mode 100644 index 0000000..a1596f8 --- /dev/null +++ b/src/trainer/marigold_trainer.py @@ -0,0 +1,674 @@ +# An official reimplemented version of Marigold training script. +# Last modified: 2024-04-29 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + + +import logging +import os +import shutil +from datetime import datetime +from typing import List, Union + +import numpy as np +import torch +from diffusers import DDPMScheduler +from omegaconf import OmegaConf +from torch.nn import Conv2d +from torch.nn.parameter import Parameter +from torch.optim import Adam +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader +from tqdm import tqdm +from PIL import Image + +from marigold.marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput +from src.util import metric +from src.util.data_loader import skip_first_batches +from src.util.logging_util import tb_logger, eval_dic_to_text +from src.util.loss import get_loss +from src.util.lr_scheduler import IterExponential +from src.util.metric import MetricTracker +from src.util.multi_res_noise import multi_res_noise_like +from src.util.alignment import align_depth_least_square +from src.util.seeding import generate_seed_sequence + + +class MarigoldTrainer: + def __init__( + self, + cfg: OmegaConf, + model: MarigoldPipeline, + train_dataloader: DataLoader, + device, + base_ckpt_dir, + out_dir_ckpt, + out_dir_eval, + out_dir_vis, + accumulation_steps: int, + val_dataloaders: List[DataLoader] = None, + vis_dataloaders: List[DataLoader] = None, + ): + self.cfg: OmegaConf = cfg + self.model: MarigoldPipeline = model + self.device = device + self.seed: Union[int, None] = ( + self.cfg.trainer.init_seed + ) # used to generate seed sequence, set to `None` to train w/o seeding + self.out_dir_ckpt = out_dir_ckpt + self.out_dir_eval = out_dir_eval + self.out_dir_vis = out_dir_vis + self.train_loader: DataLoader = train_dataloader + self.val_loaders: List[DataLoader] = val_dataloaders + self.vis_loaders: List[DataLoader] = vis_dataloaders + self.accumulation_steps: int = accumulation_steps + + # Adapt input layers + if 8 != self.model.unet.config["in_channels"]: + self._replace_unet_conv_in() + + # Encode empty text prompt + self.model.encode_empty_text() + self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) + + self.model.unet.enable_xformers_memory_efficient_attention() + + # Trainability + self.model.vae.requires_grad_(False) + self.model.text_encoder.requires_grad_(False) + self.model.unet.requires_grad_(True) + + # Optimizer !should be defined after input layer is adapted + lr = self.cfg.lr + self.optimizer = Adam(self.model.unet.parameters(), lr=lr) + + # LR scheduler + lr_func = IterExponential( + total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, + final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, + warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, + ) + self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) + + # Loss + self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) + + # Training noise scheduler + self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained( + os.path.join( + base_ckpt_dir, + cfg.trainer.training_noise_scheduler.pretrained_path, + "scheduler", + ) + ) + self.prediction_type = self.training_noise_scheduler.config.prediction_type + assert ( + self.prediction_type == self.model.scheduler.config.prediction_type + ), "Different prediction types" + self.scheduler_timesteps = ( + self.training_noise_scheduler.config.num_train_timesteps + ) + + # Eval metrics + self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics] + self.train_metrics = MetricTracker(*["loss"]) + self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs]) + # main metric for best checkpoint saving + self.main_val_metric = cfg.validation.main_val_metric + self.main_val_metric_goal = cfg.validation.main_val_metric_goal + assert ( + self.main_val_metric in cfg.eval.eval_metrics + ), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." + self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 + + # Settings + self.max_epoch = self.cfg.max_epoch + self.max_iter = self.cfg.max_iter + self.gradient_accumulation_steps = accumulation_steps + self.gt_depth_type = self.cfg.gt_depth_type + self.gt_mask_type = self.cfg.gt_mask_type + self.save_period = self.cfg.trainer.save_period + self.backup_period = self.cfg.trainer.backup_period + self.val_period = self.cfg.trainer.validation_period + self.vis_period = self.cfg.trainer.visualization_period + + # Multi-resolution noise + self.apply_multi_res_noise = self.cfg.multi_res_noise is not None + if self.apply_multi_res_noise: + self.mr_noise_strength = self.cfg.multi_res_noise.strength + self.annealed_mr_noise = self.cfg.multi_res_noise.annealed + self.mr_noise_downscale_strategy = ( + self.cfg.multi_res_noise.downscale_strategy + ) + + # Internal variables + self.epoch = 1 + self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training + self.effective_iter = 0 # how many times optimizer.step() is called + self.in_evaluation = False + self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming + + def _replace_unet_conv_in(self): + # replace the first layer to accept 8 in_channels + _weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3] + _bias = self.model.unet.conv_in.bias.clone() # [320] + _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s) + # half the activation magnitude + _weight *= 0.5 + # new conv_in channel + _n_convin_out_channel = self.model.unet.conv_in.out_channels + _new_conv_in = Conv2d( + 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + _new_conv_in.weight = Parameter(_weight) + _new_conv_in.bias = Parameter(_bias) + self.model.unet.conv_in = _new_conv_in + logging.info("Unet conv_in layer is replaced") + # replace config + self.model.unet.config["in_channels"] = 8 + logging.info("Unet config is updated") + return + + def train(self, t_end=None): + logging.info("Start training") + + device = self.device + self.model.to(device) + + if self.in_evaluation: + logging.info( + "Last evaluation was not finished, will do evaluation before continue training." + ) + self.validate() + + self.train_metrics.reset() + accumulated_step = 0 + + for epoch in range(self.epoch, self.max_epoch + 1): + self.epoch = epoch + logging.debug(f"epoch: {self.epoch}") + + # Skip previous batches when resume + for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): + self.model.unet.train() + + # globally consistent random generators + if self.seed is not None: + local_seed = self._get_next_seed() + rand_num_generator = torch.Generator(device=device) + rand_num_generator.manual_seed(local_seed) + else: + rand_num_generator = None + + # >>> With gradient accumulation >>> + + # Get data + rgb = batch["rgb_norm"].to(device) + depth_gt_for_latent = batch[self.gt_depth_type].to(device) + + if self.gt_mask_type is not None: + valid_mask_for_latent = batch[self.gt_mask_type].to(device) + invalid_mask = ~valid_mask_for_latent + valid_mask_down = ~torch.max_pool2d( + invalid_mask.float(), 8, 8 + ).bool() + valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1)) + else: + raise NotImplementedError + + batch_size = rgb.shape[0] + + with torch.no_grad(): + # Encode image + rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w] + # Encode GT depth + gt_depth_latent = self.encode_depth( + depth_gt_for_latent + ) # [B, 4, h, w] + + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + self.scheduler_timesteps, + (batch_size,), + device=device, + generator=rand_num_generator, + ).long() # [B] + + # Sample noise + if self.apply_multi_res_noise: + strength = self.mr_noise_strength + if self.annealed_mr_noise: + # calculate strength depending on t + strength = strength * (timesteps / self.scheduler_timesteps) + noise = multi_res_noise_like( + gt_depth_latent, + strength=strength, + downscale_strategy=self.mr_noise_downscale_strategy, + generator=rand_num_generator, + device=device, + ) + else: + noise = torch.randn( + gt_depth_latent.shape, + device=device, + generator=rand_num_generator, + ) # [B, 4, h, w] + + # Add noise to the latents (diffusion forward process) + noisy_latents = self.training_noise_scheduler.add_noise( + gt_depth_latent, noise, timesteps + ) # [B, 4, h, w] + + # Text embedding + text_embed = self.empty_text_embed.to(device).repeat( + (batch_size, 1, 1) + ) # [B, 77, 1024] + + # Concat rgb and depth latents + cat_latents = torch.cat( + [rgb_latent, noisy_latents], dim=1 + ) # [B, 8, h, w] + cat_latents = cat_latents.float() + + # Predict the noise residual + model_pred = self.model.unet( + cat_latents, timesteps, text_embed + ).sample # [B, 4, h, w] + if torch.isnan(model_pred).any(): + logging.warning("model_pred contains NaN.") + + # Get the target for loss depending on the prediction type + if "sample" == self.prediction_type: + target = gt_depth_latent + elif "epsilon" == self.prediction_type: + target = noise + elif "v_prediction" == self.prediction_type: + target = self.training_noise_scheduler.get_velocity( + gt_depth_latent, noise, timesteps + ) # [B, 4, h, w] + else: + raise ValueError(f"Unknown prediction type {self.prediction_type}") + + # Masked latent loss + if self.gt_mask_type is not None: + latent_loss = self.loss( + model_pred[valid_mask_down].float(), + target[valid_mask_down].float(), + ) + else: + latent_loss = self.loss(model_pred.float(), target.float()) + + loss = latent_loss.mean() + + self.train_metrics.update("loss", loss.item()) + + loss = loss / self.gradient_accumulation_steps + loss.backward() + accumulated_step += 1 + + self.n_batch_in_epoch += 1 + # Practical batch end + + # Perform optimization step + if accumulated_step >= self.gradient_accumulation_steps: + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + accumulated_step = 0 + + self.effective_iter += 1 + + # Log to tensorboard + accumulated_loss = self.train_metrics.result()["loss"] + tb_logger.log_dic( + { + f"train/{k}": v + for k, v in self.train_metrics.result().items() + }, + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "lr", + self.lr_scheduler.get_last_lr()[0], + global_step=self.effective_iter, + ) + tb_logger.writer.add_scalar( + "n_batch_in_epoch", + self.n_batch_in_epoch, + global_step=self.effective_iter, + ) + logging.info( + f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}" + ) + self.train_metrics.reset() + + # Per-step callback + self._train_step_callback() + + # End of training + if self.max_iter > 0 and self.effective_iter >= self.max_iter: + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), + save_train_state=False, + ) + logging.info("Training ended.") + return + # Time's up + elif t_end is not None and datetime.now() >= t_end: + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + logging.info("Time is up, training paused.") + return + + torch.cuda.empty_cache() + # <<< Effective batch end <<< + + # Epoch end + self.n_batch_in_epoch = 0 + + def encode_depth(self, depth_in): + # stack depth into 3-channel + stacked = self.stack_depth_images(depth_in) + # encode using VAE encoder + depth_latent = self.model.encode_rgb(stacked) + return depth_latent + + @staticmethod + def stack_depth_images(depth_in): + if 4 == len(depth_in.shape): + stacked = depth_in.repeat(1, 3, 1, 1) + elif 3 == len(depth_in.shape): + stacked = depth_in.unsqueeze(1) + stacked = depth_in.repeat(1, 3, 1, 1) + return stacked + + def _train_step_callback(self): + """Executed after every iteration""" + # Save backup (with a larger interval, without training states) + if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), save_train_state=False + ) + + _is_latest_saved = False + # Validation + if self.val_period > 0 and 0 == self.effective_iter % self.val_period: + self.in_evaluation = True # flag to do evaluation in resume run if validation is not finished + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + _is_latest_saved = True + self.validate() + self.in_evaluation = False + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + + # Save training checkpoint (can be resumed) + if ( + self.save_period > 0 + and 0 == self.effective_iter % self.save_period + and not _is_latest_saved + ): + self.save_checkpoint(ckpt_name="latest", save_train_state=True) + + # Visualization + if self.vis_period > 0 and 0 == self.effective_iter % self.vis_period: + self.visualize() + + def validate(self): + for i, val_loader in enumerate(self.val_loaders): + val_dataset_name = val_loader.dataset.disp_name + val_metric_dic = self.validate_single_dataset( + data_loader=val_loader, metric_tracker=self.val_metrics + ) + logging.info( + f"Iter {self.effective_iter}. Validation metrics on `{val_dataset_name}`: {val_metric_dic}" + ) + tb_logger.log_dic( + {f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()}, + global_step=self.effective_iter, + ) + # save to file + eval_text = eval_dic_to_text( + val_metrics=val_metric_dic, + dataset_name=val_dataset_name, + sample_list_path=val_loader.dataset.filename_ls_path, + ) + _save_to = os.path.join( + self.out_dir_eval, + f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", + ) + with open(_save_to, "w+") as f: + f.write(eval_text) + + # Update main eval metric + if 0 == i: + main_eval_metric = val_metric_dic[self.main_val_metric] + if ( + "minimize" == self.main_val_metric_goal + and main_eval_metric < self.best_metric + or "maximize" == self.main_val_metric_goal + and main_eval_metric > self.best_metric + ): + self.best_metric = main_eval_metric + logging.info( + f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" + ) + # Save a checkpoint + self.save_checkpoint( + ckpt_name=self._get_backup_ckpt_name(), save_train_state=False + ) + + def visualize(self): + for val_loader in self.vis_loaders: + vis_dataset_name = val_loader.dataset.disp_name + vis_out_dir = os.path.join( + self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name + ) + os.makedirs(vis_out_dir, exist_ok=True) + _ = self.validate_single_dataset( + data_loader=val_loader, + metric_tracker=self.val_metrics, + save_to_dir=vis_out_dir, + ) + + @torch.no_grad() + def validate_single_dataset( + self, + data_loader: DataLoader, + metric_tracker: MetricTracker, + save_to_dir: str = None, + ): + self.model.to(self.device) + metric_tracker.reset() + + # Generate seed sequence for consistent evaluation + val_init_seed = self.cfg.validation.init_seed + val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) + + for i, batch in enumerate( + tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), + start=1, + ): + assert 1 == data_loader.batch_size + # Read input image + rgb_int = batch["rgb_int"].squeeze() # [3, H, W] + # GT depth + depth_raw_ts = batch["depth_raw_linear"].squeeze() + depth_raw = depth_raw_ts.numpy() + depth_raw_ts = depth_raw_ts.to(self.device) + valid_mask_ts = batch["valid_mask_raw"].squeeze() + valid_mask = valid_mask_ts.numpy() + valid_mask_ts = valid_mask_ts.to(self.device) + + # Random number generator + seed = val_seed_ls.pop() + if seed is None: + generator = None + else: + generator = torch.Generator(device=self.device) + generator.manual_seed(seed) + + # Predict depth + pipe_out: MarigoldDepthOutput = self.model( + rgb_int, + denoising_steps=self.cfg.validation.denoising_steps, + ensemble_size=self.cfg.validation.ensemble_size, + processing_res=self.cfg.validation.processing_res, + match_input_res=self.cfg.validation.match_input_res, + generator=generator, + batch_size=1, # use batch size 1 to increase reproducibility + color_map=None, + show_progress_bar=False, + resample_method=self.cfg.validation.resample_method, + ) + + depth_pred: np.ndarray = pipe_out.depth_np + + if "least_square" == self.cfg.eval.alignment: + depth_pred, scale, shift = align_depth_least_square( + gt_arr=depth_raw, + pred_arr=depth_pred, + valid_mask_arr=valid_mask, + return_scale_shift=True, + max_resolution=self.cfg.eval.align_max_res, + ) + else: + raise RuntimeError(f"Unknown alignment type: {self.cfg.eval.alignment}") + + # Clip to dataset min max + depth_pred = np.clip( + depth_pred, + a_min=data_loader.dataset.min_depth, + a_max=data_loader.dataset.max_depth, + ) + + # clip to d > 0 for evaluation + depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) + + # Evaluate + sample_metric = [] + depth_pred_ts = torch.from_numpy(depth_pred).to(self.device) + + for met_func in self.metric_funcs: + _metric_name = met_func.__name__ + _metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).item() + sample_metric.append(_metric.__str__()) + metric_tracker.update(_metric_name, _metric) + + # Save as 16-bit uint png + if save_to_dir is not None: + img_name = batch["rgb_relative_path"][0].replace("/", "_") + png_save_path = os.path.join(save_to_dir, f"{img_name}.png") + depth_to_save = (pipe_out.depth_np * 65535.0).astype(np.uint16) + Image.fromarray(depth_to_save).save(png_save_path, mode="I;16") + + return metric_tracker.result() + + def _get_next_seed(self): + if 0 == len(self.global_seed_sequence): + self.global_seed_sequence = generate_seed_sequence( + initial_seed=self.seed, + length=self.max_iter * self.gradient_accumulation_steps, + ) + logging.info( + f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" + ) + return self.global_seed_sequence.pop() + + def save_checkpoint(self, ckpt_name, save_train_state): + ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) + logging.info(f"Saving checkpoint to: {ckpt_dir}") + # Backup previous checkpoint + temp_ckpt_dir = None + if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): + temp_ckpt_dir = os.path.join( + os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" + ) + if os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + os.rename(ckpt_dir, temp_ckpt_dir) + logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") + + # Save UNet + unet_path = os.path.join(ckpt_dir, "unet") + self.model.unet.save_pretrained(unet_path, safe_serialization=False) + logging.info(f"UNet is saved to: {unet_path}") + + if save_train_state: + state = { + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.lr_scheduler.state_dict(), + "config": self.cfg, + "effective_iter": self.effective_iter, + "epoch": self.epoch, + "n_batch_in_epoch": self.n_batch_in_epoch, + "best_metric": self.best_metric, + "in_evaluation": self.in_evaluation, + "global_seed_sequence": self.global_seed_sequence, + } + train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") + torch.save(state, train_state_path) + # iteration indicator + f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") + f.close() + + logging.info(f"Trainer state is saved to: {train_state_path}") + + # Remove temp ckpt + if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): + shutil.rmtree(temp_ckpt_dir, ignore_errors=True) + logging.debug("Old checkpoint backup is removed.") + + def load_checkpoint( + self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True + ): + logging.info(f"Loading checkpoint from: {ckpt_path}") + # Load UNet + _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") + self.model.unet.load_state_dict( + torch.load(_model_path, map_location=self.device) + ) + self.model.unet.to(self.device) + logging.info(f"UNet parameters are loaded from {_model_path}") + + # Load training states + if load_trainer_state: + checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) + self.effective_iter = checkpoint["effective_iter"] + self.epoch = checkpoint["epoch"] + self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] + self.in_evaluation = checkpoint["in_evaluation"] + self.global_seed_sequence = checkpoint["global_seed_sequence"] + + self.best_metric = checkpoint["best_metric"] + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + logging.info(f"optimizer state is loaded from {ckpt_path}") + + if resume_lr_scheduler: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + logging.info(f"LR scheduler state is loaded from {ckpt_path}") + + logging.info( + f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" + ) + return + + def _get_backup_ckpt_name(self): + return f"iter_{self.effective_iter:06d}" diff --git a/src/util/config_util.py b/src/util/config_util.py new file mode 100644 index 0000000..0e9fa45 --- /dev/null +++ b/src/util/config_util.py @@ -0,0 +1,49 @@ +# Author: Bingxin Ke +# Last modified: 2024-02-14 + +import omegaconf +from omegaconf import OmegaConf + + +def recursive_load_config(config_path: str) -> OmegaConf: + conf = OmegaConf.load(config_path) + + output_conf = OmegaConf.create({}) + + # Load base config. Later configs on the list will overwrite previous + base_configs = conf.get("base_config", default_value=None) + if base_configs is not None: + assert isinstance(base_configs, omegaconf.listconfig.ListConfig) + for _path in base_configs: + assert ( + _path != config_path + ), "Circulate merging, base_config should not include itself." + _base_conf = recursive_load_config(_path) + output_conf = OmegaConf.merge(output_conf, _base_conf) + + # Merge configs and overwrite values + output_conf = OmegaConf.merge(output_conf, conf) + + return output_conf + + +def find_value_in_omegaconf(search_key, config): + result_list = [] + + if isinstance(config, omegaconf.DictConfig): + for key, value in config.items(): + if key == search_key: + result_list.append(value) + elif isinstance(value, (omegaconf.DictConfig, omegaconf.ListConfig)): + result_list.extend(find_value_in_omegaconf(search_key, value)) + elif isinstance(config, omegaconf.ListConfig): + for item in config: + if isinstance(item, (omegaconf.DictConfig, omegaconf.ListConfig)): + result_list.extend(find_value_in_omegaconf(search_key, item)) + + return result_list + + +if "__main__" == __name__: + conf = recursive_load_config("config/train_base.yaml") + print(OmegaConf.to_yaml(conf)) diff --git a/src/util/data_loader.py b/src/util/data_loader.py new file mode 100644 index 0000000..0fe42ab --- /dev/null +++ b/src/util/data_loader.py @@ -0,0 +1,111 @@ +# Copied from https://github.com/huggingface/accelerate/blob/e2ae254008061b3e53fc1c97f88d65743a857e75/src/accelerate/data_loader.py + +from torch.utils.data import BatchSampler, DataLoader, IterableDataset + +# kwargs of the DataLoader in min version 1.4.0. +_PYTORCH_DATALOADER_KWARGS = { + "batch_size": 1, + "shuffle": False, + "sampler": None, + "batch_sampler": None, + "num_workers": 0, + "collate_fn": None, + "pin_memory": False, + "drop_last": False, + "timeout": 0, + "worker_init_fn": None, + "multiprocessing_context": None, + "generator": None, + "prefetch_factor": 2, + "persistent_workers": False, +} + + +class SkipBatchSampler(BatchSampler): + """ + A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. + """ + + def __init__(self, batch_sampler, skip_batches=0): + self.batch_sampler = batch_sampler + self.skip_batches = skip_batches + + def __iter__(self): + for index, samples in enumerate(self.batch_sampler): + if index >= self.skip_batches: + yield samples + + @property + def total_length(self): + return len(self.batch_sampler) + + def __len__(self): + return len(self.batch_sampler) - self.skip_batches + + +class SkipDataLoader(DataLoader): + """ + Subclass of a PyTorch `DataLoader` that will skip the first batches. + + Args: + dataset (`torch.utils.data.dataset.Dataset`): + The dataset to use to build this datalaoder. + skip_batches (`int`, *optional*, defaults to 0): + The number of batches to skip at the beginning. + kwargs: + All other keyword arguments to pass to the regular `DataLoader` initialization. + """ + + def __init__(self, dataset, skip_batches=0, **kwargs): + super().__init__(dataset, **kwargs) + self.skip_batches = skip_batches + + def __iter__(self): + for index, batch in enumerate(super().__iter__()): + if index >= self.skip_batches: + yield batch + + +# Adapted from https://github.com/huggingface/accelerate +def skip_first_batches(dataloader, num_batches=0): + """ + Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. + """ + dataset = dataloader.dataset + sampler_is_batch_sampler = False + if isinstance(dataset, IterableDataset): + new_batch_sampler = None + else: + sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) + batch_sampler = ( + dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler + ) + new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches) + + # We ignore all of those since they are all dealt with by our new_batch_sampler + ignore_kwargs = [ + "batch_size", + "shuffle", + "sampler", + "batch_sampler", + "drop_last", + ] + + kwargs = { + k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) + for k in _PYTORCH_DATALOADER_KWARGS + if k not in ignore_kwargs + } + + # Need to provide batch_size as batch_sampler is None for Iterable dataset + if new_batch_sampler is None: + kwargs["drop_last"] = dataloader.drop_last + kwargs["batch_size"] = dataloader.batch_size + + if new_batch_sampler is None: + # Need to manually skip batches in the dataloader + dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) + else: + dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) + + return dataloader diff --git a/src/util/depth_transform.py b/src/util/depth_transform.py index 6062f59..ac9d626 100644 --- a/src/util/depth_transform.py +++ b/src/util/depth_transform.py @@ -1,7 +1,8 @@ # Author: Bingxin Ke -# Last modified: 2024-02-08 +# Last modified: 2024-04-18 import torch +import logging def get_depth_normalizer(cfg_normalizer): @@ -12,8 +13,8 @@ def identical(x): depth_transform = identical - elif "near_far_metric" == cfg_normalizer.type: - depth_transform = NearFarMetricNormalizer( + elif "scale_shift_depth" == cfg_normalizer.type: + depth_transform = ScaleShiftDepthNormalizer( norm_min=cfg_normalizer.norm_min, norm_max=cfg_normalizer.norm_max, min_max_quantile=cfg_normalizer.min_max_quantile, @@ -25,7 +26,7 @@ def identical(x): class DepthNormalizerBase: - is_relative = None + is_absolute = None far_plane_at_max = None def __init__( @@ -46,12 +47,15 @@ def denormalize(self, depth_norm, **kwargs): raise NotImplementedError -class NearFarMetricNormalizer(DepthNormalizerBase): +class ScaleShiftDepthNormalizer(DepthNormalizerBase): """ - depth in [0, d_max] -> [-1, 1] + Use near and far plane to linearly normalize depth, + i.e. d' = d * s + t, + where near plane is mapped to `norm_min`, and far plane is mapped to `norm_max` + Near and far planes are determined by taking quantile values. """ - is_relative = True + is_absolute = False far_plane_at_max = True def __init__( @@ -95,4 +99,5 @@ def scale_back(self, depth_norm): return depth_linear def denormalize(self, depth_norm, **kwargs): + logging.warning(f"{self.__class__} is not revertible without GT") return self.scale_back(depth_norm=depth_norm) diff --git a/src/util/logging_util.py b/src/util/logging_util.py new file mode 100644 index 0000000..37dd103 --- /dev/null +++ b/src/util/logging_util.py @@ -0,0 +1,102 @@ +# Author: Bingxin Ke +# Last modified: 2024-03-12 + +import logging +import os +import sys +import wandb +from tabulate import tabulate +from torch.utils.tensorboard import SummaryWriter + + +def config_logging(cfg_logging, out_dir=None): + file_level = cfg_logging.get("file_level", 10) + console_level = cfg_logging.get("console_level", 10) + + log_formatter = logging.Formatter(cfg_logging["format"]) + + root_logger = logging.getLogger() + root_logger.handlers.clear() + + root_logger.setLevel(min(file_level, console_level)) + + if out_dir is not None: + _logging_file = os.path.join( + out_dir, cfg_logging.get("filename", "logging.log") + ) + file_handler = logging.FileHandler(_logging_file) + file_handler.setFormatter(log_formatter) + file_handler.setLevel(file_level) + root_logger.addHandler(file_handler) + + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(log_formatter) + console_handler.setLevel(console_level) + root_logger.addHandler(console_handler) + + # Avoid pollution by packages + logging.getLogger("PIL").setLevel(logging.INFO) + logging.getLogger("matplotlib").setLevel(logging.INFO) + + +class MyTrainingLogger: + """Tensorboard + wandb logger""" + + writer: SummaryWriter + is_initialized = False + + def __init__(self) -> None: + pass + + def set_dir(self, tb_log_dir): + if self.is_initialized: + raise ValueError("Do not initialize writer twice") + self.writer = SummaryWriter(tb_log_dir) + self.is_initialized = True + + def log_dic(self, scalar_dic, global_step, walltime=None): + for k, v in scalar_dic.items(): + self.writer.add_scalar(k, v, global_step=global_step, walltime=walltime) + return + + +# global instance +tb_logger = MyTrainingLogger() + + +# -------------- wandb tools -------------- +def init_wandb(enable: bool, **kwargs): + if enable: + run = wandb.init(sync_tensorboard=True, **kwargs) + else: + run = wandb.init(mode="disabled") + return run + + +def log_slurm_job_id(step): + global tb_logger + _jobid = os.getenv("SLURM_JOB_ID") + if _jobid is None: + _jobid = -1 + tb_logger.writer.add_scalar("job_id", int(_jobid), global_step=step) + logging.debug(f"Slurm job_id: {_jobid}") + + +def load_wandb_job_id(out_dir): + with open(os.path.join(out_dir, "WANDB_ID"), "r") as f: + wandb_id = f.read() + return wandb_id + + +def save_wandb_job_id(run, out_dir): + with open(os.path.join(out_dir, "WANDB_ID"), "w+") as f: + f.write(run.id) + + +def eval_dic_to_text(val_metrics: dict, dataset_name: str, sample_list_path: str): + eval_text = f"Evaluation metrics:\n\ + on dataset: {dataset_name}\n\ + over samples in: {sample_list_path}\n" + + eval_text += tabulate([val_metrics.keys(), val_metrics.values()]) + return eval_text diff --git a/src/util/loss.py b/src/util/loss.py new file mode 100644 index 0000000..ee6dace --- /dev/null +++ b/src/util/loss.py @@ -0,0 +1,124 @@ +# Author: Bingxin Ke +# Last modified: 2024-02-22 + +import torch + + +def get_loss(loss_name, **kwargs): + if "silog_mse" == loss_name: + criterion = SILogMSELoss(**kwargs) + elif "silog_rmse" == loss_name: + criterion = SILogRMSELoss(**kwargs) + elif "mse_loss" == loss_name: + criterion = torch.nn.MSELoss(**kwargs) + elif "l1_loss" == loss_name: + criterion = torch.nn.L1Loss(**kwargs) + elif "l1_loss_with_mask" == loss_name: + criterion = L1LossWithMask(**kwargs) + elif "mean_abs_rel" == loss_name: + criterion = MeanAbsRelLoss() + else: + raise NotImplementedError + + return criterion + + +class L1LossWithMask: + def __init__(self, batch_reduction=False): + self.batch_reduction = batch_reduction + + def __call__(self, depth_pred, depth_gt, valid_mask=None): + diff = depth_pred - depth_gt + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = depth_gt.shape[-2] * depth_gt.shape[-1] + + loss = torch.sum(torch.abs(diff)) / n + if self.batch_reduction: + loss = loss.mean() + return loss + + +class MeanAbsRelLoss: + def __init__(self) -> None: + # super().__init__() + pass + + def __call__(self, pred, gt): + diff = pred - gt + rel_abs = torch.abs(diff / gt) + loss = torch.mean(rel_abs, dim=0) + return loss + + +class SILogMSELoss: + def __init__(self, lamb, log_pred=True, batch_reduction=True): + """Scale Invariant Log MSE Loss + + Args: + lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss + log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred + """ + super(SILogMSELoss, self).__init__() + self.lamb = lamb + self.pred_in_log = log_pred + self.batch_reduction = batch_reduction + + def __call__(self, depth_pred, depth_gt, valid_mask=None): + log_depth_pred = ( + depth_pred if self.pred_in_log else torch.log(torch.clip(depth_pred, 1e-8)) + ) + log_depth_gt = torch.log(depth_gt) + + diff = log_depth_pred - log_depth_gt + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = depth_gt.shape[-2] * depth_gt.shape[-1] + + diff2 = torch.pow(diff, 2) + + first_term = torch.sum(diff2, (-1, -2)) / n + second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) + loss = first_term - second_term + if self.batch_reduction: + loss = loss.mean() + return loss + + +class SILogRMSELoss: + def __init__(self, lamb, alpha, log_pred=True): + """Scale Invariant Log RMSE Loss + + Args: + lamb (_type_): lambda, lambda=1 -> scale invariant, lambda=0 -> L2 loss + alpha: + log_pred (bool, optional): True if model prediction is logarithmic depht. Will not do log for depth_pred + """ + super(SILogRMSELoss, self).__init__() + self.lamb = lamb + self.alpha = alpha + self.pred_in_log = log_pred + + def __call__(self, depth_pred, depth_gt, valid_mask): + log_depth_pred = depth_pred if self.pred_in_log else torch.log(depth_pred) + log_depth_gt = torch.log(depth_gt) + # borrowed from https://github.com/aliyun/NeWCRFs + # diff = log_depth_pred[valid_mask] - log_depth_gt[valid_mask] + # return torch.sqrt((diff ** 2).mean() - self.lamb * (diff.mean() ** 2)) * self.alpha + + diff = log_depth_pred - log_depth_gt + if valid_mask is not None: + diff[~valid_mask] = 0 + n = valid_mask.sum((-1, -2)) + else: + n = depth_gt.shape[-2] * depth_gt.shape[-1] + + diff2 = torch.pow(diff, 2) + first_term = torch.sum(diff2, (-1, -2)) / n + second_term = self.lamb * torch.pow(torch.sum(diff, (-1, -2)), 2) / (n**2) + loss = torch.sqrt(first_term - second_term).mean() * self.alpha + return loss diff --git a/src/util/lr_scheduler.py b/src/util/lr_scheduler.py new file mode 100644 index 0000000..cd2d67f --- /dev/null +++ b/src/util/lr_scheduler.py @@ -0,0 +1,48 @@ +# Author: Bingxin Ke +# Last modified: 2024-02-22 + +import numpy as np + + +class IterExponential: + def __init__(self, total_iter_length, final_ratio, warmup_steps=0) -> None: + """ + Customized iteration-wise exponential scheduler. + Re-calculate for every step, to reduce error accumulation + + Args: + total_iter_length (int): Expected total iteration number + final_ratio (float): Expected LR ratio at n_iter = total_iter_length + """ + self.total_length = total_iter_length + self.effective_length = total_iter_length - warmup_steps + self.final_ratio = final_ratio + self.warmup_steps = warmup_steps + + def __call__(self, n_iter) -> float: + if n_iter < self.warmup_steps: + alpha = 1.0 * n_iter / self.warmup_steps + elif n_iter >= self.total_length: + alpha = self.final_ratio + else: + actual_iter = n_iter - self.warmup_steps + alpha = np.exp( + actual_iter / self.effective_length * np.log(self.final_ratio) + ) + return alpha + + +if "__main__" == __name__: + lr_scheduler = IterExponential( + total_iter_length=50000, final_ratio=0.01, warmup_steps=200 + ) + lr_scheduler = IterExponential( + total_iter_length=50000, final_ratio=0.01, warmup_steps=0 + ) + + x = np.arange(100000) + alphas = [lr_scheduler(i) for i in x] + import matplotlib.pyplot as plt + + plt.plot(alphas) + plt.savefig("lr_scheduler.png") diff --git a/src/util/multi_res_noise.py b/src/util/multi_res_noise.py new file mode 100644 index 0000000..e4d0ee0 --- /dev/null +++ b/src/util/multi_res_noise.py @@ -0,0 +1,75 @@ +# Author: Bingxin Ke +# Last modified: 2024-04-18 + +import torch +import math + + +# adapted from: https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2?s=31 +def multi_res_noise_like( + x, strength=0.9, downscale_strategy="original", generator=None, device=None +): + if torch.is_tensor(strength): + strength = strength.reshape((-1, 1, 1, 1)) + b, c, w, h = x.shape + + if device is None: + device = x.device + + up_sampler = torch.nn.Upsample(size=(w, h), mode="bilinear") + noise = torch.randn(x.shape, device=x.device, generator=generator) + + if "original" == downscale_strategy: + for i in range(10): + r = ( + torch.rand(1, generator=generator, device=device) * 2 + 2 + ) # Rather than always going 2x, + w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i))) + noise += ( + up_sampler( + torch.randn(b, c, w, h, generator=generator, device=device).to(x) + ) + * strength**i + ) + if w == 1 or h == 1: + break # Lowest resolution is 1x1 + elif "every_layer" == downscale_strategy: + for i in range(int(math.log2(min(w, h)))): + w, h = max(1, int(w / 2)), max(1, int(h / 2)) + noise += ( + up_sampler( + torch.randn(b, c, w, h, generator=generator, device=device).to(x) + ) + * strength**i + ) + elif "power_of_two" == downscale_strategy: + for i in range(10): + r = 2 + w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i))) + noise += ( + up_sampler( + torch.randn(b, c, w, h, generator=generator, device=device).to(x) + ) + * strength**i + ) + if w == 1 or h == 1: + break # Lowest resolution is 1x1 + elif "random_step" == downscale_strategy: + for i in range(10): + r = ( + torch.rand(1, generator=generator, device=device) * 2 + 2 + ) # Rather than always going 2x, + w, h = max(1, int(w / (r))), max(1, int(h / (r))) + noise += ( + up_sampler( + torch.randn(b, c, w, h, generator=generator, device=device).to(x) + ) + * strength**i + ) + if w == 1 or h == 1: + break # Lowest resolution is 1x1 + else: + raise ValueError(f"unknown downscale strategy: {downscale_strategy}") + + noise = noise / noise.std() # Scaled back to roughly unit variance + return noise diff --git a/src/util/seed_all.py b/src/util/seeding.py similarity index 72% rename from src/util/seed_all.py rename to src/util/seeding.py index 9579565..b63a778 100644 --- a/src/util/seed_all.py +++ b/src/util/seeding.py @@ -21,6 +21,7 @@ import numpy as np import random import torch +import logging def seed_all(seed: int = 0): @@ -31,3 +32,23 @@ def seed_all(seed: int = 0): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + + +def generate_seed_sequence( + initial_seed: int, + length: int, + min_val=-0x8000_0000_0000_0000, + max_val=0xFFFF_FFFF_FFFF_FFFF, +): + if initial_seed is None: + logging.warning("initial_seed is None, reproducibility is not guaranteed") + random.seed(initial_seed) + + seed_sequence = [] + + for _ in range(length): + seed = random.randint(min_val, max_val) + + seed_sequence.append(seed) + + return seed_sequence diff --git a/src/util/slurm_util.py b/src/util/slurm_util.py new file mode 100644 index 0000000..a983d86 --- /dev/null +++ b/src/util/slurm_util.py @@ -0,0 +1,15 @@ +# Author: Bingxin Ke +# Last modified: 2024-02-22 + +import os + + +def is_on_slurm(): + cluster_name = os.getenv("SLURM_CLUSTER_NAME") + is_on_slurm = cluster_name is not None + return is_on_slurm + + +def get_local_scratch_dir(): + local_scratch_dir = os.getenv("TMPDIR") + return local_scratch_dir diff --git a/train.py b/train.py new file mode 100644 index 0000000..47bd87d --- /dev/null +++ b/train.py @@ -0,0 +1,363 @@ +# An official reimplemented version of Marigold training script +# Last modified: 2024-05-17 +# +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + +import argparse +import logging +import os +import shutil +from datetime import datetime, timedelta +from typing import List + +import torch +from omegaconf import OmegaConf +from torch.utils.data import ConcatDataset, DataLoader +from tqdm import tqdm + +from marigold.marigold_pipeline import MarigoldPipeline +from src.dataset import BaseDepthDataset, DatasetMode, get_dataset +from src.dataset.mixed_sampler import MixedBatchSampler +from src.trainer import get_trainer_cls +from src.util.config_util import ( + find_value_in_omegaconf, + recursive_load_config, +) +from src.util.depth_transform import ( + DepthNormalizerBase, + get_depth_normalizer, +) +from src.util.logging_util import ( + config_logging, + init_wandb, + load_wandb_job_id, + log_slurm_job_id, + save_wandb_job_id, + tb_logger, +) +from src.util.slurm_util import get_local_scratch_dir, is_on_slurm + +if "__main__" == __name__: + t_start = datetime.now() + print(f"start at {t_start}") + + # -------------------- Arguments -------------------- + parser = argparse.ArgumentParser(description="Train your cute model!") + parser.add_argument( + "--config", + type=str, + default="config/train_marigold.yaml", + help="Path to config file.", + ) + parser.add_argument( + "--resume_run", + action="store", + default=None, + help="Path of checkpoint to be resumed. If given, will ignore --config, and checkpoint in the config", + ) + parser.add_argument( + "--output_dir", type=str, default=None, help="directory to save checkpoints" + ) + parser.add_argument("--no_cuda", action="store_true", help="Do not use cuda.") + parser.add_argument( + "--exit_after", + type=int, + default=-1, + help="Save checkpoint and exit after X minutes.", + ) + parser.add_argument("--no_wandb", action="store_true", help="run without wandb") + parser.add_argument( + "--do_not_copy_data", + action="store_true", + help="On Slurm cluster, do not copy data to local scratch", + ) + parser.add_argument( + "--base_data_dir", type=str, default=None, help="directory of training data" + ) + parser.add_argument( + "--base_ckpt_dir", + type=str, + default=None, + help="directory of pretrained checkpoint", + ) + parser.add_argument( + "--add_datetime_prefix", + action="store_true", + help="Add datetime to the output folder name", + ) + + args = parser.parse_args() + resume_run = args.resume_run + output_dir = args.output_dir + base_data_dir = ( + args.base_data_dir + if args.base_data_dir is not None + else os.environ["BASE_DATA_DIR"] + ) + base_ckpt_dir = ( + args.base_ckpt_dir + if args.base_ckpt_dir is not None + else os.environ["BASE_CKPT_DIR"] + ) + + # -------------------- Initialization -------------------- + # Resume previous run + if resume_run is not None: + print(f"Resume run: {resume_run}") + out_dir_run = os.path.dirname(os.path.dirname(resume_run)) + job_name = os.path.basename(out_dir_run) + # Resume config file + cfg = OmegaConf.load(os.path.join(out_dir_run, "config.yaml")) + else: + # Run from start + cfg = recursive_load_config(args.config) + # Full job name + pure_job_name = os.path.basename(args.config).split(".")[0] + # Add time prefix + if args.add_datetime_prefix: + job_name = f"{t_start.strftime('%y_%m_%d-%H_%M_%S')}-{pure_job_name}" + else: + job_name = pure_job_name + + # Output dir + if output_dir is not None: + out_dir_run = os.path.join(output_dir, job_name) + else: + out_dir_run = os.path.join("./output", job_name) + os.makedirs(out_dir_run, exist_ok=False) + + cfg_data = cfg.dataset + + # Other directories + out_dir_ckpt = os.path.join(out_dir_run, "checkpoint") + if not os.path.exists(out_dir_ckpt): + os.makedirs(out_dir_ckpt) + out_dir_tb = os.path.join(out_dir_run, "tensorboard") + if not os.path.exists(out_dir_tb): + os.makedirs(out_dir_tb) + out_dir_eval = os.path.join(out_dir_run, "evaluation") + if not os.path.exists(out_dir_eval): + os.makedirs(out_dir_eval) + out_dir_vis = os.path.join(out_dir_run, "visualization") + if not os.path.exists(out_dir_vis): + os.makedirs(out_dir_vis) + + # -------------------- Logging settings -------------------- + config_logging(cfg.logging, out_dir=out_dir_run) + logging.debug(f"config: {cfg}") + + # Initialize wandb + if not args.no_wandb: + if resume_run is not None: + wandb_id = load_wandb_job_id(out_dir_run) + wandb_cfg_dic = { + "id": wandb_id, + "resume": "must", + **cfg.wandb, + } + else: + wandb_cfg_dic = { + "config": dict(cfg), + "name": job_name, + "mode": "online", + **cfg.wandb, + } + wandb_cfg_dic.update({"dir": out_dir_run}) + wandb_run = init_wandb(enable=True, **wandb_cfg_dic) + save_wandb_job_id(wandb_run, out_dir_run) + else: + init_wandb(enable=False) + + # Tensorboard (should be initialized after wandb) + tb_logger.set_dir(out_dir_tb) + + log_slurm_job_id(step=0) + + # -------------------- Device -------------------- + cuda_avail = torch.cuda.is_available() and not args.no_cuda + device = torch.device("cuda" if cuda_avail else "cpu") + logging.info(f"device = {device}") + + # -------------------- Snapshot of code and config -------------------- + if resume_run is None: + _output_path = os.path.join(out_dir_run, "config.yaml") + with open(_output_path, "w+") as f: + OmegaConf.save(config=cfg, f=f) + logging.info(f"Config saved to {_output_path}") + # Copy and tar code on the first run + _temp_code_dir = os.path.join(out_dir_run, "code_tar") + _code_snapshot_path = os.path.join(out_dir_run, "code_snapshot.tar") + os.system( + f"rsync --relative -arhvz --quiet --filter=':- .gitignore' --exclude '.git' . '{_temp_code_dir}'" + ) + os.system(f"tar -cf {_code_snapshot_path} {_temp_code_dir}") + os.system(f"rm -rf {_temp_code_dir}") + logging.info(f"Code snapshot saved to: {_code_snapshot_path}") + + # -------------------- Copy data to local scratch (Slurm) -------------------- + if is_on_slurm() and (not args.do_not_copy_data): + # local scratch dir + original_data_dir = base_data_dir + base_data_dir = os.path.join(get_local_scratch_dir(), "Marigold_data") + # copy data + required_data_list = find_value_in_omegaconf("dir", cfg_data) + # if cfg_train.visualize.init_latent_path is not None: + # required_data_list.append(cfg_train.visualize.init_latent_path) + required_data_list = list(set(required_data_list)) + logging.info(f"Required_data_list: {required_data_list}") + for d in tqdm(required_data_list, desc="Copy data to local scratch"): + ori_dir = os.path.join(original_data_dir, d) + dst_dir = os.path.join(base_data_dir, d) + os.makedirs(os.path.dirname(dst_dir), exist_ok=True) + if os.path.isfile(ori_dir): + shutil.copyfile(ori_dir, dst_dir) + elif os.path.isdir(ori_dir): + shutil.copytree(ori_dir, dst_dir) + logging.info(f"Data copied to: {base_data_dir}") + + # -------------------- Gradient accumulation steps -------------------- + eff_bs = cfg.dataloader.effective_batch_size + accumulation_steps = eff_bs / cfg.dataloader.max_train_batch_size + assert int(accumulation_steps) == accumulation_steps + accumulation_steps = int(accumulation_steps) + + logging.info( + f"Effective batch size: {eff_bs}, accumulation steps: {accumulation_steps}" + ) + + # -------------------- Data -------------------- + loader_seed = cfg.dataloader.seed + if loader_seed is None: + loader_generator = None + else: + loader_generator = torch.Generator().manual_seed(loader_seed) + + # Training dataset + depth_transform: DepthNormalizerBase = get_depth_normalizer( + cfg_normalizer=cfg.depth_normalization + ) + train_dataset: BaseDepthDataset = get_dataset( + cfg_data.train, + base_data_dir=base_data_dir, + mode=DatasetMode.TRAIN, + augmentation_args=cfg.augmentation, + depth_transform=depth_transform, + ) + logging.debug("Augmentation: ", cfg.augmentation) + if "mixed" == cfg_data.train.name: + dataset_ls = train_dataset + assert len(cfg_data.train.prob_ls) == len( + dataset_ls + ), "Lengths don't match: `prob_ls` and `dataset_list`" + concat_dataset = ConcatDataset(dataset_ls) + mixed_sampler = MixedBatchSampler( + src_dataset_ls=dataset_ls, + batch_size=cfg.dataloader.max_train_batch_size, + drop_last=True, + prob=cfg_data.train.prob_ls, + shuffle=True, + generator=loader_generator, + ) + train_loader = DataLoader( + concat_dataset, + batch_sampler=mixed_sampler, + num_workers=cfg.dataloader.num_workers, + ) + else: + train_loader = DataLoader( + dataset=train_dataset, + batch_size=cfg.dataloader.max_train_batch_size, + num_workers=cfg.dataloader.num_workers, + shuffle=True, + generator=loader_generator, + ) + # Validation dataset + val_loaders: List[DataLoader] = [] + for _val_dic in cfg_data.val: + _val_dataset = get_dataset( + _val_dic, + base_data_dir=base_data_dir, + mode=DatasetMode.EVAL, + ) + _val_loader = DataLoader( + dataset=_val_dataset, + batch_size=1, + shuffle=False, + num_workers=cfg.dataloader.num_workers, + ) + val_loaders.append(_val_loader) + + # Visualization dataset + vis_loaders: List[DataLoader] = [] + for _vis_dic in cfg_data.vis: + _vis_dataset = get_dataset( + _vis_dic, + base_data_dir=base_data_dir, + mode=DatasetMode.EVAL, + ) + _vis_loader = DataLoader( + dataset=_vis_dataset, + batch_size=1, + shuffle=False, + num_workers=cfg.dataloader.num_workers, + ) + vis_loaders.append(_vis_loader) + + # -------------------- Model -------------------- + _pipeline_kwargs = cfg.pipeline.kwargs if cfg.pipeline.kwargs is not None else {} + model = MarigoldPipeline.from_pretrained( + os.path.join(base_ckpt_dir, cfg.model.pretrained_path), **_pipeline_kwargs + ) + + # -------------------- Trainer -------------------- + # Exit time + if args.exit_after > 0: + t_end = t_start + timedelta(minutes=args.exit_after) + logging.info(f"Will exit at {t_end}") + else: + t_end = None + + trainer_cls = get_trainer_cls(cfg.trainer.name) + logging.debug(f"Trainer: {trainer_cls}") + trainer = trainer_cls( + cfg=cfg, + model=model, + train_dataloader=train_loader, + device=device, + base_ckpt_dir=base_ckpt_dir, + out_dir_ckpt=out_dir_ckpt, + out_dir_eval=out_dir_eval, + out_dir_vis=out_dir_vis, + accumulation_steps=accumulation_steps, + val_dataloaders=val_loaders, + vis_dataloaders=vis_loaders, + ) + + # -------------------- Checkpoint -------------------- + if resume_run is not None: + trainer.load_checkpoint( + resume_run, load_trainer_state=True, resume_lr_scheduler=True + ) + + # -------------------- Training & Evaluation Loop -------------------- + try: + trainer.train(t_end=t_end) + except Exception as e: + logging.exception(e)