diff --git a/DO.example.advanced b/DO.example.advanced new file mode 100755 index 0000000000000000000000000000000000000000..d705f5088ce85227ffb0eb85e769ea473ace3139 --- /dev/null +++ b/DO.example.advanced @@ -0,0 +1,65 @@ +#!/bin/bash + +# Parameters ------------------------------------------------------ + +#TASK="sim-m" +#DATA_DIR="data/simulated-dialogue/sim-M" +#TASK="sim-r" +#DATA_DIR="data/simulated-dialogue/sim-R" +TASK="woz2" +DATA_DIR="data/woz2" +#TASK="multiwoz21" +#DATA_DIR="data/MULTIWOZ2.1" + +# Project paths etc. ---------------------------------------------- + +OUT_DIR=results +mkdir -p ${OUT_DIR} + +# Main ------------------------------------------------------------ + +for step in train dev test; do + args_add="" + if [ "$step" = "train" ]; then + args_add="--do_train --predict_type=dummy" + elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then + args_add="--do_eval --predict_type=${step}" + fi + + python3 run_dst.py \ + --task_name=${TASK} \ + --data_dir=${DATA_DIR} \ + --dataset_config=dataset_config/${TASK}.json \ + --model_type="bert" \ + --model_name_or_path="bert-base-uncased" \ + --do_lower_case \ + --learning_rate=1e-4 \ + --num_train_epochs=10 \ + --max_seq_length=180 \ + --per_gpu_train_batch_size=48 \ + --per_gpu_eval_batch_size=1 \ + --output_dir=${OUT_DIR} \ + --save_epochs=2 \ + --logging_steps=10 \ + --warmup_proportion=0.1 \ + --eval_all_checkpoints \ + --adam_epsilon=1e-6 \ + --label_value_repetitions \ + --swap_utterances \ + --append_history \ + --use_history_labels \ + --delexicalize_sys_utts \ + --class_aux_feats_inform \ + --class_aux_feats_ds \ + ${args_add} \ + 2>&1 | tee ${OUT_DIR}/${step}.log + fi + + if [ "$step" = "dev" ] || [ "$step" = "test" ]; then + python3 ~/tools/trippy_clean/metric_bert_dst.py \ + ${TASK} \ + dataset_config/${TASK}.json \ + "${OUT_DIR}/pred_res.${step}*json" \ + 2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log + fi +done diff --git a/DO.example.simple b/DO.example.simple new file mode 100755 index 0000000000000000000000000000000000000000..43bb891897e2fec9198f6fed631c50b9a1b10f72 --- /dev/null +++ b/DO.example.simple @@ -0,0 +1,59 @@ +#!/bin/bash + +# Parameters ------------------------------------------------------ + +#TASK="sim-m" +#DATA_DIR="data/simulated-dialogue/sim-M" +#TASK="sim-r" +#DATA_DIR="data/simulated-dialogue/sim-R" +TASK="woz2" +DATA_DIR="data/woz2" +#TASK="multiwoz21" +#DATA_DIR="data/MULTIWOZ2.1" + +# Project paths etc. ---------------------------------------------- + +OUT_DIR=results +mkdir -p ${OUT_DIR} + +# Main ------------------------------------------------------------ + +for step in train dev test; do + args_add="" + if [ "$step" = "train" ]; then + args_add="--do_train --predict_type=dummy" + elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then + args_add="--do_eval --predict_type=${step}" + fi + + python3 run_dst.py \ + --task_name=${TASK} \ + --data_dir=${DATA_DIR} \ + --dataset_config=dataset_config/${TASK}.json \ + --model_type="bert" \ + --model_name_or_path="bert-base-uncased" \ + --do_lower_case \ + --learning_rate=1e-4 \ + --num_train_epochs=10 \ + --max_seq_length=180 \ + --per_gpu_train_batch_size=48 \ + --per_gpu_eval_batch_size=1 \ + --output_dir=${OUT_DIR} \ + --save_epochs=2 \ + --logging_steps=10 \ + --warmup_proportion=0.1 \ + --eval_all_checkpoints \ + --adam_epsilon=1e-6 \ + --label_value_repetitions \ + ${args_add} \ + 2>&1 | tee ${OUT_DIR}/${step}.log + fi + + if [ "$step" = "dev" ] || [ "$step" = "test" ]; then + python3 ~/tools/trippy_clean/metric_bert_dst.py \ + ${TASK} \ + dataset_config/${TASK}.json \ + "${OUT_DIR}/pred_res.${step}*json" \ + 2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log + fi +done diff --git a/data_processors.py b/data_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7670df0b5a0734f05499fab24bf4a137edb652 --- /dev/null +++ b/data_processors.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json + +import dataset_woz2 +import dataset_sim +import dataset_multiwoz21 + + +class DataProcessor(object): + def __init__(self, dataset_config): + with open(dataset_config, "r", encoding='utf-8') as f: + raw_config = json.load(f) + self.class_types = raw_config['class_types'] + self.slot_list = raw_config['slots'] + self.label_maps = raw_config['label_maps'] + + def get_train_examples(self, data_dir, **args): + raise NotImplementedError() + + def get_dev_examples(self, data_dir, **args): + raise NotImplementedError() + + def get_test_examples(self, data_dir, **args): + raise NotImplementedError() + + +class Woz2Processor(DataProcessor): + def get_train_examples(self, data_dir, args): + return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_train_en.json'), + 'train', self.slot_list, self.label_maps, **args) + + def get_dev_examples(self, data_dir, args): + return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_validate_en.json'), + 'dev', self.slot_list, self.label_maps, **args) + + def get_test_examples(self, data_dir, args): + return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_test_en.json'), + 'test', self.slot_list, self.label_maps, **args) + + +class Multiwoz21Processor(DataProcessor): + def get_train_examples(self, data_dir, args): + return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'train_dials.json'), + os.path.join(data_dir, 'dialogue_acts.json'), + 'train', self.slot_list, self.label_maps, **args) + + def get_dev_examples(self, data_dir, args): + return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'val_dials.json'), + os.path.join(data_dir, 'dialogue_acts.json'), + 'dev', self.slot_list, self.label_maps, **args) + + def get_test_examples(self, data_dir, args): + return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'test_dials.json'), + os.path.join(data_dir, 'dialogue_acts.json'), + 'test', self.slot_list, self.label_maps, **args) + + +class SimProcessor(DataProcessor): + def get_train_examples(self, data_dir, args): + return dataset_sim.create_examples(os.path.join(data_dir, 'train.json'), + 'train', self.slot_list, **args) + + def get_dev_examples(self, data_dir, args): + return dataset_sim.create_examples(os.path.join(data_dir, 'dev.json'), + 'dev', self.slot_list, **args) + + def get_test_examples(self, data_dir, args): + return dataset_sim.create_examples(os.path.join(data_dir, 'test.json'), + 'test', self.slot_list, **args) + + +PROCESSORS = {"woz2": Woz2Processor, + "sim-m": SimProcessor, + "sim-r": SimProcessor, + "multiwoz21": Multiwoz21Processor} diff --git a/dataset_config/multiwoz21.json b/dataset_config/multiwoz21.json new file mode 100644 index 0000000000000000000000000000000000000000..2fff3bb30b60d48d623c9792ebf7a0caf70e2aaf --- /dev/null +++ b/dataset_config/multiwoz21.json @@ -0,0 +1,1297 @@ +{ + "class_types": [ + "none", + "dontcare", + "copy_value", + "true", + "false", + "refer", + "inform" + ] + "slots": [ + "taxi-leaveAt", + "taxi-destination", + "taxi-departure", + "taxi-arriveBy", + "restaurant-book_people", + "restaurant-book_day", + "restaurant-book_time", + "restaurant-food", + "restaurant-pricerange", + "restaurant-name", + "restaurant-area", + "hotel-book_people", + "hotel-book_day", + "hotel-book_stay", + "hotel-name", + "hotel-area", + "hotel-parking", + "hotel-pricerange", + "hotel-stars", + "hotel-internet", + "hotel-type", + "attraction-type", + "attraction-name", + "attraction-area", + "train-book_people", + "train-leaveAt", + "train-destination", + "train-day", + "train-arriveBy", + "train-departure" + ] + "label_maps": { + "guest house": [ + "guest houses" + ], + "hotel": [ + "hotels" + ], + "centre": [ + "center", + "downtown" + ], + "north": [ + "northern", + "northside", + "northend" + ], + "east": [ + "eastern", + "eastside", + "eastend" + ], + "west": [ + "western", + "westside", + "westend" + ], + "south": [ + "southern", + "southside", + "southend" + ], + "cheap": [ + "inexpensive", + "lower price", + "lower range", + "cheaply", + "cheaper", + "cheapest", + "very affordable" + ], + "moderate": [ + "moderately", + "reasonable", + "reasonably", + "affordable", + "mid range", + "mid-range", + "priced moderately", + "decently priced", + "mid price", + "mid-price", + "mid priced", + "mid-priced", + "middle price", + "medium price", + "medium priced", + "not too expensive", + "not too cheap" + ], + "expensive": [ + "high end", + "high-end", + "high class", + "high-class", + "high scale", + "high-scale", + "high price", + "high priced", + "higher price", + "fancy", + "upscale", + "nice", + "expensively", + "luxury" + ], + "0": [ + "zero" + ], + "1": [ + "one", + "just me", + "for me", + "myself", + "alone", + "me" + ], + "2": [ + "two" + ], + "3": [ + "three" + ], + "4": [ + "four" + ], + "5": [ + "five" + ], + "6": [ + "six" + ], + "7": [ + "seven" + ], + "8": [ + "eight" + ], + "9": [ + "nine" + ], + "10": [ + "ten" + ], + "11": [ + "eleven" + ], + "12": [ + "twelve" + ], + "architecture": [ + "architectural", + "architecturally", + "architect" + ], + "boat": [ + "boating", + "boats", + "camboats" + ], + "boating": [ + "boat", + "boats", + "camboats" + ], + "camboats": [ + "boating", + "boat", + "boats" + ], + "cinema": [ + "cinemas", + "movie", + "films", + "film" + ], + "college": [ + "colleges" + ], + "concert": [ + "concert hall", + "concert halls", + "concerthall", + "concerthalls", + "concerts" + ], + "concerthall": [ + "concert hall", + "concert halls", + "concerthalls", + "concerts", + "concert" + ], + "entertainment": [ + "entertaining" + ], + "gallery": [ + "museum" + ], + "gastropubs": [ + "gastropub" + ], + "multiple sports": [ + "multiple sport", + "multi sport", + "multi sports", + "sports", + "sporting" + ], + "museum": [ + "museums", + "gallery", + "galleries" + ], + "night club": [ + "night clubs", + "nightclub", + "nightclubs", + "club", + "clubs" + ], + "park": [ + "parks" + ], + "pool": [ + "swimming pool", + "swimming", + "pools", + "swimmingpool", + "water", + "swim" + ], + "sports": [ + "multiple sport", + "multi sport", + "multi sports", + "multiple sports", + "sporting" + ], + "swimming pool": [ + "swimming", + "pool", + "pools", + "swimmingpool", + "water", + "swim" + ], + "theater": [ + "theatre", + "theatres", + "theaters" + ], + "theatre": [ + "theater", + "theatres", + "theaters" + ], + "abbey pool and astroturf pitch": [ + "abbey pool and astroturf", + "abbey pool" + ], + "abbey pool and astroturf": [ + "abbey pool and astroturf pitch", + "abbey pool" + ], + "abbey pool": [ + "abbey pool and astroturf pitch", + "abbey pool and astroturf" + ], + "adc theatre": [ + "adc theater", + "adc" + ], + "adc": [ + "adc theatre", + "adc theater" + ], + "addenbrookes hospital": [ + "addenbrooke's hospital" + ], + "cafe jello gallery": [ + "cafe jello" + ], + "cambridge and county folk museum": [ + "cambridge and country folk museum", + "county folk museum" + ], + "cambridge and country folk museum": [ + "cambridge and county folk museum", + "county folk museum" + ], + "county folk museum": [ + "cambridge and county folk museum", + "cambridge and country folk museum" + ], + "cambridge arts theatre": [ + "cambridge arts theater" + ], + "cambridge book and print gallery": [ + "book and print gallery" + ], + "cambridge contemporary art": [ + "cambridge contemporary art museum", + "contemporary art museum" + ], + "cambridge contemporary art museum": [ + "cambridge contemporary art", + "contemporary art museum" + ], + "cambridge corn exchange": [ + "the cambridge corn exchange" + ], + "the cambridge corn exchange": [ + "cambridge corn exchange" + ], + "cambridge museum of technology": [ + "museum of technology" + ], + "cambridge punter": [ + "the cambridge punter", + "cambridge punters" + ], + "cambridge punters": [ + "the cambridge punter", + "cambridge punter" + ], + "the cambridge punter": [ + "cambridge punter", + "cambridge punters" + ], + "cambridge university botanic gardens": [ + "cambridge university botanical gardens", + "cambridge university botanical garden", + "cambridge university botanic garden", + "cambridge botanic gardens", + "cambridge botanical gardens", + "cambridge botanic garden", + "cambridge botanical garden", + "botanic gardens", + "botanical gardens", + "botanic garden", + "botanical garden" + ], + "cambridge botanic gardens": [ + "cambridge university botanic gardens", + "cambridge university botanical gardens", + "cambridge university botanical garden", + "cambridge university botanic garden", + "cambridge botanical gardens", + "cambridge botanic garden", + "cambridge botanical garden", + "botanic gardens", + "botanical gardens", + "botanic garden", + "botanical garden" + ], + "botanic gardens": [ + "cambridge university botanic gardens", + "cambridge university botanical gardens", + "cambridge university botanical garden", + "cambridge university botanic garden", + "cambridge botanic gardens", + "cambridge botanical gardens", + "cambridge botanic garden", + "cambridge botanical garden", + "botanical gardens", + "botanic garden", + "botanical garden" + ], + "cherry hinton village centre": [ + "cherry hinton village center" + ], + "cherry hinton village center": [ + "cherry hinton village centre" + ], + "cherry hinton hall and grounds": [ + "cherry hinton hall" + ], + "cherry hinton hall": [ + "cherry hinton hall and grounds" + ], + "cherry hinton water play": [ + "cherry hinton water play park" + ], + "cherry hinton water play park": [ + "cherry hinton water play" + ], + "christ college": [ + "christ's college", + "christs college" + ], + "christs college": [ + "christ college", + "christ's college" + ], + "churchills college": [ + "churchill's college", + "churchill college" + ], + "cineworld cinema": [ + "cineworld" + ], + "clair hall": [ + "clare hall" + ], + "clare hall": [ + "clair hall" + ], + "the fez club": [ + "fez club" + ], + "great saint marys church": [ + "great saint mary's church", + "great saint mary's", + "great saint marys" + ], + "jesus green outdoor pool": [ + "jesus green" + ], + "jesus green": [ + "jesus green outdoor pool" + ], + "kettles yard": [ + "kettle's yard" + ], + "kings college": [ + "king's college" + ], + "kings hedges learner pool": [ + "king's hedges learner pool", + "king hedges learner pool" + ], + "king hedges learner pool": [ + "king's hedges learner pool", + "kings hedges learner pool" + ], + "little saint marys church": [ + "little saint mary's church", + "little saint mary's", + "little saint marys" + ], + "mumford theatre": [ + "mumford theater" + ], + "museum of archaelogy": [ + "museum of archaeology" + ], + "museum of archaelogy and anthropology": [ + "museum of archaeology and anthropology" + ], + "peoples portraits exhibition": [ + "people's portraits exhibition at girton college", + "peoples portraits exhibition at girton college", + "people's portraits exhibition" + ], + "peoples portraits exhibition at girton college": [ + "people's portraits exhibition at girton college", + "people's portraits exhibition", + "peoples portraits exhibition" + ], + "queens college": [ + "queens' college", + "queen's college" + ], + "riverboat georgina": [ + "riverboat" + ], + "saint barnabas": [ + "saint barbabas press gallery" + ], + "saint barnabas press gallery": [ + "saint barbabas" + ], + "saint catharines college": [ + "saint catharine's college", + "saint catharine's" + ], + "saint johns college": [ + "saint john's college", + "st john's college", + "st johns college" + ], + "scott polar": [ + "scott polar museum" + ], + "scott polar museum": [ + "scott polar" + ], + "scudamores punting co": [ + "scudamore's punting co", + "scudamores punting", + "scudamore's punting", + "scudamores", + "scudamore's", + "scudamore" + ], + "scudamore": [ + "scudamore's punting co", + "scudamores punting co", + "scudamores punting", + "scudamore's punting", + "scudamores", + "scudamore's" + ], + "sheeps green and lammas land park fen causeway": [ + "sheep's green and lammas land park fen causeway", + "sheep's green and lammas land park", + "sheeps green and lammas land park", + "lammas land park", + "sheep's green", + "sheeps green" + ], + "sheeps green and lammas land park": [ + "sheep's green and lammas land park fen causeway", + "sheeps green and lammas land park fen causeway", + "sheep's green and lammas land park", + "lammas land park", + "sheep's green", + "sheeps green" + ], + "lammas land park": [ + "sheep's green and lammas land park fen causeway", + "sheeps green and lammas land park fen causeway", + "sheep's green and lammas land park", + "sheeps green and lammas land park", + "sheep's green", + "sheeps green" + ], + "sheeps green": [ + "sheep's green and lammas land park fen causeway", + "sheeps green and lammas land park fen causeway", + "sheep's green and lammas land park", + "sheeps green and lammas land park", + "lammas land park", + "sheep's green" + ], + "soul tree nightclub": [ + "soul tree night club", + "soul tree", + "soultree" + ], + "soultree": [ + "soul tree nightclub", + "soul tree night club", + "soul tree" + ], + "the man on the moon": [ + "man on the moon" + ], + "man on the moon": [ + "the man on the moon" + ], + "the junction": [ + "junction theatre", + "junction theater" + ], + "junction theatre": [ + "the junction", + "junction theater" + ], + "old schools": [ + "old school" + ], + "vue cinema": [ + "vue" + ], + "wandlebury country park": [ + "the wandlebury" + ], + "the wandlebury": [ + "wandlebury country park" + ], + "whipple museum of the history of science": [ + "whipple museum", + "history of science museum" + ], + "history of science museum": [ + "whipple museum of the history of science", + "whipple museum" + ], + "williams art and antique": [ + "william's art and antique" + ], + "alimentum": [ + "restaurant alimentum" + ], + "restaurant alimentum": [ + "alimentum" + ], + "bedouin": [ + "the bedouin" + ], + "the bedouin": [ + "bedouin" + ], + "bloomsbury restaurant": [ + "bloomsbury" + ], + "cafe uno": [ + "caffe uno", + "caffee uno" + ], + "caffe uno": [ + "cafe uno", + "caffee uno" + ], + "caffee uno": [ + "cafe uno", + "caffe uno" + ], + "cambridge lodge restaurant": [ + "cambridge lodge" + ], + "chiquito": [ + "chiquito restaurant bar", + "chiquito restaurant" + ], + "chiquito restaurant bar": [ + "chiquito restaurant", + "chiquito" + ], + "city stop restaurant": [ + "city stop" + ], + "cityr": [ + "cityroomz" + ], + "citiroomz": [ + "cityroomz" + ], + "clowns cafe": [ + "clown's cafe" + ], + "cow pizza kitchen and bar": [ + "the cow pizza kitchen and bar", + "cow pizza" + ], + "the cow pizza kitchen and bar": [ + "cow pizza kitchen and bar", + "cow pizza" + ], + "darrys cookhouse and wine shop": [ + "darry's cookhouse and wine shop", + "darry's cookhouse", + "darrys cookhouse" + ], + "de luca cucina and bar": [ + "de luca cucina and bar riverside brasserie", + "luca cucina and bar", + "de luca cucina", + "luca cucina" + ], + "de luca cucina and bar riverside brasserie": [ + "de luca cucina and bar", + "luca cucina and bar", + "de luca cucina", + "luca cucina" + ], + "da vinci pizzeria": [ + "da vinci pizza", + "da vinci" + ], + "don pasquale pizzeria": [ + "don pasquale pizza", + "don pasquale", + "pasquale pizzeria", + "pasquale pizza" + ], + "efes": [ + "efes restaurant" + ], + "efes restaurant": [ + "efes" + ], + "fitzbillies restaurant": [ + "fitzbillies" + ], + "frankie and bennys": [ + "frankie and benny's" + ], + "funky": [ + "funky fun house" + ], + "funky fun house": [ + "funky" + ], + "gardenia": [ + "the gardenia" + ], + "the gardenia": [ + "gardenia" + ], + "grafton hotel restaurant": [ + "the grafton hotel", + "grafton hotel" + ], + "the grafton hotel": [ + "grafton hotel restaurant", + "grafton hotel" + ], + "grafton hotel": [ + "grafton hotel restaurant", + "the grafton hotel" + ], + "hotel du vin and bistro": [ + "hotel du vin", + "du vin" + ], + "Kohinoor": [ + "kohinoor", + "the kohinoor" + ], + "kohinoor": [ + "the kohinoor" + ], + "the kohinoor": [ + "kohinoor" + ], + "lan hong house": [ + "lan hong", + "ian hong house", + "ian hong" + ], + "ian hong": [ + "lan hong house", + "lan hong", + "ian hong house" + ], + "lovel": [ + "the lovell lodge", + "lovell lodge" + ], + "lovell lodge": [ + "lovell" + ], + "the lovell lodge": [ + "lovell lodge", + "lovell" + ], + "mahal of cambridge": [ + "mahal" + ], + "mahal": [ + "mahal of cambridge" + ], + "maharajah tandoori restaurant": [ + "maharajah tandoori" + ], + "the maharajah tandoor": [ + "maharajah tandoori restaurant", + "maharajah tandoori" + ], + "meze bar": [ + "meze bar restaurant", + "the meze bar" + ], + "meze bar restaurant": [ + "the meze bar", + "meze bar" + ], + "the meze bar": [ + "meze bar restaurant", + "meze bar" + ], + "michaelhouse cafe": [ + "michael house cafe" + ], + "midsummer house restaurant": [ + "midsummer house" + ], + "missing sock": [ + "the missing sock" + ], + "the missing sock": [ + "missing sock" + ], + "nandos": [ + "nando's city centre", + "nando's city center", + "nandos city centre", + "nandos city center", + "nando's" + ], + "nandos city centre": [ + "nando's city centre", + "nando's city center", + "nandos city center", + "nando's", + "nandos" + ], + "oak bistro": [ + "the oak bistro" + ], + "the oak bistro": [ + "oak bistro" + ], + "one seven": [ + "restaurant one seven" + ], + "restaurant one seven": [ + "one seven" + ], + "river bar steakhouse and grill": [ + "the river bar steakhouse and grill", + "the river bar steakhouse", + "river bar steakhouse" + ], + "the river bar steakhouse and grill": [ + "river bar steakhouse and grill", + "the river bar steakhouse", + "river bar steakhouse" + ], + "pipasha restaurant": [ + "pipasha" + ], + "pizza hut city centre": [ + "pizza hut city center" + ], + "pizza hut fenditton": [ + "pizza hut fen ditton", + "pizza express fen ditton" + ], + "restaurant two two": [ + "two two", + "restaurant 22" + ], + "saffron brasserie": [ + "saffron" + ], + "saint johns chop house": [ + "saint john's chop house", + "st john's chop house", + "st johns chop house" + ], + "sesame restaurant and bar": [ + "sesame restaurant", + "sesame" + ], + "shanghai": [ + "shanghai family restaurant" + ], + "shanghai family restaurant": [ + "shanghai" + ], + "sitar": [ + "sitar tandoori" + ], + "sitar tandoori": [ + "sitar" + ], + "slug and lettuce": [ + "the slug and lettuce" + ], + "the slug and lettuce": [ + "slug and lettuce" + ], + "st johns chop house": [ + "saint john's chop house", + "st john's chop house", + "saint johns chop house" + ], + "stazione restaurant and coffee bar": [ + "stazione restaurant", + "stazione" + ], + "thanh binh": [ + "thanh", + "binh" + ], + "thanh": [ + "thanh binh", + "binh" + ], + "binh": [ + "thanh binh", + "thanh" + ], + "the hotpot": [ + "the hotspot", + "hotpot", + "hotspot" + ], + "hotpot": [ + "the hotpot", + "the hotpot", + "hotspot" + ], + "the lucky star": [ + "lucky star" + ], + "lucky star": [ + "the lucky star" + ], + "the peking restaurant: ": [ + "peking restaurant" + ], + "the varsity restaurant": [ + "varsity restaurant", + "the varsity", + "varsity" + ], + "two two": [ + "restaurant two two", + "restaurant 22" + ], + "restaurant 22": [ + "restaurant two two", + "two two" + ], + "zizzi cambridge": [ + "zizzi" + ], + "american": [ + "americas" + ], + "asian oriental": [ + "asian", + "oriental" + ], + "australian": [ + "australasian" + ], + "barbeque": [ + "barbecue", + "bbq" + ], + "corsica": [ + "corsican" + ], + "indian": [ + "tandoori" + ], + "italian": [ + "pizza", + "pizzeria" + ], + "japanese": [ + "sushi" + ], + "latin american": [ + "latin-american", + "latin" + ], + "malaysian": [ + "malay" + ], + "middle eastern": [ + "middle-eastern" + ], + "traditional american": [ + "american" + ], + "modern american": [ + "american modern", + "american" + ], + "modern european": [ + "european modern", + "european" + ], + "north american": [ + "north-american", + "american" + ], + "portuguese": [ + "portugese" + ], + "portugese": [ + "portuguese" + ], + "seafood": [ + "sea food" + ], + "singaporean": [ + "singapore" + ], + "steakhouse": [ + "steak house", + "steak" + ], + "the americas": [ + "american", + "americas" + ], + "a and b guest house": [ + "a & b guest house", + "a and b", + "a & b" + ], + "the acorn guest house": [ + "acorn guest house", + "acorn" + ], + "acorn guest house": [ + "the acorn guest house", + "acorn" + ], + "alexander bed and breakfast": [ + "alexander" + ], + "allenbell": [ + "the allenbell" + ], + "the allenbell": [ + "allenbell" + ], + "alpha-milton guest house": [ + "the alpha-milton", + "alpha-milton" + ], + "the alpha-milton": [ + "alpha-milton guest house", + "alpha-milton" + ], + "arbury lodge guest house": [ + "arbury lodge", + "arbury" + ], + "archway house": [ + "archway" + ], + "ashley hotel": [ + "the ashley hotel", + "ashley" + ], + "the ashley hotel": [ + "ashley hotel", + "ashley" + ], + "aylesbray lodge guest house": [ + "aylesbray lodge", + "aylesbray" + ], + "aylesbray lodge guest": [ + "aylesbray lodge guest house", + "aylesbray lodge", + "aylesbray" + ], + "alesbray lodge guest house": [ + "aylesbray lodge guest house", + "aylesbray lodge", + "aylesbray" + ], + "alyesbray lodge hotel": [ + "aylesbray lodge guest house", + "aylesbray lodge", + "aylesbray" + ], + "bridge guest house": [ + "bridge house" + ], + "cambridge belfry": [ + "the cambridge belfry", + "belfry hotel", + "belfry" + ], + "the cambridge belfry": [ + "cambridge belfry", + "belfry hotel", + "belfry" + ], + "belfry hotel": [ + "the cambridge belfry", + "cambridge belfry", + "belfry" + ], + "carolina bed and breakfast": [ + "carolina" + ], + "city centre north": [ + "city centre north bed and breakfast" + ], + "north b and b": [ + "city centre north bed and breakfast" + ], + "city centre north b and b": [ + "city centre north bed and breakfast" + ], + "el shaddia guest house": [ + "el shaddai guest house", + "el shaddai", + "el shaddia" + ], + "el shaddai guest house": [ + "el shaddia guest house", + "el shaddai", + "el shaddia" + ], + "express by holiday inn cambridge": [ + "express by holiday inn", + "holiday inn cambridge", + "holiday inn" + ], + "holiday inn": [ + "express by holiday inn cambridge", + "express by holiday inn", + "holiday inn cambridge" + ], + "finches bed and breakfast": [ + "finches" + ], + "gonville hotel": [ + "gonville" + ], + "hamilton lodge": [ + "the hamilton lodge", + "hamilton" + ], + "the hamilton lodge": [ + "hamilton lodge", + "hamilton" + ], + "hobsons house": [ + "hobson's house", + "hobson's" + ], + "huntingdon marriott hotel": [ + "huntington marriott hotel", + "huntington marriot hotel", + "huntingdon marriot hotel", + "huntington marriott", + "huntingdon marriott", + "huntington marriot", + "huntingdon marriot", + "huntington", + "huntingdon" + ], + "kirkwood": [ + "kirkwood house" + ], + "kirkwood house": [ + "kirkwood" + ], + "lensfield hotel": [ + "the lensfield hotel", + "lensfield" + ], + "the lensfield hotel": [ + "lensfield hotel", + "lensfield" + ], + "leverton house": [ + "leverton" + ], + "marriot hotel": [ + "marriott hotel", + "marriott" + ], + "rosas bed and breakfast": [ + "rosa's bed and breakfast", + "rosa's", + "rosas" + ], + "university arms hotel": [ + "university arms" + ], + "warkworth house": [ + "warkworth hotel", + "warkworth" + ], + "warkworth hotel": [ + "warkworth house", + "warkworth" + ], + "wartworth": [ + "warkworth house", + "warkworth hotel", + "warkworth" + ], + "worth house": [ + "the worth house" + ], + "the worth house": [ + "worth house" + ], + "birmingham new street": [ + "birmingham new street train station" + ], + "birmingham new street train station": [ + "birmingham new street" + ], + "bishops stortford": [ + "bishops stortford train station" + ], + "bishops stortford train station": [ + "bishops stortford" + ], + "broxbourne": [ + "broxbourne train station" + ], + "broxbourne train station": [ + "broxbourne" + ], + "cambridge": [ + "cambridge train station" + ], + "cambridge train station": [ + "cambridge" + ], + "ely": [ + "ely train station" + ], + "ely train station": [ + "ely" + ], + "kings lynn": [ + "king's lynn", + "king's lynn train station", + "kings lynn train station" + ], + "kings lynn train station": [ + "kings lynn", + "king's lynn", + "king's lynn train station" + ], + "leicester": [ + "leicester train station" + ], + "leicester train station": [ + "leicester" + ], + "london kings cross": [ + "kings cross", + "king's cross", + "london king's cross", + "kings cross train station", + "king's cross train station", + "london king's cross train station", + "london kings cross train station" + ], + "london kings cross train station": [ + "kings cross", + "king's cross", + "london king's cross", + "london kings cross", + "kings cross train station", + "king's cross train station", + "london king's cross train station" + ], + "london liverpool": [ + "liverpool street", + "london liverpool street", + "london liverpool train station", + "liverpool street train station", + "london liverpool street train station" + ], + "london liverpool street": [ + "london liverpool", + "liverpool street", + "london liverpool train station", + "liverpool street train station", + "london liverpool street train station" + ], + "london liverpool street train station": [ + "london liverpool", + "liverpool street", + "london liverpool street", + "london liverpool train station", + "liverpool street train station" + ], + "norwich": [ + "norwich train station" + ], + "norwich train station": [ + "norwich" + ], + "peterborough": [ + "peterborough train station" + ], + "peterborough train station": [ + "peterborough" + ], + "stansted airport": [ + "stansted airport train station" + ], + "stansted airport train station": [ + "stansted airport" + ], + "stevenage": [ + "stevenage train station" + ], + "stevenage train station": [ + "stevenage" + ] + } +} diff --git a/dataset_config/sim-m.json b/dataset_config/sim-m.json new file mode 100644 index 0000000000000000000000000000000000000000..34793d75da2a58876eb98ebf0aa8441069d6a300 --- /dev/null +++ b/dataset_config/sim-m.json @@ -0,0 +1,16 @@ +{ + "class_types": [ + "none", + "dontcare", + "copy_value", + "inform" + ], + "slots": [ + "date", + "movie", + "time", + "num_tickets", + "theatre_name" + ], + "label_maps": {} +} diff --git a/dataset_config/sim-r.json b/dataset_config/sim-r.json new file mode 100644 index 0000000000000000000000000000000000000000..cc635e174e8ef50d33abe5b6ee12f868aca1155a --- /dev/null +++ b/dataset_config/sim-r.json @@ -0,0 +1,20 @@ +{ + "class_types": [ + "none", + "dontcare", + "copy_value", + "inform" + ], + "slots": [ + "category", + "rating", + "num_people", + "location", + "restaurant_name", + "time", + "date", + "price_range", + "meal" + ], + "label_maps": {} +} diff --git a/dataset_config/woz2.json b/dataset_config/woz2.json new file mode 100644 index 0000000000000000000000000000000000000000..361446e7c9936f99643e8e91ba449578fa26c868 --- /dev/null +++ b/dataset_config/woz2.json @@ -0,0 +1,212 @@ +{ + "class_types": [ + "none", + "dontcare", + "copy_value", + "inform" + ], + "slots": [ + "area", + "food", + "price_range" + ], + "label_maps": { + "center": [ + "centre", + "downtown", + "central", + "down town", + "middle" + ], + "south": [ + "southern", + "southside" + ], + "north": [ + "northern", + "uptown", + "northside" + ], + "west": [ + "western", + "westside" + ], + "east": [ + "eastern", + "eastside" + ], + "east side": [ + "eastern", + "eastside" + ], + "cheap": [ + "low price", + "inexpensive", + "cheaper", + "low priced", + "affordable", + "nothing too expensive", + "without costing a fortune", + "cheapest", + "good deals", + "low prices", + "afford", + "on a budget", + "fair prices", + "less expensive", + "cheapeast", + "not cost an arm and a leg" + ], + "moderate": [ + "moderately", + "medium priced", + "medium price", + "fair price", + "fair prices", + "reasonable", + "reasonably priced", + "mid price", + "fairly priced", + "not outrageous", + "not too expensive", + "on a budget", + "mid range", + "reasonable priced", + "less expensive", + "not too pricey", + "nothing too expensive", + "nothing cheap", + "not overpriced", + "medium", + "inexpensive" + ], + "expensive": [ + "high priced", + "high end", + "high class", + "high quality", + "fancy", + "upscale", + "nice", + "fine dining", + "expensively priced" + ], + "afghan": [ + "afghanistan" + ], + "african": [ + "africa" + ], + "asian oriental": [ + "asian", + "oriental" + ], + "australasian": [ + "australian asian", + "austral asian" + ], + "australian": [ + "aussie" + ], + "barbeque": [ + "barbecue", + "bbq" + ], + "basque": [ + "bask" + ], + "belgian": [ + "belgium" + ], + "british": [ + "cotto" + ], + "canapes": [ + "canopy", + "canape", + "canap" + ], + "catalan": [ + "catalonian" + ], + "corsican": [ + "corsica" + ], + "crossover": [ + "cross over", + "over" + ], + "gastropub": [ + "gastro pub", + "gastro", + "gastropubs" + ], + "hungarian": [ + "goulash" + ], + "indian": [ + "india", + "indians", + "nirala" + ], + "international": [ + "all types of food" + ], + "italian": [ + "prezzo" + ], + "jamaican": [ + "jamaica" + ], + "japanese": [ + "sushi", + "beni hana" + ], + "korean": [ + "korea" + ], + "lebanese": [ + "lebanse" + ], + "north american": [ + "american", + "hamburger" + ], + "portuguese": [ + "portugese" + ], + "seafood": [ + "sea food", + "shellfish", + "fish" + ], + "singaporean": [ + "singapore" + ], + "steakhouse": [ + "steak house", + "steak" + ], + "thai": [ + "thailand", + "bangkok" + ], + "traditional": [ + "old fashioned", + "plain" + ], + "turkish": [ + "turkey" + ], + "unusual": [ + "unique and strange" + ], + "venetian": [ + "vanessa" + ], + "vietnamese": [ + "vietnam", + "thanh binh" + ] + } +} diff --git a/dataset_multiwoz21.py b/dataset_multiwoz21.py new file mode 100644 index 0000000000000000000000000000000000000000..9751f5c8a945c8f84004ac2a9eea3edab9d44852 --- /dev/null +++ b/dataset_multiwoz21.py @@ -0,0 +1,541 @@ +# coding=utf-8 +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re + +from utils_dst import (DSTExample, convert_to_unicode) + + +# Required for mapping slot names in dialogue_acts.json file +# to proper designations. +ACTS_DICT = {'taxi-depart': 'taxi-departure', + 'taxi-dest': 'taxi-destination', + 'taxi-leave': 'taxi-leaveAt', + 'taxi-arrive': 'taxi-arriveBy', + 'train-depart': 'train-departure', + 'train-dest': 'train-destination', + 'train-leave': 'train-leaveAt', + 'train-arrive': 'train-arriveBy', + 'train-people': 'train-book_people', + 'restaurant-price': 'restaurant-pricerange', + 'restaurant-people': 'restaurant-book_people', + 'restaurant-day': 'restaurant-book_day', + 'restaurant-time': 'restaurant-book_time', + 'hotel-price': 'hotel-pricerange', + 'hotel-people': 'hotel-book_people', + 'hotel-day': 'hotel-book_day', + 'hotel-stay': 'hotel-book_stay', + 'booking-people': 'booking-book_people', + 'booking-day': 'booking-book_day', + 'booking-stay': 'booking-book_stay', + 'booking-time': 'booking-book_time', +} + + +LABEL_MAPS = {} # Loaded from file + + +# Loads the dialogue_acts.json and returns a list +# of slot-value pairs. +def load_acts(input_file): + with open(input_file) as f: + acts = json.load(f) + s_dict = {} + for d in acts: + for t in acts[d]: + # Only process, if turn has annotation + if isinstance(acts[d][t], dict): + for a in acts[d][t]: + aa = a.lower().split('-') + if aa[1] == 'inform' or aa[1] == 'recommend' or aa[1] == 'select' or aa[1] == 'book': + for i in acts[d][t][a]: + s = i[0].lower() + v = i[1].lower().strip() + if s == 'none' or v == '?' or v == 'none': + continue + slot = aa[0] + '-' + s + if slot in ACTS_DICT: + slot = ACTS_DICT[slot] + key = d + '.json', t, slot + # In case of multiple mentioned values... + # ... Option 1: Keep first informed value + if key not in s_dict: + s_dict[key] = list([v]) + # ... Option 2: Keep last informed value + #s_dict[key] = list([v]) + return s_dict + + +def normalize_time(text): + text = re.sub("(\d{1})(a\.?m\.?|p\.?m\.?)", r"\1 \2", text) # am/pm without space + text = re.sub("(^| )(\d{1,2}) (a\.?m\.?|p\.?m\.?)", r"\1\2:00 \3", text) # am/pm short to long form + text = re.sub("(^| )(at|from|by|until|after) ?(\d{1,2}) ?(\d{2})([^0-9]|$)", r"\1\2 \3:\4\5", text) # Missing separator + text = re.sub("(^| )(\d{2})[;.,](\d{2})", r"\1\2:\3", text) # Wrong separator + text = re.sub("(^| )(at|from|by|until|after) ?(\d{1,2})([;., ]|$)", r"\1\2 \3:00\4", text) # normalize simple full hour time + text = re.sub("(^| )(\d{1}:\d{2})", r"\g<1>0\2", text) # Add missing leading 0 + # Map 12 hour times to 24 hour times + text = re.sub("(\d{2})(:\d{2}) ?p\.?m\.?", lambda x: str(int(x.groups()[0]) + 12 if int(x.groups()[0]) < 12 else int(x.groups()[0])) + x.groups()[1], text) + text = re.sub("(^| )24:(\d{2})", r"\g<1>00:\2", text) # Correct times that use 24 as hour + return text + + +def normalize_text(text): + text = normalize_time(text) + text = re.sub("n't", " not", text) + text = re.sub("(^| )zero(-| )star([s.,? ]|$)", r"\g<1>0 star\3", text) + text = re.sub("(^| )one(-| )star([s.,? ]|$)", r"\g<1>1 star\3", text) + text = re.sub("(^| )two(-| )star([s.,? ]|$)", r"\g<1>2 star\3", text) + text = re.sub("(^| )three(-| )star([s.,? ]|$)", r"\g<1>3 star\3", text) + text = re.sub("(^| )four(-| )star([s.,? ]|$)", r"\g<1>4 star\3", text) + text = re.sub("(^| )five(-| )star([s.,? ]|$)", r"\g<1>5 star\3", text) + text = re.sub("archaelogy", "archaeology", text) # Systematic typo + text = re.sub("guesthouse", "guest house", text) # Normalization + text = re.sub("(^| )b ?& ?b([.,? ]|$)", r"\1bed and breakfast\2", text) # Normalization + text = re.sub("bed & breakfast", "bed and breakfast", text) # Normalization + return text + + +# This should only contain label normalizations. All other mappings should +# be defined in LABEL_MAPS. +def normalize_label(slot, value_label): + # Normalization of empty slots + if value_label == '' or value_label == "not mentioned": + return "none" + + # Normalization of time slots + if "leaveAt" in slot or "arriveBy" in slot or slot == 'restaurant-book_time': + return normalize_time(value_label) + + # Normalization + if "type" in slot or "name" in slot or "destination" in slot or "departure" in slot: + value_label = re.sub("guesthouse", "guest house", value_label) + + # Map to boolean slots + if slot == 'hotel-parking' or slot == 'hotel-internet': + if value_label == 'yes' or value_label == 'free': + return "true" + if value_label == "no": + return "false" + if slot == 'hotel-type': + if value_label == "hotel": + return "true" + if value_label == "guest house": + return "false" + + return value_label + + +def get_token_pos(tok_list, value_label): + find_pos = [] + found = False + label_list = [item for item in map(str.strip, re.split("(\W+)", value_label)) if len(item) > 0] + len_label = len(label_list) + for i in range(len(tok_list) + 1 - len_label): + if tok_list[i:i + len_label] == label_list: + find_pos.append((i, i + len_label)) # start, exclusive_end + found = True + return found, find_pos + + +def check_label_existence(value_label, usr_utt_tok): + in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label) + # If no hit even though there should be one, check for value label variants + if not in_usr and value_label in LABEL_MAPS: + for value_label_variant in LABEL_MAPS[value_label]: + in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label_variant) + if in_usr: + break + return in_usr, usr_pos + + +def check_slot_referral(value_label, slot, seen_slots): + referred_slot = 'none' + if slot == 'hotel-stars' or slot == 'hotel-internet' or slot == 'hotel-parking': + return referred_slot + for s in seen_slots: + # Avoid matches for slots that share values with different meaning. + # hotel-internet and -parking are handled separately as Boolean slots. + if s == 'hotel-stars' or s == 'hotel-internet' or s == 'hotel-parking': + continue + if re.match("(hotel|restaurant)-book_people", s) and slot == 'hotel-book_stay': + continue + if re.match("(hotel|restaurant)-book_people", slot) and s == 'hotel-book_stay': + continue + if slot != s and (slot not in seen_slots or seen_slots[slot] != value_label): + if seen_slots[s] == value_label: + referred_slot = s + break + elif value_label in LABEL_MAPS: + for value_label_variant in LABEL_MAPS[value_label]: + if seen_slots[s] == value_label_variant: + referred_slot = s + break + return referred_slot + + +def is_in_list(tok, value): + found = False + tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0] + value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0] + tok_len = len(tok_list) + value_len = len(value_list) + for i in range(tok_len + 1 - value_len): + if tok_list[i:i + value_len] == value_list: + found = True + break + return found + + +def delex_utt(utt, values): + utt_norm = tokenize(utt) + for s, vals in values.items(): + for v in vals: + if v != 'none': + v_norm = tokenize(v) + v_len = len(v_norm) + for i in range(len(utt_norm) + 1 - v_len): + if utt_norm[i:i + v_len] == v_norm: + utt_norm[i:i + v_len] = ['[UNK]'] * v_len + return utt_norm + + +# Fuzzy matching to label informed slot values +def check_slot_inform(value_label, inform_label): + result = False + informed_value = 'none' + vl = ' '.join(tokenize(value_label)) + for il in inform_label: + if vl == il: + result = True + elif is_in_list(il, vl): + result = True + elif is_in_list(vl, il): + result = True + elif il in LABEL_MAPS: + for il_variant in LABEL_MAPS[il]: + if vl == il_variant: + result = True + break + elif is_in_list(il_variant, vl): + result = True + break + elif is_in_list(vl, il_variant): + result = True + break + elif vl in LABEL_MAPS: + for value_label_variant in LABEL_MAPS[vl]: + if value_label_variant == il: + result = True + break + elif is_in_list(il, value_label_variant): + result = True + break + elif is_in_list(value_label_variant, il): + result = True + break + if result: + informed_value = il + break + return result, informed_value + + +def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, seen_slots, slot_last_occurrence): + usr_utt_tok_label = [0 for _ in usr_utt_tok] + informed_value = 'none' + referred_slot = 'none' + if value_label == 'none' or value_label == 'dontcare' or value_label == 'true' or value_label == 'false': + class_type = value_label + else: + in_usr, usr_pos = check_label_existence(value_label, usr_utt_tok) + if in_usr: + class_type = 'copy_value' + if slot_last_occurrence: + (s, e) = usr_pos[-1] + for i in range(s, e): + usr_utt_tok_label[i] = 1 + else: + for (s, e) in usr_pos: + for i in range(s, e): + usr_utt_tok_label[i] = 1 + else: + is_informed, informed_value = check_slot_inform(value_label, inform_label) + if is_informed: + class_type = 'inform' + else: + referred_slot = check_slot_referral(value_label, slot, seen_slots) + if referred_slot != 'none': + class_type = 'refer' + else: + class_type = 'unpointable' + return informed_value, referred_slot, usr_utt_tok_label, class_type + + +def tokenize(utt): + utt_lower = convert_to_unicode(utt).lower() + utt_lower = normalize_text(utt_lower) + utt_tok = [tok for tok in map(str.strip, re.split("(\W+)", utt_lower)) if len(tok) > 0] + return utt_tok + + +def create_examples(input_file, acts_file, set_type, slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + analyze=False): + """Read a DST json file into a list of DSTExample.""" + + sys_inform_dict = load_acts(acts_file) + + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + + global LABEL_MAPS + LABEL_MAPS = label_maps + + examples = [] + for dialog_id in input_data: + entry = input_data[dialog_id] + utterances = entry['log'] + + # Collects all slot changes throughout the dialog + cumulative_labels = {slot: 'none' for slot in slot_list} + + # First system utterance is empty, since multiwoz starts with user input + utt_tok_list = [[]] + mod_slots_list = [{}] + + # Collect all utterances and their metadata + usr_sys_switch = True + turn_itr = 0 + for utt in utterances: + # Assert that system and user utterances alternate + is_sys_utt = utt['metadata'] != {} + if usr_sys_switch == is_sys_utt: + print("WARN: Wrong order of system and user utterances. Skipping rest of dialog %s" % (dialog_id)) + break + usr_sys_switch = is_sys_utt + + if is_sys_utt: + turn_itr += 1 + + # Delexicalize sys utterance + if delexicalize_sys_utts and is_sys_utt: + inform_dict = {slot: 'none' for slot in slot_list} + for slot in slot_list: + if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict: + inform_dict[slot] = sys_inform_dict[(str(dialog_id), str(turn_itr), slot)] + utt_tok_list.append(delex_utt(utt['text'], inform_dict)) # normalize utterances + else: + utt_tok_list.append(tokenize(utt['text'])) # normalize utterances + + modified_slots = {} + + # If sys utt, extract metadata (identify and collect modified slots) + if is_sys_utt: + for d in utt['metadata']: + booked = utt['metadata'][d]['book']['booked'] + booked_slots = {} + # Check the booked section + if booked != []: + for s in booked[0]: + booked_slots[s] = normalize_label('%s-%s' % (d, s), booked[0][s]) # normalize labels + # Check the semi and the inform slots + for category in ['book', 'semi']: + for s in utt['metadata'][d][category]: + cs = '%s-book_%s' % (d, s) if category == 'book' else '%s-%s' % (d, s) + value_label = normalize_label(cs, utt['metadata'][d][category][s]) # normalize labels + # Prefer the slot value as stored in the booked section + if s in booked_slots: + value_label = booked_slots[s] + # Remember modified slots and entire dialog state + if cs in slot_list and cumulative_labels[cs] != value_label: + modified_slots[cs] = value_label + cumulative_labels[cs] = value_label + + mod_slots_list.append(modified_slots.copy()) + + # Form proper (usr, sys) turns + turn_itr = 0 + diag_seen_slots_dict = {} + diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list} + diag_state = {slot: 'none' for slot in slot_list} + sys_utt_tok = [] + usr_utt_tok = [] + hst_utt_tok = [] + hst_utt_tok_label_dict = {slot: [] for slot in slot_list} + for i in range(1, len(utt_tok_list) - 1, 2): + sys_utt_tok_label_dict = {} + usr_utt_tok_label_dict = {} + value_dict = {} + inform_dict = {} + inform_slot_dict = {} + referral_dict = {} + class_type_dict = {} + + # Collect turn data + if append_history: + if swap_utterances: + hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok + else: + hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok + sys_utt_tok = utt_tok_list[i - 1] + usr_utt_tok = utt_tok_list[i] + turn_slots = mod_slots_list[i + 1] + + guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr)) + + if analyze: + print("%15s %2s %s ||| %s" % (dialog_id, turn_itr, ' '.join(sys_utt_tok), ' '.join(usr_utt_tok))) + print("%15s %2s [" % (dialog_id, turn_itr), end='') + + new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy() + new_diag_state = diag_state.copy() + for slot in slot_list: + value_label = 'none' + if slot in turn_slots: + value_label = turn_slots[slot] + # We keep the original labels so as to not + # overlook unpointable values, as well as to not + # modify any of the original labels for test sets, + # since this would make comparison difficult. + value_dict[slot] = value_label + elif label_value_repetitions and slot in diag_seen_slots_dict: + value_label = diag_seen_slots_value_dict[slot] + + # Get dialog act annotations + inform_label = list(['none']) + if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict: + inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]]) + elif (str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1]) in sys_inform_dict: + inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1])]]) + + (informed_value, + referred_slot, + usr_utt_tok_label, + class_type) = get_turn_label(value_label, + inform_label, + sys_utt_tok, + usr_utt_tok, + slot, + diag_seen_slots_value_dict, + slot_last_occurrence=True) + + inform_dict[slot] = informed_value + if informed_value != 'none': + inform_slot_dict[slot] = 1 + else: + inform_slot_dict[slot] = 0 + + # Generally don't use span prediction on sys utterance (but inform prediction instead). + sys_utt_tok_label = [0 for _ in sys_utt_tok] + + # Determine what to do with value repetitions. + # If value is unique in seen slots, then tag it, otherwise not, + # since correct slot assignment can not be guaranteed anymore. + if label_value_repetitions and slot in diag_seen_slots_dict: + if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(value_label) > 1: + class_type = 'none' + usr_utt_tok_label = [0 for _ in usr_utt_tok_label] + + sys_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label_dict[slot] = usr_utt_tok_label + + if append_history: + if use_history_labels: + if swap_utterances: + new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[slot] + else: + new_hst_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot] + else: + new_hst_utt_tok_label_dict[slot] = [0 for _ in sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]] + + # For now, we map all occurences of unpointable slot values + # to none. However, since the labels will still suggest + # a presence of unpointable slot values, the task of the + # DST is still to find those values. It is just not + # possible to do that via span prediction on the current input. + if class_type == 'unpointable': + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + if analyze: + if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[slot]: + print("(%s): %s, " % (slot, value_label), end='') + elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform': + # If slot has seen before and its class type did not change, label this slot a not present, + # assuming that the slot has not actually been mentioned in this turn. + # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, + # this must mean there is evidence in the original labels, therefore consider + # them as mentioned again. + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + else: + class_type_dict[slot] = class_type + referral_dict[slot] = referred_slot + # Remember that this slot was mentioned during this dialog already. + if class_type != 'none': + diag_seen_slots_dict[slot] = class_type + diag_seen_slots_value_dict[slot] = value_label + new_diag_state[slot] = class_type + # Unpointable is not a valid class, therefore replace with + # some valid class for now... + if class_type == 'unpointable': + new_diag_state[slot] = 'copy_value' + + if analyze: + print("]") + + if swap_utterances: + txt_a = usr_utt_tok + txt_b = sys_utt_tok + txt_a_lbl = usr_utt_tok_label_dict + txt_b_lbl = sys_utt_tok_label_dict + else: + txt_a = sys_utt_tok + txt_b = usr_utt_tok + txt_a_lbl = sys_utt_tok_label_dict + txt_b_lbl = usr_utt_tok_label_dict + examples.append(DSTExample( + guid=guid, + text_a=txt_a, + text_b=txt_b, + history=hst_utt_tok, + text_a_label=txt_a_lbl, + text_b_label=txt_b_lbl, + history_label=hst_utt_tok_label_dict, + values=diag_seen_slots_value_dict.copy(), + inform_label=inform_dict, + inform_slot_label=inform_slot_dict, + refer_label=referral_dict, + diag_state=diag_state, + class_label=class_type_dict)) + + # Update some variables. + hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy() + diag_state = new_diag_state.copy() + + turn_itr += 1 + + if analyze: + print("----------------------------------------------------------------------") + + return examples diff --git a/dataset_sim.py b/dataset_sim.py new file mode 100644 index 0000000000000000000000000000000000000000..e30ed09cd9c289778dcd0cb78546780c7df28493 --- /dev/null +++ b/dataset_sim.py @@ -0,0 +1,230 @@ +# coding=utf-8 +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from utils_dst import (DSTExample) + + +def dialogue_state_to_sv_dict(sv_list): + sv_dict = {} + for d in sv_list: + sv_dict[d['slot']] = d['value'] + return sv_dict + + +def get_token_and_slot_label(turn): + if 'system_utterance' in turn: + sys_utt_tok = turn['system_utterance']['tokens'] + sys_slot_label = turn['system_utterance']['slots'] + else: + sys_utt_tok = [] + sys_slot_label = [] + + usr_utt_tok = turn['user_utterance']['tokens'] + usr_slot_label = turn['user_utterance']['slots'] + return sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label + + +def get_tok_label(prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok, + sys_slot_label, usr_utt_tok, usr_slot_label, dial_id, + turn_id, slot_last_occurrence=True): + """The position of the last occurrence of the slot value will be used.""" + sys_utt_tok_label = [0 for _ in sys_utt_tok] + usr_utt_tok_label = [0 for _ in usr_utt_tok] + if slot_type not in cur_ds_dict: + class_type = 'none' + else: + value = cur_ds_dict[slot_type] + if value == 'dontcare' and (slot_type not in prev_ds_dict or prev_ds_dict[slot_type] != 'dontcare'): + # Only label dontcare at its first occurrence in the dialog + class_type = 'dontcare' + else: # If not none or dontcare, we have to identify whether + # there is a span, or if the value is informed + in_usr = False + in_sys = False + for label_d in usr_slot_label: + if label_d['slot'] == slot_type and value == ' '.join( + usr_utt_tok[label_d['start']:label_d['exclusive_end']]): + + for idx in range(label_d['start'], label_d['exclusive_end']): + usr_utt_tok_label[idx] = 1 + in_usr = True + class_type = 'copy_value' + if slot_last_occurrence: + break + if not in_usr or not slot_last_occurrence: + for label_d in sys_slot_label: + if label_d['slot'] == slot_type and value == ' '.join( + sys_utt_tok[label_d['start']:label_d['exclusive_end']]): + for idx in range(label_d['start'], label_d['exclusive_end']): + sys_utt_tok_label[idx] = 1 + in_sys = True + class_type = 'inform' + if slot_last_occurrence: + break + if not in_usr and not in_sys: + assert sum(usr_utt_tok_label + sys_utt_tok_label) == 0 + if (slot_type not in prev_ds_dict or value != prev_ds_dict[slot_type]): + raise ValueError('Copy value cannot found in Dial %s Turn %s' % (str(dial_id), str(turn_id))) + else: + class_type = 'none' + else: + assert sum(usr_utt_tok_label + sys_utt_tok_label) > 0 + return sys_utt_tok_label, usr_utt_tok_label, class_type + + +def delex_utt(utt, values): + utt_delex = utt.copy() + for v in values: + utt_delex[v['start']:v['exclusive_end']] = ['[UNK]'] * (v['exclusive_end'] - v['start']) + return utt_delex + + +def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, + delexicalize_sys_utts=False, slot_last_occurrence=True): + """Make turn_label a dictionary of slot with value positions or being dontcare / none: + Turn label contains: + (1) the updates from previous to current dialogue state, + (2) values in current dialogue state explicitly mentioned in system or user utterance.""" + prev_ds_dict = dialogue_state_to_sv_dict(prev_dialogue_state) + cur_ds_dict = dialogue_state_to_sv_dict(turn['dialogue_state']) + + (sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label) = get_token_and_slot_label(turn) + + sys_utt_tok_label_dict = {} + usr_utt_tok_label_dict = {} + inform_label_dict = {} + inform_slot_label_dict = {} + referral_label_dict = {} + class_type_dict = {} + + for slot_type in slot_list: + inform_label_dict[slot_type] = 'none' + inform_slot_label_dict[slot_type] = 0 + referral_label_dict[slot_type] = 'none' # Referral is not present in sim data + sys_utt_tok_label, usr_utt_tok_label, class_type = get_tok_label( + prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok, sys_slot_label, + usr_utt_tok, usr_slot_label, dial_id, turn_id, + slot_last_occurrence=slot_last_occurrence) + if sum(sys_utt_tok_label) > 0: + inform_label_dict[slot_type] = cur_ds_dict[slot_type] + inform_slot_label_dict[slot_type] = 1 + sys_utt_tok_label = [0 for _ in sys_utt_tok_label] # Don't use token labels for sys utt + sys_utt_tok_label_dict[slot_type] = sys_utt_tok_label + usr_utt_tok_label_dict[slot_type] = usr_utt_tok_label + class_type_dict[slot_type] = class_type + + if delexicalize_sys_utts: + sys_utt_tok = delex_utt(sys_utt_tok, sys_slot_label) + + return (sys_utt_tok, sys_utt_tok_label_dict, + usr_utt_tok, usr_utt_tok_label_dict, + inform_label_dict, inform_slot_label_dict, + referral_label_dict, cur_ds_dict, class_type_dict) + + +def create_examples(input_file, set_type, slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + analyze=False): + """Read a DST json file into a list of DSTExample.""" + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + + examples = [] + for entry in input_data: + dial_id = entry['dialogue_id'] + prev_ds = [] + hst = [] + prev_hst_lbl_dict = {slot: [] for slot in slot_list} + prev_ds_lbl_dict = {slot: 'none' for slot in slot_list} + + for turn_id, turn in enumerate(entry['turns']): + guid = '%s-%s-%s' % (set_type, dial_id, str(turn_id)) + ds_lbl_dict = prev_ds_lbl_dict.copy() + hst_lbl_dict = prev_hst_lbl_dict.copy() + (text_a, + text_a_label, + text_b, + text_b_label, + inform_label, + inform_slot_label, + referral_label, + cur_ds_dict, + class_label) = get_turn_label(turn, + prev_ds, + slot_list, + dial_id, + turn_id, + delexicalize_sys_utts=delexicalize_sys_utts, + slot_last_occurrence=True) + + if swap_utterances: + txt_a = text_b + txt_b = text_a + txt_a_lbl = text_b_label + txt_b_lbl = text_a_label + else: + txt_a = text_a + txt_b = text_b + txt_a_lbl = text_a_label + txt_b_lbl = text_b_label + + value_dict = {} + for slot in slot_list: + if slot in cur_ds_dict: + value_dict[slot] = cur_ds_dict[slot] + else: + value_dict[slot] = 'none' + if class_label[slot] != 'none': + ds_lbl_dict[slot] = class_label[slot] + if append_history: + if use_history_labels: + hst_lbl_dict[slot] = txt_a_lbl[slot] + txt_b_lbl[slot] + hst_lbl_dict[slot] + else: + hst_lbl_dict[slot] = [0 for _ in txt_a_lbl[slot] + txt_b_lbl[slot] + hst_lbl_dict[slot]] + + examples.append(DSTExample( + guid=guid, + text_a=txt_a, + text_b=txt_b, + history=hst, + text_a_label=txt_a_lbl, + text_b_label=txt_b_lbl, + history_label=prev_hst_lbl_dict, + values=value_dict, + inform_label=inform_label, + inform_slot_label=inform_slot_label, + refer_label=referral_label, + diag_state=prev_ds_lbl_dict, + class_label=class_label)) + + prev_ds = turn['dialogue_state'] + prev_ds_lbl_dict = ds_lbl_dict.copy() + prev_hst_lbl_dict = hst_lbl_dict.copy() + + if append_history: + hst = txt_a + txt_b + hst + + return examples diff --git a/dataset_woz2.py b/dataset_woz2.py new file mode 100644 index 0000000000000000000000000000000000000000..c32af05afd55bc99c4f4c84adea711991c12c10e --- /dev/null +++ b/dataset_woz2.py @@ -0,0 +1,278 @@ +# coding=utf-8 +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re + +from utils_dst import (DSTExample, convert_to_unicode) + + +LABEL_MAPS = {} # Loaded from file +LABEL_FIX = {'centre': 'center', 'areas': 'area', 'phone number': 'number', 'price range': 'price_range'} + + +def delex_utt(utt, values): + utt_norm = utt.copy() + for s, v in values.items(): + if v != 'none': + v_norm = tokenize(v) + v_len = len(v_norm) + for i in range(len(utt_norm) + 1 - v_len): + if utt_norm[i:i + v_len] == v_norm: + utt_norm[i:i + v_len] = ['[UNK]'] * v_len + return utt_norm + + +def get_token_pos(tok_list, label): + find_pos = [] + found = False + label_list = [item for item in map(str.strip, re.split("(\W+)", label)) if len(item) > 0] + len_label = len(label_list) + for i in range(len(tok_list) + 1 - len_label): + if tok_list[i:i + len_label] == label_list: + find_pos.append((i, i + len_label)) # start, exclusive_end + found = True + return found, find_pos + + +def check_label_existence(label, usr_utt_tok, sys_utt_tok): + in_usr, usr_pos = get_token_pos(usr_utt_tok, label) + if not in_usr and label in LABEL_MAPS: + for tmp_label in LABEL_MAPS[label]: + in_usr, usr_pos = get_token_pos(usr_utt_tok, tmp_label) + if in_usr: + break + in_sys, sys_pos = get_token_pos(sys_utt_tok, label) + if not in_sys and label in LABEL_MAPS: + for tmp_label in LABEL_MAPS[label]: + in_sys, sys_pos = get_token_pos(sys_utt_tok, tmp_label) + if in_sys: + break + return in_usr, usr_pos, in_sys, sys_pos + + +def get_turn_label(label, sys_utt_tok, usr_utt_tok, slot_last_occurrence): + usr_utt_tok_label = [0 for _ in usr_utt_tok] + if label == 'none' or label == 'dontcare': + class_type = label + else: + in_usr, usr_pos, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok) + if in_usr: + class_type = 'copy_value' + if slot_last_occurrence: + (s, e) = usr_pos[-1] + for i in range(s, e): + usr_utt_tok_label[i] = 1 + else: + for (s, e) in usr_pos: + for i in range(s, e): + usr_utt_tok_label[i] = 1 + elif in_sys: + class_type = 'inform' + else: + class_type = 'unpointable' + return usr_utt_tok_label, class_type + + +def tokenize(utt): + utt_lower = convert_to_unicode(utt).lower() + utt_tok = [tok for tok in map(str.strip, re.split("(\W+)", utt_lower)) if len(tok) > 0] + return utt_tok + + +def create_examples(input_file, set_type, slot_list, + label_maps={}, + append_history=False, + use_history_labels=False, + swap_utterances=False, + label_value_repetitions=False, + delexicalize_sys_utts=False, + analyze=False): + """Read a DST json file into a list of DSTExample.""" + with open(input_file, "r", encoding='utf-8') as reader: + input_data = json.load(reader) + + global LABEL_MAPS + LABEL_MAPS = label_maps + + examples = [] + for entry in input_data: + diag_seen_slots_dict = {} + diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list} + diag_state = {slot: 'none' for slot in slot_list} + sys_utt_tok = [] + sys_utt_tok_delex = [] + usr_utt_tok = [] + hst_utt_tok = [] + hst_utt_tok_label_dict = {slot: [] for slot in slot_list} + for turn in entry['dialogue']: + sys_utt_tok_label_dict = {} + usr_utt_tok_label_dict = {} + inform_dict = {slot: 'none' for slot in slot_list} + inform_slot_dict = {slot: 0 for slot in slot_list} + referral_dict = {} + class_type_dict = {} + + # Collect turn data + if append_history: + if swap_utterances: + if delexicalize_sys_utts: + hst_utt_tok = usr_utt_tok + sys_utt_tok_delex + hst_utt_tok + else: + hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok + else: + if delexicalize_sys_utts: + hst_utt_tok = sys_utt_tok_delex + usr_utt_tok + hst_utt_tok + else: + hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok + + sys_utt_tok = tokenize(turn['system_transcript']) + usr_utt_tok = tokenize(turn['transcript']) + turn_label = {LABEL_FIX.get(s.strip(), s.strip()): LABEL_FIX.get(v.strip(), v.strip()) for s, v in turn['turn_label']} + + guid = '%s-%s-%s' % (set_type, str(entry['dialogue_idx']), str(turn['turn_idx'])) + + # Create delexicalized sys utterances. + if delexicalize_sys_utts: + delex_dict = {} + for slot in slot_list: + delex_dict[slot] = 'none' + label = 'none' + if slot in turn_label: + label = turn_label[slot] + elif label_value_repetitions and slot in diag_seen_slots_dict: + label = diag_seen_slots_value_dict[slot] + if label != 'none' and label != 'dontcare': + _, _, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok) + if in_sys: + delex_dict[slot] = label + sys_utt_tok_delex = delex_utt(sys_utt_tok, delex_dict) + + new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy() + new_diag_state = diag_state.copy() + for slot in slot_list: + label = 'none' + if slot in turn_label: + label = turn_label[slot] + elif label_value_repetitions and slot in diag_seen_slots_dict: + label = diag_seen_slots_value_dict[slot] + + (usr_utt_tok_label, + class_type) = get_turn_label(label, + sys_utt_tok, + usr_utt_tok, + slot_last_occurrence=True) + + if class_type == 'inform': + inform_dict[slot] = label + if label != 'none': + inform_slot_dict[slot] = 1 + + referral_dict[slot] = 'none' # Referral is not present in woz2 data + + # Generally don't use span prediction on sys utterance (but inform prediction instead). + if delexicalize_sys_utts: + sys_utt_tok_label = [0 for _ in sys_utt_tok_delex] + else: + sys_utt_tok_label = [0 for _ in sys_utt_tok] + + # Determine what to do with value repetitions. + # If value is unique in seen slots, then tag it, otherwise not, + # since correct slot assignment can not be guaranteed anymore. + if label_value_repetitions and slot in diag_seen_slots_dict: + if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(label) > 1: + class_type = 'none' + usr_utt_tok_label = [0 for _ in usr_utt_tok_label] + + sys_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label_dict[slot] = usr_utt_tok_label + + if append_history: + if use_history_labels: + if swap_utterances: + new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[slot] + else: + new_hst_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot] + else: + new_hst_utt_tok_label_dict[slot] = [0 for _ in sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]] + + # For now, we map all occurences of unpointable slot values + # to none. However, since the labels will still suggest + # a presence of unpointable slot values, the task of the + # DST is still to find those values. It is just not + # possible to do that via span prediction on the current input. + if class_type == 'unpointable': + class_type_dict[slot] = 'none' + elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform': + # If slot has seen before and its class type did not change, label this slot a not present, + # assuming that the slot has not actually been mentioned in this turn. + # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, + # this must mean there is evidence in the original labels, therefore consider + # them as mentioned again. + class_type_dict[slot] = 'none' + referral_dict[slot] = 'none' + else: + class_type_dict[slot] = class_type + # Remember that this slot was mentioned during this dialog already. + if class_type != 'none': + diag_seen_slots_dict[slot] = class_type + diag_seen_slots_value_dict[slot] = label + new_diag_state[slot] = class_type + # Unpointable is not a valid class, therefore replace with + # some valid class for now... + if class_type == 'unpointable': + new_diag_state[slot] = 'copy_value' + + if swap_utterances: + txt_a = usr_utt_tok + if delexicalize_sys_utts: + txt_b = sys_utt_tok_delex + else: + txt_b = sys_utt_tok + txt_a_lbl = usr_utt_tok_label_dict + txt_b_lbl = sys_utt_tok_label_dict + else: + if delexicalize_sys_utts: + txt_a = sys_utt_tok_delex + else: + txt_a = sys_utt_tok + txt_b = usr_utt_tok + txt_a_lbl = sys_utt_tok_label_dict + txt_b_lbl = usr_utt_tok_label_dict + examples.append(DSTExample( + guid=guid, + text_a=txt_a, + text_b=txt_b, + history=hst_utt_tok, + text_a_label=txt_a_lbl, + text_b_label=txt_b_lbl, + history_label=hst_utt_tok_label_dict, + values=diag_seen_slots_value_dict.copy(), + inform_label=inform_dict, + inform_slot_label=inform_slot_dict, + refer_label=referral_dict, + diag_state=diag_state, + class_label=class_type_dict)) + + # Update some variables. + hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy() + diag_state = new_diag_state.copy() + + return examples + diff --git a/metric_bert_dst.py b/metric_bert_dst.py new file mode 100644 index 0000000000000000000000000000000000000000..d2df63af7146e0feef28b90ee9bc11000c8fa337 --- /dev/null +++ b/metric_bert_dst.py @@ -0,0 +1,370 @@ +# coding=utf-8 +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import json +import sys +import numpy as np +import re + + +def load_dataset_config(dataset_config): + with open(dataset_config, "r", encoding='utf-8') as f: + raw_config = json.load(f) + return raw_config['class_types'], raw_config['slots'], raw_config['label_maps'] + + +def tokenize(text): + if "\u0120" in text: + text = re.sub(" ", "", text) + text = re.sub("\u0120", " ", text) + text = text.strip() + return ' '.join([tok for tok in map(str.strip, re.split("(\W+)", text)) if len(tok) > 0]) + + +def is_in_list(tok, value): + found = False + tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0] + value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0] + tok_len = len(tok_list) + value_len = len(value_list) + for i in range(tok_len + 1 - value_len): + if tok_list[i:i + value_len] == value_list: + found = True + break + return found + + +def check_slot_inform(value_label, inform_label, label_maps): + value = inform_label + if value_label == inform_label: + value = value_label + elif is_in_list(inform_label, value_label): + value = value_label + elif is_in_list(value_label, inform_label): + value = value_label + elif inform_label in label_maps: + for inform_label_variant in label_maps[inform_label]: + if value_label == inform_label_variant: + value = value_label + break + elif is_in_list(inform_label_variant, value_label): + value = value_label + break + elif is_in_list(value_label, inform_label_variant): + value = value_label + break + elif value_label in label_maps: + for value_label_variant in label_maps[value_label]: + if value_label_variant == inform_label: + value = value_label + break + elif is_in_list(inform_label, value_label_variant): + value = value_label + break + elif is_in_list(value_label_variant, inform_label): + value = value_label + break + return value + + +def get_joint_slot_correctness(fp, class_types, label_maps, + key_class_label_id='class_label_id', + key_class_prediction='class_prediction', + key_start_pos='start_pos', + key_start_prediction='start_prediction', + key_end_pos='end_pos', + key_end_prediction='end_prediction', + key_refer_id='refer_id', + key_refer_prediction='refer_prediction', + key_slot_groundtruth='slot_groundtruth', + key_slot_prediction='slot_prediction'): + with open(fp) as f: + preds = json.load(f) + class_correctness = [[] for cl in range(len(class_types) + 1)] + confusion_matrix = [[[] for cl_b in range(len(class_types))] for cl_a in range(len(class_types))] + pos_correctness = [] + refer_correctness = [] + val_correctness = [] + total_correctness = [] + c_tp = {ct: 0 for ct in range(len(class_types))} + c_tn = {ct: 0 for ct in range(len(class_types))} + c_fp = {ct: 0 for ct in range(len(class_types))} + c_fn = {ct: 0 for ct in range(len(class_types))} + + for pred in preds: + guid = pred['guid'] # List: set_type, dialogue_idx, turn_idx + turn_gt_class = pred[key_class_label_id] + turn_pd_class = pred[key_class_prediction] + gt_start_pos = pred[key_start_pos] + pd_start_pos = pred[key_start_prediction] + gt_end_pos = pred[key_end_pos] + pd_end_pos = pred[key_end_prediction] + gt_refer = pred[key_refer_id] + pd_refer = pred[key_refer_prediction] + gt_slot = pred[key_slot_groundtruth] + pd_slot = pred[key_slot_prediction] + + gt_slot = tokenize(gt_slot) + pd_slot = tokenize(pd_slot) + + # Make sure the true turn labels are contained in the prediction json file! + joint_gt_slot = gt_slot + + if guid[-1] == '0': # First turn, reset the slots + joint_pd_slot = 'none' + + # If turn_pd_class or a value to be copied is "none", do not update the dialog state. + if turn_pd_class == class_types.index('none'): + pass + elif turn_pd_class == class_types.index('dontcare'): + joint_pd_slot = 'dontcare' + elif turn_pd_class == class_types.index('copy_value'): + joint_pd_slot = pd_slot + elif 'true' in class_types and turn_pd_class == class_types.index('true'): + joint_pd_slot = 'true' + elif 'false' in class_types and turn_pd_class == class_types.index('false'): + joint_pd_slot = 'false' + elif 'refer' in class_types and turn_pd_class == class_types.index('refer'): + if pd_slot[0:3] == "ยงยง ": + if pd_slot[3:] != 'none': + joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[3:], label_maps) + elif pd_slot[0:2] == "ยงยง": + if pd_slot[2:] != 'none': + joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:], label_maps) + elif pd_slot != 'none': + joint_pd_slot = pd_slot + elif 'inform' in class_types and turn_pd_class == class_types.index('inform'): + if pd_slot[0:3] == "ยงยง ": + if pd_slot[3:] != 'none': + joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[3:], label_maps) + elif pd_slot[0:2] == "ยงยง": + if pd_slot[2:] != 'none': + joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:], label_maps) + else: + print("ERROR: Unexpected slot value format. Aborting.") + exit() + else: + print("ERROR: Unexpected class_type. Aborting.") + exit() + + total_correct = True + + # Check the per turn correctness of the class_type prediction + if turn_gt_class == turn_pd_class: + class_correctness[turn_gt_class].append(1.0) + class_correctness[-1].append(1.0) + c_tp[turn_gt_class] += 1 + for cc in range(len(class_types)): + if cc != turn_gt_class: + c_tn[cc] += 1 + # Only where there is a span, we check its per turn correctness + if turn_gt_class == class_types.index('copy_value'): + if gt_start_pos == pd_start_pos and gt_end_pos == pd_end_pos: + pos_correctness.append(1.0) + else: + pos_correctness.append(0.0) + # Only where there is a referral, we check its per turn correctness + if 'refer' in class_types and turn_gt_class == class_types.index('refer'): + if gt_refer == pd_refer: + refer_correctness.append(1.0) + print(" [%s] Correct referral: %s | %s" % (guid, gt_refer, pd_refer)) + else: + refer_correctness.append(0.0) + print(" [%s] Incorrect referral: %s | %s" % (guid, gt_refer, pd_refer)) + else: + if turn_gt_class == class_types.index('copy_value'): + pos_correctness.append(0.0) + if 'refer' in class_types and turn_gt_class == class_types.index('refer'): + refer_correctness.append(0.0) + class_correctness[turn_gt_class].append(0.0) + class_correctness[-1].append(0.0) + confusion_matrix[turn_gt_class][turn_pd_class].append(1.0) + c_fn[turn_gt_class] += 1 + c_fp[turn_pd_class] += 1 + + # Check the joint slot correctness. + # If the value label is not none, then we need to have a value prediction. + # Even if the class_type is 'none', there can still be a value label, + # it might just not be pointable in the current turn. It might however + # be referrable and thus predicted correctly. + if joint_gt_slot == joint_pd_slot: + val_correctness.append(1.0) + elif joint_gt_slot != 'none' and joint_gt_slot != 'dontcare' and joint_gt_slot != 'true' and joint_gt_slot != 'false' and joint_gt_slot in label_maps: + no_match = True + for variant in label_maps[joint_gt_slot]: + if variant == joint_pd_slot: + no_match = False + break + if no_match: + val_correctness.append(0.0) + total_correct = False + print(" [%s] Incorrect value (variant): %s (turn class: %s) | %s (turn class: %s)" % (guid, joint_gt_slot, turn_gt_class, joint_pd_slot, turn_pd_class)) + else: + val_correctness.append(1.0) + else: + val_correctness.append(0.0) + total_correct = False + print(" [%s] Incorrect value: %s (turn class: %s) | %s (turn class: %s)" % (guid, joint_gt_slot, turn_gt_class, joint_pd_slot, turn_pd_class)) + + total_correctness.append(1.0 if total_correct else 0.0) + + # Account for empty lists (due to no instances of spans or referrals being seen) + if pos_correctness == []: + pos_correctness.append(1.0) + if refer_correctness == []: + refer_correctness.append(1.0) + + for ct in range(len(class_types)): + if c_tp[ct] + c_fp[ct] > 0: + precision = c_tp[ct] / (c_tp[ct] + c_fp[ct]) + else: + precision = 1.0 + if c_tp[ct] + c_fn[ct] > 0: + recall = c_tp[ct] / (c_tp[ct] + c_fn[ct]) + else: + recall = 1.0 + if precision + recall > 0: + f1 = 2 * ((precision * recall) / (precision + recall)) + else: + f1 = 1.0 + if c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct] > 0: + acc = (c_tp[ct] + c_tn[ct]) / (c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct]) + else: + acc = 1.0 + print("Performance for class '%s' (%s): Recall: %.2f (%d of %d), Precision: %.2f, F1: %.2f, Accuracy: %.2f (TP/TN/FP/FN: %d/%d/%d/%d)" % + (class_types[ct], ct, recall, np.sum(class_correctness[ct]), len(class_correctness[ct]), precision, f1, acc, c_tp[ct], c_tn[ct], c_fp[ct], c_fn[ct])) + + print("Confusion matrix:") + for cl in range(len(class_types)): + print(" %s" % (cl), end="") + print("") + for cl_a in range(len(class_types)): + print("%s " % (cl_a), end="") + for cl_b in range(len(class_types)): + if len(class_correctness[cl_a]) > 0: + print("%.2f " % (np.sum(confusion_matrix[cl_a][cl_b]) / len(class_correctness[cl_a])), end="") + else: + print("---- ", end="") + print("") + + return np.asarray(total_correctness), np.asarray(val_correctness), np.asarray(class_correctness), np.asarray(pos_correctness), np.asarray(refer_correctness), np.asarray(confusion_matrix), c_tp, c_tn, c_fp, c_fn + + +if __name__ == "__main__": + acc_list = [] + acc_list_v = [] + key_class_label_id = 'class_label_id_%s' + key_class_prediction = 'class_prediction_%s' + key_start_pos = 'start_pos_%s' + key_start_prediction = 'start_prediction_%s' + key_end_pos = 'end_pos_%s' + key_end_prediction = 'end_prediction_%s' + key_refer_id = 'refer_id_%s' + key_refer_prediction = 'refer_prediction_%s' + key_slot_groundtruth = 'slot_groundtruth_%s' + key_slot_prediction = 'slot_prediction_%s' + + dataset = sys.argv[1].lower() + dataset_config = sys.argv[2].lower() + + if dataset not in ['woz2', 'sim-m', 'sim-r', 'multiwoz21']: + raise ValueError("Task not found: %s" % (dataset)) + + class_types, slots, label_maps = load_dataset_config(dataset_config) + + # Prepare label_maps + label_maps_tmp = {} + for v in label_maps: + label_maps_tmp[tokenize(v)] = [tokenize(nv) for nv in label_maps[v]] + label_maps = label_maps_tmp + + for fp in sorted(glob.glob(sys.argv[3])): + print(fp) + goal_correctness = 1.0 + cls_acc = [[] for cl in range(len(class_types))] + cls_conf = [[[] for cl_b in range(len(class_types))] for cl_a in range(len(class_types))] + c_tp = {ct: 0 for ct in range(len(class_types))} + c_tn = {ct: 0 for ct in range(len(class_types))} + c_fp = {ct: 0 for ct in range(len(class_types))} + c_fn = {ct: 0 for ct in range(len(class_types))} + for slot in slots: + tot_cor, joint_val_cor, cls_cor, pos_cor, ref_cor, conf_mat, ctp, ctn, cfp, cfn = get_joint_slot_correctness(fp, class_types, label_maps, + key_class_label_id=(key_class_label_id % slot), + key_class_prediction=(key_class_prediction % slot), + key_start_pos=(key_start_pos % slot), + key_start_prediction=(key_start_prediction % slot), + key_end_pos=(key_end_pos % slot), + key_end_prediction=(key_end_prediction % slot), + key_refer_id=(key_refer_id % slot), + key_refer_prediction=(key_refer_prediction % slot), + key_slot_groundtruth=(key_slot_groundtruth % slot), + key_slot_prediction=(key_slot_prediction % slot) + ) + print('%s: joint slot acc: %g, joint value acc: %g, turn class acc: %g, turn position acc: %g, turn referral acc: %g' % + (slot, np.mean(tot_cor), np.mean(joint_val_cor), np.mean(cls_cor[-1]), np.mean(pos_cor), np.mean(ref_cor))) + goal_correctness *= tot_cor + for cl_a in range(len(class_types)): + cls_acc[cl_a] += cls_cor[cl_a] + for cl_b in range(len(class_types)): + cls_conf[cl_a][cl_b] += list(conf_mat[cl_a][cl_b]) + c_tp[cl_a] += ctp[cl_a] + c_tn[cl_a] += ctn[cl_a] + c_fp[cl_a] += cfp[cl_a] + c_fn[cl_a] += cfn[cl_a] + + for ct in range(len(class_types)): + if c_tp[ct] + c_fp[ct] > 0: + precision = c_tp[ct] / (c_tp[ct] + c_fp[ct]) + else: + precision = 1.0 + if c_tp[ct] + c_fn[ct] > 0: + recall = c_tp[ct] / (c_tp[ct] + c_fn[ct]) + else: + recall = 1.0 + if precision + recall > 0: + f1 = 2 * ((precision * recall) / (precision + recall)) + else: + f1 = 1.0 + if c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct] > 0: + acc = (c_tp[ct] + c_tn[ct]) / (c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct]) + else: + acc = 1.0 + print("Performance for class '%s' (%s): Recall: %.2f (%d of %d), Precision: %.2f, F1: %.2f, Accuracy: %.2f (TP/TN/FP/FN: %d/%d/%d/%d)" % + (class_types[ct], ct, recall, np.sum(cls_acc[ct]), len(cls_acc[ct]), precision, f1, acc, c_tp[ct], c_tn[ct], c_fp[ct], c_fn[ct])) + + print("Confusion matrix:") + for cl in range(len(class_types)): + print(" %s" % (cl), end="") + print("") + for cl_a in range(len(class_types)): + print("%s " % (cl_a), end="") + for cl_b in range(len(class_types)): + if len(cls_acc[cl_a]) > 0: + print("%.2f " % (np.sum(cls_conf[cl_a][cl_b]) / len(cls_acc[cl_a])), end="") + else: + print("---- ", end="") + print("") + + acc = np.mean(goal_correctness) + acc_list.append((fp, acc)) + + acc_list_s = sorted(acc_list, key=lambda tup: tup[1], reverse=True) + for (fp, acc) in acc_list_s: + print('Joint goal acc: %g, %s' % (acc, fp)) diff --git a/modeling_bert_dst.py b/modeling_bert_dst.py new file mode 100644 index 0000000000000000000000000000000000000000..dc7206c3f0380417292289f2fe1909777c6cb418 --- /dev/null +++ b/modeling_bert_dst.py @@ -0,0 +1,174 @@ +# coding=utf-8 +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable) +from transformers.modeling_bert import (BertModel, BertPreTrainedModel, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) + + +@add_start_docstrings( + """BERT Model with a classification heads for the DST task. """, + BERT_START_DOCSTRING, +) +class BertForDST(BertPreTrainedModel): + def __init__(self, config): + super(BertForDST, self).__init__(config) + self.slot_list = config.dst_slot_list + self.class_types = config.dst_class_types + self.class_labels = config.dst_class_labels + self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable + self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable + self.class_aux_feats_inform = config.dst_class_aux_feats_inform + self.class_aux_feats_ds = config.dst_class_aux_feats_ds + self.class_loss_ratio = config.dst_class_loss_ratio + + # Only use refer loss if refer class is present in dataset. + if 'refer' in self.class_types: + self.refer_index = self.class_types.index('refer') + else: + self.refer_index = -1 + + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.dst_dropout_rate) + self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate) + + if self.class_aux_feats_inform: + self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) + if self.class_aux_feats_ds: + self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) + + aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2 + + for slot in self.slot_list: + self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels)) + self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2)) + self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1)) + + self.init_weights() + + @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) + def forward(self, + input_ids, + input_mask=None, + segment_ids=None, + position_ids=None, + head_mask=None, + start_pos=None, + end_pos=None, + inform_slot_id=None, + refer_id=None, + class_label_id=None, + diag_state=None): + outputs = self.bert( + input_ids, + attention_mask=input_mask, + token_type_ids=segment_ids, + position_ids=position_ids, + head_mask=head_mask + ) + + sequence_output = outputs[0] + pooled_output = outputs[1] + + sequence_output = self.dropout(sequence_output) + pooled_output = self.dropout(pooled_output) + + # TODO: establish proper format in labels already? + if inform_slot_id is not None: + inform_labels = torch.stack(list(inform_slot_id.values()), 1).float() + if diag_state is not None: + diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0) + + total_loss = 0 + per_slot_per_example_loss = {} + per_slot_class_logits = {} + per_slot_start_logits = {} + per_slot_end_logits = {} + per_slot_refer_logits = {} + for slot in self.slot_list: + if self.class_aux_feats_inform and self.class_aux_feats_ds: + pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) + elif self.class_aux_feats_inform: + pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1) + elif self.class_aux_feats_ds: + pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1) + else: + pooled_output_aux = pooled_output + class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux)) + + token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output)) + start_logits, end_logits = token_logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux)) + + per_slot_class_logits[slot] = class_logits + per_slot_start_logits[slot] = start_logits + per_slot_end_logits[slot] = end_logits + per_slot_refer_logits[slot] = refer_logits + + # If there are no labels, don't compute loss + if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None: + # If we are on multi-GPU, split add a dimension + if len(start_pos[slot].size()) > 1: + start_pos[slot] = start_pos[slot].squeeze(-1) + if len(end_pos[slot].size()) > 1: + end_pos[slot] = end_pos[slot].squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) # This is a single index + start_pos[slot].clamp_(0, ignored_index) + end_pos[slot].clamp_(0, ignored_index) + + class_loss_fct = CrossEntropyLoss(reduction='none') + token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index) + refer_loss_fct = CrossEntropyLoss(reduction='none') + + start_loss = token_loss_fct(start_logits, start_pos[slot]) + end_loss = token_loss_fct(end_logits, end_pos[slot]) + token_loss = (start_loss + end_loss) / 2.0 + + token_is_pointable = (start_pos[slot] > 0).float() + if not self.token_loss_for_nonpointable: + token_loss *= token_is_pointable + + refer_loss = refer_loss_fct(refer_logits, refer_id[slot]) + token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float() + if not self.refer_loss_for_nonpointable: + refer_loss *= token_is_referrable + + class_loss = class_loss_fct(class_logits, class_label_id[slot]) + + if self.refer_index > -1: + per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss + else: + per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss + + total_loss += per_example_loss.sum() + per_slot_per_example_loss[slot] = per_example_loss + + # add hidden states and attention if they are here + outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:] + + return outputs diff --git a/run_dst.py b/run_dst.py new file mode 100644 index 0000000000000000000000000000000000000000..bde263007dfaa6bf071a3c9547f7732f5c38183b --- /dev/null +++ b/run_dst.py @@ -0,0 +1,764 @@ +# coding=utf-8 +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import random +import glob +import json +import math +import re + +import numpy as np +import torch +from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler) +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm, trange + +from tensorboardX import SummaryWriter + +from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer) +from transformers import (AdamW, get_linear_schedule_with_warmup) + +from modeling_bert_dst import (BertForDST) +from data_processors import PROCESSORS +from utils_dst import (convert_examples_to_features) +from tensorlistdataset import (TensorListDataset) + +logger = logging.getLogger(__name__) + +ALL_MODELS = tuple(BertConfig.pretrained_config_archive_map.keys()) + +MODEL_CLASSES = { + 'bert': (BertConfig, BertForDST, BertTokenizer), +} + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def to_list(tensor): + return tensor.detach().cpu().tolist() + + +def batch_to_device(batch, device): + batch_on_device = [] + for element in batch: + if isinstance(element, dict): + batch_on_device.append({k: v.to(device) for k, v in element.items()}) + else: + batch_on_device.append(element.to(device)) + return tuple(batch_on_device) + + +def train(args, train_dataset, features, model, tokenizer, processor, continue_from_global_step=0): + """ Train the model """ + if args.local_rank in [-1, 0]: + tb_writer = SummaryWriter() + + args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) + train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) + train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) + + if args.max_steps > 0: + t_total = args.max_steps + args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 + else: + t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs + + if args.save_epochs > 0: + args.save_steps = t_total // args.num_train_epochs * args.save_epochs + + num_warmup_steps = int(t_total * args.warmup_proportion) + + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ['bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total) + if args.fp16: + try: + from apex import amp + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) + + # multi-gpu training (should be after apex fp16 initialization) + model_single_gpu = model + if args.n_gpu > 1: + model = torch.nn.DataParallel(model_single_gpu) + + # Distributed training (should be after apex fp16 initialization) + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True) + + # Train! + logger.info("***** Running training *****") + logger.info(" Num examples = %d", len(train_dataset)) + logger.info(" Num Epochs = %d", args.num_train_epochs) + logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) + logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", t_total) + logger.info(" Warmup steps = %d", num_warmup_steps) + + if continue_from_global_step > 0: + logger.info("Fast forwarding to global step %d to resume training from latest checkpoint...", continue_from_global_step) + + global_step = 0 + tr_loss, logging_loss = 0.0, 0.0 + model.zero_grad() + train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) + set_seed(args) # Added here for reproductibility (even between python 2 and 3) + + for _ in train_iterator: + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) + + for step, batch in enumerate(epoch_iterator): + # If training is continued from a checkpoint, fast forward + # to the state of that checkpoint. + if global_step < continue_from_global_step: + if (step + 1) % args.gradient_accumulation_steps == 0: + scheduler.step() # Update learning rate schedule + global_step += 1 + continue + + model.train() + batch = batch_to_device(batch, args.device) + + # This is what is forwarded to the "forward" def. + inputs = {'input_ids': batch[0], + 'input_mask': batch[1], + 'segment_ids': batch[2], + 'start_pos': batch[3], + 'end_pos': batch[4], + 'inform_slot_id': batch[5], + 'refer_id': batch[6], + 'diag_state': batch[7], + 'class_label_id': batch[8]} + outputs = model(**inputs) + loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) + + if args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + if args.fp16: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) + else: + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + tr_loss += loss.item() + if (step + 1) % args.gradient_accumulation_steps == 0: + optimizer.step() + scheduler.step() # Update learning rate schedule + model.zero_grad() + global_step += 1 + + # Log metrics + if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: + tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) + tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step) + logging_loss = tr_loss + + # Save model checkpoint + if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: + output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + model_to_save.save_pretrained(output_dir) + torch.save(args, os.path.join(output_dir, 'training_args.bin')) + logger.info("Saving model checkpoint to %s", output_dir) + + if args.max_steps > 0 and global_step > args.max_steps: + epoch_iterator.close() + break + + if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well + results = evaluate(args, model_single_gpu, tokenizer, processor, prefix=global_step) + for key, value in results.items(): + tb_writer.add_scalar('eval_{}'.format(key), value, global_step) + + if args.max_steps > 0 and global_step > args.max_steps: + train_iterator.close() + break + + if args.local_rank in [-1, 0]: + tb_writer.close() + + return global_step, tr_loss / global_step + + +def evaluate(args, model, tokenizer, processor, prefix=""): + dataset, features = load_and_cache_examples(args, model, tokenizer, processor, evaluate=True) + + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + args.eval_batch_size = args.per_gpu_eval_batch_size + eval_sampler = SequentialSampler(dataset) # Note that DistributedSampler samples randomly + eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + + # Eval! + logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info(" Num examples = %d", len(dataset)) + logger.info(" Batch size = %d", args.eval_batch_size) + all_results = [] + all_preds = [] + ds = {slot: 'none' for slot in model.slot_list} + with torch.no_grad(): + diag_state = {slot: torch.tensor([0 for _ in range(args.eval_batch_size)]).to(args.device) for slot in model.slot_list} + for batch in tqdm(eval_dataloader, desc="Evaluating"): + model.eval() + batch = batch_to_device(batch, args.device) + + # Reset dialog state if turn is first in the dialog. + turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]] + reset_diag_state = np.where(np.array(turn_itrs) == '0')[0] + for slot in model.slot_list: + for i in reset_diag_state: + diag_state[slot][i] = 0 + + with torch.no_grad(): + inputs = {'input_ids': batch[0], + 'input_mask': batch[1], + 'segment_ids': batch[2], + 'start_pos': batch[3], + 'end_pos': batch[4], + 'inform_slot_id': batch[5], + 'refer_id': batch[6], + 'diag_state': diag_state, + 'class_label_id': batch[8]} + unique_ids = [features[i.item()].guid for i in batch[9]] + values = [features[i.item()].values for i in batch[9]] + input_ids_unmasked = [features[i.item()].input_ids_unmasked for i in batch[9]] + inform = [features[i.item()].inform for i in batch[9]] + outputs = model(**inputs) + + # Update dialog state for next turn. + for slot in model.slot_list: + updates = outputs[2][slot].max(1)[1] + for i, u in enumerate(updates): + if u != 0: + diag_state[slot][i] = u + + results = eval_metric(model, inputs, outputs[0], outputs[1], outputs[2], outputs[3], outputs[4], outputs[5]) + preds, ds = predict_and_format(model, tokenizer, inputs, outputs[2], outputs[3], outputs[4], outputs[5], unique_ids, input_ids_unmasked, values, inform, prefix, ds) + all_results.append(results) + all_preds.append(preds) + + all_preds = [item for sublist in all_preds for item in sublist] # Flatten list + + # Generate final results + final_results = {} + for k in all_results[0].keys(): + final_results[k] = torch.stack([r[k] for r in all_results]).mean() + + # Write final predictions (for evaluation with external tool) + output_prediction_file = os.path.join(args.output_dir, "pred_res.%s.%s.json" % (args.predict_type, prefix)) + with open(output_prediction_file, "w") as f: + json.dump(all_preds, f, indent=2) + + return final_results + + +def eval_metric(model, features, total_loss, per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits): + metric_dict = {} + per_slot_correctness = {} + for slot in model.slot_list: + per_example_loss = per_slot_per_example_loss[slot] + class_logits = per_slot_class_logits[slot] + start_logits = per_slot_start_logits[slot] + end_logits = per_slot_end_logits[slot] + refer_logits = per_slot_refer_logits[slot] + + class_label_id = features['class_label_id'][slot] + start_pos = features['start_pos'][slot] + end_pos = features['end_pos'][slot] + refer_id = features['refer_id'][slot] + + _, class_prediction = class_logits.max(1) + class_correctness = torch.eq(class_prediction, class_label_id).float() + class_accuracy = class_correctness.mean() + + # "is pointable" means whether class label is "copy_value", + # i.e., that there is a span to be detected. + token_is_pointable = torch.eq(class_label_id, model.class_types.index('copy_value')).float() + _, start_prediction = start_logits.max(1) + start_correctness = torch.eq(start_prediction, start_pos).float() + _, end_prediction = end_logits.max(1) + end_correctness = torch.eq(end_prediction, end_pos).float() + token_correctness = start_correctness * end_correctness + token_accuracy = (token_correctness * token_is_pointable).sum() / token_is_pointable.sum() + # NaNs mean that none of the examples in this batch contain spans. -> division by 0 + # The accuracy therefore is 1 by default. -> replace NaNs + if math.isnan(token_accuracy): + token_accuracy = torch.tensor(1.0, device=token_accuracy.device) + + token_is_referrable = torch.eq(class_label_id, model.class_types.index('refer') if 'refer' in model.class_types else -1).float() + _, refer_prediction = refer_logits.max(1) + refer_correctness = torch.eq(refer_prediction, refer_id).float() + refer_accuracy = refer_correctness.sum() / token_is_referrable.sum() + # NaNs mean that none of the examples in this batch contain referrals. -> division by 0 + # The accuracy therefore is 1 by default. -> replace NaNs + if math.isnan(refer_accuracy) or math.isinf(refer_accuracy): + refer_accuracy = torch.tensor(1.0, device=refer_accuracy.device) + + total_correctness = class_correctness * (token_is_pointable * token_correctness + (1 - token_is_pointable)) * (token_is_referrable * refer_correctness + (1 - token_is_referrable)) + total_accuracy = total_correctness.mean() + + loss = per_example_loss.mean() + metric_dict['eval_accuracy_class_%s' % slot] = class_accuracy + metric_dict['eval_accuracy_token_%s' % slot] = token_accuracy + metric_dict['eval_accuracy_refer_%s' % slot] = refer_accuracy + metric_dict['eval_accuracy_%s' % slot] = total_accuracy + metric_dict['eval_loss_%s' % slot] = loss + per_slot_correctness[slot] = total_correctness + + goal_correctness = torch.stack([c for c in per_slot_correctness.values()], 1).prod(1) + goal_accuracy = goal_correctness.mean() + metric_dict['eval_accuracy_goal'] = goal_accuracy + metric_dict['loss'] = total_loss + return metric_dict + + +def predict_and_format(model, tokenizer, features, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits, ids, input_ids_unmasked, values, inform, prefix, ds): + prediction_list = [] + dialog_state = ds + for i in range(len(ids)): + if int(ids[i].split("-")[2]) == 0: + dialog_state = {slot: 'none' for slot in model.slot_list} + + prediction = {} + prediction_addendum = {} + for slot in model.slot_list: + class_logits = per_slot_class_logits[slot][i] + start_logits = per_slot_start_logits[slot][i] + end_logits = per_slot_end_logits[slot][i] + refer_logits = per_slot_refer_logits[slot][i] + + input_ids = features['input_ids'][i].tolist() + class_label_id = int(features['class_label_id'][slot][i]) + start_pos = int(features['start_pos'][slot][i]) + end_pos = int(features['end_pos'][slot][i]) + refer_id = int(features['refer_id'][slot][i]) + + class_prediction = int(class_logits.argmax()) + start_prediction = int(start_logits.argmax()) + end_prediction = int(end_logits.argmax()) + refer_prediction = int(refer_logits.argmax()) + + prediction['guid'] = ids[i].split("-") + prediction['class_prediction_%s' % slot] = class_prediction + prediction['class_label_id_%s' % slot] = class_label_id + prediction['start_prediction_%s' % slot] = start_prediction + prediction['start_pos_%s' % slot] = start_pos + prediction['end_prediction_%s' % slot] = end_prediction + prediction['end_pos_%s' % slot] = end_pos + prediction['refer_prediction_%s' % slot] = refer_prediction + prediction['refer_id_%s' % slot] = refer_id + prediction['input_ids_%s' % slot] = input_ids + + if class_prediction == model.class_types.index('dontcare'): + dialog_state[slot] = 'dontcare' + elif class_prediction == model.class_types.index('copy_value'): + input_tokens = tokenizer.convert_ids_to_tokens(input_ids_unmasked[i]) + dialog_state[slot] = ' '.join(input_tokens[start_prediction:end_prediction + 1]) + dialog_state[slot] = re.sub("(^| )##", "", dialog_state[slot]) + elif 'true' in model.class_types and class_prediction == model.class_types.index('true'): + dialog_state[slot] = 'true' + elif 'false' in model.class_types and class_prediction == model.class_types.index('false'): + dialog_state[slot] = 'false' + elif class_prediction == model.class_types.index('inform'): + dialog_state[slot] = 'ยงยง' + inform[i][slot] + # Referral case is handled below + + prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot] + prediction_addendum['slot_groundtruth_%s' % slot] = values[i][slot] + + # Referral case. All other slot values need to be seen first in order + # to be able to do this correctly. + for slot in model.slot_list: + class_logits = per_slot_class_logits[slot][i] + refer_logits = per_slot_refer_logits[slot][i] + + class_prediction = int(class_logits.argmax()) + refer_prediction = int(refer_logits.argmax()) + + if 'refer' in model.class_types and class_prediction == model.class_types.index('refer'): + # Only slots that have been mentioned before can be referred to. + # One can think of a situation where one slot is referred to in the same utterance. + # This phenomenon is however currently not properly covered in the training data + # label generation process. + dialog_state[slot] = dialog_state[model.slot_list[refer_prediction - 1]] + prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot] # Value update + + prediction.update(prediction_addendum) + prediction_list.append(prediction) + + return prediction_list, dialog_state + + +def load_and_cache_examples(args, model, tokenizer, processor, evaluate=False): + if args.local_rank not in [-1, 0] and not evaluate: + torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache + + # Load data features from cache or dataset file + cached_file = os.path.join(os.path.dirname(args.output_dir), 'cached_{}_features'.format( + args.predict_type if evaluate else 'train')) + if os.path.exists(cached_file) and not args.overwrite_cache: # and not output_examples: + logger.info("Loading features from cached file %s", cached_file) + features = torch.load(cached_file) + else: + logger.info("Creating features from dataset file at %s", args.data_dir) + processor_args = {'append_history': args.append_history, + 'use_history_labels': args.use_history_labels, + 'swap_utterances': args.swap_utterances, + 'label_value_repetitions': args.label_value_repetitions, + 'delexicalize_sys_utts': args.delexicalize_sys_utts} + if evaluate and args.predict_type == "dev": + examples = processor.get_dev_examples(args.data_dir, processor_args) + elif evaluate and args.predict_type == "test": + examples = processor.get_test_examples(args.data_dir, processor_args) + else: + examples = processor.get_train_examples(args.data_dir, processor_args) + features = convert_examples_to_features(examples=examples, + slot_list=model.slot_list, + class_types=model.class_types, + model_type=args.model_type, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + slot_value_dropout=(0.0 if evaluate else args.svd)) + if args.local_rank in [-1, 0]: + logger.info("Saving features into cached file %s", cached_file) + torch.save(features, cached_file) + + if args.local_rank == 0 and not evaluate: + torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache + + # Convert to Tensors and build dataset + all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) + all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) + all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) + all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) + f_start_pos = [f.start_pos for f in features] + f_end_pos = [f.end_pos for f in features] + f_inform_slot_ids = [f.inform_slot for f in features] + f_refer_ids = [f.refer_id for f in features] + f_diag_state = [f.diag_state for f in features] + f_class_label_ids = [f.class_label_id for f in features] + all_start_positions = {} + all_end_positions = {} + all_inform_slot_ids = {} + all_refer_ids = {} + all_diag_state = {} + all_class_label_ids = {} + for s in model.slot_list: + all_start_positions[s] = torch.tensor([f[s] for f in f_start_pos], dtype=torch.long) + all_end_positions[s] = torch.tensor([f[s] for f in f_end_pos], dtype=torch.long) + all_inform_slot_ids[s] = torch.tensor([f[s] for f in f_inform_slot_ids], dtype=torch.long) + all_refer_ids[s] = torch.tensor([f[s] for f in f_refer_ids], dtype=torch.long) + all_diag_state[s] = torch.tensor([f[s] for f in f_diag_state], dtype=torch.long) + all_class_label_ids[s] = torch.tensor([f[s] for f in f_class_label_ids], dtype=torch.long) + dataset = TensorListDataset(all_input_ids, all_input_mask, all_segment_ids, + all_start_positions, all_end_positions, + all_inform_slot_ids, + all_refer_ids, + all_diag_state, + all_class_label_ids, all_example_index) + + return dataset, features + + +def main(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--task_name", default=None, type=str, required=True, + help="Name of the task (e.g., multiwoz21).") + parser.add_argument("--data_dir", default=None, type=str, required=True, + help="Task database.") + parser.add_argument("--dataset_config", default=None, type=str, required=True, + help="Dataset configuration file.") + parser.add_argument("--predict_type", default=None, type=str, required=True, + help="Portion of the data to perform prediction on (e.g., dev, test).") + parser.add_argument("--model_type", default=None, type=str, required=True, + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) + parser.add_argument("--model_name_or_path", default=None, type=str, required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) + parser.add_argument("--output_dir", default=None, type=str, required=True, + help="The output directory where the model checkpoints and predictions will be written.") + + # Other parameters + parser.add_argument("--config_name", default="", type=str, + help="Pretrained config name or path if not the same as model_name") + parser.add_argument("--tokenizer_name", default="", type=str, + help="Pretrained tokenizer name or path if not the same as model_name") + + parser.add_argument("--max_seq_length", default=384, type=int, + help="Maximum input length after tokenization. Longer sequences will be truncated, shorter ones padded.") + parser.add_argument("--do_train", action='store_true', + help="Whether to run training.") + parser.add_argument("--do_eval", action='store_true', + help="Whether to run eval on the <predict_type> set.") + parser.add_argument("--evaluate_during_training", action='store_true', + help="Rul evaluation during training at each logging step.") + parser.add_argument("--do_lower_case", action='store_true', + help="Set this flag if you are using an uncased model.") + + parser.add_argument("--dropout_rate", default=0.3, type=float, + help="Dropout rate for BERT representations.") + parser.add_argument("--heads_dropout", default=0.0, type=float, + help="Dropout rate for classification heads.") + parser.add_argument("--class_loss_ratio", default=0.8, type=float, + help="The ratio applied on class loss in total loss calculation. " + "Should be a value in [0.0, 1.0]. " + "The ratio applied on token loss is (1-class_loss_ratio)/2. " + "The ratio applied on refer loss is (1-class_loss_ratio)/2.") + parser.add_argument("--token_loss_for_nonpointable", action='store_true', + help="Whether the token loss for classes other than copy_value contribute towards total loss.") + parser.add_argument("--refer_loss_for_nonpointable", action='store_true', + help="Whether the refer loss for classes other than refer contribute towards total loss.") + + parser.add_argument("--append_history", action='store_true', + help="Whether or not to append the dialog history to each turn.") + parser.add_argument("--use_history_labels", action='store_true', + help="Whether or not to label the history as well.") + parser.add_argument("--swap_utterances", action='store_true', + help="Whether or not to swap the turn utterances (default: sys|usr, swapped: usr|sys).") + parser.add_argument("--label_value_repetitions", action='store_true', + help="Whether or not to label values that have been mentioned before.") + parser.add_argument("--delexicalize_sys_utts", action='store_true', + help="Whether or not to delexicalize the system utterances.") + parser.add_argument("--class_aux_feats_inform", action='store_true', + help="Whether or not to use the identity of informed slots as auxiliary featurs for class prediction.") + parser.add_argument("--class_aux_feats_ds", action='store_true', + help="Whether or not to use the identity of slots in the current dialog state as auxiliary featurs for class prediction.") + + parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, + help="Batch size per GPU/CPU for training.") + parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, + help="Batch size per GPU/CPU for evaluation.") + parser.add_argument("--learning_rate", default=5e-5, type=float, + help="The initial learning rate for Adam.") + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument("--weight_decay", default=0.0, type=float, + help="Weight deay if we apply some.") + parser.add_argument("--adam_epsilon", default=1e-8, type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, + help="Max gradient norm.") + parser.add_argument("--num_train_epochs", default=3.0, type=float, + help="Total number of training epochs to perform.") + parser.add_argument("--max_steps", default=-1, type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.") + parser.add_argument("--warmup_proportion", default=0.0, type=float, + help="Linear warmup over warmup_proportion * steps.") + parser.add_argument("--svd", default=0.0, type=float, + help="Slot value dropout ratio (default: 0.0)") + + parser.add_argument('--logging_steps', type=int, default=50, + help="Log every X updates steps.") + parser.add_argument('--save_steps', type=int, default=0, + help="Save checkpoint every X updates steps. Overwritten by --save_epochs.") + parser.add_argument('--save_epochs', type=int, default=0, + help="Save checkpoint every X epochs. Overrides --save_steps.") + parser.add_argument("--eval_all_checkpoints", action='store_true', + help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") + parser.add_argument("--no_cuda", action='store_true', + help="Whether not to use CUDA when available") + parser.add_argument('--overwrite_output_dir', action='store_true', + help="Overwrite the content of the output directory") + parser.add_argument('--overwrite_cache', action='store_true', + help="Overwrite the cached training and evaluation sets") + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + + parser.add_argument("--local_rank", type=int, default=-1, + help="local_rank for distributed training on gpus") + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") + parser.add_argument('--fp16_opt_level', type=str, default='O1', + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html") + + args = parser.parse_args() + + assert(args.warmup_proportion >= 0.0 and args.warmup_proportion <= 1.0) + assert(args.svd >= 0.0 and args.svd <= 1.0) + assert(args.class_aux_feats_ds is False or args.per_gpu_eval_batch_size == 1) + assert(not args.class_aux_feats_inform or args.per_gpu_eval_batch_size == 1) + assert(not args.class_aux_feats_ds or args.per_gpu_eval_batch_size == 1) + + task_name = args.task_name.lower() + if task_name not in PROCESSORS: + raise ValueError("Task not found: %s" % (task_name)) + + processor = PROCESSORS[task_name](args.dataset_config) + dst_slot_list = processor.slot_list + dst_class_types = processor.class_types + dst_class_labels = len(dst_class_types) + + # Setup CUDA, GPU & distributed training + if args.local_rank == -1 or args.no_cuda: + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend='nccl') + args.n_gpu = 1 + args.device = device + + # Setup logging + logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) + + # Set seed + set_seed(args) + + # Load pretrained model and tokenizer + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab + + args.model_type = args.model_type.lower() + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) + + # Add DST specific parameters to config + config.dst_dropout_rate = args.dropout_rate + config.dst_heads_dropout_rate = args.heads_dropout + config.dst_class_loss_ratio = args.class_loss_ratio + config.dst_token_loss_for_nonpointable = args.token_loss_for_nonpointable + config.dst_refer_loss_for_nonpointable = args.refer_loss_for_nonpointable + config.dst_class_aux_feats_inform = args.class_aux_feats_inform + config.dst_class_aux_feats_ds = args.class_aux_feats_ds + config.dst_slot_list = dst_slot_list + config.dst_class_types = dst_class_types + config.dst_class_labels = dst_class_labels + + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) + model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) + + logger.info("Updated model config: %s" % config) + + if args.local_rank == 0: + torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab + + model.to(args.device) + + logger.info("Training/evaluation parameters %s", args) + + # Training + if args.do_train: + # If output files already exists, assume to continue training from latest checkpoint (unless overwrite_output_dir is set) + continue_from_global_step = 0 # If set to 0, start training from the beginning + if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: + checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/*/' + WEIGHTS_NAME, recursive=True))) + if len(checkpoints) > 0: + checkpoint = checkpoints[-1] + logger.info("Resuming training from the latest checkpoint: %s", checkpoint) + continue_from_global_step = int(checkpoint.split('-')[-1]) + model = model_class.from_pretrained(checkpoint) + model.to(args.device) + + train_dataset, features = load_and_cache_examples(args, model, tokenizer, processor, evaluate=False) + global_step, tr_loss = train(args, train_dataset, features, model, tokenizer, processor, continue_from_global_step) + logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) + + # Save the trained model and the tokenizer + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + # Create output directory if needed + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + logger.info("Saving model checkpoint to %s", args.output_dir) + # Save a trained model, configuration and tokenizer using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + model_to_save.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) + + # Load a trained model and vocabulary that you have fine-tuned + model = model_class.from_pretrained(args.output_dir) + tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) + model.to(args.device) + + # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory + results = [] + if args.do_eval and args.local_rank in [-1, 0]: + output_eval_file = os.path.join(args.output_dir, "eval_res.%s.json" % (args.predict_type)) + checkpoints = [args.output_dir] + if args.eval_all_checkpoints: + checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) + logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs + + logger.info("Evaluate the following checkpoints: %s", checkpoints) + + for cItr, checkpoint in enumerate(checkpoints): + # Reload the model + global_step = checkpoint.split('-')[-1] + if cItr == len(checkpoints) - 1: + global_step = "final" + model = model_class.from_pretrained(checkpoint) + model.to(args.device) + + # Evaluate + result = evaluate(args, model, tokenizer, processor, prefix=global_step) + result_dict = {k: float(v) for k, v in result.items()} + result_dict["global_step"] = global_step + results.append(result_dict) + + for key in sorted(result_dict.keys()): + logger.info("%s = %s", key, str(result_dict[key])) + + with open(output_eval_file, "w") as f: + json.dump(results, f, indent=2) + + return results + + +if __name__ == "__main__": + main() diff --git a/tensorlistdataset.py b/tensorlistdataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5b27b3cd494ed316647ef727d4b5ee957feccc18 --- /dev/null +++ b/tensorlistdataset.py @@ -0,0 +1,57 @@ +# coding=utf-8 +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.utils.data import Dataset + + +class TensorListDataset(Dataset): + r"""Dataset wrapping tensors, tensor dicts and tensor lists. + + Arguments: + *data (Tensor or dict or list of Tensors): tensors that have the same size + of the first dimension. + """ + + def __init__(self, *data): + if isinstance(data[0], dict): + size = list(data[0].values())[0].size(0) + elif isinstance(data[0], list): + size = data[0][0].size(0) + else: + size = data[0].size(0) + for element in data: + if isinstance(element, dict): + assert all(size == tensor.size(0) for name, tensor in element.items()) # dict of tensors + elif isinstance(element, list): + assert all(size == tensor.size(0) for tensor in element) # list of tensors + else: + assert size == element.size(0) # tensor + self.size = size + self.data = data + + def __getitem__(self, index): + result = [] + for element in self.data: + if isinstance(element, dict): + result.append({k: v[index] for k, v in element.items()}) + elif isinstance(element, list): + result.append(v[index] for v in element) + else: + result.append(element[index]) + return tuple(result) + + def __len__(self): + return self.size diff --git a/utils_dst.py b/utils_dst.py new file mode 100644 index 0000000000000000000000000000000000000000..e91971da86f4f9185f3428e52a09f48b5247343d --- /dev/null +++ b/utils_dst.py @@ -0,0 +1,429 @@ +# coding=utf-8 +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import six +import numpy as np +import json + +logger = logging.getLogger(__name__) + + +class DSTExample(object): + """ + A single training/test example for the DST dataset. + """ + + def __init__(self, + guid, + text_a, + text_b, + history, + text_a_label=None, + text_b_label=None, + history_label=None, + values=None, + inform_label=None, + inform_slot_label=None, + refer_label=None, + diag_state=None, + class_label=None): + self.guid = guid + self.text_a = text_a + self.text_b = text_b + self.history = history + self.text_a_label = text_a_label + self.text_b_label = text_b_label + self.history_label = history_label + self.values = values + self.inform_label = inform_label + self.inform_slot_label = inform_slot_label + self.refer_label = refer_label + self.diag_state = diag_state + self.class_label = class_label + + def __str__(self): + return self.__repr__() + + def __repr__(self): + s = "" + s += "guid: %s" % (self.guid) + s += ", text_a: %s" % (self.text_a) + s += ", text_b: %s" % (self.text_b) + s += ", history: %s" % (self.history) + if self.text_a_label: + s += ", text_a_label: %d" % (self.text_a_label) + if self.text_b_label: + s += ", text_b_label: %d" % (self.text_b_label) + if self.history_label: + s += ", history_label: %d" % (self.history_label) + if self.values: + s += ", values: %d" % (self.values) + if self.inform_label: + s += ", inform_label: %d" % (self.inform_label) + if self.inform_slot_label: + s += ", inform_slot_label: %d" % (self.inform_slot_label) + if self.refer_label: + s += ", refer_label: %d" % (self.refer_label) + if self.diag_state: + s += ", diag_state: %d" % (self.diag_state) + if self.class_label: + s += ", class_label: %d" % (self.class_label) + return s + + +class InputFeatures(object): + """A single set of features of data.""" + + def __init__(self, + input_ids, + input_ids_unmasked, + input_mask, + segment_ids, + start_pos=None, + end_pos=None, + values=None, + inform=None, + inform_slot=None, + refer_id=None, + diag_state=None, + class_label_id=None, + guid="NONE"): + self.guid = guid + self.input_ids = input_ids + self.input_ids_unmasked = input_ids_unmasked + self.input_mask = input_mask + self.segment_ids = segment_ids + self.start_pos = start_pos + self.end_pos = end_pos + self.values = values + self.inform = inform + self.inform_slot = inform_slot + self.refer_id = refer_id + self.diag_state = diag_state + self.class_label_id = class_label_id + + +def convert_examples_to_features(examples, slot_list, class_types, model_type, tokenizer, max_seq_length, slot_value_dropout=0.0): + """Loads a data file into a list of `InputBatch`s.""" + + if model_type == 'bert': + model_specs = {'MODEL_TYPE': 'bert', + 'CLS_TOKEN': '[CLS]', + 'UNK_TOKEN': '[UNK]', + 'SEP_TOKEN': '[SEP]', + 'TOKEN_CORRECTION': 4} + else: + logger.error("Unknown model type (%s). Aborting." % (model_type)) + exit(1) + + def _tokenize_text_and_label(text, text_label_dict, slot, tokenizer, model_specs, slot_value_dropout): + joint_text_label = [0 for _ in text_label_dict[slot]] # joint all slots' label + for slot_text_label in text_label_dict.values(): + for idx, label in enumerate(slot_text_label): + if label == 1: + joint_text_label[idx] = 1 + + text_label = text_label_dict[slot] + tokens = [] + tokens_unmasked = [] + token_labels = [] + for token, token_label, joint_label in zip(text, text_label, joint_text_label): + token = convert_to_unicode(token) + sub_tokens = tokenizer.tokenize(token) # Most time intensive step + tokens_unmasked.extend(sub_tokens) + if slot_value_dropout == 0.0 or joint_label == 0: + tokens.extend(sub_tokens) + else: + rn_list = np.random.random_sample((len(sub_tokens),)) + for rn, sub_token in zip(rn_list, sub_tokens): + if rn > slot_value_dropout: + tokens.append(sub_token) + else: + tokens.append(model_specs['UNK_TOKEN']) + token_labels.extend([token_label for _ in sub_tokens]) + assert len(tokens) == len(token_labels) + assert len(tokens_unmasked) == len(token_labels) + return tokens, tokens_unmasked, token_labels + + def _truncate_seq_pair(tokens_a, tokens_b, history, max_length): + """Truncates a sequence pair in place to the maximum length. + Copied from bert/run_classifier.py + """ + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + while True: + total_length = len(tokens_a) + len(tokens_b) + len(history) + if total_length <= max_length: + break + if len(history) > 0: + history.pop() + elif len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + + def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, model_specs, guid): + # Modifies `tokens_a` and `tokens_b` in place so that the total + # length is less than the specified length. + # Account for [CLS], [SEP], [SEP], [SEP] with "- 4" (BERT) + if len(tokens_a) + len(tokens_b) + len(history) > max_seq_length - model_specs['TOKEN_CORRECTION']: + logger.info("Truncate Example %s. Total len=%d." % (guid, len(tokens_a) + len(tokens_b) + len(history))) + input_text_too_long = True + else: + input_text_too_long = False + _truncate_seq_pair(tokens_a, tokens_b, history, max_seq_length - model_specs['TOKEN_CORRECTION']) + return input_text_too_long + + def _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs): + token_label_ids = [] + token_label_ids.append(0) # [CLS] + for token_label in token_labels_a: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + for token_label in token_labels_b: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + for token_label in token_labels_history: + token_label_ids.append(token_label) + token_label_ids.append(0) # [SEP] + while len(token_label_ids) < max_seq_length: + token_label_ids.append(0) # padding + assert len(token_label_ids) == max_seq_length + return token_label_ids + + def _get_start_end_pos(class_type, token_label_ids, max_seq_length): + if class_type == 'copy_value' and 1 not in token_label_ids: + #logger.warn("copy_value label, but token_label not detected. Setting label to 'none'.") + class_type = 'none' + start_pos = 0 + end_pos = 0 + if 1 in token_label_ids: + start_pos = token_label_ids.index(1) + # Parsing is supposed to find only first location of wanted value + if 0 not in token_label_ids[start_pos:]: + end_pos = len(token_label_ids[start_pos:]) + start_pos - 1 + else: + end_pos = token_label_ids[start_pos:].index(0) + start_pos - 1 + for i in range(max_seq_length): + if i >= start_pos and i <= end_pos: + assert token_label_ids[i] == 1 + return class_type, start_pos, end_pos + + def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, tokenizer, model_specs): + # The convention in BERT is: + # (a) For sequence pairs: + # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] + # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 + # (b) For single sequences: + # tokens: [CLS] the dog is hairy . [SEP] + # type_ids: 0 0 0 0 0 0 0 + # + # Where "type_ids" are used to indicate whether this is the first + # sequence or the second sequence. The embedding vectors for `type=0` and + # `type=1` were learned during pre-training and are added to the wordpiece + # embedding vector (and position vector). This is not *strictly* necessary + # since the [SEP] token unambiguously separates the sequences, but it makes + # it easier for the model to learn the concept of sequences. + # + # For classification tasks, the first vector (corresponding to [CLS]) is + # used as the "sentence vector". Note that this only makes sense because + # the entire model is fine-tuned. + tokens = [] + segment_ids = [] + tokens.append(model_specs['CLS_TOKEN']) + segment_ids.append(0) + for token in tokens_a: + tokens.append(token) + segment_ids.append(0) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(0) + for token in tokens_b: + tokens.append(token) + segment_ids.append(1) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(1) + for token in history: + tokens.append(token) + segment_ids.append(1) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(1) + input_ids = tokenizer.convert_tokens_to_ids(tokens) + # The mask has 1 for real tokens and 0 for padding tokens. Only real + # tokens are attended to. + input_mask = [1] * len(input_ids) + # Zero-pad up to the sequence length. + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + return tokens, input_ids, input_mask, segment_ids + + total_cnt = 0 + too_long_cnt = 0 + + refer_list = ['none'] + slot_list + + features = [] + # Convert single example + for (example_index, example) in enumerate(examples): + if example_index % 1000 == 0: + logger.info("Writing example %d of %d" % (example_index, len(examples))) + + total_cnt += 1 + + value_dict = {} + inform_dict = {} + inform_slot_dict = {} + refer_id_dict = {} + diag_state_dict = {} + class_label_id_dict = {} + start_pos_dict = {} + end_pos_dict = {} + for slot in slot_list: + tokens_a, tokens_a_unmasked, token_labels_a = _tokenize_text_and_label( + example.text_a, example.text_a_label, slot, tokenizer, model_specs, slot_value_dropout) + tokens_b, tokens_b_unmasked, token_labels_b = _tokenize_text_and_label( + example.text_b, example.text_b_label, slot, tokenizer, model_specs, slot_value_dropout) + tokens_history, tokens_history_unmasked, token_labels_history = _tokenize_text_and_label( + example.history, example.history_label, slot, tokenizer, model_specs, slot_value_dropout) + + input_text_too_long = _truncate_length_and_warn( + tokens_a, tokens_b, tokens_history, max_seq_length, model_specs, example.guid) + + if input_text_too_long: + if example_index < 10: + if len(token_labels_a) > len(tokens_a): + logger.info(' tokens_a truncated labels: %s' % str(token_labels_a[len(tokens_a):])) + if len(token_labels_b) > len(tokens_b): + logger.info(' tokens_b truncated labels: %s' % str(token_labels_b[len(tokens_b):])) + if len(token_labels_history) > len(tokens_history): + logger.info(' tokens_history truncated labels: %s' % str(token_labels_history[len(tokens_history):])) + + token_labels_a = token_labels_a[:len(tokens_a)] + token_labels_b = token_labels_b[:len(tokens_b)] + token_labels_history = token_labels_history[:len(tokens_history)] + tokens_a_unmasked = tokens_a_unmasked[:len(tokens_a)] + tokens_b_unmasked = tokens_b_unmasked[:len(tokens_b)] + tokens_history_unmasked = tokens_history_unmasked[:len(tokens_history)] + + assert len(token_labels_a) == len(tokens_a) + assert len(token_labels_b) == len(tokens_b) + assert len(token_labels_history) == len(tokens_history) + assert len(token_labels_a) == len(tokens_a_unmasked) + assert len(token_labels_b) == len(tokens_b_unmasked) + assert len(token_labels_history) == len(tokens_history_unmasked) + token_label_ids = _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs) + + value_dict[slot] = example.values[slot] + inform_dict[slot] = example.inform_label[slot] + + class_label_mod, start_pos_dict[slot], end_pos_dict[slot] = _get_start_end_pos( + example.class_label[slot], token_label_ids, max_seq_length) + if class_label_mod != example.class_label[slot]: + example.class_label[slot] = class_label_mod + inform_slot_dict[slot] = example.inform_slot_label[slot] + refer_id_dict[slot] = refer_list.index(example.refer_label[slot]) + diag_state_dict[slot] = class_types.index(example.diag_state[slot]) + class_label_id_dict[slot] = class_types.index(example.class_label[slot]) + + if input_text_too_long: + too_long_cnt += 1 + + tokens, input_ids, input_mask, segment_ids = _get_transformer_input(tokens_a, + tokens_b, + tokens_history, + max_seq_length, + tokenizer, + model_specs) + if slot_value_dropout > 0.0: + _, input_ids_unmasked, _, _ = _get_transformer_input(tokens_a_unmasked, + tokens_b_unmasked, + tokens_history_unmasked, + max_seq_length, + tokenizer, + model_specs) + else: + input_ids_unmasked = input_ids + + assert(len(input_ids) == len(input_ids_unmasked)) + + if example_index < 10: + logger.info("*** Example ***") + logger.info("guid: %s" % (example.guid)) + logger.info("tokens: %s" % " ".join(tokens)) + logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) + logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) + logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) + logger.info("start_pos: %s" % str(start_pos_dict)) + logger.info("end_pos: %s" % str(end_pos_dict)) + logger.info("values: %s" % str(value_dict)) + logger.info("inform: %s" % str(inform_dict)) + logger.info("inform_slot: %s" % str(inform_slot_dict)) + logger.info("refer_id: %s" % str(refer_id_dict)) + logger.info("diag_state: %s" % str(diag_state_dict)) + logger.info("class_label_id: %s" % str(class_label_id_dict)) + + features.append( + InputFeatures( + guid=example.guid, + input_ids=input_ids, + input_ids_unmasked=input_ids_unmasked, + input_mask=input_mask, + segment_ids=segment_ids, + start_pos=start_pos_dict, + end_pos=end_pos_dict, + values=value_dict, + inform=inform_dict, + inform_slot=inform_slot_dict, + refer_id=refer_id_dict, + diag_state=diag_state_dict, + class_label_id=class_label_id_dict)) + + logger.info("========== %d out of %d examples have text too long" % (too_long_cnt, total_cnt)) + + return features + + +# From bert.tokenization (TF code) +def convert_to_unicode(text): + """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" + if six.PY3: + if isinstance(text, str): + return text + elif isinstance(text, bytes): + return text.decode("utf-8", "ignore") + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + elif six.PY2: + if isinstance(text, str): + return text.decode("utf-8", "ignore") + elif isinstance(text, unicode): + return text + else: + raise ValueError("Unsupported string type: %s" % (type(text))) + else: + raise ValueError("Not running on Python2 or Python 3?")