From 9cc922664b084bed266f3ffa4c3142b31cca354c Mon Sep 17 00:00:00 2001
From: Michael Heck <heckmi@hhu.de>
Date: Fri, 11 Aug 2023 10:00:40 +0000
Subject: [PATCH] Fix in spanless training with none tag targets

---
 DO.example.spanless | 7 +++----
 run_dst.py          | 4 ++++
 utils_dst.py        | 2 +-
 3 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/DO.example.spanless b/DO.example.spanless
index 564b77e..87a703e 100644
--- a/DO.example.spanless
+++ b/DO.example.spanless
@@ -58,15 +58,15 @@ for x in ${SEEDS}; do
 	    if [ "$phase" = 0 ]; then
 		ep=50
 		warmup=0.1
-		args_add_0="--no_append_history"
+		args_add_0="--no_append_history --tag_none_target"
 	    fi
 	    args_add_1=""
 	    if [ "$phase" = 1 ]; then
-		args_add_1="--no_append_history"
+		args_add_1="--no_append_history --tag_none_target"
 	    fi
 	    args_add_2=""
 	    if [ "$phase" = 2 ]; then
-		args_add_2="--cache_suffix=_auto_${x}"
+		args_add_2="--use_none_target_tags --cache_suffix=_auto_${x}"
 	    fi
 
 	    python3 run_dst.py \
@@ -92,7 +92,6 @@ for x in ${SEEDS}; do
 		--fp16 \
 		--value_matching_weight=${VALUE_MATCHING_WEIGHT} \
 		--none_weight=0.1 \
-		--tag_none_target \
 		--use_td \
 		--td_ratio=0.2 \
 		--training_phase=${phase} \
diff --git a/run_dst.py b/run_dst.py
index 0f59253..99eacc7 100644
--- a/run_dst.py
+++ b/run_dst.py
@@ -191,6 +191,8 @@ def main():
     # Spanless training
     parser.add_argument('--tag_none_target', action='store_true',
                         help="Use <none>/[NONE] as target when tagging negative samples")
+    parser.add_argument('--use_none_target_tags', action='store_true',
+                        help="Use <none>/[NONE] as target during spanless training")
     parser.add_argument("--rand_seq_max_len", type=int, default=4,
                         help="Maximum length of random sequences for proto DST training")
     parser.add_argument("--proto_neg_sample_ratio", type=float, default=0.1,
@@ -208,6 +210,8 @@ def main():
     assert args.td_ratio >= 0.0 and args.td_ratio <= 1.0
     assert args.proto_neg_sample_ratio >= 0.0 and args.proto_neg_sample_ratio <= 1.0
     assert args.training_phase in [-1, 0, 1, 2]
+    assert not args.tag_none_target or args.training_phase in [0, 1]
+    assert not args.use_none_target_tags or args.training_phase == 2
 
     task_name = args.task_name.lower()
     if task_name not in PROCESSORS:
diff --git a/utils_dst.py b/utils_dst.py
index 53a833d..00728dd 100644
--- a/utils_dst.py
+++ b/utils_dst.py
@@ -905,7 +905,7 @@ class TrippyDataset(Dataset):
             if automatic_labels is not None:
                 for slot in slot_list:
                     # Case where <none> target was used during pre-training/tagging
-                    if self.args.tag_none_target:
+                    if self.args.use_none_target_tags:
                         if model_specs['MODEL_TYPE'] == 'roberta':
                             a_start = 4
                         else:
-- 
GitLab