0%

maskrcnn-benchmark源码分析(1)_网络构建.md

maskrcnn-benchmark 源码分析

代码中抽象出来一个model类GeneralizedRCNN。输入images,target到model中,然后返回loss,loss的计算都放在forward函数中了。外层的train函数,只是做iter循环、loss backward、optimizer.step、记录日志等这类比较固化的通用代码。

GeneralizedRCNN里面由三部分组成:backbone,RPN,Roi_head。网络的构建大量使用工厂方法,基本上可以根据配置来创建不同的检测网络,能够支持多种组合。

  • model GeneralizedRCNN
    • backbone nn.Sequential
      • body ResNet
        • stem StemWithFixedBatchNorm //resnet的基础层
        • module BottleneckWithFixedBatchNorm //resnet的bottleneck模块
      • fpn
        • inner_block 1*1的conv层
        • layer_block 3*3的conv层
      • backbone是RCNN的骨干网,就是图片的特征提取器,在resnet+fpn的网络下,提取出来的是各尺度(5个尺度)上的256-d的feature map,这个feature map作为下一阶段rpn的输入。
    • rpn RPNModule
      • head RPNHead //rpn的cnn网络,对featuremap每个点计算对应的bbox和cls_logits
      • anchor_generator AnchorGenerator //根据rpn_head的计算结果,找到最可能的bbox的坐标
      • box_selector_train RPNPostProcessor
      • loss_evaluator RPNLossComputation // 评估RPN网络的Loss
      • rpn是一个cnn网络,输入是feature map,输出是bbox坐标。
    • roi_heads CombinedROIHeads
      • box ROIBoxHead
        • feature_extractor FPN2MLPFeatureExtractor
          • pooler
          • fc6
          • fc7
        • predictor FPNPredictor
          • cls_score nn.Linear
          • bbox_pred nn.Linear
        • post_processor
        • loss_evaluator
      • mask ROIMaskHead

整体上,就是backbone算特征,rpn算框框, roi_heads算物体类别。外部看封装的很好,细节都藏在具体的各个类里面了。明天继续看代码,但愿自学能学到一点高手的编码风格。