Skip to content
Snippets Groups Projects
Commit 2e2a8b87 authored by heck's avatar heck
Browse files

initial

parent eea0be40
No related branches found
No related tags found
No related merge requests found
#!/bin/bash
# Parameters ------------------------------------------------------
#TASK="sim-m"
#DATA_DIR="data/simulated-dialogue/sim-M"
#TASK="sim-r"
#DATA_DIR="data/simulated-dialogue/sim-R"
TASK="woz2"
DATA_DIR="data/woz2"
#TASK="multiwoz21"
#DATA_DIR="data/MULTIWOZ2.1"
# Project paths etc. ----------------------------------------------
OUT_DIR=results
mkdir -p ${OUT_DIR}
# Main ------------------------------------------------------------
for step in train dev test; do
args_add=""
if [ "$step" = "train" ]; then
args_add="--do_train --predict_type=dummy"
elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
args_add="--do_eval --predict_type=${step}"
fi
python3 run_dst.py \
--task_name=${TASK} \
--data_dir=${DATA_DIR} \
--dataset_config=dataset_config/${TASK}.json \
--model_type="bert" \
--model_name_or_path="bert-base-uncased" \
--do_lower_case \
--learning_rate=1e-4 \
--num_train_epochs=10 \
--max_seq_length=180 \
--per_gpu_train_batch_size=48 \
--per_gpu_eval_batch_size=1 \
--output_dir=${OUT_DIR} \
--save_epochs=2 \
--logging_steps=10 \
--warmup_proportion=0.1 \
--eval_all_checkpoints \
--adam_epsilon=1e-6 \
--label_value_repetitions \
--swap_utterances \
--append_history \
--use_history_labels \
--delexicalize_sys_utts \
--class_aux_feats_inform \
--class_aux_feats_ds \
${args_add} \
2>&1 | tee ${OUT_DIR}/${step}.log
fi
if [ "$step" = "dev" ] || [ "$step" = "test" ]; then
python3 ~/tools/trippy_clean/metric_bert_dst.py \
${TASK} \
dataset_config/${TASK}.json \
"${OUT_DIR}/pred_res.${step}*json" \
2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log
fi
done
#!/bin/bash
# Parameters ------------------------------------------------------
#TASK="sim-m"
#DATA_DIR="data/simulated-dialogue/sim-M"
#TASK="sim-r"
#DATA_DIR="data/simulated-dialogue/sim-R"
TASK="woz2"
DATA_DIR="data/woz2"
#TASK="multiwoz21"
#DATA_DIR="data/MULTIWOZ2.1"
# Project paths etc. ----------------------------------------------
OUT_DIR=results
mkdir -p ${OUT_DIR}
# Main ------------------------------------------------------------
for step in train dev test; do
args_add=""
if [ "$step" = "train" ]; then
args_add="--do_train --predict_type=dummy"
elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then
args_add="--do_eval --predict_type=${step}"
fi
python3 run_dst.py \
--task_name=${TASK} \
--data_dir=${DATA_DIR} \
--dataset_config=dataset_config/${TASK}.json \
--model_type="bert" \
--model_name_or_path="bert-base-uncased" \
--do_lower_case \
--learning_rate=1e-4 \
--num_train_epochs=10 \
--max_seq_length=180 \
--per_gpu_train_batch_size=48 \
--per_gpu_eval_batch_size=1 \
--output_dir=${OUT_DIR} \
--save_epochs=2 \
--logging_steps=10 \
--warmup_proportion=0.1 \
--eval_all_checkpoints \
--adam_epsilon=1e-6 \
--label_value_repetitions \
${args_add} \
2>&1 | tee ${OUT_DIR}/${step}.log
fi
if [ "$step" = "dev" ] || [ "$step" = "test" ]; then
python3 ~/tools/trippy_clean/metric_bert_dst.py \
${TASK} \
dataset_config/${TASK}.json \
"${OUT_DIR}/pred_res.${step}*json" \
2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log
fi
done
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import json
import dataset_woz2
import dataset_sim
import dataset_multiwoz21
class DataProcessor(object):
def __init__(self, dataset_config):
with open(dataset_config, "r", encoding='utf-8') as f:
raw_config = json.load(f)
self.class_types = raw_config['class_types']
self.slot_list = raw_config['slots']
self.label_maps = raw_config['label_maps']
def get_train_examples(self, data_dir, **args):
raise NotImplementedError()
def get_dev_examples(self, data_dir, **args):
raise NotImplementedError()
def get_test_examples(self, data_dir, **args):
raise NotImplementedError()
class Woz2Processor(DataProcessor):
def get_train_examples(self, data_dir, args):
return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_train_en.json'),
'train', self.slot_list, self.label_maps, **args)
def get_dev_examples(self, data_dir, args):
return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_validate_en.json'),
'dev', self.slot_list, self.label_maps, **args)
def get_test_examples(self, data_dir, args):
return dataset_woz2.create_examples(os.path.join(data_dir, 'woz_test_en.json'),
'test', self.slot_list, self.label_maps, **args)
class Multiwoz21Processor(DataProcessor):
def get_train_examples(self, data_dir, args):
return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'train_dials.json'),
os.path.join(data_dir, 'dialogue_acts.json'),
'train', self.slot_list, self.label_maps, **args)
def get_dev_examples(self, data_dir, args):
return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'val_dials.json'),
os.path.join(data_dir, 'dialogue_acts.json'),
'dev', self.slot_list, self.label_maps, **args)
def get_test_examples(self, data_dir, args):
return dataset_multiwoz21.create_examples(os.path.join(data_dir, 'test_dials.json'),
os.path.join(data_dir, 'dialogue_acts.json'),
'test', self.slot_list, self.label_maps, **args)
class SimProcessor(DataProcessor):
def get_train_examples(self, data_dir, args):
return dataset_sim.create_examples(os.path.join(data_dir, 'train.json'),
'train', self.slot_list, **args)
def get_dev_examples(self, data_dir, args):
return dataset_sim.create_examples(os.path.join(data_dir, 'dev.json'),
'dev', self.slot_list, **args)
def get_test_examples(self, data_dir, args):
return dataset_sim.create_examples(os.path.join(data_dir, 'test.json'),
'test', self.slot_list, **args)
PROCESSORS = {"woz2": Woz2Processor,
"sim-m": SimProcessor,
"sim-r": SimProcessor,
"multiwoz21": Multiwoz21Processor}
This diff is collapsed.
{
"class_types": [
"none",
"dontcare",
"copy_value",
"inform"
],
"slots": [
"date",
"movie",
"time",
"num_tickets",
"theatre_name"
],
"label_maps": {}
}
{
"class_types": [
"none",
"dontcare",
"copy_value",
"inform"
],
"slots": [
"category",
"rating",
"num_people",
"location",
"restaurant_name",
"time",
"date",
"price_range",
"meal"
],
"label_maps": {}
}
{
"class_types": [
"none",
"dontcare",
"copy_value",
"inform"
],
"slots": [
"area",
"food",
"price_range"
],
"label_maps": {
"center": [
"centre",
"downtown",
"central",
"down town",
"middle"
],
"south": [
"southern",
"southside"
],
"north": [
"northern",
"uptown",
"northside"
],
"west": [
"western",
"westside"
],
"east": [
"eastern",
"eastside"
],
"east side": [
"eastern",
"eastside"
],
"cheap": [
"low price",
"inexpensive",
"cheaper",
"low priced",
"affordable",
"nothing too expensive",
"without costing a fortune",
"cheapest",
"good deals",
"low prices",
"afford",
"on a budget",
"fair prices",
"less expensive",
"cheapeast",
"not cost an arm and a leg"
],
"moderate": [
"moderately",
"medium priced",
"medium price",
"fair price",
"fair prices",
"reasonable",
"reasonably priced",
"mid price",
"fairly priced",
"not outrageous",
"not too expensive",
"on a budget",
"mid range",
"reasonable priced",
"less expensive",
"not too pricey",
"nothing too expensive",
"nothing cheap",
"not overpriced",
"medium",
"inexpensive"
],
"expensive": [
"high priced",
"high end",
"high class",
"high quality",
"fancy",
"upscale",
"nice",
"fine dining",
"expensively priced"
],
"afghan": [
"afghanistan"
],
"african": [
"africa"
],
"asian oriental": [
"asian",
"oriental"
],
"australasian": [
"australian asian",
"austral asian"
],
"australian": [
"aussie"
],
"barbeque": [
"barbecue",
"bbq"
],
"basque": [
"bask"
],
"belgian": [
"belgium"
],
"british": [
"cotto"
],
"canapes": [
"canopy",
"canape",
"canap"
],
"catalan": [
"catalonian"
],
"corsican": [
"corsica"
],
"crossover": [
"cross over",
"over"
],
"gastropub": [
"gastro pub",
"gastro",
"gastropubs"
],
"hungarian": [
"goulash"
],
"indian": [
"india",
"indians",
"nirala"
],
"international": [
"all types of food"
],
"italian": [
"prezzo"
],
"jamaican": [
"jamaica"
],
"japanese": [
"sushi",
"beni hana"
],
"korean": [
"korea"
],
"lebanese": [
"lebanse"
],
"north american": [
"american",
"hamburger"
],
"portuguese": [
"portugese"
],
"seafood": [
"sea food",
"shellfish",
"fish"
],
"singaporean": [
"singapore"
],
"steakhouse": [
"steak house",
"steak"
],
"thai": [
"thailand",
"bangkok"
],
"traditional": [
"old fashioned",
"plain"
],
"turkish": [
"turkey"
],
"unusual": [
"unique and strange"
],
"venetian": [
"vanessa"
],
"vietnamese": [
"vietnam",
"thanh binh"
]
}
}
This diff is collapsed.
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from utils_dst import (DSTExample)
def dialogue_state_to_sv_dict(sv_list):
sv_dict = {}
for d in sv_list:
sv_dict[d['slot']] = d['value']
return sv_dict
def get_token_and_slot_label(turn):
if 'system_utterance' in turn:
sys_utt_tok = turn['system_utterance']['tokens']
sys_slot_label = turn['system_utterance']['slots']
else:
sys_utt_tok = []
sys_slot_label = []
usr_utt_tok = turn['user_utterance']['tokens']
usr_slot_label = turn['user_utterance']['slots']
return sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label
def get_tok_label(prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok,
sys_slot_label, usr_utt_tok, usr_slot_label, dial_id,
turn_id, slot_last_occurrence=True):
"""The position of the last occurrence of the slot value will be used."""
sys_utt_tok_label = [0 for _ in sys_utt_tok]
usr_utt_tok_label = [0 for _ in usr_utt_tok]
if slot_type not in cur_ds_dict:
class_type = 'none'
else:
value = cur_ds_dict[slot_type]
if value == 'dontcare' and (slot_type not in prev_ds_dict or prev_ds_dict[slot_type] != 'dontcare'):
# Only label dontcare at its first occurrence in the dialog
class_type = 'dontcare'
else: # If not none or dontcare, we have to identify whether
# there is a span, or if the value is informed
in_usr = False
in_sys = False
for label_d in usr_slot_label:
if label_d['slot'] == slot_type and value == ' '.join(
usr_utt_tok[label_d['start']:label_d['exclusive_end']]):
for idx in range(label_d['start'], label_d['exclusive_end']):
usr_utt_tok_label[idx] = 1
in_usr = True
class_type = 'copy_value'
if slot_last_occurrence:
break
if not in_usr or not slot_last_occurrence:
for label_d in sys_slot_label:
if label_d['slot'] == slot_type and value == ' '.join(
sys_utt_tok[label_d['start']:label_d['exclusive_end']]):
for idx in range(label_d['start'], label_d['exclusive_end']):
sys_utt_tok_label[idx] = 1
in_sys = True
class_type = 'inform'
if slot_last_occurrence:
break
if not in_usr and not in_sys:
assert sum(usr_utt_tok_label + sys_utt_tok_label) == 0
if (slot_type not in prev_ds_dict or value != prev_ds_dict[slot_type]):
raise ValueError('Copy value cannot found in Dial %s Turn %s' % (str(dial_id), str(turn_id)))
else:
class_type = 'none'
else:
assert sum(usr_utt_tok_label + sys_utt_tok_label) > 0
return sys_utt_tok_label, usr_utt_tok_label, class_type
def delex_utt(utt, values):
utt_delex = utt.copy()
for v in values:
utt_delex[v['start']:v['exclusive_end']] = ['[UNK]'] * (v['exclusive_end'] - v['start'])
return utt_delex
def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id,
delexicalize_sys_utts=False, slot_last_occurrence=True):
"""Make turn_label a dictionary of slot with value positions or being dontcare / none:
Turn label contains:
(1) the updates from previous to current dialogue state,
(2) values in current dialogue state explicitly mentioned in system or user utterance."""
prev_ds_dict = dialogue_state_to_sv_dict(prev_dialogue_state)
cur_ds_dict = dialogue_state_to_sv_dict(turn['dialogue_state'])
(sys_utt_tok, sys_slot_label, usr_utt_tok, usr_slot_label) = get_token_and_slot_label(turn)
sys_utt_tok_label_dict = {}
usr_utt_tok_label_dict = {}
inform_label_dict = {}
inform_slot_label_dict = {}
referral_label_dict = {}
class_type_dict = {}
for slot_type in slot_list:
inform_label_dict[slot_type] = 'none'
inform_slot_label_dict[slot_type] = 0
referral_label_dict[slot_type] = 'none' # Referral is not present in sim data
sys_utt_tok_label, usr_utt_tok_label, class_type = get_tok_label(
prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok, sys_slot_label,
usr_utt_tok, usr_slot_label, dial_id, turn_id,
slot_last_occurrence=slot_last_occurrence)
if sum(sys_utt_tok_label) > 0:
inform_label_dict[slot_type] = cur_ds_dict[slot_type]
inform_slot_label_dict[slot_type] = 1
sys_utt_tok_label = [0 for _ in sys_utt_tok_label] # Don't use token labels for sys utt
sys_utt_tok_label_dict[slot_type] = sys_utt_tok_label
usr_utt_tok_label_dict[slot_type] = usr_utt_tok_label
class_type_dict[slot_type] = class_type
if delexicalize_sys_utts:
sys_utt_tok = delex_utt(sys_utt_tok, sys_slot_label)
return (sys_utt_tok, sys_utt_tok_label_dict,
usr_utt_tok, usr_utt_tok_label_dict,
inform_label_dict, inform_slot_label_dict,
referral_label_dict, cur_ds_dict, class_type_dict)
def create_examples(input_file, set_type, slot_list,
label_maps={},
append_history=False,
use_history_labels=False,
swap_utterances=False,
label_value_repetitions=False,
delexicalize_sys_utts=False,
analyze=False):
"""Read a DST json file into a list of DSTExample."""
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)
examples = []
for entry in input_data:
dial_id = entry['dialogue_id']
prev_ds = []
hst = []
prev_hst_lbl_dict = {slot: [] for slot in slot_list}
prev_ds_lbl_dict = {slot: 'none' for slot in slot_list}
for turn_id, turn in enumerate(entry['turns']):
guid = '%s-%s-%s' % (set_type, dial_id, str(turn_id))
ds_lbl_dict = prev_ds_lbl_dict.copy()
hst_lbl_dict = prev_hst_lbl_dict.copy()
(text_a,
text_a_label,
text_b,
text_b_label,
inform_label,
inform_slot_label,
referral_label,
cur_ds_dict,
class_label) = get_turn_label(turn,
prev_ds,
slot_list,
dial_id,
turn_id,
delexicalize_sys_utts=delexicalize_sys_utts,
slot_last_occurrence=True)
if swap_utterances:
txt_a = text_b
txt_b = text_a
txt_a_lbl = text_b_label
txt_b_lbl = text_a_label
else:
txt_a = text_a
txt_b = text_b
txt_a_lbl = text_a_label
txt_b_lbl = text_b_label
value_dict = {}
for slot in slot_list:
if slot in cur_ds_dict:
value_dict[slot] = cur_ds_dict[slot]
else:
value_dict[slot] = 'none'
if class_label[slot] != 'none':
ds_lbl_dict[slot] = class_label[slot]
if append_history:
if use_history_labels:
hst_lbl_dict[slot] = txt_a_lbl[slot] + txt_b_lbl[slot] + hst_lbl_dict[slot]
else:
hst_lbl_dict[slot] = [0 for _ in txt_a_lbl[slot] + txt_b_lbl[slot] + hst_lbl_dict[slot]]
examples.append(DSTExample(
guid=guid,
text_a=txt_a,
text_b=txt_b,
history=hst,
text_a_label=txt_a_lbl,
text_b_label=txt_b_lbl,
history_label=prev_hst_lbl_dict,
values=value_dict,
inform_label=inform_label,
inform_slot_label=inform_slot_label,
refer_label=referral_label,
diag_state=prev_ds_lbl_dict,
class_label=class_label))
prev_ds = turn['dialogue_state']
prev_ds_lbl_dict = ds_lbl_dict.copy()
prev_hst_lbl_dict = hst_lbl_dict.copy()
if append_history:
hst = txt_a + txt_b + hst
return examples
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from utils_dst import (DSTExample, convert_to_unicode)
LABEL_MAPS = {} # Loaded from file
LABEL_FIX = {'centre': 'center', 'areas': 'area', 'phone number': 'number', 'price range': 'price_range'}
def delex_utt(utt, values):
utt_norm = utt.copy()
for s, v in values.items():
if v != 'none':
v_norm = tokenize(v)
v_len = len(v_norm)
for i in range(len(utt_norm) + 1 - v_len):
if utt_norm[i:i + v_len] == v_norm:
utt_norm[i:i + v_len] = ['[UNK]'] * v_len
return utt_norm
def get_token_pos(tok_list, label):
find_pos = []
found = False
label_list = [item for item in map(str.strip, re.split("(\W+)", label)) if len(item) > 0]
len_label = len(label_list)
for i in range(len(tok_list) + 1 - len_label):
if tok_list[i:i + len_label] == label_list:
find_pos.append((i, i + len_label)) # start, exclusive_end
found = True
return found, find_pos
def check_label_existence(label, usr_utt_tok, sys_utt_tok):
in_usr, usr_pos = get_token_pos(usr_utt_tok, label)
if not in_usr and label in LABEL_MAPS:
for tmp_label in LABEL_MAPS[label]:
in_usr, usr_pos = get_token_pos(usr_utt_tok, tmp_label)
if in_usr:
break
in_sys, sys_pos = get_token_pos(sys_utt_tok, label)
if not in_sys and label in LABEL_MAPS:
for tmp_label in LABEL_MAPS[label]:
in_sys, sys_pos = get_token_pos(sys_utt_tok, tmp_label)
if in_sys:
break
return in_usr, usr_pos, in_sys, sys_pos
def get_turn_label(label, sys_utt_tok, usr_utt_tok, slot_last_occurrence):
usr_utt_tok_label = [0 for _ in usr_utt_tok]
if label == 'none' or label == 'dontcare':
class_type = label
else:
in_usr, usr_pos, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok)
if in_usr:
class_type = 'copy_value'
if slot_last_occurrence:
(s, e) = usr_pos[-1]
for i in range(s, e):
usr_utt_tok_label[i] = 1
else:
for (s, e) in usr_pos:
for i in range(s, e):
usr_utt_tok_label[i] = 1
elif in_sys:
class_type = 'inform'
else:
class_type = 'unpointable'
return usr_utt_tok_label, class_type
def tokenize(utt):
utt_lower = convert_to_unicode(utt).lower()
utt_tok = [tok for tok in map(str.strip, re.split("(\W+)", utt_lower)) if len(tok) > 0]
return utt_tok
def create_examples(input_file, set_type, slot_list,
label_maps={},
append_history=False,
use_history_labels=False,
swap_utterances=False,
label_value_repetitions=False,
delexicalize_sys_utts=False,
analyze=False):
"""Read a DST json file into a list of DSTExample."""
with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)
global LABEL_MAPS
LABEL_MAPS = label_maps
examples = []
for entry in input_data:
diag_seen_slots_dict = {}
diag_seen_slots_value_dict = {slot: 'none' for slot in slot_list}
diag_state = {slot: 'none' for slot in slot_list}
sys_utt_tok = []
sys_utt_tok_delex = []
usr_utt_tok = []
hst_utt_tok = []
hst_utt_tok_label_dict = {slot: [] for slot in slot_list}
for turn in entry['dialogue']:
sys_utt_tok_label_dict = {}
usr_utt_tok_label_dict = {}
inform_dict = {slot: 'none' for slot in slot_list}
inform_slot_dict = {slot: 0 for slot in slot_list}
referral_dict = {}
class_type_dict = {}
# Collect turn data
if append_history:
if swap_utterances:
if delexicalize_sys_utts:
hst_utt_tok = usr_utt_tok + sys_utt_tok_delex + hst_utt_tok
else:
hst_utt_tok = usr_utt_tok + sys_utt_tok + hst_utt_tok
else:
if delexicalize_sys_utts:
hst_utt_tok = sys_utt_tok_delex + usr_utt_tok + hst_utt_tok
else:
hst_utt_tok = sys_utt_tok + usr_utt_tok + hst_utt_tok
sys_utt_tok = tokenize(turn['system_transcript'])
usr_utt_tok = tokenize(turn['transcript'])
turn_label = {LABEL_FIX.get(s.strip(), s.strip()): LABEL_FIX.get(v.strip(), v.strip()) for s, v in turn['turn_label']}
guid = '%s-%s-%s' % (set_type, str(entry['dialogue_idx']), str(turn['turn_idx']))
# Create delexicalized sys utterances.
if delexicalize_sys_utts:
delex_dict = {}
for slot in slot_list:
delex_dict[slot] = 'none'
label = 'none'
if slot in turn_label:
label = turn_label[slot]
elif label_value_repetitions and slot in diag_seen_slots_dict:
label = diag_seen_slots_value_dict[slot]
if label != 'none' and label != 'dontcare':
_, _, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok)
if in_sys:
delex_dict[slot] = label
sys_utt_tok_delex = delex_utt(sys_utt_tok, delex_dict)
new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy()
new_diag_state = diag_state.copy()
for slot in slot_list:
label = 'none'
if slot in turn_label:
label = turn_label[slot]
elif label_value_repetitions and slot in diag_seen_slots_dict:
label = diag_seen_slots_value_dict[slot]
(usr_utt_tok_label,
class_type) = get_turn_label(label,
sys_utt_tok,
usr_utt_tok,
slot_last_occurrence=True)
if class_type == 'inform':
inform_dict[slot] = label
if label != 'none':
inform_slot_dict[slot] = 1
referral_dict[slot] = 'none' # Referral is not present in woz2 data
# Generally don't use span prediction on sys utterance (but inform prediction instead).
if delexicalize_sys_utts:
sys_utt_tok_label = [0 for _ in sys_utt_tok_delex]
else:
sys_utt_tok_label = [0 for _ in sys_utt_tok]
# Determine what to do with value repetitions.
# If value is unique in seen slots, then tag it, otherwise not,
# since correct slot assignment can not be guaranteed anymore.
if label_value_repetitions and slot in diag_seen_slots_dict:
if class_type == 'copy_value' and list(diag_seen_slots_value_dict.values()).count(label) > 1:
class_type = 'none'
usr_utt_tok_label = [0 for _ in usr_utt_tok_label]
sys_utt_tok_label_dict[slot] = sys_utt_tok_label
usr_utt_tok_label_dict[slot] = usr_utt_tok_label
if append_history:
if use_history_labels:
if swap_utterances:
new_hst_utt_tok_label_dict[slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[slot]
else:
new_hst_utt_tok_label_dict[slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]
else:
new_hst_utt_tok_label_dict[slot] = [0 for _ in sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[slot]]
# For now, we map all occurences of unpointable slot values
# to none. However, since the labels will still suggest
# a presence of unpointable slot values, the task of the
# DST is still to find those values. It is just not
# possible to do that via span prediction on the current input.
if class_type == 'unpointable':
class_type_dict[slot] = 'none'
elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[slot] and class_type != 'copy_value' and class_type != 'inform':
# If slot has seen before and its class type did not change, label this slot a not present,
# assuming that the slot has not actually been mentioned in this turn.
# Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform,
# this must mean there is evidence in the original labels, therefore consider
# them as mentioned again.
class_type_dict[slot] = 'none'
referral_dict[slot] = 'none'
else:
class_type_dict[slot] = class_type
# Remember that this slot was mentioned during this dialog already.
if class_type != 'none':
diag_seen_slots_dict[slot] = class_type
diag_seen_slots_value_dict[slot] = label
new_diag_state[slot] = class_type
# Unpointable is not a valid class, therefore replace with
# some valid class for now...
if class_type == 'unpointable':
new_diag_state[slot] = 'copy_value'
if swap_utterances:
txt_a = usr_utt_tok
if delexicalize_sys_utts:
txt_b = sys_utt_tok_delex
else:
txt_b = sys_utt_tok
txt_a_lbl = usr_utt_tok_label_dict
txt_b_lbl = sys_utt_tok_label_dict
else:
if delexicalize_sys_utts:
txt_a = sys_utt_tok_delex
else:
txt_a = sys_utt_tok
txt_b = usr_utt_tok
txt_a_lbl = sys_utt_tok_label_dict
txt_b_lbl = usr_utt_tok_label_dict
examples.append(DSTExample(
guid=guid,
text_a=txt_a,
text_b=txt_b,
history=hst_utt_tok,
text_a_label=txt_a_lbl,
text_b_label=txt_b_lbl,
history_label=hst_utt_tok_label_dict,
values=diag_seen_slots_value_dict.copy(),
inform_label=inform_dict,
inform_slot_label=inform_slot_dict,
refer_label=referral_dict,
diag_state=diag_state,
class_label=class_type_dict))
# Update some variables.
hst_utt_tok_label_dict = new_hst_utt_tok_label_dict.copy()
diag_state = new_diag_state.copy()
return examples
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import json
import sys
import numpy as np
import re
def load_dataset_config(dataset_config):
with open(dataset_config, "r", encoding='utf-8') as f:
raw_config = json.load(f)
return raw_config['class_types'], raw_config['slots'], raw_config['label_maps']
def tokenize(text):
if "\u0120" in text:
text = re.sub(" ", "", text)
text = re.sub("\u0120", " ", text)
text = text.strip()
return ' '.join([tok for tok in map(str.strip, re.split("(\W+)", text)) if len(tok) > 0])
def is_in_list(tok, value):
found = False
tok_list = [item for item in map(str.strip, re.split("(\W+)", tok)) if len(item) > 0]
value_list = [item for item in map(str.strip, re.split("(\W+)", value)) if len(item) > 0]
tok_len = len(tok_list)
value_len = len(value_list)
for i in range(tok_len + 1 - value_len):
if tok_list[i:i + value_len] == value_list:
found = True
break
return found
def check_slot_inform(value_label, inform_label, label_maps):
value = inform_label
if value_label == inform_label:
value = value_label
elif is_in_list(inform_label, value_label):
value = value_label
elif is_in_list(value_label, inform_label):
value = value_label
elif inform_label in label_maps:
for inform_label_variant in label_maps[inform_label]:
if value_label == inform_label_variant:
value = value_label
break
elif is_in_list(inform_label_variant, value_label):
value = value_label
break
elif is_in_list(value_label, inform_label_variant):
value = value_label
break
elif value_label in label_maps:
for value_label_variant in label_maps[value_label]:
if value_label_variant == inform_label:
value = value_label
break
elif is_in_list(inform_label, value_label_variant):
value = value_label
break
elif is_in_list(value_label_variant, inform_label):
value = value_label
break
return value
def get_joint_slot_correctness(fp, class_types, label_maps,
key_class_label_id='class_label_id',
key_class_prediction='class_prediction',
key_start_pos='start_pos',
key_start_prediction='start_prediction',
key_end_pos='end_pos',
key_end_prediction='end_prediction',
key_refer_id='refer_id',
key_refer_prediction='refer_prediction',
key_slot_groundtruth='slot_groundtruth',
key_slot_prediction='slot_prediction'):
with open(fp) as f:
preds = json.load(f)
class_correctness = [[] for cl in range(len(class_types) + 1)]
confusion_matrix = [[[] for cl_b in range(len(class_types))] for cl_a in range(len(class_types))]
pos_correctness = []
refer_correctness = []
val_correctness = []
total_correctness = []
c_tp = {ct: 0 for ct in range(len(class_types))}
c_tn = {ct: 0 for ct in range(len(class_types))}
c_fp = {ct: 0 for ct in range(len(class_types))}
c_fn = {ct: 0 for ct in range(len(class_types))}
for pred in preds:
guid = pred['guid'] # List: set_type, dialogue_idx, turn_idx
turn_gt_class = pred[key_class_label_id]
turn_pd_class = pred[key_class_prediction]
gt_start_pos = pred[key_start_pos]
pd_start_pos = pred[key_start_prediction]
gt_end_pos = pred[key_end_pos]
pd_end_pos = pred[key_end_prediction]
gt_refer = pred[key_refer_id]
pd_refer = pred[key_refer_prediction]
gt_slot = pred[key_slot_groundtruth]
pd_slot = pred[key_slot_prediction]
gt_slot = tokenize(gt_slot)
pd_slot = tokenize(pd_slot)
# Make sure the true turn labels are contained in the prediction json file!
joint_gt_slot = gt_slot
if guid[-1] == '0': # First turn, reset the slots
joint_pd_slot = 'none'
# If turn_pd_class or a value to be copied is "none", do not update the dialog state.
if turn_pd_class == class_types.index('none'):
pass
elif turn_pd_class == class_types.index('dontcare'):
joint_pd_slot = 'dontcare'
elif turn_pd_class == class_types.index('copy_value'):
joint_pd_slot = pd_slot
elif 'true' in class_types and turn_pd_class == class_types.index('true'):
joint_pd_slot = 'true'
elif 'false' in class_types and turn_pd_class == class_types.index('false'):
joint_pd_slot = 'false'
elif 'refer' in class_types and turn_pd_class == class_types.index('refer'):
if pd_slot[0:3] == "§§ ":
if pd_slot[3:] != 'none':
joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[3:], label_maps)
elif pd_slot[0:2] == "§§":
if pd_slot[2:] != 'none':
joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:], label_maps)
elif pd_slot != 'none':
joint_pd_slot = pd_slot
elif 'inform' in class_types and turn_pd_class == class_types.index('inform'):
if pd_slot[0:3] == "§§ ":
if pd_slot[3:] != 'none':
joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[3:], label_maps)
elif pd_slot[0:2] == "§§":
if pd_slot[2:] != 'none':
joint_pd_slot = check_slot_inform(joint_gt_slot, pd_slot[2:], label_maps)
else:
print("ERROR: Unexpected slot value format. Aborting.")
exit()
else:
print("ERROR: Unexpected class_type. Aborting.")
exit()
total_correct = True
# Check the per turn correctness of the class_type prediction
if turn_gt_class == turn_pd_class:
class_correctness[turn_gt_class].append(1.0)
class_correctness[-1].append(1.0)
c_tp[turn_gt_class] += 1
for cc in range(len(class_types)):
if cc != turn_gt_class:
c_tn[cc] += 1
# Only where there is a span, we check its per turn correctness
if turn_gt_class == class_types.index('copy_value'):
if gt_start_pos == pd_start_pos and gt_end_pos == pd_end_pos:
pos_correctness.append(1.0)
else:
pos_correctness.append(0.0)
# Only where there is a referral, we check its per turn correctness
if 'refer' in class_types and turn_gt_class == class_types.index('refer'):
if gt_refer == pd_refer:
refer_correctness.append(1.0)
print(" [%s] Correct referral: %s | %s" % (guid, gt_refer, pd_refer))
else:
refer_correctness.append(0.0)
print(" [%s] Incorrect referral: %s | %s" % (guid, gt_refer, pd_refer))
else:
if turn_gt_class == class_types.index('copy_value'):
pos_correctness.append(0.0)
if 'refer' in class_types and turn_gt_class == class_types.index('refer'):
refer_correctness.append(0.0)
class_correctness[turn_gt_class].append(0.0)
class_correctness[-1].append(0.0)
confusion_matrix[turn_gt_class][turn_pd_class].append(1.0)
c_fn[turn_gt_class] += 1
c_fp[turn_pd_class] += 1
# Check the joint slot correctness.
# If the value label is not none, then we need to have a value prediction.
# Even if the class_type is 'none', there can still be a value label,
# it might just not be pointable in the current turn. It might however
# be referrable and thus predicted correctly.
if joint_gt_slot == joint_pd_slot:
val_correctness.append(1.0)
elif joint_gt_slot != 'none' and joint_gt_slot != 'dontcare' and joint_gt_slot != 'true' and joint_gt_slot != 'false' and joint_gt_slot in label_maps:
no_match = True
for variant in label_maps[joint_gt_slot]:
if variant == joint_pd_slot:
no_match = False
break
if no_match:
val_correctness.append(0.0)
total_correct = False
print(" [%s] Incorrect value (variant): %s (turn class: %s) | %s (turn class: %s)" % (guid, joint_gt_slot, turn_gt_class, joint_pd_slot, turn_pd_class))
else:
val_correctness.append(1.0)
else:
val_correctness.append(0.0)
total_correct = False
print(" [%s] Incorrect value: %s (turn class: %s) | %s (turn class: %s)" % (guid, joint_gt_slot, turn_gt_class, joint_pd_slot, turn_pd_class))
total_correctness.append(1.0 if total_correct else 0.0)
# Account for empty lists (due to no instances of spans or referrals being seen)
if pos_correctness == []:
pos_correctness.append(1.0)
if refer_correctness == []:
refer_correctness.append(1.0)
for ct in range(len(class_types)):
if c_tp[ct] + c_fp[ct] > 0:
precision = c_tp[ct] / (c_tp[ct] + c_fp[ct])
else:
precision = 1.0
if c_tp[ct] + c_fn[ct] > 0:
recall = c_tp[ct] / (c_tp[ct] + c_fn[ct])
else:
recall = 1.0
if precision + recall > 0:
f1 = 2 * ((precision * recall) / (precision + recall))
else:
f1 = 1.0
if c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct] > 0:
acc = (c_tp[ct] + c_tn[ct]) / (c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct])
else:
acc = 1.0
print("Performance for class '%s' (%s): Recall: %.2f (%d of %d), Precision: %.2f, F1: %.2f, Accuracy: %.2f (TP/TN/FP/FN: %d/%d/%d/%d)" %
(class_types[ct], ct, recall, np.sum(class_correctness[ct]), len(class_correctness[ct]), precision, f1, acc, c_tp[ct], c_tn[ct], c_fp[ct], c_fn[ct]))
print("Confusion matrix:")
for cl in range(len(class_types)):
print(" %s" % (cl), end="")
print("")
for cl_a in range(len(class_types)):
print("%s " % (cl_a), end="")
for cl_b in range(len(class_types)):
if len(class_correctness[cl_a]) > 0:
print("%.2f " % (np.sum(confusion_matrix[cl_a][cl_b]) / len(class_correctness[cl_a])), end="")
else:
print("---- ", end="")
print("")
return np.asarray(total_correctness), np.asarray(val_correctness), np.asarray(class_correctness), np.asarray(pos_correctness), np.asarray(refer_correctness), np.asarray(confusion_matrix), c_tp, c_tn, c_fp, c_fn
if __name__ == "__main__":
acc_list = []
acc_list_v = []
key_class_label_id = 'class_label_id_%s'
key_class_prediction = 'class_prediction_%s'
key_start_pos = 'start_pos_%s'
key_start_prediction = 'start_prediction_%s'
key_end_pos = 'end_pos_%s'
key_end_prediction = 'end_prediction_%s'
key_refer_id = 'refer_id_%s'
key_refer_prediction = 'refer_prediction_%s'
key_slot_groundtruth = 'slot_groundtruth_%s'
key_slot_prediction = 'slot_prediction_%s'
dataset = sys.argv[1].lower()
dataset_config = sys.argv[2].lower()
if dataset not in ['woz2', 'sim-m', 'sim-r', 'multiwoz21']:
raise ValueError("Task not found: %s" % (dataset))
class_types, slots, label_maps = load_dataset_config(dataset_config)
# Prepare label_maps
label_maps_tmp = {}
for v in label_maps:
label_maps_tmp[tokenize(v)] = [tokenize(nv) for nv in label_maps[v]]
label_maps = label_maps_tmp
for fp in sorted(glob.glob(sys.argv[3])):
print(fp)
goal_correctness = 1.0
cls_acc = [[] for cl in range(len(class_types))]
cls_conf = [[[] for cl_b in range(len(class_types))] for cl_a in range(len(class_types))]
c_tp = {ct: 0 for ct in range(len(class_types))}
c_tn = {ct: 0 for ct in range(len(class_types))}
c_fp = {ct: 0 for ct in range(len(class_types))}
c_fn = {ct: 0 for ct in range(len(class_types))}
for slot in slots:
tot_cor, joint_val_cor, cls_cor, pos_cor, ref_cor, conf_mat, ctp, ctn, cfp, cfn = get_joint_slot_correctness(fp, class_types, label_maps,
key_class_label_id=(key_class_label_id % slot),
key_class_prediction=(key_class_prediction % slot),
key_start_pos=(key_start_pos % slot),
key_start_prediction=(key_start_prediction % slot),
key_end_pos=(key_end_pos % slot),
key_end_prediction=(key_end_prediction % slot),
key_refer_id=(key_refer_id % slot),
key_refer_prediction=(key_refer_prediction % slot),
key_slot_groundtruth=(key_slot_groundtruth % slot),
key_slot_prediction=(key_slot_prediction % slot)
)
print('%s: joint slot acc: %g, joint value acc: %g, turn class acc: %g, turn position acc: %g, turn referral acc: %g' %
(slot, np.mean(tot_cor), np.mean(joint_val_cor), np.mean(cls_cor[-1]), np.mean(pos_cor), np.mean(ref_cor)))
goal_correctness *= tot_cor
for cl_a in range(len(class_types)):
cls_acc[cl_a] += cls_cor[cl_a]
for cl_b in range(len(class_types)):
cls_conf[cl_a][cl_b] += list(conf_mat[cl_a][cl_b])
c_tp[cl_a] += ctp[cl_a]
c_tn[cl_a] += ctn[cl_a]
c_fp[cl_a] += cfp[cl_a]
c_fn[cl_a] += cfn[cl_a]
for ct in range(len(class_types)):
if c_tp[ct] + c_fp[ct] > 0:
precision = c_tp[ct] / (c_tp[ct] + c_fp[ct])
else:
precision = 1.0
if c_tp[ct] + c_fn[ct] > 0:
recall = c_tp[ct] / (c_tp[ct] + c_fn[ct])
else:
recall = 1.0
if precision + recall > 0:
f1 = 2 * ((precision * recall) / (precision + recall))
else:
f1 = 1.0
if c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct] > 0:
acc = (c_tp[ct] + c_tn[ct]) / (c_tp[ct] + c_tn[ct] + c_fp[ct] + c_fn[ct])
else:
acc = 1.0
print("Performance for class '%s' (%s): Recall: %.2f (%d of %d), Precision: %.2f, F1: %.2f, Accuracy: %.2f (TP/TN/FP/FN: %d/%d/%d/%d)" %
(class_types[ct], ct, recall, np.sum(cls_acc[ct]), len(cls_acc[ct]), precision, f1, acc, c_tp[ct], c_tn[ct], c_fp[ct], c_fn[ct]))
print("Confusion matrix:")
for cl in range(len(class_types)):
print(" %s" % (cl), end="")
print("")
for cl_a in range(len(class_types)):
print("%s " % (cl_a), end="")
for cl_b in range(len(class_types)):
if len(cls_acc[cl_a]) > 0:
print("%.2f " % (np.sum(cls_conf[cl_a][cl_b]) / len(cls_acc[cl_a])), end="")
else:
print("---- ", end="")
print("")
acc = np.mean(goal_correctness)
acc_list.append((fp, acc))
acc_list_s = sorted(acc_list, key=lambda tup: tup[1], reverse=True)
for (fp, acc) in acc_list_s:
print('Joint goal acc: %g, %s' % (acc, fp))
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
# Part of this code is based on the source code of Transformers
# (arXiv:1910.03771)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable)
from transformers.modeling_bert import (BertModel, BertPreTrainedModel, BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
@add_start_docstrings(
"""BERT Model with a classification heads for the DST task. """,
BERT_START_DOCSTRING,
)
class BertForDST(BertPreTrainedModel):
def __init__(self, config):
super(BertForDST, self).__init__(config)
self.slot_list = config.dst_slot_list
self.class_types = config.dst_class_types
self.class_labels = config.dst_class_labels
self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable
self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable
self.class_aux_feats_inform = config.dst_class_aux_feats_inform
self.class_aux_feats_ds = config.dst_class_aux_feats_ds
self.class_loss_ratio = config.dst_class_loss_ratio
# Only use refer loss if refer class is present in dataset.
if 'refer' in self.class_types:
self.refer_index = self.class_types.index('refer')
else:
self.refer_index = -1
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.dst_dropout_rate)
self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate)
if self.class_aux_feats_inform:
self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
if self.class_aux_feats_ds:
self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list)))
aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2
for slot in self.slot_list:
self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels))
self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2))
self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1))
self.init_weights()
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
def forward(self,
input_ids,
input_mask=None,
segment_ids=None,
position_ids=None,
head_mask=None,
start_pos=None,
end_pos=None,
inform_slot_id=None,
refer_id=None,
class_label_id=None,
diag_state=None):
outputs = self.bert(
input_ids,
attention_mask=input_mask,
token_type_ids=segment_ids,
position_ids=position_ids,
head_mask=head_mask
)
sequence_output = outputs[0]
pooled_output = outputs[1]
sequence_output = self.dropout(sequence_output)
pooled_output = self.dropout(pooled_output)
# TODO: establish proper format in labels already?
if inform_slot_id is not None:
inform_labels = torch.stack(list(inform_slot_id.values()), 1).float()
if diag_state is not None:
diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0)
total_loss = 0
per_slot_per_example_loss = {}
per_slot_class_logits = {}
per_slot_start_logits = {}
per_slot_end_logits = {}
per_slot_refer_logits = {}
for slot in self.slot_list:
if self.class_aux_feats_inform and self.class_aux_feats_ds:
pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1)
elif self.class_aux_feats_inform:
pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1)
elif self.class_aux_feats_ds:
pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1)
else:
pooled_output_aux = pooled_output
class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux))
token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output))
start_logits, end_logits = token_logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux))
per_slot_class_logits[slot] = class_logits
per_slot_start_logits[slot] = start_logits
per_slot_end_logits[slot] = end_logits
per_slot_refer_logits[slot] = refer_logits
# If there are no labels, don't compute loss
if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None:
# If we are on multi-GPU, split add a dimension
if len(start_pos[slot].size()) > 1:
start_pos[slot] = start_pos[slot].squeeze(-1)
if len(end_pos[slot].size()) > 1:
end_pos[slot] = end_pos[slot].squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) # This is a single index
start_pos[slot].clamp_(0, ignored_index)
end_pos[slot].clamp_(0, ignored_index)
class_loss_fct = CrossEntropyLoss(reduction='none')
token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index)
refer_loss_fct = CrossEntropyLoss(reduction='none')
start_loss = token_loss_fct(start_logits, start_pos[slot])
end_loss = token_loss_fct(end_logits, end_pos[slot])
token_loss = (start_loss + end_loss) / 2.0
token_is_pointable = (start_pos[slot] > 0).float()
if not self.token_loss_for_nonpointable:
token_loss *= token_is_pointable
refer_loss = refer_loss_fct(refer_logits, refer_id[slot])
token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float()
if not self.refer_loss_for_nonpointable:
refer_loss *= token_is_referrable
class_loss = class_loss_fct(class_logits, class_label_id[slot])
if self.refer_index > -1:
per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss
else:
per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss
total_loss += per_example_loss.sum()
per_slot_per_example_loss[slot] = per_example_loss
# add hidden states and attention if they are here
outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:]
return outputs
This diff is collapsed.
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch.utils.data import Dataset
class TensorListDataset(Dataset):
r"""Dataset wrapping tensors, tensor dicts and tensor lists.
Arguments:
*data (Tensor or dict or list of Tensors): tensors that have the same size
of the first dimension.
"""
def __init__(self, *data):
if isinstance(data[0], dict):
size = list(data[0].values())[0].size(0)
elif isinstance(data[0], list):
size = data[0][0].size(0)
else:
size = data[0].size(0)
for element in data:
if isinstance(element, dict):
assert all(size == tensor.size(0) for name, tensor in element.items()) # dict of tensors
elif isinstance(element, list):
assert all(size == tensor.size(0) for tensor in element) # list of tensors
else:
assert size == element.size(0) # tensor
self.size = size
self.data = data
def __getitem__(self, index):
result = []
for element in self.data:
if isinstance(element, dict):
result.append({k: v[index] for k, v in element.items()})
elif isinstance(element, list):
result.append(v[index] for v in element)
else:
result.append(element[index])
return tuple(result)
def __len__(self):
return self.size
# coding=utf-8
#
# Copyright 2020 Heinrich Heine University Duesseldorf
#
# Part of this code is based on the source code of BERT-DST
# (arXiv:1907.03040)
# Part of this code is based on the source code of Transformers
# (arXiv:1910.03771)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import six
import numpy as np
import json
logger = logging.getLogger(__name__)
class DSTExample(object):
"""
A single training/test example for the DST dataset.
"""
def __init__(self,
guid,
text_a,
text_b,
history,
text_a_label=None,
text_b_label=None,
history_label=None,
values=None,
inform_label=None,
inform_slot_label=None,
refer_label=None,
diag_state=None,
class_label=None):
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.history = history
self.text_a_label = text_a_label
self.text_b_label = text_b_label
self.history_label = history_label
self.values = values
self.inform_label = inform_label
self.inform_slot_label = inform_slot_label
self.refer_label = refer_label
self.diag_state = diag_state
self.class_label = class_label
def __str__(self):
return self.__repr__()
def __repr__(self):
s = ""
s += "guid: %s" % (self.guid)
s += ", text_a: %s" % (self.text_a)
s += ", text_b: %s" % (self.text_b)
s += ", history: %s" % (self.history)
if self.text_a_label:
s += ", text_a_label: %d" % (self.text_a_label)
if self.text_b_label:
s += ", text_b_label: %d" % (self.text_b_label)
if self.history_label:
s += ", history_label: %d" % (self.history_label)
if self.values:
s += ", values: %d" % (self.values)
if self.inform_label:
s += ", inform_label: %d" % (self.inform_label)
if self.inform_slot_label:
s += ", inform_slot_label: %d" % (self.inform_slot_label)
if self.refer_label:
s += ", refer_label: %d" % (self.refer_label)
if self.diag_state:
s += ", diag_state: %d" % (self.diag_state)
if self.class_label:
s += ", class_label: %d" % (self.class_label)
return s
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
input_ids,
input_ids_unmasked,
input_mask,
segment_ids,
start_pos=None,
end_pos=None,
values=None,
inform=None,
inform_slot=None,
refer_id=None,
diag_state=None,
class_label_id=None,
guid="NONE"):
self.guid = guid
self.input_ids = input_ids
self.input_ids_unmasked = input_ids_unmasked
self.input_mask = input_mask
self.segment_ids = segment_ids
self.start_pos = start_pos
self.end_pos = end_pos
self.values = values
self.inform = inform
self.inform_slot = inform_slot
self.refer_id = refer_id
self.diag_state = diag_state
self.class_label_id = class_label_id
def convert_examples_to_features(examples, slot_list, class_types, model_type, tokenizer, max_seq_length, slot_value_dropout=0.0):
"""Loads a data file into a list of `InputBatch`s."""
if model_type == 'bert':
model_specs = {'MODEL_TYPE': 'bert',
'CLS_TOKEN': '[CLS]',
'UNK_TOKEN': '[UNK]',
'SEP_TOKEN': '[SEP]',
'TOKEN_CORRECTION': 4}
else:
logger.error("Unknown model type (%s). Aborting." % (model_type))
exit(1)
def _tokenize_text_and_label(text, text_label_dict, slot, tokenizer, model_specs, slot_value_dropout):
joint_text_label = [0 for _ in text_label_dict[slot]] # joint all slots' label
for slot_text_label in text_label_dict.values():
for idx, label in enumerate(slot_text_label):
if label == 1:
joint_text_label[idx] = 1
text_label = text_label_dict[slot]
tokens = []
tokens_unmasked = []
token_labels = []
for token, token_label, joint_label in zip(text, text_label, joint_text_label):
token = convert_to_unicode(token)
sub_tokens = tokenizer.tokenize(token) # Most time intensive step
tokens_unmasked.extend(sub_tokens)
if slot_value_dropout == 0.0 or joint_label == 0:
tokens.extend(sub_tokens)
else:
rn_list = np.random.random_sample((len(sub_tokens),))
for rn, sub_token in zip(rn_list, sub_tokens):
if rn > slot_value_dropout:
tokens.append(sub_token)
else:
tokens.append(model_specs['UNK_TOKEN'])
token_labels.extend([token_label for _ in sub_tokens])
assert len(tokens) == len(token_labels)
assert len(tokens_unmasked) == len(token_labels)
return tokens, tokens_unmasked, token_labels
def _truncate_seq_pair(tokens_a, tokens_b, history, max_length):
"""Truncates a sequence pair in place to the maximum length.
Copied from bert/run_classifier.py
"""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b) + len(history)
if total_length <= max_length:
break
if len(history) > 0:
history.pop()
elif len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def _truncate_length_and_warn(tokens_a, tokens_b, history, max_seq_length, model_specs, guid):
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP], [SEP] with "- 4" (BERT)
if len(tokens_a) + len(tokens_b) + len(history) > max_seq_length - model_specs['TOKEN_CORRECTION']:
logger.info("Truncate Example %s. Total len=%d." % (guid, len(tokens_a) + len(tokens_b) + len(history)))
input_text_too_long = True
else:
input_text_too_long = False
_truncate_seq_pair(tokens_a, tokens_b, history, max_seq_length - model_specs['TOKEN_CORRECTION'])
return input_text_too_long
def _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs):
token_label_ids = []
token_label_ids.append(0) # [CLS]
for token_label in token_labels_a:
token_label_ids.append(token_label)
token_label_ids.append(0) # [SEP]
for token_label in token_labels_b:
token_label_ids.append(token_label)
token_label_ids.append(0) # [SEP]
for token_label in token_labels_history:
token_label_ids.append(token_label)
token_label_ids.append(0) # [SEP]
while len(token_label_ids) < max_seq_length:
token_label_ids.append(0) # padding
assert len(token_label_ids) == max_seq_length
return token_label_ids
def _get_start_end_pos(class_type, token_label_ids, max_seq_length):
if class_type == 'copy_value' and 1 not in token_label_ids:
#logger.warn("copy_value label, but token_label not detected. Setting label to 'none'.")
class_type = 'none'
start_pos = 0
end_pos = 0
if 1 in token_label_ids:
start_pos = token_label_ids.index(1)
# Parsing is supposed to find only first location of wanted value
if 0 not in token_label_ids[start_pos:]:
end_pos = len(token_label_ids[start_pos:]) + start_pos - 1
else:
end_pos = token_label_ids[start_pos:].index(0) + start_pos - 1
for i in range(max_seq_length):
if i >= start_pos and i <= end_pos:
assert token_label_ids[i] == 1
return class_type, start_pos, end_pos
def _get_transformer_input(tokens_a, tokens_b, history, max_seq_length, tokenizer, model_specs):
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
segment_ids = []
tokens.append(model_specs['CLS_TOKEN'])
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append(model_specs['SEP_TOKEN'])
segment_ids.append(0)
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append(model_specs['SEP_TOKEN'])
segment_ids.append(1)
for token in history:
tokens.append(token)
segment_ids.append(1)
tokens.append(model_specs['SEP_TOKEN'])
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
return tokens, input_ids, input_mask, segment_ids
total_cnt = 0
too_long_cnt = 0
refer_list = ['none'] + slot_list
features = []
# Convert single example
for (example_index, example) in enumerate(examples):
if example_index % 1000 == 0:
logger.info("Writing example %d of %d" % (example_index, len(examples)))
total_cnt += 1
value_dict = {}
inform_dict = {}
inform_slot_dict = {}
refer_id_dict = {}
diag_state_dict = {}
class_label_id_dict = {}
start_pos_dict = {}
end_pos_dict = {}
for slot in slot_list:
tokens_a, tokens_a_unmasked, token_labels_a = _tokenize_text_and_label(
example.text_a, example.text_a_label, slot, tokenizer, model_specs, slot_value_dropout)
tokens_b, tokens_b_unmasked, token_labels_b = _tokenize_text_and_label(
example.text_b, example.text_b_label, slot, tokenizer, model_specs, slot_value_dropout)
tokens_history, tokens_history_unmasked, token_labels_history = _tokenize_text_and_label(
example.history, example.history_label, slot, tokenizer, model_specs, slot_value_dropout)
input_text_too_long = _truncate_length_and_warn(
tokens_a, tokens_b, tokens_history, max_seq_length, model_specs, example.guid)
if input_text_too_long:
if example_index < 10:
if len(token_labels_a) > len(tokens_a):
logger.info(' tokens_a truncated labels: %s' % str(token_labels_a[len(tokens_a):]))
if len(token_labels_b) > len(tokens_b):
logger.info(' tokens_b truncated labels: %s' % str(token_labels_b[len(tokens_b):]))
if len(token_labels_history) > len(tokens_history):
logger.info(' tokens_history truncated labels: %s' % str(token_labels_history[len(tokens_history):]))
token_labels_a = token_labels_a[:len(tokens_a)]
token_labels_b = token_labels_b[:len(tokens_b)]
token_labels_history = token_labels_history[:len(tokens_history)]
tokens_a_unmasked = tokens_a_unmasked[:len(tokens_a)]
tokens_b_unmasked = tokens_b_unmasked[:len(tokens_b)]
tokens_history_unmasked = tokens_history_unmasked[:len(tokens_history)]
assert len(token_labels_a) == len(tokens_a)
assert len(token_labels_b) == len(tokens_b)
assert len(token_labels_history) == len(tokens_history)
assert len(token_labels_a) == len(tokens_a_unmasked)
assert len(token_labels_b) == len(tokens_b_unmasked)
assert len(token_labels_history) == len(tokens_history_unmasked)
token_label_ids = _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs)
value_dict[slot] = example.values[slot]
inform_dict[slot] = example.inform_label[slot]
class_label_mod, start_pos_dict[slot], end_pos_dict[slot] = _get_start_end_pos(
example.class_label[slot], token_label_ids, max_seq_length)
if class_label_mod != example.class_label[slot]:
example.class_label[slot] = class_label_mod
inform_slot_dict[slot] = example.inform_slot_label[slot]
refer_id_dict[slot] = refer_list.index(example.refer_label[slot])
diag_state_dict[slot] = class_types.index(example.diag_state[slot])
class_label_id_dict[slot] = class_types.index(example.class_label[slot])
if input_text_too_long:
too_long_cnt += 1
tokens, input_ids, input_mask, segment_ids = _get_transformer_input(tokens_a,
tokens_b,
tokens_history,
max_seq_length,
tokenizer,
model_specs)
if slot_value_dropout > 0.0:
_, input_ids_unmasked, _, _ = _get_transformer_input(tokens_a_unmasked,
tokens_b_unmasked,
tokens_history_unmasked,
max_seq_length,
tokenizer,
model_specs)
else:
input_ids_unmasked = input_ids
assert(len(input_ids) == len(input_ids_unmasked))
if example_index < 10:
logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid))
logger.info("tokens: %s" % " ".join(tokens))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
logger.info("start_pos: %s" % str(start_pos_dict))
logger.info("end_pos: %s" % str(end_pos_dict))
logger.info("values: %s" % str(value_dict))
logger.info("inform: %s" % str(inform_dict))
logger.info("inform_slot: %s" % str(inform_slot_dict))
logger.info("refer_id: %s" % str(refer_id_dict))
logger.info("diag_state: %s" % str(diag_state_dict))
logger.info("class_label_id: %s" % str(class_label_id_dict))
features.append(
InputFeatures(
guid=example.guid,
input_ids=input_ids,
input_ids_unmasked=input_ids_unmasked,
input_mask=input_mask,
segment_ids=segment_ids,
start_pos=start_pos_dict,
end_pos=end_pos_dict,
values=value_dict,
inform=inform_dict,
inform_slot=inform_slot_dict,
refer_id=refer_id_dict,
diag_state=diag_state_dict,
class_label_id=class_label_id_dict))
logger.info("========== %d out of %d examples have text too long" % (too_long_cnt, total_cnt))
return features
# From bert.tokenization (TF code)
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment