【怎样获取源码包】【删除支付页面源码】【开源写字机源码】源码解析pytorch

时间:2025-01-18 16:59:08 分类:溯源码燕碎图片 来源:Nemo源码解析

1.PyTorch显存管理介绍与源码解析(一)
2.PyTorch - DataLoader 源码解析(一)
3.源码详解Pytorch的源码state_dict和load_state_dict
4.Pytorch源码剖析:nn.Module功能介绍及实现原理
5.PyTorch 源码分析(一):torch.nn.Module
6.Pytorch之Dataparallel源码解析

源码解析pytorch

PyTorch显存管理介绍与源码解析(一)

       GPU作为一种通用的数据处理设备,在设计时比较开放,解析API在满足客户需求的源码同时,也使得维护成本降低。解析然而,源码对于显存的解析怎样获取源码包精细管理需要由上层应用来完成。在PyTorch框架中,源码有一套专门的解析显存管理逻辑,能够更好地满足框架的源码需求,相比原生的解析CUDA API,该机制在管理细化和使用效率上更胜一筹。源码本文将主要讲解PyTorch1.版显存管理的解析逻辑,通过分析设计原理,源码帮助读者理解PyTorch的解析显存管理机制,以便在使用过程中遇到相关问题时能够更加得心应手。源码

       显存管理的主要任务是解决当用户创建一个数据(例如张量)时,需要一个确定大小的内存块,管理机制如何合理分配显存块给这个请求,并处理多对多的关系:请求、设备、用户进程。主流AI框架的显存管理方式基本遵循这一逻辑,主要步骤包括:

       1. 管理器申请整块显存,然后将其切分成多个小的显存块;

       2. 上层应用向管理器请求显存,管理器搜索并返回最佳的小显存块给应用;

       3. 管理器将多个闲置的小显存块进行合并,必要时释放这些显存块。

       在这一过程中,涉及的关键动作有申请、切分、搜索匹配、合并、释放等。接下来,我们将分别介绍显存申请/释放方式、设计要素、实现思路。

       ### 显存申请/释放方式

       #### 1.1.1 mallloc方式:`cudaMalloc/Free`

       `cudaMalloc/Free`是最常用的显存申请方式,其操作与CPU的`malloc/free`类似。用户只需指定指针和数据大小,即可调用API获取指定大小的显存块,并返回给`void*devPtr`指针。

       #### 关键问题:时间开销

       `cudaMalloc/Free`的API调用时间并不小,尤其是在框架使用的数据非常零碎且数量多时,频繁调用会直接影响程序整体性能。因此,尽量减少`cudaMalloc/Free`的调用频率是优化的关键。

       #### 1.1.2 统一内存:`cudaMallocManaged`

       `cudaMallocManaged`是一种与CPU内存统一管理的使用方式,允许使用系统内存充当“显存”,从而增加可用显存量,但会导致运行速度降低。

       #### 1.1.3 虚拟内存管理:`cuMemCreate`

       `cuMemCreate`是一种cu driver层的API,提供了一个独立的地址空间,支持显存块大小的动态调整,满足了用户增加显存大小的需求。但同样面临时间开销大和调用不够灵活的问题。

       ### 设计要素

       显存的申请有两种常见方式:动态申请和一次性申请。PyTorch框架采用动态申请方式,实时调整显存使用,避免了过量的删除支付页面源码显存占用。这种方式的优点是方便多人同时使用设备,但也带来了如何设计申请频率、处理API时间消耗和管理机制带来的碎片问题等挑战。

       ### 实现思路

       PyTorch1.版本显存管理主要采用`cudaMalloc`方式,通过考虑的问题和实现方式来优化显存使用。具体实现逻辑包括:

       #### 2.1 管理逻辑1: size触发创建

       管理机制根据申请的`size`决定创建多大的`segment`以及是否进行切分。

       #### 2.2 管理逻辑2:显存池

       申请显存后,多余显存会被放入显存池中。框架运行时会创建多个显存池,根据显存块的`size`将其映射到不同的池中。

       #### 2.3 管理逻辑3:块融合回收

       用户不需要使用的显存块不直接释放,而是回收到`blockPool`中。当整个`segment`未被使用时,可以触发`cudaFree`操作释放显存。同时,实施一种块融合机制,当释放一个`Block`时,寻找相邻的空闲`Block`进行合并,降低显存碎片问题。

       #### 2.4 整体逻辑

       通过上述介绍,可以整理出一个整体的运行逻辑,包括查找、创建、切分、保存、返回、回收和释放等步骤。在当前机制下,存在的问题是显存可视化,PyTorch支持将操作数据存储下来并进行分析,以便更好地理解`segment`和`block`的关系。

       ### 显存可视化

       在PyTorch2.x中,可以通过Snapshot将显存消耗进行可视化,这有助于了解`segment`和`block`之间的关系。例如,系统创建了一个MB的`segment`,该`segment`可以满足1~MB的`block`需求。通过可视化数据,可以追溯一个`segment`的消耗全过程。

       ### 结论

       通过以上内容的介绍和分析,我们可以了解到PyTorch显存管理机制的设计原理、实现思路以及存在的问题。这一机制旨在优化显存使用效率,满足框架需求,提供灵活且高效的显存管理方案。随着技术的不断发展,显存管理机制也将不断优化,以适应更多复杂场景的需要。

