From 5818e14ebf98005c7e3b13a3f18fade501065c23 Mon Sep 17 00:00:00 2001
From: zz-jacob <zhangz.goal@gmail.com>
Date: Mon, 21 Mar 2022 15:41:56 +0800
Subject: [PATCH] fix bugs

---
 convlab2/nlg/scgpt/evaluate.py |  2 +-
 convlab2/nlg/scgpt/main.py     | 10 ++++++++--
 2 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/convlab2/nlg/scgpt/evaluate.py b/convlab2/nlg/scgpt/evaluate.py
index c54b079f..f7435b68 100644
--- a/convlab2/nlg/scgpt/evaluate.py
+++ b/convlab2/nlg/scgpt/evaluate.py
@@ -246,7 +246,7 @@ class GentScorer(object):
     ## 2. Compute slot error rate
     ## 3. Detailed illustraction of how differet split
     ##    of data affect performance
-    def __init__(self, detectfile):
+    def __init__(self):
         self.bleuscorer = BLEUScorer()
 
     def scoreERR(self, parallel_pairs):
diff --git a/convlab2/nlg/scgpt/main.py b/convlab2/nlg/scgpt/main.py
index a48bf7b4..2d69ba2f 100644
--- a/convlab2/nlg/scgpt/main.py
+++ b/convlab2/nlg/scgpt/main.py
@@ -221,6 +221,8 @@ def test(model, nlg_data, ontology, model_path):
     test_data = nlg_data['test']
     dialog_acts = [act2str(item['dialogue_acts']) for item in test_data]
     golden_responses = [item['utterance'] for item in test_data]
+    # dialog_acts = dialog_acts[:10]
+    # golden_responses = golden_responses[:10]
     outputs = inference_sents(model, dialog_acts)
     if dist.get_rank() == 0:
         output_file = './test_output.txt'
@@ -241,13 +243,15 @@ def test(model, nlg_data, ontology, model_path):
         domain = ontology['domains'][domain_name]
         for slot_name in domain['slots']:
             slot = domain['slots'][slot_name]
+            if 'possible_values' not in slot:
+                continue
             possible_vals = slot['possible_values']
             if len(possible_vals) > 0:
                 for val in possible_vals:
                     val2ds_dict[val] = f'{domain_name}-{slot_name}'
     ## missing values
     score_list = []
-    for item in nlg_data:
+    for item in test_data:
         da = item['dialogue_acts']
         utterance = item['utterance']
         missing_count = 0
@@ -263,11 +267,13 @@ def test(model, nlg_data, ontology, model_path):
                     if value.strip().lower() not in utterance.lower():
                         missing_count += 1
                     all_count += 1
+        if all_count == 0:
+            continue
         ## redundant values
         for val in val2ds_dict:
             if f' {val.strip().lower()} ' in f' {utterance.strip().lower()} ' and val.strip().lower() not in all_values:
                 redundant_count += 1
-        item_score = float(redundant_count + all_count) / all_count
+        item_score = float(redundant_count + redundant_count) / all_count
         score_list.append(item_score)
     ERR_Score = np.mean(score_list)
     print(f'BLEU: {BLEU_Score}\nERR_Score: {ERR_Score}')
-- 
GitLab