1
2
3
4 """ handles doing cross validation with neural nets
5
6 This is, perhaps, a little misleading. For the purposes of this module,
7 cross validation == evaluating the accuracy of a net.
8
9 """
10 from __future__ import print_function
11 from rdkit.ML.Neural import Network,Trainers
12 from rdkit.ML.Data import SplitData
13 import math
14
16 """ Determines the classification error for the testExamples
17 **Arguments**
18
19 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method)
20
21 - testExamples: a list of examples to be used for testing
22
23 - appendExamples: a toggle which is ignored, it's just here to maintain
24 the same API as the decision tree code.
25
26 **Returns**
27
28 a 2-tuple consisting of:
29
30 1) the percent error of the net
31
32 2) a list of misclassified examples
33
34 **Note**
35 At the moment, this is specific to nets with only one output
36 """
37 nTest = len(testExamples)
38 nBad = 0
39 badExamples = []
40 for i in range(nTest):
41 testEx = testExamples[i]
42 trueRes = testExamples[i][-1]
43 res = net.ClassifyExample(testEx)
44 if math.fabs(trueRes-res) > tolerance:
45 badExamples.append(testEx)
46 nBad = nBad + 1
47
48 return float(nBad)/nTest,badExamples
49
50 -def CrossValidationDriver(examples,attrs=[],nPossibleVals=[],holdOutFrac=.3,silent=0,
51 tolerance=0.3,calcTotalError=0,hiddenSizes=None,
52 **kwargs):
53 """
54 **Arguments**
55
56 - examples: the full set of examples
57
58 - attrs: a list of attributes to consider in the tree building
59 *This argument is ignored*
60
61 - nPossibleVals: a list of the number of possible values each variable can adopt
62 *This argument is ignored*
63
64 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set
65 (used to calculate the error)
66
67 - silent: a toggle used to control how much visual noise this makes as it goes.
68
69 - tolerance: the tolerance for convergence of the net
70
71 - calcTotalError: if this is true the entire data set is used to calculate
72 accuracy of the net
73
74 - hiddenSizes: a list containing the size(s) of the hidden layers in the network.
75 if _hiddenSizes_ is None, one hidden layer containing the same number of nodes
76 as the input layer will be used
77
78 **Returns**
79
80 a 2-tuple containing:
81
82 1) the net
83
84 2) the cross-validation error of the net
85
86 **Note**
87 At the moment, this is specific to nets with only one output
88
89 """
90 nTot = len(examples)
91 if not kwargs.get('replacementSelection',0):
92 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
93 silent=1,legacy=1,
94 replacement=0)
95 else:
96 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
97 silent=1,legacy=0,
98 replacement=1)
99 trainExamples = [examples[x] for x in trainIndices]
100 testExamples = [examples[x] for x in testIndices]
101
102 nTrain = len(trainExamples)
103 if not silent:
104 print('Training with %d examples'%(nTrain))
105
106 nInput = len(examples[0])-1
107 nOutput = 1
108 if hiddenSizes is None:
109 nHidden = nInput
110 netSize = [nInput,nHidden,nOutput]
111 else:
112 netSize = [nInput] + hiddenSizes + [nOutput]
113 net = Network.Network(netSize)
114 t = Trainers.BackProp()
115 t.TrainOnLine(trainExamples,net,errTol=tolerance,useAvgErr=0,silent=silent)
116
117
118 nTest = len(testExamples)
119 if not silent:
120 print('Testing with %d examples'%nTest)
121 if not calcTotalError:
122 xValError,badExamples = CrossValidate(net,testExamples,tolerance)
123 else:
124 xValError,badExamples = CrossValidate(net,allExamples,tolerance)
125 if not silent:
126 print('Validation error was %%%4.2f'%(100*xValError))
127 net._trainIndices=trainIndices
128 return net,xValError
129