CCF-CSP真题《202305-2 矩阵运算》思路+python,c++满分题解

想查看其他题的真题及题解的同学可以前往查看:CCF-CSP真题附题解大全

试题编号:202305-2
试题名称:矩阵运算
时间限制:5.0s
内存限制:512.0MB
问题描述:

题目背景

Softmax(Q×KTd)×V 是 Transformer 中注意力模块的核心算式,其中 Q、K 和 V 均是 n 行 d 列的矩阵,KT 表示矩阵 K 的转置,× 表示矩阵乘法。

问题描述

为了方便计算,顿顿同学将 Softmax 简化为了点乘一个大小为 n 的一维向量 W:
(W⋅(Q×KT))×V
点乘即对应位相乘,记 W(i) 为向量 W 的第 i 个元素,即将 (Q×KT) 第 i 行中的每个元素都与 W(i) 相乘。

现给出矩阵 Q、K 和 V 和向量 W,试计算顿顿按简化的算式计算的结果。

输入格式

从标准输入读入数据。

输入的第一行包含空格分隔的两个正整数 n 和 d,表示矩阵的大小。

接下来依次输入矩阵 Q、K 和 V。每个矩阵输入 n 行,每行包含空格分隔的 d 个整数,其中第 i 行的第 j 个数对应矩阵的第 i 行、第 j 列。

最后一行输入 n 个整数,表示向量 W。

输出格式

输出到标准输出中。

输出共 n 行,每行包含空格分隔的 d 个整数,表示计算的结果。

样例输入

3 2
1 2
3 4
5 6
10 10
-20 -20
30 30
6 5
4 3
2 1
4 0 -5

样例输出

480 240
0 0
-2200 -1100

子任务

70 的测试数据满足:n≤100 且 d≤10;输入矩阵、向量中的元素均为整数,且绝对值均不超过 30。

全部的测试数据满足:n≤104 且 d≤20;输入矩阵、向量中的元素均为整数,且绝对值均不超过 1000。

提示

请谨慎评估矩阵乘法运算后的数值范围,并使用适当数据类型存储矩阵中的整数。

真题来源:矩阵运算

 感兴趣的同学可以如此编码进去进行练习提交

思路讲解:

这道题也不难,再纸上推一下规律就能找到循环去计算的规律。这道题的重点在于时间复杂度,如果先算QK矩阵相乘,会得到n * n的矩阵,会显示超时,所以要先算后面两个矩阵,时间复杂度是可以过的。

c++满分题解:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 10010, D = 30;
LL tmp[D][D], ans[N][N];
int n, d;
int Q[N][D], K[N][D], V[N][D], W[N];
int main()
{
    cin >> n >> d;
    for (int i = 1; i <= n; i ++)
        for (int j = 1; j <= d; j ++)
            cin >> Q[i][j];
    for (int i = 1; i <= n; i ++)
        for (int j = 1; j <= d; j ++)
            cin >> K[i][j];
    for (int i = 1; i <= n; i ++)
        for (int j = 1; j <= d; j ++)
            cin >> V[i][j];
    for (int i = 1; i <= n; i ++) cin >> W[i];
    
	// 计算 K的转置 * V = tmp
    for (int i = 1; i <= d; i ++)
        for (int j = 1; j <= d; j ++)
            for (int k = 1; k <= n; k ++)
                tmp[i][j] += K[k][i] * V[k][j];
                
    // 计算 Q * tmp = ans
    for (int i = 1; i <= n; i ++)
        for (int j = 1; j <= d; j ++)
        {
            for (int k = 1; k <= d; k ++)
                ans[i][j] += Q[i][k] * tmp[k][j];
            ans[i][j] *= (LL) W[i];
        }
        
    for (int i = 1; i <= n; i ++)
    {
        for (int j = 1; j <= d; j ++)
            cout << ans[i][j] << " ";
        cout << endl;
    }
    return 0;
}

 运行结果:

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2023年7月29日
下一篇 2023年7月29日

相关推荐