Shark machine learning library
About Shark
News!
Contribute
Credits and copyright
Downloads
Getting Started
Installation
Using the docs
Documentation
Tutorials
Quick references
Class list
Global functions
FAQ
Showroom
Main Page
Related Pages
Modules
Namespaces
Classes
Files
File List
File Members
obj-x86_64-linux-gnu
examples
Supervised
quickstartTutorial.cpp
Go to the documentation of this file.
1
#include <
shark/Data/Csv.h
>
2
#include <
shark/Algorithms/Trainers/LDA.h
>
3
using namespace
shark
;
4
5
6
#include <iostream>
7
using namespace
std
;
8
9
10
int
main
(
int
argc,
char
**argv){
11
//create a Dataset from the file "quickstartData"
12
if
(argc < 2) {
13
cerr <<
"usage: "
<< argv[0] <<
" (filename)"
<< endl;
14
exit(EXIT_FAILURE);
15
}
16
17
ClassificationDataset
data;
18
try
{
19
importCSV
(data, argv[1],
LAST_COLUMN
,
' '
);
20
}
21
catch
(...) {
22
cerr <<
"unable to read data from file "
<< argv[1] << endl;
23
exit(EXIT_FAILURE);
24
}
25
26
//create a test and training partition of the data
27
ClassificationDataset
test =
splitAtElement
(data,static_cast<std::size_t>(0.8*data.
numberOfElements
()));
28
29
//create a classifier for the problem
30
LinearClassifier<>
classifier;
31
//create the lda trainer
32
LDA
lda;
33
//train the classifier using the training portion of the Data
34
lda.
train
(classifier,data);
35
36
37
//now use the test data to evaluate the model
38
//loop over all points of the test set
39
//be aware that in this example a single point consists of an input and a label
40
//this code here is just for illustration purposes
41
unsigned
int
correct = 0;
42
BOOST_FOREACH(
ClassificationDataset::element_reference
point, test.
elements
()){
43
unsigned
int
result = classifier(point.input);
44
if
(result == point.label){
45
correct++;
46
}
47
}
48
49
//print results
50
cout <<
"RESULTS: "
<< endl;
51
cout <<
"========\n"
<< endl;
52
cout <<
"test data size: "
<< test.
numberOfElements
() << endl;
53
cout <<
"correct classification: "
<< correct << endl;
54
cout <<
"error rate: "
<< 1.0 - double(correct)/test.
numberOfElements
() << endl;
55
}