调用流程
-
- 获得op_ptr,ck有个工厂模式:
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<DeviceOp>::GetInstances();
-
- 设置参数,这些参数包括输入输出,以及其他必要的配置
auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
-
- 获得invoker_ptr:auto invoker_ptr = op_ptr->MakeInvokerPointer();
-
- run:float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
-
- 结果后处理
Invoker
- 有一个基类BaseInvoker,定义了赋值拷贝,和Run函数(用于算子运行),以及一个虚析构
- 地址:include/ck/tensor_operation/gpu/device/device_base.hpp
- 然后每个算子里面会实现一个Invoker,来实现run的操作
struct BaseInvoker
{
BaseInvoker() = default;
BaseInvoker(const BaseInvoker&) = default;
BaseInvoker& operator=(const BaseInvoker&) = default;
virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
{
return float{0};
}
virtual ~BaseInvoker() {}
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
// run kernel ....
// cost time ....
};
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
};
};
Argument
- 同样有个基类BaseArgument,有一个p_workspace_的void指针参数,暂不清楚做啥的
- 地址:include/ck/tensor_operation/gpu/device/device_base.hpp
- 而每个Operator中都会定义一个Argument子类,里面存一些输入输出,配置等参数
struct BaseArgument
{
BaseArgument() = default;
BaseArgument(const BaseArgument&) = default;
BaseArgument& operator=(const BaseArgument&) = default;
virtual ~BaseArgument() {}
void* p_workspace_ = nullptr;
};
struct Argument : public ck::tensor_operation::device::BaseArgument
{
Argument(const Tensor<ADataType>& a_gs_ms_ks,
const Tensor<BDataType>& b_gs_ns_ks,
Tensor<EDataType>& e_gs_ms_ns,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: a_gs_ms_ks_{a_gs_ms_ks},
b_gs_ns_ks_{b_gs_ns_ks},
e_gs_ms_ns_{e_gs_ms_ns},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op}
{
}
const Tensor<ADataType>& a_gs_ms_ks_;
const Tensor<BDataType>& b_gs_ns_ks_;
Tensor<EDataType>& e_gs_ms_ns_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
};
Operator
- 基类叫BaseOperator,定义如下函数 都是一些比较通用的基础属性:
- IsSupportedArgument
- GetTypeString
- GetTypeIdName
- GetTypeIdHashCode
- GetWorkSpaceSize
- SetWorkSpacePointer
- 通常子类中需要有定义:
- struct Argument/MakeArgumentPointer
- struct Invoke/MakeInvokerPointer
struct BaseOperator
{
BaseOperator() = default;
BaseOperator(const BaseOperator&) = default;
BaseOperator& operator=(const BaseOperator&) = default;
virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
virtual std::string GetTypeString() const { return ""; }
virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
virtual std::string GetTypeIdHashCode() const
{
std::ostringstream oss;
oss << std::hex << typeid(*this).hash_code();
return oss.str();
};
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const
{
assert(p_arg);
p_arg->p_workspace_ = p_workspace;
}
virtual ~BaseOperator() {}
};
DeviceOperationInstanceFactory
- library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
template <typename DeviceOp, typename Tag = void>
struct DeviceOperationInstanceFactory;
- library/include/ck/library/tensor_operation_instance/add_device_operation_instance.hpp
- 这里面有个add_device_operation_instances方法,定义了将op实现加入到vector(instance)中
- 在这之上,有一些函数是用于添加这些instance的,比如device_gemm_dl_f16_f16_f16_km_kn_mn_instances
- 位于library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
- 原理就是把tuple中的元素在add_device_operation_instances中全部加到vector中去
using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple<
// MPerBlock=8, NPerBlock=8
DeviceGemmDl<.....>,
DeviceGemmDl<.....>,
DeviceGemmDl<.....>,
.....
>;
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{});
}
- 然后这个函数会在DeviceOperationInstanceFactory中的GetInstances中被调用到,于是就得到了一个vector数组,里面装满了invoke_ptr实现
- 对于上面这个例子,在这个文件中被调用到:library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp
案例
- client_example/01_gemm/gemm.cpp
- 在这个example中有这样一句代码:
using DeviceOp =
ck::tensor_operation::device::DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
- DeviceGemm这个operator长这样,当然这也是个虚基类,真正的实现实在Impl文件夹中定义的:
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
- 然后会在下一级子类中真正实现:
struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
........
- 然后通过工厂类的GetInstances拿到op_ptrs,接下来就是遍历,在for的过程中需要经过:
- auto argument_ptr = op_ptr->MakeArgumentPointer
- auto invoker_ptr = op_ptr->MakeInvokerPointer
- invoker_ptr->Run
- …
- 这就是这个example干的事儿,实际上在调用的过程中factory应该可以不用,而直接使用实例化的op_ptr
特有名词
以PassThrough为例
- 这是一个传值操作,代码实现位于:include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
- 下面展示了一部分可以看到,函数的作用是传值
struct PassThrough
{
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = x;
}
....
};