一、理论基础
layer normalization 重要的两个部分是平移不变性和缩放不变性。 Root Mean Square Layer Normalization 认为 layer normalization 取得成功重要的是缩放不变性,而不是平移不变性。因此,去除了计算过程中的平移,只保留了缩放,进行了简化,提出了RMS Norm(Root Mean Square Layer Normalization),即均方根 norm。
Layer Normalization (LayerNorm) | Root Mean Square Layer Normalization (RMSNorm) |
对特征张量按照某一维度或某几个维度进行0均值,1方差的归一化 操作 LayerNorm 是一种标准化方法,它计算一个样本的均值和方差,然后使用这些来对样本进行归一化。这种方法是独立于批量大小的,使得模型更加稳定。 | RMSNorm是对LayerNorm的一个改进,没有做re-center操作(移除了其中的均值项),可以看作LayerNorm在均值为0时的一个特例。论文通过实验证明,re-center操作不重要。RMSNorm 也是一种标准化方法,但与 LayerNorm 不同,它不是使用整个样本的均值和方差,而是使用平方根的均值来归一化,这样做可以降低噪声的影响。 |
二、代码实现
# 均方根标准化
class RMSNorm(torch.nn.Module):
def __init__(self,normalized_shape,eps=1e-5,devices=None,dtype=None,**kwargs):
super().__init__()
self.weight=torch.nn.Parameter(torch.empty(size=normalized_shape,device=devices,dtype=dtype)) #待训练的参数
self.eps=eps
def forward(self,hidden_state:torch.Tensor):
input_type=hidden_state.dtype
variace=hidden_state.to(torch.float32).pow(2).mean(-1,keepdim=True)
hidden_state=hidden_state*torch.rsqrt(variace+self.eps)
return (hidden_state*self.weight).to(input_type)
if __name__ == '__main__':
x=RMSNorm(normalized_shape=[3,4])
y=x(torch.randn(size=(3,4)))
print(y)
https://arxiv.org/pdf/1910.07467
原创文章。转载请注明:
作者:JiangYuan
网址: https://www.icnma.com