diff --git a/README.md b/README.md index e701d0fbf..d16effd9b 100644 --- a/README.md +++ b/README.md @@ -279,6 +279,50 @@ python -m training.main \ --name "10_unfrozen" \ --report-to "tensorboard" \ ``` +### Training with pre-trained image tower and custom text tower: +Here is an example to initialize the image tower with ViT-B-32 pretrained by `laion` while using `bert-base-uncased` as the text tower. +```bash +#!/bin/bash +python \ + -m training.main \ + --pretrained laion2b_s34b_b79k \ + --pretrained-image \ + --pretrained-cache-dir ./laion-pretrained-models \ + --model bert-base-uncased-laion-ViT-B-32 \ + --lock-image \ + --lock-image-freeze-bn-stats \ + --lock-text \ + --train-data="pipe:aws s3 cp s3://s-mas/cc3m/{00000..00329}.tar -" \ + --train-num-samples 3000000 \ + --val-data="pipe:aws s3 cp s3://s-mas/cc3m/{00330..00331}.tar -" \ + --val-num-samples 10000 \ + --dataset-type webdataset \ + --batch-size 256 \ + --warmup 2000 \ + --epochs 10 \ + --lr 5e-4 \ + --precision amp \ + --workers 6 \ + --gather-with-grad \ + --local-loss \ +``` +The arguments need to be clarified a little. Initializing the image tower with supported pretrained weights is triggered by setting `pretrained` and `pretrained-image` together. The `model` argument should point to a config json file with such `vision_cfg` syntax: +```json +{ + "embed_dim": 512, + "vision_cfg": { + "model_name": "ViT-B-32", + "pretrained": "laion2b_s34b_b79k" + }, + "text_cfg": { + "hf_model_name": "bert-base-uncased", + "hf_tokenizer_name": "bert-base-uncased", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} +``` +Where the `model_name` attribute of `vision_cfg` correspond to the actual structure of pretrained model. The `pretrained` attribute defaults to `openai` and would override the argument passed by `pretrained` if `pretrained-image` is also set. ### Loss Curves diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index bf07009bc..d108fb4c7 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -1,7 +1,7 @@ import json import logging import os -import pathlib +import torch import re from copy import deepcopy from pathlib import Path @@ -87,15 +87,90 @@ def load_state_dict(checkpoint_path: str, map_location='cpu'): return state_dict -def load_checkpoint(model, checkpoint_path, strict=True): +def load_checkpoint( + model, + checkpoint_path=None, + strict=True, + which_pretrained_image_tower=None, + pretrained_image_tower=None, + ): + if pretrained_image_tower is not None: + setattr(model, 'visual', pretrained_image_tower) + return None + state_dict = load_state_dict(checkpoint_path) # detect old format and make compatible with new format if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): state_dict = convert_to_custom_text_state_dict(state_dict) + resize_pos_embed(state_dict, model) + + if which_pretrained_image_tower is not None: + state_dict = filter(lambda i: i[0].startswith("visual"), state_dict.items()) + prefix_length = len("visual") + state_dict = map(lambda i: (i[0][prefix_length+1:], i[1]), state_dict) + return model.visual.load_state_dict(dict(state_dict), strict=strict) + incompatible_keys = model.load_state_dict(state_dict, strict=strict) return incompatible_keys +def get_cfg_and_handle_error(model_name): + model_cfg = get_model_config(model_name) + if model_cfg is not None: + logging.info(f'Loaded {model_name} model config.') + else: + logging.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + return model_cfg + +def load_and_prepare_cfg( + model_name, + force_quick_gelu, + force_patch_dropout, + pretrained_image, + force_custom_text, + pretrained_hf, + ): + '''Decouples cfg loading and updating.''' + model_cfg = get_cfg_and_handle_error(model_name) + + # handle pretrained image + which_pretrained_image_tower = None + vision_cfg = model_cfg.get('vision_cfg', {}) + pretrained_image = pretrained_image or 'model_name' in vision_cfg or 'pretrained' in vision_cfg + + if pretrained_image: + if 'timm_model_name' in vision_cfg: + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True + + # elif init image tower from pre-defined and pre-trained model + elif 'model_name' in vision_cfg: + pretrained_image_model_name = vision_cfg.get('model_name') + pretrained_image_model_cfg = get_cfg_and_handle_error(pretrained_image_model_name) + model_name = pretrained_image_model_name + + which_pretrained_image_tower = vision_cfg.get('pretrained', 'openai') + model_cfg["vision_cfg"] = pretrained_image_model_cfg["vision_cfg"] + + else: + assert False, 'Unintended logic triggered, please debug or implement this block.' + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + # for `custom_text` + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or ('hf_model_name' in model_cfg.get('text_cfg', {})) + if custom_text: + if 'hf_model_name' in model_cfg.get('text_cfg', {}): + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf + + return model_name, model_cfg, custom_text, which_pretrained_image_tower def create_model( model_name: str, @@ -111,10 +186,28 @@ def create_model( cache_dir: Optional[str] = None, ): model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + if isinstance(device, str): device = torch.device(device) - if pretrained and pretrained.lower() == 'openai': + cast_dtype = get_cast_dtype(precision) + + # load and prepare model config + model_name, model_cfg, custom_text, which_pretrained_image_tower = load_and_prepare_cfg( + model_name=model_name, + force_quick_gelu=force_quick_gelu, + force_patch_dropout=force_patch_dropout, + pretrained_image=pretrained_image, + force_custom_text=force_custom_text, + pretrained_hf=pretrained_hf, + ) + + extract_openai_image_tower = which_pretrained_image_tower is not None and which_pretrained_image_tower.lower()=='openai' + pure_openai = pretrained and pretrained.lower() == 'openai' and not extract_openai_image_tower + process_openai = extract_openai_image_tower or pure_openai + pretrained_image_tower = None + + if process_openai: logging.info(f'Loading pretrained {model_name} from OpenAI.') model = load_openai_model( model_name, @@ -123,57 +216,41 @@ def create_model( jit=jit, cache_dir=cache_dir, ) - else: - model_cfg = get_model_config(model_name) - if model_cfg is not None: - logging.info(f'Loaded {model_name} model config.') - else: - logging.error(f'Model config for {model_name} not found; available models {list_models()}.') - raise RuntimeError(f'Model config for {model_name} not found.') - - if force_quick_gelu: - # override for use of QuickGELU on non-OpenAI transformer models - model_cfg["quick_gelu"] = True - - if force_patch_dropout is not None: - # override the default patch dropout value - model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout - - if pretrained_image: - if 'timm_model_name' in model_cfg.get('vision_cfg', {}): - # pretrained weight loading for timm models set via vision_cfg - model_cfg['vision_cfg']['timm_model_pretrained'] = True - else: - assert False, 'pretrained image towers currently only supported for timm models' - - cast_dtype = get_cast_dtype(precision) - custom_text = model_cfg.pop('custom_text', False) or force_custom_text or ('hf_model_name' in model_cfg.get('text_cfg', {})) + pretrained_image_tower = getattr(model, 'visual', None) if extract_openai_image_tower else None + if not pure_openai: if custom_text: - if 'hf_model_name' in model_cfg.get('text_cfg', {}): - model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) else: model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + # This seems unnecessary, since only the image tower will be referenced + # torch.cuda.empty_cache() pretrained_cfg = {} - if pretrained: - checkpoint_path = '' + pretrained = which_pretrained_image_tower if which_pretrained_image_tower is not None else pretrained + checkpoint_path = None + if pretrained and not process_openai: pretrained_cfg = get_pretrained_cfg(model_name, pretrained) if pretrained_cfg: checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) elif os.path.exists(pretrained): checkpoint_path = pretrained - if checkpoint_path: - logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') - load_checkpoint(model, checkpoint_path) - else: - error_str = ( - f'Pretrained weights ({pretrained}) not found for model {model_name}.' - f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') - logging.warning(error_str) - raise RuntimeError(error_str) + if checkpoint_path or extract_openai_image_tower: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint( + model, + checkpoint_path, + which_pretrained_image_tower=which_pretrained_image_tower, + pretrained_image_tower=pretrained_image_tower, + ) + elif pretrained: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logging.warning(error_str) + raise RuntimeError(error_str) model.to(device=device) if precision in ("fp16", "bf16"): diff --git a/src/open_clip/hf_configs.py b/src/open_clip/hf_configs.py index e236222ba..d902f230b 100644 --- a/src/open_clip/hf_configs.py +++ b/src/open_clip/hf_configs.py @@ -1,5 +1,18 @@ # HF architecture dict: arch_dict = { + # https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertConfig + "bert": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, # https://huggingface.co/docs/transformers/model_doc/roberta#roberta "roberta": { "config_names": { diff --git a/src/open_clip/model_configs/bert-base-uncased-laion-ViT-B-32.json b/src/open_clip/model_configs/bert-base-uncased-laion-ViT-B-32.json new file mode 100644 index 000000000..2e5a1e861 --- /dev/null +++ b/src/open_clip/model_configs/bert-base-uncased-laion-ViT-B-32.json @@ -0,0 +1,13 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "model_name": "ViT-B-32", + "pretrained": "laion2b_s34b_b79k" + }, + "text_cfg": { + "hf_model_name": "bert-base-uncased", + "hf_tokenizer_name": "bert-base-uncased", + "proj": "mlp", + "pooler_type": "mean_pooler" + } +} \ No newline at end of file diff --git a/src/training/main.py b/src/training/main.py index 3514d130b..2edecf086 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -175,6 +175,7 @@ def main(args): pretrained_image=args.pretrained_image, image_mean=args.image_mean, image_std=args.image_std, + cache_dir=args.pretrained_cache_dir, ) random_seed(args.seed, args.rank) diff --git a/src/training/params.py b/src/training/params.py index abc07dd50..3ae172787 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -163,6 +163,12 @@ def parse_args(args): type=str, help="Use a pretrained CLIP model weights with the specified tag or file path.", ) + parser.add_argument( + "--pretrained-cache-dir", + default=None, + type=str, + help="Cache dir for storing downloaded pretrained weights", + ) parser.add_argument( "--pretrained-image", default=False,