1
2
3
4 """ code for dealing with forests (collections) of decision trees
5
6 **NOTE** This code should be obsolete now that ML.Composite.Composite is up and running.
7
8 """
9 from __future__ import print_function
10 from rdkit.six.moves import cPickle
11 import numpy
12 from rdkit.ML.DecTree import CrossValidate,PruneTree
13
15 """a forest of unique decision trees.
16
17 adding an existing tree just results in its count field being incremented
18 and the errors being averaged.
19
20 typical usage:
21
22 1) grow the forest with AddTree until happy with it
23
24 2) call AverageErrors to calculate the average error values
25
26 3) call SortTrees to put things in order by either error or count
27
28 """
30 """ creates a histogram of error/count pairs
31
32 """
33 nExamples = len(self.treeList)
34 histo = []
35 i = 1
36 lastErr = self.errList[0]
37 countHere = self.countList[0]
38 eps = 0.001
39 while i < nExamples:
40 if self.errList[i]-lastErr > eps:
41 histo.append((lastErr,countHere))
42 lastErr = self.errList[i]
43 countHere = self.countList[i]
44 else:
45 countHere = countHere + self.countList[i]
46 i = i + 1
47
48 return histo
49
51 """ collects votes across every member of the forest for the given example
52
53 **Returns**
54
55 a list of the results
56
57 """
58 nTrees = len(self.treeList)
59 votes = [0]*nTrees
60 for i in range(nTrees):
61 votes[i] = self.treeList[i].ClassifyExample(example)
62 return votes
63
65 """ classifies the given example using the entire forest
66
67 **returns** a result and a measure of confidence in it.
68
69 **FIX:** statistics sucks... I'm not seeing an obvious way to get
70 the confidence intervals. For that matter, I'm not seeing
71 an unobvious way.
72
73 For now, this is just treated as a voting problem with the confidence
74 measure being the percent of trees which voted for the winning result.
75 """
76 self.treeVotes = self.CollectVotes(example)
77 votes = [0]*len(self._nPossible)
78 for i in range(len(self.treeList)):
79 res = self.treeVotes[i]
80 votes[res] = votes[res] + self.countList[i]
81
82 totVotes = sum(votes)
83 res = argmax(votes)
84
85 return res,float(votes[res])/float(totVotes)
86
88 """ Returns the details of the last vote the forest conducted
89
90 this will be an empty list if no voting has yet been done
91
92 """
93 return self.treeVotes
94
95 - def Grow(self,examples,attrs,nPossibleVals,nTries=10,pruneIt=0,
96 lessGreedy=0):
97 """ Grows the forest by adding trees
98
99 **Arguments**
100
101 - examples: the examples to be used for training
102
103 - attrs: a list of the attributes to be used in training
104
105 - nPossibleVals: a list with the number of possible values each variable
106 (as well as the result) can take on
107
108 - nTries: the number of new trees to add
109
110 - pruneIt: a toggle for whether or not the tree should be pruned
111
112 - lessGreedy: toggles the use of a less greedy construction algorithm where
113 each possible tree root is used. The best tree from each step is actually
114 added to the forest.
115
116 """
117 self._nPossible = nPossibleVals
118 for i in range(nTries):
119 tree,frac = CrossValidate.CrossValidationDriver(examples,attrs,nPossibleVals,
120 silent=1,calcTotalError=1,
121 lessGreedy=lessGreedy)
122 if pruneIt:
123 tree,frac2 = PruneTree.PruneTree(tree,tree.GetTrainingExamples(),
124 tree.GetTestExamples(),
125 minimizeTestErrorOnly=0)
126 print('prune: ', frac,frac2)
127 frac = frac2
128 self.AddTree(tree,frac)
129 if i % (nTries/10) == 0:
130 print('Cycle: % 4d'%(i))
131
132 - def Pickle(self,fileName='foo.pkl'):
133 """ Writes this forest off to a file so that it can be easily loaded later
134
135 **Arguments**
136
137 fileName is the name of the file to be written
138
139 """
140 pFile = open(fileName,'wb+')
141 cPickle.dump(self,pFile,1)
142 pFile.close()
143
145 """ Adds a tree to the forest
146
147 If an identical tree is already present, its count is incremented
148
149 **Arguments**
150
151 - tree: the new tree
152
153 - error: its error value
154
155 **NOTE:** the errList is run as an accumulator,
156 you probably want to call AverageErrors after finishing the forest
157
158 """
159 if tree in self.treeList:
160 idx = self.treeList.index(tree)
161 self.errList[idx] = self.errList[idx]+error
162 self.countList[idx] = self.countList[idx] + 1
163 else:
164 self.treeList.append(tree)
165 self.errList.append(error)
166 self.countList.append(1)
167
169 """ convert summed error to average error
170
171 This does the conversion in place
172 """
173 self.errList = [x/y for x,y in zip(self.errList,self.countList)]
174
176 """ sorts the list of trees
177
178 **Arguments**
179
180 sortOnError: toggles sorting on the trees' errors rather than their counts
181
182 """
183 if sortOnError:
184 order = numpy.argsort(self.errList)
185 else:
186 order = numpy.argsort(self.countList)
187
188
189
190 self.treeList = [self.treeList[x] for x in order]
191 self.countList = [self.countList[x] for x in order]
192 self.errList = [self.errList[x] for x in order]
193
195 return self.treeList[i]
197 self.treeList[i] = val
198
200 return self.countList[i]
202 self.countList[i] = val
203
205 return self.errList[i]
207 self.errList[i] = val
208
210 """ returns all relevant data about a particular tree in the forest
211
212 **Arguments**
213
214 i: an integer indicating which tree should be returned
215
216 **Returns**
217
218 a 3-tuple consisting of:
219
220 1) the tree
221
222 2) its count
223
224 3) its error
225 """
226 return (self.treeList[i],self.countList[i],self.errList[i])
227
229 """ sets all relevant data for a particular tree in the forest
230
231 **Arguments**
232
233 - i: an integer indicating which tree should be returned
234
235 - tup: a 3-tuple consisting of:
236
237 1) the tree
238
239 2) its count
240
241 3) its error
242 """
243 self.treeList[i],self.countList[i],self.errList[i] = tup
244
246 """ Returns everything we know
247
248 **Returns**
249
250 a 3-tuple consisting of:
251
252 1) our list of trees
253
254 2) our list of tree counts
255
256 3) our list of tree errors
257
258 """
259 return (self.treeList,self.countList,self.errList)
260
262 """ allows len(forest) to work
263
264 """
265 return len(self.treeList)
266
268 """ allows forest[i] to work. return the data tuple
269
270 """
271 return self.GetDataTuple(which)
272
274 """ allows the forest to show itself as a string
275
276 """
277 outStr= 'Forest\n'
278 for i in range(len(self.treeList)):
279 outStr = outStr + \
280 ' Tree % 4d: % 5d occurances %%% 5.2f average error\n'%(i,self.countList[i],
281 100.*self.errList[i])
282 return outStr
283
285 self.treeList=[]
286 self.errList=[]
287 self.countList=[]
288 self.treeVotes=[]
289
290 if __name__ == '__main__':
291 from rdkit.ML.DecTree import DecTree
292 f = Forest()
293 n = DecTree.DecTreeNode(None,'foo')
294 f.AddTree(n,0.5)
295 f.AddTree(n,0.5)
296 f.AverageErrors()
297 f.SortTrees()
298 print(f)
299