Skip to content
Snippets Groups Projects
Unverified Commit 2cd7eeea authored by zhuqi's avatar zhuqi Committed by GitHub
Browse files

Merge pull request #87 from ConvLab/e2e-soloist

E2e soloist
parents 459d1e2e 66f92372
No related branches found
No related tags found
No related merge requests found
......@@ -65,15 +65,10 @@ docker exec -it CONTAINER_ID bash
## Tutorials
| Section | Description |
| ------------------------------------------------------------ | ----------- |
| [Getting Started](https://github.com/thu-coai/ConvLab-2/blob/master/tutorials/Getting_Started.ipynb) (Have a try on [Colab](https://colab.research.google.com/github/thu-coai/ConvLab-2/blob/master/tutorials/Getting_Started.ipynb)!) | |
| [Unified Data Format](https://github.com/ConvLab/ConvLab-3/tree/master/data/unified_datasets) | |
| [Utility functions for unified datasets](https://github.com/ConvLab/ConvLab-3/blob/master/convlab/util/unified_datasets_util.py) | |
| [RL Toolkit](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/policy) | |
| [How to add a new dataset](https://github.com/thu-coai/ConvLab-2/blob/master/tutorials/Add_New_Model.md) | |
| How to add a new model | |
| [Interactive Tool](https://github.com/ConvLab/ConvLab-3/blob/master/deploy) [[demo video]](https://youtu.be/00VWzbcx26E) | |
- [Introduction to Unified Data Format](https://github.com/ConvLab/ConvLab-3/tree/master/data/unified_datasets)
- [Utility functions for unified datasets](https://github.com/ConvLab/ConvLab-3/blob/master/convlab/util/unified_datasets_util.py)
- [RL Toolkit](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/policy)
- [Interactive Tool](https://github.com/ConvLab/ConvLab-3/blob/master/deploy) [[demo video]](https://youtu.be/00VWzbcx26E)
## Unified Datasets
......@@ -112,10 +107,6 @@ We list newly integrated models in ConvLab-3 that support unified data format an
Trained models are available on [Hugging Face Hub](https://huggingface.co/ConvLab).
## Code structure
## Contributing
We welcome contributions from community. Please see issues to find what we need.
......@@ -131,15 +122,6 @@ We would like to thank all contributors of ConvLab:
Yan Fang, Zhuoer Feng, Jianfeng Gao, Qihan Guo, Kaili Huang, Minlie Huang, Sungjin Lee, Bing Li, Jinchao Li, Xiang Li, Xiujun Li, Jiexi Liu, Lingxiao Luo, Wenchang Ma, Mehrad Moradshahi, Baolin Peng, Runze Liang, Ryuichi Takanobu, Dazhen Wan, Hongru Wang, Jiaxin Wen, Yaoqin Zhang, Zheng Zhang, Qi Zhu, Xiaoyan Zhu, Carel van Niekerk, Christian Geishauser, Hsien-chin Lin, Nurul Lubis, Xiaochen Zhu, Michael Heck, Shutong Feng, Milica Gašić.
## Citing
If you use ConvLab-3 in your research, please cite:
```
```
## License
Apache License 2.0
# SOLOIST
On top of the pre-trained LMs, SOLOIST subsumes different components of task-oriented dialogs into a single model and emplies a pre-training then fine-tuning schema to build task bots.
## Usage
Follow the instruction under each dataset's directory to prepare data training and evaluation.
#### Dataset Creation
Create datasets of three settings.
```sh
$ cd multiwoz
$ python script/create_dataset.py joint
$ python script/create_dataset.py transfer
$ python script/create_dataset.py single
```
#### Train a model
```sh
$ python train.py --model_name_or_path t5-base --dataset_name e2e_dataloader.py --output_dir ./model --per_device_train_batch_size=2 --per_device_eval_batch_size=2 --max_target_length 128 --max_length 512 --num_train_epochs 50 --save_steps 10000 --preprocessing_num_workers 1 --num_beams 5 --learning_rate 5e-5 --dataset_config_name SINGLE --logging_steps 100
```
The model (`pytorch_model.bin`) will be saved under the `output_dir` of the config file. The script will save predictions for validation/test every epoch.
#### Test a model
The result will be saved under the `output_dir` of the config file. For evaluation, a 3rd party package is used. Please follow the instructions at https://github.com/Tomiinek/MultiWOZ_Evaluation
## Performance on unified format datasets of different settings
Note that we use almost the same hyper-parameters for different settings, which may not be optimal.
<table>
<thead>
<tr>
<th></th>
<th colspan=2>MultiWOZ 2.1</th>
<th colspan=2>SGD</th>
<th colspan=2>Taskmaster-1</th>
</tr>
</thead>
<thead>
<tr>
<th>Model</th>
<th>Combined</th><th>BLEU</th>
<th>Slot F1</th><th>BLEU</th>
<th>Slot F1</th><th>BLEU</th>
</tr>
</thead>
<tbody>
<tr>
<td>SOLOIST w/o pre-training</td>
<td>67.0</td><td>16.8</td>
<td>56.9</td><td>11.2</td>
<td>8.5</td><td>28.0</td>
</tr>
<tr>
<td>SOLOIST </td>
<td>71.4</td><td>17.1</td>
<td>69.7</td><td>23.1</td>
<td>9.2</td><td>29.2</td>
</tr>
</tbody>
</table>
- Slot F1: F1 measure of the delexicalized slot predictions over the corpus.
## References
```
@article{peng2021soloist,
title={Soloist: Buildingtask bots at scale with transfer learning and machine teaching},
author={Peng, Baolin and Li, Chunyuan and Li, Jinchao and Shayandeh, Shahin and Liden, Lars and Gao, Jianfeng},
journal={Transactions of the Association for Computational Linguistics},
volume={9},
pages={807--824},
year={2021},
publisher={MIT Press}
}
@article{nekvinda2021shades,
title={Shades of BLEU, flavours of success: The case of MultiWOZ},
author={Nekvinda, Tom{\'a}{\v{s}} and Du{\v{s}}ek, Ond{\v{r}}ej},
journal={arXiv preprint arXiv:2106.05555},
year={2021}
}
@article{peng2022godel,
title={GODEL: Large-Scale Pre-Training for Goal-Directed Dialog},
author={Peng, Baolin and Galley, Michel and He, Pengcheng and Brockett, Chris and Liden, Lars and Nouri, Elnaz and Yu, Zhou and Dolan, Bill and Gao, Jianfeng},
journal={arXiv preprint arXiv:2206.11309},
year={2022}
}
```
\ No newline at end of file
import datasets
import jsonlines
import random
# coding=utf-8
# Copyright 2020 HuggingFace Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Corpus for E2E Dialog Modeling"""
import csv
import datasets
_DESCRIPTION = """\
E2E Dialog Modeling
"""
_CITATION = """\
E2E Dialog Modeling
"""
_DOWNLOAD_URL = ""
_WEBPAGE = ""
class UnifiedDialogConfig(datasets.BuilderConfig):
"""BuilderConfig for SuperGLUE."""
def __init__(self, data_name, **kwargs):
"""BuilderConfig for SuperGLUE.
Args:
features: `list[string]`, list of the features that will appear in the
feature dict. Should not include "label".
data_url: `string`, url to download the zip file from.
citation: `string`, citation for the data set.
url: `string`, url for information about the data set.
label_classes: `list[string]`, the list of classes for the label if the
label is present as a string. Non-string labels will be cast to either
'False' or 'True'.
**kwargs: keyword arguments forwarded to super.
"""
# Version history:
# 1.0.2: Fixed non-nondeterminism in ReCoRD.
# 1.0.1: Change from the pre-release trial version of SuperGLUE (v1.9) to
# the full release (v2.0).
# 1.0.0: S3 (new shuffling, sharding and slicing mechanism).
# 0.0.2: Initial version.
super(UnifiedDialogConfig, self).__init__(version=datasets.Version("1.0.2"), **kwargs)
self.data_name = data_name
class Summarization(datasets.GeneratorBasedBuilder):
"""Summarization"""
BUILDER_CONFIGS = [
UnifiedDialogConfig(name='JOINT',data_name='joint'),
UnifiedDialogConfig(name='TRANSFER',data_name='transfer'),
UnifiedDialogConfig(name='SINGLE',data_name='single'),
]
random.seed(2022)
def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"Context": datasets.Value("string"),
"Knowledge": datasets.Value("string"),
"Response": datasets.Value("string"),
"Dataset": datasets.Value("string"),
}
),
homepage=_WEBPAGE,
citation=_CITATION,
)
def _split_generators(self, dl_manager):
data_name = self.config.data_name
if data_name == 'joint':
train_path = f'./multiwoz/data/joint_train.jsonl'
validation_path = f'./multiwoz/data/single_validation.jsonl'
test_path = f'./multiwoz/data/single_test.jsonl'
elif data_name == 'transfer':
train_path = f'./multiwoz/data/transfer_train.jsonl'
validation_path = f'./multiwoz/data/single_validation.jsonl'
test_path = f'./multiwoz/data/single_test.jsonl'
elif data_name == 'single':
train_path = f'./multiwoz/data/single_train.jsonl'
validation_path = f'./multiwoz/data/single_validation.jsonl'
test_path = f'./multiwoz/data/single_test.jsonl'
else:
raise('Please specific dataset config.')
return [
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": train_path}),
datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": validation_path}),
datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": test_path}),
]
def _generate_examples(self, filepath):
with open(filepath, "r", encoding="utf-8") as reader:
key = 0
for item in jsonlines.Reader(reader):
yield key, item
key += 1
\ No newline at end of file
import jsonlines
import copy
import fire
from convlab.util.unified_datasets_util import create_delex_data, load_dataset
from convlab.util import load_e2e_data
def state_to_string(state):
domain_str = []
for domain,svs in state.items():
svs_str = []
for s,v in svs.items():
if v != '':
svs_str.append(f'{s} is {v}')
svs_str = ' ; '.join(svs_str)
if svs_str != '':
domain_str.append(f'{domain} {svs_str}')
domain_str = ' | '.join(domain_str)
return domain_str
def context_to_string(context):
response = ' EOS '.join(i['utterance'].strip() for i in context)
return response
def delex_function(d,s,v):
s = s.replace(' ','')
str_ = f'[{d}_{s}]'
return str_
def create_dataset(mode='joint'):
dataset_list = {
'joint': ['tm1','sgd','multiwoz21'],
'transfer': ['tm1','sgd'],
'single': ['multiwoz21']
}
examples = []
for _data in dataset_list[mode]:
dataset = load_dataset(_data)
dataset, delex_vocab = create_delex_data(dataset, delex_func=delex_function)
e2e_data = load_e2e_data(dataset, delex_utterance = True)
split_list = ['train','validation','test'] if mode == 'single' else ['train']
for split in split_list:
data = e2e_data[split]
for i in data:
response = i['delex_utterance'].strip()
context = i['context']
context = context_to_string(context)
example = {}
example['Context'] = context
try:
knowledge = state_to_string(i['context'][-1]['state'])
except Exception:
knowledge = ''
example['Knowledge'] = knowledge
example['Response'] = 'Agent: ' + response.strip()
example['Dataset'] = f'{_data}'
examples.append(copy.copy(example))
if mode == 'single':
with jsonlines.open(f'./data/{mode}_{split}.jsonl', "w") as writer:
for item in examples:
writer.write(item)
examples = []
if mode != 'single':
with jsonlines.open(f'./data/{mode}_train.jsonl', "w") as writer:
for item in examples:
writer.write(item)
if __name__ == '__main__':
fire.Fire(create_dataset)
\ No newline at end of file
import jsonlines
import json,copy
fidx = open('test.idx.txt','w')
data = json.load(open('data/test.json'))
examples = []
for i in data:
name = i['file'].lower()
history = []
for turn in i['info']:
history.append(turn['user_orig'])
bs = turn['BS']
bs_str = []
for domain, states in bs.items():
domain_str = []
for state in states:
domain_str.append(f'{state[0]} = {state[1]}')
domain_str = ' ; '.join(domain_str)
bs_str.append(domain + ' ' + domain_str)
bs_str = ' | '.join(bs_str)
db_str = 'kb '
db = turn['KB']
if db == 0:
db_str += 'zero'
elif db_str == 1:
db_str += 'one'
elif db_str == 2:
db_str += 'two'
else:
db_str += 'more than two'
act_seq = ' '.join(turn['act'].keys())
example = {}
example['Context'] = ' EOS '.join(history[:])
example['Knowledge'] = ''
example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip()
history.append(turn['sys'].strip())
examples.append(copy.copy(example))
fidx.write(name + '\n')
writer = jsonlines.open('multiwoz_test_e2e.jsonl', mode='w')
for i in examples:
writer.write(i)
data = json.load(open('data/val.json'))
examples = []
for i in data:
name = i['file'].lower()
history = []
for turn in i['info']:
history.append(turn['user_orig'])
bs = turn['BS']
bs_str = []
for domain, states in bs.items():
domain_str = []
for state in states:
domain_str.append(f'{state[0]} = {state[1]}')
domain_str = ' ; '.join(domain_str)
bs_str.append(domain + ' ' + domain_str)
bs_str = ' | '.join(bs_str)
db_str = 'kb '
db = turn['KB']
if db == 0:
db_str += 'zero'
elif db_str == 1:
db_str += 'one'
elif db_str == 2:
db_str += 'two'
else:
db_str += 'more than two'
act_seq = ' '.join(turn['act'].keys())
example = {}
example['Context'] = ' EOS '.join(history[:])
example['Knowledge'] = ''
example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip()
history.append(turn['sys'].strip())
examples.append(copy.copy(example))
# fidx.write(name + '\n')
writer = jsonlines.open('multiwoz_valid_e2e.jsonl', mode='w')
for i in examples:
writer.write(i)
data = json.load(open('data/train.json'))
examples = []
for i in data:
name = i['file'].lower()
history = []
for turn in i['info']:
history.append(turn['user_orig'])
bs = turn['BS']
bs_str = []
for domain, states in bs.items():
domain_str = []
for state in states:
domain_str.append(f'{state[0]} = {state[1]}')
domain_str = ' ; '.join(domain_str)
bs_str.append(domain + ' ' + domain_str)
bs_str = ' | '.join(bs_str)
db_str = 'kb '
db = turn['KB']
if db == 0:
db_str += 'zero'
elif db_str == 1:
db_str += 'one'
elif db_str == 2:
db_str += 'two'
else:
db_str += 'more than two'
act_seq = ' '.join(turn['act'].keys())
example = {}
example['Context'] = ' EOS '.join(history[:])
example['Knowledge'] = ''
example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip()
history.append(turn['sys'].strip())
examples.append(copy.copy(example))
# fidx.write(name + '\n')
writer = jsonlines.open('multiwoz_train_e2e.jsonl', mode='w')
for i in examples:
writer.write(i)
import argparse
import logging
import math
import os
import random
import datasets
import nltk
import numpy as np
import torch
from datasets import load_dataset, load_metric
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from filelock import FileLock
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AdamW,
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
SchedulerType,
get_scheduler,
set_seed,
)
from transformers.file_utils import is_offline_mode
from transformers.utils.versions import require_version
import copy, operator
from queue import PriorityQueue
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
from torch.distributions import Categorical
from convlab.e2e.soloist.multiwoz.config import global_config as cfg
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
def cuda_(var):
return var.cuda() if cfg.cuda and torch.cuda.is_available() else var
def tensor(var):
return cuda_(torch.tensor(var))
class SOLOIST:
def __init__(self) -> None:
self.config = AutoConfig.from_pretrained(cfg.model_name_or_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_name_or_path,config=self.config)
self.tokenizer = AutoTokenizer.from_pretrained('t5-base')
print('model loaded!')
self.model = self.model.cuda() if torch.cuda.is_available() else self.model
def generate(self, inputs):
self.model.eval()
inputs = self.tokenizer([inputs])
input_ids = tensor(inputs['input_ids'])
# generated_tokens = self.model.generate(input_ids = input_ids, max_length = cfg.max_length, num_beams = cfg.num_beams)
generated_tokens = self.model.generate(input_ids = input_ids, max_length = cfg.max_length, top_p=cfg.top_p)
decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return decoded_preds[0]
def train_loop(self):
def preprocess_function(examples):
contextes = examples['Context']
responses = examples['Response']
belief = examples['Belief']
responses_labels = []
inputs = []
for context, response, kb in zip(contextes, responses, belief):
if cfg.no_kb:
inputs.append(context + ' => ')
else:
if cfg.format_version == 'e2e':
context = ' EOS '.join(context.split(' EOS ')[-10:])
_input = context
if cfg.format_version == 'e2e+lm':
context = ' EOS '.join(context.split(' EOS ')[-10:])
inputs.append('[E2E] ' + context)
responses_labels.append(response)
inputs.append('[LM] ' + context )
responses_labels.append(response.split(' EOS ')[1])
continue
if cfg.format_version == 'v2':
_input = kb + context
if cfg.format_version == 'v3':
_input = ''
context = context.split(' EOS ')
for idx, turn in enumerate(context):
if idx % 2 == 0:
_input += 'user : ' + turn.strip()
else:
_input += ' system : ' + turn.strip()
_input = _input + ' <|Knowledge|> ' + kb
if cfg.format_version == 'v4':
_input = ''
context = context.split(' EOS ')
for idx, turn in enumerate(context):
if idx % 2 == 0:
_input += 'user : ' + turn.strip()
else:
_input += ' system : ' + turn.strip()
_input = kb + _input
inputs.append(_input)
responses_labels.append(response)
model_inputs = self.tokenizer(inputs, max_length=cfg.max_length, padding="max_length", truncation=True)
with self.tokenizer.as_target_tokenizer():
labels = self.tokenizer(responses_labels, max_length=cfg.max_target_length, padding="max_length", truncation=True)
if cfg.ignore_pad_token_for_loss:
labels["labels"] = [
[(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = labels["labels"]
return model_inputs
raw_datasets = load_dataset(cfg.dataset_name)
column_names = ['Context','Response','Belief']
lm_datasets = raw_datasets.map(
preprocess_function,
batched=True,
remove_columns=column_names,
num_proc=cfg.preprocessing_num_workers,
load_from_cache_file=False,
desc=f"Processing dataset",
)
train_dataset = lm_datasets["test"]
# train_dataset = lm_datasets["validation"]
eval_dataset = lm_datasets["test"]
test_dataset = lm_datasets["test"]
for index in random.sample(range(len(train_dataset)), 1):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
label_pad_token_id = -100 if cfg.ignore_pad_token_for_loss else self.tokenizer.pad_token_id
accelerator = Accelerator()
logger.info(accelerator.state)
data_collator = DataCollatorForSeq2Seq(
self.tokenizer,
model=self.model,
label_pad_token_id=label_pad_token_id,
pad_to_multiple_of=8 if accelerator.use_fp16 else None,
)
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=data_collator, batch_size=cfg.per_device_train_batch_size
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=cfg.per_device_eval_batch_size)
test_dataloader = DataLoader(test_dataset, collate_fn=data_collator, batch_size=cfg.per_device_eval_batch_size)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": cfg.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.learning_rate)
# Prepare everything with our `accelerator`.
self.model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
self.model, optimizer, train_dataloader, eval_dataloader, test_dataloader
)
# Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
# shorter in multiprocess)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
if cfg.max_train_steps is None:
cfg.max_train_steps = cfg.num_train_epochs * num_update_steps_per_epoch
else:
cfg.num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch)
lr_scheduler = get_scheduler(
name=cfg.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=cfg.num_warmup_steps,
num_training_steps=cfg.max_train_steps,
)
# Metric
# Train!
total_batch_size = cfg.per_device_train_batch_size * accelerator.num_processes * cfg.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {cfg.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {cfg.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {cfg.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {cfg.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(cfg.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
global_steps = 0
tr_loss, logging_loss = 0.0, 0.0
for epoch in range(cfg.num_train_epochs):
self.model.train()
# for step, batch in enumerate(train_dataloader):
for step, batch in enumerate(train_dataloader):
global_steps += 1
outputs = self.model(**batch)
loss = outputs.loss
loss = loss / cfg.gradient_accumulation_steps
tr_loss += loss.item()
accelerator.backward(loss)
if step % cfg.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
completed_steps += 1
if completed_steps >= cfg.max_train_steps:
break
if step % cfg.logging_steps == 0:
logger.info(f" EVALERR: {(tr_loss - logging_loss)/float(cfg.logging_steps)}")
logging_loss = tr_loss
progress_bar.update(cfg.logging_steps)
if cfg.output_dir is not None and global_steps % cfg.save_steps == 0 and global_steps > 0:
accelerator.wait_for_everyone()
if accelerator.is_local_main_process:
checkpoint_prefix = 'checkpoint'
output_dir = os.path.join(cfg.output_dir, '{}-{}'.format(checkpoint_prefix, global_steps))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
unwrapped_model = accelerator.unwrap_model(self.model)
unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
self.tokenizer.save_pretrained(output_dir)
torch.save(cfg, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir)
\ No newline at end of file
This diff is collapsed.
......@@ -166,6 +166,7 @@ def load_unified_data(
dialogue_acts=False,
state=False,
db_results=False,
delex_utterance=False,
use_context=False,
context_window_size=0,
terminated=False,
......@@ -200,8 +201,7 @@ def load_unified_data(
data_splits = dataset.keys() if data_split == 'all' else [data_split]
assert speaker in ['user', 'system', 'all']
assert not use_context or context_window_size > 0
info_list = list(
filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results']))
info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results', 'delex_utterance']))
info_list += ['utt_idx']
data_by_split = {}
for data_split in data_splits:
......@@ -452,10 +452,13 @@ def create_delex_data(dataset, delex_func=lambda d, s, v: f'[({d})-({s})]', igno
# has value
for value in values.split('|'):
if value.lower() not in ignore_values:
placeholder = delex_func(
domain, slot, value)
pattern = re.compile(
r'\b({})\b'.format(value), flags=re.I)
placeholder = delex_func(domain, slot, value)
#TODO: value = ?
value = '\?' if value == '?' else value
try:
pattern = re.compile(r'\b({})\b'.format(value), flags=re.I)
except Exception:
print(value)
if delex_inplace(delex_utt, pattern):
delex_vocab.add(placeholder)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment