Python:torch.nn.Conv1d(), torch.nn.Conv2d()和torch.nn.Conv3d()函数理解

news/2024/7/23 16:28:15 标签: python, pytorch, 开发语言, 深度学习, Conv2d

Conv2dtorchnnConv3d_0">Python:torch.nn.Conv1d(), torch.nn.Conv2d()和torch.nn.Conv3d()函数理解

1. 函数参数

在torch中的卷积操作有三个,torch.nn.Conv1d(),torch.nn.Conv2d()还有torch.nn.Conv3d(),这是搭建网络过程中常用的网络层,为了用好卷积层,需要知道这些参数代表的含义。

这三种不同的卷积的输入参数是相同的,所以只看一个就可以。

def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_2_t,
        stride: _size_2_t = 1,
        padding: Union[str, _size_2_t] = 0,
        dilation: _size_2_t = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = 'zeros',  # TODO: refine this type
        device=None,
        dtype=None

这里面的参数网上有很多说明,重点是怎么理解和使用。

2. 参数理解

这里面重点是in_channels参数,这个是代表数据输入的通道,很多说明这个通道是利用torch.nn.Conv2d处理图片数据来进行说明的,代表的是图片的通道数,然后面的两个参数对应着图片的长度和宽度。

下面是本人对这参数的理解过程:

  • 首先对于torch.nn.Conv函数,所接受的数据是可以带有batch维度的,也可以不带有batch维度,这就表示对于torch.nn.Conv2d可以接受的数据包括3维数据或者4维数据,

如:

conv2 = torch.nn.Conv2d(16, 120, 3, stride=2)
input2_3 = torch.randn(16, 5, 5)
output2_3 = conv2(input2_3)
print(output2_3.shape)

input2_4 = torch.randn(20, 16, 5, 5)
output2_4 = conv2(input2_4)
print(output2_4.shape)

该段得到的输出为:

torch.Size([120, 2, 2])
torch.Size([20, 120, 2, 2])

这是因为input2_4只是多了一个维度batch在第一个维度上,如果输入的数据是2维的或者5维的,就会提示如下的错误:指明只能接受3维的数据或者4维的数据.

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [20, 20, 16, 5, 5]

这其实就说明了根据自己数据维度选择合适的torch.nn.Conv, 例如,如果数据是2维的,那么就选择torch.nn.Conv1d,这个可以接收传入的数据维度可以是2维,或者是带有batch维度的3维数据。

之后需要注意的是in_channels参数其实对应的就是传入数据的第一个维度(不带有batch)或者带有batch的第二个维度,这个要和in_channels参数相同。

可以理解成这个in_channels就是表示了有多个卷积核在参与计算,那么剩下的维度正好就是卷积核的维度,

如对于torch.nn.Conv3d,传入的数据最少是4维数据,(不带有batch),那么第一维的数据应该等于in_channels,然后剩下三维正好的是卷积核的维度。
如:

conv3 = torch.nn.Conv3d(16, 120, 3, stride=2)
input3 = torch.randn(16, 5, 5, 5)
output3 = conv3(input3)
print(output3.shape)

会得到

torch.Size([120, 2, 2, 2])

这个卷积核是333,相当于有16个卷积核,每个卷积核在16维的数据上依次计算。

其他的作为输出影响的是数据的维度大小,但是out_channels又决定了输出数据的第一个维度,(不带有batch),就可以依然用这个方式思考。

针对后面几维数据的大小,由其他的参数决定,这个有公式可以计算,懒得算也可以直接打印输出看一下维度。


http://www.niftyadmin.cn/n/5069191.html

相关文章

python和go相互调用的两种方法

前言 Python 和 Go 语言是两种不同的编程语言,它们分别有自己的优势和适用场景。在一些项目中,由于团队内已有的技术栈或者某一部分业务的需求,可能需要 Python 和 Go 相互调用,以此来提升效率和性能。 性能优势 Go 通常比 Python 更高效&…

Java基础--泛型详解

一、背景 java推出泛型之前,集合元素类型可以是object类型,能够存储任意的数据类型对象,但是在使用过程中,如果不知道集合里面的各个元素的类型,在进行类型转换的时候就很容易引发ClassCastException异常。 二、概念 …

【小沐学Vulkan】Vulkan入门简介与开发环境配置

文章目录 1、简介2、下载和安装3、代码示例3.1 简单测试(glfwglm) 结语 1、简介 https://www.vulkan.org/ Vulkan 是新一代图形和计算 API,用于高效、跨平台访问 GPU。 Vulkan是一个跨平台的2D和3D绘图应用程序接口(API&#xff…

整体网络架构p22

1. 两次卷积,一次池化。得到一个三维特征图,然后让三维的特征图,三个值进行相乘拉成特征向量,把得到的结果需要靠全连接层。 带参数计算才算一层 算conv的个数FC全连接层就得到卷积神经网络的层数 FC:全连接层 2. 3.reset网络&a…

手机图片合成gif怎么操作?用这个网站试试

制作gif动图的工具越来越多,但是很多时候使用电脑并不方便,想要在手机上制作gif动图的时候应该怎么办呢?很简单,给大家分享一款无需下载手机浏览器就能操作的gif制作(https://www.gif.cn/)工具-GIF中文网&a…

[开源项目推荐]privateGPT使用体验和修改

文章目录 一.跑通简述二.解读ingest.py1.导入库和读取环境变量2.文档加载3.文档处理(调用文件加载函数和文本分割函数) 三.injest.py效果演示1.main函数解读2.测试 四.修改代码,使之适配多知识库需求1.修改配置文件:constants.py2…

Postman接口测试学习之常用断言

什么是断言? 断言——就是结果中的特定属性或值与预期做对比,如果一致,则用例通过,如果不一致,断言失败,用例失败。断言,是一个完整测试用例所不可或缺的一部分,没有断言的测试用例…

数据结构-顺序存储二叉树

文章目录 目录 文章目录 前言 一 . 什么是顺序存储二叉树 二 . 模拟实现 前序遍历 总结 前言 大家好,今天给大家讲一下顺序存储二叉树 一 . 什么是顺序存储二叉树 顺序存储二叉树是一种将二叉树的节点按照从上到下、从左到右的顺序存储在数组中的方法。具体来说,顺…