SHOGUN  v3.2.0
StructuredOutputMachine.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Shell Hu
8  * Written (W) 2013 Thoralf Klein
9  * Written (W) 2012 Fernando José Iglesias García
10  * Copyright (C) 2012 Fernando José Iglesias García
11  */
12 
14 
15 using namespace shogun;
16 
18 : CMachine(), m_model(NULL), m_surrogate_loss(NULL)
19 {
20  register_parameters();
21 }
22 
24  CStructuredModel* model,
25  CStructuredLabels* labs)
26 : CMachine(), m_model(model), m_surrogate_loss(NULL)
27 {
28  SG_REF(m_model);
29  set_labels(labs);
30  register_parameters();
31 }
32 
34 {
38 }
39 
41 {
42  SG_REF(model);
44  m_model = model;
45 }
46 
48 {
49  SG_REF(m_model);
50  return m_model;
51 }
52 
53 void CStructuredOutputMachine::register_parameters()
54 {
55  SG_ADD((CSGObject**)&m_model, "m_model", "Structured model", MS_NOT_AVAILABLE);
56  SG_ADD((CSGObject**)&m_surrogate_loss, "m_surrogate_loss", "Surrogate loss", MS_NOT_AVAILABLE);
57  SG_ADD(&m_verbose, "verbose", "Verbosity flag", MS_NOT_AVAILABLE);
58  SG_ADD((CSGObject**)&m_helper, "helper", "Training helper", MS_NOT_AVAILABLE);
59 
60  m_verbose = false;
61  m_helper = NULL;
62 }
63 
65 {
67  REQUIRE(m_model != NULL, "please call set_model() before set_labels()\n");
69 }
70 
72 {
74 }
75 
77 {
78  return m_model->get_features();
79 }
80 
82 {
83  SG_REF(loss);
85  m_surrogate_loss = loss;
86 }
87 
89 {
91  return m_surrogate_loss;
92 }
93 
95 {
96  int32_t dim = m_model->get_dim();
97 
98  int32_t from=0, to=0;
99  CFeatures* features = get_features();
100  if (info)
101  {
102  from = info->m_from;
103  to = (info->m_N == 0) ? features->get_num_vectors() : from+info->m_N;
104  }
105  else
106  {
107  from = 0;
108  to = features->get_num_vectors();
109  }
110  SG_UNREF(features);
111 
112  float64_t R = 0.0;
113  for (int32_t i=0; i<dim; i++)
114  subgrad[i] = 0;
115 
116  for (int32_t i=from; i<to; i++)
117  {
118  CResultSet* result = m_model->argmax(SGVector<float64_t>(W,dim,false), i, true);
119  SGVector<float64_t> psi_pred = result->psi_pred;
120  SGVector<float64_t> psi_truth = result->psi_truth;
121  SGVector<float64_t>::vec1_plus_scalar_times_vec2(subgrad, 1.0, psi_pred.vector, dim);
122  SGVector<float64_t>::vec1_plus_scalar_times_vec2(subgrad, -1.0, psi_truth.vector, dim);
123  R += result->score;
124  SG_UNREF(result);
125  }
126 
127  return R;
128 }
129 
131 {
132  SG_ERROR("%s::risk_nslack_slack_rescale() has not been implemented!\n", get_name());
133  return 0.0;
134 }
135 
137 {
138  SG_ERROR("%s::risk_1slack_margin_rescale() has not been implemented!\n", get_name());
139  return 0.0;
140 }
141 
143 {
144  SG_ERROR("%s::risk_1slack_slack_rescale() has not been implemented!\n", get_name());
145  return 0.0;
146 }
147 
149 {
150  SG_ERROR("%s::risk_customized_formulation() has not been implemented!\n", get_name());
151  return 0.0;
152 }
153 
155  TMultipleCPinfo* info, EStructRiskType rtype)
156 {
157  float64_t ret = 0.0;
158  switch(rtype)
159  {
161  ret = risk_nslack_margin_rescale(subgrad, W, info);
162  break;
164  ret = risk_nslack_slack_rescale(subgrad, W, info);
165  break;
167  ret = risk_1slack_margin_rescale(subgrad, W, info);
168  break;
170  ret = risk_1slack_slack_rescale(subgrad, W, info);
171  break;
172  case CUSTOMIZED_RISK:
173  ret = risk_customized_formulation(subgrad, W, info);
174  break;
175  default:
176  SG_ERROR("%s::risk(): cannot recognize the risk type!\n", get_name());
177  ret = -1;
178  break;
179  }
180  return ret;
181 }
182 
184 {
185  if (m_helper == NULL)
186  {
187  SG_ERROR("%s::get_helper(): no helper has been created!"
188  "Please set verbose before training!\n", get_name());
189  }
190 
191  SG_REF(m_helper);
192  return m_helper;
193 }
194 
196 {
197  m_verbose = verbose;
198 }
199 
201 {
202  return m_verbose;
203 }
SGVector< float64_t > psi_truth
Base class of the labels used in Structured Output (SO) problems.
Class CLossFunction is the base class of all loss functions.
Definition: LossFunction.h:53
void set_labels(CStructuredLabels *labs)
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:35
#define SG_UNREF(x)
Definition: SGRefObject.h:35
virtual float64_t risk_customized_formulation(float64_t *subgrad, float64_t *W, TMultipleCPinfo *info=0)
virtual int32_t get_num_vectors() const =0
#define SG_ERROR(...)
Definition: SGIO.h:131
#define REQUIRE(x,...)
Definition: SGIO.h:208
virtual float64_t risk_1slack_slack_rescale(float64_t *subgrad, float64_t *W, TMultipleCPinfo *info=0)
virtual int32_t get_dim() const =0
CLossFunction * get_surrogate_loss() const
virtual const char * get_name() const
A generic learning machine interface.
Definition: Machine.h:138
void set_features(CFeatures *feats)
virtual float64_t risk_nslack_margin_rescale(float64_t *subgrad, float64_t *W, TMultipleCPinfo *info=0)
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:102
void set_model(CStructuredModel *model)
double float64_t
Definition: common.h:48
#define SG_REF(x)
Definition: SGRefObject.h:34
class CSOSVMHelper contains helper functions to compute primal objectives, dual objectives, average training losses, duality gaps etc. These values will be recorded to check convergence. This class is inspired by the matlab implementation of the block coordinate Frank-Wolfe SOSVM solver [1].
Definition: SOSVMHelper.h:29
virtual float64_t risk_nslack_slack_rescale(float64_t *subgrad, float64_t *W, TMultipleCPinfo *info=0)
virtual float64_t risk(float64_t *subgrad, float64_t *W, TMultipleCPinfo *info=0, EStructRiskType rtype=N_SLACK_MARGIN_RESCALING)
Class CStructuredModel that represents the application specific model and contains most of the applic...
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:16
virtual CResultSet * argmax(SGVector< float64_t > w, int32_t feat_idx, bool const training=true)=0
The class Features is the base class of all feature objects.
Definition: Features.h:62
SGVector< float64_t > psi_pred
CStructuredModel * get_model() const
static CStructuredLabels * to_structured(CLabels *base_labels)
#define SG_ADD(...)
Definition: SGObject.h:71
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:75
static void vec1_plus_scalar_times_vec2(T *vec1, const T scalar, const T *vec2, int32_t n)
x=x+alpha*y
Definition: SGVector.cpp:580
virtual float64_t risk_1slack_margin_rescale(float64_t *subgrad, float64_t *W, TMultipleCPinfo *info=0)
void set_surrogate_loss(CLossFunction *loss)

SHOGUN Machine Learning Toolbox - Documentation