-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 9a2a8d6
Showing
47 changed files
with
7,673 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
wandb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
__pycache__/ | ||
.cache/ | ||
wandb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# NVIDIA CORPORATION and its licensors retain all intellectual property | ||
# and proprietary rights in and to this software, related documentation | ||
# and any modifications thereto. Any use, reproduction, disclosure or | ||
# distribution of this software and related documentation without an express | ||
# license agreement from NVIDIA CORPORATION is strictly prohibited. | ||
|
||
FROM nvcr.io/nvidia/pytorch:20.12-py3 as base | ||
|
||
ENV PYTHONDONTWRITEBYTECODE 1 | ||
ENV PYTHONUNBUFFERED 1 | ||
|
||
RUN pip install imageio-ffmpeg==0.4.3 pyspng==0.1.0 wandb requests | ||
|
||
WORKDIR /workspace | ||
|
||
# Unset TORCH_CUDA_ARCH_LIST and exec. This makes pytorch run-time | ||
# extension builds significantly faster as we only compile for the | ||
# currently active GPU configuration. | ||
RUN (printf '#!/bin/bash\nunset TORCH_CUDA_ARCH_LIST\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh | ||
ENTRYPOINT ["/entry.sh"] | ||
|
||
FROM base | ||
COPY . /workspace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
Copyright (c) 2021, NVIDIA Corporation. All rights reserved. | ||
|
||
|
||
NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA) | ||
|
||
|
||
======================================================================= | ||
|
||
1. Definitions | ||
|
||
"Licensor" means any person or entity that distributes its Work. | ||
|
||
"Software" means the original work of authorship made available under | ||
this License. | ||
|
||
"Work" means the Software and any additions to or derivative works of | ||
the Software that are made available under this License. | ||
|
||
The terms "reproduce," "reproduction," "derivative works," and | ||
"distribution" have the meaning as provided under U.S. copyright law; | ||
provided, however, that for the purposes of this License, derivative | ||
works shall not include works that remain separable from, or merely | ||
link (or bind by name) to the interfaces of, the Work. | ||
|
||
Works, including the Software, are "made available" under this License | ||
by including in or with the Work either (a) a copyright notice | ||
referencing the applicability of this License to the Work, or (b) a | ||
copy of this License. | ||
|
||
2. License Grants | ||
|
||
2.1 Copyright Grant. Subject to the terms and conditions of this | ||
License, each Licensor grants to you a perpetual, worldwide, | ||
non-exclusive, royalty-free, copyright license to reproduce, | ||
prepare derivative works of, publicly display, publicly perform, | ||
sublicense and distribute its Work and any resulting derivative | ||
works in any form. | ||
|
||
3. Limitations | ||
|
||
3.1 Redistribution. You may reproduce or distribute the Work only | ||
if (a) you do so under this License, (b) you include a complete | ||
copy of this License with your distribution, and (c) you retain | ||
without modification any copyright, patent, trademark, or | ||
attribution notices that are present in the Work. | ||
|
||
3.2 Derivative Works. You may specify that additional or different | ||
terms apply to the use, reproduction, and distribution of your | ||
derivative works of the Work ("Your Terms") only if (a) Your Terms | ||
provide that the use limitation in Section 3.3 applies to your | ||
derivative works, and (b) you identify the specific derivative | ||
works that are subject to Your Terms. Notwithstanding Your Terms, | ||
this License (including the redistribution requirements in Section | ||
3.1) will continue to apply to the Work itself. | ||
|
||
3.3 Use Limitation. The Work and any derivative works thereof only | ||
may be used or intended for use non-commercially. Notwithstanding | ||
the foregoing, NVIDIA and its affiliates may use the Work and any | ||
derivative works commercially. As used herein, "non-commercially" | ||
means for research or evaluation purposes only. | ||
|
||
3.4 Patent Claims. If you bring or threaten to bring a patent claim | ||
against any Licensor (including any claim, cross-claim or | ||
counterclaim in a lawsuit) to enforce any patents that you allege | ||
are infringed by any Work, then your rights under this License from | ||
such Licensor (including the grant in Section 2.1) will terminate | ||
immediately. | ||
|
||
3.5 Trademarks. This License does not grant any rights to use any | ||
Licensor’s or its affiliates’ names, logos, or trademarks, except | ||
as necessary to reproduce the notices described in this License. | ||
|
||
3.6 Termination. If you violate any term of this License, then your | ||
rights under this License (including the grant in Section 2.1) will | ||
terminate immediately. | ||
|
||
4. Disclaimer of Warranty. | ||
|
||
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF | ||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR | ||
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER | ||
THIS LICENSE. | ||
|
||
5. Limitation of Liability. | ||
|
||
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL | ||
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE | ||
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, | ||
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF | ||
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK | ||
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, | ||
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER | ||
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF | ||
THE POSSIBILITY OF SUCH DAMAGES. | ||
|
||
======================================================================= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
## StyleGAN2-UNet | ||
A style-based GAN with UNet-guided synthesis. | ||
|
||
 | ||