PyTorch - DataLoader 源码解析(一)

       本文为作者基于个人经验进行的初步解析,由于能力有限,可能存在遗漏或错误,敬请各位批评指正。

       本文并未全面解析 DataLoader 的全部源码,仅对 DataLoader 与 Sampler 之间的联系进行了分析。以下内容均基于单线程迭代器代码展开,多线程情况将在后续文章中阐述。

       以一个简单的数据集遍历代码为例,在循环中,开源写字机源码数据是如何从 loader 中被取出的?通过断点调试,我们发现循环时,代码进入了 torch.utils.data.DataLoader 类的 __iter__() 方法,具体内容如下:

       可以看到,该函数返回了一个迭代器,主要由 self._get_iterator() 和 self._iterator._reset(self) 提供。接下来,我们进入 self._get_iterator() 方法查看迭代器的产生过程。

       在此方法中,根据 self.num_workers 的数量返回了不同的迭代器,主要区别在于多线程处理方式不同,但这两种迭代器都是继承自 _BaseDataLoaderIter 类。这里我们先看单线程下的例子,进入 _SingleProcessDataLoaderIter(self)。

       构造函数并不复杂,在父类的构造器中执行了大量初始化属性,然后在自己的构造器中获得了一个 self._dataset_fetcher。此时继续单步前进断点,发现程序进入到了父类的 __next__() 方法中。

       在分析代码之前,我们先整理一下目前得到的信息:

       下面是 __next__() 方法的内容:

       可以看到最后返回的是变量 data,而 data 是由 self._next_data() 生成的,进入这个方法,我们发现这个方法由子类负责实现。

       在这个方法中,我们可以看到数据从 self._dataset_fecther.fetch() 中得到,需要依赖参数 index,而这个 index 由 self._next_index() 提供。进入这个方法可以发现它是由父类实现的。

       而前面的 index 实际上是由这个 self._sampler_iter 迭代器提供的。查找 self._sampler_iter 的定义,我们发现其在构造函数中。

       仔细观察,我们可以在倒数第 4 行发现 self._sampler_iter = iter(self._index_sampler),这个迭代器就是这里的 self._index_sampler 提供的,而 self._index_sampler 来自 loader._index_sampler。这个 loader 就是最外层的 DataLoader。因此我们回到 DataLoader 类中查看这个 _index_sampler 是如何得到的。

       我们可以发现 _index_sampler 是一个由 @property 装饰得到的属性,会根据 self._auto_collation 来返回 self.batch_sampler 或者 self.sampler。再次整理已知信息,我们可以得到:

       因此,只要知道 batch_sampler 和 sampler 如何返回 index,就能了解整个流程。

       首先发现这两个属性来自 DataLoader 的构造函数,因此下面先分析构造函数。

       由于构造函数代码量较大,因此这里只关注与 Sampler 相关的部分,代码如下:

       在这里我们只关注以下部分:

       代码首先检查了参数的合法性,然后进行了一轮初始化属性,接着判断了 dataset 的类型,处理完特殊情况。接下来,函数对参数冲突进行了判断,共判断了 3 种参数冲突:

       检查完参数冲突后,函数开始创建 sampler 和 batch_sampler,如下图所示:

       注意,仅当未指定 sampler 时才会创建 sampler;同理,小浣熊资源码仅在未指定 batch_sampler 且存在 batch_size 时才会创建 batch_sampler。

       在 DataLoader 的构造函数中,如果不指定参数 batch_sampler,则默认创建 BatchSampler 对象。该对象需要一个 Sampler 对象作为参数参与构造。这也是在构造函数中,batch_sampler 与 sampler 冲突的原因之一。因为传入一个 batch_sampler 时,说明 sampler 已经作为参数完成了 batch_sampler 的构造,若再将 sampler 传入 DataLoader 是多余的。

       以第一节中的简单代码为例,此时并未指定 Sampler 和 batch_sampler,也未指定 batch_size,默认为 1,因此在 DataLoader 构造时,创建了一个 SequencialSampler,并传入了 BatchSampler 进行构建。继续第一节中的断点,可以发现:

       具体使用 sampler 还是 batch_sampler 来生成 index,取决于 _auto_collation,而从上面的代码发现,只要存在 self.batch_sampler 就永远使用 batch_sampler 来生成。batch_sampler 与 sampler 冲突的原因之二:若不设置冲突,那么使用者试图同时指定 batch_sampler 与 sampler 后,尤其是在使用者继承了新的 Sampler 子类后, sampler 在获取数据的时候完全没有被使用,这对开发者来说是一个困惑的现象,容易引起不易察觉的 BUG。

       继续断点发现程序进入了 BatchSampler 的 __iter__() 方法,代码如下:

       从代码中可以发现,程序不停地从 self.sampler 中获取 idx 加入列表,直到填满一个 batch 的量,并将这一整个 batch 的 index 返回到迭代器的 _next_data()。

       此处由 self._dataset_fetcher.fetch(index) 来获取真正的数据,进入函数后看到:

       这里依然根据 self.auto_collation(来自 DataLoader._auto_collation)进行分别处理,但是总体逻辑都是通过 self.dataset[] 来调用 Dataset 对象的 __getitem__() 方法。

       此处的 Dataset 是来自 torchvision 的 DatasetFolder 对象,这里读取文件路径中的后,经过转换变为 Tensor 对象,与标签 target 一起返回。参数中的 index 是由迭代器的 self._dataset_fetcher.fetch() 传入。

       整个获取数据的流程可以用以下流程图简略表示:

       注意:

       另附:

       对于一条循环语句,在执行过程中发生了以下事件:

