Skip to content
Snippets Groups Projects
Commit 92501f68 authored by Michael Heck's avatar Michael Heck
Browse files

initial commit

parent ef736409
No related branches found
No related tags found
No related merge requests found
Showing with 7665 additions and 4 deletions
# Store binaries in LFS
## Custom paths
results/ filter=lfs diff=lfs merge=lfs -text
data/ filter=lfs diff=lfs merge=lfs -text
## Archive/Compressed
*.7z filter=lfs diff=lfs merge=lfs -text
*.cpio filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.iso filter=lfs diff=lfs merge=lfs -text
*.bz filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.bzip filter=lfs diff=lfs merge=lfs -text
*.bzip2 filter=lfs diff=lfs merge=lfs -text
*.cab filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.gzip filter=lfs diff=lfs merge=lfs -text
*.lz filter=lfs diff=lfs merge=lfs -text
*.lzma filter=lfs diff=lfs merge=lfs -text
*.lzo filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.z filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.ace filter=lfs diff=lfs merge=lfs -text
*.dmg filter=lfs diff=lfs merge=lfs -text
*.dd filter=lfs diff=lfs merge=lfs -text
*.apk filter=lfs diff=lfs merge=lfs -text
*.ear filter=lfs diff=lfs merge=lfs -text
*.jar filter=lfs diff=lfs merge=lfs -text
*.deb filter=lfs diff=lfs merge=lfs -text
*.cue filter=lfs diff=lfs merge=lfs -text
*.dump filter=lfs diff=lfs merge=lfs -text
## Image
*.jpg filter=lfs diff=lfs merge=lfs -text
*.jpeg filter=lfs diff=lfs merge=lfs -text
*.gif filter=lfs diff=lfs merge=lfs -text
*.png filter=lfs diff=lfs merge=lfs -text
*.psd filter=lfs diff=lfs merge=lfs -text
*.bmp filter=lfs diff=lfs merge=lfs -text
*.dng filter=lfs diff=lfs merge=lfs -text
*.cdr filter=lfs diff=lfs merge=lfs -text
*.indd filter=lfs diff=lfs merge=lfs -text
*.tiff filter=lfs diff=lfs merge=lfs -text
*.tif filter=lfs diff=lfs merge=lfs -text
*.psp filter=lfs diff=lfs merge=lfs -text
*.tga filter=lfs diff=lfs merge=lfs -text
*.eps filter=lfs diff=lfs merge=lfs -text
*.svg filter=lfs diff=lfs merge=lfs -text
## Documents
*.pdf filter=lfs diff=lfs merge=lfs -text
*.doc filter=lfs diff=lfs merge=lfs -text
*.docx filter=lfs diff=lfs merge=lfs -text
*.xls filter=lfs diff=lfs merge=lfs -text
*.xlsx filter=lfs diff=lfs merge=lfs -text
*.ppt filter=lfs diff=lfs merge=lfs -text
*.pptx filter=lfs diff=lfs merge=lfs -text
*.ppz filter=lfs diff=lfs merge=lfs -text
*.dot filter=lfs diff=lfs merge=lfs -text
*.dotx filter=lfs diff=lfs merge=lfs -text
*.lwp filter=lfs diff=lfs merge=lfs -text
*.odm filter=lfs diff=lfs merge=lfs -text
*.odt filter=lfs diff=lfs merge=lfs -text
*.ott filter=lfs diff=lfs merge=lfs -text
*.ods filter=lfs diff=lfs merge=lfs -text
*.ots filter=lfs diff=lfs merge=lfs -text
*.odp filter=lfs diff=lfs merge=lfs -text
*.otp filter=lfs diff=lfs merge=lfs -text
*.odg filter=lfs diff=lfs merge=lfs -text
*.otg filter=lfs diff=lfs merge=lfs -text
*.wps filter=lfs diff=lfs merge=lfs -text
*.wpd filter=lfs diff=lfs merge=lfs -text
*.wpt filter=lfs diff=lfs merge=lfs -text
*.xps filter=lfs diff=lfs merge=lfs -text
*.ttf filter=lfs diff=lfs merge=lfs -text
*.otf filter=lfs diff=lfs merge=lfs -text
*.dvi filter=lfs diff=lfs merge=lfs -text
*.pages filter=lfs diff=lfs merge=lfs -text
*.key filter=lfs diff=lfs merge=lfs -text
## Audio/Video
*.mpg filter=lfs diff=lfs merge=lfs -text
*.mpeg filter=lfs diff=lfs merge=lfs -text
*.mp3 filter=lfs diff=lfs merge=lfs -text
*.mp4 filter=lfs diff=lfs merge=lfs -text
*.avi filter=lfs diff=lfs merge=lfs -text
*.wav filter=lfs diff=lfs merge=lfs -text
*.mkv filter=lfs diff=lfs merge=lfs -text
*.3gp filter=lfs diff=lfs merge=lfs -text
*.flv filter=lfs diff=lfs merge=lfs -text
*.m4v filter=lfs diff=lfs merge=lfs -text
*.ogg filter=lfs diff=lfs merge=lfs -text
*.mov filter=lfs diff=lfs merge=lfs -text
*.wmv filter=lfs diff=lfs merge=lfs -text
*.webm filter=lfs diff=lfs merge=lfs -text
## VM
*.vfd filter=lfs diff=lfs merge=lfs -text
*.vhd filter=lfs diff=lfs merge=lfs -text
*.vmdk filter=lfs diff=lfs merge=lfs -text
*.vmsd filter=lfs diff=lfs merge=lfs -text
*.vmsn filter=lfs diff=lfs merge=lfs -text
*.vmss filter=lfs diff=lfs merge=lfs -text
*.dsk filter=lfs diff=lfs merge=lfs -text
*.vdi filter=lfs diff=lfs merge=lfs -text
*.cow filter=lfs diff=lfs merge=lfs -text
*.qcow filter=lfs diff=lfs merge=lfs -text
*.qcow2 filter=lfs diff=lfs merge=lfs -text
*.qed filter=lfs diff=lfs merge=lfs -text
## Other
*.exe filter=lfs diff=lfs merge=lfs -text
*.sxi filter=lfs diff=lfs merge=lfs -text
*.dat filter=lfs diff=lfs merge=lfs -text
*.data filter=lfs diff=lfs merge=lfs -text
#!/bin/bash
# Parameters ------------------------------------------------------
# --- Sim-M dataset
#TASK="sim-m"
#DATA_DIR="data/simulated-dialogue/sim-M"
#DATASET_CONFIG="dataset_config/sim-m.json"
# --- Sim-R dataset
#TASK="sim-r"
#DATA_DIR="data/simulated-dialogue/sim-R"
#DATASET_CONFIG="dataset_config/sim-r.json"
# --- WOZ 2.0 dataset
#TASK="woz2"
#DATA_DIR="data/woz2"
#DATASET_CONFIG="dataset_config/woz2.json"
# --- MultiWOZ 2.1 legacy version dataset
#TASK="multiwoz21_legacy"
#DATA_DIR="data/MULTIWOZ2.1"
#DATASET_CONFIG="dataset_config/multiwoz21.json"
# --- MultiWOZ 2.1 dataset
TASK="multiwoz21"
DATA_DIR="data/multiwoz/data/MultiWOZ_2.1"
DATASET_CONFIG="dataset_config/multiwoz21.json"
# --- MultiWOZ 2.1 in ConvLab3's unified data format
#TASK="unified"
#DATA_DIR=""
#DATASET_CONFIG="dataset_config/unified_multiwoz21.json"
SEEDS="42"
TRAIN_PHASES="-1" # -1: regular training, 0: proto training, 1: tagging, 2: spanless training
VALUE_MATCHING_WEIGHT=0.1 # When 0.0, value matching is not used
# Project paths etc. ----------------------------------------------
OUT_DIR=results
for x in ${SEEDS}; do
mkdir -p ${OUT_DIR}.${x}
done
# Main ------------------------------------------------------------
for x in ${SEEDS}; do
for step in train dev test; do
args_add=""
phases="-1"
if [ "$step" = "train" ]; then
args_add="--do_train --predict_type=dev --svd=0.1 --hd=0.1"
phases=${TRAIN_PHASES}
elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
args_add="--do_eval --predict_type=${step}"
fi
for phase in ${phases}; do
args_add_0=""
if [ "$phase" = 0 ]; then
args_add_0=""
fi
args_add_1=""
if [ "$phase" = 1 ]; then
args_add_1=""
fi
args_add_2=""
if [ "$phase" = 2 ]; then
args_add_2=""
fi
python3 run_dst.py \
--task_name=${TASK} \
--data_dir=${DATA_DIR} \
--dataset_config=${DATASET_CONFIG} \
--model_type="roberta" \
--model_name_or_path="roberta-base" \
--seed=${x} \
--do_lower_case \
--learning_rate=5e-5 \
--num_train_epochs=20 \
--max_seq_length=180 \
--per_gpu_train_batch_size=32 \
--per_gpu_eval_batch_size=32 \
--output_dir=${OUT_DIR}.${x} \
--patience=10 \
--evaluate_during_training \
--eval_all_checkpoints \
--warmup_proportion=0.05 \
--adam_epsilon=1e-6 \
--weight_decay=0.01 \
--fp16 \
--value_matching_weight=${VALUE_MATCHING_WEIGHT} \
--none_weight=0.1 \
--use_td \
--td_ratio=0.2 \
--training_phase=${phase} \
${args_add} \
${args_add_0} \
${args_add_1} \
${args_add_2} \
2>&1 | tee ${OUT_DIR}.${x}/${step}.${phase}.log
done
if [ "$step" = "dev" -o "$step" = "test" ]; then
confidence=1.0
if [[ ${VALUE_MATCHING_WEIGHT} > 0.0 ]]; then
confidence="1.0 0.9 0.8 0.7 0.6 0.5"
fi
for dist_conf_threshold in ${confidence}; do
python3 metric_dst.py \
--dataset_config=${DATASET_CONFIG} \
--confidence_threshold=${dist_conf_threshold} \
--file_list="${OUT_DIR}.${x}/pred_res.${step}*json" \
2>&1 | tee ${OUT_DIR}.${x}/eval_pred_${step}.${dist_conf_threshold}.log
done
fi
done
done
#!/bin/bash
# Parameters ------------------------------------------------------
# --- Sim-M dataset
#TASK="sim-m"
#DATA_DIR="data/simulated-dialogue/sim-M"
#DATASET_CONFIG="dataset_config/sim-m.json"
# --- Sim-R dataset
#TASK="sim-r"
#DATA_DIR="data/simulated-dialogue/sim-R"
#DATASET_CONFIG="dataset_config/sim-r.json"
# --- WOZ 2.0 dataset
#TASK="woz2"
#DATA_DIR="data/woz2"
#DATASET_CONFIG="dataset_config/woz2.json"
# --- MultiWOZ 2.1 legacy version dataset
#TASK="multiwoz21_legacy"
#DATA_DIR="data/MULTIWOZ2.1"
#DATASET_CONFIG="dataset_config/multiwoz21.json"
# --- MultiWOZ 2.1 dataset
TASK="multiwoz21"
DATA_DIR="data/multiwoz/data/MultiWOZ_2.1"
DATASET_CONFIG="dataset_config/multiwoz21.json"
# --- MultiWOZ 2.1 in ConvLab3's unified data format
#TASK="unified"
#DATA_DIR=""
#DATASET_CONFIG="dataset_config/unified_multiwoz21.json"
SEEDS="42"
TRAIN_PHASES="-1" # -1: regular training, 0: proto training, 1: tagging, 2: spanless training
VALUE_MATCHING_WEIGHT=0.1 # When 0.0, value matching is not used
# Project paths etc. ----------------------------------------------
OUT_DIR=results
for x in ${SEEDS}; do
mkdir -p ${OUT_DIR}.${x}
done
# Main ------------------------------------------------------------
for x in ${SEEDS}; do
for step in train dev test; do
args_add=""
phases="-1"
if [ "$step" = "train" ]; then
args_add="--do_train --predict_type=dev --svd=0.1 --hd=0.1"
phases=${TRAIN_PHASES}
elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
args_add="--do_eval --predict_type=${step} --no_cache"
fi
for phase in ${phases}; do
ep=20
warmup=0.05
args_add_0=""
if [ "$phase" = 0 ]; then
ep=50
warmup=0.1
args_add_0="--no_append_history"
fi
args_add_1=""
if [ "$phase" = 1 ]; then
args_add_1="--no_append_history"
fi
args_add_2=""
if [ "$phase" = 2 ]; then
args_add_2="--cache_suffix=_auto_${x}"
fi
python3 run_dst.py \
--task_name=${TASK} \
--data_dir=${DATA_DIR} \
--dataset_config=${DATASET_CONFIG} \
--model_type="roberta" \
--model_name_or_path="roberta-base" \
--seed=${x} \
--do_lower_case \
--learning_rate=5e-5 \
--num_train_epochs=${ep} \
--max_seq_length=180 \
--per_gpu_train_batch_size=32 \
--per_gpu_eval_batch_size=32 \
--output_dir=${OUT_DIR}.${x} \
--patience=10 \
--evaluate_during_training \
--eval_all_checkpoints \
--warmup_proportion=${warmup} \
--adam_epsilon=1e-6 \
--weight_decay=0.01 \
--fp16 \
--value_matching_weight=${VALUE_MATCHING_WEIGHT} \
--none_weight=0.1 \
--tag_none_target \
--use_td \
--td_ratio=0.2 \
--training_phase=${phase} \
${args_add} \
${args_add_0} \
${args_add_1} \
${args_add_2} \
2>&1 | tee ${OUT_DIR}.${x}/${step}.${phase}.log
done
if [ "$step" = "dev" -o "$step" = "test" ]; then
confidence=1.0
if [[ ${VALUE_MATCHING_WEIGHT} > 0.0 ]]; then
confidence="1.0 0.9 0.8 0.7 0.6 0.5"
fi
for dist_conf_threshold in ${confidence}; do
python3 metric_dst.py \
--dataset_config=${DATASET_CONFIG} \
--confidence_threshold=${dist_conf_threshold} \
--file_list="${OUT_DIR}.${x}/pred_res.${step}*json" \
2>&1 | tee ${OUT_DIR}.${x}/eval_pred_${step}.${dist_conf_threshold}.log
done
fi
done
done
LICENSE 0 → 100644
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2020 Heinrich Heine University Duesseldorf
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# TripPy-R - Public
## Introduction
This is the code repository for our journal paper titled "Robust Dialogue State Tracking with Weak Supervision and Sparse Data", which is accepted for publication in TACL.
The pre-print is available at [arXiv](https://doi.org/10.48550/arXiv.2202.03354).
Generalising dialogue state tracking (DST) to new data is especially challenging due to the strong reliance on abundant and fine-grained supervision during training. Sample sparsity, distributional shift and the occurrence of new concepts and topics frequently lead to severe performance degradation during inference. TripPy-R (pronounced "Trippier"), robust triple copy strategy DST, can use a training strategy to build extractive DST models without the need for fine-grained manual span labels ("spanless training"). Further, two novel input-level dropout methods mitigate the negative impact of sample sparsity. TripPy-R uses a new model architecture with a unified encoder that supports value as well as slot independence by leveraging the attention mechanism, making it zero-shot capable. The framework combines the strengths of triple copy strategy DST and value matching to benefit from complementary predictions without violating the principle of ontology independence. In our paper we demonstrate that an extractive DST model can be trained without manual span labels. Our architecture and training strategies improve robustness towards sample sparsity, new concepts and topics, leading to state-of-the-art performance on a range of benchmarks.
We will publish the code here after our journal paper is published in TACL.
## Recent updates
- 2023.08.08: Initial commit
## How to run
Two example scripts are provided for how to use TripPy-R.
`DO.example` will train and evaluate a model with recommended settings with the default supervised training strategy.
`DO.example.spanless` will train and evaluate a model with recommended settings with the novel spanless training strategy. The training consists of three steps: 1) Training a proto-DST that learns to tag the positions of queried subsequences in an input sequence. 2) Applying the proto-DST to tag the positions of slot-value occurrences in the training data. 3) Training the DST using the automatic labels produced by the previous step.
See below table for expected performance per dataset and training strategy. Our scripts use the parameters that were used for experiments in our paper "Robust Dialogue State Tracking with Weak Supervision and Sparse Data". Thus, performance will be similar to the reported ones. For more challenging datasets with longer dialogues, better performance may be achieved by using the maximum sequence length of 512.
## Trouble-shooting
When conducting spanless training, the training of the proto-DST (Step 1 of 3, see above) is rather sensitive to the training hyperparameters such as learning rate, warm-up ratio and max. number of epochs, as well as the random model initialization. We recommend the hyperparameters as listed in the example script above. If the proto-DST's tagging performance (Step 2 of 3) remains below expectations for one or more slots, try running the training with a different random initialization, i.e. pick a different random seed, while using the recommended hyperparameters.
## Datasets
Supported datasets are:
- sim-M (https://github.com/google-research-datasets/simulated-dialogue.git)
- sim-R (https://github.com/google-research-datasets/simulated-dialogue.git)
- WOZ 2.0 (see https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public.git)
- MultiWOZ 2.0 (https://github.com/budzianowski/multiwoz.git)
- MultiWOZ 2.1 (https://github.com/budzianowski/multiwoz.git)
- MultiWOZ 2.1 legacy version (see https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public.git)
- MultiWOZ 2.2 (https://github.com/budzianowski/multiwoz.git)
- MultiWOZ 2.3 (https://github.com/lexmen318/MultiWOZ-coref.git)
- MultiWOZ 2.4 (https://github.com/smartyfh/MultiWOZ2.4.git)
- Unified data format (currently supported: MultiWOZ 2.1) (see https://github.com/ConvLab/ConvLab-3)
See the [README file](https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public/-/blob/master/data/README.md) in 'data/' in the original [TripPy repo](https://gitlab.cs.uni-duesseldorf.de/general/dsml/trippy-public) for more details how to obtain and prepare the datasets for use in TripPy-R.
The ```--task_name``` is
- 'sim-m', for sim-M
- 'sim-r', for sim-R
- 'woz2', for WOZ 2.0
- 'multiwoz21', for MultiWOZ 2.0-2.4
- 'multiwoz21_legacy', for MultiWOZ 2.1 legacy version
- 'unified', for ConvLab-3's unified data format
With a sequence length of 180, you should expect the following average JGA:
| Dataset | Normal training | Spanless training |
| ------ | ------ | ------ |
| MultiWOZ 2.0 | 51% | tbd |
| MultiWOZ 2.1 | 56% | 55% |
| MultiWOZ 2.1 legacy | 56% | 55% |
| MultiWOZ 2.2 | 56% | tbd |
| MultiWOZ 2.3 | 62% | tbd |
| MultiWOZ 2.4 | 69% | tbd |
| sim-M | 95% | 95% |
| sim-R | 92% | 92% |
| WOZ 2.0 | 92% | 91% |
## Requirements
- torch (tested: 1.12.1)
- transformers (tested: 4.18.0)
- tensorboardX (tested: 2.5.1)
## Citation
This work is published as [Robust Dialogue State Tracking with Weak Supervision and Sparse Data ](https://doi.org/10.1162/tacl_a_00513)
If you use TripPy-R in your own work, please cite our work as follows:
```
@article{heck-etal-2022-robust,
title = "Robust Dialogue State Tracking with Weak Supervision and Sparse Data",
author = "Heck, Michael and Lubis, Nurul and van Niekerk, Carel and
Feng, Shutong and Geishauser, Christian and Lin, Hsien-Chin and Ga{\v{s}}i{\'c}, Milica",
journal = "Transactions of the Association for Computational Linguistics",
volume = "10",
year = "2022",
address = "Cambridge, MA",
publisher = "MIT Press",
url = "https://aclanthology.org/2022.tacl-1.68",
doi = "10.1162/tacl_a_00513",
pages = "1175--1192",
}
```
# coding=utf-8
#
# Copyright 2020-2022 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
import dataset_woz2
import dataset_sim
import dataset_multiwoz21
import dataset_multiwoz21_legacy
import dataset_unified
class DataProcessor(object):
data_dir = ""
dataset_name = ""
class_types = []
slot_list = {}
noncategorical = []
boolean = []
label_maps = {}
value_list = {'train': {}, 'dev': {}, 'test': {}}
def __init__(self, dataset_config, data_dir):
self.data_dir = data_dir
# Load dataset config file.
with open(dataset_config, "r", encoding='utf-8') as f:
raw_config = json.load(f)
self.dataset_name = raw_config['dataset_name'] if 'dataset_name' in raw_config else ""
self.class_types = raw_config['class_types'] # Required
self.slot_list = raw_config['slots'] if 'slots' in raw_config else {}
self.noncategorical = raw_config['noncategorical'] if 'noncategorical' in raw_config else []
self.boolean = raw_config['boolean'] if 'boolean' in raw_config else []
self.label_maps = raw_config['label_maps'] if 'label_maps' in raw_config else {}
# If not slot list is provided, generate from data.
if len(self.slot_list) == 0:
self.slot_list = self._get_slot_list()
def _add_dummy_value_to_value_list(self):
for dset in self.value_list:
for s in self.value_list[dset]:
if len(self.value_list[dset][s]) == 0:
self.value_list[dset][s] = {'dummy': 1}
def _remove_dummy_value_from_value_list(self):
for dset in self.value_list:
for s in self.value_list[dset]:
if self.value_list[dset][s] == {'dummy': 1}:
self.value_list[dset][s] = {}
def _merge_with_train_value_list(self, new_value_list):
self._remove_dummy_value_from_value_list()
for s in new_value_list:
if s not in self.value_list['train']:
self.value_list['train'][s] = new_value_list[s]
else:
for v in new_value_list[s]:
if v not in self.value_list['train'][s]:
self.value_list['train'][s][v] = new_value_list[s][v]
else:
self.value_list['train'][s][v] += new_value_list[s][v]
self._add_dummy_value_to_value_list()
def _get_slot_list(self):
raise NotImplementedError()
def prediction_normalization(self, slot, value):
return value
def get_train_examples(self):
raise NotImplementedError()
def get_dev_examples(self):
raise NotImplementedError()
def get_test_examples(self):
raise NotImplementedError()
class Woz2Processor(DataProcessor):
def __init__(self, dataset_config, data_dir):
super(Woz2Processor, self).__init__(dataset_config, data_dir)
self.value_list['train'] = dataset_woz2.get_value_list(os.path.join(self.data_dir, 'woz_train_en.json'),
self.slot_list)
self.value_list['dev'] = dataset_woz2.get_value_list(os.path.join(self.data_dir, 'woz_validate_en.json'),
self.slot_list)
self.value_list['test'] = dataset_woz2.get_value_list(os.path.join(self.data_dir, 'woz_test_en.json'),
self.slot_list)
def get_train_examples(self, args):
return dataset_woz2.create_examples(os.path.join(self.data_dir, 'woz_train_en.json'),
'train', self.slot_list, self.label_maps, **args)
def get_dev_examples(self, args):
return dataset_woz2.create_examples(os.path.join(self.data_dir, 'woz_validate_en.json'),
'dev', self.slot_list, self.label_maps, **args)
def get_test_examples(self, args):
return dataset_woz2.create_examples(os.path.join(self.data_dir, 'woz_test_en.json'),
'test', self.slot_list, self.label_maps, **args)
class Multiwoz21Processor(DataProcessor):
def __init__(self, dataset_config, data_dir):
super(Multiwoz21Processor, self).__init__(dataset_config, data_dir)
self.value_list['train'] = dataset_multiwoz21.get_value_list(os.path.join(self.data_dir, 'train_dials.json'),
self.slot_list)
self.value_list['dev'] = dataset_multiwoz21.get_value_list(os.path.join(self.data_dir, 'val_dials.json'),
self.slot_list)
self.value_list['test'] = dataset_multiwoz21.get_value_list(os.path.join(self.data_dir, 'test_dials.json'),
self.slot_list)
self._add_dummy_value_to_value_list()
def prediction_normalization(self, slot, value):
return dataset_multiwoz21.prediction_normalization(slot, value)
def get_train_examples(self, args):
return dataset_multiwoz21.create_examples(os.path.join(self.data_dir, 'train_dials.json'),
'train', self.class_types, self.slot_list, self.label_maps, **args)
def get_dev_examples(self, args):
return dataset_multiwoz21.create_examples(os.path.join(self.data_dir, 'val_dials.json'),
'dev', self.class_types, self.slot_list, self.label_maps, **args)
def get_test_examples(self, args):
return dataset_multiwoz21.create_examples(os.path.join(self.data_dir, 'test_dials.json'),
'test', self.class_types, self.slot_list, self.label_maps, **args)
class Multiwoz21LegacyProcessor(DataProcessor):
def __init__(self, dataset_config, data_dir):
super(Multiwoz21LegacyProcessor, self).__init__(dataset_config, data_dir)
self.value_list['train'] = dataset_multiwoz21_legacy.get_value_list(os.path.join(self.data_dir, 'train_dials.json'),
self.slot_list)
self.value_list['dev'] = dataset_multiwoz21_legacy.get_value_list(os.path.join(self.data_dir, 'val_dials.json'),
self.slot_list)
self.value_list['test'] = dataset_multiwoz21_legacy.get_value_list(os.path.join(self.data_dir, 'test_dials.json'),
self.slot_list)
self._add_dummy_value_to_value_list()
def prediction_normalization(self, slot, value):
return dataset_multiwoz21.prediction_normalization(slot, value)
def get_train_examples(self, args):
return dataset_multiwoz21_legacy.create_examples(os.path.join(self.data_dir, 'train_dials.json'),
os.path.join(self.data_dir, 'dialogue_acts.json'),
'train', self.slot_list, self.label_maps, **args)
def get_dev_examples(self, args):
return dataset_multiwoz21_legacy.create_examples(os.path.join(self.data_dir, 'val_dials.json'),
os.path.join(self.data_dir, 'dialogue_acts.json'),
'dev', self.slot_list, self.label_maps, **args)
def get_test_examples(self, args):
return dataset_multiwoz21_legacy.create_examples(os.path.join(self.data_dir, 'test_dials.json'),
os.path.join(self.data_dir, 'dialogue_acts.json'),
'test', self.slot_list, self.label_maps, **args)
class SimProcessor(DataProcessor):
def __init__(self, dataset_config, data_dir):
super(SimProcessor, self).__init__(dataset_config, data_dir)
self.value_list['train'] = dataset_sim.get_value_list(os.path.join(self.data_dir, 'train.json'),
self.slot_list)
self.value_list['dev'] = dataset_sim.get_value_list(os.path.join(self.data_dir, 'dev.json'),
self.slot_list)
self.value_list['test'] = dataset_sim.get_value_list(os.path.join(self.data_dir, 'test.json'),
self.slot_list)
def get_train_examples(self, args):
return dataset_sim.create_examples(os.path.join(self.data_dir, 'train.json'),
'train', self.slot_list, **args)
def get_dev_examples(self, args):
return dataset_sim.create_examples(os.path.join(self.data_dir, 'dev.json'),
'dev', self.slot_list, **args)
def get_test_examples(self, args):
return dataset_sim.create_examples(os.path.join(self.data_dir, 'test.json'),
'test', self.slot_list, **args)
class UnifiedDatasetProcessor(DataProcessor):
def __init__(self, dataset_config, data_dir):
super(UnifiedDatasetProcessor, self).__init__(dataset_config, data_dir)
self.value_list['train'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list)
self.value_list['dev'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list)
self.value_list['test'] = dataset_unified.get_value_list(self.dataset_name, self.slot_list)
self._add_dummy_value_to_value_list()
def prediction_normalization(self, slot, value):
return dataset_unified.prediction_normalization(self.dataset_name, slot, value)
def _get_slot_list(self):
return dataset_unified.get_slot_list(self.dataset_name)
def get_train_examples(self, args):
return dataset_unified.create_examples('train', self.dataset_name, self.class_types,
self.slot_list, self.label_maps, **args)
def get_dev_examples(self, args):
return dataset_unified.create_examples('validation', self.dataset_name, self.class_types,
self.slot_list, self.label_maps, **args)
def get_test_examples(self, args):
return dataset_unified.create_examples('test', self.dataset_name, self.class_types,
self.slot_list, self.label_maps, **args)
PROCESSORS = {"woz2": Woz2Processor,
"sim-m": SimProcessor,
"sim-r": SimProcessor,
"multiwoz21": Multiwoz21Processor,
"multiwoz21_legacy": Multiwoz21LegacyProcessor,
"unified": UnifiedDatasetProcessor}
This diff is collapsed.
{
"class_types": [
"none",
"dontcare",
"copy_value",
"inform"
],
"slots": {
"date": "the date for which to book the movie",
"movie": "the name of the movie",
"time": "the time for which to book the movie",
"num_tickets": "the amount of tickets for the movie",
"theatre_name": "the name of the movie theatre"
},
"noncategorical": [
"movie"
],
"boolean": [],
"label_maps": {}
}
{
"class_types": [
"none",
"dontcare",
"copy_value",
"inform"
],
"slots": {
"category": "the category of the restaurant",
"rating": "the rating of the restaurant",
"num_people": "the amount of people to book the restaurant for",
"location": "the location of the restaurant",
"restaurant_name": "the name of the restaurant",
"time": "the time for which to book the restaurant",
"date": "the date for which to book the restaurant",
"price_range": "the price range of the restaurant",
"meal": "the food type of the restaurant"
},
"noncategorical": [
"restaurant_name"
],
"boolean": [],
"label_maps": {}
}
This diff is collapsed.
{
"class_types": [
"none",
"dontcare",
"copy_value",
"inform"
],
"slots": {
"area": "the area of the restaurant",
"food": "the food type of the restaurant",
"price_range": "the price range of the restaurant"
},
"noncategorical": [
"food"
],
"categorical": [
"area",
"price_range"
],
"boolean": [],
"label_maps": {
"center": [
"centre",
"downtown",
"central",
"down town",
"middle"
],
"centre": [
"center",
"downtown",
"central",
"down town",
"middle"
],
"south": [
"southern",
"southside"
],
"north": [
"northern",
"uptown",
"northside"
],
"west": [
"western",
"westside"
],
"east": [
"eastern",
"eastside"
],
"east side": [
"eastern",
"eastside"
],
"cheap": [
"low price",
"inexpensive",
"cheaper",
"low priced",
"affordable",
"nothing too expensive",
"without costing a fortune",
"cheapest",
"good deals",
"low prices",
"afford",
"on a budget",
"fair prices",
"less expensive",
"cheapeast",
"not cost an arm and a leg"
],
"moderate": [
"moderately",
"medium priced",
"medium price",
"fair price",
"fair prices",
"reasonable",
"reasonably priced",
"mid price",
"fairly priced",
"not outrageous",
"not too expensive",
"on a budget",
"mid range",
"reasonable priced",
"less expensive",
"not too pricey",
"nothing too expensive",
"nothing cheap",
"not overpriced",
"medium",
"inexpensive"
],
"expensive": [
"high priced",
"high end",
"high class",
"high quality",
"fancy",
"upscale",
"nice",
"fine dining",
"expensively priced",
"not some cheapie"
],
"afghan": [
"afghanistan"
],
"african": [
"africa"
],
"asian oriental": [
"asian",
"oriental"
],
"australasian": [
"australian asian",
"austral asian"
],
"australian": [
"aussie"
],
"barbeque": [
"barbecue",
"bbq"
],
"basque": [
"bask"
],
"belgian": [
"belgium"
],
"british": [
"cotto"
],
"canapes": [
"canopy",
"canape",
"canap"
],
"catalan": [
"catalonian"
],
"corsican": [
"corsica"
],
"crossover": [
"cross over",
"over"
],
"gastropub": [
"gastro pub",
"gastro",
"gastropubs"
],
"hungarian": [
"goulash"
],
"indian": [
"india",
"indians",
"nirala"
],
"international": [
"all types of food"
],
"italian": [
"prezzo"
],
"jamaican": [
"jamaica"
],
"japanese": [
"sushi",
"beni hana"
],
"korean": [
"korea"
],
"lebanese": [
"lebanse"
],
"north american": [
"american",
"hamburger"
],
"portuguese": [
"portugese"
],
"seafood": [
"sea food",
"shellfish",
"fish"
],
"singaporean": [
"singapore"
],
"steakhouse": [
"steak house",
"steak"
],
"thai": [
"thailand",
"bangkok"
],
"traditional": [
"old fashioned",
"plain"
],
"turkish": [
"turkey"
],
"unusual": [
"unique and strange"
],
"venetian": [
"vanessa"
],
"vietnamese": [
"vietnam",
"thanh binh"
]
}
}
This diff is collapsed.
# coding=utf-8
#
# Copyright 2020-2022 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from tqdm import tqdm
from utils_dst import (DSTExample)
from dataset_multiwoz21 import (ACTS_DICT,
tokenize, normalize_label,
get_turn_label, delex_utt)
# Loads the dialogue_acts.json and returns a list
# of slot-value pairs.
def load_acts(input_file, boolean_slots=True):
with open(input_file) as f:
acts = json.load(f)
s_dict = {}
for d in acts:
for t in acts[d]:
# Only process, if turn has annotation
if isinstance(acts[d][t], dict):
is_22_format = False
if 'dialog_act' in acts[d][t]:
is_22_format = True
acts_list = acts[d][t]['dialog_act']
if int(t) % 2 == 0:
continue
else:
acts_list = acts[d][t]
for a in acts_list:
aa = a.lower().split('-')
if aa[1] in ['inform', 'recommend', 'select', 'book']:
for i in acts_list[a]:
s = i[0].lower()
v = i[1].lower().strip()
if s == 'none' or v == '?':
continue
if v == 'none':
if s in ['parking', 'internet']:
if boolean_slots:
v = 'true'
else:
v = s
else:
continue
if d == 'hotel' and s == 'type':
if v in ['hotel', 'hotels']:
if boolean_slots:
v = 'true'
else:
v = 'hotel'
else:
if boolean_slots:
v = 'false'
else:
v = 'guest house'
slot = aa[0] + '-' + s
if slot in ACTS_DICT:
slot = ACTS_DICT[slot]
if is_22_format:
t_key = str(int(int(t) / 2 + 1))
d_key = d
else:
t_key = t
d_key = d + '.json'
key = d_key, t_key, slot
# INFO: Since the model has no mechanism to predict
# one among several informed value candidates, we
# keep only one informed value. For fairness, we
# apply a global rule:
# ... Option 1: Keep first informed value
if key not in s_dict:
s_dict[key] = list([v])
# ... Option 2: Keep last informed value
#s_dict[key] = list([v])
return s_dict
def create_examples(input_file, acts_file, set_type, slot_list,
label_maps={},
no_label_value_repetitions=False,
swap_utterances=False,
delexicalize_sys_utts=False,
unk_token="[UNK]",
boolean_slots=True,
analyze=False):
"""Read a DST json file into a list of DSTExample."""
sys_inform_dict = load_acts(acts_file, boolean_slots)
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)
examples = []
for d_itr, dialog_id in enumerate(tqdm(input_data)):
entry = input_data[dialog_id]
utterances = entry['log']
# Collects all slot changes throughout the dialog
cumulative_labels = {slot: 'none' for slot in slot_list}
# First system utterance is empty, since multiwoz starts with user input
utt_tok_list = [[]]
mod_slots_list = [{}]
# Collect all utterances and their metadata
usr_sys_switch = True
turn_itr = 0
for utt in utterances:
# Assert that system and user utterances alternate
is_sys_utt = utt['metadata'] != {}
if usr_sys_switch == is_sys_utt:
print("WARN: Wrong order of system and user utterances. Skipping rest of dialog %s" % (dialog_id))
break
usr_sys_switch = is_sys_utt
if is_sys_utt:
turn_itr += 1
# Delexicalize sys utterance
if delexicalize_sys_utts and is_sys_utt:
inform_dict = {slot: 'none' for slot in slot_list}
for slot in slot_list:
if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict:
inform_dict[slot] = sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]
utt_tok_list.append(delex_utt(utt['text'], inform_dict, unk_token)) # normalizes utterances
else:
utt_tok_list.append(tokenize(utt['text'])) # normalizes utterances
modified_slots = {}
# If sys utt, extract metadata (identify and collect modified slots)
if is_sys_utt:
for d in utt['metadata']:
booked = utt['metadata'][d]['book']['booked']
booked_slots = {}
# Check the booked section
if booked != []:
for s in booked[0]:
booked_slots[s] = normalize_label('%s-%s' % (d, s), booked[0][s], boolean_slots) # normalize labels
# Check the semi and the inform slots
for category in ['book', 'semi']:
for s in utt['metadata'][d][category]:
cs = '%s-book_%s' % (d, s) if category == 'book' else '%s-%s' % (d, s)
value_label = normalize_label(cs, utt['metadata'][d][category][s], boolean_slots) # normalize labels
# Prefer the slot value as stored in the booked section
if s in booked_slots:
value_label = booked_slots[s]
# Remember modified slots and entire dialog state
if cs in slot_list and cumulative_labels[cs] != value_label:
modified_slots[cs] = value_label
cumulative_labels[cs] = value_label
mod_slots_list.append(modified_slots.copy())
# Form proper (usr, sys) turns
turn_itr = 0
diag_seen_slots_dict = {}
diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list}
diag_state = {slot: 'none' for slot in slot_list}
sys_utt_tok = []
usr_utt_tok = []
for i in range(1, len(utt_tok_list) - 1, 2):
sys_utt_tok_label_dict = {}
usr_utt_tok_label_dict = {}
value_dict = {}
inform_dict = {}
inform_slot_dict = {}
referral_dict = {}
class_type_dict = {}
updated_slots = {slot: 0 for slot in slot_list}
# Collect turn data
sys_utt_tok = utt_tok_list[i - 1]
usr_utt_tok = utt_tok_list[i]
turn_slots = mod_slots_list[i + 1]
guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr))
if analyze:
print("%15s %2s %s ||| %s" % (dialog_id, turn_itr, ' '.join(sys_utt_tok), ' '.join(usr_utt_tok)))
print("%15s %2s [" % (dialog_id, turn_itr), end='')
new_diag_state = diag_state.copy()
for slot in slot_list:
value_label = 'none'
if slot in turn_slots:
value_label = turn_slots[slot]
# We keep the original labels so as to not
# overlook unpointable values, as well as to not
# modify any of the original labels for test sets,
# since this would make comparison difficult.
value_dict[slot] = value_label
elif not no_label_value_repetitions and slot in diag_seen_slots_dict:
value_label = diag_seen_slots_value_dict[slot]
# Get dialog act annotations
inform_label = list(['none'])
inform_slot_dict[slot] = 0
if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict:
inform_label = list([normalize_label(slot, i, boolean_slots) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]])
inform_slot_dict[slot] = 1
elif (str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1]) in sys_inform_dict:
inform_label = list([normalize_label(slot, i, boolean_slots) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1])]])
inform_slot_dict[slot] = 1
(informed_value,
referred_slot,
usr_utt_tok_label,
class_type) = get_turn_label(value_label,
inform_label,
sys_utt_tok,
usr_utt_tok,
slot,
diag_seen_slots_value_dict,
slot_last_occurrence=True,
label_maps=label_maps)
inform_dict[slot] = informed_value
# Generally don't use span prediction on sys utterance (but inform prediction instead).
sys_utt_tok_label = [0 for _ in sys_utt_tok]
# Determine what to do with value repetitions.
# If value is unique in seen slots, then tag it, otherwise not,
# since correct slot assignment can not be guaranteed anymore.
if not no_label_value_repetitions and slot in diag_seen_slots_dict:
if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(value_label) > 1:
class_type = 'none'
usr_utt_tok_label = [0 for _ in usr_utt_tok_label]
sys_utt_tok_label_dict[slot] = sys_utt_tok_label
usr_utt_tok_label_dict[slot] = usr_utt_tok_label
if diag_seen_slots_value_dict[slot] != value_label:
updated_slots[slot] = 1
# For now, we map all occurences of unpointable slot values
# to none. However, since the labels will still suggest
# a presence of unpointable slot values, the task of the
# DST is still to find those values. It is just not
# possible to do that via span prediction on the current input.
if class_type == 'unpointable':
class_type_dict[slot] = 'none'
referral_dict[slot] = 'none'
if analyze:
if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[slot]:
print("(%s): %s, " % (slot, value_label), end='')
elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform':
# If slot has seen before and its class type did not change, label this slot a not present,
# assuming that the slot has not actually been mentioned in this turn.
# Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform,
# this must mean there is evidence in the original labels, therefore consider
# them as mentioned again.
class_type_dict[slot] = 'none'
referral_dict[slot] = 'none'
else:
class_type_dict[slot] = class_type
referral_dict[slot] = referred_slot
# Remember that this slot was mentioned during this dialog already.
if class_type != 'none':
diag_seen_slots_dict[slot] = class_type
diag_seen_slots_value_dict[slot] = value_label
new_diag_state[slot] = class_type
# Unpointable is not a valid class, therefore replace with
# some valid class for now...
if class_type == 'unpointable':
new_diag_state[slot] = 'copy_value'
if analyze:
print("]")
if not swap_utterances:
txt_a = usr_utt_tok
txt_b = sys_utt_tok
txt_a_lbl = usr_utt_tok_label_dict
txt_b_lbl = sys_utt_tok_label_dict
else:
txt_a = sys_utt_tok
txt_b = usr_utt_tok
txt_a_lbl = sys_utt_tok_label_dict
txt_b_lbl = usr_utt_tok_label_dict
examples.append(DSTExample(
guid=guid,
text_a=txt_a,
text_b=txt_b,
text_a_label=txt_a_lbl,
text_b_label=txt_b_lbl,
values=diag_seen_slots_value_dict.copy(),
inform_label=inform_dict,
inform_slot_label=inform_slot_dict,
refer_label=referral_dict,
diag_state=diag_state,
slot_update=updated_slots,
class_label=class_type_dict))
# Update some variables.
diag_state = new_diag_state.copy()
turn_itr += 1
if analyze:
print("----------------------------------------------------------------------")
return examples
def get_value_list(input_file, slot_list, boolean_slots=True):
exclude = ['none', 'dontcare']
if not boolean_slots:
exclude += ['true', 'false']
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)
value_list = {slot: {} for slot in slot_list}
for dialog_id in input_data:
for utt in input_data[dialog_id]['log']:
if utt['metadata'] != {}:
for d in utt['metadata']:
booked = utt['metadata'][d]['book']['booked']
booked_slots = {}
# Check the booked section
if booked != []:
for s in booked[0]:
booked_slots[s] = normalize_label('%s-%s' % (d, s), booked[0][s], boolean_slots)
# Check the semi and the inform slots
for category in ['book', 'semi']:
for s in utt['metadata'][d][category]:
cs = '%s-book_%s' % (d, s) if category == 'book' else '%s-%s' % (d, s)
value_label = normalize_label(cs, utt['metadata'][d][category][s], boolean_slots)
# Prefer the slot value as stored in the booked section
if s in booked_slots:
value_label = booked_slots[s]
if cs in slot_list and value_label not in exclude:
if "|" not in value_label and "<" not in value_label and ">" not in value_label:
if value_label not in value_list[cs]:
value_list[cs][value_label] = 0
value_list[cs][value_label] += 1
return value_list
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment