cpu版本

初版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
void rmsnorm_kernel_cpu(const tensor::Tensor& input, const tensor::Tensor& weight, const tensor::Tensor& output, void* stream)
{
float* in_ptr = const_cast<float*>(input.ptr<float>());
float* out_ptr = const_cast<float*>(input.ptr<float>());

int size = static_cast<int>(input.size());
float sum = 0.f;
for(int i=0;i<size;i++){
float input_value = input.index<float>(i);
sum += input_value * input_value;
}
const float eps = 1e-5f;
float mean = sum / float(size) + eps;
const float rsqrt = 1.f / std::sqrt(mean);
for(int i=0;i<size;i++){
*(out_ptr + i) = weight.index<float>(i) * (rsqrt * (*(in_ptr + i)));
}
}

最终版

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include<armadillo>

void rmsnorm_kernel_cpu(const tensor::Tensor& input, const tensor::Tensor& weight, const tensor::Tensor& output)
{
CHECK(!input.empty());
CHECK(!weight.empty());
CHECK(!output.empty());

CHECK(input.device_type() == base::DeviceType::KDeviceCPU &&
weight.device_type() == base::DeviceType::KDeviceCPU &&
output.device_type() == base::DeviceType::KDeviceCPU);

const float* in_ptr = input.ptr<float>();
const float* wei_ptr = weight.ptr<float>();
const float* out_prt = output.ptr<float>();
const int32_t dim = static_cast<int32_t>(input.size());

//这里的第三个参数用于控制是否复制数据,如果为false则表示不复制数据,创建的arma::fvec对象会直接使用玩不的数据,不会复制到arma::fvec对象内部
//第四个参数用于控制是否接管外部数据的内存管理,为true时,arma::fvec不会接管外部数据的内存管理,不会再对象销毁时释放外部数据所占用的内存,若为false,则会进行管理
arma::fvec in_tensor(const_cast<float*>(in_ptr), dim, false, true);
arma::fvec out_tensor(const_cast<float*>(out_ptr), dim, false, true);
arma::fvec wei_tensor(const_cast<float*>(wei_ptr), dim, false, true);

const float eps = 1e-5f;

const float mean = arma::as_scalar(arma::mean(arma::pow(in_tensor, 2))) + eps;
const float rsqrt = 1.f / std::sqrt(mean);
out_tensor = wei_tensor % (in_tensor * rsqrt);
}

cuda版本

简单版本

这个代码是有问题的,如果使用的threads数量大于128个,而warpSize为32个,则会有多个线程有相同的lane_id,而CUDA又是以warp为基本的调度单位,所以会有多个相同的线程执行对相同的数据执行操作。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
static __global__ void row_rmsnorm_f32(float* in, float* wei, float* out, int size, float eps)
{
const tid = threadIdx.x;
//warpSize通常为32
const lane_id = tid % warpSize;
float sum = 0.f;
for(int i=lane_id;i<size;i+=warpSize){
sum += in[i] * in[i];
}
using WarpReduce = cub::WarpReduce<float, 32>;
__shared__ typename WarpReduce::TempStorage temp;
__shared__ float shared_val;
sum = WarpReduce(temp).Reduce(sum, cub::Sum());

const float scaler = rsqrtf(sum / static_cast<float>(size) + eps);
for(int i=lane_id; i<size;i += warpSize){
out[i] = in[i] * wei[i] * scalar;
}
}

