新开一个专题来介绍一下矩阵计算相关的内容, 从最基本的算法,到Cutlass这些线性代数模版库, 特别是Layout代数相关的内容, 后面再逐渐细化到一些硬件实现访存优化和一些算子融合相关的话题, 准备工作闲暇时间有点空就补一点, 做个长期的专栏.
1. GEMM概述
1.1 GEMM定义
对于一个矩阵乘法, 我们定义如下:
如下图所示:
1.1.1 内积形式
因此我们可以构建一个最简单的算法
for (int i = 0; i < M; ++i)
for (int j = 0; j < N; ++j)
for (int k = 0; k < K; ++k)
C[i][j] += A[i][k] * B[k][j];
这种乘法是也被称为矩阵乘法的内积形式
我们可以注意整个过程中随着循环, B矩阵的乘法空间局部性很差,存在多次访问, 因此我们尽量需要缓存一些数据来避免缓存颠簸(cache thrashing)
1.1.2 外积形式
换一种思路, 如果我们按照如下方法构建乘法
其中
即我们可以把K维度放在最外面, 这样A和B矩阵都可以按照列和行整个一块的读取.
for (int k = 0; k < K; ++k) //dim-k at outer loop
//outer-product for C_i
for (int i = 0; i < M; ++i)
for (int j = 0; j < N; ++j)
C[i][j] += A[i][k] * B[k][j];
2. 分块矩阵乘法
2.1 分块乘法的原因
如果矩阵维度(M,K,N)规模很大时, 需要大量的片上缓存数据结果,计算效率很低. 矩阵分块乘法(Block Matrix Multiplication)在计算机科学和数学领域中是一种非常实用的技术,尤其当处理大规模矩阵时,它提供了几个关键优势:
内存限制:对于非常大的矩阵,可能无法一次性将整个矩阵加载到内存中。通过将大矩阵分成较小的块(子矩阵),可以只加载一部分到内存中进行计算,然后交换出其他部分,从而管理有限的内存资源。 并行计算:现代处理器和计算架构,如多核CPU、GPU以及分布式系统,都支持并行计算。矩阵分块乘法允许将矩阵乘法任务分解成更小的独立任务,这些任务可以在不同的处理器核心或节点上同时进行,从而加速计算过程。 缓存优化:计算机的缓存层次结构意味着访问连续或接近的数据比访问随机分布的数据更快。通过适当地分块矩阵,可以确保计算过程中频繁访问的数据位于缓存中,减少缓存缺失,提高计算效率。 易于实现:从编程的角度来看,分块乘法往往更容易理解和实现,尤其是当涉及到并行编程时。它提供了一种直观的方法来划分工作负载和数据。
2.2 分块乘法
通常我们可以把一个矩阵分成多个块, 例如
我们可以将其划分为 4个块
分块后的矩阵记为
分块矩阵乘法如下所示:
划分不一定需要完全等间隔, 只需要满足子矩阵乘法规则即可, 例如
更一般的来讲, 如下图所示:
给定一个的矩阵切分为行列
另一个的矩阵切分为行列,
则它们的乘积计算如下:
相应的乘法循环代码如下
for (int m = 0; m < M; m += Mtile) // iterate over M dimension
for (int n = 0; n < N; n += Ntile) // iterate over N dimension
for (int k = 0; k < K; ++k)
for (int i = 0; i < Mtile; ++i) // compute one tile
for (int j = 0; j < Ntile; ++j) {
int row = m + i;
int col = n + j;
C[row][col] += A[row][k] * B[k][col];
}
3. 硬件的视角看GEMM
Standford的CS217课程是一个很好的参考资料.
3.1 分块乘法的内存层次架构
分块矩阵乘法如下所示, 通过将矩阵分块拆分,能够在处理器的Cache和寄存器内存放进行快速计算.计算完成后写回主存.
首先所有的数据都在主内存中,如下图所示:
然后在分块加载到L2Cache中, 完成分块子任务
在进入计算核内部时,将向量Block进一步拆分成更小的Tile进行计算
整个计算过程的访存层次化结构如下:
熟悉这个过程就对英伟达的矩阵乘法流程清楚了
它也通过多次将矩阵逐渐分为更小的Tile进行计算, 原理是相通的
3.2 Memory Layout
我们注意到在计算的过程中, 为了保证访存的连续性, A矩阵需要按行排序, 而B矩阵需要按列排序
而在分块的内部进行块乘法时,访存顺序变为列优先/行优先的方式, 因此矩阵的Layout变成了一种Z字排列, 如下所示:
4. GEMM in action
这一节我们给两个手工实现的Naive GEMM和Block GEMM的例子来解释矩阵分块乘法的原理和性能影响, 可以看到性能差距接近53倍. 按照测试的A10 GPU峰值FP32算力31TFFLOPS来算, 最朴素的算法由于访存效率的问题, 浮点算力仅为峰值的1%
# ./naive
AveragePerformance 0.2336 Tflops
# ./block
AveragePerformance 10.7669 Tflops
在下一篇文章我们再来详细谈谈矩阵乘法优化相关的内容, 再到第三篇文章引出Cutlass.
4.1 Naive GEMM
最简单的矩阵乘法如下:
#define OFFSET(row, col, stride) ((row) * (stride) + (col))
__global__ void basic_gemm(
float * A, float * B, float * C,
const int M, const int N, const int K) {
int _x = blockIdx.x * blockDim.x + threadIdx.x;
int _y = blockIdx.y * blockDim.y + threadIdx.y;
if (_x < M && _y < N) {
float sum = 0.0;
for (int k = 0; k < K; k++) {
sum +=A[OFFSET(_x, k, K)] * B[OFFSET(k , _y, N)];
}
C[OFFSET(_x, _y, N)] = sum;
}
}
在A10上测试其FLOS大概仅有233GFlops
int main() {
const int M = 4096;
const int K = 1024;
const int N = 4096;
const int ITER = 100;
dim3 gridDim(ceil(M/32), ceil(N/32), 1);
dim3 blockDim(32, 32, 1);
float *d_a, *d_b, *d_c ;
cudaMalloc(&d_a, M * K * sizeof(float));
cudaMalloc(&d_b, K * N * sizeof(float));
cudaMalloc(&d_c, M * N * sizeof(float));
cudaEvent_t start, end;
cudaEventCreate(&start);
cudaEventCreate(&end);
cudaEventRecord(start);
for (int i = 0; i < ITER; i++)
basic_gemm<<<gridDim, blockDim>>>(d_a, d_b, d_c, M, N, K);
cudaEventRecord(end);
cudaEventSynchronize(end);
float msec;
cudaEventElapsedTime(&msec, start, end);
long workload = long(M) * N * K * 2 * ITER;
double avg_Tflops = ((double)workload / 1e12 ) / (double(msec)/ 1e3);
printf("AveragePerformance %6.4lf Tflops\n",avg_Tflops);
cudaFree(d_a);
cudaFree(d_b);
cudaFree(d_c);
}
4.2 Block GEMM
代码来自于Anthropic Performance & Kernel团队的siboehm 《How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog》[1] 相关的代码画了一个容易理解的示意图
详细的内容我们将在下一篇讲述.
__global__ void block2d_gemm(const float *A, const float *B, float *C,
int M, int N, int K) {
const int BM = 128;
const int BN = 128;
const int BK = 8;
const int TM = 8;
const int TN = 8;
const uint cRow = blockIdx.y;
const uint cCol = blockIdx.x;
const uint totalResultsBlocktile = BM * BN;
// A thread is responsible for calculating TM*TN elements in the blocktile
const uint numThreadsBlocktile = totalResultsBlocktile / (TM * TN);
// BN/TN are the number of threads to span a column
const int threadCol = threadIdx.x % (BN / TN);
const int threadRow = threadIdx.x / (BN / TN);
// allocate space for the current blocktile in smem
__shared__ float As[BM * BK];
__shared__ float Bs[BK * BN];
// Move blocktile to beginning of A's row and B's column
A += cRow * BM * K;
B += cCol * BN;
C += cRow * BM * N + cCol * BN;
// calculating the indices that this thread will load into SMEM
const uint innerRowA = threadIdx.x / BK;
const uint innerColA = threadIdx.x % BK;
// calculates the number of rows of As that are being loaded in a single step
// by a single block
const uint strideA = numThreadsBlocktile / BK;
const uint innerRowB = threadIdx.x / BN;
const uint innerColB = threadIdx.x % BN;
// for both As and Bs we want each load to span the full column-width, for
// better GMEM coalescing (as opposed to spanning full row-width and iterating
// across columns)
const uint strideB = numThreadsBlocktile / BN;
// allocate thread-local cache for results in registerfile
float threadResults[TM * TN] = {0.0};
// register caches for As and Bs
float regM[TM] = {0.0};
float regN[TN] = {0.0};
// outer-most loop over block tiles
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
// populate the SMEM caches
for (uint loadOffset = 0; loadOffset < BM; loadOffset += strideA) {
As[(innerRowA + loadOffset) * BK + innerColA] =
A[(innerRowA + loadOffset) * K + innerColA];
}
for (uint loadOffset = 0; loadOffset < BK; loadOffset += strideB) {
Bs[(innerRowB + loadOffset) * BN + innerColB] =
B[(innerRowB + loadOffset) * N + innerColB];
}
__syncthreads();
// advance blocktile
A += BK; // move BK columns to right
B += BK * N; // move BK rows down
// calculate per-thread results
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
// block into registers
for (uint i = 0; i < TM; ++i) {
regM[i] = As[(threadRow * TM + i) * BK + dotIdx];
}
for (uint i = 0; i < TN; ++i) {
regN[i] = Bs[dotIdx * BN + threadCol * TN + i];
}
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
threadResults[resIdxM * TN + resIdxN] +=
regM[resIdxM] * regN[resIdxN];
}
}
}
__syncthreads();
}
// write out the results
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN] =
threadResults[resIdxM * TN + resIdxN] +
C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN];
}
}
}
How to Optimize a CUDA Matmul Kernel for cuBLAS-like Performance: a Worklog: https://siboehm.com/articles/22/CUDA-MMM