1
2
3
4 """ Implements a class used to represent N-ary trees
5
6 """
7 from __future__ import print_function
8 import numpy
9 from rdkit.six.moves import cPickle
10 from rdkit.six import cmp
11
12
13
14
15
17 """ This is your bog standard Tree class.
18
19 the root of the tree is just a TreeNode like all other members.
20 """
21 - def __init__(self,parent,name,label=None,data=None,level=0,isTerminal=0):
22 """ constructor
23
24 **Arguments**
25
26 - parent: the parent of this node in the tree
27
28 - name: the name of the node
29
30 - label: the node's label (should be an integer)
31
32 - data: an optional data field
33
34 - level: an integer indicating the level of this node in the hierarchy
35 (used for printing)
36
37 - isTerminal: flags a node as being terminal. This is useful for those
38 times when it's useful to know such things.
39
40 """
41 self.children = []
42 self.parent = parent
43 self.name = name
44 self.data = data
45 self.terminalNode = isTerminal
46 self.label = label
47 self.level = level
48 self.examples = []
50 """ Set the names of each node in the tree from a list of variable names.
51
52 **Arguments**
53
54 - varNames: a list of names to be assigned
55
56 **Notes**
57
58 1) this works its magic by recursively traversing all children
59
60 2) The assumption is made here that the varNames list can be indexed
61 by the labels of tree nodes
62
63 """
64 if self.GetTerminal():
65 return
66 else:
67 for child in self.GetChildren():
68 child.NameTree(varNames)
69 self.SetName(varNames[self.GetLabel()])
70 NameModel=NameTree
71
73 """ Adds a TreeNode to the local list of children
74
75 **Arguments**
76
77 - node: the node to be added
78
79 **Note**
80
81 the level of the node (used in printing) is set as well
82
83 """
84 node.SetLevel(self.level + 1)
85 self.children.append(node)
86
87 - def AddChild(self,name,label=None,data=None,isTerminal=0):
88 """ Creates a new TreeNode and adds a child to the tree
89
90 **Arguments**
91
92 - name: the name of the new node
93
94 - label: the label of the new node (should be an integer)
95
96 - data: the data to be stored in the new node
97
98 - isTerminal: a toggle to indicate whether or not the new node is
99 a terminal (leaf) node.
100
101 **Returns*
102
103 the _TreeNode_ which is constructed
104
105 """
106 child = TreeNode(self,name,label,data,level=self.level+1,isTerminal=isTerminal)
107 self.children.append(child)
108 return child
109
111 """ Removes the child node
112
113 **Arguments**
114
115 - child: a TreeNode
116
117 """
118 self.children.remove(child)
119
121 """ Replaces a given child with a new one
122
123 **Arguments**
124
125 - index: an integer
126
127 - child: a TreeNode
128
129 """
130 self.children[index] = newChild
131
133 """ Returns a python list of the children of this node
134
135 """
136 return self.children
137
139 """ Destroys this node and all of its children
140
141 """
142 for child in self.children:
143 child.Destroy()
144 self.children = None
145
146 self.parent = None
147
149 """ Returns the name of this node
150
151 """
152 return self.name
154 """ Sets the name of this node
155
156 """
157 self.name = name
158
160 """ Returns the data stored at this node
161
162 """
163 return self.data
165 """ Sets the data stored at this node
166
167 """
168 self.data=data
169
171 """ Returns whether or not this node is terminal
172
173 """
174 return self.terminalNode
176 """ Sets whether or not this node is terminal
177
178 """
179 self.terminalNode = isTerminal
180
182 """ Returns the label of this node
183
184 """
185 return self.label
187 """ Sets the label of this node (should be an integer)
188
189 """
190 self.label=label
191
193 """ Returns the level of this node
194
195 """
196 return self.level
198 """ Sets the level of this node
199
200 """
201 self.level=level
202
204 """ Returns the parent of this node
205
206 """
207 return self.parent
209 """ Sets the parent of this node
210
211 """
212 self.parent = parent
213
214
215 - def Print(self,level=0,showData=0):
216 """ Pretty prints the tree
217
218 **Arguments**
219
220 - level: sets the number of spaces to be added at the beginning of the output
221
222 - showData: if this is nonzero, the node's _data_ value will be printed as well
223
224 **Note**
225
226 this works recursively
227
228 """
229 if showData:
230 print('%s%s: %s'%(' '*level,self.name,str(self.data)))
231 else:
232 print('%s%s'%(' '*level,self.name))
233
234 for child in self.children:
235 child.Print(level+1,showData=showData)
236
237 - def Pickle(self,fileName='foo.pkl'):
238 """ Pickles the tree and writes it to disk
239
240 """
241 with open(fileName,'wb+') as pFile:
242 cPickle.dump(self,pFile)
243
245 """ returns a string representation of the tree
246
247 **Note**
248
249 this works recursively
250
251 """
252 here = '%s%s\n'%(' '*self.level,self.name)
253 for child in self.children:
254 here = here + str(child)
255 return here
256
258 """ allows tree1 == tree2
259
260 **Note**
261
262 This works recursively
263 """
264 return (self<other)*-1 or (other<self)*1
265
267 """ allows tree1 < tree2
268
269 **Note**
270
271 This works recursively
272 """
273 try:
274 nChildren = len(self.children)
275 oChildren=len(other.children)
276 if str(type(self))<str(type(other)): return True
277 if self.name<other.name: return True
278 if self.label is not None:
279 if other.label is not None:
280 if self.label<other.label: return True
281 else:
282 return False
283 elif other.label is not None:
284 return True
285 if nChildren<oChildren: return True
286 if nChildren>oChildren: return False
287 for i in range(nChildren):
288 if self.children[i]<other.children[i]: return True
289 except AttributeError:
290 return True
291 return False
293 return not self<other and not other<self
294
295 if __name__ == '__main__':
296 tree = TreeNode(None,'root')
297 for i in range(3):
298 child = tree.AddChild('child %d'%i)
299 print(tree)
300 tree.GetChildren()[1].AddChild('grandchild')
301 tree.GetChildren()[1].AddChild('grandchild2')
302 tree.GetChildren()[1].AddChild('grandchild3')
303 print(tree)
304 tree.Pickle('save.pkl')
305 print('prune')
306 tree.PruneChild(tree.GetChildren()[1])
307 print('done')
308 print(tree)
309
310 import copy
311 tree2 = copy.deepcopy(tree)
312 print('tree==tree2', tree==tree2)
313
314 foo = [tree]
315 print('tree in [tree]:', tree in foo,foo.index(tree))
316 print('tree2 in [tree]:', tree2 in foo, foo.index(tree2))
317
318 tree2.GetChildren()[1].AddChild('grandchild4')
319 print('tree==tree2', tree==tree2)
320 tree.Destroy()
321