From 732944e24eb53c997f9961960fff4fc286f9bf37 Mon Sep 17 00:00:00 2001
From: heck <heckmi@hhu.de>
Date: Fri, 13 Nov 2020 10:01:40 +0000
Subject: [PATCH] fixes for dataset_*

---
 dataset_multiwoz21.py | 15 +++++++--------
 dataset_sim.py        | 21 +++++++++++----------
 dataset_woz2.py       |  9 +++++----
 3 files changed, 23 insertions(+), 22 deletions(-)

diff --git a/dataset_multiwoz21.py b/dataset_multiwoz21.py
index 9751f5c..269c2bb 100644
--- a/dataset_multiwoz21.py
+++ b/dataset_multiwoz21.py
@@ -264,6 +264,7 @@ def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, se
         class_type = value_label
     else:
         in_usr, usr_pos = check_label_existence(value_label, usr_utt_tok)
+        is_informed, informed_value = check_slot_inform(value_label, inform_label)
         if in_usr:
             class_type = 'copy_value'
             if slot_last_occurrence:
@@ -274,16 +275,14 @@ def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, se
                 for (s, e) in usr_pos:
                     for i in range(s, e):
                         usr_utt_tok_label[i] = 1
+        elif is_informed:
+            class_type = 'inform'
         else:
-            is_informed, informed_value = check_slot_inform(value_label, inform_label)
-            if is_informed:
-                class_type = 'inform'
+            referred_slot = check_slot_referral(value_label, slot, seen_slots)
+            if referred_slot != 'none':
+                class_type = 'refer'
             else:
-                referred_slot = check_slot_referral(value_label, slot, seen_slots)
-                if referred_slot != 'none':
-                    class_type = 'refer'
-                else:
-                    class_type = 'unpointable'
+                class_type = 'unpointable'
     return informed_value, referred_slot, usr_utt_tok_label, class_type
 
 
diff --git a/dataset_sim.py b/dataset_sim.py
index e30ed09..6685f4e 100644
--- a/dataset_sim.py
+++ b/dataset_sim.py
@@ -62,23 +62,24 @@ def get_tok_label(prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok,
             for label_d in usr_slot_label:
                 if label_d['slot'] == slot_type and value == ' '.join(
                         usr_utt_tok[label_d['start']:label_d['exclusive_end']]):
-
                     for idx in range(label_d['start'], label_d['exclusive_end']):
                         usr_utt_tok_label[idx] = 1
                     in_usr = True
                     class_type = 'copy_value'
                     if slot_last_occurrence:
                         break
-            if not in_usr or not slot_last_occurrence:
-                for label_d in sys_slot_label:
-                    if label_d['slot'] == slot_type and value == ' '.join(
-                            sys_utt_tok[label_d['start']:label_d['exclusive_end']]):
-                        for idx in range(label_d['start'], label_d['exclusive_end']):
-                            sys_utt_tok_label[idx] = 1
-                        in_sys = True
+
+            for label_d in sys_slot_label:
+                if label_d['slot'] == slot_type and value == ' '.join(
+                        sys_utt_tok[label_d['start']:label_d['exclusive_end']]):
+                    for idx in range(label_d['start'], label_d['exclusive_end']):
+                        sys_utt_tok_label[idx] = 1
+                    in_sys = True
+                    if not in_usr or not slot_last_occurrence:
                         class_type = 'inform'
-                        if slot_last_occurrence:
-                            break
+                    if slot_last_occurrence:
+                        break
+
             if not in_usr and not in_sys:
                 assert sum(usr_utt_tok_label + sys_utt_tok_label) == 0
                 if (slot_type not in prev_ds_dict or value != prev_ds_dict[slot_type]):
diff --git a/dataset_woz2.py b/dataset_woz2.py
index c32af05..fcb802f 100644
--- a/dataset_woz2.py
+++ b/dataset_woz2.py
@@ -87,7 +87,7 @@ def get_turn_label(label, sys_utt_tok, usr_utt_tok, slot_last_occurrence):
             class_type = 'inform'
         else:
             class_type = 'unpointable'
-    return usr_utt_tok_label, class_type
+    return usr_utt_tok_label, class_type, in_sys
 
 
 def tokenize(utt):
@@ -174,15 +174,16 @@ def create_examples(input_file, set_type, slot_list,
                     label = diag_seen_slots_value_dict[slot]
 
                 (usr_utt_tok_label,
-                 class_type) = get_turn_label(label,
+                 class_type,
+                 is_informed) = get_turn_label(label,
                                               sys_utt_tok,
                                               usr_utt_tok,
                                               slot_last_occurrence=True)
 
                 if class_type == 'inform':
                     inform_dict[slot] = label
-                    if label != 'none':
-                        inform_slot_dict[slot] = 1
+                if is_informed and label != 'none':
+                    inform_slot_dict[slot] = 1
 
                 referral_dict[slot] = 'none' # Referral is not present in woz2 data
 
-- 
GitLab