Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for conditional sampling #344

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions cirkit/backend/torch/layers/inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,15 @@ def forward(self, x: Tensor) -> Tensor:
is the number of output units.
"""

def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
def sample(self, x: Tensor, conditional: bool = False) -> tuple[Tensor, Tensor | None]:
"""Perform a forward sampling step.

Args:
x: A tensor representing the input variable assignments, having shape
$(F, H, C, K, N, D)$, where $F$ is the number of folds, $H$ is the arity,
$C$ is the number of channels, $K$ is the numbe rof input units, $N$ is the number
of samples, $D$ is the number of variables.
conditional: whether the sample is drawn conditionally; conditioned on input specified in the sample query..

Returns:
Tensor: A new tensor representing the new variable assignements the layers gives
Expand Down Expand Up @@ -123,7 +124,7 @@ def config(self) -> Mapping[str, Any]:
def forward(self, x: Tensor) -> Tensor:
return self.semiring.prod(x, dim=1, keepdim=False) # shape (F, H, B, K) -> (F, B, K).

def sample(self, x: Tensor) -> tuple[Tensor, None]:
def sample(self, x: Tensor, conditional: bool = False) -> tuple[Tensor, None]:
# Concatenate samples over disjoint variables through a sum
# x: (F, H, C, K, num_samples, D)
x = torch.sum(x, dim=1) # (F, C, K, num_samples, D)
Expand Down Expand Up @@ -182,7 +183,7 @@ def forward(self, x: Tensor) -> Tensor:
# shape (F, B, Ki, Ki) -> (F, B, Ko=Ki**2).
return self.semiring.mul(x0, x1).flatten(start_dim=-2)

def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
def sample(self, x: Tensor, conditional: bool = False) -> tuple[Tensor, Tensor | None]:
# x: (F, H, C, K, num_samples, D)
x0 = x[:, 0].unsqueeze(dim=3) # (F, C, Ki, 1, num_samples, D)
x1 = x[:, 1].unsqueeze(dim=2) # (F, C, 1, Ki, num_samples, D)
Expand Down Expand Up @@ -234,6 +235,7 @@ def __init__(
f"and shape {self._weight_shape} for 'weight', found"
f"{weight.num_folds} and {weight.shape}, respectively"
)
self.input = None
self.weight = weight

def _valid_weight_shape(self, w: TorchParameter) -> bool:
Expand Down Expand Up @@ -261,12 +263,13 @@ def forward(self, x: Tensor) -> Tensor:
# x: (F, H, B, Ki) -> (F, B, H * Ki)
# weight: (F, Ko, H * Ki)
x = x.permute(0, 2, 1, 3).flatten(start_dim=2)
self.input = x
weight = self.weight()
return self.semiring.einsum(
"fbi,foi->fbo", inputs=(x,), operands=(weight,), dim=-1, keepdim=True
) # shape (F, B, K_o).

def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
def sample(self, x: Tensor, conditional: bool = False) -> tuple[Tensor, Tensor]:
weight = self.weight()
negative = torch.any(weight < 0.0)
normalized = torch.allclose(torch.sum(weight, dim=-1), torch.ones(1, device=weight.device))
Expand All @@ -280,7 +283,15 @@ def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
d = x.shape[4]

# mixing_distribution: (F, Ko, H * Ki)
mixing_distribution = torch.distributions.Categorical(probs=weight)
if conditional:
prior = torch.log(torch.clamp(weight, min=1e-10))
posterior = prior + self.input
normalized_posterior = torch.exp(
posterior - torch.logsumexp(posterior, 2, keepdim=True)
)
mixing_distribution = torch.distributions.Categorical(probs=normalized_posterior)
else:
mixing_distribution = torch.distributions.Categorical(probs=weight)

# mixing_samples: (num_samples, F, Ko) -> (F, Ko, num_samples)
mixing_samples = mixing_distribution.sample((num_samples,))
Expand Down
2 changes: 1 addition & 1 deletion cirkit/backend/torch/layers/optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def forward(self, x: Tensor) -> Tensor:
"fbi,foi->fbo", inputs=(x,), operands=(weight,), dim=-1, keepdim=True
)