启动rms核函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
void rmsnorm_kernel_cu(const tensor::Tensor& input, const tensor::Tensor& weight, const tensor::Tensor& output, void* stream)
{
CHECK(!input.is_empty());
CHECK(!output.is_empty());
CHECK(!weight.is_empty());

CHECK(input.device_type() == base::DeviceType::KDeviceGPU &&
output.device_type() == base::DeviceType::KDeviceGPU &&
weight.device_type() == base::DeviceType::KDeviceGPU);

const float eps = 1e-5f;
const int32_t size = static_cast<int32_t>(input.size());
float* in_ptr = const_cast<float*>(input.ptr<float*>());
float* out_ptr = const_cast<float*>(input.ptr<float*>());
float* wei_ptr = const_cast<float*>(input.pyt<float*>());

constexpr int threads_num = 128;
if(stream){
cudaSteam_t stream_ = static_cast<cudaStream_t>(stream);
row_rmsnorm_f32<<<1, threads_num, 0, stream_>>>(in_ptr, wei_ptr, out_ptr, size, eps);
}
else{
row_rmsnorm_f32<<<1, threads_num>>>(in_ptr, wei_ptr, out_ptr, size, eps);
}
}

优化版本

在上面的版本优化之后呢,不以warp为单位进行规约,而是以block为单位进行规约。但是CUDA又是以warp为单位进行调度,这就需要两次规约操作,假设总共1024个数据要处理,一个block包含128个线程,则每个线程要处理8个数据。

每个线程处理完成8个数据之后,在所属的warp里先进行一次规约,即每32个进行一次规约,需要四次规约,并存储4个规约结果,将结果存放在shared_memory中,最后对这四个值,再进行一次规约得到block规约的结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
static __global__ void row_rmsnorm_f32(float* in, float* wei, float* out, int size, float eps)
{
const tid = threadIdx.x;

float sum = 0.f;
for(int i=tid;i<size;i+= blockDim.x){
sum += in[i] * in[i];
}

using BlockReduce = cub::BlockReduce<float, BLOCK_DIM>;
__shared__ typename BlockReduce::TempStorage temp;
__shared__ float shared_val;
sum = BlockReduce(temp).Sum(sum);
if(threadIdx.x == 0){
shared_val = sum;
}
__syncthreads();
sum = shared_val;
const float scale = rsqrtf(sum / static_cast<float>(size) + eps);
for(int i=0;i<size;i+=blockDim.x){
out[i] = scale * in[i] * wei[i];
}
}

使用向量化存储之后的版本 ,更优版本

使用float4进行读写

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
static __global__ void row_rmsnorm_f32(float* in, float* wei, float* out, int size, float eps) {
const int tid = threadIdx.x;

constexpr int pack_size = 4;
const int pack_num = size / pack_size;
const int pack_off = pack_size * pack_num;

float sum = 0.0f;
float4* in_pack = reinterpret_cast<float4*>(in);
for (int i = tid; i < pack_num; i += blockDim.x) {
float4 in_float4 = *(in_pack + i);
sum += in_float4.x * in_float4.x;
sum += in_float4.y * in_float4.y;
sum += in_float4.z * in_float4.z;
sum += in_float4.w * in_float4.w;
}

for (int i = pack_off + tid; i < size; i += blockDim.x) {
sum += in[i] * in[i];
}

using BlockReduce = cub::BlockReduce<float, BLOCK_DIM>;
__shared__ typename BlockReduce::TempStorage temp;
__shared__ float shared_val;
sum = BlockReduce(temp).Sum(sum);
if (threadIdx.x == 0) {
shared_val = sum;
}
__syncthreads();
sum = shared_val;
const float scale = rsqrtf(sum / static_cast<float>(size) + eps);

float4* wei_pack = reinterpret_cast<float4*>(wei);
float4* out_pack = reinterpret_cast<float4*>(out);
for (int i = tid; i < pack_num; i += blockDim.x) {
float4 in_float4 = *(in_pack + i);
float4 wei_float4 = *(wei_pack + i);
*(out_pack + i) =
make_float4(scale * in_float4.x * wei_float4.x, scale * in_float4.y * wei_float4.y,
scale * in_float4.z * wei_float4.z, scale * in_float4.w * wei_float4.w);
}

for (int i = pack_off + tid; i < size; i += blockDim.x) {
out[i] = wei[i] * in[i] * scale;
}
}