-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinference.py
119 lines (104 loc) · 4.78 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from utils.cleegn import CLEEGN
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import from_numpy as np2TT
from torchinfo import summary
from matplotlib.colors import rgb2hex
import matplotlib.pyplot as plt
from scipy.io import loadmat
from scipy.io import savemat
from scipy import signal
import numpy as np
import math
import json
import time
import sys
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
electrode = ['Fp1', 'Fp2', 'AF7', 'AF3', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7', 'POz', 'PO8', 'O1', 'O2']
""" pyplot waveform visualization """
def plotEEG(tstmps, data_colle, ref_i, electrode, titles=None, colors=None, alphas=None, ax=None):
n_data = len(data_colle)
titles = ["" for di in range(n_data)] if titles is None else titles
alphas = [0.5 for di in range(n_data)] if alphas is None else alphas
if colors is None:
cmap_ = plt.cm.get_cmap("tab20", n_data)
colors = [rgb2hex(cmap_(di)) for di in range(n_data)]
picks_chs = ["Fp1", "Fp2", "T7", "T8", "O1", "O2", "Fz", "Pz"]
picks = [electrode.index(c) for c in picks_chs]
for di in range(n_data):
data_colle[di] = data_colle[di][picks, :]
if ax is None:
ax = plt.subplot()
for ii, ch_name in enumerate(picks_chs):
offset = len(picks) - ii - 1
norm_coef = 0.25 / np.abs(data_colle[ref_i][ii]).max()
for di in range(n_data):
eeg_dt = data_colle[di]
ax.plot(tstmps, eeg_dt[ii] * norm_coef + offset,
label=None if ii else titles[di], color=colors[di], alpha=alphas[di],
linewidth=3 if alphas[di] > 0.6 else 1.5, # default=1.5
)
ax.set_xlim(tstmps[0], tstmps[-1])
ax.set_ylim(-0.5, len(picks) - 0.5)
ax.set_xticks([])
ax.set_yticks([])
ax.set_yticks(np.arange(len(picks)))
ax.set_yticklabels(picks_chs[::-1], fontsize=20)
ax.legend(
bbox_to_anchor=(0, 1.02, 1, 0.2),
loc="lower right", borderaxespad=0, ncol=3, fontsize=20
)
def ar_through_model(eeg_data, model, window_size, stride):
model.eval()
noiseless_eeg = np.zeros(eeg_data.shape, dtype=np.float32)
hcoef = np.zeros(eeg_data.shape[1], dtype=np.float32)
hwin = signal.hann(window_size) + 1e-9
for i in range(0, noiseless_eeg.shape[1], stride):
tstap, LAST_FRAME = i, False
segment = eeg_data[:, tstap: tstap + window_size]
if segment.shape[1] != window_size:
tstap = noiseless_eeg.shape[1] - window_size
segment = eeg_data[:, tstap:]
LAST_FRAME = True
with torch.no_grad():
segment = np.expand_dims(segment, axis=0)
data = np2TT(np.expand_dims(segment, axis=0))
data = data.to(device, dtype=torch.float)
pred_segment = model(data)
pred_segment = np.array(pred_segment.cpu()).astype(np.float32)
noiseless_eeg[:, tstap: tstap + window_size] += \
pred_segment.squeeze() * hwin
hcoef[tstap: tstap + window_size] += hwin
if LAST_FRAME:
break
noiseless_eeg /= hcoef
return noiseless_eeg
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="removal artifact from multi-channel EEG data")
parser.add_argument("--mat-path", required=True, type=str, help="path to EEG data (.mat)")
parser.add_argument("--model-path", required=True, type=str, help="path to pre-trained model (.pth)")
args = parser.parse_args()
mat = loadmat(args.mat_path)
dt_polluted, dt_ref = mat["x_test"], mat["y_test"]
### temporary fixed mode
state_path = os.path.join(args.model_path)
state = torch.load(state_path, map_location="cpu")
model = CLEEGN(n_chan=56, fs=128.0, N_F=56).to(device)
model.load_state_dict(state["state_dict"])
dt_cleegn = ar_through_model(
dt_polluted, model, math.ceil(4.0 * 128.0), math.ceil(1.0 * 128.0)
)
x_min, x_max = 1500, 1500 + 1000
x_data = dt_polluted[:, x_min: x_max]
y_data = dt_ref[:, x_min: x_max]
p_data = dt_cleegn[:, x_min: x_max]
fig, ax = plt.subplots(1, 1, figsize=(16, 9))
plotEEG(
np.linspace(0, math.ceil(x_data.shape[-1] / 128.0), x_data.shape[-1]),
[x_data, y_data, y_data, p_data], 1, electrode,
titles=["Original", "", "Reference", "CLEEGN"], colors=["gray", "gray", "red", "blue"], alphas=[0.5, 0, 0.8, 0.8], ax=ax
)
plt.show()