详解 torch.triu:上三角矩阵的高效构造(中英双语)

news/2025/2/24 16:12:48

详解 torch.triu:上三角矩阵的高效构造

在深度学习和矩阵运算中,我们经常需要构造上三角矩阵(Upper Triangular Matrix),其中主对角线以下的元素全部设为 0。PyTorch 提供了一个高效的函数 torch.triu(),用于生成上三角矩阵,并允许我们灵活地调整对角线的偏移量。

在本篇博客中,我们将深入探讨:

  • torch.triu() 的基本用法
  • 第二个参数 diagonal 如何影响结果
  • torch.triu(all_ones, -1 * 2 + 1) 会生成什么
  • 代码示例与应用场景

1. torch.triu 的基本用法

1.1 语法

torch.triu(input, diagonal=0)
  • input:输入张量(必须是 2D 矩阵
  • diagonal:指定从哪条对角线开始保留元素:
    • diagonal=0(默认):保留主对角线及其上的元素
    • diagonal>0:向上偏移 diagonal
    • diagonal<0:向下偏移 diagonal

1.2 示例:默认 diagonal=0

import torch

A = torch.tensor([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])

B = torch.triu(A)
print(B)

输出:

tensor([[1, 2, 3],
        [0, 5, 6],
        [0, 0, 9]])

解释

  • 主对角线(1, 5, 9)及其上方元素(2, 3, 6)被保留
  • 下三角部分(4, 7, 8)被置为 0

2. diagonal 参数的作用

2.1 diagonal > 0:向上偏移

B = torch.triu(A, diagonal=1)
print(B)

输出:

tensor([[0, 2, 3],
        [0, 0, 6],
        [0, 0, 0]])

解释

  • diagonal=1 表示从主对角线上方一行开始保留元素
  • 主对角线元素(1, 5, 9)被置为 0
  • 仅保留 2, 3, 6

2.2 diagonal < 0:向下偏移

B = torch.triu(A, diagonal=-1)
print(B)

输出:

tensor([[1, 2, 3],
        [4, 5, 6],
        [0, 8, 9]])

解释

  • diagonal=-1 表示从主对角线下一行开始保留元素
  • 主对角线以上元素仍保留
  • 下三角部分的 7 变成 0,但 4, 8 仍然保留

3. torch.triu(all_ones, -1 * 2 + 1) 解析

假设:

all_ones = torch.ones(5, 5)
B = torch.triu(all_ones, -1 * 2 + 1)
print(B)

让我们拆解 diagonal 参数:

  • -1 * 2 + 1 = -1
  • 这等价于 torch.triu(all_ones, -1)

all_ones 矩阵

tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]])

torch.triu(all_ones, -1) 结果:

tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1],
        [0, 0, 1, 1, 1],
        [0, 0, 0, 1, 1]])

解释

  • diagonal=-1 意味着主对角线及其上一行都保留
  • 低于 -1 的部分被置 0

4. torch.triu() 的应用场景

4.1 生成注意力掩码(Transformer)

在 Transformer 的自回归解码过程中,我们使用 torch.triu() 生成上三角掩码(mask),避免未来信息泄露:

seq_len = 5
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
print(mask)

输出(掩码矩阵):

tensor([[  0., -inf, -inf, -inf, -inf],
        [  0.,   0., -inf, -inf, -inf],
        [  0.,   0.,   0., -inf, -inf],
        [  0.,   0.,   0.,   0., -inf],
        [  0.,   0.,   0.,   0.,   0.]])

用于 softmax 计算,使模型只能关注当前及之前的 token


4.2 计算上三角矩阵的和

A = torch.tensor([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])
upper_sum = torch.triu(A).sum()
print(upper_sum)  # 26

解释

  • 只保留 1, 2, 3, 5, 6, 9
  • 1 + 2 + 3 + 5 + 6 + 9 = 26

4.3 生成 Pascal 三角形

n = 5
pascal = torch.triu(torch.ones(n, n), diagonal=0)
for i in range(1, n):
    for j in range(1, i+1):
        pascal[i, j] = pascal[i-1, j-1] + pascal[i-1, j]
print(pascal)

输出:

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 2., 1., 0., 0.],
        [1., 3., 3., 1., 0.],
        [1., 4., 6., 4., 1.]])

