Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swin-UNETR pretraining model weights do not load into the model #346

Open
Saqeeb95 opened this issue Dec 14, 2023 · 5 comments
Open

Swin-UNETR pretraining model weights do not load into the model #346

Saqeeb95 opened this issue Dec 14, 2023 · 5 comments

Comments

@Saqeeb95
Copy link

Saqeeb95 commented Dec 14, 2023

Describe the bug
The provided pre-trained Swin-UNETR weights do not load into a newly instantiated SSLHead model object. The naming scheme for the model state_dict keys is different between the provided weights and the instantiated SSLHead. Even when renaming the dict keys, there remains a mismatch between some layers; the weights file contains weights for fully connected layers whereas the instantiated model expects them to be linear layers. I think the architecture in the SSLHead in ssl_head.py differs from the architecture of the model that the provided model_swinvit.pt file comes from.

To Reproduce
Steps to reproduce the behavior:

  1. Clone the repo from here
  2. Download the model_swinvit.pt file from here
  3. Place the model_swinvit.pt file in the /models/weights/ directory inside the /SwinUNETR/Pretrain/ directory (create the dirs if needed)
  4. Set up the datasets as per the instructions in the same repo
  5. cd into the /SwinUNETR/Pretrain/ directory and run the following command:
    python main.py --use_checkpoint --batch_size=3 --num_steps=450 --lrdecay --eval_num=100 --logdir="test8" --lr=0.000004 --roi_x=96 --roi_y=96 --roi_z=96 --lr_schedule="poly" --noamp --epochs=15000 --resume="./models/weights/model_swinvit.pt"

Expected behavior
I expected that the model would train when given the "--resume" argument which pointed to the model_swinvit.pt file obtained from the Swin-UNETR pretraining page in the Pre-trained Models section.

Screenshots
Here is the mismatching keys error I get when trying to load the model into a newly initialized SSLHead model:
full_err
This is pretty hard to read but the gist is that the naming scheme seems to differ between the provided weights and the SSLHead model when initialized. Also there are more "Unexpected" weights listed but I couldn't fit them into the screenshot.

I noticed that many of the layers just had the "swinViT" part replaced with "module" so I tried renaming the dict keys to match as well as I could. Here is a screenshot of the remaining mismatched keys afterwards:
image
The remaining mismatched keys seem to be due to the fact that the SSLHead model expects two linear layers per encoder level (e.g. "swinViT.layers1.0.blocks.0.mlp.linear1.weight"), whereas the provided weights have two fully connected layers instead (e.g. "swinViT.layers1.0.blocks.0.mlp.fc1.weight"). The last mismatches are the "swinViT.norm.weight", "swinViT.norm.bias" that don't appear to be needed

Environment (please complete the following information):

Additional context
The model can train from scratch just fine, and I can also load checkpoints from my own previous runs. I just can't load the state_dict from the provided weights file.

@Saqeeb95 Saqeeb95 changed the title Swin-UNETR pretraining model weight do not load into Swin-UNETR pretraining model weights do not load into the model Dec 14, 2023
@prateekgrover-in
Copy link

Hi Saqeeb, Were you able to find a solution to this issue? Facing the same problem!

@Saqeeb95
Copy link
Author

Unfortunately not. I tried renaming the dict keys in the loaded weights to match the ones the model expects, but some of them don't match up. I think the weights provided are actually for a slightly different architecture based on the differences between the expected dict keys and the ones in the weights file.

@prateekgrover-in
Copy link

Unfortunately not. I tried renaming the dict keys in the loaded weights to match the ones the model expects, but some of them don't match up. I think the weights provided are actually for a slightly different architecture based on the differences between the expected dict keys and the ones in the weights file.

Yeah, it copies 126 of the 159 weights from the file for me. With those missing weights it's very hard to fine-tune I think, or were you able to find another way?

@Saqeeb95
Copy link
Author

No sorry I wasn't. I had my own dataset of ~1500 images I could pretrain on so I just moved on and did my own pretraining.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants
@Saqeeb95 @prateekgrover-in and others