Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poor reconstruction usinf DAE pre-trained weights #390

Open
OeslleLucena opened this issue Jun 14, 2024 · 1 comment
Open

Poor reconstruction usinf DAE pre-trained weights #390

OeslleLucena opened this issue Jun 14, 2024 · 1 comment

Comments

@OeslleLucena
Copy link

I am trying to load the pretrained weights from the tutorial https://github.com/Project-MONAI/tutorials/blob/main/self_supervised_pretraining/vit_unetr_ssl/ssl_finetune.ipynb that were trained using the DAE method https://github.com/Project-MONAI/research-contributions/tree/main/DAE/Pretrain_full_contrast. The weights used are the "ssl_pretrained_weights.pth" available at https://github.com/Project-MONAI/MONAI-extra-test-data/releases.

When I run the following code to see how well I can reconstruct one sample for the HNSCC dataset:

`

   swinUnetr = SwinUNETR(
        img_size=(96, 96, 96),
        in_channels=1,
        out_channels=1,
        feature_size=48,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        dropout_path_rate=0.0,
        use_checkpoint=True,
    )

    pretrained_path = (
        "/home/ol18/Downloads/ssl_pretrained_weights.pth"
    )
    model_dict = torch.load(pretrained_path)["model"]
    pretrained_weights_keys = model_dict.keys()
    store_dict = swinUnetr.state_dict()
    net_keys = store_dict.keys()

    print(len(net_keys), len(pretrained_weights_keys))

    count = 0
    del model_dict["encoder.mask_token"]
    del model_dict["encoder.norm.weight"]
    del model_dict["encoder.norm.bias"]

    for key, value in model_dict.items():

        if key[:8] == "encoder.":
            if key[8:19] == "patch_embed":
                new_key = "swinViT." + key[8:]
            else:
                new_key = "swinViT." + key[8:18] + key[20:]
            store_dict[new_key] = value
            count += 1
        elif key in net_keys:
            store_dict[key] = value
            count += 1
        else:
            print(key)

    print(count)
    swinUnetr.load_state_dict(store_dict)

    for key in batch.keys():
        # get inputs
        inputs = batch[key]["image"]
        B, C, H, W, Z = inputs.shape
        noise = (0.1**0.5) * torch.randn(B, C, H, W, Z).to(
            inputs.device
        )
        img_noisy = inputs + noise
        img_lowres = F.interpolate(
            img_noisy, size=(int(96 / 4), int(96 / 4), int(96 / 4))
        )
        img_resam = F.interpolate(img_lowres, size=(96, 96, 96))

        swinUnetr.to(inputs.device)
        x_rec = swinUnetr(img_resam)

        self._log_image_tensorboard(
            inputs[0],
            img_resam[0],
            x_rec[0],
            f"test_images_0_{key}",
        )

`

I've got the following outputs for the original image, noisy and noisy + low resolution:

original image as input

no_augmentations

noisy image as input

noisy

noisy + low resolution image as input

all_augmentations

From my understanding these reconstructions are fairly poor for the amount of pretraining that is done. Am I missing something in here? Can someone give me a direction?

Many thanks

@tangy5

@yalcintur
Copy link

I also had similar results like yours. Were you able to figure it out?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants