1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
base::Status VecAddLayer::base_forward() {
auto status = this->check();
if (!status) {
return status;
}
auto input1 = this->get_input(0);
auto input2 = this->get_input(1);
auto output = this->get_output(0);
kernel::get_add_kernel(device_type_)(input1, input2, output);
return base::error::Success();
}

AddKernel get_add_kernel(base::DeviceType device_type) {
if (device_type == base::DeviceType::kDeviceCPU) {
return add_kernel_cpu; // 返回一个具体的函数指针
} else if (device_type == base::DeviceType::kDeviceCUDA) {
return add_kernel_cu;
} else {
LOG(FATAL) << "Unknown device type for get a add kernel.";
return nullptr;
}
}

上面的VecAddLayer的base_forward函数中调用了kernel::get_add_kernel(device_type_)(input1, input2, output),具体的函数如下面的get_add_kernel函数所示,可根据不同的device_type返回不同的执行函数,但参数必须跟下面定义的函数指针所指定的参数一致。

其中AddKernel类型的定义是一个函数指针,如下面的代码所示:

1
typedef void (*AddKernel)(const tensor::Tensor& input1, const tensor::Tensor& input2, const tensor::Tensor& output, void* stream);