From abded2cbd69958030a17ac46746679bedbf6cd99 Mon Sep 17 00:00:00 2001
From: Michael Heck <michael.heck@hhu.de>
Date: Thu, 17 Feb 2022 13:54:43 +0100
Subject: [PATCH] Fixes for mwoz 2.0 and 2.2 compatibility

---
 dataset_multiwoz21.py | 26 ++++++++++++++++++++++----
 1 file changed, 22 insertions(+), 4 deletions(-)

diff --git a/dataset_multiwoz21.py b/dataset_multiwoz21.py
index d842f68..0f04afd 100644
--- a/dataset_multiwoz21.py
+++ b/dataset_multiwoz21.py
@@ -77,10 +77,18 @@ def load_acts(input_file):
         for t in acts[d]:
             # Only process, if turn has annotation
             if isinstance(acts[d][t], dict):
-                for a in acts[d][t]:
+                is_22_format = False
+                if 'dialog_act' in acts[d][t]:
+                    is_22_format = True
+                    acts_list = acts[d][t]['dialog_act']
+                    if int(t) % 2 == 0:
+                        continue
+                else:
+                    acts_list = acts[d][t]
+                for a in acts_list:
                     aa = a.lower().split('-')
-                    if aa[1] == 'inform' or aa[1] == 'recommend' or aa[1] == 'select' or aa[1] == 'book':
-                        for i in acts[d][t][a]:
+                    if aa[1] in ['inform', 'recommend', 'select', 'book']:
+                        for i in acts_list[a]:
                             s = i[0].lower()
                             v = i[1].lower().strip()
                             if s == 'none' or v == '?' or v == 'none':
@@ -88,7 +96,13 @@ def load_acts(input_file):
                             slot = aa[0] + '-' + s
                             if slot in ACTS_DICT:
                                 slot = ACTS_DICT[slot]
-                            key = d + '.json', t, slot
+                            if is_22_format:
+                                t_key = str(int(int(t) / 2 + 1))
+                                d_key = d
+                            else:
+                                t_key = t
+                                d_key = d + '.json'
+                            key = d_key, t_key, slot
                             # In case of multiple mentioned values...
                             # ... Option 1: Keep first informed value
                             if key not in s_dict:
@@ -147,6 +161,10 @@ def normalize_label(slot, value_label):
     if value_label == '' or value_label == "not mentioned":
         return "none"
 
+    # Normalization of 'dontcare'
+    if value_label == 'dont care':
+        return "dontcare"
+
     # Normalization of time slots
     if "leaveAt" in slot or "arriveBy" in slot or slot == 'restaurant-book_time':
         return normalize_time(value_label)
-- 
GitLab