Skip to content

Commit

Permalink
Support SDXL. update sd_scripts to v0.8.4
Browse files Browse the repository at this point in the history
  • Loading branch information
liasece committed Mar 15, 2024
1 parent 2f7c1d4 commit a71f6e9
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 36 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ English (TODO) [google translate](https://github-com.translate.goog/liasece/sd-w

在 stable-diffusion-webui 的运行命令行中可以看到训练的过程及进度。

### 使用 SD2 或 SDXL 作为基础模型

请正确勾选 `Base on Stable Diffusion V2` 或者 `Base on Stable Diffusion XL` ,否则会导致训练失败。

### 高级 sd_script 参数追加或覆盖

`Append or override the sd_script args.` 文本框中输入参数,务必使用 `--` 开头的参数,例如:`--lr_scheduler="constant_with_warmup" --max_grad_norm=0.0` 。插件代码中会使用 `--` 分隔符作为参数间的分隔符。

### 一次训练多种参数训练,充分利用你睡觉时的 GPU

有时,一套训练配置可能并不是最优的。等待你的训练结束然后再重新开始训练,这样的效率太低了。因此,你可以一次性配置多种参数,点击一次训练,自动组合不同的参数进行训练。
Expand Down
69 changes: 40 additions & 29 deletions liasece_sd_webui_train_tools/ArgsList.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import gc
import torch

from liasece_sd_webui_train_tools.util import *
import liasece_sd_webui_train_tools.PythonContextWarper as pc
with pc.PythonContextWarper(
to_module_path= os.path.abspath(os.path.join(os.path.dirname(__file__), "sd_scripts")),
Expand All @@ -30,7 +31,6 @@ def __init__(self):
self.save_json_folder: Union[str, None] = None
self.save_json_name: Union[str, None] = None
self.load_json_path: Union[str, None] = None
self.multi_run_folder: Union[str, None] = None
self.reg_img_folder: Union[str, None] = None
self.sample_prompts: Union[str, None] = None # path to a txt file that has all of the sample prompts in it,
# one per line. Only goes to 75 tokens, will cut off the rest. Just place the prompts into the txt file per line
Expand Down Expand Up @@ -188,6 +188,11 @@ def __init__(self):
self.locon_dim: Union[int, None] = None # deprecated
self.locon_alpha: Union[int, None] = None # deprecated
self.locon: bool = False # deprecated
self.use_sdxl: bool = False # use the sdxl trainer
self.no_half_vae: bool = False # Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
self.cache_text_encoder_outputs: bool = False # Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
self.cache_text_encoder_outputs_to_disk: bool = False # Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
self.ext_sd_script_args: str = "" # Append or override the sd_script args. (e.g. `--lr_scheduler="constant_with_warmup" --max_grad_norm=0.0`)

# Creates the dict that is used for the rest of the code, to facilitate easier json saving and loading
def convert_args_to_dict(self):
Expand All @@ -196,32 +201,11 @@ def convert_args_to_dict(self):
def create_args(self) -> argparse.Namespace:
parser = Parser()
args = self.convert_args_to_dict()
multi_path = args['multi_run_folder']
if multi_path and ensure_path(multi_path, "multi_run_folder"):
for file in os.listdir(multi_path):
if os.path.isdir(file) or file.split(".")[-1] != "json":
continue
args = self.convert_args_to_dict()
args['json_load_skip_list'] = None
try:
ensure_file_paths(args)
except FileNotFoundError:
print("failed to find one or more folders or paths, skipping.")
continue
if args['tag_occurrence_txt_file']:
get_occurrence_of_tags(args)
args = parser.create_args(self.change_dict_to_internal_names(args))
train_network.train(args)
gc.collect()
torch.cuda.empty_cache()
if not os.path.exists(os.path.join(multi_path, "complete")):
os.makedirs(os.path.join(multi_path, "complete"))
os.rename(os.path.join(multi_path, file), os.path.join(multi_path, "complete", file))
print("completed all training")
quit()
ensure_file_paths(args)
if args['tag_occurrence_txt_file']:
get_occurrence_of_tags(args)
if self.use_sdxl:
self.no_half_vae = True
args = parser.create_args(self.change_dict_to_internal_names(args))
return args

Expand Down Expand Up @@ -283,7 +267,7 @@ def find_max_steps(args: dict) -> int:

def ensure_file_paths(args: dict) -> None:
failed_to_find = False
folders_to_check = ['img_folder', 'output_folder', 'save_json_folder', 'multi_run_folder',
folders_to_check = ['img_folder', 'output_folder', 'save_json_folder',
'reg_img_folder', 'log_dir']
for folder in folders_to_check:
if folder in args and args[folder] is not None:
Expand Down Expand Up @@ -360,22 +344,49 @@ def ensure_path(path, name, ext_list=None) -> bool:

class Parser:
def __init__(self) -> None:
self.parser = train_network.setup_parser()
parser = train_network.setup_parser()
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)
parser.add_argument(
"--cache_text_encoder_outputs_to_disk",
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
self.parser = parser

def create_args(self, args: dict) -> argparse.Namespace:
remove_epochs = False
args_list = []
skip_list = ["save_json_folder", "load_json_path", "multi_run_folder", "json_load_skip_list",
args_list: list[str] = []
skip_list = ["save_json_folder", "load_json_path", "json_load_skip_list",
"tag_occurrence_txt_file", "sort_tag_occurrence_alphabetically", "save_json_only",
"warmup_lr_ratio", "optimizer_args", "locon_dim", "locon_alpha", "locon", "lyco", "network_args",
"resolution", "height_resolution"]
"resolution", "height_resolution", "use_sdxl", "ext_sd_script_args"]

# decode ext_sd_script_args
if "ext_sd_script_args" in args and args["ext_sd_script_args"]:
ext_sd_script_args = args["ext_sd_script_args"].split("--")
for arg in ext_sd_script_args:
if not arg:
continue
args_list.append(f"--{arg}")

for key, value in args.items():
if not value:
continue
if key in skip_list:
continue
if key == "max_train_steps":
remove_epochs = True
# check key is in the parser
already_exists = False
for arg in args_list:
if arg.startswith(f"--{key}"):
already_exists = True
break
if already_exists:
printD(f"Skipping {key} as it already exists")
continue
if isinstance(value, bool):
args_list.append(f"--{key}")
else:
Expand Down
2 changes: 2 additions & 0 deletions liasece_sd_webui_train_tools/config_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@
{"train_finish_generate_all_checkpoint_preview": True},
{"train_optimizer_type": ["Lion"]},
{"train_learning_rate": "0.0001"},
{"sd_script_args": ""},
{"train_net_dim": 128},
{"train_alpha": 64},
{"train_clip_skip": 2},
{"train_mixed_precision": "fp16"},
{"train_xformers": True},
{"train_base_on_sd_v2": False},
{"use_sdxl": False},
# preview
{"preview_include_sub_img": False},
{"preview_txt2img_prompt": "best quality,Amazing,finely detail,extremely detailed CG unity 8k wallpaper"},
Expand Down
2 changes: 1 addition & 1 deletion liasece_sd_webui_train_tools/sd_scripts
Submodule sd_scripts updated 86 files
+7 −0 .github/dependabot.yml
+2 −2 .github/workflows/typos.yml
+35 −48 README-ja.md
+236 −231 README.md
+123 −128 XTI_hijack.py
+19 −1 _typos.toml
+ bitsandbytes_windows/libbitsandbytes_cuda118.dll
+166 −166 bitsandbytes_windows/main.py
+4 −0 docs/config_README-ja.md
+33 −0 docs/gen_img_README-ja.md
+4 −0 docs/train_README-ja.md
+214 −0 docs/train_lllite_README-ja.md
+217 −0 docs/train_lllite_README.md
+13 −8 docs/train_network_README-ja.md
+99 −79 fine_tune.py
+5 −1 finetune/blip/blip.py
+23 −19 finetune/clean_captions_and_tags.py
+21 −11 finetune/make_captions.py
+22 −11 finetune/make_captions_by_git.py
+12 −8 finetune/merge_captions_to_metadata.py
+12 −8 finetune/merge_dd_tags_to_metadata.py
+37 −95 finetune/prepare_buckets_latents.py
+109 −26 finetune/tag_images_by_wd14_tagger.py
+3,326 −0 gen_img.py
+676 −434 gen_img_diffusers.py
+227 −0 library/attention_processors.py
+572 −420 library/config_util.py
+80 −6 library/custom_train_functions.py
+84 −0 library/device_utils.py
+11 −8 library/huggingface_util.py
+223 −0 library/hypernetwork.py
+179 −0 library/ipex/__init__.py
+177 −0 library/ipex/attention.py
+312 −0 library/ipex/diffusers.py
+183 −0 library/ipex/gradscaler.py
+298 −0 library/ipex/hijacks.py
+97 −43 library/lpw_stable_diffusion.py
+234 −54 library/model_util.py
+1,919 −0 library/original_unet.py
+309 −0 library/sai_model_spec.py
+1,347 −0 library/sdxl_lpw_stable_diffusion.py
+577 −0 library/sdxl_model_util.py
+1,284 −0 library/sdxl_original_unet.py
+373 −0 library/sdxl_train_util.py
+23 −20 library/slicing_vae.py
+2,195 −882 library/train_util.py
+261 −1 library/utils.py
+31 −22 networks/check_lora_weights.py
+449 −0 networks/control_net_lllite.py
+505 −0 networks/control_net_lllite_for_train.py
+44 −16 networks/dylora.py
+10 −7 networks/extract_lora_from_dylora.py
+336 −165 networks/extract_lora_from_models.py
+158 −77 networks/lora.py
+616 −0 networks/lora_diffusers.py
+1,244 −0 networks/lora_fa.py
+18 −11 networks/lora_interrogator.py
+147 −30 networks/merge_lora.py
+16 −11 networks/merge_lora_old.py
+433 −0 networks/oft.py
+278 −226 networks/resize_lora.py
+351 −0 networks/sdxl_merge_lora.py
+238 −168 networks/svd_merge_lora.py
+24 −15 requirements.txt
+3,210 −0 sdxl_gen_img.py
+329 −0 sdxl_minimal_inference.py
+792 −0 sdxl_train.py
+616 −0 sdxl_train_control_net_lllite.py
+584 −0 sdxl_train_control_net_lllite_old.py
+184 −0 sdxl_train_network.py
+138 −0 sdxl_train_textual_inversion.py
+197 −0 tools/cache_latents.py
+194 −0 tools/cache_text_encoder_outputs.py
+5 −1 tools/canny.py
+42 −12 tools/convert_diffusers20_original_sd.py
+13 −9 tools/detect_face_rotate.py
+16 −10 tools/latent_upscaler.py
+171 −0 tools/merge_models.py
+289 −257 tools/original_control_net.py
+6 −3 tools/resize_images_to_resolution.py
+23 −0 tools/show_metadata.py
+620 −0 train_controlnet.py
+99 −81 train_db.py
+879 −694 train_network.py
+626 −454 train_textual_inversion.py
+101 −74 train_textual_inversion_XTI.py
9 changes: 8 additions & 1 deletion liasece_sd_webui_train_tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import pkg_resources

from liasece_sd_webui_train_tools.ArgsList import ArgStore
from liasece_sd_webui_train_tools.util import *
from modules import script_loading

import liasece_sd_webui_train_tools.sd_scripts.train_network as train_network
import liasece_sd_webui_train_tools.sd_scripts.sdxl_train_network as sdxl_train_network

import liasece_sd_webui_train_tools.PythonContextWarper as pc
import liasece_sd_webui_train_tools.util as util
Expand All @@ -34,4 +36,9 @@ def train(cfg: ArgStore) -> None:
sub_module=["library", "networks"],
):
# begin training
train_network.train(args)
if cfg.use_sdxl:
trainer = sdxl_train_network.SdxlNetworkTrainer()
else:
trainer = train_network.NetworkTrainer()
printD("train begin", args)
trainer.train(args)
6 changes: 6 additions & 0 deletions liasece_sd_webui_train_tools/train_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ def on_train_begin_click(id: str, project: str, version: str,
train_finish_generate_all_checkpoint_preview: bool,
train_optimizer_type: list[str],
train_learning_rate: str,
sd_script_args: str,
train_net_dim: int,
train_alpha: int,
train_clip_skip: int,
train_mixed_precision: str,
train_xformers: bool,
train_base_on_sd_v2: bool,
use_sdxl: bool, # use sdxl
# preview view config
preview_include_sub_img: bool,
# txt2txt
Expand All @@ -61,12 +63,14 @@ def on_train_begin_click(id: str, project: str, version: str,
"train_finish_generate_all_checkpoint_preview": train_finish_generate_all_checkpoint_preview,
"train_optimizer_type": train_optimizer_type,
"train_learning_rate": train_learning_rate,
"sd_script_args": sd_script_args,
"train_net_dim": int(train_net_dim),
"train_alpha": int(train_alpha),
"train_clip_skip": int(train_clip_skip),
"train_mixed_precision": train_mixed_precision,
"train_xformers": train_xformers,
"train_base_on_sd_v2": train_base_on_sd_v2,
"use_sdxl": use_sdxl,
})
save_preview_config(project, version, {
# preview view config
Expand Down Expand Up @@ -119,6 +123,8 @@ def on_train_begin_click(id: str, project: str, version: str,
cfg.mixed_precision = train_mixed_precision
cfg.xformers = train_xformers
cfg.v2 = train_base_on_sd_v2
cfg.use_sdxl = use_sdxl
cfg.ext_sd_script_args = sd_script_args
# check if reg path exist
if os.path.exists(os.path.join(processed_path, "..", "reg")):
cfg.reg_img_folder = os.path.abspath(os.path.join(processed_path, "..", "reg"))
Expand Down
8 changes: 6 additions & 2 deletions liasece_sd_webui_train_tools/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def new_ui():
train_base_model_refresh_button = ui.ToolButton(value=ui.refresh_symbol, elem_id="train_base_model_refresh_button")
with gr.Row():
train_base_on_sd_v2 = gr.Checkbox(label="Base on Stable Diffusion V2", value=False, elem_id="train_base_on_sd_v2", interactive = True)
use_sdxl = gr.Checkbox(label="Base on Stable Diffusion XL", value=False, elem_id="use_sdxl", interactive = True)
with gr.Row():
train_xformers = gr.Checkbox(label="Use xformers", value=True, elem_id="train_xformers", interactive = True)
with gr.Row():
Expand All @@ -139,11 +140,12 @@ def new_ui():
with gr.Column():
train_batch_size = gr.Number(value=1, label="Batch size", elem_id="train_batch_size", interactive = True)
train_num_epochs = gr.Number(value=40, label="Number of epochs", elem_id="train_num_epochs", interactive = True)
train_learning_rate = gr.Textbox(value="0.0001", label="Learning rate", elem_id="train_learning_rate", interactive = True)
train_learning_rate = gr.Textbox(value="0.0001", label="Learning rate(Multi-select e.g. 0.0001,0.0002)", elem_id="train_learning_rate", interactive = True)
sd_script_args = gr.Textbox(value="", label="Append or override the sd_script args. (e.g. `--lr_scheduler=\"constant_with_warmup\" --max_grad_norm=0.0`)", elem_id="sd_script_args", interactive = True)
with gr.Column():
train_net_dim = gr.Number(value=128, label="Net dim (128 ~ 144MB)", elem_id="train_net_dim", interactive = True)
train_alpha = gr.Number(value=64, label="Alpha (default is half of Net dim)", elem_id="train_alpha", interactive = True)
train_optimizer_type = gr.Dropdown(label="Optimizer type",value=["Lion"], choices=["Adam", "AdamW", "AdamW8bit", "Lion", "SGDNesterov", "SGDNesterov8bit", "DAdaptation", "AdaFactor"], multiselect = True, interactive = True, elem_id="train_optimizer_type")
train_optimizer_type = gr.Dropdown(label="Optimizer type(Multi-select)",value=["Lion"], choices=["Adam", "AdamW", "AdamW8bit", "Lion", "SGDNesterov", "SGDNesterov8bit", "DAdaptation", "AdaFactor"], multiselect = True, interactive = True, elem_id="train_optimizer_type")
train_mixed_precision = gr.Dropdown(label="Mixed precision (If your graphics card supports bf16 better)",value="fp16", choices=["fp16", "bf16"], interactive = True, elem_id="train_mixed_precision")
with gr.Row():
with gr.Column(scale=2):
Expand Down Expand Up @@ -248,12 +250,14 @@ def train_config_inputs():
train_finish_generate_all_checkpoint_preview,
train_optimizer_type,
train_learning_rate,
sd_script_args,
train_net_dim,
train_alpha,
train_clip_skip,
train_mixed_precision,
train_xformers,
train_base_on_sd_v2,
use_sdxl,
]
def preview_config_inputs():
return [
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
accelerate==0.15.0
transformers
accelerate==0.25.0
diffusers[torch]==0.25.0
transformers==4.36.2
ftfy
albumentations
opencv-python
einops
diffusers[torch]==0.10.2
pytorch-lightning
bitsandbytes
tensorboard
Expand Down

0 comments on commit a71f6e9

Please sign in to comment.