Halide  12.0.1
Halide compiler and libraries
IRMatch.h
Go to the documentation of this file.
1 #ifndef HALIDE_IR_MATCH_H
2 #define HALIDE_IR_MATCH_H
3 
4 /** \file
5  * Defines a method to match a fragment of IR against a pattern containing wildcards
6  */
7 
8 #include <map>
9 #include <random>
10 #include <set>
11 #include <vector>
12 
13 #include "IR.h"
14 #include "IREquality.h"
15 #include "IROperator.h"
16 
17 namespace Halide {
18 namespace Internal {
19 
20 /** Does the first expression have the same structure as the second?
21  * Variables in the first expression with the name * are interpreted
22  * as wildcards, and their matching equivalent in the second
23  * expression is placed in the vector give as the third argument.
24  * Wildcards require the types to match. For the type bits and width,
25  * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
26  * integer vectors of any width (including scalars), and a UInt(0, 0)
27  * will match any unsigned integer type.
28  *
29  * For example:
30  \code
31  Expr x = Variable::make(Int(32), "*");
32  match(x + x, 3 + (2*k), result)
33  \endcode
34  * should return true, and set result[0] to 3 and
35  * result[1] to 2*k.
36  */
37 bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
38 
39 /** Does the first expression have the same structure as the second?
40  * Variables are matched consistently. The first time a variable is
41  * matched, it assumes the value of the matching part of the second
42  * expression. Subsequent matches must be equal to the first match.
43  *
44  * For example:
45  \code
46  Var x("x"), y("y");
47  match(x*(x + y), a*(a + b), result)
48  \endcode
49  * should return true, and set result["x"] = a, and result["y"] = b.
50  */
51 bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
52 
53 /** Rewrite the expression x to have `lanes` lanes. This is useful
54  * for substituting the results of expr_match into a pattern expression. */
55 Expr with_lanes(const Expr &x, int lanes);
56 
58 
59 /** An alternative template-metaprogramming approach to expression
60  * matching. Potentially more efficient. We lift the expression
61  * pattern into a type, and then use force-inlined functions to
62  * generate efficient matching and reconstruction code for any
63  * pattern. Pattern elements are either one of the classes in the
64  * namespace IRMatcher, or are non-null Exprs (represented as
65  * BaseExprNode &).
66  *
67  * Pattern elements that are fully specified by their pattern can be
68  * built into an expression using the make method. Some patterns,
69  * such as a broadcast that matches any number of lanes, don't have
70  * enough information to recreate an Expr.
71  */
72 namespace IRMatcher {
73 
74 constexpr int max_wild = 6;
75 
76 static const halide_type_t i64_type = {halide_type_int, 64, 1};
77 
78 /** To save stack space, the matcher objects are largely stateless and
79  * immutable. This state object is built up during matching and then
80  * consumed when constructing a replacement Expr.
81  */
82 struct MatcherState {
85 
86  // values of the lanes field with special meaning.
87  static constexpr uint16_t signed_integer_overflow = 0x8000;
88  static constexpr uint16_t special_values_mask = 0x8000; // currently only one
89 
91 
93  void set_binding(int i, const BaseExprNode &n) noexcept {
94  bindings[i] = &n;
95  }
96 
98  const BaseExprNode *get_binding(int i) const noexcept {
99  return bindings[i];
100  }
101 
103  void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
104  bound_const[i].u.i64 = s;
105  bound_const_type[i] = t;
106  }
107 
109  void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
110  bound_const[i].u.u64 = u;
111  bound_const_type[i] = t;
112  }
113 
115  void set_bound_const(int i, double f, halide_type_t t) noexcept {
116  bound_const[i].u.f64 = f;
117  bound_const_type[i] = t;
118  }
119 
121  void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept {
122  bound_const[i] = val;
123  bound_const_type[i] = t;
124  }
125 
127  void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
128  val = bound_const[i];
129  type = bound_const_type[i];
130  }
131 
133  // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
134  MatcherState() noexcept {
135  }
136 };
137 
138 template<typename T,
139  typename = typename std::remove_reference<T>::type::pattern_tag>
141  struct type {};
142 };
143 
144 template<typename T>
145 struct bindings {
146  constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
147 };
148 
151  ty.lanes &= ~MatcherState::special_values_mask;
153  return make_signed_integer_overflow(ty);
154  }
155  // unreachable
156  return Expr();
157 }
158 
161  halide_type_t scalar_type = ty;
162  if (scalar_type.lanes & MatcherState::special_values_mask) {
163  return make_const_special_expr(scalar_type);
164  }
165 
166  const int lanes = scalar_type.lanes;
167  scalar_type.lanes = 1;
168 
169  Expr e;
170  switch (scalar_type.code) {
171  case halide_type_int:
172  e = IntImm::make(scalar_type, val.u.i64);
173  break;
174  case halide_type_uint:
175  e = UIntImm::make(scalar_type, val.u.u64);
176  break;
177  case halide_type_float:
178  case halide_type_bfloat:
179  e = FloatImm::make(scalar_type, val.u.f64);
180  break;
181  default:
182  // Unreachable
183  return Expr();
184  }
185  if (lanes > 1) {
186  e = Broadcast::make(e, lanes);
187  }
188  return e;
189 }
190 
191 bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept;
192 
193 // A fast version of expression equality that assumes a well-typed non-null expression tree.
195 bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept {
196  // Early out
197  return (&a == &b) ||
198  ((a.type == b.type) &&
199  (a.node_type == b.node_type) &&
200  equal_helper(a, b));
201 }
202 
203 // A pattern that matches a specific expression
204 struct SpecificExpr {
205  struct pattern_tag {};
206 
207  constexpr static uint32_t binds = 0;
208 
209  // What is the weakest and strongest IR node this could possibly be
212  constexpr static bool canonical = true;
213 
215 
216  template<uint32_t bound>
217  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
218  return equal(expr, e);
219  }
220 
222  Expr make(MatcherState &state, halide_type_t type_hint) const {
223  return Expr(&expr);
224  }
225 
226  constexpr static bool foldable = false;
227 };
228 
229 inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
230  s << Expr(&e.expr);
231  return s;
232 }
233 
234 template<int i>
235 struct WildConstInt {
236  struct pattern_tag {};
237 
238  constexpr static uint32_t binds = 1 << i;
239 
242  constexpr static bool canonical = true;
243 
244  template<uint32_t bound>
245  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
246  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
247  const BaseExprNode *op = &e;
248  if (op->node_type == IRNodeType::Broadcast) {
249  op = ((const Broadcast *)op)->value.get();
250  }
251  if (op->node_type != IRNodeType::IntImm) {
252  return false;
253  }
254  int64_t value = ((const IntImm *)op)->value;
255  if (bound & binds) {
257  halide_type_t type;
258  state.get_bound_const(i, val, type);
259  return (halide_type_t)e.type == type && value == val.u.i64;
260  }
261  state.set_bound_const(i, value, e.type);
262  return true;
263  }
264 
265  template<uint32_t bound>
266  HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
267  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
268  if (bound & binds) {
270  halide_type_t type;
271  state.get_bound_const(i, val, type);
272  return type == i64_type && value == val.u.i64;
273  }
274  state.set_bound_const(i, value, i64_type);
275  return true;
276  }
277 
279  Expr make(MatcherState &state, halide_type_t type_hint) const {
281  halide_type_t type;
282  state.get_bound_const(i, val, type);
283  return make_const_expr(val, type);
284  }
285 
286  constexpr static bool foldable = true;
287 
290  state.get_bound_const(i, val, ty);
291  }
292 };
293 
294 template<int i>
295 std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
296  s << "ci" << i;
297  return s;
298 }
299 
300 template<int i>
302  struct pattern_tag {};
303 
304  constexpr static uint32_t binds = 1 << i;
305 
308  constexpr static bool canonical = true;
309 
310  template<uint32_t bound>
311  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
312  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
313  const BaseExprNode *op = &e;
314  if (op->node_type == IRNodeType::Broadcast) {
315  op = ((const Broadcast *)op)->value.get();
316  }
317  if (op->node_type != IRNodeType::UIntImm) {
318  return false;
319  }
320  uint64_t value = ((const UIntImm *)op)->value;
321  if (bound & binds) {
323  halide_type_t type;
324  state.get_bound_const(i, val, type);
325  return (halide_type_t)e.type == type && value == val.u.u64;
326  }
327  state.set_bound_const(i, value, e.type);
328  return true;
329  }
330 
332  Expr make(MatcherState &state, halide_type_t type_hint) const {
334  halide_type_t type;
335  state.get_bound_const(i, val, type);
336  return make_const_expr(val, type);
337  }
338 
339  constexpr static bool foldable = true;
340 
342  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
343  state.get_bound_const(i, val, ty);
344  }
345 };
346 
347 template<int i>
348 std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
349  s << "cu" << i;
350  return s;
351 }
352 
353 template<int i>
355  struct pattern_tag {};
356 
357  constexpr static uint32_t binds = 1 << i;
358 
361  constexpr static bool canonical = true;
362 
363  template<uint32_t bound>
364  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
365  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
366  const BaseExprNode *op = &e;
367  if (op->node_type == IRNodeType::Broadcast) {
368  op = ((const Broadcast *)op)->value.get();
369  }
370  if (op->node_type != IRNodeType::FloatImm) {
371  return false;
372  }
373  double value = ((const FloatImm *)op)->value;
374  if (bound & binds) {
376  halide_type_t type;
377  state.get_bound_const(i, val, type);
378  return (halide_type_t)e.type == type && value == val.u.f64;
379  }
380  state.set_bound_const(i, value, e.type);
381  return true;
382  }
383 
385  Expr make(MatcherState &state, halide_type_t type_hint) const {
387  halide_type_t type;
388  state.get_bound_const(i, val, type);
389  return make_const_expr(val, type);
390  }
391 
392  constexpr static bool foldable = true;
393 
395  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
396  state.get_bound_const(i, val, ty);
397  }
398 };
399 
400 template<int i>
401 std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
402  s << "cf" << i;
403  return s;
404 }
405 
406 // Matches and binds to any constant Expr. Does not support constant-folding.
407 template<int i>
408 struct WildConst {
409  struct pattern_tag {};
410 
411  constexpr static uint32_t binds = 1 << i;
412 
415  constexpr static bool canonical = true;
416 
417  template<uint32_t bound>
418  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
419  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
420  const BaseExprNode *op = &e;
421  if (op->node_type == IRNodeType::Broadcast) {
422  op = ((const Broadcast *)op)->value.get();
423  }
424  switch (op->node_type) {
425  case IRNodeType::IntImm:
426  return WildConstInt<i>().template match<bound>(e, state);
427  case IRNodeType::UIntImm:
428  return WildConstUInt<i>().template match<bound>(e, state);
430  return WildConstFloat<i>().template match<bound>(e, state);
431  default:
432  return false;
433  }
434  }
435 
436  template<uint32_t bound>
437  HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
438  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
439  return WildConstInt<i>().template match<bound>(e, state);
440  }
441 
443  Expr make(MatcherState &state, halide_type_t type_hint) const {
445  halide_type_t type;
446  state.get_bound_const(i, val, type);
447  return make_const_expr(val, type);
448  }
449 
450  constexpr static bool foldable = true;
451 
453  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
454  state.get_bound_const(i, val, ty);
455  }
456 };
457 
458 template<int i>
459 std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
460  s << "c" << i;
461  return s;
462 }
463 
464 // Matches and binds to any Expr
465 template<int i>
466 struct Wild {
467  struct pattern_tag {};
468 
469  constexpr static uint32_t binds = 1 << (i + 16);
470 
473  constexpr static bool canonical = true;
474 
475  template<uint32_t bound>
476  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
477  if (bound & binds) {
478  return equal(*state.get_binding(i), e);
479  }
480  state.set_binding(i, e);
481  return true;
482  }
483 
485  Expr make(MatcherState &state, halide_type_t type_hint) const {
486  return state.get_binding(i);
487  }
488 
489  constexpr static bool foldable = true;
491  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
492  const auto *e = state.get_binding(i);
493  ty = e->type;
494  switch (e->node_type) {
495  case IRNodeType::UIntImm:
496  val.u.u64 = ((const UIntImm *)e)->value;
497  return;
498  case IRNodeType::IntImm:
499  val.u.i64 = ((const IntImm *)e)->value;
500  return;
502  val.u.f64 = ((const FloatImm *)e)->value;
503  return;
504  default:
505  // The function is noexcept, so silent failure. You
506  // shouldn't be calling this if you haven't already
507  // checked it's going to be a constant (e.g. with
508  // is_const, or because you manually bound a constant Expr
509  // to the state).
510  val.u.u64 = 0;
511  }
512  }
513 };
514 
515 template<int i>
516 std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
517  s << "_" << i;
518  return s;
519 }
520 
521 // Matches a specific constant or broadcast of that constant. The
522 // constant must be representable as an int64_t.
523 struct IntLiteral {
524  struct pattern_tag {};
526 
527  constexpr static uint32_t binds = 0;
528 
531  constexpr static bool canonical = true;
532 
534  explicit IntLiteral(int64_t v)
535  : v(v) {
536  }
537 
538  template<uint32_t bound>
539  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
540  const BaseExprNode *op = &e;
541  if (e.node_type == IRNodeType::Broadcast) {
542  op = ((const Broadcast *)op)->value.get();
543  }
544  switch (op->node_type) {
545  case IRNodeType::IntImm:
546  return ((const IntImm *)op)->value == (int64_t)v;
547  case IRNodeType::UIntImm:
548  return ((const UIntImm *)op)->value == (uint64_t)v;
550  return ((const FloatImm *)op)->value == (double)v;
551  default:
552  return false;
553  }
554  }
555 
556  template<uint32_t bound>
557  HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
558  return v == val;
559  }
560 
561  template<uint32_t bound>
562  HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
563  return v == b.v;
564  }
565 
567  Expr make(MatcherState &state, halide_type_t type_hint) const {
568  return make_const(type_hint, v);
569  }
570 
571  constexpr static bool foldable = true;
572 
574  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
575  // Assume type is already correct
576  switch (ty.code) {
577  case halide_type_int:
578  val.u.i64 = v;
579  break;
580  case halide_type_uint:
581  val.u.u64 = (uint64_t)v;
582  break;
583  case halide_type_float:
584  case halide_type_bfloat:
585  val.u.f64 = (double)v;
586  break;
587  default:
588  // Unreachable
589  ;
590  }
591  }
592 };
593 
595  return t.v;
596 }
597 
598 // Convert a provided pattern, expr, or constant int into the internal
599 // representation we use in the matcher trees.
600 template<typename T,
601  typename = typename std::decay<T>::type::pattern_tag>
603  return t;
604 }
607  return IntLiteral{x};
608 }
609 
610 template<typename T>
612  static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
613  "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
614 }
615 
617  return {*e.get()};
618 }
619 
620 // Helpers to deref SpecificExprs to const BaseExprNode & rather than
621 // passing them by value anywhere (incurring lots of refcounting)
622 template<typename T,
623  // T must be a pattern node
624  typename = typename std::decay<T>::type::pattern_tag,
625  // But T may not be SpecificExpr
626  typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
628  return t;
629 }
630 
632 const BaseExprNode &unwrap(const SpecificExpr &e) {
633  return e.expr;
634 }
635 
636 inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
637  s << op.v;
638  return s;
639 }
640 
641 template<typename Op>
643 
644 template<typename Op>
646 
647 template<typename Op>
648 double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
649 
650 constexpr bool commutative(IRNodeType t) {
651  return (t == IRNodeType::Add ||
652  t == IRNodeType::Mul ||
653  t == IRNodeType::And ||
654  t == IRNodeType::Or ||
655  t == IRNodeType::Min ||
656  t == IRNodeType::Max ||
657  t == IRNodeType::EQ ||
658  t == IRNodeType::NE);
659 }
660 
661 // Matches one of the binary operators
662 template<typename Op, typename A, typename B>
663 struct BinOp {
664  struct pattern_tag {};
665  A a;
666  B b;
667 
669 
670  constexpr static IRNodeType min_node_type = Op::_node_type;
671  constexpr static IRNodeType max_node_type = Op::_node_type;
672 
673  // For commutative bin ops, we expect the weaker IR node type on
674  // the right. That is, for the rule to be canonical it must be
675  // possible that A is at least as strong as B.
676  constexpr static bool canonical =
677  A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
678 
679  template<uint32_t bound>
680  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
681  if (e.node_type != Op::_node_type) {
682  return false;
683  }
684  const Op &op = (const Op &)e;
685  return (a.template match<bound>(*op.a.get(), state) &&
686  b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
687  }
688 
689  template<uint32_t bound, typename Op2, typename A2, typename B2>
690  HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
691  return (std::is_same<Op, Op2>::value &&
692  a.template match<bound>(unwrap(op.a), state) &&
693  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
694  }
695 
696  constexpr static bool foldable = A::foldable && B::foldable;
697 
699  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
700  halide_scalar_value_t val_a, val_b;
701  if (std::is_same<A, IntLiteral>::value) {
702  b.make_folded_const(val_b, ty, state);
703  if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
704  (std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
705  // Short circuit
706  val = val_b;
707  return;
708  }
709  const uint16_t l = ty.lanes;
710  a.make_folded_const(val_a, ty, state);
711  ty.lanes |= l; // Make sure the overflow bits are sticky
712  } else {
713  a.make_folded_const(val_a, ty, state);
714  if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
715  (std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
716  // Short circuit
717  val = val_a;
718  return;
719  }
720  const uint16_t l = ty.lanes;
721  b.make_folded_const(val_b, ty, state);
722  ty.lanes |= l;
723  }
724  switch (ty.code) {
725  case halide_type_int:
726  val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
727  break;
728  case halide_type_uint:
729  val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
730  break;
731  case halide_type_float:
732  case halide_type_bfloat:
733  val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
734  break;
735  default:
736  // unreachable
737  ;
738  }
739  }
740 
742  Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
743  Expr ea, eb;
744  if (std::is_same<A, IntLiteral>::value) {
745  eb = b.make(state, type_hint);
746  ea = a.make(state, eb.type());
747  } else {
748  ea = a.make(state, type_hint);
749  eb = b.make(state, ea.type());
750  }
751  // We sometimes mix vectors and scalars in the rewrite rules,
752  // so insert a broadcast if necessary.
753  if (ea.type().is_vector() && !eb.type().is_vector()) {
754  eb = Broadcast::make(eb, ea.type().lanes());
755  }
756  if (eb.type().is_vector() && !ea.type().is_vector()) {
757  ea = Broadcast::make(ea, eb.type().lanes());
758  }
759  return Op::make(std::move(ea), std::move(eb));
760  }
761 };
762 
763 template<typename Op>
765 
766 template<typename Op>
768 
769 template<typename Op>
770 uint64_t constant_fold_cmp_op(double, double) noexcept;
771 
772 // Matches one of the comparison operators
773 template<typename Op, typename A, typename B>
774 struct CmpOp {
775  struct pattern_tag {};
776  A a;
777  B b;
778 
780 
781  constexpr static IRNodeType min_node_type = Op::_node_type;
782  constexpr static IRNodeType max_node_type = Op::_node_type;
783  constexpr static bool canonical = (A::canonical &&
784  B::canonical &&
785  (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
786  (Op::_node_type != IRNodeType::GE) &&
787  (Op::_node_type != IRNodeType::GT));
788 
789  template<uint32_t bound>
790  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
791  if (e.node_type != Op::_node_type) {
792  return false;
793  }
794  const Op &op = (const Op &)e;
795  return (a.template match<bound>(*op.a.get(), state) &&
796  b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
797  }
798 
799  template<uint32_t bound, typename Op2, typename A2, typename B2>
800  HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
801  return (std::is_same<Op, Op2>::value &&
802  a.template match<bound>(unwrap(op.a), state) &&
803  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
804  }
805 
806  constexpr static bool foldable = A::foldable && B::foldable;
807 
809  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
810  halide_scalar_value_t val_a, val_b;
811  // If one side is an untyped const, evaluate the other side first to get a type hint.
812  if (std::is_same<A, IntLiteral>::value) {
813  b.make_folded_const(val_b, ty, state);
814  const uint16_t l = ty.lanes;
815  a.make_folded_const(val_a, ty, state);
816  ty.lanes |= l;
817  } else {
818  a.make_folded_const(val_a, ty, state);
819  const uint16_t l = ty.lanes;
820  b.make_folded_const(val_b, ty, state);
821  ty.lanes |= l;
822  }
823  switch (ty.code) {
824  case halide_type_int:
825  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
826  break;
827  case halide_type_uint:
828  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
829  break;
830  case halide_type_float:
831  case halide_type_bfloat:
832  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
833  break;
834  default:
835  // unreachable
836  ;
837  }
838  ty.code = halide_type_uint;
839  ty.bits = 1;
840  }
841 
843  Expr make(MatcherState &state, halide_type_t type_hint) const {
844  // If one side is an untyped const, evaluate the other side first to get a type hint.
845  Expr ea, eb;
846  if (std::is_same<A, IntLiteral>::value) {
847  eb = b.make(state, {});
848  ea = a.make(state, eb.type());
849  } else {
850  ea = a.make(state, {});
851  eb = b.make(state, ea.type());
852  }
853  // We sometimes mix vectors and scalars in the rewrite rules,
854  // so insert a broadcast if necessary.
855  if (ea.type().is_vector() && !eb.type().is_vector()) {
856  eb = Broadcast::make(eb, ea.type().lanes());
857  }
858  if (eb.type().is_vector() && !ea.type().is_vector()) {
859  ea = Broadcast::make(ea, eb.type().lanes());
860  }
861  return Op::make(std::move(ea), std::move(eb));
862  }
863 };
864 
865 template<typename A, typename B>
866 std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
867  s << "(" << op.a << " + " << op.b << ")";
868  return s;
869 }
870 
871 template<typename A, typename B>
872 std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
873  s << "(" << op.a << " - " << op.b << ")";
874  return s;
875 }
876 
877 template<typename A, typename B>
878 std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
879  s << "(" << op.a << " * " << op.b << ")";
880  return s;
881 }
882 
883 template<typename A, typename B>
884 std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
885  s << "(" << op.a << " / " << op.b << ")";
886  return s;
887 }
888 
889 template<typename A, typename B>
890 std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
891  s << "(" << op.a << " && " << op.b << ")";
892  return s;
893 }
894 
895 template<typename A, typename B>
896 std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
897  s << "(" << op.a << " || " << op.b << ")";
898  return s;
899 }
900 
901 template<typename A, typename B>
902 std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
903  s << "min(" << op.a << ", " << op.b << ")";
904  return s;
905 }
906 
907 template<typename A, typename B>
908 std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
909  s << "max(" << op.a << ", " << op.b << ")";
910  return s;
911 }
912 
913 template<typename A, typename B>
914 std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
915  s << "(" << op.a << " <= " << op.b << ")";
916  return s;
917 }
918 
919 template<typename A, typename B>
920 std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
921  s << "(" << op.a << " < " << op.b << ")";
922  return s;
923 }
924 
925 template<typename A, typename B>
926 std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
927  s << "(" << op.a << " >= " << op.b << ")";
928  return s;
929 }
930 
931 template<typename A, typename B>
932 std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
933  s << "(" << op.a << " > " << op.b << ")";
934  return s;
935 }
936 
937 template<typename A, typename B>
938 std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
939  s << "(" << op.a << " == " << op.b << ")";
940  return s;
941 }
942 
943 template<typename A, typename B>
944 std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
945  s << "(" << op.a << " != " << op.b << ")";
946  return s;
947 }
948 
949 template<typename A, typename B>
950 std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
951  s << "(" << op.a << " % " << op.b << ")";
952  return s;
953 }
954 
955 template<typename A, typename B>
956 HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
957  assert_is_lvalue_if_expr<A>();
958  assert_is_lvalue_if_expr<B>();
959  return {pattern_arg(a), pattern_arg(b)};
960 }
961 
962 template<typename A, typename B>
963 HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) {
964  assert_is_lvalue_if_expr<A>();
965  assert_is_lvalue_if_expr<B>();
966  return IRMatcher::operator+(a, b);
967 }
968 
969 template<>
971  t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
972  int dead_bits = 64 - t.bits;
973  // Drop the high bits then sign-extend them back
974  return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
975 }
976 
977 template<>
979  uint64_t ones = (uint64_t)(-1);
980  return (a + b) & (ones >> (64 - t.bits));
981 }
982 
983 template<>
984 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
985  return a + b;
986 }
987 
988 template<typename A, typename B>
989 HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
990  assert_is_lvalue_if_expr<A>();
991  assert_is_lvalue_if_expr<B>();
992  return {pattern_arg(a), pattern_arg(b)};
993 }
994 
995 template<typename A, typename B>
996 HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) {
997  assert_is_lvalue_if_expr<A>();
998  assert_is_lvalue_if_expr<B>();
999  return IRMatcher::operator-(a, b);
1000 }
1001 
1002 template<>
1004  t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
1005  // Drop the high bits then sign-extend them back
1006  int dead_bits = 64 - t.bits;
1007  return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
1008 }
1009 
1010 template<>
1012  uint64_t ones = (uint64_t)(-1);
1013  return (a - b) & (ones >> (64 - t.bits));
1014 }
1015 
1016 template<>
1017 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
1018  return a - b;
1019 }
1020 
1021 template<typename A, typename B>
1022 HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1023  assert_is_lvalue_if_expr<A>();
1024  assert_is_lvalue_if_expr<B>();
1025  return {pattern_arg(a), pattern_arg(b)};
1026 }
1027 
1028 template<typename A, typename B>
1029 HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) {
1030  assert_is_lvalue_if_expr<A>();
1031  assert_is_lvalue_if_expr<B>();
1032  return IRMatcher::operator*(a, b);
1033 }
1034 
1035 template<>
1037  t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
1038  int dead_bits = 64 - t.bits;
1039  // Drop the high bits then sign-extend them back
1040  return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
1041 }
1042 
1043 template<>
1045  uint64_t ones = (uint64_t)(-1);
1046  return (a * b) & (ones >> (64 - t.bits));
1047 }
1048 
1049 template<>
1050 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
1051  return a * b;
1052 }
1053 
1054 template<typename A, typename B>
1055 HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1056  assert_is_lvalue_if_expr<A>();
1057  assert_is_lvalue_if_expr<B>();
1058  return {pattern_arg(a), pattern_arg(b)};
1059 }
1060 
1061 template<typename A, typename B>
1062 HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
1063  return IRMatcher::operator/(a, b);
1064 }
1065 
1066 template<>
1068  return div_imp(a, b);
1069 }
1070 
1071 template<>
1073  return div_imp(a, b);
1074 }
1075 
1076 template<>
1077 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
1078  return div_imp(a, b);
1079 }
1080 
1081 template<typename A, typename B>
1082 HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1083  assert_is_lvalue_if_expr<A>();
1084  assert_is_lvalue_if_expr<B>();
1085  return {pattern_arg(a), pattern_arg(b)};
1086 }
1087 
1088 template<typename A, typename B>
1089 HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
1090  assert_is_lvalue_if_expr<A>();
1091  assert_is_lvalue_if_expr<B>();
1092  return IRMatcher::operator%(a, b);
1093 }
1094 
1095 template<>
1097  return mod_imp(a, b);
1098 }
1099 
1100 template<>
1102  return mod_imp(a, b);
1103 }
1104 
1105 template<>
1106 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
1107  return mod_imp(a, b);
1108 }
1109 
1110 template<typename A, typename B>
1111 HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1112  assert_is_lvalue_if_expr<A>();
1113  assert_is_lvalue_if_expr<B>();
1114  return {pattern_arg(a), pattern_arg(b)};
1115 }
1116 
1117 template<>
1119  return std::min(a, b);
1120 }
1121 
1122 template<>
1124  return std::min(a, b);
1125 }
1126 
1127 template<>
1128 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
1129  return std::min(a, b);
1130 }
1131 
1132 template<typename A, typename B>
1133 HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1134  assert_is_lvalue_if_expr<A>();
1135  assert_is_lvalue_if_expr<B>();
1136  return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
1137 }
1138 
1139 template<>
1141  return std::max(a, b);
1142 }
1143 
1144 template<>
1146  return std::max(a, b);
1147 }
1148 
1149 template<>
1150 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
1151  return std::max(a, b);
1152 }
1153 
1154 template<typename A, typename B>
1155 HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1156  return {pattern_arg(a), pattern_arg(b)};
1157 }
1158 
1159 template<typename A, typename B>
1160 HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
1161  return IRMatcher::operator<(a, b);
1162 }
1163 
1164 template<>
1166  return a < b;
1167 }
1168 
1169 template<>
1171  return a < b;
1172 }
1173 
1174 template<>
1176  return a < b;
1177 }
1178 
1179 template<typename A, typename B>
1180 HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1181  return {pattern_arg(a), pattern_arg(b)};
1182 }
1183 
1184 template<typename A, typename B>
1185 HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
1186  return IRMatcher::operator>(a, b);
1187 }
1188 
1189 template<>
1191  return a > b;
1192 }
1193 
1194 template<>
1196  return a > b;
1197 }
1198 
1199 template<>
1201  return a > b;
1202 }
1203 
1204 template<typename A, typename B>
1205 HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1206  return {pattern_arg(a), pattern_arg(b)};
1207 }
1208 
1209 template<typename A, typename B>
1210 HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
1211  return IRMatcher::operator<=(a, b);
1212 }
1213 
1214 template<>
1216  return a <= b;
1217 }
1218 
1219 template<>
1221  return a <= b;
1222 }
1223 
1224 template<>
1226  return a <= b;
1227 }
1228 
1229 template<typename A, typename B>
1230 HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1231  return {pattern_arg(a), pattern_arg(b)};
1232 }
1233 
1234 template<typename A, typename B>
1235 HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
1236  return IRMatcher::operator>=(a, b);
1237 }
1238 
1239 template<>
1241  return a >= b;
1242 }
1243 
1244 template<>
1246  return a >= b;
1247 }
1248 
1249 template<>
1251  return a >= b;
1252 }
1253 
1254 template<typename A, typename B>
1255 HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1256  return {pattern_arg(a), pattern_arg(b)};
1257 }
1258 
1259 template<typename A, typename B>
1260 HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
1261  return IRMatcher::operator==(a, b);
1262 }
1263 
1264 template<>
1266  return a == b;
1267 }
1268 
1269 template<>
1271  return a == b;
1272 }
1273 
1274 template<>
1276  return a == b;
1277 }
1278 
1279 template<typename A, typename B>
1280 HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1281  return {pattern_arg(a), pattern_arg(b)};
1282 }
1283 
1284 template<typename A, typename B>
1285 HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
1286  return IRMatcher::operator!=(a, b);
1287 }
1288 
1289 template<>
1291  return a != b;
1292 }
1293 
1294 template<>
1296  return a != b;
1297 }
1298 
1299 template<>
1301  return a != b;
1302 }
1303 
1304 template<typename A, typename B>
1305 HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1306  return {pattern_arg(a), pattern_arg(b)};
1307 }
1308 
1309 template<typename A, typename B>
1310 HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
1311  return IRMatcher::operator||(a, b);
1312 }
1313 
1314 template<>
1316  return (a | b) & 1;
1317 }
1318 
1319 template<>
1321  return (a | b) & 1;
1322 }
1323 
1324 template<>
1325 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
1326  // Unreachable, as it would be a type mismatch.
1327  return 0;
1328 }
1329 
1330 template<typename A, typename B>
1331 HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1332  return {pattern_arg(a), pattern_arg(b)};
1333 }
1334 
1335 template<typename A, typename B>
1336 HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
1337  return IRMatcher::operator&&(a, b);
1338 }
1339 
1340 template<>
1342  return a & b & 1;
1343 }
1344 
1345 template<>
1347  return a & b & 1;
1348 }
1349 
1350 template<>
1351 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
1352  // Unreachable
1353  return 0;
1354 }
1355 
1356 constexpr inline uint32_t bitwise_or_reduce() {
1357  return 0;
1358 }
1359 
1360 template<typename... Args>
1361 constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
1362  return first | bitwise_or_reduce(rest...);
1363 }
1364 
1365 constexpr inline bool and_reduce() {
1366  return true;
1367 }
1368 
1369 template<typename... Args>
1370 constexpr bool and_reduce(bool first, Args... rest) {
1371  return first && and_reduce(rest...);
1372 }
1373 
1374 // TODO: this can be replaced with std::min() once we require C++14 or later
1375 constexpr int const_min(int a, int b) {
1376  return a < b ? a : b;
1377 }
1378 
1379 template<typename... Args>
1380 struct Intrin {
1381  struct pattern_tag {};
1383  std::tuple<Args...> args;
1384 
1386 
1389  constexpr static bool canonical = and_reduce((Args::canonical)...);
1390 
1391  template<int i,
1392  uint32_t bound,
1393  typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1394  HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
1395  using T = decltype(std::get<i>(args));
1396  return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
1397  match_args<i + 1, bound | bindings<T>::mask>(0, c, state));
1398  }
1399 
1400  template<int i, uint32_t binds>
1401  HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
1402  return true;
1403  }
1404 
1405  template<uint32_t bound>
1406  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1407  if (e.node_type != IRNodeType::Call) {
1408  return false;
1409  }
1410  const Call &c = (const Call &)e;
1411  return (c.is_intrinsic(intrin) && match_args<0, bound>(0, c, state));
1412  }
1413 
1414  template<int i,
1415  typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1416  HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
1417  s << std::get<i>(args);
1418  if (i + 1 < sizeof...(Args)) {
1419  s << ", ";
1420  }
1421  print_args<i + 1>(0, s);
1422  }
1423 
1424  template<int i>
1425  HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
1426  }
1427 
1429  void print_args(std::ostream &s) const {
1430  print_args<0>(0, s);
1431  }
1432 
1434  Expr make(MatcherState &state, halide_type_t type_hint) const {
1435  Expr arg0 = std::get<0>(args).make(state, type_hint);
1436  if (intrin == Call::likely) {
1437  return likely(arg0);
1438  } else if (intrin == Call::likely_if_innermost) {
1439  return likely_if_innermost(arg0);
1440  } else if (intrin == Call::abs) {
1441  return abs(arg0);
1442  }
1443 
1444  Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
1445  if (intrin == Call::absd) {
1446  return absd(arg0, arg1);
1447  } else if (intrin == Call::widening_add) {
1448  return widening_add(arg0, arg1);
1449  } else if (intrin == Call::widening_sub) {
1450  return widening_sub(arg0, arg1);
1451  } else if (intrin == Call::widening_mul) {
1452  return widening_mul(arg0, arg1);
1453  } else if (intrin == Call::saturating_add) {
1454  return saturating_add(arg0, arg1);
1455  } else if (intrin == Call::saturating_sub) {
1456  return saturating_sub(arg0, arg1);
1457  } else if (intrin == Call::halving_add) {
1458  return halving_add(arg0, arg1);
1459  } else if (intrin == Call::halving_sub) {
1460  return halving_sub(arg0, arg1);
1461  } else if (intrin == Call::rounding_halving_add) {
1462  return rounding_halving_add(arg0, arg1);
1463  } else if (intrin == Call::rounding_halving_sub) {
1464  return rounding_halving_sub(arg0, arg1);
1465  } else if (intrin == Call::shift_left) {
1466  return arg0 << arg1;
1467  } else if (intrin == Call::shift_right) {
1468  return arg0 >> arg1;
1469  } else if (intrin == Call::rounding_shift_left) {
1470  return rounding_shift_left(arg0, arg1);
1471  } else if (intrin == Call::rounding_shift_right) {
1472  return rounding_shift_right(arg0, arg1);
1473  }
1474 
1475  Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
1476  if (intrin == Call::mul_shift_right) {
1477  return mul_shift_right(arg0, arg1, arg2);
1478  } else if (intrin == Call::rounding_mul_shift_right) {
1479  return rounding_mul_shift_right(arg0, arg1, arg2);
1480  }
1481 
1482  internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
1483  return Expr();
1484  }
1485 
1486  constexpr static bool foldable = false;
1487 
1490  : intrin(intrin), args(args...) {
1491  }
1492 };
1493 
1494 template<typename... Args>
1495 std::ostream &operator<<(std::ostream &s, const Intrin<Args...> &op) {
1496  s << op.intrin << "(";
1497  op.print_args(s);
1498  s << ")";
1499  return s;
1500 }
1501 
1502 template<typename... Args>
1503 HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
1504  return {intrinsic_op, pattern_arg(args)...};
1505 }
1506 
1507 template<typename A, typename B>
1508 auto widening_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1509  return {Call::widening_add, pattern_arg(a), pattern_arg(b)};
1510 }
1511 template<typename A, typename B>
1512 auto widening_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1513  return {Call::widening_sub, pattern_arg(a), pattern_arg(b)};
1514 }
1515 template<typename A, typename B>
1516 auto widening_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1517  return {Call::widening_mul, pattern_arg(a), pattern_arg(b)};
1518 }
1519 template<typename A, typename B>
1520 auto saturating_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1521  return {Call::saturating_add, pattern_arg(a), pattern_arg(b)};
1522 }
1523 template<typename A, typename B>
1524 auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1525  return {Call::saturating_sub, pattern_arg(a), pattern_arg(b)};
1526 }
1527 template<typename A, typename B>
1528 auto halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1529  return {Call::halving_add, pattern_arg(a), pattern_arg(b)};
1530 }
1531 template<typename A, typename B>
1532 auto halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1533  return {Call::halving_sub, pattern_arg(a), pattern_arg(b)};
1534 }
1535 template<typename A, typename B>
1536 auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1538 }
1539 template<typename A, typename B>
1540 auto rounding_halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1542 }
1543 template<typename A, typename B>
1544 auto shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1545  return {Call::shift_left, pattern_arg(a), pattern_arg(b)};
1546 }
1547 template<typename A, typename B>
1548 auto shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1549  return {Call::shift_right, pattern_arg(a), pattern_arg(b)};
1550 }
1551 template<typename A, typename B>
1552 auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1554 }
1555 template<typename A, typename B>
1556 auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1558 }
1559 template<typename A, typename B, typename C>
1560 auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1562 }
1563 template<typename A, typename B, typename C>
1564 auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1566 }
1567 
1568 template<typename A>
1569 struct NotOp {
1570  struct pattern_tag {};
1571  A a;
1572 
1573  constexpr static uint32_t binds = bindings<A>::mask;
1574 
1577  constexpr static bool canonical = A::canonical;
1578 
1579  template<uint32_t bound>
1580  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1581  if (e.node_type != IRNodeType::Not) {
1582  return false;
1583  }
1584  const Not &op = (const Not &)e;
1585  return (a.template match<bound>(*op.a.get(), state));
1586  }
1587 
1588  template<uint32_t bound, typename A2>
1589  HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
1590  return a.template match<bound>(unwrap(op.a), state);
1591  }
1592 
1594  Expr make(MatcherState &state, halide_type_t type_hint) const {
1595  return Not::make(a.make(state, type_hint));
1596  }
1597 
1598  constexpr static bool foldable = A::foldable;
1599 
1600  template<typename A1 = A>
1602  a.make_folded_const(val, ty, state);
1603  val.u.u64 = ~val.u.u64;
1604  val.u.u64 &= 1;
1605  }
1606 };
1607 
1608 template<typename A>
1609 HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
1610  assert_is_lvalue_if_expr<A>();
1611  return {pattern_arg(a)};
1612 }
1613 
1614 template<typename A>
1615 HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a)) {
1616  assert_is_lvalue_if_expr<A>();
1617  return IRMatcher::operator!(a);
1618 }
1619 
1620 template<typename A>
1621 inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1622  s << "!(" << op.a << ")";
1623  return s;
1624 }
1625 
1626 template<typename C, typename T, typename F>
1627 struct SelectOp {
1628  struct pattern_tag {};
1629  C c;
1630  T t;
1631  F f;
1632 
1634 
1637 
1638  constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1639 
1640  template<uint32_t bound>
1641  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1642  if (e.node_type != Select::_node_type) {
1643  return false;
1644  }
1645  const Select &op = (const Select &)e;
1646  return (c.template match<bound>(*op.condition.get(), state) &&
1647  t.template match<bound | bindings<C>::mask>(*op.true_value.get(), state) &&
1648  f.template match<bound | bindings<C>::mask | bindings<T>::mask>(*op.false_value.get(), state));
1649  }
1650  template<uint32_t bound, typename C2, typename T2, typename F2>
1651  HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
1652  return (c.template match<bound>(unwrap(instance.c), state) &&
1653  t.template match<bound | bindings<C>::mask>(unwrap(instance.t), state) &&
1654  f.template match<bound | bindings<C>::mask | bindings<T>::mask>(unwrap(instance.f), state));
1655  }
1656 
1658  Expr make(MatcherState &state, halide_type_t type_hint) const {
1659  return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
1660  }
1661 
1662  constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1663 
1664  template<typename C1 = C>
1666  halide_scalar_value_t c_val, t_val, f_val;
1667  halide_type_t c_ty;
1668  c.make_folded_const(c_val, c_ty, state);
1669  if ((c_val.u.u64 & 1) == 1) {
1670  t.make_folded_const(val, ty, state);
1671  } else {
1672  f.make_folded_const(val, ty, state);
1673  }
1674  ty.lanes |= c_ty.lanes & MatcherState::special_values_mask;
1675  }
1676 };
1677 
1678 template<typename C, typename T, typename F>
1679 std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1680  s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
1681  return s;
1682 }
1683 
1684 template<typename C, typename T, typename F>
1685 HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
1686  assert_is_lvalue_if_expr<C>();
1687  assert_is_lvalue_if_expr<T>();
1688  assert_is_lvalue_if_expr<F>();
1689  return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
1690 }
1691 
1692 template<typename A, typename B>
1693 struct BroadcastOp {
1694  struct pattern_tag {};
1695  A a;
1697 
1699 
1702 
1703  constexpr static bool canonical = A::canonical && B::canonical;
1704 
1705  template<uint32_t bound>
1706  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1707  if (e.node_type == Broadcast::_node_type) {
1708  const Broadcast &op = (const Broadcast &)e;
1709  if (a.template match<bound>(*op.value.get(), state) &&
1710  lanes.template match<bound>(op.lanes, state)) {
1711  return true;
1712  }
1713  }
1714  return false;
1715  }
1716 
1717  template<uint32_t bound, typename A2, typename B2>
1718  HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
1719  return (a.template match<bound>(unwrap(op.a), state) &&
1720  lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1721  }
1722 
1724  Expr make(MatcherState &state, halide_type_t type_hint) const {
1725  halide_scalar_value_t lanes_val;
1726  halide_type_t ty;
1727  lanes.make_folded_const(lanes_val, ty, state);
1728  int32_t l = (int32_t)lanes_val.u.i64;
1729  type_hint.lanes /= l;
1730  Expr val = a.make(state, type_hint);
1731  if (l == 1) {
1732  return val;
1733  } else {
1734  return Broadcast::make(std::move(val), l);
1735  }
1736  }
1737 
1738  constexpr static bool foldable = false;
1739 
1740  template<typename A1 = A>
1742  halide_scalar_value_t lanes_val;
1743  halide_type_t lanes_ty;
1744  lanes.make_folded_const(lanes_val, lanes_ty, state);
1745  uint16_t l = (uint16_t)lanes_val.u.i64;
1746  a.make_folded_const(val, ty, state);
1747  ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
1748  }
1749 };
1750 
1751 template<typename A, typename B>
1752 inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1753  s << "broadcast(" << op.a << ", " << op.lanes << ")";
1754  return s;
1755 }
1756 
1757 template<typename A, typename B>
1758 HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
1759  assert_is_lvalue_if_expr<A>();
1760  return {pattern_arg(a), pattern_arg(lanes)};
1761 }
1762 
1763 template<typename A, typename B, typename C>
1764 struct RampOp {
1765  struct pattern_tag {};
1766  A a;
1767  B b;
1769 
1771 
1774 
1775  constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1776 
1777  template<uint32_t bound>
1778  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1779  if (e.node_type != Ramp::_node_type) {
1780  return false;
1781  }
1782  const Ramp &op = (const Ramp &)e;
1783  if (a.template match<bound>(*op.base.get(), state) &&
1784  b.template match<bound | bindings<A>::mask>(*op.stride.get(), state) &&
1785  lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(op.lanes, state)) {
1786  return true;
1787  } else {
1788  return false;
1789  }
1790  }
1791 
1792  template<uint32_t bound, typename A2, typename B2, typename C2>
1793  HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
1794  return (a.template match<bound>(unwrap(op.a), state) &&
1795  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state) &&
1796  lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(unwrap(op.lanes), state));
1797  }
1798 
1800  Expr make(MatcherState &state, halide_type_t type_hint) const {
1801  halide_scalar_value_t lanes_val;
1802  halide_type_t ty;
1803  lanes.make_folded_const(lanes_val, ty, state);
1804  int32_t l = (int32_t)lanes_val.u.i64;
1805  type_hint.lanes /= l;
1806  Expr ea, eb;
1807  eb = b.make(state, type_hint);
1808  ea = a.make(state, eb.type());
1809  return Ramp::make(ea, eb, l);
1810  }
1811 
1812  constexpr static bool foldable = false;
1813 };
1814 
1815 template<typename A, typename B, typename C>
1816 std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1817  s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
1818  return s;
1819 }
1820 
1821 template<typename A, typename B, typename C>
1822 HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1823  assert_is_lvalue_if_expr<A>();
1824  assert_is_lvalue_if_expr<B>();
1825  assert_is_lvalue_if_expr<C>();
1826  return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1827 }
1828 
1829 template<typename A, typename B, VectorReduce::Operator reduce_op>
1831  struct pattern_tag {};
1832  A a;
1834 
1835  constexpr static uint32_t binds = bindings<A>::mask;
1836 
1839  constexpr static bool canonical = A::canonical;
1840 
1841  template<uint32_t bound>
1842  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1843  if (e.node_type == VectorReduce::_node_type) {
1844  const VectorReduce &op = (const VectorReduce &)e;
1845  if (op.op == reduce_op &&
1846  a.template match<bound>(*op.value.get(), state) &&
1847  lanes.template match<bound | bindings<A>::mask>(op.type.lanes(), state)) {
1848  return true;
1849  }
1850  }
1851  return false;
1852  }
1853 
1854  template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
1856  return (reduce_op == reduce_op_2 &&
1857  a.template match<bound>(unwrap(op.a), state) &&
1858  lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1859  }
1860 
1862  Expr make(MatcherState &state, halide_type_t type_hint) const {
1863  halide_scalar_value_t lanes_val;
1864  halide_type_t ty;
1865  lanes.make_folded_const(lanes_val, ty, state);
1866  int l = (int)lanes_val.u.i64;
1867  return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
1868  }
1869 
1870  constexpr static bool foldable = false;
1871 };
1872 
1873 template<typename A, typename B, VectorReduce::Operator reduce_op>
1874 inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1875  s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
1876  return s;
1877 }
1878 
1879 template<typename A, typename B>
1880 HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
1881  assert_is_lvalue_if_expr<A>();
1882  return {pattern_arg(a), pattern_arg(lanes)};
1883 }
1884 
1885 template<typename A, typename B>
1886 HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
1887  assert_is_lvalue_if_expr<A>();
1888  return {pattern_arg(a), pattern_arg(lanes)};
1889 }
1890 
1891 template<typename A, typename B>
1892 HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
1893  assert_is_lvalue_if_expr<A>();
1894  return {pattern_arg(a), pattern_arg(lanes)};
1895 }
1896 
1897 template<typename A, typename B>
1898 HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
1899  assert_is_lvalue_if_expr<A>();
1900  return {pattern_arg(a), pattern_arg(lanes)};
1901 }
1902 
1903 template<typename A, typename B>
1904 HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
1905  assert_is_lvalue_if_expr<A>();
1906  return {pattern_arg(a), pattern_arg(lanes)};
1907 }
1908 
1909 template<typename A>
1910 struct NegateOp {
1911  struct pattern_tag {};
1912  A a;
1913 
1914  constexpr static uint32_t binds = bindings<A>::mask;
1915 
1918 
1919  constexpr static bool canonical = A::canonical;
1920 
1921  template<uint32_t bound>
1922  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1923  if (e.node_type != Sub::_node_type) {
1924  return false;
1925  }
1926  const Sub &op = (const Sub &)e;
1927  return (a.template match<bound>(*op.b.get(), state) &&
1928  is_const_zero(op.a));
1929  }
1930 
1931  template<uint32_t bound, typename A2>
1932  HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
1933  return a.template match<bound>(unwrap(p.a), state);
1934  }
1935 
1937  Expr make(MatcherState &state, halide_type_t type_hint) const {
1938  Expr ea = a.make(state, type_hint);
1939  Expr z = make_zero(ea.type());
1940  return Sub::make(std::move(z), std::move(ea));
1941  }
1942 
1943  constexpr static bool foldable = A::foldable;
1944 
1945  template<typename A1 = A>
1947  a.make_folded_const(val, ty, state);
1948  int dead_bits = 64 - ty.bits;
1949  switch (ty.code) {
1950  case halide_type_int:
1951  if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
1952  // Trying to negate the most negative signed int for a no-overflow type.
1954  } else {
1955  // Negate, drop the high bits, and then sign-extend them back
1956  val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
1957  }
1958  break;
1959  case halide_type_uint:
1960  val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
1961  break;
1962  case halide_type_float:
1963  case halide_type_bfloat:
1964  val.u.f64 = -val.u.f64;
1965  break;
1966  default:
1967  // unreachable
1968  ;
1969  }
1970  }
1971 };
1972 
1973 template<typename A>
1974 std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
1975  s << "-" << op.a;
1976  return s;
1977 }
1978 
1979 template<typename A>
1980 HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
1981  assert_is_lvalue_if_expr<A>();
1982  return {pattern_arg(a)};
1983 }
1984 
1985 template<typename A>
1986 HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a)) {
1987  assert_is_lvalue_if_expr<A>();
1988  return IRMatcher::operator-(a);
1989 }
1990 
1991 template<typename A>
1992 struct CastOp {
1993  struct pattern_tag {};
1995  A a;
1996 
1997  constexpr static uint32_t binds = bindings<A>::mask;
1998 
2001  constexpr static bool canonical = A::canonical;
2002 
2003  template<uint32_t bound>
2004  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2005  if (e.node_type != Cast::_node_type) {
2006  return false;
2007  }
2008  const Cast &op = (const Cast &)e;
2009  return (e.type == t &&
2010  a.template match<bound>(*op.value.get(), state));
2011  }
2012  template<uint32_t bound, typename A2>
2013  HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
2014  return t == op.t && a.template match<bound>(unwrap(op.a), state);
2015  }
2016 
2018  Expr make(MatcherState &state, halide_type_t type_hint) const {
2019  return cast(t, a.make(state, {}));
2020  }
2021 
2022  constexpr static bool foldable = false;
2023 };
2024 
2025 template<typename A>
2026 std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2027  s << "cast(" << op.t << ", " << op.a << ")";
2028  return s;
2029 }
2030 
2031 template<typename A>
2032 HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
2033  assert_is_lvalue_if_expr<A>();
2034  return {t, pattern_arg(a)};
2035 }
2036 
2037 template<typename A>
2038 struct Fold {
2039  struct pattern_tag {};
2040  A a;
2041 
2042  constexpr static uint32_t binds = bindings<A>::mask;
2043 
2046  constexpr static bool canonical = true;
2047 
2049  Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
2051  halide_type_t ty = type_hint;
2052  a.make_folded_const(c, ty, state);
2053 
2054  // The result of the fold may have an underspecified type
2055  // (e.g. because it's from an int literal). Make the type code
2056  // and bits match the required type, if there is one (we can
2057  // tell from the bits field).
2058  if (type_hint.bits) {
2059  if (((int)ty.code == (int)halide_type_int) &&
2060  ((int)type_hint.code == (int)halide_type_float)) {
2061  int64_t x = c.u.i64;
2062  c.u.f64 = (double)x;
2063  }
2064  ty.code = type_hint.code;
2065  ty.bits = type_hint.bits;
2066  }
2067 
2068  Expr e = make_const_expr(c, ty);
2069  return e;
2070  }
2071 
2072  constexpr static bool foldable = A::foldable;
2073 
2074  template<typename A1 = A>
2076  a.make_folded_const(val, ty, state);
2077  }
2078 };
2079 
2080 template<typename A>
2081 HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
2082  assert_is_lvalue_if_expr<A>();
2083  return {pattern_arg(a)};
2084 }
2085 
2086 template<typename A>
2087 std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2088  s << "fold(" << op.a << ")";
2089  return s;
2090 }
2091 
2092 template<typename A>
2093 struct Overflows {
2094  struct pattern_tag {};
2095  A a;
2096 
2097  constexpr static uint32_t binds = bindings<A>::mask;
2098 
2099  // This rule is a predicate, so it always evaluates to a boolean,
2100  // which has IRNodeType UIntImm
2103  constexpr static bool canonical = true;
2104 
2105  constexpr static bool foldable = A::foldable;
2106 
2107  template<typename A1 = A>
2109  a.make_folded_const(val, ty, state);
2110  ty.code = halide_type_uint;
2111  ty.bits = 64;
2112  val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
2113  ty.lanes = 1;
2114  }
2115 };
2116 
2117 template<typename A>
2118 HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
2119  assert_is_lvalue_if_expr<A>();
2120  return {pattern_arg(a)};
2121 }
2122 
2123 template<typename A>
2124 std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2125  s << "overflows(" << op.a << ")";
2126  return s;
2127 }
2128 
2129 struct Overflow {
2130  struct pattern_tag {};
2131 
2132  constexpr static uint32_t binds = 0;
2133 
2134  // Overflow is an intrinsic, represented as a Call node
2137  constexpr static bool canonical = true;
2138 
2139  template<uint32_t bound>
2140  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2141  if (e.node_type != Call::_node_type) {
2142  return false;
2143  }
2144  const Call &op = (const Call &)e;
2146  }
2147 
2149  Expr make(MatcherState &state, halide_type_t type_hint) const {
2151  return make_const_special_expr(type_hint);
2152  }
2153 
2154  constexpr static bool foldable = true;
2155 
2157  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
2158  val.u.u64 = 0;
2160  }
2161 };
2162 
2163 inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
2164  s << "overflow()";
2165  return s;
2166 }
2167 
2168 template<typename A>
2169 struct IsConst {
2170  struct pattern_tag {};
2171 
2172  constexpr static uint32_t binds = bindings<A>::mask;
2173 
2174  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2177  constexpr static bool canonical = true;
2178 
2179  A a;
2180  bool check_v;
2182 
2183  constexpr static bool foldable = true;
2184 
2185  template<typename A1 = A>
2187  Expr e = a.make(state, {});
2188  ty.code = halide_type_uint;
2189  ty.bits = 64;
2190  ty.lanes = 1;
2191  if (check_v) {
2192  val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
2193  } else {
2194  val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
2195  }
2196  }
2197 };
2198 
2199 template<typename A>
2200 HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
2201  assert_is_lvalue_if_expr<A>();
2202  return {pattern_arg(a), false, 0};
2203 }
2204 
2205 template<typename A>
2206 HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst<decltype(pattern_arg(a))> {
2207  assert_is_lvalue_if_expr<A>();
2208  return {pattern_arg(a), true, value};
2209 }
2210 
2211 template<typename A>
2212 std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2213  if (op.check_v) {
2214  s << "is_const(" << op.a << ")";
2215  } else {
2216  s << "is_const(" << op.a << ", " << op.v << ")";
2217  }
2218  return s;
2219 }
2220 
2221 template<typename A, typename Prover>
2222 struct CanProve {
2223  struct pattern_tag {};
2224  A a;
2225  Prover *prover; // An existing simplifying mutator
2226 
2227  constexpr static uint32_t binds = bindings<A>::mask;
2228 
2229  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2232  constexpr static bool canonical = true;
2233 
2234  constexpr static bool foldable = true;
2235 
2236  // Includes a raw call to an inlined make method, so don't inline.
2238  Expr condition = a.make(state, {});
2239  condition = prover->mutate(condition, nullptr);
2240  val.u.u64 = is_const_one(condition);
2241  ty.code = halide_type_uint;
2242  ty.bits = 1;
2243  ty.lanes = condition.type().lanes();
2244  }
2245 };
2246 
2247 template<typename A, typename Prover>
2248 HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
2249  assert_is_lvalue_if_expr<A>();
2250  return {pattern_arg(a), p};
2251 }
2252 
2253 template<typename A, typename Prover>
2254 std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2255  s << "can_prove(" << op.a << ")";
2256  return s;
2257 }
2258 
2259 template<typename A>
2260 struct IsFloat {
2261  struct pattern_tag {};
2262  A a;
2263 
2264  constexpr static uint32_t binds = bindings<A>::mask;
2265 
2266  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2269  constexpr static bool canonical = true;
2270 
2271  constexpr static bool foldable = true;
2272 
2275  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2276  Type t = a.make(state, {}).type();
2277  val.u.u64 = t.is_float();
2278  ty.code = halide_type_uint;
2279  ty.bits = 1;
2280  ty.lanes = t.lanes();
2281  }
2282 };
2283 
2284 template<typename A>
2285 HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
2286  assert_is_lvalue_if_expr<A>();
2287  return {pattern_arg(a)};
2288 }
2289 
2290 template<typename A>
2291 std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2292  s << "is_float(" << op.a << ")";
2293  return s;
2294 }
2295 
2296 template<typename A>
2297 struct IsInt {
2298  struct pattern_tag {};
2299  A a;
2300  int bits;
2301 
2302  constexpr static uint32_t binds = bindings<A>::mask;
2303 
2304  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2307  constexpr static bool canonical = true;
2308 
2309  constexpr static bool foldable = true;
2310 
2313  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2314  Type t = a.make(state, {}).type();
2315  val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits);
2316  ty.code = halide_type_uint;
2317  ty.bits = 1;
2318  ty.lanes = t.lanes();
2319  }
2320 };
2321 
2322 template<typename A>
2323 HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
2324  assert_is_lvalue_if_expr<A>();
2325  return {pattern_arg(a), bits};
2326 }
2327 
2328 template<typename A>
2329 std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2330  s << "is_int(" << op.a;
2331  if (op.bits > 0) {
2332  s << ", " << op.bits;
2333  }
2334  s << ")";
2335  return s;
2336 }
2337 
2338 template<typename A>
2339 struct IsUInt {
2340  struct pattern_tag {};
2341  A a;
2342  int bits;
2343 
2344  constexpr static uint32_t binds = bindings<A>::mask;
2345 
2346  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2349  constexpr static bool canonical = true;
2350 
2351  constexpr static bool foldable = true;
2352 
2355  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2356  Type t = a.make(state, {}).type();
2357  val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits);
2358  ty.code = halide_type_uint;
2359  ty.bits = 1;
2360  ty.lanes = t.lanes();
2361  }
2362 };
2363 
2364 template<typename A>
2365 HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2366  assert_is_lvalue_if_expr<A>();
2367  return {pattern_arg(a), bits};
2368 }
2369 
2370 template<typename A>
2371 std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2372  s << "is_uint(" << op.a;
2373  if (op.bits > 0) {
2374  s << ", " << op.bits;
2375  }
2376  s << ")";
2377  return s;
2378 }
2379 
2380 template<typename A>
2381 struct IsScalar {
2382  struct pattern_tag {};
2383  A a;
2384 
2385  constexpr static uint32_t binds = bindings<A>::mask;
2386 
2387  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2390  constexpr static bool canonical = true;
2391 
2392  constexpr static bool foldable = true;
2393 
2396  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2397  Type t = a.make(state, {}).type();
2398  val.u.u64 = t.is_scalar();
2399  ty.code = halide_type_uint;
2400  ty.bits = 1;
2401  ty.lanes = t.lanes();
2402  }
2403 };
2404 
2405 template<typename A>
2406 HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
2407  assert_is_lvalue_if_expr<A>();
2408  return {pattern_arg(a)};
2409 }
2410 
2411 template<typename A>
2412 struct IsMaxValue {
2413  struct pattern_tag {};
2414  A a;
2415 
2416  constexpr static uint32_t binds = bindings<A>::mask;
2417 
2418  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2421  constexpr static bool canonical = true;
2422 
2423  constexpr static bool foldable = true;
2424 
2427  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2428  a.make_folded_const(val, ty, state);
2429  const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
2430  if (ty.code == halide_type_uint || ty.code == halide_type_int) {
2431  val.u.u64 = (val.u.u64 == max_bits);
2432  } else {
2433  val.u.u64 = 0;
2434  }
2435  ty.code = halide_type_uint;
2436  ty.bits = 1;
2437  }
2438 };
2439 
2440 template<typename A>
2441 HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
2442  assert_is_lvalue_if_expr<A>();
2443  return {pattern_arg(a)};
2444 }
2445 
2446 template<typename A>
2447 struct IsMinValue {
2448  struct pattern_tag {};
2449  A a;
2450 
2451  constexpr static uint32_t binds = bindings<A>::mask;
2452 
2453  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2456  constexpr static bool canonical = true;
2457 
2458  constexpr static bool foldable = true;
2459 
2462  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2463  a.make_folded_const(val, ty, state);
2464  if (ty.code == halide_type_int) {
2465  const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
2466  val.u.u64 = (val.u.u64 == min_bits);
2467  } else if (ty.code == halide_type_uint) {
2468  val.u.u64 = (val.u.u64 == 0);
2469  } else {
2470  val.u.u64 = 0;
2471  }
2472  ty.code = halide_type_uint;
2473  ty.bits = 1;
2474  }
2475 };
2476 
2477 template<typename A>
2478 HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
2479  assert_is_lvalue_if_expr<A>();
2480  return {pattern_arg(a)};
2481 }
2482 
2483 template<typename A>
2484 std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2485  s << "is_scalar(" << op.a << ")";
2486  return s;
2487 }
2488 
2489 // Verify properties of each rewrite rule. Currently just fuzz tests them.
2490 template<typename Before,
2491  typename After,
2492  typename Predicate,
2493  typename = typename std::enable_if<std::decay<Before>::type::foldable &&
2494  std::decay<After>::type::foldable>::type>
2495 HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2496  halide_type_t wildcard_type, halide_type_t output_type) noexcept {
2497 
2498  // We only validate the rules in the scalar case
2499  wildcard_type.lanes = output_type.lanes = 1;
2500 
2501  // Track which types this rule has been tested for before
2502  static std::set<uint32_t> tested;
2503 
2504  if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2505  return;
2506  }
2507 
2508  // Print it in a form where it can be piped into a python/z3 validator
2509  debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
2510 
2511  // Substitute some random constants into the before and after
2512  // expressions and see if the rule holds true. This should catch
2513  // silly errors, but not necessarily corner cases.
2514  static std::mt19937_64 rng(0);
2515  MatcherState state;
2516 
2517  Expr exprs[max_wild];
2518 
2519  for (int trials = 0; trials < 100; trials++) {
2520  // We want to test small constants more frequently than
2521  // large ones, otherwise we'll just get coverage of
2522  // overflow rules.
2523  int shift = (int)(rng() & (wildcard_type.bits - 1));
2524 
2525  for (int i = 0; i < max_wild; i++) {
2526  // Bind all the exprs and constants
2527  switch (wildcard_type.code) {
2528  case halide_type_uint: {
2529  // Normalize to the type's range by adding zero
2530  uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2531  state.set_bound_const(i, val, wildcard_type);
2532  val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2533  exprs[i] = make_const(wildcard_type, val);
2534  state.set_binding(i, *exprs[i].get());
2535  } break;
2536  case halide_type_int: {
2537  int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2538  state.set_bound_const(i, val, wildcard_type);
2539  val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2540  exprs[i] = make_const(wildcard_type, val);
2541  } break;
2542  case halide_type_float:
2543  case halide_type_bfloat: {
2544  // Use a very narrow range of precise floats, so
2545  // that none of the rules a human is likely to
2546  // write have instabilities.
2547  double val = ((int64_t)(rng() & 15) - 8) / 2.0;
2548  state.set_bound_const(i, val, wildcard_type);
2549  val = ((int64_t)(rng() & 15) - 8) / 2.0;
2550  exprs[i] = make_const(wildcard_type, val);
2551  } break;
2552  default:
2553  return; // Don't care about handles
2554  }
2555  state.set_binding(i, *exprs[i].get());
2556  }
2557 
2558  halide_scalar_value_t val_pred, val_before, val_after;
2559  halide_type_t type = output_type;
2560  if (!evaluate_predicate(pred, state)) {
2561  continue;
2562  }
2563  before.make_folded_const(val_before, type, state);
2564  uint16_t lanes = type.lanes;
2565  after.make_folded_const(val_after, type, state);
2566  lanes |= type.lanes;
2567 
2568  if (lanes & MatcherState::special_values_mask) {
2569  continue;
2570  }
2571 
2572  bool ok = true;
2573  switch (output_type.code) {
2574  case halide_type_uint:
2575  // Compare normalized representations
2576  ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
2577  constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
2578  break;
2579  case halide_type_int:
2580  ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
2581  constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
2582  break;
2583  case halide_type_float:
2584  case halide_type_bfloat: {
2585  double error = std::abs(val_before.u.f64 - val_after.u.f64);
2586  // We accept an equal bit pattern (e.g. inf vs inf),
2587  // a small floating point difference, or turning a nan into not-a-nan.
2588  ok &= (error < 0.01 ||
2589  val_before.u.u64 == val_after.u.u64 ||
2590  std::isnan(val_before.u.f64));
2591  break;
2592  }
2593  default:
2594  return;
2595  }
2596 
2597  if (!ok) {
2598  debug(0) << "Fails with values:\n";
2599  for (int i = 0; i < max_wild; i++) {
2601  state.get_bound_const(i, val, wildcard_type);
2602  debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
2603  }
2604  for (int i = 0; i < max_wild; i++) {
2605  debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
2606  }
2607  debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
2608  debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
2609  debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
2611  }
2612  }
2613 }
2614 
2615 template<typename Before,
2616  typename After,
2617  typename Predicate,
2618  typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
2619  std::decay<After>::type::foldable)>::type>
2620 HALIDE_ALWAYS_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2621  halide_type_t, halide_type_t, int dummy = 0) noexcept {
2622  // We can't verify rewrite rules that can't be constant-folded.
2623 }
2624 
2626 bool evaluate_predicate(bool x, MatcherState &) noexcept {
2627  return x;
2628 }
2629 
2630 template<typename Pattern,
2631  typename = typename enable_if_pattern<Pattern>::type>
2634  halide_type_t ty = halide_type_of<bool>();
2635  p.make_folded_const(c, ty, state);
2636  // Overflow counts as a failed predicate
2637  return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
2638 }
2639 
2640 // #defines for testing
2641 
2642 // Print all successful or failed matches
2643 #define HALIDE_DEBUG_MATCHED_RULES 0
2644 #define HALIDE_DEBUG_UNMATCHED_RULES 0
2645 
2646 // Set to true if you want to fuzz test every rewrite passed to
2647 // operator() to ensure the input and the output have the same value
2648 // for lots of random values of the wildcards. Run
2649 // correctness_simplify with this on.
2650 #define HALIDE_FUZZ_TEST_RULES 0
2651 
2652 template<typename Instance>
2653 struct Rewriter {
2654  Instance instance;
2658  bool validate;
2659 
2662  : instance(std::move(instance)), output_type(ot), wildcard_type(wt) {
2663  }
2664 
2665  template<typename After>
2667  result = after.make(state, output_type);
2668  }
2669 
2670  template<typename Before,
2671  typename After,
2672  typename = typename enable_if_pattern<Before>::type,
2673  typename = typename enable_if_pattern<After>::type>
2674  HALIDE_ALWAYS_INLINE bool operator()(Before before, After after) {
2675  static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2676  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2677  static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2678 #if HALIDE_FUZZ_TEST_RULES
2679  fuzz_test_rule(before, after, true, wildcard_type, output_type);
2680 #endif
2681  if (before.template match<0>(unwrap(instance), state)) {
2682 #if HALIDE_DEBUG_MATCHED_RULES
2683  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2684 #endif
2685  build_replacement(after);
2686  return true;
2687  } else {
2688 #if HALIDE_DEBUG_UNMATCHED_RULES
2689  debug(0) << instance << " does not match " << before << "\n";
2690 #endif
2691  return false;
2692  }
2693  }
2694 
2695  template<typename Before,
2696  typename = typename enable_if_pattern<Before>::type>
2697  HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept {
2698  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2699  if (before.template match<0>(unwrap(instance), state)) {
2700  result = after;
2701 #if HALIDE_DEBUG_MATCHED_RULES
2702  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2703 #endif
2704  return true;
2705  } else {
2706 #if HALIDE_DEBUG_UNMATCHED_RULES
2707  debug(0) << instance << " does not match " << before << "\n";
2708 #endif
2709  return false;
2710  }
2711  }
2712 
2713  template<typename Before,
2714  typename = typename enable_if_pattern<Before>::type>
2715  HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept {
2716  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2717 #if HALIDE_FUZZ_TEST_RULES
2718  fuzz_test_rule(before, IntLiteral(after), true, wildcard_type, output_type);
2719 #endif
2720  if (before.template match<0>(unwrap(instance), state)) {
2721  result = make_const(output_type, after);
2722 #if HALIDE_DEBUG_MATCHED_RULES
2723  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2724 #endif
2725  return true;
2726  } else {
2727 #if HALIDE_DEBUG_UNMATCHED_RULES
2728  debug(0) << instance << " does not match " << before << "\n";
2729 #endif
2730  return false;
2731  }
2732  }
2733 
2734  template<typename Before,
2735  typename After,
2736  typename Predicate,
2737  typename = typename enable_if_pattern<Before>::type,
2738  typename = typename enable_if_pattern<After>::type,
2739  typename = typename enable_if_pattern<Predicate>::type>
2740  HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred) {
2741  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2742  static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2743  static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
2744  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2745  static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2746 
2747 #if HALIDE_FUZZ_TEST_RULES
2748  fuzz_test_rule(before, after, pred, wildcard_type, output_type);
2749 #endif
2750  if (before.template match<0>(unwrap(instance), state) &&
2751  evaluate_predicate(pred, state)) {
2752 #if HALIDE_DEBUG_MATCHED_RULES
2753  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2754 #endif
2755  build_replacement(after);
2756  return true;
2757  } else {
2758 #if HALIDE_DEBUG_UNMATCHED_RULES
2759  debug(0) << instance << " does not match " << before << "\n";
2760 #endif
2761  return false;
2762  }
2763  }
2764 
2765  template<typename Before,
2766  typename Predicate,
2767  typename = typename enable_if_pattern<Before>::type,
2768  typename = typename enable_if_pattern<Predicate>::type>
2769  HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred) {
2770  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2771  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2772 
2773  if (before.template match<0>(unwrap(instance), state) &&
2774  evaluate_predicate(pred, state)) {
2775  result = after;
2776 #if HALIDE_DEBUG_MATCHED_RULES
2777  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2778 #endif
2779  return true;
2780  } else {
2781 #if HALIDE_DEBUG_UNMATCHED_RULES
2782  debug(0) << instance << " does not match " << before << "\n";
2783 #endif
2784  return false;
2785  }
2786  }
2787 
2788  template<typename Before,
2789  typename Predicate,
2790  typename = typename enable_if_pattern<Before>::type,
2791  typename = typename enable_if_pattern<Predicate>::type>
2792  HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred) {
2793  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2794  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2795 #if HALIDE_FUZZ_TEST_RULES
2796  fuzz_test_rule(before, IntLiteral(after), pred, wildcard_type, output_type);
2797 #endif
2798  if (before.template match<0>(unwrap(instance), state) &&
2799  evaluate_predicate(pred, state)) {
2800  result = make_const(output_type, after);
2801 #if HALIDE_DEBUG_MATCHED_RULES
2802  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2803 #endif
2804  return true;
2805  } else {
2806 #if HALIDE_DEBUG_UNMATCHED_RULES
2807  debug(0) << instance << " does not match " << before << "\n";
2808 #endif
2809  return false;
2810  }
2811  }
2812 };
2813 
2814 /** Construct a rewriter for the given instance, which may be a pattern
2815  * with concrete expressions as leaves, or just an expression. The
2816  * second optional argument (wildcard_type) is a hint as to what the
2817  * type of the wildcards is likely to be. If omitted it uses the same
2818  * type as the expression itself. They are not required to be this
2819  * type, but the rule will only be tested for wildcards of that type
2820  * when testing is enabled.
2821  *
2822  * The rewriter can be used to check to see if the instance is one of
2823  * some number of patterns and if so rewrite it into another form,
2824  * using its operator() method. See Simplify.cpp for a bunch of
2825  * example usage.
2826  *
2827  * Important: Any Exprs in patterns are captured by reference, not by
2828  * value, so ensure they outlive the rewriter.
2829  */
2830 // @{
2831 template<typename Instance,
2832  typename = typename enable_if_pattern<Instance>::type>
2833 HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
2834  return {pattern_arg(instance), output_type, wildcard_type};
2835 }
2836 
2837 template<typename Instance,
2838  typename = typename enable_if_pattern<Instance>::type>
2839 HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
2840  return {pattern_arg(instance), output_type, output_type};
2841 }
2842 
2844 auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
2845  return {pattern_arg(e), e.type(), wildcard_type};
2846 }
2847 
2849 auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
2850  return {pattern_arg(e), e.type(), e.type()};
2851 }
2852 // @}
2853 
2854 } // namespace IRMatcher
2855 
2856 } // namespace Internal
2857 } // namespace Halide
2858 
2859 #endif
#define internal_error
Definition: Errors.h:23
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
Definition: HalideRuntime.h:39
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:38
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
For optional debugging during codegen, use the debug class as follows:
Definition: Debug.h:49
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
Definition: IRMatch.h:229
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1552
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1544
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
Definition: IRMatch.h:2833
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
Definition: IRMatch.h:602
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
Definition: IRMatch.h:1310
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:1609
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1111
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0) noexcept -> IsInt< decltype(pattern_arg(a))>
Definition: IRMatch.h:2323
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
Definition: IRMatch.h:2626
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1067
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
Definition: IRMatch.h:1285
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
Definition: IRMatch.h:1986
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1205
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:956
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
Definition: IRMatch.h:2441
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
Definition: IRMatch.h:1336
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
Definition: IRMatch.h:1898
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
Definition: IRMatch.h:1185
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition: IRMatch.h:2200
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
Definition: IRMatch.h:1503
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1215
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1022
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1536
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1556
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Definition: IRMatch.h:963
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
Definition: IRMatch.h:1062
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1520
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
Definition: IRMatch.h:1029
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1133
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1822
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1055
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1516
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1096
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1341
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Definition: IRMatch.h:594
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1180
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:2032
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
Definition: IRMatch.h:2118
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1508
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
Definition: IRMatch.h:611
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1082
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1003
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
Definition: IRMatch.h:2406
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
Definition: IRMatch.h:2081
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
Definition: IRMatch.h:1615
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1528
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1140
constexpr bool and_reduce()
Definition: IRMatch.h:1365
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1305
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1512
constexpr int max_wild
Definition: IRMatch.h:74
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1280
HALIDE_ALWAYS_INLINE bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept
Definition: IRMatch.h:195
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
Definition: IRMatch.h:2285
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1230
bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1155
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1331
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
Definition: IRMatch.h:1904
constexpr bool commutative(IRNodeType t)
Definition: IRMatch.h:650
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
Definition: IRMatch.h:996
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
Definition: IRMatch.h:1892
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
Definition: IRMatch.h:1758
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition: IRMatch.h:1685
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
Definition: IRMatch.h:2478
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1118
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
Definition: IRMatch.h:2495
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1190
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1532
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1524
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1036
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
Definition: IRMatch.h:2365
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1560
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1548
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1240
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:989
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
Definition: IRMatch.h:1210
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
Definition: IRMatch.h:1160
HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition: IRMatch.h:2206
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1165
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
Definition: IRMatch.h:1886
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
Definition: IRMatch.h:1880
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1315
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
Definition: IRMatch.h:160
constexpr uint32_t bitwise_or_reduce()
Definition: IRMatch.h:1356
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1564
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1265
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
Definition: IRMatch.h:149
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
Definition: IRMatch.h:1235
constexpr int const_min(int a, int b)
Definition: IRMatch.h:1375
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1290
auto rounding_halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1540
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1089
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1255
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:970
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
Definition: IRMatch.h:2248
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
Definition: IRMatch.h:1260
T div_imp(T a, T b)
Definition: IROperator.h:257
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
void expr_match_test()
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
constexpr IRNodeType StrongestExprNodeType
Definition: Expr.h:79
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition: IROperator.h:236
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition: Expr.h:25
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the inner loop with an if statement that prevents evaluation beyond the original extent,...
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
@ C
No name mangling.
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop.
Expr abs(Expr a)
Returns the absolute value of a signed integer or floating-point expression.
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
Definition: Expr.h:256
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition: Expr.h:320
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition: Expr.h:314
The sum of two expressions.
Definition: IR.h:38
Logical and - are both expressions true.
Definition: IR.h:157
A base class for expression nodes.
Definition: Expr.h:141
A vector with 'lanes' elements, in which every element is 'value'.
Definition: IR.h:241
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
Definition: IR.h:247
A function call.
Definition: IR.h:464
@ signed_integer_overflow
Definition: IR.h:553
@ rounding_mul_shift_right
Definition: IR.h:544
bool is_intrinsic() const
Definition: IR.h:645
static const IRNodeType _node_type
Definition: IR.h:677
The actual IR nodes begin here.
Definition: IR.h:29
static const IRNodeType _node_type
Definition: IR.h:34
The ratio of two expressions.
Definition: IR.h:65
Is the first expression equal to the second.
Definition: IR.h:103
Floating point constants.
Definition: Expr.h:234
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Definition: IR.h:148
Is the first expression greater than the second.
Definition: IR.h:139
constexpr static uint32_t binds
Definition: IRMatch.h:668
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:671
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:699
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:680
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:670
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition: IRMatch.h:742
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:690
constexpr static bool canonical
Definition: IRMatch.h:676
constexpr static bool foldable
Definition: IRMatch.h:696
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1724
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1718
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1706
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1700
constexpr static uint32_t binds
Definition: IRMatch.h:1698
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1741
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1701
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2237
constexpr static bool foldable
Definition: IRMatch.h:2234
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2230
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2231
constexpr static uint32_t binds
Definition: IRMatch.h:2227
constexpr static bool canonical
Definition: IRMatch.h:2232
constexpr static bool canonical
Definition: IRMatch.h:2001
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2000
constexpr static bool foldable
Definition: IRMatch.h:2022
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2004
constexpr static uint32_t binds
Definition: IRMatch.h:1997
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1999
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:2013
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2018
constexpr static bool canonical
Definition: IRMatch.h:783
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:843
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:781
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:782
constexpr static bool foldable
Definition: IRMatch.h:806
constexpr static uint32_t binds
Definition: IRMatch.h:779
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:790
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:809
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:800
constexpr static bool foldable
Definition: IRMatch.h:2072
constexpr static uint32_t binds
Definition: IRMatch.h:2042
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2044
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2045
constexpr static bool canonical
Definition: IRMatch.h:2046
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition: IRMatch.h:2049
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2075
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:539
constexpr static bool canonical
Definition: IRMatch.h:531
constexpr static uint32_t binds
Definition: IRMatch.h:527
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
Definition: IRMatch.h:534
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Definition: IRMatch.h:562
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:574
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:529
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:530
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:567
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
Definition: IRMatch.h:557
constexpr static bool foldable
Definition: IRMatch.h:571
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
Definition: IRMatch.h:1401
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1434
constexpr static bool canonical
Definition: IRMatch.h:1389
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
Definition: IRMatch.h:1429
constexpr static bool foldable
Definition: IRMatch.h:1486
static constexpr uint32_t binds
Definition: IRMatch.h:1385
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
Definition: IRMatch.h:1394
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
Definition: IRMatch.h:1416
std::tuple< Args... > args
Definition: IRMatch.h:1383
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1406
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Definition: IRMatch.h:1425
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
Definition: IRMatch.h:1489
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1388
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1387
constexpr static bool canonical
Definition: IRMatch.h:2177
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2186
constexpr static bool foldable
Definition: IRMatch.h:2183
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2176
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2175
constexpr static uint32_t binds
Definition: IRMatch.h:2172
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2267
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2274
constexpr static bool canonical
Definition: IRMatch.h:2269
constexpr static uint32_t binds
Definition: IRMatch.h:2264
constexpr static bool foldable
Definition: IRMatch.h:2271
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2268
constexpr static uint32_t binds
Definition: IRMatch.h:2302
constexpr static bool foldable
Definition: IRMatch.h:2309
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2305
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2312
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2306
constexpr static bool canonical
Definition: IRMatch.h:2307
constexpr static bool canonical
Definition: IRMatch.h:2421
constexpr static bool foldable
Definition: IRMatch.h:2423
constexpr static uint32_t binds
Definition: IRMatch.h:2416
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2419
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2420
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2426
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2454
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2455
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2461
constexpr static bool canonical
Definition: IRMatch.h:2456
constexpr static uint32_t binds
Definition: IRMatch.h:2451
constexpr static bool foldable
Definition: IRMatch.h:2458
constexpr static bool foldable
Definition: IRMatch.h:2392
constexpr static bool canonical
Definition: IRMatch.h:2390
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2395
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2389
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2388
constexpr static uint32_t binds
Definition: IRMatch.h:2385
constexpr static bool canonical
Definition: IRMatch.h:2349
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2354
constexpr static uint32_t binds
Definition: IRMatch.h:2344
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2347
constexpr static bool foldable
Definition: IRMatch.h:2351
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2348
To save stack space, the matcher objects are largely stateless and immutable.
Definition: IRMatch.h:82
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
Definition: IRMatch.h:127
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
Definition: IRMatch.h:103
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
Definition: IRMatch.h:115
static constexpr uint16_t special_values_mask
Definition: IRMatch.h:88
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
Definition: IRMatch.h:121
halide_type_t bound_const_type[max_wild]
Definition: IRMatch.h:90
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
Definition: IRMatch.h:93
HALIDE_ALWAYS_INLINE MatcherState() noexcept
Definition: IRMatch.h:134
halide_scalar_value_t bound_const[max_wild]
Definition: IRMatch.h:84
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
Definition: IRMatch.h:98
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
Definition: IRMatch.h:109
static constexpr uint16_t signed_integer_overflow
Definition: IRMatch.h:87
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1916
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1917
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1922
constexpr static uint32_t binds
Definition: IRMatch.h:1914
constexpr static bool canonical
Definition: IRMatch.h:1919
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1937
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
Definition: IRMatch.h:1932
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1946
constexpr static bool foldable
Definition: IRMatch.h:1943
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1575
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1580
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1589
constexpr static uint32_t binds
Definition: IRMatch.h:1573
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1576
constexpr static bool foldable
Definition: IRMatch.h:1598
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1594
constexpr static bool canonical
Definition: IRMatch.h:1577
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1601
constexpr static bool canonical
Definition: IRMatch.h:2137
constexpr static bool foldable
Definition: IRMatch.h:2154
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2140
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2149
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2136
constexpr static uint32_t binds
Definition: IRMatch.h:2132
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2157
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2135
constexpr static bool foldable
Definition: IRMatch.h:2105
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2108
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2102
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2101
constexpr static uint32_t binds
Definition: IRMatch.h:2097
constexpr static bool canonical
Definition: IRMatch.h:2103
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1800
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1773
constexpr static bool canonical
Definition: IRMatch.h:1775
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1772
constexpr static bool foldable
Definition: IRMatch.h:1812
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1793
constexpr static uint32_t binds
Definition: IRMatch.h:1770
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1778
HALIDE_NEVER_INLINE void build_replacement(After after)
Definition: IRMatch.h:2666
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
Definition: IRMatch.h:2740
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
Definition: IRMatch.h:2715
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
Definition: IRMatch.h:2661
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
Definition: IRMatch.h:2769
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
Definition: IRMatch.h:2697
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Definition: IRMatch.h:2792
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Definition: IRMatch.h:2674
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1636
constexpr static bool canonical
Definition: IRMatch.h:1638
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1665
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1635
constexpr static uint32_t binds
Definition: IRMatch.h:1633
constexpr static bool foldable
Definition: IRMatch.h:1662
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Definition: IRMatch.h:1651
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1641
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1658
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:210
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:217
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:222
constexpr static uint32_t binds
Definition: IRMatch.h:207
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:211
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1855
constexpr static uint32_t binds
Definition: IRMatch.h:1835
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1838
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1842
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1837
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1862
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:364
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:360
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:385
constexpr static uint32_t binds
Definition: IRMatch.h:357
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:359
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:395
constexpr static uint32_t binds
Definition: IRMatch.h:411
constexpr static bool foldable
Definition: IRMatch.h:450
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:443
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:414
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:418
constexpr static bool canonical
Definition: IRMatch.h:415
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:453
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:413
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
Definition: IRMatch.h:437
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:240
constexpr static uint32_t binds
Definition: IRMatch.h:238
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:279
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:289
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:245
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
Definition: IRMatch.h:266
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:241
constexpr static uint32_t binds
Definition: IRMatch.h:304
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:307
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:306
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:311
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:342
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:332
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:472
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:491
constexpr static bool foldable
Definition: IRMatch.h:489
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:485
constexpr static bool canonical
Definition: IRMatch.h:473
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:471
constexpr static uint32_t binds
Definition: IRMatch.h:469
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:476
constexpr static uint32_t mask
Definition: IRMatch.h:146
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition: Expr.h:111
Integer constants.
Definition: Expr.h:216
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Definition: IR.h:130
Is the first expression less than the second.
Definition: IR.h:121
The greater of two values.
Definition: IR.h:94
The lesser of two values.
Definition: IR.h:85
The remainder of a / b.
Definition: IR.h:76
The product of two expressions.
Definition: IR.h:56
Is the first expression not equal to the second.
Definition: IR.h:112
Logical not - true if the expression false.
Definition: IR.h:175
static Expr make(Expr a)
Logical or - is at least one of the expression true.
Definition: IR.h:166
A linear ramp vector node.
Definition: IR.h:229
static const IRNodeType _node_type
Definition: IR.h:235
static Expr make(Expr base, Expr stride, int lanes)
A ternary operator.
Definition: IR.h:186
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Definition: IR.h:191
The difference of two expressions.
Definition: IR.h:47
static const IRNodeType _node_type
Definition: IR.h:52
static Expr make(Expr a, Expr b)
Unsigned integer constants.
Definition: Expr.h:225
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:871
static const IRNodeType _node_type
Definition: IR.h:890
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
Definition: Type.h:269
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition: Type.h:406
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
Definition: Type.h:337
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition: Type.h:412
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition: Type.h:331
HALIDE_ALWAYS_INLINE bool is_vector() const
Is this type a vector type? (lanes() != 1).
Definition: Type.h:381
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition: Type.h:388
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition: Type.h:394
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@3 u
A runtime tag for a type in the halide type system.
uint8_t bits
The number of bits of precision of a single scalar value of this type.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.