Skip to content

Commit

Permalink
fix predict
Browse files Browse the repository at this point in the history
  • Loading branch information
rafa9811 committed Jun 21, 2023
1 parent 679875d commit ed03263
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions skfda/ml/regression/_linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ..._utils import nquad_vec
from ..._utils._sklearn_adapter import BaseEstimator, RegressorMixin
from ...misc._math import inner_product
from ...misc._math import inner_product_matrix
from ...misc.lstsq import solve_regularized_weighted_lstsq
from ...misc.regularization import L2Regularization, compute_penalty_matrix
from ...representation import FData, FDataBasis
Expand Down Expand Up @@ -446,22 +446,25 @@ def predict( # noqa: D102
X = self._argcheck_X(X)
result = []

if self.functional_response:
""" if self.functional_response:
# Predict each covariate (multivariate or functional)
if len(X) > 1:
return list(
itertools.chain.from_iterable(
[self.predict(x) for x in X],
),
)
X = X[0]
X = X[0] """

for coef, x, coef_info in zip(self.coef_, X, self._coef_info):
if self.functional_response:
def prediction(arg, x_eval=x, coef_eval=coef): # noqa: WPS430
if isinstance(x_eval, Callable):
x_eval = x_eval(arg) # noqa: WPS220

#TODO: MIRAR ESTO BIEN
x_eval = x_eval.reshape((-1, 1, 1))

return coef_eval(arg) * x_eval

result.append(
Expand All @@ -470,11 +473,11 @@ def prediction(arg, x_eval=x, coef_eval=coef): # noqa: WPS430
else:
result.append(coef_info.inner_product(coef, x))

result = np.sum(result, axis=0)
result = sum(result[1:], start=result[0])

if self.fit_intercept:
if self.functional_response:
result[0] = result[0] + self._change_function_basis(
result = result + self._change_function_basis(
self.intercept_, self.y_basis,
)
else:
Expand Down Expand Up @@ -722,14 +725,14 @@ def _change_function_basis(
if isinstance(f, FDataBasis) and f.basis == new_basis:
return f

inner_prod = inner_product(
f,
inner_prod = inner_product_matrix(
new_basis,
f,
_domain_range=new_basis.domain_range,
)[:, np.newaxis]
)

gram_matrix = new_basis.gram_matrix()

coefs = np.linalg.solve(gram_matrix, inner_prod).flatten()
coefs = np.linalg.solve(gram_matrix, inner_prod)

return FDataBasis(new_basis, coefs)
return FDataBasis(new_basis, coefs.T)

0 comments on commit ed03263

Please sign in to comment.