5. 总结

  • torch.triu() 用于生成上三角矩阵,对角线以下的元素设为 0。
  • diagonal 控制保留的最小对角线
    • diagonal=0:默认保留主对角线及以上
    • diagonal>0:向上偏移,更多元素变 0
    • diagonal<0:向下偏移,更多元素被保留
  • torch.triu(all_ones, -1 * 2 + 1) 生成 diagonal=-1 的上三角矩阵
  • 常见应用
    • Transformer 掩码
    • 矩阵运算
    • 构造 Pascal 三角形

🚀 torch.triu()矩阵计算和深度学习中必不可少的函数,掌握它可以优化你的 PyTorch 代码!

Understanding torch.triu: Constructing Upper Triangular Matrices in PyTorch

In deep learning and matrix computations, upper triangular matrices are widely used, where all elements below the main diagonal are set to zero. PyTorch provides the efficient function torch.triu() to generate upper triangular matrices and allows flexible control over which diagonal to retain.

In this blog post, we will explore:

  • The basic usage of torch.triu()
  • How the second parameter diagonal affects the output
  • What torch.triu(all_ones, -1 * 2 + 1) generates
  • Practical examples and applications

1. Introduction to torch.triu

1.1 Syntax

torch.triu(input, diagonal=0)
  • input: The input tensor (must be a 2D matrix).
  • diagonal: Specifies which diagonal to retain:
    • diagonal=0 (default): Retains the main diagonal and elements above it.
    • diagonal>0: Shifts retention upwards.
    • diagonal<0: Shifts retention downwards.

1.2 Example: Default diagonal=0

import torch

A = torch.tensor([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])

B = torch.triu(A)
print(B)

Output:

tensor([[1, 2, 3],
        [0, 5, 6],
        [0, 0, 9]])

Explanation:

  • The main diagonal (1, 5, 9) and elements above it (2, 3, 6) are retained.
  • The lower triangular part (4, 7, 8) is set to 0.

2. Understanding the diagonal Parameter

2.1 diagonal > 0: Shift upwards

B = torch.triu(A, diagonal=1)
print(B)

Output:

tensor([[0, 2, 3],
        [0, 0, 6],
        [0, 0, 0]])

Explanation:

  • diagonal=1 retains elements from one row above the main diagonal.
  • The main diagonal (1, 5, 9) is set to 0.
  • Only elements 2, 3, 6 are preserved.

2.2 diagonal < 0: Shift downwards

B = torch.triu(A, diagonal=-1)
print(B)

Output:

tensor([[1, 2, 3],
        [4, 5, 6],
        [0, 8, 9]])

Explanation:

  • diagonal=-1 retains elements from one row below the main diagonal.
  • The main diagonal and upper part remain unchanged.
  • The lowest element 7 is set to 0, but 4, 8 are retained.

3. What does torch.triu(all_ones, -1 * 2 + 1) generate?

Assume:

all_ones = torch.ones(5, 5)
B = torch.triu(all_ones, -1 * 2 + 1)
print(B)

Breaking down diagonal:

  • -1 * 2 + 1 = -1
  • Equivalent to torch.triu(all_ones, -1)

all_ones matrix:

tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]])

torch.triu(all_ones, -1) result:

tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1],
        [0, 0, 1, 1, 1],
        [0, 0, 0, 1, 1]])

Explanation:

  • diagonal=-1 means retaining the main diagonal and one row below it.
  • Elements below -1 are set to 0.

4. Applications of torch.triu()

4.1 Generating Attention Masks (Transformers)

In Transformers, upper triangular masks are used to prevent future information leakage during autoregressive decoding:

seq_len = 5
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
mask = mask.masked_fill(mask == 1, float('-inf'))
print(mask)

Output (Mask Matrix):

tensor([[  0., -inf, -inf, -inf, -inf],
        [  0.,   0., -inf, -inf, -inf],
        [  0.,   0.,   0., -inf, -inf],
        [  0.,   0.,   0.,   0., -inf],
        [  0.,   0.,   0.,   0.,   0.]])

This ensures that the model can only attend to current and past tokens.


4.2 Summing the Upper Triangular Matrix

A = torch.tensor([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])
upper_sum = torch.triu(A).sum()
print(upper_sum)  # 26

Explanation:

  • Retains only 1, 2, 3, 5, 6, 9
  • 1 + 2 + 3 + 5 + 6 + 9 = 26

4.3 Constructing Pascal’s Triangle

n = 5
pascal = torch.triu(torch.ones(n, n), diagonal=0)
for i in range(1, n):
    for j in range(1, i+1):
        pascal[i, j] = pascal[i-1, j-1] + pascal[i-1, j]
