Skip to content

Commit

Permalink
apply ruff again
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoGrin committed Jan 16, 2025
1 parent a419bb0 commit 83a6360
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 88 deletions.
48 changes: 24 additions & 24 deletions src/tabpfn/model/bar_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def cdf(self, logits: torch.Tensor, ys: torch.Tensor) -> torch.Tensor:
# bring new borders to the same dim as logits up to the last dim
ys = ys.repeat(logits.shape[:-1] + (1,))
else:
assert ys.shape[:-1] == logits.shape[:-1], (
f"ys.shape: {ys.shape} logits.shape: {logits.shape}"
)
assert (
ys.shape[:-1] == logits.shape[:-1]
), f"ys.shape: {ys.shape} logits.shape: {logits.shape}"
probs = torch.softmax(logits, dim=-1)
buckets_of_ys = self.map_to_bucket_idx(ys).clamp(0, self.num_bars - 1)

Expand Down Expand Up @@ -190,9 +190,9 @@ def forward(
ignore_loss_mask = self.ignore_init(y)
target_sample = self.map_to_bucket_idx(y)
assert (target_sample >= 0).all()
assert (target_sample < self.num_bars).all(), (
f"y {y} not in support set for borders (min_y, max_y) {self.borders}"
)
assert (
target_sample < self.num_bars
).all(), f"y {y} not in support set for borders (min_y, max_y) {self.borders}"

last_dim = logits.shape[-1]
assert last_dim == self.num_bars, f"{last_dim} v {self.num_bars}"
Expand Down Expand Up @@ -462,9 +462,9 @@ def __init__(

def assert_support(self, *, allow_zero_bucket_left: bool = False) -> None:
if allow_zero_bucket_left:
assert self.bucket_widths[-1] > 0, (
f"Half Normal weight must be > 0 (got -1:{self.bucket_widths[-1]})."
)
assert (
self.bucket_widths[-1] > 0
), f"Half Normal weight must be > 0 (got -1:{self.bucket_widths[-1]})."
# This fixes the distribution if the half normal at zero is width zero
if self.bucket_widths[0] == 0:
self.borders[0] = self.borders[0] - 1
Expand Down Expand Up @@ -505,13 +505,13 @@ def forward(
target_sample = self.map_to_bucket_idx(y) # shape: T x B (same as y)
target_sample.clamp_(0, self.num_bars - 1)

assert logits.shape[-1] == self.num_bars, (
f"{logits.shape[-1]} vs {self.num_bars}"
)
assert (
logits.shape[-1] == self.num_bars
), f"{logits.shape[-1]} vs {self.num_bars}"
assert (target_sample >= 0).all()
assert (target_sample < self.num_bars).all(), (
f"y {y} not in support set for borders (min_y, max_y) {self.borders}"
)
assert (
target_sample < self.num_bars
).all(), f"y {y} not in support set for borders (min_y, max_y) {self.borders}"
last_dim = logits.shape[-1]
assert last_dim == self.num_bars, f"{last_dim} vs {self.num_bars}"
# ignore all position with nan values
Expand Down Expand Up @@ -544,9 +544,9 @@ def forward(
nll_loss = -log_probs

if mean_prediction_logits is not None: # TO BE REMOVED AFTER BO PAPER IS DONE
assert not ignore_loss_mask.any(), (
"Ignoring examples is not implemented with mean pred."
)
assert (
not ignore_loss_mask.any()
), "Ignoring examples is not implemented with mean pred."
if not torch.is_grad_enabled():
pass
nll_loss = torch.cat(
Expand Down Expand Up @@ -783,16 +783,16 @@ def get_bucket_limits(
If set, the bucket limits are widened by this factor.
This allows to have a slightly larger range than the actual data.
"""
assert (ys is None) != (full_range is None), (
"Either full_range or ys must be passed."
)
assert (ys is None) != (
full_range is None
), "Either full_range or ys must be passed."

if ys is not None:
ys = ys.flatten()
ys = ys[~torch.isnan(ys)]
assert len(ys) > num_outputs, (
f"Number of ys :{len(ys)} must be larger than num_outputs: {num_outputs}"
)
assert (
len(ys) > num_outputs
), f"Number of ys :{len(ys)} must be larger than num_outputs: {num_outputs}"
if len(ys) % num_outputs:
ys = ys[: -(len(ys) % num_outputs)]
ys_per_bucket = len(ys) // num_outputs
Expand Down
18 changes: 9 additions & 9 deletions src/tabpfn/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def normalize_data(
"""
# TODO(eddiebergman): I feel like this function is easier to just do what you need
# where you need it, rather than supporting all these variations
assert (mean is None) == (std is None), (
"Either both or none of mean and std must be given"
)
assert (mean is None) == (
std is None
), "Either both or none of mean and std must be given"
if mean is None:
if normalize_positions is not None and normalize_positions > 0:
mean = torch_nanmean(data[:normalize_positions], axis=0) # type: ignore
Expand Down Expand Up @@ -759,9 +759,9 @@ def _transform(
x = to_ranking_low_mem(x)

if self.remove_outliers:
assert self.remove_outliers_sigma > 1.0, (
"remove_outliers_sigma must be > 1.0"
)
assert (
self.remove_outliers_sigma > 1.0
), "remove_outliers_sigma must be > 1.0"

x, _ = remove_outliers(
x,
Expand Down Expand Up @@ -965,9 +965,9 @@ def flatten_targets(y: torch.Tensor, unique_ys: torch.Tensor | None = None):

def _transform(self, y: torch.Tensor, single_eval_pos: int | None = None):
assert len(y.shape) == 3 and (y.shape[-1] == 1), "y must be of shape (T, B, 1)"
assert not (y.isnan().any() and self.training), (
"NaNs are not allowed in the target at this point during training (set to model.eval() if not in training)"
)
assert not (
y.isnan().any() and self.training
), "NaNs are not allowed in the target at this point during training (set to model.eval() if not in training)"
y_new = y.clone()
for B in range(y.shape[1]):
y_new[:, B, :] = self.flatten_targets(y[:, B, :], self.unique_ys_[B])
Expand Down
12 changes: 6 additions & 6 deletions src/tabpfn/model/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,9 @@ def forward( # noqa: C901
Returns:
The transformer state passed through the encoder layer.
"""
assert len(state.shape) == 4, (
"src must be of shape (batch_size, num_items, num feature blocks, d_model)"
)
assert (
len(state.shape) == 4
), "src must be of shape (batch_size, num_items, num feature blocks, d_model)"
if single_eval_pos is None:
single_eval_pos = 0

Expand All @@ -313,9 +313,9 @@ def forward( # noqa: C901
save_peak_mem_factor = None

if att_src is not None:
assert not self.multiquery_item_attention_for_test_set, (
"Not implemented yet."
)
assert (
not self.multiquery_item_attention_for_test_set
), "Not implemented yet."
assert not cache_trainset_representation, "Not implemented yet."
assert not single_eval_pos, (
"single_eval_pos should not be set, as the train representation"
Expand Down
3 changes: 2 additions & 1 deletion src/tabpfn/model/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def method_(
) -> torch.Tensor:
assert isinstance(self, torch.nn.Module)
assert save_peak_mem_factor is None or allow_inplace, (
"The parameter save_peak_mem_factor only supported with 'allow_inplace' set."
"The parameter 'save_peak_mem_factor' is only supported with "
"'allow_inplace' set."
)
assert isinstance(x, torch.Tensor)

Expand Down
24 changes: 12 additions & 12 deletions src/tabpfn/model/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,17 +299,17 @@ def forward(
Else, keys and values are attained by applying the respective linear
transformations to 'x' (self attention).
"""
assert not (cache_kv and use_cached_kv), (
"Cannot cache and use cached keys and values at the same time."
)
assert not (
cache_kv and use_cached_kv
), "Cannot cache and use cached keys and values at the same time."
if use_second_set_of_queries:
assert self.two_sets_of_queries, (
"Two sets of queries are not supported."
"Please set 'two_sets_of_queries' to True."
)
assert not x.requires_grad or (not self.has_cached_kv and not cache_kv), (
"Saving keys and values is only supported during inference."
)
assert not x.requires_grad or (
not self.has_cached_kv and not cache_kv
), "Saving keys and values is only supported during inference."
x, x_kv, x_shape_after_transpose = self._rearrange_inputs_to_flat_batch(x, x_kv)

nhead_kv = 1 if reuse_first_head_kv else self._nhead_kv
Expand Down Expand Up @@ -387,9 +387,9 @@ def compute_qkv( # noqa: PLR0912, C901
torch.Tensor | None,
torch.Tensor | None,
]:
assert not (cache_kv and use_cached_kv), (
"You cannot both cache new KV and use the cached KV at once."
)
assert not (
cache_kv and use_cached_kv
), "You cannot both cache new KV and use the cached KV at once."
if reuse_first_head_kv:
assert x is not x_kv, (
"x and x_kv must be different tensors. That means reuse_first_head_kv"
Expand All @@ -400,9 +400,9 @@ def compute_qkv( # noqa: PLR0912, C901

k = v = kv = None
if use_cached_kv:
assert self.has_cached_kv, (
"You try to use cached keys and values but the cache is empty."
)
assert (
self.has_cached_kv
), "You try to use cached keys and values but the cache is empty."
k = k_cache
v = v_cache
kv = kv_cache
Expand Down
18 changes: 9 additions & 9 deletions src/tabpfn/model/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,9 @@ def fit(self, X: np.ndarray, categorical_features: list[int]) -> Self:
X: 2d array of shape (n_samples, n_features)
categorical_features: list of indices of categorical feature.
"""
assert len(self) > 0, (
"The SequentialFeatureTransformer must have at least one step."
)
assert (
len(self) > 0
), "The SequentialFeatureTransformer must have at least one step."
self.fit_transform(X, categorical_features)
return self

Expand All @@ -422,9 +422,9 @@ def transform(self, X: np.ndarray) -> _TransformResult:
Args:
X: 2d array of shape (n_samples, n_features).
"""
assert len(self) > 0, (
"The SequentialFeatureTransformer must have at least one step."
)
assert (
len(self) > 0
), "The SequentialFeatureTransformer must have at least one step."
assert self.categorical_features_ is not None, (
"The SequentialFeatureTransformer must be fit before it"
" can be used to transform."
Expand Down Expand Up @@ -565,9 +565,9 @@ def _fit(self, X: np.ndarray, categorical_features: list[int]) -> list[int]:
@override
def _transform(self, X: np.ndarray, *, is_test: bool = False) -> np.ndarray:
assert self.index_permutation_ is not None, "You must call fit first"
assert len(self.index_permutation_) == X.shape[1], (
"The number of features must not change after fit"
)
assert (
len(self.index_permutation_) == X.shape[1]
), "The number of features must not change after fit"
return X[:, self.index_permutation_]


Expand Down
18 changes: 9 additions & 9 deletions src/tabpfn/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def forward(
**kwargs: Any,
) -> torch.Tensor:
if half_layers:
assert self.min_num_layers_layer_dropout == self.num_layers, (
"half_layers only works without layer dropout"
)
assert (
self.min_num_layers_layer_dropout == self.num_layers
), "half_layers only works without layer dropout"
n_layers = self.num_layers // 2
else:
n_layers = torch.randint(
Expand Down Expand Up @@ -688,9 +688,9 @@ def add_embeddings( # noqa: C901, PLR0912
use_cached_embeddings: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
if use_cached_embeddings and self.cached_embeddings is not None:
assert data_dags is None, (
"Caching embeddings is not supported with data_dags at this point."
)
assert (
data_dags is None
), "Caching embeddings is not supported with data_dags at this point."
x += self.cached_embeddings[None, None]
return x, y

Expand Down Expand Up @@ -748,9 +748,9 @@ def add_embeddings( # noqa: C901, PLR0912

self.cached_embeddings = None
if cache_embeddings and embs is not None:
assert data_dags is None, (
"Caching embeddings is not supported with data_dags at this point."
)
assert (
data_dags is None
), "Caching embeddings is not supported with data_dags at this point."
self.cached_embeddings = embs

# TODO(old) should this go into encoder?
Expand Down
6 changes: 3 additions & 3 deletions src/tabpfn/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,9 +611,9 @@ def predict( # noqa: C901, PLR0912
if quantiles is None:
quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
else:
assert all((0 <= q <= 1) and (isinstance(q, float)) for q in quantiles), (
"All quantiles must be between 0 and 1 and floats."
)
assert all(
(0 <= q <= 1) and (isinstance(q, float)) for q in quantiles
), "All quantiles must be between 0 and 1 and floats."
if output_type not in self._USABLE_OUTPUT_TYPES:
raise ValueError(f"Invalid output type: {output_type}")

Expand Down
16 changes: 10 additions & 6 deletions tests/test_classifier_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,11 @@ def test_balanced_probabilities(X_y: tuple[np.ndarray, np.ndarray]) -> None:
# Check that the mean probability for each class is roughly equal
mean_probs = probabilities.mean(axis=0)
expected_mean = 1.0 / len(np.unique(y))
assert np.allclose(mean_probs, expected_mean, rtol=0.1), (
"Class probabilities are not properly balanced"
)
assert np.allclose(
mean_probs,
expected_mean,
rtol=0.1,
), "Class probabilities are not properly balanced"


def test_classifier_in_pipeline(X_y: tuple[np.ndarray, np.ndarray]) -> None:
Expand Down Expand Up @@ -162,6 +164,8 @@ def test_classifier_in_pipeline(X_y: tuple[np.ndarray, np.ndarray]) -> None:
# Check that the mean probability for each class is roughly equal
mean_probs = probabilities.mean(axis=0)
expected_mean = 1.0 / len(np.unique(y))
assert np.allclose(mean_probs, expected_mean, rtol=0.1), (
"Class probabilities are not properly balanced in pipeline"
)
assert np.allclose(
mean_probs,
expected_mean,
rtol=0.1,
), "Class probabilities are not properly balanced in pipeline"
18 changes: 9 additions & 9 deletions tests/test_regressor_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,18 @@ def test_regressor_in_pipeline(X_y: tuple[np.ndarray, np.ndarray]) -> None:

# Test different prediction modes through the pipeline
predictions_median = pipeline.predict(X, output_type="median")
assert predictions_median.shape == (X.shape[0],), (
"Median predictions shape is incorrect"
)
assert predictions_median.shape == (
X.shape[0],
), "Median predictions shape is incorrect"

predictions_mode = pipeline.predict(X, output_type="mode")
assert predictions_mode.shape == (X.shape[0],), (
"Mode predictions shape is incorrect"
)
assert predictions_mode.shape == (
X.shape[0],
), "Mode predictions shape is incorrect"

quantiles = pipeline.predict(X, output_type="quantiles", quantiles=[0.1, 0.9])
assert isinstance(quantiles, list)
assert len(quantiles) == 2
assert quantiles[0].shape == (X.shape[0],), (
"Quantile predictions shape is incorrect"
)
assert quantiles[0].shape == (
X.shape[0],
), "Quantile predictions shape is incorrect"

0 comments on commit 83a6360

Please sign in to comment.