From 9cc922664b084bed266f3ffa4c3142b31cca354c Mon Sep 17 00:00:00 2001 From: Michael Heck <heckmi@hhu.de> Date: Fri, 11 Aug 2023 10:00:40 +0000 Subject: [PATCH] Fix in spanless training with none tag targets --- DO.example.spanless | 7 +++---- run_dst.py | 4 ++++ utils_dst.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/DO.example.spanless b/DO.example.spanless index 564b77e..87a703e 100644 --- a/DO.example.spanless +++ b/DO.example.spanless @@ -58,15 +58,15 @@ for x in ${SEEDS}; do if [ "$phase" = 0 ]; then ep=50 warmup=0.1 - args_add_0="--no_append_history" + args_add_0="--no_append_history --tag_none_target" fi args_add_1="" if [ "$phase" = 1 ]; then - args_add_1="--no_append_history" + args_add_1="--no_append_history --tag_none_target" fi args_add_2="" if [ "$phase" = 2 ]; then - args_add_2="--cache_suffix=_auto_${x}" + args_add_2="--use_none_target_tags --cache_suffix=_auto_${x}" fi python3 run_dst.py \ @@ -92,7 +92,6 @@ for x in ${SEEDS}; do --fp16 \ --value_matching_weight=${VALUE_MATCHING_WEIGHT} \ --none_weight=0.1 \ - --tag_none_target \ --use_td \ --td_ratio=0.2 \ --training_phase=${phase} \ diff --git a/run_dst.py b/run_dst.py index 0f59253..99eacc7 100644 --- a/run_dst.py +++ b/run_dst.py @@ -191,6 +191,8 @@ def main(): # Spanless training parser.add_argument('--tag_none_target', action='store_true', help="Use <none>/[NONE] as target when tagging negative samples") + parser.add_argument('--use_none_target_tags', action='store_true', + help="Use <none>/[NONE] as target during spanless training") parser.add_argument("--rand_seq_max_len", type=int, default=4, help="Maximum length of random sequences for proto DST training") parser.add_argument("--proto_neg_sample_ratio", type=float, default=0.1, @@ -208,6 +210,8 @@ def main(): assert args.td_ratio >= 0.0 and args.td_ratio <= 1.0 assert args.proto_neg_sample_ratio >= 0.0 and args.proto_neg_sample_ratio <= 1.0 assert args.training_phase in [-1, 0, 1, 2] + assert not args.tag_none_target or args.training_phase in [0, 1] + assert not args.use_none_target_tags or args.training_phase == 2 task_name = args.task_name.lower() if task_name not in PROCESSORS: diff --git a/utils_dst.py b/utils_dst.py index 53a833d..00728dd 100644 --- a/utils_dst.py +++ b/utils_dst.py @@ -905,7 +905,7 @@ class TrippyDataset(Dataset): if automatic_labels is not None: for slot in slot_list: # Case where <none> target was used during pre-training/tagging - if self.args.tag_none_target: + if self.args.use_none_target_tags: if model_specs['MODEL_TYPE'] == 'roberta': a_start = 4 else: -- GitLab