Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when the train set has 1 example #73

Open
LeoGrin opened this issue Jan 13, 2025 · 0 comments
Open

Error when the train set has 1 example #73

LeoGrin opened this issue Jan 13, 2025 · 0 comments

Comments

@LeoGrin
Copy link
Collaborator

LeoGrin commented Jan 13, 2025

Example to reproduce:

from tabpfn_client import TabPFNClassifier
import numpy as np
import pandas as pd

# Create minimal example with just one training sample and two features
X_train = pd.DataFrame({
    "feature1": [0.5],
    "feature2": [0.7]
})
X_test = pd.DataFrame({
    "feature1": np.random.rand(10),
    "feature2": np.random.rand(10)
})

# Single training label and 10 test labels
y_train = np.array([1])  # Single class label
y_test = np.random.randint(0, 2, size=10)  # Random binary labels for testing

# Initialize and fit TabPFN
model = TabPFNClassifier()
model.fit(X_train, y_train)

# Make predictions
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)

# Calculate accuracy
accuracy = np.mean(y_pred == y_test)
print(f"Test accuracy: {accuracy:.4f}")

Traceback:

ERROR:tabpfn_client.client:Fail to call fit, response status: 500
Traceback (most recent call last):
  File "/scratch/lgrinszt/lm_tab/scripts/../test_one_example.py", line 24, in <module>
    model.fit(X_train, y_train)
  File "/scratch/lgrinszt/micromamba/envs/lm_tab/lib/python3.10/site-packages/tabpfn_client/estimator.py", line 146, in fit
    self.last_train_set_uid = InferenceClient.fit(X, y, config=estimator_param)
  File "/scratch/lgrinszt/micromamba/envs/lm_tab/lib/python3.10/site-packages/tabpfn_client/service_wrapper.py", line 225, in fit
    return ServiceClient.fit(X, y, config=config)
  File "/scratch/lgrinszt/micromamba/envs/lm_tab/lib/python3.10/site-packages/tabpfn_client/client.py", line 237, in fit
    cls._validate_response(response, "fit")
  File "/scratch/lgrinszt/micromamba/envs/lm_tab/lib/python3.10/site-packages/tabpfn_client/client.py", line 477, in _validate_response
    raise RuntimeError(
RuntimeError: Fail to call fit with error: 500, reason: Internal Server Error and text: Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/starlette/middleware/errors.py", line 165, in __call__
    await self.app(scope, receive, _send)
  File "/usr/local/lib/python3.10/site-packages/starlette/middleware/exceptions.py", line 62, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
    raise exc
  File "/usr/local/lib/python3.10/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
    await app(scope, receive, sender)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 715, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 735, in app
    await route.handle(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 288, in handle
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 76, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/usr/local/lib/python3.10/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
    raise exc
  File "/usr/local/lib/python3.10/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
    await app(scope, receive, sender)
  File "/usr/local/lib/python3.10/site-packages/starlette/routing.py", line 73, in app
    response = await f(request)
  File "/usr/local/lib/python3.10/site-packages/fastapi/routing.py", line 301, in app
    raw_response = await run_endpoint_function(
  File "/usr/local/lib/python3.10/site-packages/fastapi/routing.py", line 212, in run_endpoint_function
    return await dependant.call(**values)
  File "/code/tabpfn-server/app/routers/fit.py", line 70, in fit
    train_set_schema = await upload_train_set(
  File "/code/tabpfn-server/app/routers/fit.py", line 39, in upload_train_set
    user_train_set_mapping = await dataset_serv.add_train_set(
  File "/code/tabpfn-server/app/services/dataset_repo_service.py", line 327, in add_train_set
    content[FileType.Y_TRAIN] = self.preprocess_y_train(content[FileType.Y_TRAIN])
  File "/code/tabpfn-server/app/services/dataset_repo_service.py", line 312, in preprocess_y_train
    return y_train.to_csv(index=False).encode()
AttributeError: 'numpy.int64' object has no attribute 'to_csv'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant