41 #include <pcl/ml/branch_estimator.h>
42 #include <pcl/ml/stats_estimator.h>
50 template <
class FeatureType,
class LabelType>
66 feature.serialize(stream);
68 stream.write(reinterpret_cast<const char*>(&threshold),
sizeof(threshold));
70 stream.write(reinterpret_cast<const char*>(&value),
sizeof(value));
71 stream.write(reinterpret_cast<const char*>(&variance),
sizeof(variance));
73 const int num_of_sub_nodes = static_cast<int>(sub_nodes.size());
74 stream.write(reinterpret_cast<const char*>(&num_of_sub_nodes),
75 sizeof(num_of_sub_nodes));
76 for (
int sub_node_index = 0; sub_node_index < num_of_sub_nodes; ++sub_node_index) {
77 sub_nodes[sub_node_index].serialize(stream);
88 feature.deserialize(stream);
90 stream.read(reinterpret_cast<char*>(&threshold),
sizeof(threshold));
92 stream.read(reinterpret_cast<char*>(&value),
sizeof(value));
93 stream.read(reinterpret_cast<char*>(&variance),
sizeof(variance));
96 stream.read(reinterpret_cast<char*>(&num_of_sub_nodes),
sizeof(num_of_sub_nodes));
97 sub_nodes.resize(num_of_sub_nodes);
99 if (num_of_sub_nodes > 0) {
100 for (
int sub_node_index = 0; sub_node_index < num_of_sub_nodes;
102 sub_nodes[sub_node_index].deserialize(stream);
125 template <
class LabelDataType,
class NodeType,
class DataSet,
class ExampleIndex>
131 : branch_estimator_(branch_estimator)
142 return branch_estimator_->getNumOfBranches();
166 std::vector<ExampleIndex>& examples,
167 std::vector<LabelDataType>& label_data,
168 std::vector<float>& results,
169 std::vector<unsigned char>& flags,
170 const float threshold)
const
172 const std::size_t num_of_examples = examples.size();
173 const std::size_t num_of_branches = getNumOfBranches();
176 std::vector<LabelDataType> sums(num_of_branches + 1, 0);
177 std::vector<LabelDataType> sqr_sums(num_of_branches + 1, 0);
178 std::vector<std::size_t> branch_element_count(num_of_branches + 1, 0);
180 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
181 branch_element_count[branch_index] = 1;
182 ++branch_element_count[num_of_branches];
185 for (std::size_t example_index = 0; example_index < num_of_examples;
187 unsigned char branch_index;
189 results[example_index], flags[example_index], threshold, branch_index);
191 LabelDataType label = label_data[example_index];
193 sums[branch_index] += label;
194 sums[num_of_branches] += label;
196 sqr_sums[branch_index] += label * label;
197 sqr_sums[num_of_branches] += label * label;
199 ++branch_element_count[branch_index];
200 ++branch_element_count[num_of_branches];
203 std::vector<float> variances(num_of_branches + 1, 0);
204 for (std::size_t branch_index = 0; branch_index < num_of_branches + 1;
206 const float mean_sum =
207 static_cast<float>(sums[branch_index]) / branch_element_count[branch_index];
208 const float mean_sqr_sum = static_cast<float>(sqr_sums[branch_index]) /
209 branch_element_count[branch_index];
210 variances[branch_index] = mean_sqr_sum - mean_sum * mean_sum;
213 float information_gain = variances[num_of_branches];
214 for (std::size_t branch_index = 0; branch_index < num_of_branches; ++branch_index) {
217 const float weight = static_cast<float>(branch_element_count[branch_index]) /
218 static_cast<float>(branch_element_count[num_of_branches]);
219 information_gain -= weight * variances[branch_index];
222 return information_gain;
234 std::vector<unsigned char>& flags,
235 const float threshold,
236 std::vector<unsigned char>& branch_indices)
const
238 const std::size_t num_of_results = results.size();
239 const std::size_t num_of_branches = getNumOfBranches();
241 branch_indices.resize(num_of_results);
242 for (std::size_t result_index = 0; result_index < num_of_results; ++result_index) {
243 unsigned char branch_index;
245 results[result_index], flags[result_index], threshold, branch_index);
246 branch_indices[result_index] = branch_index;
259 const unsigned char flag,
260 const float threshold,
261 unsigned char& branch_index)
const
263 branch_estimator_->computeBranchIndex(result, flag, threshold, branch_index);
277 std::vector<ExampleIndex>& examples,
278 std::vector<LabelDataType>& label_data,
279 NodeType& node)
const
281 const std::size_t num_of_examples = examples.size();
283 LabelDataType sum = 0.0f;
284 LabelDataType sqr_sum = 0.0f;
285 for (std::size_t example_index = 0; example_index < num_of_examples;
287 const LabelDataType label = label_data[example_index];
290 sqr_sum += label * label;
293 sum /= num_of_examples;
294 sqr_sum /= num_of_examples;
296 const float variance = sqr_sum - sum * sum;
299 node.variance = variance;
310 stream <<
"ERROR: RegressionVarianceStatsEstimator does not implement "
311 "generateCodeForBranchIndex(...)";
322 stream <<
"ERROR: RegressionVarianceStatsEstimator does not implement "
323 "generateCodeForBranchIndex(...)";