1 #ifndef HALIDE_IR_MATCH_H
2 #define HALIDE_IR_MATCH_H
139 typename =
typename std::remove_reference<T>::type::pattern_tag>
146 constexpr
static uint32_t mask = std::remove_reference<T>::type::binds;
166 const int lanes = scalar_type.
lanes;
167 scalar_type.
lanes = 1;
170 switch (scalar_type.
code) {
198 ((a.type == b.type) &&
199 (a.node_type == b.node_type) &&
216 template<u
int32_t bound>
244 template<u
int32_t bound>
246 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
249 op = ((
const Broadcast *)op)->value.get();
258 state.get_bound_const(i, val, type);
261 state.set_bound_const(i, value, e.type);
265 template<u
int32_t bound>
267 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
271 state.get_bound_const(i, val, type);
272 return type == i64_type && value == val.
u.
i64;
274 state.set_bound_const(i, value, i64_type);
310 template<u
int32_t bound>
312 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
315 op = ((
const Broadcast *)op)->value.get();
324 state.get_bound_const(i, val, type);
327 state.set_bound_const(i, value, e.type);
343 state.get_bound_const(i, val, ty);
363 template<u
int32_t bound>
365 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
368 op = ((
const Broadcast *)op)->value.get();
373 double value = ((
const FloatImm *)op)->value;
377 state.get_bound_const(i, val, type);
380 state.set_bound_const(i, value, e.type);
396 state.get_bound_const(i, val, ty);
417 template<u
int32_t bound>
419 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
422 op = ((
const Broadcast *)op)->value.get();
436 template<u
int32_t bound>
438 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
454 state.get_bound_const(i, val, ty);
475 template<u
int32_t bound>
478 return equal(*state.get_binding(i), e);
480 state.set_binding(i, e);
492 const auto *e = state.get_binding(i);
494 switch (e->node_type) {
496 val.u.u64 = ((
const UIntImm *)e)->value;
499 val.u.i64 = ((
const IntImm *)e)->value;
502 val.u.f64 = ((
const FloatImm *)e)->value;
538 template<u
int32_t bound>
542 op = ((
const Broadcast *)op)->value.get();
550 return ((
const FloatImm *)op)->value == (
double)
v;
556 template<u
int32_t bound>
561 template<u
int32_t bound>
585 val.u.f64 = (double)
v;
601 typename =
typename std::decay<T>::type::pattern_tag>
612 static_assert(!std::is_same<
typename std::decay<T>::type,
Expr>::value || std::is_lvalue_reference<T>::value,
613 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
624 typename =
typename std::decay<T>::type::pattern_tag,
626 typename =
typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
641 template<
typename Op>
644 template<
typename Op>
647 template<
typename Op>
662 template<
typename Op,
typename A,
typename B>
677 A::canonical && B::canonical && (!
commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
679 template<u
int32_t bound>
681 if (e.node_type != Op::_node_type) {
684 const Op &op = (
const Op &)e;
685 return (
a.template match<bound>(*op.a.get(), state) &&
686 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
689 template<u
int32_t bound,
typename Op2,
typename A2,
typename B2>
691 return (std::is_same<Op, Op2>::value &&
692 a.template match<bound>(
unwrap(op.a), state) &&
696 constexpr
static bool foldable = A::foldable && B::foldable;
701 if (std::is_same<A, IntLiteral>::value) {
702 b.make_folded_const(val_b, ty, state);
703 if ((std::is_same<Op, And>::value && val_b.
u.
u64 == 0) ||
704 (std::is_same<Op, Or>::value && val_b.
u.
u64 == 1)) {
710 a.make_folded_const(val_a, ty, state);
713 a.make_folded_const(val_a, ty, state);
714 if ((std::is_same<Op, And>::value && val_a.
u.
u64 == 0) ||
715 (std::is_same<Op, Or>::value && val_a.
u.
u64 == 1)) {
721 b.make_folded_const(val_b, ty, state);
726 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.
u.
i64, val_b.
u.
i64);
729 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.
u.
u64, val_b.
u.
u64);
733 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.
u.
f64, val_b.
u.
f64);
744 if (std::is_same<A, IntLiteral>::value) {
745 eb =
b.make(state, type_hint);
746 ea =
a.make(state, eb.
type());
748 ea =
a.make(state, type_hint);
749 eb =
b.make(state, ea.
type());
759 return Op::make(std::move(ea), std::move(eb));
763 template<
typename Op>
766 template<
typename Op>
769 template<
typename Op>
773 template<
typename Op,
typename A,
typename B>
785 (!
commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
789 template<u
int32_t bound>
791 if (e.node_type != Op::_node_type) {
794 const Op &op = (
const Op &)e;
795 return (
a.template match<bound>(*op.a.get(), state) &&
796 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
799 template<u
int32_t bound,
typename Op2,
typename A2,
typename B2>
801 return (std::is_same<Op, Op2>::value &&
802 a.template match<bound>(
unwrap(op.a), state) &&
806 constexpr
static bool foldable = A::foldable && B::foldable;
812 if (std::is_same<A, IntLiteral>::value) {
813 b.make_folded_const(val_b, ty, state);
815 a.make_folded_const(val_a, ty, state);
818 a.make_folded_const(val_a, ty, state);
820 b.make_folded_const(val_b, ty, state);
825 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
i64, val_b.
u.
i64);
828 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
u64, val_b.
u.
u64);
832 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
f64, val_b.
u.
f64);
846 if (std::is_same<A, IntLiteral>::value) {
847 eb =
b.make(state, {});
848 ea =
a.make(state, eb.
type());
850 ea =
a.make(state, {});
851 eb =
b.make(state, ea.
type());
861 return Op::make(std::move(ea), std::move(eb));
865 template<
typename A,
typename B>
867 s <<
"(" << op.
a <<
" + " << op.
b <<
")";
871 template<
typename A,
typename B>
873 s <<
"(" << op.
a <<
" - " << op.
b <<
")";
877 template<
typename A,
typename B>
879 s <<
"(" << op.
a <<
" * " << op.
b <<
")";
883 template<
typename A,
typename B>
885 s <<
"(" << op.
a <<
" / " << op.
b <<
")";
889 template<
typename A,
typename B>
891 s <<
"(" << op.
a <<
" && " << op.
b <<
")";
895 template<
typename A,
typename B>
897 s <<
"(" << op.
a <<
" || " << op.
b <<
")";
901 template<
typename A,
typename B>
903 s <<
"min(" << op.
a <<
", " << op.
b <<
")";
907 template<
typename A,
typename B>
909 s <<
"max(" << op.
a <<
", " << op.
b <<
")";
913 template<
typename A,
typename B>
915 s <<
"(" << op.
a <<
" <= " << op.
b <<
")";
919 template<
typename A,
typename B>
921 s <<
"(" << op.
a <<
" < " << op.
b <<
")";
925 template<
typename A,
typename B>
927 s <<
"(" << op.
a <<
" >= " << op.
b <<
")";
931 template<
typename A,
typename B>
933 s <<
"(" << op.
a <<
" > " << op.
b <<
")";
937 template<
typename A,
typename B>
939 s <<
"(" << op.
a <<
" == " << op.
b <<
")";
943 template<
typename A,
typename B>
945 s <<
"(" << op.
a <<
" != " << op.
b <<
")";
949 template<
typename A,
typename B>
951 s <<
"(" << op.
a <<
" % " << op.
b <<
")";
955 template<
typename A,
typename B>
957 assert_is_lvalue_if_expr<A>();
958 assert_is_lvalue_if_expr<B>();
962 template<
typename A,
typename B>
964 assert_is_lvalue_if_expr<A>();
965 assert_is_lvalue_if_expr<B>();
972 int dead_bits = 64 - t.bits;
980 return (a + b) & (ones >> (64 - t.bits));
988 template<
typename A,
typename B>
990 assert_is_lvalue_if_expr<A>();
991 assert_is_lvalue_if_expr<B>();
995 template<
typename A,
typename B>
997 assert_is_lvalue_if_expr<A>();
998 assert_is_lvalue_if_expr<B>();
1006 int dead_bits = 64 - t.bits;
1013 return (a - b) & (ones >> (64 - t.bits));
1021 template<
typename A,
typename B>
1023 assert_is_lvalue_if_expr<A>();
1024 assert_is_lvalue_if_expr<B>();
1028 template<
typename A,
typename B>
1030 assert_is_lvalue_if_expr<A>();
1031 assert_is_lvalue_if_expr<B>();
1038 int dead_bits = 64 - t.bits;
1046 return (a * b) & (ones >> (64 - t.bits));
1054 template<
typename A,
typename B>
1056 assert_is_lvalue_if_expr<A>();
1057 assert_is_lvalue_if_expr<B>();
1061 template<
typename A,
typename B>
1081 template<
typename A,
typename B>
1083 assert_is_lvalue_if_expr<A>();
1084 assert_is_lvalue_if_expr<B>();
1088 template<
typename A,
typename B>
1090 assert_is_lvalue_if_expr<A>();
1091 assert_is_lvalue_if_expr<B>();
1110 template<
typename A,
typename B>
1112 assert_is_lvalue_if_expr<A>();
1113 assert_is_lvalue_if_expr<B>();
1132 template<
typename A,
typename B>
1134 assert_is_lvalue_if_expr<A>();
1135 assert_is_lvalue_if_expr<B>();
1154 template<
typename A,
typename B>
1159 template<
typename A,
typename B>
1179 template<
typename A,
typename B>
1184 template<
typename A,
typename B>
1204 template<
typename A,
typename B>
1209 template<
typename A,
typename B>
1229 template<
typename A,
typename B>
1234 template<
typename A,
typename B>
1254 template<
typename A,
typename B>
1259 template<
typename A,
typename B>
1279 template<
typename A,
typename B>
1284 template<
typename A,
typename B>
1304 template<
typename A,
typename B>
1309 template<
typename A,
typename B>
1330 template<
typename A,
typename B>
1335 template<
typename A,
typename B>
1360 template<
typename... Args>
1369 template<
typename... Args>
1376 return a < b ? a : b;
1379 template<
typename... Args>
1393 typename =
typename std::enable_if<(i <
sizeof...(Args))>::type>
1395 using T = decltype(std::get<i>(
args));
1396 return (std::get<i>(
args).
template match<bound>(*c.args[i].get(), state) &&
1400 template<
int i, u
int32_t binds>
1405 template<u
int32_t bound>
1415 typename =
typename std::enable_if<(i <
sizeof...(Args))>::type>
1417 s << std::get<i>(
args);
1418 if (i + 1 <
sizeof...(Args)) {
1421 print_args<i + 1>(0, s);
1430 print_args<0>(0, s);
1435 Expr arg0 = std::get<0>(
args).make(state, type_hint);
1446 return absd(arg0, arg1);
1466 return arg0 << arg1;
1468 return arg0 >> arg1;
1494 template<
typename... Args>
1502 template<
typename... Args>
1507 template<
typename A,
typename B>
1511 template<
typename A,
typename B>
1515 template<
typename A,
typename B>
1519 template<
typename A,
typename B>
1523 template<
typename A,
typename B>
1527 template<
typename A,
typename B>
1531 template<
typename A,
typename B>
1535 template<
typename A,
typename B>
1539 template<
typename A,
typename B>
1543 template<
typename A,
typename B>
1547 template<
typename A,
typename B>
1551 template<
typename A,
typename B>
1555 template<
typename A,
typename B>
1559 template<
typename A,
typename B,
typename C>
1563 template<
typename A,
typename B,
typename C>
1568 template<
typename A>
1579 template<u
int32_t bound>
1584 const Not &op = (
const Not &)e;
1585 return (
a.template match<bound>(*op.
a.
get(), state));
1588 template<u
int32_t bound,
typename A2>
1590 return a.template match<bound>(
unwrap(op.a), state);
1600 template<
typename A1 = A>
1602 a.make_folded_const(val, ty, state);
1603 val.u.u64 = ~val.u.u64;
1608 template<
typename A>
1610 assert_is_lvalue_if_expr<A>();
1614 template<
typename A>
1616 assert_is_lvalue_if_expr<A>();
1620 template<
typename A>
1622 s <<
"!(" << op.
a <<
")";
1626 template<
typename C,
typename T,
typename F>
1638 constexpr
static bool canonical = C::canonical && T::canonical && F::canonical;
1640 template<u
int32_t bound>
1646 return (
c.template match<bound>(*op.
condition.
get(), state) &&
1647 t.template match<bound | bindings<C>::mask>(*op.
true_value.
get(), state) &&
1650 template<u
int32_t bound,
typename C2,
typename T2,
typename F2>
1652 return (
c.template match<bound>(
unwrap(instance.c), state) &&
1659 return Select::make(
c.make(state, {}),
t.make(state, type_hint),
f.make(state, type_hint));
1662 constexpr
static bool foldable = C::foldable && T::foldable && F::foldable;
1664 template<
typename C1 = C>
1668 c.make_folded_const(c_val, c_ty, state);
1669 if ((c_val.
u.
u64 & 1) == 1) {
1670 t.make_folded_const(val, ty, state);
1672 f.make_folded_const(val, ty, state);
1678 template<
typename C,
typename T,
typename F>
1680 s <<
"select(" << op.
c <<
", " << op.
t <<
", " << op.
f <<
")";
1684 template<
typename C,
typename T,
typename F>
1686 assert_is_lvalue_if_expr<C>();
1687 assert_is_lvalue_if_expr<T>();
1688 assert_is_lvalue_if_expr<F>();
1692 template<
typename A,
typename B>
1703 constexpr
static bool canonical = A::canonical && B::canonical;
1705 template<u
int32_t bound>
1709 if (
a.template match<bound>(*op.
value.
get(), state) &&
1710 lanes.template match<bound>(op.
lanes, state)) {
1717 template<u
int32_t bound,
typename A2,
typename B2>
1719 return (
a.template match<bound>(
unwrap(op.a), state) &&
1727 lanes.make_folded_const(lanes_val, ty, state);
1729 type_hint.
lanes /= l;
1730 Expr val =
a.make(state, type_hint);
1740 template<
typename A1 = A>
1744 lanes.make_folded_const(lanes_val, lanes_ty, state);
1746 a.make_folded_const(val, ty, state);
1751 template<
typename A,
typename B>
1753 s <<
"broadcast(" << op.
a <<
", " << op.
lanes <<
")";
1757 template<
typename A,
typename B>
1759 assert_is_lvalue_if_expr<A>();
1763 template<
typename A,
typename B,
typename C>
1775 constexpr
static bool canonical = A::canonical && B::canonical && C::canonical;
1777 template<u
int32_t bound>
1783 if (
a.template match<bound>(*op.
base.
get(), state) &&
1784 b.template match<bound | bindings<A>::mask>(*op.
stride.
get(), state) &&
1792 template<u
int32_t bound,
typename A2,
typename B2,
typename C2>
1794 return (
a.template match<bound>(
unwrap(op.a), state) &&
1803 lanes.make_folded_const(lanes_val, ty, state);
1805 type_hint.
lanes /= l;
1807 eb =
b.make(state, type_hint);
1808 ea =
a.make(state, eb.type());
1815 template<
typename A,
typename B,
typename C>
1817 s <<
"ramp(" << op.
a <<
", " << op.
b <<
", " << op.
lanes <<
")";
1821 template<
typename A,
typename B,
typename C>
1823 assert_is_lvalue_if_expr<A>();
1824 assert_is_lvalue_if_expr<B>();
1825 assert_is_lvalue_if_expr<C>();
1829 template<
typename A,
typename B, VectorReduce::Operator reduce_op>
1841 template<u
int32_t bound>
1845 if (op.
op == reduce_op &&
1846 a.template match<bound>(*op.
value.
get(), state) &&
1847 lanes.template match<bound | bindings<A>::mask>(op.
type.
lanes(), state)) {
1854 template<u
int32_t bound,
typename A2,
typename B2, VectorReduce::Operator reduce_op_2>
1856 return (reduce_op == reduce_op_2 &&
1857 a.template match<bound>(
unwrap(op.a), state) &&
1865 lanes.make_folded_const(lanes_val, ty, state);
1866 int l = (int)lanes_val.
u.
i64;
1873 template<
typename A,
typename B, VectorReduce::Operator reduce_op>
1875 s <<
"vector_reduce(" << reduce_op <<
", " << op.
a <<
", " << op.
lanes <<
")";
1879 template<
typename A,
typename B>
1881 assert_is_lvalue_if_expr<A>();
1885 template<
typename A,
typename B>
1887 assert_is_lvalue_if_expr<A>();
1891 template<
typename A,
typename B>
1893 assert_is_lvalue_if_expr<A>();
1897 template<
typename A,
typename B>
1899 assert_is_lvalue_if_expr<A>();
1903 template<
typename A,
typename B>
1905 assert_is_lvalue_if_expr<A>();
1909 template<
typename A>
1921 template<u
int32_t bound>
1926 const Sub &op = (
const Sub &)e;
1927 return (
a.template match<bound>(*op.
b.
get(), state) &&
1931 template<u
int32_t bound,
typename A2>
1933 return a.template match<bound>(
unwrap(p.a), state);
1938 Expr ea =
a.make(state, type_hint);
1940 return Sub::make(std::move(z), std::move(ea));
1945 template<
typename A1 = A>
1947 a.make_folded_const(val, ty, state);
1948 int dead_bits = 64 - ty.bits;
1951 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
1960 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
1964 val.u.f64 = -val.u.f64;
1973 template<
typename A>
1979 template<
typename A>
1981 assert_is_lvalue_if_expr<A>();
1985 template<
typename A>
1987 assert_is_lvalue_if_expr<A>();
1991 template<
typename A>
2003 template<u
int32_t bound>
2009 return (e.type ==
t &&
2010 a.template match<bound>(*op.
value.
get(), state));
2012 template<u
int32_t bound,
typename A2>
2014 return t == op.t &&
a.template match<bound>(
unwrap(op.a), state);
2019 return cast(
t,
a.make(state, {}));
2025 template<
typename A>
2027 s <<
"cast(" << op.
t <<
", " << op.
a <<
")";
2031 template<
typename A>
2033 assert_is_lvalue_if_expr<A>();
2037 template<
typename A>
2052 a.make_folded_const(c, ty, state);
2058 if (type_hint.bits) {
2062 c.
u.
f64 = (double)x;
2064 ty.
code = type_hint.code;
2065 ty.
bits = type_hint.bits;
2074 template<
typename A1 = A>
2076 a.make_folded_const(val, ty, state);
2080 template<
typename A>
2082 assert_is_lvalue_if_expr<A>();
2086 template<
typename A>
2088 s <<
"fold(" << op.
a <<
")";
2092 template<
typename A>
2107 template<
typename A1 = A>
2109 a.make_folded_const(val, ty, state);
2117 template<
typename A>
2119 assert_is_lvalue_if_expr<A>();
2123 template<
typename A>
2125 s <<
"overflows(" << op.
a <<
")";
2139 template<u
int32_t bound>
2168 template<
typename A>
2185 template<
typename A1 = A>
2187 Expr e =
a.make(state, {});
2199 template<
typename A>
2201 assert_is_lvalue_if_expr<A>();
2205 template<
typename A>
2207 assert_is_lvalue_if_expr<A>();
2211 template<
typename A>
2214 s <<
"is_const(" << op.
a <<
")";
2216 s <<
"is_const(" << op.
a <<
", " << op.
v <<
")";
2221 template<
typename A,
typename Prover>
2238 Expr condition =
a.make(state, {});
2239 condition =
prover->mutate(condition,
nullptr);
2247 template<
typename A,
typename Prover>
2249 assert_is_lvalue_if_expr<A>();
2253 template<
typename A,
typename Prover>
2255 s <<
"can_prove(" << op.
a <<
")";
2259 template<
typename A>
2276 Type t =
a.make(state, {}).type();
2284 template<
typename A>
2286 assert_is_lvalue_if_expr<A>();
2290 template<
typename A>
2292 s <<
"is_float(" << op.
a <<
")";
2296 template<
typename A>
2314 Type t =
a.make(state, {}).type();
2322 template<
typename A>
2324 assert_is_lvalue_if_expr<A>();
2328 template<
typename A>
2330 s <<
"is_int(" << op.
a;
2332 s <<
", " << op.
bits;
2338 template<
typename A>
2356 Type t =
a.make(state, {}).type();
2364 template<
typename A>
2366 assert_is_lvalue_if_expr<A>();
2370 template<
typename A>
2372 s <<
"is_uint(" << op.
a;
2374 s <<
", " << op.
bits;
2380 template<
typename A>
2397 Type t =
a.make(state, {}).type();
2405 template<
typename A>
2407 assert_is_lvalue_if_expr<A>();
2411 template<
typename A>
2428 a.make_folded_const(val, ty, state);
2431 val.
u.
u64 = (val.
u.
u64 == max_bits);
2440 template<
typename A>
2442 assert_is_lvalue_if_expr<A>();
2446 template<
typename A>
2463 a.make_folded_const(val, ty, state);
2466 val.
u.
u64 = (val.
u.
u64 == min_bits);
2477 template<
typename A>
2479 assert_is_lvalue_if_expr<A>();
2483 template<
typename A>
2485 s <<
"is_scalar(" << op.
a <<
")";
2490 template<
typename Before,
2493 typename =
typename std::enable_if<std::decay<Before>::type::foldable &&
2494 std::decay<After>::type::foldable>::type>
2499 wildcard_type.lanes = output_type.lanes = 1;
2502 static std::set<uint32_t> tested;
2504 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2509 debug(0) <<
"validate('" << before <<
"', '" << after <<
"', '" << pred <<
"', " <<
Type(wildcard_type) <<
", " <<
Type(output_type) <<
")\n";
2514 static std::mt19937_64 rng(0);
2519 for (
int trials = 0; trials < 100; trials++) {
2523 int shift = (int)(rng() & (wildcard_type.bits - 1));
2525 for (
int i = 0; i <
max_wild; i++) {
2527 switch (wildcard_type.code) {
2547 double val = ((
int64_t)(rng() & 15) - 8) / 2.0;
2549 val = ((
int64_t)(rng() & 15) - 8) / 2.0;
2563 before.make_folded_const(val_before, type, state);
2565 after.make_folded_const(val_after, type, state);
2566 lanes |= type.
lanes;
2573 switch (output_type.code) {
2588 ok &= (error < 0.01 ||
2589 val_before.
u.
u64 == val_after.
u.
u64 ||
2590 std::isnan(val_before.
u.
f64));
2598 debug(0) <<
"Fails with values:\n";
2599 for (
int i = 0; i <
max_wild; i++) {
2604 for (
int i = 0; i <
max_wild; i++) {
2609 debug(0) << val_before.
u.
u64 <<
" " << val_after.
u.
u64 <<
"\n";
2615 template<
typename Before,
2618 typename =
typename std::enable_if<!(std::decay<Before>::type::foldable &&
2619 std::decay<After>::type::foldable)>::type>
2630 template<
typename Pattern,
2631 typename =
typename enable_if_pattern<Pattern>::type>
2635 p.make_folded_const(c, ty, state);
2643 #define HALIDE_DEBUG_MATCHED_RULES 0
2644 #define HALIDE_DEBUG_UNMATCHED_RULES 0
2650 #define HALIDE_FUZZ_TEST_RULES 0
2652 template<
typename Instance>
2665 template<
typename After>
2670 template<
typename Before,
2675 static_assert((Before::binds & After::binds) == After::binds,
"Rule result uses unbound values");
2676 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2677 static_assert(After::canonical,
"RHS of rewrite rule should be in canonical form");
2678 #if HALIDE_FUZZ_TEST_RULES
2682 #if HALIDE_DEBUG_MATCHED_RULES
2688 #if HALIDE_DEBUG_UNMATCHED_RULES
2689 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2695 template<
typename Before,
2698 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2701 #if HALIDE_DEBUG_MATCHED_RULES
2706 #if HALIDE_DEBUG_UNMATCHED_RULES
2707 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2713 template<
typename Before,
2716 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2717 #if HALIDE_FUZZ_TEST_RULES
2722 #if HALIDE_DEBUG_MATCHED_RULES
2727 #if HALIDE_DEBUG_UNMATCHED_RULES
2728 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2734 template<
typename Before,
2741 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2742 static_assert((Before::binds & After::binds) == After::binds,
"Rule result uses unbound values");
2743 static_assert((Before::binds & Predicate::binds) == Predicate::binds,
"Rule predicate uses unbound values");
2744 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2745 static_assert(After::canonical,
"RHS of rewrite rule should be in canonical form");
2747 #if HALIDE_FUZZ_TEST_RULES
2752 #if HALIDE_DEBUG_MATCHED_RULES
2753 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
2758 #if HALIDE_DEBUG_UNMATCHED_RULES
2759 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2765 template<
typename Before,
2770 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2771 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2776 #if HALIDE_DEBUG_MATCHED_RULES
2777 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
2781 #if HALIDE_DEBUG_UNMATCHED_RULES
2782 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2788 template<
typename Before,
2793 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2794 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2795 #if HALIDE_FUZZ_TEST_RULES
2801 #if HALIDE_DEBUG_MATCHED_RULES
2802 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
2806 #if HALIDE_DEBUG_UNMATCHED_RULES
2807 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2831 template<
typename Instance,
2832 typename =
typename enable_if_pattern<Instance>::type>
2834 return {
pattern_arg(instance), output_type, wildcard_type};
2837 template<
typename Instance,
2838 typename =
typename enable_if_pattern<Instance>::type>
2840 return {
pattern_arg(instance), output_type, output_type};
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
#define HALIDE_ALWAYS_INLINE
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
For optional debugging during codegen, use the debug class as follows:
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0) noexcept -> IsInt< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
constexpr bool and_reduce()
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
constexpr bool commutative(IRNodeType t)
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
constexpr uint32_t bitwise_or_reduce()
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
constexpr int const_min(int a, int b)
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
auto rounding_halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
constexpr IRNodeType StrongestExprNodeType
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the inner loop with an if statement that prevents evaluation beyond the original extent,...
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop.
Expr abs(Expr a)
Returns the absolute value of a signed integer or floating-point expression.
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
The sum of two expressions.
Logical and - are both expressions true.
A base class for expression nodes.
A vector with 'lanes' elements, in which every element is 'value'.
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
@ signed_integer_overflow
@ rounding_mul_shift_right
bool is_intrinsic() const
static const IRNodeType _node_type
The actual IR nodes begin here.
static const IRNodeType _node_type
The ratio of two expressions.
Is the first expression equal to the second.
Floating point constants.
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Is the first expression greater than the second.
constexpr static uint32_t binds
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
constexpr static bool canonical
constexpr static bool foldable
constexpr static bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static IRNodeType max_node_type
constexpr static bool canonical
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static bool foldable
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
constexpr static uint32_t binds
constexpr static bool canonical
constexpr static bool canonical
constexpr static IRNodeType max_node_type
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static uint32_t binds
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
constexpr static bool foldable
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
constexpr static bool foldable
constexpr static uint32_t binds
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
constexpr static bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static bool canonical
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
constexpr static bool foldable
static constexpr uint32_t binds
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
std::tuple< Args... > args
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
constexpr static IRNodeType max_node_type
constexpr static IRNodeType min_node_type
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static bool foldable
constexpr static IRNodeType max_node_type
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static bool canonical
constexpr static uint32_t binds
constexpr static bool foldable
constexpr static IRNodeType max_node_type
constexpr static uint32_t binds
constexpr static bool foldable
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static IRNodeType max_node_type
constexpr static bool canonical
constexpr static bool canonical
constexpr static bool foldable
constexpr static uint32_t binds
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static bool canonical
constexpr static uint32_t binds
constexpr static bool foldable
constexpr static bool foldable
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static IRNodeType max_node_type
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static uint32_t binds
constexpr static IRNodeType min_node_type
constexpr static bool foldable
constexpr static IRNodeType max_node_type
To save stack space, the matcher objects are largely stateless and immutable.
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
static constexpr uint16_t special_values_mask
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
halide_type_t bound_const_type[max_wild]
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
HALIDE_ALWAYS_INLINE MatcherState() noexcept
halide_scalar_value_t bound_const[max_wild]
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
static constexpr uint16_t signed_integer_overflow
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static uint32_t binds
constexpr static bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static bool foldable
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
constexpr static uint32_t binds
constexpr static IRNodeType max_node_type
constexpr static bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static bool canonical
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static IRNodeType max_node_type
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
constexpr static bool foldable
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static IRNodeType max_node_type
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
constexpr static bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static IRNodeType max_node_type
constexpr static bool canonical
constexpr static IRNodeType min_node_type
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_NEVER_INLINE void build_replacement(After after)
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
halide_type_t wildcard_type
halide_type_t output_type
constexpr static IRNodeType max_node_type
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static IRNodeType min_node_type
constexpr static bool canonical
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static uint32_t binds
const BaseExprNode & expr
constexpr static IRNodeType max_node_type
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
constexpr static uint32_t binds
constexpr static IRNodeType max_node_type
constexpr static bool canonical
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool foldable
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static uint32_t binds
constexpr static bool canonical
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static uint32_t binds
constexpr static bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
constexpr static bool canonical
constexpr static bool foldable
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
constexpr static IRNodeType max_node_type
constexpr static uint32_t binds
constexpr static IRNodeType max_node_type
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static bool foldable
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool canonical
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static bool foldable
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool canonical
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static uint32_t mask
IRNodeType node_type
Each IR node subclass has a unique identifier.
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Is the first expression less than the second.
The greater of two values.
The lesser of two values.
The product of two expressions.
Is the first expression not equal to the second.
Logical not - true if the expression false.
Logical or - is at least one of the expression true.
A linear ramp vector node.
static const IRNodeType _node_type
static Expr make(Expr base, Expr stride, int lanes)
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
The difference of two expressions.
static const IRNodeType _node_type
static Expr make(Expr a, Expr b)
Unsigned integer constants.
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
static const IRNodeType _node_type
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
HALIDE_ALWAYS_INLINE bool is_vector() const
Is this type a vector type? (lanes() != 1).
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@3 u
A runtime tag for a type in the halide type system.
uint8_t bits
The number of bits of precision of a single scalar value of this type.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.