From c854a517b012a5c204e1160e0abc551813a5f9da Mon Sep 17 00:00:00 2001
From: heck <heckmi@hhu.de>
Date: Tue, 17 Nov 2020 09:46:30 +0000
Subject: [PATCH] fixes in dataset_*

---
 dataset_multiwoz21.py |  7 +++----
 dataset_sim.py        | 30 +++++++++++++++++++++++++++++-
 2 files changed, 32 insertions(+), 5 deletions(-)

diff --git a/dataset_multiwoz21.py b/dataset_multiwoz21.py
index 269c2bb..e514638 100644
--- a/dataset_multiwoz21.py
+++ b/dataset_multiwoz21.py
@@ -423,10 +423,13 @@ def create_examples(input_file, acts_file, set_type, slot_list,
 
                 # Get dialog act annotations
                 inform_label = list(['none'])
+                inform_slot_dict[slot] = 0
                 if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict:
                     inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]])
+                    inform_slot_dict[slot] = 1
                 elif (str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1]) in sys_inform_dict:
                     inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1])]])
+                    inform_slot_dict[slot] = 1
 
                 (informed_value,
                  referred_slot,
@@ -440,10 +443,6 @@ def create_examples(input_file, acts_file, set_type, slot_list,
                                               slot_last_occurrence=True)
 
                 inform_dict[slot] = informed_value
-                if informed_value != 'none':
-                    inform_slot_dict[slot] = 1
-                else:
-                    inform_slot_dict[slot] = 0
 
                 # Generally don't use span prediction on sys utterance (but inform prediction instead).
                 sys_utt_tok_label = [0 for _ in sys_utt_tok]
diff --git a/dataset_sim.py b/dataset_sim.py
index 6685f4e..575a29f 100644
--- a/dataset_sim.py
+++ b/dataset_sim.py
@@ -22,6 +22,29 @@ import json
 from utils_dst import (DSTExample)
 
 
+# Loads the dialogue_acts.json and returns a list
+# of slot-value pairs.
+def load_acts(input_file):
+    with open(input_file) as f:
+        acts = json.load(f)
+    s_dict = {}
+    for d in acts:
+        d_id = d["dialogue_id"]
+        for t_id, t in enumerate(d["turns"]):
+            # Only process, if turn has annotation
+            if "system_acts" in t:
+                for a in t["system_acts"]:
+                    if "value" in a:
+                        key = d_id, t_id, a["slot"]
+                        # In case of multiple mentioned values...
+                        # ... Option 1: Keep first informed value
+                        if key not in s_dict:
+                            s_dict[key] = a["value"]
+                        # ... Option 2: Keep last informed value
+                        #s_dict[key] = a["value"]
+    return s_dict
+
+
 def dialogue_state_to_sv_dict(sv_list):
     sv_dict = {}
     for d in sv_list:
@@ -98,7 +121,7 @@ def delex_utt(utt, values):
     return utt_delex
     
 
-def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id,
+def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, sys_inform_dict,
                    delexicalize_sys_utts=False, slot_last_occurrence=True):
     """Make turn_label a dictionary of slot with value positions or being dontcare / none:
     Turn label contains:
@@ -126,6 +149,7 @@ def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id,
             slot_last_occurrence=slot_last_occurrence)
         if sum(sys_utt_tok_label) > 0:
             inform_label_dict[slot_type] = cur_ds_dict[slot_type]
+        if (dial_id, turn_id, slot_type) in sys_inform_dict:
             inform_slot_label_dict[slot_type] = 1
         sys_utt_tok_label = [0 for _ in sys_utt_tok_label] # Don't use token labels for sys utt
         sys_utt_tok_label_dict[slot_type] = sys_utt_tok_label
@@ -150,6 +174,9 @@ def create_examples(input_file, set_type, slot_list,
                     delexicalize_sys_utts=False,
                     analyze=False):
     """Read a DST json file into a list of DSTExample."""
+
+    sys_inform_dict = load_acts(input_file)
+
     with open(input_file, "r", encoding='utf-8') as reader:
         input_data = json.load(reader)
 
@@ -178,6 +205,7 @@ def create_examples(input_file, set_type, slot_list,
                                            slot_list,
                                            dial_id,
                                            turn_id,
+                                           sys_inform_dict,
                                            delexicalize_sys_utts=delexicalize_sys_utts,
                                            slot_last_occurrence=True)
 
-- 
GitLab