1
2
3
4
5
6
7
8
9
10
11
12 """Command line tool to construct an enrichment plot from saved composite models
13
14 Usage: EnrichPlot [optional args] -d dbname -t tablename <models>
15
16 Required Arguments:
17 -d "dbName": the name of the database for screening
18
19 -t "tablename": provide the name of the table with the data to be screened
20
21 <models>: file name(s) of pickled composite model(s).
22 If the -p argument is also provided (see below), this argument is ignored.
23
24 Optional Arguments:
25 - -a "list": the list of result codes to be considered active. This will be
26 eval'ed, so be sure that it evaluates as a list or sequence of
27 integers. For example, -a "[1,2]" will consider activity values 1 and 2
28 to be active
29
30 - --enrich "list": identical to the -a argument above.
31
32 - --thresh: sets a threshold for the plot. If the confidence falls below
33 this value, picking will be terminated
34
35 - -H: screen only the hold out set (works only if a version of
36 BuildComposite more recent than 1.2.2 was used).
37
38 - -T: screen only the training set (works only if a version of
39 BuildComposite more recent than 1.2.2 was used).
40
41 - -S: shuffle activity values before screening
42
43 - -R: randomize activity values before screening
44
45 - -F *filter frac*: filters the data before training to change the
46 distribution of activity values in the training set. *filter frac*
47 is the fraction of the training set that should have the target value.
48 **See note in BuildComposite help about data filtering**
49
50 - -v *filter value*: filters the data before training to change the
51 distribution of activity values in the training set. *filter value*
52 is the target value to use in filtering.
53 **See note in BuildComposite help about data filtering**
54
55 - -p "tableName": provides the name of a db table containing the
56 models to be screened. If you use this argument, you should also
57 use the -N argument (below) to specify a note value.
58
59 - -N "note": provides a note to be used to pull models from a db table.
60
61 - --plotFile "filename": writes the data to an output text file (filename.dat)
62 and creates a gnuplot input file (filename.gnu) to plot it
63
64 - --showPlot: causes the gnuplot plot constructed using --plotFile to be
65 displayed in gnuplot.
66
67 """
68 from __future__ import print_function
69 from rdkit import RDConfig
70 import numpy
71 import copy
72 from rdkit.six.moves import cPickle
73
74 from rdkit.ML.Data import DataUtils,SplitData,Stats
75 from rdkit.Dbase.DbConnection import DbConnect
76 from rdkit import DataStructs
77 from rdkit.ML import CompositeRun
78 import sys,os,types
79 from rdkit.six import cmp
80
81 __VERSION_STRING="2.4.0"
82 -def message(msg,noRet=0,dest=sys.stderr):
83 """ emits messages to _sys.stderr_
84 override this in modules which import this one to redirect output
85
86 **Arguments**
87
88 - msg: the string to be displayed
89
90 """
91 if noRet:
92 dest.write('%s '%(msg))
93 else:
94 dest.write('%s\n'%(msg))
95 -def error(msg,dest=sys.stderr):
96 """ emits messages to _sys.stderr_
97 override this in modules which import this one to redirect output
98
99 **Arguments**
100
101 - msg: the string to be displayed
102
103 """
104 sys.stderr.write('ERROR: %s\n'%(msg))
105
106 -def ScreenModel(mdl,descs,data,picking=[1],indices=[],errorEstimate=0):
107 """ collects the results of screening an individual composite model that match
108 a particular value
109
110 **Arguments**
111
112 - mdl: the composite model
113
114 - descs: a list of descriptor names corresponding to the data set
115
116 - data: the data set, a list of points to be screened.
117
118 - picking: (Optional) a list of values that are to be collected.
119 For examples, if you want an enrichment plot for picking the values
120 1 and 2, you'd having picking=[1,2].
121
122 **Returns**
123
124 a list of 4-tuples containing:
125
126 - the id of the point
127
128 - the true result (from the data set)
129
130 - the predicted result
131
132 - the confidence value for the prediction
133
134 """
135 mdl.SetInputOrder(descs)
136
137 for j in range(len(mdl)):
138 tmp = mdl.GetModel(j)
139 if hasattr(tmp,'_trainIndices') and type(tmp._trainIndices)!=types.DictType:
140 tis = {}
141 if hasattr(tmp,'_trainIndices'):
142 for v in tmp._trainIndices: tis[v]=1
143 tmp._trainIndices=tis
144
145 res = []
146 if mdl.GetQuantBounds():
147 needsQuant = 1
148 else:
149 needsQuant = 0
150
151 if not indices: indices = range(len(data))
152 nTrueActives=0
153 for i in indices:
154 if errorEstimate:
155 use=[]
156 for j in range(len(mdl)):
157 tmp = mdl.GetModel(j)
158 if not tmp._trainIndices.get(i,0):
159 use.append(j)
160 else:
161 use=None
162 pt = data[i]
163 pred,conf = mdl.ClassifyExample(pt,onlyModels=use)
164 if needsQuant:
165 pt = mdl.QuantizeActivity(pt[:])
166 trueRes = pt[-1]
167 if trueRes in picking:
168 nTrueActives+=1
169 if pred in picking:
170 res.append((pt[0],trueRes,pred,conf))
171 return nTrueActives,res
172
174 """ Accumulates the data for the enrichment plot for a single model
175
176 **Arguments**
177
178 - predictions: a list of 3-tuples (as returned by _ScreenModels_)
179
180 - thresh: a threshold for the confidence level. Anything below
181 this threshold will not be considered
182
183 - sortIt: toggles sorting on confidence levels
184
185
186 **Returns**
187
188 - a list of 3-tuples:
189
190 - the id of the active picked here
191
192 - num actives found so far
193
194 - number of picks made so far
195
196 """
197 if sortIt:
198 predictions.sort(lambda x,y:cmp(y[3],x[3]))
199 res = []
200 nCorrect = 0
201 nPts = 0
202 for i in range(len(predictions)):
203 id,real,pred,conf = predictions[i]
204 if conf > thresh:
205 if pred == real:
206 nCorrect += 1
207 nPts += 1
208 res.append((id,nCorrect,nPts))
209
210 return res
211
212 -def MakePlot(details,final,counts,pickVects,nModels,nTrueActs=-1):
213 if not hasattr(details,'plotFile') or not details.plotFile:
214 return
215
216 dataFileName = '%s.dat'%(details.plotFile)
217 outF = open(dataFileName,'w+')
218 i = 0
219 while i < len(final) and counts[i] != 0:
220 if nModels>1:
221 mean,sd = Stats.MeanAndDev(pickVects[i])
222 confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90)
223 outF.write('%d %f %f %d %f\n'%(i+1,final[i][0]/counts[i],
224 final[i][1]/counts[i],counts[i],confInterval))
225 else:
226 outF.write('%d %f %f %d\n'%(i+1,final[i][0]/counts[i],
227 final[i][1]/counts[i],counts[i]))
228 i+=1
229 outF.close()
230 plotFileName = '%s.gnu'%(details.plotFile)
231 gnuF = open(plotFileName,'w+')
232 gnuHdr="""# Generated by EnrichPlot.py version: %s
233 set size square 0.7
234 set xr [0:]
235 set data styl points
236 set ylab 'Num Correct Picks'
237 set xlab 'Num Picks'
238 set grid
239 set nokey
240 set term postscript enh color solid "Helvetica" 16
241 set term X
242 """%(__VERSION_STRING)
243 print(gnuHdr, file=gnuF)
244 if nTrueActs >0:
245 print('set yr [0:%d]'%nTrueActs, file=gnuF)
246 print('plot x with lines', file=gnuF)
247 if nModels>1:
248 everyGap = i/20
249 print('replot "%s" using 1:2 with lines,'%(dataFileName),end='', file=gnuF)
250 print('"%s" every %d using 1:2:5 with yerrorbars'%(dataFileName,
251 everyGap), file=gnuF)
252 else:
253 print('replot "%s" with points'%(dataFileName), file=gnuF)
254 gnuF.close()
255
256 if hasattr(details,'showPlot') and details.showPlot:
257 try:
258 import os
259 from Gnuplot import Gnuplot
260 p = Gnuplot()
261
262 p('load "%s"'%(plotFileName))
263 raw_input('press return to continue...\n')
264 except:
265 import traceback
266 traceback.print_exc()
267
268
269
270
272 """ displays a usage message and exits """
273 sys.stderr.write(__doc__)
274 sys.exit(-1)
275
276 if __name__=='__main__':
277 import getopt
278 try:
279 args,extras = getopt.getopt(sys.argv[1:],'d:t:a:N:p:cSTHF:v:',
280 ('thresh=','plotFile=','showPlot',
281 'pickleCol=','OOB','noSort','pickBase=',
282 'doROC','rocThresh=','enrich='))
283 except:
284 import traceback
285 traceback.print_exc()
286 Usage()
287
288
289 details = CompositeRun.CompositeRun()
290 CompositeRun.SetDefaults(details)
291
292 details.activeTgt=[1]
293 details.doTraining = 0
294 details.doHoldout = 0
295 details.dbTableName = ''
296 details.plotFile = ''
297 details.showPlot = 0
298 details.pickleCol = -1
299 details.errorEstimate=0
300 details.sortIt=1
301 details.pickBase = ''
302 details.doROC=0
303 details.rocThresh=-1
304 for arg,val in args:
305 if arg == '-d':
306 details.dbName = val
307 if arg == '-t':
308 details.dbTableName = val
309 elif arg == '-a' or arg == '--enrich':
310 details.activeTgt = eval(val)
311 if(type(details.activeTgt) not in (types.TupleType,types.ListType)):
312 details.activeTgt = (details.activeTgt,)
313
314 elif arg == '--thresh':
315 details.threshold = float(val)
316 elif arg == '-N':
317 details.note = val
318 elif arg == '-p':
319 details.persistTblName = val
320 elif arg == '-S':
321 details.shuffleActivities = 1
322 elif arg == '-H':
323 details.doTraining = 0
324 details.doHoldout = 1
325 elif arg == '-T':
326 details.doTraining = 1
327 details.doHoldout = 0
328 elif arg == '-F':
329 details.filterFrac=float(val)
330 elif arg == '-v':
331 details.filterVal=float(val)
332 elif arg == '--plotFile':
333 details.plotFile = val
334 elif arg == '--showPlot':
335 details.showPlot=1
336 elif arg == '--pickleCol':
337 details.pickleCol=int(val)-1
338 elif arg == '--OOB':
339 details.errorEstimate=1
340 elif arg == '--noSort':
341 details.sortIt=0
342 elif arg == '--doROC':
343 details.doROC=1
344 elif arg == '--rocThresh':
345 details.rocThresh=int(val)
346 elif arg == '--pickBase':
347 details.pickBase=val
348
349 if not details.dbName or not details.dbTableName:
350 Usage()
351 print('*******Please provide both the -d and -t arguments')
352
353 message('Building Data set\n')
354 dataSet = DataUtils.DBToData(details.dbName,details.dbTableName,
355 user=RDConfig.defaultDBUser,
356 password=RDConfig.defaultDBPassword,
357 pickleCol=details.pickleCol,
358 pickleClass=DataStructs.ExplicitBitVect)
359
360 descs = dataSet.GetVarNames()
361 nPts = dataSet.GetNPts()
362 message('npts: %d\n'%(nPts))
363 final = numpy.zeros((nPts,2),numpy.float)
364 counts = numpy.zeros(nPts,numpy.integer)
365 selPts = [None]*nPts
366
367 models = []
368 if details.persistTblName:
369 conn = DbConnect(details.dbName,details.persistTblName)
370 message('-> Retrieving models from database')
371 curs = conn.GetCursor()
372 curs.execute("select model from %s where note='%s'"%(details.persistTblName,details.note))
373 message('-> Reconstructing models')
374 try:
375 blob = curs.fetchone()
376 except:
377 blob = None
378 while blob:
379 message(' Building model %d'%len(models))
380 blob = blob[0]
381 try:
382 models.append(cPickle.loads(str(blob)))
383 except:
384 import traceback
385 traceback.print_exc()
386 print('Model failed')
387 else:
388 message(' <-Done')
389 try:
390 blob = curs.fetchone()
391 except:
392 blob = None
393 curs = None
394 else:
395 for modelName in extras:
396 try:
397 model = cPickle.load(open(modelName,'rb'))
398 except:
399 import traceback
400 print('problems with model %s:'%modelName)
401 traceback.print_exc()
402 else:
403 models.append(model)
404 nModels = len(models)
405 pickVects = {}
406 halfwayPts = [1e8]*len(models)
407 for whichModel,model in enumerate(models):
408 tmpD = dataSet
409 try:
410 seed = model._randomSeed
411 except AttributeError:
412 pass
413 else:
414 DataUtils.InitRandomNumbers(seed)
415 if details.shuffleActivities:
416 DataUtils.RandomizeActivities(tmpD,
417 shuffle=1)
418 if hasattr(model,'_splitFrac') and (details.doHoldout or details.doTraining):
419 trainIdx,testIdx = SplitData.SplitIndices(tmpD.GetNPts(),model._splitFrac,
420 silent=1)
421 if details.filterFrac != 0.0:
422 trainFilt,temp = DataUtils.FilterData(tmpD,details.filterVal,
423 details.filterFrac,-1,
424 indicesToUse=trainIdx,
425 indicesOnly=1)
426 testIdx += temp
427 trainIdx = trainFilt
428 if details.doTraining:
429 testIdx,trainIdx = trainIdx,testIdx
430 else:
431 testIdx = range(tmpD.GetNPts())
432
433 message('screening %d examples'%(len(testIdx)))
434 nTrueActives,screenRes = ScreenModel(model,descs,tmpD,picking=details.activeTgt,
435 indices=testIdx,
436 errorEstimate=details.errorEstimate)
437 message('accumulating')
438 runningCounts = AccumulateCounts(screenRes,
439 sortIt=details.sortIt,
440 thresh=details.threshold)
441 if details.pickBase:
442 pickFile = open('%s.%d.picks'%(details.pickBase,whichModel+1),'w+')
443 else:
444 pickFile = None
445
446
447 for i,entry in enumerate(runningCounts):
448 entry = runningCounts[i]
449 selPts[i] = entry[0]
450 final[i][0] += entry[1]
451 final[i][1] += entry[2]
452 v = pickVects.get(i,[])
453 v.append(entry[1])
454 pickVects[i] = v
455 counts[i] += 1
456 if pickFile:
457 pickFile.write('%s\n'%(entry[0]))
458 if entry[1] >= nTrueActives/2 and entry[2]<halfwayPts[whichModel]:
459 halfwayPts[whichModel]=entry[2]
460 message('Halfway point: %d\n'%halfwayPts[whichModel])
461
462 if details.plotFile:
463 MakePlot(details,final,counts,pickVects,nModels,nTrueActs=nTrueActives)
464 else:
465 if nModels>1:
466 print('#Index\tAvg_num_correct\tConf90Pct\tAvg_num_picked\tNum_picks\tlast_selection')
467 else:
468 print('#Index\tAvg_num_correct\tAvg_num_picked\tNum_picks\tlast_selection')
469
470 i = 0
471 while i < nPts and counts[i] != 0:
472 if nModels>1:
473 mean,sd = Stats.MeanAndDev(pickVects[i])
474 confInterval = Stats.GetConfidenceInterval(sd,len(pickVects[i]),level=90)
475 print('%d\t%f\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i],confInterval,
476 final[i][1]/counts[i],
477 counts[i],str(selPts[i])))
478 else:
479 print('%d\t%f\t%f\t%d\t%s'%(i+1,final[i][0]/counts[i],
480 final[i][1]/counts[i],
481 counts[i],str(selPts[i])))
482 i += 1
483
484 mean,sd = Stats.MeanAndDev(halfwayPts)
485 print('Halfway point: %.2f(%.2f)'%(mean,sd))
486