From 92501f68f55be7d9242bed73ae1564b9fa6fda01 Mon Sep 17 00:00:00 2001 From: Michael Heck <heckmi_hhu@heckmi-dev-pytorch110-cpu-1.europe-west1-b.c.hhu-cs-ds-prod.internal> Date: Tue, 8 Aug 2023 20:02:22 +0000 Subject: [PATCH] initial commit --- .gitattributes | 118 ++ DO.example | 115 ++ DO.example.spanless | 120 +++ LICENSE | 201 ++++ README.md | 89 +- data_processors.py | 231 ++++ dataset_config/multiwoz21.json | 1375 ++++++++++++++++++++++++ dataset_config/sim-m.json | 20 + dataset_config/sim-r.json | 24 + dataset_config/unified_multiwoz21.json | 1346 +++++++++++++++++++++++ dataset_config/woz2.json | 228 ++++ dataset_multiwoz21.py | 680 ++++++++++++ dataset_multiwoz21_legacy.py | 358 ++++++ dataset_sim.py | 268 +++++ dataset_unified.py | 350 ++++++ dataset_woz2.py | 288 +++++ dst_proto.py | 341 ++++++ dst_tag.py | 188 ++++ dst_train.py | 763 +++++++++++++ metric_dst.py | 566 ++++++++++ modeling_dst.py | 415 +++++++ run_dst.py | 431 ++++++++ tensorlistdataset.py | 57 + utils_dst.py | 1056 ++++++++++++++++++ utils_run.py | 146 +++ 25 files changed, 9770 insertions(+), 4 deletions(-) create mode 100644 .gitattributes create mode 100644 DO.example create mode 100644 DO.example.spanless create mode 100644 LICENSE create mode 100644 data_processors.py create mode 100644 dataset_config/multiwoz21.json create mode 100644 dataset_config/sim-m.json create mode 100644 dataset_config/sim-r.json create mode 100644 dataset_config/unified_multiwoz21.json create mode 100644 dataset_config/woz2.json create mode 100644 dataset_multiwoz21.py create mode 100644 dataset_multiwoz21_legacy.py create mode 100644 dataset_sim.py create mode 100644 dataset_unified.py create mode 100644 dataset_woz2.py create mode 100644 dst_proto.py create mode 100644 dst_tag.py create mode 100644 dst_train.py create mode 100644 metric_dst.py create mode 100644 modeling_dst.py create mode 100644 run_dst.py create mode 100644 tensorlistdataset.py create mode 100644 utils_dst.py create mode 100644 utils_run.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..49b6d14 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,118 @@ +# Store binaries in LFS +## Custom paths +results/ filter=lfs diff=lfs merge=lfs -text +data/ filter=lfs diff=lfs merge=lfs -text + +## Archive/Compressed +*.7z filter=lfs diff=lfs merge=lfs -text +*.cpio filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.iso filter=lfs diff=lfs merge=lfs -text +*.bz filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.bzip filter=lfs diff=lfs merge=lfs -text +*.bzip2 filter=lfs diff=lfs merge=lfs -text +*.cab filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.gzip filter=lfs diff=lfs merge=lfs -text +*.lz filter=lfs diff=lfs merge=lfs -text +*.lzma filter=lfs diff=lfs merge=lfs -text +*.lzo filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.z filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.ace filter=lfs diff=lfs merge=lfs -text +*.dmg filter=lfs diff=lfs merge=lfs -text +*.dd filter=lfs diff=lfs merge=lfs -text +*.apk filter=lfs diff=lfs merge=lfs -text +*.ear filter=lfs diff=lfs merge=lfs -text +*.jar filter=lfs diff=lfs merge=lfs -text +*.deb filter=lfs diff=lfs merge=lfs -text +*.cue filter=lfs diff=lfs merge=lfs -text +*.dump filter=lfs diff=lfs merge=lfs -text + +## Image +*.jpg filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.psd filter=lfs diff=lfs merge=lfs -text +*.bmp filter=lfs diff=lfs merge=lfs -text +*.dng filter=lfs diff=lfs merge=lfs -text +*.cdr filter=lfs diff=lfs merge=lfs -text +*.indd filter=lfs diff=lfs merge=lfs -text +*.tiff filter=lfs diff=lfs merge=lfs -text +*.tif filter=lfs diff=lfs merge=lfs -text +*.psp filter=lfs diff=lfs merge=lfs -text +*.tga filter=lfs diff=lfs merge=lfs -text +*.eps filter=lfs diff=lfs merge=lfs -text +*.svg filter=lfs diff=lfs merge=lfs -text + +## Documents +*.pdf filter=lfs diff=lfs merge=lfs -text +*.doc filter=lfs diff=lfs merge=lfs -text +*.docx filter=lfs diff=lfs merge=lfs -text +*.xls filter=lfs diff=lfs merge=lfs -text +*.xlsx filter=lfs diff=lfs merge=lfs -text +*.ppt filter=lfs diff=lfs merge=lfs -text +*.pptx filter=lfs diff=lfs merge=lfs -text +*.ppz filter=lfs diff=lfs merge=lfs -text +*.dot filter=lfs diff=lfs merge=lfs -text +*.dotx filter=lfs diff=lfs merge=lfs -text +*.lwp filter=lfs diff=lfs merge=lfs -text +*.odm filter=lfs diff=lfs merge=lfs -text +*.odt filter=lfs diff=lfs merge=lfs -text +*.ott filter=lfs diff=lfs merge=lfs -text +*.ods filter=lfs diff=lfs merge=lfs -text +*.ots filter=lfs diff=lfs merge=lfs -text +*.odp filter=lfs diff=lfs merge=lfs -text +*.otp filter=lfs diff=lfs merge=lfs -text +*.odg filter=lfs diff=lfs merge=lfs -text +*.otg filter=lfs diff=lfs merge=lfs -text +*.wps filter=lfs diff=lfs merge=lfs -text +*.wpd filter=lfs diff=lfs merge=lfs -text +*.wpt filter=lfs diff=lfs merge=lfs -text +*.xps filter=lfs diff=lfs merge=lfs -text +*.ttf filter=lfs diff=lfs merge=lfs -text +*.otf filter=lfs diff=lfs merge=lfs -text +*.dvi filter=lfs diff=lfs merge=lfs -text +*.pages filter=lfs diff=lfs merge=lfs -text +*.key filter=lfs diff=lfs merge=lfs -text + +## Audio/Video +*.mpg filter=lfs diff=lfs merge=lfs -text +*.mpeg filter=lfs diff=lfs merge=lfs -text +*.mp3 filter=lfs diff=lfs merge=lfs -text +*.mp4 filter=lfs diff=lfs merge=lfs -text +*.avi filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text +*.mkv filter=lfs diff=lfs merge=lfs -text +*.3gp filter=lfs diff=lfs merge=lfs -text +*.flv filter=lfs diff=lfs merge=lfs -text +*.m4v filter=lfs diff=lfs merge=lfs -text +*.ogg filter=lfs diff=lfs merge=lfs -text +*.mov filter=lfs diff=lfs merge=lfs -text +*.wmv filter=lfs diff=lfs merge=lfs -text +*.webm filter=lfs diff=lfs merge=lfs -text + +## VM +*.vfd filter=lfs diff=lfs merge=lfs -text +*.vhd filter=lfs diff=lfs merge=lfs -text +*.vmdk filter=lfs diff=lfs merge=lfs -text +*.vmsd filter=lfs diff=lfs merge=lfs -text +*.vmsn filter=lfs diff=lfs merge=lfs -text +*.vmss filter=lfs diff=lfs merge=lfs -text +*.dsk filter=lfs diff=lfs merge=lfs -text +*.vdi filter=lfs diff=lfs merge=lfs -text +*.cow filter=lfs diff=lfs merge=lfs -text +*.qcow filter=lfs diff=lfs merge=lfs -text +*.qcow2 filter=lfs diff=lfs merge=lfs -text +*.qed filter=lfs diff=lfs merge=lfs -text + +## Other +*.exe filter=lfs diff=lfs merge=lfs -text +*.sxi filter=lfs diff=lfs merge=lfs -text +*.dat filter=lfs diff=lfs merge=lfs -text +*.data filter=lfs diff=lfs merge=lfs -text diff --git a/DO.example b/DO.example new file mode 100644 index 0000000..dfc3d72 --- /dev/null +++ b/DO.example @@ -0,0 +1,115 @@ +#!/bin/bash + +# Parameters ------------------------------------------------------ + +# --- Sim-M dataset +#TASK="sim-m" +#DATA_DIR="data/simulated-dialogue/sim-M" +#DATASET_CONFIG="dataset_config/sim-m.json" +# --- Sim-R dataset +#TASK="sim-r" +#DATA_DIR="data/simulated-dialogue/sim-R" +#DATASET_CONFIG="dataset_config/sim-r.json" +# --- WOZ 2.0 dataset +#TASK="woz2" +#DATA_DIR="data/woz2" +#DATASET_CONFIG="dataset_config/woz2.json" +# --- MultiWOZ 2.1 legacy version dataset +#TASK="multiwoz21_legacy" +#DATA_DIR="data/MULTIWOZ2.1" +#DATASET_CONFIG="dataset_config/multiwoz21.json" +# --- MultiWOZ 2.1 dataset +TASK="multiwoz21" +DATA_DIR="data/multiwoz/data/MultiWOZ_2.1" +DATASET_CONFIG="dataset_config/multiwoz21.json" +# --- MultiWOZ 2.1 in ConvLab3's unified data format +#TASK="unified" +#DATA_DIR="" +#DATASET_CONFIG="dataset_config/unified_multiwoz21.json" + +SEEDS="42" +TRAIN_PHASES="-1" # -1: regular training, 0: proto training, 1: tagging, 2: spanless training +VALUE_MATCHING_WEIGHT=0.1 # When 0.0, value matching is not used + +# Project paths etc. ---------------------------------------------- + +OUT_DIR=results +for x in ${SEEDS}; do + mkdir -p ${OUT_DIR}.${x} +done + +# Main ------------------------------------------------------------ + +for x in ${SEEDS}; do + for step in train dev test; do + args_add="" + phases="-1" + if [ "$step" = "train" ]; then + args_add="--do_train --predict_type=dev --svd=0.1 --hd=0.1" + phases=${TRAIN_PHASES} + elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then + args_add="--do_eval --predict_type=${step}" + fi + + for phase in ${phases}; do + args_add_0="" + if [ "$phase" = 0 ]; then + args_add_0="" + fi + args_add_1="" + if [ "$phase" = 1 ]; then + args_add_1="" + fi + args_add_2="" + if [ "$phase" = 2 ]; then + args_add_2="" + fi + + python3 run_dst.py \ + --task_name=${TASK} \ + --data_dir=${DATA_DIR} \ + --dataset_config=${DATASET_CONFIG} \ + --model_type="roberta" \ + --model_name_or_path="roberta-base" \ + --seed=${x} \ + --do_lower_case \ + --learning_rate=5e-5 \ + --num_train_epochs=20 \ + --max_seq_length=180 \ + --per_gpu_train_batch_size=32 \ + --per_gpu_eval_batch_size=32 \ + --output_dir=${OUT_DIR}.${x} \ + --patience=10 \ + --evaluate_during_training \ + --eval_all_checkpoints \ + --warmup_proportion=0.05 \ + --adam_epsilon=1e-6 \ + --weight_decay=0.01 \ + --fp16 \ + --value_matching_weight=${VALUE_MATCHING_WEIGHT} \ + --none_weight=0.1 \ + --use_td \ + --td_ratio=0.2 \ + --training_phase=${phase} \ + ${args_add} \ + ${args_add_0} \ + ${args_add_1} \ + ${args_add_2} \ + 2>&1 | tee ${OUT_DIR}.${x}/${step}.${phase}.log + done + + if [ "$step" = "dev" -o "$step" = "test" ]; then + confidence=1.0 + if [[ ${VALUE_MATCHING_WEIGHT} > 0.0 ]]; then + confidence="1.0 0.9 0.8 0.7 0.6 0.5" + fi + for dist_conf_threshold in ${confidence}; do + python3 metric_dst.py \ + --dataset_config=${DATASET_CONFIG} \ + --confidence_threshold=${dist_conf_threshold} \ + --file_list="${OUT_DIR}.${x}/pred_res.${step}*json" \ + 2>&1 | tee ${OUT_DIR}.${x}/eval_pred_${step}.${dist_conf_threshold}.log + done + fi + done +done diff --git a/DO.example.spanless b/DO.example.spanless new file mode 100644 index 0000000..564b77e --- /dev/null +++ b/DO.example.spanless @@ -0,0 +1,120 @@ +#!/bin/bash + +# Parameters ------------------------------------------------------ + +# --- Sim-M dataset +#TASK="sim-m" +#DATA_DIR="data/simulated-dialogue/sim-M" +#DATASET_CONFIG="dataset_config/sim-m.json" +# --- Sim-R dataset +#TASK="sim-r" +#DATA_DIR="data/simulated-dialogue/sim-R" +#DATASET_CONFIG="dataset_config/sim-r.json" +# --- WOZ 2.0 dataset +#TASK="woz2" +#DATA_DIR="data/woz2" +#DATASET_CONFIG="dataset_config/woz2.json" +# --- MultiWOZ 2.1 legacy version dataset +#TASK="multiwoz21_legacy" +#DATA_DIR="data/MULTIWOZ2.1" +#DATASET_CONFIG="dataset_config/multiwoz21.json" +# --- MultiWOZ 2.1 dataset +TASK="multiwoz21" +DATA_DIR="data/multiwoz/data/MultiWOZ_2.1" +DATASET_CONFIG="dataset_config/multiwoz21.json" +# --- MultiWOZ 2.1 in ConvLab3's unified data format +#TASK="unified" +#DATA_DIR="" +#DATASET_CONFIG="dataset_config/unified_multiwoz21.json" + +SEEDS="42" +TRAIN_PHASES="-1" # -1: regular training, 0: proto training, 1: tagging, 2: spanless training +VALUE_MATCHING_WEIGHT=0.1 # When 0.0, value matching is not used + +# Project paths etc. ---------------------------------------------- + +OUT_DIR=results +for x in ${SEEDS}; do + mkdir -p ${OUT_DIR}.${x} +done + +# Main ------------------------------------------------------------ + +for x in ${SEEDS}; do + for step in train dev test; do + args_add="" + phases="-1" + if [ "$step" = "train" ]; then + args_add="--do_train --predict_type=dev --svd=0.1 --hd=0.1" + phases=${TRAIN_PHASES} + elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then + args_add="--do_eval --predict_type=${step} --no_cache" + fi + + for phase in ${phases}; do + ep=20 + warmup=0.05 + args_add_0="" + if [ "$phase" = 0 ]; then + ep=50 + warmup=0.1 + args_add_0="--no_append_history" + fi + args_add_1="" + if [ "$phase" = 1 ]; then + args_add_1="--no_append_history" + fi + args_add_2="" + if [ "$phase" = 2 ]; then + args_add_2="--cache_suffix=_auto_${x}" + fi + + python3 run_dst.py \ + --task_name=${TASK} \ + --data_dir=${DATA_DIR} \ + --dataset_config=${DATASET_CONFIG} \ + --model_type="roberta" \ + --model_name_or_path="roberta-base" \ + --seed=${x} \ + --do_lower_case \ + --learning_rate=5e-5 \ + --num_train_epochs=${ep} \ + --max_seq_length=180 \ + --per_gpu_train_batch_size=32 \ + --per_gpu_eval_batch_size=32 \ + --output_dir=${OUT_DIR}.${x} \ + --patience=10 \ + --evaluate_during_training \ + --eval_all_checkpoints \ + --warmup_proportion=${warmup} \ + --adam_epsilon=1e-6 \ + --weight_decay=0.01 \ + --fp16 \ + --value_matching_weight=${VALUE_MATCHING_WEIGHT} \ + --none_weight=0.1 \ + --tag_none_target \ + --use_td \ + --td_ratio=0.2 \ + --training_phase=${phase} \ + ${args_add} \ + ${args_add_0} \ + ${args_add_1} \ + ${args_add_2} \ + 2>&1 | tee ${OUT_DIR}.${x}/${step}.${phase}.log + done + + if [ "$step" = "dev" -o "$step" = "test" ]; then + confidence=1.0 + if [[ ${VALUE_MATCHING_WEIGHT} > 0.0 ]]; then + confidence="1.0 0.9 0.8 0.7 0.6 0.5" + fi + for dist_conf_threshold in ${confidence}; do + python3 metric_dst.py \ + --dataset_config=${DATASET_CONFIG} \ + --confidence_threshold=${dist_conf_threshold} \ + --file_list="${OUT_DIR}.${x}/pred_res.${step}*json" \ + 2>&1 | tee ${OUT_DIR}.${x}/eval_pred_${step}.${dist_conf_threshold}.log + done + fi + done +done diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e38336a --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020 Heinrich Heine University Duesseldorf + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index fff1738..429681a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,87 @@ -# TripPy-R - Public +## Introduction -This is the code repository for our journal paper titled "Robust Dialogue State Tracking with Weak Supervision and Sparse Data", which is accepted for publication in TACL. -The pre-print is available at [arXiv](https://doi.org/10.48550/arXiv.2202.03354). +Generalising dialogue state tracking (DST) to new data is especially challenging due to the strong reliance on abundant and fine-grained supervision during training. Sample sparsity, distributional shift and the occurrence of new concepts and topics frequently lead to severe performance degradation during inference. TripPy-R (pronounced "Trippier"), robust triple copy strategy DST, can use a training strategy to build extractive DST models without the need for fine-grained manual span labels ("spanless training"). Further, two novel input-level dropout methods mitigate the negative impact of sample sparsity. TripPy-R uses a new model architecture with a unified encoder that supports value as well as slot independence by leveraging the attention mechanism, making it zero-shot capable. The framework combines the strengths of triple copy strategy DST and value matching to benefit from complementary predictions without violating the principle of ontology independence. In our paper we demonstrate that an extractive DST model can be trained without manual span labels. Our architecture and training strategies improve robustness towards sample sparsity, new concepts and topics, leading to state-of-the-art performance on a range of benchmarks. -We will publish the code here after our journal paper is published in TACL. +## Recent updates + +- 2023.08.08: Initial commit + +## How to run + +Two example scripts are provided for how to use TripPy-R. + +`DO.example` will train and evaluate a model with recommended settings with the default supervised training strategy. + +`DO.example.spanless` will train and evaluate a model with recommended settings with the novel spanless training strategy. The training consists of three steps: 1) Training a proto-DST that learns to tag the positions of queried subsequences in an input sequence. 2) Applying the proto-DST to tag the positions of slot-value occurrences in the training data. 3) Training the DST using the automatic labels produced by the previous step. + +See below table for expected performance per dataset and training strategy. Our scripts use the parameters that were used for experiments in our paper "Robust Dialogue State Tracking with Weak Supervision and Sparse Data". Thus, performance will be similar to the reported ones. For more challenging datasets with longer dialogues, better performance may be achieved by using the maximum sequence length of 512. + +## Trouble-shooting + +When conducting spanless training, the training of the proto-DST (Step 1 of 3, see above) is rather sensitive to the training hyperparameters such as learning rate, warm-up ratio and max. number of epochs, as well as the random model initialization. We recommend the hyperparameters as listed in the example script above. If the proto-DST's tagging performance (Step 2 of 3) remains below expectations for one or more slots, try running the training with a different random initialization, i.e. pick a different random seed, while using the recommended hyperparameters. + +## Datasets + +Supported datasets are: +- sim-M (https://github.com/google-research-datasets/simulated-dialogue.git) +- sim-R (https://github.com/google-research-datasets/simulated-dialogue.git) +- WOZ 2.0 (see https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public.git) +- MultiWOZ 2.0 (https://github.com/budzianowski/multiwoz.git) +- MultiWOZ 2.1 (https://github.com/budzianowski/multiwoz.git) +- MultiWOZ 2.1 legacy version (see https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public.git) +- MultiWOZ 2.2 (https://github.com/budzianowski/multiwoz.git) +- MultiWOZ 2.3 (https://github.com/lexmen318/MultiWOZ-coref.git) +- MultiWOZ 2.4 (https://github.com/smartyfh/MultiWOZ2.4.git) +- Unified data format (currently supported: MultiWOZ 2.1) (see https://github.com/ConvLab/ConvLab-3) + +See the [README file](https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public/-/blob/master/data/README.md) in 'data/' in the original [TripPy repo](https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public) for more details how to obtain and prepare the datasets for use in TripPy-R. + +The ```--task_name``` is +- 'sim-m', for sim-M +- 'sim-r', for sim-R +- 'woz2', for WOZ 2.0 +- 'multiwoz21', for MultiWOZ 2.0-2.4 +- 'multiwoz21_legacy', for MultiWOZ 2.1 legacy version +- 'unified', for ConvLab-3's unified data format + +With a sequence length of 180, you should expect the following average JGA: + +| Dataset | Normal training | Spanless training | +| ------ | ------ | ------ | +| MultiWOZ 2.0 | 51% | tbd | +| MultiWOZ 2.1 | 56% | 55% | +| MultiWOZ 2.1 legacy | 56% | 55% | +| MultiWOZ 2.2 | 56% | tbd | +| MultiWOZ 2.3 | 62% | tbd | +| MultiWOZ 2.4 | 69% | tbd | +| sim-M | 95% | 95% | +| sim-R | 92% | 92% | +| WOZ 2.0 | 92% | 91% | + +## Requirements + +- torch (tested: 1.12.1) +- transformers (tested: 4.18.0) +- tensorboardX (tested: 2.5.1) + +## Citation + +This work is published as [Robust Dialogue State Tracking with Weak Supervision and Sparse Data ](https://doi.org/10.1162/tacl_a_00513) + +If you use TripPy-R in your own work, please cite our work as follows: + +``` +@article{heck-etal-2022-robust, + title = "Robust Dialogue State Tracking with Weak Supervision and Sparse Data", + author = "Heck, Michael and Lubis, Nurul and van Niekerk, Carel and + Feng, Shutong and Geishauser, Christian and Lin, Hsien-Chin and Ga{\v{s}}i{\'c}, Milica", + journal = "Transactions of the Association for Computational Linguistics", + volume = "10", + year = "2022", + address = "Cambridge, MA", + publisher = "MIT Press", + url = "https://aclanthology.org/2022.tacl-1.68", + doi = "10.1162/tacl_a_00513", + pages = "1175--1192", +} +``` diff --git a/data_processors.py b/data_processors.py new file mode 100644 index 0000000..58cf927 --- /dev/null +++ b/data_processors.py @@ -0,0 +1,231 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json + +import dataset_woz2 +import dataset_sim +import dataset_multiwoz21 +import dataset_multiwoz21_legacy +import dataset_unified + + +class DataProcessor(object): + data_dir = "" + dataset_name = "" + class_types = [] + slot_list = {} + noncategorical = [] + boolean = [] + label_maps = {} + value_list = {'train': {}, 'dev': {}, 'test': {}} + + def __init__(self, dataset_config, data_dir): + self.data_dir = data_dir + # Load dataset config file. + with open(dataset_config, "r", encoding='utf-8') as f: + raw_config = json.load(f) + self.dataset_name = raw_config['dataset_name'] if 'dataset_name' in raw_config else "" + self.class_types = raw_config['class_types'] # Required + self.slot_list = raw_config['slots'] if 'slots' in raw_config else {} + self.noncategorical = raw_config['noncategorical'] if 'noncategorical' in raw_config else [] + self.boolean = raw_config['boolean'] if 'boolean' in raw_config else [] + self.label_maps = raw_config['label_maps'] if 'label_maps' in raw_config else {} + # If not slot list is provided, generate from data. + if len(self.slot_list) == 0: + self.slot_list = self._get_slot_list() + + def _add_dummy_value_to_value_list(self): + for dset in self.value_list: + for s in self.value_list[dset]: + if len(self.value_list[dset][s]) == 0: + self.value_list[dset][s] = {'dummy': 1} + + def _remove_dummy_value_from_value_list(self): + for dset in self.value_list: + for s in self.value_list[dset]: + if self.value_list[dset][s] == {'dummy': 1}: + self.value_list[dset][s] = {} + + def _merge_with_train_value_list(self, new_value_list): + self._remove_dummy_value_from_value_list() + for s in new_value_list: + if s not in self.value_list['train']: + self.value_list['train'][s] = new_value_list[s] + else: + for v in new_value_list[s]: + if v not in self.value_list['train'][s]: + self.value_list['train'][s][v] = new_value_list[s][v] + else: + self.value_list['train'][s][v] += new_value_list[s][v] + self._add_dummy_value_to_value_list() + + def _get_slot_list(self): + raise NotImplementedError() + + def prediction_normalization(self, slot, value): + return value + + def get_train_examples(self): + raise NotImplementedError() + + def get_dev_examples(self): + raise NotImplementedError() + + def get_test_examples(self): + raise NotImplementedError() + + +class Woz2Processor(DataProcessor): + def __init__(self, dataset_config, data_dir): + super(Woz2Processor, self).__init__(dataset_config, data_dir) + self.value_list['train'] = dataset_woz2.get_value_list(os.path.join(self.data_dir, 'woz_train_en.json'), + self.slot_list) + self.value_list['dev'] = dataset_woz2.get_value_list(os.path.join(self.data_dir, 'woz_validate_en.json'), + self.slot_list) + self.value_list['test'] = dataset_woz2.get_value_list(os.path.join(self.data_dir, 'woz_test_en.json'), + self.slot_list) + + def get_train_examples(self, args): + return dataset_woz2.create_examples(os.path.join(self.data_dir, 'woz_train_en.json'), + 'train', self.slot_list, self.label_maps, **args) + + def get_dev_examples(self, args): + return dataset_woz2.create_examples(os.path.join(self.data_dir, 'woz_validate_en.json'), + 'dev', self.slot_list, self.label_maps, **args) + + def get_test_examples(self, args): + return dataset_woz2.create_examples(os.path.join(self.data_dir, 'woz_test_en.json'), + 'test', self.slot_list, self.label_maps, **args) + + +class Multiwoz21Processor(DataProcessor): + def __init__(self, dataset_config, data_dir): + super(Multiwoz21Processor, self).__init__(dataset_config, data_dir) + self.value_list['train'] = dataset_multiwoz21.get_value_list(os.path.join(self.data_dir, 'train_dials.json'), + self.slot_list) + self.value_list['dev'] = dataset_multiwoz21.get_value_list(os.path.join(self.data_dir, 'val_dials.json'), + self.slot_list) + self.value_list['test'] = dataset_multiwoz21.get_value_list(os.path.join(self.data_dir, 'test_dials.json'), + self.slot_list) + self._add_dummy_value_to_value_list() + + def prediction_normalization(self, slot, value): + return dataset_multiwoz21.prediction_normalization(slot, value) + + def get_train_examples(self, args): + return dataset_multiwoz21.create_examples(os.path.join(self.data_dir, 'train_dials.json'), + 'train', self.class_types, self.slot_list, self.label_maps, **args) + + def get_dev_examples(self, args): + return dataset_multiwoz21.create_examples(os.path.join(self.data_dir, 'val_dials.json'), + 'dev', self.class_types, self.slot_list, self.label_maps, **args) + + def get_test_examples(self, args): + return dataset_multiwoz21.create_examples(os.path.join(self.data_dir, 'test_dials.json'), + 'test', self.class_types, self.slot_list, self.label_maps, **args) + + +class Multiwoz21LegacyProcessor(DataProcessor): + def __init__(self, dataset_config, data_dir): + super(Multiwoz21LegacyProcessor, self).__init__(dataset_config, data_dir) + self.value_list['train'] = dataset_multiwoz21_legacy.get_value_list(os.path.join(self.data_dir, 'train_dials.json'), + self.slot_list) + self.value_list['dev'] = dataset_multiwoz21_legacy.get_value_list(os.path.join(self.data_dir, 'val_dials.json'), + self.slot_list) + self.value_list['test'] = dataset_multiwoz21_legacy.get_value_list(os.path.join(self.data_dir, 'test_dials.json'), + self.slot_list) + self._add_dummy_value_to_value_list() + + def prediction_normalization(self, slot, value): + return dataset_multiwoz21.prediction_normalization(slot, value) + + def get_train_examples(self, args): + return dataset_multiwoz21_legacy.create_examples(os.path.join(self.data_dir, 'train_dials.json'), + os.path.join(self.data_dir, 'dialogue_acts.json'), + 'train', self.slot_list, self.label_maps, **args) + + def get_dev_examples(self, args): + return dataset_multiwoz21_legacy.create_examples(os.path.join(self.data_dir, 'val_dials.json'), + os.path.join(self.data_dir, 'dialogue_acts.json'), + 'dev', self.slot_list, self.label_maps, **args) + + def get_test_examples(self, args): + return dataset_multiwoz21_legacy.create_examples(os.path.join(self.data_dir, 'test_dials.json'), + os.path.join(self.data_dir, 'dialogue_acts.json'), + 'test', self.slot_list, self.label_maps, **args) + + +class SimProcessor(DataProcessor): + def __init__(self, dataset_config, data_dir): + super(SimProcessor, self).__init__(dataset_config, data_dir) + self.value_list['train'] = dataset_sim.get_value_list(os.path.join(self.data_dir, 'train.json'), + self.slot_list) + self.value_list['dev'] = dataset_sim.get_value_list(os.path.join(self.data_dir, 'dev.json'), + self.slot_list) + self.value_list['test'] = dataset_sim.get_value_list(os.path.join(self.data_dir, 'test.json'), + self.slot_list) + + def get_train_examples(self, args): + return dataset_sim.create_examples(os.path.join(self.data_dir, 'train.json'), + 'train', self.slot_list, **args) + + def get_dev_examples(self, args): + return dataset_sim.create_examples(os.path.join(self.data_dir, 'dev.json'), + 'dev', self.slot_list, **args) + + def get_test_examples(self, args): + return dataset_sim.create_examples(os.path.join(self.data_dir, 'test.json'), + 'test', self.slot_list, **args) + + +class UnifiedDatasetProcessor(DataProcessor): + def __init__(self, dataset_config, data_dir): + super(UnifiedDatasetProcessor, self).__init__(dataset_config, data_dir) + self.value_list['train'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list) + self.value_list['dev'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list) + self.value_list['test'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list) + self._add_dummy_value_to_value_list() + + def prediction_normalization(self, slot, value): + return dataset_unified.prediction_normalization(self.dataset_name, slot, value) + + def _get_slot_list(self): + return dataset_unified.get_slot_list(self.dataset_name) + + def get_train_examples(self, args): + return dataset_unified.create_examples('train', self.dataset_name, self.class_types, + self.slot_list, self.label_maps, **args) + + def get_dev_examples(self, args): + return dataset_unified.create_examples('validation', self.dataset_name, self.class_types, + self.slot_list, self.label_maps, **args) + + def get_test_examples(self, args): + return dataset_unified.create_examples('test', self.dataset_name, self.class_types, + self.slot_list, self.label_maps, **args) + + +PROCESSORS = {"woz2": Woz2Processor, + "sim-m": SimProcessor, + "sim-r": SimProcessor, + "multiwoz21": Multiwoz21Processor, + "multiwoz21_legacy": Multiwoz21LegacyProcessor, + "unified": UnifiedDatasetProcessor} diff --git a/dataset_config/multiwoz21.json b/dataset_config/multiwoz21.json new file mode 100644 index 0000000..b389ab7 --- /dev/null +++ b/dataset_config/multiwoz21.json @@ -0,0 +1,1375 @@ +{ + "class_types": [ + "none", + "dontcare", + "copy_value", + "true", + "false", + "refer", + "inform" + ], + "slots": { + "taxi-leaveAt": "the time of the taxi departure", + "taxi-destination": "the name of the taxi destination", + "taxi-departure": "the name of the taxi departure", + "taxi-arriveBy": "the time of the taxi arrival", + "restaurant-book_people": "the number of people for the restaurant", + "restaurant-book_day": "the day for the restaurant", + "restaurant-book_time": "the time for the restaurant", + "restaurant-food": "the type of the restaurant", + "restaurant-pricerange": "the price range of the restaurant", + "restaurant-name": "the name of the restaurant", + "restaurant-area": "the area of the restaurant", + "hotel-book_people": "the number of people for the hotel", + "hotel-book_day": "the day for the hotel", + "hotel-book_stay": "the number of nights for the hotel", + "hotel-name": "the name of the hotel", + "hotel-area": "the area of the hotel", + "hotel-parking": "the parking of the hotel", + "hotel-pricerange": "the price range of the hotel", + "hotel-stars": "the number of stars of the hotel", + "hotel-internet": "the internet of the hotel", + "hotel-type": "the type of the hotel", + "attraction-type": "the type of the attraction", + "attraction-name": "the name of the attraction", + "attraction-area": "the area of the attraction", + "train-book_people": "the number of people for the train", + "train-leaveAt": "the time of the train departure", + "train-destination": "the name of the train destination", + "train-day": "the day for the train", + "train-arriveBy": "the time of the train arrival", + "train-departure": "the name of the train departure" + }, + "noncategorical": [ + "taxi-leaveAt", + "taxi-destination", + "taxi-departure", + "taxi-arriveBy", + "restaurant-book_time", + "restaurant-food", + "restaurant-name", + "hotel-name", + "attraction-name", + "train-leaveAt", + "train-arriveBy" + ], + "boolean": [ + "hotel-parking", + "hotel-internet", + "hotel-type" + ], + "label_maps": { + "no parking": [ + "not need free parking", + "not need parking", + "not need to have free parking", + "not need to have parking", + "not require free parking", + "not require parking", + "not have free parking", + "not have parking", + "no free parking" + ], + "no internet": [ + "not need free internet", + "not need internet", + "not need to have free internet", + "not need to have internet", + "not require free internet", + "not require internet", + "not have free internet", + "not have internet", + "no free internet", + "not need free wifi", + "not need wifi", + "not need to have free wifi", + "not need to have wifi", + "not require free wifi", + "not require wifi", + "not have free wifi", + "not have wifi", + "no free wifi", + "no wifi", + "not need free wi-fi", + "not need wi-fi", + "not need to have free wi-fi", + "not need to have wi-fi", + "not require free wi-fi", + "not require wi-fi", + "not have free wi-fi", + "not have wi-fi", + "no free wi-fi", + "no wi-fi" + ], + "internet": [ + "wifi", + "wi-fi" + ], + "guest house": [ + "guest houses" + ], + "hotel": [ + "hotels" + ], + "centre": [ + "center", + "downtown" + ], + "north": [ + "northern", + "northside", + "northend" + ], + "east": [ + "eastern", + "eastside", + "eastend" + ], + "west": [ + "western", + "westside", + "westend" + ], + "south": [ + "southern", + "southside", + "southend" + ], + "cheap": [ + "inexpensive", + "lower price", + "lower range", + "cheaply", + "cheaper", + "cheapest", + "very affordable" + ], + "moderate": [ + "moderately", + "reasonable", + "reasonably", + "affordable", + "mid range", + "mid-range", + "priced moderately", + "decently priced", + "mid price", + "mid-price", + "mid priced", + "mid-priced", + "middle price", + "medium price", + "medium priced", + "not too expensive", + "not too cheap" + ], + "expensive": [ + "high end", + "high-end", + "high class", + "high-class", + "high scale", + "high-scale", + "high price", + "high priced", + "higher price", + "fancy", + "upscale", + "nice", + "expensively", + "luxury" + ], + "0": [ + "zero" + ], + "1": [ + "one", + "just me", + "for me", + "myself", + "alone", + "me" + ], + "2": [ + "two" + ], + "3": [ + "three" + ], + "4": [ + "four" + ], + "5": [ + "five" + ], + "6": [ + "six" + ], + "7": [ + "seven" + ], + "8": [ + "eight" + ], + "9": [ + "nine" + ], + "10": [ + "ten" + ], + "11": [ + "eleven" + ], + "12": [ + "twelve" + ], + "architecture": [ + "architectures", + "architectural", + "architecturally", + "architect" + ], + "boat": [ + "boating", + "boats", + "camboats" + ], + "boating": [ + "boat", + "boats", + "camboats" + ], + "camboats": [ + "boating", + "boat", + "boats" + ], + "cinema": [ + "cinemas", + "movie", + "films", + "film" + ], + "college": [ + "colleges" + ], + "concert": [ + "concert hall", + "concert halls", + "concerthall", + "concerthalls", + "concerts" + ], + "concerthall": [ + "concert hall", + "concert halls", + "concerthalls", + "concerts", + "concert" + ], + "entertainment": [ + "entertaining" + ], + "gallery": [ + "museum", + "galleries" + ], + "gastropubs": [ + "gastropub" + ], + "multiple sports": [ + "multiple sport", + "multi sport", + "multi sports", + "sports", + "sporting" + ], + "museum": [ + "museums", + "gallery", + "galleries" + ], + "night club": [ + "night clubs", + "nightclub", + "nightclubs", + "club", + "clubs" + ], + "nightclub": [ + "night club", + "night clubs", + "nightclubs", + "club", + "clubs" + ], + "park": [ + "parks" + ], + "pool": [ + "swimming pool", + "swimming pools", + "swimming", + "pools", + "swimmingpool", + "swimmingpools", + "water", + "swim" + ], + "sports": [ + "multiple sport", + "multi sport", + "multi sports", + "multiple sports", + "sporting" + ], + "swimming pool": [ + "swimming", + "pool", + "pools", + "swimmingpool", + "swimmingpools", + "water", + "swim" + ], + "theater": [ + "theatre", + "theatres", + "theaters" + ], + "theatre": [ + "theater", + "theatres", + "theaters" + ], + "abbey pool and astroturf pitch": [ + "abbey pool and astroturf", + "abbey pool" + ], + "abbey pool and astroturf": [ + "abbey pool and astroturf pitch", + "abbey pool" + ], + "abbey pool": [ + "abbey pool and astroturf pitch", + "abbey pool and astroturf" + ], + "adc theatre": [ + "adc theater", + "adc" + ], + "adc": [ + "adc theatre", + "adc theater" + ], + "addenbrookes hospital": [ + "addenbrooke's hospital" + ], + "cafe jello gallery": [ + "cafe jello" + ], + "cambridge and county folk museum": [ + "cambridge and country folk museum", + "county folk museum" + ], + "cambridge and country folk museum": [ + "cambridge and county folk museum", + "county folk museum" + ], + "county folk museum": [ + "cambridge and county folk museum", + "cambridge and country folk museum" + ], + "cambridge arts theatre": [ + "cambridge arts theater" + ], + "cambridge book and print gallery": [ + "book and print gallery" + ], + "cambridge contemporary art": [ + "cambridge contemporary art museum", + "contemporary art museum" + ], + "cambridge contemporary art museum": [ + "cambridge contemporary art", + "contemporary art museum" + ], + "cambridge corn exchange": [ + "the cambridge corn exchange" + ], + "the cambridge corn exchange": [ + "cambridge corn exchange" + ], + "cambridge museum of technology": [ + "museum of technology" + ], + "cambridge punter": [ + "the cambridge punter", + "cambridge punters" + ], + "cambridge punters": [ + "the cambridge punter", + "cambridge punter" + ], + "the cambridge punter": [ + "cambridge punter", + "cambridge punters" + ], + "cambridge university botanic gardens": [ + "cambridge university botanical gardens", + "cambridge university botanical garden", + "cambridge university botanic garden", + "cambridge botanic gardens", + "cambridge botanical gardens", + "cambridge botanic garden", + "cambridge botanical garden", + "botanic gardens", + "botanical gardens", + "botanic garden", + "botanical garden" + ], + "cambridge botanic gardens": [ + "cambridge university botanic gardens", + "cambridge university botanical gardens", + "cambridge university botanical garden", + "cambridge university botanic garden", + "cambridge botanical gardens", + "cambridge botanic garden", + "cambridge botanical garden", + "botanic gardens", + "botanical gardens", + "botanic garden", + "botanical garden" + ], + "botanic gardens": [ + "cambridge university botanic gardens", + "cambridge university botanical gardens", + "cambridge university botanical garden", + "cambridge university botanic garden", + "cambridge botanic gardens", + "cambridge botanical gardens", + "cambridge botanic garden", + "cambridge botanical garden", + "botanical gardens", + "botanic garden", + "botanical garden" + ], + "cherry hinton village centre": [ + "cherry hinton village center" + ], + "cherry hinton village center": [ + "cherry hinton village centre" + ], + "cherry hinton hall and grounds": [ + "cherry hinton hall" + ], + "cherry hinton hall": [ + "cherry hinton hall and grounds" + ], + "cherry hinton water play": [ + "cherry hinton water play park" + ], + "cherry hinton water play park": [ + "cherry hinton water play" + ], + "christ college": [ + "christ's college", + "christs college" + ], + "christs college": [ + "christ college", + "christ's college" + ], + "churchills college": [ + "churchill's college", + "churchill college" + ], + "cineworld cinema": [ + "cineworld" + ], + "clair hall": [ + "clare hall" + ], + "clare hall": [ + "clair hall" + ], + "the fez club": [ + "fez club" + ], + "great saint marys church": [ + "great saint mary's church", + "great saint mary's", + "great saint marys" + ], + "jesus green outdoor pool": [ + "jesus green" + ], + "jesus green": [ + "jesus green outdoor pool" + ], + "kettles yard": [ + "kettle's yard" + ], + "kings college": [ + "king's college" + ], + "kings hedges learner pool": [ + "king's hedges learner pool", + "king hedges learner pool" + ], + "king hedges learner pool": [ + "king's hedges learner pool", + "kings hedges learner pool" + ], + "little saint marys church": [ + "little saint mary's church", + "little saint mary's", + "little saint marys" + ], + "mumford theatre": [ + "mumford theater" + ], + "museum of archaelogy": [ + "museum of archaeology" + ], + "museum of archaelogy and anthropology": [ + "museum of archaeology and anthropology" + ], + "peoples portraits exhibition": [ + "people's portraits exhibition at girton college", + "peoples portraits exhibition at girton college", + "people's portraits exhibition" + ], + "peoples portraits exhibition at girton college": [ + "people's portraits exhibition at girton college", + "people's portraits exhibition", + "peoples portraits exhibition" + ], + "queens college": [ + "queens' college", + "queen's college" + ], + "riverboat georgina": [ + "riverboat" + ], + "saint barnabas": [ + "saint barbabas press gallery" + ], + "saint barnabas press gallery": [ + "saint barbabas" + ], + "saint catharines college": [ + "saint catharine's college", + "saint catharine's", + "saint catherine's college", + "saint catherine's" + ], + "saint johns college": [ + "saint john's college", + "st john's college", + "st johns college" + ], + "scott polar": [ + "scott polar museum" + ], + "scott polar museum": [ + "scott polar" + ], + "scudamores punting co": [ + "scudamore's punting co", + "scudamores punting", + "scudamore's punting", + "scudamores", + "scudamore's", + "scudamore" + ], + "scudamore": [ + "scudamore's punting co", + "scudamores punting co", + "scudamores punting", + "scudamore's punting", + "scudamores", + "scudamore's" + ], + "sheeps green and lammas land park fen causeway": [ + "sheep's green and lammas land park fen causeway", + "sheep's green and lammas land park", + "sheeps green and lammas land park", + "lammas land park", + "sheep's green", + "sheeps green" + ], + "sheeps green and lammas land park": [ + "sheep's green and lammas land park fen causeway", + "sheeps green and lammas land park fen causeway", + "sheep's green and lammas land park", + "lammas land park", + "sheep's green", + "sheeps green" + ], + "lammas land park": [ + "sheep's green and lammas land park fen causeway", + "sheeps green and lammas land park fen causeway", + "sheep's green and lammas land park", + "sheeps green and lammas land park", + "sheep's green", + "sheeps green" + ], + "sheeps green": [ + "sheep's green and lammas land park fen causeway", + "sheeps green and lammas land park fen causeway", + "sheep's green and lammas land park", + "sheeps green and lammas land park", + "lammas land park", + "sheep's green" + ], + "soul tree nightclub": [ + "soul tree night club", + "soul tree", + "soultree" + ], + "soultree": [ + "soul tree nightclub", + "soul tree night club", + "soul tree" + ], + "the man on the moon": [ + "man on the moon" + ], + "man on the moon": [ + "the man on the moon" + ], + "the junction": [ + "junction theatre", + "junction theater" + ], + "junction theatre": [ + "the junction", + "junction theater" + ], + "old schools": [ + "old school" + ], + "vue cinema": [ + "vue" + ], + "wandlebury country park": [ + "the wandlebury" + ], + "the wandlebury": [ + "wandlebury country park" + ], + "whipple museum of the history of science": [ + "whipple museum", + "history of science museum" + ], + "history of science museum": [ + "whipple museum of the history of science", + "whipple museum" + ], + "williams art and antique": [ + "william's art and antique" + ], + "alimentum": [ + "restaurant alimentum" + ], + "restaurant alimentum": [ + "alimentum" + ], + "bedouin": [ + "the bedouin" + ], + "the bedouin": [ + "bedouin" + ], + "bloomsbury restaurant": [ + "bloomsbury" + ], + "cafe uno": [ + "caffe uno", + "caffee uno" + ], + "caffe uno": [ + "cafe uno", + "caffee uno" + ], + "caffee uno": [ + "cafe uno", + "caffe uno" + ], + "cambridge lodge restaurant": [ + "cambridge lodge" + ], + "chiquito": [ + "chiquito restaurant bar", + "chiquito restaurant" + ], + "chiquito restaurant bar": [ + "chiquito restaurant", + "chiquito" + ], + "city stop restaurant": [ + "city stop" + ], + "cityr": [ + "cityroomz" + ], + "citiroomz": [ + "cityroomz" + ], + "clowns cafe": [ + "clown's cafe" + ], + "cow pizza kitchen and bar": [ + "the cow pizza kitchen and bar", + "cow pizza" + ], + "the cow pizza kitchen and bar": [ + "cow pizza kitchen and bar", + "cow pizza" + ], + "darrys cookhouse and wine shop": [ + "darry's cookhouse and wine shop", + "darry's cookhouse", + "darrys cookhouse" + ], + "de luca cucina and bar": [ + "de luca cucina and bar riverside brasserie", + "luca cucina and bar", + "de luca cucina", + "luca cucina" + ], + "de luca cucina and bar riverside brasserie": [ + "de luca cucina and bar", + "luca cucina and bar", + "de luca cucina", + "luca cucina" + ], + "da vinci pizzeria": [ + "da vinci pizza", + "da vinci" + ], + "don pasquale pizzeria": [ + "don pasquale pizza", + "don pasquale", + "pasquale pizzeria", + "pasquale pizza" + ], + "efes": [ + "efes restaurant" + ], + "efes restaurant": [ + "efes" + ], + "fitzbillies restaurant": [ + "fitzbillies" + ], + "frankie and bennys": [ + "frankie and benny's" + ], + "funky": [ + "funky fun house" + ], + "funky fun house": [ + "funky" + ], + "gardenia": [ + "the gardenia" + ], + "the gardenia": [ + "gardenia" + ], + "grafton hotel restaurant": [ + "the grafton hotel", + "grafton hotel" + ], + "the grafton hotel": [ + "grafton hotel restaurant", + "grafton hotel" + ], + "grafton hotel": [ + "grafton hotel restaurant", + "the grafton hotel" + ], + "hotel du vin and bistro": [ + "hotel du vin", + "du vin" + ], + "Kohinoor": [ + "kohinoor", + "the kohinoor" + ], + "kohinoor": [ + "the kohinoor" + ], + "the kohinoor": [ + "kohinoor" + ], + "lan hong house": [ + "lan hong", + "ian hong house", + "ian hong" + ], + "ian hong": [ + "lan hong house", + "lan hong", + "ian hong house" + ], + "lovel": [ + "the lovell lodge", + "lovell lodge" + ], + "lovell lodge": [ + "lovell" + ], + "the lovell lodge": [ + "lovell lodge", + "lovell" + ], + "mahal of cambridge": [ + "mahal" + ], + "mahal": [ + "mahal of cambridge" + ], + "maharajah tandoori restaurant": [ + "maharajah tandoori" + ], + "the maharajah tandoor": [ + "maharajah tandoori restaurant", + "maharajah tandoori" + ], + "meze bar": [ + "meze bar restaurant", + "the meze bar" + ], + "meze bar restaurant": [ + "the meze bar", + "meze bar" + ], + "the meze bar": [ + "meze bar restaurant", + "meze bar" + ], + "michaelhouse cafe": [ + "michael house cafe" + ], + "midsummer house restaurant": [ + "midsummer house" + ], + "missing sock": [ + "the missing sock" + ], + "the missing sock": [ + "missing sock" + ], + "nandos": [ + "nando's city centre", + "nando's city center", + "nandos city centre", + "nandos city center", + "nando's" + ], + "nandos city centre": [ + "nando's city centre", + "nando's city center", + "nandos city center", + "nando's", + "nandos" + ], + "oak bistro": [ + "the oak bistro" + ], + "the oak bistro": [ + "oak bistro" + ], + "one seven": [ + "restaurant one seven" + ], + "restaurant one seven": [ + "one seven" + ], + "river bar steakhouse and grill": [ + "the river bar steakhouse and grill", + "the river bar steakhouse", + "river bar steakhouse" + ], + "the river bar steakhouse and grill": [ + "river bar steakhouse and grill", + "the river bar steakhouse", + "river bar steakhouse" + ], + "pipasha restaurant": [ + "pipasha" + ], + "pizza hut city centre": [ + "pizza hut city center" + ], + "pizza hut fenditton": [ + "pizza hut fen ditton", + "pizza express fen ditton" + ], + "restaurant two two": [ + "two two", + "restaurant 22" + ], + "saffron brasserie": [ + "saffron" + ], + "saint johns chop house": [ + "saint john's chop house", + "st john's chop house", + "st johns chop house" + ], + "sesame restaurant and bar": [ + "sesame restaurant", + "sesame" + ], + "shanghai": [ + "shanghai family restaurant" + ], + "shanghai family restaurant": [ + "shanghai" + ], + "sitar": [ + "sitar tandoori" + ], + "sitar tandoori": [ + "sitar" + ], + "slug and lettuce": [ + "the slug and lettuce" + ], + "the slug and lettuce": [ + "slug and lettuce" + ], + "st johns chop house": [ + "saint john's chop house", + "st john's chop house", + "saint johns chop house" + ], + "stazione restaurant and coffee bar": [ + "stazione restaurant", + "stazione" + ], + "thanh binh": [ + "thanh", + "binh" + ], + "thanh": [ + "thanh binh", + "binh" + ], + "binh": [ + "thanh binh", + "thanh" + ], + "the hotpot": [ + "the hotspot", + "hotpot", + "hotspot" + ], + "hotpot": [ + "the hotpot", + "the hotpot", + "hotspot" + ], + "the lucky star": [ + "lucky star" + ], + "lucky star": [ + "the lucky star" + ], + "the peking restaurant: ": [ + "peking restaurant" + ], + "the varsity restaurant": [ + "varsity restaurant", + "the varsity", + "varsity" + ], + "two two": [ + "restaurant two two", + "restaurant 22" + ], + "restaurant 22": [ + "restaurant two two", + "two two" + ], + "zizzi cambridge": [ + "zizzi" + ], + "american": [ + "americas" + ], + "asian oriental": [ + "asian", + "oriental" + ], + "australian": [ + "australasian" + ], + "barbeque": [ + "barbecue", + "bbq" + ], + "corsica": [ + "corsican" + ], + "indian": [ + "tandoori" + ], + "italian": [ + "pizza", + "pizzeria" + ], + "japanese": [ + "sushi" + ], + "latin american": [ + "latin-american", + "latin" + ], + "malaysian": [ + "malay" + ], + "middle eastern": [ + "middle-eastern" + ], + "traditional american": [ + "american" + ], + "modern american": [ + "american modern", + "american" + ], + "modern european": [ + "european modern", + "european" + ], + "north american": [ + "north-american", + "american" + ], + "portuguese": [ + "portugese" + ], + "portugese": [ + "portuguese" + ], + "seafood": [ + "sea food" + ], + "singaporean": [ + "singapore" + ], + "steakhouse": [ + "steak house", + "steak" + ], + "the americas": [ + "american", + "americas" + ], + "a and b guest house": [ + "a & b guest house", + "a and b", + "a & b" + ], + "the acorn guest house": [ + "acorn guest house", + "acorn" + ], + "acorn guest house": [ + "the acorn guest house", + "acorn" + ], + "alexander bed and breakfast": [ + "alexander" + ], + "allenbell": [ + "the allenbell" + ], + "the allenbell": [ + "allenbell" + ], + "alpha-milton guest house": [ + "the alpha-milton", + "alpha-milton" + ], + "the alpha-milton": [ + "alpha-milton guest house", + "alpha-milton" + ], + "arbury lodge guest house": [ + "arbury lodge", + "arbury" + ], + "archway house": [ + "archway" + ], + "ashley hotel": [ + "the ashley hotel", + "ashley" + ], + "the ashley hotel": [ + "ashley hotel", + "ashley" + ], + "aylesbray lodge guest house": [ + "aylesbray lodge", + "aylesbray" + ], + "aylesbray lodge guest": [ + "aylesbray lodge guest house", + "aylesbray lodge", + "aylesbray" + ], + "alesbray lodge guest house": [ + "aylesbray lodge guest house", + "aylesbray lodge", + "aylesbray" + ], + "alyesbray lodge hotel": [ + "aylesbray lodge guest house", + "aylesbray lodge", + "aylesbray" + ], + "bridge guest house": [ + "bridge house" + ], + "cambridge belfry": [ + "the cambridge belfry", + "belfry hotel", + "belfry" + ], + "the cambridge belfry": [ + "cambridge belfry", + "belfry hotel", + "belfry" + ], + "belfry hotel": [ + "the cambridge belfry", + "cambridge belfry", + "belfry" + ], + "carolina bed and breakfast": [ + "carolina" + ], + "city centre north": [ + "city centre north bed and breakfast" + ], + "north b and b": [ + "city centre north bed and breakfast" + ], + "city centre north b and b": [ + "city centre north bed and breakfast" + ], + "el shaddia guest house": [ + "el shaddai guest house", + "el shaddai", + "el shaddia" + ], + "el shaddai guest house": [ + "el shaddia guest house", + "el shaddai", + "el shaddia" + ], + "express by holiday inn cambridge": [ + "express by holiday inn", + "holiday inn cambridge", + "holiday inn" + ], + "holiday inn": [ + "express by holiday inn cambridge", + "express by holiday inn", + "holiday inn cambridge" + ], + "finches bed and breakfast": [ + "finches" + ], + "gonville hotel": [ + "gonville" + ], + "hamilton lodge": [ + "the hamilton lodge", + "hamilton" + ], + "the hamilton lodge": [ + "hamilton lodge", + "hamilton" + ], + "hobsons house": [ + "hobson's house", + "hobson's" + ], + "huntingdon marriott hotel": [ + "huntington marriott hotel", + "huntington marriot hotel", + "huntingdon marriot hotel", + "huntington marriott", + "huntingdon marriott", + "huntington marriot", + "huntingdon marriot", + "huntington", + "huntingdon" + ], + "kirkwood": [ + "kirkwood house" + ], + "kirkwood house": [ + "kirkwood" + ], + "lensfield hotel": [ + "the lensfield hotel", + "lensfield" + ], + "the lensfield hotel": [ + "lensfield hotel", + "lensfield" + ], + "leverton house": [ + "leverton" + ], + "marriot hotel": [ + "marriott hotel", + "marriott" + ], + "rosas bed and breakfast": [ + "rosa's bed and breakfast", + "rosa's", + "rosas" + ], + "university arms hotel": [ + "university arms" + ], + "warkworth house": [ + "warkworth hotel", + "warkworth" + ], + "warkworth hotel": [ + "warkworth house", + "warkworth" + ], + "wartworth": [ + "warkworth house", + "warkworth hotel", + "warkworth" + ], + "worth house": [ + "the worth house" + ], + "the worth house": [ + "worth house" + ], + "birmingham new street": [ + "birmingham new street train station" + ], + "birmingham new street train station": [ + "birmingham new street" + ], + "bishops stortford": [ + "bishops stortford train station" + ], + "bishops stortford train station": [ + "bishops stortford" + ], + "broxbourne": [ + "broxbourne train station" + ], + "broxbourne train station": [ + "broxbourne" + ], + "cambridge": [ + "cambridge train station" + ], + "cambridge train station": [ + "cambridge" + ], + "ely": [ + "ely train station" + ], + "ely train station": [ + "ely" + ], + "kings lynn": [ + "king's lynn", + "king's lynn train station", + "kings lynn train station" + ], + "kings lynn train station": [ + "kings lynn", + "king's lynn", + "king's lynn train station" + ], + "leicester": [ + "leicester train station" + ], + "leicester train station": [ + "leicester" + ], + "london kings cross": [ + "kings cross", + "king's cross", + "london king's cross", + "kings cross train station", + "king's cross train station", + "london king's cross train station", + "london kings cross train station" + ], + "london kings cross train station": [ + "kings cross", + "king's cross", + "london king's cross", + "london kings cross", + "kings cross train station", + "king's cross train station", + "london king's cross train station" + ], + "london liverpool": [ + "liverpool street", + "london liverpool street", + "london liverpool train station", + "liverpool street train station", + "london liverpool street train station" + ], + "london liverpool street": [ + "london liverpool", + "liverpool street", + "london liverpool train station", + "liverpool street train station", + "london liverpool street train station" + ], + "london liverpool street train station": [ + "london liverpool", + "liverpool street", + "london liverpool street", + "london liverpool train station", + "liverpool street train station" + ], + "norwich": [ + "norwich train station" + ], + "norwich train station": [ + "norwich" + ], + "peterborough": [ + "peterborough train station" + ], + "peterborough train station": [ + "peterborough" + ], + "stansted airport": [ + "stansted airport train station" + ], + "stansted airport train station": [ + "stansted airport" + ], + "stevenage": [ + "stevenage train station" + ], + "stevenage train station": [ + "stevenage" + ] + } +} diff --git a/dataset_config/sim-m.json b/dataset_config/sim-m.json new file mode 100644 index 0000000..c114445 --- /dev/null +++ b/dataset_config/sim-m.json @@ -0,0 +1,20 @@ +{ + "class_types": [ + "none", + "dontcare", + "copy_value", + "inform" + ], + "slots": { + "date": "the date for which to book the movie", + "movie": "the name of the movie", + "time": "the time for which to book the movie", + "num_tickets": "the amount of tickets for the movie", + "theatre_name": "the name of the movie theatre" + }, + "noncategorical": [ + "movie" + ], + "boolean": [], + "label_maps": {} +} diff --git a/dataset_config/sim-r.json b/dataset_config/sim-r.json new file mode 100644 index 0000000..d7400e5 --- /dev/null +++ b/dataset_config/sim-r.json @@ -0,0 +1,24 @@ +{ + "class_types": [ + "none", + "dontcare", + "copy_value", + "inform" + ], + "slots": { + "category": "the category of the restaurant", + "rating": "the rating of the restaurant", + "num_people": "the amount of people to book the restaurant for", + "location": "the location of the restaurant", + "restaurant_name": "the name of the restaurant", + "time": "the time for which to book the restaurant", + "date": "the date for which to book the restaurant", + "price_range": "the price range of the restaurant", + "meal": "the food type of the restaurant" + }, + "noncategorical": [ + "restaurant_name" + ], + "boolean": [], + "label_maps": {} +} diff --git a/dataset_config/unified_multiwoz21.json b/dataset_config/unified_multiwoz21.json new file mode 100644 index 0000000..77140bc --- /dev/null +++ b/dataset_config/unified_multiwoz21.json @@ -0,0 +1,1346 @@ +{ + "dataset_name": "multiwoz21", + "class_types": [ + "none", + "dontcare", + "copy_value", + "true", + "false", + "refer", + "inform", + "request" + ], + "slots": [], + "noncategorical": [ + "taxi-leaveAt", + "taxi-destination", + "taxi-departure", + "taxi-arriveBy", + "restaurant-book_time", + "restaurant-food", + "restaurant-name", + "hotel-name", + "attraction-name", + "train-leaveAt", + "train-arriveBy" + ], + "boolean": [ + "hotel-parking", + "hotel-internet", + "hotel-type" + ], + "label_maps": { + "no parking": [ + "not need free parking", + "not need parking", + "not need to have free parking", + "not need to have parking", + "not require free parking", + "not require parking", + "not have free parking", + "not have parking", + "no free parking" + ], + "no internet": [ + "not need free internet", + "not need internet", + "not need to have free internet", + "not need to have internet", + "not require free internet", + "not require internet", + "not have free internet", + "not have internet", + "no free internet", + "not need free wifi", + "not need wifi", + "not need to have free wifi", + "not need to have wifi", + "not require free wifi", + "not require wifi", + "not have free wifi", + "not have wifi", + "no free wifi", + "no wifi", + "not need free wi-fi", + "not need wi-fi", + "not need to have free wi-fi", + "not need to have wi-fi", + "not require free wi-fi", + "not require wi-fi", + "not have free wi-fi", + "not have wi-fi", + "no free wi-fi", + "no wi-fi" + ], + "internet": [ + "wifi", + "wi-fi" + ], + "guest house": [ + "guest houses" + ], + "hotel": [ + "hotels" + ], + "centre": [ + "center", + "downtown" + ], + "north": [ + "northern", + "northside", + "northend" + ], + "east": [ + "eastern", + "eastside", + "eastend" + ], + "west": [ + "western", + "westside", + "westend" + ], + "south": [ + "southern", + "southside", + "southend" + ], + "cheap": [ + "inexpensive", + "lower price", + "lower range", + "cheaply", + "cheaper", + "cheapest", + "very affordable" + ], + "moderate": [ + "moderately", + "reasonable", + "reasonably", + "affordable", + "mid range", + "mid-range", + "priced moderately", + "decently priced", + "mid price", + "mid-price", + "mid priced", + "mid-priced", + "middle price", + "medium price", + "medium priced", + "not too expensive", + "not too cheap" + ], + "expensive": [ + "high end", + "high-end", + "high class", + "high-class", + "high scale", + "high-scale", + "high price", + "high priced", + "higher price", + "fancy", + "upscale", + "nice", + "expensively", + "luxury" + ], + "0": [ + "zero" + ], + "1": [ + "one", + "just me", + "for me", + "myself", + "alone", + "me" + ], + "2": [ + "two" + ], + "3": [ + "three" + ], + "4": [ + "four" + ], + "5": [ + "five" + ], + "6": [ + "six" + ], + "7": [ + "seven" + ], + "8": [ + "eight" + ], + "9": [ + "nine" + ], + "10": [ + "ten" + ], + "11": [ + "eleven" + ], + "12": [ + "twelve" + ], + "architecture": [ + "architectures", + "architectural", + "architecturally", + "architect" + ], + "boat": [ + "boating", + "boats", + "camboats" + ], + "boating": [ + "boat", + "boats", + "camboats" + ], + "camboats": [ + "boating", + "boat", + "boats" + ], + "cinema": [ + "cinemas", + "movie", + "films", + "film" + ], + "college": [ + "colleges" + ], + "concert": [ + "concert hall", + "concert halls", + "concerthall", + "concerthalls", + "concerts" + ], + "concerthall": [ + "concert hall", + "concert halls", + "concerthalls", + "concerts", + "concert" + ], + "entertainment": [ + "entertaining" + ], + "gallery": [ + "museum", + "galleries" + ], + "gastropubs": [ + "gastropub" + ], + "multiple sports": [ + "multiple sport", + "multi sport", + "multi sports", + "sports", + "sporting" + ], + "museum": [ + "museums", + "gallery", + "galleries" + ], + "night club": [ + "night clubs", + "nightclub", + "nightclubs", + "club", + "clubs" + ], + "nightclub": [ + "night club", + "night clubs", + "nightclubs", + "club", + "clubs" + ], + "park": [ + "parks" + ], + "pool": [ + "swimming pool", + "swimming pools", + "swimming", + "pools", + "swimmingpool", + "swimmingpools", + "water", + "swim" + ], + "sports": [ + "multiple sport", + "multi sport", + "multi sports", + "multiple sports", + "sporting" + ], + "swimming pool": [ + "swimming", + "pool", + "pools", + "swimmingpool", + "swimmingpools", + "water", + "swim" + ], + "theater": [ + "theatre", + "theatres", + "theaters" + ], + "theatre": [ + "theater", + "theatres", + "theaters" + ], + "abbey pool and astroturf pitch": [ + "abbey pool and astroturf", + "abbey pool" + ], + "abbey pool and astroturf": [ + "abbey pool and astroturf pitch", + "abbey pool" + ], + "abbey pool": [ + "abbey pool and astroturf pitch", + "abbey pool and astroturf" + ], + "adc theatre": [ + "adc theater", + "adc" + ], + "adc": [ + "adc theatre", + "adc theater" + ], + "addenbrookes hospital": [ + "addenbrooke's hospital" + ], + "cafe jello gallery": [ + "cafe jello" + ], + "cambridge and county folk museum": [ + "cambridge and country folk museum", + "county folk museum" + ], + "cambridge and country folk museum": [ + "cambridge and county folk museum", + "county folk museum" + ], + "county folk museum": [ + "cambridge and county folk museum", + "cambridge and country folk museum" + ], + "cambridge arts theatre": [ + "cambridge arts theater" + ], + "cambridge book and print gallery": [ + "book and print gallery" + ], + "cambridge contemporary art": [ + "cambridge contemporary art museum", + "contemporary art museum" + ], + "cambridge contemporary art museum": [ + "cambridge contemporary art", + "contemporary art museum" + ], + "cambridge corn exchange": [ + "the cambridge corn exchange" + ], + "the cambridge corn exchange": [ + "cambridge corn exchange" + ], + "cambridge museum of technology": [ + "museum of technology" + ], + "cambridge punter": [ + "the cambridge punter", + "cambridge punters" + ], + "cambridge punters": [ + "the cambridge punter", + "cambridge punter" + ], + "the cambridge punter": [ + "cambridge punter", + "cambridge punters" + ], + "cambridge university botanic gardens": [ + "cambridge university botanical gardens", + "cambridge university botanical garden", + "cambridge university botanic garden", + "cambridge botanic gardens", + "cambridge botanical gardens", + "cambridge botanic garden", + "cambridge botanical garden", + "botanic gardens", + "botanical gardens", + "botanic garden", + "botanical garden" + ], + "cambridge botanic gardens": [ + "cambridge university botanic gardens", + "cambridge university botanical gardens", + "cambridge university botanical garden", + "cambridge university botanic garden", + "cambridge botanical gardens", + "cambridge botanic garden", + "cambridge botanical garden", + "botanic gardens", + "botanical gardens", + "botanic garden", + "botanical garden" + ], + "botanic gardens": [ + "cambridge university botanic gardens", + "cambridge university botanical gardens", + "cambridge university botanical garden", + "cambridge university botanic garden", + "cambridge botanic gardens", + "cambridge botanical gardens", + "cambridge botanic garden", + "cambridge botanical garden", + "botanical gardens", + "botanic garden", + "botanical garden" + ], + "cherry hinton village centre": [ + "cherry hinton village center" + ], + "cherry hinton village center": [ + "cherry hinton village centre" + ], + "cherry hinton hall and grounds": [ + "cherry hinton hall" + ], + "cherry hinton hall": [ + "cherry hinton hall and grounds" + ], + "cherry hinton water play": [ + "cherry hinton water play park" + ], + "cherry hinton water play park": [ + "cherry hinton water play" + ], + "christ college": [ + "christ's college", + "christs college" + ], + "christs college": [ + "christ college", + "christ's college" + ], + "churchills college": [ + "churchill's college", + "churchill college" + ], + "cineworld cinema": [ + "cineworld" + ], + "clair hall": [ + "clare hall" + ], + "clare hall": [ + "clair hall" + ], + "the fez club": [ + "fez club" + ], + "great saint marys church": [ + "great saint mary's church", + "great saint mary's", + "great saint marys" + ], + "jesus green outdoor pool": [ + "jesus green" + ], + "jesus green": [ + "jesus green outdoor pool" + ], + "kettles yard": [ + "kettle's yard" + ], + "kings college": [ + "king's college" + ], + "kings hedges learner pool": [ + "king's hedges learner pool", + "king hedges learner pool" + ], + "king hedges learner pool": [ + "king's hedges learner pool", + "kings hedges learner pool" + ], + "little saint marys church": [ + "little saint mary's church", + "little saint mary's", + "little saint marys" + ], + "mumford theatre": [ + "mumford theater" + ], + "museum of archaelogy": [ + "museum of archaeology" + ], + "museum of archaelogy and anthropology": [ + "museum of archaeology and anthropology" + ], + "peoples portraits exhibition": [ + "people's portraits exhibition at girton college", + "peoples portraits exhibition at girton college", + "people's portraits exhibition" + ], + "peoples portraits exhibition at girton college": [ + "people's portraits exhibition at girton college", + "people's portraits exhibition", + "peoples portraits exhibition" + ], + "queens college": [ + "queens' college", + "queen's college" + ], + "riverboat georgina": [ + "riverboat" + ], + "saint barnabas": [ + "saint barbabas press gallery" + ], + "saint barnabas press gallery": [ + "saint barbabas" + ], + "saint catharines college": [ + "saint catharine's college", + "saint catharine's", + "saint catherine's college", + "saint catherine's" + ], + "saint johns college": [ + "saint john's college", + "st john's college", + "st johns college" + ], + "scott polar": [ + "scott polar museum" + ], + "scott polar museum": [ + "scott polar" + ], + "scudamores punting co": [ + "scudamore's punting co", + "scudamores punting", + "scudamore's punting", + "scudamores", + "scudamore's", + "scudamore" + ], + "scudamore": [ + "scudamore's punting co", + "scudamores punting co", + "scudamores punting", + "scudamore's punting", + "scudamores", + "scudamore's" + ], + "sheeps green and lammas land park fen causeway": [ + "sheep's green and lammas land park fen causeway", + "sheep's green and lammas land park", + "sheeps green and lammas land park", + "lammas land park", + "sheep's green", + "sheeps green" + ], + "sheeps green and lammas land park": [ + "sheep's green and lammas land park fen causeway", + "sheeps green and lammas land park fen causeway", + "sheep's green and lammas land park", + "lammas land park", + "sheep's green", + "sheeps green" + ], + "lammas land park": [ + "sheep's green and lammas land park fen causeway", + "sheeps green and lammas land park fen causeway", + "sheep's green and lammas land park", + "sheeps green and lammas land park", + "sheep's green", + "sheeps green" + ], + "sheeps green": [ + "sheep's green and lammas land park fen causeway", + "sheeps green and lammas land park fen causeway", + "sheep's green and lammas land park", + "sheeps green and lammas land park", + "lammas land park", + "sheep's green" + ], + "soul tree nightclub": [ + "soul tree night club", + "soul tree", + "soultree" + ], + "soultree": [ + "soul tree nightclub", + "soul tree night club", + "soul tree" + ], + "the man on the moon": [ + "man on the moon" + ], + "man on the moon": [ + "the man on the moon" + ], + "the junction": [ + "junction theatre", + "junction theater" + ], + "junction theatre": [ + "the junction", + "junction theater" + ], + "old schools": [ + "old school" + ], + "vue cinema": [ + "vue" + ], + "wandlebury country park": [ + "the wandlebury" + ], + "the wandlebury": [ + "wandlebury country park" + ], + "whipple museum of the history of science": [ + "whipple museum", + "history of science museum" + ], + "history of science museum": [ + "whipple museum of the history of science", + "whipple museum" + ], + "williams art and antique": [ + "william's art and antique" + ], + "alimentum": [ + "restaurant alimentum" + ], + "restaurant alimentum": [ + "alimentum" + ], + "bedouin": [ + "the bedouin" + ], + "the bedouin": [ + "bedouin" + ], + "bloomsbury restaurant": [ + "bloomsbury" + ], + "cafe uno": [ + "caffe uno", + "caffee uno" + ], + "caffe uno": [ + "cafe uno", + "caffee uno" + ], + "caffee uno": [ + "cafe uno", + "caffe uno" + ], + "cambridge lodge restaurant": [ + "cambridge lodge" + ], + "chiquito": [ + "chiquito restaurant bar", + "chiquito restaurant" + ], + "chiquito restaurant bar": [ + "chiquito restaurant", + "chiquito" + ], + "city stop restaurant": [ + "city stop" + ], + "cityr": [ + "cityroomz" + ], + "citiroomz": [ + "cityroomz" + ], + "clowns cafe": [ + "clown's cafe" + ], + "cow pizza kitchen and bar": [ + "the cow pizza kitchen and bar", + "cow pizza" + ], + "the cow pizza kitchen and bar": [ + "cow pizza kitchen and bar", + "cow pizza" + ], + "darrys cookhouse and wine shop": [ + "darry's cookhouse and wine shop", + "darry's cookhouse", + "darrys cookhouse" + ], + "de luca cucina and bar": [ + "de luca cucina and bar riverside brasserie", + "luca cucina and bar", + "de luca cucina", + "luca cucina" + ], + "de luca cucina and bar riverside brasserie": [ + "de luca cucina and bar", + "luca cucina and bar", + "de luca cucina", + "luca cucina" + ], + "da vinci pizzeria": [ + "da vinci pizza", + "da vinci" + ], + "don pasquale pizzeria": [ + "don pasquale pizza", + "don pasquale", + "pasquale pizzeria", + "pasquale pizza" + ], + "efes": [ + "efes restaurant" + ], + "efes restaurant": [ + "efes" + ], + "fitzbillies restaurant": [ + "fitzbillies" + ], + "frankie and bennys": [ + "frankie and benny's" + ], + "funky": [ + "funky fun house" + ], + "funky fun house": [ + "funky" + ], + "gardenia": [ + "the gardenia" + ], + "the gardenia": [ + "gardenia" + ], + "grafton hotel restaurant": [ + "the grafton hotel", + "grafton hotel" + ], + "the grafton hotel": [ + "grafton hotel restaurant", + "grafton hotel" + ], + "grafton hotel": [ + "grafton hotel restaurant", + "the grafton hotel" + ], + "hotel du vin and bistro": [ + "hotel du vin", + "du vin" + ], + "Kohinoor": [ + "kohinoor", + "the kohinoor" + ], + "kohinoor": [ + "the kohinoor" + ], + "the kohinoor": [ + "kohinoor" + ], + "lan hong house": [ + "lan hong", + "ian hong house", + "ian hong" + ], + "ian hong": [ + "lan hong house", + "lan hong", + "ian hong house" + ], + "lovel": [ + "the lovell lodge", + "lovell lodge" + ], + "lovell lodge": [ + "lovell" + ], + "the lovell lodge": [ + "lovell lodge", + "lovell" + ], + "mahal of cambridge": [ + "mahal" + ], + "mahal": [ + "mahal of cambridge" + ], + "maharajah tandoori restaurant": [ + "maharajah tandoori" + ], + "the maharajah tandoor": [ + "maharajah tandoori restaurant", + "maharajah tandoori" + ], + "meze bar": [ + "meze bar restaurant", + "the meze bar" + ], + "meze bar restaurant": [ + "the meze bar", + "meze bar" + ], + "the meze bar": [ + "meze bar restaurant", + "meze bar" + ], + "michaelhouse cafe": [ + "michael house cafe" + ], + "midsummer house restaurant": [ + "midsummer house" + ], + "missing sock": [ + "the missing sock" + ], + "the missing sock": [ + "missing sock" + ], + "nandos": [ + "nando's city centre", + "nando's city center", + "nandos city centre", + "nandos city center", + "nando's" + ], + "nandos city centre": [ + "nando's city centre", + "nando's city center", + "nandos city center", + "nando's", + "nandos" + ], + "oak bistro": [ + "the oak bistro" + ], + "the oak bistro": [ + "oak bistro" + ], + "one seven": [ + "restaurant one seven" + ], + "restaurant one seven": [ + "one seven" + ], + "river bar steakhouse and grill": [ + "the river bar steakhouse and grill", + "the river bar steakhouse", + "river bar steakhouse" + ], + "the river bar steakhouse and grill": [ + "river bar steakhouse and grill", + "the river bar steakhouse", + "river bar steakhouse" + ], + "pipasha restaurant": [ + "pipasha" + ], + "pizza hut city centre": [ + "pizza hut city center" + ], + "pizza hut fenditton": [ + "pizza hut fen ditton", + "pizza express fen ditton" + ], + "restaurant two two": [ + "two two", + "restaurant 22" + ], + "saffron brasserie": [ + "saffron" + ], + "saint johns chop house": [ + "saint john's chop house", + "st john's chop house", + "st johns chop house" + ], + "sesame restaurant and bar": [ + "sesame restaurant", + "sesame" + ], + "shanghai": [ + "shanghai family restaurant" + ], + "shanghai family restaurant": [ + "shanghai" + ], + "sitar": [ + "sitar tandoori" + ], + "sitar tandoori": [ + "sitar" + ], + "slug and lettuce": [ + "the slug and lettuce" + ], + "the slug and lettuce": [ + "slug and lettuce" + ], + "st johns chop house": [ + "saint john's chop house", + "st john's chop house", + "saint johns chop house" + ], + "stazione restaurant and coffee bar": [ + "stazione restaurant", + "stazione" + ], + "thanh binh": [ + "thanh", + "binh" + ], + "thanh": [ + "thanh binh", + "binh" + ], + "binh": [ + "thanh binh", + "thanh" + ], + "the hotpot": [ + "the hotspot", + "hotpot", + "hotspot" + ], + "hotpot": [ + "the hotpot", + "the hotpot", + "hotspot" + ], + "the lucky star": [ + "lucky star" + ], + "lucky star": [ + "the lucky star" + ], + "the peking restaurant: ": [ + "peking restaurant" + ], + "the varsity restaurant": [ + "varsity restaurant", + "the varsity", + "varsity" + ], + "two two": [ + "restaurant two two", + "restaurant 22" + ], + "restaurant 22": [ + "restaurant two two", + "two two" + ], + "zizzi cambridge": [ + "zizzi" + ], + "american": [ + "americas" + ], + "asian oriental": [ + "asian", + "oriental" + ], + "australian": [ + "australasian" + ], + "barbeque": [ + "barbecue", + "bbq" + ], + "corsica": [ + "corsican" + ], + "indian": [ + "tandoori" + ], + "italian": [ + "pizza", + "pizzeria" + ], + "japanese": [ + "sushi" + ], + "latin american": [ + "latin-american", + "latin" + ], + "malaysian": [ + "malay" + ], + "middle eastern": [ + "middle-eastern" + ], + "traditional american": [ + "american" + ], + "modern american": [ + "american modern", + "american" + ], + "modern european": [ + "european modern", + "european" + ], + "north american": [ + "north-american", + "american" + ], + "portuguese": [ + "portugese" + ], + "portugese": [ + "portuguese" + ], + "seafood": [ + "sea food" + ], + "singaporean": [ + "singapore" + ], + "steakhouse": [ + "steak house", + "steak" + ], + "the americas": [ + "american", + "americas" + ], + "a and b guest house": [ + "a & b guest house", + "a and b", + "a & b" + ], + "the acorn guest house": [ + "acorn guest house", + "acorn" + ], + "acorn guest house": [ + "the acorn guest house", + "acorn" + ], + "alexander bed and breakfast": [ + "alexander" + ], + "allenbell": [ + "the allenbell" + ], + "the allenbell": [ + "allenbell" + ], + "alpha-milton guest house": [ + "the alpha-milton", + "alpha-milton" + ], + "the alpha-milton": [ + "alpha-milton guest house", + "alpha-milton" + ], + "arbury lodge guest house": [ + "arbury lodge", + "arbury" + ], + "archway house": [ + "archway" + ], + "ashley hotel": [ + "the ashley hotel", + "ashley" + ], + "the ashley hotel": [ + "ashley hotel", + "ashley" + ], + "aylesbray lodge guest house": [ + "aylesbray lodge", + "aylesbray" + ], + "aylesbray lodge guest": [ + "aylesbray lodge guest house", + "aylesbray lodge", + "aylesbray" + ], + "alesbray lodge guest house": [ + "aylesbray lodge guest house", + "aylesbray lodge", + "aylesbray" + ], + "alyesbray lodge hotel": [ + "aylesbray lodge guest house", + "aylesbray lodge", + "aylesbray" + ], + "bridge guest house": [ + "bridge house" + ], + "cambridge belfry": [ + "the cambridge belfry", + "belfry hotel", + "belfry" + ], + "the cambridge belfry": [ + "cambridge belfry", + "belfry hotel", + "belfry" + ], + "belfry hotel": [ + "the cambridge belfry", + "cambridge belfry", + "belfry" + ], + "carolina bed and breakfast": [ + "carolina" + ], + "city centre north": [ + "city centre north bed and breakfast" + ], + "north b and b": [ + "city centre north bed and breakfast" + ], + "city centre north b and b": [ + "city centre north bed and breakfast" + ], + "el shaddia guest house": [ + "el shaddai guest house", + "el shaddai", + "el shaddia" + ], + "el shaddai guest house": [ + "el shaddia guest house", + "el shaddai", + "el shaddia" + ], + "express by holiday inn cambridge": [ + "express by holiday inn", + "holiday inn cambridge", + "holiday inn" + ], + "holiday inn": [ + "express by holiday inn cambridge", + "express by holiday inn", + "holiday inn cambridge" + ], + "finches bed and breakfast": [ + "finches" + ], + "gonville hotel": [ + "gonville" + ], + "hamilton lodge": [ + "the hamilton lodge", + "hamilton" + ], + "the hamilton lodge": [ + "hamilton lodge", + "hamilton" + ], + "hobsons house": [ + "hobson's house", + "hobson's" + ], + "huntingdon marriott hotel": [ + "huntington marriott hotel", + "huntington marriot hotel", + "huntingdon marriot hotel", + "huntington marriott", + "huntingdon marriott", + "huntington marriot", + "huntingdon marriot", + "huntington", + "huntingdon" + ], + "kirkwood": [ + "kirkwood house" + ], + "kirkwood house": [ + "kirkwood" + ], + "lensfield hotel": [ + "the lensfield hotel", + "lensfield" + ], + "the lensfield hotel": [ + "lensfield hotel", + "lensfield" + ], + "leverton house": [ + "leverton" + ], + "marriot hotel": [ + "marriott hotel", + "marriott" + ], + "rosas bed and breakfast": [ + "rosa's bed and breakfast", + "rosa's", + "rosas" + ], + "university arms hotel": [ + "university arms" + ], + "warkworth house": [ + "warkworth hotel", + "warkworth" + ], + "warkworth hotel": [ + "warkworth house", + "warkworth" + ], + "wartworth": [ + "warkworth house", + "warkworth hotel", + "warkworth" + ], + "worth house": [ + "the worth house" + ], + "the worth house": [ + "worth house" + ], + "birmingham new street": [ + "birmingham new street train station" + ], + "birmingham new street train station": [ + "birmingham new street" + ], + "bishops stortford": [ + "bishops stortford train station" + ], + "bishops stortford train station": [ + "bishops stortford" + ], + "broxbourne": [ + "broxbourne train station" + ], + "broxbourne train station": [ + "broxbourne" + ], + "cambridge": [ + "cambridge train station" + ], + "cambridge train station": [ + "cambridge" + ], + "ely": [ + "ely train station" + ], + "ely train station": [ + "ely" + ], + "kings lynn": [ + "king's lynn", + "king's lynn train station", + "kings lynn train station" + ], + "kings lynn train station": [ + "kings lynn", + "king's lynn", + "king's lynn train station" + ], + "leicester": [ + "leicester train station" + ], + "leicester train station": [ + "leicester" + ], + "london kings cross": [ + "kings cross", + "king's cross", + "london king's cross", + "kings cross train station", + "king's cross train station", + "london king's cross train station", + "london kings cross train station" + ], + "london kings cross train station": [ + "kings cross", + "king's cross", + "london king's cross", + "london kings cross", + "kings cross train station", + "king's cross train station", + "london king's cross train station" + ], + "london liverpool": [ + "liverpool street", + "london liverpool street", + "london liverpool train station", + "liverpool street train station", + "london liverpool street train station" + ], + "london liverpool street": [ + "london liverpool", + "liverpool street", + "london liverpool train station", + "liverpool street train station", + "london liverpool street train station" + ], + "london liverpool street train station": [ + "london liverpool", + "liverpool street", + "london liverpool street", + "london liverpool train station", + "liverpool street train station" + ], + "norwich": [ + "norwich train station" + ], + "norwich train station": [ + "norwich" + ], + "peterborough": [ + "peterborough train station" + ], + "peterborough train station": [ + "peterborough" + ], + "stansted airport": [ + "stansted airport train station" + ], + "stansted airport train station": [ + "stansted airport" + ], + "stevenage": [ + "stevenage train station" + ], + "stevenage train station": [ + "stevenage" + ] + } +} diff --git a/dataset_config/woz2.json b/dataset_config/woz2.json new file mode 100644 index 0000000..f97eda8 --- /dev/null +++ b/dataset_config/woz2.json @@ -0,0 +1,228 @@ +{ + "class_types": [ + "none", + "dontcare", + "copy_value", + "inform" + ], + "slots": { + "area": "the area of the restaurant", + "food": "the food type of the restaurant", + "price_range": "the price range of the restaurant" + }, + "noncategorical": [ + "food" + ], + "categorical": [ + "area", + "price_range" + ], + "boolean": [], + "label_maps": { + "center": [ + "centre", + "downtown", + "central", + "down town", + "middle" + ], + "centre": [ + "center", + "downtown", + "central", + "down town", + "middle" + ], + "south": [ + "southern", + "southside" + ], + "north": [ + "northern", + "uptown", + "northside" + ], + "west": [ + "western", + "westside" + ], + "east": [ + "eastern", + "eastside" + ], + "east side": [ + "eastern", + "eastside" + ], + "cheap": [ + "low price", + "inexpensive", + "cheaper", + "low priced", + "affordable", + "nothing too expensive", + "without costing a fortune", + "cheapest", + "good deals", + "low prices", + "afford", + "on a budget", + "fair prices", + "less expensive", + "cheapeast", + "not cost an arm and a leg" + ], + "moderate": [ + "moderately", + "medium priced", + "medium price", + "fair price", + "fair prices", + "reasonable", + "reasonably priced", + "mid price", + "fairly priced", + "not outrageous", + "not too expensive", + "on a budget", + "mid range", + "reasonable priced", + "less expensive", + "not too pricey", + "nothing too expensive", + "nothing cheap", + "not overpriced", + "medium", + "inexpensive" + ], + "expensive": [ + "high priced", + "high end", + "high class", + "high quality", + "fancy", + "upscale", + "nice", + "fine dining", + "expensively priced", + "not some cheapie" + ], + "afghan": [ + "afghanistan" + ], + "african": [ + "africa" + ], + "asian oriental": [ + "asian", + "oriental" + ], + "australasian": [ + "australian asian", + "austral asian" + ], + "australian": [ + "aussie" + ], + "barbeque": [ + "barbecue", + "bbq" + ], + "basque": [ + "bask" + ], + "belgian": [ + "belgium" + ], + "british": [ + "cotto" + ], + "canapes": [ + "canopy", + "canape", + "canap" + ], + "catalan": [ + "catalonian" + ], + "corsican": [ + "corsica" + ], + "crossover": [ + "cross over", + "over" + ], + "gastropub": [ + "gastro pub", + "gastro", + "gastropubs" + ], + "hungarian": [ + "goulash" + ], + "indian": [ + "india", + "indians", + "nirala" + ], + "international": [ + "all types of food" + ], + "italian": [ + "prezzo" + ], + "jamaican": [ + "jamaica" + ], + "japanese": [ + "sushi", + "beni hana" + ], + "korean": [ + "korea" + ], + "lebanese": [ + "lebanse" + ], + "north american": [ + "american", + "hamburger" + ], + "portuguese": [ + "portugese" + ], + "seafood": [ + "sea food", + "shellfish", + "fish" + ], + "singaporean": [ + "singapore" + ], + "steakhouse": [ + "steak house", + "steak" + ], + "thai": [ + "thailand", + "bangkok" + ], + "traditional": [ + "old fashioned", + "plain" + ], + "turkish": [ + "turkey" + ], + "unusual": [ + "unique and strange" + ], + "venetian": [ + "vanessa" + ], + "vietnamese": [ + "vietnam", + "thanh binh" + ] + } +} diff --git a/dataset_multiwoz21.py b/dataset_multiwoz21.py new file mode 100644 index 0000000..d57f1c9 --- /dev/null +++ b/dataset_multiwoz21.py @@ -0,0 +1,680 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from tqdm import tqdm + +from utils_dst import (DSTExample, convert_to_unicode) + + +# Required for mapping slot names in dialogue_acts.json file +# to proper designations. +ACTS_DICT = {'taxi-depart': 'taxi-departure', + 'taxi-dest': 'taxi-destination', + 'taxi-leave': 'taxi-leaveAt', + 'taxi-arrive': 'taxi-arriveBy', + 'train-depart': 'train-departure', + 'train-dest': 'train-destination', + 'train-leave': 'train-leaveAt', + 'train-arrive': 'train-arriveBy', + 'train-people': 'train-book_people', + 'restaurant-price': 'restaurant-pricerange', + 'restaurant-people': 'restaurant-book_people', + 'restaurant-day': 'restaurant-book_day', + 'restaurant-time': 'restaurant-book_time', + 'hotel-price': 'hotel-pricerange', + 'hotel-people': 'hotel-book_people', + 'hotel-day': 'hotel-book_day', + 'hotel-stay': 'hotel-book_stay', + 'booking-people': 'booking-book_people', + 'booking-day': 'booking-book_day', + 'booking-stay': 'booking-book_stay', + 'booking-time': 'booking-book_time', + 'taxi-leaveat': 'taxi-leaveAt', + 'taxi-arriveby': 'taxi-arriveBy', + 'train-leaveat': 'train-leaveAt', + 'train-arriveby': 'train-arriveBy', + 'train-bookpeople': 'train-book_people', + 'restaurant-bookpeople': 'restaurant-book_people', + 'restaurant-bookday': 'restaurant-book_day', + 'restaurant-booktime': 'restaurant-book_time', + 'hotel-bookpeople': 'hotel-book_people', + 'hotel-bookday': 'hotel-book_day', + 'hotel-bookstay': 'hotel-book_stay', + 'booking-bookpeople': 'booking-book_people', + 'booking-bookday': 'booking-book_day', + 'booking-bookstay': 'booking-book_stay', + 'booking-booktime': 'booking-book_time' +} + + +def prediction_normalization(slot, value): + def _normalize_time(text): + informed = False + if text[:2] == '§§': + informed = True + text = text[2:] + text = re.sub("noon", r"12:00", text) # noon + text = re.sub("(\d{1})(a\.?m\.?|p\.?m\.?)", r"\1 \2", text) # am/pm without space + text = re.sub("(^| )(\d{1,2}) ?[^0-9]? ?(\d{2})", r"\1\2:\3", text) # Missing/wrong separator + text = re.sub("(^| )(\d{1,2})( |$)", r"\1\2:00\3", text) # normalize simple full hour time + text = re.sub("(^| )(\d{1}:\d{2})", r"\g<1>0\2", text) # Add missing leading 0 + # Map 12 hour times to 24 hour times + text = re.sub("(\d{2})(:\d{2}) ?p\.?m\.?", lambda x: str(int(x.groups()[0]) + 12 if int(x.groups()[0]) < 12 else int(x.groups()[0])) + x.groups()[1], text) + text = re.sub("(^| )24:(\d{2})", r"\g<1>00:\2", text) # Correct times that use 24 as hour + final = re.match(".*((before|after) \d{2}:\d{2})", text) + result = text + if final is not None: + result = final[1] + final = re.match(".*(\d{2}:\d{2})", text) + if final is not None: + result = final[1] + if informed and result[:2] != '§§': + return '§§' + result + return result + + def _normalize_value(text): + text = re.sub(" ?' ?s", "s", text) + return text + + if "leave" in slot or "arrive" in slot or "time" in slot: + if isinstance(value, list): + for e_itr in range(len(value)): + for ee_itr in range(len(value[e_itr])): + tmp = list(value[e_itr][ee_itr]) + tmp[0] = _normalize_time(tmp[0]) + value[e_itr][ee_itr] = tuple(tmp) + else: + value = _normalize_time(value) + else: + value = _normalize_value(value) + return value + + +def normalize_time(text): + text = re.sub("(\d{1})(a\.?m\.?|p\.?m\.?)", r"\1 \2", text) # am/pm without space + text = re.sub("(^| )(\d{1,2}) (a\.?m\.?|p\.?m\.?)", r"\1\2:00 \3", text) # am/pm short to long form + text = re.sub("(^| )(at|from|by|until|after) ?(\d{1,2}) ?(\d{2})([^0-9]|$)", r"\1\2 \3:\4\5", text) # Missing separator + text = re.sub("(^| )(\d{2})[;.,](\d{2})", r"\1\2:\3", text) # Wrong separator + text = re.sub("(^| )(at|from|by|until|after) ?(\d{1,2})([;., ]|$)", r"\1\2 \3:00\4", text) # normalize simple full hour time + text = re.sub("(^| )(\d{1}:\d{2})", r"\g<1>0\2", text) # Add missing leading 0 + # Map 12 hour times to 24 hour times + text = re.sub("(\d{2})(:\d{2}) ?p\.?m\.?", lambda x: str(int(x.groups()[0]) + 12 if int(x.groups()[0]) < 12 else int(x.groups()[0])) + x.groups()[1], text) + text = re.sub("(^| )24:(\d{2})", r"\g<1>00:\2", text) # Correct times that use 24 as hour + return text + + +def normalize_text(text): + text = normalize_time(text) + text = re.sub("n't", " not", text) + text = re.sub("(^| )zero(-| )star([s.,? ]|$)", r"\g<1>0 star\3", text) + text = re.sub("(^| )one(-| )star([s.,? ]|$)", r"\g<1>1 star\3", text) + text = re.sub("(^| )two(-| )star([s.,? ]|$)", r"\g<1>2 star\3", text) + text = re.sub("(^| )three(-| )star([s.,? ]|$)", r"\g<1>3 star\3", text) + text = re.sub("(^| )four(-| )star([s.,? ]|$)", r"\g<1>4 star\3", text) + text = re.sub("(^| )five(-| )star([s.,? ]|$)", r"\g<1>5 star\3", text) + text = re.sub("archaelogy", "archaeology", text) # Systematic typo + text = re.sub("guesthouse", "guest house", text) # Normalization + text = re.sub("(^| )b ?& ?b([.,? ]|$)", r"\1bed and breakfast\2", text) # Normalization + text = re.sub("bed & breakfast", "bed and breakfast", text) # Normalization + text = re.sub("\t", " ", text) # Error + text = re.sub("\n", " ", text) # Error + return text + + +# This should only contain label normalizations, no label mappings. +def normalize_label(slot, value_label, boolean_slots=True): + # Normalization of capitalization + if isinstance(value_label, str): + value_label = value_label.lower().strip() + elif isinstance(value_label, list): + if len(value_label) > 1: + value_label = value_label[0] # TODO: Workaround. Note that Multiwoz 2.2 supports variants directly in the labels. + elif len(value_label) == 1: + value_label = value_label[0] + elif len(value_label) == 0: + value_label = "" + + # Normalization of empty slots + if value_label == '' or value_label == "not mentioned": + return "none" + + # Normalization of 'dontcare' + if value_label == 'dont care': + return "dontcare" + + # Normalization of time slots + if "leave" in slot or "arrive" in slot or "time" in slot: + return normalize_time(value_label) + + # Normalization + if "type" in slot or "name" in slot or "destination" in slot or "departure" in slot: + value_label = re.sub(" ?'s", "s", value_label) + value_label = re.sub("guesthouse", "guest house", value_label) + + # Map to boolean slots + if slot in ["hotel-parking", "hotel-internet"]: + slot_name = slot.split("-")[1] + if value_label in ["yes", "free"]: + if boolean_slots: + return "true" + else: + return slot_name + if value_label == "no": + if boolean_slots: + return "false" + else: + return "no " + slot_name + if slot == "hotel-type": + if value_label in ["bed and breakfast", "guest houses"]: + value_label = "guest house" + if boolean_slots: + if value_label == "hotel": + return "true" + if value_label == "guest house": + return "false" + + return value_label + + +def get_token_pos(tok_list, value_label): + find_pos = [] + found = False + label_list = [item for item in map(str.strip, re.split("(\W+)", value_label)) if len(item) > 0] + len_label = len(label_list) + for i in range(len(tok_list) + 1 - len_label): + if tok_list[i:i + len_label] == label_list: + find_pos.append((i, i + len_label)) # start, exclusive_end + found = True + return found, find_pos + + +def check_label_existence(value_label, usr_utt_tok, label_maps={}): + in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label) + # If no hit even though there should be one, check for value label variants + if not in_usr and value_label in label_maps: + for value_label_variant in label_maps[value_label]: + in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label_variant) + if in_usr: + break + return in_usr, usr_pos + + +def check_slot_referral(value_label, slot, seen_slots, label_maps={}): + referred_slot = 'none' + if slot == 'hotel-stars' or slot == 'hotel-internet' or slot == 'hotel-parking': + return referred_slot + for s in seen_slots: + # Avoid matches for slots that share values with different meaning. + # hotel-internet and -parking are handled separately as Boolean slots. + if s == 'hotel-stars' or s == 'hotel-internet' or s == 'hotel-parking': + continue + if re.match("(hotel|restaurant)-book_people", s) and slot == 'hotel-book_stay': + continue + if re.match("(hotel|restaurant)-book_people", slot) and s == 'hotel-book_stay': + continue + if slot != s and (slot not in seen_slots or seen_slots[slot] != value_label): + if seen_slots[s] == value_label: + referred_slot = s + break + elif value_label in label_maps: + for value_label_variant in label_maps[value_label]: + if seen_slots[s] == value_label_variant: + referred_slot = s + break + return referred_slot + + +def is_in_list(tok, value): + found = False + tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0] + value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0] + tok_len = len(tok_list) + value_len = len(value_list) + for i in range(tok_len + 1 - value_len): + if tok_list[i:i + value_len] == value_list: + found = True + break + return found + + +def delex_utt(utt, values, unk_token="[UNK]"): + utt_norm = tokenize(utt) + for s, vals in values.items(): + for v in vals: + if v != 'none': + v_norm = tokenize(v) + v_len = len(v_norm) + for i in range(len(utt_norm) + 1 - v_len): + if utt_norm[i:i + v_len] == v_norm: + utt_norm[i:i + v_len] = [unk_token] * v_len + return utt_norm + + +# Fuzzy matching to label informed slot values +def check_slot_inform(value_label, inform_label, label_maps={}): + result = False + informed_value = 'none' + vl = ' '.join(tokenize(value_label)) + for il in inform_label: + if vl == il: + result = True + elif is_in_list(il, vl): + result = True + elif is_in_list(vl, il): + result = True + elif il in label_maps: + for il_variant in label_maps[il]: + if vl == il_variant: + result = True + break + elif is_in_list(il_variant, vl): + result = True + break + elif is_in_list(vl, il_variant): + result = True + break + elif vl in label_maps: + for value_label_variant in label_maps[vl]: + if value_label_variant == il: + result = True + break + elif is_in_list(il, value_label_variant): + result = True + break + elif is_in_list(value_label_variant, il): + result = True + break + if result: + informed_value = il + break + return result, informed_value + + +def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, seen_slots, slot_last_occurrence, label_maps={}): + usr_utt_tok_label = [0 for _ in usr_utt_tok] + informed_value = 'none' + referred_slot = 'none' + if value_label == 'none' or value_label == 'dontcare' or value_label == 'true' or value_label == 'false': + class_type = value_label + else: + in_usr, usr_pos = check_label_existence(value_label, usr_utt_tok, label_maps) + is_informed, informed_value = check_slot_inform(value_label, inform_label, label_maps) + if in_usr: + class_type = 'copy_value' + if slot_last_occurrence: + (s, e) = usr_pos[-1] + for i in range(s, e): + usr_utt_tok_label[i] = 1 + else: + for (s, e) in usr_pos: + for i in range(s, e): + usr_utt_tok_label[i] = 1 + elif is_informed: + class_type = 'inform' + else: + referred_slot = check_slot_referral(value_label, slot, seen_slots, label_maps) + if referred_slot != 'none': + class_type = 'refer' + else: + class_type = 'unpointable' + return informed_value, referred_slot, usr_utt_tok_label, class_type + + +# Requestable slots, general acts and domain indicator slots +def is_request(slot, user_act, turn_domains): + if slot in user_act: + if isinstance(user_act[slot], list): + for act in user_act[slot]: + if act['intent'] in ['request', 'bye', 'thank', 'greet']: + return True + elif user_act[slot]['intent'] in ['request', 'bye', 'thank', 'greet']: + return True + do, sl = slot.split('-') + if sl == 'none' and do in turn_domains: + return True + return False + + +def tokenize(utt): + utt_lower = convert_to_unicode(utt).lower() + utt_lower = normalize_text(utt_lower) + utt_tok = utt_to_token(utt_lower) + return utt_tok + + +def utt_to_token(utt): + return [tok for tok in map(lambda x: re.sub(" ", "", x), re.split("(\W+)", utt)) if len(tok) > 0] + + +def create_examples(input_file, set_type, class_types, slot_list, + label_maps={}, + no_label_value_repetitions=False, + swap_utterances=False, + delexicalize_sys_utts=False, + unk_token="[UNK]", + boolean_slots=True, + analyze=False): + """Read a DST json file into a list of DSTExample.""" + + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + + examples = [] + for d_itr, dialog_id in enumerate(tqdm(input_data)): + entry = input_data[dialog_id] + utterances = entry['log'] + + # Collects all slot changes throughout the dialog + cumulative_labels = {slot: 'none' for slot in slot_list} + + # First system utterance is empty, since multiwoz starts with user input + utt_tok_list = [[]] + mod_slots_list = [{}] + inform_dict_list = [{}] + user_act_dict_list = [{}] + mod_domains_list = [{}] + + # Collect all utterances and their metadata + usr_sys_switch = True + turn_itr = 0 + for utt in utterances: + # Assert that system and user utterances alternate + is_sys_utt = utt['metadata'] != {} + if usr_sys_switch == is_sys_utt: + print("WARN: Wrong order of system and user utterances. Skipping rest of dialog %s" % (dialog_id)) + break + usr_sys_switch = is_sys_utt + + if is_sys_utt: + turn_itr += 1 + + # Extract dialog_act information for sys and usr utts. + inform_dict = {} + user_act_dict = {} + modified_slots = {} + modified_domains = set() + if 'dialog_act' in utt: + for a in utt['dialog_act']: + aa = a.lower().split('-') + for i in utt['dialog_act'][a]: + s = i[0].lower() + # Some special intents are modeled as slots in TripPy + if aa[0] == 'general': + cs = "%s-%s" % (aa[0], aa[1]) + else: + cs = "%s-%s" % (aa[0], s) + if cs in ACTS_DICT: + cs = ACTS_DICT[cs] + v = normalize_label(cs, i[1].lower().strip()) + if cs in ['hotel-internet', 'hotel-parking']: + v = 'true' + modified_domains.add(aa[0]) # Remember domains + if is_sys_utt and aa[1] in ['inform', 'recommend', 'select', 'book'] and v != 'none': + if cs not in inform_dict: + inform_dict[cs] = [] + inform_dict[cs].append(v) + elif not is_sys_utt: + if cs not in user_act_dict: + user_act_dict[cs] = [] + user_act_dict[cs] = {'domain': aa[0], 'intent': aa[1], 'slot': s, 'value': v} + # INFO: Since the model has no mechanism to predict + # one among several informed value candidates, we + # keep only one informed value. For fairness, we + # apply a global rule: + for e in inform_dict: + # ... Option 1: Always keep first informed value + inform_dict[e] = list([inform_dict[e][0]]) + # ... Option 2: Always keep last informed value + #inform_dict[e] = list([inform_dict[e][-1]]) + else: + print("WARN: dialogue %s is missing dialog_act information." % dialog_id) + + # If sys utt, extract metadata (identify and collect modified slots) + if is_sys_utt: + for d in utt['metadata']: + booked = utt['metadata'][d]['book']['booked'] + booked_slots = {} + # Check the booked section + if booked != []: + for s in booked[0]: + booked_slots[s] = normalize_label('%s-%s' % (d, s), booked[0][s], boolean_slots) + # Check the semi and the inform slots + for category in ['book', 'semi']: + for s in utt['metadata'][d][category]: + cs = '%s-book_%s' % (d, s) if category == 'book' else '%s-%s' % (d, s) + value_label = normalize_label(cs, utt['metadata'][d][category][s], boolean_slots) + # Prefer the slot value as stored in the booked section + if s in booked_slots: + value_label = booked_slots[s] + # Remember modified slots and entire dialog state + if cs in slot_list and cumulative_labels[cs] != value_label: + modified_slots[cs] = value_label + cumulative_labels[cs] = value_label + modified_domains.add(cs.split("-")[0]) # Remember domains + + # Delexicalize sys utterance + if delexicalize_sys_utts and is_sys_utt: + utt_tok_list.append(delex_utt(utt['text'], inform_dict, unk_token)) # normalizes utterances + else: + utt_tok_list.append(tokenize(utt['text'])) # normalizes utterances + + inform_dict_list.append(inform_dict.copy()) + user_act_dict_list.append(user_act_dict.copy()) + mod_slots_list.append(modified_slots.copy()) + modified_domains = list(modified_domains) + modified_domains.sort() + mod_domains_list.append(modified_domains) + + # Form proper (usr, sys) turns + turn_itr = 0 + diag_seen_slots_dict = {} + diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list} + diag_state = {slot: 'none' for slot in slot_list} + sys_utt_tok = [] + usr_utt_tok = [] + for i in range(1, len(utt_tok_list) - 1, 2): + sys_utt_tok_label_dict = {} + usr_utt_tok_label_dict = {} + value_dict = {} + inform_dict = {} + inform_slot_dict = {} + referral_dict = {} + class_type_dict = {} + updated_slots = {slot: 0 for slot in slot_list} + + # Collect turn data + sys_utt_tok = utt_tok_list[i - 1] + usr_utt_tok = utt_tok_list[i] + turn_slots = mod_slots_list[i + 1] + inform_mem = inform_dict_list[i - 1] + user_act = user_act_dict_list[i] + turn_domains = mod_domains_list[i + 1] + + guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr)) + + if analyze: + print("%15s %2s %s ||| %s" % (dialog_id, turn_itr, ' '.join(sys_utt_tok), ' '.join(usr_utt_tok))) + print("%15s %2s [" % (dialog_id, turn_itr), end='') + + new_diag_state = diag_state.copy() + + for slot in slot_list: + value_label = 'none' + if slot in turn_slots: + value_label = turn_slots[slot] + # We keep the original labels so as to not + # overlook unpointable values, as well as to not + # modify any of the original labels for test sets, + # since this would make comparison difficult. + value_dict[slot] = value_label + elif not no_label_value_repetitions and slot in diag_seen_slots_dict: + value_label = diag_seen_slots_value_dict[slot] + + # Get dialog act annotations + inform_label = list(['none']) + inform_slot_dict[slot] = 0 + booking_slot = 'booking-' + slot.split('-')[1] + if slot in inform_mem: + inform_label = inform_mem[slot] + inform_slot_dict[slot] = 1 + elif booking_slot in inform_mem: + inform_label = inform_mem[booking_slot] + inform_slot_dict[slot] = 1 + + (informed_value, + referred_slot, + usr_utt_tok_label, + class_type) = get_turn_label(value_label, + inform_label, + sys_utt_tok, + usr_utt_tok, + slot, + diag_seen_slots_value_dict, + slot_last_occurrence=True, + label_maps=label_maps) + + inform_dict[slot] = informed_value + + # Requestable slots, domain indicator slots and general slots + # should have class_type 'request', if they ought to be predicted. + # Give other class_types preference. + if 'request' in class_types: + if class_type in ['none', 'unpointable'] and is_request(slot, user_act, turn_domains): + class_type = 'request' + + # Generally don't use span prediction on sys utterance (but inform prediction instead). + sys_utt_tok_label = [0 for _ in sys_utt_tok] + + # Determine what to do with value repetitions. + # If value is unique in seen slots, then tag it, otherwise not, + # since correct slot assignment can not be guaranteed anymore. + if not no_label_value_repetitions and slot in diag_seen_slots_dict: + if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(value_label) > 1: + class_type = 'none' + usr_utt_tok_label = [0 for _ in usr_utt_tok_label] + + sys_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label_dict[slot] = usr_utt_tok_label + + if diag_seen_slots_value_dict[slot] != value_label: + updated_slots[slot] = 1 + + # For now, we map all occurences of unpointable slot values + # to none. However, since the labels will still suggest + # a presence of unpointable slot values, the task of the + # DST is still to find those values. It is just not + # possible to do that via span prediction on the current input. + if class_type == 'unpointable': + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + if analyze: + if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[slot]: + print("(%s): %s, " % (slot, value_label), end='') + elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform': + # If slot has seen before and its class type did not change, label this slot a not present, + # assuming that the slot has not actually been mentioned in this turn. + # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, + # this must mean there is evidence in the original labels, therefore consider + # them as mentioned again. + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + else: + class_type_dict[slot] = class_type + referral_dict[slot] = referred_slot + # Remember that this slot was mentioned during this dialog already. + if class_type != 'none': + diag_seen_slots_dict[slot] = class_type + diag_seen_slots_value_dict[slot] = value_label + new_diag_state[slot] = class_type + # Unpointable is not a valid class, therefore replace with + # some valid class for now... + if class_type == 'unpointable': + new_diag_state[slot] = 'copy_value' + + if analyze: + print("]") + + if not swap_utterances: + txt_a = usr_utt_tok + txt_b = sys_utt_tok + txt_a_lbl = usr_utt_tok_label_dict + txt_b_lbl = sys_utt_tok_label_dict + else: + txt_a = sys_utt_tok + txt_b = usr_utt_tok + txt_a_lbl = sys_utt_tok_label_dict + txt_b_lbl = usr_utt_tok_label_dict + examples.append(DSTExample( + guid=guid, + text_a=txt_a, + text_b=txt_b, + text_a_label=txt_a_lbl, + text_b_label=txt_b_lbl, + values=diag_seen_slots_value_dict.copy(), + inform_label=inform_dict, + inform_slot_label=inform_slot_dict, + refer_label=referral_dict, + diag_state=diag_state, + slot_update=updated_slots, + class_label=class_type_dict)) + + # Update some variables. + diag_state = new_diag_state.copy() + + turn_itr += 1 + + if analyze: + print("----------------------------------------------------------------------") + + return examples + + +def get_value_list(input_file, slot_list, boolean_slots=True): + exclude = ['none', 'dontcare'] + if not boolean_slots: + exclude += ['true', 'false'] + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + value_list = {slot: {} for slot in slot_list} + for dialog_id in input_data: + for utt in input_data[dialog_id]['log']: + if utt['metadata'] != {}: + for d in utt['metadata']: + booked = utt['metadata'][d]['book']['booked'] + booked_slots = {} + # Check the booked section + if booked != []: + for s in booked[0]: + booked_slots[s] = normalize_label('%s-%s' % (d, s), booked[0][s], boolean_slots) + # Check the semi and the inform slots + for category in ['book', 'semi']: + for s in utt['metadata'][d][category]: + cs = '%s-book_%s' % (d, s) if category == 'book' else '%s-%s' % (d, s) + value_label = normalize_label(cs, utt['metadata'][d][category][s], boolean_slots) + # Prefer the slot value as stored in the booked section + if s in booked_slots: + value_label = booked_slots[s] + if cs in slot_list and value_label not in exclude: + if "|" not in value_label and "<" not in value_label and ">" not in value_label: + if value_label not in value_list[cs]: + value_list[cs][value_label] = 0 + value_list[cs][value_label] += 1 + return value_list diff --git a/dataset_multiwoz21_legacy.py b/dataset_multiwoz21_legacy.py new file mode 100644 index 0000000..c6731a8 --- /dev/null +++ b/dataset_multiwoz21_legacy.py @@ -0,0 +1,358 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from tqdm import tqdm + +from utils_dst import (DSTExample) + +from dataset_multiwoz21 import (ACTS_DICT, + tokenize, normalize_label, + get_turn_label, delex_utt) + + +# Loads the dialogue_acts.json and returns a list +# of slot-value pairs. +def load_acts(input_file, boolean_slots=True): + with open(input_file) as f: + acts = json.load(f) + s_dict = {} + for d in acts: + for t in acts[d]: + # Only process, if turn has annotation + if isinstance(acts[d][t], dict): + is_22_format = False + if 'dialog_act' in acts[d][t]: + is_22_format = True + acts_list = acts[d][t]['dialog_act'] + if int(t) % 2 == 0: + continue + else: + acts_list = acts[d][t] + for a in acts_list: + aa = a.lower().split('-') + if aa[1] in ['inform', 'recommend', 'select', 'book']: + for i in acts_list[a]: + s = i[0].lower() + v = i[1].lower().strip() + if s == 'none' or v == '?': + continue + if v == 'none': + if s in ['parking', 'internet']: + if boolean_slots: + v = 'true' + else: + v = s + else: + continue + if d == 'hotel' and s == 'type': + if v in ['hotel', 'hotels']: + if boolean_slots: + v = 'true' + else: + v = 'hotel' + else: + if boolean_slots: + v = 'false' + else: + v = 'guest house' + slot = aa[0] + '-' + s + if slot in ACTS_DICT: + slot = ACTS_DICT[slot] + if is_22_format: + t_key = str(int(int(t) / 2 + 1)) + d_key = d + else: + t_key = t + d_key = d + '.json' + key = d_key, t_key, slot + # INFO: Since the model has no mechanism to predict + # one among several informed value candidates, we + # keep only one informed value. For fairness, we + # apply a global rule: + # ... Option 1: Keep first informed value + if key not in s_dict: + s_dict[key] = list([v]) + # ... Option 2: Keep last informed value + #s_dict[key] = list([v]) + return s_dict + + +def create_examples(input_file, acts_file, set_type, slot_list, + label_maps={}, + no_label_value_repetitions=False, + swap_utterances=False, + delexicalize_sys_utts=False, + unk_token="[UNK]", + boolean_slots=True, + analyze=False): + """Read a DST json file into a list of DSTExample.""" + + sys_inform_dict = load_acts(acts_file, boolean_slots) + + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + + examples = [] + for d_itr, dialog_id in enumerate(tqdm(input_data)): + entry = input_data[dialog_id] + utterances = entry['log'] + + # Collects all slot changes throughout the dialog + cumulative_labels = {slot: 'none' for slot in slot_list} + + # First system utterance is empty, since multiwoz starts with user input + utt_tok_list = [[]] + mod_slots_list = [{}] + + # Collect all utterances and their metadata + usr_sys_switch = True + turn_itr = 0 + for utt in utterances: + # Assert that system and user utterances alternate + is_sys_utt = utt['metadata'] != {} + if usr_sys_switch == is_sys_utt: + print("WARN: Wrong order of system and user utterances. Skipping rest of dialog %s" % (dialog_id)) + break + usr_sys_switch = is_sys_utt + + if is_sys_utt: + turn_itr += 1 + + # Delexicalize sys utterance + if delexicalize_sys_utts and is_sys_utt: + inform_dict = {slot: 'none' for slot in slot_list} + for slot in slot_list: + if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict: + inform_dict[slot] = sys_inform_dict[(str(dialog_id), str(turn_itr), slot)] + utt_tok_list.append(delex_utt(utt['text'], inform_dict, unk_token)) # normalizes utterances + else: + utt_tok_list.append(tokenize(utt['text'])) # normalizes utterances + + modified_slots = {} + + # If sys utt, extract metadata (identify and collect modified slots) + if is_sys_utt: + for d in utt['metadata']: + booked = utt['metadata'][d]['book']['booked'] + booked_slots = {} + # Check the booked section + if booked != []: + for s in booked[0]: + booked_slots[s] = normalize_label('%s-%s' % (d, s), booked[0][s], boolean_slots) # normalize labels + # Check the semi and the inform slots + for category in ['book', 'semi']: + for s in utt['metadata'][d][category]: + cs = '%s-book_%s' % (d, s) if category == 'book' else '%s-%s' % (d, s) + value_label = normalize_label(cs, utt['metadata'][d][category][s], boolean_slots) # normalize labels + # Prefer the slot value as stored in the booked section + if s in booked_slots: + value_label = booked_slots[s] + # Remember modified slots and entire dialog state + if cs in slot_list and cumulative_labels[cs] != value_label: + modified_slots[cs] = value_label + cumulative_labels[cs] = value_label + + mod_slots_list.append(modified_slots.copy()) + + # Form proper (usr, sys) turns + turn_itr = 0 + diag_seen_slots_dict = {} + diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list} + diag_state = {slot: 'none' for slot in slot_list} + sys_utt_tok = [] + usr_utt_tok = [] + for i in range(1, len(utt_tok_list) - 1, 2): + sys_utt_tok_label_dict = {} + usr_utt_tok_label_dict = {} + value_dict = {} + inform_dict = {} + inform_slot_dict = {} + referral_dict = {} + class_type_dict = {} + updated_slots = {slot: 0 for slot in slot_list} + + # Collect turn data + sys_utt_tok = utt_tok_list[i - 1] + usr_utt_tok = utt_tok_list[i] + turn_slots = mod_slots_list[i + 1] + + guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr)) + + if analyze: + print("%15s %2s %s ||| %s" % (dialog_id, turn_itr, ' '.join(sys_utt_tok), ' '.join(usr_utt_tok))) + print("%15s %2s [" % (dialog_id, turn_itr), end='') + + new_diag_state = diag_state.copy() + + for slot in slot_list: + value_label = 'none' + if slot in turn_slots: + value_label = turn_slots[slot] + # We keep the original labels so as to not + # overlook unpointable values, as well as to not + # modify any of the original labels for test sets, + # since this would make comparison difficult. + value_dict[slot] = value_label + elif not no_label_value_repetitions and slot in diag_seen_slots_dict: + value_label = diag_seen_slots_value_dict[slot] + + # Get dialog act annotations + inform_label = list(['none']) + inform_slot_dict[slot] = 0 + if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict: + inform_label = list([normalize_label(slot, i, boolean_slots) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]]) + inform_slot_dict[slot] = 1 + elif (str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1]) in sys_inform_dict: + inform_label = list([normalize_label(slot, i, boolean_slots) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1])]]) + inform_slot_dict[slot] = 1 + + (informed_value, + referred_slot, + usr_utt_tok_label, + class_type) = get_turn_label(value_label, + inform_label, + sys_utt_tok, + usr_utt_tok, + slot, + diag_seen_slots_value_dict, + slot_last_occurrence=True, + label_maps=label_maps) + + inform_dict[slot] = informed_value + + # Generally don't use span prediction on sys utterance (but inform prediction instead). + sys_utt_tok_label = [0 for _ in sys_utt_tok] + + # Determine what to do with value repetitions. + # If value is unique in seen slots, then tag it, otherwise not, + # since correct slot assignment can not be guaranteed anymore. + if not no_label_value_repetitions and slot in diag_seen_slots_dict: + if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(value_label) > 1: + class_type = 'none' + usr_utt_tok_label = [0 for _ in usr_utt_tok_label] + + sys_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label_dict[slot] = usr_utt_tok_label + + if diag_seen_slots_value_dict[slot] != value_label: + updated_slots[slot] = 1 + + # For now, we map all occurences of unpointable slot values + # to none. However, since the labels will still suggest + # a presence of unpointable slot values, the task of the + # DST is still to find those values. It is just not + # possible to do that via span prediction on the current input. + if class_type == 'unpointable': + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + if analyze: + if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[slot]: + print("(%s): %s, " % (slot, value_label), end='') + elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform': + # If slot has seen before and its class type did not change, label this slot a not present, + # assuming that the slot has not actually been mentioned in this turn. + # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, + # this must mean there is evidence in the original labels, therefore consider + # them as mentioned again. + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + else: + class_type_dict[slot] = class_type + referral_dict[slot] = referred_slot + # Remember that this slot was mentioned during this dialog already. + if class_type != 'none': + diag_seen_slots_dict[slot] = class_type + diag_seen_slots_value_dict[slot] = value_label + new_diag_state[slot] = class_type + # Unpointable is not a valid class, therefore replace with + # some valid class for now... + if class_type == 'unpointable': + new_diag_state[slot] = 'copy_value' + + if analyze: + print("]") + + if not swap_utterances: + txt_a = usr_utt_tok + txt_b = sys_utt_tok + txt_a_lbl = usr_utt_tok_label_dict + txt_b_lbl = sys_utt_tok_label_dict + else: + txt_a = sys_utt_tok + txt_b = usr_utt_tok + txt_a_lbl = sys_utt_tok_label_dict + txt_b_lbl = usr_utt_tok_label_dict + examples.append(DSTExample( + guid=guid, + text_a=txt_a, + text_b=txt_b, + text_a_label=txt_a_lbl, + text_b_label=txt_b_lbl, + values=diag_seen_slots_value_dict.copy(), + inform_label=inform_dict, + inform_slot_label=inform_slot_dict, + refer_label=referral_dict, + diag_state=diag_state, + slot_update=updated_slots, + class_label=class_type_dict)) + + # Update some variables. + diag_state = new_diag_state.copy() + + turn_itr += 1 + + if analyze: + print("----------------------------------------------------------------------") + + return examples + + +def get_value_list(input_file, slot_list, boolean_slots=True): + exclude = ['none', 'dontcare'] + if not boolean_slots: + exclude += ['true', 'false'] + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + value_list = {slot: {} for slot in slot_list} + for dialog_id in input_data: + for utt in input_data[dialog_id]['log']: + if utt['metadata'] != {}: + for d in utt['metadata']: + booked = utt['metadata'][d]['book']['booked'] + booked_slots = {} + # Check the booked section + if booked != []: + for s in booked[0]: + booked_slots[s] = normalize_label('%s-%s' % (d, s), booked[0][s], boolean_slots) + # Check the semi and the inform slots + for category in ['book', 'semi']: + for s in utt['metadata'][d][category]: + cs = '%s-book_%s' % (d, s) if category == 'book' else '%s-%s' % (d, s) + value_label = normalize_label(cs, utt['metadata'][d][category][s], boolean_slots) + # Prefer the slot value as stored in the booked section + if s in booked_slots: + value_label = booked_slots[s] + if cs in slot_list and value_label not in exclude: + if "|" not in value_label and "<" not in value_label and ">" not in value_label: + if value_label not in value_list[cs]: + value_list[cs][value_label] = 0 + value_list[cs][value_label] += 1 + return value_list diff --git a/dataset_sim.py b/dataset_sim.py new file mode 100644 index 0000000..74ab3ba --- /dev/null +++ b/dataset_sim.py @@ -0,0 +1,268 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from utils_dst import (DSTExample) + + +# Loads the dialogue_acts.json and returns a list +# of slot-value pairs. +def load_acts(input_file): + with open(input_file) as f: + acts = json.load(f) + s_dict = {} + for d in acts: + d_id = d["dialogue_id"] + for t_id, t in enumerate(d["turns"]): + # Only process, if turn has annotation + if "system_acts" in t: + for a in t["system_acts"]: + if "value" in a and a["type"] not in ["NEGATE", "NOTIFY_FAILURE"]: + key = d_id, t_id, a["slot"] + # In case of multiple mentioned values... + # ... Option 1: Keep first informed value + if key not in s_dict: + s_dict[key] = a["value"] + # ... Option 2: Keep last informed value + #s_dict[key] = a["value"] + return s_dict + + +def dialogue_state_to_sv_dict(sv_list): + sv_dict = {} + for d in sv_list: + sv_dict[d['slot']] = d['value'] + return sv_dict + + +def get_token_and_slot_label(turn): + if 'system_utterance' in turn: + sys_utt_tok = turn['system_utterance']['tokens'] + sys_slot_label = turn['system_utterance']['slots'] + else: + sys_utt_tok = [] + sys_slot_label = [] + + usr_utt_tok = turn['user_utterance']['tokens'] + usr_slot_label = turn['user_utterance']['slots'] + return sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label + + +def get_tok_label(prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok, + sys_slot_label, usr_utt_tok, usr_slot_label, dial_id, + turn_id, slot_last_occurrence=True): + """The position of the last occurrence of the slot value will be used.""" + sys_utt_tok_label = [0 for _ in sys_utt_tok] + usr_utt_tok_label = [0 for _ in usr_utt_tok] + if slot_type not in cur_ds_dict: + class_type = 'none' + else: + value = cur_ds_dict[slot_type] + if value == 'dontcare' and (slot_type not in prev_ds_dict or prev_ds_dict[slot_type] != 'dontcare'): + # Only label dontcare at its first occurrence in the dialog + class_type = 'dontcare' + else: # If not none or dontcare, we have to identify whether + # there is a span, or if the value is informed + in_usr = False + in_sys = False + for label_d in usr_slot_label: + if label_d['slot'] == slot_type and value == ' '.join( + usr_utt_tok[label_d['start']:label_d['exclusive_end']]): + for idx in range(label_d['start'], label_d['exclusive_end']): + usr_utt_tok_label[idx] = 1 + in_usr = True + class_type = 'copy_value' + if slot_last_occurrence: + break + + for label_d in sys_slot_label: + if label_d['slot'] == slot_type and value == ' '.join( + sys_utt_tok[label_d['start']:label_d['exclusive_end']]): + for idx in range(label_d['start'], label_d['exclusive_end']): + sys_utt_tok_label[idx] = 1 + in_sys = True + if not in_usr or not slot_last_occurrence: + class_type = 'inform' + if slot_last_occurrence: + break + + if not in_usr and not in_sys: + assert sum(usr_utt_tok_label + sys_utt_tok_label) == 0 + if (slot_type not in prev_ds_dict or value != prev_ds_dict[slot_type]): + raise ValueError('Copy value cannot found in Dial %s Turn %s' % (str(dial_id), str(turn_id))) + else: + class_type = 'none' + else: + assert sum(usr_utt_tok_label + sys_utt_tok_label) > 0 + return sys_utt_tok_label, usr_utt_tok_label, class_type + + +def delex_utt(utt, values, unk_token="[UNK]"): + utt_delex = utt.copy() + for v in values: + utt_delex[v['start']:v['exclusive_end']] = [unk_token] * (v['exclusive_end'] - v['start']) + return utt_delex + + +def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, sys_inform_dict, + delexicalize_sys_utts=False, unk_token="[UNK]", slot_last_occurrence=True): + """Make turn_label a dictionary of slot with value positions or being dontcare / none: + Turn label contains: + (1) the updates from previous to current dialogue state, + (2) values in current dialogue state explicitly mentioned in system or user utterance.""" + prev_ds_dict = dialogue_state_to_sv_dict(prev_dialogue_state) + cur_ds_dict = dialogue_state_to_sv_dict(turn['dialogue_state']) + + (sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label) = get_token_and_slot_label(turn) + + sys_utt_tok_label_dict = {} + usr_utt_tok_label_dict = {} + inform_label_dict = {} + inform_slot_label_dict = {} + referral_label_dict = {} + class_type_dict = {} + updated_slots_dict = {} + + for slot_type in slot_list: + updated_slots_dict[slot_type] = 0 + if slot_type in cur_ds_dict and slot_type in prev_ds_dict and cur_ds_dict[slot_type] != prev_ds_dict[slot_type]: + updated_slots_dict[slot_type] = 1 + inform_label_dict[slot_type] = 'none' + inform_slot_label_dict[slot_type] = 0 + referral_label_dict[slot_type] = 'none' # Referral is not present in sim data + sys_utt_tok_label, usr_utt_tok_label, class_type = get_tok_label( + prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok, sys_slot_label, + usr_utt_tok, usr_slot_label, dial_id, turn_id, + slot_last_occurrence=slot_last_occurrence) + if (dial_id, turn_id, slot_type) in sys_inform_dict: + inform_label_dict[slot_type] = sys_inform_dict[(dial_id, turn_id, slot_type)] + inform_slot_label_dict[slot_type] = 1 + if class_type == 'inform' and inform_label_dict[slot_type] != cur_ds_dict[slot_type]: + class_type = 'none' + sys_utt_tok_label = [0 for _ in sys_utt_tok_label] # Don't use token labels for sys utt + sys_utt_tok_label_dict[slot_type] = sys_utt_tok_label + usr_utt_tok_label_dict[slot_type] = usr_utt_tok_label + class_type_dict[slot_type] = class_type + + if delexicalize_sys_utts: + sys_utt_tok = delex_utt(sys_utt_tok, sys_slot_label, unk_token) + + return (sys_utt_tok, sys_utt_tok_label_dict, + usr_utt_tok, usr_utt_tok_label_dict, + inform_label_dict, inform_slot_label_dict, + referral_label_dict, cur_ds_dict, class_type_dict, + updated_slots_dict) + + +def create_examples(input_file, set_type, slot_list, + no_label_value_repetitions=False, + swap_utterances=False, + delexicalize_sys_utts=False, + unk_token="[UNK]", + boolean_slots=True, + analyze=False): + """Read a DST json file into a list of DSTExample.""" + + sys_inform_dict = load_acts(input_file) + + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + + examples = [] + for entry in input_data: + dial_id = entry['dialogue_id'] + prev_ds = [] + prev_ds_lbl_dict = {slot: 'none' for slot in slot_list} + for turn_id, turn in enumerate(entry['turns']): + guid = '%s-%s-%s' % (set_type, dial_id, str(turn_id)) + ds_lbl_dict = prev_ds_lbl_dict.copy() + (text_a, + text_a_label, + text_b, + text_b_label, + inform_label, + inform_slot_label, + referral_label, + cur_ds_dict, + class_label, + updated_slots) = get_turn_label(turn, + prev_ds, + slot_list, + dial_id, + turn_id, + sys_inform_dict, + delexicalize_sys_utts=delexicalize_sys_utts, + unk_token=unk_token, + slot_last_occurrence=True) + + if not swap_utterances: + txt_a = text_b + txt_b = text_a + txt_a_lbl = text_b_label + txt_b_lbl = text_a_label + else: + txt_a = text_a + txt_b = text_b + txt_a_lbl = text_a_label + txt_b_lbl = text_b_label + + value_dict = {} + for slot in slot_list: + if slot in cur_ds_dict: + value_dict[slot] = cur_ds_dict[slot] + else: + value_dict[slot] = 'none' + if class_label[slot] != 'none': + ds_lbl_dict[slot] = class_label[slot] + + examples.append(DSTExample( + guid=guid, + text_a=txt_a, + text_b=txt_b, + text_a_label=txt_a_lbl, + text_b_label=txt_b_lbl, + values=value_dict, + inform_label=inform_label, + inform_slot_label=inform_slot_label, + refer_label=referral_label, + diag_state=prev_ds_lbl_dict, + slot_update=updated_slots, + class_label=class_label)) + + prev_ds = turn['dialogue_state'] + prev_ds_lbl_dict = ds_lbl_dict.copy() + + return examples + + +def get_value_list(input_file, slot_list): + exclude = ['none', 'dontcare'] + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + value_list = {slot: {} for slot in slot_list} + for entry in input_data: + for turn in entry['turns']: + cur_ds_dict = dialogue_state_to_sv_dict(turn['dialogue_state']) + for slot in slot_list: + if slot in cur_ds_dict and cur_ds_dict[slot] not in exclude: + if cur_ds_dict[slot] not in value_list[slot]: + value_list[slot][cur_ds_dict[slot]] = 0 + value_list[slot][cur_ds_dict[slot]] += 1 + return value_list diff --git a/dataset_unified.py b/dataset_unified.py new file mode 100644 index 0000000..8fc992a --- /dev/null +++ b/dataset_unified.py @@ -0,0 +1,350 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from tqdm import tqdm + +from utils_dst import (DSTExample) + +try: + from convlab.util import (load_dataset, load_ontology, load_dst_data) +except ModuleNotFoundError as e: + print(e) + print("Ignore this error if you don't intend to use the data processor for ConvLab3's unified data format.") + print("Otherwise, make sure you have ConvLab3 installed and added to your PYTHONPATH.") + + +def get_ontology_slots(ontology): + domains = [domain for domain in ontology['domains']] + ontology_slots = dict() + for domain in domains: + if domain not in ontology_slots: + ontology_slots[domain] = dict() + for slot in ontology['domains'][domain]['slots']: + ontology_slots[domain][slot] = ontology['domains'][domain]['slots'][slot]['description'] + return ontology_slots + + +def get_slot_list(dataset_name): + slot_list = {} + ontology = load_ontology(dataset_name) + dataset_slot_list = get_ontology_slots(ontology) + for domain in dataset_slot_list: + for slot in dataset_slot_list[domain]: + slot_list["%s-%s" % (domain, slot)] = dataset_slot_list[domain][slot] + slot_list["%s-none" % (domain)] = "the topic is %s" % (domain) + # Some special intents are modeled as 'request' slots in TripPy + if 'bye' in ontology['intents']: + slot_list["general-bye"] = ontology['intents']['bye']['description'] + if 'thank' in ontology['intents']: + slot_list["general-thank"] = ontology['intents']['thank']['description'] + if 'greet' in ontology['intents']: + slot_list["general-greet"] = ontology['intents']['greet']['description'] + return slot_list + + +def get_value_list(dataset_name, slot_list): + value_list = {slot: {} for slot in slot_list} + ontology = load_ontology(dataset_name) + for slot in slot_list: + d, s = slot.split('-') + if d in ontology['domains']: + if s in ontology['domains'][d]['slots']: + if ontology['domains'][d]['slots'][s]['is_categorical']: + for v in ontology['domains'][d]['slots'][s]['possible_values']: + value_list[slot][v] = 1 + return value_list + + +def create_examples(set_type, dataset_name="multiwoz21", class_types=[], slot_list=[], label_maps={}, + no_label_value_repetitions=False, + swap_utterances=False, + delexicalize_sys_utts=False, + unk_token="[UNK]", + boolean_slots=True, + analyze=False): + """Read a DST json file into a list of DSTExample.""" + + # TODO: Make sure normalization etc. will be compatible with or suitable for SGD and + # other datasets as well. + if dataset_name == "multiwoz21": + from dataset_multiwoz21 import (tokenize, normalize_label, + get_turn_label, delex_utt, + is_request) + else: + raise ValueError("Unknown dataset_name.") + + dataset_args = {"dataset_name": dataset_name} + dataset_dict = load_dataset(**dataset_args) + + if slot_list == []: + slot_list = get_slot_list() + + data = load_dst_data(dataset_dict, data_split=set_type, speaker='all', dialogue_acts=True, split_to_turn=False) + + examples = [] + for d_itr, entry in enumerate(tqdm(data[set_type])): + dialog_id = entry['dialogue_id'] + #dialog_id = entry['original_id'] + original_id = entry['original_id'] + domains = entry['domains'] + turns = entry['turns'] + + # Collects all slot changes throughout the dialog + cumulative_labels = {slot: 'none' for slot in slot_list} + + # First system utterance is empty, since multiwoz starts with user input + utt_tok_list = [[]] + mod_slots_list = [{}] + inform_dict_list = [{}] + user_act_dict_list = [{}] + mod_domains_list = [{}] + + # Collect all utterances and their metadata + usr_sys_switch = True + for turn in turns: + utterance = turn['utterance'] + state = turn['state'] if 'state' in turn else {} + acts = [item for sublist in list(turn['dialogue_acts'].values()) for item in sublist] # flatten list + + # Assert that system and user utterances alternate + is_sys_utt = turn['speaker'] in ['sys', 'system'] + if usr_sys_switch == is_sys_utt: + print("WARN: Wrong order of system and user utterances. Skipping rest of dialog %s" % (dialog_id)) + break + usr_sys_switch = is_sys_utt + + # Extract metadata: identify modified slots and values informed by the system + inform_dict = {} + user_act_dict = {} + modified_slots = {} + modified_domains = set() + for act in acts: + slot = "%s-%s" % (act['domain'], act['slot'] if act['slot'] != '' else 'none') + if act['intent'] in ['bye', 'thank', 'hello']: + slot = "general-%s" % (act['intent']) + value_label = act['value'] if 'value' in act else 'yes' if act['slot'] != '' else 'none' + value_label = normalize_label(slot, value_label) + modified_domains.add(act['domain']) # Remember domains + if is_sys_utt and act['intent'] in ['inform', 'recommend', 'select', 'book'] and value_label != 'none': + if slot not in inform_dict: + inform_dict[slot] = [] + inform_dict[slot].append(value_label) + elif not is_sys_utt: + if slot not in user_act_dict: + user_act_dict[slot] = [] + user_act_dict[slot].append(act) + # INFO: Since the model has no mechanism to predict + # one among several informed value candidates, we + # keep only one informed value. For fairness, we + # apply a global rule: + for e in inform_dict: + # ... Option 1: Always keep first informed value + inform_dict[e] = list([inform_dict[e][0]]) + # ... Option 2: Always keep last informed value + #inform_dict[e] = list([inform_dict[e][-1]]) + for d in state: + for s in state[d]: + slot = "%s-%s" % (d, s) + value_label = normalize_label(slot, state[d][s]) + # Remember modified slots and entire dialog state + if slot in slot_list and cumulative_labels[slot] != value_label: + modified_slots[slot] = value_label + cumulative_labels[slot] = value_label + modified_domains.add(d) # Remember domains + + # Delexicalize sys utterance + if delexicalize_sys_utts and is_sys_utt: + utt_tok_list.append(delex_utt(utterance, inform_dict, unk_token)) # normalizes utterances + else: + utt_tok_list.append(tokenize(utterance)) # normalizes utterances + + inform_dict_list.append(inform_dict.copy()) + user_act_dict_list.append(user_act_dict.copy()) + mod_slots_list.append(modified_slots.copy()) + modified_domains = list(modified_domains) + modified_domains.sort() + mod_domains_list.append(modified_domains) + + # Form proper (usr, sys) turns + turn_itr = 0 + diag_seen_slots_dict = {} + diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list} + diag_state = {slot: 'none' for slot in slot_list} + sys_utt_tok = [] + usr_utt_tok = [] + for i in range(1, len(utt_tok_list) - 1, 2): + sys_utt_tok_label_dict = {} + usr_utt_tok_label_dict = {} + value_dict = {} + inform_dict = {} + inform_slot_dict = {} + referral_dict = {} + class_type_dict = {} + updated_slots = {slot: 0 for slot in slot_list} + + # Collect turn data + sys_utt_tok = utt_tok_list[i - 1] + usr_utt_tok = utt_tok_list[i] + turn_slots = mod_slots_list[i] + inform_mem = inform_dict_list[i - 1] + user_act = user_act_dict_list[i] + turn_domains = mod_domains_list[i] + + guid = '%s-%s' % (dialog_id, turn_itr) + + if analyze: + print("%15s %2s %s ||| %s" % (dialog_id, turn_itr, ' '.join(sys_utt_tok), ' '.join(usr_utt_tok))) + print("%15s %2s [" % (dialog_id, turn_itr), end='') + + new_diag_state = diag_state.copy() + for slot in slot_list: + value_label = 'none' + if slot in turn_slots: + value_label = turn_slots[slot] + # We keep the original labels so as to not + # overlook unpointable values, as well as to not + # modify any of the original labels for test sets, + # since this would make comparison difficult. + value_dict[slot] = value_label + elif not no_label_value_repetitions and slot in diag_seen_slots_dict: + value_label = diag_seen_slots_value_dict[slot] + + # Get dialog act annotations + inform_label = list(['none']) + inform_slot_dict[slot] = 0 + if slot in inform_mem: + inform_label = inform_mem[slot] + inform_slot_dict[slot] = 1 + + (informed_value, + referred_slot, + usr_utt_tok_label, + class_type) = get_turn_label(value_label, + inform_label, + sys_utt_tok, + usr_utt_tok, + slot, + diag_seen_slots_value_dict, + slot_last_occurrence=True, + label_maps=label_maps) + + inform_dict[slot] = informed_value + + # Requestable slots, domain indicator slots and general slots + # should have class_type 'request', if they ought to be predicted. + # Give other class_types preference. + if 'request' in class_types: + if class_type in ['none', 'unpointable'] and is_request(slot, user_act, turn_domains): + class_type = 'request' + + # Generally don't use span prediction on sys utterance (but inform prediction instead). + sys_utt_tok_label = [0 for _ in sys_utt_tok] + + # Determine what to do with value repetitions. + # If value is unique in seen slots, then tag it, otherwise not, + # since correct slot assignment can not be guaranteed anymore. + if not no_label_value_repetitions and slot in diag_seen_slots_dict: + if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(value_label) > 1: + class_type = 'none' + usr_utt_tok_label = [0 for _ in usr_utt_tok_label] + + sys_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label_dict[slot] = usr_utt_tok_label + + if diag_seen_slots_value_dict[slot] != value_label: + updated_slots[slot] = 1 + + # For now, we map all occurences of unpointable slot values + # to none. However, since the labels will still suggest + # a presence of unpointable slot values, the task of the + # DST is still to find those values. It is just not + # possible to do that via span prediction on the current input. + if class_type == 'unpointable': + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + if analyze: + if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[slot]: + print("(%s): %s, " % (slot, value_label), end='') + elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform': + # If slot has seen before and its class type did not change, label this slot a not present, + # assuming that the slot has not actually been mentioned in this turn. + # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, + # this must mean there is evidence in the original labels, therefore consider + # them as mentioned again. + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + else: + class_type_dict[slot] = class_type + referral_dict[slot] = referred_slot + # Remember that this slot was mentioned during this dialog already. + if class_type != 'none': + diag_seen_slots_dict[slot] = class_type + diag_seen_slots_value_dict[slot] = value_label + new_diag_state[slot] = class_type + # Unpointable is not a valid class, therefore replace with + # some valid class for now... + if class_type == 'unpointable': + new_diag_state[slot] = 'copy_value' + + if analyze: + print("]") + + if not swap_utterances: + txt_a = usr_utt_tok + txt_b = sys_utt_tok + txt_a_lbl = usr_utt_tok_label_dict + txt_b_lbl = sys_utt_tok_label_dict + else: + txt_a = sys_utt_tok + txt_b = usr_utt_tok + txt_a_lbl = sys_utt_tok_label_dict + txt_b_lbl = usr_utt_tok_label_dict + examples.append(DSTExample( + guid=guid, + text_a=txt_a, + text_b=txt_b, + text_a_label=txt_a_lbl, + text_b_label=txt_b_lbl, + values=diag_seen_slots_value_dict.copy(), + inform_label=inform_dict, + inform_slot_label=inform_slot_dict, + refer_label=referral_dict, + diag_state=diag_state, + slot_update=updated_slots, + class_label=class_type_dict)) + + # Update some variables. + diag_state = new_diag_state.copy() + + turn_itr += 1 + + if analyze: + print("----------------------------------------------------------------------") + + return examples + + +def prediction_normalization(dataset_name, slot, value): + if dataset_name == "multiwoz21": + from dataset_multiwoz21 import prediction_normalization as pred_norm + return pred_norm(slot, value) + else: + return value diff --git a/dataset_woz2.py b/dataset_woz2.py new file mode 100644 index 0000000..6d75386 --- /dev/null +++ b/dataset_woz2.py @@ -0,0 +1,288 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re + +from utils_dst import (DSTExample, convert_to_unicode) + + +LABEL_MAPS = {} # Loaded from file +LABEL_FIX = {'areas': 'area', 'phone number': 'number', 'price range': 'price_range', 'center': 'centre', 'east side': 'east', 'corsican': 'corsica'} + + +def delex_utt(utt, values, unk_token="[UNK]"): + utt_norm = utt.copy() + for s, v in values.items(): + if v != 'none': + v_norm = tokenize(v) + v_len = len(v_norm) + for i in range(len(utt_norm) + 1 - v_len): + if utt_norm[i:i + v_len] == v_norm: + utt_norm[i:i + v_len] = [unk_token] * v_len + return utt_norm + + +def get_token_pos(tok_list, label): + find_pos = [] + found = False + label_list = [item for item in map(str.strip, re.split("(\W+)", label)) if len(item) > 0] + len_label = len(label_list) + for i in range(len(tok_list) + 1 - len_label): + if tok_list[i:i + len_label] == label_list: + find_pos.append((i, i + len_label)) # start, exclusive_end + found = True + return found, find_pos + + +def check_label_existence(label, usr_utt_tok, sys_utt_tok): + in_usr, usr_pos = get_token_pos(usr_utt_tok, label) + if not in_usr and label in LABEL_MAPS: + for tmp_label in LABEL_MAPS[label]: + in_usr, usr_pos = get_token_pos(usr_utt_tok, tmp_label) + if in_usr: + break + in_sys, sys_pos = get_token_pos(sys_utt_tok, label) + if not in_sys and label in LABEL_MAPS: + for tmp_label in LABEL_MAPS[label]: + in_sys, sys_pos = get_token_pos(sys_utt_tok, tmp_label) + if in_sys: + break + return in_usr, usr_pos, in_sys, sys_pos + + +def get_turn_label(label, sys_utt_tok, usr_utt_tok, slot_last_occurrence): + usr_utt_tok_label = [0 for _ in usr_utt_tok] + in_sys = False + if label == 'none' or label == 'dontcare': + class_type = label + else: + in_usr, usr_pos, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok) + if in_usr: + class_type = 'copy_value' + if slot_last_occurrence: + (s, e) = usr_pos[-1] + for i in range(s, e): + usr_utt_tok_label[i] = 1 + else: + for (s, e) in usr_pos: + for i in range(s, e): + usr_utt_tok_label[i] = 1 + elif in_sys: + class_type = 'inform' + else: + class_type = 'unpointable' + return usr_utt_tok_label, class_type, in_sys + + +def tokenize(utt): + utt_lower = convert_to_unicode(utt).lower() + utt_tok = utt_to_token(utt_lower) + return utt_tok + + +def utt_to_token(utt): + return [tok for tok in map(lambda x: re.sub(" ", "", x), re.split("(\W+)", utt)) if len(tok) > 0] + + +def create_examples(input_file, set_type, slot_list, + label_maps={}, + asr=False, + no_label_value_repetitions=False, + swap_utterances=False, + delexicalize_sys_utts=False, + unk_token="[UNK]", + boolean_slots=True, + analyze=False): + """Read a DST json file into a list of DSTExample.""" + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + + global LABEL_MAPS + LABEL_MAPS = label_maps + + examples = [] + for entry in input_data: + diag_seen_slots_dict = {} + diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list} + diag_state = {slot: 'none' for slot in slot_list} + sys_utt_tok = [] + sys_utt_tok_delex = [] + usr_utt_tok = [] + for turn in entry['dialogue']: + sys_utt_tok_label_dict = {} + usr_utt_tok_label_dict = {} + inform_dict = {slot: 'none' for slot in slot_list} + inform_slot_dict = {slot: 0 for slot in slot_list} + referral_dict = {} + class_type_dict = {} + updated_slots = {slot: 0 for slot in slot_list} + + sys_utt_tok = tokenize(turn['system_transcript']) + if asr: + # The model always expects a non-empty user input. + # This can not be guaranteed for ASR hypotheses. + # In case of an empty hypo, use a generic hello. + for asr_hypo in turn['asr']: + usr_utt_tok = tokenize(asr_hypo[0]) + if len(usr_utt_tok) > 0: + break + if len(usr_utt_tok) == 0: + usr_utt_tok = ['hello'] + else: + usr_utt_tok = tokenize(turn['transcript']) + turn_label = {LABEL_FIX.get(s.strip(), s.strip()): LABEL_FIX.get(v.strip(), v.strip()) for s, v in turn['turn_label']} + + guid = '%s-%s-%s' % (set_type, str(entry['dialogue_idx']), str(turn['turn_idx'])) + + # Create delexicalized sys utterances. + if delexicalize_sys_utts: + delex_dict = {} + for slot in slot_list: + delex_dict[slot] = 'none' + label = 'none' + if slot in turn_label: + label = turn_label[slot] + elif not no_label_value_repetitions and slot in diag_seen_slots_dict: + label = diag_seen_slots_value_dict[slot] + if label != 'none' and label != 'dontcare': + _, _, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok) + if in_sys: + delex_dict[slot] = label + sys_utt_tok_delex = delex_utt(sys_utt_tok, delex_dict, unk_token) + + new_diag_state = diag_state.copy() + for slot in slot_list: + label = 'none' + if slot in turn_label: + label = turn_label[slot] + elif not no_label_value_repetitions and slot in diag_seen_slots_dict: + label = diag_seen_slots_value_dict[slot] + + (usr_utt_tok_label, + class_type, + is_informed) = get_turn_label(label, + sys_utt_tok, + usr_utt_tok, + slot_last_occurrence=True) + + if class_type == 'inform': + inform_dict[slot] = label + if is_informed and label != 'none': + inform_slot_dict[slot] = 1 + + referral_dict[slot] = 'none' # Referral is not present in woz2 data + + # Generally don't use span prediction on sys utterance (but inform prediction instead). + if delexicalize_sys_utts: + sys_utt_tok_label = [0 for _ in sys_utt_tok_delex] + else: + sys_utt_tok_label = [0 for _ in sys_utt_tok] + + # Determine what to do with value repetitions. + # If value is unique in seen slots, then tag it, otherwise not, + # since correct slot assignment can not be guaranteed anymore. + if not no_label_value_repetitions and slot in diag_seen_slots_dict: + if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(label) > 1: + class_type = 'none' + usr_utt_tok_label = [0 for _ in usr_utt_tok_label] + + sys_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label_dict[slot] = usr_utt_tok_label + + if diag_seen_slots_value_dict[slot] != label: + updated_slots[slot] = 1 + + # For now, we map all occurences of unpointable slot values + # to none. However, since the labels will still suggest + # a presence of unpointable slot values, the task of the + # DST is still to find those values. It is just not + # possible to do that via span prediction on the current input. + if class_type == 'unpointable': + class_type_dict[slot] = 'none' + elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform': + # If slot has seen before and its class type did not change, label this slot a not present, + # assuming that the slot has not actually been mentioned in this turn. + # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, + # this must mean there is evidence in the original labels, therefore consider + # them as mentioned again. + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + else: + class_type_dict[slot] = class_type + # Remember that this slot was mentioned during this dialog already. + if class_type != 'none': + diag_seen_slots_dict[slot] = class_type + diag_seen_slots_value_dict[slot] = label + new_diag_state[slot] = class_type + # Unpointable is not a valid class, therefore replace with + # some valid class for now... + if class_type == 'unpointable': + new_diag_state[slot] = 'copy_value' + + if not swap_utterances: + txt_a = usr_utt_tok + if delexicalize_sys_utts: + txt_b = sys_utt_tok_delex + else: + txt_b = sys_utt_tok + txt_a_lbl = usr_utt_tok_label_dict + txt_b_lbl = sys_utt_tok_label_dict + else: + if delexicalize_sys_utts: + txt_a = sys_utt_tok_delex + else: + txt_a = sys_utt_tok + txt_b = usr_utt_tok + txt_a_lbl = sys_utt_tok_label_dict + txt_b_lbl = usr_utt_tok_label_dict + examples.append(DSTExample( + guid=guid, + text_a=txt_a, + text_b=txt_b, + text_a_label=txt_a_lbl, + text_b_label=txt_b_lbl, + values=diag_seen_slots_value_dict.copy(), + inform_label=inform_dict, + inform_slot_label=inform_slot_dict, + refer_label=referral_dict, + diag_state=diag_state, + slot_update=updated_slots, + class_label=class_type_dict)) + + # Update some variables. + diag_state = new_diag_state.copy() + + return examples + + +def get_value_list(input_file, slot_list): + exclude = ['none', 'dontcare'] + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + value_list = {slot: {} for slot in slot_list} + for entry in input_data: + for turn in entry['dialogue']: + turn_label = {LABEL_FIX.get(s.strip(), s.strip()): LABEL_FIX.get(v.strip(), v.strip()) for s, v in turn['turn_label']} + for slot in turn_label: + if slot in slot_list and turn_label[slot] not in exclude: + if turn_label[slot] not in value_list[slot]: + value_list[slot][turn_label[slot]] = 0 + value_list[slot][turn_label[slot]] += 1 + return value_list diff --git a/dst_proto.py b/dst_proto.py new file mode 100644 index 0000000..b6c759e --- /dev/null +++ b/dst_proto.py @@ -0,0 +1,341 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import math + +import numpy as np +import torch +from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler) +from torch.utils.data.distributed import DistributedSampler +from torch.optim import (AdamW) +from tqdm import tqdm, trange + +from tensorboardX import SummaryWriter +from transformers import (get_linear_schedule_with_warmup) +from utils_run import (set_seed, to_device, from_device, + save_checkpoint, load_and_cache_examples, + dilate_and_erode) + +logger = logging.getLogger(__name__) + + +def train_proto(args, train_dataset, dev_dataset, model, tokenizer, processor): + """ Train the proto model """ + if args.local_rank in [-1, 0]: + tb_writer = SummaryWriter() + + # This controls the item return function (__getitem__). + train_dataset.proto() + if dev_dataset is not None: + dev_dataset.proto() + + model.eval() # No dropout + + # If sequences were not tokenized yet, do so now, then save tokenized and encoded sequences. + if not train_dataset.load_tokenized_sequences(args.output_dir): + train_dataset.tokenize_sequences(max_len=args.rand_seq_max_len) + train_dataset.save_tokenized_sequences(args.output_dir, overwrite=False) + if dev_dataset is not None and not dev_dataset.load_tokenized_sequences(args.output_dir): + dev_dataset.tokenize_sequences(max_len=args.rand_seq_max_len) + dev_dataset.save_tokenized_sequences(args.output_dir, overwrite=False) + + args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) + train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) + train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) + + if args.max_steps > 0: + t_total = args.max_steps + args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 + else: + t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs + + if args.save_epochs > 0: + args.save_steps = t_total // args.num_train_epochs * args.save_epochs + assert args.save_steps == 0 or args.patience < 0 + + num_warmup_steps = int(t_total * args.warmup_proportion) + if args.patience > 0: + patience = args.patience + cur_min_loss = math.inf + + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ['bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total) + scaler = torch.cuda.amp.GradScaler() + if 'cuda' in args.device.type: + autocast = torch.cuda.amp.autocast(enabled=args.fp16) + else: + autocast = torch.cpu.amp.autocast(enabled=args.fp16) + + # multi-gpu training + model_single_gpu = model + if args.n_gpu > 1: + model = torch.nn.DataParallel(model_single_gpu) + + # Distributed training + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True) + + # Pretrain! + logger.info("***** Running proto training *****") + logger.info(" Num examples = %d", len(train_dataset)) + logger.info(" Num Epochs = %d", args.num_train_epochs) + logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) + logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", t_total) + logger.info(" Warmup steps = %d", num_warmup_steps) + + global_step = 0 + tr_loss, logging_loss = 0.0, 0.0 + model.zero_grad() + train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) + set_seed(args) # Added here for reproducibility (even between python 2 and 3) + + for e_itr, _ in enumerate(train_iterator): + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) + train_dataset.update_samples_for_proto(max_len=args.rand_seq_max_len) + + all_train_results = [] + for step, batch in enumerate(epoch_iterator): + model.train() + + batch = to_device(batch, args.device) + with autocast: + outputs = model(batch, step=step, mode="proto") # calls the "forward" def. + loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) + outputs = from_device(outputs) + batch = from_device(batch) + + train_results = eval_metric_proto(args, model, tokenizer, batch, outputs, threshold=0.5) + all_train_results.append(train_results) + + if args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + epoch_iterator.set_postfix({'loss': loss.item()}) + + tr_loss += loss.item() + if (step + 1) % args.gradient_accumulation_steps == 0: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + scaler.step(optimizer) + scaler.update() + scheduler.step() # Update learning rate schedule + model.zero_grad() + global_step += 1 + + # Log metrics + if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: + tb_writer.add_scalar('lr', scheduler.get_last_lr()[0], global_step) + tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step) + logging_loss = tr_loss + + # Save model checkpoint + if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: + save_checkpoint(args, global_step, model, prefix='proto') + + if args.max_steps > 0 and global_step > args.max_steps: + epoch_iterator.close() + break + + # Generate final results + final_train_results = {'loss': torch.tensor(0), 'accuracy': torch.tensor(0)} + if len(all_train_results) > 0: + for k in all_train_results[0].keys(): + final_train_results[k] = torch.stack([r[k] for r in all_train_results]).sum() / len(train_dataset) + + # Only evaluate when single GPU otherwise metrics may not average well + if args.local_rank == -1 and dev_dataset is not None: + results = evaluate_proto(args, dev_dataset, model_single_gpu, tokenizer, processor, no_print=True, prefix=global_step) + for key, value in results.items(): + tb_writer.add_scalar('eval_{}'.format(key), value, global_step) + + # Patience + if args.patience > 0: + if args.early_stop_criterion == "loss": + criterion = results['proto_loss'].item() + elif args.early_stop_criterion == "goal": + criterion = -1 * results['proto_accuracy'].item() + else: + logger.warn("Early stopping criterion %s not known. Aborting" % (args.early_stop_criterion)) + if criterion > cur_min_loss: + patience -= 1 + else: + # Save model checkpoint + patience = args.patience + save_checkpoint(args, global_step, model, prefix='proto', keep_only_last_checkpoint=True) + cur_min_loss = criterion + train_iterator.set_postfix({'patience': patience, + 'train loss': final_train_results['proto_loss'].item(), + 'eval loss': results['proto_loss'].item(), + 'train acc': final_train_results['proto_accuracy'].item(), + 'eval acc': results['proto_accuracy'].item()}) + if patience == 0: + train_iterator.close() + break + + if args.max_steps > 0 and global_step > args.max_steps: + train_iterator.close() + break + + if args.local_rank in [-1, 0]: + tb_writer.close() + + # This controls the item return function (__getitem__). + train_dataset.reset() + if dev_dataset is not None: + dev_dataset.reset() + + return global_step, tr_loss / global_step + + +def evaluate_proto(args, dataset, model, tokenizer, processor, no_print=False, prefix=""): + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + dataset.proto() # This controls the item return function (__getitem__). + + model.eval() # No dropout + + if not dataset.load_tokenized_sequences(args.output_dir): + dataset.tokenize_sequences(max_len=args.rand_seq_max_len) + + dataset.update_samples_for_proto(max_len=args.rand_seq_max_len) + + args.eval_batch_size = args.per_gpu_eval_batch_size + eval_sampler = SequentialSampler(dataset) # Note that DistributedSampler samples randomly + eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + + # Eval! + logger.info("***** Running evaluation of proto training {} *****".format(prefix)) + logger.info(" Num examples = %d", len(dataset)) + logger.info(" Batch size = %d", args.eval_batch_size) + all_results = [] + for batch in tqdm(eval_dataloader, desc="Evaluating"): + model.eval() + + with torch.no_grad(): + batch = to_device(batch, args.device) + outputs = model(batch, mode="proto") + outputs = from_device(outputs) + batch = from_device(batch) + + unique_ids = [dataset.features[i.item()].guid for i in batch['example_id']] + values = [dataset.features[i.item()].values for i in batch['example_id']] + input_ids = [dataset.features[i.item()].input_ids for i in batch['example_id']] + inform = [dataset.features[i.item()].inform for i in batch['example_id']] + + results = eval_metric_proto(args, model, tokenizer, batch, outputs, threshold=0.5) + all_results.append(results) + if not no_print: + predict_and_print_proto(args, model, tokenizer, batch, outputs, unique_ids, input_ids, values, inform) + + # Generate final results + final_results = {} + for k in all_results[0].keys(): + final_results[k] = torch.stack([r[k] for r in all_results]).sum() / len(dataset) + + dataset.reset() # This controls the item return function (__getitem__). + + return final_results + + +def eval_metric_proto(args, model, tokenizer, batch, outputs, threshold=0.0, dae=False): + loss = outputs[0] + logits = outputs[1] + + input_ids = [] + for i in range(len(batch['input_ids'])): + clipped = batch['input_ids'][i].tolist() + clipped = clipped[:len(clipped) - clipped[::-1].index(tokenizer.sep_token_id)] + input_ids.append(clipped) + + metric_dict = {} + + mean = [] + for i in range(len(batch['input_ids'])): + mean.append(torch.mean(logits[i][:len(input_ids[i])])) + mean = torch.stack(mean) + norm_logits = torch.clamp(logits - mean.unsqueeze(1), min=0) / logits.max(1)[0].unsqueeze(1) + + start_pos = batch['start_pos'] + + if dae: + token_prediction = dilate_and_erode(norm_logits, threshold) + else: + token_prediction = norm_logits > threshold + token_prediction[:, 0] = False # Ignore <s> + token_correctness = torch.all(torch.eq(token_prediction, start_pos), 1).float() + token_accuracy = token_correctness.sum() + metric_dict['proto_loss'] = loss + metric_dict['proto_accuracy'] = token_accuracy + return metric_dict + + +def predict_and_print_proto(args, model, tokenizer, batch, outputs, ids, input_ids_unmasked, values, inform): + per_slot_start_logits = outputs[1] + + for i in range(len(ids)): + input_tokens = tokenizer.convert_ids_to_tokens(input_ids_unmasked[i]) + + input_ids = batch['input_ids'][i].tolist() + input_ids = input_ids[:len(input_ids) - input_ids[::-1].index(tokenizer.sep_token_id)] + + token_norm_weights = {} + pos_i = batch['start_pos'][i].tolist() + token_weights = per_slot_start_logits[i][:len(input_ids)] + token_norm_weights = torch.clamp(token_weights - torch.mean(token_weights), min=0) / torch.max(token_weights) + + print(ids[i]) + for k in range(len(input_ids)): + bold = False + print(" ", end="") + t_weight = token_norm_weights[k] + if t_weight == 0.0: + print(" ", end="") + elif t_weight < 0.25: + print("\u2591 ", end="") + elif t_weight < 0.5: + print("\u2592 ", end="") + elif t_weight < 0.75: + print("\u2593 ", end="") + else: + print("\u2588 ", end="") + if pos_i[k]: + bold = True + if bold: + print("\033[1m%s\033[0m" % (input_tokens[k])) + else: + print("%s" % (input_tokens[k])) diff --git a/dst_tag.py b/dst_tag.py new file mode 100644 index 0000000..cbc09b3 --- /dev/null +++ b/dst_tag.py @@ -0,0 +1,188 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import torch +from torch.utils.data import (DataLoader, SequentialSampler) +from tqdm import tqdm + +from utils_run import (set_seed, to_device, from_device, + save_checkpoint, load_and_cache_examples, + dilate_and_erode) + +logger = logging.getLogger(__name__) + + +def tag_values(args, dataset, model, tokenizer, processor, no_print=False, prefix="", threshold=0.0, dae=False): + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + model.eval() # No dropout + + dataset.tag() # This controls the item return function (__getitem__). + + dataset.encode_slot_values(val_rep_mode="encode", val_rep="v") + + args.eval_batch_size = args.per_gpu_eval_batch_size + eval_sampler = SequentialSampler(dataset) # Note that DistributedSampler samples randomly + eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + + # Tag! + logger.info("***** Running value tagging {} *****".format(prefix)) + logger.info(" Num examples = %d", len(dataset)) + logger.info(" Batch size = %d", args.eval_batch_size) + all_labels = [] + all_results = [] + for batch in tqdm(eval_dataloader, desc="Tagging"): + model.eval() + + with torch.no_grad(): + batch['encoded_slot_values'] = dataset.encoded_slot_values + batch = to_device(batch, args.device) + outputs = model(batch, mode="tag") + outputs = from_device(outputs) + batch = from_device(batch) + + unique_ids = [dataset.features[i.item()].guid for i in batch['example_id']] + values = [dataset.features[i.item()].values for i in batch['example_id']] + input_ids = [dataset.features[i.item()].input_ids for i in batch['example_id']] + inform = [dataset.features[i.item()].inform for i in batch['example_id']] + + auto_labels, results = label_and_eval_tags(args, model, tokenizer, batch, outputs, values, threshold=threshold, dae=dae) + all_labels.append(auto_labels) + all_results.append(results) + if not no_print: + predict_and_print_tags(args, model, tokenizer, batch, outputs, unique_ids, input_ids, values, inform) + + # Generate final labels + final_labels = {slot: {} for slot in model.slot_list} + for b in all_labels: + for s in b: + final_labels[s].update(b[s]) + + # Generate final results + final_results = {} + for k in all_results[0].keys(): + final_results[k] = (torch.stack([r[k] for r in all_results]).sum() / len(dataset)).item() + + dataset.reset() # This controls the item return function (__getitem__). + + return final_labels, final_results + + +def label_and_eval_tags(args, model, tokenizer, batch, outputs, values, threshold=0.0, dae=False): + per_slot_start_logits = outputs[0] + + input_ids = [] + for i in range(len(batch['input_ids'])): + clipped = batch['input_ids'][i].tolist() + clipped = clipped[:len(clipped) - clipped[::-1].index(tokenizer.sep_token_id)] + input_ids.append(clipped) + + auto_labels = {} + metric_dict = {} + per_slot_correctness = {} + for s_itr, slot in enumerate(model.slot_list): + start_logits = per_slot_start_logits[:, s_itr] + mean = [] + for i in range(len(batch['input_ids'])): + mean.append(torch.mean(start_logits[i][:len(input_ids[i])])) + mean = torch.stack(mean) + norm_logits = torch.clamp(start_logits - mean.unsqueeze(1), min=0) / start_logits.max(1)[0].unsqueeze(1) + + start_pos = batch['start_pos'][slot] + + # "is pointable" means whether there is a span to be detected. + token_is_pointable = (start_pos.sum(1) > 0).float() + + if dae: + token_prediction = dilate_and_erode(norm_logits, threshold) + else: + token_prediction = norm_logits > threshold + token_prediction[:, 0] = False # Ignore [CLS]/<s> + token_correctness = torch.all(torch.eq(token_prediction, start_pos), 1).float() + token_accuracy = (token_correctness * token_is_pointable).sum() + (1 - token_is_pointable).sum() + total_correctness = token_correctness * token_is_pointable + (1 - token_is_pointable) + + metric_dict['eval_accuracy_%s' % slot] = token_accuracy + per_slot_correctness[slot] = total_correctness + + auto_labels[slot] = {} + for i in range(len(batch['input_ids'])): + auto_labels[slot][int(batch['example_id'][i])] = token_prediction[i] * token_is_pointable[i] + + goal_correctness = torch.stack([c for c in per_slot_correctness.values()], 1).prod(1) + goal_accuracy = goal_correctness.sum() + metric_dict['eval_accuracy_goal'] = goal_accuracy + return auto_labels, metric_dict + + +def predict_and_print_tags(args, model, tokenizer, batch, outputs, ids, input_ids_unmasked, values, inform): + per_slot_start_logits = outputs[0] + + class_types = model.class_types + + for i in range(len(ids)): + input_tokens = tokenizer.convert_ids_to_tokens(input_ids_unmasked[i]) + + input_ids = batch['input_ids'][i].tolist() + input_ids = input_ids[:len(input_ids) - input_ids[::-1].index(tokenizer.sep_token_id)] + + pos_i = {} + clb_i = {} + token_norm_weights = {} + for s_itr, slot in enumerate(model.slot_list): + pos_i[slot] = batch['start_pos'][slot][i].tolist() + clb_i[slot] = batch['class_label_id'][slot][i] + token_weights = per_slot_start_logits[i][s_itr][:len(input_ids)] + token_norm_weights[slot] = torch.clamp(token_weights - torch.mean(token_weights), min=0) / max(token_weights) + + print(ids[i]) + print(" ", end="") + for slot in model.slot_list: + if clb_i[slot] == class_types.index('copy_value'): + print("\033[1m%s\033[0m " % (slot[0]), end="") + else: + print(slot[0] + " ", end="") + print() + for k in range(len(input_ids)): + bold = False + print(" ", end="") + for slot in model.slot_list: + t_weight = token_norm_weights[slot][k] + if t_weight == 0.0: + print(" ", end="") + elif t_weight < 0.25: + print("\u2591 ", end="") + elif t_weight < 0.5: + print("\u2592 ", end="") + elif t_weight < 0.75: + print("\u2593 ", end="") + else: + print("\u2588 ", end="") + if pos_i[slot][k]: + bold = True + if bold: + print("\033[1m%s\033[0m" % (input_tokens[k])) + else: + print("%s" % (input_tokens[k])) diff --git a/dst_train.py b/dst_train.py new file mode 100644 index 0000000..b01ba7b --- /dev/null +++ b/dst_train.py @@ -0,0 +1,763 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import json +import math +import re + +import numpy as np +import torch +from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler) +from torch.utils.data.distributed import DistributedSampler +from torch.optim import (AdamW) +from tqdm import tqdm, trange + +from tensorboardX import SummaryWriter +from transformers import (get_linear_schedule_with_warmup) +from utils_run import (set_seed, to_device, from_device, + save_checkpoint, load_and_cache_examples, + dilate_and_erode) + +logger = logging.getLogger(__name__) + + +def train(args, train_dataset, dev_dataset, automatic_labels, model, tokenizer, processor): + """ Train the model """ + if args.local_rank in [-1, 0]: + tb_writer = SummaryWriter() + + model.eval() # No dropout + + args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) + train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) + train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) + + if args.max_steps > 0: + t_total = args.max_steps + args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 + else: + t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs + + if args.save_epochs > 0: + args.save_steps = t_total // args.num_train_epochs * args.save_epochs + assert args.save_steps == 0 or args.patience < 0 + + num_warmup_steps = int(t_total * args.warmup_proportion) + if args.patience > 0: + patience = args.patience + cur_min_loss = math.inf + + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ['bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total) + scaler = torch.cuda.amp.GradScaler() + if 'cuda' in args.device.type: + autocast = torch.cuda.amp.autocast(enabled=args.fp16) + else: + autocast = torch.cpu.amp.autocast(enabled=args.fp16) + + # multi-gpu training + model_single_gpu = model + if args.n_gpu > 1: + model = torch.nn.DataParallel(model_single_gpu) + + # Distributed training + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True) + + # Train! + logger.info("***** Running training *****") + logger.info(" Num examples = %d", len(train_dataset)) + logger.info(" Num Epochs = %d", args.num_train_epochs) + logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) + logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", t_total) + logger.info(" Warmup steps = %d", num_warmup_steps) + + global_step = 0 + tr_loss, logging_loss = 0.0, 0.0 + model.zero_grad() + train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) + set_seed(args) # Added here for reproductibility (even between python 2 and 3) + + for e_itr, _ in enumerate(train_iterator): + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) + train_dataset.dropout_input() + train_dataset.encode_slots() + train_dataset.encode_slot_values() + + for step, batch in enumerate(epoch_iterator): + model.train() + + # Add tokenized or encoded slot descriptions and encoded values to batch. + # We do this here instead of in TrippyDataset.__getitem__() because we only + # need them once, and not once for the entire batch. + batch['slot_ids'] = [] + batch['slot_mask'] = [] + for slot in model.slot_list: + batch['slot_ids'].append(train_dataset.encoded_slots_ids[slot][0]) + batch['slot_mask'].append(train_dataset.encoded_slots_ids[slot][1]) + batch['slot_ids'] = torch.stack(batch['slot_ids']) + batch['slot_mask'] = torch.stack(batch['slot_mask']) + batch['encoded_slot_values'] = train_dataset.encoded_slot_values + + batch = to_device(batch, args.device) + with autocast: + outputs = model(batch, step=step) # calls the "forward" def. + loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) + outputs = from_device(outputs) + batch = from_device(batch) + + cl_loss = outputs[1] + tk_loss = outputs[2] + tp_loss = outputs[3] + + if args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + epoch_iterator.set_postfix({'loss': loss.item(), 'cl': cl_loss.item(), 'tk': tk_loss.item(), 'tp': tp_loss.item()}) + + tr_loss += loss.item() + if (step + 1) % args.gradient_accumulation_steps == 0: + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + scaler.step(optimizer) + scaler.update() + scheduler.step() # Update learning rate schedule + model.zero_grad() + global_step += 1 + + # Log metrics + if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: + tb_writer.add_scalar('lr', scheduler.get_last_lr()[0], global_step) + tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step) + logging_loss = tr_loss + + # Save model checkpoint + if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: + save_checkpoint(args, global_step, model) + + if args.max_steps > 0 and global_step > args.max_steps: + epoch_iterator.close() + break + + # Only evaluate when single GPU otherwise metrics may not average well + if args.local_rank == -1 and dev_dataset is not None: + results = evaluate(args, dev_dataset, model_single_gpu, tokenizer, processor, no_print=True, prefix=global_step) + for key, value in results.items(): + tb_writer.add_scalar('eval_{}'.format(key), value, global_step) + + # Patience + if args.patience > 0: + if args.early_stop_criterion == "loss": + criterion = results['loss'].item() + elif args.early_stop_criterion == "goal": + criterion = -1 * results['eval_accuracy_goal'].item() + else: + logger.warn("Early stopping criterion %s not known. Aborting" % (args.early_stop_criterion)) + if criterion > cur_min_loss: + patience -= 1 + else: + # Save model checkpoint + patience = args.patience + save_checkpoint(args, global_step, model, keep_only_last_checkpoint=True) + cur_min_loss = criterion + train_iterator.set_postfix({'patience': patience, + 'eval loss': results['loss'].item(), + 'cl': results['cl_loss'].item(), + 'tk': results['tk_loss'].item(), + 'tp': results['tp_loss'].item(), + 'eval goal': results['eval_accuracy_goal'].item()}) + if patience == 0: + train_iterator.close() + break + + if args.max_steps > 0 and global_step > args.max_steps: + train_iterator.close() + break + + if args.local_rank in [-1, 0]: + tb_writer.close() + + return global_step, tr_loss / global_step + + +def evaluate(args, dataset, model, tokenizer, processor, no_print=False, prefix=""): + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + model.eval() # No dropout + + dataset.encode_slots() + dataset.save_encoded_slots(args.output_dir) + + dataset.encode_slot_values() + dataset.save_encoded_slot_values(args.output_dir) + + args.eval_batch_size = args.per_gpu_eval_batch_size + eval_sampler = SequentialSampler(dataset) # Note that DistributedSampler samples randomly + eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + + # Eval! + logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info(" Num examples = %d", len(dataset)) + logger.info(" Batch size = %d", args.eval_batch_size) + all_results = [] + all_preds = [] + ds = {slot: 'none' for slot in model.slot_list} + diag_state = {slot: torch.tensor([0 for _ in range(args.eval_batch_size)]).to(args.device) for slot in model.slot_list} + for batch in tqdm(eval_dataloader, desc="Evaluating"): + model.eval() + + # Reset dialog state if turn is first in the dialog. + turn_itrs = [dataset.features[i.item()].guid.split('-')[2] for i in batch['example_id']] + reset_diag_state = np.where(np.array(turn_itrs) == '0')[0] + for slot in model.slot_list: + for i in reset_diag_state: + diag_state[slot][i] = 0 + + with torch.no_grad(): + batch['diag_state'] = diag_state # Update + batch['encoded_slots_pooled'] = dataset.encoded_slots_pooled + batch['encoded_slots_seq'] = dataset.encoded_slots_seq + batch['encoded_slot_values'] = dataset.encoded_slot_values + batch = to_device(batch, args.device) + outputs = model(batch) + outputs = from_device(outputs) + batch = from_device(batch) + + unique_ids = [dataset.features[i.item()].guid for i in batch['example_id']] + values = [dataset.features[i.item()].values for i in batch['example_id']] + input_ids = [dataset.features[i.item()].input_ids for i in batch['example_id']] + inform = [dataset.features[i.item()].inform for i in batch['example_id']] + + # Update dialog state for next turn. + for slot in model.slot_list: + updates = outputs[8][slot].max(1)[1] + for i, u in enumerate(updates): + if u != 0: + diag_state[slot][i] = u + + value_match = dataset.query_values(outputs[13]) + + results = eval_metric(args, model, tokenizer, batch, outputs, 0.5, False, values, value_match) + all_results.append(results) + if not no_print: + predict_and_print(args, model, tokenizer, batch, outputs, unique_ids, input_ids, values, inform, ds, value_match) + preds, ds = predict_and_format(args, model, tokenizer, processor, batch, outputs, unique_ids, input_ids, values, inform, ds, value_match) + all_preds.append(preds) + + if not no_print: + all_preds = [item for sublist in all_preds for item in sublist] # Flatten list + + # Generate final results + final_results = {} + for k in all_results[0].keys(): + final_results[k] = torch.stack([r[k] for r in all_results]).sum() / len(dataset) + + # Write final predictions (for evaluation with external tool) + output_prediction_file = os.path.join(args.output_dir, "pred_res%s.%s.%s.json" % (args.cache_suffix, args.predict_type, prefix)) + if not no_print: + with open(output_prediction_file, "w") as f: + json.dump(all_preds, f, indent=2) + + return final_results + + +def eval_metric(args, model, tokenizer, batch, outputs, threshold=0.0, dae=False, values=None, value_match=None): + total_loss = outputs[0] + total_cl_loss = outputs[1] + total_tk_loss = outputs[2] + total_tp_loss = outputs[3] + per_slot_per_example_loss = outputs[4] + per_slot_per_example_cl_loss = outputs[5] + per_slot_per_example_tk_loss = outputs[6] + per_slot_per_example_tp_loss = outputs[7] + per_slot_class_logits = outputs[8] + per_slot_start_logits = outputs[9] + per_slot_value_logits = outputs[10] + per_slot_refer_logits = outputs[11] + + class_types = model.class_types + + input_ids = [] + for i in range(len(batch['input_ids'])): + clipped = batch['input_ids'][i].tolist() + clipped = clipped[:len(clipped) - clipped[::-1].index(tokenizer.sep_token_id)] + input_ids.append(clipped) + + metric_dict = {} + per_slot_correctness = {} + for slot in model.slot_list: + per_example_loss = per_slot_per_example_loss[slot] + per_example_cl_loss = per_slot_per_example_cl_loss[slot] + per_example_tk_loss = per_slot_per_example_tk_loss[slot] + per_example_tp_loss = per_slot_per_example_tp_loss[slot] + class_logits = per_slot_class_logits[slot] + start_logits = per_slot_start_logits[slot] + value_logits = per_slot_value_logits[slot] + refer_logits = per_slot_refer_logits[slot] + + mean = [] + is_value_match = [] + for i in range(len(batch['input_ids'])): + mean.append(torch.mean(start_logits[i][:len(input_ids[i])])) + is_value_match.append(value_match[slot][i][0] == values[i][slot]) + mean = torch.stack(mean) + is_value_match = torch.tensor(is_value_match) + norm_logits = torch.clamp(start_logits - mean.unsqueeze(1), min=0) / start_logits.max(1)[0].unsqueeze(1) + + class_label_id = batch['class_label_id'][slot] + start_pos = batch['start_pos'][slot] + value_label_id = batch['value_labels'][slot] if slot in batch['value_labels'] else None + refer_id = batch['refer_id'][slot] + + _, class_prediction = class_logits.max(1) + class_correctness = torch.eq(class_prediction, class_label_id).float() + class_accuracy = class_correctness.sum() + + # "is pointable" means whether class label is "copy_value", + # i.e., that there is a span to be detected. + token_is_pointable = torch.eq(class_label_id, class_types.index('copy_value')).float() # TODO: which is better? + #token_is_pointable = (start_pos.sum(1) > 0).float() + if dae: + token_prediction = dilate_and_erode(norm_logits, threshold) + else: + token_prediction = norm_logits > threshold + token_correctness = torch.all(torch.eq(token_prediction, start_pos), 1).float() + token_accuracy = (token_correctness * token_is_pointable).sum() + (1 - token_is_pointable).sum() + + value_correctness = is_value_match + value_accuracy = (value_correctness * token_is_pointable).sum() + (1 - token_is_pointable).sum() + + token_is_referrable = torch.eq(class_label_id, class_types.index('refer') if 'refer' in class_types else -1).float() + _, refer_prediction = refer_logits.max(1) + refer_correctness = torch.eq(refer_prediction, refer_id).float() + refer_accuracy = (refer_correctness * token_is_referrable).sum() + (1 - token_is_referrable).sum() + # NaNs mean that none of the examples in this batch contain referrals. -> division by 0 + # The accuracy therefore is 1 by default. -> replace NaNs + #if math.isnan(refer_accuracy) or math.isinf(refer_accuracy): + # refer_accuracy = torch.tensor(1.0, device=refer_accuracy.device) + + if args.value_matching_weight > 0.0: + total_correctness = class_correctness * \ + (token_is_pointable * token_correctness + (1 - token_is_pointable)) * \ + (token_is_pointable * value_correctness + (1 - token_is_pointable)) * \ + (token_is_referrable * refer_correctness + (1 - token_is_referrable)) + else: + total_correctness = class_correctness * \ + (token_is_pointable * token_correctness + (1 - token_is_pointable)) * \ + (token_is_referrable * refer_correctness + (1 - token_is_referrable)) + total_accuracy = total_correctness.sum() + + loss = per_example_loss.sum() + cl_loss = per_example_cl_loss.sum() + tk_loss = per_example_tk_loss.sum() + tp_loss = per_example_tp_loss.sum() + metric_dict['eval_accuracy_class_%s' % slot] = class_accuracy + metric_dict['eval_accuracy_token_%s' % slot] = token_accuracy + metric_dict['eval_accuracy_value_%s' % slot] = value_accuracy + metric_dict['eval_accuracy_refer_%s' % slot] = refer_accuracy + metric_dict['eval_accuracy_%s' % slot] = total_accuracy + metric_dict['eval_loss_%s' % slot] = loss + metric_dict['eval_cl_loss_%s' % slot] = cl_loss + metric_dict['eval_tk_loss_%s' % slot] = tk_loss + metric_dict['eval_tp_loss_%s' % slot] = tp_loss + per_slot_correctness[slot] = total_correctness + + goal_correctness = torch.stack([c for c in per_slot_correctness.values()], 1).prod(1) + goal_accuracy = goal_correctness.sum() + metric_dict['eval_accuracy_goal'] = goal_accuracy + metric_dict['loss'] = total_loss + metric_dict['cl_loss'] = total_cl_loss + metric_dict['tk_loss'] = total_tk_loss + metric_dict['tp_loss'] = total_tp_loss + return metric_dict + + +def get_spans(pred, norm_logits, input_tokens, usr_utt_spans): + span_indices = [i for i in range(len(pred)) if pred[i]] + prev_si = None + spans = [] + for si in span_indices: + if prev_si is None or si - prev_si > 1: + spans.append(([], [], [])) + spans[-1][0].append(si) + spans[-1][1].append(input_tokens[si]) + spans[-1][2].append(norm_logits[si]) + prev_si = si + spans = [(min(i), max(i), ' '.join(t for t in s), (sum(c) / len(c)).item()) for (i, s, c) in spans] + final_spans = {} + for s in spans: + for us_itr, us in enumerate(usr_utt_spans): + if s[0] >= us[0] and s[1] <= us[1]: + if us_itr not in final_spans: + final_spans[us_itr] = [] + final_spans[us_itr].append(s[2:]) + break + final_spans = list(final_spans.values()) + return final_spans + + +def get_usr_utt_spans(usr_mask): + span_indices = [i for i in range(len(usr_mask)) if usr_mask[i]] + prev_si = None + spans = [] + for si in span_indices: + if prev_si is None or si - prev_si > 1: + spans.append([]) + spans[-1].append(si) + prev_si = si + spans = [[min(s), max(s)] for s in spans] + return spans + + +def smooth_roberta_predictions(pred, input_tokens, tokenizer): + smoothed_pred = pred.detach().clone() + # Forward + span = False + i = 0 + while i < len(pred): + if pred[i] > 0: + span = True + + elif span and input_tokens[i][0] != "\u0120" and input_tokens[i] not in [tokenizer.unk_token, tokenizer.bos_token, + tokenizer.eos_token, tokenizer.unk_token, + tokenizer.sep_token, tokenizer.pad_token, + tokenizer.cls_token, tokenizer.mask_token]: + smoothed_pred[i] = 1 # use label for in-span tokens + elif span and (input_tokens[i][0] == "\u0120" or input_tokens[i] in [tokenizer.unk_token, tokenizer.bos_token, + tokenizer.eos_token, tokenizer.unk_token, + tokenizer.sep_token, tokenizer.pad_token, + tokenizer.cls_token, tokenizer.mask_token]): + span = False + i += 1 + # Backward + span = False + i = len(pred) - 1 + while i >= 0: + if pred[i] > 0: + span = True + if span and input_tokens[i][0] != "\u0120" and input_tokens[i] not in [tokenizer.unk_token, tokenizer.bos_token, + tokenizer.eos_token, tokenizer.unk_token, + tokenizer.sep_token, tokenizer.pad_token, + tokenizer.cls_token, tokenizer.mask_token]: + smoothed_pred[i] = 1 # use label for in-span tokens + elif span and input_tokens[i][0] == "\u0120": + smoothed_pred[i] = 1 # use label for in-span tokens + span = False + i -= 1 + return smoothed_pred + + +def smooth_bert_predictions(pred, input_tokens, tokenizer): + smoothed_pred = pred.detach().clone() + # Forward + span = False + i = 0 + while i < len(pred): + if pred[i] > 0: + span = True + elif span and input_tokens[i][0:2] == "##": + smoothed_pred[i] = 1 # use label for in-span tokens + else: + span = False + i += 1 + # Backward + span = False + i = len(pred) - 1 + while i >= 0: + if pred[i] > 0: + span = True + if span and input_tokens[i + 1][0:2] == "##": + smoothed_pred[i] = 1 # use label for in-span tokens + else: + span = False + i -= 1 + return smoothed_pred + + +def predict_and_format(args, model, tokenizer, processor, batch, outputs, ids, input_ids_unmasked, values, inform, ds, value_match, dae=False): + def _tokenize(text): + if "\u0120" in text: + text = re.sub(" ", "", text) + text = re.sub("\u0120", " ", text) + else: + text = re.sub(" ##", "", text) + text = text.strip() + return ' '.join([tok for tok in map(str.strip, re.split("(\W+)", text)) if len(tok) > 0]) + + per_slot_class_logits = outputs[8] + per_slot_start_logits = outputs[9] + per_slot_value_logits = outputs[10] + per_slot_refer_logits = outputs[11] + + class_types = model.class_types + + prediction_list = [] + dialog_state = ds + for i in range(len(ids)): + if int(ids[i].split("-")[2]) == 0: + dialog_state = {slot: 'none' for slot in model.slot_list} + + input_tokens = tokenizer.convert_ids_to_tokens(input_ids_unmasked[i]) + + prediction = {} + prediction_addendum = {} + + prediction['guid'] = ids[i].split("-") + input_ids = batch['input_ids'][i].tolist() + input_ids = input_ids[:len(input_ids) - input_ids[::-1].index(tokenizer.sep_token_id)] + prediction['input_ids'] = input_ids + + # assign identified spans to their respective usr turns (simply append spans as list of lists) + usr_utt_spans = get_usr_utt_spans(batch['usr_mask'][i][1:]) + + for slot in model.slot_list: + class_logits = per_slot_class_logits[slot][i] + start_logits = per_slot_start_logits[slot][i] + value_logits = per_slot_value_logits[slot][i] if per_slot_value_logits[slot] is not None else None + refer_logits = per_slot_refer_logits[slot][i] + + weights = start_logits[:len(input_ids)] + norm_logits = torch.clamp(weights - torch.mean(weights), min=0) / torch.max(weights) + + class_label_id = int(batch['class_label_id'][slot][i]) + start_pos = batch['start_pos'][slot][i].tolist() + refer_id = int(batch['refer_id'][slot][i]) + + class_prediction = int(class_logits.argmax()) + + if dae: + start_prediction = dilate_and_erode(norm_logits.unsqueeze(0), 0.0).squeeze(0) + else: + start_prediction = norm_logits > 0.0 + if "roberta" in args.model_type: + start_prediction = smooth_roberta_predictions(start_prediction, input_tokens, tokenizer) + else: + start_prediction = smooth_bert_predictions(start_prediction, input_tokens, tokenizer) + start_prediction[0] = False # Ignore <s> + + value_label_id = [] + if slot in batch['value_labels']: + value_label_id = batch['value_labels'][slot][i] > 0.0 + if value_logits is not None: + value_logits /= sum(value_logits) # Scale + value_prediction = value_logits >= (1.0 / len(value_logits)) # For attention based value matching + value_logits = value_logits.tolist() + + refer_prediction = int(refer_logits.argmax()) + + prediction['class_prediction_%s' % slot] = class_prediction + prediction['class_label_id_%s' % slot] = class_label_id + prediction['start_prediction_%s' % slot] = [i for i in range(len(start_prediction)) if start_prediction[i] > 0] + prediction['start_confidence_%s' % slot] = [norm_logits[j].item() for j in range(len(start_prediction)) if start_prediction[j] > 0] + prediction['start_pos_%s' % slot] = [i for i in range(len(start_pos)) if start_pos[i] > 0] + prediction['value_label_id_%s' % slot] = [i for i in range(len(value_label_id)) if value_label_id[i] > 0] + prediction['value_prediction_%s' % slot] = [] + prediction['value_confidence_%s' % slot] = [] + if value_logits is not None: + prediction['value_prediction_%s' % slot] = [i for i in range(len(value_prediction)) if value_prediction[i] > 0] + prediction['value_confidence_%s' % slot] = [value_logits[i] for i in range(len(value_logits)) if value_prediction[i] > 0] + prediction['refer_prediction_%s' % slot] = refer_prediction + prediction['refer_id_%s' % slot] = refer_id + + if class_prediction == class_types.index('dontcare'): + dialog_state[slot] = 'dontcare' + elif class_prediction == class_types.index('copy_value'): + spans = get_spans(start_prediction[1:], norm_logits[1:], input_tokens[1:], usr_utt_spans) + if len(spans) > 0: + for e_itr in range(len(spans)): + for ee_itr in range(len(spans[e_itr])): + tmp = list(spans[e_itr][ee_itr]) + tmp[0] = _tokenize(tmp[0]) + spans[e_itr][ee_itr] = tuple(tmp) + dialog_state[slot] = spans + else: + dialog_state[slot] = "none" + elif 'true' in model.class_types and class_prediction == class_types.index('true'): + dialog_state[slot] = 'true' + elif 'false' in model.class_types and class_prediction == class_types.index('false'): + dialog_state[slot] = 'false' + elif class_prediction == class_types.index('inform'): + dialog_state[slot] = '§§' + inform[i][slot] # TODO: implement handling of multiple informed values + elif 'request' in model.class_types and class_prediction == model.class_types.index('request'): + # Don't carry over requested slots, except of type Boolean + if slot in processor.boolean: + dialog_state[slot] = 'true' + # Referral case is handled below + + prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot] + prediction_addendum['slot_groundtruth_%s' % slot] = values[i][slot] + prediction_addendum['slot_dist_prediction_%s' % slot] = value_match[slot][i][0] + prediction_addendum['slot_dist_confidence_%s' % slot] = value_match[slot][i][2] + prediction_addendum['slot_dist_similarity_%s' % slot] = value_match[slot][i][1] + prediction_addendum['slot_value_prediction_%s' % slot] = "" + prediction_addendum['slot_value_confidence_%s' % slot] = 1.0 + if len(prediction['value_prediction_%s' % slot]) > 0: + top_conf = np.argmax(prediction['value_confidence_%s' % slot]) + top_pred = prediction['value_prediction_%s' % slot][top_conf] + top_val = list(batch['encoded_slot_values'][slot].keys())[top_pred] + prediction_addendum['slot_value_prediction_%s' % slot] = top_val + prediction_addendum['slot_value_confidence_%s' % slot] = np.max(prediction['value_confidence_%s' % slot]) + + # Referral case. All other slot values need to be seen first in order + # to be able to do this correctly. + for slot in model.slot_list: + class_logits = per_slot_class_logits[slot][i] + refer_logits = per_slot_refer_logits[slot][i] + + class_prediction = int(class_logits.argmax()) + refer_prediction = int(refer_logits.argmax()) + + if 'refer' in class_types and class_prediction == class_types.index('refer'): + # Only slots that have been mentioned before can be referred to. + # One can think of a situation where one slot is referred to in the same utterance. + # This phenomenon is however currently not properly covered in the training data + # label generation process. + dialog_state[slot] = dialog_state[list(model.slot_list.keys())[refer_prediction]] + prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot] # Value update + + # Normalize value predictions + for slot in model.slot_list: + if isinstance(dialog_state[slot], list): + for e_itr in range(len(dialog_state[slot])): + for f_itr in range(len(dialog_state[slot][e_itr])): + tmp_state = list(dialog_state[slot][e_itr][f_itr]) + tmp_state[0] = processor.prediction_normalization(slot, tmp_state[0]) + dialog_state[slot][e_itr][f_itr] = tuple(tmp_state) + else: + dialog_state[slot] = processor.prediction_normalization(slot, dialog_state[slot]) + prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot] # Value update + + prediction.update(prediction_addendum) + prediction_list.append(prediction) + + return prediction_list, dialog_state + + +def predict_and_print(args, model, tokenizer, batch, outputs, ids, input_ids_unmasked, values, inform, ds, value_match): + per_slot_class_logits = outputs[8] + per_slot_start_logits = outputs[9] + per_slot_att_weights = outputs[12] + + class_types = model.class_types + + for i in range(len(ids)): + input_tokens = tokenizer.convert_ids_to_tokens(input_ids_unmasked[i]) + + input_ids = batch['input_ids'][i].tolist() + input_ids = input_ids[:len(input_ids) - input_ids[::-1].index(tokenizer.sep_token_id)] + + print(ids[i]) + + pos_i = {} + clb_i = {} + class_norm_weights = {} + token_norm_weights = {} + is_value_match = {} + for s in model.slot_list: + pos_i[s] = batch['start_pos'][s][i].tolist() + clb_i[s] = batch['class_label_id'][s][i] + if per_slot_att_weights[s] is not None: + class_weights = per_slot_att_weights[s][i][:len(input_ids)] + class_norm_weights[s] = torch.clamp(class_weights - torch.mean(class_weights), min=0) / torch.max(class_weights) + token_weights = per_slot_start_logits[s][i][:len(input_ids)] + token_norm_weights[s] = torch.clamp(token_weights - torch.mean(token_weights), min=0) / torch.max(token_weights) + is_value_match[s] = value_match[s][i][0] == values[i][s] + # Print value matching results + sorted_dists = value_match[s][i][3] + print("%20s: %s %s ..." % (s, values[i][s], sorted_dists[:3])) + + print(" ", end="") + if per_slot_att_weights[s] is not None: + for s in model.slot_list: + if clb_i[s] != class_types.index('none'): + print("\033[1m%s\033[0m" % (s[0]), end="") + else: + print(s[0], end="") + print(" | ", end="") + missed = "" + no_value_match = "" + for s in model.slot_list: + class_prediction = int(per_slot_class_logits[s][i].argmax()) + if clb_i[s] != class_types.index('none') and clb_i[s] != class_prediction: + missed += "%s: %d -> %d " % (s[0], clb_i[s], class_prediction) + if clb_i[s] == class_types.index('copy_value'): + print("\033[1m%s\033[0m " % (s[0]), end="") + else: + print(s[0] + " ", end="") + if clb_i[s] == class_types.index('copy_value') and not is_value_match[s]: + no_value_match += "%s (%s)" % (s[0], value_match[s][i][0]) + if len(missed) > 0: + print("| missed: " + missed, end="") + if len(no_value_match) > 0: + print("| wrong value match: %s" % no_value_match, end="") + print() + for k in range(len(input_ids)): + bold = False + print(" ", end="") + if per_slot_att_weights[s] is not None: + for s in model.slot_list: + c_weight = class_norm_weights[s][k] + if c_weight == 0.0: + print(" ", end="") + elif c_weight < 0.25: + print("\u2591", end="") + elif c_weight < 0.5: + print("\u2592", end="") + elif c_weight < 0.75: + print("\u2593", end="") + else: + print("\u2588", end="") + print(" | ", end="") + for s in model.slot_list: + t_weight = token_norm_weights[s][k] + if t_weight == 0.0: + print(" ", end="") + elif t_weight < 0.25: + print("\u2591 ", end="") + elif t_weight < 0.5: + print("\u2592 ", end="") + elif t_weight < 0.75: + print("\u2593 ", end="") + else: + print("\u2588 ", end="") + if pos_i[s][k]: + bold = True + if bold: + print("\033[1m%s\033[0m" % (input_tokens[k])) + else: + print("%s" % (input_tokens[k])) diff --git a/metric_dst.py b/metric_dst.py new file mode 100644 index 0000000..0603d08 --- /dev/null +++ b/metric_dst.py @@ -0,0 +1,566 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import json +import sys +import numpy as np +import re +import math +import argparse + + +def load_dataset_config(dataset_config): + with open(dataset_config, "r", encoding='utf-8') as f: + raw_config = json.load(f) + return raw_config['class_types'], raw_config['slots'], raw_config['label_maps'], raw_config['noncategorical'], raw_config['boolean'] + + +def tokenize(text): + if "\u0120" in text: + text = re.sub(" ", "", text) + text = re.sub("\u0120", " ", text) + else: + text = re.sub(" ##", "", text) + text = text.strip() + return ' '.join([tok for tok in map(str.strip, re.split("(\W+)", text)) if len(tok) > 0]) + + +def filter_sequences(seqs, mode="first"): + if mode == "first": + return tokenize(seqs[0][0][0]) + elif mode == "max_first": + max_conf = 0 + max_idx = 0 + for e_itr, e in enumerate(seqs[0]): + if e[1] > max_conf: + max_conf = e[1] + max_idx = e_itr + return tokenize(seqs[0][max_idx][0]) + elif mode == "max": + max_conf = 0 + max_t_idx = 0 + for t_itr, t in enumerate(seqs): + for e_itr, e in enumerate(t): + if e[1] > max_conf: + max_conf = e[1] + max_t_idx = t_itr + max_idx = e_itr + return tokenize(seqs[max_t_idx][max_idx][0]) + else: + print("WARN: mode %s unknown. Aborting." % mode) + exit() + + +def is_in_list(tok, value): + found = False + tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0] + value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0] + tok_len = len(tok_list) + value_len = len(value_list) + for i in range(tok_len + 1 - value_len): + if tok_list[i:i + value_len] == value_list: + found = True + break + return found + + +def check_slot_inform(value_label, inform_label, label_maps): + value = inform_label + if value_label == inform_label: + value = value_label + elif is_in_list(inform_label, value_label): + value = value_label + elif is_in_list(value_label, inform_label): + value = value_label + elif inform_label in label_maps: + for inform_label_variant in label_maps[inform_label]: + if value_label == inform_label_variant: + value = value_label + break + elif is_in_list(inform_label_variant, value_label): + value = value_label + break + elif is_in_list(value_label, inform_label_variant): + value = value_label + break + elif value_label in label_maps: + for value_label_variant in label_maps[value_label]: + if value_label_variant == inform_label: + value = value_label + break + elif is_in_list(inform_label, value_label_variant): + value = value_label + break + elif is_in_list(value_label_variant, inform_label): + value = value_label + break + return value + + +def match(gt, pd, label_maps): + # We want to be as conservative as possible here. + # We only allow maps according to label_maps and + # tolerate the absence/presence of the definite article. + if pd[:4] == "the " and gt == pd[4:]: + return True + if gt[:4] == "the " and gt[4:] == pd: + return True + if gt in label_maps: + for variant in label_maps[gt]: + if variant == pd: + return True + return False + + +def get_joint_slot_correctness(fp, args, class_types, label_maps, + key_class_label_id='class_label_id', + key_class_prediction='class_prediction', + key_start_pos='start_pos', + key_start_prediction='start_prediction', + key_start_confidence='start_confidence', + key_refer_id='refer_id', + key_refer_prediction='refer_prediction', + key_slot_groundtruth='slot_groundtruth', + key_slot_prediction='slot_prediction', + key_slot_dist_prediction='slot_dist_prediction', + key_slot_dist_confidence='slot_dist_confidence', + key_value_prediction='value_prediction', + key_value_groundtruth='value_groundtruth', + key_value_confidence='value_confidence', + key_slot_value_prediction='slot_value_prediction', + key_slot_value_confidence='slot_value_confidence', + noncategorical=False, boolean=False): + with open(fp) as f: + preds = json.load(f) + class_correctness = [[] for cl in range(len(class_types) + 1)] + confusion_matrix = [[[] for cl_b in range(len(class_types))] for cl_a in range(len(class_types))] + pos_correctness = [] + refer_correctness = [] + val_correctness = [] + total_correctness = [] + c_tp = {ct: 0 for ct in range(len(class_types))} + c_tn = {ct: 0 for ct in range(len(class_types))} + c_fp = {ct: 0 for ct in range(len(class_types))} + c_fn = {ct: 0 for ct in range(len(class_types))} + s_confidence_bins = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} + s_confidence_cnts = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} + confidence_bins = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} + confidence_cnts = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} + a_confidence_bins = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} + a_confidence_cnts = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} + + value_match_cnt = 0 + for pred in preds: + guid = pred['guid'] # List: set_type, dialogue_idx, turn_idx + turn_gt_class = pred[key_class_label_id] + turn_pd_class = pred[key_class_prediction] + gt_start_pos = pred[key_start_pos] + pd_start_pos = pred[key_start_prediction] + pd_start_conf = pred[key_start_confidence] + gt_refer = pred[key_refer_id] + pd_refer = pred[key_refer_prediction] + gt_slot = tokenize(pred[key_slot_groundtruth]) + pd_slot = pred[key_slot_prediction] + pd_slot_dist_pred = tokenize(pred[key_slot_dist_prediction]) + pd_slot_dist_conf = float(pred[key_slot_dist_confidence]) + pd_slot_value_pred = tokenize(pred[key_slot_value_prediction]) + pd_slot_value_conf = pred[key_slot_value_confidence] + + pd_slot_raw = pd_slot + if isinstance(pd_slot, list): + pd_slot = filter_sequences(pd_slot, mode="max") + else: + pd_slot = tokenize(pd_slot) + + # Make sure the true turn labels are contained in the prediction json file! + joint_gt_slot = gt_slot + + # Sequence tagging confidence + if len(pd_start_pos) > 0: + avg_s_conf = np.mean(pd_start_conf) + if avg_s_conf == 0.0: + avg_s_conf += 1e-8 + s_c_bin = "%.1f" % (math.ceil(avg_s_conf * 10) / 10) + if gt_start_pos == pd_start_pos: + s_confidence_bins[s_c_bin] += 1 + s_confidence_cnts[s_c_bin] += 1 + + # Distance based value matching confidence + if pd_slot_dist_conf == 0.0: + pd_slot_dist_conf += 1e-8 + c_bin = "%.1f" % (math.ceil(pd_slot_dist_conf * 10) / 10) + if joint_gt_slot == pd_slot_dist_pred: + confidence_bins[c_bin] += 1 + confidence_cnts[c_bin] += 1 + + # Attention based value matching confidence + if pd_slot_value_conf == 0.0: + pd_slot_value_conf += 1e-8 + c_bin = "%.1f" % (math.ceil(pd_slot_value_conf * 10) / 10) + if joint_gt_slot == pd_slot_value_pred: + a_confidence_bins[c_bin] += 1 + a_confidence_cnts[c_bin] += 1 + + if guid[-1] == '0': # First turn, reset the slots + joint_pd_slot = 'none' + + # If turn_pd_class or a value to be copied is "none", do not update the dialog state. + if turn_pd_class == class_types.index('none'): + pass + elif turn_pd_class == class_types.index('dontcare'): + if not boolean: + joint_pd_slot = 'dontcare' + elif turn_pd_class == class_types.index('copy_value'): + if not boolean: + if pd_slot not in ["< none >", "[ NONE ]"]: + joint_pd_slot = pd_slot + elif 'true' in class_types and turn_pd_class == class_types.index('true'): + if boolean: + joint_pd_slot = 'true' + elif 'false' in class_types and turn_pd_class == class_types.index('false'): + if boolean: + joint_pd_slot = 'false' + elif 'refer' in class_types and turn_pd_class == class_types.index('refer'): + if not boolean: + if pd_slot[0:2] == "§§": + if pd_slot[2:].strip() != 'none': + joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:].strip(), label_maps) + elif pd_slot != 'none': + joint_pd_slot = pd_slot + elif 'inform' in class_types and turn_pd_class == class_types.index('inform'): + if not boolean: + if pd_slot[0:2] == "§§": + if pd_slot[2:].strip() != 'none': + joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:].strip(), label_maps) + elif 'request' in class_types and turn_pd_class == class_types.index('request'): + pass + else: + print("ERROR: Unexpected class_type. Aborting.") + exit() + + # Value matching + if args.confidence_threshold < 1.0 and turn_pd_class == class_types.index('copy_value') and not boolean: + # Treating categorical slots + if not noncategorical: + max_conf = max(np.mean(pd_start_conf), pd_slot_dist_conf, pd_slot_value_conf) + if max_conf == pd_slot_dist_conf and max_conf > args.confidence_threshold: + joint_pd_slot = tokenize(pd_slot_dist_pred) + value_match_cnt += 1 + elif max_conf == pd_slot_value_conf and max_conf > args.confidence_threshold: + joint_pd_slot = tokenize(pd_slot_value_pred) + value_match_cnt += 1 + # Treating all slots (including categorical slots) + if pd_slot_dist_conf > args.confidence_threshold: + joint_pd_slot = tokenize(pd_slot_dist_pred) + value_match_cnt += 1 + + total_correct = True + + # Check the per turn correctness of the class_type prediction + if turn_gt_class == turn_pd_class: + class_correctness[turn_gt_class].append(1.0) + class_correctness[-1].append(1.0) + c_tp[turn_gt_class] += 1 + # Only where there is a span, we check its per turn correctness + if turn_gt_class == class_types.index('copy_value'): + if gt_start_pos == pd_start_pos: + pos_correctness.append(1.0) + else: + pos_correctness.append(0.0) + # Only where there is a referral, we check its per turn correctness + if 'refer' in class_types and turn_gt_class == class_types.index('refer'): + if gt_refer == pd_refer: + refer_correctness.append(1.0) + print(" [%s] Correct referral: %s | %s" % (guid, gt_refer, pd_refer)) + else: + refer_correctness.append(0.0) + print(" [%s] Incorrect referral: %s | %s" % (guid, gt_refer, pd_refer)) + else: + if turn_gt_class == class_types.index('copy_value'): + pos_correctness.append(0.0) + if 'refer' in class_types and turn_gt_class == class_types.index('refer'): + refer_correctness.append(0.0) + class_correctness[turn_gt_class].append(0.0) + class_correctness[-1].append(0.0) + confusion_matrix[turn_gt_class][turn_pd_class].append(1.0) + c_fn[turn_gt_class] += 1 + c_fp[turn_pd_class] += 1 + for cc in range(len(class_types)): + if cc != turn_gt_class and cc != turn_pd_class: + c_tn[cc] += 1 + + # Check the joint slot correctness. + # If the value label is not none, then we need to have a value prediction. + # Even if the class_type is 'none', there can still be a value label, + # it might just not be pointable in the current turn. It might however + # be referrable and thus predicted correctly. + if joint_gt_slot == joint_pd_slot: + val_correctness.append(1.0) + elif joint_gt_slot != 'none' and joint_gt_slot != 'dontcare' and joint_gt_slot != 'true' and joint_gt_slot != 'false': + is_match = match(joint_gt_slot, joint_pd_slot, label_maps) + if not is_match: + val_correctness.append(0.0) + total_correct = False + print(" [%s] Incorrect value (variant): %s (turn class: %s) | %s (turn class: %s) | %.2f %s %.2f %s %s %s" % (guid, + joint_gt_slot, turn_gt_class, + joint_pd_slot, turn_pd_class, + np.mean(pd_start_conf), pd_slot_raw, + pd_slot_dist_conf, pd_slot_dist_pred, + "%.2f" % pd_slot_value_conf if pd_slot_value_pred != "" else "", + pd_slot_value_pred)) + else: + val_correctness.append(1.0) + else: + val_correctness.append(0.0) + total_correct = False + print(" [%s] Incorrect value: %s (turn class: %s) | %s (turn class: %s) | %.2f %s %.2f %s %s %s" % (guid, + joint_gt_slot, turn_gt_class, + joint_pd_slot, turn_pd_class, + np.mean(pd_start_conf), pd_slot_raw, + pd_slot_dist_conf, pd_slot_dist_pred, + "%.2f" % pd_slot_value_conf if pd_slot_value_pred != "" else "", + pd_slot_value_pred)) + + total_correctness.append(1.0 if total_correct else 0.0) + + # Account for empty lists (due to no instances of spans or referrals being seen) + if pos_correctness == []: + pos_correctness.append(1.0) + if refer_correctness == []: + refer_correctness.append(1.0) + + for ct in range(len(class_types)): + if c_tp[ct] + c_fp[ct] > 0: + precision = c_tp[ct] / (c_tp[ct] + c_fp[ct]) + else: + precision = 1.0 + if c_tp[ct] + c_fn[ct] > 0: + recall = c_tp[ct] / (c_tp[ct] + c_fn[ct]) + else: + recall = 1.0 + if precision + recall > 0: + f1 = 2 * ((precision * recall) / (precision + recall)) + else: + f1 = 1.0 + if c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct] > 0: + acc = (c_tp[ct] + c_tn[ct]) / (c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct]) + else: + acc = 1.0 + print("Performance for class '%s' (%s): Recall: %.2f (%d of %d), Precision: %.2f, F1: %.2f, Accuracy: %.2f (TP/TN/FP/FN: %d/%d/%d/%d)" % + (class_types[ct], ct, recall, np.sum(class_correctness[ct]), len(class_correctness[ct]), precision, f1, acc, c_tp[ct], c_tn[ct], c_fp[ct], c_fn[ct])) + + print("Confusion matrix:") + for cl in range(len(class_types)): + print(" %s" % (cl), end="") + print("") + for cl_a in range(len(class_types)): + print("%s " % (cl_a), end="") + for cl_b in range(len(class_types)): + if len(class_correctness[cl_a]) > 0: + print("%.2f " % (np.sum(confusion_matrix[cl_a][cl_b]) / len(class_correctness[cl_a])), end="") + else: + print("---- ", end="") + print("") + + print("Confidence bins for sequence tagging:") + print(" bin cor") + for c in s_confidence_bins: + print(" %s %.2f (%d of %d)" % (c, s_confidence_bins[c] / (s_confidence_cnts[c] + 1e-8), s_confidence_bins[c], s_confidence_cnts[c])) + + print("Confidence bins for distance based value matching:") + print(" bin cor") + for c in confidence_bins: + print(" %s %.2f (%d of %d)" % (c, confidence_bins[c] / (confidence_cnts[c] + 1e-8), confidence_bins[c], confidence_cnts[c])) + + print("Confidence bins for attention based value matching:") + print(" bin cor") + for c in a_confidence_bins: + print(" %s %.2f (%d of %d)" % (c, a_confidence_bins[c] / (a_confidence_cnts[c] + 1e-8), a_confidence_bins[c], a_confidence_cnts[c])) + + print("Values replaced by value matching:", value_match_cnt) + + return np.asarray(total_correctness), np.asarray(val_correctness), np.asarray(class_correctness), np.asarray(pos_correctness), np.asarray(refer_correctness), np.asarray(confusion_matrix), c_tp, c_tn, c_fp, c_fn, s_confidence_bins, s_confidence_cnts, confidence_bins, confidence_cnts, a_confidence_bins, a_confidence_cnts + + +if __name__ == "__main__": + acc_list = [] + s_acc_list = [] + key_class_label_id = 'class_label_id_%s' + key_class_prediction = 'class_prediction_%s' + key_start_pos = 'start_pos_%s' + key_start_prediction = 'start_prediction_%s' + key_start_confidence = 'start_confidence_%s' + key_refer_id = 'refer_id_%s' + key_refer_prediction = 'refer_prediction_%s' + key_slot_groundtruth = 'slot_groundtruth_%s' + key_slot_prediction = 'slot_prediction_%s' + key_slot_dist_prediction = 'slot_dist_prediction_%s' + key_slot_dist_confidence = 'slot_dist_confidence_%s' + key_value_prediction = 'value_prediction_%s' + key_value_groundtruth = 'value_label_id_%s' + key_value_confidence = 'value_confidence_%s' + key_slot_value_prediction = 'slot_value_prediction_%s' + key_slot_value_confidence = 'slot_value_confidence_%s' + + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--dataset_config", default=None, type=str, required=True, + help="Dataset configuration file.") + parser.add_argument("--file_list", default=None, type=str, required=True, + help="List of input files.") + + # Other parameters + parser.add_argument("--confidence_threshold", default=1.0, type=float, + help="Threshold for value matching confidence. 1.0 means no value matching is used.") + + args = parser.parse_args() + + assert args.confidence_threshold >= 0.0 and args.confidence_threshold <= 1.0 + + class_types, slots, label_maps, noncategorical, boolean = load_dataset_config(args.dataset_config) + + # Prepare label_maps + label_maps_tmp = {} + for v in label_maps: + label_maps_tmp[tokenize(v)] = [tokenize(nv) for nv in label_maps[v]] + label_maps = label_maps_tmp + + for fp in sorted(glob.glob(args.file_list)): + # Infer slot list from data if not provided. + if len(slots) == 0: + with open(fp) as f: + preds = json.load(f) + for e in preds[0]: + slot = re.match("^slot_groundtruth_(.*)$", e) + slot = slot[1] if slot else None + if slot and slot not in slots: + slots.append(slot) + print(fp) + goal_correctness = 1.0 + cls_acc = [[] for cl in range(len(class_types))] + cls_conf = [[[] for cl_b in range(len(class_types))] for cl_a in range(len(class_types))] + c_tp = {ct: 0 for ct in range(len(class_types))} + c_tn = {ct: 0 for ct in range(len(class_types))} + c_fp = {ct: 0 for ct in range(len(class_types))} + c_fn = {ct: 0 for ct in range(len(class_types))} + s_confidence_bins = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} + s_confidence_cnts = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} + confidence_bins = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} + confidence_cnts = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} + a_confidence_bins = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} + a_confidence_cnts = {"%.1f" % (0.1 + b * 0.1): 0 for b in range(10)} + for slot in slots: + tot_cor, joint_val_cor, cls_cor, pos_cor, ref_cor, \ + conf_mat, ctp, ctn, cfp, cfn, \ + scbins, sccnts, cbins, ccnts, acbins, accnts = get_joint_slot_correctness(fp, args, class_types, label_maps, + key_class_label_id=(key_class_label_id % slot), + key_class_prediction=(key_class_prediction % slot), + key_start_pos=(key_start_pos % slot), + key_start_prediction=(key_start_prediction % slot), + key_start_confidence=(key_start_confidence % slot), + key_refer_id=(key_refer_id % slot), + key_refer_prediction=(key_refer_prediction % slot), + key_slot_groundtruth=(key_slot_groundtruth % slot), + key_slot_prediction=(key_slot_prediction % slot), + key_slot_dist_prediction=(key_slot_dist_prediction % slot), + key_slot_dist_confidence=(key_slot_dist_confidence % slot), + key_value_prediction=(key_value_prediction % slot), + key_value_groundtruth=(key_value_groundtruth % slot), + key_value_confidence=(key_value_confidence % slot), + key_slot_value_prediction=(key_slot_value_prediction % slot), + key_slot_value_confidence=(key_slot_value_confidence % slot), + noncategorical=slot in noncategorical, + boolean=slot in boolean) + print('%s: joint slot acc: %g, joint value acc: %g, turn class acc: %g, turn position acc: %g, turn referral acc: %g' % + (slot, np.mean(tot_cor), np.mean(joint_val_cor), np.mean(cls_cor[-1]), np.mean(pos_cor), np.mean(ref_cor))) + goal_correctness *= tot_cor + for cl_a in range(len(class_types)): + cls_acc[cl_a] += cls_cor[cl_a] + for cl_b in range(len(class_types)): + cls_conf[cl_a][cl_b] += list(conf_mat[cl_a][cl_b]) + c_tp[cl_a] += ctp[cl_a] + c_tn[cl_a] += ctn[cl_a] + c_fp[cl_a] += cfp[cl_a] + c_fn[cl_a] += cfn[cl_a] + for c in scbins: + s_confidence_bins[c] += scbins[c] + s_confidence_cnts[c] += sccnts[c] + for c in cbins: + confidence_bins[c] += cbins[c] + confidence_cnts[c] += ccnts[c] + for c in cbins: + a_confidence_bins[c] += acbins[c] + a_confidence_cnts[c] += accnts[c] + + for ct in range(len(class_types)): + if c_tp[ct] + c_fp[ct] > 0: + precision = c_tp[ct] / (c_tp[ct] + c_fp[ct]) + else: + precision = 1.0 + if c_tp[ct] + c_fn[ct] > 0: + recall = c_tp[ct] / (c_tp[ct] + c_fn[ct]) + else: + recall = 1.0 + if precision + recall > 0: + f1 = 2 * ((precision * recall) / (precision + recall)) + else: + f1 = 1.0 + if c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct] > 0: + acc = (c_tp[ct] + c_tn[ct]) / (c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct]) + else: + acc = 1.0 + print("Performance for class '%s' (%s): Recall: %.2f (%d of %d), Precision: %.2f, F1: %.2f, Accuracy: %.2f (TP/TN/FP/FN: %d/%d/%d/%d)" % + (class_types[ct], ct, recall, np.sum(cls_acc[ct]), len(cls_acc[ct]), precision, f1, acc, c_tp[ct], c_tn[ct], c_fp[ct], c_fn[ct])) + + print("Confusion matrix:") + for cl in range(len(class_types)): + print(" %s" % (cl), end="") + print("") + for cl_a in range(len(class_types)): + print("%s " % (cl_a), end="") + for cl_b in range(len(class_types)): + if len(cls_acc[cl_a]) > 0: + print("%.2f " % (np.sum(cls_conf[cl_a][cl_b]) / len(cls_acc[cl_a])), end="") + else: + print("---- ", end="") + print("") + + print("Confidence bins for sequence tagging:") + print(" bin cor") + for c in s_confidence_bins: + print(" %s %.2f (%d of %d)" % (c, s_confidence_bins[c] / (s_confidence_cnts[c] + 1e-8), s_confidence_bins[c], s_confidence_cnts[c])) + + print("Confidence bins for distance based value matching:") + print(" bin cor") + for c in confidence_bins: + print(" %s %.2f (%d of %d)" % (c, confidence_bins[c] / (confidence_cnts[c] + 1e-8), confidence_bins[c], confidence_cnts[c])) + + print("Confidence bins for attention based value matching:") + print(" bin cor") + for c in a_confidence_bins: + print(" %s %.2f (%d of %d)" % (c, a_confidence_bins[c] / (a_confidence_cnts[c] + 1e-8), a_confidence_bins[c], a_confidence_cnts[c])) + + acc = np.mean(goal_correctness) + acc_list.append((fp, acc)) + + acc_list_s = sorted(acc_list, key=lambda tup: tup[1], reverse=True) + for (fp, acc) in acc_list_s: + print('Joint goal acc: %g, %s' % (acc, fp)) diff --git a/modeling_dst.py b/modeling_dst.py new file mode 100644 index 0000000..72bc8d4 --- /dev/null +++ b/modeling_dst.py @@ -0,0 +1,415 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.nn import MultiheadAttention +import torch.nn.functional as F + +from transformers import (BertModel, BertPreTrainedModel, + RobertaModel, RobertaPreTrainedModel, + ElectraModel, ElectraPreTrainedModel) + +PARENT_CLASSES = { + 'bert': BertPreTrainedModel, + 'roberta': RobertaPreTrainedModel, + 'electra': ElectraPreTrainedModel +} + +MODEL_CLASSES = { + BertPreTrainedModel: BertModel, + RobertaPreTrainedModel: RobertaModel, + ElectraPreTrainedModel: ElectraModel +} + +class ElectraPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +def TransformerForDST(parent_name): + if parent_name not in PARENT_CLASSES: + raise ValueError("Unknown model %s" % (parent_name)) + + class TransformerForDST(PARENT_CLASSES[parent_name]): + def __init__(self, config): + assert config.model_type in PARENT_CLASSES + super(TransformerForDST, self).__init__(config) + self.model_type = config.model_type + self.slot_list = config.slot_list + self.noncategorical = config.noncategorical + self.class_types = config.class_types + self.class_labels = config.class_labels + self.class_loss_ratio = config.class_loss_ratio + self.slot_attention_heads = config.slot_attention_heads + self.tag_none_target = config.tag_none_target + self.value_matching_weight = config.value_matching_weight + self.none_weight = config.none_weight + self.proto_loss_function = config.proto_loss_function + self.token_loss_function = config.token_loss_function + self.value_loss_function = config.value_loss_function + + config.output_hidden_states = True + + # Make sure this module has the same name as in the pretrained checkpoint you want to load! + self.add_module(self.model_type, MODEL_CLASSES[PARENT_CLASSES[self.model_type]](config)) + if self.model_type == "electra": + self.pooler = ElectraPooler(config) + + self.dropout = nn.Dropout(config.dropout_rate) + self.gelu = nn.GELU() + + # Only use refer loss if refer class is present in dataset. + if 'refer' in self.class_types: + self.refer_index = self.class_types.index('refer') + else: + self.refer_index = -1 + + # Attention for slot gates + self.class_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads) + + # Conditioned sequence tagging + self.token_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads) + self.refer_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads) + self.value_att = MultiheadAttention(config.hidden_size, self.slot_attention_heads) + + self.token_layer_norm_proto = nn.LayerNorm(config.hidden_size) + self.token_layer_norm = nn.LayerNorm(config.hidden_size) + self.class_layer_norm = nn.LayerNorm(config.hidden_size) + + # Conditioned slot gate + self.h1c = nn.Linear(config.hidden_size, config.hidden_size) + self.h2c = nn.Linear(config.hidden_size * 2, config.hidden_size * 2) + self.llc = nn.Linear(config.hidden_size * 2, self.class_labels) + + # Conditioned refer gate + self.h2r = nn.Linear(config.hidden_size * 2, config.hidden_size * 1) + + # Loss functions + self.binary_cross_entropy = F.binary_cross_entropy + self.mse = nn.MSELoss(reduction="none") + self.refer_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=len(self.slot_list)) # Ignore 'none' target + if self.none_weight != 1.0: + none_weight = self.none_weight + weight_mass = none_weight + (self.class_labels - 1) + none_weight /= weight_mass + other_weights = 1 / weight_mass + self.clweights = torch.tensor([other_weights] * self.class_labels) + self.clweights[self.class_types.index('none')] = none_weight + self.class_loss_fct = CrossEntropyLoss(weight=self.clweights, reduction='none') + else: + self.class_loss_fct = CrossEntropyLoss(reduction='none') + + self.init_weights() + + def forward(self, batch, step=None, mode=None): + assert mode in [None, "proto", "tag", "encode", "represent"] + + # Required + input_ids = batch['input_ids'] + input_mask = batch['input_mask'] + # Optional + segment_ids = batch['segment_ids'] if 'segment_ids' in batch else None + usr_mask = batch['usr_mask'] if 'usr_mask' in batch else None + # For loss computation + token_pos = batch['start_pos'] if 'start_pos' in batch else None + refer_id = batch['refer_id'] if 'refer_id' in batch else None + class_label_id = batch['class_label_id'] if 'class_label_id' in batch else None + # Dynamic elements + slot_ids = batch['slot_ids'] if 'slot_ids' in batch else None + slot_mask = batch['slot_mask'] if 'slot_mask' in batch else None + value_labels = batch['value_labels'] if 'value_labels' in batch else None + dropout_value_feat = batch['dropout_value_feat'] if 'dropout_value_feat' in batch else None + + batch_input_mask = input_mask + if slot_ids is not None and slot_mask is not None: + input_ids = torch.cat((input_ids, slot_ids)) + input_mask = torch.cat((input_mask, slot_mask)) + + outputs = getattr(self, self.model_type)( + input_ids, + attention_mask=input_mask, + token_type_ids=None, + position_ids=None, + head_mask=None + ) + + sequence_output = outputs[0] + if self.model_type == "electra": + pooled_output = self.pooler(sequence_output) + else: + pooled_output = outputs[1] + + if slot_ids is not None and slot_mask is not None: + encoded_slots_seq = sequence_output[-1 * len(slot_ids):, :, :] + sequence_output = sequence_output[:-1 * len(slot_ids), :, :] + encoded_slots_pooled = pooled_output[-1 * len(slot_ids):, :] + pooled_output = pooled_output[:-1 * len(slot_ids), :] + + sequence_output = self.dropout(sequence_output) + pooled_output = self.dropout(pooled_output) + + inverted_input_mask = ~(batch_input_mask.bool()) + if usr_mask is None: + usr_mask = input_mask + inverted_usr_mask = ~(usr_mask.bool()) + + # Create vector representations only + if mode == "encode": + return pooled_output, sequence_output, None + + # Proto-DST + if mode == "proto": + pos_vectors = {} + pos_weights = {} + pos_vectors, pos_weights = self.token_att( + query=encoded_slots_pooled.squeeze(1).unsqueeze(0), + key=sequence_output.transpose(0, 1), + value=sequence_output.transpose(0, 1), + key_padding_mask=inverted_input_mask, + need_weights=True) + pos_vectors = pos_vectors.squeeze(0) + pos_vectors = self.token_layer_norm_proto(pos_vectors) + pos_weights = pos_weights.squeeze(1) + + pos_labels_clipped = torch.clamp(token_pos.float(), min=0, max=1) + pos_labels_clipped_scaled = pos_labels_clipped / torch.clamp(pos_labels_clipped.sum(1).unsqueeze(1), min=1) + if self.proto_loss_function == "mse": + pos_token_loss = self.mse(pos_weights, pos_labels_clipped_scaled) # MSE should be better for scaled targets + else: + pos_token_loss = self.binary_cross_entropy(pos_weights, pos_labels_clipped_scaled, reduction="none") + pos_token_loss = pos_token_loss.sum(1) + + per_example_loss = pos_token_loss + total_loss = per_example_loss.sum() + + return (total_loss, pos_weights,) + + # Value tagging with proto-DST + if mode == "tag": + _, tag_weights = self.token_att(query=torch.stack(list(batch['value_reps'].values())).squeeze(2), + key=sequence_output.transpose(0, 1), + value=sequence_output.transpose(0, 1), + key_padding_mask=inverted_input_mask + inverted_usr_mask, + need_weights=True) + return (tag_weights,) + + # Attention for sequence tagging + vectors = {} + weights = {} + for s_itr, slot in enumerate(self.slot_list): + if slot_ids is not None and slot_mask is not None: + encoded_slot_seq = encoded_slots_seq[s_itr] + encoded_slot_pooled = encoded_slots_pooled[s_itr] + else: + encoded_slot_seq = batch['encoded_slots_seq'][slot] + encoded_slot_pooled = batch['encoded_slots_pooled'][slot] + query = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(0) + vectors[slot], weights[slot] = self.token_att( + query=query, + key=sequence_output.transpose(0, 1), + value=sequence_output.transpose(0, 1), + key_padding_mask=inverted_input_mask + inverted_usr_mask, + need_weights=True) + vectors[slot] = vectors[slot].squeeze(0) + vectors[slot] = self.token_layer_norm(vectors[slot]) + weights[slot] = weights[slot].squeeze(1) + + # Create vector representations only (alternative) + if mode == "represent": + return vectors, None, weights + + # ---- + # MAIN + # ---- + + total_loss = 0 + total_cl_loss = 0 + total_tk_loss = 0 + total_tp_loss = 0 + per_slot_per_example_loss = {} + per_slot_per_example_cl_loss = {} + per_slot_per_example_tk_loss = {} + per_slot_per_example_tp_loss = {} + per_slot_class_logits = {} + per_slot_token_weights = {} + per_slot_value_weights = {} + per_slot_refer_logits = {} + per_slot_att_weights = {} + for s_itr, slot in enumerate(self.slot_list): + if slot_ids is not None and slot_mask is not None: + encoded_slot_seq = encoded_slots_seq[s_itr] + encoded_slot_pooled = encoded_slots_pooled[s_itr] + else: + encoded_slot_seq = batch['encoded_slots_seq'][slot] + encoded_slot_pooled = batch['encoded_slots_pooled'][slot] + + # Attention for slot gates + query = encoded_slot_pooled.expand(pooled_output.size()).unsqueeze(0) + att_output, c_weights = self.class_att( + query=query, + key=sequence_output.transpose(0, 1), + value=sequence_output.transpose(0, 1), + key_padding_mask=inverted_input_mask, + need_weights=True) + att_output = self.class_layer_norm(att_output) + per_slot_att_weights[slot] = c_weights.squeeze(1) + + # Conditioned slot gate + slot_gate_feats = self.gelu(self.h1c(att_output.squeeze(0))) + slot_gate_input = self.gelu(self.h2c(torch.cat((encoded_slot_pooled.expand(pooled_output.size()), slot_gate_feats), 1))) + class_logits = self.llc(slot_gate_input) + + # Conditioned refer gate + slot_refer_input = self.gelu(self.h2r(torch.cat((encoded_slot_pooled.expand(pooled_output.size()), slot_gate_feats), 1))) + + # Sequence tagging + token_weights = weights[slot] + + # Value matching + if self.value_matching_weight > 0.0: + slot_values = torch.stack(list(batch['encoded_slot_values'][slot].values())) + slot_values = slot_values.expand((-1, pooled_output.size(0), -1)) + is_dropout_sample = (batch['dropout_value_feat'][slot].sum(2) > 0.0) + v_lbl = batch['value_labels'][slot] * is_dropout_sample + v_lbl_orig = v_lbl == 0 + orig_feats = slot_values * v_lbl_orig.transpose(0, 1).unsqueeze(2) + v_lbl_dropout = v_lbl == 1 + dropout_feats = batch['dropout_value_feat'][slot].expand(-1, slot_values.size(0), -1) * v_lbl_dropout.unsqueeze(2) + slot_values = orig_feats + dropout_feats.transpose(0, 1) + _, value_weights = self.value_att( + query=vectors[slot].unsqueeze(0), + key=slot_values, + value=slot_values, + need_weights=True) + value_weights = value_weights.squeeze(1) + + # Refer gate + if slot_ids is not None and slot_mask is not None: + refer_slots = encoded_slots_pooled.unsqueeze(1).expand(-1, pooled_output.size()[0], -1) + else: + refer_slots = torch.stack(list(batch['encoded_slots_pooled'].values())).expand(-1, pooled_output.size()[0], -1) + _, refer_weights = self.refer_att( + query=slot_refer_input.unsqueeze(0), + key=refer_slots, + value=refer_slots, + need_weights=True) + refer_weights = refer_weights.squeeze(1) + refer_logits = refer_weights + + per_slot_class_logits[slot] = class_logits + per_slot_token_weights[slot] = token_weights + per_slot_value_weights[slot] = value_weights + per_slot_refer_logits[slot] = refer_logits + + # If there are no labels, don't compute loss + if class_label_id is not None and token_pos is not None and refer_id is not None: + # If we are on multi-GPU, split add a dimension + if len(token_pos[slot].size()) > 1: + token_pos[slot] = token_pos[slot].squeeze(-1) + + # Sequence tagging loss + labels_clipped = torch.clamp(token_pos[slot].float(), min=0, max=1) + labels_clipped_scaled = labels_clipped / torch.clamp(labels_clipped.sum(1).unsqueeze(1), min=1) + no_seq_mask = labels_clipped_scaled.sum(1) == 0 + no_seq_w = 1 / batch_input_mask.sum(1) + labels_clipped_scaled += batch_input_mask * (no_seq_mask * no_seq_w).unsqueeze(1) + if self.token_loss_function == "mse": + token_loss = self.mse(token_weights, labels_clipped_scaled) # MSE should be better for scaled targets + else: + token_loss = self.binary_cross_entropy(token_weights, labels_clipped_scaled, reduction="none") + + # TODO: subsample negative examples due to their large number? + token_loss = token_loss.sum(1) + token_is_pointable = (token_pos[slot].sum(1) > 0).float() + token_loss *= token_is_pointable + + # Value matching loss + value_loss = torch.zeros(token_is_pointable.size(), device=token_is_pointable.device) + if self.value_matching_weight > 0.0: + value_labels_clipped = torch.clamp(value_labels[slot].float(), min=0, max=1) + value_labels_clipped /= torch.clamp(value_labels_clipped.sum(1).unsqueeze(1), min=1) + value_no_seq_mask = value_labels_clipped.sum(1) == 0 + value_no_seq_w = 1 / value_labels_clipped.size(1) + value_labels_clipped += (value_no_seq_mask * value_no_seq_w).unsqueeze(1) + if self.value_loss_function == "mse": + value_loss = self.mse(value_weights, value_labels_clipped) + else: + value_loss = self.binary_cross_entropy(value_weights, value_labels_clipped, reduction="none") + value_loss = value_loss.sum(1) + token_is_matchable = token_is_pointable + if self.tag_none_target: + token_is_matchable *= (token_pos[slot][:, 1] == 0).float() + value_loss *= token_is_matchable + + # Refer loss + # Re-definition is necessary here to make slot-independent prediction possible + self.refer_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=len(self.slot_list)) # Ignore 'none' target + refer_loss = self.refer_loss_fct(refer_logits, refer_id[slot]) + token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float() + refer_loss *= token_is_referrable + + # Class loss (i.e., slot gate loss) + class_loss = self.class_loss_fct(class_logits, class_label_id[slot]) + + if self.refer_index > -1: + per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss + self.value_matching_weight * value_loss + else: + per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss + self.value_matching_weight * value_loss + + total_loss += per_example_loss.sum() + total_cl_loss += class_loss.sum() + total_tk_loss += token_loss.sum() + total_tp_loss += value_loss.sum() + per_slot_per_example_loss[slot] = per_example_loss + per_slot_per_example_cl_loss[slot] = class_loss + per_slot_per_example_tk_loss[slot] = token_loss + per_slot_per_example_tp_loss[slot] = value_loss + + # add hidden states and attention if they are here + outputs = (total_loss, + total_cl_loss, + total_tk_loss, + total_tp_loss, + per_slot_per_example_loss, + per_slot_per_example_cl_loss, + per_slot_per_example_tk_loss, + per_slot_per_example_tp_loss, + per_slot_class_logits, + per_slot_token_weights, + per_slot_value_weights, + per_slot_refer_logits, + per_slot_att_weights,) + (vectors,) + + return outputs + + return TransformerForDST diff --git a/run_dst.py b/run_dst.py new file mode 100644 index 0000000..0f59253 --- /dev/null +++ b/run_dst.py @@ -0,0 +1,431 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import glob +import json +import pickle + +import torch + +from transformers import (WEIGHTS_NAME, + BertConfig, BertTokenizer, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, + RobertaConfig, RobertaTokenizer, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, + ElectraConfig, ElectraTokenizer, ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP) + +from modeling_dst import (TransformerForDST) +from data_processors import PROCESSORS +from utils_run import (print_header, set_seed, load_and_cache_examples) + +from dst_proto import (train_proto, evaluate_proto) +from dst_tag import (tag_values) +from dst_train import (train, evaluate) + +logger = logging.getLogger(__name__) + +ALL_MODELS = tuple(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP) +ALL_MODELS += tuple(ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP) +ALL_MODELS += tuple(ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP) + +class BertForDST(TransformerForDST('bert')): pass +class RobertaForDST(TransformerForDST('roberta')): pass +class ElectraForDST(TransformerForDST('electra')): pass + +MODEL_CLASSES = { + 'bert': (BertConfig, BertForDST, BertTokenizer), + 'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer), + 'electra': (ElectraConfig, ElectraForDST, ElectraTokenizer), +} + + +def main(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--task_name", default=None, type=str, required=True, + help="Name of the task (e.g., multiwoz21).") + parser.add_argument("--data_dir", default=None, type=str, required=True, + help="Task database.") + parser.add_argument("--dataset_config", default=None, type=str, required=True, + help="Dataset configuration file.") + parser.add_argument("--predict_type", default=None, type=str, required=True, + help="Portion of the data to perform prediction on (e.g., dev, test).") + parser.add_argument("--model_type", default=None, type=str, required=True, + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) + parser.add_argument("--model_name_or_path", default=None, type=str, required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) + parser.add_argument("--output_dir", default=None, type=str, required=True, + help="The output directory where the model checkpoints and predictions will be written.") + + # Other parameters + parser.add_argument("--config_name", default="", type=str, + help="Pretrained config name or path if not the same as model_name") + parser.add_argument("--tokenizer_name", default="", type=str, + help="Pretrained tokenizer name or path if not the same as model_name") + + parser.add_argument("--max_seq_length", default=384, type=int, + help="Maximum input length after tokenization. Longer sequences will be truncated, shorter ones padded.") + parser.add_argument("--do_train", action='store_true', + help="Whether to run training.") + parser.add_argument("--do_eval", action='store_true', + help="Whether to run eval on the <predict_type> set.") + parser.add_argument("--evaluate_during_training", action='store_true', + help="Rul evaluation during training at each logging step.") + parser.add_argument("--do_lower_case", action='store_true', + help="Set this flag if you are using an uncased model.") + + parser.add_argument("--dropout_rate", default=0.3, type=float, + help="Dropout rate for transformer representations.") + parser.add_argument("--class_loss_ratio", default=0.8, type=float, + help="The ratio applied on class loss in total loss calculation. " + "Should be a value in [0.0, 1.0]. " + "The ratio applied on token loss is (1-class_loss_ratio)/2. " + "The ratio applied on refer loss is (1-class_loss_ratio)/2.") + parser.add_argument("--proto_loss_function", type=str, default="mse", + help="Loss function for proto DST training (mse|ce). Default 'mse'") + parser.add_argument("--token_loss_function", type=str, default="mse", + help="Loss function for sequence tagging (mse|ce). Default 'mse'") + parser.add_argument("--value_loss_function", type=str, default="mse", + help="Loss function for value matching (mse|ce). Default 'mse'") + parser.add_argument("--slot_attention_heads", type=int, default=8, + help="Number of heads in multihead attention") + + parser.add_argument("--no_append_history", action='store_true', + help="Do not append the dialog history to each turn.") + parser.add_argument("--no_use_history_labels", action='store_true', + help="Do not label the history as well.") + parser.add_argument("--no_label_value_repetitions", action='store_true', + help="Do not label values that have been mentioned before.") + parser.add_argument("--swap_utterances", action='store_true', + help="Swap the turn utterances (default: usr|sys, swapped: sys|usr).") + parser.add_argument("--delexicalize_sys_utts", action='store_true', + help="Delexicalize the system utterances.") + parser.add_argument("--none_weight", type=float, default=1.0, + help="Weight for the none class of the slot gates (default: 1.0)") + + parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, + help="Batch size per GPU/CPU for training.") + parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, + help="Batch size per GPU/CPU for evaluation.") + parser.add_argument("--learning_rate", default=5e-5, type=float, + help="The initial learning rate for Adam.") + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument("--weight_decay", default=0.0, type=float, + help="Weight deay if we apply some.") + parser.add_argument("--adam_epsilon", default=1e-8, type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, + help="Max gradient norm.") + parser.add_argument("--num_train_epochs", default=3.0, type=float, + help="Total number of training epochs to perform.") + parser.add_argument("--max_steps", default=-1, type=int, + help="If > 0: set total number of training steps to perform. Overwrites num_train_epochs.") + parser.add_argument("--warmup_proportion", default=0.0, type=float, + help="Linear warmup over warmup_proportion * steps.") + parser.add_argument("--patience", type=int, default=-1, + help="Patience for early stopping. When -1, no patience is used") + parser.add_argument("--early_stop_criterion", type=str, default="goal", + help="Early stopping criterion (goal|loss). Default 'goal'") + + parser.add_argument('--logging_steps', type=int, default=10, + help="Log every X updates steps.") + parser.add_argument('--save_steps', type=int, default=0, + help="Save checkpoint every X updates steps. Overwritten by --save_epochs.") + parser.add_argument('--save_epochs', type=int, default=0, + help="Save checkpoint every X epochs. Overwrites --save_steps.") + parser.add_argument("--eval_all_checkpoints", action='store_true', + help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") + parser.add_argument("--no_cuda", action='store_true', + help="Whether not to use CUDA when available") + parser.add_argument('--overwrite_cache', action='store_true', + help="Overwrite the cached training and evaluation sets") + parser.add_argument('--no_cache', action='store_true', + help="Don't use cached training and evaluation sets") + parser.add_argument('--cache_suffix', default="", type=str, + help="Optionally add a suffix to the cache files (use trailing _).") + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + + parser.add_argument("--local_rank", type=int, default=-1, + help="local_rank for distributed training on gpus") + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit (mixed) precision instead of 32-bit") + parser.add_argument('--local_files_only', action='store_true', + help="Whether to only load local model files (useful when working offline).") + + parser.add_argument("--training_phase", type=int, default=-1, help="-1: regular training, 0: proto training, 1: tagging, 2: spanless training") + + # Slot value dropout, token noising, history dropout + parser.add_argument("--svd", default=0.0, type=float, + help="Slot value dropout ratio (default: 0.0)") + parser.add_argument('--use_td', action='store_true', + help="Do slot value dropout using random tokens, i.e., token noising. Requires --svd") + parser.add_argument("--td_ratio", type=float, default=1.0, + help="Fraction of token vocabulary to draw replacements from for token noising. Requires --use_td") + parser.add_argument('--svd_for_all_slots', action='store_true', + help="By default, SVD/TD is used for noncategorical slots only. Set to use SVD/TD for categorical slots as well") + parser.add_argument("--hd", type=float, default=0.0, + help="History dropout ratio") + + # Spanless training + parser.add_argument('--tag_none_target', action='store_true', + help="Use <none>/[NONE] as target when tagging negative samples") + parser.add_argument("--rand_seq_max_len", type=int, default=4, + help="Maximum length of random sequences for proto DST training") + parser.add_argument("--proto_neg_sample_ratio", type=float, default=0.1, + help="Negative sample ratio for proto DST training. Requires --tag_none_target") + + # Value matching + parser.add_argument("--value_matching_weight", type=float, default=0.0, + help="Value matching weight. When 0.0, value matching is not used") + + args = parser.parse_args() + + assert args.warmup_proportion >= 0.0 and args.warmup_proportion <= 1.0 + assert args.svd >= 0.0 and args.svd <= 1.0 + assert args.hd >= 0.0 and args.hd <= 1.0 + assert args.td_ratio >= 0.0 and args.td_ratio <= 1.0 + assert args.proto_neg_sample_ratio >= 0.0 and args.proto_neg_sample_ratio <= 1.0 + assert args.training_phase in [-1, 0, 1, 2] + + task_name = args.task_name.lower() + if task_name not in PROCESSORS: + raise ValueError("Task not found: %s" % (task_name)) + + processor = PROCESSORS[task_name](args.dataset_config, args.data_dir) + slot_list = processor.slot_list + noncategorical = processor.noncategorical + class_types = processor.class_types + class_labels = len(class_types) + + # Setup CUDA, GPU & distributed training + if args.local_rank == -1 or args.no_cuda: + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend='nccl') + args.n_gpu = 1 + args.device = device + + # Setup logging + logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) + + # Set seed + set_seed(args) + + # Load pretrained model and tokenizer + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab + + args.model_type = args.model_type.lower() + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, local_files_only=args.local_files_only) + + # Add DST specific parameters to config + config.max_seq_length = args.max_seq_length + config.dropout_rate = args.dropout_rate + config.class_loss_ratio = args.class_loss_ratio + config.slot_list = slot_list + config.noncategorical = noncategorical + config.class_types = class_types + config.class_labels = class_labels + config.tag_none_target = args.tag_none_target + config.value_matching_weight = args.value_matching_weight + config.none_weight = args.none_weight + config.proto_loss_function = args.proto_loss_function + config.token_loss_function = args.token_loss_function + config.value_loss_function = args.value_loss_function + config.slot_attention_heads = args.slot_attention_heads + + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, local_files_only=args.local_files_only) + model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config, local_files_only=args.local_files_only) + + if args.tag_none_target: + if args.model_type == "roberta": + tokenizer.add_special_tokens({'additional_special_tokens': ['<none>']}) + else: + tokenizer.add_special_tokens({'additional_special_tokens': ['[NONE]']}) + model.resize_token_embeddings(len(tokenizer)) + config.vocab_size = len(tokenizer) + + logger.info("Updated model config: %s" % config) + + if args.local_rank == 0: + torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab + + model.to(args.device) + + logger.info("Training/evaluation parameters %s", args) + + # Training + if args.do_train: + global_step = 0 + proto_checkpoints = [] + checkpoints = [] + if os.path.exists(args.output_dir) and os.listdir(args.output_dir): + proto_checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/proto_checkpoint*/' + WEIGHTS_NAME, recursive=True))) + checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/checkpoint*/' + WEIGHTS_NAME, recursive=True))) + + if args.training_phase in [-1, 0, 1]: + train_dataset = load_and_cache_examples(args, model, tokenizer, processor, dset="train", evaluate=False) + dev_dataset = None + if args.do_train and args.evaluate_during_training: + dev_dataset = load_and_cache_examples(args, model, tokenizer, processor, dset=args.predict_type, evaluate=True) + + # Step 1: Pretrain attention layer for random sequence tagging. + if args.training_phase == 0: + if len(proto_checkpoints) == 0: + global_step, tr_loss = train_proto(args, train_dataset, dev_dataset, model, tokenizer, processor) + logger.info(" global_step = %s, average loss = %s (proto)", global_step, tr_loss) + else: + logger.warning(" Preconditions for proto training not fulfilled! Skipping.") + + # Step 2: Get labels for slot values. + if args.training_phase == 1: + if len(checkpoints) == 0 and len(proto_checkpoints) > 0: + # Load correct proto checkpoint, otherwise last model state is used, which is not desired. + proto_checkpoint = proto_checkpoints[-1] + model = model_class.from_pretrained(proto_checkpoint) + model.to(args.device) + train_dataset.update_model(model) + dev_dataset.update_model(model) + max_tag_goal = 0.0 + max_tag_thresh = 0.0 + max_dae = True # default should be true + if not os.path.exists(os.path.join(args.output_dir, "tagging_threshold.txt")): + for tagging_threshold in [0.2, 0.3, 0.4]: + max_dae = True # default should be true + for dae in [True]: + file_name = os.path.join(args.output_dir, "automatic_labels_%s_%s.pickle" % (tagging_threshold, dae)) + automatic_labels, tag_eval = tag_values(args, train_dataset, model, tokenizer, processor, no_print=(tagging_threshold > 0.0 or dae is True), prefix=global_step, threshold=tagging_threshold, dae=dae) + logger.info("tagging_threshold: %s, dae: %s %s" % (tagging_threshold, dae, tag_eval)) + tag_goal = tag_eval['eval_accuracy_goal'] + pickle.dump(automatic_labels, open(file_name, "wb")) + if tag_goal > max_tag_goal: + max_tag_goal = tag_goal + max_tag_thresh = tagging_threshold + max_dae = dae + with open(os.path.join(args.output_dir, 'tagging_threshold.txt'), 'w') as f: + f.write("%f %d" % (max_tag_thresh, max_dae)) + else: + logger.warning(" Preconditions for tagging not fulfilled! Skipping.") + + # Step 3: Train full model. + if args.training_phase == 2: + if len(checkpoints) == 0 and os.path.exists(os.path.join(args.output_dir, "tagging_threshold.txt")): + with open(os.path.join(args.output_dir, 'tagging_threshold.txt'), 'r') as f: + max_tag_thresh, max_dae = f.readline().split() + max_tag_thresh = float(max_tag_thresh) + max_dae = bool(int(max_dae)) + al_file_name = os.path.join(args.output_dir, "automatic_labels_%s_%s.pickle" % (max_tag_thresh, max_dae)) + logger.info("Loading automatic labels: %s" % (al_file_name)) + automatic_labels = pickle.load(open(al_file_name, "rb")) + train_dataset = load_and_cache_examples(args, model, tokenizer, processor, dset="train", evaluate=False, automatic_labels=automatic_labels) + dev_dataset = None + if args.do_train and args.evaluate_during_training: + dev_dataset = load_and_cache_examples(args, model, tokenizer, processor, dset=args.predict_type, evaluate=True) + train_dataset.compute_vectors() + dev_dataset.compute_vectors() + global_step, tr_loss = train(args, train_dataset, dev_dataset, automatic_labels, model, tokenizer, processor) + logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) + else: + logger.warning(" Preconditions for spanless training not fulfilled! Skipping.") + + # Train full model with original training. + if args.training_phase == -1: + if len(checkpoints) == 0: + train_dataset.compute_vectors() + dev_dataset.compute_vectors() + global_step, tr_loss = train(args, train_dataset, dev_dataset, None, model, tokenizer, processor) + logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) + else: + logger.warning(" Preconditions for training not fulfilled! Skipping.") + + # Save the trained model and the tokenizer + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + # Create output directory if needed + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + if args.save_steps == 0 and args.save_epochs == 0 and args.patience < 0: + logger.info("Saving model checkpoint to %s", args.output_dir) + # Save a trained model, configuration and tokenizer using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + model_to_save.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) + + # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory + results = [] + if args.do_eval and args.local_rank in [-1, 0]: + dataset = load_and_cache_examples(args, model, tokenizer, processor, dset=args.predict_type, evaluate=True) + dataset.compute_vectors() + + output_eval_file = os.path.join(args.output_dir, "eval_res%s.%s.json" % (args.cache_suffix, args.predict_type)) + checkpoints = [args.output_dir] + if args.eval_all_checkpoints: + checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) + logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs + + logger.info("Evaluate the following checkpoints: %s", checkpoints) + + for cItr, checkpoint in enumerate(checkpoints): + # Reload the model + checkpoint_name = checkpoint.split('/')[-1] + model = model_class.from_pretrained(checkpoint) + model.slot_list = slot_list # Necessary for slot independent DST, as slots might differ during evaluation + model.noncategorical = noncategorical + model.to(args.device) + dataset.update_model(model) + + # Evaluate + if "proto" in checkpoint_name: + result = evaluate_proto(args, dataset, model, tokenizer, processor, prefix=checkpoint_name) + else: + result = evaluate(args, dataset, model, tokenizer, processor, prefix=checkpoint_name) + + result_dict = {k: float(v) for k, v in result.items()} + result_dict["checkpoint_name"] = checkpoint_name + results.append(result_dict) + + for key in sorted(result_dict.keys()): + logger.info("%s = %s", key, str(result_dict[key])) + + with open(output_eval_file, "w") as f: + json.dump(results, f, indent=2) + + return results + + +if __name__ == "__main__": + main() diff --git a/tensorlistdataset.py b/tensorlistdataset.py new file mode 100644 index 0000000..238a02e --- /dev/null +++ b/tensorlistdataset.py @@ -0,0 +1,57 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.utils.data import Dataset + + +class TensorListDataset(Dataset): + r"""Dataset wrapping tensors, tensor dicts and tensor lists. + + Arguments: + *data (Tensor or dict or list of Tensors): tensors that have the same size + of the first dimension. + """ + + def __init__(self, *data): + if isinstance(data[0], dict): + size = list(data[0].values())[0].size(0) + elif isinstance(data[0], list): + size = data[0][0].size(0) + else: + size = data[0].size(0) + for element in data: + if isinstance(element, dict): + assert all(size == tensor.size(0) for name, tensor in element.items()) # dict of tensors + elif isinstance(element, list): + assert all(size == tensor.size(0) for tensor in element) # list of tensors + else: + assert size == element.size(0) # tensor + self.size = size + self.data = data + + def __getitem__(self, index): + result = [] + for element in self.data: + if isinstance(element, dict): + result.append({k: v[index] for k, v in element.items()}) + elif isinstance(element, list): + result.append(v[index] for v in element) + else: + result.append(element[index]) + return tuple(result) + + def __len__(self): + return self.size diff --git a/utils_dst.py b/utils_dst.py new file mode 100644 index 0000000..53a833d --- /dev/null +++ b/utils_dst.py @@ -0,0 +1,1056 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import six +import numpy as np +import os +import re +import pickle +import random +import copy +from scipy.stats import invgauss +from tqdm import tqdm + +import torch +from torch.utils.data import Dataset + +logger = logging.getLogger(__name__) + + +class DSTExample(object): + """ + A single training/test example for the DST dataset. + """ + + def __init__(self, + guid, + text_a, + text_b, + text_a_label=None, + text_b_label=None, + values=None, + inform_label=None, + inform_slot_label=None, + refer_label=None, + diag_state=None, + slot_update=None, + class_label=None): + self.guid = guid + self.text_a = text_a + self.text_b = text_b + self.text_a_label = text_a_label + self.text_b_label = text_b_label + self.values = values + self.inform_label = inform_label + self.inform_slot_label = inform_slot_label + self.refer_label = refer_label + self.diag_state = diag_state + self.slot_update = slot_update + self.class_label = class_label + + def __str__(self): + return self.__repr__() + + def __repr__(self): + s = "" + s += "guid: %s" % (self.guid) + s += ", text_a: %s" % (self.text_a) + s += ", text_b: %s" % (self.text_b) + if self.text_a_label: + s += ", text_a_label: %s" % (self.text_a_label) + if self.text_b_label: + s += ", text_b_label: %s" % (self.text_b_label) + if self.values: + s += ", values: %s" % (self.values) + if self.inform_label: + s += ", inform_label: %s" % (self.inform_label) + if self.inform_slot_label: + s += ", inform_slot_label: %s" % (self.inform_slot_label) + if self.refer_label: + s += ", refer_label: %s" % (self.refer_label) + if self.diag_state: + s += ", diag_state: %s" % (self.diag_state) + if self.slot_update: + s += ", slot_update: %s" % (self.slot_update) + if self.class_label: + s += ", class_label: %s" % (self.class_label) + return s + + +class InputFeatures(object): + """A single set of features of data.""" + + def __init__(self, + input_ids, + input_mask, + segment_ids, + usr_mask, + start_pos=None, + values=None, + inform=None, + inform_slot=None, + refer_id=None, + diag_state=None, + class_label_id=None, + hst_boundaries=None, + guid="NONE"): + self.guid = guid + self.input_ids = input_ids + self.input_mask = input_mask + self.segment_ids = segment_ids + self.usr_mask = usr_mask + self.start_pos = start_pos + self.values = values + self.inform = inform + self.inform_slot = inform_slot + self.refer_id = refer_id + self.diag_state = diag_state + self.class_label_id = class_label_id + self.hst_boundaries = hst_boundaries + + +class TrippyDataset(Dataset): + def __init__(self, args, examples, model, tokenizer, processor, dset="train", evaluate=False, automatic_labels=None): + self.args = args + self.examples = examples + self.automatic_labels = automatic_labels + self.model = model + self.tokenizer = tokenizer + self.slot_list = model.slot_list + self.slot_dict = model.slot_list + self.noncategorical = model.noncategorical + self.class_list = model.class_types + self.class_dict = model.class_types + self.evaluate = evaluate + self.encoded_slots_pooled = None + self.encoded_slots_seq = None + self.encoded_slot_values = None + self.negative_samples = None + self.tokenized_sequences_ids = None + self.dropout_value_seq = None + self.dropout_value_list = None + self.dset = dset + self.mode = "default" # default, proto, tag + + self.label_maps = copy.deepcopy(processor.label_maps) + self.value_list = copy.deepcopy(processor.value_list['train']) + if evaluate: + for s in processor.value_list[dset]: + for v in processor.value_list[dset][s]: + if v not in self.value_list[s]: + self.value_list[s][v] = processor.value_list[dset][s][v] + else: + self.value_list[s][v] += processor.value_list[dset][s][v] + + if examples is None: + logger.warn("Creating empty dataset. You should load or build features before use.") + self.features = None + else: + self.features = self._convert_examples_to_features(examples=examples, + slot_list=self.slot_list, + class_list=self.class_list, + model_type=self.args.model_type, + max_seq_length=self.args.max_seq_length, + automatic_labels=self.automatic_labels) + + def proto(self): + self.mode = "proto" + + def tag(self): + self.mode = "tag" + + def reset(self): + self.mode = "default" + + def update_model(self, model): + self.model = model + self.slot_list = model.slot_list + self.slot_dict = model.slot_list + self.class_list = model.class_types + self.class_dict = model.class_types + + def load_features_from_file(self, cached_file): + logger.info("Loading features from cached file %s", cached_file) + self.features, self.examples = torch.load(cached_file) + self.size = len(self.features) + self._build_dataset() + + def save_features_to_file(self, cached_file): + logger.info("Saving features into cached file %s", cached_file) + torch.save((self.features, self.examples), cached_file) + + def build_features_from_examples(self, examples): + self.examples = examples + self.features = self._convert_examples_to_features(examples=self.examples, + slot_list=self.slot_list, + class_list=self.class_list, + model_type=self.args.model_type, + max_seq_length=self.args.max_seq_length, + automatic_labels=self.automatic_labels) + self.size = len(self.features) + self._build_dataset() + + def _build_dataset(self): + assert self.features is not None + # Convert to Tensors and build dataset + all_input_ids = torch.tensor([f.input_ids for f in self.features], dtype=torch.long) + all_input_mask = torch.tensor([f.input_mask for f in self.features], dtype=torch.long) + all_segment_ids = torch.tensor([f.segment_ids for f in self.features], dtype=torch.long) + all_usr_mask = torch.tensor([f.usr_mask for f in self.features], dtype=torch.long) + all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) + f_start_pos = [f.start_pos for f in self.features] + f_inform_slot_ids = [f.inform_slot for f in self.features] + f_refer_ids = [f.refer_id for f in self.features] + f_diag_state = [f.diag_state for f in self.features] + f_class_label_ids = [f.class_label_id for f in self.features] + all_start_positions = {} + all_inform_slot_ids = {} + all_refer_ids = {} + all_diag_state = {} + all_class_label_ids = {} + for s in self.slot_list: + all_start_positions[s] = torch.tensor([f[s] for f in f_start_pos], dtype=torch.long) + all_inform_slot_ids[s] = torch.tensor([f[s] for f in f_inform_slot_ids], dtype=torch.long) + all_refer_ids[s] = torch.tensor([f[s] for f in f_refer_ids], dtype=torch.long) + all_diag_state[s] = torch.tensor([f[s] for f in f_diag_state], dtype=torch.long) + all_class_label_ids[s] = torch.tensor([f[s] for f in f_class_label_ids], dtype=torch.long) + data = {'input_ids': all_input_ids, 'input_mask': all_input_mask, + 'segment_ids': all_segment_ids, 'usr_mask': all_usr_mask, + 'start_pos': all_start_positions, + 'inform_slot_id': all_inform_slot_ids, 'refer_id': all_refer_ids, + 'diag_state': all_diag_state if not self.evaluate else {}, + 'class_label_id': all_class_label_ids, 'example_id': all_example_index} + for _, element in data.items(): + if isinstance(element, dict): + assert all(self.size == tensor.size(0) for name, tensor in element.items()) # dict of tensors + elif isinstance(element, list): + assert all(self.size == tensor.size(0) for tensor in element) # list of tensors + else: + assert self.size == element.size(0) # tensor + self.data = data + + def encode_slot_values(self, val_rep_mode="represent", val_rep="full"): + def get_val_desc(val_rep, slot, value): + if val_rep == "full": + text = "%s is %s ." % (self.slot_dict[slot], value) + elif val_rep == "v": + text = value + else: + logger.error("Unknown val_rep (%s). Aborting." % (val_rep)) + exit(1) + return text + + # Separate values by slots, because some slots share values + self.encoded_slot_values = {} + self.encoded_slot_values_variants = {} # For tagging only. Keep separate to not break rest of code + self.encoded_dropout_slot_values = {} # For training with token dropout + for slot in self.slot_dict: + self.encoded_slot_values[slot] = {} + self.encoded_slot_values_variants[slot] = {} + self.encoded_dropout_slot_values[slot] = {} + for value in self.value_list[slot]: + # Encode value variants, if existent. + if value in self.label_maps: + for variant in self.label_maps[value]: + text = get_val_desc(val_rep, slot, variant) + input_ids, input_mask = self._build_input(text) + encoded_slot_value, _ = self._encode_text(text, input_ids.unsqueeze(0), input_mask.unsqueeze(0), mode=val_rep_mode) + if isinstance(encoded_slot_value, dict): + encoded_slot_value = encoded_slot_value[slot] # Keep only slot-specific encoding + encoded_slot_value = encoded_slot_value.cpu() + self.encoded_slot_values_variants[slot][variant] = encoded_slot_value + # Encode regular values. + text = get_val_desc(val_rep, slot, value) + input_ids, input_mask = self._build_input(text) + encoded_slot_value, _ = self._encode_text(text, input_ids.unsqueeze(0), input_mask.unsqueeze(0), mode=val_rep_mode) + if isinstance(encoded_slot_value, dict): + encoded_slot_value = encoded_slot_value[slot] # Keep only slot-specific encoding + encoded_slot_value = encoded_slot_value.cpu() + self.encoded_slot_values[slot][value] = encoded_slot_value + # Encode dropped out values. + if self.dropout_value_list is not None and slot in self.dropout_value_list and value in self.dropout_value_list[slot]: + for dropout_value_seq in self.dropout_value_list[slot][value]: + v_dropped_out = ''.join(self.tokenizer.convert_ids_to_tokens(dropout_value_seq)) + if "\u0120" in v_dropped_out: + v_dropped_out = re.sub("\u0120", " ", v_dropped_out) + v_dropped_out = v_dropped_out.strip() + else: + v_dropped_out = re.sub("(^| )##", "", v_dropped_out) + assert "\u0122" not in value + v_tmp = re.sub(self.tokenizer.unk_token, "\u0122", v_dropped_out) + v_tmp = re.sub(" ", "", v_tmp) + text_dropped_out = get_val_desc(val_rep, slot, v_dropped_out) + input_ids_dropped_out, input_mask_dropped_out = self._build_input(text_dropped_out) + encoded_slot_value_dropped_out, _ = self._encode_text(text, input_ids.unsqueeze(0), input_mask.unsqueeze(0), mode=val_rep_mode) + if isinstance(encoded_slot_value_dropped_out, dict): + encoded_slot_value_dropped_out = encoded_slot_value_dropped_out[slot] # Keep only slot-specific encoding + encoded_slot_value_dropped_out = encoded_slot_value_dropped_out.cpu() + self.encoded_dropout_slot_values[slot][tuple(dropout_value_seq)] = encoded_slot_value_dropped_out + logger.info("Slot values encoded") + + def save_encoded_slot_values(self, dir_name=""): + file_name = "encoded_slot_values_%s.pickle" % self.dset + pickle.dump(self.encoded_slot_values, open(os.path.join(dir_name, file_name), "wb")) + logger.info("Saved encoded slot values to %s" % os.path.join(dir_name, file_name)) + + def load_encoded_slot_values(self, dir_name=""): + result = False + file_name = os.path.join(dir_name, "encoded_slot_values_%s.pickle" % self.dset) + if os.path.exists(file_name): + result = True + self.encoded_slot_values = pickle.load(open(file_name, "rb")) + logger.info("Loaded encoded slot values from %s -> %s" % (file_name, result)) + return result + + def encode_slots(self, train=False): + self.encoded_slots_pooled = {} + self.encoded_slots_seq = {} + self.encoded_slots_ids = {} + for slot in self.slot_dict: + text = slot + " . " + self.slot_dict[slot] + " ." + input_ids, input_mask = self._build_input(text) + encoded_slot_pooled, encoded_slot_seq = self._encode_text(text, + input_ids.unsqueeze(0), input_mask.unsqueeze(0), + mode="encode", train=train) + self.encoded_slots_pooled[slot] = encoded_slot_pooled + self.encoded_slots_seq[slot] = encoded_slot_seq + self.encoded_slots_ids[slot] = (input_ids, input_mask) + logger.info("Slots encoded") + + def save_encoded_slots(self, dir_name=""): + pickle.dump(self.encoded_slots_pooled, open(os.path.join(dir_name, "encoded_slots_pooled.pickle"), "wb")) + pickle.dump(self.encoded_slots_seq, open(os.path.join(dir_name, "encoded_slots_seq.pickle"), "wb")) + pickle.dump(self.encoded_slots_ids, open(os.path.join(dir_name, "encoded_slots_ids.pickle"), "wb")) + logger.info("Saved encoded slots to %s" % dir_name) + + def load_encoded_slots(self, dir_name=""): + result = False + try: + self.encoded_slots_pooled = pickle.load(open(os.path.join(dir_name, "encoded_slots_pooled.pickle"), "rb")) + self.encoded_slots_seq = pickle.load(open(os.path.join(dir_name, "encoded_slots_seq.pickle"), "rb")) + self.encoded_slots_ids = pickle.load(open(os.path.join(dir_name, "encoded_slots_ids.pickle"), "rb")) + result = True + except FileNotFoundError: + logger.warn("Loading encoded slots from %s failed" % dir_name) + logger.info("Loaded encoded slots from %s" % dir_name) + return result + + def compute_vectors(self): + self.model.eval() # No dropout + if not self.load_encoded_slots(self.args.output_dir): + self.encode_slots() + + def distance(self, x, y): + return torch.dist(x, y, p=2) # Euclidean/L2 + + def query_values(self, turn_representation): + def confidence(d, idx): + d_lol = d[:idx] + d[idx + 1:] + return max(1 - (d[idx] / ((sum(d_lol) + 1e-8) / (len(d_lol) + 1e-8))), 0) + + result = {} + for slot in self.slot_list: + result[slot] = [] + for e in turn_representation[slot]: + distances = [] + keys = [] + for v in self.value_list[slot]: + distances.append(self.distance(self.encoded_slot_values[slot][v], e).item()) + keys.append(v) + idx = np.argmin(distances) + conf = confidence(distances, idx) + sorted_dists = sorted(zip(keys, distances), key = lambda t: t[1]) + result[slot].append((keys[idx], "%.4f" % distances[idx], "%.4f" % conf, sorted_dists)) + return result + + def tokenize_sequences(self, max_len=1, train=False): + self.tokenized_sequences_ids = {} + self.tokenized_sequences_list = [] + self.seqs_per_sample = {} + for f in tqdm(self.features, desc="Tokenize sequences"): + seq = f.input_ids[1:f.input_ids.index(self.tokenizer.pad_token_id) if self.tokenizer.pad_token_id in f.input_ids else -1] + example_id = f.guid + self.seqs_per_sample[example_id] = [] + + # Consider full words for max_len, not just tokens. + token_seq = self.tokenizer.convert_ids_to_tokens(seq) + word_list = [] + idx_list = [] + for t_itr, t in enumerate(token_seq): + if ("roberta" in self.args.model_type and t[0] == "\u0120") or t[0:2] != "##" or \ + t in [self.tokenizer.unk_token, self.tokenizer.bos_token, self.tokenizer.eos_token, + self.tokenizer.sep_token, self.tokenizer.pad_token, self.tokenizer.cls_token, + self.tokenizer.mask_token] or t in self.tokenizer.additional_special_tokens: + word_list.append([t]) + idx_list.append([t_itr]) + else: + word_list[-1].append(t) + idx_list[-1].append(t_itr) + + # Keep list of all sequences in each sample. + seq_len = len(word_list) + for start in range(seq_len): + for offset in range(1, 1 + max_len): + if start + offset <= seq_len: + subseq = seq[idx_list[start][0]:idx_list[start + offset - 1][-1] + 1] + if self.tokenizer.sep_token_id in subseq: + continue + if tuple(subseq) not in self.tokenized_sequences_ids: + input_ids, input_mask = self._build_input(subseq, is_token_ids=True) + self.tokenized_sequences_ids[tuple(subseq)] = (input_ids.cpu(), input_mask.cpu()) + self.seqs_per_sample[example_id].append(subseq) + self.tokenized_sequences_list = list(self.tokenized_sequences_ids) + + def save_tokenized_sequences(self, dir_name="", overwrite=True): + file_name = os.path.join(dir_name, "tokenized_sequences_ids_%s.pickle" % self.dset) + if overwrite or not os.path.exists(file_name): + pickle.dump(self.tokenized_sequences_ids, open(file_name, "wb")) + logger.info("Saved tokenized sequences to %s" % dir_name) + file_name = os.path.join(dir_name, "seqs_per_sample_%s.pickle" % self.dset) + if overwrite or not os.path.exists(file_name): + pickle.dump(self.seqs_per_sample, open(file_name, "wb")) + + def load_tokenized_sequences(self, dir_name=""): + result = False + file_name = os.path.join(dir_name, "tokenized_sequences_ids_%s.pickle" % self.dset) + if os.path.exists(file_name): + result = True + self.tokenized_sequences_ids = pickle.load(open(file_name, "rb")) + self.tokenized_sequences_list = list(self.tokenized_sequences_ids) + logger.info("Loaded tokenized sequences from %s" % dir_name) + file_name = os.path.join(dir_name, "seqs_per_sample_%s.pickle" % self.dset) + if os.path.exists(file_name): + self.seqs_per_sample = pickle.load(open(file_name, "rb")) + return result + + def update_samples_for_proto(self, max_len=1): + def list_in_list(a, lst): + for i in range(len(lst) + 1 - len(a)): + if lst[i:i + len(a)] == a: + return True + return False + + if self.tokenized_sequences_ids is None: + logger.warn("Updating negative samples, but values not encoded yet. Encoding now.") + self.tokenize_sequences(max_len=max_len) + result = {} + self.positive_samples_for_proto_pos = {} + self.positive_samples_for_proto_input_ids = {} + self.positive_samples_for_proto_input_mask = {} + self.negative_samples_for_proto_pos = {} + self.negative_samples_for_proto_input_ids = {} + self.negative_samples_for_proto_input_mask = {} + for index in tqdm(range(self.size), desc="Update negative samples for proto training"): + b_index = index % self.args.per_gpu_train_batch_size + if b_index == 0: + offset = min(self.args.per_gpu_train_batch_size, self.size - index + 1) + for key, element in self.data.items(): + if isinstance(element, dict): + result[key] = {k: v[index:index + offset] for k, v in element.items()} + elif isinstance(element, list): + result[key] = [v[index:index + offset] for v in element] + else: + result[key] = element[index:index + offset] + + input_ids = result['input_ids'][b_index].tolist() + input_ids = input_ids[0:input_ids.index(1) if self.tokenizer.pad_token_id in input_ids else -1] + + # Pick a random sequence in the (entire) input as pos example + guid = self.features[result['example_id'][b_index]].guid + seq_list = self.seqs_per_sample[guid] + seq = random.choice(seq_list) + seq_len = len(seq) + seq_tuple = tuple(seq) + + self.positive_samples_for_proto_input_ids[index] = self.tokenized_sequences_ids[seq_tuple][0] + self.positive_samples_for_proto_input_mask[index] = self.tokenized_sequences_ids[seq_tuple][1] + # Find all occurrences in input. + self.positive_samples_for_proto_pos[index] = torch.zeros(self.args.max_seq_length, dtype=torch.long) + for i in range(len(input_ids) + 1 - seq_len): + if input_ids[i:i + seq_len] == seq: + self.positive_samples_for_proto_pos[index][i:i + seq_len] = 1 + + # Pick a random sequence not in any location in the input as neg example + subseq = random.choice(self.tokenized_sequences_list) + while list_in_list(list(subseq), input_ids): + subseq = random.choice(self.tokenized_sequences_list) + self.negative_samples_for_proto_input_ids[index] = self.tokenized_sequences_ids[subseq][0] + self.negative_samples_for_proto_input_mask[index] = self.tokenized_sequences_ids[subseq][1] + self.negative_samples_for_proto_pos[index] = torch.zeros(self.args.max_seq_length, dtype=torch.long) + if self.args.tag_none_target: + self.negative_samples_for_proto_pos[index][1] = 1 + assert len(self.positive_samples_for_proto_pos) == self.size + assert len(self.negative_samples_for_proto_pos) == self.size + + def dropout_input(self): + if self.evaluate or self.mode != "default" or self.args.svd == 0.0: + return + + # Preparation. + self.data['input_ids_dropout'] = copy.deepcopy(self.data['input_ids']) + joint_text_label = sum(list(self.data['start_pos'].values())) + joint_text_label_noncat = torch.zeros(joint_text_label.size()) + joint_text_label_cat = torch.zeros(joint_text_label.size()) + for slot in self.slot_list: + if slot in self.noncategorical: + joint_text_label_noncat += self.data['start_pos'][slot] + else: + joint_text_label_cat += self.data['start_pos'][slot] + rn = np.random.random_sample(joint_text_label.size()) + if self.args.use_td: + assert self.args.td_ratio >= 0.0 and self.args.td_ratio <= 1.0 + top_n = int(self.tokenizer.vocab_size * self.args.td_ratio) + svd_mask = (joint_text_label > 0) * (rn <= self.args.svd) + svd_mask_noncat = (joint_text_label_noncat > 0) * (rn <= self.args.svd) + svd_mask_cat = (joint_text_label_cat > 0) * (rn <= self.args.svd) + + self.dropout_value_seq = {slot: {} for slot in self.slot_list} + self.dropout_value_list = {slot: {} for slot in self.slot_list} + for i, input_ids_dropout in tqdm(enumerate(self.data['input_ids_dropout']), desc="Dropout inputs"): + # Slot value dropout. + if self.args.svd > 0.0: + indices_to_drop_out = (svd_mask[i] == 1).nonzero(as_tuple=True)[0] + indices_to_drop_out_noncat = (svd_mask_noncat[i] == 1).nonzero(as_tuple=True)[0] + indices_to_drop_out_cat = (svd_mask_cat[i] == 1).nonzero(as_tuple=True)[0] + if self.args.use_td: + random_token_id = random.sample(range(top_n), len(indices_to_drop_out_noncat)) + while self.tokenizer.sep_token_id in random_token_id or \ + self.tokenizer.pad_token_id in random_token_id or \ + self.tokenizer.cls_token_id in random_token_id: + random_token_id = random.sample(range(top_n), len(indices_to_drop_out_noncat)) + for k in range(len(indices_to_drop_out_noncat)): + input_ids_dropout[indices_to_drop_out_noncat[k]] = random_token_id[k] + input_ids_dropout[indices_to_drop_out_cat] = self.tokenizer.unk_token_id + else: + input_ids_dropout[indices_to_drop_out] = self.tokenizer.unk_token_id + + # Remember dropped-out values. + example_id = self.data['example_id'][i] + for slot in self.slot_list: + orig_value = self.features[example_id].values[slot] + if orig_value not in self.dropout_value_list[slot]: + self.dropout_value_list[slot][orig_value] = [] + if not self.args.svd_for_all_slots and slot not in self.noncategorical: + continue + value_indices = (self.data['start_pos'][slot][i] > 0).nonzero(as_tuple=True)[0] + prev_si = None + spans = [] + for si in value_indices: + if prev_si is None or si - prev_si > 1: + spans.append([]) + spans[-1].append(si) + prev_si = si + for s_itr in range(len(spans)): + spans[s_itr] = torch.stack(spans[s_itr]) + # In case of length variations, revert dropout (this might however never happen). + # Else, make sure that all mentions of the same value are identically dropped out. + if len(spans) > 1: + is_ambiguous = False + for span in spans[1:]: + if len(span) != len(spans[0]): + is_ambiguous = True + break + if is_ambiguous: + self.data['input_ids_dropout'][i][value_indices] = self.data['input_ids'][i][value_indices] + else: + for span in spans[1:]: + self.data['input_ids_dropout'][i][span] = self.data['input_ids_dropout'][i][spans[0]] + # We only need to check if spans[0] differs from the original seqs, since all s in spans are identical now. + if len(spans) > 0 and not torch.equal(self.data['input_ids'][i][spans[0]], self.data['input_ids_dropout'][i][spans[0]]): + self.dropout_value_seq[slot][i] = self.data['input_ids_dropout'][i][spans[0]].tolist() + if self.dropout_value_seq[slot][i] not in self.dropout_value_list[slot][orig_value]: + self.dropout_value_list[slot][orig_value].append(self.dropout_value_seq[slot][i]) + + def __getitem__(self, index): + result = {} + # Static elements. Copy, because they will be modified below. + for key, element in self.data.items(): + if isinstance(element, dict): + result[key] = {k: v[index].detach().clone() for k, v in element.items()} + elif isinstance(element, list): + result[key] = [v[index].detach().clone() for v in element] + else: + result[key] = element[index].detach().clone() + + # For dropout, simply use pre-processed input_ids. + if not self.evaluate and self.mode == "default" and self.args.svd > 0.0: + result['input_ids'] = result['input_ids_dropout'] + + # Dynamic elements. + result['dropout_value_feat'] = {} + result['value_labels'] = {} + + if self.mode == "proto": + assert self.positive_samples_for_proto_pos is not None + rn = random.random() + if self.args.tag_none_target and rn <= self.args.proto_neg_sample_ratio: + result['start_pos'] = self.negative_samples_for_proto_pos[index] + result['slot_ids'] = self.negative_samples_for_proto_input_ids[index] + result['slot_mask'] = self.negative_samples_for_proto_input_mask[index] + else: + result['start_pos'] = self.positive_samples_for_proto_pos[index] + result['slot_ids'] = self.positive_samples_for_proto_input_ids[index] + result['slot_mask'] = self.positive_samples_for_proto_input_mask[index] + elif self.mode == "tag": + value_reps = {} + for slot in self.slot_list: + value_name = self.features[result['example_id']].values[slot] + if value_name not in self.encoded_slot_values[slot]: + value_rep = torch.zeros((1, self.model.config.hidden_size), dtype=torch.float) + else: + value_rep = self.encoded_slot_values[slot][value_name] + value_reps[slot] = value_rep + result['value_reps'] = value_reps + else: + # History dropout + if not self.evaluate and (self.args.hd > 0.0): + hst_boundaries = self.features[result['example_id']].hst_boundaries + if len(hst_boundaries) > 0: + rn = random.random() + if rn <= self.args.hd: + hst_dropout_idx = random.randint(0, len(hst_boundaries) - 1) + hst_dropout_start = hst_boundaries[hst_dropout_idx][0] + hst_dropout_end = hst_boundaries[-1][1] + result['input_ids'][hst_dropout_start] = result['input_ids'][hst_dropout_end] + result['input_ids'][hst_dropout_start + 1:hst_dropout_end + 1] = self.tokenizer.pad_token_id + result['input_mask'][hst_dropout_start + 1:hst_dropout_end + 1] = 0 + result['segment_ids'][hst_dropout_start + 1:hst_dropout_end + 1] = 0 + result['usr_mask'][hst_dropout_start:hst_dropout_end + 1] = 0 + for slot in self.slot_list: + result['start_pos'][slot][hst_dropout_start + 1:hst_dropout_end + 1] = 0 + # Labels + for slot in self.slot_list: + token_is_pointable = result['start_pos'][slot].sum() > 0 + # If no sequence is present, attention should be on <none> + if self.args.tag_none_target and not token_is_pointable: + result['start_pos'][slot][1] = 1 + pos_value = self.features[index].values[slot] + # For value matching: Only the correct value has a weight, all (!) others automatically become negative samples. + # TODO: Test subsampling negative samples. + + # For attention based value matching + result['value_labels'][slot] = torch.zeros((len(self.encoded_slot_values[slot])), dtype=torch.float) + result['dropout_value_feat'][slot] = torch.zeros((1, self.model.config.hidden_size), dtype=torch.float) + # Only train value matching, if value is extractable + if token_is_pointable and pos_value in self.encoded_slot_values[slot]: + result['value_labels'][slot][list(self.encoded_slot_values[slot].keys()).index(pos_value)] = 1.0 + # In case of dropout, forward new representation as target for value matching instead. + if self.dropout_value_seq is not None: + if result['example_id'].item() in self.dropout_value_seq[slot]: + dropout_value_seq = tuple(self.dropout_value_seq[slot][result['example_id'].item()]) + result['dropout_value_feat'][slot] = self.encoded_dropout_slot_values[slot][dropout_value_seq] + return result + + def _encode_text(self, text, input_ids, input_mask, mode="represent", train=False): + batch = { + "input_ids": input_ids.to(self.args.device), + "input_mask": input_mask.to(self.args.device), + "encoded_slots_pooled": self.encoded_slots_pooled.copy() if self.encoded_slots_pooled is not None else None, + "encoded_slots_seq": self.encoded_slots_seq.copy() if self.encoded_slots_seq is not None else None, + } + if train: + self.model.train() + encoded_text_pooled, encoded_text_seq, weights = self.model(batch, mode=mode) + self.model.eval() + else: + self.model.eval() + with torch.no_grad(): + encoded_text_pooled, encoded_text_seq, weights = self.model(batch, mode=mode) + return encoded_text_pooled, encoded_text_seq + + def __len__(self): + return self.size + + def _build_input(self, text, is_token_ids=False): + if not is_token_ids: + if "roberta" in self.args.model_type: + tokens = self.tokenizer.tokenize(convert_to_unicode(' ' + text)) + else: + tokens = self.tokenizer.tokenize(convert_to_unicode(text)) + input_id = self.tokenizer.convert_tokens_to_ids([self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token]) + else: + input_id = [self.tokenizer.cls_token_id] + text + [self.tokenizer.sep_token_id] + input_mask = [1] * len(input_id) + while len(input_id) < self.args.max_seq_length: + input_id.append(self.tokenizer.pad_token_id) + input_mask.append(0) + assert len(input_id) == self.args.max_seq_length + return torch.tensor(input_id), torch.tensor(input_mask) + + def _convert_examples_to_features(self, examples, slot_list, class_list, model_type, + max_seq_length, automatic_labels=None): + """Loads a data file into a list of `InputBatch`s.""" + + if model_type == 'roberta': + model_specs = {'MODEL_TYPE': 'roberta', + 'TOKEN_CORRECTION': 6} + else: + model_specs = {'MODEL_TYPE': 'bert', + 'TOKEN_CORRECTION': 4} + + def _tokenize_text(text, text_label_dict, model_specs): + token_to_subtoken = [] + tokens = [] + for token in text: + token = convert_to_unicode(token) + if model_specs['MODEL_TYPE'] == 'roberta': + # It seems the behaviour of the tokenizer changed in newer versions, + # which makes this case handling necessary. + if token != self.tokenizer.unk_token: + token = ' ' + token + sub_tokens = self.tokenizer.tokenize(token) # Most time intensive step + token_to_subtoken.append([token, sub_tokens]) + tokens.extend(sub_tokens) + return tokens, token_to_subtoken + + def _label_tokenized_text(tokens, text_label_dict, slot): + token_labels = [] + for element, token_label in zip(tokens, text_label_dict[slot]): + token, sub_tokens = element + token_labels.extend([token_label for _ in sub_tokens]) + return token_labels + + def _truncate_seq_pair(tokens_a, tokens_b, history, max_length): + """Truncates a sequence pair in place to the maximum length. + Copied from bert/run_classifier.py + """ + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + while True: + history_len = 0 + for hst in history: + for spk in hst: + history_len += len(spk) + total_length = len(tokens_a) + len(tokens_b) + history_len + if total_length <= max_length: + break + if len(history) > 0: + history.pop() # Remove one entire turn from the history + elif len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + + def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, model_specs, guid): + # Modifies `tokens_a` and `tokens_b` in place so that the total + # length is less than the specified length. + # Account for [CLS], [SEP], [SEP], [SEP] with "- 4" (BERT) + # Account for <s>, </s></s>, </s></s>, </s> with "- 6" (RoBERTa) + # Account for </s> after each history utterance (all models) + history_len = 0 + for hst in history: + for spk in hst: + history_len += len(spk) + max_len = max_seq_length - model_specs['TOKEN_CORRECTION'] - len(history) * 2 - self.args.tag_none_target * int(model_specs['TOKEN_CORRECTION'] / 2) + if len(tokens_a) + len(tokens_b) + history_len > max_len: + logger.info("Truncate Example %s. Total len=%d." % (guid, len(tokens_a) + len(tokens_b) + history_len)) + input_text_too_long = True + else: + input_text_too_long = False + _truncate_seq_pair(tokens_a, tokens_b, history, max_len) + return input_text_too_long + + def _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs): + token_label_ids = {slot: [] for slot in token_labels_a} + for slot in token_label_ids: + if self.args.tag_none_target: + if model_specs['MODEL_TYPE'] == 'roberta': + labels = [0] + [0, 0, 0] + token_labels_a[slot] + [0] # <s> <none> </s> </s> ... </s> + else: + labels = [0] + [0, 0] + token_labels_a[slot] + [0] # [CLS] [NONE] [SEP] ... [SEP] + else: + labels = [0] + token_labels_a[slot] + [0] # [CLS]/<s> ... [SEP]/</s> + if model_specs['MODEL_TYPE'] == 'roberta': + labels.append(0) # </s> + labels += token_labels_b[slot] + [0] # ... [SEP]/</s> + if model_specs['MODEL_TYPE'] == 'roberta': + labels.append(0) # </s> + token_label_ids[slot] = labels + + for hst in token_labels_history: + (utt_a, utt_b) = hst + for slot in token_label_ids: + token_label_ids[slot] += utt_a[slot] + [0] + utt_b[slot] + [0] # [SEP]/</s> + + for slot in token_label_ids: + if len(token_label_ids[slot]) < max_seq_length: + token_label_ids[slot] += (max_seq_length - len(token_label_ids[slot])) * [0] + + return token_label_ids + + def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, model_specs): + # The convention in BERT is: + # (a) For sequence pairs: + # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] + # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 + # (b) For single sequences: + # tokens: [CLS] the dog is hairy . [SEP] + # type_ids: 0 0 0 0 0 0 0 + # + # Where "type_ids" are used to indicate whether this is the first + # sequence or the second sequence. The embedding vectors for `type=0` and + # `type=1` were learned during pre-training and are added to the wordpiece + # embedding vector (and position vector). This is not *strictly* necessary + # since the [SEP] token unambiguously separates the sequences, but it makes + # it easier for the model to learn the concept of sequences. + # + # For classification tasks, the first vector (corresponding to [CLS]) is + # used as the "sentence vector". Note that this only makes sense because + # the entire model is fine-tuned. + cls = self.tokenizer.cls_token + sep = self.tokenizer.sep_token + if model_specs['MODEL_TYPE'] == 'roberta': + if self.args.tag_none_target: + tokens = [cls] + ['<none>', sep, sep] + tokens_a + [sep] + [sep] + tokens_b + [sep] + [sep] + segment_ids = [0] + [0, 0, 0] + len(tokens_a) * [0] + 2 * [0] + len(tokens_b) * [0] + 2 * [0] + usr_mask = [0] + [1, 0, 0] + len(tokens_a) * [0 if self.args.swap_utterances else 1] + 2 * [0] + len(tokens_b) * [1 if self.args.swap_utterances else 0] + 2 * [0] + else: + tokens = [cls] + tokens_a + [sep] + [sep] + tokens_b + [sep] + [sep] + segment_ids = [0] + len(tokens_a) * [0] + 2 * [0] + len(tokens_b) * [0] + 2 * [0] + usr_mask = [0] + len(tokens_a) * [0 if self.args.swap_utterances else 1] + 2 * [0] + len(tokens_b) * [1 if self.args.swap_utterances else 0] + 2 * [0] + else: + if self.args.tag_none_target: + tokens = [cls] + ['[NONE]', sep] + tokens_a + [sep] + tokens_b + [sep] + segment_ids = [0] + [0, 0] + len(tokens_a) * [0] + [0] + len(tokens_b) * [1] + [1] + usr_mask = [0] + [1, 0] + len(tokens_a) * [0 if self.args.swap_utterances else 1] + [0] + len(tokens_b) * [1 if self.args.swap_utterances else 0] + [0] + else: + tokens = [cls] + tokens_a + [sep] + tokens_b + [sep] + segment_ids = [0] + len(tokens_a) * [0] + [0] + len(tokens_b) * [1] + [1] + usr_mask = [0] + len(tokens_a) * [0 if self.args.swap_utterances else 1] + [0] + len(tokens_b) * [1 if self.args.swap_utterances else 0] + [0] + hst_boundaries = [] + for hst_itr in range(len(history)): + hst_a, hst_b = history[hst_itr] + hst_start = len(tokens) + tokens += hst_a + [sep] + hst_b + [sep] + hst_end = len(tokens) + hst_boundaries.append([hst_start, hst_end]) + if model_specs['MODEL_TYPE'] == 'roberta': + segment_ids += [0] * (len(hst_a) + 1 + len(hst_b) + 1) + else: + segment_ids += [1] * (len(hst_a) + 1 + len(hst_b) + 1) + usr_mask += len(hst_a) * [0 if self.args.swap_utterances else 1] + [0] + len(hst_b) * [1 if self.args.swap_utterances else 0] + [0] + tokens.append(sep) + if model_specs['MODEL_TYPE'] == 'roberta': + segment_ids.append(0) + else: + segment_ids.append(1) + usr_mask.append(0) + input_ids = self.tokenizer.convert_tokens_to_ids(tokens) + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1] * len(input_ids) + # Zero-pad up to the sequence length. + if len(input_ids) < max_seq_length: + len_diff = max_seq_length - len(input_ids) + input_ids += len_diff * [self.tokenizer.pad_token_id] + input_mask += len_diff * [0] + segment_ids += len_diff * [0] + usr_mask += len_diff * [0] + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + assert len(usr_mask) == max_seq_length + return tokens, input_ids, input_mask, segment_ids, usr_mask, hst_boundaries + + if automatic_labels is not None: + logger.warning("USING AUTOMATIC LABELS TO REPLACE GROUNDTRUTH! BE SURE YOU KNOW WHAT YOU ARE DOING!") + + total_cnt = 0 + too_long_cnt = 0 + + refer_list = list(slot_list.keys()) + ['none'] + + # Tokenize turns + tokens_dict = {} + for (example_index, example) in enumerate(examples): + if example_index % 1000 == 0: + logger.info("Tokenizing turn %d of %d" % (example_index, len(examples))) + re_guid = re.match("(.*)-([0-9]+)", example.guid) + re_guid_diag = re_guid[1] + re_guid_turn = int(re_guid[2]) + + tokens_a, token_to_subtoken_a = _tokenize_text( + example.text_a, example.text_a_label, model_specs) + tokens_b, token_to_subtoken_b = _tokenize_text( + example.text_b, example.text_b_label, model_specs) + + token_labels_a_dict = {} + token_labels_b_dict = {} + for slot in slot_list: + token_labels_a_dict[slot] = _label_tokenized_text(token_to_subtoken_a, example.text_a_label, slot) + token_labels_b_dict[slot] = _label_tokenized_text(token_to_subtoken_b, example.text_b_label, slot) + + # Use automatic labels (if provided) + if automatic_labels is not None: + for slot in slot_list: + # Case where <none> target was used during pre-training/tagging + if self.args.tag_none_target: + if model_specs['MODEL_TYPE'] == 'roberta': + a_start = 4 + else: + a_start = 3 + else: + a_start = 1 + auto_lbl = automatic_labels[slot][example_index] + a_end = a_start + len(token_labels_a_dict[slot]) + token_labels_a_dict[slot] = auto_lbl[a_start:a_end].int().tolist() + if model_specs['MODEL_TYPE'] == 'roberta': + b_start = a_end + 2 + else: + b_start = a_end + 1 + b_end = b_start + len(token_labels_b_dict[slot]) + token_labels_b_dict[slot] = auto_lbl[b_start:b_end].int().tolist() + + tokens_dict[(re_guid_diag, re_guid_turn)] = [[tokens_a, token_labels_a_dict], [tokens_b, token_labels_b_dict]] + + # Build single example + features = [] + for (example_index, example) in enumerate(examples): + if example_index % 1000 == 0: + logger.info("Writing example %d of %d" % (example_index, len(examples))) + + total_cnt += 1 + + # Gather history + re_guid = re.match("(.*)-([0-9]+)", example.guid) + diag_id = re_guid[1] + turn_id = int(re_guid[2]) + tokens_a, token_labels_a_dict = tokens_dict[(diag_id, turn_id)][0] + tokens_b, token_labels_b_dict = tokens_dict[(diag_id, turn_id)][1] + tokens_history = [] + token_labels_history_dict = [] + if not self.args.no_append_history: + for hst_itr in range(turn_id - 1, -1, -1): + tokens_history.append([]) + token_labels_history_dict.append([]) + for spk_itr in range(len(tokens_dict[(diag_id, hst_itr)])): + tokens_h, token_labels_h_dict = tokens_dict[(diag_id, hst_itr)][spk_itr] + tokens_history[-1].append(tokens_h) + token_labels_history_dict[-1].append(token_labels_h_dict) + for slot in slot_list: + if self.args.no_use_history_labels or example.slot_update[slot]: + for h in token_labels_history_dict: + for s in h: + s[slot] = len(s[slot]) * [0] + + input_text_too_long = _truncate_length_and_warn( + tokens_a, tokens_b, tokens_history, max_seq_length, model_specs, example.guid) + + if input_text_too_long: + too_long_cnt += 1 + + tokens, input_ids, input_mask, segment_ids, usr_mask, hst_boundaries = _get_transformer_input(tokens_a, + tokens_b, + tokens_history, + max_seq_length, + model_specs) + + for slot in slot_list: + token_labels_a_dict[slot] = token_labels_a_dict[slot][:len(tokens_a)] + token_labels_b_dict[slot] = token_labels_b_dict[slot][:len(tokens_b)] + token_labels_history_dict = token_labels_history_dict[:len(tokens_history)] + + token_label_ids = _get_token_label_ids(token_labels_a_dict, + token_labels_b_dict, + token_labels_history_dict, + max_seq_length, + model_specs) + + value_dict = {} + inform_dict = {} + inform_slot_dict = {} + refer_id_dict = {} + diag_state_dict = {} + class_label_id_dict = {} + start_pos_dict = {} + for slot in slot_list: + assert len(token_label_ids[slot]) == len(input_ids) + + value_dict[slot] = example.values[slot] + inform_dict[slot] = example.inform_label[slot] + + start_pos_dict[slot] = token_label_ids[slot] + + inform_slot_dict[slot] = example.inform_slot_label[slot] + refer_id_dict[slot] = refer_list.index(example.refer_label[slot]) + diag_state_dict[slot] = class_list.index(example.diag_state[slot]) + class_label_id_dict[slot] = class_list.index(example.class_label[slot]) + + if example_index < 10: + logger.info("*** Example ***") + logger.info("guid: %s" % (example.guid)) + logger.info("tokens: %s" % " ".join(tokens)) + logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) + logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) + logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) + logger.info("usr_mask: %s" % " ".join([str(x) for x in usr_mask])) + logger.info("start_pos: %s" % str(start_pos_dict)) + logger.info("values: %s" % str(value_dict)) + logger.info("inform: %s" % str(inform_dict)) + logger.info("inform_slot: %s" % str(inform_slot_dict)) + logger.info("refer_id: %s" % str(refer_id_dict)) + logger.info("diag_state: %s" % str(diag_state_dict)) + logger.info("class_label_id: %s" % str(class_label_id_dict)) + logger.info("hst_boundaries: %s" % " ".join([str(x) for x in hst_boundaries])) + + features.append( + InputFeatures( + guid=example.guid, + input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids, + usr_mask=usr_mask, + start_pos=start_pos_dict, + values=value_dict, + inform=inform_dict, + inform_slot=inform_slot_dict, + refer_id=refer_id_dict, + diag_state=diag_state_dict, + class_label_id=class_label_id_dict, + hst_boundaries=hst_boundaries)) + + logger.info("========== %d out of %d examples have text too long" % (too_long_cnt, total_cnt)) + + return features + + +# From bert.tokenization (TF code) +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + elif isinstance(text, unicode): + return text + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?") diff --git a/utils_run.py b/utils_run.py new file mode 100644 index 0000000..0a185c1 --- /dev/null +++ b/utils_run.py @@ -0,0 +1,146 @@ +# coding=utf-8 +# +# Copyright 2020-2022 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import random + +import numpy as np +import torch + +from utils_dst import (TrippyDataset) + +logger = logging.getLogger(__name__) + + +def print_header(): + logger.info(" _________ ________ ___ ________ ________ ___ ___ ________ ") + logger.info("|\___ ___\\\ __ \|\ \|\ __ \|\ __ \|\ \ / /| |\ __ \ ") + logger.info("\|___ \ \_\ \ \|\ \ \ \ \ \|\ \ \ \|\ \ \ \/ / /______\ \ \|\ \ ") + logger.info(" \ \ \ \ \ _ _\ \ \ \ ____\ \ ____\ \ / /\_______\ \ _ _\ ") + logger.info(" \ \ \ \ \ \\\ \\\ \ \ \ \___|\ \ \___|\/ / /\|_______|\ \ \\\ \| ") + logger.info(" \ \__\ \ \__\\\ _\\\ \__\ \__\ \ \__\ __/ / / \ \__\\\ _\ ") + logger.info(" \|__| \|__|\|__|\|__|\|__| \|__||\___/ / \|__|\|__|") + logger.info(" (c) 2022 Heinrich Heine University \|___|/ ") + logger.info("") + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def to_device(batch, device): + if isinstance(batch, tuple): + batch_on_device = tuple([to_device(element, device) for element in batch]) + if isinstance(batch, dict): + batch_on_device = {k: to_device(v, device) for k, v in batch.items()} + else: + batch_on_device = batch.to(device) if batch is not None else batch + return batch_on_device + + +def from_device(batch): + if isinstance(batch, tuple): + batch_on_cpu = tuple([from_device(element) for element in batch]) + elif isinstance(batch, dict): + batch_on_cpu = {k: from_device(v) for k, v in batch.items()} + else: + batch_on_cpu = batch.cpu() if batch is not None else batch + return batch_on_cpu + + +def save_checkpoint(args, global_step, model, prefix='', keep_only_last_checkpoint=False): + if len(prefix) > 0: + prefix = prefix + '_' + if keep_only_last_checkpoint: + output_dir = os.path.join(args.output_dir, prefix + 'checkpoint') + else: + output_dir = os.path.join(args.output_dir, prefix + 'checkpoint-{}'.format(global_step)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + model_to_save.save_pretrained(output_dir) + torch.save(args, os.path.join(output_dir, prefix + 'training_args.bin')) + logger.info("Saving model checkpoint after global step %d to %s" % (global_step, output_dir)) + with open(os.path.join(args.output_dir, 'last_' + prefix + 'checkpoint.txt'), 'w') as f: + f.write('{}checkpoint-{}'.format(prefix, global_step)) + + +def load_and_cache_examples(args, model, tokenizer, processor, dset="train", evaluate=False, automatic_labels=None): + assert dset in ["train", "dev", "test"] + + if args.local_rank not in [-1, 0] and not evaluate: + torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache + + dataset = TrippyDataset(args, examples=None, model=model, tokenizer=tokenizer, processor=processor, dset=dset, evaluate=evaluate, + automatic_labels=automatic_labels) + + # Load data features from cache or dataset file + cached_file = os.path.join(os.path.dirname(args.output_dir), 'cached_{}_features{}'.format( + dset, args.cache_suffix)) + if os.path.exists(cached_file) and not args.overwrite_cache and not args.no_cache: + dataset.load_features_from_file(cached_file) + else: + logger.info("Creating features from dataset file at %s", args.data_dir) + processor_args = {'no_label_value_repetitions': args.no_label_value_repetitions, + 'swap_utterances': args.swap_utterances, + 'delexicalize_sys_utts': args.delexicalize_sys_utts, + 'unk_token': tokenizer.unk_token} + if dset == "dev": + examples = processor.get_dev_examples(processor_args) + elif dset == "test": + examples = processor.get_test_examples(processor_args) + elif dset == "train": + examples = processor.get_train_examples(processor_args) + else: + logger.warning("Unknown dataset \"%s\". Aborting" % (dset)) + dataset.build_features_from_examples(examples) + if not args.no_cache: + if args.local_rank in [-1, 0]: + dataset.save_features_to_file(cached_file) + + if args.local_rank == 0 and not evaluate: + torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache + + return dataset + + +def dilate_and_erode(weights, threshold): + def dilate(seq): + result = np.clip(np.convolve(seq, [1, 1, 1], mode='same'), 0, 1) + return result + + def erode(seq): + result = (~np.clip(np.convolve(~seq.astype(bool), [1, 1, 1], mode='same'), 0, 1).astype(bool)).astype(float) + return result + + result = [] + for seq in weights: + d = dilate(seq) + dt = d > threshold + e = erode(dt) + result.append(torch.tensor(e)) + result = torch.stack(result) + return result -- GitLab