# Model
check_suffix(weights, '.pt') # 检查权重文件后缀
pretrained = weights.endswith('.pt') # 检查权重文件后缀,返回布尔类型
# 载入模型
if pretrained:
with torch_distributed_zero_first(LOCAL_RANK):
weights = attempt_download(weights) # 如果本地找不到,则下载权重文件
# 使用预训练
# ---------------------------------------------------------#
# 加载模型及参数
ckpt = torch.load(weights, map_location='cpu') # 将检查点加载到 CPU 上,以避免 CUDA 内存泄漏
# 这里加载模型有两种方式,一种是通过opt.cfg 另一种是通过ckpt['model'].yaml
# 区别在于是否使用resume 如果使用resume会将opt.cfg设为空,按照ckpt['model'].yaml来创建模型
# 这也影响了下面是否除去anchor的key(也就是不加载anchor), 如果resume则不加载anchor
# 原因: 保存的模型会保存anchors,有时候用户自定义了anchor之后,再resume,则原来基于coco数据集的anchor会自己覆盖自己设定的anchor
# 详情参考: <https://github.com/ultralytics/yolov5/issues/459>
# 所以下面设置intersect_dicts()就是忽略exclude
model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # 创建模型
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # 排除的键列表
csd = ckpt['model'].float().state_dict() # 以 FP32 格式保存检查点的 state_dict
# 筛选字典中的键值对 把exclude删除
csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # 取交集
model.load_state_dict(csd, strict=False) # 加载模型参数
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # 输出日志信息
else:
# 不使用预训练下创建模型
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
# 检查是否可以使用自动混合精度(AMP),若不支持则继续使用默认的浮点型精度进行训练
amp = check_amp(model) # check AMP
关于freeze冻结权重层