From 5818e14ebf98005c7e3b13a3f18fade501065c23 Mon Sep 17 00:00:00 2001 From: zz-jacob <zhangz.goal@gmail.com> Date: Mon, 21 Mar 2022 15:41:56 +0800 Subject: [PATCH] fix bugs --- convlab2/nlg/scgpt/evaluate.py | 2 +- convlab2/nlg/scgpt/main.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/convlab2/nlg/scgpt/evaluate.py b/convlab2/nlg/scgpt/evaluate.py index c54b079f..f7435b68 100644 --- a/convlab2/nlg/scgpt/evaluate.py +++ b/convlab2/nlg/scgpt/evaluate.py @@ -246,7 +246,7 @@ class GentScorer(object): ## 2. Compute slot error rate ## 3. Detailed illustraction of how differet split ## of data affect performance - def __init__(self, detectfile): + def __init__(self): self.bleuscorer = BLEUScorer() def scoreERR(self, parallel_pairs): diff --git a/convlab2/nlg/scgpt/main.py b/convlab2/nlg/scgpt/main.py index a48bf7b4..2d69ba2f 100644 --- a/convlab2/nlg/scgpt/main.py +++ b/convlab2/nlg/scgpt/main.py @@ -221,6 +221,8 @@ def test(model, nlg_data, ontology, model_path): test_data = nlg_data['test'] dialog_acts = [act2str(item['dialogue_acts']) for item in test_data] golden_responses = [item['utterance'] for item in test_data] + # dialog_acts = dialog_acts[:10] + # golden_responses = golden_responses[:10] outputs = inference_sents(model, dialog_acts) if dist.get_rank() == 0: output_file = './test_output.txt' @@ -241,13 +243,15 @@ def test(model, nlg_data, ontology, model_path): domain = ontology['domains'][domain_name] for slot_name in domain['slots']: slot = domain['slots'][slot_name] + if 'possible_values' not in slot: + continue possible_vals = slot['possible_values'] if len(possible_vals) > 0: for val in possible_vals: val2ds_dict[val] = f'{domain_name}-{slot_name}' ## missing values score_list = [] - for item in nlg_data: + for item in test_data: da = item['dialogue_acts'] utterance = item['utterance'] missing_count = 0 @@ -263,11 +267,13 @@ def test(model, nlg_data, ontology, model_path): if value.strip().lower() not in utterance.lower(): missing_count += 1 all_count += 1 + if all_count == 0: + continue ## redundant values for val in val2ds_dict: if f' {val.strip().lower()} ' in f' {utterance.strip().lower()} ' and val.strip().lower() not in all_values: redundant_count += 1 - item_score = float(redundant_count + all_count) / all_count + item_score = float(redundant_count + redundant_count) / all_count score_list.append(item_score) ERR_Score = np.mean(score_list) print(f'BLEU: {BLEU_Score}\nERR_Score: {ERR_Score}') -- GitLab