Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Path: blob/main/apex/csrc/multi_tensor_sgd_kernel.cu
Views: 792
#include <ATen/ATen.h>1#include <ATen/AccumulateType.h>2#include <ATen/cuda/CUDAContext.h>3#include <ATen/cuda/Exceptions.h>4#include "multi_tensor_apply.cuh"5#include "compat.h"67#include <assert.h>8#include <cuda_runtime.h>910#define BLOCK_SIZE 51211#define ILP 41213/**14* Perform fused SGD on multiple buffers15* N: number of tensors16* tl[0] : gradients17* tl[1] : weights18* tl[2] : momentum buffers19* tl[3] : fp16 weights (if appropriate)20* wd : weight_decay (scalar)21* momentum : momentum (scalar)22* dampening : momentum dampening (scalar)23* lr : learning rate (scalar)24* nesterov : enable nesterov (bool)25* first run : necessary for proper momentum handling & init26* wd_after_momentum : apply weight decay _after_ momentum instead of before27**/28template<int N, typename T_grad, typename T_weight>29struct SGDFunctor30{31__device__ __forceinline__ void operator()(32int chunk_size,33volatile int* noop_gmem,34TensorListMetadata<N>& tl,35float wd,36float momentum,37float dampening,38float lr,39bool nesterov,40bool first_run,41bool wd_after_momentum,42float scale)43{44// Early exit if we don't need to do anything45if (*noop_gmem) return;4647int tensor_loc = tl.block_to_tensor[blockIdx.x];48int chunk_idx = tl.block_to_chunk[blockIdx.x];49int n = tl.sizes[tensor_loc];5051T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];52grad_in += chunk_idx*chunk_size;5354T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc];55weight_in += chunk_idx*chunk_size;5657T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc];58mom_in += chunk_idx*chunk_size;5960at::Half *model_weights_out = nullptr;61if(N == 4)62{63model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];64model_weights_out += chunk_idx*chunk_size;65}6667n -= chunk_idx*chunk_size;6869// Non-divergent exit condition for the __syncthreads70float incoming_grads[ILP];71float incoming_weights[ILP];72float incoming_moms[ILP];73for(int i_start = 0;74i_start < n && i_start < chunk_size;75i_start += blockDim.x*ILP)76{77#pragma unroll78for(int ii = 0; ii < ILP; ii++)79{80incoming_grads[ii] = 0;81incoming_weights[ii] = 0;82incoming_moms[ii] = 0;83int i = i_start + threadIdx.x + ii*blockDim.x;84if(i < n && i < chunk_size)85{86incoming_grads[ii] = static_cast<float>(grad_in[i])*scale;87incoming_weights[ii] = static_cast<float>(weight_in[i]);88incoming_moms[ii] = static_cast<float>(mom_in[i]);89}90}9192// note for clarification to future michael:93// From a pure memory dependency perspective, there's likely no point unrolling94// the write loop, since writes just fire off once their LDGs arrive.95// Put another way, the STGs are dependent on the LDGs, but not on each other.96// There is still compute ILP benefit from unrolling the loop though.97#pragma unroll98for(int ii = 0; ii < ILP; ii++)99{100int i = i_start + threadIdx.x + ii*blockDim.x;101if(i < n && i < chunk_size)102{103// apply weight decay before momentum if necessary104if(wd != 0.f && !wd_after_momentum)105incoming_grads[ii] += wd * incoming_weights[ii];106107if(momentum != 0.f)108{109if(!first_run)110incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];111else // initialize momentums to current incoming grads112incoming_moms[ii] = incoming_grads[ii];113114if(nesterov)115incoming_grads[ii] += momentum * incoming_moms[ii];116else117incoming_grads[ii] = incoming_moms[ii];118}119120// Apply WD after momentum if desired121if(wd != 0.f && wd_after_momentum)122incoming_grads[ii] += wd * incoming_weights[ii];123124// adjust the weight and write out125weight_in[i] += (-lr * incoming_grads[ii]);126127// if necessary, write out an fp16 copy of the weights128if(N == 4)129model_weights_out[i] = static_cast<at::Half>(weight_in[i]);130131// also write out the new momentum132if(momentum != 0.f)133mom_in[i] = incoming_moms[ii];134}135}136}137}138};139140void multi_tensor_sgd_cuda(141int chunk_size,142at::Tensor noop_flag,143std::vector<std::vector<at::Tensor>> tensor_lists,144float wd,145float momentum,146float dampening,147float lr,148bool nesterov,149bool first_run,150bool wd_after_momentum,151float scale)152{153auto num_tensors = tensor_lists.size();154auto grad_type = tensor_lists[0][0].scalar_type();155auto weight_type = tensor_lists[1][0].scalar_type();156157if(num_tensors == 4)158for(int i = 0; i < tensor_lists[3].size(); i++)159TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,160"Additional output tensors should always be fp16.");161162TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors");163164// We have 3 possibilities to handle here, in terms of165// grad_type, param_type, momentum_type, requires_fp16_copy166// 1. fp16, fp16, fp16, No167// 2. fp32, fp32, fp32, No168// 3. fp16, fp32, fp32, Yes169// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case170// It's easier to hardcode these possibilities than to use171// switches etc. to handle the cross-product of cases where172// we don't want the majority of them.173174// Case 1. fp16, fp16, fp16, No175if(grad_type == at::ScalarType::Half &&176weight_type == at::ScalarType::Half &&177num_tensors == 3)178{179multi_tensor_apply<3>(180BLOCK_SIZE,181chunk_size,182noop_flag,183tensor_lists,184SGDFunctor<3, at::Half, at::Half>(),185wd,186momentum,187dampening,188lr,189nesterov,190first_run,191wd_after_momentum,192scale);193}194// Case 2. fp16, fp32, fp32, No195// else if (grad_type == at::ScalarType::Half &&196// weight_type == at::ScalarType::Float &&197// num_tensors == 3) {198// multi_tensor_apply<3>(199// BLOCK_SIZE,200// chunk_size,201// noop_flag,202// tensor_lists,203// SGDFunctor<3, at::Half, float>(),204// wd,205// momentum,206// dampening,207// lr,208// nesterov,209// first_run,210// wd_after_momentum);211// }212// Case 2. fp32, fp32, fp32, No213else if(grad_type == at::ScalarType::Float &&214weight_type == at::ScalarType::Float &&215num_tensors == 3)216{217multi_tensor_apply<3>(218BLOCK_SIZE,219chunk_size,220noop_flag,221tensor_lists,222SGDFunctor<3, float, float>(),223wd,224momentum,225dampening,226lr,227nesterov,228first_run,229wd_after_momentum,230scale);231}232// Case 3. fp16, fp32, fp32, Yes233else if(grad_type == at::ScalarType::Half &&234weight_type == at::ScalarType::Float &&235num_tensors == 4)236{237multi_tensor_apply<4>(238BLOCK_SIZE,239chunk_size,240noop_flag,241tensor_lists,242SGDFunctor<4, at::Half, float>(),243wd,244momentum,245dampening,246lr,247nesterov,248first_run,249wd_after_momentum,250scale);251}252// Case 4. fp32, fp32, fp32, Yes253else if(grad_type == at::ScalarType::Float &&254weight_type == at::ScalarType::Float &&255num_tensors == 4)256{257multi_tensor_apply<4>(258BLOCK_SIZE,259chunk_size,260noop_flag,261tensor_lists,262SGDFunctor<4, float, float>(),263wd,264momentum,265dampening,266lr,267nesterov,268first_run,269wd_after_momentum,270scale);271}272else273{274AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",275"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);276}277278AT_CUDA_CHECK(cudaGetLastError());279}280281282