源码详解Pytorch的state_dict和load_state_dict

       在Pytorch中,保存和加载模型的一种方式是通过调用model.state_dict(),该函数返回的是一个OrderDict,包含网络结构的名称及其对应的参数。要深入了解实现细节,我们先关注其内部逻辑。

       在state_dict函数中,主要遍历了四个元素:_parameters,_buffers,_modules和_state_dict_hooks。前三种在先前的文章中已有详细介绍,而最后一种在读取state_dict时执行特定操作,通常为空,因此不必过多考虑。重要的一点是,当读取Module时,闪送程序源码采用递归方式,并以.作为分割符号,方便后续load_state_dict加载参数。

       最后,该函数输出了三种关键参数。

       接下来,让我们深入load_state_dict函数,它主要分为两部分。

       首先,load(self)函数会递归地恢复模型参数。其中,_load_from_state_dict源码在文末附上。

       在load_state_dict中,state_dict表示你之前保存的模型参数序列,而local_state表示你当前模型的结构。

       load_state_dict的主要作用在于,假设我们需恢复名为conv.weight的子模块参数,它会以递归方式先检查conv是否存在于state_dict和local_state中。如果不在,则将conv添加到unexpected_keys中;如果在,则进一步检查conv.weight是否存在,如果都存在,则执行param.copy_(input_param),完成参数拷贝。

       在if strict部分中,主要判断参数拷贝过程中是否有unexpected_keys或missing_keys,如有,则抛出错误,终止执行。当然,当strict=False时,会忽略这些细节。

       总结而言,state_dict和load_state_dict是Pytorch中用于保存和加载模型参数的关键函数,它们通过递归方式确保模型参数的准确恢复。

