diff --git a/DO.example.spanless b/DO.example.spanless index 564b77e01cfb6a8d81e3dff6b999296ffb687458..87a703ee0e63c7c05093524a3168d11880626c3d 100644 --- a/DO.example.spanless +++ b/DO.example.spanless @@ -58,15 +58,15 @@ for x in ${SEEDS}; do if [ "$phase" = 0 ]; then ep=50 warmup=0.1 - args_add_0="--no_append_history" + args_add_0="--no_append_history --tag_none_target" fi args_add_1="" if [ "$phase" = 1 ]; then - args_add_1="--no_append_history" + args_add_1="--no_append_history --tag_none_target" fi args_add_2="" if [ "$phase" = 2 ]; then - args_add_2="--cache_suffix=_auto_${x}" + args_add_2="--use_none_target_tags --cache_suffix=_auto_${x}" fi python3 run_dst.py \ @@ -92,7 +92,6 @@ for x in ${SEEDS}; do --fp16 \ --value_matching_weight=${VALUE_MATCHING_WEIGHT} \ --none_weight=0.1 \ - --tag_none_target \ --use_td \ --td_ratio=0.2 \ --training_phase=${phase} \ diff --git a/run_dst.py b/run_dst.py index 0f5925345f545175f6166a22e2b4d0f3356a772e..99eacc722602cae18161b46adb841085dbebe773 100644 --- a/run_dst.py +++ b/run_dst.py @@ -191,6 +191,8 @@ def main(): # Spanless training parser.add_argument('--tag_none_target', action='store_true', help="Use <none>/[NONE] as target when tagging negative samples") + parser.add_argument('--use_none_target_tags', action='store_true', + help="Use <none>/[NONE] as target during spanless training") parser.add_argument("--rand_seq_max_len", type=int, default=4, help="Maximum length of random sequences for proto DST training") parser.add_argument("--proto_neg_sample_ratio", type=float, default=0.1, @@ -208,6 +210,8 @@ def main(): assert args.td_ratio >= 0.0 and args.td_ratio <= 1.0 assert args.proto_neg_sample_ratio >= 0.0 and args.proto_neg_sample_ratio <= 1.0 assert args.training_phase in [-1, 0, 1, 2] + assert not args.tag_none_target or args.training_phase in [0, 1] + assert not args.use_none_target_tags or args.training_phase == 2 task_name = args.task_name.lower() if task_name not in PROCESSORS: diff --git a/utils_dst.py b/utils_dst.py index 53a833db41c3d338a69882d312022b87bbf8fce6..00728dd0488a6f22dc4e315eb6bbabcc4a0804b4 100644 --- a/utils_dst.py +++ b/utils_dst.py @@ -905,7 +905,7 @@ class TrippyDataset(Dataset): if automatic_labels is not None: for slot in slot_list: # Case where <none> target was used during pre-training/tagging - if self.args.tag_none_target: + if self.args.use_none_target_tags: if model_specs['MODEL_TYPE'] == 'roberta': a_start = 4 else: