1 #ifndef HL_PYTORCH_WRAPPER_H
2 #define HL_PYTORCH_WRAPPER_H
15 #include "torch/extension.h"
22 #include "cuda_runtime.h"
25 #define HLPT_CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
26 #define HLPT_CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
27 #define HLPT_CHECK_DEVICE(x, dev) AT_ASSERTM(x.is_cuda() && x.get_device() == dev, #x " must be a CUDA tensor")
34 inline std::vector<int>
get_dims(
const at::Tensor tensor) {
35 int ndims = tensor.ndimension();
36 std::vector<int> dims(ndims, 0);
38 for (
int dim = 0; dim < ndims; ++dim) {
39 dims[dim] = tensor.size(ndims - 1 - dim);
44 template<
class scalar_t>
46 AT_ERROR(
"Scalar type ", tensor.scalar_type(),
" not handled by Halide's PyTorch wrapper");
52 #ifdef AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS
53 #define HL_PYTORCH_API_VERSION 13
55 #define HL_PYTORCH_API_VERSION 12
58 #if HL_PYTORCH_API_VERSION >= 13
61 #define HL_PT_DEFINE_TYPECHECK(ctype, ttype) \
63 inline void check_type<ctype>(at::Tensor & tensor) { \
64 AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \
69 #undef HL_PT_DEFINE_TYPECHECK
75 #define HL_PT_DEFINE_TYPECHECK(ctype, ttype, _3) \
77 inline void check_type<ctype>(at::Tensor & tensor) { \
78 AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \
83 #undef HL_PT_DEFINE_TYPECHECK
87 template<
class scalar_t>
89 check_type<scalar_t>(tensor);
90 std::vector<int> dims =
get_dims(tensor);
91 #if HL_PYTORCH_API_VERSION >= 13
92 scalar_t *pData = tensor.data_ptr<scalar_t>();
94 scalar_t *pData = tensor.data<scalar_t>();
99 if (tensor.is_cuda()) {
103 int err = buffer.device_wrap_native(cuda_interface, (
uint64_t)pData);
104 AT_ASSERTM(err == 0,
"halide_device_wrap failed");
105 buffer.set_device_dirty();
107 AT_ERROR(
"Trying to wrap a CUDA tensor, but HL_PT_CUDA was not defined: cuda is not available");
Defines a Buffer type that wraps from halide_buffer_t and adds functionality, and methods for more co...
#define HL_PT_DEFINE_TYPECHECK(ctype, ttype, _3)
Routines specific to the Halide Cuda runtime.
const struct halide_device_interface_t * halide_cuda_device_interface()
A templated Buffer class that wraps halide_buffer_t and adds functionality.
Buffer< scalar_t > wrap(at::Tensor &tensor)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(HL_PT_DEFINE_TYPECHECK)
std::vector< int > get_dims(const at::Tensor tensor)
void check_type(at::Tensor &tensor)
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
unsigned __INT64_TYPE__ uint64_t
Each GPU API provides a halide_device_interface_t struct pointing to the code that manages device all...