Select Git revision
BOperation.java
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
evaluate.py 12.46 KiB
"""
Evaluate NLU models on specified dataset
Metric: dataset level Precision/Recall/F1
Usage: python evaluate.py [MultiWOZ] [SCLSTM|TemplateNLG] [usr|sys]
"""
import json
import os
import random
import sys
import itertools
import zipfile
import numpy
from numpy.lib.shape_base import _put_along_axis_dispatcher
from numpy.lib.twodim_base import triu_indices_from
import torch
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from pprint import pprint
from tqdm import tqdm
def slot_error(dialog_acts, utts):
halucination = []
halucinate = 0
missing = 0
total = 0
for acts,utt in zip(dialog_acts, utts):
for act in acts:
tmp_act = [x.lower() for x in act]
tmp_utt = utt.lower()
i, d, s, v = tmp_act
if i == 'inform':
total = total + 1
if not (v in tmp_utt):
missing = missing + 1
return missing, total
def fine_SER(dialog_acts, utts):
path = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(path, 'template', 'multiwoz', 'label_maps.json')
with open(path, 'r') as mapping_file:
mappings = json.load(mapping_file)
mapping_file.close()
path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
path = os.path.join(path, 'data', 'multiwoz', 'ontology_nlg_eval.json')
with open(path, 'r') as entity_file:
possible_entity = json.load(entity_file)
entity_file.close()
entity_list = []
for key in possible_entity.keys():
entity_list = entity_list + possible_entity[key]
hallucinate = 0
missing = 0
total = 0
unk_token_count = 0
missing_dialogs = []
hallucination_dialogs = []
slot_span = []
domain_span = []
for acts,utt in zip(dialog_acts, utts):
hallucination_flag = False
tmp_utt = utt.lower()
origin_utt = utt.lower()
legal_act_flag = False
for act in acts:
missing_fact = None
missing_flag = False
tmp_act = [x.lower() for x in act]
i, d, s, v = tmp_act
if not(d in domain_span):
domain_span.append(d)
if not(s in slot_span):
slot_span.append(s)
#intializing all possible span keyword
if i in ['inform', 'recommend', 'offerbook', 'offerbooked','book','select']:
legal_act_flag = True
total = total + 1
if not (v in origin_utt) and v!='none':
exist_flag = False
try:
synoyms = mappings[v]
for item in synoyms:
if item in origin_utt:
exist_flag = True
tmp_utt = tmp_utt.replace(item,'')
tmp_utt = tmp_utt.replace(s,'')
#remove span for hallucination detection
except:
pass
if i in ['offerbook', 'offerbooked'] and v =='none':
if 'book' in origin_utt:
exist_flag = True
tmp_utt = tmp_utt.replace('book','')
if i in ['inform','recommend'] and v=='none':
if d in origin_utt:
exist_flag = True
tmp_utt = tmp_utt.replace(d, '')
if exist_flag == False:
missing_flag = True
missing_fact = v
else:
tmp_utt = tmp_utt.replace(v,'')
tmp_utt = tmp_utt.replace(s,'')
if s in origin_utt:
missing_flag = False
if s =='booking' and ('book' in origin_utt or 'reserv' in origin_utt):
missing_flag = False
elif i == 'request':
legal_act_flag = True
total = total + 1
if s=='depart' or s=='dest' or s=='area':
if not ('where' in origin_utt):
if s in origin_utt:
tmp_utt = tmp_utt.replace(s,'')
else:
missing_flag = True
missing_fact = s
elif s=='leave' or s=='arrive':
if (not 'when' in origin_utt):
if not ('what' in origin_utt and 'time' in origin_utt):
missing_flag = True
missing_fact = s
else:
tmp_utt.replace('time', '')
else:
tmp_utt = tmp_utt.replace(s,'')
tmp_utt = tmp_utt.replace(d,'')
if s in origin_utt:
missing_flag = False
if s =='booking' and ('book' in origin_utt or 'reserv' in origin_utt):
missing_flag = False
try:
tmp_utt = tmp_utt.replace(d,'')
tmp_utt = tmp_utt.replace(s,'')
if 'arrive' in s or 'leave' in s:
tmp_utt = tmp_utt.replace('time', '')
except:
pass
if missing_flag == True:
missing = missing + 1
missing_dialogs.append(missing_fact)
missing_dialogs.append(acts)
missing_dialogs.append(utt)
for keyword in slot_span + entity_list:
if keyword in tmp_utt and len(keyword) >= 4 and legal_act_flag == True:
hallucination_flag = True
hallucinate = hallucinate + 1
hallucination_dialogs.append(keyword)
hallucination_dialogs.append(acts)
hallucination_dialogs.append(tmp_utt)
hallucination_dialogs.append(utt)
break
return missing, hallucinate, total, hallucination_dialogs, missing_dialogs
def get_bleu4(dialog_acts, golden_utts, gen_utts):
das2utts = {}
for das, utt, gen in zip(dialog_acts, golden_utts, gen_utts):
utt = utt.lower()
gen = gen.lower()
for da in das:
act, domain, s, v = da
if act == 'Request' or domain == 'general':
continue
else:
if s == 'Internet' or s == 'Parking' or s == 'none' or v == 'none':
continue
else:
v = v.lower()
if (' ' + v in utt) or (v + ' ' in utt):
utt = utt.replace(v, '{}-{}'.format(act + '-' + domain, s), 1)
if (' ' + v in gen) or (v + ' ' in gen):
gen = gen.replace(v, '{}-{}'.format(act + '-' + domain, s), 1)
hash_key = ''
for da in sorted(das, key=lambda x: x[0] + x[1] + x[2]):
hash_key += '-'.join(da[:-1]) + ';'
das2utts.setdefault(hash_key, {'refs': [], 'gens': []})
das2utts[hash_key]['refs'].append(utt)
das2utts[hash_key]['gens'].append(gen)
# pprint(das2utts)
refs, gens = [], []
for das in das2utts.keys():
for gen in das2utts[das]['gens']:
refs.append([s.split() for s in das2utts[das]['refs']])
gens.append(gen.split())
bleu = corpus_bleu(refs, gens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=SmoothingFunction().method1)
return bleu
if __name__ == '__main__':
seed = 2020
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
if len(sys.argv) < 4:
print("usage:")
print("\t python evaluate.py dataset model role")
print("\t dataset=MultiWOZ, CrossWOZ, or Camrest")
print("\t model=SCLSTM, SCLSTM_NoUNK, SCGPT or TemplateNLG")
print("\t role=usr/sys")
print("\t [Optional] model_file")
sys.exit()
dataset_name = sys.argv[1]
model_name = sys.argv[2]
role = sys.argv[3]
model_file = sys.argv[4] if len(sys.argv) >= 5 else None
if dataset_name == 'MultiWOZ':
if model_name == 'SCLSTM':
from convlab2.nlg.sclstm.multiwoz import SCLSTM
if role == 'usr':
model = SCLSTM(is_user=True, use_cuda=True, unk_suppress=False)
elif role == 'sys':
model = SCLSTM(is_user=False, use_cuda=True, unk_suppress=False)
elif model_name == 'SCLSTM_NoUNK':
from convlab2.nlg.sclstm.multiwoz import SCLSTM
if role == 'usr':
model = SCLSTM(is_user=True, use_cuda=True, unk_suppress=True)
elif role == 'sys':
model = SCLSTM(is_user=False, use_cuda=True, unk_suppress=True)
elif model_name == 'TemplateNLG':
from convlab2.nlg.template.multiwoz import TemplateNLG
if role == 'usr':
model = TemplateNLG(is_user=True)
elif role == 'sys':
model = TemplateNLG(is_user=False)
elif model_name == 'SCGPT':
from convlab2.nlg.scgpt.multiwoz import SCGPT
if model_file is not None:
print(f"load model at {model_file}")
if role == 'usr':
model = SCGPT(model_file, is_user=True)
elif role == 'sys':
model = SCGPT(model_file, is_user=False)
else:
raise Exception("Available models: SCLSTM, SCGPT, TEMPLATE")
from convlab2.util.dataloader.module_dataloader import SingleTurnNLGDataloader
from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
dataloader = SingleTurnNLGDataloader(dataset_dataloader=MultiWOZDataloader())
data = dataloader.load_data(data_key='all', role=role, session_id=True)['test']
dialog_acts = []
golden_utts = []
gen_utts = []
gen_slots = []
sen_num = 0
# sys.stdout = open(sys.argv[2] + '-' + sys.argv[3] + '-' + 'evaluate_logs_neo.txt','w')
assert 'utterance' in data and 'dialog_act' in data and 'session_id' in data
assert len(data['utterance']) == len(data['dialog_act']) == len(data['session_id'])
# Turns during the same session should be contiguous, so we can call init_session at the first turn of a new session.
# This is necessary for SCGPT, but unnecessary for SCLSTM and TemplateNLG.
is_first_turn = []
for _, iterator in itertools.groupby(data['session_id']):
is_first_turn.append(True)
next(iterator)
is_first_turn.extend(False for _ in iterator)
for i in tqdm(range(len(data['utterance']))):
if is_first_turn[i]:
model.init_session()
dialog_acts.append(data['dialog_act'][i])
golden_utts.append(data['utterance'][i])
gen_utts.append(model.generate(data['dialog_act'][i]))
# print(dialog_acts[-1])
# print(golden_utts[-1])
# print(gen_utts[-1])
print("Calculate SER for golden responses")
missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(dialog_acts, golden_utts)
print("Golden response Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(missing, total, hallucinate, missing/total))
print("Calculate SER")
missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(dialog_acts, gen_utts)
# with open('{}-{}-genutt_neo.txt'.format(sys.argv[2], sys.argv[3]), mode='wt', encoding='utf-8') as gen_diag:
# for x in gen_utts:
# gen_diag.writelines(str(x)+'\n')
# with open('{}-{}-hallucinate_neo.txt'.format(sys.argv[2], sys.argv[3]), mode='wt', encoding='utf-8') as hal_diag:
# for x in hallucination_dialogs:
# hal_diag.writelines(str(x)+'\n')
# with open('{}-{}-missing_neo.txt'.format(sys.argv[2], sys.argv[3]), mode='wt', encoding='utf-8') as miss_diag:
# for x in missing_dialogs:
# miss_diag.writelines(str(x)+'\n')
print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(sys.argv[2], missing, total, hallucinate, missing/total))
print("Calculate bleu-4")
bleu4 = get_bleu4(dialog_acts, golden_utts, gen_utts)
print("BLEU-4: %.4f" % bleu4)
print('Model on {} sentences role={}'.format(len(data['utterance']), role))
# sys.stdout.close()
else:
raise Exception("currently supported dataset: MultiWOZ")