1
2
3
4 """ ID3 Decision Trees
5
6 contains an implementation of the ID3 decision tree algorithm
7 as described in Tom Mitchell's book "Machine Learning"
8
9 It relies upon the _Tree.TreeNode_ data structure (or something
10 with the same API) defined locally to represent the trees
11
12 """
13
14 import numpy
15 from rdkit.ML.DecTree import DecTree
16 from rdkit.ML.InfoTheory import entropy
17 from rdkit.six.moves import range, xrange
18
20 """ Calculates the total entropy of the data set (w.r.t. the results)
21
22 **Arguments**
23
24 - examples: a list (nInstances long) of lists of variable values + instance
25 values
26 - nPossibleVals: a list (nVars long) of the number of possible values each variable
27 can adopt.
28
29 **Returns**
30
31 a float containing the informational entropy of the data set.
32
33 """
34 nRes = nPossibleVals[-1]
35 resList = numpy.zeros(nRes,'i')
36 for example in examples:
37 res = int(example[-1])
38 resList[res] += 1
39 return entropy.InfoEntropy(resList)
40
42 """Generates a list of variable tables for the examples passed in.
43
44 The table for a given variable records the number of times each possible value
45 of that variable appears for each possible result of the function.
46
47 **Arguments**
48
49 - examples: a list (nInstances long) of lists of variable values + instance
50 values
51
52 - nPossibleVals: a list containing the number of possible values of
53 each variable + the number of values of the function.
54
55 - vars: a list of the variables to include in the var table
56
57
58 **Returns**
59
60 a list of variable result tables. Each table is a Numeric array
61 which is varValues x nResults
62 """
63 nVars = len(vars)
64 res = [None]*nVars
65 nFuncVals = nPossibleVals[-1]
66
67 for i in xrange(nVars):
68 res[i] = numpy.zeros((nPossibleVals[vars[i]],nFuncVals),'i')
69 for example in examples:
70 val = int(example[-1])
71 for i in xrange(nVars):
72 res[i][int(example[vars[i]]),val] += 1
73
74 return res
75
76 -def ID3(examples,target,attrs,nPossibleVals,depth=0,maxDepth=-1,
77 **kwargs):
78 """ Implements the ID3 algorithm for constructing decision trees.
79
80 From Mitchell's book, page 56
81
82 This is *slightly* modified from Mitchell's book because it supports
83 multivalued (non-binary) results.
84
85 **Arguments**
86
87 - examples: a list (nInstances long) of lists of variable values + instance
88 values
89
90 - target: an int
91
92 - attrs: a list of ints indicating which variables can be used in the tree
93
94 - nPossibleVals: a list containing the number of possible values of
95 every variable.
96
97 - depth: (optional) the current depth in the tree
98
99 - maxDepth: (optional) the maximum depth to which the tree
100 will be grown
101
102 **Returns**
103
104 a DecTree.DecTreeNode with the decision tree
105
106 **NOTE:** This code cannot bootstrap (start from nothing...)
107 use _ID3Boot_ (below) for that.
108 """
109 varTable = GenVarTable(examples,nPossibleVals,attrs)
110 tree=DecTree.DecTreeNode(None,'node')
111
112
113 totEntropy = CalcTotalEntropy(examples,nPossibleVals)
114 tree.SetData(totEntropy)
115
116
117
118 tMat = GenVarTable(examples,nPossibleVals,[target])[0]
119
120 counts = sum(tMat)
121 nzCounts = numpy.nonzero(counts)[0]
122
123 if len(nzCounts) == 1:
124
125
126
127 res = nzCounts[0]
128 tree.SetLabel(res)
129 tree.SetName(str(res))
130 tree.SetTerminal(1)
131 elif len(attrs) == 0 or (maxDepth>=0 and depth>=maxDepth):
132
133
134
135
136 v = numpy.argmax(counts)
137 tree.SetLabel(v)
138 tree.SetName('%d?'%v)
139 tree.SetTerminal(1)
140 else:
141
142
143 gains = [entropy.InfoGain(x) for x in varTable]
144 best = attrs[numpy.argmax(gains)]
145
146
147
148 nextAttrs = attrs[:]
149 if not kwargs.get('recycleVars',0):
150 nextAttrs.remove(best)
151
152
153 tree.SetName('Var: %d'%best)
154 tree.SetLabel(best)
155
156 tree.SetTerminal(0)
157
158
159
160 for val in xrange(nPossibleVals[best]):
161 nextExamples = []
162 for example in examples:
163 if example[best] == val:
164 nextExamples.append(example)
165 if len(nextExamples) == 0:
166
167
168
169 v = numpy.argmax(counts)
170 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1)
171 else:
172
173 tree.AddChildNode(ID3(nextExamples,best,nextAttrs,nPossibleVals,depth+1,maxDepth,
174 **kwargs))
175 return tree
176
177 -def ID3Boot(examples,attrs,nPossibleVals,initialVar=None,depth=0,maxDepth=-1,
178 **kwargs):
179 """ Bootstrapping code for the ID3 algorithm
180
181 see ID3 for descriptions of the arguments
182
183 If _initialVar_ is not set, the algorithm will automatically
184 choose the first variable in the tree (the standard greedy
185 approach). Otherwise, _initialVar_ will be used as the first
186 split.
187
188 """
189 totEntropy = CalcTotalEntropy(examples,nPossibleVals)
190 varTable = GenVarTable(examples,nPossibleVals,attrs)
191
192 tree=DecTree.DecTreeNode(None,'node')
193
194 tree._nResultCodes = nPossibleVals[-1]
195
196
197
198 if initialVar is None:
199 best = attrs[numpy.argmax([entropy.InfoGain(x) for x in varTable])]
200 else:
201 best = initialVar
202
203 tree.SetName('Var: %d'%best)
204 tree.SetData(totEntropy)
205 tree.SetLabel(best)
206 tree.SetTerminal(0)
207 nextAttrs = list(attrs)
208 if not kwargs.get('recycleVars',0):
209 nextAttrs.remove(best)
210
211 for val in xrange(nPossibleVals[best]):
212 nextExamples = []
213 for example in examples:
214 if example[best] == val:
215 nextExamples.append(example)
216
217 tree.AddChildNode(ID3(nextExamples,best,nextAttrs,nPossibleVals,depth,maxDepth,
218 **kwargs))
219 return tree
220