1
2
3
4
5
6 """ code for dealing with composite models
7
8 For a model to be useable here, it should support the following API:
9
10 - _ClassifyExample(example)_, returns a classification
11
12 Other compatibility notes:
13
14 1) To use _Composite.Grow_ there must be some kind of builder
15 functionality which returns a 2-tuple containing (model,percent accuracy).
16
17 2) The models should be pickleable
18
19 3) It would be very happy if the models support the __cmp__ method so that
20 membership tests used to make sure models are unique work.
21
22
23
24 """
25 from __future__ import print_function
26 import math
27 import numpy
28 from rdkit.six.moves import cPickle
29 from rdkit.ML.Data import DataUtils
30
32 """a composite model
33
34
35 **Notes**
36
37 - adding a model which is already present just results in its count
38 field being incremented and the errors being averaged.
39
40 - typical usage:
41
42 1) grow the composite with AddModel until happy with it
43
44 2) call AverageErrors to calculate the average error values
45
46 3) call SortModels to put things in order by either error or count
47
48 - Composites can support individual models requiring either quantized or
49 nonquantized data. This is done by keeping a set of quantization bounds
50 (_QuantBounds_) in the composite and quantizing data passed in when required.
51 Quantization bounds can be set and interrogated using the
52 _Get/SetQuantBounds()_ methods. When models are added to the composite,
53 it can be indicated whether or not they require quantization.
54
55 - Composites are also capable of extracting relevant variables from longer lists.
56 This is accessible using _SetDescriptorNames()_ to register the descriptors about
57 which the composite cares and _SetInputOrder()_ to tell the composite what the
58 ordering of input vectors will be. **Note** there is a limitation on this: each
59 model needs to take the same set of descriptors as inputs. This could be changed.
60
61 """
63 self.modelList=[]
64 self.errList=[]
65 self.countList=[]
66 self.modelVotes=[]
67 self.quantBounds = None
68 self.nPossibleVals = None
69 self.quantizationRequirements=[]
70 self._descNames = []
71 self._mapOrder = None
72 self.activityQuant=[]
73
75 self._modelFilterFrac = modelFilterFrac
76 self._modelFilterVal = modelFilterVal
77
79 """ registers the names of the descriptors this composite uses
80
81 **Arguments**
82
83 - names: a list of descriptor names (strings).
84
85 **NOTE**
86
87 the _names_ list is not
88 copied, so if you modify it later, the composite itself will also be modified.
89
90 """
91 self._descNames = names
93 """ returns the names of the descriptors this composite uses
94
95 """
96 return self._descNames
97
99 """ sets the quantization bounds that the composite will use
100
101 **Arguments**
102
103 - qBounds: a list of quantization bounds, each quantbound is a
104 list of boundaries
105
106 - nPossible: a list of integers indicating how many possible values
107 each descriptor can take on.
108
109 **NOTE**
110
111 - if the two lists are of different lengths, this will assert out
112
113 - neither list is copied, so if you modify it later, the composite
114 itself will also be modified.
115
116
117 """
118 if nPossible is not None:
119 assert len(qBounds)==len(nPossible),'qBounds/nPossible mismatch'
120 self.quantBounds = qBounds
121 self.nPossibleVals = nPossible
122
124 """ returns the quantization bounds
125
126 **Returns**
127
128 a 2-tuple consisting of:
129
130 1) the list of quantization bounds
131
132 2) the nPossibleVals list
133
134 """
135 return self.quantBounds,self.nPossibleVals
136
138 if not hasattr(self,'activityQuant'):
139 self.activityQuant=[]
140 return self.activityQuant
142 self.activityQuant=bounds
144 if activityQuant is None:
145 activityQuant=self.activityQuant
146 if activityQuant:
147 example = example[:]
148 act = example[actCol]
149 for box in range(len(activityQuant)):
150 if act < activityQuant[box]:
151 act = box
152 break
153 else:
154 act = box + 1
155 example[actCol] = act
156 return example
157
159 """ quantizes an example
160
161 **Arguments**
162
163 - example: a data point (list, tuple or numpy array)
164
165 - quantBounds: a list of quantization bounds, each quantbound is a
166 list of boundaries. If this argument is not provided, the composite
167 will use its own quantBounds
168
169 **Returns**
170
171 the quantized example as a list
172
173 **Notes**
174
175 - If _example_ is different in length from _quantBounds_, this will
176 assert out.
177
178 - This is primarily intended for internal use
179
180 """
181 if quantBounds is None:
182 quantBounds = self.quantBounds
183 assert len(example)==len(quantBounds),'example/quantBounds mismatch'
184 quantExample = [None]*len(example)
185 for i in range(len(quantBounds)):
186 bounds = quantBounds[i]
187 p = example[i]
188 if len(bounds):
189 for box in range(len(bounds)):
190 if p < bounds[box]:
191 p = box
192 break
193 else:
194 p = box + 1
195 else:
196 if i != 0:
197 p = int(p)
198 quantExample[i] = p
199 return quantExample
200
202 """ creates a histogram of error/count pairs
203
204 **Returns**
205
206 the histogram as a series of (error, count) 2-tuples
207
208
209 """
210 nExamples = len(self.modelList)
211 histo = []
212 i = 1
213 lastErr = self.errList[0]
214 countHere = self.countList[0]
215 eps = 0.001
216 while i < nExamples:
217 if self.errList[i]-lastErr > eps:
218 histo.append((lastErr,countHere))
219 lastErr = self.errList[i]
220 countHere = self.countList[i]
221 else:
222 countHere = countHere + self.countList[i]
223 i = i + 1
224
225 return histo
226
227 - def CollectVotes(self,example,quantExample,appendExample=0,
228 onlyModels=None):
229 """ collects votes across every member of the composite for the given example
230
231 **Arguments**
232
233 - example: the example to be voted upon
234
235 - quantExample: the quantized form of the example
236
237 - appendExample: toggles saving the example on the models
238
239 - onlyModels: if provided, this should be a sequence of model
240 indices. Only the specified models will be used in the
241 prediction.
242
243 **Returns**
244
245 a list with a vote from each member
246
247 """
248 if not onlyModels:
249 onlyModels = range(len(self))
250
251 nModels = len(onlyModels)
252 votes = [-1]*len(self)
253 for i in onlyModels:
254 if self.quantizationRequirements[i]:
255 votes[i] = int(round(self.modelList[i].ClassifyExample(quantExample,
256 appendExamples=appendExample)))
257 else:
258 votes[i] = int(round(self.modelList[i].ClassifyExample(example,
259 appendExamples=appendExample)))
260
261 return votes
262
263 - def ClassifyExample(self,example,threshold=0,appendExample=0,
264 onlyModels=None):
265 """ classifies the given example using the entire composite
266
267 **Arguments**
268
269 - example: the data to be classified
270
271 - threshold: if this is a number greater than zero, then a
272 classification will only be returned if the confidence is
273 above _threshold_. Anything lower is returned as -1.
274
275 - appendExample: toggles saving the example on the models
276
277 - onlyModels: if provided, this should be a sequence of model
278 indices. Only the specified models will be used in the
279 prediction.
280
281 **Returns**
282
283 a (result,confidence) tuple
284
285
286 **FIX:**
287 statistics sucks... I'm not seeing an obvious way to get
288 the confidence intervals. For that matter, I'm not seeing
289 an unobvious way.
290
291 For now, this is just treated as a voting problem with the confidence
292 measure being the percent of models which voted for the winning result.
293
294 """
295 if self._mapOrder is not None:
296 example = self._RemapInput(example)
297 if self.GetActivityQuantBounds():
298 example = self.QuantizeActivity(example)
299 if self.quantBounds is not None and 1 in self.quantizationRequirements:
300 quantExample = self.QuantizeExample(example,self.quantBounds)
301 else:
302 quantExample = []
303
304 if not onlyModels:
305 onlyModels = range(len(self))
306 self.modelVotes = self.CollectVotes(example,quantExample,appendExample=appendExample,
307 onlyModels=onlyModels)
308
309 votes = [0]*self.nPossibleVals[-1]
310 for i in onlyModels:
311 res = self.modelVotes[i]
312 votes[res] = votes[res] + self.countList[i]
313
314 totVotes = sum(votes)
315 res = numpy.argmax(votes)
316 conf = float(votes[res])/float(totVotes)
317 if conf > threshold:
318 return res,conf
319 else:
320 return -1,conf
321
323 """ returns the votes from the last classification
324
325 This will be _None_ if nothing has yet be classified
326 """
327 return self.modelVotes
328
360
366
407
408 - def Grow(self,examples,attrs,nPossibleVals,buildDriver,pruner=None,
409 nTries=10,pruneIt=0,
410 needsQuantization=1,progressCallback=None,
411 **buildArgs):
412 """ Grows the composite
413
414 **Arguments**
415
416 - examples: a list of examples to be used in training
417
418 - attrs: a list of the variables to be used in training
419
420 - nPossibleVals: this is used to provide a list of the number
421 of possible values for each variable. It is used if the
422 local quantBounds have not been set (for example for when you
423 are working with data which is already quantized).
424
425 - buildDriver: the function to call to build the new models
426
427 - pruner: a function used to "prune" (reduce the complexity of)
428 the resulting model.
429
430 - nTries: the number of new models to add
431
432 - pruneIt: toggles whether or not pruning is done
433
434 - needsQuantization: used to indicate whether or not this type of model
435 requires quantized data
436
437 - **buildArgs: all other keyword args are passed to _buildDriver_
438
439 **Note**
440
441 - new models are *added* to the existing ones
442
443 """
444 try:
445 silent = buildArgs['silent']
446 except:
447 silent = 0
448 buildArgs['silent']=1
449 buildArgs['calcTotalError']=1
450
451 if self._mapOrder is not None:
452 examples = map(self._RemapInput,examples)
453 if self.GetActivityQuantBounds():
454 for i in range(len(examples)):
455 examples[i] = self.QuantizeActivity(examples[i])
456 nPossibleVals[-1]=len(self.GetActivityQuantBounds())+1
457 if self.nPossibleVals is None:
458 self.nPossibleVals = nPossibleVals[:]
459 if needsQuantization:
460 trainExamples = [None]*len(examples)
461 nPossibleVals = self.nPossibleVals
462 for i in range(len(examples)):
463 trainExamples[i] = self.QuantizeExample(examples[i],self.quantBounds)
464 else:
465 trainExamples = examples
466
467 for i in range(nTries):
468 trainSet = None
469
470 if (hasattr(self, '_modelFilterFrac')) and (self._modelFilterFrac != 0) :
471 trainIdx, temp = DataUtils.FilterData(trainExamples, self._modelFilterVal,
472 self._modelFilterFrac,-1, indicesOnly=1)
473 trainSet = [trainExamples[x] for x in trainIdx]
474
475 else:
476 trainSet = trainExamples
477
478
479 model,frac = buildDriver(*(trainSet,attrs,nPossibleVals), **buildArgs)
480 if pruneIt:
481 model,frac2 = pruner(model,model.GetTrainingExamples(),
482 model.GetTestExamples(),
483 minimizeTestErrorOnly=0)
484 frac = frac2
485 if hasattr(self, '_modelFilterFrac') and self._modelFilterFrac!=0 and \
486 hasattr(model,'_trainIndices'):
487
488 trainIndices = [trainIdx[x] for x in model._trainIndices]
489 model._trainIndices = trainIndices
490
491 self.AddModel(model,frac,needsQuantization)
492 if not silent and (nTries < 10 or i % (nTries/10) == 0):
493 print('Cycle: % 4d'%(i))
494 if progressCallback is not None:
495 progressCallback(i)
496
497
499 for i in range(len(self)):
500 m = self.GetModel(i)
501 try:
502 m.ClearExamples()
503 except AttributeError:
504 pass
505
506 - def Pickle(self,fileName='foo.pkl',saveExamples=0):
507 """ Writes this composite off to a file so that it can be easily loaded later
508
509 **Arguments**
510
511 - fileName: the name of the file to be written
512
513 - saveExamples: if this is zero, the individual models will have
514 their stored examples cleared.
515
516 """
517 if not saveExamples:
518 self.ClearModelExamples()
519
520 pFile = open(fileName,'wb+')
521 cPickle.dump(self,pFile,1)
522 pFile.close()
523
524 - def AddModel(self,model,error,needsQuantization=1):
525 """ Adds a model to the composite
526
527 **Arguments**
528
529 - model: the model to be added
530
531 - error: the model's error
532
533 - needsQuantization: a toggle to indicate whether or not this model
534 requires quantized inputs
535
536 **NOTE**
537
538 - this can be used as an alternative to _Grow()_ if you already have
539 some models constructed
540
541 - the errList is run as an accumulator,
542 you probably want to call _AverageErrors_ after finishing the forest
543
544 """
545 if model in self.modelList:
546 try:
547 idx = self.modelList.index(model)
548 except ValueError:
549
550 self.modelList.append(model)
551 self.errList.append(error)
552 self.countList.append(1)
553 self.quantizationRequirements.append(needsQuantization)
554 else:
555 self.errList[idx] = self.errList[idx]+error
556 self.countList[idx] = self.countList[idx] + 1
557 else:
558 self.modelList.append(model)
559 self.errList.append(error)
560 self.countList.append(1)
561 self.quantizationRequirements.append(needsQuantization)
562
564 """ convert local summed error to average error
565
566 """
567 self.errList = list(map(lambda x,y:x/y,self.errList,self.countList))
568
570 """ sorts the list of models
571
572 **Arguments**
573
574 sortOnError: toggles sorting on the models' errors rather than their counts
575
576
577 """
578 if sortOnError:
579 order = numpy.argsort(self.errList)
580 else:
581 order = numpy.argsort(self.countList)
582
583
584
585
586 self.modelList = [self.modelList[x] for x in order]
587 self.countList = [self.countList[x] for x in order]
588 self.errList = [self.errList[x] for x in order]
589
590
592 """ returns a particular model
593
594 """
595 return self.modelList[i]
597 """ replaces a particular model
598
599 **Note**
600
601 This is included for the sake of completeness, but you need to be
602 *very* careful when you use it.
603
604 """
605 self.modelList[i] = val
606
608 """ returns the count of the _i_th model
609
610 """
611 return self.countList[i]
613 """ sets the count of the _i_th model
614
615 """
616 self.countList[i] = val
617
619 """ returns the error of the _i_th model
620
621 """
622 return self.errList[i]
624 """ sets the error of the _i_th model
625
626 """
627 self.errList[i] = val
628
630 """ returns all relevant data about a particular model
631
632 **Arguments**
633
634 i: an integer indicating which model should be returned
635
636 **Returns**
637
638 a 3-tuple consisting of:
639
640 1) the model
641
642 2) its count
643
644 3) its error
645 """
646 return (self.modelList[i],self.countList[i],self.errList[i])
648 """ sets all relevant data for a particular tree in the forest
649
650 **Arguments**
651
652 - i: an integer indicating which model should be returned
653
654 - tup: a 3-tuple consisting of:
655
656 1) the model
657
658 2) its count
659
660 3) its error
661
662 **Note**
663
664 This is included for the sake of completeness, but you need to be
665 *very* careful when you use it.
666
667 """
668 self.modelList[i],self.countList[i],self.errList[i] = tup
669
671 """ Returns everything we know
672
673 **Returns**
674
675 a 3-tuple consisting of:
676
677 1) our list of models
678
679 2) our list of model counts
680
681 3) our list of model errors
682
683 """
684 return (self.modelList,self.countList,self.errList)
685
687 """ allows len(composite) to work
688
689 """
690 return len(self.modelList)
691
693 """ allows composite[i] to work, returns the data tuple
694
695 """
696 return self.GetDataTuple(which)
697
699 """ returns a string representation of the composite
700
701 """
702 outStr= 'Composite\n'
703 for i in range(len(self.modelList)):
704 outStr = outStr + \
705 ' Model % 4d: % 5d occurances %%% 5.2f average error\n'%(i,self.countList[i],
706 100.*self.errList[i])
707 return outStr
708
709 if __name__ == '__main__':
710 if 0:
711 from rdkit.ML.DecTree import DecTree
712 c = Composite()
713 n = DecTree.DecTreeNode(None,'foo')
714 c.AddModel(n,0.5)
715 c.AddModel(n,0.5)
716 c.AverageErrors()
717 c.SortModels()
718 print(c)
719
720 qB = [[],[.5,1,1.5]]
721 exs = [['foo',0],['foo',.4],['foo',.6],['foo',1.1],['foo',2.0]]
722 print('quantBounds:',qB)
723 for ex in exs:
724 q = c.QuantizeExample(ex,qB)
725 print(ex,q)
726 else:
727 pass
728