Halide  14.0.0
Halide compiler and libraries
Simplify_Internal.h
Go to the documentation of this file.
1 #ifndef HALIDE_SIMPLIFY_VISITORS_H
2 #define HALIDE_SIMPLIFY_VISITORS_H
3 
4 /** \file
5  * The simplifier is separated into multiple compilation units with
6  * this single shared header to speed up the build. This file is not
7  * exported in Halide.h. */
8 
9 #include "Bounds.h"
10 #include "IRMatch.h"
11 #include "IRVisitor.h"
12 #include "Scope.h"
13 
14 // Because this file is only included by the simplify methods and
15 // doesn't go into Halide.h, we're free to use any old names for our
16 // macros.
17 
18 #define LOG_EXPR_MUTATIONS 0
19 #define LOG_STMT_MUTATIONS 0
20 
21 // On old compilers, some visitors would use large stack frames,
22 // because they use expression templates that generate large numbers
23 // of temporary objects when they are built and matched against. If we
24 // wrap the expressions that imply lots of temporaries in a lambda, we
25 // can get these large frames out of the recursive path.
26 #define EVAL_IN_LAMBDA(x) (([&]() HALIDE_NEVER_INLINE { return (x); })())
27 
28 namespace Halide {
29 namespace Internal {
30 
32  if (mul_would_overflow(64, a, b)) {
33  if ((a > 0) == (b > 0)) {
34  return INT64_MAX;
35  } else {
36  return INT64_MIN;
37  }
38  } else {
39  return a * b;
40  }
41 }
42 
43 class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
45 
46 public:
47  Simplify(bool r, const Scope<Interval> *bi, const Scope<ModulusRemainder> *ai);
48 
49  struct ExprInfo {
50  // We track constant integer bounds when they exist
51  // TODO: Use ConstantInterval?
52  int64_t min = 0, max = 0;
53  bool min_defined = false, max_defined = false;
54  // And the alignment of integer variables
56 
58  if (alignment.modulus == 0) {
59  min_defined = max_defined = true;
61  } else if (alignment.modulus > 1) {
62  if (min_defined) {
64  if (new_min < min) {
65  new_min += alignment.modulus;
66  }
67  min = new_min;
68  }
69  if (max_defined) {
71  if (new_max > max) {
72  new_max -= alignment.modulus;
73  }
74  max = new_max;
75  }
76  }
77 
78  if (min_defined && max_defined && min == max) {
79  alignment.modulus = 0;
81  }
82  }
83 
84  // Mix in existing knowledge about this Expr
85  void intersect(const ExprInfo &other) {
86  if (min_defined && other.min_defined) {
87  min = std::max(min, other.min);
88  } else if (other.min_defined) {
89  min_defined = true;
90  min = other.min;
91  }
92 
93  if (max_defined && other.max_defined) {
94  max = std::min(max, other.max);
95  } else if (other.max_defined) {
96  max_defined = true;
97  max = other.max;
98  }
99 
101 
103  }
104  };
105 
108  if (b) {
109  *b = ExprInfo{};
110  }
111  }
112 
113 #if (LOG_EXPR_MUTATORIONS || LOG_STMT_MUTATIONS)
114  static int debug_indent;
115 #endif
116 
117 #if LOG_EXPR_MUTATIONS
118  Expr mutate(const Expr &e, ExprInfo *b) {
119  const std::string spaces(debug_indent, ' ');
120  debug(1) << spaces << "Simplifying Expr: " << e << "\n";
121  debug_indent++;
122  Expr new_e = Super::dispatch(e, b);
123  debug_indent--;
124  if (!new_e.same_as(e)) {
125  debug(1)
126  << spaces << "Before: " << e << "\n"
127  << spaces << "After: " << new_e << "\n";
128  }
129  internal_assert(e.type() == new_e.type());
130  return new_e;
131  }
132 
133 #else
135  Expr mutate(const Expr &e, ExprInfo *b) {
136  // This gets inlined into every call to mutate, so do not add any code here.
137  return Super::dispatch(e, b);
138  }
139 #endif
140 
141 #if LOG_STMT_MUTATIONS
142  Stmt mutate(const Stmt &s) {
143  const std::string spaces(debug_indent, ' ');
144  debug(1) << spaces << "Simplifying Stmt: " << s << "\n";
145  debug_indent++;
146  Stmt new_s = Super::dispatch(s);
147  debug_indent--;
148  if (!new_s.same_as(s)) {
149  debug(1)
150  << spaces << "Before: " << s << "\n"
151  << spaces << "After: " << new_s << "\n";
152  }
153  return new_s;
154  }
155 #else
156  Stmt mutate(const Stmt &s) {
157  return Super::dispatch(s);
158  }
159 #endif
160 
163 
165  bool may_simplify(const Type &t) const {
166  return !no_float_simplify || !t.is_float();
167  }
168 
169  // Returns true iff t is an integral type where overflow is undefined
172  return t.is_int() && t.bits() >= 32;
173  }
174 
177  return t.is_scalar() && no_overflow_int(t);
178  }
179 
180  // Returns true iff t does not have a well defined overflow behavior.
182  bool no_overflow(Type t) {
183  return t.is_float() || no_overflow_int(t);
184  }
185 
186  struct VarInfo {
189  };
190 
191  // Tracked for all let vars
193 
194  // Only tracked for integer let vars
196 
197  // Symbols used by rewrite rules
210 
211  // Tracks whether or not we're inside a vector loop. Certain
212  // transformations are not a good idea if the code is to be
213  // vectorized.
214  bool in_vector_loop = false;
215 
216  // Tracks whether or not the current IR is unconditionally unreachable.
217  bool in_unreachable = false;
218 
219  // If we encounter a reference to a buffer (a Load, Store, Call,
220  // or Provide), there's an implicit dependence on some associated
221  // symbols.
222  void found_buffer_reference(const std::string &name, size_t dimensions = 0);
223 
224  // Wrappers for as_const_foo that are more convenient to use in
225  // the large chains of conditions in the visit methods below.
226  bool const_float(const Expr &e, double *f);
227  bool const_int(const Expr &e, int64_t *i);
228  bool const_uint(const Expr &e, uint64_t *u);
229 
230  // Put the args to a commutative op in a canonical order
232  bool should_commute(const Expr &a, const Expr &b) {
233  if (a.node_type() < b.node_type()) {
234  return true;
235  }
236  if (a.node_type() > b.node_type()) {
237  return false;
238  }
239 
240  if (a.node_type() == IRNodeType::Variable) {
241  const Variable *va = a.as<Variable>();
242  const Variable *vb = b.as<Variable>();
243  return va->name.compare(vb->name) > 0;
244  }
245 
246  return false;
247  }
248 
249  std::set<Expr, IRDeepCompare> truths, falsehoods;
250 
251  struct ScopedFact {
253 
254  std::vector<const Variable *> pop_list;
255  std::vector<const Variable *> bounds_pop_list;
256  std::vector<Expr> truths, falsehoods;
257 
258  void learn_false(const Expr &fact);
259  void learn_true(const Expr &fact);
260  void learn_upper_bound(const Variable *v, int64_t val);
261  void learn_lower_bound(const Variable *v, int64_t val);
262 
263  // Replace exprs known to be truths or falsehoods with const_true or const_false.
266 
268  : simplify(s) {
269  }
271 
272  // allow move but not copy
273  ScopedFact(const ScopedFact &that) = delete;
274  ScopedFact(ScopedFact &&that) = default;
275  };
276 
277  // Tell the simplifier to learn from and exploit a boolean
278  // condition, over the lifetime of the returned object.
279  ScopedFact scoped_truth(const Expr &fact) {
280  ScopedFact f(this);
281  f.learn_true(fact);
282  return f;
283  }
284 
285  // Tell the simplifier to assume a boolean condition is false over
286  // the lifetime of the returned object.
288  ScopedFact f(this);
289  f.learn_false(fact);
290  return f;
291  }
292 
293  template<typename T>
295 
297  return mutate(s);
298  }
299  Expr mutate_let_body(const Expr &e, ExprInfo *bounds) {
300  return mutate(e, bounds);
301  }
302 
303  template<typename T, typename Body>
304  Body simplify_let(const T *op, ExprInfo *bounds);
305 
306  Expr visit(const IntImm *op, ExprInfo *bounds);
307  Expr visit(const UIntImm *op, ExprInfo *bounds);
308  Expr visit(const FloatImm *op, ExprInfo *bounds);
309  Expr visit(const StringImm *op, ExprInfo *bounds);
310  Expr visit(const Broadcast *op, ExprInfo *bounds);
311  Expr visit(const Cast *op, ExprInfo *bounds);
312  Expr visit(const Variable *op, ExprInfo *bounds);
313  Expr visit(const Add *op, ExprInfo *bounds);
314  Expr visit(const Sub *op, ExprInfo *bounds);
315  Expr visit(const Mul *op, ExprInfo *bounds);
316  Expr visit(const Div *op, ExprInfo *bounds);
317  Expr visit(const Mod *op, ExprInfo *bounds);
318  Expr visit(const Min *op, ExprInfo *bounds);
319  Expr visit(const Max *op, ExprInfo *bounds);
320  Expr visit(const EQ *op, ExprInfo *bounds);
321  Expr visit(const NE *op, ExprInfo *bounds);
322  Expr visit(const LT *op, ExprInfo *bounds);
323  Expr visit(const LE *op, ExprInfo *bounds);
324  Expr visit(const GT *op, ExprInfo *bounds);
325  Expr visit(const GE *op, ExprInfo *bounds);
326  Expr visit(const And *op, ExprInfo *bounds);
327  Expr visit(const Or *op, ExprInfo *bounds);
328  Expr visit(const Not *op, ExprInfo *bounds);
329  Expr visit(const Select *op, ExprInfo *bounds);
330  Expr visit(const Ramp *op, ExprInfo *bounds);
331  Stmt visit(const IfThenElse *op);
332  Expr visit(const Load *op, ExprInfo *bounds);
333  Expr visit(const Call *op, ExprInfo *bounds);
334  Expr visit(const Shuffle *op, ExprInfo *bounds);
335  Expr visit(const VectorReduce *op, ExprInfo *bounds);
336  Expr visit(const Let *op, ExprInfo *bounds);
337  Stmt visit(const LetStmt *op);
338  Stmt visit(const AssertStmt *op);
339  Stmt visit(const For *op);
340  Stmt visit(const Provide *op);
341  Stmt visit(const Store *op);
342  Stmt visit(const Allocate *op);
343  Stmt visit(const Evaluate *op);
345  Stmt visit(const Block *op);
346  Stmt visit(const Realize *op);
347  Stmt visit(const Prefetch *op);
348  Stmt visit(const Free *op);
349  Stmt visit(const Acquire *op);
350  Stmt visit(const Fork *op);
351  Stmt visit(const Atomic *op);
352 
353  std::pair<std::vector<Expr>, bool> mutate_with_changes(const std::vector<Expr> &old_exprs, ExprInfo *bounds);
354 };
355 
356 } // namespace Internal
357 } // namespace Halide
358 
359 #endif
Methods for computing the upper and lower bounds of an expression, and the regions of a function read...
#define internal_assert(c)
Definition: Errors.h:19
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:38
Defines a method to match a fragment of IR against a pattern containing wildcards.
Defines the base class for things that recursively walk over the IR.
Defines the Scope class, which is used for keeping track of names in a scope while traversing IR.
A common pattern when traversing Halide IR is that you need to keep track of stuff when you find a Le...
Definition: Scope.h:94
Expr visit(const Min *op, ExprInfo *bounds)
Stmt visit(const ProducerConsumer *op)
HALIDE_ALWAYS_INLINE Expr mutate(const Expr &e, ExprInfo *b)
Scope< ExprInfo > bounds_and_alignment_info
bool const_uint(const Expr &e, uint64_t *u)
IRMatcher::WildConst< 5 > c5
void found_buffer_reference(const std::string &name, size_t dimensions=0)
Expr visit(const Cast *op, ExprInfo *bounds)
Expr visit(const LT *op, ExprInfo *bounds)
Stmt visit(const Block *op)
Expr visit(const VectorReduce *op, ExprInfo *bounds)
Stmt visit(const AssertStmt *op)
Expr visit(const UIntImm *op, ExprInfo *bounds)
Stmt mutate(const Stmt &s)
HALIDE_ALWAYS_INLINE void clear_bounds_info(ExprInfo *b)
Expr visit(const Load *op, ExprInfo *bounds)
Stmt visit(const Evaluate *op)
Expr visit(const Not *op, ExprInfo *bounds)
Body simplify_let(const T *op, ExprInfo *bounds)
Simplify(bool r, const Scope< Interval > *bi, const Scope< ModulusRemainder > *ai)
Stmt visit(const Prefetch *op)
HALIDE_ALWAYS_INLINE bool no_overflow(Type t)
Expr visit(const Div *op, ExprInfo *bounds)
IRMatcher::WildConst< 1 > c1
Expr visit(const Let *op, ExprInfo *bounds)
Expr hoist_slice_vector(Expr e)
Expr visit(const And *op, ExprInfo *bounds)
Stmt visit(const IfThenElse *op)
Expr visit(const NE *op, ExprInfo *bounds)
Expr visit(const FloatImm *op, ExprInfo *bounds)
Expr visit(const Shuffle *op, ExprInfo *bounds)
Expr visit(const Add *op, ExprInfo *bounds)
IRMatcher::WildConst< 0 > c0
ScopedFact scoped_truth(const Expr &fact)
IRMatcher::WildConst< 3 > c3
IRMatcher::WildConst< 2 > c2
Expr visit(const Ramp *op, ExprInfo *bounds)
Expr visit(const IntImm *op, ExprInfo *bounds)
Expr visit(const Max *op, ExprInfo *bounds)
Expr visit(const Variable *op, ExprInfo *bounds)
HALIDE_ALWAYS_INLINE bool may_simplify(const Type &t) const
Stmt visit(const For *op)
Stmt visit(const Atomic *op)
bool const_float(const Expr &e, double *f)
Expr visit(const GT *op, ExprInfo *bounds)
Stmt visit(const Provide *op)
Expr visit(const Sub *op, ExprInfo *bounds)
Expr visit(const LE *op, ExprInfo *bounds)
Expr visit(const Call *op, ExprInfo *bounds)
Stmt mutate_let_body(const Stmt &s, ExprInfo *)
Stmt visit(const Acquire *op)
Expr visit(const Broadcast *op, ExprInfo *bounds)
Stmt visit(const Fork *op)
HALIDE_ALWAYS_INLINE bool no_overflow_int(Type t)
Expr visit(const StringImm *op, ExprInfo *bounds)
std::set< Expr, IRDeepCompare > truths
Expr visit(const Select *op, ExprInfo *bounds)
ScopedFact scoped_falsehood(const Expr &fact)
HALIDE_ALWAYS_INLINE bool should_commute(const Expr &a, const Expr &b)
Expr visit(const Or *op, ExprInfo *bounds)
Expr visit(const Mul *op, ExprInfo *bounds)
Expr mutate_let_body(const Expr &e, ExprInfo *bounds)
Stmt visit(const Store *op)
Expr visit(const Mod *op, ExprInfo *bounds)
HALIDE_ALWAYS_INLINE bool no_overflow_scalar_int(Type t)
bool const_int(const Expr &e, int64_t *i)
Stmt visit(const Free *op)
IRMatcher::WildConst< 4 > c4
Stmt visit(const Allocate *op)
Stmt visit(const Realize *op)
std::pair< std::vector< Expr >, bool > mutate_with_changes(const std::vector< Expr > &old_exprs, ExprInfo *bounds)
Stmt visit(const LetStmt *op)
std::set< Expr, IRDeepCompare > falsehoods
Expr visit(const EQ *op, ExprInfo *bounds)
Expr visit(const GE *op, ExprInfo *bounds)
A visitor/mutator capable of passing arbitrary arguments to the visit methods using CRTP and returnin...
Definition: IRVisitor.h:157
HALIDE_ALWAYS_INLINE Stmt dispatch(const Stmt &s, Args &&...args)
Definition: IRVisitor.h:325
For optional debugging during codegen, use the debug class as follows:
Definition: Debug.h:49
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition: IROperator.h:239
bool mul_would_overflow(int bits, int64_t a, int64_t b)
int64_t saturating_mul(int64_t a, int64_t b)
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Expr min(const FuncRef &a, const FuncRef &b)
Explicit overloads of min and max for FuncRef.
Definition: Func.h:600
Expr max(const FuncRef &a, const FuncRef &b)
Definition: Func.h:603
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
A fragment of Halide syntax.
Definition: Expr.h:256
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition: Expr.h:320
The sum of two expressions.
Definition: IR.h:38
Allocate a scratch area called with the given name, type, and size.
Definition: IR.h:353
Logical and - are both expressions true.
Definition: IR.h:157
If the 'condition' is false, then evaluate and return the message, which should be a call to an error...
Definition: IR.h:276
Lock all the Store nodes in the body statement.
Definition: IR.h:870
A sequence of statements to be executed in-order.
Definition: IR.h:418
A vector with 'lanes' elements, in which every element is 'value'.
Definition: IR.h:241
A function call.
Definition: IR.h:466
The actual IR nodes begin here.
Definition: IR.h:29
The ratio of two expressions.
Definition: IR.h:65
Is the first expression equal to the second.
Definition: IR.h:103
Evaluate and discard an expression, presumably because it has some side-effect.
Definition: IR.h:452
Floating point constants.
Definition: Expr.h:234
A for loop.
Definition: IR.h:747
A pair of statements executed concurrently.
Definition: IR.h:433
Free the resources associated with the given buffer.
Definition: IR.h:389
Is the first expression greater than or equal to the second.
Definition: IR.h:148
Is the first expression greater than the second.
Definition: IR.h:139
IRNodeType node_type() const
Definition: Expr.h:210
const T * as() const
Downcast this ir node to its actual type (e.g.
Definition: Expr.h:203
An if-then-else block.
Definition: IR.h:442
Integer constants.
Definition: Expr.h:216
HALIDE_ALWAYS_INLINE bool same_as(const IntrusivePtr &other) const
Definition: IntrusivePtr.h:168
Is the first expression less than or equal to the second.
Definition: IR.h:130
Is the first expression less than the second.
Definition: IR.h:121
A let expression, like you might find in a functional language.
Definition: IR.h:253
The statement form of a let node.
Definition: IR.h:264
Load a value from a named symbol if predicate is true.
Definition: IR.h:199
The greater of two values.
Definition: IR.h:94
The lesser of two values.
Definition: IR.h:85
The remainder of a / b.
Definition: IR.h:76
The result of modulus_remainder analysis.
static ModulusRemainder intersect(const ModulusRemainder &a, const ModulusRemainder &b)
The product of two expressions.
Definition: IR.h:56
Is the first expression not equal to the second.
Definition: IR.h:112
Logical not - true if the expression false.
Definition: IR.h:175
Logical or - is at least one of the expression true.
Definition: IR.h:166
Represent a multi-dimensional region of a Func or an ImageParam that needs to be prefetched.
Definition: IR.h:847
This node is a helpful annotation to do with permissions.
Definition: IR.h:297
This defines the value of a function at a multi-dimensional location.
Definition: IR.h:336
A linear ramp vector node.
Definition: IR.h:229
Allocate a multi-dimensional buffer of the given type and size.
Definition: IR.h:403
A ternary operator.
Definition: IR.h:186
Construct a new vector by taking elements from another sequence of vectors.
Definition: IR.h:778
void intersect(const ExprInfo &other)
ScopedFact(ScopedFact &&that)=default
void learn_false(const Expr &fact)
std::vector< const Variable * > bounds_pop_list
ScopedFact(const ScopedFact &that)=delete
std::vector< const Variable * > pop_list
void learn_lower_bound(const Variable *v, int64_t val)
void learn_upper_bound(const Variable *v, int64_t val)
A reference-counted handle to a statement node.
Definition: Expr.h:417
Store a 'value' to the buffer called 'name' at a given 'index' if 'predicate' is true.
Definition: IR.h:315
String constants.
Definition: Expr.h:243
The difference of two expressions.
Definition: IR.h:47
Unsigned integer constants.
Definition: Expr.h:225
A named variable.
Definition: IR.h:700
std::string name
Definition: IR.h:701
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:888
Types in the halide type system.
Definition: Type.h:266
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition: Type.h:414
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition: Type.h:328
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition: Type.h:396
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition: Type.h:402