base::Status VecAddLayer::base_forward(){ auto status = this->check(); if (!status) { return status; } auto input1 = this->get_input(0); auto input2 = this->get_input(1); auto output = this->get_output(0); kernel::get_add_kernel(device_type_)(input1, input2, output); return base::error::Success(); }
AddKernel get_add_kernel(base::DeviceType device_type){ if (device_type == base::DeviceType::kDeviceCPU) { return add_kernel_cpu; // 返回一个具体的函数指针 } elseif (device_type == base::DeviceType::kDeviceCUDA) { return add_kernel_cu; } else { LOG(FATAL) << "Unknown device type for get a add kernel."; returnnullptr; } }