diff --git a/DO.example.mtl b/DO.example.mtl new file mode 100755 index 0000000000000000000000000000000000000000..7963eb61ad23d47e3a1a9edfabbfee4d59dc8e63 --- /dev/null +++ b/DO.example.mtl @@ -0,0 +1,74 @@ +#!/bin/bash + +# Parameters ------------------------------------------------------ + +#TASK="sim-m" +#DATA_DIR="data/simulated-dialogue/sim-M" +#TASK="sim-r" +#DATA_DIR="data/simulated-dialogue/sim-R" +#TASK="woz2" +#DATA_DIR="data/woz2" +TASK="multiwoz21" +DATA_DIR="data/MULTIWOZ2.1" + +AUX_TASK="cola" +AUX_DATA_DIR="data/aux/roberta_base_cased_lower" + +# Project paths etc. ---------------------------------------------- + +OUT_DIR=results +mkdir -p ${OUT_DIR} + +# Main ------------------------------------------------------------ + +for step in train dev test; do + args_add="" + if [ "$step" = "train" ]; then + args_add="--do_train --predict_type=dummy" + elif [ "$step" = "dev" ] || [ "$step" = "test" ]; then + args_add="--do_eval --predict_type=${step}" + fi + + python3 run_dst_mtl.py \ + --task_name=${TASK} \ + --data_dir=${DATA_DIR} \ + --dataset_config=dataset_config/${TASK}.json \ + --model_type="roberta" \ + --model_name_or_path="roberta-base" \ + --do_lower_case \ + --learning_rate=1e-4 \ + --num_train_epochs=10 \ + --max_seq_length=180 \ + --per_gpu_train_batch_size=48 \ + --per_gpu_eval_batch_size=1 \ + --output_dir=${OUT_DIR} \ + --save_epochs=2 \ + --logging_steps=10 \ + --warmup_proportion=0.1 \ + --eval_all_checkpoints \ + --adam_epsilon=1e-6 \ + --weight_decay=0.01 \ + --heads_dropout=0.1 \ + --label_value_repetitions \ + --swap_utterances \ + --append_history \ + --use_history_labels \ + --delexicalize_sys_utts \ + --class_aux_feats_inform \ + --class_aux_feats_ds \ + --mtl_use \ + --mtl_task_def=dataset_config/aux_task_def.json \ + --mtl_train_dataset=${AUX_TASK} \ + --mtl_data_dir=${AUX_DATA_DIR} \ + --mtl_ratio=0.7 \ + ${args_add} \ + 2>&1 | tee ${OUT_DIR}/${step}.log + + if [ "$step" = "dev" ] || [ "$step" = "test" ]; then + python3 metric_bert_dst.py \ + ${TASK} \ + dataset_config/${TASK}.json \ + "${OUT_DIR}/pred_res.${step}*json" \ + 2>&1 | tee ${OUT_DIR}/eval_pred_${step}.log + fi +done diff --git a/README.md b/README.md index 04489ca48fcd2a68113dbca98b0ca86fe0c1797f..b90e4deda21bdced28dc5f97fe6c32340a92e2b2 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,10 @@ ## *** For readers of "Out-of-Task Training for Dialog State Tracking Models" *** -The code update to cover the content of the paper (https://arxiv.org/abs/2011.09379) will be available here after our presentation at COLING 2020 +The first version of the MTL code is available now. `DO.example.mtl` will train a model with MTL using an auxiliary task. As of now, pre-tokenized data is loaded for the auxiliary tasks. The next update will also include tokenization of the original data. + +The paper is available here: +https://www.aclweb.org/anthology/2020.coling-main.596 +https://arxiv.org/abs/2011.09379 ## Introduction diff --git a/data/aux/bert_base_uncased_lower.tar.gz b/data/aux/bert_base_uncased_lower.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..4f91ff72f72ca534f644e0dfda03f4fda09157d4 Binary files /dev/null and b/data/aux/bert_base_uncased_lower.tar.gz differ diff --git a/data/aux/roberta_base_cased_lower.tar.gz b/data/aux/roberta_base_cased_lower.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..c9270c264070c66e4e4153a2be30ba8ad55f39d3 Binary files /dev/null and b/data/aux/roberta_base_cased_lower.tar.gz differ diff --git a/data_processors.py b/data_processors.py index 6d7670df0b5a0734f05499fab24bf4a137edb652..88d94ae3c3b8687080571533e5f2b40f3386b3cd 100644 --- a/data_processors.py +++ b/data_processors.py @@ -23,6 +23,7 @@ import json import dataset_woz2 import dataset_sim import dataset_multiwoz21 +import dataset_aux_task class DataProcessor(object): @@ -88,7 +89,14 @@ class SimProcessor(DataProcessor): 'test', self.slot_list, **args) +class AuxTaskProcessor(object): + def get_aux_task_examples(self, data_dir, data_name, max_seq_length): + file_path = os.path.join(data_dir, '{}_train.json'.format(data_name)) + return dataset_aux_task.create_examples(file_path, max_seq_length) + + PROCESSORS = {"woz2": Woz2Processor, "sim-m": SimProcessor, "sim-r": SimProcessor, - "multiwoz21": Multiwoz21Processor} + "multiwoz21": Multiwoz21Processor, + "aux_task": AuxTaskProcessor} diff --git a/dataset_aux_task.py b/dataset_aux_task.py new file mode 100644 index 0000000000000000000000000000000000000000..40a8667640a30342c0498243730d8c8f000d5774 --- /dev/null +++ b/dataset_aux_task.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + + +def create_examples(path, maxlen=512): + examples = [] + with open(path, 'r', encoding='utf-8') as reader: + cnt = 0 + for line in reader: + sample = json.loads(line) + if len(sample['token_id']) > maxlen: + continue + cnt += 1 + examples.append(sample) + print('Loaded {} samples out of {}'.format(len(examples), cnt)) + return examples diff --git a/dataset_config/aux_task_def.json b/dataset_config/aux_task_def.json new file mode 100644 index 0000000000000000000000000000000000000000..dd92bce417d71a105bfbcece370d0cf45dfb6f42 --- /dev/null +++ b/dataset_config/aux_task_def.json @@ -0,0 +1,46 @@ +{ + "cola": { + "n_class": 2, + "task_type": "classification" + }, + "mnli": { + "n_class": 3, + "task_type": "classification" + }, + "mrpc": { + "n_class": 2, + "task_type": "classification" + }, + "qnli": { + "n_class": 2, + "task_type": "classification" + }, + "qqp": { + "n_class": 2, + "task_type": "classification" + }, + "rte": { + "n_class": 2, + "task_type": "classification" + }, + "sst": { + "n_class": 2, + "task_type": "classification" + }, + "stsb": { + "n_class": 1, + "task_type": "regression" + }, + "wnli": { + "n_class": 2, + "task_type": "classification" + }, + "squad": { + "n_class": 2, + "task_type": "span" + }, + "squad-v2": { + "n_class": 2, + "task_type": "span" + } +} diff --git a/dataset_multiwoz21.py b/dataset_multiwoz21.py index e514638415db99af164190fa555e0fe993c69cab..92bdd2c57ad778134af66476dfac1d963df9c81b 100644 --- a/dataset_multiwoz21.py +++ b/dataset_multiwoz21.py @@ -203,7 +203,7 @@ def is_in_list(tok, value): return found -def delex_utt(utt, values): +def delex_utt(utt, values, unk_token="[UNK]"): utt_norm = tokenize(utt) for s, vals in values.items(): for v in vals: @@ -212,7 +212,7 @@ def delex_utt(utt, values): v_len = len(v_norm) for i in range(len(utt_norm) + 1 - v_len): if utt_norm[i:i + v_len] == v_norm: - utt_norm[i:i + v_len] = ['[UNK]'] * v_len + utt_norm[i:i + v_len] = [unk_token] * v_len return utt_norm @@ -300,6 +300,7 @@ def create_examples(input_file, acts_file, set_type, slot_list, swap_utterances=False, label_value_repetitions=False, delexicalize_sys_utts=False, + unk_token="[UNK]", analyze=False): """Read a DST json file into a list of DSTExample.""" @@ -343,7 +344,7 @@ def create_examples(input_file, acts_file, set_type, slot_list, for slot in slot_list: if (str(dialog_id), str(turn_itr), slot) in sys_inform_dict: inform_dict[slot] = sys_inform_dict[(str(dialog_id), str(turn_itr), slot)] - utt_tok_list.append(delex_utt(utt['text'], inform_dict)) # normalize utterances + utt_tok_list.append(delex_utt(utt['text'], inform_dict, unk_token)) # normalize utterances else: utt_tok_list.append(tokenize(utt['text'])) # normalize utterances diff --git a/dataset_sim.py b/dataset_sim.py index 575a29f0da6746efb963283de479e8119de39cba..e8caa3b1e486c581e3bfcfc205be4b384d71a305 100644 --- a/dataset_sim.py +++ b/dataset_sim.py @@ -114,15 +114,15 @@ def get_tok_label(prev_ds_dict, cur_ds_dict, slot_type, sys_utt_tok, return sys_utt_tok_label, usr_utt_tok_label, class_type -def delex_utt(utt, values): +def delex_utt(utt, values, unk_token="[UNK]"): utt_delex = utt.copy() for v in values: - utt_delex[v['start']:v['exclusive_end']] = ['[UNK]'] * (v['exclusive_end'] - v['start']) + utt_delex[v['start']:v['exclusive_end']] = [unk_token] * (v['exclusive_end'] - v['start']) return utt_delex def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, sys_inform_dict, - delexicalize_sys_utts=False, slot_last_occurrence=True): + delexicalize_sys_utts=False, unk_token="[UNK]", slot_last_occurrence=True): """Make turn_label a dictionary of slot with value positions or being dontcare / none: Turn label contains: (1) the updates from previous to current dialogue state, @@ -157,7 +157,7 @@ def get_turn_label(turn, prev_dialogue_state, slot_list, dial_id, turn_id, sys_i class_type_dict[slot_type] = class_type if delexicalize_sys_utts: - sys_utt_tok = delex_utt(sys_utt_tok, sys_slot_label) + sys_utt_tok = delex_utt(sys_utt_tok, sys_slot_label, unk_token) return (sys_utt_tok, sys_utt_tok_label_dict, usr_utt_tok, usr_utt_tok_label_dict, @@ -172,6 +172,7 @@ def create_examples(input_file, set_type, slot_list, swap_utterances=False, label_value_repetitions=False, delexicalize_sys_utts=False, + unk_token="[UNK]", analyze=False): """Read a DST json file into a list of DSTExample.""" @@ -207,6 +208,7 @@ def create_examples(input_file, set_type, slot_list, turn_id, sys_inform_dict, delexicalize_sys_utts=delexicalize_sys_utts, + unk_token=unk_token, slot_last_occurrence=True) if swap_utterances: diff --git a/dataset_woz2.py b/dataset_woz2.py index ef0a1efbaf98888f3103602e9210148ab835b03c..3c605d8bd7d4c3a95db1965c846e5ea11bbdfc78 100644 --- a/dataset_woz2.py +++ b/dataset_woz2.py @@ -27,7 +27,7 @@ LABEL_MAPS = {} # Loaded from file LABEL_FIX = {'centre': 'center', 'areas': 'area', 'phone number': 'number', 'price range': 'price_range'} -def delex_utt(utt, values): +def delex_utt(utt, values, unk_token="[UNK]"): utt_norm = utt.copy() for s, v in values.items(): if v != 'none': @@ -35,7 +35,7 @@ def delex_utt(utt, values): v_len = len(v_norm) for i in range(len(utt_norm) + 1 - v_len): if utt_norm[i:i + v_len] == v_norm: - utt_norm[i:i + v_len] = ['[UNK]'] * v_len + utt_norm[i:i + v_len] = [unk_token] * v_len return utt_norm @@ -104,6 +104,7 @@ def create_examples(input_file, set_type, slot_list, swap_utterances=False, label_value_repetitions=False, delexicalize_sys_utts=False, + unk_token="[UNK]", analyze=False): """Read a DST json file into a list of DSTExample.""" with open(input_file, "r", encoding='utf-8') as reader: @@ -163,7 +164,7 @@ def create_examples(input_file, set_type, slot_list, _, _, in_sys, _ = check_label_existence(label, usr_utt_tok, sys_utt_tok) if in_sys: delex_dict[slot] = label - sys_utt_tok_delex = delex_utt(sys_utt_tok, delex_dict) + sys_utt_tok_delex = delex_utt(sys_utt_tok, delex_dict, unk_token) new_hst_utt_tok_label_dict = hst_utt_tok_label_dict.copy() new_diag_state = diag_state.copy() @@ -177,9 +178,9 @@ def create_examples(input_file, set_type, slot_list, (usr_utt_tok_label, class_type, is_informed) = get_turn_label(label, - sys_utt_tok, - usr_utt_tok, - slot_last_occurrence=True) + sys_utt_tok, + usr_utt_tok, + slot_last_occurrence=True) if class_type == 'inform': inform_dict[slot] = label diff --git a/modeling_bert_dst.py b/modeling_bert_dst.py index dc7206c3f0380417292289f2fe1909777c6cb418..e0ccc7d0539e63fe6397ad2c5fa41ac7401a4332 100644 --- a/modeling_bert_dst.py +++ b/modeling_bert_dst.py @@ -65,6 +65,10 @@ class BertForDST(BertPreTrainedModel): self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2)) self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1)) + # Head for aux task + if hasattr(config, "aux_task_def"): + self.add_module("aux_out_projection", nn.Linear(config.hidden_size, int(config.aux_task_def['n_class']))) + self.init_weights() @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING) @@ -79,7 +83,8 @@ class BertForDST(BertPreTrainedModel): inform_slot_id=None, refer_id=None, class_label_id=None, - diag_state=None): + diag_state=None, + aux_task_def=None): outputs = self.bert( input_ids, attention_mask=input_mask, @@ -94,6 +99,40 @@ class BertForDST(BertPreTrainedModel): sequence_output = self.dropout(sequence_output) pooled_output = self.dropout(pooled_output) + if aux_task_def is not None: + if aux_task_def['task_type'] == "classification": + aux_logits = getattr(self, 'aux_out_projection')(pooled_output) + aux_logits = self.dropout_heads(aux_logits) + aux_loss_fct = CrossEntropyLoss() + aux_loss = aux_loss_fct(aux_logits, class_label_id) + # add hidden states and attention if they are here + return (aux_loss,) + outputs[2:] + elif aux_task_def['task_type'] == "span": + aux_logits = getattr(self, 'aux_out_projection')(sequence_output) + aux_start_logits, aux_end_logits = aux_logits.split(1, dim=-1) + aux_start_logits = self.dropout_heads(aux_start_logits) + aux_end_logits = self.dropout_heads(aux_end_logits) + aux_start_logits = aux_start_logits.squeeze(-1) + aux_end_logits = aux_end_logits.squeeze(-1) + + # If we are on multi-GPU, split add a dimension + if len(start_pos.size()) > 1: + start_pos = start_pos.squeeze(-1) + if len(end_pos.size()) > 1: + end_pos = end_pos.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = aux_start_logits.size(1) # This is a single index + start_pos.clamp_(0, ignored_index) + end_pos.clamp_(0, ignored_index) + + aux_token_loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + aux_start_loss = aux_token_loss_fct(torch.cat((aux_start_logits, aux_end_logits), 1), start_pos) + aux_end_loss = aux_token_loss_fct(torch.cat((aux_end_logits, aux_start_logits), 1), end_pos) + aux_loss = (aux_start_loss + aux_end_loss) / 2.0 + return (aux_loss,) + outputs[2:] + else: + raise Exception("Unknown task_type") + # TODO: establish proper format in labels already? if inform_slot_id is not None: inform_labels = torch.stack(list(inform_slot_id.values()), 1).float() diff --git a/modeling_roberta_dst.py b/modeling_roberta_dst.py new file mode 100644 index 0000000000000000000000000000000000000000..062fccb97831f65f67a4c3a4cdeb3685688fcd51 --- /dev/null +++ b/modeling_roberta_dst.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers.file_utils import (add_start_docstrings, add_start_docstrings_to_callable) +from transformers.modeling_utils import (PreTrainedModel) +from transformers.modeling_roberta import (RobertaModel, RobertaConfig, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, + ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING, BertLayerNorm) + + +class RobertaPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + config_class = RobertaConfig + pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP + base_model_prefix = "roberta" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +@add_start_docstrings( + """RoBERTa Model with classification heads for the DST task. """, + ROBERTA_START_DOCSTRING, +) +class RobertaForDST(RobertaPreTrainedModel): + def __init__(self, config): + super(RobertaForDST, self).__init__(config) + self.slot_list = config.dst_slot_list + self.class_types = config.dst_class_types + self.class_labels = config.dst_class_labels + self.token_loss_for_nonpointable = config.dst_token_loss_for_nonpointable + self.refer_loss_for_nonpointable = config.dst_refer_loss_for_nonpointable + self.class_aux_feats_inform = config.dst_class_aux_feats_inform + self.class_aux_feats_ds = config.dst_class_aux_feats_ds + self.class_loss_ratio = config.dst_class_loss_ratio + + # Only use refer loss if refer class is present in dataset. + if 'refer' in self.class_types: + self.refer_index = self.class_types.index('refer') + else: + self.refer_index = -1 + + self.roberta = RobertaModel(config) + self.dropout = nn.Dropout(config.dst_dropout_rate) + self.dropout_heads = nn.Dropout(config.dst_heads_dropout_rate) + + if self.class_aux_feats_inform: + self.add_module("inform_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) + if self.class_aux_feats_ds: + self.add_module("ds_projection", nn.Linear(len(self.slot_list), len(self.slot_list))) + + aux_dims = len(self.slot_list) * (self.class_aux_feats_inform + self.class_aux_feats_ds) # second term is 0, 1 or 2 + + for slot in self.slot_list: + self.add_module("class_" + slot, nn.Linear(config.hidden_size + aux_dims, self.class_labels)) + self.add_module("token_" + slot, nn.Linear(config.hidden_size, 2)) + self.add_module("refer_" + slot, nn.Linear(config.hidden_size + aux_dims, len(self.slot_list) + 1)) + + # Head for aux task + if hasattr(config, "aux_task_def"): + self.add_module("aux_out_projection", nn.Linear(config.hidden_size, int(config.aux_task_def['n_class']))) + + self.init_weights() + + @add_start_docstrings_to_callable(ROBERTA_INPUTS_DOCSTRING) + def forward(self, + input_ids, + input_mask=None, + segment_ids=None, + position_ids=None, + head_mask=None, + start_pos=None, + end_pos=None, + inform_slot_id=None, + refer_id=None, + class_label_id=None, + diag_state=None, + aux_task_def=None): + outputs = self.roberta( + input_ids, + attention_mask=input_mask, + token_type_ids=segment_ids, + position_ids=position_ids, + head_mask=head_mask + ) + + sequence_output = outputs[0] + pooled_output = outputs[1] + + sequence_output = self.dropout(sequence_output) + pooled_output = self.dropout(pooled_output) + + if aux_task_def is not None: + if aux_task_def['task_type'] == "classification": + aux_logits = getattr(self, 'aux_out_projection')(pooled_output) + aux_logits = self.dropout_heads(aux_logits) + aux_loss_fct = CrossEntropyLoss() + aux_loss = aux_loss_fct(aux_logits, class_label_id) + # add hidden states and attention if they are here + return (aux_loss,) + outputs[2:] + elif aux_task_def['task_type'] == "span": + aux_logits = getattr(self, 'aux_out_projection')(sequence_output) + aux_start_logits, aux_end_logits = aux_logits.split(1, dim=-1) + aux_start_logits = self.dropout_heads(aux_start_logits) + aux_end_logits = self.dropout_heads(aux_end_logits) + aux_start_logits = aux_start_logits.squeeze(-1) + aux_end_logits = aux_end_logits.squeeze(-1) + + # If we are on multi-GPU, split add a dimension + if len(start_pos.size()) > 1: + start_pos = start_pos.squeeze(-1) + if len(end_pos.size()) > 1: + end_pos = end_pos.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = aux_start_logits.size(1) # This is a single index + start_pos.clamp_(0, ignored_index) + end_pos.clamp_(0, ignored_index) + + aux_token_loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + aux_start_loss = aux_token_loss_fct(torch.cat((aux_start_logits, aux_end_logits), 1), start_pos) + aux_end_loss = aux_token_loss_fct(torch.cat((aux_end_logits, aux_start_logits), 1), end_pos) + aux_loss = (aux_start_loss + aux_end_loss) / 2.0 + return (aux_loss,) + outputs[2:] + else: + raise Exception("Unknown task_type") + + # TODO: establish proper format in labels already? + if inform_slot_id is not None: + inform_labels = torch.stack(list(inform_slot_id.values()), 1).float() + if diag_state is not None: + diag_state_labels = torch.clamp(torch.stack(list(diag_state.values()), 1).float(), 0.0, 1.0) + + total_loss = 0 + per_slot_per_example_loss = {} + per_slot_class_logits = {} + per_slot_start_logits = {} + per_slot_end_logits = {} + per_slot_refer_logits = {} + for slot in self.slot_list: + if self.class_aux_feats_inform and self.class_aux_feats_ds: + pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels), self.ds_projection(diag_state_labels)), 1) + elif self.class_aux_feats_inform: + pooled_output_aux = torch.cat((pooled_output, self.inform_projection(inform_labels)), 1) + elif self.class_aux_feats_ds: + pooled_output_aux = torch.cat((pooled_output, self.ds_projection(diag_state_labels)), 1) + else: + pooled_output_aux = pooled_output + class_logits = self.dropout_heads(getattr(self, 'class_' + slot)(pooled_output_aux)) + + token_logits = self.dropout_heads(getattr(self, 'token_' + slot)(sequence_output)) + start_logits, end_logits = token_logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + refer_logits = self.dropout_heads(getattr(self, 'refer_' + slot)(pooled_output_aux)) + + per_slot_class_logits[slot] = class_logits + per_slot_start_logits[slot] = start_logits + per_slot_end_logits[slot] = end_logits + per_slot_refer_logits[slot] = refer_logits + + # If there are no labels, don't compute loss + if class_label_id is not None and start_pos is not None and end_pos is not None and refer_id is not None: + # If we are on multi-GPU, split add a dimension + if len(start_pos[slot].size()) > 1: + start_pos[slot] = start_pos[slot].squeeze(-1) + if len(end_pos[slot].size()) > 1: + end_pos[slot] = end_pos[slot].squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) # This is a single index + start_pos[slot].clamp_(0, ignored_index) + end_pos[slot].clamp_(0, ignored_index) + + class_loss_fct = CrossEntropyLoss(reduction='none') + token_loss_fct = CrossEntropyLoss(reduction='none', ignore_index=ignored_index) + refer_loss_fct = CrossEntropyLoss(reduction='none') + + start_loss = token_loss_fct(start_logits, start_pos[slot]) + end_loss = token_loss_fct(end_logits, end_pos[slot]) + token_loss = (start_loss + end_loss) / 2.0 + + token_is_pointable = (start_pos[slot] > 0).float() + if not self.token_loss_for_nonpointable: + token_loss *= token_is_pointable + + refer_loss = refer_loss_fct(refer_logits, refer_id[slot]) + token_is_referrable = torch.eq(class_label_id[slot], self.refer_index).float() + if not self.refer_loss_for_nonpointable: + refer_loss *= token_is_referrable + + class_loss = class_loss_fct(class_logits, class_label_id[slot]) + + if self.refer_index > -1: + per_example_loss = (self.class_loss_ratio) * class_loss + ((1 - self.class_loss_ratio) / 2) * token_loss + ((1 - self.class_loss_ratio) / 2) * refer_loss + else: + per_example_loss = self.class_loss_ratio * class_loss + (1 - self.class_loss_ratio) * token_loss + + total_loss += per_example_loss.sum() + per_slot_per_example_loss[slot] = per_example_loss + + # add hidden states and attention if they are here + outputs = (total_loss,) + (per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits,) + outputs[2:] + + return outputs diff --git a/run_dst.py b/run_dst.py index bde263007dfaa6bf071a3c9547f7732f5c38183b..7809bb1192db14b9cc25f04bd13b028a61e878f7 100644 --- a/run_dst.py +++ b/run_dst.py @@ -36,10 +36,12 @@ from tqdm import tqdm, trange from tensorboardX import SummaryWriter -from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer) +from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer, + RobertaConfig, RobertaTokenizer) from transformers import (AdamW, get_linear_schedule_with_warmup) from modeling_bert_dst import (BertForDST) +from modeling_roberta_dst import (RobertaForDST) from data_processors import PROCESSORS from utils_dst import (convert_examples_to_features) from tensorlistdataset import (TensorListDataset) @@ -47,9 +49,11 @@ from tensorlistdataset import (TensorListDataset) logger = logging.getLogger(__name__) ALL_MODELS = tuple(BertConfig.pretrained_config_archive_map.keys()) +ALL_MODELS += tuple(RobertaConfig.pretrained_config_archive_map.keys()) MODEL_CLASSES = { 'bert': (BertConfig, BertForDST, BertTokenizer), + 'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer), } @@ -450,7 +454,8 @@ def load_and_cache_examples(args, model, tokenizer, processor, evaluate=False): 'use_history_labels': args.use_history_labels, 'swap_utterances': args.swap_utterances, 'label_value_repetitions': args.label_value_repetitions, - 'delexicalize_sys_utts': args.delexicalize_sys_utts} + 'delexicalize_sys_utts': args.delexicalize_sys_utts, + 'unk_token': '<unk>' if args.model_type == 'roberta' else '[UNK]'} if evaluate and args.predict_type == "dev": examples = processor.get_dev_examples(args.data_dir, processor_args) elif evaluate and args.predict_type == "test": diff --git a/run_dst_mtl.py b/run_dst_mtl.py new file mode 100644 index 0000000000000000000000000000000000000000..c077c7bd253c7de5f36117e10fe36dca6f064b5e --- /dev/null +++ b/run_dst_mtl.py @@ -0,0 +1,920 @@ +# coding=utf-8 +# +# Copyright 2020 Heinrich Heine University Duesseldorf +# +# Part of this code is based on the source code of BERT-DST +# (arXiv:1907.03040) +# Part of this code is based on the source code of Transformers +# (arXiv:1910.03771) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import random +import glob +import json +import math +import re +import gc + +import numpy as np +import torch +from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler) +from torch.utils.data.distributed import DistributedSampler +from tqdm import tqdm, trange +from collections import deque + +from tensorboardX import SummaryWriter + +from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer, + RobertaConfig, RobertaTokenizer) +from transformers import (AdamW, get_linear_schedule_with_warmup) + +from modeling_bert_dst import (BertForDST) +from modeling_roberta_dst import (RobertaForDST) +from data_processors import PROCESSORS +from utils_dst import (convert_examples_to_features, convert_aux_examples_to_features) +from tensorlistdataset import (TensorListDataset) + +logger = logging.getLogger(__name__) + +ALL_MODELS = tuple(BertConfig.pretrained_config_archive_map.keys()) +ALL_MODELS += tuple(RobertaConfig.pretrained_config_archive_map.keys()) + +MODEL_CLASSES = { + 'bert': (BertConfig, BertForDST, BertTokenizer), + 'roberta': (RobertaConfig, RobertaForDST, RobertaTokenizer), +} + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def to_list(tensor): + return tensor.detach().cpu().tolist() + + +def batch_to_device(batch, device): + batch_on_device = [] + for element in batch: + if isinstance(element, dict): + batch_on_device.append({k: v.to(device) for k, v in element.items()}) + else: + batch_on_device.append(element.to(device)) + return tuple(batch_on_device) + + +def train(args, train_dataset, aux_dataset, aux_task_def, features, model, tokenizer, processor, continue_from_global_step=0): + assert(not args.mtl_use or args.gradient_accumulation_steps == 1) + + """ Train the model """ + if args.local_rank in [-1, 0]: + tb_writer = SummaryWriter() + + args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) + train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) + train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, pin_memory=True) + aux_sampler = RandomSampler(aux_dataset) if args.local_rank == -1 else DistributedSampler(aux_dataset) + aux_dataloader = DataLoader(aux_dataset, sampler=aux_sampler, batch_size=args.train_batch_size, pin_memory=True) + + if args.max_steps > 0: + t_total = args.max_steps + args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 + else: + t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs + + if args.save_epochs > 0: + args.save_steps = t_total // args.num_train_epochs * args.save_epochs + + num_warmup_steps = int(t_total * args.warmup_proportion) + mtl_last_step = int(t_total * args.mtl_ratio) + + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ['bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total) + if args.fp16: + try: + from apex import amp + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) + + # multi-gpu training (should be after apex fp16 initialization) + model_single_gpu = model + if args.n_gpu > 1: + model = torch.nn.DataParallel(model_single_gpu) + + # Distributed training (should be after apex fp16 initialization) + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True) + + # Train! + logger.info("***** Running training *****") + logger.info(" Num examples = %d", len(train_dataset)) + logger.info(" Num Epochs = %d", args.num_train_epochs) + logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) + logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", t_total) + logger.info(" Warmup steps = %d", num_warmup_steps) + + if continue_from_global_step > 0: + logger.info("Fast forwarding to global step %d to resume training from latest checkpoint...", continue_from_global_step) + + global_step = 0 + tr_loss, logging_loss = 0.0, 0.0 + model.zero_grad() + train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) + set_seed(args) # Added here for reproductibility (even between python 2 and 3) + + # Iterators for aux tasks + aux_loss_diff = 0.0 + aux_logged_steps = 0 + tr_aux_loss = 0.0 + logging_aux_loss = 0.0 + loss_diff_queue = deque([], args.mtl_diff_window) + aux_iterator_dict = iter(aux_dataloader) + + for _ in train_iterator: + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) + + for step, batch in enumerate(epoch_iterator): + # If training is continued from a checkpoint, fast forward + # to the state of that checkpoint. + if global_step < continue_from_global_step: + if (step + 1) % args.gradient_accumulation_steps == 0: + scheduler.step() # Update learning rate schedule + global_step += 1 + continue + + #model.train() + batch = batch_to_device(batch, args.device) + + # This is what is forwarded to the "forward" def. + inputs = {'input_ids': batch[0], + 'input_mask': batch[1], + 'segment_ids': batch[2], + 'start_pos': batch[3], + 'end_pos': batch[4], + 'inform_slot_id': batch[5], + 'refer_id': batch[6], + 'diag_state': batch[7], + 'class_label_id': batch[8]} + + # MTL (optional) + if args.mtl_use and global_step < mtl_last_step: + if args.mtl_print_loss_diff: + model.eval() + with torch.no_grad(): + pre_aux_loss = model(**inputs)[0] + + try: + aux_batch = batch_to_device(next(aux_iterator_dict), args.device) + except StopIteration: + logger.info("Resetting iterator for aux task") + aux_iterator_dict = iter(aux_dataloader) + aux_batch = batch_to_device(next(aux_iterator_dict), args.device) + + aux_inputs = {'input_ids': aux_batch[0], + 'input_mask': aux_batch[1], + 'segment_ids': aux_batch[2], + 'start_pos': aux_batch[3], + 'end_pos': aux_batch[4], + 'class_label_id': aux_batch[5], + 'aux_task_def': aux_task_def} + model.train() + aux_outputs = model(**aux_inputs) + aux_loss = aux_outputs[0] + + if args.n_gpu > 1: + aux_loss = aux_loss.mean() # mean() to average on multi-gpu parallel (not distributed) training + + if args.fp16: + with amp.scale_loss(aux_loss, optimizer) as scaled_loss: + scaled_loss.backward() + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) + else: + aux_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + tr_aux_loss += aux_loss.item() + aux_logged_steps += 1 + + optimizer.step() + + model.zero_grad() + if args.mtl_print_loss_diff: + model.eval() + with torch.no_grad(): + post_aux_loss = model(**inputs)[0] + aux_loss_diff = pre_aux_loss - post_aux_loss + else: + post_aux_loss = 0.0 # TODO: move somewhere else... + + pre_aux_loss = post_aux_loss + + loss_diff_queue.append(aux_loss_diff) + + # Normal training + model.train() + + outputs = model(**inputs) + loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) + + if args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + if args.fp16: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) + else: + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + tr_loss += loss.item() + if (step + 1) % args.gradient_accumulation_steps == 0: + optimizer.step() + scheduler.step() # Update learning rate schedule + model.zero_grad() + global_step += 1 + + # Aux losses + if args.mtl_use and args.local_rank in [-1, 0] and args.logging_steps > 0: + tb_writer.add_scalar('loss_diff', aux_loss_diff, global_step) + # TODO: make nicer + if len(loss_diff_queue) > 0: + tb_writer.add_scalar('loss_diff_mean', sum(loss_diff_queue) / len(loss_diff_queue), global_step) + if aux_logged_steps > 0 and tr_aux_loss != logging_aux_loss: + tb_writer.add_scalar('aux_loss', tr_aux_loss - logging_aux_loss, global_step) + logging_aux_loss = tr_aux_loss + + # Log metrics + if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: + tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) + tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step) + logging_loss = tr_loss + + # Save model checkpoint + if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: + output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + model_to_save.save_pretrained(output_dir) + torch.save(args, os.path.join(output_dir, 'training_args.bin')) + logger.info("Saving model checkpoint to %s", output_dir) + + if args.max_steps > 0 and global_step > args.max_steps: + epoch_iterator.close() + break + + if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well + results = evaluate(args, model_single_gpu, tokenizer, processor, prefix=global_step) + for key, value in results.items(): + tb_writer.add_scalar('eval_{}'.format(key), value, global_step) + + # To prevent GPU memory to overflow + gc.collect() + torch.cuda.empty_cache() + + if args.max_steps > 0 and global_step > args.max_steps: + train_iterator.close() + break + + if args.local_rank in [-1, 0]: + tb_writer.close() + + tr_aux_loss /= aux_logged_steps + + return global_step, tr_loss / global_step, tr_aux_loss + + +def evaluate(args, model, tokenizer, processor, prefix=""): + dataset, features = load_and_cache_examples(args, model, tokenizer, processor, evaluate=True) + + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + args.eval_batch_size = args.per_gpu_eval_batch_size + eval_sampler = SequentialSampler(dataset) # Note that DistributedSampler samples randomly + eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + + # Eval! + logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info(" Num examples = %d", len(dataset)) + logger.info(" Batch size = %d", args.eval_batch_size) + all_results = [] + all_preds = [] + ds = {slot: 'none' for slot in model.slot_list} + with torch.no_grad(): + diag_state = {slot: torch.tensor([0 for _ in range(args.eval_batch_size)]).to(args.device) for slot in model.slot_list} + for batch in tqdm(eval_dataloader, desc="Evaluating"): + model.eval() + batch = batch_to_device(batch, args.device) + + # Reset dialog state if turn is first in the dialog. + turn_itrs = [features[i.item()].guid.split('-')[2] for i in batch[9]] + reset_diag_state = np.where(np.array(turn_itrs) == '0')[0] + for slot in model.slot_list: + for i in reset_diag_state: + diag_state[slot][i] = 0 + + with torch.no_grad(): + inputs = {'input_ids': batch[0], + 'input_mask': batch[1], + 'segment_ids': batch[2], + 'start_pos': batch[3], + 'end_pos': batch[4], + 'inform_slot_id': batch[5], + 'refer_id': batch[6], + 'diag_state': diag_state, + 'class_label_id': batch[8]} + unique_ids = [features[i.item()].guid for i in batch[9]] + values = [features[i.item()].values for i in batch[9]] + input_ids_unmasked = [features[i.item()].input_ids_unmasked for i in batch[9]] + inform = [features[i.item()].inform for i in batch[9]] + outputs = model(**inputs) + + # Update dialog state for next turn. + for slot in model.slot_list: + updates = outputs[2][slot].max(1)[1] + for i, u in enumerate(updates): + if u != 0: + diag_state[slot][i] = u + + results = eval_metric(model, inputs, outputs[0], outputs[1], outputs[2], outputs[3], outputs[4], outputs[5]) + preds, ds = predict_and_format(model, tokenizer, inputs, outputs[2], outputs[3], outputs[4], outputs[5], unique_ids, input_ids_unmasked, values, inform, prefix, ds) + all_results.append(results) + all_preds.append(preds) + + all_preds = [item for sublist in all_preds for item in sublist] # Flatten list + + # Generate final results + final_results = {} + for k in all_results[0].keys(): + final_results[k] = torch.stack([r[k] for r in all_results]).mean() + + # Write final predictions (for evaluation with external tool) + output_prediction_file = os.path.join(args.output_dir, "pred_res.%s.%s.json" % (args.predict_type, prefix)) + with open(output_prediction_file, "w") as f: + json.dump(all_preds, f, indent=2) + + return final_results + + +def eval_metric(model, features, total_loss, per_slot_per_example_loss, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits): + metric_dict = {} + per_slot_correctness = {} + for slot in model.slot_list: + per_example_loss = per_slot_per_example_loss[slot] + class_logits = per_slot_class_logits[slot] + start_logits = per_slot_start_logits[slot] + end_logits = per_slot_end_logits[slot] + refer_logits = per_slot_refer_logits[slot] + + class_label_id = features['class_label_id'][slot] + start_pos = features['start_pos'][slot] + end_pos = features['end_pos'][slot] + refer_id = features['refer_id'][slot] + + _, class_prediction = class_logits.max(1) + class_correctness = torch.eq(class_prediction, class_label_id).float() + class_accuracy = class_correctness.mean() + + # "is pointable" means whether class label is "copy_value", + # i.e., that there is a span to be detected. + token_is_pointable = torch.eq(class_label_id, model.class_types.index('copy_value')).float() + _, start_prediction = start_logits.max(1) + start_correctness = torch.eq(start_prediction, start_pos).float() + _, end_prediction = end_logits.max(1) + end_correctness = torch.eq(end_prediction, end_pos).float() + token_correctness = start_correctness * end_correctness + token_accuracy = (token_correctness * token_is_pointable).sum() / token_is_pointable.sum() + # NaNs mean that none of the examples in this batch contain spans. -> division by 0 + # The accuracy therefore is 1 by default. -> replace NaNs + if math.isnan(token_accuracy): + token_accuracy = torch.tensor(1.0, device=token_accuracy.device) + + token_is_referrable = torch.eq(class_label_id, model.class_types.index('refer') if 'refer' in model.class_types else -1).float() + _, refer_prediction = refer_logits.max(1) + refer_correctness = torch.eq(refer_prediction, refer_id).float() + refer_accuracy = refer_correctness.sum() / token_is_referrable.sum() + # NaNs mean that none of the examples in this batch contain referrals. -> division by 0 + # The accuracy therefore is 1 by default. -> replace NaNs + if math.isnan(refer_accuracy) or math.isinf(refer_accuracy): + refer_accuracy = torch.tensor(1.0, device=refer_accuracy.device) + + total_correctness = class_correctness * (token_is_pointable * token_correctness + (1 - token_is_pointable)) * (token_is_referrable * refer_correctness + (1 - token_is_referrable)) + total_accuracy = total_correctness.mean() + + loss = per_example_loss.mean() + metric_dict['eval_accuracy_class_%s' % slot] = class_accuracy + metric_dict['eval_accuracy_token_%s' % slot] = token_accuracy + metric_dict['eval_accuracy_refer_%s' % slot] = refer_accuracy + metric_dict['eval_accuracy_%s' % slot] = total_accuracy + metric_dict['eval_loss_%s' % slot] = loss + per_slot_correctness[slot] = total_correctness + + goal_correctness = torch.stack([c for c in per_slot_correctness.values()], 1).prod(1) + goal_accuracy = goal_correctness.mean() + metric_dict['eval_accuracy_goal'] = goal_accuracy + metric_dict['loss'] = total_loss + return metric_dict + + +def predict_and_format(model, tokenizer, features, per_slot_class_logits, per_slot_start_logits, per_slot_end_logits, per_slot_refer_logits, ids, input_ids_unmasked, values, inform, prefix, ds): + prediction_list = [] + dialog_state = ds + for i in range(len(ids)): + if int(ids[i].split("-")[2]) == 0: + dialog_state = {slot: 'none' for slot in model.slot_list} + + prediction = {} + prediction_addendum = {} + for slot in model.slot_list: + class_logits = per_slot_class_logits[slot][i] + start_logits = per_slot_start_logits[slot][i] + end_logits = per_slot_end_logits[slot][i] + refer_logits = per_slot_refer_logits[slot][i] + + input_ids = features['input_ids'][i].tolist() + class_label_id = int(features['class_label_id'][slot][i]) + start_pos = int(features['start_pos'][slot][i]) + end_pos = int(features['end_pos'][slot][i]) + refer_id = int(features['refer_id'][slot][i]) + + class_prediction = int(class_logits.argmax()) + start_prediction = int(start_logits.argmax()) + end_prediction = int(end_logits.argmax()) + refer_prediction = int(refer_logits.argmax()) + + prediction['guid'] = ids[i].split("-") + prediction['class_prediction_%s' % slot] = class_prediction + prediction['class_label_id_%s' % slot] = class_label_id + prediction['start_prediction_%s' % slot] = start_prediction + prediction['start_pos_%s' % slot] = start_pos + prediction['end_prediction_%s' % slot] = end_prediction + prediction['end_pos_%s' % slot] = end_pos + prediction['refer_prediction_%s' % slot] = refer_prediction + prediction['refer_id_%s' % slot] = refer_id + prediction['input_ids_%s' % slot] = input_ids + + if class_prediction == model.class_types.index('dontcare'): + dialog_state[slot] = 'dontcare' + elif class_prediction == model.class_types.index('copy_value'): + input_tokens = tokenizer.convert_ids_to_tokens(input_ids_unmasked[i]) + dialog_state[slot] = ' '.join(input_tokens[start_prediction:end_prediction + 1]) + dialog_state[slot] = re.sub("(^| )##", "", dialog_state[slot]) + elif 'true' in model.class_types and class_prediction == model.class_types.index('true'): + dialog_state[slot] = 'true' + elif 'false' in model.class_types and class_prediction == model.class_types.index('false'): + dialog_state[slot] = 'false' + elif class_prediction == model.class_types.index('inform'): + dialog_state[slot] = 'ยงยง' + inform[i][slot] + # Referral case is handled below + + prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot] + prediction_addendum['slot_groundtruth_%s' % slot] = values[i][slot] + + # Referral case. All other slot values need to be seen first in order + # to be able to do this correctly. + for slot in model.slot_list: + class_logits = per_slot_class_logits[slot][i] + refer_logits = per_slot_refer_logits[slot][i] + + class_prediction = int(class_logits.argmax()) + refer_prediction = int(refer_logits.argmax()) + + if 'refer' in model.class_types and class_prediction == model.class_types.index('refer'): + # Only slots that have been mentioned before can be referred to. + # One can think of a situation where one slot is referred to in the same utterance. + # This phenomenon is however currently not properly covered in the training data + # label generation process. + dialog_state[slot] = dialog_state[model.slot_list[refer_prediction - 1]] + prediction_addendum['slot_prediction_%s' % slot] = dialog_state[slot] # Value update + + prediction.update(prediction_addendum) + prediction_list.append(prediction) + + return prediction_list, dialog_state + + +def load_and_cache_examples(args, model, tokenizer, processor, evaluate=False): + if args.local_rank not in [-1, 0] and not evaluate: + torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache + + # Load data features from cache or dataset file + cached_file = os.path.join(os.path.dirname(args.output_dir), 'cached_{}_features'.format( + args.predict_type if evaluate else 'train')) + if os.path.exists(cached_file) and not args.overwrite_cache: # and not output_examples: + logger.info("Loading features from cached file %s", cached_file) + features = torch.load(cached_file) + else: + logger.info("Creating features from dataset file at %s", args.data_dir) + processor_args = {'append_history': args.append_history, + 'use_history_labels': args.use_history_labels, + 'swap_utterances': args.swap_utterances, + 'label_value_repetitions': args.label_value_repetitions, + 'delexicalize_sys_utts': args.delexicalize_sys_utts, + 'unk_token': '<unk>' if args.model_type == 'roberta' else '[UNK]'} + if evaluate and args.predict_type == "dev": + examples = processor.get_dev_examples(args.data_dir, processor_args) + elif evaluate and args.predict_type == "test": + examples = processor.get_test_examples(args.data_dir, processor_args) + else: + examples = processor.get_train_examples(args.data_dir, processor_args) + features = convert_examples_to_features(examples=examples, + slot_list=model.slot_list, + class_types=model.class_types, + model_type=args.model_type, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length, + slot_value_dropout=(0.0 if evaluate else args.svd)) + if args.local_rank in [-1, 0]: + logger.info("Saving features into cached file %s", cached_file) + torch.save(features, cached_file) + + if args.local_rank == 0 and not evaluate: + torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache + + # Convert to Tensors and build dataset + all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) + all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) + all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) + all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) + f_start_pos = [f.start_pos for f in features] + f_end_pos = [f.end_pos for f in features] + f_inform_slot_ids = [f.inform_slot for f in features] + f_refer_ids = [f.refer_id for f in features] + f_diag_state = [f.diag_state for f in features] + f_class_label_ids = [f.class_label_id for f in features] + all_start_positions = {} + all_end_positions = {} + all_inform_slot_ids = {} + all_refer_ids = {} + all_diag_state = {} + all_class_label_ids = {} + for s in model.slot_list: + all_start_positions[s] = torch.tensor([f[s] for f in f_start_pos], dtype=torch.long) + all_end_positions[s] = torch.tensor([f[s] for f in f_end_pos], dtype=torch.long) + all_inform_slot_ids[s] = torch.tensor([f[s] for f in f_inform_slot_ids], dtype=torch.long) + all_refer_ids[s] = torch.tensor([f[s] for f in f_refer_ids], dtype=torch.long) + all_diag_state[s] = torch.tensor([f[s] for f in f_diag_state], dtype=torch.long) + all_class_label_ids[s] = torch.tensor([f[s] for f in f_class_label_ids], dtype=torch.long) + dataset = TensorListDataset(all_input_ids, all_input_mask, all_segment_ids, + all_start_positions, all_end_positions, + all_inform_slot_ids, + all_refer_ids, + all_diag_state, + all_class_label_ids, all_example_index) + + return dataset, features + + +def load_and_cache_aux_examples(args, model, tokenizer, aux_task_def=None): + if aux_task_def is None: + return None, None + + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache + + # Load data features from cache or dataset file + cached_file = os.path.join(os.path.dirname(args.output_dir), 'cached_aux_features') + if os.path.exists(cached_file) and not args.overwrite_cache: # and not output_examples: + logger.info("Loading features from cached file %s", cached_file) + features = torch.load(cached_file) + else: + processor = PROCESSORS['aux_task']() + logger.info("Creating features from aux task dataset file at %s", args.mtl_data_dir) + examples = processor.get_aux_task_examples(data_dir=args.mtl_data_dir, + data_name=args.mtl_train_dataset, + max_seq_length=args.max_seq_length) + + features = convert_aux_examples_to_features(examples=examples, aux_task_def=aux_task_def, max_seq_length=args.max_seq_length) + + if args.local_rank in [-1, 0]: + logger.info("Saving features into cached file %s", cached_file) + torch.save(features, cached_file) + + if args.local_rank == 0: + torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache + + # Convert to Tensors and build dataset + all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) + all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) + all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) + all_start_positions = torch.tensor([f.start_pos for f in features], dtype=torch.long) + all_end_positions = torch.tensor([f.end_pos for f in features], dtype=torch.long) + all_label = torch.tensor([f.label for f in features], dtype=torch.long) + dataset = TensorListDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions, all_label) + + return dataset, features + + +def main(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument("--task_name", default=None, type=str, required=True, + help="Name of the task (e.g., multiwoz21).") + parser.add_argument("--data_dir", default=None, type=str, required=True, + help="Task database.") + parser.add_argument("--dataset_config", default=None, type=str, required=True, + help="Dataset configuration file.") + parser.add_argument("--predict_type", default=None, type=str, required=True, + help="Portion of the data to perform prediction on (e.g., dev, test).") + parser.add_argument("--model_type", default=None, type=str, required=True, + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) + parser.add_argument("--model_name_or_path", default=None, type=str, required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) + parser.add_argument("--output_dir", default=None, type=str, required=True, + help="The output directory where the model checkpoints and predictions will be written.") + + # Other parameters + parser.add_argument("--config_name", default="", type=str, + help="Pretrained config name or path if not the same as model_name") + parser.add_argument("--tokenizer_name", default="", type=str, + help="Pretrained tokenizer name or path if not the same as model_name") + + parser.add_argument("--max_seq_length", default=384, type=int, + help="Maximum input length after tokenization. Longer sequences will be truncated, shorter ones padded.") + parser.add_argument("--do_train", action='store_true', + help="Whether to run training.") + parser.add_argument("--do_eval", action='store_true', + help="Whether to run eval on the <predict_type> set.") + parser.add_argument("--evaluate_during_training", action='store_true', + help="Rul evaluation during training at each logging step.") + parser.add_argument("--do_lower_case", action='store_true', + help="Set this flag if you are using an uncased model.") + + parser.add_argument("--dropout_rate", default=0.3, type=float, + help="Dropout rate for BERT representations.") + parser.add_argument("--heads_dropout", default=0.0, type=float, + help="Dropout rate for classification heads.") + parser.add_argument("--class_loss_ratio", default=0.8, type=float, + help="The ratio applied on class loss in total loss calculation. " + "Should be a value in [0.0, 1.0]. " + "The ratio applied on token loss is (1-class_loss_ratio)/2. " + "The ratio applied on refer loss is (1-class_loss_ratio)/2.") + parser.add_argument("--token_loss_for_nonpointable", action='store_true', + help="Whether the token loss for classes other than copy_value contribute towards total loss.") + parser.add_argument("--refer_loss_for_nonpointable", action='store_true', + help="Whether the refer loss for classes other than refer contribute towards total loss.") + + parser.add_argument("--append_history", action='store_true', + help="Whether or not to append the dialog history to each turn.") + parser.add_argument("--use_history_labels", action='store_true', + help="Whether or not to label the history as well.") + parser.add_argument("--swap_utterances", action='store_true', + help="Whether or not to swap the turn utterances (default: sys|usr, swapped: usr|sys).") + parser.add_argument("--label_value_repetitions", action='store_true', + help="Whether or not to label values that have been mentioned before.") + parser.add_argument("--delexicalize_sys_utts", action='store_true', + help="Whether or not to delexicalize the system utterances.") + parser.add_argument("--class_aux_feats_inform", action='store_true', + help="Whether or not to use the identity of informed slots as auxiliary featurs for class prediction.") + parser.add_argument("--class_aux_feats_ds", action='store_true', + help="Whether or not to use the identity of slots in the current dialog state as auxiliary featurs for class prediction.") + + parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, + help="Batch size per GPU/CPU for training.") + parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, + help="Batch size per GPU/CPU for evaluation.") + parser.add_argument("--learning_rate", default=5e-5, type=float, + help="The initial learning rate for Adam.") + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument("--weight_decay", default=0.0, type=float, + help="Weight deay if we apply some.") + parser.add_argument("--adam_epsilon", default=1e-8, type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, + help="Max gradient norm.") + parser.add_argument("--num_train_epochs", default=3.0, type=float, + help="Total number of training epochs to perform.") + parser.add_argument("--max_steps", default=-1, type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.") + parser.add_argument("--warmup_proportion", default=0.0, type=float, + help="Linear warmup over warmup_proportion * steps.") + parser.add_argument("--svd", default=0.0, type=float, + help="Slot value dropout ratio (default: 0.0)") + + parser.add_argument('--logging_steps', type=int, default=50, + help="Log every X updates steps.") + parser.add_argument('--save_steps', type=int, default=0, + help="Save checkpoint every X updates steps. Overwritten by --save_epochs.") + parser.add_argument('--save_epochs', type=int, default=0, + help="Save checkpoint every X epochs. Overrides --save_steps.") + parser.add_argument("--eval_all_checkpoints", action='store_true', + help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") + parser.add_argument("--no_cuda", action='store_true', + help="Whether not to use CUDA when available") + parser.add_argument('--overwrite_output_dir', action='store_true', + help="Overwrite the content of the output directory") + parser.add_argument('--overwrite_cache', action='store_true', + help="Overwrite the cached training and evaluation sets") + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + + parser.add_argument("--local_rank", type=int, default=-1, + help="local_rank for distributed training on gpus") + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") + parser.add_argument('--fp16_opt_level', type=str, default='O1', + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html") + + # TODO + parser.add_argument('--mtl_use', action='store_true', help="") + parser.add_argument('--mtl_task_def', type=str, default="/home/heckmi_hhu/tools/trippy/aux_task_def.json", help="") # TODO + parser.add_argument('--mtl_train_dataset', type=str, default="", help="cola|mnli|mrpc|qnli|qqp|rte|sst|wnli|squad|squad-v2") + parser.add_argument("--mtl_data_dir", type=str, default="/home/heckmi_hhu/data/glue/canonical_data/bert_base_uncased_lower", help="") # TODO + parser.add_argument("--mtl_ratio", type=float, default=1.0, help="") # TODO + parser.add_argument("--mtl_diff_window", type=int, default=10) + parser.add_argument('--mtl_print_loss_diff', action='store_true', help="") + + args = parser.parse_args() + + assert(args.warmup_proportion >= 0.0 and args.warmup_proportion <= 1.0) + assert(args.svd >= 0.0 and args.svd <= 1.0) + assert(args.class_aux_feats_ds is False or args.per_gpu_eval_batch_size == 1) + assert(not args.class_aux_feats_inform or args.per_gpu_eval_batch_size == 1) + assert(not args.class_aux_feats_ds or args.per_gpu_eval_batch_size == 1) + + assert(not args.mtl_use or args.gradient_accumulation_steps == 1) + assert(args.mtl_ratio >= 0.0 and args.mtl_ratio <= 1.0) + + task_name = args.task_name.lower() + if task_name not in PROCESSORS: + raise ValueError("Task not found: %s" % (task_name)) + + # Load the MTL task definitions + if args.mtl_use: + with open(args.mtl_task_def, "r", encoding='utf-8') as reader: + aux_task_defs = json.load(reader) + aux_task_def = aux_task_defs[args.mtl_train_dataset] + else: + aux_task_def = None + + processor = PROCESSORS[task_name](args.dataset_config) + dst_slot_list = processor.slot_list + dst_class_types = processor.class_types + dst_class_labels = len(dst_class_types) + + # Setup CUDA, GPU & distributed training + if args.local_rank == -1 or args.no_cuda: + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend='nccl') + args.n_gpu = 1 + args.device = device + + # Setup logging + logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) + + # Set seed + set_seed(args) + + # Load pretrained model and tokenizer + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab + + args.model_type = args.model_type.lower() + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) + + # Add DST specific parameters to config + config.dst_dropout_rate = args.dropout_rate + config.dst_heads_dropout_rate = args.heads_dropout + config.dst_class_loss_ratio = args.class_loss_ratio + config.dst_token_loss_for_nonpointable = args.token_loss_for_nonpointable + config.dst_refer_loss_for_nonpointable = args.refer_loss_for_nonpointable + config.dst_class_aux_feats_inform = args.class_aux_feats_inform + config.dst_class_aux_feats_ds = args.class_aux_feats_ds + config.dst_slot_list = dst_slot_list + config.dst_class_types = dst_class_types + config.dst_class_labels = dst_class_labels + config.aux_task_def = aux_task_def + + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) + model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) + + logger.info("Updated model config: %s" % config) + + if args.local_rank == 0: + torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab + + model.to(args.device) + + logger.info("Training/evaluation parameters %s", args) + + # Training + if args.do_train: + # If output files already exists, assume to continue training from latest checkpoint (unless overwrite_output_dir is set) + continue_from_global_step = 0 # If set to 0, start training from the beginning + if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: + checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/*/' + WEIGHTS_NAME, recursive=True))) + if len(checkpoints) > 0: + checkpoint = checkpoints[-1] + logger.info("Resuming training from the latest checkpoint: %s", checkpoint) + continue_from_global_step = int(checkpoint.split('-')[-1]) + model = model_class.from_pretrained(checkpoint) + model.to(args.device) + + train_dataset, features = load_and_cache_examples(args, model, tokenizer, processor, evaluate=False) + aux_dataset, _ = load_and_cache_aux_examples(args, model, tokenizer, aux_task_def) + global_step, tr_loss, aux_loss = train(args, train_dataset, aux_dataset, aux_task_def, features, model, tokenizer, processor, continue_from_global_step) + logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) + + # Save the trained model and the tokenizer + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + # Create output directory if needed + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + logger.info("Saving model checkpoint to %s", args.output_dir) + # Save a trained model, configuration and tokenizer using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + model_to_save.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) + + # Load a trained model and vocabulary that you have fine-tuned + model = model_class.from_pretrained(args.output_dir) + tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) + model.to(args.device) + + # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory + results = [] + if args.do_eval and args.local_rank in [-1, 0]: + output_eval_file = os.path.join(args.output_dir, "eval_res.%s.json" % (args.predict_type)) + checkpoints = [args.output_dir] + if args.eval_all_checkpoints: + checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) + logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs + + logger.info("Evaluate the following checkpoints: %s", checkpoints) + + for cItr, checkpoint in enumerate(checkpoints): + # Reload the model + global_step = checkpoint.split('-')[-1] + if cItr == len(checkpoints) - 1: + global_step = "final" + model = model_class.from_pretrained(checkpoint) + model.to(args.device) + + # Evaluate + result = evaluate(args, model, tokenizer, processor, prefix=global_step) + result_dict = {k: float(v) for k, v in result.items()} + result_dict["global_step"] = global_step + results.append(result_dict) + + for key in sorted(result_dict.keys()): + logger.info("%s = %s", key, str(result_dict[key])) + + with open(output_eval_file, "w") as f: + json.dump(results, f, indent=2) + + return results + + +if __name__ == "__main__": + main() diff --git a/utils_dst.py b/utils_dst.py index e91971da86f4f9185f3428e52a09f48b5247343d..8f91df8c8a824b13cadf87bb4fe551f1b7ec026d 100644 --- a/utils_dst.py +++ b/utils_dst.py @@ -122,6 +122,26 @@ class InputFeatures(object): self.class_label_id = class_label_id +class AuxInputFeatures(object): + """A single set of features of data.""" + + def __init__(self, + input_ids, + input_mask, + segment_ids, + start_pos=None, + end_pos=None, + label=None, + uid="NONE"): + self.uid = uid + self.input_ids = input_ids + self.input_mask = input_mask + self.segment_ids = segment_ids + self.start_pos = start_pos + self.end_pos = end_pos + self.label = label + + def convert_examples_to_features(examples, slot_list, class_types, model_type, tokenizer, max_seq_length, slot_value_dropout=0.0): """Loads a data file into a list of `InputBatch`s.""" @@ -131,6 +151,12 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t 'UNK_TOKEN': '[UNK]', 'SEP_TOKEN': '[SEP]', 'TOKEN_CORRECTION': 4} + elif model_type == 'roberta': + model_specs = {'MODEL_TYPE': 'roberta', + 'CLS_TOKEN': '<s>', + 'UNK_TOKEN': '<unk>', + 'SEP_TOKEN': '</s>', + 'TOKEN_CORRECTION': 6} else: logger.error("Unknown model type (%s). Aborting." % (model_type)) exit(1) @@ -148,6 +174,8 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t token_labels = [] for token, token_label, joint_label in zip(text, text_label, joint_text_label): token = convert_to_unicode(token) + if model_specs['MODEL_TYPE'] == 'roberta': + token = ' ' + token sub_tokens = tokenizer.tokenize(token) # Most time intensive step tokens_unmasked.extend(sub_tokens) if slot_value_dropout == 0.0 or joint_label == 0: @@ -187,6 +215,7 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t # Modifies `tokens_a` and `tokens_b` in place so that the total # length is less than the specified length. # Account for [CLS], [SEP], [SEP], [SEP] with "- 4" (BERT) + # Account for <s>, </s></s>, </s></s>, </s> with "- 6" (RoBERTa) if len(tokens_a) + len(tokens_b) + len(history) > max_seq_length - model_specs['TOKEN_CORRECTION']: logger.info("Truncate Example %s. Total len=%d." % (guid, len(tokens_a) + len(tokens_b) + len(history))) input_text_too_long = True @@ -197,16 +226,20 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t def _get_token_label_ids(token_labels_a, token_labels_b, token_labels_history, max_seq_length, model_specs): token_label_ids = [] - token_label_ids.append(0) # [CLS] + token_label_ids.append(0) # [CLS]/<s> for token_label in token_labels_a: token_label_ids.append(token_label) - token_label_ids.append(0) # [SEP] + token_label_ids.append(0) # [SEP]/</s></s> + if model_specs['MODEL_TYPE'] == 'roberta': + token_label_ids.append(0) for token_label in token_labels_b: token_label_ids.append(token_label) - token_label_ids.append(0) # [SEP] + token_label_ids.append(0) # [SEP]/</s></s> + if model_specs['MODEL_TYPE'] == 'roberta': + token_label_ids.append(0) for token_label in token_labels_history: token_label_ids.append(token_label) - token_label_ids.append(0) # [SEP] + token_label_ids.append(0) # [SEP]/</s> while len(token_label_ids) < max_seq_length: token_label_ids.append(0) # padding assert len(token_label_ids) == max_seq_length @@ -258,23 +291,45 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t segment_ids.append(0) tokens.append(model_specs['SEP_TOKEN']) segment_ids.append(0) - for token in tokens_b: - tokens.append(token) + if model_specs['MODEL_TYPE'] == 'roberta': + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(0) + if model_specs['MODEL_TYPE'] != 'roberta': + for token in tokens_b: + tokens.append(token) + segment_ids.append(1) + tokens.append(model_specs['SEP_TOKEN']) segment_ids.append(1) - tokens.append(model_specs['SEP_TOKEN']) - segment_ids.append(1) + else: + for token in tokens_b: + tokens.append(token) + segment_ids.append(0) + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(0) + if model_specs['MODEL_TYPE'] == 'roberta': + tokens.append(model_specs['SEP_TOKEN']) + segment_ids.append(0) for token in history: tokens.append(token) - segment_ids.append(1) + if model_specs['MODEL_TYPE'] == 'roberta': + segment_ids.append(0) + else: + segment_ids.append(1) tokens.append(model_specs['SEP_TOKEN']) - segment_ids.append(1) + if model_specs['MODEL_TYPE'] == 'roberta': + segment_ids.append(0) + else: + segment_ids.append(1) input_ids = tokenizer.convert_tokens_to_ids(tokens) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1] * len(input_ids) # Zero-pad up to the sequence length. while len(input_ids) < max_seq_length: - input_ids.append(0) + if model_specs['MODEL_TYPE'] == 'roberta': + input_ids.append(1) + else: + input_ids.append(0) input_mask.append(0) segment_ids.append(0) assert len(input_ids) == max_seq_length @@ -408,6 +463,71 @@ def convert_examples_to_features(examples, slot_list, class_types, model_type, t return features +# TODO: Don't start with pre-tokenized data, instead do tokenization here +def convert_aux_examples_to_features(examples, aux_task_def, max_seq_length): + """Loads a data file into a list of AuxInputFeatures.""" + + def _get_transformer_input(tokens, type_id, max_seq_length): + assert len(tokens) == len(type_id) + input_ids = tokens + segment_ids = type_id + # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. + input_mask = [1] * len(type_id) + # Zero-pad up to the sequence length. + while len(input_ids) < max_seq_length: + input_ids.append(0) + input_mask.append(0) + segment_ids.append(0) + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + return input_ids, input_mask, segment_ids + + features = [] + # Convert single example + for (example_index, example) in enumerate(examples): + if example_index % 1000 == 0: + logger.info("Writing example %d of %d" % (example_index, len(examples))) + + uid = example['uid'] + label = example['label'] + tokens = example['token_id'] + type_id = example['type_id'] + + start = 0 + end = 0 + if aux_task_def['task_type'] == "span": + start = example['start_position'] + end = example['end_position'] + + # TODO: implement truncation + assert len(tokens) <= max_seq_length + + input_ids, input_mask, segment_ids = _get_transformer_input(tokens, type_id, max_seq_length) + + if example_index < 10: + logger.info("*** Example ***") + logger.info("uid: %s" % (uid)) + logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) + logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) + logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) + logger.info("start_pos: %s" % str(start)) + logger.info("end_pos: %s" % str(end)) + logger.info("label: %s" % str(label)) + + features.append( + AuxInputFeatures( + uid=uid, + input_ids=input_ids, + input_mask=input_mask, + segment_ids=segment_ids, + start_pos=start, + end_pos=end, + label=label)) + + return features + + # From bert.tokenization (TF code) def convert_to_unicode(text): """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""