Skip to content

Commit

Permalink
Merge pull request #340 from april-tools/remove-channels
Browse files Browse the repository at this point in the history
Remove channels
  • Loading branch information
loreloc authored Feb 5, 2025
2 parents 503dc96 + 2627047 commit 134cd5d
Show file tree
Hide file tree
Showing 46 changed files with 755 additions and 1,093 deletions.
28 changes: 11 additions & 17 deletions cirkit/backend/torch/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,16 @@ def lookup(
if in_graph is None:
yield layer, ()
continue
# in_graph: An input batch (assignments to variables) of shape (B, C, D)
# in_graph: An input batch (assignments to variables) of shape (B, D)
# scope_idx: The scope of the layers in each fold, a tensor of shape (F, D'), D' < D
# x: (B, C, D) -> (B, C, F, D') -> (F, C, B, D')
x = in_graph[..., layer.scope_idx].permute(2, 1, 0, 3)
# x: (B, D) -> (B, F, D') -> (F, B, D')
if len(in_graph.shape) != 2:
raise ValueError(
"The input to the circuit should have shape (B, D), "
"where B is the batch size and D is the number of variables "
"the circuit is defined on"
)
x = in_graph[..., layer.scope_idx].permute(1, 0, 2)
yield layer, (x,)
continue

Expand Down Expand Up @@ -121,7 +127,6 @@ class AbstractTorchCircuit(TorchDiAcyclicGraph[TorchLayer]):
def __init__(
self,
scope: Scope,
num_channels: int,
layers: Sequence[TorchLayer],
in_layers: dict[TorchLayer, Sequence[TorchLayer]],
outputs: Sequence[TorchLayer],
Expand All @@ -133,7 +138,6 @@ def __init__(
Args:
scope: The variables scope.
num_channels: The number of channels per variable.
layers: The sequence of layers.
in_layers: A dictionary mapping layers to their inputs, if any.
outputs: A list of output layers.
Expand All @@ -148,7 +152,6 @@ def __init__(
fold_idx_info=fold_idx_info,
)
self._scope = scope
self._num_channels = num_channels
self._properties = properties

@property
Expand All @@ -169,15 +172,6 @@ def num_variables(self) -> int:
"""
return len(self.scope)

@property
def num_channels(self) -> int:
"""Retrieve the number of channels of each variable.
Returns:
The number of variables.
"""
return self._num_channels

@property
def properties(self) -> StructuralProperties:
"""Retrieve the structural properties of the circuit.
Expand Down Expand Up @@ -272,8 +266,8 @@ def forward(self, x: Tensor) -> Tensor:
following the topological ordering.
Args:
x: The tensor input of the circuit, with shape $(B, C, D)$, where B is the batch size,
$C$ is the number of channels, and $D$ is the number of variables.
x: The tensor input of the circuit, with shape $(B, D)$, where B is the batch size,
and $D$ is the number of variables.
Returns:
Tensor: The tensor output of the circuit, with shape $(B, O, K)$,
Expand Down
4 changes: 1 addition & 3 deletions cirkit/backend/torch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ def _compile_circuit(self, sc: Circuit) -> AbstractTorchCircuit:
layers = list(compiled_layers_map.values())
cc = cc_cls(
sc.scope,
sc.num_channels,
layers=layers,
in_layers=in_layers,
outputs=outputs,
Expand Down Expand Up @@ -275,7 +274,6 @@ def _fold_circuit(compiler: TorchCompiler, cc: AbstractTorchCircuit) -> Abstract
# Instantiate a folded circuit
return type(cc)(
cc.scope,
cc.num_channels,
layers,
in_layers,
outputs,
Expand Down Expand Up @@ -507,7 +505,7 @@ def match_optimizer_fuse(match: LayerOptMatch) -> tuple[TorchLayer, ...]:
if optimize_result is None:
return cc, False
layers, in_layers, outputs = optimize_result
cc = type(cc)(cc.scope, cc.num_channels, layers, in_layers, outputs, properties=cc.properties)
cc = type(cc)(cc.scope, layers, in_layers, outputs, properties=cc.properties)
return cc, True


Expand Down
18 changes: 9 additions & 9 deletions cirkit/backend/torch/layers/inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,11 @@ def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
if negative or not normalized:
raise TypeError("Sampling in sum layers only works with positive weights summing to 1")

# x: (F, H, C, Ki, num_samples, D) -> (F, C, H * Ki, num_samples, D)
x = x.permute(0, 2, 1, 3, 4, 5).flatten(2, 3)
c = x.shape[1]
num_samples = x.shape[3]
d = x.shape[4]
# x: (F, H, Ki, num_samples, D) -> (F, H * Ki, num_samples, D)
x = x.flatten(1, 2)

num_samples = x.shape[2]
d = x.shape[3]

# mixing_distribution: (F, Ko, H * Ki)
mixing_distribution = torch.distributions.Categorical(probs=weight)
Expand All @@ -289,9 +289,9 @@ def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
mixing_samples = mixing_distribution.sample((num_samples,))
mixing_samples = E.rearrange(mixing_samples, "n f k -> f k n")

# mixing_indices: (F, C, Ko, num_samples, D)
mixing_indices = E.repeat(mixing_samples, "f k n -> f c k n d", c=c, d=d)
# mixing_indices: (F, Ko, num_samples, D)
mixing_indices = E.repeat(mixing_samples, "f k n -> f k n d", d=d)

# x: (F, C, Ko, num_samples, D)
x = torch.gather(x, dim=2, index=mixing_indices)
# x: (F, Ko, num_samples, D)
x = torch.gather(x, dim=1, index=mixing_indices)
return x, mixing_samples
Loading

0 comments on commit 134cd5d

Please sign in to comment.