forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpoint_head.py
367 lines (316 loc) · 14.9 KB
/
point_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
try:
from mmcv.ops import point_sample
except ModuleNotFoundError:
point_sample = None
from typing import List
from mmseg.registry import MODELS
from mmseg.utils import SampleList
from ..losses import accuracy
from ..utils import resize
from .cascade_decode_head import BaseCascadeDecodeHead
def calculate_uncertainty(seg_logits):
"""Estimate uncertainty based on seg logits.
For each location of the prediction ``seg_logits`` we estimate
uncertainty as the difference between top first and top second
predicted logits.
Args:
seg_logits (Tensor): Semantic segmentation logits,
shape (batch_size, num_classes, height, width).
Returns:
scores (Tensor): T uncertainty scores with the most uncertain
locations having the highest uncertainty score, shape (
batch_size, 1, height, width)
"""
top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
@MODELS.register_module()
class PointHead(BaseCascadeDecodeHead):
"""A mask point head use in PointRend.
This head is implemented of `PointRend: Image Segmentation as
Rendering <https://arxiv.org/abs/1912.08193>`_.
``PointHead`` use shared multi-layer perceptron (equivalent to
nn.Conv1d) to predict the logit of input points. The fine-grained feature
and coarse feature will be concatenate together for predication.
Args:
num_fcs (int): Number of fc layers in the head. Default: 3.
in_channels (int): Number of input channels. Default: 256.
fc_channels (int): Number of fc channels. Default: 256.
num_classes (int): Number of classes for logits. Default: 80.
class_agnostic (bool): Whether use class agnostic classification.
If so, the output channels of logits will be 1. Default: False.
coarse_pred_each_layer (bool): Whether concatenate coarse feature with
the output of each fc layer. Default: True.
conv_cfg (dict|None): Dictionary to construct and config conv layer.
Default: dict(type='Conv1d'))
norm_cfg (dict|None): Dictionary to construct and config norm layer.
Default: None.
loss_point (dict): Dictionary to construct and config loss layer of
point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
loss_weight=1.0).
"""
def __init__(self,
num_fcs=3,
coarse_pred_each_layer=True,
conv_cfg=dict(type='Conv1d'),
norm_cfg=None,
act_cfg=dict(type='ReLU', inplace=False),
**kwargs):
super().__init__(
input_transform='multiple_select',
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
init_cfg=dict(
type='Normal', std=0.01, override=dict(name='fc_seg')),
**kwargs)
if point_sample is None:
raise RuntimeError('Please install mmcv-full for '
'point_sample ops')
self.num_fcs = num_fcs
self.coarse_pred_each_layer = coarse_pred_each_layer
fc_in_channels = sum(self.in_channels) + self.num_classes
fc_channels = self.channels
self.fcs = nn.ModuleList()
for k in range(num_fcs):
fc = ConvModule(
fc_in_channels,
fc_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.fcs.append(fc)
fc_in_channels = fc_channels
fc_in_channels += self.num_classes if self.coarse_pred_each_layer \
else 0
self.fc_seg = nn.Conv1d(
fc_in_channels,
self.num_classes,
kernel_size=1,
stride=1,
padding=0)
if self.dropout_ratio > 0:
self.dropout = nn.Dropout(self.dropout_ratio)
delattr(self, 'conv_seg')
def cls_seg(self, feat):
"""Classify each pixel with fc."""
if self.dropout is not None:
feat = self.dropout(feat)
output = self.fc_seg(feat)
return output
def forward(self, fine_grained_point_feats, coarse_point_feats):
x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)
for fc in self.fcs:
x = fc(x)
if self.coarse_pred_each_layer:
x = torch.cat((x, coarse_point_feats), dim=1)
return self.cls_seg(x)
def _get_fine_grained_point_feats(self, x, points):
"""Sample from fine grained features.
Args:
x (list[Tensor]): Feature pyramid from by neck or backbone.
points (Tensor): Point coordinates, shape (batch_size,
num_points, 2).
Returns:
fine_grained_feats (Tensor): Sampled fine grained feature,
shape (batch_size, sum(channels of x), num_points).
"""
fine_grained_feats_list = [
point_sample(_, points, align_corners=self.align_corners)
for _ in x
]
if len(fine_grained_feats_list) > 1:
fine_grained_feats = torch.cat(fine_grained_feats_list, dim=1)
else:
fine_grained_feats = fine_grained_feats_list[0]
return fine_grained_feats
def _get_coarse_point_feats(self, prev_output, points):
"""Sample from fine grained features.
Args:
prev_output (list[Tensor]): Prediction of previous decode head.
points (Tensor): Point coordinates, shape (batch_size,
num_points, 2).
Returns:
coarse_feats (Tensor): Sampled coarse feature, shape (batch_size,
num_classes, num_points).
"""
coarse_feats = point_sample(
prev_output, points, align_corners=self.align_corners)
return coarse_feats
def loss(self, inputs, prev_output, batch_data_samples: SampleList,
train_cfg, **kwargs):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
batch_data_samples (list[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `img_metas` or `gt_semantic_seg`.
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self._transform_inputs(inputs)
with torch.no_grad():
points = self.get_points_train(
prev_output, calculate_uncertainty, cfg=train_cfg)
fine_grained_point_feats = self._get_fine_grained_point_feats(
x, points)
coarse_point_feats = self._get_coarse_point_feats(prev_output, points)
point_logits = self.forward(fine_grained_point_feats,
coarse_point_feats)
losses = self.loss_by_feat(point_logits, points, batch_data_samples)
return losses
def predict(self, inputs, prev_output, batch_img_metas: List[dict],
test_cfg, **kwargs):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
prev_output (Tensor): The output of previous decode head.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output segmentation map.
"""
x = self._transform_inputs(inputs)
refined_seg_logits = prev_output.clone()
for _ in range(test_cfg.subdivision_steps):
refined_seg_logits = resize(
refined_seg_logits,
scale_factor=test_cfg.scale_factor,
mode='bilinear',
align_corners=self.align_corners)
batch_size, channels, height, width = refined_seg_logits.shape
point_indices, points = self.get_points_test(
refined_seg_logits, calculate_uncertainty, cfg=test_cfg)
fine_grained_point_feats = self._get_fine_grained_point_feats(
x, points)
coarse_point_feats = self._get_coarse_point_feats(
prev_output, points)
point_logits = self.forward(fine_grained_point_feats,
coarse_point_feats)
point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
refined_seg_logits = refined_seg_logits.reshape(
batch_size, channels, height * width)
refined_seg_logits = refined_seg_logits.scatter_(
2, point_indices, point_logits)
refined_seg_logits = refined_seg_logits.view(
batch_size, channels, height, width)
return self.predict_by_feat(refined_seg_logits, batch_img_metas,
**kwargs)
def loss_by_feat(self, point_logits, points, batch_data_samples, **kwargs):
"""Compute segmentation loss."""
gt_semantic_seg = self._stack_batch_gt(batch_data_samples)
point_label = point_sample(
gt_semantic_seg.float(),
points,
mode='nearest',
align_corners=self.align_corners)
point_label = point_label.squeeze(1).long()
loss = dict()
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_module in losses_decode:
loss['point' + loss_module.loss_name] = loss_module(
point_logits, point_label, ignore_index=self.ignore_index)
loss['acc_point'] = accuracy(
point_logits, point_label, ignore_index=self.ignore_index)
return loss
def get_points_train(self, seg_logits, uncertainty_func, cfg):
"""Sample points for training.
Sample points in [0, 1] x [0, 1] coordinate space based on their
uncertainty. The uncertainties are calculated for each point using
'uncertainty_func' function that takes point's logit prediction as
input.
Args:
seg_logits (Tensor): Semantic segmentation logits, shape (
batch_size, num_classes, height, width).
uncertainty_func (func): uncertainty calculation function.
cfg (dict): Training config of point head.
Returns:
point_coords (Tensor): A tensor of shape (batch_size, num_points,
2) that contains the coordinates of ``num_points`` sampled
points.
"""
num_points = cfg.num_points
oversample_ratio = cfg.oversample_ratio
importance_sample_ratio = cfg.importance_sample_ratio
assert oversample_ratio >= 1
assert 0 <= importance_sample_ratio <= 1
batch_size = seg_logits.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(
batch_size, num_sampled, 2, device=seg_logits.device)
point_logits = point_sample(seg_logits, point_coords)
# It is crucial to calculate uncertainty based on the sampled
# prediction value for the points. Calculating uncertainties of the
# coarse predictions first and sampling them for points leads to
# incorrect results. To illustrate this: assume uncertainty func(
# logits)=-abs(logits), a sampled point between two coarse
# predictions with -1 and 1 logits has 0 logits, and therefore 0
# uncertainty value. However, if we calculate uncertainties for the
# coarse predictions first, both will have -1 uncertainty,
# and sampled point will get -1 uncertainty.
point_uncertainties = uncertainty_func(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(
point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_sampled * torch.arange(
batch_size, dtype=torch.long, device=seg_logits.device)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
batch_size, num_uncertain_points, 2)
if num_random_points > 0:
rand_point_coords = torch.rand(
batch_size, num_random_points, 2, device=seg_logits.device)
point_coords = torch.cat((point_coords, rand_point_coords), dim=1)
return point_coords
def get_points_test(self, seg_logits, uncertainty_func, cfg):
"""Sample points for testing.
Find ``num_points`` most uncertain points from ``uncertainty_map``.
Args:
seg_logits (Tensor): A tensor of shape (batch_size, num_classes,
height, width) for class-specific or class-agnostic prediction.
uncertainty_func (func): uncertainty calculation function.
cfg (dict): Testing config of point head.
Returns:
point_indices (Tensor): A tensor of shape (batch_size, num_points)
that contains indices from [0, height x width) of the most
uncertain points.
point_coords (Tensor): A tensor of shape (batch_size, num_points,
2) that contains [0, 1] x [0, 1] normalized coordinates of the
most uncertain points from the ``height x width`` grid .
"""
num_points = cfg.subdivision_num_points
uncertainty_map = uncertainty_func(seg_logits)
batch_size, _, height, width = uncertainty_map.shape
h_step = 1.0 / height
w_step = 1.0 / width
uncertainty_map = uncertainty_map.view(batch_size, height * width)
num_points = min(height * width, num_points)
point_indices = uncertainty_map.topk(num_points, dim=1)[1]
point_coords = torch.zeros(
batch_size,
num_points,
2,
dtype=torch.float,
device=seg_logits.device)
point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
width).float() * w_step
point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
width).float() * h_step
return point_indices, point_coords