Pytorch源码剖析:nn.Module功能介绍及实现原理

       nn.Module作为Pytorch的核心类,是构建模型的基础。它提供了一系列功能,包括记录模型的参数,实现网络的前向传播,加载和保存模型数据,以及进行设备和数据类型转换等。这些功能在模型的训练和应用中起到关键作用。

       在训练与评估模式间切换,模块的行为会有所不同,如rrelu、dropout、batchnorm等操作在两种模式下表现不同。可学习的参数,如权重和偏置,需要通过梯度下降进行更新。非学习参数,比如batchnorm的running_mean,是训练过程中的统计结果。_buffers包含的Tensor不作为模型的一部分保存。

       模块内部包含一系列钩子(hook)函数,用于在特定的前向传播或反向传播阶段执行自定义操作。子模块列表用于存储模型中的所有子模块。

       魔术函数__init__在声明对象时自动调用,优化性能的关键在于使用super().__setattr__而非直接赋值。super调用父类的方法,避免不必要的检查,提高效率。使用register_buffer为模块注册可变的中间结果,例如BatchNorm的running_mean。register_parameter用于注册需要梯度下降更新的参数。

       递归应用函数用于对模型进行操作,如参数初始化。可以将模型移动到指定设备,转换数据类型,以及注册钩子函数以实现对网络的扩展和修改。

       调用魔术方法__call__执行前向传播。nn.Module未实现forward函数,子类需要提供此方法的具体实现。对于线性层等,forward函数定义了特定的运算流程。从检查点加载参数时,模块自动处理兼容性问题,确保模型结构与参数值的兼容。

       模块的__setattr__方法被重写,以区别对待Parameter、Module和Buffer。当尝试设置这些特定类型的属性时,执行注册或更新操作。其他属性的设置遵循标准的Python行为。

       模块的save方法用于保存模型参数和状态,确保模型结构和参数值在不同设备间转移时的一致性。改变训练状态(如将模型切换到训练或评估模式)是模块管理过程的重要组成部分。

PyTorch 源码分析(一):torch.nn.Module

       nn.Module是PyTorch中最核心和基础的结构,它是操作符/损失函数的基类,同时也是组成各种网络结构的基类(实际上是由多个module组合而成的一个module)。

       在Python侧,2.1回调函数注册,2.2 module类定义中,有以下几个重点函数:

       重点函数一:将模型的参数移动到CUDA上,内部会遍历其子module。

       重点函数二:将模型的参数移动到CPU上,内部会遍历其子module。

       重点函数三:将模型的参数转化为fp或者fp等,内部会遍历其子module。

       重点函数四:forward函数调用。

       重点函数五:返回该net的所有layer。

       在类图中,PyTorch的算子都是module的子类,包括自定义算子和整网定义。

       在C++侧,3.1 module.to("cuda")详细分析中,本质是将module的parameter&buffer等tensor移动到CUDA上,最终调用的是tensor.to(cuda)。

       3.2 module.load/save逻辑中,PyTorch模型保存分为两种,一种是纯参数,一种是带模型结构(PyTorch中的模型结构,本质上是由module、sub-module构造的一个计算图)。

       parameter、buffer是通过key-value的形式来存储和检索的,key为module的.name,value为存储具体数据的tensor。

       InputArchive/OutputArchive的write和read逻辑。

       通过Module,PyTorch将op/loss/opt等串联起来,类似于一个计算图。基于PyTorch构建的ResNet等模型,是逐个算子进行计算的,tensor在CPU和GPU之间来回流动,而不是整个计算都在GPU上完成(即中间计算结果不出GPU)。实际上,在进行推理时,可以构建一个计算图,让整个计算图的计算都在GPU上完成,不知道是否可行(如果GPU上有一个CPU就可以完成这个操作,不知道tensorrt是否是这样的操作)。

Pytorch之Dataparallel源码解析

       深入解析Pytorch之Dataparallel源码

       在深入理解Dataparallel原理之前,需要明白它的使用场景和目的。Dataparallel设计用于在多GPU环境下并行处理数据,提高模型训练效率。

       初始化阶段,Dataparallel需要实例化一个模型。这一步中,模型的参数会被复制到所有可用的GPU上,从而实现并行计算。

       在前向传播阶段,Dataparallel的核心作用体现出来。它会将输入数据分割成多个小批次,然后分别发送到各个GPU上。在每个GPU上执行前向传播操作后,结果会被收集并汇总。这样,即便模型在多GPU上运行,输出结果也如同在单GPU上运行一样。

       具体实现中,Dataparallel会利用Python的多重继承和数据并行策略。它继承自nn.Module,同时调用nn.DataParallel的构造函数,从而实现并行计算。

       对于那些需要在GPU间共享的状态或变量,Dataparallel还提供了相应的管理机制,确保数据的一致性和计算的正确性。这样的设计使得模型能够高效地在多GPU环境下运行,同时保持代码的简洁性和易读性。

       总结而言,Dataparallel通过分割数据、并行执行前向传播和收集结果的机制,实现了高效的数据并行训练。理解其源码有助于开发者更好地利用多GPU资源,提升模型训练效率。

