权值初始化的目的是防止层激活输出在深度神经网络的正向传递过程中爆炸或消失。如果发生任何一种情况,损失梯度要么太大,要么太小,无法有利地向后流动,如果网络能够这样做,则需要更长的时间才能收敛。
一、权值初始化的影响 假设我们有一个简单的100层网络,没有激活,并且每个层都有一个矩阵a,其中包含该层的权重。为了完成单次前向传递,我们必须在每100层的输入和权重之间执行矩阵乘法,这将导致总共100个连续矩阵乘法。
1 2 3 4 5 6 7 8 9 10 import torchx = torch.randn(512 ) for i in range(100 ): a = torch.randn(512 , 512 ) x = x @ a if torch.isnan(x.std()): print(i) break
激活输出在第29个网络层中爆炸。我们显然将权重初始化为太大。 接下来我们调整权值,使其标准差从1变为0.01。
1 2 3 4 5 6 7 8 9 10 import torchx = torch.randn(512 ) for i in range(100 ): a = torch.randn(512 , 512 ) * 0.01 x = x @ a if x.std() == 0 : print(i) break
这回激活输出在第69个网络层中完全消失。 总而言之,如果初始化的权重过大,网络就不能很好地学习。当权重初始化过小时也会发生同样的情况。 我们能不能找到最佳的点? 有的!对于矩阵乘法而言,我们需要权值的标准差非常接近输入连接的数量的平方根,它在我们的例子中是 √512。
1 2 3 4 5 6 7 8 9 10 11 12 13 import torchimport mathx = torch.randn(512 ) for i in range(100 ): a = torch.randn(512 , 512 ) * math.sqrt(1. /512 ) x = x @ a if x.std() == 0 or torch.isnan(x.std()) == True : print(i) break print(x.mean(), x.std())
看上去很不错,既没有爆炸也没有消失。但是如果我们给权值限定值域(激活函数一般是有界的),这种初始化方案就不那么理想了。
1 2 3 4 5 6 7 8 9 10 11 12 13 import torchimport mathx = torch.randn(512 ) for i in range(100 ): a = torch.randn(512 , 512 ).uniform_(-1 , 1 ) * math.sqrt(1. /512 ) x = x @ a if x.std() == 0 or torch.isnan(x.std()) == True : print(i) break print(x.mean(), x.std())
二、两种有效的权值初始化方案 1、Xavier初始化 当使用关于0对称且在[-1,1]内部有输出(如softsign和tanh)的激活函数时,适用Xavier初始化
1 2 3 4 5 6 7 8 9 10 11 12 13 import torchimport mathx = torch.randn(512 ) def xaiver (m, h) : return torch.Tensor(m, h).uniform_(-1 , 1 ) * math.sqrt(6. /(m+h)) for i in range(100 ): a = xaiver(512 , 512 ) x = torch.tanh(x @ a) print(x.mean(), x.std())
2、Kaiming初始化 当使用非对称、非有界的激活函数时,适用Kaiming初始化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 import torchimport torch.nn.functional as Fimport mathx = torch.randn(512 ) def Kaiming (m, h) : return torch.randn(m, h) * math.sqrt(2. /m) for i in range(100 ): a = Kaiming(512 , 512 ) x = torch.relu(x @ a) print(x.mean(), x.std())
三、建模时加入权值初始化设置 1、单层初始化 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 import torch.nn as nnclass MyNet (nn.Module) : def __init__ (self) : super(MyNet, self).__init__() self.layer1 = nn.Linear(100 , 200 ) self.layer2 = nn.Linear(200 , 10 ) nn.init.xavier_uniform_(self.layer2.weight) nn.init.constant_(self.layer2.bias, 0.1 ) def forward (self, x) : x = self.layer1(x) x = nn.relu(x) x = self.layer2(x) x = nn.softmax(x) return x print("MyNet:" ) model = MyNet() print(model)
2、模型初始化 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 import torch.nn as nnclass MyNet (nn.Module) : def __init__ (self) : super(MyNet, self).__init__() self.layer1 = nn.Linear(100 , 200 ) self.layer2 = nn.Linear(200 , 10 ) self._initialize_weights() def _initialize_weights (self) : for m in self.modules(): if isinstance(m, nn.BatchNorm2d): nn.init.constant(m.weight, 1 ) nn.init.constant(m.bias, 0 ) elif isinstance(m, (nn.Conv2d, nn.Linear)): nn.init.kaiming_normal_(m.weight) if m.bias is not None : nn.init.constant_(m.bias, 0 ) def forward (self, x) : x = self.layer1(x) x = nn.relu(x) x = self.layer2(x) x = nn.softmax(x) return x print("MyNet:" ) model = MyNet() print(model)