Package rdkit :: Package ML :: Package Cluster :: Module ClusterVis
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.Cluster.ClusterVis

  1  # $Id$ 
  2  # 
  3  # Copyright (C) 2001-2006  greg Landrum and Rational Discovery LLC 
  4  # 
  5  #   @@ All Rights Reserved @@ 
  6  #  This file is part of the RDKit. 
  7  #  The contents are covered by the terms of the BSD license 
  8  #  which is included in the file license.txt, found at the root 
  9  #  of the RDKit source tree. 
 10  # 
 11  """Cluster tree visualization using Sping 
 12   
 13  """ 
 14   
 15  try: 
 16    from rdkit.sping import pid 
 17    piddle = pid 
 18  except ImportError: 
 19    from rdkit.piddle import piddle 
 20  import ClusterUtils 
 21   
 22  import numpy 
 23   
24 -class VisOpts(object):
25 """ stores visualization options for cluster viewing 26 27 **Instance variables** 28 29 - x/yOffset: amount by which the drawing is offset from the edges of the canvas 30 31 - lineColor: default color for drawing the cluster tree 32 33 - lineWidth: the width of the lines used to draw the tree 34 35 """ 36 xOffset = 20 37 yOffset = 20 38 lineColor = piddle.Color(0,0,0) 39 hideColor = piddle.Color(.8,.8,.8) 40 terminalColors = [piddle.Color(1,0,0),piddle.Color(0,0,1),piddle.Color(1,1,0), 41 piddle.Color(0,.5,.5),piddle.Color(0,.8,0),piddle.Color(.5,.5,.5), 42 piddle.Color(.8,.3,.3),piddle.Color(.3,.3,.8),piddle.Color(.8,.8,.3), 43 piddle.Color(.3,.8,.8)] 44 lineWidth = 2 45 hideWidth = 1.1 46 nodeRad=15 47 nodeColor = piddle.Color(1.,.4,.4) 48 highlightColor = piddle.Color(1.,1.,.4) 49 highlightRad = 10
50
51 -def _scaleMetric(val,power=2,min=1e-4):
52 val = float(val) 53 nval = pow(val,power) 54 if nval < min: 55 return 0.0 56 else: 57 return numpy.log(nval/min)
58
59 -class ClusterRenderer(object):
60 - def __init__(self,canvas,size, 61 ptColors=[],lineWidth=None, 62 showIndices=0, 63 showNodes=1, 64 stopAtCentroids=0, 65 logScale=0, 66 tooClose=-1):
67 self.canvas = canvas 68 self.size = size 69 self.ptColors = ptColors 70 self.lineWidth = lineWidth 71 self.showIndices = showIndices 72 self.showNodes = showNodes 73 self.stopAtCentroids = stopAtCentroids 74 self.logScale = logScale 75 self.tooClose = tooClose
76
77 - def _AssignPointLocations(self,cluster,terminalOffset=4):
78 self.pts = cluster.GetPoints() 79 self.nPts = len(self.pts) 80 self.xSpace = float(self.size[0]-2*VisOpts.xOffset)/float(self.nPts-1) 81 ySize = self.size[1] 82 for i in xrange(self.nPts): 83 pt = self.pts[i] 84 if self.logScale > 0: 85 v = _scaleMetric(pt.GetMetric(), self.logScale) 86 else: 87 v = float(pt.GetMetric()) 88 pt._drawPos = (VisOpts.xOffset+i*self.xSpace, 89 ySize-(v*self.ySpace+VisOpts.yOffset)+terminalOffset)
90
91 - def _AssignClusterLocations(self,cluster):
92 # first get the search order (top down) 93 toDo = [cluster] 94 examine = cluster.GetChildren()[:] 95 while len(examine): 96 node = examine.pop(0) 97 children = node.GetChildren() 98 if len(children): 99 toDo.append(node) 100 for child in children: 101 if not child.IsTerminal(): 102 examine.append(child) 103 # and reverse it (to run from bottom up) 104 toDo.reverse() 105 for node in toDo: 106 if self.logScale > 0: 107 v = _scaleMetric(node.GetMetric(), self.logScale) 108 else: 109 v = float(node.GetMetric()) 110 # average our children's x positions 111 childLocs = [x._drawPos[0] for x in node.GetChildren()] 112 if len(childLocs): 113 xp = sum(childLocs)/float(len(childLocs)) 114 yp = self.size[1] - (v*self.ySpace+VisOpts.yOffset) 115 node._drawPos = (xp,yp)
116
117 - def _DrawToLimit(self,cluster):
118 """ 119 we assume that _drawPos settings have been done already 120 """ 121 if self.lineWidth is None: 122 lineWidth = VisOpts.lineWidth 123 else: 124 lineWidth = self.lineWidth 125 126 examine = [cluster] 127 while len(examine): 128 node = examine.pop(0) 129 xp,yp = node._drawPos 130 children = node.GetChildren() 131 if abs(children[1]._drawPos[0]-children[0]._drawPos[0])>self.tooClose: 132 # draw the horizontal line connecting things 133 drawColor = VisOpts.lineColor 134 self.canvas.drawLine(children[0]._drawPos[0],yp, 135 children[-1]._drawPos[0],yp, 136 drawColor,lineWidth) 137 # and draw the lines down to the children 138 for child in children: 139 if self.ptColors and child.GetData() is not None: 140 drawColor = self.ptColors[child.GetData()] 141 else: 142 drawColor = VisOpts.lineColor 143 cxp,cyp = child._drawPos 144 self.canvas.drawLine(cxp,yp,cxp,cyp,drawColor,lineWidth) 145 if not child.IsTerminal(): 146 examine.append(child) 147 else: 148 if self.showIndices and not self.stopAtCentroids: 149 try: 150 txt = str(child.GetName()) 151 except: 152 txt = str(child.GetIndex()) 153 self.canvas.drawString(txt, 154 cxp-self.canvas.stringWidth(txt)/2, 155 cyp) 156 157 else: 158 # draw a "hidden" line to the bottom 159 self.canvas.drawLine(xp,yp,xp,self.size[1]-VisOpts.yOffset, 160 VisOpts.hideColor,lineWidth)
161 162
163 - def DrawTree(self,cluster,minHeight=2.0):
164 if self.logScale > 0: 165 v = _scaleMetric(cluster.GetMetric(), self.logScale) 166 else: 167 v = float(cluster.GetMetric()) 168 if v <= 0: 169 v = minHeight 170 self.ySpace = float(self.size[1]-2*VisOpts.yOffset)/v 171 172 self._AssignPointLocations(cluster) 173 self._AssignClusterLocations(cluster) 174 if not self.stopAtCentroids: 175 self._DrawToLimit(cluster) 176 else: 177 raise NotImplementedError('stopAtCentroids drawing not yet implemented')
178 179
180 -def DrawClusterTree(cluster,canvas,size, 181 ptColors=[],lineWidth=None, 182 showIndices=0, 183 showNodes=1, 184 stopAtCentroids=0, 185 logScale=0, 186 tooClose=-1):
187 """ handles the work of drawing a cluster tree on a Sping canvas 188 189 **Arguments** 190 191 - cluster: the cluster tree to be drawn 192 193 - canvas: the Sping canvas on which to draw 194 195 - size: the size of _canvas_ 196 197 - ptColors: if this is specified, the _colors_ will be used to color 198 the terminal nodes of the cluster tree. (color == _pid.Color_) 199 200 - lineWidth: if specified, it will be used for the widths of the lines 201 used to draw the tree 202 203 **Notes** 204 205 - _Canvas_ is neither _save_d nor _flush_ed at the end of this 206 207 - if _ptColors_ is the wrong length for the number of possible terminal 208 node types, this will throw an IndexError 209 210 - terminal node types are determined using their _GetData()_ methods 211 212 """ 213 renderer = ClusterRenderer(canvas,size,ptColors,lineWidth,showIndices,showNodes,stopAtCentroids, 214 logScale,tooClose) 215 renderer.DrawTree(cluster)
216 -def _DrawClusterTree(cluster,canvas,size, 217 ptColors=[],lineWidth=None, 218 showIndices=0, 219 showNodes=1, 220 stopAtCentroids=0, 221 logScale=0, 222 tooClose=-1):
223 """ handles the work of drawing a cluster tree on a Sping canvas 224 225 **Arguments** 226 227 - cluster: the cluster tree to be drawn 228 229 - canvas: the Sping canvas on which to draw 230 231 - size: the size of _canvas_ 232 233 - ptColors: if this is specified, the _colors_ will be used to color 234 the terminal nodes of the cluster tree. (color == _pid.Color_) 235 236 - lineWidth: if specified, it will be used for the widths of the lines 237 used to draw the tree 238 239 **Notes** 240 241 - _Canvas_ is neither _save_d nor _flush_ed at the end of this 242 243 - if _ptColors_ is the wrong length for the number of possible terminal 244 node types, this will throw an IndexError 245 246 - terminal node types are determined using their _GetData()_ methods 247 248 """ 249 if lineWidth is None: 250 lineWidth = VisOpts.lineWidth 251 pts = cluster.GetPoints() 252 nPts = len(pts) 253 if nPts <= 1: return 254 xSpace = float(size[0]-2*VisOpts.xOffset)/float(nPts-1) 255 if logScale > 0: 256 v = _scaleMetric(cluster.GetMetric(), logScale) 257 else: 258 v = float(cluster.GetMetric()) 259 ySpace = float(size[1]-2*VisOpts.yOffset)/v 260 261 for i in xrange(nPts): 262 pt = pts[i] 263 if logScale > 0: 264 v = _scaleMetric(pt.GetMetric(), logScale) 265 else: 266 v = float(pt.GetMetric()) 267 pt._drawPos = (VisOpts.xOffset+i*xSpace, 268 size[1]-(v*ySpace+VisOpts.yOffset)) 269 if not stopAtCentroids or not hasattr(pt,'_isCentroid'): 270 allNodes.remove(pt) 271 272 if not stopAtCentroids: 273 allNodes=ClusterUtils.GetNodeList(cluster) 274 else: 275 allNodes=ClusterUtils.GetNodesDownToCentroids(cluster) 276 277 while len(allNodes): 278 node = allNodes.pop(0) 279 children = node.GetChildren() 280 if len(children): 281 if logScale > 0: 282 v = _scaleMetric(node.GetMetric(), logScale) 283 else: 284 v = float(node.GetMetric()) 285 yp = size[1]-(v*ySpace+VisOpts.yOffset) 286 childLocs = [x._drawPos[0] for x in children] 287 xp = sum(childLocs)/float(len(childLocs)) 288 node._drawPos = (xp,yp) 289 if not stopAtCentroids or node._aboveCentroid > 0: 290 for child in children: 291 if ptColors != [] and child.GetData() is not None: 292 drawColor = ptColors[child.GetData()] 293 else: 294 drawColor = VisOpts.lineColor 295 if showNodes and hasattr(child,'_isCentroid'): 296 canvas.drawLine(child._drawPos[0],child._drawPos[1]-VisOpts.nodeRad/2, 297 child._drawPos[0],node._drawPos[1], 298 drawColor,lineWidth) 299 else: 300 canvas.drawLine(child._drawPos[0],child._drawPos[1], 301 child._drawPos[0],node._drawPos[1], 302 drawColor,lineWidth) 303 canvas.drawLine(children[0]._drawPos[0],node._drawPos[1], 304 children[-1]._drawPos[0],node._drawPos[1], 305 VisOpts.lineColor,lineWidth) 306 else: 307 for child in children: 308 drawColor = VisOpts.hideColor 309 canvas.drawLine(child._drawPos[0],child._drawPos[1], 310 child._drawPos[0],node._drawPos[1], 311 drawColor,VisOpts.hideWidth) 312 canvas.drawLine(children[0]._drawPos[0],node._drawPos[1], 313 children[-1]._drawPos[0],node._drawPos[1], 314 VisOpts.hideColor,VisOpts.hideWidth) 315 316 if showIndices and (not stopAtCentroids or node._aboveCentroid >= 0): 317 txt = str(node.GetIndex()) 318 if hasattr(node,'_isCentroid'): 319 txtColor = piddle.Color(1,.2,.2) 320 else: 321 txtColor = piddle.Color(0,0,0) 322 323 canvas.drawString(txt, 324 node._drawPos[0]-canvas.stringWidth(txt)/2, 325 node._drawPos[1]+canvas.fontHeight()/4, 326 color=txtColor) 327 328 if showNodes and hasattr(node,'_isCentroid'): 329 rad = VisOpts.nodeRad 330 canvas.drawEllipse(node._drawPos[0]-rad/2,node._drawPos[1]-rad/2, 331 node._drawPos[0]+rad/2,node._drawPos[1]+rad/2, 332 piddle.transparent, 333 fillColor=VisOpts.nodeColor) 334 txt = str(node._clustID) 335 canvas.drawString(txt, 336 node._drawPos[0]-canvas.stringWidth(txt)/2, 337 node._drawPos[1]+canvas.fontHeight()/4, 338 color=piddle.Color(0,0,0)) 339 340 if showIndices and not stopAtCentroids: 341 for pt in pts: 342 txt = str(pt.GetIndex()) 343 canvas.drawString(str(pt.GetIndex()), 344 pt._drawPos[0]-canvas.stringWidth(txt)/2, 345 pt._drawPos[1])
346
347 -def ClusterToPDF(cluster,fileName,size=(300,300),ptColors=[],lineWidth=None, 348 showIndices=0,stopAtCentroids=0,logScale=0):
349 """ handles the work of drawing a cluster tree to an PDF file 350 351 **Arguments** 352 353 - cluster: the cluster tree to be drawn 354 355 - fileName: the name of the file to be created 356 357 - size: the size of output canvas 358 359 - ptColors: if this is specified, the _colors_ will be used to color 360 the terminal nodes of the cluster tree. (color == _pid.Color_) 361 362 - lineWidth: if specified, it will be used for the widths of the lines 363 used to draw the tree 364 365 **Notes** 366 367 - if _ptColors_ is the wrong length for the number of possible terminal 368 node types, this will throw an IndexError 369 370 - terminal node types are determined using their _GetData()_ methods 371 372 """ 373 try: 374 from rdkit.sping.PDF import pidPDF 375 except ImportError: 376 from rdkit.piddle import piddlePDF 377 pidPDF = piddlePDF 378 379 canvas = pidPDF.PDFCanvas(size,fileName) 380 if lineWidth is None: 381 lineWidth = VisOpts.lineWidth 382 DrawClusterTree(cluster,canvas,size,ptColors=ptColors,lineWidth=lineWidth, 383 showIndices=showIndices,stopAtCentroids=stopAtCentroids, 384 logScale=logScale) 385 if fileName: 386 canvas.save() 387 return canvas
388
389 -def ClusterToSVG(cluster,fileName,size=(300,300),ptColors=[],lineWidth=None, 390 showIndices=0,stopAtCentroids=0,logScale=0):
391 """ handles the work of drawing a cluster tree to an SVG file 392 393 **Arguments** 394 395 - cluster: the cluster tree to be drawn 396 397 - fileName: the name of the file to be created 398 399 - size: the size of output canvas 400 401 - ptColors: if this is specified, the _colors_ will be used to color 402 the terminal nodes of the cluster tree. (color == _pid.Color_) 403 404 - lineWidth: if specified, it will be used for the widths of the lines 405 used to draw the tree 406 407 **Notes** 408 409 - if _ptColors_ is the wrong length for the number of possible terminal 410 node types, this will throw an IndexError 411 412 - terminal node types are determined using their _GetData()_ methods 413 414 """ 415 try: 416 from rdkit.sping.SVG import pidSVG 417 except ImportError: 418 from rdkit.piddle.piddleSVG import piddleSVG 419 pidSVG = piddleSVG 420 421 canvas = pidSVG.SVGCanvas(size,fileName) 422 423 if lineWidth is None: 424 lineWidth = VisOpts.lineWidth 425 DrawClusterTree(cluster,canvas,size,ptColors=ptColors,lineWidth=lineWidth, 426 showIndices=showIndices,stopAtCentroids=stopAtCentroids, 427 logScale=logScale) 428 if fileName: 429 canvas.save() 430 return canvas
431
432 -def ClusterToImg(cluster,fileName,size=(300,300),ptColors=[],lineWidth=None, 433 showIndices=0,stopAtCentroids=0,logScale=0):
434 """ handles the work of drawing a cluster tree to an image file 435 436 **Arguments** 437 438 - cluster: the cluster tree to be drawn 439 440 - fileName: the name of the file to be created 441 442 - size: the size of output canvas 443 444 - ptColors: if this is specified, the _colors_ will be used to color 445 the terminal nodes of the cluster tree. (color == _pid.Color_) 446 447 - lineWidth: if specified, it will be used for the widths of the lines 448 used to draw the tree 449 450 **Notes** 451 452 - The extension on _fileName_ determines the type of image file created. 453 All formats supported by PIL can be used. 454 455 - if _ptColors_ is the wrong length for the number of possible terminal 456 node types, this will throw an IndexError 457 458 - terminal node types are determined using their _GetData()_ methods 459 460 """ 461 try: 462 from rdkit.sping.PIL import pidPIL 463 except ImportError: 464 from rdkit.piddle import piddlePIL 465 pidPIL = piddlePIL 466 canvas = pidPIL.PILCanvas(size,fileName) 467 if lineWidth is None: 468 lineWidth = VisOpts.lineWidth 469 DrawClusterTree(cluster,canvas,size,ptColors=ptColors,lineWidth=lineWidth, 470 showIndices=showIndices,stopAtCentroids=stopAtCentroids, 471 logScale=logScale) 472 if fileName: 473 canvas.save() 474 return canvas
475