forked from lcicek/Critic-VAE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcrafter_extension_dataset.py
78 lines (49 loc) · 2.22 KB
/
crafter_extension_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from torch.utils.data import Dataset
from crafter_extension_utils import choose, interpolate_simple
import torch
import numpy as np
class CrafterCriticDataset(Dataset):
def __init__(self, X, Y, oversample=False, dataset_size=50000, interpolate_real=False,real_windowsize=5):
if not interpolate_real:
Y = torch.tensor(Y)
if oversample:
X_positive = X[Y == 1]
X_negative = X[Y == 0]
del X
no_positive = len(X_positive)
no_negative = len(X_negative)
if no_negative > dataset_size / 2:
no_negative = int(dataset_size / 2)
X_negative = choose(X_negative, no_negative, replace=False)
X_positive = choose(X_positive, no_negative)
X = torch.vstack((X_positive, X_negative))
Y = torch.cat((torch.ones(len(X_positive)), torch.zeros(len(X_negative))))
self.X = X
self.Y = Y
else:
if oversample:
# get interpolated reward
self.Y = interpolate_simple(Y,windowsize=real_windowsize)
self.Y=torch.tensor(Y)
# get indices by reward value
ix_low = np.where(Y<=0.3)[0]
ix_high = np.where(Y>=0.7)[0]
ix_med = np.where((Y>0.3) & (Y<0.7))[0]
# get 1/3 per sample category
low_samples_ix = np.random.choice(ix_low,int(dataset_size/3))
med_samples_ix = np.random.choice(ix_med,int(dataset_size/3))
high_samples_ix = np.random.choice(ix_high,int(dataset_size/3))
samples_ix = np.concatenate((low_samples_ix,med_samples_ix,high_samples_ix))
self.X = X[samples_ix]
self.Y=Y[samples_ix]
del X, Y
else:
self.X = X
self.Y = interpolate_simple(Y)
self.Y = torch.tensor(Y)
assert len(self.X) == len(self.Y), '?! error X shape does not match Y shape'
self.len = len(self.X)
def __len__(self):
return self.len
def __getitem__(self, idx):
return (self.X[idx], self.Y[idx])