# Image size 图像大小
gs = max(int(model.stride.max()), 32) # 网格大小(最大步长)
imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # 验证 imgsz 是否为 gs 的倍数
# Batch size 批大小
# 判断条件 RANK == -1 and batch_size == -1 是否成立。如果成立,说明当前是单 GPU 的情况,并且批大小未指定。
if RANK == -1 and batch_size == -1: # 仅适用于单 GPU,估计最佳批大小
batch_size = check_train_batch_size(model, imgsz, amp) # 估计最佳批处理大小并赋值给batch_size
loggers.on_params_update({'batch_size': batch_size})
这段代码的作用是设置图像尺寸和批大小。
首先,通过计算模型的最大步长 model.stride.max()
和 32 的较大值,得到网格大小 gs
。
然后,使用 check_img_size
函数验证 opt.imgsz
是否为 gs
的倍数,并将结果赋值给 imgsz
。该函数的作用是确保图像尺寸 imgsz
是网格大小 gs
的倍数,以保证在进行特征提取和预测时能够正确对齐。
接下来,判断条件 RANK == -1 and batch_size == -1
是否成立。如果成立,说明当前是单 GPU 的情况,并且批大小未指定。
在这种情况下,调用 check_train_batch_size
函数估计最佳的批大小。该函数的作用是根据模型、图像尺寸和是否使用自动混合精度(AMP)来估计最佳的训练批大小。然后,将得到的最佳批大小赋值给 batch_size
。
最后,使用 loggers.on_params_update({'batch_size': batch_size})
更新日志记录器中的批大小参数。
这段代码的目的是设置图像尺寸和批大小,确保它们符合模型的要求,并根据情况估计最佳的批大小。这样可以在训练过程中充分利用 GPU 的计算资源,提高训练效率。