MMA, Matrix Multiply and Accumulate

MMA = Matrix Multiply and Accumulate

核心公式:

D = A × B + C
項目 說明
A 輸入矩陣(例如 activations)
B 權重矩陣(例如 weight)
C 累加基底(bias 或上一輪的部分和)
D 輸出矩陣

→ 矩陣相乘之後再累加,一次搞定


為什麼重要?Transformer 裡到處都是 MMA

Transformer 幾乎所有的計算都是矩陣乘法:

Self-Attention:
  Q = X · Wq    ← MMA
  K = X · Wk    ← MMA
  V = X · Wv    ← MMA
  scores = Q · Kᵀ  ← MMA
  out = scores · V  ← MMA

FFN(Feed-Forward Network):
  hidden = X · W1 + b1  ← MMA
  out    = hidden · W2 + b2  ← MMA

→ MMA 是整個深度學習的底層核心操作

參考:Matrix Multiplication Background — NVIDIA Docs


GEMM = General Matrix Multiplication

MMA 的完整形式叫做 GEMM

C = α × (A @ B) + β × C

α、β 是縮放係數,通常設為 1 和 0,退化成:

C = A @ B

GEMM 是 BLAS(Basic Linear Algebra Subprograms)的標準介面
→ cuBLAS、cuDNN 底層全部在跑 GEMM


Tensor Core 是什麼?

GPU 的一般 CUDA Core 每次只做一個純量乘法(scalar FMA)
Tensor Core 是專門加速 MMA 的硬體單元

CUDA Core:每次計算 a × b + c(一個數字)

Tensor Core:每次計算整塊矩陣的 MMA
  → 一個 Tensor Core 一個 clock 完成 4×4×4 的矩陣乘法
  → 64 個 FMA 同時執行

Tensor Core 最早在 Volta 架構(V100)引入
後續 Ampere(A100)、Hopper(H100)持續進化

參考:Programming Tensor Cores in CUDA 9 — NVIDIA Developer Blog


Tensor Core 怎麼做 MMA:Warp Tile

GPU 裡 32 個 thread 組成一個 warp,Tensor Core 操作是 warp 層級:

整個輸出矩陣
      ↓ 切成多個 tile
每個 CTA(thread block)負責一個 tile
      ↓ 再切
每個 warp 計算一個子 tile(例如 16×16×16)
      ↓
warp 內 32 個 thread 協作讀取 A、B 的 fragment
      ↓
Tensor Core 執行 MMA
      ↓
結果累加到輸出矩陣

關鍵:Tensor Core 需要 整個 warp 一起合作,不能單一 thread 獨立呼叫

參考:NVIDIA Tensor Core Programming — Lei Mao's Log Book


Mixed Precision MMA

Tensor Core 原生支援混合精度:

輸入 A、B:FP16(或 BF16)← 低精度,省記憶體
計算過程:FP32              ← 高精度,避免精度損失
輸出 D:FP32 或 FP16

為什麼不全用 FP16?
→ 連續累加時誤差會疊加,FP32 累加保住精度
→ 這就是「Mixed Precision Training」的核心原理

Arch 支援的 MMA 精度
Volta(V100) FP16 × FP16 → FP32
Ampere(A100) FP16、BF16、TF32、INT8
Hopper(H100) 以上 + FP8

參考:Programming Tensor Cores in CUDA 9 — NVIDIA Developer Blog


總結

MMA:D = A × B + C
  ↓
GEMM:C = α(A @ B) + βC(通用矩陣乘法)
  ↓
Tensor Core:硬體加速 MMA 的專用單元
  - 每個 clock 做 64 個 FMA
  - Warp 層級協作執行
  - Mixed Precision:FP16 輸入 × FP32 累加
  ↓
Transformer 裡每個線性變換、Attention 都在跑 MMA
Powered by Forestry.md