#!/bin/bash

# Parameters ------------------------------------------------------

# --- Sim-M dataset
#TASK="sim-m"
#DATA_DIR="data/simulated-dialogue/sim-M"
#DATASET_CONFIG="dataset_config/sim-m.json"
# --- Sim-R dataset
#TASK="sim-r"
#DATA_DIR="data/simulated-dialogue/sim-R"
#DATASET_CONFIG="dataset_config/sim-r.json"
# --- WOZ 2.0 dataset
#TASK="woz2"
#DATA_DIR="data/woz2"
#DATASET_CONFIG="dataset_config/woz2.json"
# --- MultiWOZ 2.1 legacy version dataset
#TASK="multiwoz21_legacy"
#DATA_DIR="data/MULTIWOZ2.1"
#DATASET_CONFIG="dataset_config/multiwoz21.json"
# --- MultiWOZ 2.1 dataset
TASK="multiwoz21"
DATA_DIR="data/multiwoz/data/MultiWOZ_2.1"
DATASET_CONFIG="dataset_config/multiwoz21.json"
# --- MultiWOZ 2.1 in ConvLab3's unified data format
#TASK="unified"
#DATA_DIR=""
#DATASET_CONFIG="dataset_config/unified_multiwoz21.json"

SEEDS="42"
TRAIN_PHASES="-1" # -1: regular training, 0: proto training, 1: tagging, 2: spanless training
VALUE_MATCHING_WEIGHT=0.1 # When 0.0, value matching is not used

# Project paths etc. ----------------------------------------------

OUT_DIR=results
for x in ${SEEDS}; do
    mkdir -p ${OUT_DIR}.${x}
done

# Main ------------------------------------------------------------

for x in ${SEEDS}; do
    for step in train dev test; do
	args_add=""
        phases="-1"
	if [ "$step" = "train" ]; then
	    args_add="--do_train --predict_type=dev --svd=0.1 --hd=0.1"
	    phases=${TRAIN_PHASES}
	elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
	    args_add="--do_eval --predict_type=${step}"
	fi

        for phase in ${phases}; do
	    args_add_0=""
	    if [ "$phase" = 0 ]; then
		args_add_0=""
	    fi
	    args_add_1=""
	    if [ "$phase" = 1 ]; then
		args_add_1=""
	    fi
	    args_add_2=""
	    if [ "$phase" = 2 ]; then
		args_add_2=""
	    fi

	    python3 run_dst.py \
		--task_name=${TASK} \
		--data_dir=${DATA_DIR} \
		--dataset_config=${DATASET_CONFIG} \
		--model_type="roberta" \
		--model_name_or_path="roberta-base" \
		--seed=${x} \
		--do_lower_case \
		--learning_rate=5e-5 \
		--num_train_epochs=20 \
		--max_seq_length=180 \
		--per_gpu_train_batch_size=32 \
		--per_gpu_eval_batch_size=32 \
		--output_dir=${OUT_DIR}.${x} \
		--patience=10 \
		--evaluate_during_training \
		--eval_all_checkpoints \
		--warmup_proportion=0.05 \
		--adam_epsilon=1e-6 \
		--weight_decay=0.01 \
		--fp16 \
		--value_matching_weight=${VALUE_MATCHING_WEIGHT} \
		--none_weight=0.1 \
		--use_td \
		--td_ratio=0.2 \
		--training_phase=${phase} \
		${args_add} \
		${args_add_0} \
		${args_add_1} \
		${args_add_2} \
		2>&1 | tee ${OUT_DIR}.${x}/${step}.${phase}.log
	done

	if [ "$step" = "dev" -o "$step" = "test" ]; then
	    confidence=1.0
	    if [[ ${VALUE_MATCHING_WEIGHT} > 0.0 ]]; then
		confidence="1.0 0.9 0.8 0.7 0.6 0.5"
	    fi
	    for dist_conf_threshold in ${confidence}; do
		python3 metric_dst.py \
		    --dataset_config=${DATASET_CONFIG} \
		    --confidence_threshold=${dist_conf_threshold} \
		    --file_list="${OUT_DIR}.${x}/pred_res.${step}*json" \
		    2>&1 | tee ${OUT_DIR}.${x}/eval_pred_${step}.${dist_conf_threshold}.log
	    done
	fi
    done
done