Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adamw thread coarsening kernel #753

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions dev/cuda/adamw.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ thread coarsening/ILP
#include <cuda_runtime.h>
#include "common.h"

#define COARSE_FACTOR 4

// ----------------------------------------------------------------------------
// CPU code reference

void adamw_cpu(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, int t, long num_parameters, float learning_rate=1e-3, float beta1=0.9, float beta2=0.999, float eps=1e-8, float weight_decay=0.0) {
// adapted from: train_gpt2.c

for (int i = 0; i < num_parameters; i++) {
float param = params_memory[i];
float grad = grads_memory[i];
Expand Down Expand Up @@ -97,6 +98,41 @@ __global__ void adamw_kernel2(float* params_memory, const float* grads_memory, f
params_memory[i] -= learning_rate * (m / (sqrtf(v) + eps) + weight_decay * params_memory[i]);
}

// kernel with thread coarsening
__global__ void adamw_kernel3(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;

for (int i = tid; i < num_parameters; i += stride) {
#pragma unroll
for (int c = 0; c < COARSE_FACTOR; c++) {
int idx = i + c * stride;
if (idx >= num_parameters) break;

float grad = grads_memory[idx];
float m = m_memory[idx];
float v = v_memory[idx];
float param = params_memory[idx];

// Update first moment (momentum)
m = lerp(grad, m, beta1);
m_memory[idx] = m;

// Update second moment (RMSprop)
v = lerp(grad * grad, v, beta2);
v_memory[idx] = v;

// Compute bias-corrected moments
float m_hat = m / beta1_correction;
float v_hat = v / beta2_correction;

// Update parameters
param -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * param);
params_memory[idx] = param;
}
}
}

// ----------------------------------------------------------------------------
// kernel launcher
Expand All @@ -106,7 +142,8 @@ void adamw_dispatch1(float* params_memory, const float* grads_memory, float* m_m
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) {
unsigned int block_size = 512;
unsigned int num_blocks = ceil_div(num_parameters, (long) block_size);
adamw_kernel1<<<num_blocks, block_size>>>(params_memory, grads_memory, m_memory, v_memory, num_parameters,
unsigned int shared_size = 4*block_size*sizeof(float);
adamw_kernel1<<<num_blocks, block_size, shared_size>>>(params_memory, grads_memory, m_memory, v_memory, num_parameters,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay);
cudaCheck(cudaGetLastError());
}
Expand All @@ -121,6 +158,16 @@ void adamw_dispatch2(float* params_memory, const float* grads_memory, float* m_m
cudaCheck(cudaGetLastError());
}

// version 3: naive dispatch to thread coarsening kernel
void adamw_dispatch3(float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, long num_parameters,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay) {
unsigned int block_size = 512;
unsigned int num_blocks = ceil_div(num_parameters, (long) block_size);
adamw_kernel3<<<num_blocks, block_size>>>(params_memory, grads_memory, m_memory, v_memory, num_parameters,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay);
cudaCheck(cudaGetLastError());
}

void adamw(int kernel_num,
float* params_memory, const float* grads_memory, float* m_memory, float* v_memory, int t, long num_parameters,
float learning_rate=1e-3, float beta1=0.9, float beta2=0.999, float eps=1e-8, float weight_decay=0.0) {
Expand All @@ -136,14 +183,16 @@ void adamw(int kernel_num,
adamw_dispatch2(params_memory, grads_memory, m_memory, v_memory, num_parameters,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay);
break;
case 3:
adamw_dispatch3(params_memory, grads_memory, m_memory, v_memory, num_parameters,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay);
break;
default:
printf("Invalid kernel number\n");
exit(1);
}
}

// ----------------------------------------------------------------------------

int main(int argc, char **argv) {
setup_main();

Expand Down Expand Up @@ -205,7 +254,7 @@ int main(int argc, char **argv) {
printf("All results match.\n\n");

// now benchmark the kernel
int repeat_times = 1000;
int repeat_times = 100;
float elapsed_time = benchmark_kernel(repeat_times, adamw, kernel_num,
d_params_memory, d_grads_memory, d_m_memory, d_v_memory, t, num_parameters,
learning_rate, beta1, beta2, eps, weight_decay);
Expand Down
Binary file added dev/cuda/profiles/adamw_coalsecing_shared.ncu-rep
Binary file not shown.