Pytorch网络构造Tips
关于网络的parameters
Pytorch中,每一个网络继承于nn.Module
类,当实例化之后,是通过维护一下8个字典来实现各种网络功能的:
1 | _parameters |
在网络初始化__init__
过程中,所有类内变量的定义都会通过__setattr__
方法在__dict__
中进行注册,而nn.Module
重写了注册方法,将所有类内变量中,类型派生于Parameter
的变量归属到_parameters
字典中,这就解释了为什么使用一个list
来存放网络的每一层会导致网络中的parameters
为空。
此外,在获取参数时,nn.Module
是通过遍历整个_modules
字典来实现的,因此在定义时可以使用nn.ModuleList
类型来替代list
类型存放多个网络层。
修改一个实例化后的网络
在拿到别人预训练好的模型后,有时需要对网络进行修改,例如重置参数或替换网络中的某些模块重新训练,就需要对一个已经实例化之后的网络进行修改。一个很直接的思路就是,在类中找到对应的属性直接进行赋值替换即可。
首先是修改参数,这里有两种思路,由于修改参数不改变网络结构,所以可以直接利用pytorch
提供的网络参数字典和导入方法来完成,即:
1 | state_dict = model.state_dict() |
还有一种思路就是通过named_parameters()
来进行对应修改:
1 | for name, params in model.named_parameters(): |
以上是修改参数的方法,相对来说,修改参数比较简单,pytorch
也提供了一些方法来进行便捷的参数替换。而对于修改模型,实现思路就比较直接了,在model
中找到层所在的位置,对其进行替换即可,这里使用到了named_modules()
方法。以开源大语言模型llama-2-13b
为例,在模型加载完成后,先查看各层的命名:
1 | for name,layer in model.named_modules(): |
然后在其中找到需要替换的层的名字,对其进行替换即可,以embeding
层为例
1 | from torch import nn |
可以看出,embedding
层被成功替换为了一个全连接层。其他层的修改以此类推。
Reference
[1]Gemfield, “详解Pytorch中的网络构造,” 知乎专栏, Jan. 04, 2019. https://zhuanlan.zhihu.com/p/53927068 (accessed Sep. 26, 2022).