PyTorch 源码解读之 torch.utils.data:解析数据处理全流程

       文@

       目录

       0 前言

       1 Dataset

       1.1 Map-style dataset

       1.2 Iterable-style dataset

       1.3 其他 dataset

       2 Sampler

       3 DataLoader

       3.1 三者关系 (Dataset, Sampler, Dataloader)

       3.2 批处理

       3.2.1 自动批处理(默认)

       3.2.2 关闭自动批处理

       3.2.3 collate_fn

       3.3 多进程处理 (multi-process)

       4 单进程

       5 多进程

       6 锁页内存 (Memory Pinning)

       7 预取 (prefetch)

       8 代码讲解

       0 前言

       本文以 PyTorch 1.7 版本为例,解析 torch.utils.data 模块在数据处理流程中的应用。

       理解 Python 中的迭代器是解读 PyTorch 数据处理逻辑的关键。Dataset、Sampler 和 DataLoader 三者共同构建数据处理流程。

       迭代器通过实现 __iter__() 和 __next__() 方法,支持数据的循环访问。Dataset 提供数据获取接口,Sampler 控制遍历顺序,DataLoader 负责加载和批处理数据。

       1 Dataset

       Dataset 包括 Map-style 和 Iterable-style 两种,分别用于索引访问和迭代访问数据。

       Map-style dataset 通过实现 __getitem__() 和 __len__() 方法,支持通过索引获取数据。

       Iterable-style dataset 实现 __iter__() 方法,适用于随机访问且批次大小依赖于获取数据的场景。

       2 Sampler

       Sampler 用于定义数据遍历的顺序,支持用户自定义和 PyTorch 提供的内置实现。

       3 DataLoader

       DataLoader 是数据加载的核心,支持 Map-style 和 Iterable-style Dataset,提供单多进程处理和批处理等功能。

       通过参数配置,如 batch_size、drop_last、collate_fn 等,DataLoader 实现了数据的自动和手动批处理。

       4 批处理

       3.2.1 自动批处理(默认)

       DataLoader 默认使用自动批处理,通过参数控制批次生成和样本整理。

       3.2.2 关闭自动批处理

       关闭自动批处理,允许用户自定义批处理逻辑或处理单个样本。

       3.2.3 collate_fn

       collate_fn 是手动批处理时的关键,用于整理单个样本为批次。

       5 多进程

       多进程处理通过 num_workers 参数启用,加速数据加载。

       6 单进程

       单进程模式下,数据加载可能影响计算流程,适用于数据量小且无需多进程的场景。

       7 锁页内存 (Memory Pinning)

       Memory Pinning 技术确保数据在 GPU 加速过程中快速传输,提高性能。

       8 代码讲解

       通过具体代码分析,展示了 DataLoader 的初始化、迭代和数据获取过程,涉及迭代器、Sampler 和 Dataset 的交互。

PyTorch 源码分析(三):torch.nn.Norm类算子

       PyTorch源码详解(三):torch.nn.Norm类算子深入解析

       Norm类算子在PyTorch中扮演着关键角色,它们包括BN(BatchNorm)、LayerNorm和InstanceNorm。

       1. BN/LayerNorm/InstanceNorm详解

       BatchNorm(BN)的核心功能是对每个通道(C通道)的数据进行标准化,确保数据在每个批次后保持一致的尺度。它通过学习得到的gamma和beta参数进行缩放和平移,保持输入和输出形状一致,同时让数据分布更加稳定。

       gamma和beta作为动态调整权重的参数,它们在BN的学习过程中起到至关重要的作用。

       2. Norm算子源码分析

       继承关系:Norm类在PyTorch中具有清晰的继承结构,子类如BatchNorm和InstanceNorm分别继承了其特有的功能。

       BN与InstanceNorm实现:在Python代码中,BatchNorm和InstanceNorm的实例化和计算逻辑都包含对输入数据的2D转换,即将其分割为M*N的矩阵。

       计算过程:在计算过程中,首先计算每个通道的均值和方差,这是这些标准化方法的基础步骤。

       C++侧的源码洞察

       C++实现中,对于BatchNorm和LayerNorm,代码着重于处理数据的标准化操作,同时确保线程安全,通过高效的数据视图和线程视图处理来提高性能。

