Skip to content
Snippets Groups Projects
Commit 2e2a8b87 authored by heck's avatar heck
Browse files

initial

parent eea0be40
Branches
No related tags found
No related merge requests found
#!/bin/bash
# Parameters ------------------------------------------------------
#TASK="sim-m"
#DATA_DIR="data/simulated-dialogue/sim-M"
#TASK="sim-r"
#DATA_DIR="data/simulated-dialogue/sim-R"
TASK="woz2"
DATA_DIR="data/woz2"
#TASK="multiwoz21"
#DATA_DIR="data/MULTIWOZ2.1"
# Project paths etc. ----------------------------------------------
OUT_DIR=results
mkdir -p ${OUT_DIR}
# Main ------------------------------------------------------------
for step in train dev test; do
args_add=""
if [ "$step" = "train" ]; then
args_add="--do_train --predict_type=dummy"
elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
args_add="--do_eval --predict_type=${step}"
fi
python3 run_dst.py \
--task_name=${TASK} \
--data_dir=${DATA_DIR} \
--dataset_config=dataset_config/${TASK}.json \
--model_type="bert" \
--model_name_or_path="bert-base-uncased" \
--do_lower_case \
--learning_rate=1e-4 \
--num_train_epochs=10 \
--max_seq_length=180 \
--per_gpu_train_batch_size=48 \
--per_gpu_eval_batch_size=1 \
--output_dir=${OUT_DIR} \
--save_epochs=2 \
--logging_steps=10 \
--warmup_proportion=0.1 \
--eval_all_checkpoints \
--adam_epsilon=1e-6 \
--label_value_repetitions \
--swap_utterances \
--append_history \
--use_history_labels \
--delexicalize_sys_utts \
--class_aux_feats_inform \
--class_aux_feats_ds \
${args_add} \
2>&1 | tee ${OUT_DIR}/${step}.log
fi
if [ "$step" = "dev" ] || [ "$step" = "test" ]; then
python3 ~/tools/trippy_clean/metric_bert_dst.py \
${TASK} \
dataset_config/${TASK}.json \
"${OUT_DIR}/pred_res.${step}*json" \
2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log
fi
done
#!/bin/bash
# Parameters ------------------------------------------------------
#TASK="sim-m"
#DATA_DIR="data/simulated-dialogue/sim-M"
#TASK="sim-r"
#DATA_DIR="data/simulated-dialogue/sim-R"
TASK="woz2"
DATA_DIR="data/woz2"
#TASK="multiwoz21"
#DATA_DIR="data/MULTIWOZ2.1"
# Project paths etc. ----------------------------------------------
OUT_DIR=results
mkdir -p ${OUT_DIR}
# Main ------------------------------------------------------------
for step in train dev test; do
args_add=""
if [ "$step" = "train" ]; then
args_add="--do_train --predict_type=dummy"
elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
args_add="--do_eval --predict_type=${step}"
fi
python3 run_dst.py \
--task_name=${TASK} \
--data_dir=${DATA_DIR} \
--dataset_config=dataset_config/${TASK}.json \
--model_type="bert" \
--model_name_or_path="bert-base-uncased" \
--do_lower_case \
--learning_rate=1e-4 \
--num_train_epochs=10 \
--max_seq_length=180 \
--per_gpu_train_batch_size=48 \
--per_gpu_eval_batch_size=1 \
--output_dir=${OUT_DIR} \
--save_epochs=2 \
--logging_steps=10 \
--warmup_proportion=0.1 \
--eval_all_checkpoints \
--adam_epsilon=1e-6 \
--label_value_repetitions \
${args_add} \
2>&1 | tee ${OUT_DIR}/${step}.log
fi
if [ "$step" = "dev" ] || [ "$step" = "test" ]; then
python3 ~/tools/trippy_clean/metric_bert_dst.py \
${TASK} \
dataset_config/${TASK}.json \
"${OUT_DIR}/pred_res.${step}*json" \
2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log
fi
done
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
import dataset_woz2
import dataset_sim
import dataset_multiwoz21
class DataProcessor(object):
def __init__(self, dataset_config):
with open(dataset_config, "r", encoding='utf-8') as f:
raw_config = json.load(f)
self.class_types = raw_config['class_types']
self.slot_list = raw_config['slots']
self.label_maps = raw_config['label_maps']
def get_train_examples(self, data_dir, **args):
raise NotImplementedError()
def get_dev_examples(self, data_dir, **args):
raise NotImplementedError()
def get_test_examples(self, data_dir, **args):
raise NotImplementedError()
class Woz2Processor(DataProcessor):
def get_train_examples(self, data_dir, args):
return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_train_en.json'),
'train', self.slot_list, self.label_maps, **args)
def get_dev_examples(self, data_dir, args):
return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_validate_en.json'),
'dev', self.slot_list, self.label_maps, **args)
def get_test_examples(self, data_dir, args):
return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_test_en.json'),
'test', self.slot_list, self.label_maps, **args)
class Multiwoz21Processor(DataProcessor):
def get_train_examples(self, data_dir, args):
return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'train_dials.json'),
os.path.join(data_dir, 'dialogue_acts.json'),
'train', self.slot_list, self.label_maps, **args)
def get_dev_examples(self, data_dir, args):
return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'val_dials.json'),
os.path.join(data_dir, 'dialogue_acts.json'),
'dev', self.slot_list, self.label_maps, **args)
def get_test_examples(self, data_dir, args):
return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'test_dials.json'),
os.path.join(data_dir, 'dialogue_acts.json'),
'test', self.slot_list, self.label_maps, **args)
class SimProcessor(DataProcessor):
def get_train_examples(self, data_dir, args):
return dataset_sim.create_examples(os.path.join(data_dir, 'train.json'),
'train', self.slot_list, **args)
def get_dev_examples(self, data_dir, args):
return dataset_sim.create_examples(os.path.join(data_dir, 'dev.json'),
'dev', self.slot_list, **args)
def get_test_examples(self, data_dir, args):
return dataset_sim.create_examples(os.path.join(data_dir, 'test.json'),
'test', self.slot_list, **args)
PROCESSORS = {"woz2": Woz2Processor,
"sim-m": SimProcessor,
"sim-r": SimProcessor,
"multiwoz21": Multiwoz21Processor}
This diff is collapsed.
{
"class_types": [
"none",
"dontcare",
"copy_value",
"inform"
],
"slots": [
"date",
"movie",
"time",
"num_tickets",
"theatre_name"
],
"label_maps": {}
}
{
"class_types": [
"none",
"dontcare",
"copy_value",
"inform"
],
"slots": [
"category",
"rating",
"num_people",
"location",
"restaurant_name",
"time",
"date",
"price_range",
"meal"
],
"label_maps": {}
}
{
"class_types": [
"none",
"dontcare",
"copy_value",
"inform"
],
"slots": [
"area",
"food",
"price_range"
],
"label_maps": {
"center": [
"centre",
"downtown",
"central",
"down town",
"middle"
],
"south": [
"southern",
"southside"
],
"north": [
"northern",
"uptown",
"northside"
],
"west": [
"western",
"westside"
],
"east": [
"eastern",
"eastside"
],
"east side": [
"eastern",
"eastside"
],
"cheap": [
"low price",
"inexpensive",
"cheaper",
"low priced",
"affordable",
"nothing too expensive",
"without costing a fortune",
"cheapest",
"good deals",
"low prices",
"afford",
"on a budget",
"fair prices",
"less expensive",
"cheapeast",
"not cost an arm and a leg"
],
"moderate": [
"moderately",
"medium priced",
"medium price",
"fair price",
"fair prices",
"reasonable",
"reasonably priced",
"mid price",
"fairly priced",
"not outrageous",
"not too expensive",
"on a budget",
"mid range",
"reasonable priced",
"less expensive",
"not too pricey",
"nothing too expensive",
"nothing cheap",
"not overpriced",
"medium",
"inexpensive"
],
"expensive": [
"high priced",
"high end",
"high class",
"high quality",
"fancy",
"upscale",
"nice",
"fine dining",
"expensively priced"
],
"afghan": [
"afghanistan"
],
"african": [
"africa"
],
"asian oriental": [
"asian",
"oriental"
],
"australasian": [
"australian asian",
"austral asian"
],
"australian": [
"aussie"
],
"barbeque": [
"barbecue",
"bbq"
],
"basque": [
"bask"
],
"belgian": [
"belgium"
],
"british": [
"cotto"
],
"canapes": [
"canopy",
"canape",
"canap"
],
"catalan": [
"catalonian"
],
"corsican": [
"corsica"
],
"crossover": [
"cross over",
"over"
],
"gastropub": [
"gastro pub",
"gastro",
"gastropubs"
],
"hungarian": [
"goulash"
],
"indian": [
"india",
"indians",
"nirala"
],
"international": [
"all types of food"
],
"italian": [
"prezzo"
],
"jamaican": [
"jamaica"
],
"japanese": [
"sushi",
"beni hana"
],
"korean": [
"korea"
],
"lebanese": [
"lebanse"
],
"north american": [
"american",
"hamburger"
],
"portuguese": [
"portugese"
],
"seafood": [
"sea food",
"shellfish",
"fish"
],
"singaporean": [
"singapore"
],
"steakhouse": [
"steak house",
"steak"
],
"thai": [
"thailand",
"bangkok"
],
"traditional": [
"old fashioned",
"plain"
],
"turkish": [
"turkey"
],
"unusual": [
"unique and strange"
],
"venetian": [
"vanessa"
],
"vietnamese": [
"vietnam",
"thanh binh"
]
}
}
This diff is collapsed.
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from utils_dst import (DSTExample)
def dialogue_state_to_sv_dict(sv_list):
sv_dict = {}
for d in sv_list:
sv_dict[d['slot']] = d['value']
return sv_dict
def get_token_and_slot_label(turn):
if 'system_utterance' in turn:
sys_utt_tok = turn['system_utterance']['tokens']
sys_slot_label = turn['system_utterance']['slots']
else:
sys_utt_tok = []
sys_slot_label = []
usr_utt_tok = turn['user_utterance']['tokens']
usr_slot_label = turn['user_utterance']['slots']
return sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label
def get_tok_label(prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok,
sys_slot_label, usr_utt_tok, usr_slot_label, dial_id,
turn_id, slot_last_occurrence=True):
"""The position of the last occurrence of the slot value will be used."""
sys_utt_tok_label = [0 for _ in sys_utt_tok]
usr_utt_tok_label = [0 for _ in usr_utt_tok]
if slot_type not in cur_ds_dict:
class_type = 'none'
else:
value = cur_ds_dict[slot_type]
if value == 'dontcare' and (slot_type not in prev_ds_dict or prev_ds_dict[slot_type] != 'dontcare'):
# Only label dontcare at its first occurrence in the dialog
class_type = 'dontcare'
else: # If not none or dontcare, we have to identify whether
# there is a span, or if the value is informed
in_usr = False
in_sys = False
for label_d in usr_slot_label:
if label_d['slot'] == slot_type and value == ' '.join(
usr_utt_tok[label_d['start']:label_d['exclusive_end']]):
for idx in range(label_d['start'], label_d['exclusive_end']):
usr_utt_tok_label[idx] = 1
in_usr = True
class_type = 'copy_value'
if slot_last_occurrence:
break
if not in_usr or not slot_last_occurrence:
for label_d in sys_slot_label:
if label_d['slot'] == slot_type and value == ' '.join(
sys_utt_tok[label_d['start']:label_d['exclusive_end']]):
for idx in range(label_d['start'], label_d['exclusive_end']):
sys_utt_tok_label[idx] = 1
in_sys = True
class_type = 'inform'
if slot_last_occurrence:
break
if not in_usr and not in_sys:
assert sum(usr_utt_tok_label + sys_utt_tok_label) == 0
if (slot_type not in prev_ds_dict or value != prev_ds_dict[slot_type]):
raise ValueError('Copy value cannot found in Dial %s Turn %s' % (str(dial_id), str(turn_id)))
else:
class_type = 'none'
else:
assert sum(usr_utt_tok_label + sys_utt_tok_label) > 0
return sys_utt_tok_label, usr_utt_tok_label, class_type
def delex_utt(utt, values):
utt_delex = utt.copy()
for v in values:
utt_delex[v['start']:v['exclusive_end']] = ['[UNK]'] * (v['exclusive_end'] - v['start'])
return utt_delex
def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id,
delexicalize_sys_utts=False, slot_last_occurrence=True):
"""Make turn_label a dictionary of slot with value positions or being dontcare / none:
Turn label contains:
(1) the updates from previous to current dialogue state,
(2) values in current dialogue state explicitly mentioned in system or user utterance."""
prev_ds_dict = dialogue_state_to_sv_dict(prev_dialogue_state)
cur_ds_dict = dialogue_state_to_sv_dict(turn['dialogue_state'])
(sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label) = get_token_and_slot_label(turn)
sys_utt_tok_label_dict = {}
usr_utt_tok_label_dict = {}
inform_label_dict = {}
inform_slot_label_dict = {}
referral_label_dict = {}
class_type_dict = {}
for slot_type in slot_list:
inform_label_dict[slot_type] = 'none'
inform_slot_label_dict[slot_type] = 0
referral_label_dict[slot_type] = 'none' # Referral is not present in sim data
sys_utt_tok_label, usr_utt_tok_label, class_type = get_tok_label(
prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok, sys_slot_label,
usr_utt_tok, usr_slot_label, dial_id, turn_id,
slot_last_occurrence=slot_last_occurrence)
if sum(sys_utt_tok_label) > 0:
inform_label_dict[slot_type] = cur_ds_dict[slot_type]
inform_slot_label_dict[slot_type] = 1
sys_utt_tok_label = [0 for _ in sys_utt_tok_label] # Don't use token labels for sys utt
sys_utt_tok_label_dict[slot_type] = sys_utt_tok_label
usr_utt_tok_label_dict[slot_type] = usr_utt_tok_label
class_type_dict[slot_type] = class_type
if delexicalize_sys_utts:
sys_utt_tok = delex_utt(sys_utt_tok, sys_slot_label)
return (sys_utt_tok, sys_utt_tok_label_dict,
usr_utt_tok, usr_utt_tok_label_dict,
inform_label_dict, inform_slot_label_dict,
referral_label_dict, cur_ds_dict, class_type_dict)
def create_examples(input_file, set_type, slot_list,
label_maps={},
append_history=False,
use_history_labels=False,
swap_utterances=False,
label_value_repetitions=False,
delexicalize_sys_utts=False,
analyze=False):
"""Read a DST json file into a list of DSTExample."""
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)
examples = []
for entry in input_data:
dial_id = entry['dialogue_id']
prev_ds = []
hst = []
prev_hst_lbl_dict = {slot: [] for slot in slot_list}
prev_ds_lbl_dict = {slot: 'none' for slot in slot_list}
for turn_id, turn in enumerate(entry['turns']):
guid = '%s-%s-%s' % (set_type, dial_id, str(turn_id))
ds_lbl_dict = prev_ds_lbl_dict.copy()
hst_lbl_dict = prev_hst_lbl_dict.copy()
(text_a,
text_a_label,
text_b,
text_b_label,
inform_label,
inform_slot_label,
referral_label,
cur_ds_dict,
class_label) = get_turn_label(turn,
prev_ds,
slot_list,
dial_id,
turn_id,
delexicalize_sys_utts=delexicalize_sys_utts,
slot_last_occurrence=True)
if swap_utterances:
txt_a = text_b
txt_b = text_a
txt_a_lbl = text_b_label
txt_b_lbl = text_a_label
else:
txt_a = text_a
txt_b = text_b
txt_a_lbl = text_a_label
txt_b_lbl = text_b_label
value_dict = {}
for slot in slot_list:
if slot in cur_ds_dict:
value_dict[slot] = cur_ds_dict[slot]
else:
value_dict[slot] = 'none'
if class_label[slot] != 'none':
ds_lbl_dict[slot] = class_label[slot]
if append_history:
if use_history_labels:
hst_lbl_dict[slot] = txt_a_lbl[slot] + txt_b_lbl[slot] + hst_lbl_dict[slot]
else:
hst_lbl_dict[slot] = [0 for _ in txt_a_lbl[slot] + txt_b_lbl[slot] + hst_lbl_dict[slot]]
examples.append(DSTExample(
guid=guid,
text_a=txt_a,
text_b=txt_b,
history=hst,
text_a_label=txt_a_lbl,
text_b_label=txt_b_lbl,
history_label=prev_hst_lbl_dict,
values=value_dict,
inform_label=inform_label,
inform_slot_label=inform_slot_label,
refer_label=referral_label,
diag_state=prev_ds_lbl_dict,
class_label=class_label))
prev_ds = turn['dialogue_state']
prev_ds_lbl_dict = ds_lbl_dict.copy()
prev_hst_lbl_dict = hst_lbl_dict.copy()
if append_history:
hst = txt_a + txt_b + hst
return examples
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from utils_dst import (DSTExample, convert_to_unicode)
LABEL_MAPS = {} # Loaded from file
LABEL_FIX = {'centre': 'center', 'areas': 'area', 'phone number': 'number', 'price range': 'price_range'}
def delex_utt(utt, values):
utt_norm = utt.copy()
for s, v in values.items():
if v != 'none':
v_norm = tokenize(v)
v_len = len(v_norm)
for i in range(len(utt_norm) + 1 - v_len):
if utt_norm[i:i + v_len] == v_norm:
utt_norm[i:i + v_len] = ['[UNK]'] * v_len
return utt_norm
def get_token_pos(tok_list, label):
find_pos = []
found = False
label_list = [item for item in map(str.strip, re.split("(\W+)", label)) if len(item) > 0]
len_label = len(label_list)
for i in range(len(tok_list) + 1 - len_label):
if tok_list[i:i + len_label] == label_list:
find_pos.append((i, i + len_label)) # start, exclusive_end
found = True
return found, find_pos
def check_label_existence(label, usr_utt_tok, sys_utt_tok):
in_usr, usr_pos = get_token_pos(usr_utt_tok, label)
if not in_usr and label in LABEL_MAPS:
for tmp_label in LABEL_MAPS[label]:
in_usr, usr_pos = get_token_pos(usr_utt_tok, tmp_label)
if in_usr:
break
in_sys, sys_pos = get_token_pos(sys_utt_tok, label)
if not in_sys and label in LABEL_MAPS:
for tmp_label in LABEL_MAPS[label]:
in_sys, sys_pos = get_token_pos(sys_utt_tok, tmp_label)
if in_sys:
break
return in_usr, usr_pos, in_sys, sys_pos
def get_turn_label(label, sys_utt_tok, usr_utt_tok, slot_last_occurrence):
usr_utt_tok_label = [0 for _ in usr_utt_tok]
if label == 'none' or label == 'dontcare':
class_type = label
else:
in_usr, usr_pos, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok)
if in_usr:
class_type = 'copy_value'
if slot_last_occurrence:
(s, e) = usr_pos[-1]
for i in range(s, e):
usr_utt_tok_label[i] = 1
else:
for (s, e) in usr_pos:
for i in range(s, e):
usr_utt_tok_label[i] = 1
elif in_sys:
class_type = 'inform'
else:
class_type = 'unpointable'
return usr_utt_tok_label, class_type
def tokenize(utt):
utt_lower = convert_to_unicode(utt).lower()
utt_tok = [tok for tok in map(str.strip, re.split("(\W+)", utt_lower)) if len(tok) > 0]
return utt_tok
def create_examples(input_file, set_type, slot_list,
label_maps={},
append_history=False,
use_history_labels=False,
swap_utterances=False,
label_value_repetitions=False,
delexicalize_sys_utts=False,
analyze=False):
"""Read a DST json file into a list of DSTExample."""
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)
global LABEL_MAPS
LABEL_MAPS = label_maps
examples = []
for entry in input_data:
diag_seen_slots_dict = {}
diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list}
diag_state = {slot: 'none' for slot in slot_list}
sys_utt_tok = []
sys_utt_tok_delex = []
usr_utt_tok = []
hst_utt_tok = []
hst_utt_tok_label_dict = {slot: [] for slot in slot_list}
for turn in entry['dialogue']:
sys_utt_tok_label_dict = {}
usr_utt_tok_label_dict = {}
inform_dict = {slot: 'none' for slot in slot_list}
inform_slot_dict = {slot: 0 for slot in slot_list}
referral_dict = {}
class_type_dict = {}
# Collect turn data
if append_history:
if swap_utterances:
if delexicalize_sys_utts:
hst_utt_tok = usr_utt_tok + sys_utt_tok_delex + hst_utt_tok
else:
hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok
else:
if delexicalize_sys_utts:
hst_utt_tok = sys_utt_tok_delex + usr_utt_tok + hst_utt_tok
else:
hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok
sys_utt_tok = tokenize(turn['system_transcript'])
usr_utt_tok = tokenize(turn['transcript'])
turn_label = {LABEL_FIX.get(s.strip(), s.strip()): LABEL_FIX.get(v.strip(), v.strip()) for s, v in turn['turn_label']}
guid = '%s-%s-%s' % (set_type, str(entry['dialogue_idx']), str(turn['turn_idx']))
# Create delexicalized sys utterances.
if delexicalize_sys_utts:
delex_dict = {}
for slot in slot_list:
delex_dict[slot] = 'none'
label = 'none'
if slot in turn_label:
label = turn_label[slot]
elif label_value_repetitions and slot in diag_seen_slots_dict:
label = diag_seen_slots_value_dict[slot]
if label != 'none' and label != 'dontcare':
_, _, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok)
if in_sys:
delex_dict[slot] = label
sys_utt_tok_delex = delex_utt(sys_utt_tok, delex_dict)
new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy()
new_diag_state = diag_state.copy()
for slot in slot_list:
label = 'none'
if slot in turn_label:
label = turn_label[slot]
elif label_value_repetitions and slot in diag_seen_slots_dict:
label = diag_seen_slots_value_dict[slot]
(usr_utt_tok_label,
class_type) = get_turn_label(label,
sys_utt_tok,
usr_utt_tok,
slot_last_occurrence=True)
if class_type == 'inform':
inform_dict[slot] = label
if label != 'none':
inform_slot_dict[slot] = 1
referral_dict[slot] = 'none' # Referral is not present in woz2 data
# Generally don't use span prediction on sys utterance (but inform prediction instead).
if delexicalize_sys_utts:
sys_utt_tok_label = [0 for _ in sys_utt_tok_delex]
else:
sys_utt_tok_label = [0 for _ in sys_utt_tok]
# Determine what to do with value repetitions.
# If value is unique in seen slots, then tag it, otherwise not,
# since correct slot assignment can not be guaranteed anymore.
if label_value_repetitions and slot in diag_seen_slots_dict:
if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(label) > 1:
class_type = 'none'
usr_utt_tok_label = [0 for _ in usr_utt_tok_label]
sys_utt_tok_label_dict[slot] = sys_utt_tok_label
usr_utt_tok_label_dict[slot] = usr_utt_tok_label
if append_history:
if use_history_labels:
if swap_utterances:
new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[slot]
else:
new_hst_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]
else:
new_hst_utt_tok_label_dict[slot] = [0 for _ in sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]]
# For now, we map all occurences of unpointable slot values
# to none. However, since the labels will still suggest
# a presence of unpointable slot values, the task of the
# DST is still to find those values. It is just not
# possible to do that via span prediction on the current input.
if class_type == 'unpointable':
class_type_dict[slot] = 'none'
elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform':
# If slot has seen before and its class type did not change, label this slot a not present,
# assuming that the slot has not actually been mentioned in this turn.
# Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform,
# this must mean there is evidence in the original labels, therefore consider
# them as mentioned again.
class_type_dict[slot] = 'none'
referral_dict[slot] = 'none'
else:
class_type_dict[slot] = class_type
# Remember that this slot was mentioned during this dialog already.
if class_type != 'none':
diag_seen_slots_dict[slot] = class_type
diag_seen_slots_value_dict[slot] = label
new_diag_state[slot] = class_type
# Unpointable is not a valid class, therefore replace with
# some valid class for now...
if class_type == 'unpointable':
new_diag_state[slot] = 'copy_value'
if swap_utterances:
txt_a = usr_utt_tok
if delexicalize_sys_utts:
txt_b = sys_utt_tok_delex
else:
txt_b = sys_utt_tok
txt_a_lbl = usr_utt_tok_label_dict
txt_b_lbl = sys_utt_tok_label_dict
else:
if delexicalize_sys_utts:
txt_a = sys_utt_tok_delex
else:
txt_a = sys_utt_tok
txt_b = usr_utt_tok
txt_a_lbl = sys_utt_tok_label_dict
txt_b_lbl = usr_utt_tok_label_dict
examples.append(DSTExample(
guid=guid,
text_a=txt_a,
text_b=txt_b,
history=hst_utt_tok,
text_a_label=txt_a_lbl,
text_b_label=txt_b_lbl,
history_label=hst_utt_tok_label_dict,
values=diag_seen_slots_value_dict.copy(),
inform_label=inform_dict,
inform_slot_label=inform_slot_dict,
refer_label=referral_dict,
diag_state=diag_state,
class_label=class_type_dict))
# Update some variables.
hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy()
diag_state = new_diag_state.copy()
return examples
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import json
import sys
import numpy as np
import re
def load_dataset_config(dataset_config):
with open(dataset_config, "r", encoding='utf-8') as f:
raw_config = json.load(f)
return raw_config['class_types'], raw_config['slots'], raw_config['label_maps']
def tokenize(text):
if "\u0120" in text:
text = re.sub(" ", "", text)
text = re.sub("\u0120", " ", text)
text = text.strip()
return ' '.join([tok for tok in map(str.strip, re.split("(\W+)", text)) if len(tok) > 0])
def is_in_list(tok, value):
found = False
tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0]
value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0]
tok_len = len(tok_list)
value_len = len(value_list)
for i in range(tok_len + 1 - value_len):
if tok_list[i:i + value_len] == value_list:
found = True
break
return found
def check_slot_inform(value_label, inform_label, label_maps):
value = inform_label
if value_label == inform_label:
value = value_label
elif is_in_list(inform_label, value_label):
value = value_label
elif is_in_list(value_label, inform_label):
value = value_label
elif inform_label in label_maps:
for inform_label_variant in label_maps[inform_label]:
if value_label == inform_label_variant:
value = value_label
break
elif is_in_list(inform_label_variant, value_label):
value = value_label
break
elif is_in_list(value_label, inform_label_variant):
value = value_label
break
elif value_label in label_maps:
for value_label_variant in label_maps[value_label]:
if value_label_variant == inform_label:
value = value_label
break
elif is_in_list(inform_label, value_label_variant):
value = value_label
break
elif is_in_list(value_label_variant, inform_label):
value = value_label
break
return value
def get_joint_slot_correctness(fp, class_types, label_maps,
key_class_label_id='class_label_id',
key_class_prediction='class_prediction',
key_start_pos='start_pos',
key_start_prediction='start_prediction',
key_end_pos='end_pos',
key_end_prediction='end_prediction',
key_refer_id='refer_id',
key_refer_prediction='refer_prediction',
key_slot_groundtruth='slot_groundtruth',
key_slot_prediction='slot_prediction'):
with open(fp) as f:
preds = json.load(f)
class_correctness = [[] for cl in range(len(class_types) + 1)]
confusion_matrix = [[[] for cl_b in range(len(class_types))] for cl_a in range(len(class_types))]
pos_correctness = []
refer_correctness = []
val_correctness = []
total_correctness = []
c_tp = {ct: 0 for ct in range(len(class_types))}
c_tn = {ct: 0 for ct in range(len(class_types))}
c_fp = {ct: 0 for ct in range(len(class_types))}
c_fn = {ct: 0 for ct in range(len(class_types))}
for pred in preds:
guid = pred['guid'] # List: set_type, dialogue_idx, turn_idx
turn_gt_class = pred[key_class_label_id]
turn_pd_class = pred[key_class_prediction]
gt_start_pos = pred[key_start_pos]
pd_start_pos = pred[key_start_prediction]
gt_end_pos = pred[key_end_pos]
pd_end_pos = pred[key_end_prediction]
gt_refer = pred[key_refer_id]
pd_refer = pred[key_refer_prediction]
gt_slot = pred[key_slot_groundtruth]
pd_slot = pred[key_slot_prediction]
gt_slot = tokenize(gt_slot)
pd_slot = tokenize(pd_slot)
# Make sure the true turn labels are contained in the prediction json file!
joint_gt_slot = gt_slot
if guid[-1] == '0': # First turn, reset the slots
joint_pd_slot = 'none'
# If turn_pd_class or a value to be copied is "none", do not update the dialog state.
if turn_pd_class == class_types.index('none'):
pass
elif turn_pd_class == class_types.index('dontcare'):
joint_pd_slot = 'dontcare'
elif turn_pd_class == class_types.index('copy_value'):
joint_pd_slot = pd_slot
elif 'true' in class_types and turn_pd_class == class_types.index('true'):
joint_pd_slot = 'true'
elif 'false' in class_types and turn_pd_class == class_types.index('false'):
joint_pd_slot = 'false'
elif 'refer' in class_types and turn_pd_class == class_types.index('refer'):
if pd_slot[0:3] == "§§ ":
if pd_slot[3:] != 'none':
joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[3:], label_maps)
elif pd_slot[0:2] == "§§":
if pd_slot[2:] != 'none':
joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:], label_maps)
elif pd_slot != 'none':
joint_pd_slot = pd_slot
elif 'inform' in class_types and turn_pd_class == class_types.index('inform'):
if pd_slot[0:3] == "§§ ":
if pd_slot[3:] != 'none':
joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[3:], label_maps)
elif pd_slot[0:2] == "§§":
if pd_slot[2:] != 'none':
joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:], label_maps)
else:
print("ERROR: Unexpected slot value format. Aborting.")
exit()
else:
print("ERROR: Unexpected class_type. Aborting.")
exit()
total_correct = True
# Check the per turn correctness of the class_type prediction
if turn_gt_class == turn_pd_class:
class_correctness[turn_gt_class].append(1.0)
class_correctness[-1].append(1.0)
c_tp[turn_gt_class] += 1
for cc in range(len(class_types)):
if cc != turn_gt_class:
c_tn[cc] += 1
# Only where there is a span, we check its per turn correctness
if turn_gt_class == class_types.index('copy_value'):
if gt_start_pos == pd_start_pos and gt_end_pos == pd_end_pos:
pos_correctness.append(1.0)
else:
pos_correctness.append(0.0)
# Only where there is a referral, we check its per turn correctness
if 'refer' in class_types and turn_gt_class == class_types.index('refer'):
if gt_refer == pd_refer:
refer_correctness.append(1.0)
print(" [%s] Correct referral: %s | %s" % (guid, gt_refer, pd_refer))
else:
refer_correctness.append(0.0)
print(" [%s] Incorrect referral: %s | %s" % (guid, gt_refer, pd_refer))
else:
if turn_gt_class == class_types.index('copy_value'):
pos_correctness.append(0.0)
if 'refer' in class_types and turn_gt_class == class_types.index('refer'):
refer_correctness.append(0.0)
class_correctness[turn_gt_class].append(0.0)
class_correctness[-1].append(0.0)
confusion_matrix[turn_gt_class][turn_pd_class].append(1.0)
c_fn[turn_gt_class] += 1
c_fp[turn_pd_class] += 1
# Check the joint slot correctness.
# If the value label is not none, then we need to have a value prediction.
# Even if the class_type is 'none', there can still be a value label,
# it might just not be pointable in the current turn. It might however
# be referrable and thus predicted correctly.
if joint_gt_slot == joint_pd_slot:
val_correctness.append(1.0)
elif joint_gt_slot != 'none' and joint_gt_slot != 'dontcare' and joint_gt_slot != 'true' and joint_gt_slot != 'false' and joint_gt_slot in label_maps:
no_match = True
for variant in label_maps[joint_gt_slot]:
if variant == joint_pd_slot:
no_match = False
break
if no_match:
val_correctness.append(0.0)
total_correct = False
print(" [%s] Incorrect value (variant): %s (turn class: %s) | %s (turn class: %s)" % (guid, joint_gt_slot, turn_gt_class, joint_pd_slot, turn_pd_class))
else:
val_correctness.append(1.0)
else:
val_correctness.append(0.0)
total_correct = False
print(" [%s] Incorrect value: %s (turn class: %s) | %s (turn class: %s)" % (guid, joint_gt_slot, turn_gt_class, joint_pd_slot, turn_pd_class))
total_correctness.append(1.0 if total_correct else 0.0)
# Account for empty lists (due to no instances of spans or referrals being seen)
if pos_correctness == []:
pos_correctness.append(1.0)
if refer_correctness == []:
refer_correctness.append(1.0)
for ct in range(len(class_types)):
if c_tp[ct] + c_fp[ct] > 0:
precision = c_tp[ct] / (c_tp[ct] + c_fp[ct])
else:
precision = 1.0
if c_tp[ct] + c_fn[ct] > 0:
recall = c_tp[ct] / (c_tp[ct] + c_fn[ct])
else:
recall = 1.0
if precision + recall > 0:
f1 = 2 * ((precision * recall) / (precision + recall))
else:
f1 = 1.0
if c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct] > 0:
acc = (c_tp[ct] + c_tn[ct]) / (c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct])
else:
acc = 1.0
print("Performance for class '%s' (%s): Recall: %.2f (%d of %d), Precision: %.2f, F1: %.2f, Accuracy: %.2f (TP/TN/FP/FN: %d/%d/%d/%d)" %
(class_types[ct], ct, recall, np.sum(class_correctness[ct]), len(class_correctness[ct]), precision, f1, acc, c_tp[ct], c_tn[ct], c_fp[ct], c_fn[ct]))
print("Confusion matrix:")
for cl in range(len(class_types)):
print(" %s" % (cl), end="")
print("")
for cl_a in range(len(class_types)):
print("%s " % (cl_a), end="")
for cl_b in range(len(class_types)):
if len(class_correctness[cl_a]) > 0:
print("%.2f " % (np.sum(confusion_matrix[cl_a][cl_b]) / len(class_correctness[cl_a])), end="")
else:
print("---- ", end="")
print("")
return np.asarray(total_correctness), np.asarray(val_correctness), np.asarray(class_correctness), np.asarray(pos_correctness), np.asarray(refer_correctness), np.asarray(confusion_matrix), c_tp, c_tn, c_fp, c_fn
if __name__ == "__main__":
acc_list = []
acc_list_v = []
key_class_label_id = 'class_label_id_%s'
key_class_prediction = 'class_prediction_%s'
key_start_pos = 'start_pos_%s'
key_start_prediction = 'start_prediction_%s'
key_end_pos = 'end_pos_%s'
key_end_prediction = 'end_prediction_%s'
key_refer_id = 'refer_id_%s'
key_refer_prediction = 'refer_prediction_%s'
key_slot_groundtruth = 'slot_groundtruth_%s'
key_slot_prediction = 'slot_prediction_%s'
dataset = sys.argv[1].lower()
dataset_config = sys.argv[2].lower()
if dataset not in ['woz2', 'sim-m', 'sim-r', 'multiwoz21']:
raise ValueError("Task not found: %s" % (dataset))
class_types, slots, label_maps = load_dataset_config(dataset_config)
# Prepare label_maps
label_maps_tmp = {}
for v in label_maps:
label_maps_tmp[tokenize(v)] = [tokenize(nv) for nv in label_maps[v]]
label_maps = label_maps_tmp
for fp in sorted(glob.glob(sys.argv[3])):
print(fp)
goal_correctness = 1.0
cls_acc = [[] for cl in range(len(class_types))]
cls_conf = [[[] for cl_b in range(len(class_types))] for cl_a in range(len(class_types))]
c_tp = {ct: 0 for ct in range(len(class_types))}
c_tn = {ct: 0 for ct in range(len(class_types))}
c_fp = {ct: 0 for ct in range(len(class_types))}
c_fn = {ct: 0 for ct in range(len(class_types))}
for slot in slots:
tot_cor, joint_val_cor, cls_cor, pos_cor, ref_cor, conf_mat, ctp, ctn, cfp, cfn = get_joint_slot_correctness(fp, class_types, label_maps,
key_class_label_id=(key_class_label_id % slot),
key_class_prediction=(key_class_prediction % slot),
key_start_pos=(key_start_pos % slot),
key_start_prediction=(key_start_prediction % slot),
key_end_pos=(key_end_pos % slot),
key_end_prediction=(key_end_prediction % slot),
key_refer_id=(key_refer_id % slot),
key_refer_prediction=(key_refer_prediction % slot),
key_slot_groundtruth=(key_slot_groundtruth % slot),
key_slot_prediction=(key_slot_prediction % slot)
)
print('%s: joint slot acc: %g, joint value acc: %g, turn class acc: %g, turn position acc: %g, turn referral acc: %g' %
(slot, np.mean(tot_cor), np.mean(joint_val_cor), np.mean(cls_cor[-1]), np.mean(pos_cor), np.mean(ref_cor)))
goal_correctness *= tot_cor
for cl_a in range(len(class_types)):
cls_acc[cl_a] += cls_cor[cl_a]
for cl_b in range(len(class_types)):
cls_conf[cl_a][cl_b] += list(conf_mat[cl_a][cl_b])
c_tp[cl_a] += ctp[cl_a]
c_tn[cl_a] += ctn[cl_a]
c_fp[cl_a] += cfp[cl_a]
c_fn[cl_a] += cfn[cl_a]
for ct in range(len(class_types)):
if c_tp[ct] + c_fp[ct] > 0:
precision = c_tp[ct] / (c_tp[ct] + c_fp[ct])
else:
precision = 1.0
if c_tp[ct] + c_fn[ct] > 0:
recall = c_tp[ct] / (c_tp[ct] + c_fn[ct])
else:
recall = 1.0
if precision + recall > 0:
f1 = 2 * ((precision * recall) / (precision + recall))
else:
f1 = 1.0
if c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct] > 0:
acc = (c_tp[ct] + c_tn[ct]) / (c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct])
else:
acc = 1.0
print("Performance for class '%s' (%s): Recall: %.2f (%d of %d), Precision: %.2f, F1: %.2f, Accuracy: %.2f (TP/TN/FP/FN: %d/%d/%d/%d)" %
(class_types[ct], ct, recall, np.sum(cls_acc[ct]), len(cls_acc[ct]), precision, f1, acc, c_tp[ct], c_tn[ct], c_fp[ct], c_fn[ct]))
print("Confusion matrix:")
for cl in range(len(class_types)):
print(" %s" % (cl), end="")
print("")
for cl_a in range(len(class_types)):
print("%s " % (cl_a), end="")
for cl_b in range(len(class_types)):
if len(cls_acc[cl_a]) > 0:
print("%.2f " % (np.sum(cls_conf[cl_a][cl_b]) / len(cls_acc[cl_a])), end="")
else:
print("---- ", end="")
print("")
acc = np.mean(goal_correctness)
acc_list.append((fp, acc))
acc_list_s = sorted(acc_list, key=lambda tup: tup[1], reverse=True)
for (fp, acc) in acc_list_s:
print('Joint goal acc: %g, %s' % (acc, fp))
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
# Part of this code is based on the source code of Transformers
# (arXiv:1910.03771)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable)
from transformers.modeling_bert import (BertModel, BertPreTrainedModel, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
@add_start_docstrings(
"""BERT Model with a classification heads for the DST task. """,
BERT_START_DOCSTRING,
)
class BertForDST(BertPreTrainedModel):
def __init__(self, config):
super(BertForDST, self).__init__(config)
self.slot_list = config.dst_slot_list
self.class_types = config.dst_class_types
self.class_labels = config.dst_class_labels
self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
self.class_aux_feats_inform = config.dst_class_aux_feats_inform
self.class_aux_feats_ds = config.dst_class_aux_feats_ds
self.class_loss_ratio = config.dst_class_loss_ratio
# Only use refer loss if refer class is present in dataset.
if 'refer' in self.class_types:
self.refer_index = self.class_types.index('refer')
else:
self.refer_index = -1
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.dst_dropout_rate)
self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
if self.class_aux_feats_inform:
self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
if self.class_aux_feats_ds:
self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
for slot in self.slot_list:
self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def forward(self,
input_ids,
input_mask=None,
segment_ids=None,
position_ids=None,
head_mask=None,
start_pos=None,
end_pos=None,
inform_slot_id=None,
refer_id=None,
class_label_id=None,
diag_state=None):
outputs = self.bert(
input_ids,
attention_mask=input_mask,
token_type_ids=segment_ids,
position_ids=position_ids,
head_mask=head_mask
)
sequence_output = outputs[0]
pooled_output = outputs[1]
sequence_output = self.dropout(sequence_output)
pooled_output = self.dropout(pooled_output)
# TODO: establish proper format in labels already?
if inform_slot_id is not None:
inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
if diag_state is not None:
diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
total_loss = 0
per_slot_per_example_loss = {}
per_slot_class_logits = {}
per_slot_start_logits = {}
per_slot_end_logits = {}
per_slot_refer_logits = {}
for slot in self.slot_list:
if self.class_aux_feats_inform and self.class_aux_feats_ds:
pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
elif self.class_aux_feats_inform:
pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
elif self.class_aux_feats_ds:
pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
else:
pooled_output_aux = pooled_output
class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
start_logits, end_logits = token_logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
per_slot_class_logits[slot] = class_logits
per_slot_start_logits[slot] = start_logits
per_slot_end_logits[slot] = end_logits
per_slot_refer_logits[slot] = refer_logits
# If there are no labels, don't compute loss
if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
# If we are on multi-GPU, split add a dimension
if len(start_pos[slot].size()) > 1:
start_pos[slot] = start_pos[slot].squeeze(-1)
if len(end_pos[slot].size()) > 1:
end_pos[slot] = end_pos[slot].squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) # This is a single index
start_pos[slot].clamp_(0, ignored_index)
end_pos[slot].clamp_(0, ignored_index)
class_loss_fct = CrossEntropyLoss(reduction='none')
token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
refer_loss_fct = CrossEntropyLoss(reduction='none')
start_loss = token_loss_fct(start_logits, start_pos[slot])
end_loss = token_loss_fct(end_logits, end_pos[slot])
token_loss = (start_loss + end_loss) / 2.0
token_is_pointable = (start_pos[slot] > 0).float()
if not self.token_loss_for_nonpointable:
token_loss *= token_is_pointable
refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
if not self.refer_loss_for_nonpointable:
refer_loss *= token_is_referrable
class_loss = class_loss_fct(class_logits, class_label_id[slot])
if self.refer_index > -1:
per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
else:
per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
total_loss += per_example_loss.sum()
per_slot_per_example_loss[slot] = per_example_loss
# add hidden states and attention if they are here
outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:]
return outputs
This diff is collapsed.
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.utils.data import Dataset
class TensorListDataset(Dataset):
r"""Dataset wrapping tensors, tensor dicts and tensor lists.
Arguments:
*data (Tensor or dict or list of Tensors): tensors that have the same size
of the first dimension.
"""
def __init__(self, *data):
if isinstance(data[0], dict):
size = list(data[0].values())[0].size(0)
elif isinstance(data[0], list):
size = data[0][0].size(0)
else:
size = data[0].size(0)
for element in data:
if isinstance(element, dict):
assert all(size == tensor.size(0) for name, tensor in element.items()) # dict of tensors
elif isinstance(element, list):
assert all(size == tensor.size(0) for tensor in element) # list of tensors
else:
assert size == element.size(0) # tensor
self.size = size
self.data = data
def __getitem__(self, index):
result = []
for element in self.data:
if isinstance(element, dict):
result.append({k: v[index] for k, v in element.items()})
elif isinstance(element, list):
result.append(v[index] for v in element)
else:
result.append(element[index])
return tuple(result)
def __len__(self):
return self.size
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
# Part of this code is based on the source code of Transformers
# (arXiv:1910.03771)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import six
import numpy as np
import json
logger = logging.getLogger(__name__)
class DSTExample(object):
"""
A single training/test example for the DST dataset.
"""
def __init__(self,
guid,
text_a,
text_b,
history,
text_a_label=None,
text_b_label=None,
history_label=None,
values=None,
inform_label=None,
inform_slot_label=None,
refer_label=None,
diag_state=None,
class_label=None):
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.history = history
self.text_a_label = text_a_label
self.text_b_label = text_b_label
self.history_label = history_label
self.values = values
self.inform_label = inform_label
self.inform_slot_label = inform_slot_label
self.refer_label = refer_label
self.diag_state = diag_state
self.class_label = class_label
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += "guid: %s" % (self.guid)
s += ", text_a: %s" % (self.text_a)
s += ", text_b: %s" % (self.text_b)
s += ", history: %s" % (self.history)
if self.text_a_label:
s += ", text_a_label: %d" % (self.text_a_label)
if self.text_b_label:
s += ", text_b_label: %d" % (self.text_b_label)
if self.history_label:
s += ", history_label: %d" % (self.history_label)
if self.values:
s += ", values: %d" % (self.values)
if self.inform_label:
s += ", inform_label: %d" % (self.inform_label)
if self.inform_slot_label:
s += ", inform_slot_label: %d" % (self.inform_slot_label)
if self.refer_label:
s += ", refer_label: %d" % (self.refer_label)
if self.diag_state:
s += ", diag_state: %d" % (self.diag_state)
if self.class_label:
s += ", class_label: %d" % (self.class_label)
return s
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
input_ids,
input_ids_unmasked,
input_mask,
segment_ids,
start_pos=None,
end_pos=None,
values=None,
inform=None,
inform_slot=None,
refer_id=None,
diag_state=None,
class_label_id=None,
guid="NONE"):
self.guid = guid
self.input_ids = input_ids
self.input_ids_unmasked = input_ids_unmasked
self.input_mask = input_mask
self.segment_ids = segment_ids
self.start_pos = start_pos
self.end_pos = end_pos
self.values = values
self.inform = inform
self.inform_slot = inform_slot
self.refer_id = refer_id
self.diag_state = diag_state
self.class_label_id = class_label_id
def convert_examples_to_features(examples, slot_list, class_types, model_type, tokenizer, max_seq_length, slot_value_dropout=0.0):
"""Loads a data file into a list of `InputBatch`s."""
if model_type == 'bert':
model_specs = {'MODEL_TYPE': 'bert',
'CLS_TOKEN': '[CLS]',
'UNK_TOKEN': '[UNK]',
'SEP_TOKEN': '[SEP]',
'TOKEN_CORRECTION': 4}
else:
logger.error("Unknown model type (%s). Aborting." % (model_type))
exit(1)
def _tokenize_text_and_label(text, text_label_dict, slot, tokenizer, model_specs, slot_value_dropout):
joint_text_label = [0 for _ in text_label_dict[slot]] # joint all slots' label
for slot_text_label in text_label_dict.values():
for idx, label in enumerate(slot_text_label):
if label == 1:
joint_text_label[idx] = 1
text_label = text_label_dict[slot]
tokens = []
tokens_unmasked = []
token_labels = []
for token, token_label, joint_label in zip(text, text_label, joint_text_label):
token = convert_to_unicode(token)
sub_tokens = tokenizer.tokenize(token) # Most time intensive step
tokens_unmasked.extend(sub_tokens)
if slot_value_dropout == 0.0 or joint_label == 0:
tokens.extend(sub_tokens)
else:
rn_list = np.random.random_sample((len(sub_tokens),))
for rn, sub_token in zip(rn_list, sub_tokens):
if rn > slot_value_dropout:
tokens.append(sub_token)
else:
tokens.append(model_specs['UNK_TOKEN'])
token_labels.extend([token_label for _ in sub_tokens])
assert len(tokens) == len(token_labels)
assert len(tokens_unmasked) == len(token_labels)
return tokens, tokens_unmasked, token_labels
def _truncate_seq_pair(tokens_a, tokens_b, history, max_length):
"""Truncates a sequence pair in place to the maximum length.
Copied from bert/run_classifier.py
"""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b) + len(history)
if total_length <= max_length:
break
if len(history) > 0:
history.pop()
elif len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, model_specs, guid):
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP], [SEP] with "- 4" (BERT)
if len(tokens_a) + len(tokens_b) + len(history) > max_seq_length - model_specs['TOKEN_CORRECTION']:
logger.info("Truncate Example %s. Total len=%d." % (guid, len(tokens_a) + len(tokens_b) + len(history)))
input_text_too_long = True
else:
input_text_too_long = False
_truncate_seq_pair(tokens_a, tokens_b, history, max_seq_length - model_specs['TOKEN_CORRECTION'])
return input_text_too_long
def _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs):
token_label_ids = []
token_label_ids.append(0) # [CLS]
for token_label in token_labels_a:
token_label_ids.append(token_label)
token_label_ids.append(0) # [SEP]
for token_label in token_labels_b:
token_label_ids.append(token_label)
token_label_ids.append(0) # [SEP]
for token_label in token_labels_history:
token_label_ids.append(token_label)
token_label_ids.append(0) # [SEP]
while len(token_label_ids) < max_seq_length:
token_label_ids.append(0) # padding
assert len(token_label_ids) == max_seq_length
return token_label_ids
def _get_start_end_pos(class_type, token_label_ids, max_seq_length):
if class_type == 'copy_value' and 1 not in token_label_ids:
#logger.warn("copy_value label, but token_label not detected. Setting label to 'none'.")
class_type = 'none'
start_pos = 0
end_pos = 0
if 1 in token_label_ids:
start_pos = token_label_ids.index(1)
# Parsing is supposed to find only first location of wanted value
if 0 not in token_label_ids[start_pos:]:
end_pos = len(token_label_ids[start_pos:]) + start_pos - 1
else:
end_pos = token_label_ids[start_pos:].index(0) + start_pos - 1
for i in range(max_seq_length):
if i >= start_pos and i <= end_pos:
assert token_label_ids[i] == 1
return class_type, start_pos, end_pos
def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, tokenizer, model_specs):
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
segment_ids = []
tokens.append(model_specs['CLS_TOKEN'])
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append(model_specs['SEP_TOKEN'])
segment_ids.append(0)
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append(model_specs['SEP_TOKEN'])
segment_ids.append(1)
for token in history:
tokens.append(token)
segment_ids.append(1)
tokens.append(model_specs['SEP_TOKEN'])
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
return tokens, input_ids, input_mask, segment_ids
total_cnt = 0
too_long_cnt = 0
refer_list = ['none'] + slot_list
features = []
# Convert single example
for (example_index, example) in enumerate(examples):
if example_index % 1000 == 0:
logger.info("Writing example %d of %d" % (example_index, len(examples)))
total_cnt += 1
value_dict = {}
inform_dict = {}
inform_slot_dict = {}
refer_id_dict = {}
diag_state_dict = {}
class_label_id_dict = {}
start_pos_dict = {}
end_pos_dict = {}
for slot in slot_list:
tokens_a, tokens_a_unmasked, token_labels_a = _tokenize_text_and_label(
example.text_a, example.text_a_label, slot, tokenizer, model_specs, slot_value_dropout)
tokens_b, tokens_b_unmasked, token_labels_b = _tokenize_text_and_label(
example.text_b, example.text_b_label, slot, tokenizer, model_specs, slot_value_dropout)
tokens_history, tokens_history_unmasked, token_labels_history = _tokenize_text_and_label(
example.history, example.history_label, slot, tokenizer, model_specs, slot_value_dropout)
input_text_too_long = _truncate_length_and_warn(
tokens_a, tokens_b, tokens_history, max_seq_length, model_specs, example.guid)
if input_text_too_long:
if example_index < 10:
if len(token_labels_a) > len(tokens_a):
logger.info(' tokens_a truncated labels: %s' % str(token_labels_a[len(tokens_a):]))
if len(token_labels_b) > len(tokens_b):
logger.info(' tokens_b truncated labels: %s' % str(token_labels_b[len(tokens_b):]))
if len(token_labels_history) > len(tokens_history):
logger.info(' tokens_history truncated labels: %s' % str(token_labels_history[len(tokens_history):]))
token_labels_a = token_labels_a[:len(tokens_a)]
token_labels_b = token_labels_b[:len(tokens_b)]
token_labels_history = token_labels_history[:len(tokens_history)]
tokens_a_unmasked = tokens_a_unmasked[:len(tokens_a)]
tokens_b_unmasked = tokens_b_unmasked[:len(tokens_b)]
tokens_history_unmasked = tokens_history_unmasked[:len(tokens_history)]
assert len(token_labels_a) == len(tokens_a)
assert len(token_labels_b) == len(tokens_b)
assert len(token_labels_history) == len(tokens_history)
assert len(token_labels_a) == len(tokens_a_unmasked)
assert len(token_labels_b) == len(tokens_b_unmasked)
assert len(token_labels_history) == len(tokens_history_unmasked)
token_label_ids = _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs)
value_dict[slot] = example.values[slot]
inform_dict[slot] = example.inform_label[slot]
class_label_mod, start_pos_dict[slot], end_pos_dict[slot] = _get_start_end_pos(
example.class_label[slot], token_label_ids, max_seq_length)
if class_label_mod != example.class_label[slot]:
example.class_label[slot] = class_label_mod
inform_slot_dict[slot] = example.inform_slot_label[slot]
refer_id_dict[slot] = refer_list.index(example.refer_label[slot])
diag_state_dict[slot] = class_types.index(example.diag_state[slot])
class_label_id_dict[slot] = class_types.index(example.class_label[slot])
if input_text_too_long:
too_long_cnt += 1
tokens, input_ids, input_mask, segment_ids = _get_transformer_input(tokens_a,
tokens_b,
tokens_history,
max_seq_length,
tokenizer,
model_specs)
if slot_value_dropout > 0.0:
_, input_ids_unmasked, _, _ = _get_transformer_input(tokens_a_unmasked,
tokens_b_unmasked,
tokens_history_unmasked,
max_seq_length,
tokenizer,
model_specs)
else:
input_ids_unmasked = input_ids
assert(len(input_ids) == len(input_ids_unmasked))
if example_index < 10:
logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid))
logger.info("tokens: %s" % " ".join(tokens))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
logger.info("start_pos: %s" % str(start_pos_dict))
logger.info("end_pos: %s" % str(end_pos_dict))
logger.info("values: %s" % str(value_dict))
logger.info("inform: %s" % str(inform_dict))
logger.info("inform_slot: %s" % str(inform_slot_dict))
logger.info("refer_id: %s" % str(refer_id_dict))
logger.info("diag_state: %s" % str(diag_state_dict))
logger.info("class_label_id: %s" % str(class_label_id_dict))
features.append(
InputFeatures(
guid=example.guid,
input_ids=input_ids,
input_ids_unmasked=input_ids_unmasked,
input_mask=input_mask,
segment_ids=segment_ids,
start_pos=start_pos_dict,
end_pos=end_pos_dict,
values=value_dict,
inform=inform_dict,
inform_slot=inform_slot_dict,
refer_id=refer_id_dict,
diag_state=diag_state_dict,
class_label_id=class_label_id_dict))
logger.info("========== %d out of %d examples have text too long" % (too_long_cnt, total_cnt))
return features
# From bert.tokenization (TF code)
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment