From 10c8a18a924eb5aca0d4ca3c6c5670973b93c70d Mon Sep 17 00:00:00 2001 From: James McDermott Date: Sun, 17 Oct 2021 22:22:22 +0100 Subject: [PATCH] Fix our x[0] v x[:, 0] checks in the case of scalar yhat. Closes #103. --- src/fitness/supervised_learning/supervised_learning.py | 10 ++++++---- src/utilities/fitness/optimize_constants.py | 5 +++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/fitness/supervised_learning/supervised_learning.py b/src/fitness/supervised_learning/supervised_learning.py index 15a5eed9..81e54dc3 100644 --- a/src/fitness/supervised_learning/supervised_learning.py +++ b/src/fitness/supervised_learning/supervised_learning.py @@ -89,8 +89,9 @@ def evaluate(self, ind, **kwargs): # phen will refer to x (ie test_in), and possibly to c yhat = eval(phen) assert np.isrealobj(yhat) - if y.shape != yhat.shape: - raise ValueError(shape_mismatch_txt) + if not np.isscalar(yhat): + if y.shape != yhat.shape: + raise ValueError(shape_mismatch_txt) # let's always call the error function with the # true values first, the estimate second @@ -100,8 +101,9 @@ def evaluate(self, ind, **kwargs): # phenotype won't refer to C yhat = eval(ind.phenotype) assert np.isrealobj(yhat) - if y.shape != yhat.shape: - raise ValueError(shape_mismatch_txt) + if not np.isscalar(yhat): + if y.shape != yhat.shape: + raise ValueError(shape_mismatch_txt) # let's always call the error function with the true # values first, the estimate second diff --git a/src/utilities/fitness/optimize_constants.py b/src/utilities/fitness/optimize_constants.py index 65f64432..051461cc 100644 --- a/src/utilities/fitness/optimize_constants.py +++ b/src/utilities/fitness/optimize_constants.py @@ -40,8 +40,9 @@ def optimize_constants(x, y, ind): # ind doesn't refer to c: no need to optimize c = [] yhat = f(x, c) - if y.shape != yhat.shape: - raise ValueError(shape_mismatch_txt) + if not np.isscalar(yhat): + if y.shape != yhat.shape: + raise ValueError(shape_mismatch_txt) fitness = loss(y, yhat) ind.opt_consts = c return fitness