maskrcnn-benchmark 源码分析
代码中抽象出来一个model类GeneralizedRCNN。输入images,target到model中,然后返回loss,loss的计算都放在forward函数中了。外层的train函数,只是做iter循环、loss backward、optimizer.step、记录日志等这类比较固化的通用代码。
GeneralizedRCNN里面由三部分组成:backbone,RPN,Roi_head。网络的构建大量使用工厂方法,基本上可以根据配置来创建不同的检测网络,能够支持多种组合。
- model GeneralizedRCNN
- backbone nn.Sequential
- body ResNet
- stem StemWithFixedBatchNorm //resnet的基础层
- module BottleneckWithFixedBatchNorm //resnet的bottleneck模块
- fpn
- inner_block 1*1的conv层
- layer_block 3*3的conv层
- backbone是RCNN的骨干网,就是图片的特征提取器,在resnet+fpn的网络下,提取出来的是各尺度(5个尺度)上的256-d的feature map,这个feature map作为下一阶段rpn的输入。
- body ResNet
- rpn RPNModule
- head RPNHead //rpn的cnn网络,对featuremap每个点计算对应的bbox和cls_logits
- anchor_generator AnchorGenerator //根据rpn_head的计算结果,找到最可能的bbox的坐标
- box_selector_train RPNPostProcessor
- loss_evaluator RPNLossComputation // 评估RPN网络的Loss
- rpn是一个cnn网络,输入是feature map,输出是bbox坐标。
- roi_heads CombinedROIHeads
- box ROIBoxHead
- feature_extractor FPN2MLPFeatureExtractor
- pooler
- fc6
- fc7
- predictor FPNPredictor
- cls_score nn.Linear
- bbox_pred nn.Linear
- post_processor
- loss_evaluator
- feature_extractor FPN2MLPFeatureExtractor
- mask ROIMaskHead
- box ROIBoxHead
- backbone nn.Sequential
整体上,就是backbone算特征,rpn算框框, roi_heads算物体类别。外部看封装的很好,细节都藏在具体的各个类里面了。明天继续看代码,但愿自学能学到一点高手的编码风格。