Skip to content

Commit

Permalink
Merge pull request #335 from april-tools/address-book-gpu
Browse files Browse the repository at this point in the history
Made address book fold indices buffers
  • Loading branch information
loreloc authored Feb 5, 2025
2 parents 1b39231 + a97e903 commit c446634
Show file tree
Hide file tree
Showing 13 changed files with 427 additions and 391 deletions.
22 changes: 12 additions & 10 deletions cirkit/backend/torch/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,24 @@ def lookup(
self, module_outputs: list[Tensor], *, in_graph: Tensor | None = None
) -> Iterator[tuple[TorchLayer | None, tuple]]:
# Loop through the entries and yield inputs
for entry in self._entries:
for entry in self:
layer = entry.module
in_layer_ids = entry.in_module_ids
in_fold_idx = entry.in_fold_idx
# Catch the case there are some inputs coming from other modules
if entry.in_module_ids:
(in_fold_idx,) = entry.in_fold_idx
(in_module_ids,) = entry.in_module_ids
if len(in_module_ids) == 1:
x = module_outputs[in_module_ids[0]]
if in_layer_ids:
in_fold_idx_h = in_fold_idx[0]
in_layer_ids_h = in_layer_ids[0]
if len(in_layer_ids_h) == 1:
x = module_outputs[in_layer_ids_h[0]]
else:
x = torch.cat([module_outputs[mid] for mid in in_module_ids], dim=0)
x = x[in_fold_idx]
yield entry.module, (x,)
x = torch.cat([module_outputs[mid] for mid in in_layer_ids_h], dim=0)
x = x[in_fold_idx_h]
yield layer, (x,)
continue

