Package rdkit :: Package ML :: Package DecTree :: Module TreeVis
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.TreeVis

  1  # $Id$ 
  2  # 
  3  #  Copyright (C) 2002,2003  Greg Landrum and Rational Discovery LLC 
  4  #    All Rights Reserved 
  5  # 
  6  """ functionality for drawing trees on sping canvases 
  7   
  8  """     
  9  from rdkit.sping import pid as piddle 
 10  import math 
 11   
12 -class VisOpts(object):
13 circRad = 10 14 minCircRad = 4 15 maxCircRad = 16 16 circColor = piddle.Color(0.6,0.6,0.9) 17 terminalEmptyColor = piddle.Color(.8,.8,.2) 18 terminalOnColor = piddle.Color(0.8,0.8,0.8) 19 terminalOffColor = piddle.Color(0.2,0.2,0.2) 20 outlineColor = piddle.transparent 21 lineColor = piddle.Color(0,0,0) 22 lineWidth = 2 23 horizOffset = 10 24 vertOffset = 50 25 labelFont = piddle.Font(face='helvetica',size=10) 26 highlightColor = piddle.Color(1.,1.,.4) 27 highlightWidth = 2
28 29 visOpts = VisOpts() 30
31 -def CalcTreeNodeSizes(node):
32 """Recursively calculate the total number of nodes under us. 33 34 results are set in node.totNChildren for this node and 35 everything underneath it. 36 """ 37 children = node.GetChildren() 38 if len(children) > 0: 39 nHere = 0 40 nBelow=0 41 for child in children: 42 CalcTreeNodeSizes(child) 43 nHere = nHere + child.totNChildren 44 if child.nLevelsBelow > nBelow: 45 nBelow = child.nLevelsBelow 46 else: 47 nBelow = 0 48 nHere = 1 49 50 node.nExamples = len(node.GetExamples()) 51 node.totNChildren = nHere 52 node.nLevelsBelow = nBelow+1
53
54 -def _ExampleCounter(node,min,max):
55 if node.GetTerminal(): 56 cnt = node.nExamples 57 if cnt < min: min = cnt 58 if cnt > max: max = cnt 59 else: 60 for child in node.GetChildren(): 61 provMin,provMax = _ExampleCounter(child,min,max) 62 if provMin < min: min = provMin 63 if provMax > max: max = provMax 64 return min,max
65
66 -def _ApplyNodeScales(node,min,max):
67 if node.GetTerminal(): 68 if max!=min: 69 loc = float(node.nExamples - min)/(max-min) 70 else: 71 loc = .5 72 node._scaleLoc = loc 73 else: 74 for child in node.GetChildren(): 75 _ApplyNodeScales(child,min,max)
76
77 -def SetNodeScales(node):
78 min,max = 1e8,-1e8 79 min,max = _ExampleCounter(node,min,max) 80 node._scales=min,max 81 _ApplyNodeScales(node,min,max)
82 83
84 -def DrawTreeNode(node,loc,canvas,nRes=2,scaleLeaves=False,showPurity=False):
85 """Recursively displays the given tree node and all its children on the canvas 86 """ 87 try: 88 nChildren = node.totNChildren 89 except AttributeError: 90 nChildren = None 91 if nChildren is None: 92 CalcTreeNodeSizes(node) 93 94 if not scaleLeaves or not node.GetTerminal(): 95 rad = visOpts.circRad 96 else: 97 try: 98 scaleLoc = node._scaleLoc 99 except: 100 scaleLoc = 0.5 101 102 rad = visOpts.minCircRad + node._scaleLoc*(visOpts.maxCircRad-visOpts.minCircRad) 103 104 x1 = loc[0] - rad 105 y1 = loc[1] - rad 106 x2 = loc[0] + rad 107 y2 = loc[1] + rad 108 109 110 if showPurity and node.GetTerminal(): 111 examples = node.GetExamples() 112 nEx = len(examples) 113 if nEx: 114 tgtVal = int(node.GetLabel()) 115 purity = 0.0 116 for ex in examples: 117 if int(ex[-1])==tgtVal: 118 purity += 1./len(examples) 119 else: 120 purity = 1.0 121 122 deg = purity*math.pi 123 xFact = rad*math.sin(deg) 124 yFact = rad*math.cos(deg) 125 pureX = loc[0]+xFact 126 pureY = loc[1]+yFact 127 128 129 children = node.GetChildren() 130 # just move down one level 131 childY = loc[1] + visOpts.vertOffset 132 # this is the left-hand side of the leftmost span 133 childX = loc[0] - ((visOpts.horizOffset+visOpts.circRad)*node.totNChildren)/2 134 for i in range(len(children)): 135 # center on this child's space 136 child = children[i] 137 halfWidth = ((visOpts.horizOffset+visOpts.circRad)*child.totNChildren)/2 138 139 childX = childX + halfWidth 140 childLoc = [childX,childY] 141 canvas.drawLine(loc[0],loc[1],childLoc[0],childLoc[1], 142 visOpts.lineColor,visOpts.lineWidth) 143 DrawTreeNode(child,childLoc,canvas,nRes=nRes,scaleLeaves=scaleLeaves, 144 showPurity=showPurity) 145 146 # and move over to the leftmost point of the next child 147 childX = childX + halfWidth 148 149 if node.GetTerminal(): 150 lab = node.GetLabel() 151 cFac = float(lab)/float(nRes-1) 152 if hasattr(node,'GetExamples') and node.GetExamples(): 153 theColor = (1.-cFac)*visOpts.terminalOffColor + cFac*visOpts.terminalOnColor 154 outlColor = visOpts.outlineColor 155 else: 156 theColor = (1.-cFac)*visOpts.terminalOffColor + cFac*visOpts.terminalOnColor 157 outlColor = visOpts.terminalEmptyColor 158 canvas.drawEllipse(x1,y1,x2,y2, 159 outlColor,visOpts.lineWidth, 160 theColor) 161 if showPurity: 162 canvas.drawLine(loc[0],loc[1],pureX,pureY,piddle.Color(1,1,1),2) 163 else: 164 theColor = visOpts.circColor 165 canvas.drawEllipse(x1,y1,x2,y2, 166 visOpts.outlineColor,visOpts.lineWidth, 167 theColor) 168 169 # this does not need to be done every time 170 canvas.defaultFont=visOpts.labelFont 171 172 labelStr = str(node.GetLabel()) 173 strLoc = (loc[0] - canvas.stringWidth(labelStr)/2, 174 loc[1]+canvas.fontHeight()/4) 175 176 canvas.drawString(labelStr,strLoc[0],strLoc[1]) 177 node._bBox = (x1,y1,x2,y2)
178
179 -def CalcTreeWidth(tree):
180 try: 181 tree.totNChildren 182 except AttributeError: 183 CalcTreeNodeSizes(tree) 184 totWidth = tree.totNChildren * (visOpts.circRad+visOpts.horizOffset) 185 return totWidth
186
187 -def DrawTree(tree,canvas,nRes=2,scaleLeaves=False,allowShrink=True,showPurity=False):
188 dims = canvas.size 189 loc = (dims[0]/2,visOpts.vertOffset) 190 if scaleLeaves: 191 #try: 192 # l = tree._scales 193 #except AttributeError: 194 # l = None 195 #if l is None: 196 SetNodeScales(tree) 197 if allowShrink: 198 treeWid = CalcTreeWidth(tree) 199 while treeWid > dims[0]: 200 visOpts.circRad /= 2 201 visOpts.horizOffset /= 2 202 treeWid = CalcTreeWidth(tree) 203 DrawTreeNode(tree,loc,canvas,nRes,scaleLeaves=scaleLeaves, 204 showPurity=showPurity)
205
206 -def ResetTree(tree):
207 tree._scales = None 208 tree.totNChildren = None 209 for child in tree.GetChildren(): 210 ResetTree(child)
211
212 -def _simpleTest(canv):
213 from Tree import TreeNode as Node 214 root = Node(None,'r',label='r') 215 c1 = root.AddChild('l1_1',label='l1_1') 216 c2 = root.AddChild('l1_2',isTerminal=1,label=1) 217 c3 = c1.AddChild('l2_1',isTerminal=1,label=0) 218 c4 = c1.AddChild('l2_2',isTerminal=1,label=1) 219 220 DrawTreeNode(root,(150,visOpts.vertOffset),canv)
221 222 223 if __name__ == '__main__': 224 from rdkit.sping.PIL.pidPIL import PILCanvas 225 canv = PILCanvas(size=(300,300),name='test.png') 226 _simpleTest(canv) 227 canv.save() 228