Package rdkit :: Package ML :: Package DecTree :: Module BuildQuantTree
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.BuildQuantTree

  1  ## Automatically adapted for numpy.oldnumeric Jun 27, 2008 by -c 
  2   
  3  # $Id$ 
  4  # 
  5  #  Copyright (C) 2001-2008  greg Landrum and Rational Discovery LLC 
  6  #  All Rights Reserved 
  7  # 
  8  """  
  9   
 10  """ 
 11  from __future__ import print_function 
 12  import numpy 
 13  import random 
 14  from rdkit.ML.DecTree import QuantTree, ID3 
 15  from rdkit.ML.InfoTheory import entropy 
 16  from rdkit.ML.Data import Quantize 
 17  from rdkit.six.moves import range 
 18   
19 -def FindBest(resCodes,examples,nBoundsPerVar,nPossibleRes, 20 nPossibleVals,attrs,exIndices=None,**kwargs):
21 bestGain =-1e6 22 best = -1 23 bestBounds = [] 24 25 if exIndices is None: 26 exIndices=list(range(len(examples))) 27 28 if not len(exIndices): 29 return best,bestGain,bestBounds 30 31 nToTake = kwargs.get('randomDescriptors',0) 32 if nToTake > 0: 33 nAttrs = len(attrs) 34 if nToTake < nAttrs: 35 ids = list(range(nAttrs)) 36 random.shuffle(ids,random=random.random) 37 tmp = [attrs[x] for x in ids[:nToTake]] 38 attrs = tmp 39 40 for var in attrs: 41 nBounds = nBoundsPerVar[var] 42 if nBounds > 0: 43 #vTable = map(lambda x,z=var:x[z],examples) 44 try: 45 vTable = [examples[x][var] for x in exIndices] 46 except IndexError: 47 print('index error retrieving variable: %d'%var) 48 raise 49 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBounds, 50 resCodes,nPossibleRes) 51 #print('\tvar:',var,qBounds,gainHere) 52 elif nBounds==0: 53 vTable = ID3.GenVarTable((examples[x] for x in exIndices), 54 nPossibleVals,[var])[0] 55 gainHere = entropy.InfoGain(vTable) 56 qBounds = [] 57 else: 58 gainHere = -1e6 59 qBounds = [] 60 if gainHere > bestGain: 61 bestGain = gainHere 62 bestBounds = qBounds 63 best = var 64 elif bestGain==gainHere: 65 if len(qBounds)<len(bestBounds): 66 best = var 67 bestBounds = qBounds 68 if best == -1: 69 print('best unaltered') 70 print('\tattrs:',attrs) 71 print('\tnBounds:',take(nBoundsPerVar,attrs)) 72 print('\texamples:') 73 for example in (examples[x] for x in exIndices): 74 print('\t\t',example) 75 76 77 if 0: 78 print('BEST:',len(exIndices),best,bestGain,bestBounds) 79 if(len(exIndices)<10): 80 print(len(exIndices),len(resCodes),len(examples)) 81 exs = [examples[x] for x in exIndices] 82 vals = [x[best] for x in exs] 83 sortIdx = numpy.argsort(vals) 84 sortVals = [exs[x] for x in sortIdx] 85 sortResults = [resCodes[x] for x in sortIdx] 86 for i in range(len(vals)): 87 print(' ',i,['%.4f'%x for x in sortVals[i][1:-1]],sortResults[i]) 88 return best,bestGain,bestBounds
89 90
91 -def BuildQuantTree(examples,target,attrs,nPossibleVals,nBoundsPerVar, 92 depth=0,maxDepth=-1,exIndices=None,**kwargs):
93 """ 94 **Arguments** 95 96 - examples: a list of lists (nInstances x nVariables+1) of variable 97 values + instance values 98 99 - target: an int 100 101 - attrs: a list of ints indicating which variables can be used in the tree 102 103 - nPossibleVals: a list containing the number of possible values of 104 every variable. 105 106 - nBoundsPerVar: the number of bounds to include for each variable 107 108 - depth: (optional) the current depth in the tree 109 110 - maxDepth: (optional) the maximum depth to which the tree 111 will be grown 112 **Returns** 113 114 a QuantTree.QuantTreeNode with the decision tree 115 116 **NOTE:** This code cannot bootstrap (start from nothing...) 117 use _QuantTreeBoot_ (below) for that. 118 """ 119 tree=QuantTree.QuantTreeNode(None,'node') 120 tree.SetData(-666) 121 nPossibleRes = nPossibleVals[-1] 122 123 if exIndices is None: 124 exIndices=list(range(len(examples))) 125 126 # counts of each result code: 127 resCodes = [int(x[-1]) for x in (examples[y] for y in exIndices)] 128 counts = [0]*nPossibleRes 129 for res in resCodes: 130 counts[res] += 1 131 nzCounts = numpy.nonzero(counts)[0] 132 133 if len(nzCounts) == 1: 134 # bottomed out because there is only one result code left 135 # with any counts (i.e. there's only one type of example 136 # left... this is GOOD!). 137 res = nzCounts[0] 138 tree.SetLabel(res) 139 tree.SetName(str(res)) 140 tree.SetTerminal(1) 141 elif len(attrs) == 0 or (maxDepth>=0 and depth>maxDepth): 142 # Bottomed out: no variables left or max depth hit 143 # We don't really know what to do here, so 144 # use the heuristic of picking the most prevalent 145 # result 146 v = numpy.argmax(counts) 147 tree.SetLabel(v) 148 tree.SetName('%d?'%v) 149 tree.SetTerminal(1) 150 else: 151 # find the variable which gives us the largest information gain 152 best,bestGain,bestBounds = FindBest(resCodes,examples,nBoundsPerVar, 153 nPossibleRes,nPossibleVals,attrs, 154 exIndices=exIndices, 155 **kwargs) 156 # remove that variable from the lists of possible variables 157 nextAttrs = attrs[:] 158 if not kwargs.get('recycleVars',0): 159 nextAttrs.remove(best) 160 161 # set some info at this node 162 tree.SetName('Var: %d'%(best)) 163 tree.SetLabel(best) 164 tree.SetQuantBounds(bestBounds) 165 tree.SetTerminal(0) 166 167 # loop over possible values of the new variable and 168 # build a subtree for each one 169 indices = exIndices[:] 170 if len(bestBounds) > 0: 171 for bound in bestBounds: 172 nextExamples = [] 173 for index in indices[:]: 174 ex = examples[index] 175 if ex[best] < bound: 176 nextExamples.append(index) 177 indices.remove(index) 178 179 if len(nextExamples) == 0: 180 # this particular value of the variable has no examples, 181 # so there's not much sense in recursing. 182 # This can (and does) happen. 183 v = numpy.argmax(counts) 184 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1) 185 else: 186 # recurse 187 tree.AddChildNode(BuildQuantTree(examples,best, 188 nextAttrs,nPossibleVals, 189 nBoundsPerVar, 190 depth=depth+1,maxDepth=maxDepth, 191 exIndices=nextExamples, 192 **kwargs)) 193 # add the last points remaining 194 nextExamples = [] 195 for index in indices: 196 nextExamples.append(index) 197 if len(nextExamples) == 0: 198 v = numpy.argmax(counts) 199 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1) 200 else: 201 tree.AddChildNode(BuildQuantTree(examples,best, 202 nextAttrs,nPossibleVals, 203 nBoundsPerVar, 204 depth=depth+1,maxDepth=maxDepth, 205 exIndices=nextExamples, 206 **kwargs)) 207 else: 208 for val in range(nPossibleVals[best]): 209 nextExamples = [] 210 for idx in exIndices: 211 if examples[idx][best] == val: 212 nextExamples.append(idx) 213 if len(nextExamples) == 0: 214 v = numpy.argmax(counts) 215 tree.AddChild('%d'%v,label=v,data=0.0,isTerminal=1) 216 else: 217 tree.AddChildNode(BuildQuantTree(examples,best, 218 nextAttrs,nPossibleVals, 219 nBoundsPerVar, 220 depth=depth+1,maxDepth=maxDepth, 221 exIndices=nextExamples, 222 **kwargs)) 223 return tree
224
225 -def QuantTreeBoot(examples,attrs,nPossibleVals,nBoundsPerVar,initialVar=None, 226 maxDepth=-1,**kwargs):
227 """ Bootstrapping code for the QuantTree 228 229 If _initialVar_ is not set, the algorithm will automatically 230 choose the first variable in the tree (the standard greedy 231 approach). Otherwise, _initialVar_ will be used as the first 232 split. 233 234 """ 235 attrs = list(attrs) 236 for i in range(len(nBoundsPerVar)): 237 if nBoundsPerVar[i]==-1 and i in attrs: 238 attrs.remove(i) 239 240 tree=QuantTree.QuantTreeNode(None,'node') 241 nPossibleRes = nPossibleVals[-1] 242 tree._nResultCodes = nPossibleRes 243 244 resCodes = [int(x[-1]) for x in examples] 245 counts = [0]*nPossibleRes 246 for res in resCodes: 247 counts[res] += 1 248 if initialVar is None: 249 best,gainHere,qBounds = FindBest(resCodes,examples,nBoundsPerVar, 250 nPossibleRes,nPossibleVals,attrs, 251 **kwargs) 252 else: 253 best = initialVar 254 if nBoundsPerVar[best] > 0: 255 vTable = map(lambda x,z=best:x[z],examples) 256 qBounds,gainHere = Quantize.FindVarMultQuantBounds(vTable,nBoundsPerVar[best], 257 resCodes,nPossibleRes) 258 elif nBoundsPerVar[best] == 0: 259 vTable = ID3.GenVarTable(examples,nPossibleVals,[best])[0] 260 gainHere = entropy.InfoGain(vTable) 261 qBounds = [] 262 else: 263 gainHere = -1e6 264 qBounds = [] 265 266 tree.SetName('Var: %d'%(best)) 267 tree.SetData(gainHere) 268 tree.SetLabel(best) 269 tree.SetTerminal(0) 270 tree.SetQuantBounds(qBounds) 271 nextAttrs = list(attrs) 272 if not kwargs.get('recycleVars',0): 273 nextAttrs.remove(best) 274 275 indices = list(range(len(examples))) 276 if len(qBounds) > 0: 277 for bound in qBounds: 278 nextExamples = [] 279 for index in list(indices): 280 ex = examples[index] 281 if ex[best] < bound: 282 nextExamples.append(ex) 283 indices.remove(index) 284 285 if len(nextExamples): 286 tree.AddChildNode(BuildQuantTree(nextExamples,best, 287 nextAttrs,nPossibleVals, 288 nBoundsPerVar, 289 depth=1,maxDepth=maxDepth, 290 **kwargs)) 291 else: 292 v = numpy.argmax(counts) 293 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1) 294 # add the last points remaining 295 nextExamples = [] 296 for index in indices: 297 nextExamples.append(examples[index]) 298 if len(nextExamples) != 0: 299 tree.AddChildNode(BuildQuantTree(nextExamples,best, 300 nextAttrs,nPossibleVals, 301 nBoundsPerVar, 302 depth=1,maxDepth=maxDepth, 303 **kwargs)) 304 else: 305 v = numpy.argmax(counts) 306 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1) 307 else: 308 for val in range(nPossibleVals[best]): 309 nextExamples = [] 310 for example in examples: 311 if example[best] == val: 312 nextExamples.append(example) 313 if len(nextExamples) != 0: 314 tree.AddChildNode(BuildQuantTree(nextExamples,best, 315 nextAttrs,nPossibleVals, 316 nBoundsPerVar, 317 depth=1,maxDepth=maxDepth, 318 **kwargs)) 319 else: 320 v = numpy.argmax(counts) 321 tree.AddChild('%d??'%(v),label=v,data=0.0,isTerminal=1) 322 return tree
323 324
325 -def TestTree():
326 """ testing code for named trees 327 328 """ 329 examples1 = [['p1',0,1,0,0], 330 ['p2',0,0,0,1], 331 ['p3',0,0,1,2], 332 ['p4',0,1,1,2], 333 ['p5',1,0,0,2], 334 ['p6',1,0,1,2], 335 ['p7',1,1,0,2], 336 ['p8',1,1,1,0] 337 ] 338 attrs = list(range(1,len(examples1[0])-1)) 339 nPossibleVals = [0,2,2,2,3] 340 t1 = ID3.ID3Boot(examples1,attrs,nPossibleVals,maxDepth=1) 341 t1.Print()
342 343
344 -def TestQuantTree():
345 """ testing code for named trees 346 347 """ 348 examples1 = [['p1',0,1,0.1,0], 349 ['p2',0,0,0.1,1], 350 ['p3',0,0,1.1,2], 351 ['p4',0,1,1.1,2], 352 ['p5',1,0,0.1,2], 353 ['p6',1,0,1.1,2], 354 ['p7',1,1,0.1,2], 355 ['p8',1,1,1.1,0] 356 ] 357 attrs = list(range(1,len(examples1[0])-1)) 358 nPossibleVals = [0,2,2,0,3] 359 boundsPerVar=[0,0,0,1,0] 360 361 print('base') 362 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar) 363 t1.Pickle('test_data/QuantTree1.pkl') 364 t1.Print() 365 366 print('depth limit') 367 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar,maxDepth=1) 368 t1.Pickle('test_data/QuantTree1.pkl') 369 t1.Print()
370
371 -def TestQuantTree2():
372 """ testing code for named trees 373 374 """ 375 examples1 = [['p1',0.1,1,0.1,0], 376 ['p2',0.1,0,0.1,1], 377 ['p3',0.1,0,1.1,2], 378 ['p4',0.1,1,1.1,2], 379 ['p5',1.1,0,0.1,2], 380 ['p6',1.1,0,1.1,2], 381 ['p7',1.1,1,0.1,2], 382 ['p8',1.1,1,1.1,0] 383 ] 384 attrs = list(range(1,len(examples1[0])-1)) 385 nPossibleVals = [0,0,2,0,3] 386 boundsPerVar=[0,1,0,1,0] 387 388 t1 = QuantTreeBoot(examples1,attrs,nPossibleVals,boundsPerVar) 389 t1.Print() 390 t1.Pickle('test_data/QuantTree2.pkl') 391 392 for example in examples1: 393 print(example,t1.ClassifyExample(example))
394 395 if __name__ == "__main__": 396 TestTree() 397 TestQuantTree() 398 #TestQuantTree2() 399