Skip to content
Snippets Groups Projects
Commit 4d7d80c5 authored by pengbaolin's avatar pengbaolin
Browse files

e2e-soloist

parent 8f0ad71d
Branches
No related tags found
No related merge requests found
# SOLOIST
On top of the pre-trained LMs, SOLOIST subsumes different components of task-oriented dialogs into a single model and emplies a pre-training then fine-tuning schema to build task bots.
## Usage
Follow the instruction under each dataset's directory to prepare data training and evaluation.
#### Dataset Creation
Create datasets of three settings.
```sh
$ cd multiwoz
$ python script/create_dataset.py joint
$ python script/create_dataset.py transfer
$ python script/create_dataset.py single
```
#### Train a model
```sh
$ python train.py --model_name_or_path t5-base --dataset_name e2e_dataloader.py --output_dir ./model --per_device_train_batch_size=2 --per_device_eval_batch_size=2 --max_target_length 128 --max_length 512 --num_train_epochs 50 --save_steps 10000 --preprocessing_num_workers 1 --num_beams 5 --learning_rate 5e-5 --dataset_config_name SINGLE --logging_steps 100
```
The model (`pytorch_model.bin`) will be saved under the `output_dir` of the config file. The script will save predictions for validation/test every epoch.
#### Test a model
The result will be saved under the `output_dir` of the config file. For evaluation, a 3rd party package is used. Please follow the instructions at https://github.com/Tomiinek/MultiWOZ_Evaluation
## Performance on unified format datasets of different settings
Note that we use almost the same hyper-parameters for different settings, which may not be optimal.
<table>
<thead>
<tr>
<th></th>
<th colspan=2>MultiWOZ 2.1</th>
<th colspan=2>SGD</th>
<th colspan=2>Taskmaster-1</th>
</tr>
</thead>
<thead>
<tr>
<th>Model</th>
<th>Combined</th><th>BLEU</th>
<th>Slot F1</th><th>BLEU</th>
<th>Slot F1</th><th>BLEU</th>
</tr>
</thead>
<tbody>
<tr>
<td>SOLOIST w/o pre-training</td>
<td>67.0</td><td>16.8</td>
<td>56.9</td><td>11.2</td>
<td>8.5</td><td>28.0</td>
</tr>
<tr>
<td>SOLOIST </td>
<td>71.4</td><td>17.1</td>
<td>69.7</td><td>23.1</td>
<td>9.2</td><td>29.2</td>
</tr>
</tbody>
</table>
- Slot F1: F1 measure of the delexicalized slot predictions over the corpus.
## References
```
@article{peng2021soloist,
title={Soloist: Buildingtask bots at scale with transfer learning and machine teaching},
author={Peng, Baolin and Li, Chunyuan and Li, Jinchao and Shayandeh, Shahin and Liden, Lars and Gao, Jianfeng},
journal={Transactions of the Association for Computational Linguistics},
volume={9},
pages={807--824},
year={2021},
publisher={MIT Press}
}
@article{nekvinda2021shades,
title={Shades of BLEU, flavours of success: The case of MultiWOZ},
author={Nekvinda, Tom{\'a}{\v{s}} and Du{\v{s}}ek, Ond{\v{r}}ej},
journal={arXiv preprint arXiv:2106.05555},
year={2021}
}
@article{peng2022godel,
title={GODEL: Large-Scale Pre-Training for Goal-Directed Dialog},
author={Peng, Baolin and Galley, Michel and He, Pengcheng and Brockett, Chris and Liden, Lars and Nouri, Elnaz and Yu, Zhou and Dolan, Bill and Gao, Jianfeng},
journal={arXiv preprint arXiv:2206.11309},
year={2022}
}
```
\ No newline at end of file
import datasets
import jsonlines
import random
# coding=utf-8
# Copyright 2020 HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Corpus for E2E Dialog Modeling"""
import csv
import datasets
_DESCRIPTION = """\
E2E Dialog Modeling
"""
_CITATION = """\
E2E Dialog Modeling
"""
_DOWNLOAD_URL = ""
_WEBPAGE = ""
class UnifiedDialogConfig(datasets.BuilderConfig):
"""BuilderConfig for SuperGLUE."""
def __init__(self, data_name, **kwargs):
"""BuilderConfig for SuperGLUE.
Args:
features: `list[string]`, list of the features that will appear in the
feature dict. Should not include "label".
data_url: `string`, url to download the zip file from.
citation: `string`, citation for the data set.
url: `string`, url for information about the data set.
label_classes: `list[string]`, the list of classes for the label if the
label is present as a string. Non-string labels will be cast to either
'False' or 'True'.
**kwargs: keyword arguments forwarded to super.
"""
# Version history:
# 1.0.2: Fixed non-nondeterminism in ReCoRD.
# 1.0.1: Change from the pre-release trial version of SuperGLUE (v1.9) to
# the full release (v2.0).
# 1.0.0: S3 (new shuffling, sharding and slicing mechanism).
# 0.0.2: Initial version.
super(UnifiedDialogConfig, self).__init__(version=datasets.Version("1.0.2"), **kwargs)
self.data_name = data_name
class Summarization(datasets.GeneratorBasedBuilder):
"""Summarization"""
BUILDER_CONFIGS = [
UnifiedDialogConfig(name='JOINT',data_name='joint'),
UnifiedDialogConfig(name='TRANSFER',data_name='transfer'),
UnifiedDialogConfig(name='SINGLE',data_name='single'),
]
random.seed(2022)
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"Context": datasets.Value("string"),
"Knowledge": datasets.Value("string"),
"Response": datasets.Value("string"),
"Dataset": datasets.Value("string"),
}
),
homepage=_WEBPAGE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
data_name = self.config.data_name
if data_name == 'joint':
train_path = f'./multiwoz/data/joint_train.jsonl'
validation_path = f'./multiwoz/data/single_validation.jsonl'
test_path = f'./multiwoz/data/single_test.jsonl'
elif data_name == 'transfer':
train_path = f'./multiwoz/data/transfer_train.jsonl'
validation_path = f'./multiwoz/data/single_validation.jsonl'
test_path = f'./multiwoz/data/single_test.jsonl'
elif data_name == 'single':
train_path = f'./multiwoz/data/single_train.jsonl'
validation_path = f'./multiwoz/data/single_validation.jsonl'
test_path = f'./multiwoz/data/single_test.jsonl'
else:
raise('Please specific dataset config.')
return [
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": train_path}),
datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": validation_path}),
datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": test_path}),
]
def _generate_examples(self, filepath):
with open(filepath, "r", encoding="utf-8") as reader:
key = 0
for item in jsonlines.Reader(reader):
yield key, item
key += 1
\ No newline at end of file
import jsonlines
import copy
import fire
from convlab.util.unified_datasets_util import create_delex_data, load_dataset
from convlab.util import load_e2e_data
def state_to_string(state):
domain_str = []
for domain,svs in state.items():
svs_str = []
for s,v in svs.items():
if v != '':
svs_str.append(f'{s} is {v}')
svs_str = ' ; '.join(svs_str)
if svs_str != '':
domain_str.append(f'{domain} {svs_str}')
domain_str = ' | '.join(domain_str)
return domain_str
def context_to_string(context):
response = ' EOS '.join(i['utterance'].strip() for i in context)
return response
def delex_function(d,s,v):
s = s.replace(' ','')
str_ = f'[{d}_{s}]'
return str_
def create_dataset(mode='joint'):
dataset_list = {
'joint': ['tm1','sgd','multiwoz21'],
'transfer': ['tm1','sgd'],
'single': ['multiwoz21']
}
examples = []
for _data in dataset_list[mode]:
dataset = load_dataset(_data)
dataset, delex_vocab = create_delex_data(dataset, delex_func=delex_function)
e2e_data = load_e2e_data(dataset, delex_utterance = True)
split_list = ['train','validation','test'] if mode == 'single' else ['train']
for split in split_list:
data = e2e_data[split]
for i in data:
response = i['delex_utterance'].strip()
context = i['context']
context = context_to_string(context)
example = {}
example['Context'] = context
try:
knowledge = state_to_string(i['context'][-1]['state'])
except Exception:
knowledge = ''
example['Knowledge'] = knowledge
example['Response'] = 'Agent: ' + response.strip()
example['Dataset'] = f'{_data}'
examples.append(copy.copy(example))
if mode == 'single':
with jsonlines.open(f'./data/{mode}_{split}.jsonl', "w") as writer:
for item in examples:
writer.write(item)
examples = []
if mode != 'single':
with jsonlines.open(f'./data/{mode}_train.jsonl', "w") as writer:
for item in examples:
writer.write(item)
if __name__ == '__main__':
fire.Fire(create_dataset)
\ No newline at end of file
import jsonlines
import json,copy
fidx = open('test.idx.txt','w')
data = json.load(open('data/test.json'))
examples = []
for i in data:
name = i['file'].lower()
history = []
for turn in i['info']:
history.append(turn['user_orig'])
bs = turn['BS']
bs_str = []
for domain, states in bs.items():
domain_str = []
for state in states:
domain_str.append(f'{state[0]} = {state[1]}')
domain_str = ' ; '.join(domain_str)
bs_str.append(domain + ' ' + domain_str)
bs_str = ' | '.join(bs_str)
db_str = 'kb '
db = turn['KB']
if db == 0:
db_str += 'zero'
elif db_str == 1:
db_str += 'one'
elif db_str == 2:
db_str += 'two'
else:
db_str += 'more than two'
act_seq = ' '.join(turn['act'].keys())
example = {}
example['Context'] = ' EOS '.join(history[:])
example['Knowledge'] = ''
example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip()
history.append(turn['sys'].strip())
examples.append(copy.copy(example))
fidx.write(name + '\n')
writer = jsonlines.open('multiwoz_test_e2e.jsonl', mode='w')
for i in examples:
writer.write(i)
data = json.load(open('data/val.json'))
examples = []
for i in data:
name = i['file'].lower()
history = []
for turn in i['info']:
history.append(turn['user_orig'])
bs = turn['BS']
bs_str = []
for domain, states in bs.items():
domain_str = []
for state in states:
domain_str.append(f'{state[0]} = {state[1]}')
domain_str = ' ; '.join(domain_str)
bs_str.append(domain + ' ' + domain_str)
bs_str = ' | '.join(bs_str)
db_str = 'kb '
db = turn['KB']
if db == 0:
db_str += 'zero'
elif db_str == 1:
db_str += 'one'
elif db_str == 2:
db_str += 'two'
else:
db_str += 'more than two'
act_seq = ' '.join(turn['act'].keys())
example = {}
example['Context'] = ' EOS '.join(history[:])
example['Knowledge'] = ''
example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip()
history.append(turn['sys'].strip())
examples.append(copy.copy(example))
# fidx.write(name + '\n')
writer = jsonlines.open('multiwoz_valid_e2e.jsonl', mode='w')
for i in examples:
writer.write(i)
data = json.load(open('data/train.json'))
examples = []
for i in data:
name = i['file'].lower()
history = []
for turn in i['info']:
history.append(turn['user_orig'])
bs = turn['BS']
bs_str = []
for domain, states in bs.items():
domain_str = []
for state in states:
domain_str.append(f'{state[0]} = {state[1]}')
domain_str = ' ; '.join(domain_str)
bs_str.append(domain + ' ' + domain_str)
bs_str = ' | '.join(bs_str)
db_str = 'kb '
db = turn['KB']
if db == 0:
db_str += 'zero'
elif db_str == 1:
db_str += 'one'
elif db_str == 2:
db_str += 'two'
else:
db_str += 'more than two'
act_seq = ' '.join(turn['act'].keys())
example = {}
example['Context'] = ' EOS '.join(history[:])
example['Knowledge'] = ''
example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip()
history.append(turn['sys'].strip())
examples.append(copy.copy(example))
# fidx.write(name + '\n')
writer = jsonlines.open('multiwoz_train_e2e.jsonl', mode='w')
for i in examples:
writer.write(i)
import argparse
import logging
import math
import os
import random
import datasets
import nltk
import numpy as np
import torch
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from filelock import FileLock
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AdamW,
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
SchedulerType,
get_scheduler,
set_seed,
)
from transformers.file_utils import is_offline_mode
from transformers.utils.versions import require_version
import copy, operator
from queue import PriorityQueue
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
from torch.distributions import Categorical
from convlab.e2e.soloist.multiwoz.config import global_config as cfg
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
def cuda_(var):
return var.cuda() if cfg.cuda and torch.cuda.is_available() else var
def tensor(var):
return cuda_(torch.tensor(var))
class SOLOIST:
def __init__(self) -> None:
self.config = AutoConfig.from_pretrained(cfg.model_name_or_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_name_or_path,config=self.config)
self.tokenizer = AutoTokenizer.from_pretrained('t5-base')
print('model loaded!')
self.model = self.model.cuda() if torch.cuda.is_available() else self.model
def generate(self, inputs):
self.model.eval()
inputs = self.tokenizer([inputs])
input_ids = tensor(inputs['input_ids'])
# generated_tokens = self.model.generate(input_ids = input_ids, max_length = cfg.max_length, num_beams = cfg.num_beams)
generated_tokens = self.model.generate(input_ids = input_ids, max_length = cfg.max_length, top_p=cfg.top_p)
decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return decoded_preds[0]
def train_loop(self):
def preprocess_function(examples):
contextes = examples['Context']
responses = examples['Response']
belief = examples['Belief']
responses_labels = []
inputs = []
for context, response, kb in zip(contextes, responses, belief):
if cfg.no_kb:
inputs.append(context + ' => ')
else:
if cfg.format_version == 'e2e':
context = ' EOS '.join(context.split(' EOS ')[-10:])
_input = context
if cfg.format_version == 'e2e+lm':
context = ' EOS '.join(context.split(' EOS ')[-10:])
inputs.append('[E2E] ' + context)
responses_labels.append(response)
inputs.append('[LM] ' + context )
responses_labels.append(response.split(' EOS ')[1])
continue
if cfg.format_version == 'v2':
_input = kb + context
if cfg.format_version == 'v3':
_input = ''
context = context.split(' EOS ')
for idx, turn in enumerate(context):
if idx % 2 == 0:
_input += 'user : ' + turn.strip()
else:
_input += ' system : ' + turn.strip()
_input = _input + ' <|Knowledge|> ' + kb
if cfg.format_version == 'v4':
_input = ''
context = context.split(' EOS ')
for idx, turn in enumerate(context):
if idx % 2 == 0:
_input += 'user : ' + turn.strip()
else:
_input += ' system : ' + turn.strip()
_input = kb + _input
inputs.append(_input)
responses_labels.append(response)
model_inputs = self.tokenizer(inputs, max_length=cfg.max_length, padding="max_length", truncation=True)
with self.tokenizer.as_target_tokenizer():
labels = self.tokenizer(responses_labels, max_length=cfg.max_target_length, padding="max_length", truncation=True)
if cfg.ignore_pad_token_for_loss:
labels["labels"] = [
[(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["labels"]
return model_inputs
raw_datasets = load_dataset(cfg.dataset_name)
column_names = ['Context','Response','Belief']
lm_datasets = raw_datasets.map(
preprocess_function,
batched=True,
remove_columns=column_names,
num_proc=cfg.preprocessing_num_workers,
load_from_cache_file=False,
desc=f"Processing dataset",
)
train_dataset = lm_datasets["test"]
# train_dataset = lm_datasets["validation"]
eval_dataset = lm_datasets["test"]
test_dataset = lm_datasets["test"]
for index in random.sample(range(len(train_dataset)), 1):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
label_pad_token_id = -100 if cfg.ignore_pad_token_for_loss else self.tokenizer.pad_token_id
accelerator = Accelerator()
logger.info(accelerator.state)
data_collator = DataCollatorForSeq2Seq(
self.tokenizer,
model=self.model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if accelerator.use_fp16 else None,
)
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=cfg.per_device_train_batch_size
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=cfg.per_device_eval_batch_size)
test_dataloader = DataLoader(test_dataset, collate_fn=data_collator, batch_size=cfg.per_device_eval_batch_size)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": cfg.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.learning_rate)
# Prepare everything with our `accelerator`.
self.model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
self.model, optimizer, train_dataloader, eval_dataloader, test_dataloader
)
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
if cfg.max_train_steps is None:
cfg.max_train_steps = cfg.num_train_epochs * num_update_steps_per_epoch
else:
cfg.num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch)
lr_scheduler = get_scheduler(
name=cfg.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=cfg.num_warmup_steps,
num_training_steps=cfg.max_train_steps,
)
# Metric
# Train!
total_batch_size = cfg.per_device_train_batch_size * accelerator.num_processes * cfg.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {cfg.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {cfg.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {cfg.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {cfg.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(cfg.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
global_steps = 0
tr_loss, logging_loss = 0.0, 0.0
for epoch in range(cfg.num_train_epochs):
self.model.train()
# for step, batch in enumerate(train_dataloader):
for step, batch in enumerate(train_dataloader):
global_steps += 1
outputs = self.model(**batch)
loss = outputs.loss
loss = loss / cfg.gradient_accumulation_steps
tr_loss += loss.item()
accelerator.backward(loss)
if step % cfg.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
completed_steps += 1
if completed_steps >= cfg.max_train_steps:
break
if step % cfg.logging_steps == 0:
logger.info(f" EVALERR: {(tr_loss - logging_loss)/float(cfg.logging_steps)}")
logging_loss = tr_loss
progress_bar.update(cfg.logging_steps)
if cfg.output_dir is not None and global_steps % cfg.save_steps == 0 and global_steps > 0:
accelerator.wait_for_everyone()
if accelerator.is_local_main_process:
checkpoint_prefix = 'checkpoint'
output_dir = os.path.join(cfg.output_dir, '{}-{}'.format(checkpoint_prefix, global_steps))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
unwrapped_model = accelerator.unwrap_model(self.model)
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
self.tokenizer.save_pretrained(output_dir)
torch.save(cfg, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir)
\ No newline at end of file
This diff is collapsed.
...@@ -148,6 +148,7 @@ def load_unified_data( ...@@ -148,6 +148,7 @@ def load_unified_data(
dialogue_acts=False, dialogue_acts=False,
state=False, state=False,
db_results=False, db_results=False,
delex_utterance=False,
use_context=False, use_context=False,
context_window_size=0, context_window_size=0,
terminated=False, terminated=False,
...@@ -182,7 +183,7 @@ def load_unified_data( ...@@ -182,7 +183,7 @@ def load_unified_data(
data_splits = dataset.keys() if data_split == 'all' else [data_split] data_splits = dataset.keys() if data_split == 'all' else [data_split]
assert speaker in ['user', 'system', 'all'] assert speaker in ['user', 'system', 'all']
assert not use_context or context_window_size > 0 assert not use_context or context_window_size > 0
info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results'])) info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results', 'delex_utterance']))
info_list += ['utt_idx'] info_list += ['utt_idx']
data_by_split = {} data_by_split = {}
for data_split in data_splits: for data_split in data_splits:
...@@ -426,7 +427,12 @@ def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore ...@@ -426,7 +427,12 @@ def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore
for value in values.split('|'): for value in values.split('|'):
if value.lower() not in ignore_values: if value.lower() not in ignore_values:
placeholder = delex_func(domain, slot, value) placeholder = delex_func(domain, slot, value)
#TODO: value = ?
value = '\?' if value == '?' else value
try:
pattern = re.compile(r'\b({})\b'.format(value), flags=re.I) pattern = re.compile(r'\b({})\b'.format(value), flags=re.I)
except Exception:
print(value)
if delex_inplace(delex_utt, pattern): if delex_inplace(delex_utt, pattern):
delex_vocab.add(placeholder) delex_vocab.add(placeholder)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment