Package rdkit :: Package Chem :: Package Subshape :: Module SubshapeAligner
[hide private]
[frames] | no frames]

Source Code for Module rdkit.Chem.Subshape.SubshapeAligner

  1  # $Id$ 
  2  # 
  3  # Copyright (C) 2007-2008 by Greg Landrum  
  4  #  All rights reserved 
  5  # 
  6  from __future__ import print_function 
  7  from rdkit import RDLogger 
  8  logger = RDLogger.logger() 
  9  from rdkit import Chem,Geometry 
 10  import numpy 
 11  from rdkit.Numerics import Alignment 
 12  from rdkit.Chem.Subshape import SubshapeObjects 
 13   
14 -class SubshapeAlignment(object):
15 transform=None 16 triangleSSD=None 17 targetTri=None 18 queryTri=None 19 alignedConfId=-1 20 dirMatch=0.0 21 shapeDist=0.0
22
23 -def _getAllTriangles(pts,orderedTraversal=False):
24 for i in range(len(pts)): 25 if orderedTraversal: 26 jStart=i+1 27 else: 28 jStart=0 29 for j in range(jStart,len(pts)): 30 if j==i: 31 continue 32 if orderedTraversal: 33 kStart=j+1 34 else: 35 kStart=0 36 for k in range(j+1,len(pts)): 37 if k==i or k==j: 38 continue 39 yield (i,j,k)
40
41 -class SubshapeDistanceMetric(object):
42 TANIMOTO=0 43 PROTRUDE=1
44 45 # returns the distance between two shapea according to the provided metric
46 -def GetShapeShapeDistance(s1,s2,distMetric):
47 if distMetric==SubshapeDistanceMetric.PROTRUDE: 48 #print s1.grid.GetOccupancyVect().GetTotalVal(),s2.grid.GetOccupancyVect().GetTotalVal() 49 if s1.grid.GetOccupancyVect().GetTotalVal()<s2.grid.GetOccupancyVect().GetTotalVal(): 50 d = Geometry.ProtrudeDistance(s1.grid,s2.grid) 51 #print d 52 else: 53 d = Geometry.ProtrudeDistance(s2.grid,s1.grid) 54 else: 55 d = Geometry.TanimotoDistance(s1.grid,s2.grid) 56 return d
57 58 # clusters a set of alignments and returns the cluster centroid
59 -def ClusterAlignments(mol,alignments,builder, 60 neighborTol=0.1, 61 distMetric=SubshapeDistanceMetric.PROTRUDE, 62 tempConfId=1001):
63 from rdkit.ML.Cluster import Butina 64 dists = [] 65 for i in range(len(alignments)): 66 TransformMol(mol,alignments[i].transform,newConfId=tempConfId) 67 shapeI=builder.GenerateSubshapeShape(mol,tempConfId,addSkeleton=False) 68 for j in range(i): 69 TransformMol(mol,alignments[j].transform,newConfId=tempConfId+1) 70 shapeJ=builder.GenerateSubshapeShape(mol,tempConfId+1,addSkeleton=False) 71 d = GetShapeShapeDistance(shapeI,shapeJ,distMetric) 72 dists.append(d) 73 mol.RemoveConformer(tempConfId+1) 74 mol.RemoveConformer(tempConfId) 75 clusts=Butina.ClusterData(dists,len(alignments),neighborTol,isDistData=True) 76 res = [alignments[x[0]] for x in clusts] 77 return res
78
79 -def TransformMol(mol,tform,confId=-1,newConfId=100):
80 """ Applies the transformation to a molecule and sets it up with 81 a single conformer 82 83 """ 84 newConf = Chem.Conformer() 85 newConf.SetId(0) 86 refConf = mol.GetConformer(confId) 87 for i in range(refConf.GetNumAtoms()): 88 pos = list(refConf.GetAtomPosition(i)) 89 pos.append(1.0) 90 newPos = numpy.dot(tform,numpy.array(pos)) 91 newConf.SetAtomPosition(i,list(newPos)[:3]) 92 newConf.SetId(newConfId) 93 mol.RemoveConformer(newConfId) 94 mol.AddConformer(newConf,assignId=False)
95
96 -class SubshapeAligner(object):
97 triangleRMSTol=1.0 98 distMetric=SubshapeDistanceMetric.PROTRUDE 99 shapeDistTol=0.2 100 numFeatThresh=3 101 dirThresh=2.6 102 edgeTol=6.0 103 #coarseGridToleranceMult=1.5 104 #medGridToleranceMult=1.25 105 coarseGridToleranceMult=1.0 106 medGridToleranceMult=1.0 107
108 - def GetTriangleMatches(self,target,query):
109 """ this is a generator function returning the possible triangle 110 matches between the two shapes 111 """ 112 ssdTol = (self.triangleRMSTol**2)*9 113 res = [] 114 tgtPts = target.skelPts 115 queryPts = query.skelPts 116 tgtLs = {} 117 for i in range(len(tgtPts)): 118 for j in range(i+1,len(tgtPts)): 119 l2 = (tgtPts[i].location-tgtPts[j].location).LengthSq() 120 tgtLs[(i,j)]=l2 121 queryLs = {} 122 for i in range(len(queryPts)): 123 for j in range(i+1,len(queryPts)): 124 l2 = (queryPts[i].location-queryPts[j].location).LengthSq() 125 queryLs[(i,j)]=l2 126 compatEdges={} 127 tol2 = self.edgeTol*self.edgeTol 128 for tk,tv in tgtLs.iteritems(): 129 for qk,qv in queryLs.iteritems(): 130 if abs(tv-qv)<tol2: 131 compatEdges[(tk,qk)]=1 132 seqNo=0 133 for tgtTri in _getAllTriangles(tgtPts,orderedTraversal=True): 134 tgtLocs=[tgtPts[x].location for x in tgtTri] 135 for queryTri in _getAllTriangles(queryPts,orderedTraversal=False): 136 if compatEdges.has_key(((tgtTri[0],tgtTri[1]),(queryTri[0],queryTri[1]))) and \ 137 compatEdges.has_key(((tgtTri[0],tgtTri[2]),(queryTri[0],queryTri[2]))) and \ 138 compatEdges.has_key(((tgtTri[1],tgtTri[2]),(queryTri[1],queryTri[2]))): 139 queryLocs=[queryPts[x].location for x in queryTri] 140 ssd,tf = Alignment.GetAlignmentTransform(tgtLocs,queryLocs) 141 if ssd<=ssdTol: 142 alg = SubshapeAlignment() 143 alg.transform=tf 144 alg.triangleSSD=ssd 145 alg.targetTri=tgtTri 146 alg.queryTri=queryTri 147 alg._seqNo=seqNo 148 seqNo+=1 149 yield alg
150
151 - def _checkMatchFeatures(self,targetPts,queryPts,alignment):
152 nMatched=0 153 for i in range(3): 154 tgtFeats = targetPts[alignment.targetTri[i]].molFeatures 155 qFeats = queryPts[alignment.queryTri[i]].molFeatures 156 if not tgtFeats and not qFeats: 157 nMatched+=1 158 else: 159 for j,jFeat in enumerate(tgtFeats): 160 if jFeat in qFeats: 161 nMatched+=1 162 break 163 if nMatched>=self.numFeatThresh: 164 break 165 return nMatched>=self.numFeatThresh
166
167 - def PruneMatchesUsingFeatures(self,target,query,alignments,pruneStats=None):
168 i = 0 169 targetPts = target.skelPts 170 queryPts = query.skelPts 171 while i<len(alignments): 172 alg = alignments[i] 173 if not self._checkMatchFeatures(targetPts,queryPts,alg): 174 if pruneStats is not None: 175 pruneStats['features']=pruneStats.get('features',0)+1 176 del alignments[i] 177 else: 178 i+=1
179
180 - def _checkMatchDirections(self,targetPts,queryPts,alignment):
181 dot = 0.0 182 for i in range(3): 183 tgtPt = targetPts[alignment.targetTri[i]] 184 queryPt = queryPts[alignment.queryTri[i]] 185 qv = queryPt.shapeDirs[0] 186 tv = tgtPt.shapeDirs[0] 187 rotV =[0.0]*3 188 rotV[0] = alignment.transform[0,0]*qv[0]+alignment.transform[0,1]*qv[1]+alignment.transform[0,2]*qv[2] 189 rotV[1] = alignment.transform[1,0]*qv[0]+alignment.transform[1,1]*qv[1]+alignment.transform[1,2]*qv[2] 190 rotV[2] = alignment.transform[2,0]*qv[0]+alignment.transform[2,1]*qv[1]+alignment.transform[2,2]*qv[2] 191 dot += abs(rotV[0]*tv[0]+rotV[1]*tv[1]+rotV[2]*tv[2]) 192 if dot>=self.dirThresh: 193 # already above the threshold, no need to continue 194 break 195 alignment.dirMatch=dot 196 return dot>=self.dirThresh
197
198 - def PruneMatchesUsingDirection(self,target,query,alignments,pruneStats=None):
199 i = 0 200 tgtPts = target.skelPts 201 queryPts = query.skelPts 202 while i<len(alignments): 203 if not self._checkMatchDirections(tgtPts,queryPts,alignments[i]): 204 if pruneStats is not None: 205 pruneStats['direction']=pruneStats.get('direction',0)+1 206 del alignments[i] 207 else: 208 i+=1
209
210 - def _addCoarseAndMediumGrids(self,mol,tgt,confId,builder):
211 oSpace=builder.gridSpacing 212 if mol: 213 builder.gridSpacing = oSpace*1.5 214 tgt.medGrid = builder.GenerateSubshapeShape(mol,confId,addSkeleton=False) 215 builder.gridSpacing = oSpace*2 216 tgt.coarseGrid = builder.GenerateSubshapeShape(mol,confId,addSkeleton=False) 217 builder.gridSpacing = oSpace 218 else: 219 tgt.medGrid = builder.SampleSubshape(tgt,oSpace*1.5) 220 tgt.coarseGrid = builder.SampleSubshape(tgt,oSpace*2.0)
221
222 - def _checkMatchShape(self,targetMol,target,queryMol,query,alignment,builder, 223 targetConf,queryConf,pruneStats=None,tConfId=1001):
224 matchOk=True 225 TransformMol(queryMol,alignment.transform,confId=queryConf,newConfId=tConfId) 226 oSpace=builder.gridSpacing 227 builder.gridSpacing=oSpace*2 228 coarseGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False) 229 d = GetShapeShapeDistance(coarseGrid,target.coarseGrid,self.distMetric) 230 if d>self.shapeDistTol*self.coarseGridToleranceMult: 231 matchOk=False 232 if pruneStats is not None: 233 pruneStats['coarseGrid']=pruneStats.get('coarseGrid',0)+1 234 else: 235 builder.gridSpacing=oSpace*1.5 236 medGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False) 237 d = GetShapeShapeDistance(medGrid,target.medGrid,self.distMetric) 238 if d>self.shapeDistTol*self.medGridToleranceMult: 239 matchOk=False 240 if pruneStats is not None: 241 pruneStats['medGrid']=pruneStats.get('medGrid',0)+1 242 else: 243 builder.gridSpacing=oSpace 244 fineGrid=builder.GenerateSubshapeShape(queryMol,tConfId,addSkeleton=False) 245 d = GetShapeShapeDistance(fineGrid,target,self.distMetric) 246 #print ' ',d 247 if d>self.shapeDistTol: 248 matchOk=False 249 if pruneStats is not None: 250 pruneStats['fineGrid']=pruneStats.get('fineGrid',0)+1 251 alignment.shapeDist=d 252 queryMol.RemoveConformer(tConfId) 253 builder.gridSpacing=oSpace 254 return matchOk
255
256 - def PruneMatchesUsingShape(self,targetMol,target,queryMol,query,builder, 257 alignments,tgtConf=-1,queryConf=-1, 258 pruneStats=None):
259 if not hasattr(target,'medGrid'): 260 self._addCoarseAndMediumGrids(targetMol,target,tgtConf,builder) 261 262 logger.info("Shape-based Pruning") 263 i=0 264 nOrig = len(alignments) 265 nDone=0 266 while i < len(alignments): 267 removeIt=False 268 alg = alignments[i] 269 nDone+=1 270 if not nDone%100: 271 nLeft = len(alignments) 272 logger.info(' processed %d of %d. %d alignments remain'%((nDone, 273 nOrig, 274 nLeft))) 275 if not self._checkMatchShape(targetMol,target,queryMol,query,alg,builder, 276 targetConf=tgtConf,queryConf=queryConf, 277 pruneStats=pruneStats): 278 del alignments[i] 279 else: 280 i+=1
281
282 - def GetSubshapeAlignments(self,targetMol,target,queryMol,query,builder, 283 tgtConf=-1,queryConf=-1,pruneStats=None):
284 import time 285 if pruneStats is None: 286 pruneStats={} 287 logger.info("Generating triangle matches") 288 t1=time.time() 289 res = [x for x in self.GetTriangleMatches(target,query)] 290 t2=time.time() 291 logger.info("Got %d possible alignments in %.1f seconds"%(len(res),t2-t1)) 292 pruneStats['gtm_time']=t2-t1 293 if builder.featFactory: 294 logger.info("Doing feature pruning") 295 t1 = time.time() 296 self.PruneMatchesUsingFeatures(target,query,res,pruneStats=pruneStats) 297 t2 = time.time() 298 pruneStats['feats_time']=t2-t1 299 logger.info("%d possible alignments remain. (%.1f seconds required)"%(len(res),t2-t1)) 300 logger.info("Doing direction pruning") 301 t1 = time.time() 302 self.PruneMatchesUsingDirection(target,query,res,pruneStats=pruneStats) 303 t2 = time.time() 304 pruneStats['direction_time']=t2-t1 305 logger.info("%d possible alignments remain. (%.1f seconds required)"%(len(res),t2-t1)) 306 t1 = time.time() 307 self.PruneMatchesUsingShape(targetMol,target,queryMol,query,builder,res, 308 tgtConf=tgtConf,queryConf=queryConf, 309 pruneStats=pruneStats) 310 t2 = time.time() 311 pruneStats['shape_time']=t2-t1 312 return res
313
314 - def __call__(self,targetMol,target,queryMol,query,builder, 315 tgtConf=-1,queryConf=-1,pruneStats=None):
316 for alignment in self.GetTriangleMatches(target,query): 317 if builder.featFactory and \ 318 not self._checkMatchFeatures(target.skelPts,query.skelPts,alignment): 319 if pruneStats is not None: 320 pruneStats['features']=pruneStats.get('features',0)+1 321 continue 322 if not self._checkMatchDirections(target.skelPts,query.skelPts,alignment): 323 if pruneStats is not None: 324 pruneStats['direction']=pruneStats.get('direction',0)+1 325 continue 326 327 if not hasattr(target,'medGrid'): 328 self._addCoarseAndMediumGrids(targetMol,target,tgtConf,builder) 329 330 if not self._checkMatchShape(targetMol,target,queryMol,query,alignment,builder, 331 targetConf=tgtConf,queryConf=queryConf, 332 pruneStats=pruneStats): 333 continue 334 # if we made it this far, it's a good alignment 335 yield alignment
336 337 338 if __name__=='__main__': 339 from rdkit.six.moves import cPickle 340 tgtMol,tgtShape = cPickle.load(file('target.pkl','rb')) 341 queryMol,queryShape = cPickle.load(file('query.pkl','rb')) 342 builder = cPickle.load(file('builder.pkl','rb')) 343 aligner = SubshapeAligner() 344 algs = aligner.GetSubshapeAlignments(tgtMol,tgtShape,queryMol,queryShape,builder) 345 print(len(algs)) 346 347 from rdkit.Chem.PyMol import MolViewer 348 v = MolViewer() 349 v.ShowMol(tgtMol,name='Target',showOnly=True) 350 v.ShowMol(queryMol,name='Query',showOnly=False) 351 SubshapeObjects.DisplaySubshape(v,tgtShape,'target_shape',color=(.8,.2,.2)) 352 SubshapeObjects.DisplaySubshape(v,queryShape,'query_shape',color=(.2,.2,.8)) 353