Skip to content
Snippets Groups Projects
Commit 9a7c8c25 authored by Carel van Niekerk's avatar Carel van Niekerk
Browse files

Add new nlg evaluation script including slot error rate and hallucination

parent 67809f17
No related branches found
No related tags found
No related merge requests found
...@@ -5,16 +5,172 @@ Usage: python evaluate.py [MultiWOZ] [SCLSTM|TemplateNLG] [usr|sys] ...@@ -5,16 +5,172 @@ Usage: python evaluate.py [MultiWOZ] [SCLSTM|TemplateNLG] [usr|sys]
""" """
import json import json
import os
import random import random
import sys import sys
import zipfile import zipfile
import numpy import numpy
from numpy.lib.shape_base import _put_along_axis_dispatcher
from numpy.lib.twodim_base import triu_indices_from
import torch import torch
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from pprint import pprint from pprint import pprint
from tqdm import tqdm from tqdm import tqdm
def slot_error(dialog_acts, utts):
halucination = []
halucinate = 0
missing = 0
total = 0
for acts,utt in zip(dialog_acts, utts):
for act in acts:
tmp_act = [x.lower() for x in act]
tmp_utt = utt.lower()
i, d, s, v = tmp_act
if i == 'inform':
total = total + 1
if not (v in tmp_utt):
missing = missing + 1
return missing, total
def fine_SER(dialog_acts, utts):
path = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(path, 'template', 'multiwoz', 'label_maps.json')
with open(path, 'r') as mapping_file:
mappings = json.load(mapping_file)
mapping_file.close()
path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
path = os.path.join(path, 'data', 'multiwoz', 'ontology_nlg_eval.json')
with open(path, 'r') as entity_file:
possible_entity = json.load(entity_file)
entity_file.close()
entity_list = []
for key in possible_entity.keys():
entity_list = entity_list + possible_entity[key]
hallucinate = 0
missing = 0
total = 0
unk_token_count = 0
missing_dialogs = []
hallucination_dialogs = []
slot_span = []
domain_span = []
for acts,utt in zip(dialog_acts, utts):
hallucination_flag = False
tmp_utt = utt.lower()
origin_utt = utt.lower()
legal_act_flag = False
for act in acts:
missing_fact = None
missing_flag = False
tmp_act = [x.lower() for x in act]
i, d, s, v = tmp_act
if not(d in domain_span):
domain_span.append(d)
if not(s in slot_span):
slot_span.append(s)
#intializing all possible span keyword
if i in ['inform', 'recommend', 'offerbook', 'offerbooked','book','select']:
legal_act_flag = True
total = total + 1
if not (v in origin_utt) and v!='none':
exist_flag = False
try:
synoyms = mappings[v]
for item in synoyms:
if item in origin_utt:
exist_flag = True
tmp_utt = tmp_utt.replace(item,'')
tmp_utt = tmp_utt.replace(s,'')
#remove span for hallucination detection
except:
pass
if i in ['offerbook', 'offerbooked'] and v =='none':
if 'book' in origin_utt:
exist_flag = True
tmp_utt = tmp_utt.replace('book','')
if i in ['inform','recommend'] and v=='none':
if d in origin_utt:
exist_flag = True
tmp_utt = tmp_utt.replace(d, '')
if exist_flag == False:
missing_flag = True
missing_fact = v
else:
tmp_utt = tmp_utt.replace(v,'')
tmp_utt = tmp_utt.replace(s,'')
if s in origin_utt:
missing_flag = False
if s =='booking' and ('book' in origin_utt or 'reserv' in origin_utt):
missing_flag = False
elif i == 'request':
legal_act_flag = True
total = total + 1
if s=='depart' or s=='dest' or s=='area':
if not ('where' in origin_utt):
if s in origin_utt:
tmp_utt = tmp_utt.replace(s,'')
else:
missing_flag = True
missing_fact = s
elif s=='leave' or s=='arrive':
if (not 'when' in origin_utt):
if not ('what' in origin_utt and 'time' in origin_utt):
missing_flag = True
missing_fact = s
else:
tmp_utt.replace('time', '')
else:
tmp_utt = tmp_utt.replace(s,'')
tmp_utt = tmp_utt.replace(d,'')
if s in origin_utt:
missing_flag = False
if s =='booking' and ('book' in origin_utt or 'reserv' in origin_utt):
missing_flag = False
try:
tmp_utt = tmp_utt.replace(d,'')
tmp_utt = tmp_utt.replace(s,'')
if 'arrive' in s or 'leave' in s:
tmp_utt = tmp_utt.replace('time', '')
except:
pass
if missing_flag == True:
missing = missing + 1
missing_dialogs.append(missing_fact)
missing_dialogs.append(acts)
missing_dialogs.append(utt)
for keyword in slot_span + entity_list:
if keyword in tmp_utt and len(keyword) >= 4 and legal_act_flag == True:
hallucination_flag = True
hallucinate = hallucinate + 1
hallucination_dialogs.append(keyword)
hallucination_dialogs.append(acts)
hallucination_dialogs.append(tmp_utt)
hallucination_dialogs.append(utt)
break
return missing, hallucinate, total, hallucination_dialogs, missing_dialogs
def get_bleu4(dialog_acts, golden_utts, gen_utts): def get_bleu4(dialog_acts, golden_utts, gen_utts):
das2utts = {} das2utts = {}
for das, utt, gen in zip(dialog_acts, golden_utts, gen_utts): for das, utt, gen in zip(dialog_acts, golden_utts, gen_utts):
...@@ -59,7 +215,7 @@ if __name__ == '__main__': ...@@ -59,7 +215,7 @@ if __name__ == '__main__':
print("usage:") print("usage:")
print("\t python evaluate.py dataset model role") print("\t python evaluate.py dataset model role")
print("\t dataset=MultiWOZ, CrossWOZ, or Camrest") print("\t dataset=MultiWOZ, CrossWOZ, or Camrest")
print("\t model=SCLSTM, or TemplateNLG") print("\t model=SCLSTM, SCLSTM_NoUNK, SCGPT or TemplateNLG")
print("\t role=usr/sys") print("\t role=usr/sys")
sys.exit() sys.exit()
dataset_name = sys.argv[1] dataset_name = sys.argv[1]
...@@ -69,22 +225,34 @@ if __name__ == '__main__': ...@@ -69,22 +225,34 @@ if __name__ == '__main__':
if model_name == 'SCLSTM': if model_name == 'SCLSTM':
from convlab2.nlg.sclstm.multiwoz import SCLSTM from convlab2.nlg.sclstm.multiwoz import SCLSTM
if role == 'usr': if role == 'usr':
model = SCLSTM(is_user=True, use_cuda=True) model = SCLSTM(is_user=True, use_cuda=True, unk_suppress=False)
elif role == 'sys':
model = SCLSTM(is_user=False, use_cuda=True, unk_suppress=False)
elif model_name == 'SCLSTM_NoUNK':
from convlab2.nlg.sclstm.multiwoz import SCLSTM
if role == 'usr':
model = SCLSTM(is_user=True, use_cuda=True, unk_suppress=True)
elif role == 'sys': elif role == 'sys':
model = SCLSTM(is_user=False, use_cuda=True) model = SCLSTM(is_user=False, use_cuda=True, unk_suppress=True)
elif model_name == 'TemplateNLG': elif model_name == 'TemplateNLG':
from convlab2.nlg.template.multiwoz import TemplateNLG from convlab2.nlg.template.multiwoz import TemplateNLG
if role == 'usr': if role == 'usr':
model = TemplateNLG(is_user=True) model = TemplateNLG(is_user=True)
elif role == 'sys': elif role == 'sys':
model = TemplateNLG(is_user=False) model = TemplateNLG(is_user=False)
elif model_name == 'SCGPT':
from convlab2.nlg.scgpt.multiwoz import SCGPT
if role == 'usr':
model = SCGPT(is_user=True)
elif role == 'sys':
model = SCGPT(is_user=False, model_file='scgpt/trained_output/multiwoz/')
else: else:
raise Exception("Available models: SCLSTM, TEMPLATE") raise Exception("Available models: SCLSTM, SCGPT, TEMPLATE")
from convlab2.util.dataloader.module_dataloader import SingleTurnNLGDataloader from convlab2.util.dataloader.module_dataloader import SingleTurnNLGDataloader
from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
dataloader = SingleTurnNLGDataloader(dataset_dataloader=MultiWOZDataloader()) dataloader = SingleTurnNLGDataloader(dataset_dataloader=MultiWOZDataloader())
data = dataloader.load_data(data_key='test', role=role)['test'] data = dataloader.load_data(data_key='all', role=role)['test']
dialog_acts = [] dialog_acts = []
golden_utts = [] golden_utts = []
...@@ -93,17 +261,39 @@ if __name__ == '__main__': ...@@ -93,17 +261,39 @@ if __name__ == '__main__':
sen_num = 0 sen_num = 0
# sys.stdout = open(sys.argv[2] + '-' + sys.argv[3] + '-' + 'evaluate_logs_neo.txt','w')
for i in tqdm(range(len(data['utterance']))): for i in tqdm(range(len(data['utterance']))):
dialog_acts.append(data['dialog_act'][i]) dialog_acts.append(data['dialog_act'][i])
golden_utts.append(data['utterance'][i]) golden_utts.append(data['utterance'][i])
gen_utts.append(model.generate(data['dialog_act'][i])) gen_utts.append(model.generate(data['dialog_act'][i]))
# print(dialog_acts[-1])
# print(golden_utts[-1])
# print(gen_utts[-1])
bleu4 = get_bleu4(dialog_acts, golden_utts, gen_utts) print("Calculate SER for golden responses")
missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(dialog_acts, golden_utts)
print("Golden response Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(missing, total, hallucinate, missing/total))
print("Calculate SER")
missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(dialog_acts, gen_utts)
# with open('{}-{}-genutt_neo.txt'.format(sys.argv[2], sys.argv[3]), mode='wt', encoding='utf-8') as gen_diag:
# for x in gen_utts:
# gen_diag.writelines(str(x)+'\n')
# with open('{}-{}-hallucinate_neo.txt'.format(sys.argv[2], sys.argv[3]), mode='wt', encoding='utf-8') as hal_diag:
# for x in hallucination_dialogs:
# hal_diag.writelines(str(x)+'\n')
# with open('{}-{}-missing_neo.txt'.format(sys.argv[2], sys.argv[3]), mode='wt', encoding='utf-8') as miss_diag:
# for x in missing_dialogs:
# miss_diag.writelines(str(x)+'\n')
print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(sys.argv[2], missing, total, hallucinate, missing/total))
print("Calculate bleu-4") print("Calculate bleu-4")
bleu4 = get_bleu4(dialog_acts, golden_utts, gen_utts)
print("BLEU-4: %.4f" % bleu4) print("BLEU-4: %.4f" % bleu4)
print('Model on {} sentences role={}'.format(len(data['utterance']), role)) print('Model on {} sentences role={}'.format(len(data['utterance']), role))
# sys.stdout.close()
else: else:
raise Exception("currently supported dataset: MultiWOZ") raise Exception("currently supported dataset: MultiWOZ")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment