diff --git a/DO.example.recommended b/DO.example.recommended new file mode 100755 index 0000000000000000000000000000000000000000..12491b394d364bc779bc32310dd67dbae0ec4fec --- /dev/null +++ b/DO.example.recommended @@ -0,0 +1,62 @@ +#!/bin/bash + +# Parameters ------------------------------------------------------ + +#TASK="sim-m" +#DATA_DIR="data/simulated-dialogue/sim-M" +#TASK="sim-r" +#DATA_DIR="data/simulated-dialogue/sim-R" +#TASK="woz2" +#DATA_DIR="data/woz2" +TASK="multiwoz21" +DATA_DIR="data/MULTIWOZ2.1" + +# Project paths etc. ---------------------------------------------- + +OUT_DIR=results +mkdir -p ${OUT_DIR} + +# Main ------------------------------------------------------------ + +for step in train dev test; do + args_add="" + if [ "$step" = "train" ]; then + args_add="--do_train --predict_type=dummy" # INFO: For sim-M, we recommend to add "--svd=0.3" + elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then + args_add="--do_eval --predict_type=${step}" + fi + + python3 run_dst.py \ + --task_name=${TASK} \ + --data_dir=${DATA_DIR} \ + --dataset_config=dataset_config/${TASK}.json \ + --model_type="roberta" \ + --model_name_or_path="roberta-base" \ + --do_lower_case \ + --learning_rate=1e-4 \ + --num_train_epochs=10 \ + --max_seq_length=180 \ + --per_gpu_train_batch_size=16 \ + --per_gpu_eval_batch_size=16 \ + --output_dir=${OUT_DIR} \ + --save_epochs=2 \ + --logging_steps=10 \ + --warmup_proportion=0.1 \ + --eval_all_checkpoints \ + --adam_epsilon=1e-6 \ + --weight_decay=0.01 \ + --label_value_repetitions \ + --swap_utterances \ + --append_history \ + --use_history_labels \ + ${args_add} \ + 2>&1 | tee ${OUT_DIR}/${step}.log + + if [ "$step" = "dev" ] || [ "$step" = "test" ]; then + python3 metric_bert_dst.py \ + ${TASK} \ + dataset_config/${TASK}.json \ + "${OUT_DIR}/pred_res.${step}*json" \ + 2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log + fi +done diff --git a/dataset_config/multiwoz21.json b/dataset_config/multiwoz21.json index c420933e40a8ab743700affbdd9b8fe15d11e63a..94f04066669dc8ad821540e65ff36ece4efdd3e3 100644 --- a/dataset_config/multiwoz21.json +++ b/dataset_config/multiwoz21.json @@ -160,6 +160,7 @@ "twelve" ], "architecture": [ + "architectures", "architectural", "architecturally", "architect" @@ -206,7 +207,8 @@ "entertaining" ], "gallery": [ - "museum" + "museum", + "galleries" ], "gastropubs": [ "gastropub" @@ -230,14 +232,23 @@ "club", "clubs" ], + "nightclub": [ + "night club", + "night clubs", + "nightclubs", + "club", + "clubs" + ], "park": [ "parks" ], "pool": [ "swimming pool", + "swimming pools", "swimming", "pools", "swimmingpool", + "swimmingpools", "water", "swim" ], @@ -253,6 +264,7 @@ "pool", "pools", "swimmingpool", + "swimmingpools", "water", "swim" ], @@ -484,7 +496,9 @@ ], "saint catharines college": [ "saint catharine's college", - "saint catharine's" + "saint catharine's", + "saint catherine's college", + "saint catherine's" ], "saint johns college": [ "saint john's college", diff --git a/dataset_config/woz2.json b/dataset_config/woz2.json index 361446e7c9936f99643e8e91ba449578fa26c868..9220a219fb84c807da7e3b13155ecb9d8302abc0 100644 --- a/dataset_config/woz2.json +++ b/dataset_config/woz2.json @@ -18,6 +18,13 @@ "down town", "middle" ], + "centre": [ + "center", + "downtown", + "central", + "down town", + "middle" + ], "south": [ "southern", "southside" @@ -89,7 +96,8 @@ "upscale", "nice", "fine dining", - "expensively priced" + "expensively priced", + "not some cheapie" ], "afghan": [ "afghanistan" diff --git a/dataset_multiwoz21.py b/dataset_multiwoz21.py index 92bdd2c57ad778134af66476dfac1d963df9c81b..d842f68224c073c20441fe43a8dc09030745d06d 100644 --- a/dataset_multiwoz21.py +++ b/dataset_multiwoz21.py @@ -46,6 +46,21 @@ ACTS_DICT = {'taxi-depart': 'taxi-departure', 'booking-day': 'booking-book_day', 'booking-stay': 'booking-book_stay', 'booking-time': 'booking-book_time', + 'taxi-leaveat': 'taxi-leaveAt', + 'taxi-arriveby': 'taxi-arriveBy', + 'train-leaveat': 'train-leaveAt', + 'train-arriveby': 'train-arriveBy', + 'train-bookpeople': 'train-book_people', + 'restaurant-bookpeople': 'restaurant-book_people', + 'restaurant-bookday': 'restaurant-book_day', + 'restaurant-booktime': 'restaurant-book_time', + 'hotel-bookpeople': 'hotel-book_people', + 'hotel-bookday': 'hotel-book_day', + 'hotel-bookstay': 'hotel-book_stay', + 'booking-bookpeople': 'booking-book_people', + 'booking-bookday': 'booking-book_day', + 'booking-bookstay': 'booking-book_stay', + 'booking-booktime': 'booking-book_time' } @@ -109,12 +124,25 @@ def normalize_text(text): text = re.sub("guesthouse", "guest house", text) # Normalization text = re.sub("(^| )b ?& ?b([.,? ]|$)", r"\1bed and breakfast\2", text) # Normalization text = re.sub("bed & breakfast", "bed and breakfast", text) # Normalization + text = re.sub("\t", " ", text) # Error + text = re.sub("\n", " ", text) # Error return text # This should only contain label normalizations. All other mappings should # be defined in LABEL_MAPS. def normalize_label(slot, value_label): + # Normalization of capitalization + if isinstance(value_label, str): + value_label = value_label.lower().strip() + elif isinstance(value_label, list): + if len(value_label) > 1: + value_label = value_label[0] # TODO: Workaround. Note that Multiwoz 2.2 supports variants directly in the labels. + elif len(value_label) == 1: + value_label = value_label[0] + elif len(value_label) == 0: + value_label = "" + # Normalization of empty slots if value_label == '' or value_label == "not mentioned": return "none" @@ -125,6 +153,7 @@ def normalize_label(slot, value_label): # Normalization if "type" in slot or "name" in slot or "destination" in slot or "departure" in slot: + value_label = re.sub(" ?'s", "s", value_label) value_label = re.sub("guesthouse", "guest house", value_label) # Map to boolean slots @@ -134,6 +163,8 @@ def normalize_label(slot, value_label): if value_label == "no": return "false" if slot == 'hotel-type': + if value_label in ["bed and breakfast", "guest houses"]: + value_label = "guest house" if value_label == "hotel": return "true" if value_label == "guest house": @@ -289,10 +320,14 @@ def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, se def tokenize(utt): utt_lower = convert_to_unicode(utt).lower() utt_lower = normalize_text(utt_lower) - utt_tok = [tok for tok in map(str.strip, re.split("(\W+)", utt_lower)) if len(tok) > 0] + utt_tok = utt_to_token(utt_lower) return utt_tok +def utt_to_token(utt): + return [tok for tok in map(lambda x: re.sub(" ", "", x), re.split("(\W+)", utt)) if len(tok) > 0] + + def create_examples(input_file, acts_file, set_type, slot_list, label_maps={}, append_history=False, diff --git a/dataset_sim.py b/dataset_sim.py index e8caa3b1e486c581e3bfcfc205be4b384d71a305..43fe298dfbeecb3738673565dcf520c1ec108145 100644 --- a/dataset_sim.py +++ b/dataset_sim.py @@ -34,7 +34,7 @@ def load_acts(input_file): # Only process, if turn has annotation if "system_acts" in t: for a in t["system_acts"]: - if "value" in a: + if "value" in a and a["type"] not in ["NEGATE", "NOTIFY_FAILURE"]: key = d_id, t_id, a["slot"] # In case of multiple mentioned values... # ... Option 1: Keep first informed value diff --git a/dataset_woz2.py b/dataset_woz2.py index 3c605d8bd7d4c3a95db1965c846e5ea11bbdfc78..d13b31720e04b0feb56dcc1077b28d64809f7349 100644 --- a/dataset_woz2.py +++ b/dataset_woz2.py @@ -24,7 +24,7 @@ from utils_dst import (DSTExample, convert_to_unicode) LABEL_MAPS = {} # Loaded from file -LABEL_FIX = {'centre': 'center', 'areas': 'area', 'phone number': 'number', 'price range': 'price_range'} +LABEL_FIX = {'areas': 'area', 'phone number': 'number', 'price range': 'price_range', 'center': 'centre', 'east side': 'east', 'corsican': 'corsica'} def delex_utt(utt, values, unk_token="[UNK]"): @@ -93,10 +93,14 @@ def get_turn_label(label, sys_utt_tok, usr_utt_tok, slot_last_occurrence): def tokenize(utt): utt_lower = convert_to_unicode(utt).lower() - utt_tok = [tok for tok in map(str.strip, re.split("(\W+)", utt_lower)) if len(tok) > 0] + utt_tok = utt_to_token(utt_lower) return utt_tok +def utt_to_token(utt): + return [tok for tok in map(lambda x: re.sub(" ", "", x), re.split("(\W+)", utt)) if len(tok) > 0] + + def create_examples(input_file, set_type, slot_list, label_maps={}, append_history=False,