1
2
3
4
5
6
7 from __future__ import print_function
8 import random
9 import os.path,sys
10
11 from rdkit import RDConfig,RDRandom
12 from rdkit.six.moves import xrange
13
14 SeqTypes=(list, tuple)
15
17 """ splits a set of indices into a data set into 2 pieces
18
19 **Arguments**
20
21 - nPts: the total number of points
22
23 - frac: the fraction of the data to be put in the first data set
24
25 - silent: (optional) toggles display of stats
26
27 - legacy: (optional) use the legacy splitting approach
28
29 - replacement: (optional) use selection with replacement
30
31 **Returns**
32
33 a 2-tuple containing the two sets of indices.
34
35 **Notes**
36
37 - the _legacy_ splitting approach uses randomly-generated floats
38 and compares them to _frac_. This is provided for
39 backwards-compatibility reasons.
40
41 - the default splitting approach uses a random permutation of
42 indices which is split into two parts.
43
44 - selection with replacement can generate duplicates.
45
46
47 **Usage**:
48
49 We'll start with a set of indices and pick from them using
50 the three different approaches:
51 >>> from rdkit.ML.Data import DataUtils
52
53 The base approach always returns the same number of compounds in
54 each set and has no duplicates:
55 >>> DataUtils.InitRandomNumbers((23,42))
56 >>> test,train = SplitIndices(10,.5)
57 >>> test
58 [1, 5, 6, 4, 2]
59 >>> train
60 [3, 0, 7, 8, 9]
61
62 >>> test,train = SplitIndices(10,.5)
63 >>> test
64 [5, 2, 9, 8, 7]
65 >>> train
66 [6, 0, 3, 1, 4]
67
68
69 The legacy approach can return varying numbers, but still has no
70 duplicates. Note the indices come back ordered:
71 >>> DataUtils.InitRandomNumbers((23,42))
72 >>> test,train = SplitIndices(10,.5,legacy=1)
73 >>> test
74 [3, 5, 7, 8, 9]
75 >>> train
76 [0, 1, 2, 4, 6]
77
78 >>> test,train = SplitIndices(10,.5,legacy=1)
79 >>> test
80 [0, 1, 2, 3, 5, 8, 9]
81 >>> train
82 [4, 6, 7]
83
84 The replacement approach returns a fixed number in the training set,
85 a variable number in the test set and can contain duplicates in the
86 training set.
87 >>> DataUtils.InitRandomNumbers((23,42))
88 >>> test,train = SplitIndices(10,.5,replacement=1)
89 >>> test
90 [9, 9, 8, 0, 5]
91 >>> train
92 [1, 2, 3, 4, 6, 7]
93 >>> test,train = SplitIndices(10,.5,replacement=1)
94 >>> test
95 [4, 5, 1, 1, 4]
96 >>> train
97 [0, 2, 3, 6, 7, 8, 9]
98
99 """
100 if frac<0. or frac > 1.:
101 raise ValueError('frac must be between 0.0 and 1.0 (frac=%f)'%(frac))
102
103 if replacement:
104 nTrain = int(nPts*frac)
105 resData = [None]*nTrain
106 resTest = []
107 for i in range(nTrain):
108 val = int(RDRandom.random()*nPts)
109 if val==nPts: val = nPts-1
110 resData[i] = val
111 for i in range(nPts):
112 if i not in resData:
113 resTest.append(i)
114 elif legacy:
115 resData = []
116 resTest = []
117 for i in range(nPts):
118 val = RDRandom.random()
119 if val < frac:
120 resData.append(i)
121 else:
122 resTest.append(i)
123 else:
124 perm = list(xrange(nPts))
125 random.shuffle(perm,random=random.random)
126 nTrain = int(nPts*frac)
127
128 resData = list(perm[:nTrain])
129 resTest = list(perm[nTrain:])
130
131 if not silent:
132 print('Training with %d (of %d) points.'%(len(resData),nPts))
133 print('\t%d points are in the hold-out set.'%(len(resTest)))
134 return resData,resTest
135
136
138 """ splits a data set into two pieces
139
140 **Arguments**
141
142 - data: a list of examples to be split
143
144 - frac: the fraction of the data to be put in the first data set
145
146 - silent: controls the amount of visual noise produced.
147
148 **Returns**
149
150 a 2-tuple containing the two new data sets.
151
152 """
153 if frac>0. or frac < 1.:
154 raise ValueError('frac must be between 0.0 and 1.0')
155
156 nOrig = len(data)
157 train,test = SplitIndices(nOrig,frac,silent=1)
158 resData = [data[x] for x in train]
159 resTest = [data[x] for x in test]
160
161 if not silent:
162 print('Training with %d (of %d) points.'%(len(resData),nOrig))
163 print('\t%d points are in the hold-out set.'%(len(resTest)))
164 return resData,resTest
165
166
167 -def SplitDbData(conn,fracs,table='',fields='*',where='',join='',
168 labelCol='',
169 useActs=0,nActs=2,actCol='',actBounds=[],
170 silent=0):
171 """ "splits" a data set held in a DB by returning lists of ids
172
173 **Arguments**:
174
175 - conn: a DbConnect object
176
177 - frac: the split fraction. This can optionally be specified as a
178 sequence with a different fraction for each activity value.
179
180 - table,fields,where,join: (optional) SQL query parameters
181
182 - useActs: (optional) toggles splitting based on activities
183 (ensuring that a given fraction of each activity class ends
184 up in the hold-out set)
185 Defaults to 0
186
187 - nActs: (optional) number of possible activity values, only
188 used if _useActs_ is nonzero
189 Defaults to 2
190
191 - actCol: (optional) name of the activity column
192 Defaults to use the last column returned by the query
193
194 - actBounds: (optional) sequence of activity bounds
195 (for cases where the activity isn't quantized in the db)
196 Defaults to an empty sequence
197
198 - silent: controls the amount of visual noise produced.
199
200 **Usage**:
201
202 Set up the db connection, the simple tables we're using have actives with even
203 ids and inactives with odd ids:
204 >>> from rdkit.ML.Data import DataUtils
205 >>> from rdkit.Dbase.DbConnection import DbConnect
206 >>> conn = DbConnect(RDConfig.RDTestDatabase)
207
208 Pull a set of points from a simple table... take 33% of all points:
209 >>> DataUtils.InitRandomNumbers((23,42))
210 >>> train,test = SplitDbData(conn,1./3.,'basic_2class')
211 >>> [str(x) for x in train]
212 ['id-7', 'id-6', 'id-2', 'id-8']
213
214 ...take 50% of actives and 50% of inactives:
215 >>> DataUtils.InitRandomNumbers((23,42))
216 >>> train,test = SplitDbData(conn,.5,'basic_2class',useActs=1)
217 >>> [str(x) for x in train]
218 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10', 'id-8']
219
220
221 Notice how the results came out sorted by activity
222
223 We can be asymmetrical: take 33% of actives and 50% of inactives:
224 >>> DataUtils.InitRandomNumbers((23,42))
225 >>> train,test = SplitDbData(conn,[.5,1./3.],'basic_2class',useActs=1)
226 >>> [str(x) for x in train]
227 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10']
228
229 And we can pull from tables with non-quantized activities by providing
230 activity quantization bounds:
231 >>> DataUtils.InitRandomNumbers((23,42))
232 >>> train,test = SplitDbData(conn,.5,'float_2class',useActs=1,actBounds=[1.0])
233 >>> [str(x) for x in train]
234 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10', 'id-8']
235
236
237 """
238 if not table:
239 table=conn.tableName
240 if actBounds and len(actBounds)!=nActs-1:
241 raise ValueError('activity bounds list length incorrect')
242 if useActs:
243 if type(fracs) not in SeqTypes:
244 fracs = tuple([fracs]*nActs)
245 for frac in fracs:
246 if frac <0.0 or frac>1.0:
247 raise ValueError('fractions must be between 0.0 and 1.0')
248 else:
249 if type(fracs) in SeqTypes:
250 frac = fracs[0]
251 if frac<0.0 or frac>1.0:
252 raise ValueError('fractions must be between 0.0 and 1.0')
253 else:
254 frac = fracs
255
256 colNames = conn.GetColumnNames(table=table,what=fields,join=join)
257 idCol = colNames[0]
258
259 if not useActs:
260
261 d = conn.GetData(table=table,fields=idCol,join=join)
262 ids = [x[0] for x in d]
263 nRes = len(ids)
264 train,test = SplitIndices(nRes,frac,silent=1)
265 trainPts = [ids[x] for x in train]
266 testPts = [ids[x] for x in test]
267 else:
268 trainPts = []
269 testPts = []
270 if not actCol:
271 actCol = colNames[-1]
272 whereBase=where.strip()
273 if whereBase.find('where')!=0:
274 whereBase = 'where '+whereBase
275 if where:
276 whereBase += ' and '
277 for act in range(nActs):
278 frac = fracs[act]
279 if not actBounds:
280 whereTxt = whereBase + '%s=%d'%(actCol,act)
281 else:
282 whereTxt = whereBase
283 if act!=0:
284 whereTxt += '%s>=%f '%(actCol,actBounds[act-1])
285 if act < nActs-1:
286 if act!=0:
287 whereTxt += 'and '
288 whereTxt += '%s<%f'%(actCol,actBounds[act])
289 d = conn.GetData(table=table,fields=idCol,join=join,where=whereTxt)
290 ids = [x[0] for x in d]
291 nRes = len(ids)
292 train,test = SplitIndices(nRes,frac,silent=1)
293 trainPts.extend([ids[x] for x in train])
294 testPts.extend([ids[x] for x in test])
295
296 return trainPts,testPts
297
299 import doctest,sys
300 return doctest.testmod(sys.modules["__main__"])
301
302 if __name__ == '__main__':
303 import sys
304 failed,tried = _test()
305 sys.exit(failed)
306