Package rdkit :: Package Dbase :: Module DbUtils
[hide private]
[frames] | no frames]

Source Code for Module rdkit.Dbase.DbUtils

  1  # $Id$ 
  2  # 
  3  #  Copyright (C) 2000-2006  greg Landrum and Rational Discovery LLC 
  4  # 
  5  #   @@ All Rights Reserved @@ 
  6  #  This file is part of the RDKit. 
  7  #  The contents are covered by the terms of the BSD license 
  8  #  which is included in the file license.txt, found at the root 
  9  #  of the RDKit source tree. 
 10  # 
 11  """ a set of functions for interacting with databases 
 12   
 13   When possible, it's probably preferable to use a _DbConnection.DbConnect_ object 
 14   
 15  """ 
 16  from __future__ import print_function 
 17  from rdkit import RDConfig 
 18  from rdkit.Dbase.DbResultSet import DbResultSet,RandomAccessDbResultSet 
19 -def _take(fromL,what):
20 return map(lambda x,y=fromL:y[x],what)
21 22 from rdkit.Dbase import DbModule 23 import sys 24 from rdkit.Dbase import DbInfo 25 from rdkit.six.moves import xrange #@UnresolvedImport #pylint: disable=F0401 26 from rdkit.six import string_types 27
28 -def GetColumns(dBase,table,fieldString,user='sysdba',password='masterkey', 29 join='',cn=None):
30 """ gets a set of data from a table 31 32 **Arguments** 33 34 - dBase: database name 35 36 - table: table name 37 38 - fieldString: a string with the names of the fields to be extracted, 39 this should be a comma delimited list 40 41 - user and password: 42 43 - join: a join clause (omit the verb 'join') 44 45 46 **Returns** 47 48 - a list of the data 49 50 """ 51 if not cn: 52 cn = DbModule.connect(dBase,user,password) 53 c = cn.cursor() 54 cmd = 'select %s from %s'%(fieldString,table) 55 if join: 56 if join.strip().find('join') != 0: 57 join = 'join %s'%(join) 58 cmd +=' ' + join 59 c.execute(cmd) 60 return c.fetchall()
61
62 -def GetData(dBase,table,fieldString='*',whereString='',user='sysdba',password='masterkey', 63 removeDups=-1,join='',forceList=0,transform=None,randomAccess=1,extras=None,cn=None):
64 """ a more flexible method to get a set of data from a table 65 66 **Arguments** 67 68 - fields: a string with the names of the fields to be extracted, 69 this should be a comma delimited list 70 71 - where: the SQL where clause to be used with the DB query 72 73 - removeDups indicates the column which should be used to screen 74 out duplicates. Only the first appearance of a duplicate will 75 be left in the dataset. 76 77 **Returns** 78 79 - a list of the data 80 81 82 **Notes** 83 84 - EFF: this isn't particularly efficient 85 86 """ 87 if not cn: 88 cn = DbModule.connect(dBase,user,password) 89 c = cn.cursor() 90 cmd = 'select %s from %s'%(fieldString,table) 91 if join: 92 if join.strip().find('join') != 0: 93 join = 'join %s'%(join) 94 cmd += ' ' + join 95 if whereString: 96 if whereString.strip().find('where')!=0: 97 whereString = 'where %s'%(whereString) 98 cmd += ' ' + whereString 99 100 if forceList: 101 try: 102 if not extras: 103 c.execute(cmd) 104 else: 105 c.execute(cmd,extras) 106 except: 107 sys.stderr.write('the command "%s" generated errors:\n'%(cmd)) 108 import traceback 109 traceback.print_exc() 110 return None 111 if transform is not None: 112 raise ValueError('forceList and transform arguments are not compatible') 113 if not randomAccess: 114 raise ValueError('when forceList is set, randomAccess must also be used') 115 data = c.fetchall() 116 if removeDups>0: 117 seen = [] 118 for entry in data[:]: 119 if entry[removeDups] in seen: 120 data.remove(entry) 121 else: 122 seen.append(entry[removeDups]) 123 else: 124 if randomAccess: 125 klass = RandomAccessDbResultSet 126 else: 127 klass = DbResultSet 128 129 data = klass(c,cn,cmd,removeDups=removeDups,transform=transform,extras=extras) 130 131 return data
132
133 -def DatabaseToText(dBase,table,fields='*',join='',where='', 134 user='sysdba',password='masterkey',delim=',',cn=None):
135 """ Pulls the contents of a database and makes a deliminted text file from them 136 137 **Arguments** 138 - dBase: the name of the DB file to be used 139 140 - table: the name of the table to query 141 142 - fields: the fields to select with the SQL query 143 144 - join: the join clause of the SQL query 145 (e.g. 'join foo on foo.bar=base.bar') 146 147 - where: the where clause of the SQL query 148 (e.g. 'where foo = 2' or 'where bar > 17.6') 149 150 - user: the username for DB access 151 152 - password: the password to be used for DB access 153 154 **Returns** 155 156 - the CSV data (as text) 157 158 """ 159 if len(where) and where.strip().find('where')==-1: 160 where = 'where %s'%(where) 161 if len(join) and join.strip().find('join') == -1: 162 join = 'join %s'%(join) 163 sqlCommand = 'select %s from %s %s %s'%(fields,table,join,where) 164 if not cn: 165 cn = DbModule.connect(dBase,user,password) 166 c = cn.cursor() 167 c.execute(sqlCommand) 168 headers = [] 169 colsToTake = [] 170 # the description field of the cursor carries around info about the columns 171 # of the table 172 for i in range(len(c.description)): 173 item = c.description[i] 174 if item[1] not in DbInfo.sqlBinTypes: 175 colsToTake.append(i) 176 headers.append(item[0]) 177 178 lines = [] 179 lines.append(delim.join(headers)) 180 181 # grab the data 182 results = c.fetchall() 183 for res in results: 184 d = _take(res,colsToTake) 185 lines.append(delim.join(map(str,d))) 186 187 return '\n'.join(lines)
188 189
190 -def TypeFinder(data,nRows,nCols,nullMarker=None):
191 """ 192 193 finds the types of the columns in _data_ 194 195 if nullMarker is not None, elements of the data table which are 196 equal to nullMarker will not count towards setting the type of 197 their columns. 198 199 """ 200 priorities={float:3,int:2,str:1,-1:-1} 201 res = [None]*nCols 202 for col in xrange(nCols): 203 typeHere = [-1,1] 204 for row in xrange(nRows): 205 d = data[row][col] 206 if d is not None: 207 locType = type(d) 208 if locType != float and locType != int: 209 locType = str 210 try: 211 d = str(d) 212 except UnicodeError as msg: 213 print('cannot convert text from row %d col %d to a string'%(row+2,col)) 214 print('\t>%s'%(repr(d))) 215 raise UnicodeError(msg) 216 else: 217 typeHere[1] = max(typeHere[1],len(str(d))) 218 if isinstance(d, string_types): 219 if nullMarker is None or d != nullMarker: 220 l = max(len(d),typeHere[1]) 221 typeHere = [str,l] 222 else: 223 try: 224 fD = float(int(d)) 225 except OverflowError: 226 locType = float 227 else: 228 if fD == d: 229 locType = int 230 if not isinstance(typeHere[0], string_types) and \ 231 priorities[locType] > priorities[typeHere[0]]: 232 typeHere[0] = locType 233 res[col] = typeHere 234 return res
235
236 -def _AdjustColHeadings(colHeadings,maxColLabelLen):
237 """ *For Internal Use* 238 239 removes illegal characters from column headings 240 and truncates those which are too long. 241 242 """ 243 for i in xrange(len(colHeadings)): 244 # replace unallowed characters and strip extra white space 245 colHeadings[i] = colHeadings[i].strip() 246 colHeadings[i] = colHeadings[i].replace(' ','_') 247 colHeadings[i] = colHeadings[i].replace('-','_') 248 colHeadings[i] = colHeadings[i].replace('.','_') 249 250 if len(colHeadings[i]) > maxColLabelLen: 251 # interbase (at least) has a limit on the maximum length of a column name 252 newHead = colHeadings[i].replace('_','') 253 newHead = newHead[:maxColLabelLen] 254 print('\tHeading %s too long, changed to %s'%(colHeadings[i],newHead)) 255 colHeadings[i] = newHead 256 return colHeadings
257
258 -def GetTypeStrings(colHeadings,colTypes,keyCol=None):
259 """ returns a list of SQL type strings 260 """ 261 typeStrs=[] 262 for i in xrange(len(colTypes)): 263 typ = colTypes[i] 264 if typ[0] == float: 265 typeStrs.append('%s double precision'%colHeadings[i]) 266 elif typ[0] == int: 267 typeStrs.append('%s integer'%colHeadings[i]) 268 else: 269 typeStrs.append('%s varchar(%d)'%(colHeadings[i],typ[1])) 270 if colHeadings[i] == keyCol: 271 typeStrs[-1] = '%s not null primary key'%(typeStrs[-1]) 272 return typeStrs
273
274 -def _insertBlock(conn,sqlStr,block,silent=False):
275 try: 276 conn.cursor().executemany(sqlStr,block) 277 except: 278 res = 0 279 conn.commit() 280 for row in block: 281 try: 282 conn.cursor().execute(sqlStr,tuple(row)) 283 res += 1 284 except: 285 if not silent: 286 import traceback 287 traceback.print_exc() 288 print('insert failed:',sqlStr) 289 print('\t',repr(row)) 290 else: 291 conn.commit() 292 else: 293 res = len(block) 294 return res
295
296 -def _AddDataToDb(dBase,table,user,password,colDefs,colTypes,data, 297 nullMarker=None,blockSize=100,cn=None):
298 """ *For Internal Use* 299 300 (drops and) creates a table and then inserts the values 301 302 """ 303 if not cn: 304 cn = DbModule.connect(dBase,user,password) 305 c = cn.cursor() 306 try: 307 c.execute('drop table %s'%(table)) 308 except: 309 print('cannot drop table %s'%(table)) 310 try: 311 sqlStr = 'create table %s (%s)'%(table,colDefs) 312 c.execute(sqlStr) 313 except: 314 print('create table failed: ', sqlStr) 315 print('here is the exception:') 316 import traceback 317 traceback.print_exc() 318 return 319 cn.commit() 320 c = None 321 322 block = [] 323 entryTxt = [DbModule.placeHolder]*len(data[0]) 324 dStr = ','.join(entryTxt) 325 sqlStr = 'insert into %s values (%s)'%(table,dStr) 326 nDone = 0 327 for row in data: 328 entries = [None]*len(row) 329 for col in xrange(len(row)): 330 if row[col] is not None and \ 331 (nullMarker is None or row[col] != nullMarker): 332 if colTypes[col][0] == float: 333 entries[col] = float(row[col]) 334 elif colTypes[col][0] == int: 335 entries[col] = int(row[col]) 336 else: 337 entries[col] = str(row[col]) 338 else: 339 entries[col] = None 340 block.append(tuple(entries)) 341 if len(block)>=blockSize: 342 nDone += _insertBlock(cn,sqlStr,block) 343 if not hasattr(cn,'autocommit') or not cn.autocommit: 344 cn.commit() 345 block = [] 346 if len(block): 347 nDone += _insertBlock(cn,sqlStr,block) 348 if not hasattr(cn,'autocommit') or not cn.autocommit: 349 cn.commit()
350
351 -def TextFileToDatabase(dBase,table,inF,delim=',', 352 user='sysdba',password='masterkey', 353 maxColLabelLen=31,keyCol=None,nullMarker=None):
354 """loads the contents of the text file into a database. 355 356 **Arguments** 357 358 - dBase: the name of the DB to use 359 360 - table: the name of the table to create/overwrite 361 362 - inF: the file like object from which the data should 363 be pulled (must support readline()) 364 365 - delim: the delimiter used to separate fields 366 367 - user: the user name to use in connecting to the DB 368 369 - password: the password to use in connecting to the DB 370 371 - maxColLabelLen: the maximum length a column label should be 372 allowed to have (truncation otherwise) 373 374 - keyCol: the column to be used as an index for the db 375 376 **Notes** 377 378 - if _table_ already exists, it is destroyed before we write 379 the new data 380 381 - we assume that the first row of the file contains the column names 382 383 """ 384 table.replace('-','_') 385 table.replace(' ','_') 386 387 colHeadings = inF.readline().split(delim) 388 _AdjustColHeadings(colHeadings,maxColLabelLen) 389 nCols = len(colHeadings) 390 data = [] 391 inL = inF.readline() 392 while inL: 393 inL = inL.replace('\r','') 394 inL = inL.replace('\n','') 395 splitL = inL.split(delim) 396 if len(splitL)!=nCols: 397 print('>>>',repr(inL)) 398 assert len(splitL)==nCols,'unequal length' 399 tmpVect = [] 400 for entry in splitL: 401 try: 402 val = int(entry) 403 except: 404 try: 405 val = float(entry) 406 except: 407 val = entry 408 tmpVect.append(val) 409 data.append(tmpVect) 410 inL = inF.readline() 411 nRows = len(data) 412 413 # determine the types of each column 414 colTypes = TypeFinder(data,nRows,nCols,nullMarker=nullMarker) 415 typeStrs = GetTypeStrings(colHeadings,colTypes,keyCol=keyCol) 416 colDefs=','.join(typeStrs) 417 418 _AddDataToDb(dBase,table,user,password,colDefs,colTypes,data, 419 nullMarker=nullMarker)
420 421
422 -def DatabaseToDatabase(fromDb,fromTbl,toDb,toTbl, 423 fields='*',join='',where='', 424 user='sysdba',password='masterkey',keyCol=None,nullMarker='None'):
425 """ 426 427 FIX: at the moment this is a hack 428 429 """ 430 from io import StringIO 431 sio = StringIO() 432 sio.write(DatabaseToText(fromDb,fromTbl,fields=fields,join=join,where=where, 433 user=user,password=password)) 434 sio.seek(-1) 435 TextFileToDatabase(toDb,toTbl,sio,user=user,password=password,keyCol=keyCol, 436 nullMarker=nullMarker)
437 438 439 if __name__=='__main__': 440 from io import StringIO 441 442 sio = StringIO() 443 sio.write('foo,bar,baz\n') 444 sio.write('1,2,3\n') 445 sio.write('1.1,4,5\n') 446 sio.write('4,foo,6\n') 447 sio.seek(0) 448 from rdkit import RDConfig 449 import os 450 dirLoc = os.path.join(RDConfig.RDCodeDir,'Dbase','TEST.GDB') 451 452 TextFileToDatabase(dirLoc,'fromtext',sio) 453