Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to make predictions deterministic? #18

Open
abdulmeral opened this issue Mar 18, 2024 · 8 comments
Open

How to make predictions deterministic? #18

abdulmeral opened this issue Mar 18, 2024 · 8 comments
Labels
FAQ Frequently asked question

Comments

@abdulmeral
Copy link

Hello everyone,
how to prevent it from changing new predictions after each training?
Of course it is a LLM, but it is problem that changing every time. for example; my mape value is changing rapidly, bad or better..
have any recommend?

thank you already now.
abdül

@lostella lostella changed the title Seed Number .. How to make predictions deterministic? Mar 19, 2024
@lostella
Copy link
Contributor

@abdulmeral since the models are based on PyTorch, you can refer to PyTorch documentation about reproducibility.

In particular, setting the random number generation seed with

import torch
torch.manual_seed(0) # or some other number

before doing predictions, should make the behavior deterministic.

Another option to stabilize predictions (at least the "central tendencies" like mean and median) is to increase the number of samples by passing in num_samples when calling .predict (default is 20): this however will come at the cost of slower predictions.

@lostella lostella added the FAQ Frequently asked question label Mar 19, 2024
@abdulfatir
Copy link
Contributor

@lostella I think it maybe better to use transformers.set_seed for seeding. It seeds everything under the sun, so you will have consistent results.

@abdulmeral
Copy link
Author

thank you very much guys. both ideas are working.
and also my colleague recommends that making predictions with different seed numbers so we can get more generic results.

@abdulfatir
Copy link
Contributor

@abdulmeral glad that it helped. On another note, you should probably stop using MAPE because it's not a good metric. Checkout something like MASE, if you're working with point forecasts.

@lostella
Copy link
Contributor

@lostella I think it maybe better to use transformers.set_seed for seeding. It seeds everything under the sun, so you will have consistent results.

TIL

@UncleChen2018
Copy link

Hi,
In a Jupyter Notebook, with transformers.set_seed(0) set, I've noticed in the same session, if you run the same block a second time, it gives a different result. Since no logic is changed between them, and the seed is fixed, why would that happen?

@EgShes
Copy link

EgShes commented Nov 28, 2024

@UncleChen2018 have you tried transformers.set_seed(0, deterministic=True)?

@guilhermeparreira
Copy link

guilhermeparreira commented Dec 27, 2024

I have this issue right now, and it is haunting me. I followed all the steps, except for including num_samples in predict because this option was unavailable.

I brought a reproducible example for a multivariate time series (multiple ids to predict).
I have different results if I provide a 1D Tensor compared to a 2D Tensor. Should not they be the same? Can (Should) we try to make both results the same?

import pandas as pd
import numpy as np
import torch
import os
# os.system('export HF_HOME="/mnt/data/gui"')  # Set the environment variable
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
# os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
os.environ["HF_HOME"] = (
    "/mnt/data/gui"  ## define diretório de cache do hugginface para carregamento dos modelos - Don't know if it is necessary
)
from chronos import BaseChronosPipeline
import random
import gc
import transformers
## Seta device para uso de gpu
device = "cuda" if torch.cuda.is_available() else "cpu"
random.seed(42)
np.random.seed(42)
os.environ["PYTHONHASHSEED"] = str(42)
torch.manual_seed(42)
transformers.set_seed(42, deterministic=True)
df = pd.read_parquet("https://datasets-nixtla.s3.amazonaws.com/m4-hourly.parquet")
df = df[df["unique_id"].isin([f"H{i}" for i in range(1, 10)])]
df.rename(columns={"unique_id": "id"}, inplace=True)
forecast_steps = 12
type_model = "amazon/chronos-bolt-base"  # Parameters: 205M - Tempo carregamento: 1m 11s
# local_model_path = "/mnt/data/gui/hub/models--amazon--chronos-bolt-tiny/snapshots/f6ff2d2ba9168d498c015bc8dd07e3b395b31b3f"
local_model_path = "/mnt/data/gui/hub/models--amazon--chronos-bolt-base/snapshots/6f8ced46a499ae1dfd399981f551152d756cf4f6"
try:
    # Define the path where the model is stored locally
    pipeline = BaseChronosPipeline.from_pretrained(
        local_model_path,
        device_map="cuda",  # use "cpu" for CPU inference
        torch_dtype=torch.bfloat16,
    )
except FileNotFoundError or OSError or NameError:
# except:
    pipeline = BaseChronosPipeline.from_pretrained(
        type_model,
        device_map="cuda",  # use "cpu" for CPU inference
        torch_dtype=torch.bfloat16,
    )
########## TENSOR 2D

# Group the data by 'id'
grouped = df.groupby("id")
# Prepare the context as a 2D tensor (batch_size, sequence_length)
context = []
ids = []
for id_value, group in grouped:
    # Store the ID for later reconstruction
    ids.append(id_value)
    # Convert the 'y' values to a tensor
    context.append(group["y"].values)

# Convert the list of arrays into a 2D tensor
context_tensor = torch.tensor(context, dtype=torch.float32)

# Define the prediction length and quantile levels
quantile_levels = [0.1, 0.5, 0.9]

# Make the predictions for all ids at once
quantiles, mean = pipeline.predict_quantiles(
    context=context_tensor,
    prediction_length=forecast_steps,
    quantile_levels=quantile_levels,
)
# Convert predictions back to a dataframe
results = []

for i, id_value in enumerate(ids):
    for t in range(forecast_steps):
        results.append(
            {
                "id": id_value,
                "time_step": t,
                "mean": mean[i, t].item(),
                **{f"quantile_{q}": quantiles[i, t, j].item() for j, q in enumerate(quantile_levels)},
            }
        )

# Create a dataframe from the results
results_df = pd.DataFrame(results)
########### 1D TENSOR
ids = df["id"].unique()

y_preds = []
products = []
for id in ids:
    data = df[df['id']==id]
    ## FORECAST
    quantiles, mean = pipeline.predict_quantiles(
        context=torch.tensor(np.array(data["y"]), dtype=torch.float32),
        prediction_length=forecast_steps,
        quantile_levels=[0.1, 0.5, 0.9],
    )
    y_pred = mean.cpu().numpy().flatten()
    y_preds.append(y_pred)
    products.append(np.repeat(id, forecast_steps))
df_single = pd.DataFrame({"id":np.concatenate(products).tolist(), "mean":np.concatenate(y_preds).tolist()})
df_single[df_single['id']=="H9"]

image

results_df[results_df['id']=="H9"]

image

observations:

  • If I start my df with a single time series, both results are the same
  • Once I started my df with a two-time series, I also got the same results
  • If I changed the value of CUBLAS_WORKSPACE_CONFIG I also had different results.

So, is there another step to make the results deterministic?

Thank you in advance!

env:

tensorflow                        2.17.0          
tensorflow-base              2.17.0          
tensorflow-estimator     2.17.0          
tensorflow-gpu               2.17.0          
pytorch                             2.4.1           
chronos-forecasting       1.4.1

Ubuntu 22.04.5 LTS
NVIDIA-SMI 535.183.01
CUDA Version: 12.2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FAQ Frequently asked question
Projects
None yet
Development

No branches or pull requests

6 participants