Package rdkit :: Package ML :: Package Neural :: Module CrossValidate
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.Neural.CrossValidate

  1  # 
  2  #  Copyright (C) 2000  greg Landrum 
  3  # 
  4  """ handles doing cross validation with neural nets 
  5   
  6  This is, perhaps, a little misleading.  For the purposes of this module, 
  7  cross validation == evaluating the accuracy of a net. 
  8   
  9  """ 
 10  from __future__ import print_function 
 11  from rdkit.ML.Neural import Network,Trainers 
 12  from rdkit.ML.Data import SplitData 
 13  import math 
 14   
15 -def CrossValidate(net,testExamples,tolerance,appendExamples=0):
16 """ Determines the classification error for the testExamples 17 **Arguments** 18 19 - tree: a decision tree (or anything supporting a _ClassifyExample()_ method) 20 21 - testExamples: a list of examples to be used for testing 22 23 - appendExamples: a toggle which is ignored, it's just here to maintain 24 the same API as the decision tree code. 25 26 **Returns** 27 28 a 2-tuple consisting of: 29 30 1) the percent error of the net 31 32 2) a list of misclassified examples 33 34 **Note** 35 At the moment, this is specific to nets with only one output 36 """ 37 nTest = len(testExamples) 38 nBad = 0 39 badExamples = [] 40 for i in range(nTest): 41 testEx = testExamples[i] 42 trueRes = testExamples[i][-1] 43 res = net.ClassifyExample(testEx) 44 if math.fabs(trueRes-res) > tolerance: 45 badExamples.append(testEx) 46 nBad = nBad + 1 47 48 return float(nBad)/nTest,badExamples
49
50 -def CrossValidationDriver(examples,attrs=[],nPossibleVals=[],holdOutFrac=.3,silent=0, 51 tolerance=0.3,calcTotalError=0,hiddenSizes=None, 52 **kwargs):
53 """ 54 **Arguments** 55 56 - examples: the full set of examples 57 58 - attrs: a list of attributes to consider in the tree building 59 *This argument is ignored* 60 61 - nPossibleVals: a list of the number of possible values each variable can adopt 62 *This argument is ignored* 63 64 - holdOutFrac: the fraction of the data which should be reserved for the hold-out set 65 (used to calculate the error) 66 67 - silent: a toggle used to control how much visual noise this makes as it goes. 68 69 - tolerance: the tolerance for convergence of the net 70 71 - calcTotalError: if this is true the entire data set is used to calculate 72 accuracy of the net 73 74 - hiddenSizes: a list containing the size(s) of the hidden layers in the network. 75 if _hiddenSizes_ is None, one hidden layer containing the same number of nodes 76 as the input layer will be used 77 78 **Returns** 79 80 a 2-tuple containing: 81 82 1) the net 83 84 2) the cross-validation error of the net 85 86 **Note** 87 At the moment, this is specific to nets with only one output 88 89 """ 90 nTot = len(examples) 91 if not kwargs.get('replacementSelection',0): 92 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 93 silent=1,legacy=1, 94 replacement=0) 95 else: 96 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac, 97 silent=1,legacy=0, 98 replacement=1) 99 trainExamples = [examples[x] for x in trainIndices] 100 testExamples = [examples[x] for x in testIndices] 101 102 nTrain = len(trainExamples) 103 if not silent: 104 print('Training with %d examples'%(nTrain)) 105 106 nInput = len(examples[0])-1 107 nOutput = 1 108 if hiddenSizes is None: 109 nHidden = nInput 110 netSize = [nInput,nHidden,nOutput] 111 else: 112 netSize = [nInput] + hiddenSizes + [nOutput] 113 net = Network.Network(netSize) 114 t = Trainers.BackProp() 115 t.TrainOnLine(trainExamples,net,errTol=tolerance,useAvgErr=0,silent=silent) 116 117 118 nTest = len(testExamples) 119 if not silent: 120 print('Testing with %d examples'%nTest) 121 if not calcTotalError: 122 xValError,badExamples = CrossValidate(net,testExamples,tolerance) 123 else: 124 xValError,badExamples = CrossValidate(net,allExamples,tolerance) 125 if not silent: 126 print('Validation error was %%%4.2f'%(100*xValError)) 127 net._trainIndices=trainIndices 128 return net,xValError
129