深度学习开发框架PyTorch(6)-- 权值初始化

权值初始化的目的是防止层激活输出在深度神经网络的正向传递过程中爆炸或消失。如果发生任何一种情况,损失梯度要么太大,要么太小,无法有利地向后流动,如果网络能够这样做,则需要更长的时间才能收敛。

一、权值初始化的影响

假设我们有一个简单的100层网络,没有激活,并且每个层都有一个矩阵a,其中包含该层的权重。为了完成单次前向传递,我们必须在每100层的输入和权重之间执行矩阵乘法,这将导致总共100个连续矩阵乘法。

1
2
3
4
5
6
7
8
9
10
import torch

x = torch.randn(512)

for i in range(100):
a = torch.randn(512, 512)
x = x @ a
if torch.isnan(x.std()):
print(i)
break

激活输出在第29个网络层中爆炸。我们显然将权重初始化为太大。
接下来我们调整权值,使其标准差从1变为0.01。

1
2
3
4
5
6
7
8
9
10
import torch

x = torch.randn(512)

for i in range(100):
a = torch.randn(512, 512) * 0.01 # torch.randn生成的是以0为均数、以1为标准差的正态分布
x = x @ a
if x.std() == 0:
print(i)
break

这回激活输出在第69个网络层中完全消失。
总而言之,如果初始化的权重过大,网络就不能很好地学习。当权重初始化过小时也会发生同样的情况。
我们能不能找到最佳的点?
有的!对于矩阵乘法而言,我们需要权值的标准差非常接近输入连接的数量的平方根,它在我们的例子中是 √512。

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import math

x = torch.randn(512)

for i in range(100):
a = torch.randn(512, 512) * math.sqrt(1./512)
x = x @ a
if x.std() == 0 or torch.isnan(x.std()) == True:
print(i)
break

print(x.mean(), x.std())

看上去很不错,既没有爆炸也没有消失。但是如果我们给权值限定值域(激活函数一般是有界的),这种初始化方案就不那么理想了。

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import math

x = torch.randn(512)

for i in range(100):
a = torch.randn(512, 512).uniform_(-1, 1) * math.sqrt(1./512)
x = x @ a
if x.std() == 0 or torch.isnan(x.std()) == True:
print(i)
break

print(x.mean(), x.std())

二、两种有效的权值初始化方案

1、Xavier初始化

当使用关于0对称且在[-1,1]内部有输出(如softsign和tanh)的激活函数时,适用Xavier初始化

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import math

x = torch.randn(512)

def xaiver(m, h):
return torch.Tensor(m, h).uniform_(-1, 1) * math.sqrt(6./(m+h))

for i in range(100):
a = xaiver(512, 512)
x = torch.tanh(x @ a)

print(x.mean(), x.std())

2、Kaiming初始化

当使用非对称、非有界的激活函数时,适用Kaiming初始化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
import torch.nn.functional as F
import math

x = torch.randn(512)

def Kaiming(m, h):
return torch.randn(m, h) * math.sqrt(2./m)

for i in range(100):
a = Kaiming(512, 512)
x = torch.relu(x @ a)

print(x.mean(), x.std())

三、建模时加入权值初始化设置

1、单层初始化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch.nn as nn

class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__() # 调用父类的初始化函数
self.layer1 = nn.Linear(100, 200)
self.layer2 = nn.Linear(200, 10)
nn.init.xavier_uniform_(self.layer2.weight)
nn.init.constant_(self.layer2.bias, 0.1)

def forward(self, x):
x = self.layer1(x)
x = nn.relu(x)
x = self.layer2(x)
x = nn.softmax(x)
return x

print("MyNet:")
model = MyNet()
print(model)

2、模型初始化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch.nn as nn

class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__() # 调用父类的初始化函数
self.layer1 = nn.Linear(100, 200)
self.layer2 = nn.Linear(200, 10)
self._initialize_weights()

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
nn.init.constant(m.weight, 1) # 常数值初始化
nn.init.constant(m.bias, 0)
elif isinstance(m, (nn.Conv2d, nn.Linear)):
# nn.init.uniform_(m.weight) # 均匀分布初始化
# nn.init.normal_(m.weight) # 正态分布初始化
# nn.init.eye_(m.weight) # 单位矩阵初始化
# nn.init.constant(m.weight, 1) # 常数值初始化
# nn.init.xavier_uniform_(m.weight) # xavier初始化
nn.init.kaiming_normal_(m.weight) # kaiming初始化
# nn.init.orthogonal(m.weight) # 正交初始化, 在RNN中经常使用的参数初始化方法
# nn.init.sparse_(m.weight, sparsity=0.1) # 稀疏矩阵初始化
if m.bias is not None:
nn.init.constant_(m.bias, 0)

def forward(self, x):
x = self.layer1(x)
x = nn.relu(x)
x = self.layer2(x)
x = nn.softmax(x)
return x

print("MyNet:")
model = MyNet()
print(model)
0%