PyTorch 源码解读之 torch.optim:优化算法接口详解

       本文深入解读了 PyTorch 中的优化算法接口 torch.optim,主要包括优化器 Optimizer、学习率调整策略 LRScheduler 及 SWA 相关优化策略。以下为详细内容:

       Optimizer 是所有优化器的基类,提供了初始化、更新参数、设置初始学习率等基本方法。在初始化优化器时,需要传入模型的可学习参数和超参数。Optimizer 的核心方法包括:

       1. 初始化函数:创建优化器时,需指定模型的可学习参数和超参数,如学习率、动量等。

       2. add_param_group:允许为模型的不同可学习参数组设置不同的超参数,以适应不同的学习需求。

       3. step:执行一次模型参数更新,需要闭包提供损失函数的梯度信息。

       4. zero_grad:在更新参数前,清空参数的梯度信息。

       5. state_dict 和 load_state_dict:用于序列化和反序列化优化器的状态,便于保存和加载模型的训练状态。

       Optimizer 包括常见的优化器如 SGD、Adagrad、RMSprop 和 Adam,各有特点,适用于不同的应用场景。例如,SGD 适用于简单场景,而 Adam 则在处理大数据集时表现更优。

       学习率调节器 lr_scheduler 则负责在训练过程中调整学习率,以适应模型的收敛过程。PyTorch 提供了多种学习率调整策略,如 StepLR、MultiStepLR、ExponentialLR 等,每种策略都有其特点和应用场景,如 StepLR 用于周期性调整学习率,以加速收敛。

       SWA(随机权重平均)是一种优化算法,通过在训练过程中计算模型参数的平均值,可以得到更稳定的模型,提高泛化性能。SWA 涉及 AveragedModel 类,用于更新模型的平均参数,以及 update_bn 函数,用于在训练过程中更新批量归一化参数。

       总结,torch.optim 提供了丰富的优化算法接口,可以根据模型训练的需求灵活选择和配置,以达到最佳的训练效果和泛化性能。通过深入理解这些优化器和学习率调整策略,开发者可以更有效地训练深度学习模型。

PyTorch 源码解读之 BN & SyncBN:BN 与 多卡同步 BN 详解

       BatchNorm原理

       BatchNorm最早在全连接网络中提出,旨在对每个神经元的输入进行归一化操作。在卷积神经网络(CNN)中,这一原理被扩展为对每个卷积核的输入进行归一化,即在channel维度之外的所有维度上进行归一化。BatchNorm带来的优势包括提高网络的收敛速度、稳定训练过程、减少过拟合现象等。

       BatchNorm的数学表达式为公式[1],引入缩放因子γ和移位因子β,作者在文章中解释了它们的作用。

       PyTorch中与BatchNorm相关的类主要位于torch.nn.modules.batchnorm模块中,包括如下的类:_NormBase、BatchNormNd。

       具体实现细节如下:

       _NormBase类定义了BN相关的一些属性。

       初始化过程。

       模拟BN的forward过程。

       running_mean、running_var的更新逻辑。

       γ、β参数的更新方式。

       BN在eval模式下的行为。

       BatchNormNd类包括BatchNorm1d、BatchNorm2d、BatchNorm3d,它们的区别在于检查输入的合法性,BatchNorm1d接受2D或3D的输入,BatchNorm2d接受4D的输入,BatchNorm3d接受5D的输入。

       接着,介绍SyncBatchNorm的实现。

       BN性能与batch size密切相关。在batch size较小的场景中,如检测任务,内存占用较高,单张显卡难以处理较多,导致BN效果不佳。SyncBatchNorm提供了解决方案,其原理是所有计算设备共享同一组BN参数,从而获得全局统计量。

       SyncBatchNorm在torch/nn/modules/batchnorm.py和torch/nn/modules/_functions.py中实现,前者负责输入合法性检查以及参数设置,后者负责单卡统计量计算和进程间通信。

       SyncBatchNorm的forward过程。

       复习方差计算方式。

       单卡计算均值、方差,进行归一化处理。

       同步所有卡的数据,得到全局均值mean_all和逆标准差invstd_all,计算全局统计量。

       接着,介绍SyncBatchNorm的backward过程。

       在backward过程中,需要在BN前后进行进程间通信。这在_functions.SyncBatchNorm中实现。

       计算weight、bias的梯度以及γ、β,进一步用于计算梯度。