Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

learnable_codebook and in-place optimization #97

Open
daraha76 opened this issue Dec 27, 2023 · 1 comment
Open

learnable_codebook and in-place optimization #97

daraha76 opened this issue Dec 27, 2023 · 1 comment

Comments

@daraha76
Copy link

Hi, I tried to use both codebook loss and commitment loss instead of EMA update, but I was confused about how to use codebook loss.

If 'learable_codebook' is True, then 'commit_quantize' is not detached from 'quantize',

            maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity

            commit_quantize = maybe_detach(quantize)

and 'commit_loss' will serve as both codebook loss and commitment loss.

commit_loss = F.mse_loss(commit_quantize, x)

So, just setting 'learable_codebook' as True is all I need for applying codebook loss, or need to implement codebook loss seperately?

Also, if 'commit_loss' serves as codebook loss, then I think 'commit_quantize' should be detached when using in-place update, so that codebook will not be updated twice!

        # one step in-place update

        if should_inplace_optimize and self.training and not freeze_codebook:

            if exists(mask):
                loss = F.mse_loss(quantize, x.detach(), reduction = 'none')

                loss_mask = mask
                if is_multiheaded:
                    loss_mask = repeat(mask, 'b n -> c (b h) n', c = loss.shape[0], h = loss.shape[1] // mask.shape[0])

                loss = loss[loss_mask].mean()

            else:
                loss = F.mse_loss(quantize, x.detach())

            loss.backward()
            self.in_place_codebook_optimizer.step()
            self.in_place_codebook_optimizer.zero_grad()

            # quantize again

            quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
@DarkDawn233
Copy link

I think that F.mse_loss(commit_quantize, x) is equivalent to F.mse_loss(commit_quantize.detach(), x) + F.mse_loss(commit_quantize, x.detach()). However, It's unable to configure the coefficient of these two terms in the original formula. If coefficient adjustment is necessary, perhaps an additional update could be performed using the in-place optimize?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants