Package rdkit :: Package ML :: Module GrowComposite
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.GrowComposite

  1  # $Id$ 
  2  # 
  3  #  Copyright (C) 2003-2006  greg Landrum and Rational Discovery LLC 
  4  # 
  5  #   @@ All Rights Reserved @@ 
  6  #  This file is part of the RDKit. 
  7  #  The contents are covered by the terms of the BSD license 
  8  #  which is included in the file license.txt, found at the root 
  9  #  of the RDKit source tree. 
 10  # 
 11   
 12  """ command line utility for growing composite models 
 13   
 14  **Usage** 
 15   
 16    _GrowComposite [optional args] filename_ 
 17   
 18  **Command Line Arguments** 
 19   
 20    - -n *count*: number of new models to build 
 21   
 22    - -C *pickle file name*:  name of file containing composite upon which to build. 
 23   
 24    - --inNote *note*: note to be used in loading composite models from the database 
 25        for growing 
 26   
 27    - --balTable *table name*:  table from which to take the original data set 
 28       (for balancing) 
 29   
 30    - --balWeight *weight*: (between 0 and 1) weighting factor for the new data 
 31       (for balancing). OR, *weight* can be a list of weights 
 32   
 33    - --balCnt *count*: number of individual models in the balanced composite 
 34       (for balancing) 
 35   
 36    - --balH: use only the holdout set from the original data set in the balancing 
 37       (for balancing) 
 38   
 39    - --balT: use only the training set from the original data set in the balancing 
 40       (for balancing) 
 41   
 42    - -S: shuffle the original data set 
 43       (for balancing) 
 44   
 45    - -r: randomize the activities of the original data set 
 46       (for balancing) 
 47   
 48    - -N *note*: note to be attached to the grown composite when it's saved in the 
 49       database 
 50   
 51    - --outNote *note*: equivalent to -N 
 52   
 53    - -o *filename*: name of an output file to hold the pickled composite after 
 54       it has been grown. 
 55       If multiple balance weights are used, the weights will be added to 
 56       the filenames. 
 57   
 58    - -L *limit*: provide an (integer) limit on individual model complexity 
 59     
 60    - -d *database name*: instead of reading the data from a QDAT file, 
 61       pull it from a database.  In this case, the _filename_ argument 
 62       provides the name of the database table containing the data set. 
 63   
 64    - -p *tablename*: store persistence data in the database 
 65       in table *tablename* 
 66   
 67    - -l: locks the random number generator to give consistent sets 
 68       of training and hold-out data.  This is primarily intended 
 69       for testing purposes. 
 70   
 71    - -g: be less greedy when training the models. 
 72   
 73    - -G *number*: force trees to be rooted at descriptor *number*. 
 74   
 75    - -D: show a detailed breakdown of the composite model performance 
 76       across the training and, when appropriate, hold-out sets. 
 77        
 78    - -t *threshold value*: use high-confidence predictions for the final 
 79       analysis of the hold-out data. 
 80   
 81    - -q *list string*:  Add QuantTrees to the composite and use the list 
 82       specified in *list string* as the number of target quantization 
 83       bounds for each descriptor.  Don't forget to include 0's at the 
 84       beginning and end of *list string* for the name and value fields. 
 85       For example, if there are 4 descriptors and you want 2 quant bounds 
 86       apiece, you would use _-q "[0,2,2,2,2,0]"_. 
 87       Two special cases: 
 88         1) If you would like to ignore a descriptor in the model building, 
 89            use '-1' for its number of quant bounds. 
 90         2) If you have integer valued data that should not be quantized 
 91            further, enter 0 for that descriptor. 
 92   
 93    - -V: print the version number and exit 
 94   
 95  """ 
 96  from __future__ import print_function 
 97  from rdkit import RDConfig 
 98  import numpy 
 99  from rdkit.ML.Data import DataUtils,SplitData 
