Halide  12.0.1
Halide compiler and libraries
gpu_context_common.h
Go to the documentation of this file.
1 #include "printer.h"
2 #include "scoped_mutex_lock.h"
3 
4 namespace Halide {
5 namespace Internal {
6 
7 template<typename ContextT, typename ModuleStateT>
9  struct CachedCompilation {
10  ContextT context{};
11  ModuleStateT module_state{};
12  uint32_t kernel_id{};
13  uint32_t use_count{0};
14 
15  CachedCompilation(ContextT context, ModuleStateT module_state,
16  uint32_t kernel_id, uint32_t use_count)
17  : context(context), module_state(module_state),
18  kernel_id(kernel_id), use_count(use_count) {
19  }
20  };
21 
22  halide_mutex mutex;
23 
24  static constexpr float kLoadFactor{.5f};
25  static constexpr int kInitialTableBits{7};
26  int log2_compilations_size{0}; // number of bits in index into compilations table.
27  CachedCompilation *compilations{nullptr};
28  int count{0};
29 
30  static constexpr uint32_t kInvalidId{0};
31  static constexpr uint32_t kDeletedId{1};
32 
33  uint32_t unique_id{2}; // zero is an invalid id
34 
35 public:
36  static ALWAYS_INLINE uintptr_t kernel_hash(ContextT context, uint32_t id, uint32_t bits) {
37  uintptr_t addr = (uintptr_t)context + id;
38  // Fibonacci hashing. The golden ratio is 1.9E3779B97F4A7C15F39...
39  // in hexadecimal.
40  if (sizeof(uintptr_t) >= 8) {
41  return (addr * (uintptr_t)0x9E3779B97F4A7C15) >> (64 - bits);
42  } else {
43  return (addr * (uintptr_t)0x9E3779B9) >> (32 - bits);
44  }
45  }
46 
47  HALIDE_MUST_USE_RESULT bool insert(const CachedCompilation &entry) {
48  if (log2_compilations_size == 0) {
49  if (!resize_table(kInitialTableBits)) {
50  return false;
51  }
52  }
53  if ((count + 1) > (1 << log2_compilations_size) * kLoadFactor) {
54  if (!resize_table(log2_compilations_size + 1)) {
55  return false;
56  }
57  }
58  count += 1;
59  uintptr_t index = kernel_hash(entry.context, entry.kernel_id, log2_compilations_size);
60  for (int i = 0; i < (1 << log2_compilations_size); i++) {
61  uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
62  if (compilations[effective_index].kernel_id <= kDeletedId) {
63  compilations[effective_index] = entry;
64  return true;
65  }
66  }
67  // This is a logic error that should never occur. It means the table is
68  // full, but it should have been resized.
69  halide_assert(nullptr, false);
70  return false;
71  }
72 
73  HALIDE_MUST_USE_RESULT bool find_internal(ContextT context, uint32_t id,
74  ModuleStateT *&module_state, int increment) {
75  if (log2_compilations_size == 0) {
76  return false;
77  }
78  uintptr_t index = kernel_hash(context, id, log2_compilations_size);
79  for (int i = 0; i < (1 << log2_compilations_size); i++) {
80  uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
81 
82  if (compilations[effective_index].kernel_id == kInvalidId) {
83  return false;
84  }
85  if (compilations[effective_index].context == context &&
86  compilations[effective_index].kernel_id == id) {
87  module_state = &compilations[effective_index].module_state;
88  if (increment != 0) {
89  compilations[effective_index].use_count += increment;
90  }
91  return true;
92  }
93  }
94  return false;
95  }
96 
97  HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state) {
98  ScopedMutexLock lock_guard(&mutex);
99  uint32_t id = (uint32_t)(uintptr_t)state_ptr;
100  ModuleStateT *mod_ptr;
101  if (find_internal(context, id, mod_ptr, 0)) {
102  module_state = *mod_ptr;
103  return true;
104  }
105  return false;
106  }
107 
108  HALIDE_MUST_USE_RESULT bool resize_table(int size_bits) {
109  if (size_bits != log2_compilations_size) {
110  int new_size = (1 << size_bits);
111  int old_size = (1 << log2_compilations_size);
112  CachedCompilation *new_table = (CachedCompilation *)malloc(new_size * sizeof(CachedCompilation));
113  if (new_table == nullptr) {
114  // signal error.
115  return false;
116  }
117  memset(new_table, 0, new_size * sizeof(CachedCompilation));
118  CachedCompilation *old_table = compilations;
119  compilations = new_table;
120  log2_compilations_size = size_bits;
121 
122  if (count > 0) { // Mainly to catch empty initial table case
123  for (int32_t i = 0; i < old_size; i++) {
124  if (old_table[i].kernel_id != kInvalidId &&
125  old_table[i].kernel_id != kDeletedId) {
126  bool result = insert(old_table[i]);
127  halide_assert(nullptr, result); // Resizing the table while resizing the table is a logic error.
128  }
129  }
130  }
131  free(old_table);
132  }
133  return true;
134  }
135 
136  template<typename FreeModuleT>
137  void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f) {
138  if (count == 0) {
139  return;
140  }
141 
142  for (int i = 0; i < (1 << log2_compilations_size); i++) {
143  if (compilations[i].kernel_id > kInvalidId &&
144  (all || (compilations[i].context == context)) &&
145  compilations[i].use_count == 0) {
146  debug(user_context) << "Releasing cached compilation: " << compilations[i].module_state
147  << " id " << compilations[i].kernel_id
148  << " context " << compilations[i].context << "\n";
149  f(compilations[i].module_state);
150  compilations[i].module_state = nullptr;
151  compilations[i].kernel_id = kDeletedId;
152  count--;
153  }
154  }
155  }
156 
157  template<typename FreeModuleT>
158  void delete_context(void *user_context, ContextT context, FreeModuleT &f) {
159  ScopedMutexLock lock_guard(&mutex);
160 
161  release_context(user_context, false, context, f);
162  }
163 
164  template<typename FreeModuleT>
165  void release_all(void *user_context, FreeModuleT &f) {
166  ScopedMutexLock lock_guard(&mutex);
167 
168  release_context(user_context, true, nullptr, f);
169  // Some items may have been in use, so can't free.
170  if (count == 0) {
171  free(compilations);
172  compilations = nullptr;
173  log2_compilations_size = 0;
174  }
175  }
176 
177  template<typename CompileModuleT, typename... Args>
179  ContextT context, ModuleStateT &result,
180  CompileModuleT f,
181  Args... args) {
182  ScopedMutexLock lock_guard(&mutex);
183 
184  uint32_t *id_ptr = (uint32_t *)state_ptr;
185  if (*id_ptr == 0) {
186  *id_ptr = unique_id++;
187  }
188 
189  ModuleStateT *mod;
190  if (find_internal(context, *id_ptr, mod, 1)) {
191  result = *mod;
192  return true;
193  }
194 
195  // TODO(zvookin): figure out the calling signature here...
196  ModuleStateT compiled_module = f(args...);
197  debug(user_context) << "Caching compiled kernel: " << compiled_module
198  << " id " << *id_ptr << " context " << context << "\n";
199  if (compiled_module == nullptr) {
200  return false;
201  }
202 
203  if (!insert({context, compiled_module, *id_ptr, 1})) {
204  return false;
205  }
206  result = compiled_module;
207 
208  return true;
209  }
210 
211  void release_hold(void *user_context, ContextT context, void *state_ptr) {
212  ModuleStateT *mod;
213  uint32_t id = (uint32_t)(uintptr_t)state_ptr;
214  bool result = find_internal(context, id, mod, -1);
215  halide_assert(user_context, result); // Value must be in cache to be released
216  }
217 };
218 
219 } // namespace Internal
220 } // namespace Halide
#define HALIDE_MUST_USE_RESULT
Definition: HalideRuntime.h:54
HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state)
static ALWAYS_INLINE uintptr_t kernel_hash(ContextT context, uint32_t id, uint32_t bits)
void release_hold(void *user_context, ContextT context, void *state_ptr)
HALIDE_MUST_USE_RESULT bool insert(const CachedCompilation &entry)
HALIDE_MUST_USE_RESULT bool kernel_state_setup(void *user_context, void **state_ptr, ContextT context, ModuleStateT &result, CompileModuleT f, Args... args)
void release_all(void *user_context, FreeModuleT &f)
void delete_context(void *user_context, ContextT context, FreeModuleT &f)
HALIDE_MUST_USE_RESULT bool resize_table(int size_bits)
void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f)
HALIDE_MUST_USE_RESULT bool find_internal(ContextT context, uint32_t id, ModuleStateT *&module_state, int increment)
For optional debugging during codegen, use the debug class as follows:
Definition: Debug.h:49
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1089
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
void * user_context
Definition: printer.h:33
void * malloc(size_t)
signed __INT32_TYPE__ int32_t
void * memset(void *s, int val, size_t n)
#define ALWAYS_INLINE
unsigned __INT32_TYPE__ uint32_t
#define halide_assert(user_context, cond)
void free(void *)
Cross-platform mutex.