Skip to content
Snippets Groups Projects
Unverified Commit dbe83164 authored by zhuqi's avatar zhuqi Committed by GitHub
Browse files

Merge pull request #55 from ConvLab/pre-training

t5dst on multiwoz21: jga 53.2 slot f1 92.0
parents c66e7429 5f2f6a44
No related branches found
No related tags found
No related merge requests found
n_gpus=1
task_name="dst"
dataset_name=$1
speaker="user"
context_window_size=100
data_dir="data/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
output_dir="output/${task_name}/${dataset_name}/${speaker}/context_${context_window_size}"
cache_dir="../cache"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
test_file="${data_dir}/test.json"
metric_name_or_path="dst_metric.py"
metric_for_best_model="accuracy"
source_column="context"
target_column="state_seq"
truncation_side="left"
max_source_length=1024
max_target_length=512
model_name_or_path="t5-small"
per_device_train_batch_size=64
per_device_eval_batch_size=64
gradient_accumulation_steps=2
lr=1e-3
num_train_epochs=10
python ../create_data.py -t ${task_name} -d ${dataset_name} -s ${speaker} -c ${context_window_size} -l t5-small
python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \
--task_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--do_predict \
--save_strategy epoch \
--evaluation_strategy epoch \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--debug underflow_overflow \
--adafactor \
--gradient_checkpointing
python -m torch.distributed.launch \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \
--task_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${output_dir} \
--do_predict \
--predict_with_generate \
--metric_name_or_path ${metric_name_or_path} \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size}
python merge_predict_res.py -d ${dataset_name} -s ${speaker} -c ${context_window_size} -p ${output_dir}/generated_predictions.json
python ../../../dst/evaluate_unified_datasets.py -p ${output_dir}/predictions.json
import logging
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
from convlab2.nlg.nlg import NLG
from convlab2.base_models.t5.nlu.serialization import serialize_dialogue_acts
from convlab2.util.custom_util import model_downloader
class T5NLG(NLG):
def __init__(self, speaker, context_window_size, model_name_or_path, model_file=None, device='cuda'):
assert speaker in ['user', 'system']
self.speaker = speaker
self.opponent = 'system' if speaker == 'user' else 'user'
self.context_window_size = context_window_size
self.use_context = context_window_size > 0
model_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.exists(model_name_or_path):
model_downloader(model_dir, model_file)
self.config = AutoConfig.from_pretrained(model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config)
self.model.eval()
self.device = device if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
logging.info("T5NLG loaded")
def generate(self, dialogue_acts, context=list()):
if self.use_context:
if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1:
context = [item[1] for item in context]
utts = context + ['']
else:
utts = ['']
input_seq = '\n'.join([f"{self.opponent if (i % 2) == (len(utts) % 2) else self.speaker}: {utt}" for i, utt in enumerate(utts)])
dialogue_acts_seq = serialize_dialogue_acts(dialogue_acts)
input_seq = dialogue_acts_seq + '\n' + input_seq
print(input_seq)
input_seq = self.tokenizer(input_seq, return_tensors="pt").to(self.device)
# print(input_seq)
output_seq = self.model.generate(**input_seq, max_length=256)
# print(output_seq)
output_seq = self.tokenizer.decode(output_seq[0], skip_special_tokens=True)
# print(output_seq)
return output_seq
if __name__ == '__main__':
das = [
{
"categorical": [],
"non-categorical": [],
"binary": [
{
"intent": "request",
"domain": "taxi",
"slot": "leave at"
},
{
"intent": "request",
"domain": "taxi",
"slot": "arrive by"
}
]
},
{
"categorical": [],
"non-categorical": [
{
"intent": "inform",
"domain": "taxi",
"slot": "type",
"value": "blue honda",
"start": 38,
"end": 48
},
{
"intent": "inform",
"domain": "taxi",
"slot": "phone",
"value": "07218068540",
"start": 67,
"end": 78
}
],
"binary": [
{
"intent": "book",
"domain": "taxi",
"slot": ""
}
]
},
{
"categorical": [],
"non-categorical": [],
"binary": [
{
"intent": "reqmore",
"domain": "general",
"slot": ""
}
]
},
{
"categorical": [],
"non-categorical": [],
"binary": [
{
"intent": "bye",
"domain": "general",
"slot": ""
}
]
}
]
contexts = [
["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton."],
["I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
"What time do you want to leave and what time do you want to arrive by?",
"I want to leave after 17:15."],
["I want to leave after 17:15.",
"Booking completed! your taxi will be blue honda Contact number is 07218068540",
"Thank you for all the help! I appreciate it."],
["Thank you for all the help! I appreciate it.",
"You are welcome. Is there anything else I can help you with today?"
"No, I am all set. Have a nice day. Bye."],
]
nlg = T5NLG(speaker='system', context_window_size=0, model_name_or_path='output/nlg/multiwoz21/system/context_3')
for da, context in zip(das, contexts):
print(da)
print(nlg.generate(da, context))
print()
import logging import logging
import os import os
import json
import torch import torch
from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
from convlab2.nlu.nlu import NLU from convlab2.nlu.nlu import NLU
from convlab2.base_models.t5.nlu.serialization import deserialize_dialogue_acts from convlab2.base_models.t5.nlu.serialization import deserialize_dialogue_acts
...@@ -16,7 +14,6 @@ class T5NLU(NLU): ...@@ -16,7 +14,6 @@ class T5NLU(NLU):
self.opponent = 'system' if speaker == 'user' else 'user' self.opponent = 'system' if speaker == 'user' else 'user'
self.context_window_size = context_window_size self.context_window_size = context_window_size
self.use_context = context_window_size > 0 self.use_context = context_window_size > 0
self.prefix = "parse the dialogue action of the last utterance: "
model_dir = os.path.dirname(os.path.abspath(__file__)) model_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.exists(model_name_or_path): if not os.path.exists(model_name_or_path):
...@@ -38,7 +35,7 @@ class T5NLU(NLU): ...@@ -38,7 +35,7 @@ class T5NLU(NLU):
utts = context + [utterance] utts = context + [utterance]
else: else:
utts = [utterance] utts = [utterance]
input_seq = ' '.join([f"{self.opponent if (i % 2) == (len(utts) % 2) else self.speaker}: {utt}" for i, utt in enumerate(utts)]) input_seq = '\n'.join([f"{self.opponent if (i % 2) == (len(utts) % 2) else self.speaker}: {utt}" for i, utt in enumerate(utts)])
# print(input_seq) # print(input_seq)
input_seq = self.tokenizer(input_seq, return_tensors="pt").to(self.device) input_seq = self.tokenizer(input_seq, return_tensors="pt").to(self.device)
# print(input_seq) # print(input_seq)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment