Optimizing a MatMul with CuTe C++ (Part 1)
If you like CUDA kernels, you may be aware of the canonical blogpost by Simon Boehm in which he iteratively optimizes a matrix multiplication kernel. In this post, I implement these kernels using CuTe which is a header only library with useful primitives for expressing tensor layouts and indexing. The CuTe docs are great, but I found myself struggling to understand the inner workings and design choices of CuTe until I implemented Simon’s kernels in CuTe. I found that reproducing Simon’s kernels down to the same access/computation pattern but using CuTe’s idioms helped me deeply understand the framework, and this post is an attempt to distill and convey that.
My goal with this article is not to explain the matmuls themselves, Simon’s article already does a great job at that and this post assumes the reader is familiar. Rather, I will go through a few of the most important optimizations and how they would be implemented using CuTe which is sufficient to showcase the framework and learn its core concepts.
In this first post we will write Simon’s 2D blocktiling kernel with CuTe. It’s the first relatively complex kernel and a great starting point.
Preliminaries
The goal with each of these kernels is to do a single-precision generalized matrix multiply (SGEMM) operation with the following operands:
- A: (M, K)
- B: (K, N)
The result is:
- C = alpha * (A @ B) + beta * C
- C is shaped (M, N), alpha and beta are constants
2D Blocktiling
It’s worth a quick review of the kernel we are trying to implement with CuTe. The 2D blocktiling kernel can be summarized as follows.
Computation pattern
| Unit of parallelism | Responsibility |
|---|---|
| block | Calculates a (BM, BN) blocktile of C using a (BM, BK) blocktile from A and a (BK, BN) blocktile from B |
| thread | Calculates a (TM, TN) threadtile of the C blocktile via a matmul of (TM, BK) @ (BK, TN) threadtiles from A and B respectively. The matmul is done via summing outer product of (TM, 1) slices from A and (1, TN) slices from B |
For our purposes we can make problem dimensions and tile dimensions concrete.
M = N = K = 4096(though in figures this may be reduced)BM = BN = 64, BK = 8TM = TN = 8
Memory access pattern
| Unit of parallelism | Responsibility |
|---|---|
| block | Each block loads blocktiles from A/B from global memory (gmem) to a shared memory buffer (smem) |
| thread | Each thread loads threadtile slices from blocktiles from smem to register memory (rmem) |
- All threads in a block cooperatively load A, B blocktiles from gmem —> smem.
- Each thread loads the threadtile slice from the A/B blocktiles from smem —> rmem.
- Each thread does one iteration of outer product accumulation into its (TM, TN) threadtile.
- Repeat steps 2-3 advancing the threadtiles along BK dimension.
- Repeat steps 1-4 advancing the blocktiles along K dimension.
- Write results from rmem —> gmem.
The following visualization shows the hierarchy and access patterns of this matmul.
hint: mouse over the tensors to see how the workload is divided
Now, let’s implement it using CuTe! I will build up the implementation snippet by snippet with Simon’s reference kernel on the left, and the CuTe version on the right.
First, the reference kernel (hover over the steps)
- Defines some bookkeeping variables used for indexing calculations later.
- Allocates shared memory to load the blocktiles into.
- Calculates which blocktile this block handles.
__global__ void sgemm2DBlocktiling(
int M, int N, int K,
float alpha, const float *A, const float *B, float beta, float *C
) {
const uint cRow = blockIdx.y;
const uint cCol = blockIdx.x;
const uint totalResultsBlocktile = BM * BN;
// A thread is responsible for calculating TM*TN elements in the blocktile
const uint numThreadsBlocktile = totalResultsBlocktile / (TM * TN);
// ResultsPerBlock / ResultsPerThread == ThreadsPerBlock
assert(numThreadsBlocktile == blockDim.x);
// BN/TN are the number of threads to span a column
const int threadCol = threadIdx.x % (BN / TN);
const int threadRow = threadIdx.x / (BN / TN);
// allocate space for the current blocktile in smem
__shared__ float As[BM * BK];
__shared__ float Bs[BK * BN];
// Move blocktile to beginning of A's row and B's column
A += cRow * BM * K;
B += cCol * BN;
C += cRow * BM * N + cCol * BN;
__global__ void sgemm2DBlocktilingCute(
int M, int N, int K,
float alpha, const float *A, const float *B, float beta, float *C
) {
using namespace cute;
// create Tensors from data + Layout
auto A_gmem_layout = make_layout(make_shape(M, K), LayoutRight{});
auto B_gmem_layout = make_layout(make_shape(K, N), LayoutRight{});
auto C_gmem_layout = make_layout(make_shape(M, N), LayoutRight{});
Tensor mA = make_tensor(make_gmem_ptr(A), A_gmem_layout);
Tensor mB = make_tensor(make_gmem_ptr(B), B_gmem_layout);
Tensor mC = make_tensor(make_gmem_ptr(C), C_gmem_layout);
// blocktiles in gmem
auto A_blocktile = make_shape(BM, BK);
auto B_blocktile = make_shape(BK, BN);
auto C_blocktile = make_shape(BM, BN);
Tensor gA = local_tile(mA, A_blocktile, make_coord(blockIdx.y, _)); // (BM, BK, k)
Tensor gB = local_tile(mB, B_blocktile, make_coord(_, blockIdx.x)); // (BK, BN, k)
Tensor gC = local_tile(mC, C_blocktile, make_coord(blockIdx.y, blockIdx.x)); // (BM, BN)
// blocktiles in smem
auto A_smem_layout = make_layout(A_blocktile, LayoutRight{});
auto B_smem_layout = make_layout(B_blocktile, LayoutRight{});
__shared__ float A_smem[cosize_v<decltype(A_smem_layout)>];
__shared__ float B_smem[cosize_v<decltype(B_smem_layout)>];
Tensor sA = make_tensor(make_smem_ptr(A_smem), A_smem_layout);
Tensor sB = make_tensor(make_smem_ptr(B_smem), B_smem_layout);
A few things to note about the CuTe version:
- Less indexing variables like
cRoworthreadColin CuTe. Instead of indexing via complex expressions involving these variables, we defer that calculation to a Layout on a Tensor. - We create a Tensor for every object we want to work with ie. blocktiles, the smem space for the blocktiles, later the threadtiles, etc. Anything we index or copy should be a Tensor for the most part to take advantage of CuTe’s utilities.
Layout and Tensor
We’ve just seen two foundational CuTe concepts. First is the Layout.
A Layout is a pair of (shape, stride) where shape and stride are n-tuples, often denoted as shape:stride.
Fundamentally a Layout is just a function that maps an ND logical coordinate to a 1D offset in a buffer ie. (0, 1) -> 1.
The Shape defines the dimensions of the Layout’s coordinate system.
The Stride defines how many elements you have to skip in the buffer to get to the next element along that dimension.
For example, a row major 2x4 matrix would have the layout (2, 4):(4, 1). Again, that layout defines a function which you can call:
// layout - (2, 4):(4, 1)
layout(0, 0) = 0; layout(0, 1) = 1; layout(0, 2) = 2; layout(0, 3) = 3;
layout(1, 0) = 4; layout(1, 1) = 5; layout(1, 2) = 6; layout(1, 3) = 7;
...
A Tensor is just a data buffer combined with a layout that tells you how coordinates map to offsets in the buffer. So to index a Tensor, CuTe simply uses the Layout to calculate what offset to find the data at instead of us doing it manually.
Vanilla:
// declare buffer, we treat it as (M, N) row major
int buffer[M * N];
// index buffer[row, col]
int data = buffer[row * N + col];
With a layout:
// declare buffer, we treat it as (M, N) row major
int buffer[M * N];
auto layout = make_layout(make_shape(M, N), make_stride(N, 1));
// index buffer[row, col]
int data = buffer[layout(row, col)];
With a tensor:
// declare a tensor combining buffer + layout
// we treat it as (M, N) row major
int buffer[M * N];
auto tensor = make_tensor(
buffer,
make_layout(make_shape(M, N), make_stride(N, 1))
);
// index buffer[row, col]
int data = tensor(row, col);
This is a simple example but we will see with more complex indexing that having the Layout instead of calculating offsets manually is very nice.
Now, let’s return to the CuTe snippet. Given the context on Tensors and Layouts, we can now understand this portion:
// create Tensors from data + Layout
auto A_gmem_layout = make_layout(make_shape(M, K), LayoutRight{});
auto B_gmem_layout = make_layout(make_shape(K, N), LayoutRight{});
auto C_gmem_layout = make_layout(make_shape(M, N), LayoutRight{});
Tensor mA = make_tensor(make_gmem_ptr(A), A_gmem_layout);
Tensor mB = make_tensor(make_gmem_ptr(B), B_gmem_layout);
Tensor mC = make_tensor(make_gmem_ptr(C), C_gmem_layout);
// blocktiles in smem
auto A_smem_layout = make_layout(A_blocktile, LayoutRight{});
auto B_smem_layout = make_layout(B_blocktile, LayoutRight{});
__shared__ float A_smem[cosize_v<decltype(A_smem_layout)>];
__shared__ float B_smem[cosize_v<decltype(B_smem_layout)>];
Tensor sA = make_tensor(make_smem_ptr(A_smem), A_smem_layout);
Tensor sB = make_tensor(make_smem_ptr(B_smem), B_smem_layout);
First we declare the layout of our IO tensors with the shapes you’d expect and a stride of LayoutRight{} which denotes row major in cute (make_stride(K, 1) would be equivalent for A for example).
Then we declare mA, mB, mC as Tensors backed by data buffers A/B/C, and layouts specified by the next argument.
Next we declare our smem layouts with the shape of our blocktiles (since the blocktiles will be staged here) and a row major stride.
Finally we declare sA, sB as our smem Tensors, combining the layout with a shared buffer.
Note cosize_v<LayoutT> which at compile time evaluates to the codomain size of the layout typed by LayoutT.
Since a layout is just a function, its codomain size is equivalent to the size of a buffer such that any offset that the layout could possibly produce will be in bounds for that buffer size.
Now let’s look at that last part of our CuTe snippet.
// blocktiles in gmem
auto A_blocktile = make_shape(BM, BK);
auto B_blocktile = make_shape(BK, BN);
auto C_blocktile = make_shape(BM, BN);
Tensor gA = local_tile(mA, A_blocktile, make_coord(blockIdx.y, _)); // (BM, BK, k)
Tensor gB = local_tile(mB, B_blocktile, make_coord(_, blockIdx.x)); // (BK, BN, k)
Tensor gC = local_tile(mC, C_blocktile, make_coord(blockIdx.y, blockIdx.x)); // (BM, BN)
This is where we determine what blocktiles this block actually needs to read.
Simon’s kernel achieves this by advancing the A/B/C pointers via a manual offset calculation to the start of the first blocktile the block needs to read (and then incrementing the pointers in the loop to get the next blocktiles).
With CuTe we don’t need to do that manually, we get a nice utility local_tile for pulling the part of A/B/C that this block needs to see.
This is idiomatic CuTe, declare useful layouts on data and use them to pull this thread/block/warp’s workload with the handy utilities.
The local_tile signature looks something like this: local_tile(tensor, shape, coordinate).
The function will divide the tensor into tiles of shape shape and let you yank out one of those tiles.
Formally, local_tile results in some data shaped like this: ((tile_w, tile_h), (rest_w, rest_h)) and coordinate indexes into the second “rest” mode.
mouse over the visualization to see how different coordinates index the tiling
For C, we use make_coord(blockIdx.y, blockIdx.x) to grab the output blocktile this block should compute. For A/B we use _ in the coordinate to grab all blocktiles along the K dimension, which will be reduced over to calculate the output blocktile.
This results in a 3D shape for A/B where the first 2 dimensions represent the shape of one blocktile, and the last dimension represents the number of blocktiles along K. The resulting tensor shapes are mentioned in the comments.
Great, so we’ve achieved everything our reference kernel does so far. Let’s see the next snippet from the reference kernel.
- Calculate each thread’s share of the blocktile it should load from gmem to smem.
- Declare register memory for storing computation results, and for staging parts of the threadtiles for computation.
// calculating the indices that this thread will load into SMEM
const uint innerRowA = threadIdx.x / BK;
const uint innerColA = threadIdx.x % BK;
// calculates the number of rows of As that are being loaded in a single step
// by a single block
const uint strideA = numThreadsBlocktile / BK;
const uint innerRowB = threadIdx.x / BN;
const uint innerColB = threadIdx.x % BN;
// for both As and Bs we want each load to span the full column-width, for
// better GMEM coalescing (as opposed to spanning full row-width and iterating
// across columns)
const uint strideB = numThreadsBlocktile / BN;
// allocate thread-local cache for results in registerfile
float threadResults[TM * TN] = {0.0};
// register caches for As and Bs
float regM[TM] = {0.0};
float regN[TN] = {0.0};
// part of gA each thread loads to sA
Tensor gA_to_r = local_partition(gA, A_thread_layout, threadIdx.x);
Tensor sA_to_w = local_partition(sA, A_thread_layout, threadIdx.x);
// part of gB each thread loads to sB
Tensor gB_to_r = local_partition(gB, B_thread_layout, threadIdx.x);
Tensor sB_to_w = local_partition(sB, B_thread_layout, threadIdx.x);
// part of sA, sB each thread reads for computation
auto thread_row_C = threadIdx.x / (BN / TN);
auto thread_col_C = threadIdx.x % (BN / TN);
auto A_col_shape = make_shape(TM, 1);
auto B_row_shape = make_shape(1, TN);
Tensor sA_to_r = local_tile(
sA, A_col_shape, make_coord(thread_row_C, _)); // (TM, 1, BK)
Tensor sB_to_r = local_tile(
sB, B_row_shape, make_coord(_, thread_col_C)); // (1, TN, BK)
// part of gC each thread writes results
Tensor gC_to_w = local_tile(
gC, shape(C_thread_layout),
make_coord(thread_row_C, thread_col_C)); // (TM, TN)
// rmem
Tensor thread_results = make_tensor_like(gC_to_w);
clear(thread_results);
Tensor tmp_A = make_tensor_like<float>(make_layout(make_shape(TM)));
Tensor tmp_B = make_tensor_like<float>(make_layout(make_shape(TN)));
We achieve the 2 points above, but additionally we do the following up front, which Simon’s kernel does implicitly later via indexing expressions:
- Calculate which parts of the smem blocktiles the thread needs to read in order to calculate its threadtile.
- Calculate which parts of global memory this thread should write its results to. In Simon’s kernel we also do this, it just shows up as an indexing calculation later on in the hot loop, like
sA[calculation].
Let’s take a closer look.
// part of gA each thread loads to sA
auto A_thread_layout = make_layout(make_shape(Int<8>{}, Int<8>{}), LayoutRight{});
auto B_thread_layout = make_layout(make_shape(Int<1>{}, Int<64>{}), LayoutRight{});
auto C_thread_layout = make_layout(make_shape(Int<8>{}, Int<8>{}));
Tensor gA_to_r = local_partition(gA, A_thread_layout, threadIdx.x);
Tensor sA_to_w = local_partition(sA, A_thread_layout, threadIdx.x);
// part of gB each thread loads to sB
Tensor gB_to_r = local_partition(gB, B_thread_layout, threadIdx.x);
Tensor sB_to_w = local_partition(sB, B_thread_layout, threadIdx.x);
The local_partition function is similar to the local_tile function, but rather than letting you pick out a tile by coordinate, it lets you pick out an element of each tile. Formally, local_partition(tensor, layout, index) tiles tensor according to shape(layout) producing a tensor of shape ((tile_w, tile_h), (rest_w, rest_h)). However instead of indexing into mode 1 which would be grabbing a contiguous tile, it indexes into mode 0. For example local_partition(tensor, layout, 0) grabs the first element from each tile giving you a tensor of shape (rest_w, rest_h).
hover over the tensors to see what portion each thread grabs
You might have noticed we pass local_partition a layout rather than just a shape for the tile.
This is because local_partition uses the stride of the layout to determine how to map the 1D coord we pass in to the natural ND dimensions of the tensor.
By default it is col major like the visualization above.
But in the snippet we pass LayoutRight{} so the mapping becomes row major, that way consecutive threads will load the same row from A and consecutive columns from B leading to coalesced gmem access.
Now that we understand local_partition, it is clear what this code does. For gA, and sA it tiles the tensors into tiles with size equal to the number of threads (64 in this case), and gives each thread an element in each of those tiles. So looping over gA_to_r for example would yield all elements that a given thread is responsible for loading. We will see that this ends up with a much cleaner loop definition later on, whereas in the reference kernel we have to deal with a lot more manual bookkeeping such as strideA and strideB to know how much we can load at once, and how to advance the iterator, etc.
One more thing to note. The shape of gA is 3D as we saw before, the first 2D being the blocktile shape and the last dimension being the number of blocktiles along K. local_partition only operates on the first two modes and preserves the last meaning that the shape of gA_to_r is also 3D.
Next up we create tensors for the parts of smem we load to rmem for computation.
// part of sA, sB each thread reads for computation
auto thread_row_C = threadIdx.x / (BN / TN);
auto thread_col_C = threadIdx.x % (BN / TN);
auto A_col_shape = make_shape(TM, 1);
auto B_row_shape = make_shape(1, TN);
Tensor sA_to_r = local_tile(sA, A_col_shape, make_coord(thread_row_C, _)); // (TM, 1, BK)
Tensor sB_to_r = local_tile(sB, B_row_shape, make_coord(_, thread_col_C)); // (1, TN, BK)
// part of gC each thread writes results
Tensor gC_to_w = local_tile(gC, shape(C_thread_layout), make_coord(thread_row_C, thread_col_C)); // (TM, TN)
Now since we want to grab contiguous blocks, we’re back to using local tile.
Here we grab the slices of rows/cols from A/B that each thread needs to calculate its threadtile.
Additionally we grab the tile of gC the thread should write its results to.
Below is a visualization of the local_tile calls happening:
The coordinate for local_tile is supplied by thread_row_C and thread_col_C which are the coordinates of the threadtile for the current thread.
Finally declaring the register memory in CuTe looks a little different but is functionally the same.
Tensor thread_results = make_tensor_like(gC_to_w);
clear(thread_results);
Tensor tmp_A = make_tensor_like<float>(make_layout(make_shape(TM)));
Tensor tmp_B = make_tensor_like<float>(make_layout(make_shape(TN)));
The make_tensor_like function declares a buffer in rmem and then makes a tensor backed by that buffer with a given layout.
Great, we have the entire preamble of our reference kernel implemented in CuTe. Now we can look at the interesting part, the hot loop. Here it is from our reference kernel. Conceptually this does:
- Load blocktile gmem -> smem.
- Load threadtile slices from A/B from smem -> rmem.
- Accumulate one outer product into
thread_resultsin rmem. - Repeat 2-3 BK times.
- Repeat 1-4 sliding blocktiles along K.
// outer-most loop over block tiles
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
// populate the SMEM caches
for (uint loadOffset = 0; loadOffset < BM; loadOffset += strideA) {
As[(innerRowA + loadOffset) * BK + innerColA] =
A[(innerRowA + loadOffset) * K + innerColA];
}
for (uint loadOffset = 0; loadOffset < BK; loadOffset += strideB) {
Bs[(innerRowB + loadOffset) * BN + innerColB] =
B[(innerRowB + loadOffset) * N + innerColB];
}
__syncthreads();
// advance blocktile
A += BK; // move BK columns to right
B += BK * N; // move BK rows down
// calculate per-thread results
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
// block into registers
for (uint i = 0; i < TM; ++i) {
regM[i] = As[(threadRow * TM + i) * BK + dotIdx];
}
for (uint i = 0; i < TN; ++i) {
regN[i] = Bs[dotIdx * BN + threadCol * TN + i];
}
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
threadResults[resIdxM * TN + resIdxN] +=
regM[resIdxM] * regN[resIdxN];
}
}
}
__syncthreads();
} auto max_tile_idx = shape<2>(gA);
for (int tile_idx = 0; tile_idx < max_tile_idx; tile_idx++) {
// load tiles
Tensor gA_tile = gA_to_r(_, _, tile_idx);
CUTE_UNROLL
for (int i = 0; i < size(gA_tile); i++) {
sA_to_w(i) = gA_tile(i);
}
Tensor gB_tile = gB_to_r(_, _, tile_idx);
CUTE_UNROLL
for (int i = 0; i < size(gB_tile); i++) {
sB_to_w(i) = gB_tile(i);
}
__syncthreads();
// compute partial results for this tile
CUTE_UNROLL
for (int dot_idx = 0; dot_idx < shape<2>(sA_to_r); dot_idx++) {
// load row/col from smem to rmem
Tensor sA_col = sA_to_r(_, _, dot_idx);
CUTE_UNROLL
for (int i = 0; i < size(tmp_A); i++) {
tmp_A(i) = sA_col(i);
}
Tensor sB_row = sB_to_r(_, _, dot_idx);
CUTE_UNROLL
for (int i = 0; i < size(tmp_B); i++) {
tmp_B(i) = sB_row(i);
}
// outer product
CUTE_UNROLL
for (int i = 0; i < shape<0>(thread_results); i++) {
CUTE_UNROLL
for (int j = 0; j < shape<1>(thread_results); j++) {
thread_results(i, j) += tmp_A(i) * tmp_B(j);
}
}
}
__syncthreads();
} A few things to take note of in the CuTe snippet:
- Tensors are indexed via
tensor(i)rather thantensor[i]as you may be used to seeing. - Loop bounds are determined by the shape of tensors. For example, the outer loop
tile_idxranges from0..shape<2>(gA). In this caseshape(gA) = (BM, BK, k)wherek = K / BKor the number of blocktiles along a row of A. Thereforeshape<2>(gA) = kor the number of blocktiles. So the outer loop iterates through the blocktile index as we expect. SincegA_to_ris just alocal_partitionofgA, it also has 3 dimensions and we index into the last one to grab the correct blocktile. Simon’s kernel achieves this by keeping pointers to the start of the right tile, and advancing the pointers throughout the loop. - Indexing is trivial since everything is handled by tensors and layouts we defined earlier. In the entire CuTe snippet you will not see a single indexing calculation. Let’s zoom in on one loop:
Tensor gA_tile = gA_to_r(_, _, tile_idx);
CUTE_UNROLL
for (int i = 0; i < size(gA_tile); i++) {
sA_to_w(i) = gA_tile(i);
}
Notice rather than complex indexing expressions, we’ve defined two tensors sA_to_w and gA_tile which represent the source and destination of our copy and they are of equal size. So all that’s left to do is iterate over the size of the tensor and issue each copy instruction. Above, I wrote it explicitly but CuTe actually provides a utility for this:
Tensor gA_tile = gA_to_r(_, _, tile_idx);
copy(gA_tile, sA_to_w);
which expands to the code above, and even does things like vectorize loads if possible (more on this in a future post). Things are so easy now because we already did the heavy lifting of defining layouts for the src/dst and using local_tile/local_partition to divvy them up so each thread has a chunk of the copying to do.
We see the same pattern again when loading smem -> rmem
Tensor sA_col = sA_to_r(_, _, dot_idx);
CUTE_UNROLL
for (int i = 0; i < size(tmp_A); i++) {
tmp_A(i) = sA_col(i);
}
For comparison here is what it looks like in the reference kernel:
for (uint i = 0; i < TM; ++i) {
regM[i] = As[(threadRow * TM + i) * BK + dotIdx];
}
But remember, the CuTe loop can be replaced by the copy call. I’ve just unsugared it for clarity.
The reference kernel’s indexing expression has two parts:
(threadRow * TM + i) * BK: This selects the correct row fromsAwhich increments withisince we are moving down a col of sA. In CuTe this is handled in thelocal_tilewhen tiling sA.+ dotIdx: This selects the col we are copying.sA_to_ris of shape(TM, 1, BK), so indexing into the last mode will select one thread col.
And finally, the outer product loop is basically the same but we get some nicer indexing ergonomics from CuTe again.
Now we arrive at the last snippet: storing the results back to gmem.
// write out the results
for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) {
for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) {
C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN] =
alpha * threadResults[resIdxM * TN + resIdxN] +
beta * C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN];
}
} // write results back to gmem
CUTE_UNROLL
for (int i = 0; i < shape<0>(thread_results); i++) {
CUTE_UNROLL
for (int j = 0; j < shape<1>(thread_results); j++) {
gC_to_w(i, j) = alpha * thread_results(i, j)
+ beta * gC_to_w(i, j);
}
} Again, much nicer indexing courtesy of the up front work we did defining layouts.
Instead of indexing directly into C, we index into the local tile of C that this thread got from the local_tile call that produced gC_to_w. Recalling how gC_to_w was defined we can see how it is ultimately equivalent to the indexing expression from Simon’s kernel.
Tensor gC = local_tile(mC, C_blocktile, make_coord(blockIdx.y, blockIdx.x)); // (BM, BN)
Tensor gC_to_w = local_tile(gC, shape(C_thread_layout), make_coord(thread_row_C, thread_col_C)); // (TM, TN)
First we created gC which has the (BM, BN) blocktiles from C that this block needs to handle. Within Simon’s kernel this manifests as:
C += cRow * BM * N + cCol * BN;
Then within gC we further tile into the threadtiles, to finally extract this thread’s portion of the blocktile. In Simon’s kernel that manifests as:
(threadRow * TM + resIdxM ) * N + threadCol * TN + resIdxN
These two multiplies select this thread’s portion of the blocktile, this is what local_tile does for us. Finally we loop over the threadtile with nice 2d indexing, which in Simon’s kernel is handled by the two adds.
Conclusion (for now)
I’m releasing these blogposts as a multipart series. Next post I’ll be diving in to some of the more optimized kernels in Simon’s article. Implementing Simon’s kernel 6 (the vectorized loads kernel), we will dig into some of the library internals to see how the copy macro automatically handles vectorization when possible. Implementing kernel 7 (warp tiling) we will see more features of CuTe including advanced copying and gain a deeper understanding of the framework by attempting to write the exact same access/computation patterns as Simon’s kernel with CuTe idioms.
In this post we just scratched the surface. Even in the reference kernel, indexing patterns aren’t too complex (it gets much worse with warp tiling) so CuTe might seem overkill. The main takeaway here is that CuTe gives us a powerful language to describe shapes, layouts, and access patterns along with an ergonomic set of utilities to extract a given thread’s workload from that description. Fundamentally, that’s all we need to write matmuls.
The final kernel sources: