1
2
3
4
5
6
7
8 """
9
10 """
11 from __future__ import print_function
12 import numpy
13 from rdkit.ML.DecTree import SigTree
14 from rdkit.ML import InfoTheory
15 try:
16 from rdkit.ML.FeatureSelect import CMIM
17 except ImportError:
18 CMIM=None
19 from rdkit.DataStructs.VectCollection import VectCollection
20 import copy
21 import random
23 """ Generates a random subset of a group of indices
24
25 **Arguments**
26
27 - nToInclude: the size of the desired set
28
29 - nBits: the maximum index to be included in the set
30
31 **Returns**
32
33 a list of indices
34
35 """
36
37
38 res = random.sample(range(nBits),nToInclude)
39 return res
40
41 -def BuildSigTree(examples,nPossibleRes,ensemble=None,random=0,
42 metric=InfoTheory.InfoType.BIASENTROPY,
43 biasList=[1],
44 depth=0,maxDepth=-1,
45 useCMIM=0,allowCollections=False,
46 verbose=0,**kwargs):
47 """
48 **Arguments**
49
50 - examples: the examples to be classified. Each example
51 should be a sequence at least three entries long, with
52 entry 0 being a label, entry 1 a BitVector and entry -1
53 an activity value
54
55 - nPossibleRes: the number of result codes possible
56
57 - ensemble: (optional) if this argument is provided, it
58 should be a sequence which is used to limit the bits
59 which are actually considered as potential descriptors.
60 The default is None (use all bits).
61
62 - random: (optional) If this argument is nonzero, it
63 specifies the number of bits to be randomly selected
64 for consideration at this node (i.e. this toggles the
65 growth of Random Trees).
66 The default is 0 (no random descriptor selection)
67
68 - metric: (optional) This is an _InfoTheory.InfoType_ and
69 sets the metric used to rank the bits.
70 The default is _InfoTheory.InfoType.BIASENTROPY_
71
72 - biasList: (optional) If provided, this provides a bias
73 list for the bit ranker.
74 See the _InfoTheory.InfoBitRanker_ docs for an explanation
75 of bias.
76 The default value is [1], which biases towards actives.
77
78 - maxDepth: (optional) the maximum depth to which the tree
79 will be grown
80 The default is -1 (no depth limit).
81
82 - useCMIM: (optional) if this is >0, the CMIM algorithm
83 (conditional mutual information maximization) will be
84 used to select the descriptors used to build the trees.
85 The value of the variable should be set to the number
86 of descriptors to be used. This option and the
87 ensemble option are mutually exclusive (CMIM will not be
88 used if the ensemble is set), but it happily coexsts
89 with the random argument (to only consider random subsets
90 of the top N CMIM bits)
91 The default is 0 (do not use CMIM)
92
93 - depth: (optional) the current depth in the tree
94 This is used in the recursion and should not be set
95 by the client.
96
97 **Returns**
98
99 a SigTree.SigTreeNode with the root of the decision tree
100
101 """
102 if verbose: print(' '*depth,'Build')
103 tree=SigTree.SigTreeNode(None,'node',level=depth)
104 tree.SetData(-666)
105
106
107
108
109 resCodes = [int(x[-1]) for x in examples]
110
111 counts = [0]*nPossibleRes
112 for res in resCodes:
113 counts[res] += 1
114
115
116 nzCounts = numpy.nonzero(counts)[0]
117 if verbose: print(' '*depth,'\tcounts:',counts)
118 if len(nzCounts) == 1:
119
120
121
122 res = nzCounts[0]
123 tree.SetLabel(res)
124 tree.SetName(str(res))
125 tree.SetTerminal(1)
126 elif maxDepth>=0 and depth>maxDepth:
127
128
129
130
131 v = numpy.argmax(counts)
132 tree.SetLabel(v)
133 tree.SetName('%d?'%v)
134 tree.SetTerminal(1)
135 else:
136
137
138 fp = examples[0][1]
139 nBits = fp.GetNumBits()
140 ranker = InfoTheory.InfoBitRanker(nBits,nPossibleRes,metric)
141 if biasList: ranker.SetBiasList(biasList)
142 if CMIM is not None and useCMIM > 0 and not ensemble:
143 ensemble = CMIM.SelectFeatures(examples,useCMIM,bvCol=1)
144 if random:
145 if ensemble:
146 if len(ensemble)>random:
147 picks = _GenerateRandomEnsemble(random,len(ensemble))
148 availBits = list(take(ensemble,picks))
149 else:
150 availBits = range(len(ensemble))
151 else:
152 availBits = _GenerateRandomEnsemble(random,nBits)
153 else:
154 availBits=None
155 if availBits:
156 ranker.SetMaskBits(availBits)
157
158
159 useCollections=isinstance(examples[0][1],VectCollection)
160 for example in examples:
161
162 if not useCollections:
163 ranker.AccumulateVotes(example[1],example[-1])
164 else:
165 example[1].Reset()
166 ranker.AccumulateVotes(example[1].orVect,example[-1])
167
168 try:
169 bitInfo = ranker.GetTopN(1)[0]
170 best = int(bitInfo[0])
171 gain = bitInfo[1]
172 except:
173 import traceback
174 traceback.print_exc()
175 print('get top n failed')
176 gain = -1.0
177 if gain <= 0.0:
178 v = numpy.argmax(counts)
179 tree.SetLabel(v)
180 tree.SetName('?%d?'%v)
181 tree.SetTerminal(1)
182 return tree
183 best = int(bitInfo[0])
184
185 if verbose: print(' '*depth,'\tbest:',bitInfo)
186
187 tree.SetName('Bit-%d'%(best))
188 tree.SetLabel(best)
189
190 tree.SetTerminal(0)
191
192
193
194 onExamples = []
195 offExamples = []
196 for example in examples:
197 if example[1][best]:
198 if allowCollections and useCollections:
199 sig = copy.copy(example[1])
200 sig.DetachVectsNotMatchingBit(best)
201 ex = [example[0],sig]
202 if len(example)>2:
203 ex.extend(example[2:])
204 example = ex
205 onExamples.append(example)
206 else:
207 offExamples.append(example)
208
209 for ex in (offExamples,onExamples):
210 if len(ex) == 0:
211 v = numpy.argmax(counts)
212 tree.AddChild('%d??'%v,label=v,data=0.0,isTerminal=1)
213 else:
214 child = BuildSigTree(ex,nPossibleRes,random=random,
215 ensemble=ensemble,
216 metric=metric,biasList=biasList,
217 depth=depth+1,maxDepth=maxDepth,
218 verbose=verbose)
219 if child is None:
220 v = numpy.argmax(counts)
221 tree.AddChild('%d???'%v,label=v,data=0.0,isTerminal=1)
222 else:
223 tree.AddChildNode(child)
224 return tree
225
226
227 -def SigTreeBuilder(examples,attrs,nPossibleVals,initialVar=None,ensemble=None,
228 randomDescriptors=0,
229 **kwargs):
230 nRes = nPossibleVals[-1]
231 return BuildSigTree(examples,nRes,random=randomDescriptors,**kwargs)
232