Skip to content
Snippets Groups Projects
Commit 2e2a8b87 authored by heck's avatar heck
Browse files

initial

parent eea0be40
Branches
No related tags found
No related merge requests found
#!/bin/bash
# Parameters ------------------------------------------------------
#TASK="sim-m"
#DATA_DIR="data/simulated-dialogue/sim-M"
#TASK="sim-r"
#DATA_DIR="data/simulated-dialogue/sim-R"
TASK="woz2"
DATA_DIR="data/woz2"
#TASK="multiwoz21"
#DATA_DIR="data/MULTIWOZ2.1"
# Project paths etc. ----------------------------------------------
OUT_DIR=results
mkdir -p ${OUT_DIR}
# Main ------------------------------------------------------------
for step in train dev test; do
args_add=""
if [ "$step" = "train" ]; then
args_add="--do_train --predict_type=dummy"
elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
args_add="--do_eval --predict_type=${step}"
fi
python3 run_dst.py \
--task_name=${TASK} \
--data_dir=${DATA_DIR} \
--dataset_config=dataset_config/${TASK}.json \
--model_type="bert" \
--model_name_or_path="bert-base-uncased" \
--do_lower_case \
--learning_rate=1e-4 \
--num_train_epochs=10 \
--max_seq_length=180 \
--per_gpu_train_batch_size=48 \
--per_gpu_eval_batch_size=1 \
--output_dir=${OUT_DIR} \
--save_epochs=2 \
--logging_steps=10 \
--warmup_proportion=0.1 \
--eval_all_checkpoints \
--adam_epsilon=1e-6 \
--label_value_repetitions \
--swap_utterances \
--append_history \
--use_history_labels \
--delexicalize_sys_utts \
--class_aux_feats_inform \
--class_aux_feats_ds \
${args_add} \
2>&1 | tee ${OUT_DIR}/${step}.log
fi
if [ "$step" = "dev" ] || [ "$step" = "test" ]; then
python3 ~/tools/trippy_clean/metric_bert_dst.py \
${TASK} \
dataset_config/${TASK}.json \
"${OUT_DIR}/pred_res.${step}*json" \
2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log
fi
done
#!/bin/bash
# Parameters ------------------------------------------------------
#TASK="sim-m"
#DATA_DIR="data/simulated-dialogue/sim-M"
#TASK="sim-r"
#DATA_DIR="data/simulated-dialogue/sim-R"
TASK="woz2"
DATA_DIR="data/woz2"
#TASK="multiwoz21"
#DATA_DIR="data/MULTIWOZ2.1"
# Project paths etc. ----------------------------------------------
OUT_DIR=results
mkdir -p ${OUT_DIR}
# Main ------------------------------------------------------------
for step in train dev test; do
args_add=""
if [ "$step" = "train" ]; then
args_add="--do_train --predict_type=dummy"
elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
args_add="--do_eval --predict_type=${step}"
fi
python3 run_dst.py \
--task_name=${TASK} \
--data_dir=${DATA_DIR} \
--dataset_config=dataset_config/${TASK}.json \
--model_type="bert" \
--model_name_or_path="bert-base-uncased" \
--do_lower_case \
--learning_rate=1e-4 \
--num_train_epochs=10 \
--max_seq_length=180 \
--per_gpu_train_batch_size=48 \
--per_gpu_eval_batch_size=1 \
--output_dir=${OUT_DIR} \
--save_epochs=2 \
--logging_steps=10 \
--warmup_proportion=0.1 \
--eval_all_checkpoints \
--adam_epsilon=1e-6 \
--label_value_repetitions \
${args_add} \
2>&1 | tee ${OUT_DIR}/${step}.log
fi
if [ "$step" = "dev" ] || [ "$step" = "test" ]; then
python3 ~/tools/trippy_clean/metric_bert_dst.py \
${TASK} \
dataset_config/${TASK}.json \
"${OUT_DIR}/pred_res.${step}*json" \
2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log
fi
done
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
import dataset_woz2
import dataset_sim
import dataset_multiwoz21
class DataProcessor(object):
def __init__(self, dataset_config):
with open(dataset_config, "r", encoding='utf-8') as f:
raw_config = json.load(f)
self.class_types = raw_config['class_types']
self.slot_list = raw_config['slots']
self.label_maps = raw_config['label_maps']
def get_train_examples(self, data_dir, **args):
raise NotImplementedError()
def get_dev_examples(self, data_dir, **args):
raise NotImplementedError()
def get_test_examples(self, data_dir, **args):
raise NotImplementedError()
class Woz2Processor(DataProcessor):
def get_train_examples(self, data_dir, args):
return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_train_en.json'),
'train', self.slot_list, self.label_maps, **args)
def get_dev_examples(self, data_dir, args):
return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_validate_en.json'),
'dev', self.slot_list, self.label_maps, **args)
def get_test_examples(self, data_dir, args):
return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_test_en.json'),
'test', self.slot_list, self.label_maps, **args)
class Multiwoz21Processor(DataProcessor):
def get_train_examples(self, data_dir, args):
return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'train_dials.json'),
os.path.join(data_dir, 'dialogue_acts.json'),
'train', self.slot_list, self.label_maps, **args)
def get_dev_examples(self, data_dir, args):
return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'val_dials.json'),
os.path.join(data_dir, 'dialogue_acts.json'),
'dev', self.slot_list, self.label_maps, **args)
def get_test_examples(self, data_dir, args):
return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'test_dials.json'),
os.path.join(data_dir, 'dialogue_acts.json'),
'test', self.slot_list, self.label_maps, **args)
class SimProcessor(DataProcessor):
def get_train_examples(self, data_dir, args):
return dataset_sim.create_examples(os.path.join(data_dir, 'train.json'),
'train', self.slot_list, **args)
def get_dev_examples(self, data_dir, args):
return dataset_sim.create_examples(os.path.join(data_dir, 'dev.json'),
'dev', self.slot_list, **args)
def get_test_examples(self, data_dir, args):
return dataset_sim.create_examples(os.path.join(data_dir, 'test.json'),
'test', self.slot_list, **args)
PROCESSORS = {"woz2": Woz2Processor,
"sim-m": SimProcessor,
"sim-r": SimProcessor,
"multiwoz21": Multiwoz21Processor}
{
"class_types": [
"none",
"dontcare",
"copy_value",
"true",
"false",
"refer",
"inform"
]
"slots": [
"taxi-leaveAt",
"taxi-destination",
"taxi-departure",
"taxi-arriveBy",
"restaurant-book_people",
"restaurant-book_day",
"restaurant-book_time",
"restaurant-food",
"restaurant-pricerange",
"restaurant-name",
"restaurant-area",
"hotel-book_people",
"hotel-book_day",
"hotel-book_stay",
"hotel-name",
"hotel-area",
"hotel-parking",
"hotel-pricerange",
"hotel-stars",
"hotel-internet",
"hotel-type",
"attraction-type",
"attraction-name",
"attraction-area",
"train-book_people",
"train-leaveAt",
"train-destination",
"train-day",
"train-arriveBy",
"train-departure"
]
"label_maps": {
"guest house": [
"guest houses"
],
"hotel": [
"hotels"
],
"centre": [
"center",
"downtown"
],
"north": [
"northern",
"northside",
"northend"
],
"east": [
"eastern",
"eastside",
"eastend"
],
"west": [
"western",
"westside",
"westend"
],
"south": [
"southern",
"southside",
"southend"
],
"cheap": [
"inexpensive",
"lower price",
"lower range",
"cheaply",
"cheaper",
"cheapest",
"very affordable"
],
"moderate": [
"moderately",
"reasonable",
"reasonably",
"affordable",
"mid range",
"mid-range",
"priced moderately",
"decently priced",
"mid price",
"mid-price",
"mid priced",
"mid-priced",
"middle price",
"medium price",
"medium priced",
"not too expensive",
"not too cheap"
],
"expensive": [
"high end",
"high-end",
"high class",
"high-class",
"high scale",
"high-scale",
"high price",
"high priced",
"higher price",
"fancy",
"upscale",
"nice",
"expensively",
"luxury"
],
"0": [
"zero"
],
"1": [
"one",
"just me",
"for me",
"myself",
"alone",
"me"
],
"2": [
"two"
],
"3": [
"three"
],
"4": [
"four"
],
"5": [
"five"
],
"6": [
"six"
],
"7": [
"seven"
],
"8": [
"eight"
],
"9": [
"nine"
],
"10": [
"ten"
],
"11": [
"eleven"
],
"12": [
"twelve"
],
"architecture": [
"architectural",
"architecturally",
"architect"
],
"boat": [
"boating",
"boats",
"camboats"
],
"boating": [
"boat",
"boats",
"camboats"
],
"camboats": [
"boating",
"boat",
"boats"
],
"cinema": [
"cinemas",
"movie",
"films",
"film"
],
"college": [
"colleges"
],
"concert": [
"concert hall",
"concert halls",
"concerthall",
"concerthalls",
"concerts"
],
"concerthall": [
"concert hall",
"concert halls",
"concerthalls",
"concerts",
"concert"
],
"entertainment": [
"entertaining"
],
"gallery": [
"museum"
],
"gastropubs": [
"gastropub"
],
"multiple sports": [
"multiple sport",
"multi sport",
"multi sports",
"sports",
"sporting"
],
"museum": [
"museums",
"gallery",
"galleries"
],
"night club": [
"night clubs",
"nightclub",
"nightclubs",
"club",
"clubs"
],
"park": [
"parks"
],
"pool": [
"swimming pool",
"swimming",
"pools",
"swimmingpool",
"water",
"swim"
],
"sports": [
"multiple sport",
"multi sport",
"multi sports",
"multiple sports",
"sporting"
],
"swimming pool": [
"swimming",
"pool",
"pools",
"swimmingpool",
"water",
"swim"
],
"theater": [
"theatre",
"theatres",
"theaters"
],
"theatre": [
"theater",
"theatres",
"theaters"
],
"abbey pool and astroturf pitch": [
"abbey pool and astroturf",
"abbey pool"
],
"abbey pool and astroturf": [
"abbey pool and astroturf pitch",
"abbey pool"
],
"abbey pool": [
"abbey pool and astroturf pitch",
"abbey pool and astroturf"
],
"adc theatre": [
"adc theater",
"adc"
],
"adc": [
"adc theatre",
"adc theater"
],
"addenbrookes hospital": [
"addenbrooke's hospital"
],
"cafe jello gallery": [
"cafe jello"
],
"cambridge and county folk museum": [
"cambridge and country folk museum",
"county folk museum"
],
"cambridge and country folk museum": [
"cambridge and county folk museum",
"county folk museum"
],
"county folk museum": [
"cambridge and county folk museum",
"cambridge and country folk museum"
],
"cambridge arts theatre": [
"cambridge arts theater"
],
"cambridge book and print gallery": [
"book and print gallery"
],
"cambridge contemporary art": [
"cambridge contemporary art museum",
"contemporary art museum"
],
"cambridge contemporary art museum": [
"cambridge contemporary art",
"contemporary art museum"
],
"cambridge corn exchange": [
"the cambridge corn exchange"
],
"the cambridge corn exchange": [
"cambridge corn exchange"
],
"cambridge museum of technology": [
"museum of technology"
],
"cambridge punter": [
"the cambridge punter",
"cambridge punters"
],
"cambridge punters": [
"the cambridge punter",
"cambridge punter"
],
"the cambridge punter": [
"cambridge punter",
"cambridge punters"
],
"cambridge university botanic gardens": [
"cambridge university botanical gardens",
"cambridge university botanical garden",
"cambridge university botanic garden",
"cambridge botanic gardens",
"cambridge botanical gardens",
"cambridge botanic garden",
"cambridge botanical garden",
"botanic gardens",
"botanical gardens",
"botanic garden",
"botanical garden"
],
"cambridge botanic gardens": [
"cambridge university botanic gardens",
"cambridge university botanical gardens",
"cambridge university botanical garden",
"cambridge university botanic garden",
"cambridge botanical gardens",
"cambridge botanic garden",
"cambridge botanical garden",
"botanic gardens",
"botanical gardens",
"botanic garden",
"botanical garden"
],
"botanic gardens": [
"cambridge university botanic gardens",
"cambridge university botanical gardens",
"cambridge university botanical garden",
"cambridge university botanic garden",
"cambridge botanic gardens",
"cambridge botanical gardens",
"cambridge botanic garden",
"cambridge botanical garden",
"botanical gardens",
"botanic garden",
"botanical garden"
],
"cherry hinton village centre": [
"cherry hinton village center"
],
"cherry hinton village center": [
"cherry hinton village centre"
],
"cherry hinton hall and grounds": [
"cherry hinton hall"
],
"cherry hinton hall": [
"cherry hinton hall and grounds"
],
"cherry hinton water play": [
"cherry hinton water play park"
],
"cherry hinton water play park": [
"cherry hinton water play"
],
"christ college": [
"christ's college",
"christs college"
],
"christs college": [
"christ college",
"christ's college"
],
"churchills college": [
"churchill's college",
"churchill college"
],
"cineworld cinema": [
"cineworld"
],
"clair hall": [
"clare hall"
],
"clare hall": [
"clair hall"
],
"the fez club": [
"fez club"
],
"great saint marys church": [
"great saint mary's church",
"great saint mary's",
"great saint marys"
],
"jesus green outdoor pool": [
"jesus green"
],
"jesus green": [
"jesus green outdoor pool"
],
"kettles yard": [
"kettle's yard"
],
"kings college": [
"king's college"
],
"kings hedges learner pool": [
"king's hedges learner pool",
"king hedges learner pool"
],
"king hedges learner pool": [
"king's hedges learner pool",
"kings hedges learner pool"
],
"little saint marys church": [
"little saint mary's church",
"little saint mary's",
"little saint marys"
],
"mumford theatre": [
"mumford theater"
],
"museum of archaelogy": [
"museum of archaeology"
],
"museum of archaelogy and anthropology": [
"museum of archaeology and anthropology"
],
"peoples portraits exhibition": [
"people's portraits exhibition at girton college",
"peoples portraits exhibition at girton college",
"people's portraits exhibition"
],
"peoples portraits exhibition at girton college": [
"people's portraits exhibition at girton college",
"people's portraits exhibition",
"peoples portraits exhibition"
],
"queens college": [
"queens' college",
"queen's college"
],
"riverboat georgina": [
"riverboat"
],
"saint barnabas": [
"saint barbabas press gallery"
],
"saint barnabas press gallery": [
"saint barbabas"
],
"saint catharines college": [
"saint catharine's college",
"saint catharine's"
],
"saint johns college": [
"saint john's college",
"st john's college",
"st johns college"
],
"scott polar": [
"scott polar museum"
],
"scott polar museum": [
"scott polar"
],
"scudamores punting co": [
"scudamore's punting co",
"scudamores punting",
"scudamore's punting",
"scudamores",
"scudamore's",
"scudamore"
],
"scudamore": [
"scudamore's punting co",
"scudamores punting co",
"scudamores punting",
"scudamore's punting",
"scudamores",
"scudamore's"
],
"sheeps green and lammas land park fen causeway": [
"sheep's green and lammas land park fen causeway",
"sheep's green and lammas land park",
"sheeps green and lammas land park",
"lammas land park",
"sheep's green",
"sheeps green"
],
"sheeps green and lammas land park": [
"sheep's green and lammas land park fen causeway",
"sheeps green and lammas land park fen causeway",
"sheep's green and lammas land park",
"lammas land park",
"sheep's green",
"sheeps green"
],
"lammas land park": [
"sheep's green and lammas land park fen causeway",
"sheeps green and lammas land park fen causeway",
"sheep's green and lammas land park",
"sheeps green and lammas land park",
"sheep's green",
"sheeps green"
],
"sheeps green": [
"sheep's green and lammas land park fen causeway",
"sheeps green and lammas land park fen causeway",
"sheep's green and lammas land park",
"sheeps green and lammas land park",
"lammas land park",
"sheep's green"
],
"soul tree nightclub": [
"soul tree night club",
"soul tree",
"soultree"
],
"soultree": [
"soul tree nightclub",
"soul tree night club",
"soul tree"
],
"the man on the moon": [
"man on the moon"
],
"man on the moon": [
"the man on the moon"
],
"the junction": [
"junction theatre",
"junction theater"
],
"junction theatre": [
"the junction",
"junction theater"
],
"old schools": [
"old school"
],
"vue cinema": [
"vue"
],
"wandlebury country park": [
"the wandlebury"
],
"the wandlebury": [
"wandlebury country park"
],
"whipple museum of the history of science": [
"whipple museum",
"history of science museum"
],
"history of science museum": [
"whipple museum of the history of science",
"whipple museum"
],
"williams art and antique": [
"william's art and antique"
],
"alimentum": [
"restaurant alimentum"
],
"restaurant alimentum": [
"alimentum"
],
"bedouin": [
"the bedouin"
],
"the bedouin": [
"bedouin"
],
"bloomsbury restaurant": [
"bloomsbury"
],
"cafe uno": [
"caffe uno",
"caffee uno"
],
"caffe uno": [
"cafe uno",
"caffee uno"
],
"caffee uno": [
"cafe uno",
"caffe uno"
],
"cambridge lodge restaurant": [
"cambridge lodge"
],
"chiquito": [
"chiquito restaurant bar",
"chiquito restaurant"
],
"chiquito restaurant bar": [
"chiquito restaurant",
"chiquito"
],
"city stop restaurant": [
"city stop"
],
"cityr": [
"cityroomz"
],
"citiroomz": [
"cityroomz"
],
"clowns cafe": [
"clown's cafe"
],
"cow pizza kitchen and bar": [
"the cow pizza kitchen and bar",
"cow pizza"
],
"the cow pizza kitchen and bar": [
"cow pizza kitchen and bar",
"cow pizza"
],
"darrys cookhouse and wine shop": [
"darry's cookhouse and wine shop",
"darry's cookhouse",
"darrys cookhouse"
],
"de luca cucina and bar": [
"de luca cucina and bar riverside brasserie",
"luca cucina and bar",
"de luca cucina",
"luca cucina"
],
"de luca cucina and bar riverside brasserie": [
"de luca cucina and bar",
"luca cucina and bar",
"de luca cucina",
"luca cucina"
],
"da vinci pizzeria": [
"da vinci pizza",
"da vinci"
],
"don pasquale pizzeria": [
"don pasquale pizza",
"don pasquale",
"pasquale pizzeria",
"pasquale pizza"
],
"efes": [
"efes restaurant"
],
"efes restaurant": [
"efes"
],
"fitzbillies restaurant": [
"fitzbillies"
],
"frankie and bennys": [
"frankie and benny's"
],
"funky": [
"funky fun house"
],
"funky fun house": [
"funky"
],
"gardenia": [
"the gardenia"
],
"the gardenia": [
"gardenia"
],
"grafton hotel restaurant": [
"the grafton hotel",
"grafton hotel"
],
"the grafton hotel": [
"grafton hotel restaurant",
"grafton hotel"
],
"grafton hotel": [
"grafton hotel restaurant",
"the grafton hotel"
],
"hotel du vin and bistro": [
"hotel du vin",
"du vin"
],
"Kohinoor": [
"kohinoor",
"the kohinoor"
],
"kohinoor": [
"the kohinoor"
],
"the kohinoor": [
"kohinoor"
],
"lan hong house": [
"lan hong",
"ian hong house",
"ian hong"
],
"ian hong": [
"lan hong house",
"lan hong",
"ian hong house"
],
"lovel": [
"the lovell lodge",
"lovell lodge"
],
"lovell lodge": [
"lovell"
],
"the lovell lodge": [
"lovell lodge",
"lovell"
],
"mahal of cambridge": [
"mahal"
],
"mahal": [
"mahal of cambridge"
],
"maharajah tandoori restaurant": [
"maharajah tandoori"
],
"the maharajah tandoor": [
"maharajah tandoori restaurant",
"maharajah tandoori"
],
"meze bar": [
"meze bar restaurant",
"the meze bar"
],
"meze bar restaurant": [
"the meze bar",
"meze bar"
],
"the meze bar": [
"meze bar restaurant",
"meze bar"
],
"michaelhouse cafe": [
"michael house cafe"
],
"midsummer house restaurant": [
"midsummer house"
],
"missing sock": [
"the missing sock"
],
"the missing sock": [
"missing sock"
],
"nandos": [
"nando's city centre",
"nando's city center",
"nandos city centre",
"nandos city center",
"nando's"
],
"nandos city centre": [
"nando's city centre",
"nando's city center",
"nandos city center",
"nando's",
"nandos"
],
"oak bistro": [
"the oak bistro"
],
"the oak bistro": [
"oak bistro"
],
"one seven": [
"restaurant one seven"
],
"restaurant one seven": [
"one seven"
],
"river bar steakhouse and grill": [
"the river bar steakhouse and grill",
"the river bar steakhouse",
"river bar steakhouse"
],
"the river bar steakhouse and grill": [
"river bar steakhouse and grill",
"the river bar steakhouse",
"river bar steakhouse"
],
"pipasha restaurant": [
"pipasha"
],
"pizza hut city centre": [
"pizza hut city center"
],
"pizza hut fenditton": [
"pizza hut fen ditton",
"pizza express fen ditton"
],
"restaurant two two": [
"two two",
"restaurant 22"
],
"saffron brasserie": [
"saffron"
],
"saint johns chop house": [
"saint john's chop house",
"st john's chop house",
"st johns chop house"
],
"sesame restaurant and bar": [
"sesame restaurant",
"sesame"
],
"shanghai": [
"shanghai family restaurant"
],
"shanghai family restaurant": [
"shanghai"
],
"sitar": [
"sitar tandoori"
],
"sitar tandoori": [
"sitar"
],
"slug and lettuce": [
"the slug and lettuce"
],
"the slug and lettuce": [
"slug and lettuce"
],
"st johns chop house": [
"saint john's chop house",
"st john's chop house",
"saint johns chop house"
],
"stazione restaurant and coffee bar": [
"stazione restaurant",
"stazione"
],
"thanh binh": [
"thanh",
"binh"
],
"thanh": [
"thanh binh",
"binh"
],
"binh": [
"thanh binh",
"thanh"
],
"the hotpot": [
"the hotspot",
"hotpot",
"hotspot"
],
"hotpot": [
"the hotpot",
"the hotpot",
"hotspot"
],
"the lucky star": [
"lucky star"
],
"lucky star": [
"the lucky star"
],
"the peking restaurant: ": [
"peking restaurant"
],
"the varsity restaurant": [
"varsity restaurant",
"the varsity",
"varsity"
],
"two two": [
"restaurant two two",
"restaurant 22"
],
"restaurant 22": [
"restaurant two two",
"two two"
],
"zizzi cambridge": [
"zizzi"
],
"american": [
"americas"
],
"asian oriental": [
"asian",
"oriental"
],
"australian": [
"australasian"
],
"barbeque": [
"barbecue",
"bbq"
],
"corsica": [
"corsican"
],
"indian": [
"tandoori"
],
"italian": [
"pizza",
"pizzeria"
],
"japanese": [
"sushi"
],
"latin american": [
"latin-american",
"latin"
],
"malaysian": [
"malay"
],
"middle eastern": [
"middle-eastern"
],
"traditional american": [
"american"
],
"modern american": [
"american modern",
"american"
],
"modern european": [
"european modern",
"european"
],
"north american": [
"north-american",
"american"
],
"portuguese": [
"portugese"
],
"portugese": [
"portuguese"
],
"seafood": [
"sea food"
],
"singaporean": [
"singapore"
],
"steakhouse": [
"steak house",
"steak"
],
"the americas": [
"american",
"americas"
],
"a and b guest house": [
"a & b guest house",
"a and b",
"a & b"
],
"the acorn guest house": [
"acorn guest house",
"acorn"
],
"acorn guest house": [
"the acorn guest house",
"acorn"
],
"alexander bed and breakfast": [
"alexander"
],
"allenbell": [
"the allenbell"
],
"the allenbell": [
"allenbell"
],
"alpha-milton guest house": [
"the alpha-milton",
"alpha-milton"
],
"the alpha-milton": [
"alpha-milton guest house",
"alpha-milton"
],
"arbury lodge guest house": [
"arbury lodge",
"arbury"
],
"archway house": [
"archway"
],
"ashley hotel": [
"the ashley hotel",
"ashley"
],
"the ashley hotel": [
"ashley hotel",
"ashley"
],
"aylesbray lodge guest house": [
"aylesbray lodge",
"aylesbray"
],
"aylesbray lodge guest": [
"aylesbray lodge guest house",
"aylesbray lodge",
"aylesbray"
],
"alesbray lodge guest house": [
"aylesbray lodge guest house",
"aylesbray lodge",
"aylesbray"
],
"alyesbray lodge hotel": [
"aylesbray lodge guest house",
"aylesbray lodge",
"aylesbray"
],
"bridge guest house": [
"bridge house"
],
"cambridge belfry": [
"the cambridge belfry",
"belfry hotel",
"belfry"
],
"the cambridge belfry": [
"cambridge belfry",
"belfry hotel",
"belfry"
],
"belfry hotel": [
"the cambridge belfry",
"cambridge belfry",
"belfry"
],
"carolina bed and breakfast": [
"carolina"
],
"city centre north": [
"city centre north bed and breakfast"
],
"north b and b": [
"city centre north bed and breakfast"
],
"city centre north b and b": [
"city centre north bed and breakfast"
],
"el shaddia guest house": [
"el shaddai guest house",
"el shaddai",
"el shaddia"
],
"el shaddai guest house": [
"el shaddia guest house",
"el shaddai",
"el shaddia"
],
"express by holiday inn cambridge": [
"express by holiday inn",
"holiday inn cambridge",
"holiday inn"
],
"holiday inn": [
"express by holiday inn cambridge",
"express by holiday inn",
"holiday inn cambridge"
],
"finches bed and breakfast": [
"finches"
],
"gonville hotel": [
"gonville"
],
"hamilton lodge": [
"the hamilton lodge",
"hamilton"
],
"the hamilton lodge": [
"hamilton lodge",
"hamilton"
],
"hobsons house": [
"hobson's house",
"hobson's"
],
"huntingdon marriott hotel": [
"huntington marriott hotel",
"huntington marriot hotel",
"huntingdon marriot hotel",
"huntington marriott",
"huntingdon marriott",
"huntington marriot",
"huntingdon marriot",
"huntington",
"huntingdon"
],
"kirkwood": [
"kirkwood house"
],
"kirkwood house": [
"kirkwood"
],
"lensfield hotel": [
"the lensfield hotel",
"lensfield"
],
"the lensfield hotel": [
"lensfield hotel",
"lensfield"
],
"leverton house": [
"leverton"
],
"marriot hotel": [
"marriott hotel",
"marriott"
],
"rosas bed and breakfast": [
"rosa's bed and breakfast",
"rosa's",
"rosas"
],
"university arms hotel": [
"university arms"
],
"warkworth house": [
"warkworth hotel",
"warkworth"
],
"warkworth hotel": [
"warkworth house",
"warkworth"
],
"wartworth": [
"warkworth house",
"warkworth hotel",
"warkworth"
],
"worth house": [
"the worth house"
],
"the worth house": [
"worth house"
],
"birmingham new street": [
"birmingham new street train station"
],
"birmingham new street train station": [
"birmingham new street"
],
"bishops stortford": [
"bishops stortford train station"
],
"bishops stortford train station": [
"bishops stortford"
],
"broxbourne": [
"broxbourne train station"
],
"broxbourne train station": [
"broxbourne"
],
"cambridge": [
"cambridge train station"
],
"cambridge train station": [
"cambridge"
],
"ely": [
"ely train station"
],
"ely train station": [
"ely"
],
"kings lynn": [
"king's lynn",
"king's lynn train station",
"kings lynn train station"
],
"kings lynn train station": [
"kings lynn",
"king's lynn",
"king's lynn train station"
],
"leicester": [
"leicester train station"
],
"leicester train station": [
"leicester"
],
"london kings cross": [
"kings cross",
"king's cross",
"london king's cross",
"kings cross train station",
"king's cross train station",
"london king's cross train station",
"london kings cross train station"
],
"london kings cross train station": [
"kings cross",
"king's cross",
"london king's cross",
"london kings cross",
"kings cross train station",
"king's cross train station",
"london king's cross train station"
],
"london liverpool": [
"liverpool street",
"london liverpool street",
"london liverpool train station",
"liverpool street train station",
"london liverpool street train station"
],
"london liverpool street": [
"london liverpool",
"liverpool street",
"london liverpool train station",
"liverpool street train station",
"london liverpool street train station"
],
"london liverpool street train station": [
"london liverpool",
"liverpool street",
"london liverpool street",
"london liverpool train station",
"liverpool street train station"
],
"norwich": [
"norwich train station"
],
"norwich train station": [
"norwich"
],
"peterborough": [
"peterborough train station"
],
"peterborough train station": [
"peterborough"
],
"stansted airport": [
"stansted airport train station"
],
"stansted airport train station": [
"stansted airport"
],
"stevenage": [
"stevenage train station"
],
"stevenage train station": [
"stevenage"
]
}
}
{
"class_types": [
"none",
"dontcare",
"copy_value",
"inform"
],
"slots": [
"date",
"movie",
"time",
"num_tickets",
"theatre_name"
],
"label_maps": {}
}
{
"class_types": [
"none",
"dontcare",
"copy_value",
"inform"
],
"slots": [
"category",
"rating",
"num_people",
"location",
"restaurant_name",
"time",
"date",
"price_range",
"meal"
],
"label_maps": {}
}
{
"class_types": [
"none",
"dontcare",
"copy_value",
"inform"
],
"slots": [
"area",
"food",
"price_range"
],
"label_maps": {
"center": [
"centre",
"downtown",
"central",
"down town",
"middle"
],
"south": [
"southern",
"southside"
],
"north": [
"northern",
"uptown",
"northside"
],
"west": [
"western",
"westside"
],
"east": [
"eastern",
"eastside"
],
"east side": [
"eastern",
"eastside"
],
"cheap": [
"low price",
"inexpensive",
"cheaper",
"low priced",
"affordable",
"nothing too expensive",
"without costing a fortune",
"cheapest",
"good deals",
"low prices",
"afford",
"on a budget",
"fair prices",
"less expensive",
"cheapeast",
"not cost an arm and a leg"
],
"moderate": [
"moderately",
"medium priced",
"medium price",
"fair price",
"fair prices",
"reasonable",
"reasonably priced",
"mid price",
"fairly priced",
"not outrageous",
"not too expensive",
"on a budget",
"mid range",
"reasonable priced",
"less expensive",
"not too pricey",
"nothing too expensive",
"nothing cheap",
"not overpriced",
"medium",
"inexpensive"
],
"expensive": [
"high priced",
"high end",
"high class",
"high quality",
"fancy",
"upscale",
"nice",
"fine dining",
"expensively priced"
],
"afghan": [
"afghanistan"
],
"african": [
"africa"
],
"asian oriental": [
"asian",
"oriental"
],
"australasian": [
"australian asian",
"austral asian"
],
"australian": [
"aussie"
],
"barbeque": [
"barbecue",
"bbq"
],
"basque": [
"bask"
],
"belgian": [
"belgium"
],
"british": [
"cotto"
],
"canapes": [
"canopy",
"canape",
"canap"
],
"catalan": [
"catalonian"
],
"corsican": [
"corsica"
],
"crossover": [
"cross over",
"over"
],
"gastropub": [
"gastro pub",
"gastro",
"gastropubs"
],
"hungarian": [
"goulash"
],
"indian": [
"india",
"indians",
"nirala"
],
"international": [
"all types of food"
],
"italian": [
"prezzo"
],
"jamaican": [
"jamaica"
],
"japanese": [
"sushi",
"beni hana"
],
"korean": [
"korea"
],
"lebanese": [
"lebanse"
],
"north american": [
"american",
"hamburger"
],
"portuguese": [
"portugese"
],
"seafood": [
"sea food",
"shellfish",
"fish"
],
"singaporean": [
"singapore"
],
"steakhouse": [
"steak house",
"steak"
],
"thai": [
"thailand",
"bangkok"
],
"traditional": [
"old fashioned",
"plain"
],
"turkish": [
"turkey"
],
"unusual": [
"unique and strange"
],
"venetian": [
"vanessa"
],
"vietnamese": [
"vietnam",
"thanh binh"
]
}
}
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from utils_dst import (DSTExample, convert_to_unicode)
# Required for mapping slot names in dialogue_acts.json file
# to proper designations.
ACTS_DICT = {'taxi-depart': 'taxi-departure',
'taxi-dest': 'taxi-destination',
'taxi-leave': 'taxi-leaveAt',
'taxi-arrive': 'taxi-arriveBy',
'train-depart': 'train-departure',
'train-dest': 'train-destination',
'train-leave': 'train-leaveAt',
'train-arrive': 'train-arriveBy',
'train-people': 'train-book_people',
'restaurant-price': 'restaurant-pricerange',
'restaurant-people': 'restaurant-book_people',
'restaurant-day': 'restaurant-book_day',
'restaurant-time': 'restaurant-book_time',
'hotel-price': 'hotel-pricerange',
'hotel-people': 'hotel-book_people',
'hotel-day': 'hotel-book_day',
'hotel-stay': 'hotel-book_stay',
'booking-people': 'booking-book_people',
'booking-day': 'booking-book_day',
'booking-stay': 'booking-book_stay',
'booking-time': 'booking-book_time',
}
LABEL_MAPS = {} # Loaded from file
# Loads the dialogue_acts.json and returns a list
# of slot-value pairs.
def load_acts(input_file):
with open(input_file) as f:
acts = json.load(f)
s_dict = {}
for d in acts:
for t in acts[d]:
# Only process, if turn has annotation
if isinstance(acts[d][t], dict):
for a in acts[d][t]:
aa = a.lower().split('-')
if aa[1] == 'inform' or aa[1] == 'recommend' or aa[1] == 'select' or aa[1] == 'book':
for i in acts[d][t][a]:
s = i[0].lower()
v = i[1].lower().strip()
if s == 'none' or v == '?' or v == 'none':
continue
slot = aa[0] + '-' + s
if slot in ACTS_DICT:
slot = ACTS_DICT[slot]
key = d + '.json', t, slot
# In case of multiple mentioned values...
# ... Option 1: Keep first informed value
if key not in s_dict:
s_dict[key] = list([v])
# ... Option 2: Keep last informed value
#s_dict[key] = list([v])
return s_dict
def normalize_time(text):
text = re.sub("(\d{1})(a\.?m\.?|p\.?m\.?)", r"\1 \2", text) # am/pm without space
text = re.sub("(^| )(\d{1,2}) (a\.?m\.?|p\.?m\.?)", r"\1\2:00 \3", text) # am/pm short to long form
text = re.sub("(^| )(at|from|by|until|after) ?(\d{1,2}) ?(\d{2})([^0-9]|$)", r"\1\2 \3:\4\5", text) # Missing separator
text = re.sub("(^| )(\d{2})[;.,](\d{2})", r"\1\2:\3", text) # Wrong separator
text = re.sub("(^| )(at|from|by|until|after) ?(\d{1,2})([;., ]|$)", r"\1\2 \3:00\4", text) # normalize simple full hour time
text = re.sub("(^| )(\d{1}:\d{2})", r"\g<1>0\2", text) # Add missing leading 0
# Map 12 hour times to 24 hour times
text = re.sub("(\d{2})(:\d{2}) ?p\.?m\.?", lambda x: str(int(x.groups()[0]) + 12 if int(x.groups()[0]) < 12 else int(x.groups()[0])) + x.groups()[1], text)
text = re.sub("(^| )24:(\d{2})", r"\g<1>00:\2", text) # Correct times that use 24 as hour
return text
def normalize_text(text):
text = normalize_time(text)
text = re.sub("n't", " not", text)
text = re.sub("(^| )zero(-| )star([s.,? ]|$)", r"\g<1>0 star\3", text)
text = re.sub("(^| )one(-| )star([s.,? ]|$)", r"\g<1>1 star\3", text)
text = re.sub("(^| )two(-| )star([s.,? ]|$)", r"\g<1>2 star\3", text)
text = re.sub("(^| )three(-| )star([s.,? ]|$)", r"\g<1>3 star\3", text)
text = re.sub("(^| )four(-| )star([s.,? ]|$)", r"\g<1>4 star\3", text)
text = re.sub("(^| )five(-| )star([s.,? ]|$)", r"\g<1>5 star\3", text)
text = re.sub("archaelogy", "archaeology", text) # Systematic typo
text = re.sub("guesthouse", "guest house", text) # Normalization
text = re.sub("(^| )b ?& ?b([.,? ]|$)", r"\1bed and breakfast\2", text) # Normalization
text = re.sub("bed & breakfast", "bed and breakfast", text) # Normalization
return text
# This should only contain label normalizations. All other mappings should
# be defined in LABEL_MAPS.
def normalize_label(slot, value_label):
# Normalization of empty slots
if value_label == '' or value_label == "not mentioned":
return "none"
# Normalization of time slots
if "leaveAt" in slot or "arriveBy" in slot or slot == 'restaurant-book_time':
return normalize_time(value_label)
# Normalization
if "type" in slot or "name" in slot or "destination" in slot or "departure" in slot:
value_label = re.sub("guesthouse", "guest house", value_label)
# Map to boolean slots
if slot == 'hotel-parking' or slot == 'hotel-internet':
if value_label == 'yes' or value_label == 'free':
return "true"
if value_label == "no":
return "false"
if slot == 'hotel-type':
if value_label == "hotel":
return "true"
if value_label == "guest house":
return "false"
return value_label
def get_token_pos(tok_list, value_label):
find_pos = []
found = False
label_list = [item for item in map(str.strip, re.split("(\W+)", value_label)) if len(item) > 0]
len_label = len(label_list)
for i in range(len(tok_list) + 1 - len_label):
if tok_list[i:i + len_label] == label_list:
find_pos.append((i, i + len_label)) # start, exclusive_end
found = True
return found, find_pos
def check_label_existence(value_label, usr_utt_tok):
in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label)
# If no hit even though there should be one, check for value label variants
if not in_usr and value_label in LABEL_MAPS:
for value_label_variant in LABEL_MAPS[value_label]:
in_usr, usr_pos = get_token_pos(usr_utt_tok, value_label_variant)
if in_usr:
break
return in_usr, usr_pos
def check_slot_referral(value_label, slot, seen_slots):
referred_slot = 'none'
if slot == 'hotel-stars' or slot == 'hotel-internet' or slot == 'hotel-parking':
return referred_slot
for s in seen_slots:
# Avoid matches for slots that share values with different meaning.
# hotel-internet and -parking are handled separately as Boolean slots.
if s == 'hotel-stars' or s == 'hotel-internet' or s == 'hotel-parking':
continue
if re.match("(hotel|restaurant)-book_people", s) and slot == 'hotel-book_stay':
continue
if re.match("(hotel|restaurant)-book_people", slot) and s == 'hotel-book_stay':
continue
if slot != s and (slot not in seen_slots or seen_slots[slot] != value_label):
if seen_slots[s] == value_label:
referred_slot = s
break
elif value_label in LABEL_MAPS:
for value_label_variant in LABEL_MAPS[value_label]:
if seen_slots[s] == value_label_variant:
referred_slot = s
break
return referred_slot
def is_in_list(tok, value):
found = False
tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0]
value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0]
tok_len = len(tok_list)
value_len = len(value_list)
for i in range(tok_len + 1 - value_len):
if tok_list[i:i + value_len] == value_list:
found = True
break
return found
def delex_utt(utt, values):
utt_norm = tokenize(utt)
for s, vals in values.items():
for v in vals:
if v != 'none':
v_norm = tokenize(v)
v_len = len(v_norm)
for i in range(len(utt_norm) + 1 - v_len):
if utt_norm[i:i + v_len] == v_norm:
utt_norm[i:i + v_len] = ['[UNK]'] * v_len
return utt_norm
# Fuzzy matching to label informed slot values
def check_slot_inform(value_label, inform_label):
result = False
informed_value = 'none'
vl = ' '.join(tokenize(value_label))
for il in inform_label:
if vl == il:
result = True
elif is_in_list(il, vl):
result = True
elif is_in_list(vl, il):
result = True
elif il in LABEL_MAPS:
for il_variant in LABEL_MAPS[il]:
if vl == il_variant:
result = True
break
elif is_in_list(il_variant, vl):
result = True
break
elif is_in_list(vl, il_variant):
result = True
break
elif vl in LABEL_MAPS:
for value_label_variant in LABEL_MAPS[vl]:
if value_label_variant == il:
result = True
break
elif is_in_list(il, value_label_variant):
result = True
break
elif is_in_list(value_label_variant, il):
result = True
break
if result:
informed_value = il
break
return result, informed_value
def get_turn_label(value_label, inform_label, sys_utt_tok, usr_utt_tok, slot, seen_slots, slot_last_occurrence):
usr_utt_tok_label = [0 for _ in usr_utt_tok]
informed_value = 'none'
referred_slot = 'none'
if value_label == 'none' or value_label == 'dontcare' or value_label == 'true' or value_label == 'false':
class_type = value_label
else:
in_usr, usr_pos = check_label_existence(value_label, usr_utt_tok)
if in_usr:
class_type = 'copy_value'
if slot_last_occurrence:
(s, e) = usr_pos[-1]
for i in range(s, e):
usr_utt_tok_label[i] = 1
else:
for (s, e) in usr_pos:
for i in range(s, e):
usr_utt_tok_label[i] = 1
else:
is_informed, informed_value = check_slot_inform(value_label, inform_label)
if is_informed:
class_type = 'inform'
else:
referred_slot = check_slot_referral(value_label, slot, seen_slots)
if referred_slot != 'none':
class_type = 'refer'
else:
class_type = 'unpointable'
return informed_value, referred_slot, usr_utt_tok_label, class_type
def tokenize(utt):
utt_lower = convert_to_unicode(utt).lower()
utt_lower = normalize_text(utt_lower)
utt_tok = [tok for tok in map(str.strip, re.split("(\W+)", utt_lower)) if len(tok) > 0]
return utt_tok
def create_examples(input_file, acts_file, set_type, slot_list,
label_maps={},
append_history=False,
use_history_labels=False,
swap_utterances=False,
label_value_repetitions=False,
delexicalize_sys_utts=False,
analyze=False):
"""Read a DST json file into a list of DSTExample."""
sys_inform_dict = load_acts(acts_file)
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)
global LABEL_MAPS
LABEL_MAPS = label_maps
examples = []
for dialog_id in input_data:
entry = input_data[dialog_id]
utterances = entry['log']
# Collects all slot changes throughout the dialog
cumulative_labels = {slot: 'none' for slot in slot_list}
# First system utterance is empty, since multiwoz starts with user input
utt_tok_list = [[]]
mod_slots_list = [{}]
# Collect all utterances and their metadata
usr_sys_switch = True
turn_itr = 0
for utt in utterances:
# Assert that system and user utterances alternate
is_sys_utt = utt['metadata'] != {}
if usr_sys_switch == is_sys_utt:
print("WARN: Wrong order of system and user utterances. Skipping rest of dialog %s" % (dialog_id))
break
usr_sys_switch = is_sys_utt
if is_sys_utt:
turn_itr += 1
# Delexicalize sys utterance
if delexicalize_sys_utts and is_sys_utt:
inform_dict = {slot: 'none' for slot in slot_list}
for slot in slot_list:
if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict:
inform_dict[slot] = sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]
utt_tok_list.append(delex_utt(utt['text'], inform_dict)) # normalize utterances
else:
utt_tok_list.append(tokenize(utt['text'])) # normalize utterances
modified_slots = {}
# If sys utt, extract metadata (identify and collect modified slots)
if is_sys_utt:
for d in utt['metadata']:
booked = utt['metadata'][d]['book']['booked']
booked_slots = {}
# Check the booked section
if booked != []:
for s in booked[0]:
booked_slots[s] = normalize_label('%s-%s' % (d, s), booked[0][s]) # normalize labels
# Check the semi and the inform slots
for category in ['book', 'semi']:
for s in utt['metadata'][d][category]:
cs = '%s-book_%s' % (d, s) if category == 'book' else '%s-%s' % (d, s)
value_label = normalize_label(cs, utt['metadata'][d][category][s]) # normalize labels
# Prefer the slot value as stored in the booked section
if s in booked_slots:
value_label = booked_slots[s]
# Remember modified slots and entire dialog state
if cs in slot_list and cumulative_labels[cs] != value_label:
modified_slots[cs] = value_label
cumulative_labels[cs] = value_label
mod_slots_list.append(modified_slots.copy())
# Form proper (usr, sys) turns
turn_itr = 0
diag_seen_slots_dict = {}
diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list}
diag_state = {slot: 'none' for slot in slot_list}
sys_utt_tok = []
usr_utt_tok = []
hst_utt_tok = []
hst_utt_tok_label_dict = {slot: [] for slot in slot_list}
for i in range(1, len(utt_tok_list) - 1, 2):
sys_utt_tok_label_dict = {}
usr_utt_tok_label_dict = {}
value_dict = {}
inform_dict = {}
inform_slot_dict = {}
referral_dict = {}
class_type_dict = {}
# Collect turn data
if append_history:
if swap_utterances:
hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok
else:
hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok
sys_utt_tok = utt_tok_list[i - 1]
usr_utt_tok = utt_tok_list[i]
turn_slots = mod_slots_list[i + 1]
guid = '%s-%s-%s' % (set_type, str(dialog_id), str(turn_itr))
if analyze:
print("%15s %2s %s ||| %s" % (dialog_id, turn_itr, ' '.join(sys_utt_tok), ' '.join(usr_utt_tok)))
print("%15s %2s [" % (dialog_id, turn_itr), end='')
new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy()
new_diag_state = diag_state.copy()
for slot in slot_list:
value_label = 'none'
if slot in turn_slots:
value_label = turn_slots[slot]
# We keep the original labels so as to not
# overlook unpointable values, as well as to not
# modify any of the original labels for test sets,
# since this would make comparison difficult.
value_dict[slot] = value_label
elif label_value_repetitions and slot in diag_seen_slots_dict:
value_label = diag_seen_slots_value_dict[slot]
# Get dialog act annotations
inform_label = list(['none'])
if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict:
inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), slot)]])
elif (str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1]) in sys_inform_dict:
inform_label = list([normalize_label(slot, i) for i in sys_inform_dict[(str(dialog_id), str(turn_itr), 'booking-' + slot.split('-')[1])]])
(informed_value,
referred_slot,
usr_utt_tok_label,
class_type) = get_turn_label(value_label,
inform_label,
sys_utt_tok,
usr_utt_tok,
slot,
diag_seen_slots_value_dict,
slot_last_occurrence=True)
inform_dict[slot] = informed_value
if informed_value != 'none':
inform_slot_dict[slot] = 1
else:
inform_slot_dict[slot] = 0
# Generally don't use span prediction on sys utterance (but inform prediction instead).
sys_utt_tok_label = [0 for _ in sys_utt_tok]
# Determine what to do with value repetitions.
# If value is unique in seen slots, then tag it, otherwise not,
# since correct slot assignment can not be guaranteed anymore.
if label_value_repetitions and slot in diag_seen_slots_dict:
if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(value_label) > 1:
class_type = 'none'
usr_utt_tok_label = [0 for _ in usr_utt_tok_label]
sys_utt_tok_label_dict[slot] = sys_utt_tok_label
usr_utt_tok_label_dict[slot] = usr_utt_tok_label
if append_history:
if use_history_labels:
if swap_utterances:
new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[slot]
else:
new_hst_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]
else:
new_hst_utt_tok_label_dict[slot] = [0 for _ in sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]]
# For now, we map all occurences of unpointable slot values
# to none. However, since the labels will still suggest
# a presence of unpointable slot values, the task of the
# DST is still to find those values. It is just not
# possible to do that via span prediction on the current input.
if class_type == 'unpointable':
class_type_dict[slot] = 'none'
referral_dict[slot] = 'none'
if analyze:
if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[slot]:
print("(%s): %s, " % (slot, value_label), end='')
elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform':
# If slot has seen before and its class type did not change, label this slot a not present,
# assuming that the slot has not actually been mentioned in this turn.
# Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform,
# this must mean there is evidence in the original labels, therefore consider
# them as mentioned again.
class_type_dict[slot] = 'none'
referral_dict[slot] = 'none'
else:
class_type_dict[slot] = class_type
referral_dict[slot] = referred_slot
# Remember that this slot was mentioned during this dialog already.
if class_type != 'none':
diag_seen_slots_dict[slot] = class_type
diag_seen_slots_value_dict[slot] = value_label
new_diag_state[slot] = class_type
# Unpointable is not a valid class, therefore replace with
# some valid class for now...
if class_type == 'unpointable':
new_diag_state[slot] = 'copy_value'
if analyze:
print("]")
if swap_utterances:
txt_a = usr_utt_tok
txt_b = sys_utt_tok
txt_a_lbl = usr_utt_tok_label_dict
txt_b_lbl = sys_utt_tok_label_dict
else:
txt_a = sys_utt_tok
txt_b = usr_utt_tok
txt_a_lbl = sys_utt_tok_label_dict
txt_b_lbl = usr_utt_tok_label_dict
examples.append(DSTExample(
guid=guid,
text_a=txt_a,
text_b=txt_b,
history=hst_utt_tok,
text_a_label=txt_a_lbl,
text_b_label=txt_b_lbl,
history_label=hst_utt_tok_label_dict,
values=diag_seen_slots_value_dict.copy(),
inform_label=inform_dict,
inform_slot_label=inform_slot_dict,
refer_label=referral_dict,
diag_state=diag_state,
class_label=class_type_dict))
# Update some variables.
hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy()
diag_state = new_diag_state.copy()
turn_itr += 1
if analyze:
print("----------------------------------------------------------------------")
return examples
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from utils_dst import (DSTExample)
def dialogue_state_to_sv_dict(sv_list):
sv_dict = {}
for d in sv_list:
sv_dict[d['slot']] = d['value']
return sv_dict
def get_token_and_slot_label(turn):
if 'system_utterance' in turn:
sys_utt_tok = turn['system_utterance']['tokens']
sys_slot_label = turn['system_utterance']['slots']
else:
sys_utt_tok = []
sys_slot_label = []
usr_utt_tok = turn['user_utterance']['tokens']
usr_slot_label = turn['user_utterance']['slots']
return sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label
def get_tok_label(prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok,
sys_slot_label, usr_utt_tok, usr_slot_label, dial_id,
turn_id, slot_last_occurrence=True):
"""The position of the last occurrence of the slot value will be used."""
sys_utt_tok_label = [0 for _ in sys_utt_tok]
usr_utt_tok_label = [0 for _ in usr_utt_tok]
if slot_type not in cur_ds_dict:
class_type = 'none'
else:
value = cur_ds_dict[slot_type]
if value == 'dontcare' and (slot_type not in prev_ds_dict or prev_ds_dict[slot_type] != 'dontcare'):
# Only label dontcare at its first occurrence in the dialog
class_type = 'dontcare'
else: # If not none or dontcare, we have to identify whether
# there is a span, or if the value is informed
in_usr = False
in_sys = False
for label_d in usr_slot_label:
if label_d['slot'] == slot_type and value == ' '.join(
usr_utt_tok[label_d['start']:label_d['exclusive_end']]):
for idx in range(label_d['start'], label_d['exclusive_end']):
usr_utt_tok_label[idx] = 1
in_usr = True
class_type = 'copy_value'
if slot_last_occurrence:
break
if not in_usr or not slot_last_occurrence:
for label_d in sys_slot_label:
if label_d['slot'] == slot_type and value == ' '.join(
sys_utt_tok[label_d['start']:label_d['exclusive_end']]):
for idx in range(label_d['start'], label_d['exclusive_end']):
sys_utt_tok_label[idx] = 1
in_sys = True
class_type = 'inform'
if slot_last_occurrence:
break
if not in_usr and not in_sys:
assert sum(usr_utt_tok_label + sys_utt_tok_label) == 0
if (slot_type not in prev_ds_dict or value != prev_ds_dict[slot_type]):
raise ValueError('Copy value cannot found in Dial %s Turn %s' % (str(dial_id), str(turn_id)))
else:
class_type = 'none'
else:
assert sum(usr_utt_tok_label + sys_utt_tok_label) > 0
return sys_utt_tok_label, usr_utt_tok_label, class_type
def delex_utt(utt, values):
utt_delex = utt.copy()
for v in values:
utt_delex[v['start']:v['exclusive_end']] = ['[UNK]'] * (v['exclusive_end'] - v['start'])
return utt_delex
def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id,
delexicalize_sys_utts=False, slot_last_occurrence=True):
"""Make turn_label a dictionary of slot with value positions or being dontcare / none:
Turn label contains:
(1) the updates from previous to current dialogue state,
(2) values in current dialogue state explicitly mentioned in system or user utterance."""
prev_ds_dict = dialogue_state_to_sv_dict(prev_dialogue_state)
cur_ds_dict = dialogue_state_to_sv_dict(turn['dialogue_state'])
(sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label) = get_token_and_slot_label(turn)
sys_utt_tok_label_dict = {}
usr_utt_tok_label_dict = {}
inform_label_dict = {}
inform_slot_label_dict = {}
referral_label_dict = {}
class_type_dict = {}
for slot_type in slot_list:
inform_label_dict[slot_type] = 'none'
inform_slot_label_dict[slot_type] = 0
referral_label_dict[slot_type] = 'none' # Referral is not present in sim data
sys_utt_tok_label, usr_utt_tok_label, class_type = get_tok_label(
prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok, sys_slot_label,
usr_utt_tok, usr_slot_label, dial_id, turn_id,
slot_last_occurrence=slot_last_occurrence)
if sum(sys_utt_tok_label) > 0:
inform_label_dict[slot_type] = cur_ds_dict[slot_type]
inform_slot_label_dict[slot_type] = 1
sys_utt_tok_label = [0 for _ in sys_utt_tok_label] # Don't use token labels for sys utt
sys_utt_tok_label_dict[slot_type] = sys_utt_tok_label
usr_utt_tok_label_dict[slot_type] = usr_utt_tok_label
class_type_dict[slot_type] = class_type
if delexicalize_sys_utts:
sys_utt_tok = delex_utt(sys_utt_tok, sys_slot_label)
return (sys_utt_tok, sys_utt_tok_label_dict,
usr_utt_tok, usr_utt_tok_label_dict,
inform_label_dict, inform_slot_label_dict,
referral_label_dict, cur_ds_dict, class_type_dict)
def create_examples(input_file, set_type, slot_list,
label_maps={},
append_history=False,
use_history_labels=False,
swap_utterances=False,
label_value_repetitions=False,
delexicalize_sys_utts=False,
analyze=False):
"""Read a DST json file into a list of DSTExample."""
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)
examples = []
for entry in input_data:
dial_id = entry['dialogue_id']
prev_ds = []
hst = []
prev_hst_lbl_dict = {slot: [] for slot in slot_list}
prev_ds_lbl_dict = {slot: 'none' for slot in slot_list}
for turn_id, turn in enumerate(entry['turns']):
guid = '%s-%s-%s' % (set_type, dial_id, str(turn_id))
ds_lbl_dict = prev_ds_lbl_dict.copy()
hst_lbl_dict = prev_hst_lbl_dict.copy()
(text_a,
text_a_label,
text_b,
text_b_label,
inform_label,
inform_slot_label,
referral_label,
cur_ds_dict,
class_label) = get_turn_label(turn,
prev_ds,
slot_list,
dial_id,
turn_id,
delexicalize_sys_utts=delexicalize_sys_utts,
slot_last_occurrence=True)
if swap_utterances:
txt_a = text_b
txt_b = text_a
txt_a_lbl = text_b_label
txt_b_lbl = text_a_label
else:
txt_a = text_a
txt_b = text_b
txt_a_lbl = text_a_label
txt_b_lbl = text_b_label
value_dict = {}
for slot in slot_list:
if slot in cur_ds_dict:
value_dict[slot] = cur_ds_dict[slot]
else:
value_dict[slot] = 'none'
if class_label[slot] != 'none':
ds_lbl_dict[slot] = class_label[slot]
if append_history:
if use_history_labels:
hst_lbl_dict[slot] = txt_a_lbl[slot] + txt_b_lbl[slot] + hst_lbl_dict[slot]
else:
hst_lbl_dict[slot] = [0 for _ in txt_a_lbl[slot] + txt_b_lbl[slot] + hst_lbl_dict[slot]]
examples.append(DSTExample(
guid=guid,
text_a=txt_a,
text_b=txt_b,
history=hst,
text_a_label=txt_a_lbl,
text_b_label=txt_b_lbl,
history_label=prev_hst_lbl_dict,
values=value_dict,
inform_label=inform_label,
inform_slot_label=inform_slot_label,
refer_label=referral_label,
diag_state=prev_ds_lbl_dict,
class_label=class_label))
prev_ds = turn['dialogue_state']
prev_ds_lbl_dict = ds_lbl_dict.copy()
prev_hst_lbl_dict = hst_lbl_dict.copy()
if append_history:
hst = txt_a + txt_b + hst
return examples
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from utils_dst import (DSTExample, convert_to_unicode)
LABEL_MAPS = {} # Loaded from file
LABEL_FIX = {'centre': 'center', 'areas': 'area', 'phone number': 'number', 'price range': 'price_range'}
def delex_utt(utt, values):
utt_norm = utt.copy()
for s, v in values.items():
if v != 'none':
v_norm = tokenize(v)
v_len = len(v_norm)
for i in range(len(utt_norm) + 1 - v_len):
if utt_norm[i:i + v_len] == v_norm:
utt_norm[i:i + v_len] = ['[UNK]'] * v_len
return utt_norm
def get_token_pos(tok_list, label):
find_pos = []
found = False
label_list = [item for item in map(str.strip, re.split("(\W+)", label)) if len(item) > 0]
len_label = len(label_list)
for i in range(len(tok_list) + 1 - len_label):
if tok_list[i:i + len_label] == label_list:
find_pos.append((i, i + len_label)) # start, exclusive_end
found = True
return found, find_pos
def check_label_existence(label, usr_utt_tok, sys_utt_tok):
in_usr, usr_pos = get_token_pos(usr_utt_tok, label)
if not in_usr and label in LABEL_MAPS:
for tmp_label in LABEL_MAPS[label]:
in_usr, usr_pos = get_token_pos(usr_utt_tok, tmp_label)
if in_usr:
break
in_sys, sys_pos = get_token_pos(sys_utt_tok, label)
if not in_sys and label in LABEL_MAPS:
for tmp_label in LABEL_MAPS[label]:
in_sys, sys_pos = get_token_pos(sys_utt_tok, tmp_label)
if in_sys:
break
return in_usr, usr_pos, in_sys, sys_pos
def get_turn_label(label, sys_utt_tok, usr_utt_tok, slot_last_occurrence):
usr_utt_tok_label = [0 for _ in usr_utt_tok]
if label == 'none' or label == 'dontcare':
class_type = label
else:
in_usr, usr_pos, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok)
if in_usr:
class_type = 'copy_value'
if slot_last_occurrence:
(s, e) = usr_pos[-1]
for i in range(s, e):
usr_utt_tok_label[i] = 1
else:
for (s, e) in usr_pos:
for i in range(s, e):
usr_utt_tok_label[i] = 1
elif in_sys:
class_type = 'inform'
else:
class_type = 'unpointable'
return usr_utt_tok_label, class_type
def tokenize(utt):
utt_lower = convert_to_unicode(utt).lower()
utt_tok = [tok for tok in map(str.strip, re.split("(\W+)", utt_lower)) if len(tok) > 0]
return utt_tok
def create_examples(input_file, set_type, slot_list,
label_maps={},
append_history=False,
use_history_labels=False,
swap_utterances=False,
label_value_repetitions=False,
delexicalize_sys_utts=False,
analyze=False):
"""Read a DST json file into a list of DSTExample."""
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)
global LABEL_MAPS
LABEL_MAPS = label_maps
examples = []
for entry in input_data:
diag_seen_slots_dict = {}
diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list}
diag_state = {slot: 'none' for slot in slot_list}
sys_utt_tok = []
sys_utt_tok_delex = []
usr_utt_tok = []
hst_utt_tok = []
hst_utt_tok_label_dict = {slot: [] for slot in slot_list}
for turn in entry['dialogue']:
sys_utt_tok_label_dict = {}
usr_utt_tok_label_dict = {}
inform_dict = {slot: 'none' for slot in slot_list}
inform_slot_dict = {slot: 0 for slot in slot_list}
referral_dict = {}
class_type_dict = {}
# Collect turn data
if append_history:
if swap_utterances:
if delexicalize_sys_utts:
hst_utt_tok = usr_utt_tok + sys_utt_tok_delex + hst_utt_tok
else:
hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok
else:
if delexicalize_sys_utts:
hst_utt_tok = sys_utt_tok_delex + usr_utt_tok + hst_utt_tok
else:
hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok
sys_utt_tok = tokenize(turn['system_transcript'])
usr_utt_tok = tokenize(turn['transcript'])
turn_label = {LABEL_FIX.get(s.strip(), s.strip()): LABEL_FIX.get(v.strip(), v.strip()) for s, v in turn['turn_label']}
guid = '%s-%s-%s' % (set_type, str(entry['dialogue_idx']), str(turn['turn_idx']))
# Create delexicalized sys utterances.
if delexicalize_sys_utts:
delex_dict = {}
for slot in slot_list:
delex_dict[slot] = 'none'
label = 'none'
if slot in turn_label:
label = turn_label[slot]
elif label_value_repetitions and slot in diag_seen_slots_dict:
label = diag_seen_slots_value_dict[slot]
if label != 'none' and label != 'dontcare':
_, _, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok)
if in_sys:
delex_dict[slot] = label
sys_utt_tok_delex = delex_utt(sys_utt_tok, delex_dict)
new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy()
new_diag_state = diag_state.copy()
for slot in slot_list:
label = 'none'
if slot in turn_label:
label = turn_label[slot]
elif label_value_repetitions and slot in diag_seen_slots_dict:
label = diag_seen_slots_value_dict[slot]
(usr_utt_tok_label,
class_type) = get_turn_label(label,
sys_utt_tok,
usr_utt_tok,
slot_last_occurrence=True)
if class_type == 'inform':
inform_dict[slot] = label
if label != 'none':
inform_slot_dict[slot] = 1
referral_dict[slot] = 'none' # Referral is not present in woz2 data
# Generally don't use span prediction on sys utterance (but inform prediction instead).
if delexicalize_sys_utts:
sys_utt_tok_label = [0 for _ in sys_utt_tok_delex]
else:
sys_utt_tok_label = [0 for _ in sys_utt_tok]
# Determine what to do with value repetitions.
# If value is unique in seen slots, then tag it, otherwise not,
# since correct slot assignment can not be guaranteed anymore.
if label_value_repetitions and slot in diag_seen_slots_dict:
if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(label) > 1:
class_type = 'none'
usr_utt_tok_label = [0 for _ in usr_utt_tok_label]
sys_utt_tok_label_dict[slot] = sys_utt_tok_label
usr_utt_tok_label_dict[slot] = usr_utt_tok_label
if append_history:
if use_history_labels:
if swap_utterances:
new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[slot]
else:
new_hst_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]
else:
new_hst_utt_tok_label_dict[slot] = [0 for _ in sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]]
# For now, we map all occurences of unpointable slot values
# to none. However, since the labels will still suggest
# a presence of unpointable slot values, the task of the
# DST is still to find those values. It is just not
# possible to do that via span prediction on the current input.
if class_type == 'unpointable':
class_type_dict[slot] = 'none'
elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform':
# If slot has seen before and its class type did not change, label this slot a not present,
# assuming that the slot has not actually been mentioned in this turn.
# Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform,
# this must mean there is evidence in the original labels, therefore consider
# them as mentioned again.
class_type_dict[slot] = 'none'
referral_dict[slot] = 'none'
else:
class_type_dict[slot] = class_type
# Remember that this slot was mentioned during this dialog already.
if class_type != 'none':
diag_seen_slots_dict[slot] = class_type
diag_seen_slots_value_dict[slot] = label
new_diag_state[slot] = class_type
# Unpointable is not a valid class, therefore replace with
# some valid class for now...
if class_type == 'unpointable':
new_diag_state[slot] = 'copy_value'
if swap_utterances:
txt_a = usr_utt_tok
if delexicalize_sys_utts:
txt_b = sys_utt_tok_delex
else:
txt_b = sys_utt_tok
txt_a_lbl = usr_utt_tok_label_dict
txt_b_lbl = sys_utt_tok_label_dict
else:
if delexicalize_sys_utts:
txt_a = sys_utt_tok_delex
else:
txt_a = sys_utt_tok
txt_b = usr_utt_tok
txt_a_lbl = sys_utt_tok_label_dict
txt_b_lbl = usr_utt_tok_label_dict
examples.append(DSTExample(
guid=guid,
text_a=txt_a,
text_b=txt_b,
history=hst_utt_tok,
text_a_label=txt_a_lbl,
text_b_label=txt_b_lbl,
history_label=hst_utt_tok_label_dict,
values=diag_seen_slots_value_dict.copy(),
inform_label=inform_dict,
inform_slot_label=inform_slot_dict,
refer_label=referral_dict,
diag_state=diag_state,
class_label=class_type_dict))
# Update some variables.
hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy()
diag_state = new_diag_state.copy()
return examples
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import json
import sys
import numpy as np
import re
def load_dataset_config(dataset_config):
with open(dataset_config, "r", encoding='utf-8') as f:
raw_config = json.load(f)
return raw_config['class_types'], raw_config['slots'], raw_config['label_maps']
def tokenize(text):
if "\u0120" in text:
text = re.sub(" ", "", text)
text = re.sub("\u0120", " ", text)
text = text.strip()
return ' '.join([tok for tok in map(str.strip, re.split("(\W+)", text)) if len(tok) > 0])
def is_in_list(tok, value):
found = False
tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0]
value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0]
tok_len = len(tok_list)
value_len = len(value_list)
for i in range(tok_len + 1 - value_len):
if tok_list[i:i + value_len] == value_list:
found = True
break
return found
def check_slot_inform(value_label, inform_label, label_maps):
value = inform_label
if value_label == inform_label:
value = value_label
elif is_in_list(inform_label, value_label):
value = value_label
elif is_in_list(value_label, inform_label):
value = value_label
elif inform_label in label_maps:
for inform_label_variant in label_maps[inform_label]:
if value_label == inform_label_variant:
value = value_label
break
elif is_in_list(inform_label_variant, value_label):
value = value_label
break
elif is_in_list(value_label, inform_label_variant):
value = value_label
break
elif value_label in label_maps:
for value_label_variant in label_maps[value_label]:
if value_label_variant == inform_label:
value = value_label
break
elif is_in_list(inform_label, value_label_variant):
value = value_label
break
elif is_in_list(value_label_variant, inform_label):
value = value_label
break
return value
def get_joint_slot_correctness(fp, class_types, label_maps,
key_class_label_id='class_label_id',
key_class_prediction='class_prediction',
key_start_pos='start_pos',
key_start_prediction='start_prediction',
key_end_pos='end_pos',
key_end_prediction='end_prediction',
key_refer_id='refer_id',
key_refer_prediction='refer_prediction',
key_slot_groundtruth='slot_groundtruth',
key_slot_prediction='slot_prediction'):
with open(fp) as f:
preds = json.load(f)
class_correctness = [[] for cl in range(len(class_types) + 1)]
confusion_matrix = [[[] for cl_b in range(len(class_types))] for cl_a in range(len(class_types))]
pos_correctness = []
refer_correctness = []
val_correctness = []
total_correctness = []
c_tp = {ct: 0 for ct in range(len(class_types))}
c_tn = {ct: 0 for ct in range(len(class_types))}
c_fp = {ct: 0 for ct in range(len(class_types))}
c_fn = {ct: 0 for ct in range(len(class_types))}
for pred in preds:
guid = pred['guid'] # List: set_type, dialogue_idx, turn_idx
turn_gt_class = pred[key_class_label_id]
turn_pd_class = pred[key_class_prediction]
gt_start_pos = pred[key_start_pos]
pd_start_pos = pred[key_start_prediction]
gt_end_pos = pred[key_end_pos]
pd_end_pos = pred[key_end_prediction]
gt_refer = pred[key_refer_id]
pd_refer = pred[key_refer_prediction]
gt_slot = pred[key_slot_groundtruth]
pd_slot = pred[key_slot_prediction]
gt_slot = tokenize(gt_slot)
pd_slot = tokenize(pd_slot)
# Make sure the true turn labels are contained in the prediction json file!
joint_gt_slot = gt_slot
if guid[-1] == '0': # First turn, reset the slots
joint_pd_slot = 'none'
# If turn_pd_class or a value to be copied is "none", do not update the dialog state.
if turn_pd_class == class_types.index('none'):
pass
elif turn_pd_class == class_types.index('dontcare'):
joint_pd_slot = 'dontcare'
elif turn_pd_class == class_types.index('copy_value'):
joint_pd_slot = pd_slot
elif 'true' in class_types and turn_pd_class == class_types.index('true'):
joint_pd_slot = 'true'
elif 'false' in class_types and turn_pd_class == class_types.index('false'):
joint_pd_slot = 'false'
elif 'refer' in class_types and turn_pd_class == class_types.index('refer'):
if pd_slot[0:3] == "§§ ":
if pd_slot[3:] != 'none':
joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[3:], label_maps)
elif pd_slot[0:2] == "§§":
if pd_slot[2:] != 'none':
joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:], label_maps)
elif pd_slot != 'none':
joint_pd_slot = pd_slot
elif 'inform' in class_types and turn_pd_class == class_types.index('inform'):
if pd_slot[0:3] == "§§ ":
if pd_slot[3:] != 'none':
joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[3:], label_maps)
elif pd_slot[0:2] == "§§":
if pd_slot[2:] != 'none':
joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:], label_maps)
else:
print("ERROR: Unexpected slot value format. Aborting.")
exit()
else:
print("ERROR: Unexpected class_type. Aborting.")
exit()
total_correct = True
# Check the per turn correctness of the class_type prediction
if turn_gt_class == turn_pd_class:
class_correctness[turn_gt_class].append(1.0)
class_correctness[-1].append(1.0)
c_tp[turn_gt_class] += 1
for cc in range(len(class_types)):
if cc != turn_gt_class:
c_tn[cc] += 1
# Only where there is a span, we check its per turn correctness
if turn_gt_class == class_types.index('copy_value'):
if gt_start_pos == pd_start_pos and gt_end_pos == pd_end_pos:
pos_correctness.append(1.0)
else:
pos_correctness.append(0.0)
# Only where there is a referral, we check its per turn correctness
if 'refer' in class_types and turn_gt_class == class_types.index('refer'):
if gt_refer == pd_refer:
refer_correctness.append(1.0)
print(" [%s] Correct referral: %s | %s" % (guid, gt_refer, pd_refer))
else:
refer_correctness.append(0.0)
print(" [%s] Incorrect referral: %s | %s" % (guid, gt_refer, pd_refer))
else:
if turn_gt_class == class_types.index('copy_value'):
pos_correctness.append(0.0)
if 'refer' in class_types and turn_gt_class == class_types.index('refer'):
refer_correctness.append(0.0)
class_correctness[turn_gt_class].append(0.0)
class_correctness[-1].append(0.0)
confusion_matrix[turn_gt_class][turn_pd_class].append(1.0)
c_fn[turn_gt_class] += 1
c_fp[turn_pd_class] += 1
# Check the joint slot correctness.
# If the value label is not none, then we need to have a value prediction.
# Even if the class_type is 'none', there can still be a value label,
# it might just not be pointable in the current turn. It might however
# be referrable and thus predicted correctly.
if joint_gt_slot == joint_pd_slot:
val_correctness.append(1.0)
elif joint_gt_slot != 'none' and joint_gt_slot != 'dontcare' and joint_gt_slot != 'true' and joint_gt_slot != 'false' and joint_gt_slot in label_maps:
no_match = True
for variant in label_maps[joint_gt_slot]:
if variant == joint_pd_slot:
no_match = False
break
if no_match:
val_correctness.append(0.0)
total_correct = False
print(" [%s] Incorrect value (variant): %s (turn class: %s) | %s (turn class: %s)" % (guid, joint_gt_slot, turn_gt_class, joint_pd_slot, turn_pd_class))
else:
val_correctness.append(1.0)
else:
val_correctness.append(0.0)
total_correct = False
print(" [%s] Incorrect value: %s (turn class: %s) | %s (turn class: %s)" % (guid, joint_gt_slot, turn_gt_class, joint_pd_slot, turn_pd_class))
total_correctness.append(1.0 if total_correct else 0.0)
# Account for empty lists (due to no instances of spans or referrals being seen)
if pos_correctness == []:
pos_correctness.append(1.0)
if refer_correctness == []:
refer_correctness.append(1.0)
for ct in range(len(class_types)):
if c_tp[ct] + c_fp[ct] > 0:
precision = c_tp[ct] / (c_tp[ct] + c_fp[ct])
else:
precision = 1.0
if c_tp[ct] + c_fn[ct] > 0:
recall = c_tp[ct] / (c_tp[ct] + c_fn[ct])
else:
recall = 1.0
if precision + recall > 0:
f1 = 2 * ((precision * recall) / (precision + recall))
else:
f1 = 1.0
if c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct] > 0:
acc = (c_tp[ct] + c_tn[ct]) / (c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct])
else:
acc = 1.0
print("Performance for class '%s' (%s): Recall: %.2f (%d of %d), Precision: %.2f, F1: %.2f, Accuracy: %.2f (TP/TN/FP/FN: %d/%d/%d/%d)" %
(class_types[ct], ct, recall, np.sum(class_correctness[ct]), len(class_correctness[ct]), precision, f1, acc, c_tp[ct], c_tn[ct], c_fp[ct], c_fn[ct]))
print("Confusion matrix:")
for cl in range(len(class_types)):
print(" %s" % (cl), end="")
print("")
for cl_a in range(len(class_types)):
print("%s " % (cl_a), end="")
for cl_b in range(len(class_types)):
if len(class_correctness[cl_a]) > 0:
print("%.2f " % (np.sum(confusion_matrix[cl_a][cl_b]) / len(class_correctness[cl_a])), end="")
else:
print("---- ", end="")
print("")
return np.asarray(total_correctness), np.asarray(val_correctness), np.asarray(class_correctness), np.asarray(pos_correctness), np.asarray(refer_correctness), np.asarray(confusion_matrix), c_tp, c_tn, c_fp, c_fn
if __name__ == "__main__":
acc_list = []
acc_list_v = []
key_class_label_id = 'class_label_id_%s'
key_class_prediction = 'class_prediction_%s'
key_start_pos = 'start_pos_%s'
key_start_prediction = 'start_prediction_%s'
key_end_pos = 'end_pos_%s'
key_end_prediction = 'end_prediction_%s'
key_refer_id = 'refer_id_%s'
key_refer_prediction = 'refer_prediction_%s'
key_slot_groundtruth = 'slot_groundtruth_%s'
key_slot_prediction = 'slot_prediction_%s'
dataset = sys.argv[1].lower()
dataset_config = sys.argv[2].lower()
if dataset not in ['woz2', 'sim-m', 'sim-r', 'multiwoz21']:
raise ValueError("Task not found: %s" % (dataset))
class_types, slots, label_maps = load_dataset_config(dataset_config)
# Prepare label_maps
label_maps_tmp = {}
for v in label_maps:
label_maps_tmp[tokenize(v)] = [tokenize(nv) for nv in label_maps[v]]
label_maps = label_maps_tmp
for fp in sorted(glob.glob(sys.argv[3])):
print(fp)
goal_correctness = 1.0
cls_acc = [[] for cl in range(len(class_types))]
cls_conf = [[[] for cl_b in range(len(class_types))] for cl_a in range(len(class_types))]
c_tp = {ct: 0 for ct in range(len(class_types))}
c_tn = {ct: 0 for ct in range(len(class_types))}
c_fp = {ct: 0 for ct in range(len(class_types))}
c_fn = {ct: 0 for ct in range(len(class_types))}
for slot in slots:
tot_cor, joint_val_cor, cls_cor, pos_cor, ref_cor, conf_mat, ctp, ctn, cfp, cfn = get_joint_slot_correctness(fp, class_types, label_maps,
key_class_label_id=(key_class_label_id % slot),
key_class_prediction=(key_class_prediction % slot),
key_start_pos=(key_start_pos % slot),
key_start_prediction=(key_start_prediction % slot),
key_end_pos=(key_end_pos % slot),
key_end_prediction=(key_end_prediction % slot),
key_refer_id=(key_refer_id % slot),
key_refer_prediction=(key_refer_prediction % slot),
key_slot_groundtruth=(key_slot_groundtruth % slot),
key_slot_prediction=(key_slot_prediction % slot)
)
print('%s: joint slot acc: %g, joint value acc: %g, turn class acc: %g, turn position acc: %g, turn referral acc: %g' %
(slot, np.mean(tot_cor), np.mean(joint_val_cor), np.mean(cls_cor[-1]), np.mean(pos_cor), np.mean(ref_cor)))
goal_correctness *= tot_cor
for cl_a in range(len(class_types)):
cls_acc[cl_a] += cls_cor[cl_a]
for cl_b in range(len(class_types)):
cls_conf[cl_a][cl_b] += list(conf_mat[cl_a][cl_b])
c_tp[cl_a] += ctp[cl_a]
c_tn[cl_a] += ctn[cl_a]
c_fp[cl_a] += cfp[cl_a]
c_fn[cl_a] += cfn[cl_a]
for ct in range(len(class_types)):
if c_tp[ct] + c_fp[ct] > 0:
precision = c_tp[ct] / (c_tp[ct] + c_fp[ct])
else:
precision = 1.0
if c_tp[ct] + c_fn[ct] > 0:
recall = c_tp[ct] / (c_tp[ct] + c_fn[ct])
else:
recall = 1.0
if precision + recall > 0:
f1 = 2 * ((precision * recall) / (precision + recall))
else:
f1 = 1.0
if c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct] > 0:
acc = (c_tp[ct] + c_tn[ct]) / (c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct])
else:
acc = 1.0
print("Performance for class '%s' (%s): Recall: %.2f (%d of %d), Precision: %.2f, F1: %.2f, Accuracy: %.2f (TP/TN/FP/FN: %d/%d/%d/%d)" %
(class_types[ct], ct, recall, np.sum(cls_acc[ct]), len(cls_acc[ct]), precision, f1, acc, c_tp[ct], c_tn[ct], c_fp[ct], c_fn[ct]))
print("Confusion matrix:")
for cl in range(len(class_types)):
print(" %s" % (cl), end="")
print("")
for cl_a in range(len(class_types)):
print("%s " % (cl_a), end="")
for cl_b in range(len(class_types)):
if len(cls_acc[cl_a]) > 0:
print("%.2f " % (np.sum(cls_conf[cl_a][cl_b]) / len(cls_acc[cl_a])), end="")
else:
print("---- ", end="")
print("")
acc = np.mean(goal_correctness)
acc_list.append((fp, acc))
acc_list_s = sorted(acc_list, key=lambda tup: tup[1], reverse=True)
for (fp, acc) in acc_list_s:
print('Joint goal acc: %g, %s' % (acc, fp))
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
# Part of this code is based on the source code of Transformers
# (arXiv:1910.03771)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable)
from transformers.modeling_bert import (BertModel, BertPreTrainedModel, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
@add_start_docstrings(
"""BERT Model with a classification heads for the DST task. """,
BERT_START_DOCSTRING,
)
class BertForDST(BertPreTrainedModel):
def __init__(self, config):
super(BertForDST, self).__init__(config)
self.slot_list = config.dst_slot_list
self.class_types = config.dst_class_types
self.class_labels = config.dst_class_labels
self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
self.class_aux_feats_inform = config.dst_class_aux_feats_inform
self.class_aux_feats_ds = config.dst_class_aux_feats_ds
self.class_loss_ratio = config.dst_class_loss_ratio
# Only use refer loss if refer class is present in dataset.
if 'refer' in self.class_types:
self.refer_index = self.class_types.index('refer')
else:
self.refer_index = -1
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.dst_dropout_rate)
self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
if self.class_aux_feats_inform:
self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
if self.class_aux_feats_ds:
self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
for slot in self.slot_list:
self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def forward(self,
input_ids,
input_mask=None,
segment_ids=None,
position_ids=None,
head_mask=None,
start_pos=None,
end_pos=None,
inform_slot_id=None,
refer_id=None,
class_label_id=None,
diag_state=None):
outputs = self.bert(
input_ids,
attention_mask=input_mask,
token_type_ids=segment_ids,
position_ids=position_ids,
head_mask=head_mask
)
sequence_output = outputs[0]
pooled_output = outputs[1]
sequence_output = self.dropout(sequence_output)
pooled_output = self.dropout(pooled_output)
# TODO: establish proper format in labels already?
if inform_slot_id is not None:
inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
if diag_state is not None:
diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
total_loss = 0
per_slot_per_example_loss = {}
per_slot_class_logits = {}
per_slot_start_logits = {}
per_slot_end_logits = {}
per_slot_refer_logits = {}
for slot in self.slot_list:
if self.class_aux_feats_inform and self.class_aux_feats_ds:
pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
elif self.class_aux_feats_inform:
pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
elif self.class_aux_feats_ds:
pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
else:
pooled_output_aux = pooled_output
class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
start_logits, end_logits = token_logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
per_slot_class_logits[slot] = class_logits
per_slot_start_logits[slot] = start_logits
per_slot_end_logits[slot] = end_logits
per_slot_refer_logits[slot] = refer_logits
# If there are no labels, don't compute loss
if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
# If we are on multi-GPU, split add a dimension
if len(start_pos[slot].size()) > 1:
start_pos[slot] = start_pos[slot].squeeze(-1)
if len(end_pos[slot].size()) > 1:
end_pos[slot] = end_pos[slot].squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) # This is a single index
start_pos[slot].clamp_(0, ignored_index)
end_pos[slot].clamp_(0, ignored_index)
class_loss_fct = CrossEntropyLoss(reduction='none')
token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
refer_loss_fct = CrossEntropyLoss(reduction='none')
start_loss = token_loss_fct(start_logits, start_pos[slot])
end_loss = token_loss_fct(end_logits, end_pos[slot])
token_loss = (start_loss + end_loss) / 2.0
token_is_pointable = (start_pos[slot] > 0).float()
if not self.token_loss_for_nonpointable:
token_loss *= token_is_pointable
refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
if not self.refer_loss_for_nonpointable:
refer_loss *= token_is_referrable
class_loss = class_loss_fct(class_logits, class_label_id[slot])
if self.refer_index > -1:
per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
else:
per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
total_loss += per_example_loss.sum()
per_slot_per_example_loss[slot] = per_example_loss
# add hidden states and attention if they are here
outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:]
return outputs
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
# Part of this code is based on the source code of Transformers
# (arXiv:1910.03771)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import random
import glob
import json
import math
import re
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from tensorboardX import SummaryWriter
from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer)
from transformers import (AdamW, get_linear_schedule_with_warmup)
from modeling_bert_dst import (BertForDST)
from data_processors import PROCESSORS
from utils_dst import (convert_examples_to_features)
from tensorlistdataset import (TensorListDataset)
logger = logging.getLogger(__name__)
ALL_MODELS = tuple(BertConfig.pretrained_config_archive_map.keys())
MODEL_CLASSES = {
'bert': (BertConfig, BertForDST, BertTokenizer),
}
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def to_list(tensor):
return tensor.detach().cpu().tolist()
def batch_to_device(batch, device):
batch_on_device = []
for element in batch:
if isinstance(element, dict):
batch_on_device.append({k: v.to(device) for k, v in element.items()})
else:
batch_on_device.append(element.to(device))
return tuple(batch_on_device)
def train(args, train_dataset, features, model, tokenizer, processor, continue_from_global_step=0):
""" Train the model """
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter()
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
if args.max_steps > 0:
t_total = args.max_steps
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
else:
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
if args.save_epochs > 0:
args.save_steps = t_total // args.num_train_epochs * args.save_epochs
num_warmup_steps = int(t_total * args.warmup_proportion)
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total)
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
model_single_gpu = model
if args.n_gpu > 1:
model = torch.nn.DataParallel(model_single_gpu)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
logger.info(" Warmup steps = %d", num_warmup_steps)
if continue_from_global_step > 0:
logger.info("Fast forwarding to global step %d to resume training from latest checkpoint...", continue_from_global_step)
global_step = 0
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
# If training is continued from a checkpoint, fast forward
# to the state of that checkpoint.
if global_step < continue_from_global_step:
if (step + 1) % args.gradient_accumulation_steps == 0:
scheduler.step() # Update learning rate schedule
global_step += 1
continue
model.train()
batch = batch_to_device(batch, args.device)
# This is what is forwarded to the "forward" def.
inputs = {'input_ids': batch[0],
'input_mask': batch[1],
'segment_ids': batch[2],
'start_pos': batch[3],
'end_pos': batch[4],
'inform_slot_id': batch[5],
'refer_id': batch[6],
'diag_state': batch[7],
'class_label_id': batch[8]}
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
optimizer.step()
scheduler.step() # Update learning rate schedule
model.zero_grad()
global_step += 1
# Log metrics
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step)
logging_loss = tr_loss
# Save model checkpoint
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model_single_gpu, tokenizer, processor, prefix=global_step)
for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
if args.max_steps > 0 and global_step > args.max_steps:
train_iterator.close()
break
if args.local_rank in [-1, 0]:
tb_writer.close()
return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, processor, prefix=""):
dataset, features = load_and_cache_examples(args, model, tokenizer, processor, evaluate=True)
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
args.eval_batch_size = args.per_gpu_eval_batch_size
eval_sampler = SequentialSampler(dataset) # Note that DistributedSampler samples randomly
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
all_results = []
all_preds = []
ds = {slot: 'none' for slot in model.slot_list}
with torch.no_grad():
diag_state = {slot: torch.tensor([0 for _ in range(args.eval_batch_size)]).to(args.device) for slot in model.slot_list}
for batch in tqdm(eval_dataloader, desc="Evaluating"):
model.eval()
batch = batch_to_device(batch, args.device)
# Reset dialog state if turn is first in the dialog.
turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]]
reset_diag_state = np.where(np.array(turn_itrs) == '0')[0]
for slot in model.slot_list:
for i in reset_diag_state:
diag_state[slot][i] = 0
with torch.no_grad():
inputs = {'input_ids': batch[0],
'input_mask': batch[1],
'segment_ids': batch[2],
'start_pos': batch[3],
'end_pos': batch[4],
'inform_slot_id': batch[5],
'refer_id': batch[6],
'diag_state': diag_state,
'class_label_id': batch[8]}
unique_ids = [features[i.item()].guid for i in batch[9]]
values = [features[i.item()].values for i in batch[9]]
input_ids_unmasked = [features[i.item()].input_ids_unmasked for i in batch[9]]
inform = [features[i.item()].inform for i in batch[9]]
outputs = model(**inputs)
# Update dialog state for next turn.
for slot in model.slot_list:
updates = outputs[2][slot].max(1)[1]
for i, u in enumerate(updates):
if u != 0:
diag_state[slot][i] = u
results = eval_metric(model, inputs, outputs[0], outputs[1], outputs[2], outputs[3], outputs[4], outputs[5])
preds, ds = predict_and_format(model, tokenizer, inputs, outputs[2], outputs[3], outputs[4], outputs[5], unique_ids, input_ids_unmasked, values, inform, prefix, ds)
all_results.append(results)
all_preds.append(preds)
all_preds = [item for sublist in all_preds for item in sublist] # Flatten list
# Generate final results
final_results = {}
for k in all_results[0].keys():
final_results[k] = torch.stack([r[k] for r in all_results]).mean()
# Write final predictions (for evaluation with external tool)
output_prediction_file = os.path.join(args.output_dir, "pred_res.%s.%s.json" % (args.predict_type, prefix))
with open(output_prediction_file, "w") as f:
json.dump(all_preds, f, indent=2)
return final_results
def eval_metric(model, features, total_loss, per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits):
metric_dict = {}
per_slot_correctness = {}
for slot in model.slot_list:
per_example_loss = per_slot_per_example_loss[slot]
class_logits = per_slot_class_logits[slot]
start_logits = per_slot_start_logits[slot]
end_logits = per_slot_end_logits[slot]
refer_logits = per_slot_refer_logits[slot]
class_label_id = features['class_label_id'][slot]
start_pos = features['start_pos'][slot]
end_pos = features['end_pos'][slot]
refer_id = features['refer_id'][slot]
_, class_prediction = class_logits.max(1)
class_correctness = torch.eq(class_prediction, class_label_id).float()
class_accuracy = class_correctness.mean()
# "is pointable" means whether class label is "copy_value",
# i.e., that there is a span to be detected.
token_is_pointable = torch.eq(class_label_id, model.class_types.index('copy_value')).float()
_, start_prediction = start_logits.max(1)
start_correctness = torch.eq(start_prediction, start_pos).float()
_, end_prediction = end_logits.max(1)
end_correctness = torch.eq(end_prediction, end_pos).float()
token_correctness = start_correctness * end_correctness
token_accuracy = (token_correctness * token_is_pointable).sum() / token_is_pointable.sum()
# NaNs mean that none of the examples in this batch contain spans. -> division by 0
# The accuracy therefore is 1 by default. -> replace NaNs
if math.isnan(token_accuracy):
token_accuracy = torch.tensor(1.0, device=token_accuracy.device)
token_is_referrable = torch.eq(class_label_id, model.class_types.index('refer') if 'refer' in model.class_types else -1).float()
_, refer_prediction = refer_logits.max(1)
refer_correctness = torch.eq(refer_prediction, refer_id).float()
refer_accuracy = refer_correctness.sum() / token_is_referrable.sum()
# NaNs mean that none of the examples in this batch contain referrals. -> division by 0
# The accuracy therefore is 1 by default. -> replace NaNs
if math.isnan(refer_accuracy) or math.isinf(refer_accuracy):
refer_accuracy = torch.tensor(1.0, device=refer_accuracy.device)
total_correctness = class_correctness * (token_is_pointable * token_correctness + (1 - token_is_pointable)) * (token_is_referrable * refer_correctness + (1 - token_is_referrable))
total_accuracy = total_correctness.mean()
loss = per_example_loss.mean()
metric_dict['eval_accuracy_class_%s' % slot] = class_accuracy
metric_dict['eval_accuracy_token_%s' % slot] = token_accuracy
metric_dict['eval_accuracy_refer_%s' % slot] = refer_accuracy
metric_dict['eval_accuracy_%s' % slot] = total_accuracy
metric_dict['eval_loss_%s' % slot] = loss
per_slot_correctness[slot] = total_correctness
goal_correctness = torch.stack([c for c in per_slot_correctness.values()], 1).prod(1)
goal_accuracy = goal_correctness.mean()
metric_dict['eval_accuracy_goal'] = goal_accuracy
metric_dict['loss'] = total_loss
return metric_dict
def predict_and_format(model, tokenizer, features, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits, ids, input_ids_unmasked, values, inform, prefix, ds):
prediction_list = []
dialog_state = ds
for i in range(len(ids)):
if int(ids[i].split("-")[2]) == 0:
dialog_state = {slot: 'none' for slot in model.slot_list}
prediction = {}
prediction_addendum = {}
for slot in model.slot_list:
class_logits = per_slot_class_logits[slot][i]
start_logits = per_slot_start_logits[slot][i]
end_logits = per_slot_end_logits[slot][i]
refer_logits = per_slot_refer_logits[slot][i]
input_ids = features['input_ids'][i].tolist()
class_label_id = int(features['class_label_id'][slot][i])
start_pos = int(features['start_pos'][slot][i])
end_pos = int(features['end_pos'][slot][i])
refer_id = int(features['refer_id'][slot][i])
class_prediction = int(class_logits.argmax())
start_prediction = int(start_logits.argmax())
end_prediction = int(end_logits.argmax())
refer_prediction = int(refer_logits.argmax())
prediction['guid'] = ids[i].split("-")
prediction['class_prediction_%s' % slot] = class_prediction
prediction['class_label_id_%s' % slot] = class_label_id
prediction['start_prediction_%s' % slot] = start_prediction
prediction['start_pos_%s' % slot] = start_pos
prediction['end_prediction_%s' % slot] = end_prediction
prediction['end_pos_%s' % slot] = end_pos
prediction['refer_prediction_%s' % slot] = refer_prediction
prediction['refer_id_%s' % slot] = refer_id
prediction['input_ids_%s' % slot] = input_ids
if class_prediction == model.class_types.index('dontcare'):
dialog_state[slot] = 'dontcare'
elif class_prediction == model.class_types.index('copy_value'):
input_tokens = tokenizer.convert_ids_to_tokens(input_ids_unmasked[i])
dialog_state[slot] = ' '.join(input_tokens[start_prediction:end_prediction + 1])
dialog_state[slot] = re.sub("(^| )##", "", dialog_state[slot])
elif 'true' in model.class_types and class_prediction == model.class_types.index('true'):
dialog_state[slot] = 'true'
elif 'false' in model.class_types and class_prediction == model.class_types.index('false'):
dialog_state[slot] = 'false'
elif class_prediction == model.class_types.index('inform'):
dialog_state[slot] = '§§' + inform[i][slot]
# Referral case is handled below
prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot]
prediction_addendum['slot_groundtruth_%s' % slot] = values[i][slot]
# Referral case. All other slot values need to be seen first in order
# to be able to do this correctly.
for slot in model.slot_list:
class_logits = per_slot_class_logits[slot][i]
refer_logits = per_slot_refer_logits[slot][i]
class_prediction = int(class_logits.argmax())
refer_prediction = int(refer_logits.argmax())
if 'refer' in model.class_types and class_prediction == model.class_types.index('refer'):
# Only slots that have been mentioned before can be referred to.
# One can think of a situation where one slot is referred to in the same utterance.
# This phenomenon is however currently not properly covered in the training data
# label generation process.
dialog_state[slot] = dialog_state[model.slot_list[refer_prediction - 1]]
prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot] # Value update
prediction.update(prediction_addendum)
prediction_list.append(prediction)
return prediction_list, dialog_state
def load_and_cache_examples(args, model, tokenizer, processor, evaluate=False):
if args.local_rank not in [-1, 0] and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Load data features from cache or dataset file
cached_file = os.path.join(os.path.dirname(args.output_dir), 'cached_{}_features'.format(
args.predict_type if evaluate else 'train'))
if os.path.exists(cached_file) and not args.overwrite_cache: # and not output_examples:
logger.info("Loading features from cached file %s", cached_file)
features = torch.load(cached_file)
else:
logger.info("Creating features from dataset file at %s", args.data_dir)
processor_args = {'append_history': args.append_history,
'use_history_labels': args.use_history_labels,
'swap_utterances': args.swap_utterances,
'label_value_repetitions': args.label_value_repetitions,
'delexicalize_sys_utts': args.delexicalize_sys_utts}
if evaluate and args.predict_type == "dev":
examples = processor.get_dev_examples(args.data_dir, processor_args)
elif evaluate and args.predict_type == "test":
examples = processor.get_test_examples(args.data_dir, processor_args)
else:
examples = processor.get_train_examples(args.data_dir, processor_args)
features = convert_examples_to_features(examples=examples,
slot_list=model.slot_list,
class_types=model.class_types,
model_type=args.model_type,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
slot_value_dropout=(0.0 if evaluate else args.svd))
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_file)
torch.save(features, cached_file)
if args.local_rank == 0 and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
f_start_pos = [f.start_pos for f in features]
f_end_pos = [f.end_pos for f in features]
f_inform_slot_ids = [f.inform_slot for f in features]
f_refer_ids = [f.refer_id for f in features]
f_diag_state = [f.diag_state for f in features]
f_class_label_ids = [f.class_label_id for f in features]
all_start_positions = {}
all_end_positions = {}
all_inform_slot_ids = {}
all_refer_ids = {}
all_diag_state = {}
all_class_label_ids = {}
for s in model.slot_list:
all_start_positions[s] = torch.tensor([f[s] for f in f_start_pos], dtype=torch.long)
all_end_positions[s] = torch.tensor([f[s] for f in f_end_pos], dtype=torch.long)
all_inform_slot_ids[s] = torch.tensor([f[s] for f in f_inform_slot_ids], dtype=torch.long)
all_refer_ids[s] = torch.tensor([f[s] for f in f_refer_ids], dtype=torch.long)
all_diag_state[s] = torch.tensor([f[s] for f in f_diag_state], dtype=torch.long)
all_class_label_ids[s] = torch.tensor([f[s] for f in f_class_label_ids], dtype=torch.long)
dataset = TensorListDataset(all_input_ids, all_input_mask, all_segment_ids,
all_start_positions, all_end_positions,
all_inform_slot_ids,
all_refer_ids,
all_diag_state,
all_class_label_ids, all_example_index)
return dataset, features
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--task_name", default=None, type=str, required=True,
help="Name of the task (e.g., multiwoz21).")
parser.add_argument("--data_dir", default=None, type=str, required=True,
help="Task database.")
parser.add_argument("--dataset_config", default=None, type=str, required=True,
help="Dataset configuration file.")
parser.add_argument("--predict_type", default=None, type=str, required=True,
help="Portion of the data to perform prediction on (e.g., dev, test).")
parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model checkpoints and predictions will be written.")
# Other parameters
parser.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name")
parser.add_argument("--tokenizer_name", default="", type=str,
help="Pretrained tokenizer name or path if not the same as model_name")
parser.add_argument("--max_seq_length", default=384, type=int,
help="Maximum input length after tokenization. Longer sequences will be truncated, shorter ones padded.")
parser.add_argument("--do_train", action='store_true',
help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true',
help="Whether to run eval on the <predict_type> set.")
parser.add_argument("--evaluate_during_training", action='store_true',
help="Rul evaluation during training at each logging step.")
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--dropout_rate", default=0.3, type=float,
help="Dropout rate for BERT representations.")
parser.add_argument("--heads_dropout", default=0.0, type=float,
help="Dropout rate for classification heads.")
parser.add_argument("--class_loss_ratio", default=0.8, type=float,
help="The ratio applied on class loss in total loss calculation. "
"Should be a value in [0.0, 1.0]. "
"The ratio applied on token loss is (1-class_loss_ratio)/2. "
"The ratio applied on refer loss is (1-class_loss_ratio)/2.")
parser.add_argument("--token_loss_for_nonpointable", action='store_true',
help="Whether the token loss for classes other than copy_value contribute towards total loss.")
parser.add_argument("--refer_loss_for_nonpointable", action='store_true',
help="Whether the refer loss for classes other than refer contribute towards total loss.")
parser.add_argument("--append_history", action='store_true',
help="Whether or not to append the dialog history to each turn.")
parser.add_argument("--use_history_labels", action='store_true',
help="Whether or not to label the history as well.")
parser.add_argument("--swap_utterances", action='store_true',
help="Whether or not to swap the turn utterances (default: sys|usr, swapped: usr|sys).")
parser.add_argument("--label_value_repetitions", action='store_true',
help="Whether or not to label values that have been mentioned before.")
parser.add_argument("--delexicalize_sys_utts", action='store_true',
help="Whether or not to delexicalize the system utterances.")
parser.add_argument("--class_aux_feats_inform", action='store_true',
help="Whether or not to use the identity of informed slots as auxiliary featurs for class prediction.")
parser.add_argument("--class_aux_feats_ds", action='store_true',
help="Whether or not to use the identity of slots in the current dialog state as auxiliary featurs for class prediction.")
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
help="Batch size per GPU/CPU for evaluation.")
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="Weight deay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
help="Max gradient norm.")
parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--max_steps", default=-1, type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
parser.add_argument("--warmup_proportion", default=0.0, type=float,
help="Linear warmup over warmup_proportion * steps.")
parser.add_argument("--svd", default=0.0, type=float,
help="Slot value dropout ratio (default: 0.0)")
parser.add_argument('--logging_steps', type=int, default=50,
help="Log every X updates steps.")
parser.add_argument('--save_steps', type=int, default=0,
help="Save checkpoint every X updates steps. Overwritten by --save_epochs.")
parser.add_argument('--save_epochs', type=int, default=0,
help="Save checkpoint every X epochs. Overrides --save_steps.")
parser.add_argument("--eval_all_checkpoints", action='store_true',
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
parser.add_argument("--no_cuda", action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument('--overwrite_output_dir', action='store_true',
help="Overwrite the content of the output directory")
parser.add_argument('--overwrite_cache', action='store_true',
help="Overwrite the cached training and evaluation sets")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument("--local_rank", type=int, default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
args = parser.parse_args()
assert(args.warmup_proportion >= 0.0 and args.warmup_proportion <= 1.0)
assert(args.svd >= 0.0 and args.svd <= 1.0)
assert(args.class_aux_feats_ds is False or args.per_gpu_eval_batch_size == 1)
assert(not args.class_aux_feats_inform or args.per_gpu_eval_batch_size == 1)
assert(not args.class_aux_feats_ds or args.per_gpu_eval_batch_size == 1)
task_name = args.task_name.lower()
if task_name not in PROCESSORS:
raise ValueError("Task not found: %s" % (task_name))
processor = PROCESSORS[task_name](args.dataset_config)
dst_slot_list = processor.slot_list
dst_class_types = processor.class_types
dst_class_labels = len(dst_class_types)
# Setup CUDA, GPU & distributed training
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl')
args.n_gpu = 1
args.device = device
# Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
# Set seed
set_seed(args)
# Load pretrained model and tokenizer
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
# Add DST specific parameters to config
config.dst_dropout_rate = args.dropout_rate
config.dst_heads_dropout_rate = args.heads_dropout
config.dst_class_loss_ratio = args.class_loss_ratio
config.dst_token_loss_for_nonpointable = args.token_loss_for_nonpointable
config.dst_refer_loss_for_nonpointable = args.refer_loss_for_nonpointable
config.dst_class_aux_feats_inform = args.class_aux_feats_inform
config.dst_class_aux_feats_ds = args.class_aux_feats_ds
config.dst_slot_list = dst_slot_list
config.dst_class_types = dst_class_types
config.dst_class_labels = dst_class_labels
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
logger.info("Updated model config: %s" % config)
if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
model.to(args.device)
logger.info("Training/evaluation parameters %s", args)
# Training
if args.do_train:
# If output files already exists, assume to continue training from latest checkpoint (unless overwrite_output_dir is set)
continue_from_global_step = 0 # If set to 0, start training from the beginning
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/*/' + WEIGHTS_NAME, recursive=True)))
if len(checkpoints) > 0:
checkpoint = checkpoints[-1]
logger.info("Resuming training from the latest checkpoint: %s", checkpoint)
continue_from_global_step = int(checkpoint.split('-')[-1])
model = model_class.from_pretrained(checkpoint)
model.to(args.device)
train_dataset, features = load_and_cache_examples(args, model, tokenizer, processor, evaluate=False)
global_step, tr_loss = train(args, train_dataset, features, model, tokenizer, processor, continue_from_global_step)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Save the trained model and the tokenizer
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Create output directory if needed
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model.to(args.device)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
results = []
if args.do_eval and args.local_rank in [-1, 0]:
output_eval_file = os.path.join(args.output_dir, "eval_res.%s.json" % (args.predict_type))
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for cItr, checkpoint in enumerate(checkpoints):
# Reload the model
global_step = checkpoint.split('-')[-1]
if cItr == len(checkpoints) - 1:
global_step = "final"
model = model_class.from_pretrained(checkpoint)
model.to(args.device)
# Evaluate
result = evaluate(args, model, tokenizer, processor, prefix=global_step)
result_dict = {k: float(v) for k, v in result.items()}
result_dict["global_step"] = global_step
results.append(result_dict)
for key in sorted(result_dict.keys()):
logger.info("%s = %s", key, str(result_dict[key]))
with open(output_eval_file, "w") as f:
json.dump(results, f, indent=2)
return results
if __name__ == "__main__":
main()
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.utils.data import Dataset
class TensorListDataset(Dataset):
r"""Dataset wrapping tensors, tensor dicts and tensor lists.
Arguments:
*data (Tensor or dict or list of Tensors): tensors that have the same size
of the first dimension.
"""
def __init__(self, *data):
if isinstance(data[0], dict):
size = list(data[0].values())[0].size(0)
elif isinstance(data[0], list):
size = data[0][0].size(0)
else:
size = data[0].size(0)
for element in data:
if isinstance(element, dict):
assert all(size == tensor.size(0) for name, tensor in element.items()) # dict of tensors
elif isinstance(element, list):
assert all(size == tensor.size(0) for tensor in element) # list of tensors
else:
assert size == element.size(0) # tensor
self.size = size
self.data = data
def __getitem__(self, index):
result = []
for element in self.data:
if isinstance(element, dict):
result.append({k: v[index] for k, v in element.items()})
elif isinstance(element, list):
result.append(v[index] for v in element)
else:
result.append(element[index])
return tuple(result)
def __len__(self):
return self.size
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
# Part of this code is based on the source code of Transformers
# (arXiv:1910.03771)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import six
import numpy as np
import json
logger = logging.getLogger(__name__)
class DSTExample(object):
"""
A single training/test example for the DST dataset.
"""
def __init__(self,
guid,
text_a,
text_b,
history,
text_a_label=None,
text_b_label=None,
history_label=None,
values=None,
inform_label=None,
inform_slot_label=None,
refer_label=None,
diag_state=None,
class_label=None):
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.history = history
self.text_a_label = text_a_label
self.text_b_label = text_b_label
self.history_label = history_label
self.values = values
self.inform_label = inform_label
self.inform_slot_label = inform_slot_label
self.refer_label = refer_label
self.diag_state = diag_state
self.class_label = class_label
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += "guid: %s" % (self.guid)
s += ", text_a: %s" % (self.text_a)
s += ", text_b: %s" % (self.text_b)
s += ", history: %s" % (self.history)
if self.text_a_label:
s += ", text_a_label: %d" % (self.text_a_label)
if self.text_b_label:
s += ", text_b_label: %d" % (self.text_b_label)
if self.history_label:
s += ", history_label: %d" % (self.history_label)
if self.values:
s += ", values: %d" % (self.values)
if self.inform_label:
s += ", inform_label: %d" % (self.inform_label)
if self.inform_slot_label:
s += ", inform_slot_label: %d" % (self.inform_slot_label)
if self.refer_label:
s += ", refer_label: %d" % (self.refer_label)
if self.diag_state:
s += ", diag_state: %d" % (self.diag_state)
if self.class_label:
s += ", class_label: %d" % (self.class_label)
return s
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
input_ids,
input_ids_unmasked,
input_mask,
segment_ids,
start_pos=None,
end_pos=None,
values=None,
inform=None,
inform_slot=None,
refer_id=None,
diag_state=None,
class_label_id=None,
guid="NONE"):
self.guid = guid
self.input_ids = input_ids
self.input_ids_unmasked = input_ids_unmasked
self.input_mask = input_mask
self.segment_ids = segment_ids
self.start_pos = start_pos
self.end_pos = end_pos
self.values = values
self.inform = inform
self.inform_slot = inform_slot
self.refer_id = refer_id
self.diag_state = diag_state
self.class_label_id = class_label_id
def convert_examples_to_features(examples, slot_list, class_types, model_type, tokenizer, max_seq_length, slot_value_dropout=0.0):
"""Loads a data file into a list of `InputBatch`s."""
if model_type == 'bert':
model_specs = {'MODEL_TYPE': 'bert',
'CLS_TOKEN': '[CLS]',
'UNK_TOKEN': '[UNK]',
'SEP_TOKEN': '[SEP]',
'TOKEN_CORRECTION': 4}
else:
logger.error("Unknown model type (%s). Aborting." % (model_type))
exit(1)
def _tokenize_text_and_label(text, text_label_dict, slot, tokenizer, model_specs, slot_value_dropout):
joint_text_label = [0 for _ in text_label_dict[slot]] # joint all slots' label
for slot_text_label in text_label_dict.values():
for idx, label in enumerate(slot_text_label):
if label == 1:
joint_text_label[idx] = 1
text_label = text_label_dict[slot]
tokens = []
tokens_unmasked = []
token_labels = []
for token, token_label, joint_label in zip(text, text_label, joint_text_label):
token = convert_to_unicode(token)
sub_tokens = tokenizer.tokenize(token) # Most time intensive step
tokens_unmasked.extend(sub_tokens)
if slot_value_dropout == 0.0 or joint_label == 0:
tokens.extend(sub_tokens)
else:
rn_list = np.random.random_sample((len(sub_tokens),))
for rn, sub_token in zip(rn_list, sub_tokens):
if rn > slot_value_dropout:
tokens.append(sub_token)
else:
tokens.append(model_specs['UNK_TOKEN'])
token_labels.extend([token_label for _ in sub_tokens])
assert len(tokens) == len(token_labels)
assert len(tokens_unmasked) == len(token_labels)
return tokens, tokens_unmasked, token_labels
def _truncate_seq_pair(tokens_a, tokens_b, history, max_length):
"""Truncates a sequence pair in place to the maximum length.
Copied from bert/run_classifier.py
"""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b) + len(history)
if total_length <= max_length:
break
if len(history) > 0:
history.pop()
elif len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, model_specs, guid):
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP], [SEP] with "- 4" (BERT)
if len(tokens_a) + len(tokens_b) + len(history) > max_seq_length - model_specs['TOKEN_CORRECTION']:
logger.info("Truncate Example %s. Total len=%d." % (guid, len(tokens_a) + len(tokens_b) + len(history)))
input_text_too_long = True
else:
input_text_too_long = False
_truncate_seq_pair(tokens_a, tokens_b, history, max_seq_length - model_specs['TOKEN_CORRECTION'])
return input_text_too_long
def _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs):
token_label_ids = []
token_label_ids.append(0) # [CLS]
for token_label in token_labels_a:
token_label_ids.append(token_label)
token_label_ids.append(0) # [SEP]
for token_label in token_labels_b:
token_label_ids.append(token_label)
token_label_ids.append(0) # [SEP]
for token_label in token_labels_history:
token_label_ids.append(token_label)
token_label_ids.append(0) # [SEP]
while len(token_label_ids) < max_seq_length:
token_label_ids.append(0) # padding
assert len(token_label_ids) == max_seq_length
return token_label_ids
def _get_start_end_pos(class_type, token_label_ids, max_seq_length):
if class_type == 'copy_value' and 1 not in token_label_ids:
#logger.warn("copy_value label, but token_label not detected. Setting label to 'none'.")
class_type = 'none'
start_pos = 0
end_pos = 0
if 1 in token_label_ids:
start_pos = token_label_ids.index(1)
# Parsing is supposed to find only first location of wanted value
if 0 not in token_label_ids[start_pos:]:
end_pos = len(token_label_ids[start_pos:]) + start_pos - 1
else:
end_pos = token_label_ids[start_pos:].index(0) + start_pos - 1
for i in range(max_seq_length):
if i >= start_pos and i <= end_pos:
assert token_label_ids[i] == 1
return class_type, start_pos, end_pos
def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, tokenizer, model_specs):
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
segment_ids = []
tokens.append(model_specs['CLS_TOKEN'])
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append(model_specs['SEP_TOKEN'])
segment_ids.append(0)
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append(model_specs['SEP_TOKEN'])
segment_ids.append(1)
for token in history:
tokens.append(token)
segment_ids.append(1)
tokens.append(model_specs['SEP_TOKEN'])
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
return tokens, input_ids, input_mask, segment_ids
total_cnt = 0
too_long_cnt = 0
refer_list = ['none'] + slot_list
features = []
# Convert single example
for (example_index, example) in enumerate(examples):
if example_index % 1000 == 0:
logger.info("Writing example %d of %d" % (example_index, len(examples)))
total_cnt += 1
value_dict = {}
inform_dict = {}
inform_slot_dict = {}
refer_id_dict = {}
diag_state_dict = {}
class_label_id_dict = {}
start_pos_dict = {}
end_pos_dict = {}
for slot in slot_list:
tokens_a, tokens_a_unmasked, token_labels_a = _tokenize_text_and_label(
example.text_a, example.text_a_label, slot, tokenizer, model_specs, slot_value_dropout)
tokens_b, tokens_b_unmasked, token_labels_b = _tokenize_text_and_label(
example.text_b, example.text_b_label, slot, tokenizer, model_specs, slot_value_dropout)
tokens_history, tokens_history_unmasked, token_labels_history = _tokenize_text_and_label(
example.history, example.history_label, slot, tokenizer, model_specs, slot_value_dropout)
input_text_too_long = _truncate_length_and_warn(
tokens_a, tokens_b, tokens_history, max_seq_length, model_specs, example.guid)
if input_text_too_long:
if example_index < 10:
if len(token_labels_a) > len(tokens_a):
logger.info(' tokens_a truncated labels: %s' % str(token_labels_a[len(tokens_a):]))
if len(token_labels_b) > len(tokens_b):
logger.info(' tokens_b truncated labels: %s' % str(token_labels_b[len(tokens_b):]))
if len(token_labels_history) > len(tokens_history):
logger.info(' tokens_history truncated labels: %s' % str(token_labels_history[len(tokens_history):]))
token_labels_a = token_labels_a[:len(tokens_a)]
token_labels_b = token_labels_b[:len(tokens_b)]
token_labels_history = token_labels_history[:len(tokens_history)]
tokens_a_unmasked = tokens_a_unmasked[:len(tokens_a)]
tokens_b_unmasked = tokens_b_unmasked[:len(tokens_b)]
tokens_history_unmasked = tokens_history_unmasked[:len(tokens_history)]
assert len(token_labels_a) == len(tokens_a)
assert len(token_labels_b) == len(tokens_b)
assert len(token_labels_history) == len(tokens_history)
assert len(token_labels_a) == len(tokens_a_unmasked)
assert len(token_labels_b) == len(tokens_b_unmasked)
assert len(token_labels_history) == len(tokens_history_unmasked)
token_label_ids = _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs)
value_dict[slot] = example.values[slot]
inform_dict[slot] = example.inform_label[slot]
class_label_mod, start_pos_dict[slot], end_pos_dict[slot] = _get_start_end_pos(
example.class_label[slot], token_label_ids, max_seq_length)
if class_label_mod != example.class_label[slot]:
example.class_label[slot] = class_label_mod
inform_slot_dict[slot] = example.inform_slot_label[slot]
refer_id_dict[slot] = refer_list.index(example.refer_label[slot])
diag_state_dict[slot] = class_types.index(example.diag_state[slot])
class_label_id_dict[slot] = class_types.index(example.class_label[slot])
if input_text_too_long:
too_long_cnt += 1
tokens, input_ids, input_mask, segment_ids = _get_transformer_input(tokens_a,
tokens_b,
tokens_history,
max_seq_length,
tokenizer,
model_specs)
if slot_value_dropout > 0.0:
_, input_ids_unmasked, _, _ = _get_transformer_input(tokens_a_unmasked,
tokens_b_unmasked,
tokens_history_unmasked,
max_seq_length,
tokenizer,
model_specs)
else:
input_ids_unmasked = input_ids
assert(len(input_ids) == len(input_ids_unmasked))
if example_index < 10:
logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid))
logger.info("tokens: %s" % " ".join(tokens))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
logger.info("start_pos: %s" % str(start_pos_dict))
logger.info("end_pos: %s" % str(end_pos_dict))
logger.info("values: %s" % str(value_dict))
logger.info("inform: %s" % str(inform_dict))
logger.info("inform_slot: %s" % str(inform_slot_dict))
logger.info("refer_id: %s" % str(refer_id_dict))
logger.info("diag_state: %s" % str(diag_state_dict))
logger.info("class_label_id: %s" % str(class_label_id_dict))
features.append(
InputFeatures(
guid=example.guid,
input_ids=input_ids,
input_ids_unmasked=input_ids_unmasked,
input_mask=input_mask,
segment_ids=segment_ids,
start_pos=start_pos_dict,
end_pos=end_pos_dict,
values=value_dict,
inform=inform_dict,
inform_slot=inform_slot_dict,
refer_id=refer_id_dict,
diag_state=diag_state_dict,
class_label_id=class_label_id_dict))
logger.info("========== %d out of %d examples have text too long" % (too_long_cnt, total_cnt))
return features
# From bert.tokenization (TF code)
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment