-
Notifications
You must be signed in to change notification settings - Fork 174
/
sconet.py
53 lines (44 loc) · 1.86 KB
/
sconet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from ..base_model import BaseModel
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks
from einops import rearrange
import numpy as np
class ScoNet(BaseModel):
def build_network(self, model_cfg):
self.Backbone = self.get_backbone(model_cfg['backbone_cfg'])
self.Backbone = SetBlockWrapper(self.Backbone)
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
self.TP = PackSequenceWrapper(torch.max)
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
def forward(self, inputs):
ipts, labs, class_id, _, seqL = inputs
class_id_int = np.array([1 if status == 'positive' else 2 if status == 'neutral' else 0 for status in class_id])
class_id = torch.tensor(class_id_int).cuda()
sils = ipts[0]
if len(sils.size()) == 4:
sils = sils.unsqueeze(1)
else:
sils = rearrange(sils, 'n s c h w -> n c s h w')
del ipts
outs = self.Backbone(sils) # [n, c, s, h, w]
# Temporal Pooling, TP
outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
# Horizontal Pooling Matching, HPM
feat = self.HPP(outs) # [n, c, p]
embed_1 = self.FCs(feat) # [n, c, p]
embed_2, logits = self.BNNecks(embed_1) # [n, c, p]
embed = embed_1
retval = {
'training_feat': {
'triplet': {'embeddings': embed, 'labels': labs},
'softmax': {'logits': logits, 'labels': class_id},
},
'visual_summary': {
'image/sils': rearrange(sils,'n c s h w -> (n s) c h w')
},
'inference_feat': {
'embeddings': logits
}
}
return retval