def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
def sample(self, x: Tensor, conditional: bool = False) -> tuple[Tensor, Tensor]:
weight = self.weight()
negative = torch.any(weight < 0.0)
if negative:
Expand Down
52 changes: 46 additions & 6 deletions cirkit/backend/torch/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ class SamplingQuery(Query):
def __init__(self, circuit: TorchCircuit) -> None:
"""Initialize a sampling query object. Currently, only sampling from the joint distribution
is supported, i.e., sampling won't work in the case of circuits obtained by
marginalization, or by observing evidence. Conditional sampling is currently not
implemented.
marginalization, or by observing evidence.

Args:
circuit: The circuit to sample from.
Expand All @@ -211,11 +210,22 @@ def __init__(self, circuit: TorchCircuit) -> None:
super().__init__()
self._circuit = circuit

def __call__(self, num_samples: int = 1) -> tuple[Tensor, list[Tensor]]:
def __call__(
self, num_samples: int = 1, x: Tensor = None, integrate_vars: Scope = None
) -> tuple[Tensor, list[Tensor]]:
"""Sample a number of data points.

Args:
num_samples: The number of samples to return.
x: An input batch of shape $(B, C, D)$, where $B$ is the batch size, $C$ is the number
of channels per variable, and $D$ is the number of variables.
integrate_vars: The variables to integrate. It must be a subset of the variables on
which the circuit given in the constructor is defined on.
The format can be one of the following three:
1. Tensor of shape (B, D) where B is the batch size and D is the number of
variables in the scope of the circuit. Its dtype should be torch.bool
and have True in the positions of random variables that should be
marginalised out and False elsewhere.

Return:
A pair (samples, mixture_samples), consisting of (i) an assignment to the observed
Expand All @@ -224,10 +234,27 @@ def __call__(self, num_samples: int = 1) -> tuple[Tensor, list[Tensor]]:
tensor of shape (num_samples, num_channels, num_variables).

Raises:
ValueError: if the number of samples is not a positive number.
ValueError: if the number of samples is not a positive number or only integrate_vars is specified without x.
"""
if num_samples <= 0:
raise ValueError("The number of samples must be a positive number")
if bool(integrate_vars is None) ^ bool(x is None):
raise ValueError(
"For conditional samples, both input to condition and scope to integrate out must be specified"
)

conditional_sampling = False
if (x is not None) and (integrate_vars is not None):
conditional_sampling = True

if conditional_sampling:
# the cct could be in the eval mode; therefore set it to training, perform cond. sampling,
# and then reset to current state
is_training = self._circuit.training
self._circuit.train()
int_query = IntegrateQuery(self._circuit)
int_query(x, integrate_vars=integrate_vars)
self._circuit.train(is_training)

mixture_samples: list[Tensor] = []
# samples: (O, C, K, num_samples, D)
Expand All @@ -236,16 +263,29 @@ def __call__(self, num_samples: int = 1) -> tuple[Tensor, list[Tensor]]:
self._layer_fn,
num_samples=num_samples,
mixture_samples=mixture_samples,
conditional=conditional_sampling,
),
)
# samples: (num_samples, O, K, C, D)
samples = samples.permute(3, 0, 2, 1, 4)
# TODO: fix for the case of multi-output circuits, i.e., O != 1 or K != 1
samples = samples[:, 0, 0] # (num_samples, C, D)
# if integration scopes are given, combine the conditioned input with drawn samples
if conditional_sampling:
marginalized_scope_ids = [i for i in range(x.shape[2]) if i in integrate_vars]
non_marginalized_scope_ids = [i for i in range(x.shape[2]) if i not in integrate_vars]
x[..., marginalized_scope_ids] = 0.0
samples[..., non_marginalized_scope_ids] = 0.0
samples = samples + x
return samples, mixture_samples

def _layer_fn(
self, layer: TorchLayer, *inputs: Tensor, num_samples: int, mixture_samples: list[Tensor]
self,
layer: TorchLayer,
*inputs: Tensor,
num_samples: int,
mixture_samples: list[Tensor],
conditional: bool = False,
) -> Tensor:
# Sample from an input layer
if not inputs:
Expand All @@ -257,7 +297,7 @@ def _layer_fn(

# Sample through an inner layer
assert isinstance(layer, TorchInnerLayer)
samples, mix_samples = layer.sample(*inputs)
samples, mix_samples = layer.sample(*inputs, conditional=conditional)
if mix_samples is not None:
mixture_samples.append(mix_samples)
return samples
Expand Down
Loading