-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmake_scaler.py
84 lines (79 loc) · 3.42 KB
/
make_scaler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import logging
import pickle
import glob
#import enlighten
import numpy as np
import pandas as pd
from sklearn import preprocessing
from ROOT import TFile, TTree
from root_numpy import tree2array, rec2array
import parameters
def MakeScaler(data=None,list_inputs=[],generator=False,batch=100000):
# Generate scaler #
scaler_name = 'scaler_'+parameters.suffix+'.pkl'
scaler_path = os.path.join(parameters.main_path,scaler_name)
scaler = preprocessing.StandardScaler()
if not os.path.exists(scaler_path):
# Not generator #
if data is not None:
scaler.fit(data[list_inputs])
# For generator #
if generator:
logging.info("-"*80)
logging.info("Computing mean")
# Mean Loop #
mean = np.zeros(len(list_inputs))
Ntot = 0
for f in glob.glob(parameters.path_gen_training+'/*root'):
file_handle = TFile.Open(f)
tree = file_handle.Get('tree')
N = tree.GetEntries()
Ntot += N
logging.info("Opening file %s (%d entries)"%(f,N))
# Loop over batches #
#pbar = enlighten.Counter(total=N//batch+1, desc='Mean', unit='Batch')
for i in range(0, N, batch):
array = rec2array(tree2array(tree,branches=list_inputs,start=i,stop=i+batch))
mean += np.sum(array,axis=0)
#pbar.update()
mean /= Ntot
# Var Loop #
logging.info("-"*80)
logging.info("Computing std")
std = np.zeros(len(list_inputs))
for f in glob.glob(parameters.path_gen_training+'/*root'):
file_handle = TFile.Open(f)
tree = file_handle.Get('tree')
N = tree.GetEntries()
logging.info("Opening file %s (%d entries)"%(f,N))
# Loop over batches #
#pbar = enlighten.Counter(total=N//batch+1, desc='Std', unit='Batch')
for i in range(0, N, batch):
array = rec2array(tree2array(tree,branches=list_inputs,start=i,stop=i+batch))
std += np.sum(np.square(array-mean),axis=0)
#pbar.update()
std = np.sqrt(std/Ntot)
# Set manually #
scaler.mean_ = mean
scaler.scale_ = std
# Save #
with open(scaler_path, 'wb') as handle:
pickle.dump(scaler, handle)
logging.info('Scaler %s has been created'%scaler_name)
# If exists, will import it #
else:
with open(scaler_path, 'rb') as handle:
scaler = pickle.load(handle)
logging.info('Scaler %s has been imported'%scaler_name)
# Test the scaler #
if data is not None:
try:
mean_scale = np.mean(scaler.transform(data[list_inputs]))
var_scale = np.var(scaler.transform(data[list_inputs]))
except ValueError:
logging.critical("Problem with the scaler '%s' you imported, has the data changed since it was generated ?"%scaler_name)
raise ValueError
if abs(mean_scale)>0.01 or abs((var_scale-1)/var_scale)>0.01: # Check that scaling is correct to 1%
logging.critical("Something is wrong with scaler '%s' (mean = %0.6f, var = %0.6f), maybe you loaded an incorrect scaler"%(scaler_name,mean_scale,var_scale))
raise RunTimeError