diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index a900ef1d1..ef5fae7e0 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -112,6 +112,10 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str, self.configuration_manager = configuration_manager self.list_of_parameters = parameters self.network = network + + # initialize network with first set of parameters, also see https://github.com/MIC-DKFZ/nnUNet/issues/2520 + network.load_state_dict(parameters[0]) + self.dataset_json = dataset_json self.trainer_name = trainer_name self.allowed_mirroring_axes = inference_allowed_mirroring_axes