Halide  12.0.1
Halide compiler and libraries
Derivative.h
Go to the documentation of this file.
1 #ifndef HALIDE_DERIVATIVE_H
2 #define HALIDE_DERIVATIVE_H
3 
4 /** \file
5  * Automatic differentiation
6  */
7 
8 #include "Expr.h"
9 #include "Func.h"
10 
11 #include <map>
12 #include <string>
13 #include <vector>
14 
15 namespace Halide {
16 
17 /**
18  * Helper structure storing the adjoints Func.
19  * Use d(func) or d(buffer) to obtain the derivative Func.
20  */
21 class Derivative {
22 public:
23  // function name & update_id, for initialization update_id == -1
24  using FuncKey = std::pair<std::string, int>;
25 
26  explicit Derivative(const std::map<FuncKey, Func> &adjoints_in)
27  : adjoints(adjoints_in) {
28  }
29  explicit Derivative(std::map<FuncKey, Func> &&adjoints_in)
30  : adjoints(std::move(adjoints_in)) {
31  }
32 
33  // These all return an undefined Func if no derivative is found
34  // (typically, if the input Funcs aren't differentiable)
35  Func operator()(const Func &func, int update_id = -1) const;
36  Func operator()(const Buffer<> &buffer) const;
37  Func operator()(const Param<> &param) const;
38 
39 private:
40  const std::map<FuncKey, Func> adjoints;
41 };
42 
43 /**
44  * Given a Func and a corresponding adjoint, (back)propagate the
45  * adjoint to all dependent Funcs, buffers, and parameters.
46  * The bounds of output and adjoint need to be specified with pair {min, extent}
47  * For each Func the output depends on, and for the pure definition and
48  * each update of that Func, it generates a derivative Func stored in
49  * the Derivative.
50  */
52  const Func &adjoint,
53  const Region &output_bounds);
54 /**
55  * Given a Func and a corresponding adjoint buffer, (back)propagate the
56  * adjoint to all dependent Funcs, buffers, and parameters.
57  * For each Func the output depends on, and for the pure definition and
58  * each update of that Func, it generates a derivative Func stored in
59  * the Derivative.
60  */
62  const Buffer<float> &adjoint);
63 /**
64  * Given a scalar Func with size 1, (back)propagate the gradient
65  * to all dependent Funcs, buffers, and parameters.
66  * For each Func the output depends on, and for the pure definition and
67  * each update of that Func, it generates a derivative Func stored in
68  * the Derivative.
69  */
71 
72 } // namespace Halide
73 
74 #endif
Base classes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Defines Func - the front-end handle on a halide function, and related classes.
A Halide::Buffer is a named shared reference to a Halide::Runtime::Buffer.
Definition: Buffer.h:115
Helper structure storing the adjoints Func.
Definition: Derivative.h:21
Derivative(std::map< FuncKey, Func > &&adjoints_in)
Definition: Derivative.h:29
std::pair< std::string, int > FuncKey
Definition: Derivative.h:24
Func operator()(const Func &func, int update_id=-1) const
Func operator()(const Buffer<> &buffer) const
Func operator()(const Param<> &param) const
Derivative(const std::map< FuncKey, Func > &adjoints_in)
Definition: Derivative.h:26
A halide function.
Definition: Func.h:681
A scalar parameter to a halide pipeline.
Definition: Param.h:22
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Derivative propagate_adjoints(const Func &output, const Func &adjoint, const Region &output_bounds)
Given a Func and a corresponding adjoint, (back)propagate the adjoint to all dependent Funcs,...
std::vector< Range > Region
A multi-dimensional box.
Definition: Expr.h:343