SOLO代码阅读解析
SOLO是一种直接预测instance mask的范式,摒弃了之前top-down和bottom-up两种主流的实例分割方法,从而pipeline更加简洁直观。这篇文章以官方代码中的demo为例,简单梳理一下SOLO在inference时的流程。整个代码基于mmdetection。
首先是demo.inference_demo.py
config_file = ‘../configs/solo/decoupled_solo_r50_fpn_8gpu_3x.py’
# download the checkpoint from model zoo and put it in `checkpoints/`
checkpoint_file = ‘../checkpoints/DECOUPLED_SOLO_R50_3x.pth’
# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device=’cuda:0′)
# test a single image
img = ‘demo.jpg’
result = inference_detector(model, img)
show_result_ins(img, result, model.CLASSES, score_thr=0.25, out_file=”demo_out.jpg”)
上述代码很简单,init_detector创建model,inference_detector做正向inference,并且show出*后的result。核心在于init_detector和inference_detector。这两个function存在于mmdet.apis中,下面看下这个模块:
mmdet.apis.inferece.py
def init_detector(config, checkpoint=None, device=’cuda:0′):
“””Initialize a detector from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
Returns:
nn.Module: The constructed detector.
“””
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError(‘config must be a filename or Config object, ‘
‘but got {}’.format(type(config)))
config.model.pretrained = None
model = build_detector(config.model, test_cfg=config.test_cfg)
if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint)
if ‘CLASSES’ in checkpoint[‘meta’]:
model.CLASSES = checkpoint[‘meta’][‘CLASSES’]
else:
warnings.warn(‘Class names are not saved in the checkpoint\’s ‘
‘meta data, use COCO classes by default.’)
model.CLASSES = get_classes(‘coco’)
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model
def inference_detector(model, img):
“””Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray]): Either image files or loaded
images.
Returns:
If imgs is a str, a generator will be returned, otherwise return the
detection results directly.
“””
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img)
data = test_pipeline(data)
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
# forward the model
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
return result
对于init_detector,其核心函数是build_detector,根据config文件信息创建模型,并将checkpoint加载进来;而inference_detector更简单了,首先做一系列augmentation,然后调用model做inference即可。
那么接下来仍然是两个分支,build_detector是如何创建模型的,以及该模型如何做inference,分开来说。
build_detector
build_detector方法存在于mmdet.model.builder.py:
from mmdet.utils import build_from_cfg
from .registry import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
ROI_EXTRACTORS, SHARED_HEADS)
def build(cfg, registry, default_args=None):
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_backbone(cfg):
return build(cfg, BACKBONES)
def build_neck(cfg):
return build(cfg, NECKS)
def build_roi_extractor(cfg):
return build(cfg, ROI_EXTRACTORS)
def build_shared_head(cfg):
return build(cfg, SHARED_HEADS)
def build_head(cfg):
return build(cfg, HEADS)
def build_loss(cfg):
return build(cfg, LOSSES)
def build_detector(cfg, train_cfg=None, test_cfg=None):
return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
build_detector方法又调用了build方法,而build方法中调用了build_from_cfg。注意:在调用build方法中传入了DETECTORS这个注册器(Registry,一个类,传入的参数该class的一个实例,每一个部分i.e. backbone,FPN etc. 都对应一个Registry实例),可以先理解为创建这些module以及分开进行管理。
接着看mmdet.utils.registry.py中的build_from_cfg:
def build_from_cfg(cfg, registry, default_args=None):
“””Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key “type”.
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
obj: The constructed object.
“””
assert isinstance(cfg, dict) and ‘type’ in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop(‘type’)
if mmcv.is_str(obj_type):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(‘{} is not in the {} registry’.format(
obj_type, registry.name))
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(‘type must be a str or valid type, but got {}’.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_cls(**args)
这里其实就是对注册器进行注册的部分,也就是说通过config中的字典来对模型进行搭建。obj_cls就是要创建的module,如SOLO,ResNet,FPN等等,只有某个注册器中有配置文件中存在的type时,才会对该注册器进行register,通过args中的dict得到相应的module。这里一开始obj_cls返回的是SOLO(可以refer下配置文件),所以我们要找到SOLO这个模型的文件:
mmdet.models.detectors.solo.py
@DETECTORS.register_module
class SOLO(SingleStageInsDetector):
def __init__(self,
backbone,
neck,
bbox_head,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(SOLO, self).__init__(backbone, neck, bbox_head, None, train_cfg,
test_cfg, pretrained)
可见*行用了一个装饰器,也就是说在创建SOLO实例的时候,首先就自动调用装饰器中的方法,并且把SOLO这个类作为参数,注册到注册器DETECTORS里面。而SOLO又是继承自SingleStageInsDetector,所以接下来重点是SingleStageInsDetector类:
mmdet.models.detectors.single_stage_ins.py
@DETECTORS.register_module
class SingleStageInsDetector(BaseDetector):
def __init__(self,
backbone,
neck=None,
bbox_head=None,
mask_feat_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(SingleStageInsDetector, self).__init__()
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
if mask_feat_head is not None:
self.mask_feat_head = builder.build_head(mask_feat_head)
self.bbox_head = builder.build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
上面是SingleStageInsDetector的核心代码,之前是将args作为参数传入作为这里的初始化。根据之前的config,依次创建模型的backbone,neck,bbox_head以及test_config(这里是inference),这些部分的创建又对应到builder中的函数,每一个module对应一个Registry,然后根据相应的config文件中的参数建立不同的module,*后都作为类内部变量,集中在这一个SingleStageInsDetector中。具体每一个module创建的代码就不贴了,无非是将args传递进去,根据现有的代码创建相应的模块。
至此模型的创建工作大致如此,下面来看Inference的过程。
Inference
SOLO类的forward继承自BaseDetector,其forward方法如下:
def forward_test(self, imgs, img_metas, **kwargs):
。。。。。。
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], **kwargs)
else:
return self.aug_test(imgs, img_metas, **kwargs)
@auto_fp16(apply_to=(‘img’, ))
def forward(self, img, img_meta, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(img, img_meta, **kwargs)
else:
return self.forward_test(img, img_meta, **kwargs)
以单gpu为例,调用的是simple_test,这个函数在SingleStageInsDetector中被重写过,如下:
def extract_feat(self, img):
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def simple_test(self, img, img_meta, rescale=False):
x = self.extract_feat(img)
outs = self.bbox_head(x, eval=True)
if self.with_mask_feat_head:
mask_feat_pred = self.mask_feat_head(
x[self.mask_feat_head.
start_level:self.mask_feat_head.end_level + 1])
seg_inputs = outs + (mask_feat_pred, img_meta, self.test_cfg, rescale)
else:
seg_inputs = outs + (img_meta, self.test_cfg, rescale)
seg_result = self.bbox_head.get_seg(*seg_inputs)
return seg_result
这里Inference的顺序依次是backbone->neck->bbox_head,backbone为ResNet50,neck为FPN,bbox_head为(decoupled)solo_head。所以前面特征提取部分的代码很简单,就不做过多赘述。主要来看下bbox_head:
mmdet.models.anchor_heads.decoupled_solo_head.py
@HEADS.register_module
class DecoupledSOLOHead(nn.Module):
def __init__(self,
num_classes,
in_channels,
seg_feat_channels=256,
stacked_convs=4,
strides=(4, 8, 16, 32, 64),
base_edge_list=(16, 32, 64, 128, 256),
scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
sigma=0.4,
num_grids=None,
cate_down_pos=0,
with_deform=False,
loss_ins=None,
loss_cate=None,
conv_cfg=None,
norm_cfg=None):
super(DecoupledSOLOHead, self).__init__()
self.num_classes = num_classes
self.seg_num_grids = num_grids
self.cate_out_channels = self.num_classes – 1
self.in_channels = in_channels
self.seg_feat_channels = seg_feat_channels
self.stacked_convs = stacked_convs
self.strides = strides
self.sigma = sigma
self.cate_down_pos = cate_down_pos
self.base_edge_list = base_edge_list
self.scale_ranges = scale_ranges
self.with_deform = with_deform
self.loss_cate = build_loss(loss_cate)
self.ins_loss_weight = loss_ins[‘loss_weight’]
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self._init_layers()
def _init_layers(self):
norm_cfg = dict(type=’GN’, num_groups=32, requires_grad=True)
self.ins_convs_x = nn.ModuleList()
self.ins_convs_y = nn.ModuleList()
self.cate_convs = nn.ModuleList()
for i in range(self.stacked_convs):
#*层+1表示采用coordconv concat上的position(如果非decouple则+2)
chn = self.in_channels + 1 if i == 0 else self.seg_feat_channels
# ins_x分支几个卷积+norm模块
self.ins_convs_x.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
# ins_y分支几个卷积+norm模块
self.ins_convs_y.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
chn = self.in_channels if i == 0 else self.seg_feat_channels
# cate分支几个卷积+norm模块
self.cate_convs.append(
ConvModule(
chn,
self.seg_feat_channels,
3,
stride=1,
padding=1,
norm_cfg=norm_cfg,
bias=norm_cfg is None))
self.dsolo_ins_list_x = nn.ModuleList()
self.dsolo_ins_list_y = nn.ModuleList()
#每一个level对应的num_grid不同,针对所有level的feature设计对应维度的卷积
for seg_num_grid in self.seg_num_grids:
self.dsolo_ins_list_x.append(
nn.Conv2d(
self.seg_feat_channels, seg_num_grid, 3, padding=1))
self.dsolo_ins_list_y.append(
nn.Conv2d(
self.seg_feat_channels, seg_num_grid, 3, padding=1))
self.dsolo_cate = nn.Conv2d(
self.seg_feat_channels, self.cate_out_channels, 3, padding=1)
def forward(self, feats, eval=False):
# for i in feats:
# print(i.shape)
# torch.Size([1, 256, 200, 304])
# torch.Size([1, 256, 100, 152])
# torch.Size([1, 256, 50, 76])
# torch.Size([1, 256, 25, 38])
# torch.Size([1, 256, 13, 19])
new_feats = self.split_feats(feats)
# for i in new_feats:
# print(i[0].shape)
# torch.Size([256, 100, 152])
# torch.Size([256, 100, 152])
# torch.Size([256, 50, 76])
# torch.Size([256, 25, 38])
# torch.Size([256, 25, 38])
featmap_sizes = [featmap.size()[-2:] for featmap in new_feats]
upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2)
# print(upsampled_size) (200, 304)
ins_pred_x, ins_pred_y, cate_pred = multi_apply(self.forward_single, new_feats,
list(range(len(self.seg_num_grids))),
eval=eval, upsampled_size=upsampled_size)
return ins_pred_x, ins_pred_y, cate_pred
def split_feats(self, feats):
return (F.interpolate(feats[0], scale_factor=0.5, mode=’bilinear’),
feats[1],
feats[2],
feats[3],
F.interpolate(feats[4], size=feats[3].shape[-2:], mode=’bilinear’))
def forward_single(self, x, idx, eval=False, upsampled_size=None):
ins_feat = x
cate_feat = x
# ins branch
# concat coord
x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device)
y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([ins_feat.shape[0], 1, -1, -1])
x = x.expand([ins_feat.shape[0], 1, -1, -1])
# print(ins_feat.shape)
# print(x.shape)
ins_feat_x = torch.cat([ins_feat, x], 1)
ins_feat_y = torch.cat([ins_feat, y], 1)
# print(ins_feat_x.shape) (1, 256 + 1, ?, ?)
for ins_layer_x, ins_layer_y in zip(self.ins_convs_x, self.ins_convs_y):
ins_feat_x = ins_layer_x(ins_feat_x)
ins_feat_y = ins_layer_y(ins_feat_y)
ins_feat_x = F.interpolate(ins_feat_x, scale_factor=2, mode=’bilinear’)
ins_feat_y = F.interpolate(ins_feat_y, scale_factor=2, mode=’bilinear’)
ins_pred_x = self.dsolo_ins_list_x[idx](ins_feat_x)
ins_pred_y = self.dsolo_ins_list_y[idx](ins_feat_y)
# print(ins_pred_x.shape) 对应到每个feat_map对应的grid (1,256,?,?)->(1,40/36/24/16/12,?,?)
# cate branch
for i, cate_layer in enumerate(self.cate_convs):
if i == self.cate_down_pos:
seg_num_grid = self.seg_num_grids[idx] # idx对应特征图的level,不同level的num_grid不同
cate_feat = F.interpolate(cate_feat, size=seg_num_grid, mode=’bilinear’)
cate_feat = cate_layer(cate_feat)
cate_pred = self.dsolo_cate(cate_feat)
# print(cate_pred.shape) (1, 80, num_grid, num_grid)
if eval:
ins_pred_x = F.interpolate(ins_pred_x.sigmoid(), size=upsampled_size, mode=’bilinear’)
ins_pred_y = F.interpolate(ins_pred_y.sigmoid(), size=upsampled_size, mode=’bilinear’)
cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
return ins_pred_x, ins_pred_y, cate_pred
上面的代码是solo_head正向传播以后得到的结果:ins_pred_x, ins_pred_y, cate_pred。但并不是完整的Inference,*终的maks生成还需要进行下面两个函数的操作:
def get_seg(self, seg_preds_x, seg_preds_y, cate_preds, img_metas, cfg, rescale=None):
assert len(seg_preds_x) == len(cate_preds)
num_levels = len(cate_preds)
# print(num_levels) 5
featmap_size = seg_preds_x[0].size()[-2:]
# print(featmap_size) [200, 304]
# for i in range(5):
# print(seg_preds_x[i].shape)
# print(cate_preds[i].shape)
# torch.Size([1, 40, 200, 304])
# torch.Size([1, 40, 40, 80])
# torch.Size([1, 36, 200, 304])
# torch.Size([1, 36, 36, 80])
# torch.Size([1, 24, 200, 304])
# torch.Size([1, 24, 24, 80])
# torch.Size([1, 16, 200, 304])
# torch.Size([1, 16, 16, 80])
# torch.Size([1, 12, 200, 304])
# torch.Size([1, 12, 12, 80])
result_list = []
#由于是demo,这里只有一张img
for img_id in range(len(img_metas)):
cate_pred_list = [
cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels)
]
# print(cate_pred_list[0].shape) (num_grid*num_grid, 80)
seg_pred_list_x = [
seg_preds_x[i][img_id].detach() for i in range(num_levels)
]
# print(seg_pred_list_x[0].shape) #(num_grid, 200, 304)
seg_pred_list_y = [
seg_preds_y[i][img_id].detach() for i in range(num_levels)
]
img_shape = img_metas[img_id][‘img_shape’]
scale_factor = img_metas[img_id][‘scale_factor’]
ori_shape = img_metas[img_id][‘ori_shape’]
cate_pred_list = torch.cat(cate_pred_list, dim=0) #(3872, 80) == (40^2+36^2+24^2+16^2+12^2, 80)
seg_pred_list_x = torch.cat(seg_pred_list_x, dim=0) #(128, 200, 304) == (40+36+24+16+12, 200, 304)
# print(seg_pred_list_x.shapes)
seg_pred_list_y = torch.cat(seg_pred_list_y, dim=0)
result = self.get_seg_single(cate_pred_list, seg_pred_list_x, seg_pred_list_y,
featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale)
result_list.append(result)
return result_list
def get_seg_single(self,
cate_preds,
seg_preds_x,
seg_preds_y,
featmap_size,
img_shape,
ori_shape,
scale_factor,
cfg,
rescale=False, debug=False):
# overall info.
h, w, _ = img_shape
upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4) # 原图大小
# trans trans_diff.
trans_size = torch.Tensor(self.seg_num_grids).pow(2).cumsum(0).long() # [1600, 2896, 3472, 3728, 3872]
trans_diff = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
num_grids = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
seg_size = torch.Tensor(self.seg_num_grids).cumsum(0).long()
seg_diff = torch.ones(trans_size[-1].item(), device=cate_preds.device).long()
strides = torch.ones(trans_size[-1].item(), device=cate_preds.device) # [1, 1, …, 1]
n_stage = len(self.seg_num_grids)
trans_diff[:trans_size[0]] *= 0
seg_diff[:trans_size[0]] *= 0
num_grids[:trans_size[0]] *= self.seg_num_grids[0]
# print(self.strides) [8, 8, 16, 32, 32]
strides[:trans_size[0]] *= self.strides[0]
for ind_ in range(1, n_stage):
trans_diff[trans_size[ind_ – 1]:trans_size[ind_]] *= trans_size[ind_ – 1]
seg_diff[trans_size[ind_ – 1]:trans_size[ind_]] *= seg_size[ind_ – 1]
num_grids[trans_size[ind_ – 1]:trans_size[ind_]] *= self.seg_num_grids[ind_]
strides[trans_size[ind_ – 1]:trans_size[ind_]] *= self.strides[ind_] # [0-1599:8, 1600-2895:8, 2896-3471: 16, 2372-3871:32]
# process.
inds = (cate_preds > cfg.score_thr)
# print(inds.shape) # [3872, 80]布尔矩阵
cate_scores = cate_preds[inds]
# print(cate_scores) # [3872, 80]
inds = inds.nonzero()
# print(inds.shape) # (n, 2) n表示有多少个分数>thres
trans_diff = torch.index_select(trans_diff, dim=0, index=inds[:, 0])
seg_diff = torch.index_select(seg_diff, dim=0, index=inds[:, 0])
num_grids = torch.index_select(num_grids, dim=0, index=inds[:, 0])
strides = torch.index_select(strides, dim=0, index=inds[:, 0])
y_inds = (inds[:, 0] – trans_diff) // num_grids
x_inds = (inds[:, 0] – trans_diff) % num_grids
y_inds += seg_diff
x_inds += seg_diff
cate_labels = inds[:, 1]
# print(cate_labels) # n维向量,表示类别num
seg_masks_soft = seg_preds_x[x_inds, …] * seg_preds_y[y_inds, …] # [n, 200, 304]
seg_masks = seg_masks_soft > cfg.mask_thr
sum_masks = seg_masks.sum((1, 2)).float() # [n, 1]
keep = sum_masks > strides # 进一步筛除,总的mask之和小于stride就筛掉
# print(keep)
seg_masks_soft = seg_masks_soft[keep, …]
seg_masks = seg_masks[keep, …]
cate_scores = cate_scores[keep]
sum_masks = sum_masks[keep]
cate_labels = cate_labels[keep]
# maskness
seg_score = (seg_masks_soft * seg_masks.float()).sum((1, 2)) / sum_masks
cate_scores *= seg_score
if len(cate_scores) == 0:
return None
# sort and keep top nms_pre
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.nms_pre:
sort_inds = sort_inds[:cfg.nms_pre]
seg_masks_soft = seg_masks_soft[sort_inds, :, :]
seg_masks = seg_masks[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
sum_masks = sum_masks[sort_inds]
cate_labels = cate_labels[sort_inds]
# print(cate_scores)
# Matrix NMS
cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores,
kernel=cfg.kernel, sigma=cfg.sigma, sum_masks=sum_masks)
# print(cate_scores) #维度并没变,只是将IOU高的部分的score降低,类似于soft-NMS
keep = cate_scores >= cfg.update_thr
seg_masks_soft = seg_masks_soft[keep, :, :]
cate_scores = cate_scores[keep]
# print(cate_scores.shape) #筛掉一部分
cate_labels = cate_labels[keep]
# sort and keep top_k
sort_inds = torch.argsort(cate_scores, descending=True)
if len(sort_inds) > cfg.max_per_img: # coco数据集*大一张img100个instance
sort_inds = sort_inds[:cfg.max_per_img]
seg_masks_soft = seg_masks_soft[sort_inds, :, :]
cate_scores = cate_scores[sort_inds]
cate_labels = cate_labels[sort_inds]
# 将mask的resolution还原到original图像大小
seg_masks_soft = F.interpolate(seg_masks_soft.unsqueeze(0),
size=upsampled_size_out,
mode=’bilinear’)[:, :, :h, :w]
seg_masks = F.interpolate(seg_masks_soft,
size=ori_shape[:2],
mode=’bilinear’).squeeze(0)
seg_masks = seg_masks > cfg.mask_thr
return seg_masks, cate_labels, cate_scores
*后在demo中在Matrix NMS之后,选择的筛除阈值为0.05,这个值有点小导致很多有小目标的img筛出来100个,*后demo在展示结果的时候又采用了0.25的阈值,这里会不会有些矛盾。