Skip to content

Commit

Permalink
change axes order of params
Browse files Browse the repository at this point in the history
  • Loading branch information
lkct committed Jul 9, 2023
1 parent 30be740 commit 1985e09
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 26 deletions.
12 changes: 6 additions & 6 deletions cirkit/layers/einsum/cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def __init__( # type: ignore[misc]
super().__init__(rg_nodes, num_input_units, num_output_units)
self.prod_exp = prod_exp

self.params_left = nn.Parameter(torch.empty(num_input_units, rank, len(rg_nodes)))
self.params_right = nn.Parameter(torch.empty(num_input_units, rank, len(rg_nodes)))
self.params_out = nn.Parameter(torch.empty(num_output_units, rank, len(rg_nodes)))
self.params_left = nn.Parameter(torch.empty(len(rg_nodes), num_input_units, rank))
self.params_right = nn.Parameter(torch.empty(len(rg_nodes), num_input_units, rank))
self.params_out = nn.Parameter(torch.empty(len(rg_nodes), rank, num_output_units))

# TODO: get torch.default_float_dtype
# (float ** float) is not guaranteed to be float, but here we know it is
Expand All @@ -55,13 +55,13 @@ def __init__( # type: ignore[misc]

# TODO: use bmm to replace einsum? also axis order?
def _forward_left_linear(self, x: Tensor) -> Tensor:
return torch.einsum("fkr,fkb->frb", self.params_left.permute(2, 0, 1), x)
return torch.einsum("fkr,fkb->frb", self.params_left, x)

def _forward_right_linear(self, x: Tensor) -> Tensor:
return torch.einsum("fkr,fkb->frb", self.params_right.permute(2, 0, 1), x)
return torch.einsum("fkr,fkb->frb", self.params_right, x)

def _forward_out_linear(self, x: Tensor) -> Tensor:
return torch.einsum("frk,frb->fkb", self.params_out.permute(2, 1, 0), x)
return torch.einsum("frk,frb->fkb", self.params_out, x)

def _forward_linear(self, left: Tensor, right: Tensor) -> Tensor:
left_hidden = self._forward_left_linear(left)
Expand Down
4 changes: 2 additions & 2 deletions cirkit/layers/einsum/mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(self, rg_nodes: List[RegionNode], num_output_units: int, max_compon

# TODO: test best perf?
# param_shape = (len(self.nodes), self.max_components) for better perf
self.params = nn.Parameter(torch.empty(num_output_units, len(rg_nodes), max_components))
self.params = nn.Parameter(torch.empty(max_components, len(rg_nodes), num_output_units))
# TODO: what's the use of params_mask?
self.register_buffer("params_mask", torch.ones_like(self.params))
self.param_clamp_value["min"] = torch.finfo(self.params.dtype).smallest_normal
Expand All @@ -90,7 +90,7 @@ def apply_params_mask(self) -> None:
self.params /= self.params.sum(dim=2, keepdim=True) # type: ignore[misc]

def _forward_linear(self, x: Tensor) -> Tensor:
return torch.einsum("cfk,cfkb->fkb", self.params.permute(2, 1, 0), x)
return torch.einsum("cfk,cfkb->fkb", self.params, x)

# TODO: make forward return something
# pylint: disable-next=arguments-differ
Expand Down
38 changes: 20 additions & 18 deletions tests/models/test_einet.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ def _get_einet() -> TensorizedPC:
def _get_param_shapes() -> Dict[str, Tuple[int, ...]]:
return {
"input_layer.params": (4, 1, 1, 2),
"inner_layers.0.params_left": (1, 1, 4),
"inner_layers.0.params_right": (1, 1, 4),
"inner_layers.0.params_out": (1, 1, 4),
"inner_layers.1.params_left": (1, 1, 2),
"inner_layers.1.params_right": (1, 1, 2),
"inner_layers.1.params_out": (1, 1, 2),
"inner_layers.2.params": (1, 1, 2),
"inner_layers.0.params_left": (4, 1, 1),
"inner_layers.0.params_right": (4, 1, 1),
"inner_layers.0.params_out": (4, 1, 1),
"inner_layers.1.params_left": (2, 1, 1),
"inner_layers.1.params_right": (2, 1, 1),
"inner_layers.1.params_out": (2, 1, 1),
"inner_layers.2.params": (2, 1, 1),
}


Expand All @@ -109,15 +109,15 @@ def _set_params(einet: TensorizedPC) -> None:
[math.log(3), 0], # type: ignore[misc] # 3/4, 1/4
]
).reshape(4, 1, 1, 2),
"inner_layers.0.params_left": torch.ones(1, 1, 4) / 2,
"inner_layers.0.params_right": torch.ones(1, 1, 4) * 2,
"inner_layers.0.params_out": torch.ones(1, 1, 4),
"inner_layers.1.params_left": torch.ones(1, 1, 2) * 2,
"inner_layers.1.params_right": torch.ones(1, 1, 2) / 2,
"inner_layers.1.params_out": torch.ones(1, 1, 2),
"inner_layers.0.params_left": torch.ones(4, 1, 1) / 2,
"inner_layers.0.params_right": torch.ones(4, 1, 1) * 2,
"inner_layers.0.params_out": torch.ones(4, 1, 1),
"inner_layers.1.params_left": torch.ones(2, 1, 1) * 2,
"inner_layers.1.params_right": torch.ones(2, 1, 1) / 2,
"inner_layers.1.params_out": torch.ones(2, 1, 1),
"inner_layers.2.params": torch.tensor(
[1 / 3, 2 / 3], # type: ignore[misc]
).reshape(1, 1, 2),
).reshape(2, 1, 1),
}
)
einet.load_state_dict(state_dict) # type: ignore[misc]
Expand Down Expand Up @@ -159,9 +159,9 @@ def test_einet_partition_func() -> None:
@pytest.mark.parametrize( # type: ignore[misc]
"rg_cls,kwargs,log_answer",
[
(PoonDomingos, {"shape": [4, 4], "delta": 2}, 10.188161849975586),
(QuadTree, {"width": 4, "height": 4, "struct_decomp": False}, 51.31766128540039),
(RandomBinaryTree, {"num_vars": 16, "depth": 3, "num_repetitions": 2}, 24.198360443115234),
(PoonDomingos, {"shape": [4, 4], "delta": 2}, 10.935434341430664),
(QuadTree, {"width": 4, "height": 4, "struct_decomp": False}, 44.412864685058594),
(RandomBinaryTree, {"num_vars": 16, "depth": 3, "num_repetitions": 2}, 24.313674926757812),
(PoonDomingos, {"shape": [3, 3], "delta": 2}, None),
(QuadTree, {"width": 3, "height": 3, "struct_decomp": False}, None),
(RandomBinaryTree, {"num_vars": 9, "depth": 3, "num_repetitions": 2}, None),
Expand Down Expand Up @@ -223,4 +223,6 @@ def test_einet_partition_function(

assert torch.isclose(einet.partition_function(), sum_out, rtol=1e-6, atol=0)
if log_answer is not None:
assert torch.isclose(sum_out, torch.tensor(log_answer), rtol=1e-6, atol=0)
assert torch.isclose(
sum_out, torch.tensor(log_answer), rtol=1e-6, atol=0
), f"{sum_out.item()}"

0 comments on commit 1985e09

Please sign in to comment.