-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathsanterre.py
134 lines (116 loc) · 4.87 KB
/
santerre.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import numpy as np
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
Input,
Dense,
Dropout,
Conv1D,
Lambda,
Permute,
Multiply,
Flatten,
)
import tensorflow.keras.backend as K
import tensorflow as tf
from activations import Mish
from optimizers import Ranger
import losses as l
import callbacks as cb
from layers import Attention, LayerNormalization
from data import dataset
from generator import generator
# I will leave all the flags to default, so you can fiddle
data = dataset(
"data/ninaPro", # path to dataset
butter = True, # butter high pass filter (performed after rectification bc that changes the power spectrum)
rectify = True, # rectifies the data, proven to expose more information on the firing rate, see https://pubmed.ncbi.nlm.nih.gov/12706845/
ma = 15, # moving average with a window size of 15, has not been tuned, performed lazily
step = 5, # 5 observation step between windows
window = 52, # 52 time steps per window at 200hx = 260 ms, chosen off precedent set in other papers, but not tuned. shorter is better
exercises = ["a","b","c"], # which exercise set to use. Dont change for now
features = None # if we want to engineer some classical features we can, never used
)
# splitting by repetition
reps = np.unique(data.repetition)
val_reps = reps[3::2]
train_reps = reps[np.where(np.isin(reps, val_reps, invert=True))]
test_reps = val_reps[-1].copy()
val_reps = val_reps[:-1]
# generator class indexes the dataset and lazily applies augmentation and moving
# average. Indexes return tuples, len returns number of batches
train = generator(
data, # the dataset to use
list(train_reps), # the repetitions to use
shuffle = True,
batch_size = 128, # chosen arbitrarily
imu = False, # if you set this to true, set it to true on other generators as well
augment = True, # applies noise at a spectrum of SNR ratios to data. Does not augment if class is zero to combat imbalance
ma = True # bool for moving average or not (applied in __getitem__)
)
validation = generator(data, list(val_reps), augment=False)
test = generator(data, [test_reps][0], augment=False)
n_time = train[0][0].shape[1]
n_class = train[0][-1].shape[-1]
n_features = train[0][0].shape[-1]
model_pars = {
"n_time": n_time,
"n_class": n_class,
"n_features": n_features,
"dense": [500, 500, 2000], # arbitrarily chosen classifier network
"drop": [0.36, 0.36, 0.36], # "tuned" dropout rate, not sure how good it is
}
# cosine annealing rate scheduler (pairs well with our optimization strategy)
# for five epochs, stick at a high learning rate, then cosine anneal. Has not
# really been tuned
cosine = cb.CosineAnnealingScheduler(
T_max=50, eta_max=1e-3, eta_min=1e-5, verbose=1, epoch_start=5
)
# focal loss stolen from computer vision. Softmax but it weights easy examples
# (aka when the class is 0) less than difficult examples, using the class
# alignment of the logits (if grounud truth is [0, 1, 0] and we predict [0, 0.9, 0.1])
# that contributes to the loss less than [0, 0.6, 0.4]. Gamma has an effect on
# this weighting, alpha does something that affects steepness of loss fn i
# think, higher alpha
loss = l.focal_loss(gamma=3., alpha=6.)
def attention_simple(inputs, n_time):
input_dim = int(inputs.shape[-1])
a = Permute((2, 1), name='temporalize')(inputs)
a = Dense(n_time, activation='softmax', name='attention_probs')(a)
a_probs = Permute((2, 1), name='attention_vec')(a)
output_attention_mul = Multiply(name='focused_attention')([inputs, a_probs])
output_flat = Lambda(lambda x: K.sum(x, axis=1), name='temporal_average')(output_attention_mul)
return output_flat
def make_model(n_time, n_class, n_features, dense, drop):
inputs = Input((n_time, n_features))
x = Flatten()(inputs)
#x = Conv1D(filters=n_features, kernel_size=3, padding="same", activation=Mish())(x)
#x = LayerNormalization()(x)
#x = attention_simple(x, n_time)
for d, dr in zip(dense, drop):
x = Dropout(dr)(x)
x = Dense(d, activation=Mish())(x)
x = LayerNormalization()(x)
outputs = Dense(n_class, activation="softmax")(x)
model = Model(inputs, outputs)
print(model.summary())
return model
model = make_model(**model_pars)
# fancy optimizer, combination of rectified adam and lookahead
model.compile(Ranger(learning_rate=1e-3), loss=loss, metrics=["accuracy"])
model.fit(
train,
epochs=55,
validation_data=validation,
callbacks=[
ModelCheckpoint(
f"fiddle.h5",
monitor="val_loss",
keep_best_only=True,
save_weights_only=True,
),
cosine,
],
shuffle = False, # shuffling is done by the generator, if you shuffle here it will be infinitely slower
)
import pdb; pdb.set_trace() # XXX BREAKPOINT