|
||
https://user-images.githubusercontent.com/20824840/198903666-7a130b1d-4bc3-49d5-8f41-e964515b2adb.mp4 | ||
|
||
## Training new networks | ||
### Data preparation | ||
You need to put your training images into one folder and your segmentation masks into another folder. The names of the images and masks must be paired together in a lexicographical order. | ||
|
||
### Training | ||
To train a network (or resume training), you must specify the path to the segmentation masks through the `seg_data` option and additionally provide the RGB colors for each class through the `seg_colors` option encoded in JSON format. | ||
|
||
```bash | ||
python3 train.py \ | ||
--resume=$HOME/training-runs/00000-train-auto1-resumecustom/network-snapshot-001000.pkl \ | ||
--outdir=$HOME/training-runs \ | ||
--data=$HOME/datasets/crops2-512-filtered/train \ | ||
--seg_data=$HOME/datasets/segmented-512-filtered/train \ | ||
--seg_colors="[[0,0,255],[0,128,0],[0,191,191],[191,0,191],[255,0,0],[255,255,255]]" \ | ||
--image_snapshot_ticks=1 \ | ||
--wandb_project=sgunet \ | ||
--snap=50 \ | ||
--batch=16 \ | ||
--batch_gpu=4 \ | ||
--gpus=1 | ||
``` | ||
|
||
In this example, the results are saved to a newly created directory `~/training-runs/<ID>-mydataset-auto1`, controlled by `--outdir`. The training exports network pickles (`network-snapshot-<INT>.pkl`) and example images (`fakes<INT>.png`) at regular intervals (controlled by `--snap`). For each pickle, it also evaluates FID (controlled by `--metrics`) and logs the resulting scores in `metric-fid50k_full.jsonl` (as well as TFEvents if TensorBoard is installed). | ||
|
||
The name of the output directory reflects the training configuration. For example, `00000-mydataset-auto1` indicates that the *base configuration* was `auto1`, meaning that the hyperparameters were selected automatically for training on one GPU. The base configuration is controlled by `--cfg`: | ||
|
||
| Base config | Description | ||
| :-------------------- | :---------- | ||
| `auto` (default) | Automatically select reasonable defaults based on resolution and GPU count. Serves as a good starting point for new datasets but does not necessarily lead to optimal results. | ||
| `stylegan2` | Reproduce results for StyleGAN2 config F at 1024x1024 using 1, 2, 4, or 8 GPUs. | ||
| `paper256` | Reproduce results for FFHQ and LSUN Cat at 256x256 using 1, 2, 4, or 8 GPUs. | ||
| `paper512` | Reproduce results for BreCaHAD and AFHQ at 512x512 using 1, 2, 4, or 8 GPUs. | ||
| `paper1024` | Reproduce results for MetFaces at 1024x1024 using 1, 2, 4, or 8 GPUs. | ||
| `cifar` | Reproduce results for CIFAR-10 (tuned configuration) using 1 or 2 GPUs. | ||
|
||
The training configuration can be further customized with additional command line options: | ||
|
||
* `--aug=noaug` disables ADA. | ||
* `--cond=1` enables class-conditional training (requires a dataset with labels). | ||
* `--mirror=1` amplifies the dataset with x-flips. Often beneficial, even with ADA. | ||
* `--resume=ffhq1024 --snap=10` performs transfer learning from FFHQ trained at 1024x1024. | ||
* `--resume=~/training-runs/<NAME>/network-snapshot-<INT>.pkl` resumes a previous training run. | ||
* `--gamma=10` overrides R1 gamma. We recommend trying a couple of different values for each new dataset. | ||
* `--aug=ada --target=0.7` adjusts ADA target value (default: 0.6). | ||
* `--augpipe=blit` enables pixel blitting but disables all other augmentations. | ||
* `--augpipe=bgcfnc` enables all available augmentations (blit, geom, color, filter, noise, cutout). | ||
|
||
Please refer to [`python train.py --help`](./docs/train-help.txt) for the full list. | ||
|
||
## Inference Server | ||
You can spin up an API server that can be used for inference. Refer to [srv.py](./srv.py) for more details. | ||
|
||
Example usage: | ||
```bash | ||
export SG_SEG_DATASET_PATH="$HOME/datasets/segmented-512-filtered/train" | ||
export SG_REAL_DATASET_PATH="$HOME/datasets/crops2-512-filtered/train" | ||
export SG_MODEL_PATH="$HOME/training-runs/00000-train-auto1-resumecustom/network-snapshot-001000.pkl" | ||
export SG_SEG_COLORS_JSON="[[0,0,255],[0,128,0],[0,191,191],[191,0,191],[255,0,0],[255,255,255]]" | ||
FLASK_APP=srv.py python -m flask run --host=0.0.0.0 | ||
``` | ||
|
||
## Acknowledgements | ||
|
||
This work was supported by [EIDOSLab](https://github.com/EIDOSlab). The code is heavily based on [StyleGAN2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# NVIDIA CORPORATION and its licensors retain all intellectual property | ||
# and proprietary rights in and to this software, related documentation | ||
# and any modifications thereto. Any use, reproduction, disclosure or | ||
# distribution of this software and related documentation without an express | ||
# license agreement from NVIDIA CORPORATION is strictly prohibited. | ||
|
||
from .util import EasyDict, make_cache_dir_path |
Oops, something went wrong.