Conv2dtorchnnConv3d_0">Python:torch.nn.Conv1d(), torch.nn.Conv2d()和torch.nn.Conv3d()函数理解
1. 函数参数
在torch中的卷积操作有三个,torch.nn.Conv1d(),torch.nn.Conv2d()还有torch.nn.Conv3d(),这是搭建网络过程中常用的网络层,为了用好卷积层,需要知道这些参数代表的含义。
这三种不同的卷积的输入参数是相同的,所以只看一个就可以。
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros', # TODO: refine this type
device=None,
dtype=None
这里面的参数网上有很多说明,重点是怎么理解和使用。
2. 参数理解
这里面重点是in_channels
参数,这个是代表数据输入的通道,很多说明这个通道是利用torch.nn.Conv2d
处理图片数据来进行说明的,代表的是图片的通道数,然后面的两个参数对应着图片的长度和宽度。
下面是本人对这参数的理解过程:
- 首先对于
torch.nn.Conv
函数,所接受的数据是可以带有batch
维度的,也可以不带有batch
维度,这就表示对于torch.nn.Conv2d
可以接受的数据包括3维数据或者4维数据,
如:
conv2 = torch.nn.Conv2d(16, 120, 3, stride=2)
input2_3 = torch.randn(16, 5, 5)
output2_3 = conv2(input2_3)
print(output2_3.shape)
input2_4 = torch.randn(20, 16, 5, 5)
output2_4 = conv2(input2_4)
print(output2_4.shape)
该段得到的输出为:
torch.Size([120, 2, 2])
torch.Size([20, 120, 2, 2])
这是因为input2_4
只是多了一个维度batch
在第一个维度上,如果输入的数据是2维的或者5维的,就会提示如下的错误:指明只能接受3维的数据或者4维的数据.
RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [20, 20, 16, 5, 5]
这其实就说明了根据自己数据维度选择合适的torch.nn.Conv
, 例如,如果数据是2维的,那么就选择torch.nn.Conv1d
,这个可以接收传入的数据维度可以是2维,或者是带有batch
维度的3维数据。
之后需要注意的是in_channels
参数其实对应的就是传入数据的第一个维度(不带有batch
)或者带有batch
的第二个维度,这个要和in_channels
参数相同。
可以理解成这个in_channels
就是表示了有多个卷积核在参与计算,那么剩下的维度正好就是卷积核的维度,
如对于torch.nn.Conv3d
,传入的数据最少是4维数据,(不带有batch
),那么第一维的数据应该等于in_channels
,然后剩下三维正好的是卷积核的维度。
如:
conv3 = torch.nn.Conv3d(16, 120, 3, stride=2)
input3 = torch.randn(16, 5, 5, 5)
output3 = conv3(input3)
print(output3.shape)
会得到
torch.Size([120, 2, 2, 2])
这个卷积核是333,相当于有16个卷积核,每个卷积核在16维的数据上依次计算。
其他的作为输出影响的是数据的维度大小,但是out_channels
又决定了输出数据的第一个维度,(不带有batch
),就可以依然用这个方式思考。
针对后面几维数据的大小,由其他的参数决定,这个有公式可以计算,懒得算也可以直接打印输出看一下维度。