1
2
3
4 """ Contains functionality for doing tree pruning
5
6 """
7 from __future__ import print_function
8 import numpy
9 import copy
10 from rdkit.ML.DecTree import CrossValidate, DecTree
11 from rdkit.six.moves import range
12
13 _verbose = 0
14
16 """ given a set of examples, returns the most common result code
17
18 **Arguments**
19
20 examples: a list of examples to be counted
21
22 **Returns**
23
24 the most common result code
25
26 """
27 resList = [x[-1] for x in examples]
28 maxVal = max(resList)
29 counts = [None]*(maxVal+1)
30 for i in range(maxVal+1):
31 counts[i] = sum([x==i for x in resList])
32
33 return numpy.argmax(counts)
34
36 nWrong = 0
37 for example in node.GetExamples():
38 pred = node.ClassifyExample(example,appendExamples=0)
39 if pred != example[-1]:
40 nWrong +=1
41
42 return nWrong
43
45 """Recursively finds and removes the nodes whose removals improve classification
46
47 **Arguments**
48
49 - node: the tree to be pruned. The pruning data should already be contained
50 within node (i.e. node.GetExamples() should return the pruning data)
51
52 - level: (optional) the level of recursion, used only in _verbose printing
53
54
55 **Returns**
56
57 the pruned version of node
58
59
60 **Notes**
61
62 - This uses a greedy algorithm which basically does a DFS traversal of the tree,
63 removing nodes whenever possible.
64
65 - If removing a node does not affect the accuracy, it *will be* removed. We
66 favor smaller trees.
67
68 """
69 if _verbose: print(' '*level,'<%d> '%level,'>>> Pruner')
70 children = node.GetChildren()[:]
71
72 bestTree = copy.deepcopy(node)
73 bestErr = 1e6
74 emptyChildren=[]
75
76
77
78
79
80 for i in range(len(children)):
81 child = children[i]
82 examples = child.GetExamples()
83 if _verbose:
84 print(' '*level,'<%d> '%level,' Child:',i,child.GetLabel())
85 bestTree.Print()
86 print()
87 if len(examples):
88 if _verbose: print(' '*level,'<%d> '%level,' Examples',len(examples))
89 if not child.GetTerminal():
90 if _verbose: print(' '*level,'<%d> '%level,' Nonterminal')
91
92 workTree = copy.deepcopy(bestTree)
93
94
95
96 newNode = _Pruner(child,level=level+1)
97 workTree.ReplaceChildIndex(i,newNode)
98 tempErr = _GetLocalError(workTree)
99 if tempErr<=bestErr:
100 bestErr = tempErr
101 bestTree = copy.deepcopy(workTree)
102 if _verbose:
103 print(' '*level,'<%d> '%level,'>->->->->->')
104 print(' '*level,'<%d> '%level,'replacing:',i,child.GetLabel())
105 child.Print()
106 print(' '*level,'<%d> '%level,'with:')
107 newNode.Print()
108 print(' '*level,'<%d> '%level,'<-<-<-<-<-<')
109 else:
110 workTree.ReplaceChildIndex(i,child)
111
112
113
114 bestGuess = MaxCount(child.GetExamples())
115 newNode = DecTree.DecTreeNode(workTree,'L:%d'%(bestGuess),
116 label=bestGuess,isTerminal=1)
117 newNode.SetExamples(child.GetExamples())
118 workTree.ReplaceChildIndex(i,newNode)
119 if _verbose:
120 print(' '*level,'<%d> '%level,'ATTEMPT:')
121 workTree.Print()
122 newErr = _GetLocalError(workTree)
123 if _verbose: print(' '*level,'<%d> '%level,'---> ',newErr,bestErr)
124 if newErr <= bestErr:
125 bestErr = newErr
126 bestTree = copy.deepcopy(workTree)
127 if _verbose:
128 print(' '*level,'<%d> '%level,'PRUNING:')
129 workTree.Print()
130 else:
131 if _verbose: print(' '*level,'<%d> '%level,'FAIL')
132
133 workTree.ReplaceChildIndex(i,child)
134 else:
135 if _verbose: print(' '*level,'<%d> '%level,' Terminal')
136 else:
137 if _verbose: print(' '*level,'<%d> '%level,' No Examples',len(examples))
138
139
140
141
142
143
144 pass
145
146 if _verbose: print(' '*level,'<%d> '%level,'<<< out')
147 return bestTree
148
149 -def PruneTree(tree,trainExamples,testExamples,minimizeTestErrorOnly=1):
150 """ implements a reduced-error pruning of decision trees
151
152 This algorithm is described on page 69 of Mitchell's book.
153
154 Pruning can be done using just the set of testExamples (the validation set)
155 or both the testExamples and the trainExamples by setting minimizeTestErrorOnly
156 to 0.
157
158 **Arguments**
159
160 - tree: the initial tree to be pruned
161
162 - trainExamples: the examples used to train the tree
163
164 - testExamples: the examples held out for testing the tree
165
166 - minimizeTestErrorOnly: if this toggle is zero, all examples (i.e.
167 _trainExamples_ + _testExamples_ will be used to evaluate the error.
168
169 **Returns**
170
171 a 2-tuple containing:
172
173 1) the best tree
174
175 2) the best error (the one which corresponds to that tree)
176
177 """
178 if minimizeTestErrorOnly:
179 testSet = testExamples
180 else:
181 testSet = trainExamples + testExamples
182
183
184 tree.ClearExamples()
185
186
187
188
189
190 totErr,badEx = CrossValidate.CrossValidate(tree,testSet,appendExamples=1)
191
192
193
194
195
196 newTree = _Pruner(tree)
197
198
199
200
201 totErr,badEx = CrossValidate.CrossValidate(newTree,testSet)
202 newTree.SetBadExamples(badEx)
203
204 return newTree,totErr
205
206
207
208
209
211 from rdkit.ML.DecTree import randomtest
212
213 examples,attrs,nPossibleVals = randomtest.GenRandomExamples(nVars=10,randScale=0.5,nExamples = 200)
214 tree,frac = CrossValidate.CrossValidationDriver(examples,attrs,nPossibleVals)
215 tree.Print()
216 tree.Pickle('orig.pkl')
217 print('original error is:', frac)
218
219 print('----Pruning')
220 newTree,frac2 = PruneTree(tree,tree.GetTrainingExamples(),tree.GetTestExamples())
221 newTree.Print()
222 print('pruned error is:',frac2)
223 newTree.Pickle('prune.pkl')
224
225
227 from rdkit.ML.DecTree import ID3
228 oPts= [ \
229 [0,0,1,0],
230 [0,1,1,1],
231 [1,0,1,1],
232 [1,1,0,0],
233 [1,1,1,1],
234 ]
235 tPts = oPts+[[0,1,1,0],[0,1,1,0]]
236
237 tree = ID3.ID3Boot(oPts,attrs=range(3),nPossibleVals=[2]*4)
238 tree.Print()
239 err,badEx = CrossValidate.CrossValidate(tree,oPts)
240 print('original error:',err)
241
242
243 err,badEx = CrossValidate.CrossValidate(tree,tPts)
244 print('original holdout error:',err)
245 newTree,frac2 = PruneTree(tree,oPts,tPts)
246 newTree.Print()
247 err,badEx = CrossValidate.CrossValidate(newTree,tPts)
248 print('pruned holdout error is:',err)
249 print(badEx)
250
251 print(len(tree),len(newTree))
252
254 from rdkit.ML.DecTree import ID3
255 oPts= [ \
256 [1,0,0,0,1],
257 [1,0,0,0,1],
258 [1,0,0,0,1],
259 [1,0,0,0,1],
260 [1,0,0,0,1],
261 [1,0,0,0,1],
262 [1,0,0,0,1],
263 [0,0,1,1,0],
264 [0,0,1,1,0],
265 [0,0,1,1,1],
266 [0,1,0,1,0],
267 [0,1,0,1,0],
268 [0,1,0,0,1],
269 ]
270 tPts = oPts
271
272 tree = ID3.ID3Boot(oPts,attrs=range(len(oPts[0])-1),nPossibleVals=[2]*len(oPts[0]))
273 tree.Print()
274 err,badEx = CrossValidate.CrossValidate(tree,oPts)
275 print('original error:',err)
276
277
278 err,badEx = CrossValidate.CrossValidate(tree,tPts)
279 print('original holdout error:',err)
280 newTree,frac2 = PruneTree(tree,oPts,tPts)
281 newTree.Print()
282 err,badEx = CrossValidate.CrossValidate(newTree,tPts)
283 print('pruned holdout error is:',err)
284 print(badEx)
285
286
287 if __name__ == '__main__':
288 _verbose=1
289
290
291 _testChain()
292