Package rdkit :: Package ML :: Package Data :: Module SplitData
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.Data.SplitData

  1  ## Automatically adapted for numpy.oldnumeric Jun 27, 2008 by -c 
  2   
  3  # 
  4  #  Copyright (C) 2003-2008 Greg Landrum and Rational Discovery LLC 
  5  #    All Rights Reserved 
  6  # 
  7  from __future__ import print_function 
  8  import random 
  9  import os.path,sys 
 10   
 11  from rdkit import RDConfig,RDRandom 
 12  from rdkit.six.moves import xrange 
 13   
 14  SeqTypes=(list, tuple) 
 15   
16 -def SplitIndices(nPts,frac,silent=1,legacy=0,replacement=0):
17 """ splits a set of indices into a data set into 2 pieces 18 19 **Arguments** 20 21 - nPts: the total number of points 22 23 - frac: the fraction of the data to be put in the first data set 24 25 - silent: (optional) toggles display of stats 26 27 - legacy: (optional) use the legacy splitting approach 28 29 - replacement: (optional) use selection with replacement 30 31 **Returns** 32 33 a 2-tuple containing the two sets of indices. 34 35 **Notes** 36 37 - the _legacy_ splitting approach uses randomly-generated floats 38 and compares them to _frac_. This is provided for 39 backwards-compatibility reasons. 40 41 - the default splitting approach uses a random permutation of 42 indices which is split into two parts. 43 44 - selection with replacement can generate duplicates. 45 46 47 **Usage**: 48 49 We'll start with a set of indices and pick from them using 50 the three different approaches: 51 >>> from rdkit.ML.Data import DataUtils 52 53 The base approach always returns the same number of compounds in 54 each set and has no duplicates: 55 >>> DataUtils.InitRandomNumbers((23,42)) 56 >>> test,train = SplitIndices(10,.5) 57 >>> test 58 [1, 5, 6, 4, 2] 59 >>> train 60 [3, 0, 7, 8, 9] 61 62 >>> test,train = SplitIndices(10,.5) 63 >>> test 64 [5, 2, 9, 8, 7] 65 >>> train 66 [6, 0, 3, 1, 4] 67 68 69 The legacy approach can return varying numbers, but still has no 70 duplicates. Note the indices come back ordered: 71 >>> DataUtils.InitRandomNumbers((23,42)) 72 >>> test,train = SplitIndices(10,.5,legacy=1) 73 >>> test 74 [3, 5, 7, 8, 9] 75 >>> train 76 [0, 1, 2, 4, 6] 77 78 >>> test,train = SplitIndices(10,.5,legacy=1) 79 >>> test 80 [0, 1, 2, 3, 5, 8, 9] 81 >>> train 82 [4, 6, 7] 83 84 The replacement approach returns a fixed number in the training set, 85 a variable number in the test set and can contain duplicates in the 86 training set. 87 >>> DataUtils.InitRandomNumbers((23,42)) 88 >>> test,train = SplitIndices(10,.5,replacement=1) 89 >>> test 90 [9, 9, 8, 0, 5] 91 >>> train 92 [1, 2, 3, 4, 6, 7] 93 >>> test,train = SplitIndices(10,.5,replacement=1) 94 >>> test 95 [4, 5, 1, 1, 4] 96 >>> train 97 [0, 2, 3, 6, 7, 8, 9] 98 99 """ 100 if frac<0. or frac > 1.: 101 raise ValueError('frac must be between 0.0 and 1.0 (frac=%f)'%(frac)) 102 103 if replacement: 104 nTrain = int(nPts*frac) 105 resData = [None]*nTrain 106 resTest = [] 107 for i in range(nTrain): 108 val = int(RDRandom.random()*nPts) 109 if val==nPts: val = nPts-1 110 resData[i] = val 111 for i in range(nPts): 112 if i not in resData: 113 resTest.append(i) 114 elif legacy: 115 resData = [] 116 resTest = [] 117 for i in range(nPts): 118 val = RDRandom.random() 119 if val < frac: 120 resData.append(i) 121 else: 122 resTest.append(i) 123 else: 124 perm = list(xrange(nPts)) 125 random.shuffle(perm,random=random.random) 126 nTrain = int(nPts*frac) 127 128 resData = list(perm[:nTrain]) 129 resTest = list(perm[nTrain:]) 130 131 if not silent: 132 print('Training with %d (of %d) points.'%(len(resData),nPts)) 133 print('\t%d points are in the hold-out set.'%(len(resTest))) 134 return resData,resTest
135 136
137 -def SplitDataSet(data,frac,silent=0):
138 """ splits a data set into two pieces 139 140 **Arguments** 141 142 - data: a list of examples to be split 143 144 - frac: the fraction of the data to be put in the first data set 145 146 - silent: controls the amount of visual noise produced. 147 148 **Returns** 149 150 a 2-tuple containing the two new data sets. 151 152 """ 153 if frac>0. or frac < 1.: 154 raise ValueError('frac must be between 0.0 and 1.0') 155 156 nOrig = len(data) 157 train,test = SplitIndices(nOrig,frac,silent=1) 158 resData = [data[x] for x in train] 159 resTest = [data[x] for x in test] 160 161 if not silent: 162 print('Training with %d (of %d) points.'%(len(resData),nOrig)) 163 print('\t%d points are in the hold-out set.'%(len(resTest))) 164 return resData,resTest
165 166
167 -def SplitDbData(conn,fracs,table='',fields='*',where='',join='', 168 labelCol='', 169 useActs=0,nActs=2,actCol='',actBounds=[], 170 silent=0):
171 """ "splits" a data set held in a DB by returning lists of ids 172 173 **Arguments**: 174 175 - conn: a DbConnect object 176 177 - frac: the split fraction. This can optionally be specified as a 178 sequence with a different fraction for each activity value. 179 180 - table,fields,where,join: (optional) SQL query parameters 181 182 - useActs: (optional) toggles splitting based on activities 183 (ensuring that a given fraction of each activity class ends 184 up in the hold-out set) 185 Defaults to 0 186 187 - nActs: (optional) number of possible activity values, only 188 used if _useActs_ is nonzero 189 Defaults to 2 190 191 - actCol: (optional) name of the activity column 192 Defaults to use the last column returned by the query 193 194 - actBounds: (optional) sequence of activity bounds 195 (for cases where the activity isn't quantized in the db) 196 Defaults to an empty sequence 197 198 - silent: controls the amount of visual noise produced. 199 200 **Usage**: 201 202 Set up the db connection, the simple tables we're using have actives with even 203 ids and inactives with odd ids: 204 >>> from rdkit.ML.Data import DataUtils 205 >>> from rdkit.Dbase.DbConnection import DbConnect 206 >>> conn = DbConnect(RDConfig.RDTestDatabase) 207 208 Pull a set of points from a simple table... take 33% of all points: 209 >>> DataUtils.InitRandomNumbers((23,42)) 210 >>> train,test = SplitDbData(conn,1./3.,'basic_2class') 211 >>> [str(x) for x in train] 212 ['id-7', 'id-6', 'id-2', 'id-8'] 213 214 ...take 50% of actives and 50% of inactives: 215 >>> DataUtils.InitRandomNumbers((23,42)) 216 >>> train,test = SplitDbData(conn,.5,'basic_2class',useActs=1) 217 >>> [str(x) for x in train] 218 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10', 'id-8'] 219 220 221 Notice how the results came out sorted by activity 222 223 We can be asymmetrical: take 33% of actives and 50% of inactives: 224 >>> DataUtils.InitRandomNumbers((23,42)) 225 >>> train,test = SplitDbData(conn,[.5,1./3.],'basic_2class',useActs=1) 226 >>> [str(x) for x in train] 227 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10'] 228 229 And we can pull from tables with non-quantized activities by providing 230 activity quantization bounds: 231 >>> DataUtils.InitRandomNumbers((23,42)) 232 >>> train,test = SplitDbData(conn,.5,'float_2class',useActs=1,actBounds=[1.0]) 233 >>> [str(x) for x in train] 234 ['id-5', 'id-3', 'id-1', 'id-4', 'id-10', 'id-8'] 235 236 237 """ 238 if not table: 239 table=conn.tableName 240 if actBounds and len(actBounds)!=nActs-1: 241 raise ValueError('activity bounds list length incorrect') 242 if useActs: 243 if type(fracs) not in SeqTypes: 244 fracs = tuple([fracs]*nActs) 245 for frac in fracs: 246 if frac <0.0 or frac>1.0: 247 raise ValueError('fractions must be between 0.0 and 1.0') 248 else: 249 if type(fracs) in SeqTypes: 250 frac = fracs[0] 251 if frac<0.0 or frac>1.0: 252 raise ValueError('fractions must be between 0.0 and 1.0') 253 else: 254 frac = fracs 255 # start by getting the name of the ID column: 256 colNames = conn.GetColumnNames(table=table,what=fields,join=join) 257 idCol = colNames[0] 258 259 if not useActs: 260 # get the IDS: 261 d = conn.GetData(table=table,fields=idCol,join=join) 262 ids = [x[0] for x in d] 263 nRes = len(ids) 264 train,test = SplitIndices(nRes,frac,silent=1) 265 trainPts = [ids[x] for x in train] 266 testPts = [ids[x] for x in test] 267 else: 268 trainPts = [] 269 testPts = [] 270 if not actCol: 271 actCol = colNames[-1] 272 whereBase=where.strip() 273 if whereBase.find('where')!=0: 274 whereBase = 'where '+whereBase 275 if where: 276 whereBase += ' and ' 277 for act in range(nActs): 278 frac = fracs[act] 279 if not actBounds: 280 whereTxt = whereBase + '%s=%d'%(actCol,act) 281 else: 282 whereTxt = whereBase 283 if act!=0: 284 whereTxt += '%s>=%f '%(actCol,actBounds[act-1]) 285 if act < nActs-1: 286 if act!=0: 287 whereTxt += 'and ' 288 whereTxt += '%s<%f'%(actCol,actBounds[act]) 289 d = conn.GetData(table=table,fields=idCol,join=join,where=whereTxt) 290 ids = [x[0] for x in d] 291 nRes = len(ids) 292 train,test = SplitIndices(nRes,frac,silent=1) 293 trainPts.extend([ids[x] for x in train]) 294 testPts.extend([ids[x] for x in test]) 295 296 return trainPts,testPts
297
298 -def _test():
299 import doctest,sys 300 return doctest.testmod(sys.modules["__main__"])
301 302 if __name__ == '__main__': 303 import sys 304 failed,tried = _test() 305 sys.exit(failed) 306