100  from rdkit.ML import ScreenComposite,BuildComposite 
101  from rdkit.ML.Composite import AdjustComposite 
102  from rdkit.Dbase.DbConnection import DbConnect 
103  from rdkit.ML import CompositeRun 
104  from rdkit.six.moves import cPickle 
105  import sys,time,types 
106   
107  _runDetails = CompositeRun.CompositeRun() 
108   
109  __VERSION_STRING="0.5.0" 
110   
111  _verbose = 1 
112 -def message(msg):
113 """ emits messages to _sys.stdout_ 114 override this in modules which import this one to redirect output 115 116 **Arguments** 117 118 - msg: the string to be displayed 119 120 """ 121 if _verbose: sys.stdout.write('%s\n'%(msg))
122
123 -def GrowIt(details,composite,progressCallback=None, 124 saveIt=1,setDescNames=0,data=None):
125 """ does the actual work of building a composite model 126 127 **Arguments** 128 129 - details: a _CompositeRun.CompositeRun_ object containing details 130 (options, parameters, etc.) about the run 131 132 - composite: the composite model to grow 133 134 - progressCallback: (optional) a function which is called with a single 135 argument (the number of models built so far) after each model is built. 136 137 - saveIt: (optional) if this is nonzero, the resulting model will be pickled 138 and dumped to the filename specified in _details.outName_ 139 140 - setDescNames: (optional) if nonzero, the composite's _SetInputOrder()_ method 141 will be called using the results of the data set's _GetVarNames()_ method; 142 it is assumed that the details object has a _descNames attribute which 143 is passed to the composites _SetDescriptorNames()_ method. Otherwise 144 (the default), _SetDescriptorNames()_ gets the results of _GetVarNames()_. 145 146 - data: (optional) the data set to be used. If this is not provided, the 147 data set described in details will be used. 148 149 **Returns** 150 151 the enlarged composite model 152 153 154 """ 155 details.rundate = time.asctime() 156 157 if data is None: 158 fName = details.tableName.strip() 159 if details.outName == '': 160 details.outName = fName + '.pkl' 161 if details.dbName == '': 162 data = DataUtils.BuildQuantDataSet(fName) 163 elif details.qBounds != []: 164 details.tableName = fName 165 data = details.GetDataSet() 166 else: 167 data = DataUtils.DBToQuantData(details.dbName,fName,quantName=details.qTableName, 168 user=details.dbUser,password=details.dbPassword) 169 170 nExamples = data.GetNPts() 171 seed = composite._randomSeed 172 DataUtils.InitRandomNumbers(seed) 173 testExamples = [] 174 if details.shuffleActivities == 1: 175 DataUtils.RandomizeActivities(data,shuffle=1,runDetails=details) 176 elif details.randomActivities == 1: 177 DataUtils.RandomizeActivities(data,shuffle=0,runDetails=details) 178 179 namedExamples = data.GetNamedData() 180 trainExamples = namedExamples 181 nExamples = len(trainExamples) 182 message('Training with %d examples'%(nExamples)) 183 message('\t%d descriptors'%(len(trainExamples[0])-2)) 184 nVars = data.GetNVars() 185 nPossibleVals = composite.nPossibleVals 186 attrs = range(1,nVars+1) 187 188 if details.useTrees: 189 from rdkit.ML.DecTree import CrossValidate,PruneTree 190 if details.qBounds != []: 191 from rdkit.ML.DecTree import BuildQuantTree 192 builder = BuildQuantTree.QuantTreeBoot 193 else: 194 from rdkit.ML.DecTree import ID3 195 builder = ID3.ID3Boot 196 driver = CrossValidate.CrossValidationDriver 197 pruner = PruneTree.PruneTree 198 199 if setDescNames: 200 composite.SetInputOrder(data.GetVarNames()) 201 composite.Grow(trainExamples,attrs,[0]+nPossibleVals, 202 buildDriver=driver, 203 pruner=pruner, 204 nTries=details.nModels,pruneIt=details.pruneIt, 205 lessGreedy=details.lessGreedy,needsQuantization=0, 206 treeBuilder=builder,nQuantBounds=details.qBounds, 207 startAt=details.startAt, 208 maxDepth=details.limitDepth, 209 progressCallback=progressCallback, 210 silent=not _verbose) 211 212 213 else: 214 from rdkit.ML.Neural import CrossValidate 215 driver = CrossValidate.CrossValidationDriver 216 composite.Grow(trainExamples,attrs,[0]+nPossibleVals,nTries=details.nModels, 217 buildDriver=driver,needsQuantization=0) 218 219 composite.AverageErrors() 220 composite.SortModels() 221 modelList,counts,avgErrs = composite.GetAllData() 222 counts = numpy.array(counts) 223 avgErrs = numpy.array(avgErrs) 224 composite._varNames = data.GetVarNames() 225 226 for i in range(len(modelList)): 227 modelList[i].NameModel(composite._varNames) 228 229 # do final statistics 230 weightedErrs = counts*avgErrs 231 averageErr = sum(weightedErrs)/sum(counts) 232 devs = (avgErrs - averageErr) 233 devs = devs * counts 234 devs = numpy.sqrt(devs*devs) 235 avgDev = sum(devs)/sum(counts) 236 if _verbose: 237 message('# Overall Average Error: %%% 5.2f, Average Deviation: %%% 6.2f'%(100.*averageErr,100.*avgDev)) 238 239 if details.bayesModel: 240 composite.Train(trainExamples,verbose=0) 241 242 badExamples = [] 243 if not details.detailedRes: 244 if _verbose: 245 message('Testing all examples') 246 wrong = BuildComposite.testall(composite,namedExamples,badExamples) 247 if _verbose: 248 message('%d examples (%% %5.2f) were misclassified'%(len(wrong),100.*float(len(wrong))/float(len(namedExamples)))) 249 _runDetails.overall_error = float(len(wrong))/len(namedExamples) 250 251 if details.detailedRes: 252 if _verbose: 253 message('\nEntire data set:') 254 resTup = ScreenComposite.ShowVoteResults(range(data.GetNPts()),data,composite, 255 nPossibleVals[-1],details.threshold) 256 nGood,nBad,nSkip,avgGood,avgBad,avgSkip,voteTab = resTup 257 nPts = len(namedExamples) 258 nClass = nGood+nBad 259 _runDetails.overall_error = float(nBad) / nClass 260 _runDetails.overall_correct_conf = avgGood 261 _runDetails.overall_incorrect_conf = avgBad 262 _runDetails.overall_result_matrix = repr(voteTab) 263 nRej = nClass-nPts 264 if nRej > 0: 265 _runDetails.overall_fraction_dropped = float(nRej)/nPts 266 267 return composite
268
269 -def GetComposites(details):
270 res = [] 271 if details.persistTblName and details.inNote: 272 conn = DbConnect(details.dbName,details.persistTblName) 273 mdls = conn.GetData(fields='MODEL',where="where note='%s'"%(details.inNote)) 274 for row in mdls: 275 rawD = row[0] 276 res.append(cPickle.loads(str(rawD))) 277 elif details.composFileName: 278 res.append(cPickle.load(open(details.composFileName,'rb'))) 279 return res
280
281 -def BalanceComposite(details,composite,data1=None,data2=None):
282 """ balances the composite using the parameters provided in details 283 284 **Arguments** 285 286 - details a _CompositeRun.RunDetails_ object 287 288 - composite: the composite model to be balanced 289 290 - data1: (optional) if provided, this should be the 291 data set used to construct the original models 292 293 - data2: (optional) if provided, this should be the 294 data set used to construct the new individual models 295 296 """ 297 if not details.balCnt or details.balCnt > len(composite): 298 return composite 299 message("Balancing Composite") 300 301 # 302 # start by getting data set 1: which is the data set used to build the 303 # original models 304 # 305 if data1 is None: 306 message("\tReading First Data Set") 307 fName = details.balTable.strip() 308 tmp = details.tableName 309 details.tableName = fName 310 dbName = details.dbName 311 details.dbName = details.balDb 312 data1 = details.GetDataSet() 313 details.tableName = tmp 314 details.dbName = dbName 315 if data1 is None: 316 return composite 317 details.splitFrac = composite._splitFrac 318 details.randomSeed = composite._randomSeed 319 DataUtils.InitRandomNumbers(details.randomSeed) 320 if details.shuffleActivities == 1: 321 DataUtils.RandomizeActivities(data1,shuffle=1,runDetails=details) 322 elif details.randomActivities == 1: 323 DataUtils.RandomizeActivities(data1,shuffle=0,runDetails=details) 324 namedExamples = data1.GetNamedData() 325 if details.balDoHoldout or details.balDoTrain: 326 trainIdx,testIdx = SplitData.SplitIndices(len(namedExamples),details.splitFrac, 327 silent=1) 328 trainExamples = [namedExamples[x] for x in trainIdx] 329 testExamples = [namedExamples[x] for x in testIdx] 330 if details.filterFrac != 0.0: 331 trainIdx,temp = DataUtils.FilterData(trainExamples,details.filterVal, 332 details.filterFrac,-1, 333 indicesOnly=1) 334 tmp = [trainExamples[x] for x in trainIdx] 335 testExamples += [trainExamples[x] for x in temp] 336 trainExamples = tmp 337 if details.balDoHoldout: 338 testExamples,trainExamples = trainExamples,testExamples 339 else: 340 trainExamples = namedExamples 341 dataSet1 = trainExamples 342 cols1 = [x.upper() for x in data1.GetVarNames()] 343 data1 = None 344 345 # 346 # now grab data set 2: the data used to build the new individual models 347 # 348 if data2 is None: 349 message("\tReading Second Data Set") 350 data2 = details.GetDataSet() 351 if data2 is None: 352 return composite 353 details.splitFrac = composite._splitFrac 354 details.randomSeed = composite._randomSeed 355 DataUtils.InitRandomNumbers(details.randomSeed) 356 if details.shuffleActivities == 1: 357 DataUtils.RandomizeActivities(data2,shuffle=1,runDetails=details) 358 elif details.randomActivities == 1: 359 DataUtils.RandomizeActivities(data2,shuffle=0,runDetails=details) 360 dataSet2 = data2.GetNamedData() 361 cols2 = [x.upper() for x in data2.GetVarNames()] 362 data2 = None 363 364 # and balance it: 365 res = [] 366 weights = details.balWeight 367 if type(weights) not in (types.TupleType,types.ListType): 368 weights = (weights,) 369 for weight in weights: 370 message("\tBalancing with Weight: %.4f"%(weight)) 371 res.append(AdjustComposite.BalanceComposite(composite,dataSet1,dataSet2, 372 weight, 373 details.balCnt, 374 names1=cols1,names2=cols2)) 375 return res
376
377 -def ShowVersion(includeArgs=0):
378 """ prints the version number 379 380 """ 381 print('This is GrowComposite.py version %s'%(__VERSION_STRING)) 382 if includeArgs: 383 import sys 384 print('command line was:') 385 print(' '.join(sys.argv))
386
387 -def Usage():
388 """ provides a list of arguments for when this is used from the command line 389 390 """ 391 import sys 392 print(__doc__) 393 sys.exit(-1)
394
395 -def SetDefaults(runDetails=None):
396 """ initializes a details object with default values 397 398 **Arguments** 399 400 - details: (optional) a _CompositeRun.CompositeRun_ object. 401 If this is not provided, the global _runDetails will be used. 402 403 **Returns** 404 405 the initialized _CompositeRun_ object. 406 407 408 """ 409 if runDetails is None: runDetails = _runDetails 410 return CompositeRun.SetDefaults(runDetails)
411
412 -def ParseArgs(runDetails):
413 """ parses command line arguments and updates _runDetails_ 414 415 **Arguments** 416 417 - runDetails: a _CompositeRun.CompositeRun_ object. 418 419 """ 420 import getopt 421 args,extra = getopt.getopt(sys.argv[1:],'P:o:n:p:b:sf:F:v:hlgd:rSTt:Q:q:DVG:L:C:N:', 422 ['inNote=','outNote=','balTable=','balWeight=','balCnt=', 423 'balH','balT','balDb=',]) 424 runDetails.inNote='' 425 runDetails.composFileName='' 426 runDetails.balTable='' 427 runDetails.balWeight=(0.5,) 428 runDetails.balCnt=0 429 runDetails.balDoHoldout=0 430 runDetails.balDoTrain=0 431 runDetails.balDb='' 432 for arg,val in args: 433 if arg == '-n': 434 runDetails.nModels = int(val) 435 elif arg == '-C': 436 runDetails.composFileName=val 437 elif arg=='--balTable': 438 runDetails.balTable=val 439 elif arg=='--balWeight': 440 runDetails.balWeight=eval(val) 441 if type(runDetails.balWeight) not in (types.TupleType,types.ListType): 442 runDetails.balWeight=(runDetails.balWeight,) 443 elif arg=='--balCnt': 444 runDetails.balCnt=int(val) 445 elif arg=='--balH': 446 runDetails.balDoHoldout=1 447 elif arg=='--balT': 448 runDetails.balDoTrain=1 449 elif arg=='--balDb': 450 runDetails.balDb=val 451 elif arg == '--inNote': 452 runDetails.inNote=val 453 elif arg == '-N' or arg=='--outNote': 454 runDetails.note=val 455 elif arg == '-o': 456 runDetails.outName = val 457 elif arg == '-p': 458 runDetails.persistTblName=val 459 elif arg == '-r': 460 runDetails.randomActivities = 1 461 elif arg == '-S': 462 runDetails.shuffleActivities = 1 463 elif arg == '-h': 464 Usage() 465 elif arg == '-l': 466 runDetails.lockRandom = 1 467 elif arg == '-g': 468 runDetails.lessGreedy=1 469 elif arg == '-G': 470 runDetails.startAt = int(val) 471 elif arg == '-d': 472 runDetails.dbName=val 473 elif arg == '-T': 474 runDetails.useTrees = 0 475 elif arg == '-t': 476 runDetails.threshold=float(val) 477 elif arg == '-D': 478 runDetails.detailedRes = 1 479 elif arg == '-L': 480 runDetails.limitDepth = int(val) 481 elif arg == '-q': 482 qBounds = eval(val) 483 assert type(qBounds) in (types.TupleType,types.ListType),'bad argument type for -q, specify a list as a string' 484 runDetails.qBoundCount=val 485 runDetails.qBounds = qBounds 486 elif arg == '-Q': 487 qBounds = eval(val) 488 assert type(qBounds) in [type([]),type(())],'bad argument type for -Q, specify a list as a string' 489 runDetails.activityBounds=qBounds 490 runDetails.activityBoundsVals=val 491 elif arg == '-V': 492 ShowVersion() 493 sys.exit(0) 494 else: 495 print('bad argument:',arg,file=sys.stderr) 496 Usage() 497 runDetails.tableName=extra[0] 498 if not runDetails.balDb: 499 runDetails.balDb=runDetails.dbName
500 if __name__ == '__main__': 501 if len(sys.argv) < 2: 502 Usage() 503 504 _runDetails.cmd = ' '.join(sys.argv) 505 SetDefaults(_runDetails) 506 ParseArgs(_runDetails) 507 508 ShowVersion(includeArgs=1) 509 510 initModels = GetComposites(_runDetails) 511 nModels = len(initModels) 512 if nModels>1: 513 for i in range(nModels): 514 sys.stderr.write('---------------------------------\n\tDoing %d of %d\n---------------------------------\n'%(i+1,nModels)) 515 composite = GrowIt(_runDetails,initModels[i],setDescNames=1) 516 if _runDetails.balTable and _runDetails.balCnt: 517 composites = BalanceComposite(_runDetails,composite) 518 else: 519 composites=[composite] 520 for mdl in composites: 521 mdl.ClearModelExamples() 522 if _runDetails.outName: 523 nWeights = len(_runDetails.balWeight) 524 if nWeights==1: 525 outName = _runDetails.outName 526 composites[0].Pickle(outName) 527 else: 528 for i in range(nWeights): 529 weight = int(100*_runDetails.balWeight[i]) 530 model = composites[i] 531 outName = '%s.%d.pkl'%(_runDetails.outName.split('.pkl')[0],weight) 532 model.Pickle(outName) 533 if _runDetails.persistTblName and _runDetails.dbName: 534 message('Updating results table %s:%s'%(_runDetails.dbName,_runDetails.persistTblName)) 535 if(len(_runDetails.balWeight))>1: 536 message('WARNING: updating results table with models having different weights') 537 # save the composite 538 for i in range(len(composites)): 539 _runDetails.model = cPickle.dumps(composites[i]) 540 _runDetails.Store(db=_runDetails.dbName,table=_runDetails.persistTblName) 541 elif nModels==1: 542 composite = GrowIt(_runDetails,initModels[0],setDescNames=1) 543 if _runDetails.balTable and _runDetails.balCnt: 544 composites = BalanceComposite(_runDetails,composite) 545 else: 546 composites=[composite] 547 for mdl in composites: 548 mdl.ClearModelExamples() 549 if _runDetails.outName: 550 nWeights = len(_runDetails.balWeight) 551 if nWeights==1: 552 outName = _runDetails.outName 553 composites[0].Pickle(outName) 554 else: 555 for i in range(nWeights): 556 weight = int(100*_runDetails.balWeight[i]) 557 model = composites[i] 558 outName = '%s.%d.pkl'%(_runDetails.outName.split('.pkl')[0],weight) 559 model.Pickle(outName) 560 if _runDetails.persistTblName and _runDetails.dbName: 561 message('Updating results table %s:%s'%(_runDetails.dbName,_runDetails.persistTblName)) 562 if(len(composites))>1: 563 message('WARNING: updating results table with models having different weights') 564 for i in range(len(composites)): 565 _runDetails.model = cPickle.dumps(composites[i]) 566 _runDetails.Store(db=_runDetails.dbName,table=_runDetails.persistTblName) 567 else: 568 message("No models found") 569