Package rdkit :: Package ML :: Package ModelPackage :: Module PackageUtils
[hide private]
[frames] | no frames]

Source Code for Module rdkit.ML.ModelPackage.PackageUtils

  1  # 
  2  # Copyright (C) 2003 Rational Discovery LLC 
  3  # All rights are reserved. 
  4  # 
  5  from __future__ import print_function 
  6  from elementtree.ElementTree import ElementTree,Element,SubElement 
  7  import time 
  8   
  9   
10 -def _ConvertModelPerformance(perf,modelPerf):
11 if len(modelPerf)>3: 12 confMat = modelPerf[3] 13 accum = 0 14 for row in confMat: 15 for entry in row: 16 accum += entry 17 accum = str(accum) 18 else: 19 confMat = None 20 accum = 'N/A' 21 22 if len(modelPerf)>4: 23 elem = SubElement(perf,"ScreenThreshold") 24 elem.text=str(modelPerf[4]) 25 elem = SubElement(perf,"NumScreened") 26 elem.text=accum 27 if len(modelPerf)>4: 28 elem = SubElement(perf,"NumSkipped") 29 elem.text=str(modelPerf[6]) 30 elem = SubElement(perf,"Accuracy") 31 elem.text=str(modelPerf[0]) 32 elem = SubElement(perf,"AvgCorrectConf") 33 elem.text=str(modelPerf[1]) 34 elem = SubElement(perf,"AvgIncorrectConf") 35 elem.text=str(modelPerf[2]) 36 if len(modelPerf)>4: 37 elem = SubElement(perf,"AvgSkipConf") 38 elem.text=str(modelPerf[5]) 39 if confMat: 40 elem = SubElement(perf,"ConfusionMatrix") 41 elem.text = str(confMat)
42
43 -def PackageToXml(pkg,summary="N/A",trainingDataId='N/A', 44 dataPerformance=[], 45 recommendedThreshold=None, 46 classDescriptions=[], 47 modelType=None, 48 modelOrganism=None):
49 """ generates XML for a package that follows the RD_Model.dtd 50 51 If provided, dataPerformance should be a sequence of 2-tuples: 52 ( note, performance ) 53 where performance is of the form: 54 ( accuracy, avgCorrectConf, avgIncorrectConf, confusionMatrix, thresh, avgSkipConf, nSkipped ) 55 the last four elements are optional 56 57 """ 58 head = Element("RDModelInfo") 59 name = SubElement(head,"ModelName") 60 notes = pkg.GetNotes() 61 if not notes: 62 notes = "Unnamed model" 63 name.text = notes 64 summ = SubElement(head,"ModelSummary") 65 summ.text = summary 66 calc = pkg.GetCalculator() 67 descrs = SubElement(head,"ModelDescriptors") 68 for name,summary,func in zip(calc.GetDescriptorNames(),calc.GetDescriptorSummaries(),calc.GetDescriptorFuncs()): 69 descr = SubElement(descrs,"Descriptor") 70 elem = SubElement(descr,"DescriptorName") 71 elem.text = name 72 elem = SubElement(descr,"DescriptorDetail") 73 elem.text = summary 74 if hasattr(func,'version'): 75 vers = SubElement(descr,"DescriptorVersion") 76 major,minor,patch = func.version.split('.') 77 elem = SubElement(vers,"VersionMajor") 78 elem.text = major 79 elem = SubElement(vers,"VersionMinor") 80 elem.text = minor 81 elem = SubElement(vers,"VersionPatch") 82 elem.text = patch 83 84 elem = SubElement(head,"TrainingDataId") 85 elem.text = trainingDataId 86 87 for description,perfData in dataPerformance: 88 dataNode = SubElement(head,"ValidationData") 89 note = SubElement(dataNode,'ScreenNote') 90 note.text = description 91 perf = SubElement(dataNode,"PerformanceData") 92 _ConvertModelPerformance(perf,perfData) 93 94 95 if recommendedThreshold: 96 elem = SubElement(head,"RecommendedThreshold") 97 elem.text=str(recommendedThreshold) 98 99 if classDescriptions: 100 elem = SubElement(head,"ClassDescriptions") 101 for val,text in classDescriptions: 102 descr = SubElement(elem,'ClassDescription') 103 valElem = SubElement(descr,'ClassVal') 104 valElem.text = str(val) 105 valText = SubElement(descr,'ClassText') 106 valText.text = str(text) 107 108 if modelType: 109 elem = SubElement(head,"ModelType") 110 elem.text=modelType 111 if modelOrganism: 112 elem = SubElement(head,"ModelOrganism") 113 elem.text=modelOrganism 114 115 116 hist = SubElement(head,"ModelHistory") 117 revision = SubElement(hist,"Revision") 118 tm = time.localtime() 119 date = SubElement(revision,"RevisionDate") 120 elem = SubElement(date,"Year") 121 elem.text=str(tm[0]) 122 elem = SubElement(date,"Month") 123 elem.text=str(tm[1]) 124 elem = SubElement(date,"Day") 125 elem.text=str(tm[2]) 126 note = SubElement(revision,"RevisionNote") 127 note.text = "Created" 128 129 return ElementTree(head)
130 131 132 if __name__=='__main__': 133 import sys 134 from rdkit.six.moves import cPickle 135 from cStringIO import StringIO 136 pkg = cPickle.load(open(sys.argv[1],'rb')) 137 perf = (.80,.95,.70,[[4,1],[1,4]]) 138 tree = PackageToXml(pkg,dataPerformance=[('training data performance',perf)]) 139 io = StringIO() 140 tree.write(io) 141 txt = io.getvalue() 142 header = """<?xml version="1.0"?> 143 <!DOCTYPE RDModelInfo PUBLIC "-//RD//DTD RDModelInfo //EN" "RD_Model.dtd"> 144 """ 145 print(header) 146 print(txt.replace('><','>\n<')) 147