diff --git a/nnunetv2/utilities/plans_handling/plans_handler.py b/nnunetv2/utilities/plans_handling/plans_handler.py index 11b76dfcb..eab65885a 100644 --- a/nnunetv2/utilities/plans_handling/plans_handler.py +++ b/nnunetv2/utilities/plans_handling/plans_handler.py @@ -54,6 +54,8 @@ def __init__(self, configuration_dict: dict): conv_op = convert_dim_to_conv_op(dim) instnorm = get_matching_instancenorm(dimension=dim) + convs_or_blocks = "n_conv_per_stage" if unet_class_name == "PlainConvUNet" else "n_blocks_per_stage" + arch_dict = { 'network_class_name': network_class_name, 'arch_kwargs': { @@ -64,7 +66,7 @@ def __init__(self, configuration_dict: dict): "conv_op": conv_op.__module__ + '.' + conv_op.__name__, "kernel_sizes": deepcopy(self.configuration["conv_kernel_sizes"]), "strides": deepcopy(self.configuration["pool_op_kernel_sizes"]), - "n_conv_per_stage": deepcopy(self.configuration["n_conv_per_stage_encoder"]), + convs_or_blocks: deepcopy(self.configuration["n_conv_per_stage_encoder"]), "n_conv_per_stage_decoder": deepcopy(self.configuration["n_conv_per_stage_decoder"]), "conv_bias": True, "norm_op": instnorm.__module__ + '.' + instnorm.__name__,