""" Evaluate NLU models on specified dataset Metric: dataset level Precision/Recall/F1 Usage: python evaluate.py [MultiWOZ] [SCLSTM|TemplateNLG] [usr|sys] """ import json import os import random import sys import itertools import zipfile import numpy from numpy.lib.shape_base import _put_along_axis_dispatcher from numpy.lib.twodim_base import triu_indices_from import torch from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction from pprint import pprint from tqdm import tqdm def slot_error(dialog_acts, utts): halucination = [] halucinate = 0 missing = 0 total = 0 for acts,utt in zip(dialog_acts, utts): for act in acts: tmp_act = [x.lower() for x in act] tmp_utt = utt.lower() i, d, s, v = tmp_act if i == 'inform': total = total + 1 if not (v in tmp_utt): missing = missing + 1 return missing, total def fine_SER(dialog_acts, utts): path = os.path.dirname(os.path.abspath(__file__)) path = os.path.join(path, 'template', 'multiwoz', 'label_maps.json') with open(path, 'r') as mapping_file: mappings = json.load(mapping_file) mapping_file.close() path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) path = os.path.join(path, 'data', 'multiwoz', 'ontology_nlg_eval.json') with open(path, 'r') as entity_file: possible_entity = json.load(entity_file) entity_file.close() entity_list = [] for key in possible_entity.keys(): entity_list = entity_list + possible_entity[key] hallucinate = 0 missing = 0 total = 0 unk_token_count = 0 missing_dialogs = [] hallucination_dialogs = [] slot_span = [] domain_span = [] for acts,utt in zip(dialog_acts, utts): hallucination_flag = False tmp_utt = utt.lower() origin_utt = utt.lower() legal_act_flag = False for act in acts: missing_fact = None missing_flag = False tmp_act = [x.lower() for x in act] i, d, s, v = tmp_act if not(d in domain_span): domain_span.append(d) if not(s in slot_span): slot_span.append(s) #intializing all possible span keyword if i in ['inform', 'recommend', 'offerbook', 'offerbooked','book','select']: legal_act_flag = True total = total + 1 if not (v in origin_utt) and v!='none': exist_flag = False try: synoyms = mappings[v] for item in synoyms: if item in origin_utt: exist_flag = True tmp_utt = tmp_utt.replace(item,'') tmp_utt = tmp_utt.replace(s,'') #remove span for hallucination detection except: pass if i in ['offerbook', 'offerbooked'] and v =='none': if 'book' in origin_utt: exist_flag = True tmp_utt = tmp_utt.replace('book','') if i in ['inform','recommend'] and v=='none': if d in origin_utt: exist_flag = True tmp_utt = tmp_utt.replace(d, '') if exist_flag == False: missing_flag = True missing_fact = v else: tmp_utt = tmp_utt.replace(v,'') tmp_utt = tmp_utt.replace(s,'') if s in origin_utt: missing_flag = False if s =='booking' and ('book' in origin_utt or 'reserv' in origin_utt): missing_flag = False elif i == 'request': legal_act_flag = True total = total + 1 if s=='depart' or s=='dest' or s=='area': if not ('where' in origin_utt): if s in origin_utt: tmp_utt = tmp_utt.replace(s,'') else: missing_flag = True missing_fact = s elif s=='leave' or s=='arrive': if (not 'when' in origin_utt): if not ('what' in origin_utt and 'time' in origin_utt): missing_flag = True missing_fact = s else: tmp_utt.replace('time', '') else: tmp_utt = tmp_utt.replace(s,'') tmp_utt = tmp_utt.replace(d,'') if s in origin_utt: missing_flag = False if s =='booking' and ('book' in origin_utt or 'reserv' in origin_utt): missing_flag = False try: tmp_utt = tmp_utt.replace(d,'') tmp_utt = tmp_utt.replace(s,'') if 'arrive' in s or 'leave' in s: tmp_utt = tmp_utt.replace('time', '') except: pass if missing_flag == True: missing = missing + 1 missing_dialogs.append(missing_fact) missing_dialogs.append(acts) missing_dialogs.append(utt) for keyword in slot_span + entity_list: if keyword in tmp_utt and len(keyword) >= 4 and legal_act_flag == True: hallucination_flag = True hallucinate = hallucinate + 1 hallucination_dialogs.append(keyword) hallucination_dialogs.append(acts) hallucination_dialogs.append(tmp_utt) hallucination_dialogs.append(utt) break return missing, hallucinate, total, hallucination_dialogs, missing_dialogs def get_bleu4(dialog_acts, golden_utts, gen_utts): das2utts = {} for das, utt, gen in zip(dialog_acts, golden_utts, gen_utts): utt = utt.lower() gen = gen.lower() for da in das: act, domain, s, v = da if act == 'Request' or domain == 'general': continue else: if s == 'Internet' or s == 'Parking' or s == 'none' or v == 'none': continue else: v = v.lower() if (' ' + v in utt) or (v + ' ' in utt): utt = utt.replace(v, '{}-{}'.format(act + '-' + domain, s), 1) if (' ' + v in gen) or (v + ' ' in gen): gen = gen.replace(v, '{}-{}'.format(act + '-' + domain, s), 1) hash_key = '' for da in sorted(das, key=lambda x: x[0] + x[1] + x[2]): hash_key += '-'.join(da[:-1]) + ';' das2utts.setdefault(hash_key, {'refs': [], 'gens': []}) das2utts[hash_key]['refs'].append(utt) das2utts[hash_key]['gens'].append(gen) # pprint(das2utts) refs, gens = [], [] for das in das2utts.keys(): for gen in das2utts[das]['gens']: refs.append([s.split() for s in das2utts[das]['refs']]) gens.append(gen.split()) bleu = corpus_bleu(refs, gens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=SmoothingFunction().method1) return bleu if __name__ == '__main__': seed = 2020 random.seed(seed) numpy.random.seed(seed) torch.manual_seed(seed) if len(sys.argv) < 4: print("usage:") print("\t python evaluate.py dataset model role") print("\t dataset=MultiWOZ, CrossWOZ, or Camrest") print("\t model=SCLSTM, SCLSTM_NoUNK, SCGPT or TemplateNLG") print("\t role=usr/sys") print("\t [Optional] model_file") sys.exit() dataset_name = sys.argv[1] model_name = sys.argv[2] role = sys.argv[3] model_file = sys.argv[4] if len(sys.argv) >= 5 else None if dataset_name == 'MultiWOZ': if model_name == 'SCLSTM': from convlab2.nlg.sclstm.multiwoz import SCLSTM if role == 'usr': model = SCLSTM(is_user=True, use_cuda=True, unk_suppress=False) elif role == 'sys': model = SCLSTM(is_user=False, use_cuda=True, unk_suppress=False) elif model_name == 'SCLSTM_NoUNK': from convlab2.nlg.sclstm.multiwoz import SCLSTM if role == 'usr': model = SCLSTM(is_user=True, use_cuda=True, unk_suppress=True) elif role == 'sys': model = SCLSTM(is_user=False, use_cuda=True, unk_suppress=True) elif model_name == 'TemplateNLG': from convlab2.nlg.template.multiwoz import TemplateNLG if role == 'usr': model = TemplateNLG(is_user=True) elif role == 'sys': model = TemplateNLG(is_user=False) elif model_name == 'SCGPT': from convlab2.nlg.scgpt.multiwoz import SCGPT if model_file is not None: print(f"load model at {model_file}") if role == 'usr': model = SCGPT(model_file, is_user=True) elif role == 'sys': model = SCGPT(model_file, is_user=False) else: raise Exception("Available models: SCLSTM, SCGPT, TEMPLATE") from convlab2.util.dataloader.module_dataloader import SingleTurnNLGDataloader from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader dataloader = SingleTurnNLGDataloader(dataset_dataloader=MultiWOZDataloader()) data = dataloader.load_data(data_key='all', role=role, session_id=True)['test'] dialog_acts = [] golden_utts = [] gen_utts = [] gen_slots = [] sen_num = 0 # sys.stdout = open(sys.argv[2] + '-' + sys.argv[3] + '-' + 'evaluate_logs_neo.txt','w') assert 'utterance' in data and 'dialog_act' in data and 'session_id' in data assert len(data['utterance']) == len(data['dialog_act']) == len(data['session_id']) # Turns during the same session should be contiguous, so we can call init_session at the first turn of a new session. # This is necessary for SCGPT, but unnecessary for SCLSTM and TemplateNLG. is_first_turn = [] for _, iterator in itertools.groupby(data['session_id']): is_first_turn.append(True) next(iterator) is_first_turn.extend(False for _ in iterator) for i in tqdm(range(len(data['utterance']))): if is_first_turn[i]: model.init_session() dialog_acts.append(data['dialog_act'][i]) golden_utts.append(data['utterance'][i]) gen_utts.append(model.generate(data['dialog_act'][i])) # print(dialog_acts[-1]) # print(golden_utts[-1]) # print(gen_utts[-1]) print("Calculate SER for golden responses") missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(dialog_acts, golden_utts) print("Golden response Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(missing, total, hallucinate, missing/total)) print("Calculate SER") missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(dialog_acts, gen_utts) # with open('{}-{}-genutt_neo.txt'.format(sys.argv[2], sys.argv[3]), mode='wt', encoding='utf-8') as gen_diag: # for x in gen_utts: # gen_diag.writelines(str(x)+'\n') # with open('{}-{}-hallucinate_neo.txt'.format(sys.argv[2], sys.argv[3]), mode='wt', encoding='utf-8') as hal_diag: # for x in hallucination_dialogs: # hal_diag.writelines(str(x)+'\n') # with open('{}-{}-missing_neo.txt'.format(sys.argv[2], sys.argv[3]), mode='wt', encoding='utf-8') as miss_diag: # for x in missing_dialogs: # miss_diag.writelines(str(x)+'\n') print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(sys.argv[2], missing, total, hallucinate, missing/total)) print("Calculate bleu-4") bleu4 = get_bleu4(dialog_acts, golden_utts, gen_utts) print("BLEU-4: %.4f" % bleu4) print('Model on {} sentences role={}'.format(len(data['utterance']), role)) # sys.stdout.close() else: raise Exception("currently supported dataset: MultiWOZ")