Skip to content
Snippets Groups Projects
Select Git revision
  • 812da0921a8f6ad2ce93521440ec66345be59ac1
  • master default protected
  • release/1.1.4
  • release/1.1.3
  • release/1.1.1
  • 1.4.1
  • 1.4.0
  • 1.3.0
  • 1.2.1
  • 1.2.0
  • 1.1.5
  • 1.1.4
  • 1.1.3
  • 1.1.1
  • 1.1.0
  • 1.0.9
  • 1.0.8
  • 1.0.7
  • v1.0.5
  • 1.0.5
20 results

BOperation.java

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    evaluate.py 12.46 KiB
    """
    Evaluate NLU models on specified dataset
    Metric: dataset level Precision/Recall/F1
    Usage: python evaluate.py [MultiWOZ] [SCLSTM|TemplateNLG] [usr|sys]
    """
    
    import json
    import os
    import random
    import sys
    import itertools
    import zipfile
    import numpy
    from numpy.lib.shape_base import _put_along_axis_dispatcher
    from numpy.lib.twodim_base import triu_indices_from
    import torch
    from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
    from pprint import pprint
    from tqdm import tqdm
    
    
    def slot_error(dialog_acts, utts):
        halucination = []
        halucinate = 0
        missing = 0
        total = 0
    
        for acts,utt in zip(dialog_acts, utts):
            for act in acts:
                tmp_act = [x.lower() for x in act]
                tmp_utt = utt.lower()
                i, d, s, v = tmp_act
                if i == 'inform':
                    total = total + 1
                    if not (v in tmp_utt):
                        missing = missing + 1
        return missing, total
    
    def fine_SER(dialog_acts, utts):
        path = os.path.dirname(os.path.abspath(__file__))
        path = os.path.join(path, 'template', 'multiwoz', 'label_maps.json')
        with open(path, 'r') as mapping_file:
            mappings = json.load(mapping_file)
            mapping_file.close()
    
        path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
        path = os.path.join(path, 'data', 'multiwoz', 'ontology_nlg_eval.json')
        with open(path, 'r') as entity_file:
            possible_entity = json.load(entity_file)
            entity_file.close()
    
        entity_list = []
    
        for key in possible_entity.keys():
            entity_list = entity_list + possible_entity[key]
        
        hallucinate = 0
        missing = 0
        total = 0
    
        unk_token_count = 0
        missing_dialogs = []
        hallucination_dialogs = []
    
        slot_span = []
        domain_span = []
    
        for acts,utt in zip(dialog_acts, utts):
            hallucination_flag = False        
            tmp_utt = utt.lower()
            origin_utt = utt.lower()
            legal_act_flag = False
    
            for act in acts:
                missing_fact = None
                missing_flag = False
                tmp_act = [x.lower() for x in act]
                i, d, s, v = tmp_act
    
                if not(d in domain_span):
                    domain_span.append(d)
                if not(s in slot_span):
                    slot_span.append(s)
                #intializing all possible span keyword
    
                if i in ['inform', 'recommend', 'offerbook', 'offerbooked','book','select']:
                    legal_act_flag = True
                    total = total + 1
                    if not (v in origin_utt) and v!='none':
                        exist_flag = False
                        try:
                            synoyms = mappings[v]
                            for item in synoyms:
                                if item in origin_utt:
                                    exist_flag = True
                                    tmp_utt = tmp_utt.replace(item,'')
                                    tmp_utt = tmp_utt.replace(s,'')
                                    #remove span for hallucination detection
                        except:
                            pass
                        if i in ['offerbook', 'offerbooked'] and v =='none':
                            if 'book' in origin_utt:
                                exist_flag = True
                                tmp_utt = tmp_utt.replace('book','')
                        if i in ['inform','recommend'] and v=='none':
                            if d in origin_utt:
                                exist_flag = True
                                tmp_utt = tmp_utt.replace(d, '')
                        if exist_flag == False:
                            missing_flag = True
                            missing_fact = v
                    else:
                        tmp_utt = tmp_utt.replace(v,'')
                        tmp_utt = tmp_utt.replace(s,'')
    
                    if s in origin_utt:
                        missing_flag = False
                    if s =='booking' and ('book' in origin_utt or 'reserv' in origin_utt):
                        missing_flag = False
    
                elif i == 'request':
                    legal_act_flag = True
                    total = total + 1
                    if s=='depart' or s=='dest' or s=='area':
                        if not ('where' in origin_utt):
                            if s in origin_utt:
                                tmp_utt = tmp_utt.replace(s,'')
                            else:
                                missing_flag = True
                                missing_fact = s
                    elif s=='leave' or s=='arrive':
                        if (not 'when' in origin_utt):
                            if not ('what' in origin_utt and 'time' in origin_utt):
                                missing_flag = True
                                missing_fact = s
                        else:
                            tmp_utt.replace('time', '')
                    else:
                        tmp_utt = tmp_utt.replace(s,'')
                        tmp_utt = tmp_utt.replace(d,'')
    
                    if s in origin_utt:
                            missing_flag = False
                    if s =='booking' and ('book' in origin_utt or 'reserv' in origin_utt):
                        missing_flag = False    
    
                try:
                    tmp_utt = tmp_utt.replace(d,'')
                    tmp_utt = tmp_utt.replace(s,'')
                    if 'arrive' in s or 'leave' in s:
                        tmp_utt = tmp_utt.replace('time', '')
                except:
                    pass
    
                if missing_flag == True:
                    missing = missing + 1
                    missing_dialogs.append(missing_fact)
                    missing_dialogs.append(acts)
                    missing_dialogs.append(utt)
    
            for keyword in slot_span + entity_list:
                if keyword in tmp_utt and len(keyword) >= 4 and legal_act_flag == True:
                    hallucination_flag = True
                    hallucinate = hallucinate + 1
                    hallucination_dialogs.append(keyword)
                    hallucination_dialogs.append(acts)
                    hallucination_dialogs.append(tmp_utt)
                    hallucination_dialogs.append(utt)
                    break
    
    
        return missing, hallucinate, total, hallucination_dialogs, missing_dialogs
    
    
    def get_bleu4(dialog_acts, golden_utts, gen_utts):
        das2utts = {}
        for das, utt, gen in zip(dialog_acts, golden_utts, gen_utts):
            utt = utt.lower()
            gen = gen.lower()
            for da in das:
                act, domain, s, v = da
                if act == 'Request' or domain == 'general':
                    continue
                else:
                    if s == 'Internet' or s == 'Parking' or s == 'none' or v == 'none':
                        continue
                    else:
                        v = v.lower()
                        if (' ' + v in utt) or (v + ' ' in utt):
                            utt = utt.replace(v, '{}-{}'.format(act + '-' + domain, s), 1)
                        if (' ' + v in gen) or (v + ' ' in gen):
                            gen = gen.replace(v, '{}-{}'.format(act + '-' + domain, s), 1)
            hash_key = ''
            for da in sorted(das, key=lambda x: x[0] + x[1] + x[2]):
                hash_key += '-'.join(da[:-1]) + ';'
            das2utts.setdefault(hash_key, {'refs': [], 'gens': []})
            das2utts[hash_key]['refs'].append(utt)
            das2utts[hash_key]['gens'].append(gen)
        # pprint(das2utts)
        refs, gens = [], []
        for das in das2utts.keys():
            for gen in das2utts[das]['gens']:
                refs.append([s.split() for s in das2utts[das]['refs']])
                gens.append(gen.split())
        bleu = corpus_bleu(refs, gens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=SmoothingFunction().method1)
        return bleu
    
    
    if __name__ == '__main__':
        seed = 2020
        random.seed(seed)
        numpy.random.seed(seed)
        torch.manual_seed(seed)
    
        if len(sys.argv) < 4:
            print("usage:")
            print("\t python evaluate.py dataset model role")
            print("\t dataset=MultiWOZ, CrossWOZ, or Camrest")
            print("\t model=SCLSTM, SCLSTM_NoUNK, SCGPT or TemplateNLG")
            print("\t role=usr/sys")
            print("\t [Optional] model_file")
            sys.exit()
        dataset_name = sys.argv[1]
        model_name = sys.argv[2]
        role = sys.argv[3]
        model_file = sys.argv[4] if len(sys.argv) >= 5 else None
        if dataset_name == 'MultiWOZ':
            if model_name == 'SCLSTM':
                from convlab2.nlg.sclstm.multiwoz import SCLSTM
                if role == 'usr':
                    model = SCLSTM(is_user=True, use_cuda=True, unk_suppress=False)
                elif role == 'sys':
                    model = SCLSTM(is_user=False, use_cuda=True, unk_suppress=False)
            elif model_name == 'SCLSTM_NoUNK':
                from convlab2.nlg.sclstm.multiwoz import SCLSTM
                if role == 'usr':
                    model = SCLSTM(is_user=True, use_cuda=True, unk_suppress=True)
                elif role == 'sys':
                    model = SCLSTM(is_user=False, use_cuda=True, unk_suppress=True)
            elif model_name == 'TemplateNLG':
                from convlab2.nlg.template.multiwoz import TemplateNLG
                if role == 'usr':
                    model = TemplateNLG(is_user=True)
                elif role == 'sys':
                    model = TemplateNLG(is_user=False)
            elif model_name == 'SCGPT':
                from convlab2.nlg.scgpt.multiwoz import SCGPT
                if model_file is not None:
                    print(f"load model at {model_file}")
                if role == 'usr':
                    model = SCGPT(model_file, is_user=True)
                elif role == 'sys':
                    model  = SCGPT(model_file, is_user=False)
            else:
                raise Exception("Available models: SCLSTM, SCGPT, TEMPLATE")
    
            from convlab2.util.dataloader.module_dataloader import SingleTurnNLGDataloader
            from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
            dataloader = SingleTurnNLGDataloader(dataset_dataloader=MultiWOZDataloader())
            data = dataloader.load_data(data_key='all', role=role, session_id=True)['test']
    
            dialog_acts = []
            golden_utts = []
            gen_utts = []
            gen_slots = []
    
            sen_num = 0
    
            # sys.stdout = open(sys.argv[2] + '-' + sys.argv[3] + '-' + 'evaluate_logs_neo.txt','w')
            assert 'utterance' in data and 'dialog_act' in data and 'session_id' in data
            assert len(data['utterance']) == len(data['dialog_act']) == len(data['session_id'])
    
            # Turns during the same session should be contiguous, so we can call init_session at the first turn of a new session.
            # This is necessary for SCGPT, but unnecessary for SCLSTM and TemplateNLG.
            is_first_turn = []
            for _, iterator in itertools.groupby(data['session_id']):
                is_first_turn.append(True)
                next(iterator)
                is_first_turn.extend(False for _ in iterator)
            for i in tqdm(range(len(data['utterance']))):
                if is_first_turn[i]:
                    model.init_session()
                dialog_acts.append(data['dialog_act'][i])
                golden_utts.append(data['utterance'][i])
                gen_utts.append(model.generate(data['dialog_act'][i]))
            #     print(dialog_acts[-1])
            #     print(golden_utts[-1])
            #     print(gen_utts[-1])
    
            print("Calculate SER for golden responses")
            missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(dialog_acts, golden_utts)
            print("Golden response Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(missing, total, hallucinate, missing/total))
            
            print("Calculate SER")
            missing, hallucinate, total, hallucination_dialogs, missing_dialogs = fine_SER(dialog_acts, gen_utts)
            # with open('{}-{}-genutt_neo.txt'.format(sys.argv[2], sys.argv[3]), mode='wt', encoding='utf-8') as gen_diag:
            #     for x in gen_utts:
            #         gen_diag.writelines(str(x)+'\n')
    
    
            # with open('{}-{}-hallucinate_neo.txt'.format(sys.argv[2], sys.argv[3]), mode='wt', encoding='utf-8') as hal_diag:
            #     for x in hallucination_dialogs:
            #         hal_diag.writelines(str(x)+'\n')
            
            # with open('{}-{}-missing_neo.txt'.format(sys.argv[2], sys.argv[3]), mode='wt', encoding='utf-8') as miss_diag:
            #     for x in missing_dialogs:
            #         miss_diag.writelines(str(x)+'\n')
            print("{} Missing acts: {}, Total acts: {}, Hallucinations {}, SER {}".format(sys.argv[2], missing, total, hallucinate, missing/total))
            print("Calculate bleu-4")
            bleu4 = get_bleu4(dialog_acts, golden_utts, gen_utts)
            print("BLEU-4: %.4f" % bleu4)
            print('Model on {} sentences role={}'.format(len(data['utterance']), role))
            # sys.stdout.close()
    
        else:
            raise Exception("currently supported dataset: MultiWOZ")