-
Notifications
You must be signed in to change notification settings - Fork 92
Expand file tree
/
Copy pathmatmul_tiled.cu
More file actions
186 lines (155 loc) · 5.27 KB
/
matmul_tiled.cu
File metadata and controls
186 lines (155 loc) · 5.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
#include <stdio.h>
#include <cuda_runtime.h>
#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
void sgemm_naive_cpu(float *A, float *B, float *C, int M, int N, int K)
{
for (int x = 0; x < M; x++)
{
for (int y = 0; y < N; y++)
{
float sum = 0.0f;
for (int i = 0; i < K; i++)
{
sum += A[x * K + i] * B[i * N + y];
}
C[x * N + y] = sum;
}
}
}
template <const int BM, const int BN, const int BK, const int TM>
__global__ void sgemm_blocktiling_1d_kernel(float *A, float *B, float *C, int M, int N, int K)
{
// the output block that we want to compute in this threadblock
const uint c_row = blockIdx.y;
const uint c_col = blockIdx.x;
// allocate shared memory for the input and output submatrices
__shared__ float A_shared[BM * BK];
__shared__ float B_shared[BK * BN];
// the inner row & col that we're accessing in this thread
const uint thread_row = threadIdx.x / BN;
const uint thread_col = threadIdx.x % BN;
// advance pointers to the starting positions
A += c_row * BM * K;
B += c_col * BN;
C += c_row * BM * N + c_col * BN;
// use to avoid out-of-bounds accesses
int global_m_pos = c_row * BM * K;
int global_n_pos = c_col * BN;
const uint m_size = M * K;
const uint n_size = N * K;
assert(BM * BK == blockDim.x);
assert(BN * BK == blockDim.x);
const uint A_inner_row = threadIdx.x / BK; // warp-level GMEM coalescing
const uint A_inner_col = threadIdx.x % BK;
const uint B_inner_row = threadIdx.x / BN; // warp-level GMEM coalescing
const uint B_inner_col = threadIdx.x % BN;
// allocate thread-local cache for results in registerfile
float thread_results[TM] = {0.0};
// outer loop over block tiles
for (uint bk_idx = 0; bk_idx < K; bk_idx += BK)
{
// load the next block of the input matrices into shared memory
A_shared[A_inner_row * BK + A_inner_col] = (global_m_pos + A_inner_row * K + A_inner_col < m_size) ? A[A_inner_row * K + A_inner_col] : 0.0f;
B_shared[B_inner_row * BN + B_inner_col] = (global_n_pos + B_inner_row * N + B_inner_col < n_size) ? B[B_inner_row * N + B_inner_col] : 0.0f;
// wait for all threads to finish loading
__syncthreads();
// advance the pointers
A += BK;
B += BK * N;
global_m_pos += BK;
global_n_pos += BK * N;
// compute the partial sum
for (uint dot_idx = 0; dot_idx < BK; dot_idx++)
{
// we make the dotproduct loop the outside loop, which facilitates
// reuse of the Bs entry, which we can cache in a tmp var.
float tmp_b = B_shared[dot_idx * BN + thread_col];
for (uint res_idx = 0; res_idx < TM; res_idx++)
{
thread_results[res_idx] += A_shared[(thread_row * TM + res_idx) * BK + dot_idx] * tmp_b;
}
}
// wait for all threads to finish computing
__syncthreads();
}
for (uint res_idx = 0; res_idx < TM; res_idx++)
{
if (c_row * BM + thread_row * TM + res_idx < M && c_col * BN + thread_col < N)
{
C[(thread_row * TM + res_idx) * N + thread_col] = thread_results[res_idx];
}
}
}
void run_sgemm_blocktiling_1d(float *A, float *B, float *C, int m, int n, int k)
{
const uint BM = 64;
const uint BN = 64;
const uint BK = 8;
const uint TM = 8;
dim3 grid_size(CEIL_DIV(n, BN), CEIL_DIV(m, BM));
dim3 block_size((BM * BN) / TM);
sgemm_blocktiling_1d_kernel<BM, BN, BK, TM>
<<<grid_size, block_size>>>(A, B, C, m, n, k);
}
void randomize_matrix(float *mat, int N)
{
for (int i = 0; i < N; i++)
{
mat[i] = rand() % 100;
}
}
int main()
{
int m = 256;
int n = 256;
int k = 256;
// Allocate memory for matrices
float *A, *B, *C, *C_ref;
float *d_A, *d_B, *d_C;
A = new float[m * k];
B = new float[k * n];
C = new float[m * n];
// save reference result
C_ref = new float[m * n];
// Initialize matrices
randomize_matrix(A, m * k);
randomize_matrix(B, k * n);
// Allocate device memory
cudaMalloc((void **)&d_A, m * k * sizeof(float));
cudaMalloc((void **)&d_B, k * n * sizeof(float));
cudaMalloc((void **)&d_C, m * n * sizeof(float));
// Copy matrices to device
cudaMemcpy(d_A, A, m * k * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_B, B, k * n * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_C, C, m * n * sizeof(float), cudaMemcpyHostToDevice);
run_sgemm_blocktiling_1d(d_A, d_B, d_C, m, n, k);
// Copy result to host
cudaMemcpy(C, d_C, m * n * sizeof(float), cudaMemcpyDeviceToHost);
// Run reference sgemm
sgemm_naive_cpu(A, B, C_ref, m, n, k);
// Verify result
for (int i = 0; i < m * n; i++)
{
if (C[i] != C_ref[i])
{
printf("Error: mismatch at index %d, expected %f, got %f\n", i, C_ref[i], C[i]);
return 1;
}
}
free(A);
free(B);
free(C);
free(C_ref);
A = nullptr;
B = nullptr;
C = nullptr;
C_ref = nullptr;
cudaFree(d_A);
cudaFree(d_B);
cudaFree(d_C);
d_A = nullptr;
d_B = nullptr;
d_C = nullptr;
printf("Success!\n");
return 0;
}