1
2
3
4
5
6 """ handles doing cross validation with naive bayes models
7 and evaluation of individual models
8
9 """
10 from __future__ import print_function
11 from rdkit.ML.NaiveBayes.ClassificationModel import NaiveBayesClassifier
12 from rdkit.ML.Data import SplitData
13 try:
14 from rdkit.ML.FeatureSelect import CMIM
15 except ImportError:
16 CMIM=None
17
18 -def makeNBClassificationModel(trainExamples, attrs, nPossibleValues, nQuantBounds,
19 mEstimateVal=-1.0,
20 useSigs=False,
21 ensemble=None,useCMIM=0,
22 **kwargs) :
23 if CMIM is not None and useCMIM > 0 and useSigs and not ensemble:
24 ensemble = CMIM.SelectFeatures(trainExamples,useCMIM,bvCol=1)
25 if ensemble:
26 attrs = ensemble
27 model = NaiveBayesClassifier(attrs, nPossibleValues, nQuantBounds,
28 mEstimateVal=mEstimateVal,useSigs=useSigs)
29
30
31 model.SetTrainingExamples(trainExamples)
32 model.trainModel()
33 return model
34
36
37 nTest = len(testExamples)
38 assert nTest,'no test examples: %s'%str(testExamples)
39 badExamples = []
40 nBad = 0
41 preds = NBmodel.ClassifyExamples(testExamples, appendExamples)
42 assert len(preds) == nTest
43
44 for i in range(nTest):
45 testEg = testExamples[i]
46 trueRes = testEg[-1]
47 res = preds[i]
48
49 if (trueRes != res) :
50 badExamples.append(testEg)
51 nBad += 1
52 return float(nBad)/nTest, badExamples
53
58 nTot = len(examples)
59 if not kwargs.get('replacementSelection',0):
60 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
61 silent=1,legacy=1,
62 replacement=0)
63 else :
64 testIndices,trainIndices = SplitData.SplitIndices(nTot,holdOutFrac,
65 silent=1,legacy=0,
66 replacement=1)
67
68 trainExamples = [examples[x] for x in trainIndices]
69 testExamples = [examples[x] for x in testIndices]
70
71 NBmodel = modelBuilder(trainExamples, attrs, nPossibleValues, nQuantBounds,
72 mEstimateVal,**kwargs)
73
74 if not calcTotalError:
75 xValError, badExamples = CrossValidate(NBmodel, testExamples,appendExamples=1)
76 else:
77 xValError,badExamples = CrossValidate(NBmodel, examples,appendExamples=0)
78
79 if not silent:
80 print('Validation error was %%%4.2f'%(100*xValError))
81 NBmodel._trainIndices = trainIndices
82 return NBmodel, xValError
83