Package rdkit ::
Package ML ::
Module AnalyzeComposite
|
|
1
2
3
4
5
6
7
8
9
10
11 """ command line utility to report on the contributions of descriptors to
12 tree-based composite models
13
14 Usage: AnalyzeComposite [optional args] <models>
15
16 <models>: file name(s) of pickled composite model(s)
17 (this is the name of the db table if using a database)
18
19 Optional Arguments:
20
21 -n number: the number of levels of each model to consider
22
23 -d dbname: the database from which to read the models
24
25 -N Note: the note string to search for to pull models from the database
26
27 -v: be verbose whilst screening
28 """
29 from __future__ import print_function
30 import numpy
31 import sys
32 from rdkit.six.moves import cPickle
33 from rdkit.ML.DecTree import TreeUtils,Tree
34 from rdkit.ML.Data import Stats
35 from rdkit.Dbase.DbConnection import DbConnect
36 from rdkit.ML import ScreenComposite
37
38 __VERSION_STRING="2.2.0"
39
40 -def ProcessIt(composites,nToConsider=3,verbose=0):
41 composite=composites[0]
42 nComposites =len(composites)
43 ns = composite.GetDescriptorNames()
44
45 if len(ns)>2:
46 globalRes = {}
47
48 nDone = 1
49 descNames = {}
50 for composite in composites:
51 if verbose > 0:
52 print('#------------------------------------')
53 print('Doing: ',nDone)
54 nModels = len(composite)
55 nDone += 1
56 res = {}
57 for i in range(len(composite)):
58 model = composite.GetModel(i)
59 if isinstance(model,Tree.TreeNode):
60 levels = TreeUtils.CollectLabelLevels(model,{},0,nToConsider)
61 TreeUtils.CollectDescriptorNames(model,descNames,0,nToConsider)
62 for descId in levels.keys():
63 v = res.get(descId,numpy.zeros(nToConsider,numpy.float))
64 v[levels[descId]] += 1./nModels
65 res[descId] = v
66 for k in res:
67 v = globalRes.get(k,numpy.zeros(nToConsider,numpy.float))
68 v += res[k]/nComposites
69 globalRes[k] = v
70 if verbose > 0:
71 for k in res.keys():
72 name = descNames[k]
73 strRes = ', '.join(['%4.2f'%x for x in res[k]])
74 print('%s,%s,%5.4f'%(name,strRes,sum(res[k])))
75
76 print()
77
78
79 if verbose >= 0:
80 print('# Average Descriptor Positions')
81 retVal = []
82 for k in globalRes.keys():
83 name = descNames[k]
84 if verbose >= 0:
85 strRes = ', '.join(['%4.2f'%x for x in globalRes[k]])
86 print('%s,%s,%5.4f'%(name,strRes,sum(globalRes[k])))
87 tmp = [name]
88 tmp.extend(globalRes[k])
89 tmp.append(sum(globalRes[k]))
90 retVal.append(tmp)
91 if verbose >= 0:
92 print()
93 else:
94 retVal = []
95 return retVal
96
97
99 fields = 'overall_error,holdout_error,overall_result_matrix,holdout_result_matrix,overall_correct_conf,overall_incorrect_conf,holdout_correct_conf,holdout_incorrect_conf'
100 try:
101 data = conn.GetData(fields=fields,where=where)
102 except:
103 import traceback
104 traceback.print_exc()
105 return None
106 nPts = len(data)
107 if not nPts:
108 sys.stderr.write('no runs found\n')
109 return None
110 overall = numpy.zeros(nPts,numpy.float)
111 overallEnrich = numpy.zeros(nPts,numpy.float)
112 oCorConf = 0.0
113 oInCorConf = 0.0
114 holdout = numpy.zeros(nPts,numpy.float)
115 holdoutEnrich = numpy.zeros(nPts,numpy.float)
116 hCorConf = 0.0
117 hInCorConf = 0.0
118 overallMatrix = None
119 holdoutMatrix = None
120 for i in range(nPts):
121 if data[i][0] is not None:
122 overall[i] = data[i][0]
123 oCorConf += data[i][4]
124 oInCorConf += data[i][5]
125 if data[i][1] is not None:
126 holdout[i] = data[i][1]
127 haveHoldout=1
128 else:
129 haveHoldout=0
130 tmpOverall = 1.*eval(data[i][2])
131 if enrich >=0:
132 overallEnrich[i] = ScreenComposite.CalcEnrichment(tmpOverall,tgt=enrich)
133 if haveHoldout:
134 tmpHoldout = 1.*eval(data[i][3])
135 if enrich >=0:
136 holdoutEnrich[i] = ScreenComposite.CalcEnrichment(tmpHoldout,tgt=enrich)
137 if overallMatrix is None:
138 if data[i][2] is not None:
139 overallMatrix = tmpOverall
140 if haveHoldout and data[i][3] is not None:
141 holdoutMatrix = tmpHoldout
142 else:
143 overallMatrix += tmpOverall
144 if haveHoldout:
145 holdoutMatrix += tmpHoldout
146 if haveHoldout:
147 hCorConf += data[i][6]
148 hInCorConf += data[i][7]
149
150 avgOverall = sum(overall)/nPts
151 oCorConf /= nPts
152 oInCorConf /= nPts
153 overallMatrix /= nPts
154 oSort = numpy.argsort(overall)
155 oMin = overall[oSort[0]]
156 overall -= avgOverall
157 devOverall = sqrt(sum(overall**2)/(nPts-1))
158 res = {}
159 res['oAvg'] = 100*avgOverall
160 res['oDev'] = 100*devOverall
161 res['oCorrectConf'] = 100*oCorConf
162 res['oIncorrectConf'] = 100*oInCorConf
163 res['oResultMat']=overallMatrix
164 res['oBestIdx']=oSort[0]
165 res['oBestErr']=100*oMin
166
167 if enrich>=0:
168 mean,dev = Stats.MeanAndDev(overallEnrich)
169 res['oAvgEnrich'] = mean
170 res['oDevEnrich'] = dev
171
172 if haveHoldout:
173 avgHoldout = sum(holdout)/nPts
174 hCorConf /= nPts
175 hInCorConf /= nPts
176 holdoutMatrix /= nPts
177 hSort = numpy.argsort(holdout)
178 hMin = holdout[hSort[0]]
179 holdout -= avgHoldout
180 devHoldout = sqrt(sum(holdout**2)/(nPts-1))
181 res['hAvg'] = 100*avgHoldout
182 res['hDev'] = 100*devHoldout
183 res['hCorrectConf'] = 100*hCorConf
184 res['hIncorrectConf'] = 100*hInCorConf
185 res['hResultMat']=holdoutMatrix
186 res['hBestIdx']=hSort[0]
187 res['hBestErr']=100*hMin
188 if enrich>=0:
189 mean,dev = Stats.MeanAndDev(holdoutEnrich)
190 res['hAvgEnrich'] = mean
191 res['hDevEnrich'] = dev
192 return res
193
195 statD = statD.copy()
196 statD['oBestIdx'] = statD['oBestIdx']+1
197 txt="""
198 # Error Statistics:
199 \tOverall: %(oAvg)6.3f%% (%(oDev)6.3f) %(oCorrectConf)4.1f/%(oIncorrectConf)4.1f
200 \t\tBest: %(oBestIdx)d %(oBestErr)6.3f%%"""%(statD)
201 if 'hAvg' in statD:
202 statD['hBestIdx'] = statD['hBestIdx']+1
203 txt += """
204 \tHoldout: %(hAvg)6.3f%% (%(hDev)6.3f) %(hCorrectConf)4.1f/%(hIncorrectConf)4.1f
205 \t\tBest: %(hBestIdx)d %(hBestErr)6.3f%%
206 """%(statD)
207 print(txt)
208 print()
209 print('# Results matrices:')
210 print('\tOverall:')
211 tmp = transpose(statD['oResultMat'])
212 colCounts = sum(tmp)
213 rowCounts = sum(tmp,1)
214 for i in range(len(tmp)):
215 if rowCounts[i]==0: rowCounts[i]=1
216 row = tmp[i]
217 print('\t\t', end='')
218 for j in range(len(row)):
219 print('% 6.2f'%row[j], end='')
220 print('\t| % 4.2f'%(100.*tmp[i,i]/rowCounts[i]))
221 print('\t\t', end='')
222 for i in range(len(tmp)):
223 print('------',end='')
224 print()
225 print('\t\t',end='')
226 for i in range(len(tmp)):
227 if colCounts[i]==0: colCounts[i]=1
228 print('% 6.2f'%(100.*tmp[i,i]/colCounts[i]), end='')
229 print()
230 if enrich>-1 and 'oAvgEnrich' in statD:
231 print('\t\tEnrich(%d): %.3f (%.3f)'%(enrich,statD['oAvgEnrich'],statD['oDevEnrich']))
232
233
234 if 'hResultMat' in statD:
235 print('\tHoldout:')
236 tmp = transpose(statD['hResultMat'])
237 colCounts = sum(tmp)
238 rowCounts = sum(tmp,1)
239 for i in range(len(tmp)):
240 if rowCounts[i]==0: rowCounts[i]=1
241 row = tmp[i]
242 print('\t\t', end='')
243 for j in range(len(row)):
244 print('% 6.2f'%row[j], end='')
245 print('\t| % 4.2f'%(100.*tmp[i,i]/rowCounts[i]))
246 print('\t\t',end='')
247 for i in range(len(tmp)):
248 print('------',end='')
249 print()
250 print('\t\t',end='')
251 for i in range(len(tmp)):
252 if colCounts[i]==0: colCounts[i]=1
253 print('% 6.2f'%(100.*tmp[i,i]/colCounts[i]),end='')
254 print()
255 if enrich>-1 and 'hAvgEnrich' in statD:
256 print('\t\tEnrich(%d): %.3f (%.3f)'%(enrich,statD['hAvgEnrich'],statD['hDevEnrich']))
257
258
259 return
260
261
263 print(__doc__)
264 sys.exit(-1)
265
266 if __name__ == "__main__":
267 import getopt
268 try:
269 args,extras = getopt.getopt(sys.argv[1:],'n:d:N:vX',('skip',
270 'enrich=',
271 ))
272 except:
273 Usage()
274
275 count = 3
276 db = None
277 note = ''
278 verbose = 0
279 skip = 0
280 enrich = 1
281 for arg,val in args:
282 if arg == '-n':
283 count = int(val)+1
284 elif arg == '-d':
285 db = val
286 elif arg == '-N':
287 note = val
288 elif arg == '-v':
289 verbose = 1
290 elif arg == '--skip':
291 skip = 1
292 elif arg == '--enrich':
293 enrich = int(val)
294 composites = []
295 if db is None:
296 for arg in extras:
297 composite = cPickle.load(open(arg,'rb'))
298 composites.append(composite)
299 else:
300 tbl = extras[0]
301 conn = DbConnect(db,tbl)
302 if note:
303 where="where note='%s'"%(note)
304 else:
305 where = ''
306 if not skip:
307 pkls = conn.GetData(fields='model',where=where)
308 composites = []
309 for pkl in pkls:
310 pkl = str(pkl[0])
311 comp = cPickle.loads(pkl)
312 composites.append(comp)
313
314 if len(composites):
315 ProcessIt(composites,count,verbose=verbose)
316 elif not skip:
317 print('ERROR: no composite models found')
318 sys.exit(-1)
319
320 if db:
321 res = ErrorStats(conn,where,enrich=enrich)
322 if res:
323 ShowStats(res)
324