print(pascal)

Output:

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 2., 1., 0., 0.],
        [1., 3., 3., 1., 0.],
        [1., 4., 6., 4., 1.]])

5. Conclusion

  • torch.triu() constructs upper triangular matrices, setting elements below the specified diagonal to zero.
  • The diagonal parameter controls which diagonal to retain:
    • diagonal=0: Retains the main diagonal and above.
    • diagonal>0: Shifts upwards, removing more elements.
    • diagonal<0: Shifts downwards, keeping more elements.
  • torch.triu(all_ones, -1 * 2 + 1) generates an upper triangular matrix with diagonal=-1.
  • Common use cases:
    • Transformers attention masks
    • Matrix computations
    • Constructing Pascal’s triangle

🚀 torch.triu() is an essential function for matrix computations and deep learning, making PyTorch code more efficient and readable!

后记

2025年2月23日14点50分于上海,在GPT4o大模型辅助下完成。


http://www.niftyadmin.cn/n/5864564.html

相关文章

*PyCharm 安装教程

PyCharm 安装教程&#xff0c;适用于 Windows、macOS 和 Linux 系统&#xff1a; 1. 下载 PyCharm 官网地址&#xff1a;https://www.jetbrains.com/pycharm/版本选择&#xff1a; Community&#xff08;社区版&#xff09;&#xff1a;免费&#xff0c;适合基础 Python 开发…

在聚类算法的领域特定语言(DSL)中添加一个度量矩阵组件

以下是一个详细的步骤和示例代码&#xff0c;用于在聚类算法的领域特定语言&#xff08;DSL&#xff09;中添加一个度量矩阵组件&#xff0c;同时满足处理数据集能达到完美聚类且改进后查询次数少于改进前的要求。 整体思路 定义DSL和原聚类算法&#xff1a;首先&#xff0c;…

【拥抱AI】GPT Researcher 源码试跑成功的心得与总结

一、引言 在人工智能领域&#xff0c;自然语言处理&#xff08;NLP&#xff09;技术的发展日新月异。GPT Researcher 是一个基于大型语言模型&#xff08;LLM&#xff09;的开源研究工具&#xff0c;旨在帮助用户快速生成高质量的研究报告。通过自动化的方式&#xff0c;它能够…

pycharm 创建数据库 以及增删改查

一&#xff0c;数据库 1&#xff0c;介绍&#xff1a; 数据库&#xff08;Database&#xff09;是一个有组织的数据集合&#xff0c;它通常用于存储和管理电子化的信息。这些数据可以是结构化的&#xff0c;如表格中的行和列&#xff0c;也可以是非结构化的&#xff0c;如文本…

深度剖析 C 语言函数递归:原理、应用与优化

在 C 语言的函数世界里&#xff0c;递归是一个独特且强大的概念。它不仅仅是函数调用自身这么简单&#xff0c;背后还蕴含着丰富的思想和广泛的应用。今天&#xff0c;让我们跟随这份课件&#xff0c;深入探索函数递归的奥秘。 一、递归基础&#xff1a;概念与思想 递归是一种…

ubuntu系统 pycharm 卡死了,我用资源监视器将其杀死后,再打开就变成了直接卡死 且在点击Quit Windows无法关闭,只能再次杀死

1. 问题分析&#xff1a; ubuntu系统中 pycharm意外卡死了&#xff0c;我用资源监视器将其杀死后&#xff0c;再打开就变成了直接卡死 且在点击Quit Windows无法关闭此时&#xff0c;只能通过再次杀死Java进程来关掉&#xff0c;但是关掉之后&#xff0c;再打开还是卡死。我必…

C语言番外篇(3)------------>break、continue

看到我的封面图的时候&#xff0c;部分读者可能认为这和编程有什么关系呢&#xff1f; 实际上这个三个人指的是本篇文章有三个部分组成。 在之前的博客中我们提及到了while循环和for循环&#xff0c;在这里面我们学习了它们的基本语法。今天我们要提及的是关于while循环和for…

MyBatis Plus扩展功能

一、代码生成器 二、逻辑删除 三、枚举处理器 像状态字段我们一般会定义一个枚举&#xff0c;做业务判断的时候就可以直接基于枚举做比较。但是我们数据库采用的是int类型&#xff0c;对应的PO也是Integer。因此业务操作时必须手动把枚举与Integer转换&#xff0c;非常麻烦。 …