均方层归一化RMSNorm(均方根标准化)

一、理论基础

layer normalization 重要的两个部分是平移不变性和缩放不变性。 Root Mean Square Layer Normalization 认为 layer normalization 取得成功重要的是缩放不变性,而不是平移不变性。因此,去除了计算过程中的平移,只保留了缩放,进行了简化,提出了RMS Norm(Root Mean Square Layer Normalization),即均方根 norm。

均方层归一化RMSNorm(均方根标准化)

Layer Normalization (LayerNorm)Root Mean Square Layer Normalization (RMSNorm)
对特征张量按照某一维度或某几个维度进行0均值,1方差的归一化 操作
LayerNorm 是一种标准化方法,它计算一个样本的均值和方差,然后使用这些来对样本进行归一化。这种方法是独立于批量大小的,使得模型更加稳定。
RMSNorm是对LayerNorm的一个改进,没有做re-center操作(移除了其中的均值项),可以看作LayerNorm在均值为0时的一个特例。论文通过实验证明,re-center操作不重要。RMSNorm 也是一种标准化方法,但与 LayerNorm 不同,它不是使用整个样本的均值和方差,而是使用平方根的均值来归一化,这样做可以降低噪声的影响。

二、代码实现


# 均方根标准化 class RMSNorm(torch.nn.Module): def __init__(self,normalized_shape,eps=1e-5,devices=None,dtype=None,**kwargs): super().__init__() self.weight=torch.nn.Parameter(torch.empty(size=normalized_shape,device=devices,dtype=dtype)) #待训练的参数 self.eps=eps def forward(self,hidden_state:torch.Tensor): input_type=hidden_state.dtype variace=hidden_state.to(torch.float32).pow(2).mean(-1,keepdim=True) hidden_state=hidden_state*torch.rsqrt(variace+self.eps) return (hidden_state*self.weight).to(input_type) if __name__ == '__main__': x=RMSNorm(normalized_shape=[3,4]) y=x(torch.randn(size=(3,4))) print(y)

https://arxiv.org/pdf/1910.07467

原创文章。转载请注明: 作者:JiangYuan 网址: https://www.icnma.com
Like (0)
JiangYuan管理
Previous 21/05/2025 21:01
Next 17/06/2025 10:48

猜你想看

  • ollama run Model on Hugging Face Hub

    之前写了篇比较全的ollama使用文档:https://www.icnma.com/ollama-tutorial/ 本篇主要是如何使用ollama直接运行huggingface上的gguf模型。 直接使用Ollama在Hugging Face上任何GGUF quant model,而无需创建新的Mo…

    25/11/2024
    01.6K0
  • Attention:MLA、MHA、MQA与GQA

    多头注意力机制(Multi-Head Attention,MHA) 多头注意力(Multi-Head Attention, MHA)是Transformer模型的核心机制,通过并行计算多个注意力头,使模型能够同时关注输入序列中不同位置的特征。其核心思想是将输…

    17/06/2025
    01500
  • Instruction-tuning Llama2大模型文本分类微调示例

    本文介绍通过微调 Meta 的 Llama 2 7B 模型对18 个不同类别的新闻文章进行分类,本教程将详细解释每个步骤,涵盖使用的所有类、函数和参数。 安装所需库 加载所需库  Bitsandbytes 配置 定义一个函数&nbs…

    30/08/2023
    04.7K0
  • 文本生成模型解码策略和采样方法对比分析(13种)

    本文主要探讨如下两个方面: 1、outputs = model.generate(**inputs, ...) generate()中各个参数是什么含义? 2、我们的模型在文本生成的时候,最终的结果文本是如何产生的?常见的解码策略有哪些? 常用参数释义 m…

    20/10/2024
    01.4K0
  • DeepSeek-R1是怎样炼成的?

    DeepSeek-R1反响非常大,主要是因为使用较低的成本得到了OpenAI O1的效果。开源还便宜。 在这篇文章中,我们将了解它是如何构建的。 目录: DeepSeek-R1 的训练方法 1. 大规模推理导向强化学习 (R1-Zero) 2. R1 …

    28/01/2025
    01.7K0