Halide 16.0.0
Halide compiler and libraries
Expr.h
Go to the documentation of this file.
1#ifndef HALIDE_EXPR_H
2#define HALIDE_EXPR_H
3
4/** \file
5 * Base classes for Halide expressions (\ref Halide::Expr) and statements (\ref Halide::Internal::Stmt)
6 */
7
8#include <string>
9#include <vector>
10
11#include "IntrusivePtr.h"
12#include "Type.h"
13
14namespace Halide {
15
16struct bfloat16_t;
17struct float16_t;
18
19namespace Internal {
20
21class IRMutator;
22class IRVisitor;
23
24/** All our IR node types get unique IDs for the purposes of RTTI */
25enum class IRNodeType {
26 // Exprs, in order of strength. Code in IRMatch.h and the
27 // simplifier relies on this order for canonicalization of
28 // expressions, so you may need to update those modules if you
29 // change this list.
30 IntImm,
31 UIntImm,
35 Cast,
38 Add,
39 Sub,
40 Mod,
41 Mul,
42 Div,
43 Min,
44 Max,
45 EQ,
46 NE,
47 LT,
48 LE,
49 GT,
50 GE,
51 And,
52 Or,
53 Not,
54 Select,
55 Load,
56 Ramp,
57 Call,
58 Let,
59 Shuffle,
61 // Stmts
62 LetStmt,
65 For,
66 Acquire,
67 Store,
68 Provide,
70 Free,
71 Realize,
72 Block,
73 Fork,
77 Atomic
78};
79
81
82/** The abstract base classes for a node in the Halide IR. */
83struct IRNode {
84
85 /** We use the visitor pattern to traverse IR nodes throughout the
86 * compiler, so we have a virtual accept method which accepts
87 * visitors.
88 */
89 virtual void accept(IRVisitor *v) const = 0;
91 : node_type(t) {
92 }
93 virtual ~IRNode() = default;
94
95 /** These classes are all managed with intrusive reference
96 * counting, so we also track a reference count. It's mutable
97 * so that we can do reference counting even through const
98 * references to IR nodes.
99 */
101
102 /** Each IR node subclass has a unique identifier. We can compare
103 * these values to do runtime type identification. We don't
104 * compile with rtti because that injects run-time type
105 * identification stuff everywhere (and often breaks when linking
106 * external libraries compiled without it), and we only want it
107 * for IR nodes. One might want to put this value in the vtable,
108 * but that adds another level of indirection, and for Exprs we
109 * have 32 free bits in between the ref count and the Type
110 * anyway, so this doesn't increase the memory footprint of an IR node.
111 */
113};
114
115template<>
116inline RefCount &ref_count<IRNode>(const IRNode *t) noexcept {
117 return t->ref_count;
118}
119
120template<>
121inline void destroy<IRNode>(const IRNode *t) {
122 delete t;
123}
124
125/** IR nodes are split into expressions and statements. These are
126 similar to expressions and statements in C - expressions
127 represent some value and have some type (e.g. x + 3), and
128 statements are side-effecting pieces of code that do not
129 represent a value (e.g. assert(x > 3)) */
130
131/** A base class for statement nodes. They have no properties or
132 methods beyond base IR nodes for now. */
133struct BaseStmtNode : public IRNode {
135 : IRNode(t) {
136 }
137 virtual Stmt mutate_stmt(IRMutator *v) const = 0;
138};
139
140/** A base class for expression nodes. They all contain their types
141 * (e.g. Int(32), Float(32)) */
142struct BaseExprNode : public IRNode {
144 : IRNode(t) {
145 }
146 virtual Expr mutate_expr(IRMutator *v) const = 0;
148};
149
150/** We use the "curiously recurring template pattern" to avoid
151 duplicated code in the IR Nodes. These classes live between the
152 abstract base classes and the actual IR Nodes in the
153 inheritance hierarchy. It provides an implementation of the
154 accept function necessary for the visitor pattern to work, and
155 a concrete instantiation of a unique IRNodeType per class. */
156template<typename T>
157struct ExprNode : public BaseExprNode {
158 void accept(IRVisitor *v) const override;
159 Expr mutate_expr(IRMutator *v) const override;
161 : BaseExprNode(T::_node_type) {
162 }
163 ~ExprNode() override = default;
164};
165
166template<typename T>
167struct StmtNode : public BaseStmtNode {
168 void accept(IRVisitor *v) const override;
169 Stmt mutate_stmt(IRMutator *v) const override;
171 : BaseStmtNode(T::_node_type) {
172 }
173 ~StmtNode() override = default;
174};
175
176/** IR nodes are passed around opaque handles to them. This is a
177 base class for those handles. It manages the reference count,
178 and dispatches visitors. */
179struct IRHandle : public IntrusivePtr<const IRNode> {
181 IRHandle() = default;
182
184 IRHandle(const IRNode *p)
185 : IntrusivePtr<const IRNode>(p) {
186 }
187
188 /** Dispatch to the correct visitor method for this node. E.g. if
189 * this node is actually an Add node, then this will call
190 * IRVisitor::visit(const Add *) */
191 void accept(IRVisitor *v) const {
192 ptr->accept(v);
193 }
194
195 /** Downcast this ir node to its actual type (e.g. Add, or
196 * Select). This returns nullptr if the node is not of the requested
197 * type. Example usage:
198 *
199 * if (const Add *add = node->as<Add>()) {
200 * // This is an add node
201 * }
202 */
203 template<typename T>
204 const T *as() const {
205 if (ptr && ptr->node_type == T::_node_type) {
206 return (const T *)ptr;
207 }
208 return nullptr;
209 }
210
212 return ptr->node_type;
213 }
214};
215
216/** Integer constants */
217struct IntImm : public ExprNode<IntImm> {
219
220 static const IntImm *make(Type t, int64_t value);
221
223};
224
225/** Unsigned integer constants */
226struct UIntImm : public ExprNode<UIntImm> {
228
229 static const UIntImm *make(Type t, uint64_t value);
230
232};
233
234/** Floating point constants */
235struct FloatImm : public ExprNode<FloatImm> {
236 double value;
237
238 static const FloatImm *make(Type t, double value);
239
241};
242
243/** String constants */
244struct StringImm : public ExprNode<StringImm> {
245 std::string value;
246
247 static const StringImm *make(const std::string &val);
248
250};
251
252} // namespace Internal
253
254/** A fragment of Halide syntax. It's implemented as reference-counted
255 * handle to a concrete expression node, but it's immutable, so you
256 * can treat it as a value type. */
257struct Expr : public Internal::IRHandle {
258 /** Make an undefined expression */
260 Expr() = default;
261
262 /** Make an expression from a concrete expression node pointer (e.g. Add) */
265 : IRHandle(n) {
266 }
267
268 /** Make an expression representing numeric constants of various types. */
269 // @{
270 explicit Expr(int8_t x)
271 : IRHandle(Internal::IntImm::make(Int(8), x)) {
272 }
273 explicit Expr(int16_t x)
274 : IRHandle(Internal::IntImm::make(Int(16), x)) {
275 }
277 : IRHandle(Internal::IntImm::make(Int(32), x)) {
278 }
279 explicit Expr(int64_t x)
280 : IRHandle(Internal::IntImm::make(Int(64), x)) {
281 }
282 explicit Expr(uint8_t x)
283 : IRHandle(Internal::UIntImm::make(UInt(8), x)) {
284 }
285 explicit Expr(uint16_t x)
286 : IRHandle(Internal::UIntImm::make(UInt(16), x)) {
287 }
288 explicit Expr(uint32_t x)
289 : IRHandle(Internal::UIntImm::make(UInt(32), x)) {
290 }
291 explicit Expr(uint64_t x)
292 : IRHandle(Internal::UIntImm::make(UInt(64), x)) {
293 }
295 : IRHandle(Internal::FloatImm::make(Float(16), (double)x)) {
296 }
298 : IRHandle(Internal::FloatImm::make(BFloat(16), (double)x)) {
299 }
300 Expr(float x)
301 : IRHandle(Internal::FloatImm::make(Float(32), x)) {
302 }
303 explicit Expr(double x)
304 : IRHandle(Internal::FloatImm::make(Float(64), x)) {
305 }
306 // @}
307
308 /** Make an expression representing a const string (i.e. a StringImm) */
309 Expr(const std::string &s)
310 : IRHandle(Internal::StringImm::make(s)) {
311 }
312
313 /** Override get() to return a BaseExprNode * instead of an IRNode * */
316 return (const Internal::BaseExprNode *)ptr;
317 }
318
319 /** Get the type of this expression node */
321 Type type() const {
322 return get()->type;
323 }
324};
325
326/** This lets you use an Expr as a key in a map of the form
327 * map<Expr, Foo, ExprCompare> */
329 bool operator()(const Expr &a, const Expr &b) const {
330 return a.get() < b.get();
331 }
332};
333
334/** A single-dimensional span. Includes all numbers between min and
335 * (min + extent - 1). */
336struct Range {
338
339 Range() = default;
340 Range(const Expr &min_in, const Expr &extent_in);
341};
342
343/** A multi-dimensional box. The outer product of the elements */
344typedef std::vector<Range> Region;
345
346/** An enum describing different address spaces to be used with Func::store_in. */
347enum class MemoryType {
348 /** Let Halide select a storage type automatically */
349 Auto,
350
351 /** Heap/global memory. Allocated using halide_malloc, or
352 * halide_device_malloc */
353 Heap,
354
355 /** Stack memory. Allocated using alloca. Requires a constant
356 * size. Corresponds to per-thread local memory on the GPU. If all
357 * accesses are at constant coordinates, may be promoted into the
358 * register file at the discretion of the register allocator. */
359 Stack,
360
361 /** Register memory. The allocation should be promoted into the
362 * register file. All stores must be at constant coordinates. May
363 * be spilled to the stack at the discretion of the register
364 * allocator. */
365 Register,
366
367 /** Allocation is stored in GPU shared memory. Also known as
368 * "local" in OpenCL, and "threadgroup" in metal. Can be shared
369 * across GPU threads within the same block. */
370 GPUShared,
371
372 /** Allocation is stored in GPU texture memory and accessed through
373 * hardware sampler */
375
376 /** Allocate Locked Cache Memory to act as local memory */
378 /** Vector Tightly Coupled Memory. HVX (Hexagon) local memory available on
379 * v65+. This memory has higher performance and lower power. Ideal for
380 * intermediate buffers. Necessary for vgather-vscatter instructions
381 * on Hexagon */
382 VTCM,
383
384 /** AMX Tile register for X86. Any data that would be used in an AMX matrix
385 * multiplication must first be loaded into an AMX tile register. */
386 AMXTile,
387};
388
389namespace Internal {
390
391/** An enum describing a type of loop traversal. Used in schedules,
392 * and in the For loop IR node. Serial is a conventional ordered for
393 * loop. Iterations occur in increasing order, and each iteration must
394 * appear to have finished before the next begins. Parallel, GPUBlock,
395 * and GPUThread are parallel and unordered: iterations may occur in
396 * any order, and multiple iterations may occur
397 * simultaneously. Vectorized and GPULane are parallel and
398 * synchronous: they act as if all iterations occur at the same time
399 * in lockstep. */
400enum class ForType {
401 Serial,
402 Parallel,
404 Unrolled,
405 Extern,
406 GPUBlock,
407 GPUThread,
408 GPULane,
409};
410
411/** Check if for_type executes for loop iterations in parallel and unordered. */
413
414/** Returns true if for_type executes for loop iterations in parallel. */
415bool is_parallel(ForType for_type);
416
417/** A reference-counted handle to a statement node. */
418struct Stmt : public IRHandle {
419 Stmt() = default;
421 : IRHandle(n) {
422 }
423
424 /** Override get() to return a BaseStmtNode * instead of an IRNode * */
426 const BaseStmtNode *get() const {
427 return (const Internal::BaseStmtNode *)ptr;
428 }
429
430 /** This lets you use a Stmt as a key in a map of the form
431 * map<Stmt, Foo, Stmt::Compare> */
432 struct Compare {
433 bool operator()(const Stmt &a, const Stmt &b) const {
434 return a.ptr < b.ptr;
435 }
436 };
437};
438
439} // namespace Internal
440} // namespace Halide
441
442#endif
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:40
Support classes for reference-counting via intrusive shared pointers.
Defines halide types.
A base class for passes over the IR which modify it (e.g.
Definition: IRMutator.h:26
A base class for algorithms that need to recursively walk over the IR.
Definition: IRVisitor.h:19
A class representing a reference count to be used with IntrusivePtr.
Definition: IntrusivePtr.h:19
constexpr IRNodeType StrongestExprNodeType
Definition: Expr.h:80
ForType
An enum describing a type of loop traversal.
Definition: Expr.h:400
RefCount & ref_count< IRNode >(const IRNode *t) noexcept
Definition: Expr.h:116
bool is_unordered_parallel(ForType for_type)
Check if for_type executes for loop iterations in parallel and unordered.
bool is_parallel(ForType for_type)
Returns true if for_type executes for loop iterations in parallel.
void destroy< IRNode >(const IRNode *t)
Definition: Expr.h:121
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition: Expr.h:25
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Type BFloat(int bits, int lanes=1)
Construct a floating-point type in the bfloat format.
Definition: Type.h:541
Type UInt(int bits, int lanes=1)
Constructing an unsigned integer type.
Definition: Type.h:531
Type Float(int bits, int lanes=1)
Construct a floating-point type.
Definition: Type.h:536
@ Internal
Not visible externally, similar to 'static' linkage in C.
Type Int(int bits, int lanes=1)
Constructing a signed integer type.
Definition: Type.h:526
std::vector< Range > Region
A multi-dimensional box.
Definition: Expr.h:344
MemoryType
An enum describing different address spaces to be used with Func::store_in.
Definition: Expr.h:347
@ Auto
Let Halide select a storage type automatically.
@ Register
Register memory.
@ Stack
Stack memory.
@ VTCM
Vector Tightly Coupled Memory.
@ AMXTile
AMX Tile register for X86.
@ LockedCache
Allocate Locked Cache Memory to act as local memory.
@ Heap
Heap/global memory.
@ GPUTexture
Allocation is stored in GPU texture memory and accessed through hardware sampler.
@ GPUShared
Allocation is stored in GPU shared memory.
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT8_TYPE__ uint8_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
signed __INT16_TYPE__ int16_t
signed __INT8_TYPE__ int8_t
This lets you use an Expr as a key in a map of the form map<Expr, Foo, ExprCompare>
Definition: Expr.h:328
bool operator()(const Expr &a, const Expr &b) const
Definition: Expr.h:329
A fragment of Halide syntax.
Definition: Expr.h:257
Expr(float x)
Definition: Expr.h:300
HALIDE_ALWAYS_INLINE Expr()=default
Make an undefined expression.
Expr(int32_t x)
Definition: Expr.h:276
Expr(bfloat16_t x)
Definition: Expr.h:297
Expr(uint32_t x)
Definition: Expr.h:288
Expr(const std::string &s)
Make an expression representing a const string (i.e.
Definition: Expr.h:309
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition: Expr.h:321
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition: Expr.h:315
Expr(int64_t x)
Definition: Expr.h:279
Expr(int16_t x)
Definition: Expr.h:273
Expr(uint64_t x)
Definition: Expr.h:291
Expr(uint16_t x)
Definition: Expr.h:285
Expr(double x)
Definition: Expr.h:303
Expr(int8_t x)
Make an expression representing numeric constants of various types.
Definition: Expr.h:270
HALIDE_ALWAYS_INLINE Expr(const Internal::BaseExprNode *n)
Make an expression from a concrete expression node pointer (e.g.
Definition: Expr.h:264
Expr(uint8_t x)
Definition: Expr.h:282
Expr(float16_t x)
Definition: Expr.h:294
The sum of two expressions.
Definition: IR.h:48
Allocate a scratch area called with the given name, type, and size.
Definition: IR.h:363
Logical and - are both expressions true.
Definition: IR.h:167
If the 'condition' is false, then evaluate and return the message, which should be a call to an error...
Definition: IR.h:286
Lock all the Store nodes in the body statement.
Definition: IR.h:911
A base class for expression nodes.
Definition: Expr.h:142
virtual Expr mutate_expr(IRMutator *v) const =0
BaseExprNode(IRNodeType t)
Definition: Expr.h:143
IR nodes are split into expressions and statements.
Definition: Expr.h:133
BaseStmtNode(IRNodeType t)
Definition: Expr.h:134
virtual Stmt mutate_stmt(IRMutator *v) const =0
A sequence of statements to be executed in-order.
Definition: IR.h:434
A vector with 'lanes' elements, in which every element is 'value'.
Definition: IR.h:251
A function call.
Definition: IR.h:482
The actual IR nodes begin here.
Definition: IR.h:29
The ratio of two expressions.
Definition: IR.h:75
Is the first expression equal to the second.
Definition: IR.h:113
Evaluate and discard an expression, presumably because it has some side-effect.
Definition: IR.h:468
We use the "curiously recurring template pattern" to avoid duplicated code in the IR Nodes.
Definition: Expr.h:157
~ExprNode() override=default
Expr mutate_expr(IRMutator *v) const override
void accept(IRVisitor *v) const override
We use the visitor pattern to traverse IR nodes throughout the compiler, so we have a virtual accept ...
Floating point constants.
Definition: Expr.h:235
static const IRNodeType _node_type
Definition: Expr.h:240
static const FloatImm * make(Type t, double value)
A for loop.
Definition: IR.h:788
A pair of statements executed concurrently.
Definition: IR.h:449
Is the first expression greater than or equal to the second.
Definition: IR.h:158
Is the first expression greater than the second.
Definition: IR.h:149
IR nodes are passed around opaque handles to them.
Definition: Expr.h:179
void accept(IRVisitor *v) const
Dispatch to the correct visitor method for this node.
Definition: Expr.h:191
HALIDE_ALWAYS_INLINE IRHandle()=default
const T * as() const
Downcast this ir node to its actual type (e.g.
Definition: Expr.h:204
IRNodeType node_type() const
Definition: Expr.h:211
HALIDE_ALWAYS_INLINE IRHandle(const IRNode *p)
Definition: Expr.h:184
The abstract base classes for a node in the Halide IR.
Definition: Expr.h:83
virtual ~IRNode()=default
virtual void accept(IRVisitor *v) const =0
We use the visitor pattern to traverse IR nodes throughout the compiler, so we have a virtual accept ...
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition: Expr.h:112
RefCount ref_count
These classes are all managed with intrusive reference counting, so we also track a reference count.
Definition: Expr.h:100
IRNode(IRNodeType t)
Definition: Expr.h:90
An if-then-else block.
Definition: IR.h:458
Integer constants.
Definition: Expr.h:217
static const IRNodeType _node_type
Definition: Expr.h:222
static const IntImm * make(Type t, int64_t value)
Intrusive shared pointers have a reference count (a RefCount object) stored in the class itself.
Definition: IntrusivePtr.h:68
Is the first expression less than or equal to the second.
Definition: IR.h:140
Is the first expression less than the second.
Definition: IR.h:131
A let expression, like you might find in a functional language.
Definition: IR.h:263
The statement form of a let node.
Definition: IR.h:274
Load a value from a named symbol if predicate is true.
Definition: IR.h:209
The greater of two values.
Definition: IR.h:104
The lesser of two values.
Definition: IR.h:95
The remainder of a / b.
Definition: IR.h:86
The product of two expressions.
Definition: IR.h:66
Is the first expression not equal to the second.
Definition: IR.h:122
Logical not - true if the expression false.
Definition: IR.h:185
Logical or - is at least one of the expression true.
Definition: IR.h:176
Represent a multi-dimensional region of a Func or an ImageParam that needs to be prefetched.
Definition: IR.h:888
This node is a helpful annotation to do with permissions.
Definition: IR.h:307
This defines the value of a function at a multi-dimensional location.
Definition: IR.h:346
A linear ramp vector node.
Definition: IR.h:239
Allocate a multi-dimensional buffer of the given type and size.
Definition: IR.h:419
Reinterpret value as another type, without affecting any of the bits (on little-endian systems).
Definition: IR.h:39
A ternary operator.
Definition: IR.h:196
Construct a new vector by taking elements from another sequence of vectors.
Definition: IR.h:819
This lets you use a Stmt as a key in a map of the form map<Stmt, Foo, Stmt::Compare>
Definition: Expr.h:432
bool operator()(const Stmt &a, const Stmt &b) const
Definition: Expr.h:433
A reference-counted handle to a statement node.
Definition: Expr.h:418
Stmt(const BaseStmtNode *n)
Definition: Expr.h:420
HALIDE_ALWAYS_INLINE const BaseStmtNode * get() const
Override get() to return a BaseStmtNode * instead of an IRNode *.
Definition: Expr.h:426
void accept(IRVisitor *v) const override
We use the visitor pattern to traverse IR nodes throughout the compiler, so we have a virtual accept ...
Stmt mutate_stmt(IRMutator *v) const override
~StmtNode() override=default
Store a 'value' to the buffer called 'name' at a given 'index' if 'predicate' is true.
Definition: IR.h:325
String constants.
Definition: Expr.h:244
static const StringImm * make(const std::string &val)
static const IRNodeType _node_type
Definition: Expr.h:249
The difference of two expressions.
Definition: IR.h:57
Unsigned integer constants.
Definition: Expr.h:226
static const IRNodeType _node_type
Definition: Expr.h:231
static const UIntImm * make(Type t, uint64_t value)
A named variable.
Definition: IR.h:741
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:929
A single-dimensional span.
Definition: Expr.h:336
Range()=default
Expr min
Definition: Expr.h:337
Expr extent
Definition: Expr.h:337
Range(const Expr &min_in, const Expr &extent_in)
Types in the halide type system.
Definition: Type.h:276
Class that provides a type that implements half precision floating point using the bfloat16 format.
Definition: Float16.h:158
Class that provides a type that implements half precision floating point (IEEE754 2008 binary16) in s...
Definition: Float16.h:17