Package rdkit :: Package ML :: Package DecTree :: Module Tree
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.DecTree.Tree

  1  # 
  2  #  Copyright (C) 2000-2008  greg Landrum and Rational Discovery LLC 
  3  # 
  4  """ Implements a class used to represent N-ary trees 
  5   
  6  """ 
  7  from __future__ import print_function 
  8  import numpy 
  9  from rdkit.six.moves import cPickle 
 10  from rdkit.six import cmp 
 11   
 12  # FIX: the TreeNode class has not been updated to new-style classes 
 13  # (RD Issue380) because that would break all of our legacy pickled 
 14  # data. Until a solution is found for this breakage, an update is 
 15  # impossible. 
16 -class TreeNode:
17 """ This is your bog standard Tree class. 18 19 the root of the tree is just a TreeNode like all other members. 20 """
21 - def __init__(self,parent,name,label=None,data=None,level=0,isTerminal=0):
22 """ constructor 23 24 **Arguments** 25 26 - parent: the parent of this node in the tree 27 28 - name: the name of the node 29 30 - label: the node's label (should be an integer) 31 32 - data: an optional data field 33 34 - level: an integer indicating the level of this node in the hierarchy 35 (used for printing) 36 37 - isTerminal: flags a node as being terminal. This is useful for those 38 times when it's useful to know such things. 39 40 """ 41 self.children = [] 42 self.parent = parent 43 self.name = name 44 self.data = data 45 self.terminalNode = isTerminal 46 self.label = label 47 self.level = level 48 self.examples = []
49 - def NameTree(self,varNames):
50 """ Set the names of each node in the tree from a list of variable names. 51 52 **Arguments** 53 54 - varNames: a list of names to be assigned 55 56 **Notes** 57 58 1) this works its magic by recursively traversing all children 59 60 2) The assumption is made here that the varNames list can be indexed 61 by the labels of tree nodes 62 63 """ 64 if self.GetTerminal(): 65 return 66 else: 67 for child in self.GetChildren(): 68 child.NameTree(varNames) 69 self.SetName(varNames[self.GetLabel()])
70 NameModel=NameTree 71
72 - def AddChildNode(self,node):
73 """ Adds a TreeNode to the local list of children 74 75 **Arguments** 76 77 - node: the node to be added 78 79 **Note** 80 81 the level of the node (used in printing) is set as well 82 83 """ 84 node.SetLevel(self.level + 1) 85 self.children.append(node)
86
87 - def AddChild(self,name,label=None,data=None,isTerminal=0):
88 """ Creates a new TreeNode and adds a child to the tree 89 90 **Arguments** 91 92 - name: the name of the new node 93 94 - label: the label of the new node (should be an integer) 95 96 - data: the data to be stored in the new node 97 98 - isTerminal: a toggle to indicate whether or not the new node is 99 a terminal (leaf) node. 100 101 **Returns* 102 103 the _TreeNode_ which is constructed 104 105 """ 106 child = TreeNode(self,name,label,data,level=self.level+1,isTerminal=isTerminal) 107 self.children.append(child) 108 return child
109
110 - def PruneChild(self,child):
111 """ Removes the child node 112 113 **Arguments** 114 115 - child: a TreeNode 116 117 """ 118 self.children.remove(child)
119
120 - def ReplaceChildIndex(self,index,newChild):
121 """ Replaces a given child with a new one 122 123 **Arguments** 124 125 - index: an integer 126 127 - child: a TreeNode 128 129 """ 130 self.children[index] = newChild
131
132 - def GetChildren(self):
133 """ Returns a python list of the children of this node 134 135 """ 136 return self.children
137
138 - def Destroy(self):
139 """ Destroys this node and all of its children 140 141 """ 142 for child in self.children: 143 child.Destroy() 144 self.children = None 145 # clean up circular references 146 self.parent = None
147
148 - def GetName(self):
149 """ Returns the name of this node 150 151 """ 152 return self.name
153 - def SetName(self,name):
154 """ Sets the name of this node 155 156 """ 157 self.name = name
158
159 - def GetData(self):
160 """ Returns the data stored at this node 161 162 """ 163 return self.data
164 - def SetData(self,data):
165 """ Sets the data stored at this node 166 167 """ 168 self.data=data
169
170 - def GetTerminal(self):
171 """ Returns whether or not this node is terminal 172 173 """ 174 return self.terminalNode
175 - def SetTerminal(self,isTerminal):
176 """ Sets whether or not this node is terminal 177 178 """ 179 self.terminalNode = isTerminal
180
181 - def GetLabel(self):
182 """ Returns the label of this node 183 184 """ 185 return self.label
186 - def SetLabel(self,label):
187 """ Sets the label of this node (should be an integer) 188 189 """ 190 self.label=label
191
192 - def GetLevel(self):
193 """ Returns the level of this node 194 195 """ 196 return self.level
197 - def SetLevel(self,level):
198 """ Sets the level of this node 199 200 """ 201 self.level=level
202
203 - def GetParent(self):
204 """ Returns the parent of this node 205 206 """ 207 return self.parent
208 - def SetParent(self,parent):
209 """ Sets the parent of this node 210 211 """ 212 self.parent = parent
213 214
215 - def Print(self,level=0,showData=0):
216 """ Pretty prints the tree 217 218 **Arguments** 219 220 - level: sets the number of spaces to be added at the beginning of the output 221 222 - showData: if this is nonzero, the node's _data_ value will be printed as well 223 224 **Note** 225 226 this works recursively 227 228 """ 229 if showData: 230 print('%s%s: %s'%(' '*level,self.name,str(self.data))) 231 else: 232 print('%s%s'%(' '*level,self.name)) 233 234 for child in self.children: 235 child.Print(level+1,showData=showData)
236
237 - def Pickle(self,fileName='foo.pkl'):
238 """ Pickles the tree and writes it to disk 239 240 """ 241 with open(fileName,'wb+') as pFile: 242 cPickle.dump(self,pFile)
243
244 - def __str__(self):
245 """ returns a string representation of the tree 246 247 **Note** 248 249 this works recursively 250 251 """ 252 here = '%s%s\n'%(' '*self.level,self.name) 253 for child in self.children: 254 here = here + str(child) 255 return here
256
257 - def __cmp__(self,other):
258 """ allows tree1 == tree2 259 260 **Note** 261 262 This works recursively 263 """ 264 return (self<other)*-1 or (other<self)*1
265
266 - def __lt__(self,other):
267 """ allows tree1 < tree2 268 269 **Note** 270 271 This works recursively 272 """ 273 try: 274 nChildren = len(self.children) 275 oChildren=len(other.children) 276 if str(type(self))<str(type(other)): return True 277 if self.name<other.name: return True 278 if self.label is not None: 279 if other.label is not None: 280 if self.label<other.label: return True 281 else: 282 return False 283 elif other.label is not None: 284 return True 285 if nChildren<oChildren: return True 286 if nChildren>oChildren: return False 287 for i in range(nChildren): 288 if self.children[i]<other.children[i]: return True 289 except AttributeError: 290 return True 291 return False
292 - def __eq__(self,other):
293 return not self<other and not other<self
294 295 if __name__ == '__main__': 296 tree = TreeNode(None,'root') 297 for i in range(3): 298 child = tree.AddChild('child %d'%i) 299 print(tree) 300 tree.GetChildren()[1].AddChild('grandchild') 301 tree.GetChildren()[1].AddChild('grandchild2') 302 tree.GetChildren()[1].AddChild('grandchild3') 303 print(tree) 304 tree.Pickle('save.pkl') 305 print('prune') 306 tree.PruneChild(tree.GetChildren()[1]) 307 print('done') 308 print(tree) 309 310 import copy 311 tree2 = copy.deepcopy(tree) 312 print('tree==tree2', tree==tree2) 313 314 foo = [tree] 315 print('tree in [tree]:', tree in foo,foo.index(tree)) 316 print('tree2 in [tree]:', tree2 in foo, foo.index(tree2)) 317 318 tree2.GetChildren()[1].AddChild('grandchild4') 319 print('tree==tree2', tree==tree2) 320 tree.Destroy() 321