-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathload_sprites.py
59 lines (54 loc) · 2.23 KB
/
load_sprites.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os, time
import numpy as np
def sprites_act(path, seed=0, return_labels = False):
directions = ['front', 'left', 'right']
actions = ['walk', 'spellcard', 'slash']
start = time.time()
path = path + 'npy/'
X_train = []
X_test = []
if return_labels:
A_train = []; A_test = []
D_train = []; D_test = []
for act in xrange(len(actions)):
for i in xrange(len(directions)):
label = 3 * act + i
print actions[act], directions[i], act, i, label
x = np.load(path + '%s_%s_frames_train.npy' % (actions[act], directions[i]))
X_train.append(x)
y = np.load(path + '%s_%s_frames_test.npy' % (actions[act], directions[i]))
X_test.append(y)
if return_labels:
a = np.load(path + '%s_%s_attributes_train.npy' % (actions[act], directions[i]))
A_train.append(a)
d = np.zeros([a.shape[0], a.shape[1], 9])
d[:, :, label] = 1; D_train.append(d)
a = np.load(path + '%s_%s_attributes_test.npy' % (actions[act], directions[i]))
A_test.append(a)
d = np.zeros([a.shape[0], a.shape[1], 9])
d[:, :, label] = 1; D_test.append(d)
X_train = np.concatenate(X_train, axis=0)
X_test = np.concatenate(X_test, axis=0)
np.random.seed(seed)
ind = np.random.permutation(X_train.shape[0])
X_train = X_train[ind]
if return_labels:
A_train = np.concatenate(A_train, axis=0)
D_train = np.concatenate(D_train, axis=0)
A_train = A_train[ind]
D_train = D_train[ind]
ind = np.random.permutation(X_test.shape[0])
X_test = X_test[ind]
if return_labels:
A_test = np.concatenate(A_test, axis=0)
D_test = np.concatenate(D_test, axis=0)
A_test = A_test[ind]
D_test = D_test[ind]
print A_test.shape, D_test.shape, X_test.shape, 'shapes'
print X_train.shape, X_test.min(), X_test.max()
end = time.time()
print 'data loaded in %.2f seconds...' % (end - start)
if return_labels:
return X_train, X_test, A_train, A_test, D_train, D_test
else:
return X_train, X_test