diff --git a/run_dst_mtl.py b/run_dst_mtl.py index c077c7bd253c7de5f36117e10fe36dca6f064b5e..8e50306eefaf01ae49a88895f1627a5f7c6025d2 100644 --- a/run_dst_mtl.py +++ b/run_dst_mtl.py @@ -752,12 +752,11 @@ def main(): help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." "See details at https://nvidia.github.io/apex/amp.html") - # TODO parser.add_argument('--mtl_use', action='store_true', help="") - parser.add_argument('--mtl_task_def', type=str, default="/home/heckmi_hhu/tools/trippy/aux_task_def.json", help="") # TODO + parser.add_argument('--mtl_task_def', type=str, default="aux_task_def.json", help="") parser.add_argument('--mtl_train_dataset', type=str, default="", help="cola|mnli|mrpc|qnli|qqp|rte|sst|wnli|squad|squad-v2") - parser.add_argument("--mtl_data_dir", type=str, default="/home/heckmi_hhu/data/glue/canonical_data/bert_base_uncased_lower", help="") # TODO - parser.add_argument("--mtl_ratio", type=float, default=1.0, help="") # TODO + parser.add_argument("--mtl_data_dir", type=str, default="data/aux/bert_base_uncased_lower", help="") + parser.add_argument("--mtl_ratio", type=float, default=1.0, help="") parser.add_argument("--mtl_diff_window", type=int, default=10) parser.add_argument('--mtl_print_loss_diff', action='store_true', help="")