Skip to content

Commit

Permalink
minor fix device
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Feb 5, 2025
1 parent 0cd00d8 commit 72394ca
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions cirkit/backend/torch/layers/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def forward(self, x: Tensor) -> Tensor:
x = x.long() # The input to Embedding should be discrete
x = x.squeeze(dim=2) # (F, B)
weight = self.weight()
idx_fold = torch.arange(self.num_folds)
idx_fold = torch.arange(self.num_folds, device=weight.device)
x = weight[idx_fold[:, None], :, x]
x = self.semiring.map_from(x, SumProductSemiring)
return x # (F, B, K)
Expand Down Expand Up @@ -404,7 +404,7 @@ def log_unnormalized_likelihood(self, x: Tensor) -> Tensor:
x = x.squeeze(dim=2)
# logits: (F, K, N)
logits = torch.log(self.probs()) if self.logits is None else self.logits()
idx_fold = torch.arange(self.num_folds)
idx_fold = torch.arange(self.num_folds, device=logits.device)
x = logits[idx_fold[:, None], :, x]
return x

Expand Down

0 comments on commit 72394ca

Please sign in to comment.