-
Notifications
You must be signed in to change notification settings - Fork 66
/
Copy pathmodules.py
504 lines (381 loc) · 17.4 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
import utils
import numpy as np
import torch
from torch import nn
class ContrastiveSWM(nn.Module):
"""Main module for a Contrastively-trained Structured World Model (C-SWM).
Args:
embedding_dim: Dimensionality of abstract state space.
input_dims: Shape of input observation.
hidden_dim: Number of hidden units in encoder and transition model.
action_dim: Dimensionality of action space.
num_objects: Number of object slots.
"""
def __init__(self, embedding_dim, input_dims, hidden_dim, action_dim,
num_objects, hinge=1., sigma=0.5, encoder='large',
ignore_action=False, copy_action=False):
super(ContrastiveSWM, self).__init__()
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.action_dim = action_dim
self.num_objects = num_objects
self.hinge = hinge
self.sigma = sigma
self.ignore_action = ignore_action
self.copy_action = copy_action
self.pos_loss = 0
self.neg_loss = 0
num_channels = input_dims[0]
width_height = input_dims[1:]
if encoder == 'small':
self.obj_extractor = EncoderCNNSmall(
input_dim=num_channels,
hidden_dim=hidden_dim // 16,
num_objects=num_objects)
# CNN image size changes
width_height = np.array(width_height)
width_height = width_height // 10
elif encoder == 'medium':
self.obj_extractor = EncoderCNNMedium(
input_dim=num_channels,
hidden_dim=hidden_dim // 16,
num_objects=num_objects)
# CNN image size changes
width_height = np.array(width_height)
width_height = width_height // 5
elif encoder == 'large':
self.obj_extractor = EncoderCNNLarge(
input_dim=num_channels,
hidden_dim=hidden_dim // 16,
num_objects=num_objects)
self.obj_encoder = EncoderMLP(
input_dim=np.prod(width_height),
hidden_dim=hidden_dim,
output_dim=embedding_dim,
num_objects=num_objects)
self.transition_model = TransitionGNN(
input_dim=embedding_dim,
hidden_dim=hidden_dim,
action_dim=action_dim,
num_objects=num_objects,
ignore_action=ignore_action,
copy_action=copy_action)
self.width = width_height[0]
self.height = width_height[1]
def energy(self, state, action, next_state, no_trans=False):
"""Energy function based on normalized squared L2 norm."""
norm = 0.5 / (self.sigma**2)
if no_trans:
diff = state - next_state
else:
pred_trans = self.transition_model(state, action)
diff = state + pred_trans - next_state
return norm * diff.pow(2).sum(2).mean(1)
def transition_loss(self, state, action, next_state):
return self.energy(state, action, next_state).mean()
def contrastive_loss(self, obs, action, next_obs):
objs = self.obj_extractor(obs)
next_objs = self.obj_extractor(next_obs)
state = self.obj_encoder(objs)
next_state = self.obj_encoder(next_objs)
# Sample negative state across episodes at random
batch_size = state.size(0)
perm = np.random.permutation(batch_size)
neg_state = state[perm]
self.pos_loss = self.energy(state, action, next_state)
zeros = torch.zeros_like(self.pos_loss)
self.pos_loss = self.pos_loss.mean()
self.neg_loss = torch.max(
zeros, self.hinge - self.energy(
state, action, neg_state, no_trans=True)).mean()
loss = self.pos_loss + self.neg_loss
return loss
def forward(self, obs):
return self.obj_encoder(self.obj_extractor(obs))
class TransitionGNN(torch.nn.Module):
"""GNN-based transition function."""
def __init__(self, input_dim, hidden_dim, action_dim, num_objects,
ignore_action=False, copy_action=False, act_fn='relu'):
super(TransitionGNN, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_objects = num_objects
self.ignore_action = ignore_action
self.copy_action = copy_action
if self.ignore_action:
self.action_dim = 0
else:
self.action_dim = action_dim
self.edge_mlp = nn.Sequential(
nn.Linear(input_dim*2, hidden_dim),
utils.get_act_fn(act_fn),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
utils.get_act_fn(act_fn),
nn.Linear(hidden_dim, hidden_dim))
node_input_dim = hidden_dim + input_dim + self.action_dim
self.node_mlp = nn.Sequential(
nn.Linear(node_input_dim, hidden_dim),
utils.get_act_fn(act_fn),
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
utils.get_act_fn(act_fn),
nn.Linear(hidden_dim, input_dim))
self.edge_list = None
self.batch_size = 0
def _edge_model(self, source, target, edge_attr):
del edge_attr # Unused.
out = torch.cat([source, target], dim=1)
return self.edge_mlp(out)
def _node_model(self, node_attr, edge_index, edge_attr):
if edge_attr is not None:
row, col = edge_index
agg = utils.unsorted_segment_sum(
edge_attr, row, num_segments=node_attr.size(0))
out = torch.cat([node_attr, agg], dim=1)
else:
out = node_attr
return self.node_mlp(out)
def _get_edge_list_fully_connected(self, batch_size, num_objects, cuda):
# Only re-evaluate if necessary (e.g. if batch size changed).
if self.edge_list is None or self.batch_size != batch_size:
self.batch_size = batch_size
# Create fully-connected adjacency matrix for single sample.
adj_full = torch.ones(num_objects, num_objects)
# Remove diagonal.
adj_full -= torch.eye(num_objects)
self.edge_list = adj_full.nonzero()
# Copy `batch_size` times and add offset.
self.edge_list = self.edge_list.repeat(batch_size, 1)
offset = torch.arange(
0, batch_size * num_objects, num_objects).unsqueeze(-1)
offset = offset.expand(batch_size, num_objects * (num_objects - 1))
offset = offset.contiguous().view(-1)
self.edge_list += offset.unsqueeze(-1)
# Transpose to COO format -> Shape: [2, num_edges].
self.edge_list = self.edge_list.transpose(0, 1)
if cuda:
self.edge_list = self.edge_list.cuda()
return self.edge_list
def forward(self, states, action):
cuda = states.is_cuda
batch_size = states.size(0)
num_nodes = states.size(1)
# states: [batch_size (B), num_objects, embedding_dim]
# node_attr: Flatten states tensor to [B * num_objects, embedding_dim]
node_attr = states.view(-1, self.input_dim)
edge_attr = None
edge_index = None
if num_nodes > 1:
# edge_index: [B * (num_objects*[num_objects-1]), 2] edge list
edge_index = self._get_edge_list_fully_connected(
batch_size, num_nodes, cuda)
row, col = edge_index
edge_attr = self._edge_model(
node_attr[row], node_attr[col], edge_attr)
if not self.ignore_action:
if self.copy_action:
action_vec = utils.to_one_hot(
action, self.action_dim).repeat(1, self.num_objects)
action_vec = action_vec.view(-1, self.action_dim)
else:
action_vec = utils.to_one_hot(
action, self.action_dim * num_nodes)
action_vec = action_vec.view(-1, self.action_dim)
# Attach action to each state
node_attr = torch.cat([node_attr, action_vec], dim=-1)
node_attr = self._node_model(
node_attr, edge_index, edge_attr)
# [batch_size, num_nodes, hidden_dim]
return node_attr.view(batch_size, num_nodes, -1)
class EncoderCNNSmall(nn.Module):
"""CNN encoder, maps observation to obj-specific feature maps."""
def __init__(self, input_dim, hidden_dim, num_objects, act_fn='sigmoid',
act_fn_hid='relu'):
super(EncoderCNNSmall, self).__init__()
self.cnn1 = nn.Conv2d(
input_dim, hidden_dim, (10, 10), stride=10)
self.cnn2 = nn.Conv2d(hidden_dim, num_objects, (1, 1), stride=1)
self.ln1 = nn.BatchNorm2d(hidden_dim)
self.act1 = utils.get_act_fn(act_fn_hid)
self.act2 = utils.get_act_fn(act_fn)
def forward(self, obs):
h = self.act1(self.ln1(self.cnn1(obs)))
return self.act2(self.cnn2(h))
class EncoderCNNMedium(nn.Module):
"""CNN encoder, maps observation to obj-specific feature maps."""
def __init__(self, input_dim, hidden_dim, num_objects, act_fn='sigmoid',
act_fn_hid='leaky_relu'):
super(EncoderCNNMedium, self).__init__()
self.cnn1 = nn.Conv2d(
input_dim, hidden_dim, (9, 9), padding=4)
self.act1 = utils.get_act_fn(act_fn_hid)
self.ln1 = nn.BatchNorm2d(hidden_dim)
self.cnn2 = nn.Conv2d(
hidden_dim, num_objects, (5, 5), stride=5)
self.act2 = utils.get_act_fn(act_fn)
def forward(self, obs):
h = self.act1(self.ln1(self.cnn1(obs)))
h = self.act2(self.cnn2(h))
return h
class EncoderCNNLarge(nn.Module):
"""CNN encoder, maps observation to obj-specific feature maps."""
def __init__(self, input_dim, hidden_dim, num_objects, act_fn='sigmoid',
act_fn_hid='relu'):
super(EncoderCNNLarge, self).__init__()
self.cnn1 = nn.Conv2d(input_dim, hidden_dim, (3, 3), padding=1)
self.act1 = utils.get_act_fn(act_fn_hid)
self.ln1 = nn.BatchNorm2d(hidden_dim)
self.cnn2 = nn.Conv2d(hidden_dim, hidden_dim, (3, 3), padding=1)
self.act2 = utils.get_act_fn(act_fn_hid)
self.ln2 = nn.BatchNorm2d(hidden_dim)
self.cnn3 = nn.Conv2d(hidden_dim, hidden_dim, (3, 3), padding=1)
self.act3 = utils.get_act_fn(act_fn_hid)
self.ln3 = nn.BatchNorm2d(hidden_dim)
self.cnn4 = nn.Conv2d(hidden_dim, num_objects, (3, 3), padding=1)
self.act4 = utils.get_act_fn(act_fn)
def forward(self, obs):
h = self.act1(self.ln1(self.cnn1(obs)))
h = self.act2(self.ln2(self.cnn2(h)))
h = self.act3(self.ln3(self.cnn3(h)))
return self.act4(self.cnn4(h))
class EncoderMLP(nn.Module):
"""MLP encoder, maps observation to latent state."""
def __init__(self, input_dim, output_dim, hidden_dim, num_objects,
act_fn='relu'):
super(EncoderMLP, self).__init__()
self.num_objects = num_objects
self.input_dim = input_dim
self.fc1 = nn.Linear(self.input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
self.ln = nn.LayerNorm(hidden_dim)
self.act1 = utils.get_act_fn(act_fn)
self.act2 = utils.get_act_fn(act_fn)
def forward(self, ins):
h_flat = ins.view(-1, self.num_objects, self.input_dim)
h = self.act1(self.fc1(h_flat))
h = self.act2(self.ln(self.fc2(h)))
return self.fc3(h)
class DecoderMLP(nn.Module):
"""MLP decoder, maps latent state to image."""
def __init__(self, input_dim, hidden_dim, num_objects, output_size,
act_fn='relu'):
super(DecoderMLP, self).__init__()
self.fc1 = nn.Linear(input_dim + num_objects, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, np.prod(output_size))
self.input_dim = input_dim
self.num_objects = num_objects
self.output_size = output_size
self.act1 = utils.get_act_fn(act_fn)
self.act2 = utils.get_act_fn(act_fn)
def forward(self, ins):
obj_ids = torch.arange(self.num_objects)
obj_ids = utils.to_one_hot(obj_ids, self.num_objects).unsqueeze(0)
obj_ids = obj_ids.repeat((ins.size(0), 1, 1)).to(ins.get_device())
h = torch.cat((ins, obj_ids), -1)
h = self.act1(self.fc1(h))
h = self.act2(self.fc2(h))
h = self.fc3(h).sum(1)
return h.view(-1, self.output_size[0], self.output_size[1],
self.output_size[2])
class DecoderCNNSmall(nn.Module):
"""CNN decoder, maps latent state to image."""
def __init__(self, input_dim, hidden_dim, num_objects, output_size,
act_fn='relu'):
super(DecoderCNNSmall, self).__init__()
width, height = output_size[1] // 10, output_size[2] // 10
output_dim = width * height
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
self.ln = nn.LayerNorm(hidden_dim)
self.deconv1 = nn.ConvTranspose2d(num_objects, hidden_dim,
kernel_size=1, stride=1)
self.deconv2 = nn.ConvTranspose2d(hidden_dim, output_size[0],
kernel_size=10, stride=10)
self.input_dim = input_dim
self.num_objects = num_objects
self.map_size = output_size[0], width, height
self.act1 = utils.get_act_fn(act_fn)
self.act2 = utils.get_act_fn(act_fn)
self.act3 = utils.get_act_fn(act_fn)
def forward(self, ins):
h = self.act1(self.fc1(ins))
h = self.act2(self.ln(self.fc2(h)))
h = self.fc3(h)
h_conv = h.view(-1, self.num_objects, self.map_size[1],
self.map_size[2])
h = self.act3(self.deconv1(h_conv))
return self.deconv2(h)
class DecoderCNNMedium(nn.Module):
"""CNN decoder, maps latent state to image."""
def __init__(self, input_dim, hidden_dim, num_objects, output_size,
act_fn='relu'):
super(DecoderCNNMedium, self).__init__()
width, height = output_size[1] // 5, output_size[2] // 5
output_dim = width * height
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
self.ln = nn.LayerNorm(hidden_dim)
self.deconv1 = nn.ConvTranspose2d(num_objects, hidden_dim,
kernel_size=5, stride=5)
self.deconv2 = nn.ConvTranspose2d(hidden_dim, output_size[0],
kernel_size=9, padding=4)
self.ln1 = nn.BatchNorm2d(hidden_dim)
self.input_dim = input_dim
self.num_objects = num_objects
self.map_size = output_size[0], width, height
self.act1 = utils.get_act_fn(act_fn)
self.act2 = utils.get_act_fn(act_fn)
self.act3 = utils.get_act_fn(act_fn)
def forward(self, ins):
h = self.act1(self.fc1(ins))
h = self.act2(self.ln(self.fc2(h)))
h = self.fc3(h)
h_conv = h.view(-1, self.num_objects, self.map_size[1],
self.map_size[2])
h = self.act3(self.ln1(self.deconv1(h_conv)))
return self.deconv2(h)
class DecoderCNNLarge(nn.Module):
"""CNN decoder, maps latent state to image."""
def __init__(self, input_dim, hidden_dim, num_objects, output_size,
act_fn='relu'):
super(DecoderCNNLarge, self).__init__()
width, height = output_size[1], output_size[2]
output_dim = width * height
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, output_dim)
self.ln = nn.LayerNorm(hidden_dim)
self.deconv1 = nn.ConvTranspose2d(num_objects, hidden_dim,
kernel_size=3, padding=1)
self.deconv2 = nn.ConvTranspose2d(hidden_dim, hidden_dim,
kernel_size=3, padding=1)
self.deconv3 = nn.ConvTranspose2d(hidden_dim, hidden_dim,
kernel_size=3, padding=1)
self.deconv4 = nn.ConvTranspose2d(hidden_dim, output_size[0],
kernel_size=3, padding=1)
self.ln1 = nn.BatchNorm2d(hidden_dim)
self.ln2 = nn.BatchNorm2d(hidden_dim)
self.ln3 = nn.BatchNorm2d(hidden_dim)
self.input_dim = input_dim
self.num_objects = num_objects
self.map_size = output_size[0], width, height
self.act1 = utils.get_act_fn(act_fn)
self.act2 = utils.get_act_fn(act_fn)
self.act3 = utils.get_act_fn(act_fn)
self.act4 = utils.get_act_fn(act_fn)
self.act5 = utils.get_act_fn(act_fn)
def forward(self, ins):
h = self.act1(self.fc1(ins))
h = self.act2(self.ln(self.fc2(h)))
h = self.fc3(h)
h_conv = h.view(-1, self.num_objects, self.map_size[1],
self.map_size[2])
h = self.act3(self.ln1(self.deconv1(h_conv)))
h = self.act4(self.ln1(self.deconv2(h)))
h = self.act5(self.ln1(self.deconv3(h)))
return self.deconv4(h)