26 bool do_weighted_averaging,
30 REQUIRE(model != NULL && labs != NULL,
31 "%s::CStochasticSOSVM(): model and labels cannot be NULL!\n",
get_name());
34 "%s::CStochasticSOSVM(): number of labels should be greater than 0!\n",
get_name());
38 m_do_weighted_averaging = do_weighted_averaging;
42 void CStochasticSOSVM::init()
52 m_do_weighted_averaging =
true;
53 m_debug_multiplier = 0;
68 SG_DEBUG(
"Entering CStochasticSOSVM::train_machine.\n");
76 SG_DEBUG(
"The training setup is correct.\n");
90 if (m_do_weighted_averaging)
103 int32_t debug_iter = 1;
104 if (m_debug_multiplier == 0)
107 m_debug_multiplier = 100;
114 for (int32_t pi = 0; pi < m_num_iter; ++pi)
116 for (int32_t si = 0; si < N; ++si)
133 w_s.
scale(1.0 / (N*m_lambda));
143 if (m_do_weighted_averaging)
157 if (m_do_weighted_averaging)
158 w_debug = w_avg.
clone();
165 SG_DEBUG(
"pass %d (iteration %d), SVM primal = %f, train_error = %f \n",
166 pi, k, primal, train_error);
170 debug_iter =
CMath::min(debug_iter+N, debug_iter*(1+m_debug_multiplier/100));
175 if (m_do_weighted_averaging)
181 SG_DEBUG(
"Leaving CStochasticSOSVM::train_machine.\n");
202 m_num_iter = num_iter;
207 return m_debug_multiplier;
212 m_debug_multiplier = multiplier;
222 m_rand_seed = rand_seed;
SGVector< float64_t > psi_truth
Base class of the labels used in Structured Output (SO) problems
uint32_t get_rand_seed() const
void set_debug_multiplier(int32_t multiplier)
virtual bool train_machine(CFeatures *data=NULL)
static float64_t primal_objective(SGVector< float64_t > w, CStructuredModel *model, float64_t lbda)
CStructuredModel * m_model
virtual int32_t get_dim() const =0
void set_num_iter(int32_t num_iter)
int32_t get_debug_multiplier() const
void add(const SGVector< T > x)
static float64_t average_loss(SGVector< float64_t > w, CStructuredModel *model)
float64_t get_lambda() const
int32_t get_num_iter() const
void set_rand_seed(uint32_t rand_seed)
virtual void init_training()
void scale(T alpha)
scale vector inplace
virtual const char * get_name() const
void set_features(CFeatures *f)
static void init_random(uint32_t initseed=0)
class CSOSVMHelper contains helper functions to compute primal objectives, dual objectives, average training losses, duality gaps etc. These values will be recorded to check convergence. This class is inspired by the matlab implementation of the block coordinate Frank-Wolfe SOSVM solver [1].
Class CStructuredModel that represents the application specific model and contains most of the applic...
all of classes and functions are contained in the shogun namespace
virtual CResultSet * argmax(SGVector< float64_t > w, int32_t feat_idx, bool const training=true)=0
virtual void add_debug_info(float64_t primal, float64_t eff_pass, float64_t train_error, float64_t dual=-1, float64_t dgap=-1)
The class Features is the base class of all feature objects.
static T min(T a, T b)
return the minimum of two integers
SGVector< float64_t > m_w
virtual int32_t get_num_labels() const
void set_lambda(float64_t lbda)
SGVector< float64_t > psi_pred
virtual EMachineType get_classifier_type()
static CStructuredLabels * to_structured(CLabels *base_labels)
virtual bool check_training_setup() const
SGVector< T > clone() const