#!/usr/bin/env python # Current Author: Aaron M. Cohen (cohenaa@ohsu.edu) # Calculations by Aaron Cohen and William Hersh # 2.0 Initial version requiring Ur parameter, # Based on cat_eval by Ravi Teja Bhupatiraju. import sys from os.path import exists import re, string from sets import Set def write(msg): 'generic output' # for now we don't use stderr sys.stdout.write(msg + '\n') def err(msg): 'generic error handler' write(msg) sys.exit(1) def isnum(num): 'tests if the given string is a numerical one' try: int(num) return True except: return False def CheckDataFile(data_file, type): 'performs some basic syntax checks on the input files' if type == 'gold': pop_index = 0 elif type == 'data': pop_index = 1 else: err('Invalid Type') for line_no, line in enumerate( open(data_file).readlines() ): items = line.split('\t') # now check individually if not isnum(items.pop(pop_index)): err('Missing PMID in ' + data_file + ' at ' + str(line_no)) if type == 'data': err('Are you sure this is a data file?') elif type == 'gold': err('Are you sure this is a gold standard file?') for item in items: if isnum(item): err('Numeric value in a non-numeric location ' + str(line_no) + ' in ' + data_file) def GetLines(file, slice=False): """ return a set of lines without the newlines and optionally slices the lines by elements """ items = [] f = open(file) for line in f.readlines(): # split by any white space character elem = re.split('\t', line.replace('\n', '')) # slice if required, else just add the line unchanged def add(tElem): items.append( string.join(tElem, '\t') ) if not slice: add( elem ) else: add( elem[slice[0]:slice[1]] ) f.close() unique = Set( items ) # We actually return a Set instead of a list so that we can use # the Set operators. But it would be important to see if we are # loosing any teeth in the conversion. if len( items ) != len( unique ): print 'The file %s contains redundant elemnts' % file return unique def GetGoldStandard(gold_standard_file): return GetLines(gold_standard_file) def GetRetrieved(retrieved_file): # ignore the first and last elements return GetLines( retrieved_file, (1, -1) ) def GetTestSet(crosswalk_file): unique = Set() f = open(crosswalk_file) for line in f.readlines(): elems = line.split('\t') unique.add( elems[1] ) f.close() return unique def ShowOutput(gold_standard_file, retrieved_file, csv, Ur): # the argument is csv but we are really implementing a tab delimited output of argument SGoldStandard = GetGoldStandard(gold_standard_file) SRetrieved = GetRetrieved(retrieved_file) tag = open(retrieved_file).readline().split('\t')[-1].replace('\n', '') ap = len(SGoldStandard) tp = len(SRetrieved.intersection(SGoldStandard)) fp = len(SRetrieved - SGoldStandard) # removed check, allow for tp == 0... #if tp == 0: # err('You did not retrieve any documents. Aborting!') # Aaron's calc fn = ap - tp ufactor = Ur precisionDenom = tp + fp recallDenom = tp + fn def GetTruth(): return "tp=%d; fp=%d; fn=%d;" % (tp,fp,fn) if precisionDenom > 0 and recallDenom > 0: precision = float(tp) / precisionDenom recall = float(tp) / ap if tp == 0: fscore = 0.0 else: fscore = 2.0 * precision * recall / (precision + recall) uraw = float( ufactor * tp - fp ) umax = float( ufactor * ap ) unorm = float( uraw / umax ) if csv: # csv tab = '\t' write('Run' + tab + 'TP' + tab + 'FP' + tab + 'FN' + tab + 'Precision' + tab + 'Recall' + tab + 'F-Score' + tab + 'Utility Factor' + tab + 'Raw Utility' + tab + 'Max Utility' + tab + 'Normalized Utility') write(tag + tab + str(tp) + tab + str(fp) + tab + str(fn) + tab + "%0.4f" % precision + tab + "%0.4f" % recall + tab + "%0.4f" % fscore + tab + str(ufactor)+ tab + str(uraw) + tab + str(umax) + tab + "%0.4f" % unorm) else: write("Run: %s" % tag) write("Counts: %s" % GetTruth()) write("Precision: %0.4f" % precision) write("Recall: %0.4f" % recall) write("F-score: %0.4f" % fscore) write("Utility Factor: %0.2f" % ufactor) write("Raw Utility: %d" % uraw) write("Max Utility: %d" % umax) write("Normalized Utility: %0.4f" % unorm) else: write("Not computable: %s" % GetTruth() ) # ----------------------------- Main ----------------------------- usage = 'Usage: cat_eval2.py data_file gold_standard_file Urelevant [-tab]' if not len(sys.argv) in (4, 5): err('Missing required arguments.\n%s' % usage) # grab parameters... data_file, gold_standard_file = sys.argv[1:3] try: Ur = float(sys.argv[3]) except: err('Third argument, Urelevant, must be a float.\n%s' % usage) if not exists(data_file): err('Cannot read data file.') if not exists(gold_standard_file): err('Cannot read gold standard file.') CheckDataFile( data_file, 'data' ) CheckDataFile( gold_standard_file, 'gold' ) # process special options... csv = False if len(sys.argv) > 4: options = sys.argv[4:] while options: if options[0] == "-tab": csv = True options = options[1:] else: err('Unrecognized option %s.\n%s' % (options[0], usage)) # compute and produce output... ShowOutput(gold_standard_file, data_file, csv, Ur)