# Catch the case there are no inputs coming from other modules
# That is, we are gathering the inputs of input layers
layer = entry.module
assert isinstance(layer, TorchInputLayer)
if layer.num_variables:
if in_graph is None:
Expand Down
36 changes: 20 additions & 16 deletions cirkit/backend/torch/graph/folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,11 @@ def build_address_book_stacked_entry(
in_module_ids = list(dict.fromkeys(idx[0] for fi in in_fold_idx for idx in fi))

# Compute the cumulative indices of the folded inputs
module_fold_sizes = [num_folds[mid] for mid in in_module_ids]
cum_module_ids = dict(
zip(
in_module_ids,
itertools.accumulate([0] + [num_folds[mid] for mid in in_module_ids]),
itertools.accumulate([0] + module_fold_sizes),
)
)

Expand All @@ -183,14 +184,17 @@ def build_address_book_stacked_entry(
return AddressBookEntry(module, [in_module_ids], [cum_fold_idx_t])

# If we are computing a non-output stacked address book entry,
# then check if the fold index would be equivalent to an 'unsqueeze' on dimension 0.
# If so, then replace the fold index with None as this would result in a more efficient unsqueeze
useless_fold_idx = False
if len(cum_fold_idx) == 1:
fold_size = sum(num_folds[mid] for mid in in_module_ids)
useless_fold_idx = cum_fold_idx[0] == list(range(fold_size))
cum_fold_idx_t = None if useless_fold_idx else torch.tensor(cum_fold_idx)

# then check if the fold index would be equivalent to an 'unsqueeze' on dimensions 0 or 1.
# If so, then replace the fold index with a more efficient unsqueezing operation
fold_size = sum(module_fold_sizes)
if [i for idx in cum_fold_idx for i in idx] == list(range(fold_size)):
if len(cum_fold_idx) == 1 and len(cum_fold_idx[0]) == fold_size:
# Equivalent to .unsqueeze(dim=0)
return AddressBookEntry(module, [in_module_ids], [(None,)])
elif len(cum_fold_idx) == fold_size and len(cum_fold_idx[0]) == 1:
# Equivalent to .unsqueeze(dim=1)
return AddressBookEntry(module, [in_module_ids], [(slice(None), None)])
cum_fold_idx_t = torch.tensor(cum_fold_idx)
return AddressBookEntry(module, [in_module_ids], [cum_fold_idx_t])


Expand All @@ -213,18 +217,18 @@ def build_address_book_entry(
dict(zip(mids, itertools.accumulate([0] + [num_folds[mid] for mid in mids])))
for mids in in_module_ids
]
cum_fold_idx_t: list[Tensor | None] = []
cum_fold_idx_t: list[Tensor | tuple] = []
for i, hi in enumerate(in_fold_idx):
cum_fold_i_idx: list[int] = [cum_module_ids[i][idx[0]] + idx[1] for idx in hi]

# The following checks whether using the fold index would yield the same tensor
# If so, then avoid indexing at all
module_id = hi[0][0]
if all(idx[0] == module_id for idx in hi):
fold_size = num_folds[module_id]
useless_fold_idx = cum_fold_i_idx == list(range(fold_size))
if all(idx[0] == module_id for idx in hi) and cum_fold_i_idx == list(
range(num_folds[module_id])
):
cum_fold_i_idx_t = ()
else:
useless_fold_idx = False
cum_fold_i_idx_t = None if useless_fold_idx else torch.tensor(cum_fold_i_idx)
cum_fold_i_idx_t = torch.tensor(cum_fold_i_idx)
cum_fold_idx_t.append(cum_fold_i_idx_t)

return AddressBookEntry(module, in_module_ids, cum_fold_idx_t)
47 changes: 39 additions & 8 deletions cirkit/backend/torch/graph/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ class AddressBookEntry:
compute the output of the whole computational graph."""
in_module_ids: list[list[int]]
"""For each input module, it stores the list of other module indices."""
in_fold_idx: list[Tensor | None]
in_fold_idx: list[Tensor | tuple[slice | None, ...]]
"""For each input module, it stores the fold index tensor used to gather the
input tensors to each fold. It is None whether there is no need of gathering the
input tensors, i.e., if the indexing operation would act as an identity function."""
input tensors to each fold. It is a tuple of optional slices whether there is no need of
gathering the input tensors, i.e., if the indexing operation would act as an unsqueezing
operation that can be much more efficient."""


class AddressBook(nn.Module, ABC):
Expand Down Expand Up @@ -133,27 +134,57 @@ def __init__(self, entries: list[AddressBookEntry]) -> None:
"The last entry of the address book must have only one fold index tensor"
)
(out_fold_idx,) = last_entry.in_fold_idx
if len(out_fold_idx.shape) != 1:
if not isinstance(out_fold_idx, Tensor) or len(out_fold_idx.shape) != 1:
raise ValueError("The output fold index tensor should be a 1-dimensional tensor")
super().__init__()
self._entries = entries
self._num_outputs = out_fold_idx.shape[0]
self._entry_modules: list[TorchModule | None] = [e.module for e in entries]
self._entry_in_module_ids: list[list[list[int]]] = [e.in_module_ids for e in entries]
# We register the book-keeping tensor indices as buffers.
# By doing so they are automatically transferred to the device
# This reduces CPU-device communications required to transfer these indices
#
# TODO: Perhaps this can be made more elegant in the future, if someone
# decides to introduce a nn.BufferList container in torch
self._entry_in_fold_idx_targets: list[list[str]] = []
for i, e in enumerate(entries):
self._entry_in_fold_idx_targets.append([])
for j, fi in enumerate(e.in_fold_idx):
in_fold_idx_target = f"_in_fold_idx_{i}_{j}"
if isinstance(fi, Tensor):
self.register_buffer(in_fold_idx_target, fi)
else:
setattr(self, in_fold_idx_target, fi)
self._entry_in_fold_idx_targets[-1].append(in_fold_idx_target)

def __len__(self) -> int:
"""Retrieve the length of the address book.
Returns:
The number of address book entries.
"""
return len(self._entries)
return len(self._entry_modules)

def __iter__(self) -> Iterator[AddressBookEntry]:
"""Retrieve an iterator over address book entries.
"""Retrieve an iterator over address book entries, i.e., a tuple consisting of
three objects: (i) the torch module to evaluate (it can be None if the entry
is needed to return the output of the computational graph); (ii) for each input
to the module (i.e., depending on the arity) we have the list of ids to the
outputs of other modules (it can be empty if the module is an input module); and
(iii) for each input to the module we have the fold indexing, which
is used to retrieve the inputs to a module, even if they are folded modules.
Returns:
An iterator over address book entries.
"""
return iter(self._entries)
for module, in_module_ids_hs, in_fold_idx_targets in zip(
self._entry_modules, self._entry_in_module_ids, self._entry_in_fold_idx_targets
):
yield AddressBookEntry(
module,
in_module_ids_hs,
[getattr(self, target) for target in in_fold_idx_targets],
)

@property
def num_outputs(self) -> int:
Expand Down
12 changes: 6 additions & 6 deletions cirkit/backend/torch/layers/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,12 @@ def forward(self, x: Tensor) -> Tensor:
x = x.squeeze(dim=3) # (F, C, B)
weight = self.weight()
if self.num_channels == 1:
idx_fold = torch.arange(self.num_folds)
idx_fold = torch.arange(self.num_folds, device=weight.device)
x = weight[:, :, 0][idx_fold[:, None], :, x[:, 0]]
x = self.semiring.map_from(x, SumProductSemiring)
else:
idx_fold = torch.arange(self.num_folds)[:, None, None]
idx_channel = torch.arange(self.num_channels)[None, :, None]
idx_fold = torch.arange(self.num_folds, device=weight.device)[:, None, None]
idx_channel = torch.arange(self.num_channels, device=weight.device)[None, :, None]
x = weight[idx_fold, :, idx_channel, x]
x = self.semiring.map_from(x, SumProductSemiring)
x = self.semiring.prod(x, dim=1)
Expand Down Expand Up @@ -434,11 +434,11 @@ def log_unnormalized_likelihood(self, x: Tensor) -> Tensor:
# logits: (F, K, C, N)
logits = torch.log(self.probs()) if self.logits is None else self.logits()
if self.num_channels == 1:
idx_fold = torch.arange(self.num_folds)
idx_fold = torch.arange(self.num_folds, device=logits.device)
x = logits[:, :, 0][idx_fold[:, None], :, x[:, 0]]
else:
idx_fold = torch.arange(self.num_folds)[:, None, None]
idx_channel = torch.arange(self.num_channels)[None, :, None]
idx_fold = torch.arange(self.num_folds, device=logits.device)[:, None, None]
idx_channel = torch.arange(self.num_channels, device=logits.device)[None, :, None]
x = torch.sum(logits[idx_fold, :, idx_channel, x], dim=1)
return x

Expand Down
33 changes: 16 additions & 17 deletions cirkit/backend/torch/parameters/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,29 @@ class ParameterAddressBook(AddressBook):
def lookup(
self, module_outputs: list[Tensor], *, in_graph: Tensor | None = None
) -> Iterator[tuple[TorchParameterNode | None, tuple]]:
# Loop through the entries and yield inputs
for entry in self._entries:
in_module_ids = entry.in_module_ids
def _select_index(mids: list[int], idx: Tensor | tuple[slice | None, ...]) -> Tensor:
# A useful function combining the modules outputs, and then possibly applying an index
if len(mids) == 1:
t = module_outputs[mids[0]]
else:
t = torch.cat([module_outputs[mid] for mid in mids], dim=0)
return t[idx]

# Loop through the entries and yield inputs
for entry in self:
node = entry.module
in_node_ids = entry.in_module_ids
in_fold_idx = entry.in_fold_idx
# Catch the case there are some inputs coming from other modules
if in_module_ids:
if in_node_ids:
x = tuple(
ParameterAddressBook._select_index(module_outputs, mids, in_idx)
for mids, in_idx in zip(in_module_ids, entry.in_fold_idx)
_select_index(mids, in_idx) for mids, in_idx in zip(in_node_ids, in_fold_idx)
)
yield entry.module, x
yield node, x
continue

# Catch the case there are no inputs coming from other modules
yield entry.module, ()

@staticmethod
def _select_index(node_outputs: list[Tensor], mids: list[int], idx: Tensor | None) -> Tensor:
# A useful function combining the modules outputs, and then possibly applying an index
if len(mids) == 1:
t = node_outputs[mids[0]]
else:
t = torch.cat([node_outputs[mid] for mid in mids], dim=0)
return t if idx is None else t[idx]
yield node, ()

@classmethod
def from_index_info(cls, fold_idx_info: FoldIndexInfo) -> "ParameterAddressBook":
Expand Down
20 changes: 10 additions & 10 deletions notebooks/compilation-options.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.78 s, sys: 1.14 s, total: 5.92 s\n",
"CPU times: user 4.74 s, sys: 1.13 s, total: 5.87 s\n",
"Wall time: 5.76 s\n"
]
}
Expand Down Expand Up @@ -273,7 +273,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"1.29 s ± 21.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
"1.42 s ± 16.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
Expand Down Expand Up @@ -338,8 +338,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.68 s, sys: 915 ms, total: 5.6 s\n",
"Wall time: 5.54 s\n"
"CPU times: user 4.98 s, sys: 809 ms, total: 5.79 s\n",
"Wall time: 5.69 s\n"
]
}
],
Expand Down Expand Up @@ -420,7 +420,7 @@
"id": "f074e168-dee4-4234-8eae-afd28fae317f",
"metadata": {},
"source": [
"As we see in the next code snippet, enabling folding provided an (approximately) **19.9x speed-up** for feed-forward circuit evaluations."
"As we see in the next code snippet, enabling folding provided an (approximately) **28.9x speed-up** for feed-forward circuit evaluations."
]
},
{
Expand All @@ -433,7 +433,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"58.9 ms ± 28.8 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"49.1 ms ± 33 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
Expand Down Expand Up @@ -527,8 +527,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.55 s, sys: 758 ms, total: 5.31 s\n",
"Wall time: 5.25 s\n"
"CPU times: user 5.06 s, sys: 1.07 s, total: 6.12 s\n",
"Wall time: 6.02 s\n"
]
}
],
Expand Down Expand Up @@ -591,7 +591,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"25.7 ms ± 27.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
"25.4 ms ± 8.21 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
Expand All @@ -608,7 +608,7 @@
"id": "11d95c02-2c66-4414-b676-0dec303f2aa9",
"metadata": {},
"source": [
"Note that, we achieved an (approximately) **2.3x speed-up**, when compared to the folded circuit compiled above, and an (approximately) **45.7x speed-up**, when compared to the circuit compiled with no folding and no optimizations."
"Note that, we achieved an (approximately) **1.9x speed-up**, when compared to the folded circuit compiled above, and an (approximately) **55.9x speed-up**, when compared to the circuit compiled with no folding and no optimizations."
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions notebooks/compression-cp-factorization.ipynb

Large diffs are not rendered by default.

116 changes: 57 additions & 59 deletions notebooks/generative-vs-discriminative-circuit.ipynb

Large diffs are not rendered by default.

26 changes: 13 additions & 13 deletions notebooks/learning-a-circuit-with-pic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -327,17 +327,17 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Step 200: Average NLL: 800.724\n",
"Step 400: Average NLL: 704.707\n",
"Step 600: Average NLL: 684.442\n",
"Step 800: Average NLL: 675.654\n",
"Step 1000: Average NLL: 666.133\n",
"Step 1200: Average NLL: 656.394\n",
"Step 1400: Average NLL: 653.244\n",
"Step 1600: Average NLL: 651.271\n",
"Step 1800: Average NLL: 650.399\n",
"Step 2000: Average NLL: 648.530\n",
"Step 2200: Average NLL: 647.928\n"
"Step 200: Average NLL: 798.530\n",
"Step 400: Average NLL: 702.238\n",
"Step 600: Average NLL: 685.940\n",
"Step 800: Average NLL: 679.610\n",
"Step 1000: Average NLL: 673.089\n",
"Step 1200: Average NLL: 661.166\n",
"Step 1400: Average NLL: 656.975\n",
"Step 1600: Average NLL: 654.494\n",
"Step 1800: Average NLL: 653.448\n",
"Step 2000: Average NLL: 651.315\n",
"Step 2200: Average NLL: 650.697\n"
]
}
],
Expand Down Expand Up @@ -393,8 +393,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Average test LL: -642.912\n",
"Bits per dimension: 1.183\n"
"Average test LL: -645.790\n",
"Bits per dimension: 1.188\n"
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions notebooks/learning-a-circuit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.65 s, sys: 322 ms, total: 2.97 s\n",
"Wall time: 2.89 s\n"
"CPU times: user 3.54 s, sys: 292 ms, total: 3.84 s\n",
"Wall time: 3.75 s\n"
]
}
],
Expand Down
Loading

0 comments on commit c446634

Please sign in to comment.