Pytorch网络构造Tips

关于网络的parameters

Pytorch中,每一个网络继承于nn.Module类,当实例化之后,是通过维护一下8个字典来实现各种网络功能的:

1
2
3
4
5
6
7
8
_parameters
_buffers
_backward_hooks
_forward_hooks
_forward_pre_hooks
_state_dict_hooks
_load_state_dict_pre_hooks
_modules

在网络初始化__init__过程中,所有类内变量的定义都会通过__setattr__方法在__dict__中进行注册,而nn.Module重写了注册方法,将所有类内变量中,类型派生于Parameter的变量归属到_parameters字典中,这就解释了为什么使用一个list来存放网络的每一层会导致网络中的parameters为空。

此外,在获取参数时,nn.Module是通过遍历整个_modules字典来实现的,因此在定义时可以使用nn.ModuleList类型来替代list类型存放多个网络层。

修改一个实例化后的网络

在拿到别人预训练好的模型后,有时需要对网络进行修改,例如重置参数或替换网络中的某些模块重新训练,就需要对一个已经实例化之后的网络进行修改。一个很直接的思路就是,在类中找到对应的属性直接进行赋值替换即可。

首先是修改参数,这里有两种思路,由于修改参数不改变网络结构,所以可以直接利用pytorch提供的网络参数字典和导入方法来完成,即:

1
2
3
state_dict = model.state_dict()
# change some parameters in state_dict
model.load_state_dict(state_dict)

还有一种思路就是通过named_parameters()来进行对应修改:

1
2
3
for name, params in model.named_parameters():
if name == "需要修改参数的层的名字":
params.data=torch.zeros(params.data.shape)

以上是修改参数的方法,相对来说,修改参数比较简单,pytorch也提供了一些方法来进行便捷的参数替换。而对于修改模型,实现思路就比较直接了,在model中找到层所在的位置,对其进行替换即可,这里使用到了named_modules()方法。以开源大语言模型llama-2-13b为例,在模型加载完成后,先查看各层的命名:

1
2
3
4
5
for name,layer in model.named_modules():
print(name, layer)

# 如果对网络结构熟悉的话也可以只看名字
print([name for (name,module) in model.named_modules()])

然后在其中找到需要替换的层的名字,对其进行替换即可,以embeding层为例

1
2
from torch import nn
model.model.embed_tokens = nn.Linear(32000, 5120)

可以看出,embedding层被成功替换为了一个全连接层。其他层的修改以此类推。

Reference

[1]Gemfield, “详解Pytorch中的网络构造,” 知乎专栏, Jan. 04, 2019. https://zhuanlan.zhihu.com/p/53927068 (accessed Sep. 26, 2022).