super(XXX,self).init()的作用
在使用pytorch框架自定义网络结构时,总要有必不可少的第一句,super(XXX,self).__init__()。其实它的作用就是————对继承自父类的属性进行初始化,并且用父类的初始化方法初始化继承的属性。
一个简单的例子:
class Person():
def __init__(self,name,gender) -> None:
# 初始化name,gender属性
self.name = name
self.gender = gender
def printinfo(self):
print(self.name,self.gender)
class Stu(Person):
def __init__(self, name, gender,school): #继承父类Person属性
# 使用父类的初始化方法来初始化子类name和gender属性
super().__init__(name, gender)
self.school = school
# 对父类的printinfo方法进行重写
def printinfo(self):
print(self.name,self.gender,self.school)
if __name__ == '__main__':
pe = Person('Mike','male')
pe.printinfo()
stu = Stu('Lucy','female','Tsinghua')
stu.printinfo()
输出:
Mike male
Lucy female Tsinghua
当然,如果初始化的逻辑和父类不同,也可以自己初始化子类的属性,比如:
class Person():
def __init__(self,name,gender) -> None:
# 初始化name,gender属性
self.name = name
self.gender = gender
def printinfo(self):
print(self.name,self.gender)
class Stu(Person):
def __init__(self, name, gender,school):
super().__init__(name, gender)
self.school = school
#也可以对父类属性name,gender进行改写
self.name = name.upper() #lower()
self.gender = gender.title()
def printinfo(self):
print(self.name,self.gender,self.school)
if __name__ == '__main__':
stu = Stu('Lucy','female','Tsinghua')
stu.printinfo()
输出:
LUCY Female Tsinghua
然后,我们以一个简单的卷积神经网络模型lenet5为例:
class Lenet5(torch.nn.Module):
# constructor function
def __init__(self):
super(Lenet5,self).__init__()
self.layer1 = torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5,5),padding=2)
self.layer2 = torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5,5),padding=0)
self.layer3 = torch.nn.Conv2d(in_channels=16, out_channels=120, kernel_size=(5,5),padding=0)
self.layer4 = torch.nn.Linear(120,84) #通常用于设置网络中的全连接层
self.layer5 = torch.nn.Linear(84,10)
这里super(Lenet5, self).__init__()的含义:子类Lenet5类继承父类nn.Module,super(Lenet5, self).__init__就是对继承自父类nn.Module的属性进行初始化。并且是用nn.Module的初始化方法来初始化继承的属性。就是用父类nn.Module的方法初始化子类Lenet5的属性。因为nn.Module的方法是pytorch框架中已经写好的,我们直接调用就可以,不然还要初始化各种权重和参数,太过复杂。
注意:在我们创建一个类后,通常会再创建一个 __init__()初始化方法,当我们创建一个类的实例时,该方法就会被自动执行。
例如,我们只是创建一个person实例,也不调用其它方法,他也会自动执行__init__()中的内容:
class Person():
def __init__(self,name,gender) -> None:
# 初始化name,gender属性
self.name = name
self.gender = gender
print('running')
def printinfo(self):
print(self.name,self.gender)
if __name__=='__main__':
person = Person('Mike','female')
输出:
running