Halide 16.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
IRMatch.h
Go to the documentation of this file.
1#ifndef HALIDE_IR_MATCH_H
2#define HALIDE_IR_MATCH_H
3
4/** \file
5 * Defines a method to match a fragment of IR against a pattern containing wildcards
6 */
7
8#include <map>
9#include <random>
10#include <set>
11#include <vector>
12
13#include "IR.h"
14#include "IREquality.h"
15#include "IROperator.h"
16
17namespace Halide {
18namespace Internal {
19
20/** Does the first expression have the same structure as the second?
21 * Variables in the first expression with the name * are interpreted
22 * as wildcards, and their matching equivalent in the second
23 * expression is placed in the vector give as the third argument.
24 * Wildcards require the types to match. For the type bits and width,
25 * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
26 * integer vectors of any width (including scalars), and a UInt(0, 0)
27 * will match any unsigned integer type.
28 *
29 * For example:
30 \code
31 Expr x = Variable::make(Int(32), "*");
32 match(x + x, 3 + (2*k), result)
33 \endcode
34 * should return true, and set result[0] to 3 and
35 * result[1] to 2*k.
36 */
37bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
38
39/** Does the first expression have the same structure as the second?
40 * Variables are matched consistently. The first time a variable is
41 * matched, it assumes the value of the matching part of the second
42 * expression. Subsequent matches must be equal to the first match.
43 *
44 * For example:
45 \code
46 Var x("x"), y("y");
47 match(x*(x + y), a*(a + b), result)
48 \endcode
49 * should return true, and set result["x"] = a, and result["y"] = b.
50 */
51bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
52
53/** Rewrite the expression x to have `lanes` lanes. This is useful
54 * for substituting the results of expr_match into a pattern expression. */
55Expr with_lanes(const Expr &x, int lanes);
56
58
59/** An alternative template-metaprogramming approach to expression
60 * matching. Potentially more efficient. We lift the expression
61 * pattern into a type, and then use force-inlined functions to
62 * generate efficient matching and reconstruction code for any
63 * pattern. Pattern elements are either one of the classes in the
64 * namespace IRMatcher, or are non-null Exprs (represented as
65 * BaseExprNode &).
66 *
67 * Pattern elements that are fully specified by their pattern can be
68 * built into an expression using the make method. Some patterns,
69 * such as a broadcast that matches any number of lanes, don't have
70 * enough information to recreate an Expr.
71 */
72namespace IRMatcher {
73
74constexpr int max_wild = 6;
75
76static const halide_type_t i64_type = {halide_type_int, 64, 1};
77
78/** To save stack space, the matcher objects are largely stateless and
79 * immutable. This state object is built up during matching and then
80 * consumed when constructing a replacement Expr.
81 */
85
86 // values of the lanes field with special meaning.
87 static constexpr uint16_t signed_integer_overflow = 0x8000;
88 static constexpr uint16_t special_values_mask = 0x8000; // currently only one
89
91
93 void set_binding(int i, const BaseExprNode &n) noexcept {
94 bindings[i] = &n;
95 }
96
98 const BaseExprNode *get_binding(int i) const noexcept {
99 return bindings[i];
100 }
101
103 void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
104 bound_const[i].u.i64 = s;
105 bound_const_type[i] = t;
106 }
107
109 void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
110 bound_const[i].u.u64 = u;
111 bound_const_type[i] = t;
112 }
113
115 void set_bound_const(int i, double f, halide_type_t t) noexcept {
116 bound_const[i].u.f64 = f;
117 bound_const_type[i] = t;
118 }
119
122 bound_const[i] = val;
123 bound_const_type[i] = t;
124 }
125
127 void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
128 val = bound_const[i];
129 type = bound_const_type[i];
130 }
131
133 // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
136};
137
138template<typename T,
139 typename = typename std::remove_reference<T>::type::pattern_tag>
141 struct type {};
142};
143
144template<typename T>
145struct bindings {
146 constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
147};
148
150 const uint16_t flags = ty.lanes & MatcherState::special_values_mask;
151 ty.lanes &= ~MatcherState::special_values_mask;
154 }
155 // unreachable
156 return Expr();
157}
158
164 }
165
166 const int lanes = scalar_type.lanes;
167 scalar_type.lanes = 1;
168
169 Expr e;
170 switch (scalar_type.code) {
171 case halide_type_int:
172 e = IntImm::make(scalar_type, val.u.i64);
173 break;
174 case halide_type_uint:
175 e = UIntImm::make(scalar_type, val.u.u64);
176 break;
180 break;
181 default:
182 // Unreachable
183 return Expr();
184 }
185 if (lanes > 1) {
186 e = Broadcast::make(e, lanes);
187 }
188 return e;
189}
190
191bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept;
192
193// A fast version of expression equality that assumes a well-typed non-null expression tree.
195bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept {
196 // Early out
197 return (&a == &b) ||
198 ((a.type == b.type) &&
199 (a.node_type == b.node_type) &&
200 equal_helper(a, b));
201}
202
203// A pattern that matches a specific expression
205 struct pattern_tag {};
206
207 constexpr static uint32_t binds = 0;
208
209 // What is the weakest and strongest IR node this could possibly be
212 constexpr static bool canonical = true;
213
215
216 template<uint32_t bound>
217 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
218 return equal(expr, e);
219 }
220
223 return Expr(&expr);
224 }
225
226 constexpr static bool foldable = false;
227};
228
229inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
230 s << Expr(&e.expr);
231 return s;
232}
233
234template<int i>
236 struct pattern_tag {};
237
238 constexpr static uint32_t binds = 1 << i;
239
242 constexpr static bool canonical = true;
243
244 template<uint32_t bound>
245 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
246 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
247 const BaseExprNode *op = &e;
248 if (op->node_type == IRNodeType::Broadcast) {
249 op = ((const Broadcast *)op)->value.get();
250 }
251 if (op->node_type != IRNodeType::IntImm) {
252 return false;
253 }
254 int64_t value = ((const IntImm *)op)->value;
255 if (bound & binds) {
257 halide_type_t type;
258 state.get_bound_const(i, val, type);
259 return (halide_type_t)e.type == type && value == val.u.i64;
260 }
261 state.set_bound_const(i, value, e.type);
262 return true;
263 }
264
265 template<uint32_t bound>
266 HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
267 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
268 if (bound & binds) {
270 halide_type_t type;
271 state.get_bound_const(i, val, type);
272 return type == i64_type && value == val.u.i64;
273 }
274 state.set_bound_const(i, value, i64_type);
275 return true;
276 }
277
281 halide_type_t type;
282 state.get_bound_const(i, val, type);
283 return make_const_expr(val, type);
284 }
285
286 constexpr static bool foldable = true;
287
290 state.get_bound_const(i, val, ty);
291 }
292};
293
294template<int i>
295std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
296 s << "ci" << i;
297 return s;
298}
299
300template<int i>
302 struct pattern_tag {};
303
304 constexpr static uint32_t binds = 1 << i;
305
308 constexpr static bool canonical = true;
309
310 template<uint32_t bound>
311 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
312 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
313 const BaseExprNode *op = &e;
314 if (op->node_type == IRNodeType::Broadcast) {
315 op = ((const Broadcast *)op)->value.get();
316 }
317 if (op->node_type != IRNodeType::UIntImm) {
318 return false;
319 }
320 uint64_t value = ((const UIntImm *)op)->value;
321 if (bound & binds) {
323 halide_type_t type;
324 state.get_bound_const(i, val, type);
325 return (halide_type_t)e.type == type && value == val.u.u64;
326 }
327 state.set_bound_const(i, value, e.type);
328 return true;
329 }
330
334 halide_type_t type;
335 state.get_bound_const(i, val, type);
336 return make_const_expr(val, type);
337 }
338
339 constexpr static bool foldable = true;
340
343 state.get_bound_const(i, val, ty);
344 }
345};
346
347template<int i>
348std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
349 s << "cu" << i;
350 return s;
351}
352
353template<int i>
355 struct pattern_tag {};
356
357 constexpr static uint32_t binds = 1 << i;
358
361 constexpr static bool canonical = true;
362
363 template<uint32_t bound>
364 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
365 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
366 const BaseExprNode *op = &e;
367 if (op->node_type == IRNodeType::Broadcast) {
368 op = ((const Broadcast *)op)->value.get();
369 }
370 if (op->node_type != IRNodeType::FloatImm) {
371 return false;
372 }
373 double value = ((const FloatImm *)op)->value;
374 if (bound & binds) {
376 halide_type_t type;
377 state.get_bound_const(i, val, type);
378 return (halide_type_t)e.type == type && value == val.u.f64;
379 }
380 state.set_bound_const(i, value, e.type);
381 return true;
382 }
383
387 halide_type_t type;
388 state.get_bound_const(i, val, type);
389 return make_const_expr(val, type);
390 }
391
392 constexpr static bool foldable = true;
393
396 state.get_bound_const(i, val, ty);
397 }
398};
399
400template<int i>
401std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
402 s << "cf" << i;
403 return s;
404}
405
406// Matches and binds to any constant Expr. Does not support constant-folding.
407template<int i>
408struct WildConst {
409 struct pattern_tag {};
410
411 constexpr static uint32_t binds = 1 << i;
412
415 constexpr static bool canonical = true;
416
417 template<uint32_t bound>
418 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
419 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
420 const BaseExprNode *op = &e;
421 if (op->node_type == IRNodeType::Broadcast) {
422 op = ((const Broadcast *)op)->value.get();
423 }
424 switch (op->node_type) {
426 return WildConstInt<i>().template match<bound>(e, state);
428 return WildConstUInt<i>().template match<bound>(e, state);
430 return WildConstFloat<i>().template match<bound>(e, state);
431 default:
432 return false;
433 }
434 }
435
436 template<uint32_t bound>
437 HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
438 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
439 return WildConstInt<i>().template match<bound>(e, state);
440 }
441
445 halide_type_t type;
446 state.get_bound_const(i, val, type);
447 return make_const_expr(val, type);
448 }
449
450 constexpr static bool foldable = true;
451
454 state.get_bound_const(i, val, ty);
455 }
456};
457
458template<int i>
459std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
460 s << "c" << i;
461 return s;
462}
463
464// Matches and binds to any Expr
465template<int i>
466struct Wild {
467 struct pattern_tag {};
468
469 constexpr static uint32_t binds = 1 << (i + 16);
470
473 constexpr static bool canonical = true;
474
475 template<uint32_t bound>
476 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
477 if (bound & binds) {
478 return equal(*state.get_binding(i), e);
479 }
480 state.set_binding(i, e);
481 return true;
482 }
483
486 return state.get_binding(i);
487 }
488
489 constexpr static bool foldable = false;
490};
491
492template<int i>
493std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
494 s << "_" << i;
495 return s;
496}
497
498// Matches a specific constant or broadcast of that constant. The
499// constant must be representable as an int64_t.
501 struct pattern_tag {};
503
504 constexpr static uint32_t binds = 0;
505
508 constexpr static bool canonical = true;
509
512 : v(v) {
513 }
514
515 template<uint32_t bound>
516 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
517 const BaseExprNode *op = &e;
518 if (e.node_type == IRNodeType::Broadcast) {
519 op = ((const Broadcast *)op)->value.get();
520 }
521 switch (op->node_type) {
523 return ((const IntImm *)op)->value == (int64_t)v;
525 return ((const UIntImm *)op)->value == (uint64_t)v;
527 return ((const FloatImm *)op)->value == (double)v;
528 default:
529 return false;
530 }
531 }
532
533 template<uint32_t bound>
534 HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
535 return v == val;
536 }
537
538 template<uint32_t bound>
539 HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
540 return v == b.v;
541 }
542
545 return make_const(type_hint, v);
546 }
547
548 constexpr static bool foldable = true;
549
552 // Assume type is already correct
553 switch (ty.code) {
554 case halide_type_int:
555 val.u.i64 = v;
556 break;
557 case halide_type_uint:
558 val.u.u64 = (uint64_t)v;
559 break;
562 val.u.f64 = (double)v;
563 break;
564 default:
565 // Unreachable
566 ;
567 }
568 }
569};
570
574
575// Convert a provided pattern, expr, or constant int into the internal
576// representation we use in the matcher trees.
577template<typename T,
578 typename = typename std::decay<T>::type::pattern_tag>
580 return t;
581}
584 return IntLiteral{x};
585}
586
587template<typename T>
589 static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
590 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
591}
592
594 return {*e.get()};
595}
596
597// Helpers to deref SpecificExprs to const BaseExprNode & rather than
598// passing them by value anywhere (incurring lots of refcounting)
599template<typename T,
600 // T must be a pattern node
601 typename = typename std::decay<T>::type::pattern_tag,
602 // But T may not be SpecificExpr
603 typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
605 return t;
606}
607
610 return e.expr;
611}
612
613inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
614 s << op.v;
615 return s;
616}
617
618template<typename Op>
620
621template<typename Op>
623
624template<typename Op>
625double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
626
627constexpr bool commutative(IRNodeType t) {
628 return (t == IRNodeType::Add ||
629 t == IRNodeType::Mul ||
630 t == IRNodeType::And ||
631 t == IRNodeType::Or ||
632 t == IRNodeType::Min ||
633 t == IRNodeType::Max ||
634 t == IRNodeType::EQ ||
635 t == IRNodeType::NE);
636}
637
638// Matches one of the binary operators
639template<typename Op, typename A, typename B>
640struct BinOp {
641 struct pattern_tag {};
642 A a;
643 B b;
644
646
647 constexpr static IRNodeType min_node_type = Op::_node_type;
648 constexpr static IRNodeType max_node_type = Op::_node_type;
649
650 // For commutative bin ops, we expect the weaker IR node type on
651 // the right. That is, for the rule to be canonical it must be
652 // possible that A is at least as strong as B.
653 constexpr static bool canonical =
654 A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
655
656 template<uint32_t bound>
657 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
658 if (e.node_type != Op::_node_type) {
659 return false;
660 }
661 const Op &op = (const Op &)e;
662 return (a.template match<bound>(*op.a.get(), state) &&
663 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
664 }
665
666 template<uint32_t bound, typename Op2, typename A2, typename B2>
667 HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
668 return (std::is_same<Op, Op2>::value &&
669 a.template match<bound>(unwrap(op.a), state) &&
670 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
671 }
672
673 constexpr static bool foldable = A::foldable && B::foldable;
674
678 if (std::is_same<A, IntLiteral>::value) {
679 b.make_folded_const(val_b, ty, state);
680 if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
681 (std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
682 // Short circuit
683 val = val_b;
684 return;
685 }
686 const uint16_t l = ty.lanes;
687 a.make_folded_const(val_a, ty, state);
688 ty.lanes |= l; // Make sure the overflow bits are sticky
689 } else {
690 a.make_folded_const(val_a, ty, state);
691 if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
692 (std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
693 // Short circuit
694 val = val_a;
695 return;
696 }
697 const uint16_t l = ty.lanes;
698 b.make_folded_const(val_b, ty, state);
699 ty.lanes |= l;
700 }
701 switch (ty.code) {
702 case halide_type_int:
703 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
704 break;
705 case halide_type_uint:
706 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
707 break;
710 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
711 break;
712 default:
713 // unreachable
714 ;
715 }
716 }
717
719 Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
720 Expr ea, eb;
721 if (std::is_same<A, IntLiteral>::value) {
722 eb = b.make(state, type_hint);
723 ea = a.make(state, eb.type());
724 } else {
725 ea = a.make(state, type_hint);
726 eb = b.make(state, ea.type());
727 }
728 // We sometimes mix vectors and scalars in the rewrite rules,
729 // so insert a broadcast if necessary.
730 if (ea.type().is_vector() && !eb.type().is_vector()) {
731 eb = Broadcast::make(eb, ea.type().lanes());
732 }
733 if (eb.type().is_vector() && !ea.type().is_vector()) {
734 ea = Broadcast::make(ea, eb.type().lanes());
735 }
736 return Op::make(std::move(ea), std::move(eb));
737 }
738};
739
740template<typename Op>
742
743template<typename Op>
745
746template<typename Op>
747uint64_t constant_fold_cmp_op(double, double) noexcept;
748
749// Matches one of the comparison operators
750template<typename Op, typename A, typename B>
751struct CmpOp {
752 struct pattern_tag {};
753 A a;
754 B b;
755
757
758 constexpr static IRNodeType min_node_type = Op::_node_type;
759 constexpr static IRNodeType max_node_type = Op::_node_type;
760 constexpr static bool canonical = (A::canonical &&
761 B::canonical &&
762 (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
763 (Op::_node_type != IRNodeType::GE) &&
764 (Op::_node_type != IRNodeType::GT));
765
766 template<uint32_t bound>
767 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
768 if (e.node_type != Op::_node_type) {
769 return false;
770 }
771 const Op &op = (const Op &)e;
772 return (a.template match<bound>(*op.a.get(), state) &&
773 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
774 }
775
776 template<uint32_t bound, typename Op2, typename A2, typename B2>
777 HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
778 return (std::is_same<Op, Op2>::value &&
779 a.template match<bound>(unwrap(op.a), state) &&
780 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
781 }
782
783 constexpr static bool foldable = A::foldable && B::foldable;
784
788 // If one side is an untyped const, evaluate the other side first to get a type hint.
789 if (std::is_same<A, IntLiteral>::value) {
790 b.make_folded_const(val_b, ty, state);
791 const uint16_t l = ty.lanes;
792 a.make_folded_const(val_a, ty, state);
793 ty.lanes |= l;
794 } else {
795 a.make_folded_const(val_a, ty, state);
796 const uint16_t l = ty.lanes;
797 b.make_folded_const(val_b, ty, state);
798 ty.lanes |= l;
799 }
800 switch (ty.code) {
801 case halide_type_int:
802 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
803 break;
804 case halide_type_uint:
805 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
806 break;
809 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
810 break;
811 default:
812 // unreachable
813 ;
814 }
815 ty.code = halide_type_uint;
816 ty.bits = 1;
817 }
818
821 // If one side is an untyped const, evaluate the other side first to get a type hint.
822 Expr ea, eb;
823 if (std::is_same<A, IntLiteral>::value) {
824 eb = b.make(state, {});
825 ea = a.make(state, eb.type());
826 } else {
827 ea = a.make(state, {});
828 eb = b.make(state, ea.type());
829 }
830 // We sometimes mix vectors and scalars in the rewrite rules,
831 // so insert a broadcast if necessary.
832 if (ea.type().is_vector() && !eb.type().is_vector()) {
833 eb = Broadcast::make(eb, ea.type().lanes());
834 }
835 if (eb.type().is_vector() && !ea.type().is_vector()) {
836 ea = Broadcast::make(ea, eb.type().lanes());
837 }
838 return Op::make(std::move(ea), std::move(eb));
839 }
840};
841
842template<typename A, typename B>
843std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
844 s << "(" << op.a << " + " << op.b << ")";
845 return s;
846}
847
848template<typename A, typename B>
849std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
850 s << "(" << op.a << " - " << op.b << ")";
851 return s;
852}
853
854template<typename A, typename B>
855std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
856 s << "(" << op.a << " * " << op.b << ")";
857 return s;
858}
859
860template<typename A, typename B>
861std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
862 s << "(" << op.a << " / " << op.b << ")";
863 return s;
864}
865
866template<typename A, typename B>
867std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
868 s << "(" << op.a << " && " << op.b << ")";
869 return s;
870}
871
872template<typename A, typename B>
873std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
874 s << "(" << op.a << " || " << op.b << ")";
875 return s;
876}
877
878template<typename A, typename B>
879std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
880 s << "min(" << op.a << ", " << op.b << ")";
881 return s;
882}
883
884template<typename A, typename B>
885std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
886 s << "max(" << op.a << ", " << op.b << ")";
887 return s;
888}
889
890template<typename A, typename B>
891std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
892 s << "(" << op.a << " <= " << op.b << ")";
893 return s;
894}
895
896template<typename A, typename B>
897std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
898 s << "(" << op.a << " < " << op.b << ")";
899 return s;
900}
901
902template<typename A, typename B>
903std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
904 s << "(" << op.a << " >= " << op.b << ")";
905 return s;
906}
907
908template<typename A, typename B>
909std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
910 s << "(" << op.a << " > " << op.b << ")";
911 return s;
912}
913
914template<typename A, typename B>
915std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
916 s << "(" << op.a << " == " << op.b << ")";
917 return s;
918}
919
920template<typename A, typename B>
921std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
922 s << "(" << op.a << " != " << op.b << ")";
923 return s;
924}
925
926template<typename A, typename B>
927std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
928 s << "(" << op.a << " % " << op.b << ")";
929 return s;
930}
931
932template<typename A, typename B>
933HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
936 return {pattern_arg(a), pattern_arg(b)};
937}
938
939template<typename A, typename B>
940HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) {
943 return IRMatcher::operator+(a, b);
944}
945
946template<>
948 t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
949 int dead_bits = 64 - t.bits;
950 // Drop the high bits then sign-extend them back
951 return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
952}
953
954template<>
956 uint64_t ones = (uint64_t)(-1);
957 return (a + b) & (ones >> (64 - t.bits));
958}
959
960template<>
961HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
962 return a + b;
963}
964
965template<typename A, typename B>
966HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
969 return {pattern_arg(a), pattern_arg(b)};
970}
971
972template<typename A, typename B>
973HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) {
976 return IRMatcher::operator-(a, b);
977}
978
979template<>
981 t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
982 // Drop the high bits then sign-extend them back
983 int dead_bits = 64 - t.bits;
984 return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
985}
986
987template<>
989 uint64_t ones = (uint64_t)(-1);
990 return (a - b) & (ones >> (64 - t.bits));
991}
992
993template<>
994HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
995 return a - b;
996}
997
998template<typename A, typename B>
999HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1002 return {pattern_arg(a), pattern_arg(b)};
1003}
1004
1005template<typename A, typename B>
1006HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) {
1009 return IRMatcher::operator*(a, b);
1010}
1011
1012template<>
1014 t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
1015 int dead_bits = 64 - t.bits;
1016 // Drop the high bits then sign-extend them back
1017 return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
1018}
1019
1020template<>
1022 uint64_t ones = (uint64_t)(-1);
1023 return (a * b) & (ones >> (64 - t.bits));
1024}
1025
1026template<>
1027HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
1028 return a * b;
1029}
1030
1031template<typename A, typename B>
1032HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1035 return {pattern_arg(a), pattern_arg(b)};
1036}
1037
1038template<typename A, typename B>
1039HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
1040 return IRMatcher::operator/(a, b);
1041}
1042
1043template<>
1047
1048template<>
1052
1053template<>
1054HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
1055 return div_imp(a, b);
1056}
1057
1058template<typename A, typename B>
1059HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1062 return {pattern_arg(a), pattern_arg(b)};
1063}
1064
1065template<typename A, typename B>
1066HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
1069 return IRMatcher::operator%(a, b);
1070}
1071
1072template<>
1076
1077template<>
1081
1082template<>
1083HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
1084 return mod_imp(a, b);
1085}
1086
1087template<typename A, typename B>
1088HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1091 return {pattern_arg(a), pattern_arg(b)};
1092}
1093
1094template<>
1096 return std::min(a, b);
1097}
1098
1099template<>
1101 return std::min(a, b);
1102}
1103
1104template<>
1105HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
1106 return std::min(a, b);
1107}
1108
1109template<typename A, typename B>
1110HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1113 return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
1114}
1115
1116template<>
1118 return std::max(a, b);
1119}
1120
1121template<>
1123 return std::max(a, b);
1124}
1125
1126template<>
1127HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
1128 return std::max(a, b);
1129}
1130
1131template<typename A, typename B>
1132HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1133 return {pattern_arg(a), pattern_arg(b)};
1134}
1135
1136template<typename A, typename B>
1137HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
1138 return IRMatcher::operator<(a, b);
1139}
1140
1141template<>
1145
1146template<>
1150
1151template<>
1153 return a < b;
1154}
1155
1156template<typename A, typename B>
1157HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1158 return {pattern_arg(a), pattern_arg(b)};
1159}
1160
1161template<typename A, typename B>
1162HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
1163 return IRMatcher::operator>(a, b);
1164}
1165
1166template<>
1170
1171template<>
1175
1176template<>
1178 return a > b;
1179}
1180
1181template<typename A, typename B>
1182HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1183 return {pattern_arg(a), pattern_arg(b)};
1184}
1185
1186template<typename A, typename B>
1187HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
1188 return IRMatcher::operator<=(a, b);
1189}
1190
1191template<>
1193 return a <= b;
1194}
1195
1196template<>
1200
1201template<>
1203 return a <= b;
1204}
1205
1206template<typename A, typename B>
1207HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1208 return {pattern_arg(a), pattern_arg(b)};
1209}
1210
1211template<typename A, typename B>
1212HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
1213 return IRMatcher::operator>=(a, b);
1214}
1215
1216template<>
1218 return a >= b;
1219}
1220
1221template<>
1225
1226template<>
1228 return a >= b;
1229}
1230
1231template<typename A, typename B>
1232HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1233 return {pattern_arg(a), pattern_arg(b)};
1234}
1235
1236template<typename A, typename B>
1237HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
1238 return IRMatcher::operator==(a, b);
1239}
1240
1241template<>
1243 return a == b;
1244}
1245
1246template<>
1250
1251template<>
1253 return a == b;
1254}
1255
1256template<typename A, typename B>
1257HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1258 return {pattern_arg(a), pattern_arg(b)};
1259}
1260
1261template<typename A, typename B>
1262HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
1263 return IRMatcher::operator!=(a, b);
1264}
1265
1266template<>
1268 return a != b;
1269}
1270
1271template<>
1275
1276template<>
1278 return a != b;
1279}
1280
1281template<typename A, typename B>
1282HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1283 return {pattern_arg(a), pattern_arg(b)};
1284}
1285
1286template<typename A, typename B>
1287HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
1288 return IRMatcher::operator||(a, b);
1289}
1290
1291template<>
1293 return (a | b) & 1;
1294}
1295
1296template<>
1298 return (a | b) & 1;
1299}
1300
1301template<>
1302HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
1303 // Unreachable, as it would be a type mismatch.
1304 return 0;
1305}
1306
1307template<typename A, typename B>
1308HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1309 return {pattern_arg(a), pattern_arg(b)};
1310}
1311
1312template<typename A, typename B>
1313HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
1314 return IRMatcher::operator&&(a, b);
1315}
1316
1317template<>
1319 return a & b & 1;
1320}
1321
1322template<>
1326
1327template<>
1328HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
1329 // Unreachable
1330 return 0;
1331}
1332
1333constexpr inline uint32_t bitwise_or_reduce() {
1334 return 0;
1335}
1336
1337template<typename... Args>
1338constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
1339 return first | bitwise_or_reduce(rest...);
1340}
1341
1342constexpr inline bool and_reduce() {
1343 return true;
1344}
1345
1346template<typename... Args>
1347constexpr bool and_reduce(bool first, Args... rest) {
1348 return first && and_reduce(rest...);
1349}
1350
1351// TODO: this can be replaced with std::min() once we require C++14 or later
1352constexpr int const_min(int a, int b) {
1353 return a < b ? a : b;
1354}
1355
1356template<typename... Args>
1357struct Intrin {
1358 struct pattern_tag {};
1360 std::tuple<Args...> args;
1361 // The type of the output of the intrinsic node.
1362 // Only necessary in cases where it can't be inferred
1363 // from the input types (e.g. saturating_cast).
1365
1367
1370 constexpr static bool canonical = and_reduce((Args::canonical)...);
1371
1372 template<int i,
1373 uint32_t bound,
1374 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1375 HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
1376 using T = decltype(std::get<i>(args));
1377 return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
1379 }
1380
1381 template<int i, uint32_t binds>
1382 HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
1383 return true;
1384 }
1385
1386 template<uint32_t bound>
1387 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1388 if (e.node_type != IRNodeType::Call) {
1389 return false;
1390 }
1391 const Call &c = (const Call &)e;
1392 return (c.is_intrinsic(intrin) &&
1393 ((optional_type_hint == Type()) || optional_type_hint == e.type) &&
1394 match_args<0, bound>(0, c, state));
1395 }
1396
1397 template<int i,
1398 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1399 HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
1401 if (i + 1 < sizeof...(Args)) {
1402 s << ", ";
1403 }
1404 print_args<i + 1>(0, s);
1405 }
1406
1407 template<int i>
1408 HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
1409 }
1410
1412 void print_args(std::ostream &s) const {
1413 print_args<0>(0, s);
1414 }
1415
1418 Expr arg0 = std::get<0>(args).make(state, type_hint);
1419 if (intrin == Call::likely) {
1420 return likely(arg0);
1421 } else if (intrin == Call::likely_if_innermost) {
1422 return likely_if_innermost(arg0);
1423 } else if (intrin == Call::abs) {
1424 return abs(arg0);
1425 } else if (intrin == Call::saturating_cast) {
1427 }
1428
1429 Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
1430 if (intrin == Call::absd) {
1431 return absd(arg0, arg1);
1432 } else if (intrin == Call::widen_right_add) {
1433 return widen_right_add(arg0, arg1);
1434 } else if (intrin == Call::widen_right_mul) {
1435 return widen_right_mul(arg0, arg1);
1436 } else if (intrin == Call::widen_right_sub) {
1437 return widen_right_sub(arg0, arg1);
1438 } else if (intrin == Call::widening_add) {
1439 return widening_add(arg0, arg1);
1440 } else if (intrin == Call::widening_sub) {
1441 return widening_sub(arg0, arg1);
1442 } else if (intrin == Call::widening_mul) {
1443 return widening_mul(arg0, arg1);
1444 } else if (intrin == Call::saturating_add) {
1445 return saturating_add(arg0, arg1);
1446 } else if (intrin == Call::saturating_sub) {
1447 return saturating_sub(arg0, arg1);
1448 } else if (intrin == Call::halving_add) {
1449 return halving_add(arg0, arg1);
1450 } else if (intrin == Call::halving_sub) {
1451 return halving_sub(arg0, arg1);
1452 } else if (intrin == Call::rounding_halving_add) {
1454 } else if (intrin == Call::shift_left) {
1455 return arg0 << arg1;
1456 } else if (intrin == Call::shift_right) {
1457 return arg0 >> arg1;
1458 } else if (intrin == Call::rounding_shift_left) {
1459 return rounding_shift_left(arg0, arg1);
1460 } else if (intrin == Call::rounding_shift_right) {
1462 }
1463
1464 Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
1466 return mul_shift_right(arg0, arg1, arg2);
1467 } else if (intrin == Call::rounding_mul_shift_right) {
1469 }
1470
1471 internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
1472 return Expr();
1473 }
1474
1475 constexpr static bool foldable = true;
1476
1479 // Assuming the args have the same type as the intrinsic is incorrect in
1480 // general. But for the intrinsics we can fold (just shifts), the LHS
1481 // has the same type as the intrinsic, and we can always treat the RHS
1482 // as a signed int, because we're using 64 bits for it.
1483 std::get<0>(args).make_folded_const(val, ty, state);
1486 // We can just directly get the second arg here, because we only want to
1487 // instantiate this method for shifts, which have two args.
1488 std::get<1>(args).make_folded_const(arg1, signed_ty, state);
1489
1490 if (intrin == Call::shift_left) {
1491 if (arg1.u.i64 < 0) {
1492 if (ty.code == halide_type_int) {
1493 // Arithmetic shift
1494 val.u.i64 >>= -arg1.u.i64;
1495 } else {
1496 // Logical shift
1497 val.u.u64 >>= -arg1.u.i64;
1498 }
1499 } else {
1500 val.u.u64 <<= arg1.u.i64;
1501 }
1502 } else if (intrin == Call::shift_right) {
1503 if (arg1.u.i64 > 0) {
1504 if (ty.code == halide_type_int) {
1505 // Arithmetic shift
1506 val.u.i64 >>= arg1.u.i64;
1507 } else {
1508 // Logical shift
1509 val.u.u64 >>= arg1.u.i64;
1510 }
1511 } else {
1512 val.u.u64 <<= -arg1.u.i64;
1513 }
1514 } else {
1515 internal_error << "Folding not implemented for intrinsic: " << intrin;
1516 }
1517 }
1518
1523};
1524
1525template<typename... Args>
1526std::ostream &operator<<(std::ostream &s, const Intrin<Args...> &op) {
1527 s << op.intrin << "(";
1528 op.print_args(s);
1529 s << ")";
1530 return s;
1531}
1532
1533template<typename... Args>
1534HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
1535 return {intrinsic_op, pattern_arg(args)...};
1536}
1537
1538template<typename A, typename B>
1539auto widen_right_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1541}
1542template<typename A, typename B>
1543auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1545}
1546template<typename A, typename B>
1547auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1549}
1550
1551template<typename A, typename B>
1552auto widening_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1554}
1555template<typename A, typename B>
1556auto widening_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1558}
1559template<typename A, typename B>
1560auto widening_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1562}
1563template<typename A, typename B>
1564auto saturating_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1566}
1567template<typename A, typename B>
1568auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1570}
1571template<typename A>
1572auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<decltype(pattern_arg(a))> {
1575 return p;
1576}
1577template<typename A, typename B>
1578auto halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1579 return {Call::halving_add, pattern_arg(a), pattern_arg(b)};
1580}
1581template<typename A, typename B>
1582auto halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1583 return {Call::halving_sub, pattern_arg(a), pattern_arg(b)};
1584}
1585template<typename A, typename B>
1586auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1588}
1589template<typename A, typename B>
1590auto shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1591 return {Call::shift_left, pattern_arg(a), pattern_arg(b)};
1592}
1593template<typename A, typename B>
1594auto shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1595 return {Call::shift_right, pattern_arg(a), pattern_arg(b)};
1596}
1597template<typename A, typename B>
1598auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1600}
1601template<typename A, typename B>
1602auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1604}
1605template<typename A, typename B, typename C>
1606auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1608}
1609template<typename A, typename B, typename C>
1610auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1612}
1613
1614template<typename A>
1615struct NotOp {
1616 struct pattern_tag {};
1617 A a;
1618
1620
1623 constexpr static bool canonical = A::canonical;
1624
1625 template<uint32_t bound>
1626 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1627 if (e.node_type != IRNodeType::Not) {
1628 return false;
1629 }
1630 const Not &op = (const Not &)e;
1631 return (a.template match<bound>(*op.a.get(), state));
1632 }
1633
1634 template<uint32_t bound, typename A2>
1635 HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
1636 return a.template match<bound>(unwrap(op.a), state);
1637 }
1638
1641 return Not::make(a.make(state, type_hint));
1642 }
1643
1644 constexpr static bool foldable = A::foldable;
1645
1646 template<typename A1 = A>
1648 a.make_folded_const(val, ty, state);
1649 val.u.u64 = ~val.u.u64;
1650 val.u.u64 &= 1;
1651 }
1652};
1653
1654template<typename A>
1655HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
1657 return {pattern_arg(a)};
1658}
1659
1660template<typename A>
1665
1666template<typename A>
1667inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1668 s << "!(" << op.a << ")";
1669 return s;
1670}
1671
1672template<typename C, typename T, typename F>
1673struct SelectOp {
1674 struct pattern_tag {};
1676 T t;
1678
1680
1683
1684 constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1685
1686 template<uint32_t bound>
1687 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1688 if (e.node_type != Select::_node_type) {
1689 return false;
1690 }
1691 const Select &op = (const Select &)e;
1692 return (c.template match<bound>(*op.condition.get(), state) &&
1693 t.template match<bound | bindings<C>::mask>(*op.true_value.get(), state) &&
1694 f.template match<bound | bindings<C>::mask | bindings<T>::mask>(*op.false_value.get(), state));
1695 }
1696 template<uint32_t bound, typename C2, typename T2, typename F2>
1697 HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
1698 return (c.template match<bound>(unwrap(instance.c), state) &&
1699 t.template match<bound | bindings<C>::mask>(unwrap(instance.t), state) &&
1700 f.template match<bound | bindings<C>::mask | bindings<T>::mask>(unwrap(instance.f), state));
1701 }
1702
1705 return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
1706 }
1707
1708 constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1709
1710 template<typename C1 = C>
1714 c.make_folded_const(c_val, c_ty, state);
1715 if ((c_val.u.u64 & 1) == 1) {
1716 t.make_folded_const(val, ty, state);
1717 } else {
1718 f.make_folded_const(val, ty, state);
1719 }
1721 }
1722};
1723
1724template<typename C, typename T, typename F>
1725std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1726 s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
1727 return s;
1728}
1729
1730template<typename C, typename T, typename F>
1731HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
1735 return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
1736}
1737
1738template<typename A, typename B>
1740 struct pattern_tag {};
1741 A a;
1743
1745
1748
1749 constexpr static bool canonical = A::canonical && B::canonical;
1750
1751 template<uint32_t bound>
1752 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1753 if (e.node_type == Broadcast::_node_type) {
1754 const Broadcast &op = (const Broadcast &)e;
1755 if (a.template match<bound>(*op.value.get(), state) &&
1756 lanes.template match<bound>(op.lanes, state)) {
1757 return true;
1758 }
1759 }
1760 return false;
1761 }
1762
1763 template<uint32_t bound, typename A2, typename B2>
1764 HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
1765 return (a.template match<bound>(unwrap(op.a), state) &&
1766 lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1767 }
1768
1773 lanes.make_folded_const(lanes_val, ty, state);
1774 int32_t l = (int32_t)lanes_val.u.i64;
1775 type_hint.lanes /= l;
1776 Expr val = a.make(state, type_hint);
1777 if (l == 1) {
1778 return val;
1779 } else {
1780 return Broadcast::make(std::move(val), l);
1781 }
1782 }
1783
1784 constexpr static bool foldable = false;
1785
1786 template<typename A1 = A>
1790 lanes.make_folded_const(lanes_val, lanes_ty, state);
1791 uint16_t l = (uint16_t)lanes_val.u.i64;
1792 a.make_folded_const(val, ty, state);
1793 ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
1794 }
1795};
1796
1797template<typename A, typename B>
1798inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1799 s << "broadcast(" << op.a << ", " << op.lanes << ")";
1800 return s;
1801}
1802
1803template<typename A, typename B>
1804HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
1806 return {pattern_arg(a), pattern_arg(lanes)};
1807}
1808
1809template<typename A, typename B, typename C>
1810struct RampOp {
1811 struct pattern_tag {};
1812 A a;
1813 B b;
1815
1817
1820
1821 constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1822
1823 template<uint32_t bound>
1824 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1825 if (e.node_type != Ramp::_node_type) {
1826 return false;
1827 }
1828 const Ramp &op = (const Ramp &)e;
1829 if (a.template match<bound>(*op.base.get(), state) &&
1830 b.template match<bound | bindings<A>::mask>(*op.stride.get(), state) &&
1831 lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(op.lanes, state)) {
1832 return true;
1833 } else {
1834 return false;
1835 }
1836 }
1837
1838 template<uint32_t bound, typename A2, typename B2, typename C2>
1839 HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
1840 return (a.template match<bound>(unwrap(op.a), state) &&
1841 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state) &&
1842 lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(unwrap(op.lanes), state));
1843 }
1844
1849 lanes.make_folded_const(lanes_val, ty, state);
1850 int32_t l = (int32_t)lanes_val.u.i64;
1851 type_hint.lanes /= l;
1852 Expr ea, eb;
1853 eb = b.make(state, type_hint);
1854 ea = a.make(state, eb.type());
1855 return Ramp::make(ea, eb, l);
1856 }
1857
1858 constexpr static bool foldable = false;
1859};
1860
1861template<typename A, typename B, typename C>
1862std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1863 s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
1864 return s;
1865}
1866
1867template<typename A, typename B, typename C>
1868HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1872 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1873}
1874
1875template<typename A, typename B, VectorReduce::Operator reduce_op>
1877 struct pattern_tag {};
1878 A a;
1880
1882
1885 constexpr static bool canonical = A::canonical;
1886
1887 template<uint32_t bound>
1888 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1889 if (e.node_type == VectorReduce::_node_type) {
1890 const VectorReduce &op = (const VectorReduce &)e;
1891 if (op.op == reduce_op &&
1892 a.template match<bound>(*op.value.get(), state) &&
1893 lanes.template match<bound | bindings<A>::mask>(op.type.lanes(), state)) {
1894 return true;
1895 }
1896 }
1897 return false;
1898 }
1899
1900 template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
1902 return (reduce_op == reduce_op_2 &&
1903 a.template match<bound>(unwrap(op.a), state) &&
1904 lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1905 }
1906
1911 lanes.make_folded_const(lanes_val, ty, state);
1912 int l = (int)lanes_val.u.i64;
1913 return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
1914 }
1915
1916 constexpr static bool foldable = false;
1917};
1918
1919template<typename A, typename B, VectorReduce::Operator reduce_op>
1920inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1921 s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
1922 return s;
1923}
1924
1925template<typename A, typename B>
1926HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
1928 return {pattern_arg(a), pattern_arg(lanes)};
1929}
1930
1931template<typename A, typename B>
1932HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
1934 return {pattern_arg(a), pattern_arg(lanes)};
1935}
1936
1937template<typename A, typename B>
1938HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
1940 return {pattern_arg(a), pattern_arg(lanes)};
1941}
1942
1943template<typename A, typename B>
1944HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
1946 return {pattern_arg(a), pattern_arg(lanes)};
1947}
1948
1949template<typename A, typename B>
1950HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
1952 return {pattern_arg(a), pattern_arg(lanes)};
1953}
1954
1955template<typename A>
1956struct NegateOp {
1957 struct pattern_tag {};
1958 A a;
1959
1961
1964
1965 constexpr static bool canonical = A::canonical;
1966
1967 template<uint32_t bound>
1968 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1969 if (e.node_type != Sub::_node_type) {
1970 return false;
1971 }
1972 const Sub &op = (const Sub &)e;
1973 return (a.template match<bound>(*op.b.get(), state) &&
1974 is_const_zero(op.a));
1975 }
1976
1977 template<uint32_t bound, typename A2>
1978 HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
1979 return a.template match<bound>(unwrap(p.a), state);
1980 }
1981
1984 Expr ea = a.make(state, type_hint);
1985 Expr z = make_zero(ea.type());
1986 return Sub::make(std::move(z), std::move(ea));
1987 }
1988
1989 constexpr static bool foldable = A::foldable;
1990
1991 template<typename A1 = A>
1993 a.make_folded_const(val, ty, state);
1994 int dead_bits = 64 - ty.bits;
1995 switch (ty.code) {
1996 case halide_type_int:
1997 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
1998 // Trying to negate the most negative signed int for a no-overflow type.
2000 } else {
2001 // Negate, drop the high bits, and then sign-extend them back
2002 val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
2003 }
2004 break;
2005 case halide_type_uint:
2006 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
2007 break;
2008 case halide_type_float:
2009 case halide_type_bfloat:
2010 val.u.f64 = -val.u.f64;
2011 break;
2012 default:
2013 // unreachable
2014 ;
2015 }
2016 }
2017};
2018
2019template<typename A>
2020std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
2021 s << "-" << op.a;
2022 return s;
2023}
2024
2025template<typename A>
2026HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
2028 return {pattern_arg(a)};
2029}
2030
2031template<typename A>
2036
2037template<typename A>
2038struct CastOp {
2039 struct pattern_tag {};
2041 A a;
2042
2044
2047 constexpr static bool canonical = A::canonical;
2048
2049 template<uint32_t bound>
2050 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2051 if (e.node_type != Cast::_node_type) {
2052 return false;
2053 }
2054 const Cast &op = (const Cast &)e;
2055 return (e.type == t &&
2056 a.template match<bound>(*op.value.get(), state));
2057 }
2058 template<uint32_t bound, typename A2>
2059 HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
2060 return t == op.t && a.template match<bound>(unwrap(op.a), state);
2061 }
2062
2065 return cast(t, a.make(state, {}));
2066 }
2067
2068 constexpr static bool foldable = false;
2069};
2070
2071template<typename A>
2072std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2073 s << "cast(" << op.t << ", " << op.a << ")";
2074 return s;
2075}
2076
2077template<typename A>
2078HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
2080 return {t, pattern_arg(a)};
2081}
2082
2083template<typename Vec, typename Base, typename Stride, typename Lanes>
2084struct SliceOp {
2085 struct pattern_tag {};
2090
2091 static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2092
2095 constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2096
2097 template<uint32_t bound>
2098 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2099 if (e.node_type != IRNodeType::Shuffle) {
2100 return false;
2101 }
2102 const Shuffle &v = (const Shuffle &)e;
2103 return v.vectors.size() == 1 &&
2104 vec.template match<bound>(*v.vectors[0].get(), state) &&
2105 base.template match<bound | bindings<Vec>::mask>(v.slice_begin(), state) &&
2106 stride.template match<bound | bindings<Vec>::mask | bindings<Base>::mask>(v.slice_stride(), state) &&
2108 }
2109
2114 base.make_folded_const(base_val, ty, state);
2115 int b = (int)base_val.u.i64;
2116 stride.make_folded_const(stride_val, ty, state);
2117 int s = (int)stride_val.u.i64;
2118 lanes.make_folded_const(lanes_val, ty, state);
2119 int l = (int)lanes_val.u.i64;
2120 return Shuffle::make_slice(vec.make(state, type_hint), b, s, l);
2121 }
2122
2123 constexpr static bool foldable = false;
2124
2127 : vec(v), base(b), stride(s), lanes(l) {
2128 static_assert(Base::foldable, "Base of slice should consist only of operations that constant-fold");
2129 static_assert(Stride::foldable, "Stride of slice should consist only of operations that constant-fold");
2130 static_assert(Lanes::foldable, "Lanes of slice should consist only of operations that constant-fold");
2131 }
2132};
2133
2134template<typename Vec, typename Base, typename Stride, typename Lanes>
2135std::ostream &operator<<(std::ostream &s, const SliceOp<Vec, Base, Stride, Lanes> &op) {
2136 s << "slice(" << op.vec << ", " << op.base << ", " << op.stride << ", " << op.lanes << ")";
2137 return s;
2138}
2139
2140template<typename Vec, typename Base, typename Stride, typename Lanes>
2141HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept
2142 -> SliceOp<decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))> {
2143 return {pattern_arg(vec), pattern_arg(base), pattern_arg(stride), pattern_arg(lanes)};
2144}
2145
2146template<typename A>
2147struct Fold {
2148 struct pattern_tag {};
2149 A a;
2150
2152
2155 constexpr static bool canonical = true;
2156
2161 a.make_folded_const(c, ty, state);
2162
2163 // The result of the fold may have an underspecified type
2164 // (e.g. because it's from an int literal). Make the type code
2165 // and bits match the required type, if there is one (we can
2166 // tell from the bits field).
2167 if (type_hint.bits) {
2168 if (((int)ty.code == (int)halide_type_int) &&
2169 ((int)type_hint.code == (int)halide_type_float)) {
2170 int64_t x = c.u.i64;
2171 c.u.f64 = (double)x;
2172 }
2173 ty.code = type_hint.code;
2174 ty.bits = type_hint.bits;
2175 }
2176
2177 Expr e = make_const_expr(c, ty);
2178 return e;
2179 }
2180
2181 constexpr static bool foldable = A::foldable;
2182
2183 template<typename A1 = A>
2185 a.make_folded_const(val, ty, state);
2186 }
2187};
2188
2189template<typename A>
2190HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
2192 return {pattern_arg(a)};
2193}
2194
2195template<typename A>
2196std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2197 s << "fold(" << op.a << ")";
2198 return s;
2199}
2200
2201template<typename A>
2203 struct pattern_tag {};
2204 A a;
2205
2207
2208 // This rule is a predicate, so it always evaluates to a boolean,
2209 // which has IRNodeType UIntImm
2212 constexpr static bool canonical = true;
2213
2214 constexpr static bool foldable = A::foldable;
2215
2216 template<typename A1 = A>
2218 a.make_folded_const(val, ty, state);
2219 ty.code = halide_type_uint;
2220 ty.bits = 64;
2221 val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
2222 ty.lanes = 1;
2223 }
2224};
2225
2226template<typename A>
2227HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
2229 return {pattern_arg(a)};
2230}
2231
2232template<typename A>
2233std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2234 s << "overflows(" << op.a << ")";
2235 return s;
2236}
2237
2238struct Overflow {
2239 struct pattern_tag {};
2240
2241 constexpr static uint32_t binds = 0;
2242
2243 // Overflow is an intrinsic, represented as a Call node
2246 constexpr static bool canonical = true;
2247
2248 template<uint32_t bound>
2249 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2250 if (e.node_type != Call::_node_type) {
2251 return false;
2252 }
2253 const Call &op = (const Call &)e;
2255 }
2256
2262
2263 constexpr static bool foldable = true;
2264
2267 val.u.u64 = 0;
2269 }
2270};
2271
2272inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
2273 s << "overflow()";
2274 return s;
2275}
2276
2277template<typename A>
2278struct IsConst {
2279 struct pattern_tag {};
2280
2282
2283 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2286 constexpr static bool canonical = true;
2287
2288 A a;
2291
2292 constexpr static bool foldable = true;
2293
2294 template<typename A1 = A>
2296 Expr e = a.make(state, {});
2297 ty.code = halide_type_uint;
2298 ty.bits = 64;
2299 ty.lanes = 1;
2300 if (check_v) {
2301 val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
2302 } else {
2303 val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
2304 }
2305 }
2306};
2307
2308template<typename A>
2309HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
2311 return {pattern_arg(a), false, 0};
2312}
2313
2314template<typename A>
2315HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst<decltype(pattern_arg(a))> {
2317 return {pattern_arg(a), true, value};
2318}
2319
2320template<typename A>
2321std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2322 if (op.check_v) {
2323 s << "is_const(" << op.a << ")";
2324 } else {
2325 s << "is_const(" << op.a << ", " << op.v << ")";
2326 }
2327 return s;
2328}
2329
2330template<typename A, typename Prover>
2331struct CanProve {
2332 struct pattern_tag {};
2333 A a;
2334 Prover *prover; // An existing simplifying mutator
2335
2337
2338 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2341 constexpr static bool canonical = true;
2342
2343 constexpr static bool foldable = true;
2344
2345 // Includes a raw call to an inlined make method, so don't inline.
2347 Expr condition = a.make(state, {});
2348 condition = prover->mutate(condition, nullptr);
2349 val.u.u64 = is_const_one(condition);
2350 ty.code = halide_type_uint;
2351 ty.bits = 1;
2352 ty.lanes = condition.type().lanes();
2353 }
2354};
2355
2356template<typename A, typename Prover>
2357HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
2359 return {pattern_arg(a), p};
2360}
2361
2362template<typename A, typename Prover>
2363std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2364 s << "can_prove(" << op.a << ")";
2365 return s;
2366}
2367
2368template<typename A>
2369struct IsFloat {
2370 struct pattern_tag {};
2371 A a;
2372
2374
2375 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2378 constexpr static bool canonical = true;
2379
2380 constexpr static bool foldable = true;
2381
2384 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2385 Type t = a.make(state, {}).type();
2386 val.u.u64 = t.is_float();
2387 ty.code = halide_type_uint;
2388 ty.bits = 1;
2389 ty.lanes = t.lanes();
2390 }
2391};
2392
2393template<typename A>
2394HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
2396 return {pattern_arg(a)};
2397}
2398
2399template<typename A>
2400std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2401 s << "is_float(" << op.a << ")";
2402 return s;
2403}
2404
2405template<typename A>
2406struct IsInt {
2407 struct pattern_tag {};
2408 A a;
2410
2412
2413 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2416 constexpr static bool canonical = true;
2417
2418 constexpr static bool foldable = true;
2419
2422 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2423 Type t = a.make(state, {}).type();
2424 val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2425 ty.code = halide_type_uint;
2426 ty.bits = 1;
2427 ty.lanes = t.lanes();
2428 }
2429};
2430
2431template<typename A>
2432HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0, int lanes = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
2434 return {pattern_arg(a), bits, lanes};
2435}
2436
2437template<typename A>
2438std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2439 s << "is_int(" << op.a;
2440 if (op.bits > 0) {
2441 s << ", " << op.bits;
2442 }
2443 if (op.lanes > 0) {
2444 s << ", " << op.lanes;
2445 }
2446 s << ")";
2447 return s;
2448}
2449
2450template<typename A>
2451struct IsUInt {
2452 struct pattern_tag {};
2453 A a;
2455
2457
2458 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2461 constexpr static bool canonical = true;
2462
2463 constexpr static bool foldable = true;
2464
2467 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2468 Type t = a.make(state, {}).type();
2469 val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2470 ty.code = halide_type_uint;
2471 ty.bits = 1;
2472 ty.lanes = t.lanes();
2473 }
2474};
2475
2476template<typename A>
2477HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0, int lanes = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2479 return {pattern_arg(a), bits, lanes};
2480}
2481
2482template<typename A>
2483std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2484 s << "is_uint(" << op.a;
2485 if (op.bits > 0) {
2486 s << ", " << op.bits;
2487 }
2488 if (op.lanes > 0) {
2489 s << ", " << op.lanes;
2490 }
2491 s << ")";
2492 return s;
2493}
2494
2495template<typename A>
2496struct IsScalar {
2497 struct pattern_tag {};
2498 A a;
2499
2501
2502 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2505 constexpr static bool canonical = true;
2506
2507 constexpr static bool foldable = true;
2508
2511 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2512 Type t = a.make(state, {}).type();
2513 val.u.u64 = t.is_scalar();
2514 ty.code = halide_type_uint;
2515 ty.bits = 1;
2516 ty.lanes = t.lanes();
2517 }
2518};
2519
2520template<typename A>
2521HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
2523 return {pattern_arg(a)};
2524}
2525
2526template<typename A>
2527std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2528 s << "is_scalar(" << op.a << ")";
2529 return s;
2530}
2531
2532template<typename A>
2534 struct pattern_tag {};
2535 A a;
2536
2538
2539 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2542 constexpr static bool canonical = true;
2543
2544 constexpr static bool foldable = true;
2545
2548 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2549 a.make_folded_const(val, ty, state);
2550 const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
2551 if (ty.code == halide_type_uint || ty.code == halide_type_int) {
2552 val.u.u64 = (val.u.u64 == max_bits);
2553 } else {
2554 val.u.u64 = 0;
2555 }
2556 ty.code = halide_type_uint;
2557 ty.bits = 1;
2558 }
2559};
2560
2561template<typename A>
2562HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
2564 return {pattern_arg(a)};
2565}
2566
2567template<typename A>
2568std::ostream &operator<<(std::ostream &s, const IsMaxValue<A> &op) {
2569 s << "is_max_value(" << op.a << ")";
2570 return s;
2571}
2572
2573template<typename A>
2575 struct pattern_tag {};
2576 A a;
2577
2579
2580 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2583 constexpr static bool canonical = true;
2584
2585 constexpr static bool foldable = true;
2586
2589 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2590 a.make_folded_const(val, ty, state);
2591 if (ty.code == halide_type_int) {
2592 const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
2593 val.u.u64 = (val.u.u64 == min_bits);
2594 } else if (ty.code == halide_type_uint) {
2595 val.u.u64 = (val.u.u64 == 0);
2596 } else {
2597 val.u.u64 = 0;
2598 }
2599 ty.code = halide_type_uint;
2600 ty.bits = 1;
2601 }
2602};
2603
2604template<typename A>
2605HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
2607 return {pattern_arg(a)};
2608}
2609
2610template<typename A>
2611std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
2612 s << "is_min_value(" << op.a << ")";
2613 return s;
2614}
2615
2616template<typename A>
2617struct LanesOf {
2618 struct pattern_tag {};
2619 A a;
2620
2622
2623 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2626 constexpr static bool canonical = true;
2627
2628 constexpr static bool foldable = true;
2629
2632 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2633 Type t = a.make(state, {}).type();
2634 val.u.u64 = t.lanes();
2635 ty.code = halide_type_uint;
2636 ty.bits = 32;
2637 ty.lanes = 1;
2638 }
2639};
2640
2641template<typename A>
2642HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf<decltype(pattern_arg(a))> {
2644 return {pattern_arg(a)};
2645}
2646
2647template<typename A>
2648std::ostream &operator<<(std::ostream &s, const LanesOf<A> &op) {
2649 s << "lanes_of(" << op.a << ")";
2650 return s;
2651}
2652
2653// Verify properties of each rewrite rule. Currently just fuzz tests them.
2654template<typename Before,
2655 typename After,
2656 typename Predicate,
2657 typename = typename std::enable_if<std::decay<Before>::type::foldable &&
2658 std::decay<After>::type::foldable>::type>
2660 halide_type_t wildcard_type, halide_type_t output_type) noexcept {
2661
2662 // We only validate the rules in the scalar case
2663 wildcard_type.lanes = output_type.lanes = 1;
2664
2665 // Track which types this rule has been tested for before
2666 static std::set<uint32_t> tested;
2667
2668 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2669 return;
2670 }
2671
2672 // Print it in a form where it can be piped into a python/z3 validator
2673 debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
2674
2675 // Substitute some random constants into the before and after
2676 // expressions and see if the rule holds true. This should catch
2677 // silly errors, but not necessarily corner cases.
2678 static std::mt19937_64 rng(0);
2679 MatcherState state;
2680
2681 Expr exprs[max_wild];
2682
2683 for (int trials = 0; trials < 100; trials++) {
2684 // We want to test small constants more frequently than
2685 // large ones, otherwise we'll just get coverage of
2686 // overflow rules.
2687 int shift = (int)(rng() & (wildcard_type.bits - 1));
2688
2689 for (int i = 0; i < max_wild; i++) {
2690 // Bind all the exprs and constants
2691 switch (wildcard_type.code) {
2692 case halide_type_uint: {
2693 // Normalize to the type's range by adding zero
2694 uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2695 state.set_bound_const(i, val, wildcard_type);
2696 val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2697 exprs[i] = make_const(wildcard_type, val);
2698 state.set_binding(i, *exprs[i].get());
2699 } break;
2700 case halide_type_int: {
2701 int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2702 state.set_bound_const(i, val, wildcard_type);
2703 val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2704 exprs[i] = make_const(wildcard_type, val);
2705 } break;
2706 case halide_type_float:
2707 case halide_type_bfloat: {
2708 // Use a very narrow range of precise floats, so
2709 // that none of the rules a human is likely to
2710 // write have instabilities.
2711 double val = ((int64_t)(rng() & 15) - 8) / 2.0;
2712 state.set_bound_const(i, val, wildcard_type);
2713 val = ((int64_t)(rng() & 15) - 8) / 2.0;
2714 exprs[i] = make_const(wildcard_type, val);
2715 } break;
2716 default:
2717 return; // Don't care about handles
2718 }
2719 state.set_binding(i, *exprs[i].get());
2720 }
2721
2723 halide_type_t type = output_type;
2724 if (!evaluate_predicate(pred, state)) {
2725 continue;
2726 }
2727 before.make_folded_const(val_before, type, state);
2728 uint16_t lanes = type.lanes;
2729 after.make_folded_const(val_after, type, state);
2730 lanes |= type.lanes;
2731
2733 continue;
2734 }
2735
2736 bool ok = true;
2737 switch (output_type.code) {
2738 case halide_type_uint:
2739 // Compare normalized representations
2740 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
2741 constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
2742 break;
2743 case halide_type_int:
2744 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
2745 constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
2746 break;
2747 case halide_type_float:
2748 case halide_type_bfloat: {
2749 double error = std::abs(val_before.u.f64 - val_after.u.f64);
2750 // We accept an equal bit pattern (e.g. inf vs inf),
2751 // a small floating point difference, or turning a nan into not-a-nan.
2752 ok &= (error < 0.01 ||
2753 val_before.u.u64 == val_after.u.u64 ||
2754 std::isnan(val_before.u.f64));
2755 break;
2756 }
2757 default:
2758 return;
2759 }
2760
2761 if (!ok) {
2762 debug(0) << "Fails with values:\n";
2763 for (int i = 0; i < max_wild; i++) {
2765 state.get_bound_const(i, val, wildcard_type);
2766 debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
2767 }
2768 for (int i = 0; i < max_wild; i++) {
2769 debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
2770 }
2771 debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
2772 debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
2773 debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
2775 }
2776 }
2777}
2778
2779template<typename Before,
2780 typename After,
2781 typename Predicate,
2782 typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
2783 std::decay<After>::type::foldable)>::type>
2785 halide_type_t, halide_type_t, int dummy = 0) noexcept {
2786 // We can't verify rewrite rules that can't be constant-folded.
2787}
2788
2790bool evaluate_predicate(bool x, MatcherState &) noexcept {
2791 return x;
2792}
2793
2794template<typename Pattern,
2795 typename = typename enable_if_pattern<Pattern>::type>
2799 p.make_folded_const(c, ty, state);
2800 // Overflow counts as a failed predicate
2801 return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
2802}
2803
2804// #defines for testing
2805
2806// Print all successful or failed matches
2807#define HALIDE_DEBUG_MATCHED_RULES 0
2808#define HALIDE_DEBUG_UNMATCHED_RULES 0
2809
2810// Set to true if you want to fuzz test every rewrite passed to
2811// operator() to ensure the input and the output have the same value
2812// for lots of random values of the wildcards. Run
2813// correctness_simplify with this on.
2814#define HALIDE_FUZZ_TEST_RULES 0
2815
2816template<typename Instance>
2817struct Rewriter {
2823
2828
2829 template<typename After>
2833
2834 template<typename Before,
2835 typename After,
2836 typename = typename enable_if_pattern<Before>::type,
2837 typename = typename enable_if_pattern<After>::type>
2839 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2840 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2841 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2842#if HALIDE_FUZZ_TEST_RULES
2844#endif
2845 if (before.template match<0>(unwrap(instance), state)) {
2847#if HALIDE_DEBUG_MATCHED_RULES
2848 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2849#endif
2850 return true;
2851 } else {
2852#if HALIDE_DEBUG_UNMATCHED_RULES
2853 debug(0) << instance << " does not match " << before << "\n";
2854#endif
2855 return false;
2856 }
2857 }
2858
2859 template<typename Before,
2860 typename = typename enable_if_pattern<Before>::type>
2862 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2863 if (before.template match<0>(unwrap(instance), state)) {
2864 result = after;
2865#if HALIDE_DEBUG_MATCHED_RULES
2866 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2867#endif
2868 return true;
2869 } else {
2870#if HALIDE_DEBUG_UNMATCHED_RULES
2871 debug(0) << instance << " does not match " << before << "\n";
2872#endif
2873 return false;
2874 }
2875 }
2876
2877 template<typename Before,
2878 typename = typename enable_if_pattern<Before>::type>
2880 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2881#if HALIDE_FUZZ_TEST_RULES
2883#endif
2884 if (before.template match<0>(unwrap(instance), state)) {
2886#if HALIDE_DEBUG_MATCHED_RULES
2887 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2888#endif
2889 return true;
2890 } else {
2891#if HALIDE_DEBUG_UNMATCHED_RULES
2892 debug(0) << instance << " does not match " << before << "\n";
2893#endif
2894 return false;
2895 }
2896 }
2897
2898 template<typename Before,
2899 typename After,
2900 typename Predicate,
2901 typename = typename enable_if_pattern<Before>::type,
2902 typename = typename enable_if_pattern<After>::type,
2903 typename = typename enable_if_pattern<Predicate>::type>
2905 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2906 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2907 static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
2908 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2909 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2910
2911#if HALIDE_FUZZ_TEST_RULES
2913#endif
2914 if (before.template match<0>(unwrap(instance), state) &&
2917#if HALIDE_DEBUG_MATCHED_RULES
2918 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2919#endif
2920 return true;
2921 } else {
2922#if HALIDE_DEBUG_UNMATCHED_RULES
2923 debug(0) << instance << " does not match " << before << "\n";
2924#endif
2925 return false;
2926 }
2927 }
2928
2929 template<typename Before,
2930 typename Predicate,
2931 typename = typename enable_if_pattern<Before>::type,
2932 typename = typename enable_if_pattern<Predicate>::type>
2934 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2935 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2936
2937 if (before.template match<0>(unwrap(instance), state) &&
2939 result = after;
2940#if HALIDE_DEBUG_MATCHED_RULES
2941 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2942#endif
2943 return true;
2944 } else {
2945#if HALIDE_DEBUG_UNMATCHED_RULES
2946 debug(0) << instance << " does not match " << before << "\n";
2947#endif
2948 return false;
2949 }
2950 }
2951
2952 template<typename Before,
2953 typename Predicate,
2954 typename = typename enable_if_pattern<Before>::type,
2955 typename = typename enable_if_pattern<Predicate>::type>
2957 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2958 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2959#if HALIDE_FUZZ_TEST_RULES
2961#endif
2962 if (before.template match<0>(unwrap(instance), state) &&
2965#if HALIDE_DEBUG_MATCHED_RULES
2966 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2967#endif
2968 return true;
2969 } else {
2970#if HALIDE_DEBUG_UNMATCHED_RULES
2971 debug(0) << instance << " does not match " << before << "\n";
2972#endif
2973 return false;
2974 }
2975 }
2976};
2977
2978/** Construct a rewriter for the given instance, which may be a pattern
2979 * with concrete expressions as leaves, or just an expression. The
2980 * second optional argument (wildcard_type) is a hint as to what the
2981 * type of the wildcards is likely to be. If omitted it uses the same
2982 * type as the expression itself. They are not required to be this
2983 * type, but the rule will only be tested for wildcards of that type
2984 * when testing is enabled.
2985 *
2986 * The rewriter can be used to check to see if the instance is one of
2987 * some number of patterns and if so rewrite it into another form,
2988 * using its operator() method. See Simplify.cpp for a bunch of
2989 * example usage.
2990 *
2991 * Important: Any Exprs in patterns are captured by reference, not by
2992 * value, so ensure they outlive the rewriter.
2993 */
2994// @{
2995template<typename Instance,
2996 typename = typename enable_if_pattern<Instance>::type>
2997HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
2998 return {pattern_arg(instance), output_type, wildcard_type};
2999}
3000
3001template<typename Instance,
3002 typename = typename enable_if_pattern<Instance>::type>
3003HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3004 return {pattern_arg(instance), output_type, output_type};
3005}
3006
3008auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3009 return {pattern_arg(e), e.type(), wildcard_type};
3010}
3011
3013auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3014 return {pattern_arg(e), e.type(), e.type()};
3015}
3016// @}
3017
3018} // namespace IRMatcher
3019
3020} // namespace Internal
3021} // namespace Halide
3022
3023#endif
#define internal_error
Definition Errors.h:23
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
#define HALIDE_ALWAYS_INLINE
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
For optional debugging during codegen, use the debug class as follows:
Definition Debug.h:49
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1598
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1590
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
Definition IRMatch.h:2997
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
Definition IRMatch.h:579
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1539
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
Definition IRMatch.h:1287
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
Definition IRMatch.h:1655
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
Definition IRMatch.h:2790
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1044
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
Definition IRMatch.h:1262
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
Definition IRMatch.h:2032
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1182
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:933
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
Definition IRMatch.h:2562
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
Definition IRMatch.h:229
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
Definition IRMatch.h:1313
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
Definition IRMatch.h:1944
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
Definition IRMatch.h:1162
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition IRMatch.h:2309
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
Definition IRMatch.h:1534
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1192
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:999
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1586
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1602
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1547
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Definition IRMatch.h:940
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
Definition IRMatch.h:1039
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1564
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
Definition IRMatch.h:1006
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
Definition IRMatch.h:2141
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1868
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1032
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1560
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1073
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1318
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Definition IRMatch.h:571
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1157
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
Definition IRMatch.h:2078
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
Definition IRMatch.h:2227
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1552
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
Definition IRMatch.h:588
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1059
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:980
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
Definition IRMatch.h:2521
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
Definition IRMatch.h:2190
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
Definition IRMatch.h:1661
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1578
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1117
constexpr bool and_reduce()
Definition IRMatch.h:1342
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1282
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1556
constexpr int max_wild
Definition IRMatch.h:74
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1257
HALIDE_ALWAYS_INLINE bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept
Definition IRMatch.h:195
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
Definition IRMatch.h:2394
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1207
bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1132
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1308
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0, int lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
Definition IRMatch.h:2477
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
Definition IRMatch.h:1950
constexpr bool commutative(IRNodeType t)
Definition IRMatch.h:627
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1543
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
Definition IRMatch.h:973
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
Definition IRMatch.h:1938
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
Definition IRMatch.h:1804
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0, int lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
Definition IRMatch.h:2432
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition IRMatch.h:1731
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
Definition IRMatch.h:2605
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1095
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
Definition IRMatch.h:2659
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1167
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1582
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1568
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1013
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1606
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1594
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1217
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:966
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
Definition IRMatch.h:1187
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
Definition IRMatch.h:1137
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
Definition IRMatch.h:2642
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1142
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
Definition IRMatch.h:1932
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
Definition IRMatch.h:1926
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1292
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
Definition IRMatch.h:160
constexpr uint32_t bitwise_or_reduce()
Definition IRMatch.h:1333
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1610
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1242
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
Definition IRMatch.h:149
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
Definition IRMatch.h:1212
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< decltype(pattern_arg(a))>
Definition IRMatch.h:1572
constexpr int const_min(int a, int b)
Definition IRMatch.h:1352
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1267
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition IRMatch.h:1066
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1232
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:947
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
Definition IRMatch.h:2357
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
Definition IRMatch.h:1237
T div_imp(T a, T b)
Definition IROperator.h:260
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
void expr_match_test()
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
constexpr IRNodeType StrongestExprNodeType
Definition Expr.h:80
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition IROperator.h:239
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition Expr.h:25
bool is_const(const Expr &e)
Is the expression either an IntImm, a FloatImm, a StringImm, or a Cast of the same,...
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
Expr min(const FuncRef &a, const FuncRef &b)
Explicit overloads of min and max for FuncRef.
Definition Func.h:584
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
@ C
No name mangling.
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop.
Expr abs(Expr a)
Returns the absolute value of a signed integer or floating-point expression.
Expr max(const FuncRef &a, const FuncRef &b)
Definition Func.h:587
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
Definition Expr.h:257
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition Expr.h:321
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition Expr.h:315
The sum of two expressions.
Definition IR.h:48
Logical and - are both expressions true.
Definition IR.h:167
A base class for expression nodes.
Definition Expr.h:142
A vector with 'lanes' elements, in which every element is 'value'.
Definition IR.h:251
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
Definition IR.h:257
A function call.
Definition IR.h:482
bool is_intrinsic() const
Definition IR.h:690
static const IRNodeType _node_type
Definition IR.h:735
The actual IR nodes begin here.
Definition IR.h:29
static const IRNodeType _node_type
Definition IR.h:34
The ratio of two expressions.
Definition IR.h:75
Is the first expression equal to the second.
Definition IR.h:113
Floating point constants.
Definition Expr.h:235
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Definition IR.h:158
Is the first expression greater than the second.
Definition IR.h:149
static constexpr bool canonical
Definition IRMatch.h:653
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:676
static constexpr uint32_t binds
Definition IRMatch.h:645
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:657
static constexpr bool foldable
Definition IRMatch.h:673
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition IRMatch.h:719
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:667
static constexpr IRNodeType max_node_type
Definition IRMatch.h:648
static constexpr IRNodeType min_node_type
Definition IRMatch.h:647
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1746
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1770
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1764
static constexpr uint32_t binds
Definition IRMatch.h:1744
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1752
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1747
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1787
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2346
static constexpr uint32_t binds
Definition IRMatch.h:2336
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2339
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2340
static constexpr bool foldable
Definition IRMatch.h:2343
static constexpr bool canonical
Definition IRMatch.h:2341
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2046
static constexpr bool foldable
Definition IRMatch.h:2068
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2050
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2045
static constexpr uint32_t binds
Definition IRMatch.h:2043
static constexpr bool canonical
Definition IRMatch.h:2047
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:2059
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2064
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:820
static constexpr IRNodeType max_node_type
Definition IRMatch.h:759
static constexpr uint32_t binds
Definition IRMatch.h:756
static constexpr bool canonical
Definition IRMatch.h:760
static constexpr bool foldable
Definition IRMatch.h:783
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:767
static constexpr IRNodeType min_node_type
Definition IRMatch.h:758
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:786
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:777
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2154
static constexpr uint32_t binds
Definition IRMatch.h:2151
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2153
static constexpr bool canonical
Definition IRMatch.h:2155
static constexpr bool foldable
Definition IRMatch.h:2181
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition IRMatch.h:2158
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2184
static constexpr IRNodeType max_node_type
Definition IRMatch.h:507
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:516
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
Definition IRMatch.h:511
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Definition IRMatch.h:539
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:551
static constexpr IRNodeType min_node_type
Definition IRMatch.h:506
static constexpr bool canonical
Definition IRMatch.h:508
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:544
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
Definition IRMatch.h:534
static constexpr uint32_t binds
Definition IRMatch.h:504
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
Definition IRMatch.h:1382
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1369
static constexpr bool canonical
Definition IRMatch.h:1370
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1417
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
Definition IRMatch.h:1412
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1477
static constexpr uint32_t binds
Definition IRMatch.h:1366
static constexpr bool foldable
Definition IRMatch.h:1475
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
Definition IRMatch.h:1375
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
Definition IRMatch.h:1399
std::tuple< Args... > args
Definition IRMatch.h:1360
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1387
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Definition IRMatch.h:1408
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
Definition IRMatch.h:1520
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1368
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2284
static constexpr bool canonical
Definition IRMatch.h:2286
static constexpr bool foldable
Definition IRMatch.h:2292
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2285
static constexpr uint32_t binds
Definition IRMatch.h:2281
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2295
static constexpr bool foldable
Definition IRMatch.h:2380
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2383
static constexpr bool canonical
Definition IRMatch.h:2378
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2376
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2377
static constexpr uint32_t binds
Definition IRMatch.h:2373
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2415
static constexpr bool foldable
Definition IRMatch.h:2418
static constexpr uint32_t binds
Definition IRMatch.h:2411
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2421
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2414
static constexpr bool canonical
Definition IRMatch.h:2416
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2540
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2541
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2547
static constexpr uint32_t binds
Definition IRMatch.h:2537
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2581
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2588
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2582
static constexpr uint32_t binds
Definition IRMatch.h:2578
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2504
static constexpr uint32_t binds
Definition IRMatch.h:2500
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2510
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2503
static constexpr bool foldable
Definition IRMatch.h:2507
static constexpr bool canonical
Definition IRMatch.h:2505
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2466
static constexpr bool foldable
Definition IRMatch.h:2463
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2459
static constexpr bool canonical
Definition IRMatch.h:2461
static constexpr uint32_t binds
Definition IRMatch.h:2456
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2460
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2625
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2631
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2624
static constexpr bool foldable
Definition IRMatch.h:2628
static constexpr uint32_t binds
Definition IRMatch.h:2621
static constexpr bool canonical
Definition IRMatch.h:2626
To save stack space, the matcher objects are largely stateless and immutable.
Definition IRMatch.h:82
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
Definition IRMatch.h:127
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
Definition IRMatch.h:103
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
Definition IRMatch.h:115
static constexpr uint16_t special_values_mask
Definition IRMatch.h:88
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
Definition IRMatch.h:121
halide_type_t bound_const_type[max_wild]
Definition IRMatch.h:90
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
Definition IRMatch.h:93
HALIDE_ALWAYS_INLINE MatcherState() noexcept
Definition IRMatch.h:134
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
Definition IRMatch.h:98
halide_scalar_value_t bound_const[max_wild]
Definition IRMatch.h:84
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
Definition IRMatch.h:109
static constexpr uint16_t signed_integer_overflow
Definition IRMatch.h:87
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1968
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1983
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
Definition IRMatch.h:1978
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1992
static constexpr uint32_t binds
Definition IRMatch.h:1960
static constexpr bool canonical
Definition IRMatch.h:1965
static constexpr bool foldable
Definition IRMatch.h:1989
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1963
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1962
static constexpr uint32_t binds
Definition IRMatch.h:1619
static constexpr bool foldable
Definition IRMatch.h:1644
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1626
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1622
static constexpr bool canonical
Definition IRMatch.h:1623
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1635
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1640
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1647
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1621
static constexpr uint32_t binds
Definition IRMatch.h:2241
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2245
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2249
static constexpr bool canonical
Definition IRMatch.h:2246
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2258
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2266
static constexpr bool foldable
Definition IRMatch.h:2263
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2244
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2217
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2210
static constexpr uint32_t binds
Definition IRMatch.h:2206
static constexpr bool canonical
Definition IRMatch.h:2212
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2211
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1846
static constexpr bool canonical
Definition IRMatch.h:1821
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1819
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1818
static constexpr uint32_t binds
Definition IRMatch.h:1816
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1839
static constexpr bool foldable
Definition IRMatch.h:1858
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1824
HALIDE_NEVER_INLINE void build_replacement(After after)
Definition IRMatch.h:2830
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
Definition IRMatch.h:2904
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
Definition IRMatch.h:2879
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
Definition IRMatch.h:2825
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
Definition IRMatch.h:2933
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
Definition IRMatch.h:2861
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Definition IRMatch.h:2956
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Definition IRMatch.h:2838
static constexpr uint32_t binds
Definition IRMatch.h:1679
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1711
static constexpr bool foldable
Definition IRMatch.h:1708
static constexpr bool canonical
Definition IRMatch.h:1684
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Definition IRMatch.h:1697
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1687
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1704
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1682
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1681
static constexpr bool canonical
Definition IRMatch.h:2095
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2094
static constexpr bool foldable
Definition IRMatch.h:2123
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
Definition IRMatch.h:2126
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2093
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2098
static constexpr uint32_t binds
Definition IRMatch.h:2091
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2111
static constexpr IRNodeType min_node_type
Definition IRMatch.h:210
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:217
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:222
static constexpr IRNodeType max_node_type
Definition IRMatch.h:211
static constexpr uint32_t binds
Definition IRMatch.h:207
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1901
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1883
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1888
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1884
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1908
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:364
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:385
static constexpr IRNodeType max_node_type
Definition IRMatch.h:360
static constexpr IRNodeType min_node_type
Definition IRMatch.h:359
static constexpr uint32_t binds
Definition IRMatch.h:357
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:395
static constexpr bool canonical
Definition IRMatch.h:415
static constexpr IRNodeType max_node_type
Definition IRMatch.h:414
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:443
static constexpr uint32_t binds
Definition IRMatch.h:411
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:418
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:453
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
Definition IRMatch.h:437
static constexpr IRNodeType min_node_type
Definition IRMatch.h:413
static constexpr bool foldable
Definition IRMatch.h:450
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:279
static constexpr uint32_t binds
Definition IRMatch.h:238
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:289
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:245
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
Definition IRMatch.h:266
static constexpr IRNodeType min_node_type
Definition IRMatch.h:240
static constexpr IRNodeType max_node_type
Definition IRMatch.h:241
static constexpr uint32_t binds
Definition IRMatch.h:304
static constexpr IRNodeType max_node_type
Definition IRMatch.h:307
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:311
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:342
static constexpr IRNodeType min_node_type
Definition IRMatch.h:306
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:332
static constexpr IRNodeType min_node_type
Definition IRMatch.h:471
static constexpr uint32_t binds
Definition IRMatch.h:469
static constexpr IRNodeType max_node_type
Definition IRMatch.h:472
static constexpr bool canonical
Definition IRMatch.h:473
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:485
static constexpr bool foldable
Definition IRMatch.h:489
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:476
static constexpr uint32_t mask
Definition IRMatch.h:146
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition Expr.h:112
Integer constants.
Definition Expr.h:217
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Definition IR.h:140
Is the first expression less than the second.
Definition IR.h:131
The greater of two values.
Definition IR.h:104
The lesser of two values.
Definition IR.h:95
The remainder of a / b.
Definition IR.h:86
The product of two expressions.
Definition IR.h:66
Is the first expression not equal to the second.
Definition IR.h:122
Logical not - true if the expression false.
Definition IR.h:185
static Expr make(Expr a)
Logical or - is at least one of the expression true.
Definition IR.h:176
A linear ramp vector node.
Definition IR.h:239
static const IRNodeType _node_type
Definition IR.h:245
static Expr make(Expr base, Expr stride, int lanes)
A ternary operator.
Definition IR.h:196
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Definition IR.h:201
Construct a new vector by taking elements from another sequence of vectors.
Definition IR.h:819
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector.
std::vector< Expr > vectors
Definition IR.h:820
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition IR.h:874
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition IR.h:871
The difference of two expressions.
Definition IR.h:57
static const IRNodeType _node_type
Definition IR.h:62
static Expr make(Expr a, Expr b)
Unsigned integer constants.
Definition Expr.h:226
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition IR.h:929
static const IRNodeType _node_type
Definition IR.h:948
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
Definition Type.h:276
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition Type.h:424
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
Definition Type.h:344
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition Type.h:430
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition Type.h:338
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition Type.h:406
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition Type.h:412
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@4 u
A runtime tag for a type in the halide type system.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.