Skip to content

Commit

Permalink
Fix our x[0] v x[:, 0] checks in the case of scalar yhat. Closes #103.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmmcd committed Oct 17, 2021
1 parent e1b9197 commit 10c8a18
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
10 changes: 6 additions & 4 deletions src/fitness/supervised_learning/supervised_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def evaluate(self, ind, **kwargs):
# phen will refer to x (ie test_in), and possibly to c
yhat = eval(phen)
assert np.isrealobj(yhat)
if y.shape != yhat.shape:
raise ValueError(shape_mismatch_txt)
if not np.isscalar(yhat):
if y.shape != yhat.shape:
raise ValueError(shape_mismatch_txt)

# let's always call the error function with the
# true values first, the estimate second
Expand All @@ -100,8 +101,9 @@ def evaluate(self, ind, **kwargs):
# phenotype won't refer to C
yhat = eval(ind.phenotype)
assert np.isrealobj(yhat)
if y.shape != yhat.shape:
raise ValueError(shape_mismatch_txt)
if not np.isscalar(yhat):
if y.shape != yhat.shape:
raise ValueError(shape_mismatch_txt)

# let's always call the error function with the true
# values first, the estimate second
Expand Down
5 changes: 3 additions & 2 deletions src/utilities/fitness/optimize_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def optimize_constants(x, y, ind):
# ind doesn't refer to c: no need to optimize
c = []
yhat = f(x, c)
if y.shape != yhat.shape:
raise ValueError(shape_mismatch_txt)
if not np.isscalar(yhat):
if y.shape != yhat.shape:
raise ValueError(shape_mismatch_txt)
fitness = loss(y, yhat)
ind.opt_consts = c
return fitness
Expand Down

0 comments on commit 10c8a18

Please sign in to comment.