From 4d7d80c5a00acf5de7f46d698f4ccc5d004c7ba2 Mon Sep 17 00:00:00 2001 From: pengbaolin <bapeng@bapengvmcpu.boshzpqespbehhxtd0qmc0nvoa.xx.internal.cloudapp.net> Date: Tue, 29 Nov 2022 22:39:45 +0000 Subject: [PATCH] e2e-soloist --- convlab/e2e/soloist/READEME.md | 95 ++ convlab/e2e/soloist/e2e_dataloader.py | 124 +++ .../soloist/multiwoz/script/create_dataset.py | 74 ++ .../multiwoz/script/create_mwoz_e2e_json.py | 136 --- convlab/e2e/soloist/multiwoz/soloist_net.py | 277 ------ convlab/e2e/soloist/train.py | 836 ++++++++++++++++++ convlab/util/unified_datasets_util.py | 10 +- 7 files changed, 1137 insertions(+), 415 deletions(-) create mode 100644 convlab/e2e/soloist/READEME.md create mode 100644 convlab/e2e/soloist/e2e_dataloader.py create mode 100644 convlab/e2e/soloist/multiwoz/script/create_dataset.py delete mode 100644 convlab/e2e/soloist/multiwoz/script/create_mwoz_e2e_json.py delete mode 100644 convlab/e2e/soloist/multiwoz/soloist_net.py create mode 100644 convlab/e2e/soloist/train.py diff --git a/convlab/e2e/soloist/READEME.md b/convlab/e2e/soloist/READEME.md new file mode 100644 index 00000000..12367222 --- /dev/null +++ b/convlab/e2e/soloist/READEME.md @@ -0,0 +1,95 @@ +# SOLOIST + +On top of the pre-trained LMs, SOLOIST subsumes different components of task-oriented dialogs into a single model and emplies a pre-training then fine-tuning schema to build task bots. + +## Usage + +Follow the instruction under each dataset's directory to prepare data training and evaluation. + +#### Dataset Creation +Create datasets of three settings. +```sh +$ cd multiwoz +$ python script/create_dataset.py joint +$ python script/create_dataset.py transfer +$ python script/create_dataset.py single +``` + +#### Train a model + +```sh +$ python train.py --model_name_or_path t5-base --dataset_name e2e_dataloader.py --output_dir ./model --per_device_train_batch_size=2 --per_device_eval_batch_size=2 --max_target_length 128 --max_length 512 --num_train_epochs 50 --save_steps 10000 --preprocessing_num_workers 1 --num_beams 5 --learning_rate 5e-5 --dataset_config_name SINGLE --logging_steps 100 +``` + +The model (`pytorch_model.bin`) will be saved under the `output_dir` of the config file. The script will save predictions for validation/test every epoch. + +#### Test a model + +The result will be saved under the `output_dir` of the config file. For evaluation, a 3rd party package is used. Please follow the instructions at https://github.com/Tomiinek/MultiWOZ_Evaluation + + +## Performance on unified format datasets of different settings + + Note that we use almost the same hyper-parameters for different settings, which may not be optimal. + +<table> +<thead> + <tr> + <th></th> + <th colspan=2>MultiWOZ 2.1</th> + <th colspan=2>SGD</th> + <th colspan=2>Taskmaster-1</th> + </tr> +</thead> +<thead> + <tr> + <th>Model</th> + <th>Combined</th><th>BLEU</th> + <th>Slot F1</th><th>BLEU</th> + <th>Slot F1</th><th>BLEU</th> + </tr> +</thead> +<tbody> + <tr> + <td>SOLOIST w/o pre-training</td> + <td>67.0</td><td>16.8</td> + <td>56.9</td><td>11.2</td> + <td>8.5</td><td>28.0</td> + </tr> + <tr> + <td>SOLOIST </td> + <td>71.4</td><td>17.1</td> + <td>69.7</td><td>23.1</td> + <td>9.2</td><td>29.2</td> + + </tr> +</tbody> +</table> + +- Slot F1: F1 measure of the delexicalized slot predictions over the corpus. + +## References + +``` +@article{peng2021soloist, + title={Soloist: Buildingtask bots at scale with transfer learning and machine teaching}, + author={Peng, Baolin and Li, Chunyuan and Li, Jinchao and Shayandeh, Shahin and Liden, Lars and Gao, Jianfeng}, + journal={Transactions of the Association for Computational Linguistics}, + volume={9}, + pages={807--824}, + year={2021}, + publisher={MIT Press} +} +@article{nekvinda2021shades, + title={Shades of BLEU, flavours of success: The case of MultiWOZ}, + author={Nekvinda, Tom{\'a}{\v{s}} and Du{\v{s}}ek, Ond{\v{r}}ej}, + journal={arXiv preprint arXiv:2106.05555}, + year={2021} +} +@article{peng2022godel, + title={GODEL: Large-Scale Pre-Training for Goal-Directed Dialog}, + author={Peng, Baolin and Galley, Michel and He, Pengcheng and Brockett, Chris and Liden, Lars and Nouri, Elnaz and Yu, Zhou and Dolan, Bill and Gao, Jianfeng}, + journal={arXiv preprint arXiv:2206.11309}, + year={2022} +} +``` \ No newline at end of file diff --git a/convlab/e2e/soloist/e2e_dataloader.py b/convlab/e2e/soloist/e2e_dataloader.py new file mode 100644 index 00000000..ac7be4d2 --- /dev/null +++ b/convlab/e2e/soloist/e2e_dataloader.py @@ -0,0 +1,124 @@ +import datasets +import jsonlines +import random + +# coding=utf-8 +# Copyright 2020 HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Corpus for E2E Dialog Modeling""" + + +import csv + +import datasets + + +_DESCRIPTION = """\ +E2E Dialog Modeling +""" + +_CITATION = """\ +E2E Dialog Modeling +""" + +_DOWNLOAD_URL = "" +_WEBPAGE = "" + +class UnifiedDialogConfig(datasets.BuilderConfig): + """BuilderConfig for SuperGLUE.""" + + def __init__(self, data_name, **kwargs): + """BuilderConfig for SuperGLUE. + Args: + features: `list[string]`, list of the features that will appear in the + feature dict. Should not include "label". + data_url: `string`, url to download the zip file from. + citation: `string`, citation for the data set. + url: `string`, url for information about the data set. + label_classes: `list[string]`, the list of classes for the label if the + label is present as a string. Non-string labels will be cast to either + 'False' or 'True'. + **kwargs: keyword arguments forwarded to super. + """ + # Version history: + # 1.0.2: Fixed non-nondeterminism in ReCoRD. + # 1.0.1: Change from the pre-release trial version of SuperGLUE (v1.9) to + # the full release (v2.0). + # 1.0.0: S3 (new shuffling, sharding and slicing mechanism). + # 0.0.2: Initial version. + super(UnifiedDialogConfig, self).__init__(version=datasets.Version("1.0.2"), **kwargs) + self.data_name = data_name + + + +class Summarization(datasets.GeneratorBasedBuilder): + """Summarization""" + + BUILDER_CONFIGS = [ + UnifiedDialogConfig(name='JOINT',data_name='joint'), + UnifiedDialogConfig(name='TRANSFER',data_name='transfer'), + UnifiedDialogConfig(name='SINGLE',data_name='single'), + ] + + + random.seed(2022) + + def _info(self): + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=datasets.Features( + { + "Context": datasets.Value("string"), + "Knowledge": datasets.Value("string"), + "Response": datasets.Value("string"), + "Dataset": datasets.Value("string"), + } + ), + homepage=_WEBPAGE, + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + + data_name = self.config.data_name + + if data_name == 'joint': + train_path = f'./multiwoz/data/joint_train.jsonl' + validation_path = f'./multiwoz/data/single_validation.jsonl' + test_path = f'./multiwoz/data/single_test.jsonl' + elif data_name == 'transfer': + train_path = f'./multiwoz/data/transfer_train.jsonl' + validation_path = f'./multiwoz/data/single_validation.jsonl' + test_path = f'./multiwoz/data/single_test.jsonl' + elif data_name == 'single': + train_path = f'./multiwoz/data/single_train.jsonl' + validation_path = f'./multiwoz/data/single_validation.jsonl' + test_path = f'./multiwoz/data/single_test.jsonl' + else: + raise('Please specific dataset config.') + + return [ + datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": train_path}), + datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": validation_path}), + datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": test_path}), + ] + def _generate_examples(self, filepath): + + with open(filepath, "r", encoding="utf-8") as reader: + key = 0 + for item in jsonlines.Reader(reader): + yield key, item + key += 1 \ No newline at end of file diff --git a/convlab/e2e/soloist/multiwoz/script/create_dataset.py b/convlab/e2e/soloist/multiwoz/script/create_dataset.py new file mode 100644 index 00000000..2426af9f --- /dev/null +++ b/convlab/e2e/soloist/multiwoz/script/create_dataset.py @@ -0,0 +1,74 @@ +import jsonlines +import copy +import fire + +from convlab.util.unified_datasets_util import create_delex_data, load_dataset +from convlab.util import load_e2e_data + +def state_to_string(state): + domain_str = [] + for domain,svs in state.items(): + svs_str = [] + for s,v in svs.items(): + if v != '': + svs_str.append(f'{s} is {v}') + svs_str = ' ; '.join(svs_str) + if svs_str != '': + domain_str.append(f'{domain} {svs_str}') + domain_str = ' | '.join(domain_str) + return domain_str + +def context_to_string(context): + response = ' EOS '.join(i['utterance'].strip() for i in context) + return response + +def delex_function(d,s,v): + s = s.replace(' ','') + str_ = f'[{d}_{s}]' + return str_ + +def create_dataset(mode='joint'): + dataset_list = { + 'joint': ['tm1','sgd','multiwoz21'], + 'transfer': ['tm1','sgd'], + 'single': ['multiwoz21'] + } + + examples = [] + for _data in dataset_list[mode]: + + dataset = load_dataset(_data) + dataset, delex_vocab = create_delex_data(dataset, delex_func=delex_function) + e2e_data = load_e2e_data(dataset, delex_utterance = True) + + split_list = ['train','validation','test'] if mode == 'single' else ['train'] + + for split in split_list: + data = e2e_data[split] + for i in data: + response = i['delex_utterance'].strip() + context = i['context'] + context = context_to_string(context) + + example = {} + example['Context'] = context + try: + knowledge = state_to_string(i['context'][-1]['state']) + except Exception: + knowledge = '' + example['Knowledge'] = knowledge + example['Response'] = 'Agent: ' + response.strip() + example['Dataset'] = f'{_data}' + examples.append(copy.copy(example)) + if mode == 'single': + with jsonlines.open(f'./data/{mode}_{split}.jsonl', "w") as writer: + for item in examples: + writer.write(item) + examples = [] + if mode != 'single': + with jsonlines.open(f'./data/{mode}_train.jsonl', "w") as writer: + for item in examples: + writer.write(item) + +if __name__ == '__main__': + fire.Fire(create_dataset) \ No newline at end of file diff --git a/convlab/e2e/soloist/multiwoz/script/create_mwoz_e2e_json.py b/convlab/e2e/soloist/multiwoz/script/create_mwoz_e2e_json.py deleted file mode 100644 index 5953c19f..00000000 --- a/convlab/e2e/soloist/multiwoz/script/create_mwoz_e2e_json.py +++ /dev/null @@ -1,136 +0,0 @@ -import jsonlines -import json,copy -fidx = open('test.idx.txt','w') - -data = json.load(open('data/test.json')) -examples = [] -for i in data: - name = i['file'].lower() - history = [] - for turn in i['info']: - history.append(turn['user_orig']) - - bs = turn['BS'] - bs_str = [] - for domain, states in bs.items(): - domain_str = [] - for state in states: - domain_str.append(f'{state[0]} = {state[1]}') - domain_str = ' ; '.join(domain_str) - bs_str.append(domain + ' ' + domain_str) - bs_str = ' | '.join(bs_str) - - db_str = 'kb ' - db = turn['KB'] - if db == 0: - db_str += 'zero' - elif db_str == 1: - db_str += 'one' - elif db_str == 2: - db_str += 'two' - else: - db_str += 'more than two' - - act_seq = ' '.join(turn['act'].keys()) - example = {} - example['Context'] = ' EOS '.join(history[:]) - example['Knowledge'] = '' - example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip() - - history.append(turn['sys'].strip()) - examples.append(copy.copy(example)) - fidx.write(name + '\n') - -writer = jsonlines.open('multiwoz_test_e2e.jsonl', mode='w') -for i in examples: - writer.write(i) - - -data = json.load(open('data/val.json')) -examples = [] -for i in data: - name = i['file'].lower() - history = [] - for turn in i['info']: - history.append(turn['user_orig']) - - - bs = turn['BS'] - bs_str = [] - for domain, states in bs.items(): - domain_str = [] - for state in states: - domain_str.append(f'{state[0]} = {state[1]}') - domain_str = ' ; '.join(domain_str) - bs_str.append(domain + ' ' + domain_str) - bs_str = ' | '.join(bs_str) - - db_str = 'kb ' - db = turn['KB'] - if db == 0: - db_str += 'zero' - elif db_str == 1: - db_str += 'one' - elif db_str == 2: - db_str += 'two' - else: - db_str += 'more than two' - - act_seq = ' '.join(turn['act'].keys()) - example = {} - example['Context'] = ' EOS '.join(history[:]) - example['Knowledge'] = '' - example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip() - - history.append(turn['sys'].strip()) - examples.append(copy.copy(example)) - # fidx.write(name + '\n') - -writer = jsonlines.open('multiwoz_valid_e2e.jsonl', mode='w') -for i in examples: - writer.write(i) - - -data = json.load(open('data/train.json')) -examples = [] -for i in data: - name = i['file'].lower() - history = [] - for turn in i['info']: - history.append(turn['user_orig']) - - - bs = turn['BS'] - bs_str = [] - for domain, states in bs.items(): - domain_str = [] - for state in states: - domain_str.append(f'{state[0]} = {state[1]}') - domain_str = ' ; '.join(domain_str) - bs_str.append(domain + ' ' + domain_str) - bs_str = ' | '.join(bs_str) - - db_str = 'kb ' - db = turn['KB'] - if db == 0: - db_str += 'zero' - elif db_str == 1: - db_str += 'one' - elif db_str == 2: - db_str += 'two' - else: - db_str += 'more than two' - - act_seq = ' '.join(turn['act'].keys()) - example = {} - example['Context'] = ' EOS '.join(history[:]) - example['Knowledge'] = '' - example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip() - - history.append(turn['sys'].strip()) - examples.append(copy.copy(example)) - # fidx.write(name + '\n') - -writer = jsonlines.open('multiwoz_train_e2e.jsonl', mode='w') -for i in examples: - writer.write(i) diff --git a/convlab/e2e/soloist/multiwoz/soloist_net.py b/convlab/e2e/soloist/multiwoz/soloist_net.py deleted file mode 100644 index 45f98200..00000000 --- a/convlab/e2e/soloist/multiwoz/soloist_net.py +++ /dev/null @@ -1,277 +0,0 @@ -import argparse -import logging -import math -import os -import random - -import datasets -import nltk -import numpy as np -import torch -from datasets import load_dataset, load_metric -from torch.utils.data.dataloader import DataLoader -from tqdm.auto import tqdm - -import transformers -from accelerate import Accelerator -from filelock import FileLock -from transformers import ( - CONFIG_MAPPING, - MODEL_MAPPING, - AdamW, - AutoConfig, - AutoModelForSeq2SeqLM, - AutoTokenizer, - DataCollatorForSeq2Seq, - SchedulerType, - get_scheduler, - set_seed, -) -from transformers.file_utils import is_offline_mode -from transformers.utils.versions import require_version - -import copy, operator -from queue import PriorityQueue -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn -from torch.autograd import Variable -from torch.distributions import Categorical -from convlab.e2e.soloist.multiwoz.config import global_config as cfg - -logger = logging.getLogger(__name__) -logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - -def cuda_(var): - return var.cuda() if cfg.cuda and torch.cuda.is_available() else var - - -def tensor(var): - return cuda_(torch.tensor(var)) - -class SOLOIST: - - def __init__(self) -> None: - - self.config = AutoConfig.from_pretrained(cfg.model_name_or_path) - self.model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_name_or_path,config=self.config) - self.tokenizer = AutoTokenizer.from_pretrained('t5-base') - print('model loaded!') - - self.model = self.model.cuda() if torch.cuda.is_available() else self.model - - def generate(self, inputs): - - self.model.eval() - inputs = self.tokenizer([inputs]) - input_ids = tensor(inputs['input_ids']) - # generated_tokens = self.model.generate(input_ids = input_ids, max_length = cfg.max_length, num_beams = cfg.num_beams) - generated_tokens = self.model.generate(input_ids = input_ids, max_length = cfg.max_length, top_p=cfg.top_p) - decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) - - return decoded_preds[0] - - - def train_loop(self): - - def preprocess_function(examples): - contextes = examples['Context'] - responses = examples['Response'] - belief = examples['Belief'] - responses_labels = [] - inputs = [] - - for context, response, kb in zip(contextes, responses, belief): - if cfg.no_kb: - inputs.append(context + ' => ') - else: - - if cfg.format_version == 'e2e': - context = ' EOS '.join(context.split(' EOS ')[-10:]) - _input = context - - if cfg.format_version == 'e2e+lm': - context = ' EOS '.join(context.split(' EOS ')[-10:]) - inputs.append('[E2E] ' + context) - responses_labels.append(response) - inputs.append('[LM] ' + context ) - responses_labels.append(response.split(' EOS ')[1]) - continue - - if cfg.format_version == 'v2': - _input = kb + context - - if cfg.format_version == 'v3': - _input = '' - context = context.split(' EOS ') - for idx, turn in enumerate(context): - if idx % 2 == 0: - _input += 'user : ' + turn.strip() - else: - _input += ' system : ' + turn.strip() - _input = _input + ' <|Knowledge|> ' + kb - - if cfg.format_version == 'v4': - _input = '' - context = context.split(' EOS ') - for idx, turn in enumerate(context): - if idx % 2 == 0: - _input += 'user : ' + turn.strip() - else: - _input += ' system : ' + turn.strip() - _input = kb + _input - - inputs.append(_input) - responses_labels.append(response) - model_inputs = self.tokenizer(inputs, max_length=cfg.max_length, padding="max_length", truncation=True) - - - with self.tokenizer.as_target_tokenizer(): - labels = self.tokenizer(responses_labels, max_length=cfg.max_target_length, padding="max_length", truncation=True) - - - if cfg.ignore_pad_token_for_loss: - labels["labels"] = [ - [(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] - ] - - model_inputs["labels"] = labels["labels"] - return model_inputs - - raw_datasets = load_dataset(cfg.dataset_name) - column_names = ['Context','Response','Belief'] - lm_datasets = raw_datasets.map( - preprocess_function, - batched=True, - remove_columns=column_names, - num_proc=cfg.preprocessing_num_workers, - load_from_cache_file=False, - desc=f"Processing dataset", - ) - - train_dataset = lm_datasets["test"] - # train_dataset = lm_datasets["validation"] - eval_dataset = lm_datasets["test"] - test_dataset = lm_datasets["test"] - for index in random.sample(range(len(train_dataset)), 1): - logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") - - label_pad_token_id = -100 if cfg.ignore_pad_token_for_loss else self.tokenizer.pad_token_id - - accelerator = Accelerator() - logger.info(accelerator.state) - data_collator = DataCollatorForSeq2Seq( - self.tokenizer, - model=self.model, - label_pad_token_id=label_pad_token_id, - pad_to_multiple_of=8 if accelerator.use_fp16 else None, - ) - - - train_dataloader = DataLoader( - train_dataset, shuffle=True, collate_fn=data_collator, batch_size=cfg.per_device_train_batch_size - ) - eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=cfg.per_device_eval_batch_size) - test_dataloader = DataLoader(test_dataset, collate_fn=data_collator, batch_size=cfg.per_device_eval_batch_size) - - # Optimizer - # Split weights in two groups, one with weight decay and the other not. - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": cfg.weight_decay, - }, - { - "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.learning_rate) - - # Prepare everything with our `accelerator`. - self.model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare( - self.model, optimizer, train_dataloader, eval_dataloader, test_dataloader - ) - - # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be - # shorter in multiprocess) - - # Scheduler and math around the number of training steps. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps) - if cfg.max_train_steps is None: - cfg.max_train_steps = cfg.num_train_epochs * num_update_steps_per_epoch - else: - cfg.num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch) - - lr_scheduler = get_scheduler( - name=cfg.lr_scheduler_type, - optimizer=optimizer, - num_warmup_steps=cfg.num_warmup_steps, - num_training_steps=cfg.max_train_steps, - ) - - # Metric - - # Train! - total_batch_size = cfg.per_device_train_batch_size * accelerator.num_processes * cfg.gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num Epochs = {cfg.num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {cfg.per_device_train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {cfg.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {cfg.max_train_steps}") - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(cfg.max_train_steps), disable=not accelerator.is_local_main_process) - completed_steps = 0 - global_steps = 0 - tr_loss, logging_loss = 0.0, 0.0 - for epoch in range(cfg.num_train_epochs): - self.model.train() - # for step, batch in enumerate(train_dataloader): - for step, batch in enumerate(train_dataloader): - global_steps += 1 - outputs = self.model(**batch) - loss = outputs.loss - loss = loss / cfg.gradient_accumulation_steps - tr_loss += loss.item() - accelerator.backward(loss) - - if step % cfg.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - completed_steps += 1 - - if completed_steps >= cfg.max_train_steps: - break - - if step % cfg.logging_steps == 0: - logger.info(f" EVALERR: {(tr_loss - logging_loss)/float(cfg.logging_steps)}") - logging_loss = tr_loss - progress_bar.update(cfg.logging_steps) - - if cfg.output_dir is not None and global_steps % cfg.save_steps == 0 and global_steps > 0: - - accelerator.wait_for_everyone() - if accelerator.is_local_main_process: - checkpoint_prefix = 'checkpoint' - output_dir = os.path.join(cfg.output_dir, '{}-{}'.format(checkpoint_prefix, global_steps)) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - unwrapped_model = accelerator.unwrap_model(self.model) - unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save) - - self.tokenizer.save_pretrained(output_dir) - torch.save(cfg, os.path.join(output_dir, 'training_args.bin')) - logger.info("Saving model checkpoint to %s", output_dir) - - - \ No newline at end of file diff --git a/convlab/e2e/soloist/train.py b/convlab/e2e/soloist/train.py new file mode 100644 index 00000000..1c207895 --- /dev/null +++ b/convlab/e2e/soloist/train.py @@ -0,0 +1,836 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning a 🤗 Transformers model on summarization. +""" +# You can also adapt this script on your own summarization task. Pointers for this are left as comments. + +import argparse +import logging +import math +import os +import random +import json + +import datasets +import nltk +import numpy as np +import torch +from datasets import load_dataset, load_metric +from torch.utils.data.dataloader import DataLoader +from tqdm.auto import tqdm + +import transformers +from accelerate import Accelerator +from filelock import FileLock +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AdamW, + AutoConfig, + AutoModelForSeq2SeqLM, + AutoTokenizer, + DataCollatorForSeq2Seq, + SchedulerType, + get_scheduler, + set_seed, +) +from transformers.file_utils import is_offline_mode +from transformers.utils.versions import require_version + +from nltk.tokenize import TweetTokenizer +import re +re_art = re.compile(r'\b(a|an|the)\b') +re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']') + + +def normalize_answer(s): + return s + +def clean_str(txt): + #print("in=[%s]" % txt) + txt = txt.lower() + txt = re.sub('^',' ', txt) + txt = re.sub('$',' ', txt) + + # url and tag + words = [] + for word in txt.split(): + i = word.find('http') + if i >= 0: + word = word[:i] + ' ' + '__url__' + words.append(word.strip()) + txt = ' '.join(words) + + # remove markdown URL + txt = re.sub(r'\[([^\]]*)\] \( *__url__ *\)', r'\1', txt) + + # remove illegal char + txt = re.sub('__url__','URL',txt) + txt = re.sub(r"[^A-Za-z0-9():,.!?\"\']", " ", txt) + txt = re.sub('URL','__url__',txt) + + # contraction + add_space = ["'s", "'m", "'re", "n't", "'ll","'ve","'d","'em"] + tokenizer = TweetTokenizer(preserve_case=False) + txt = ' ' + ' '.join(tokenizer.tokenize(txt)) + ' ' + txt = txt.replace(" won't ", " will n't ") + txt = txt.replace(" can't ", " can n't ") + for a in add_space: + txt = txt.replace(a+' ', ' '+a+' ') + + txt = re.sub(r'^\s+', '', txt) + txt = re.sub(r'\s+$', '', txt) + txt = re.sub(r'\s+', ' ', txt) # remove extra spaces + + #print("out=[%s]" % txt) + return txt + +logger = logging.getLogger(__name__) +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") + + +import os +from dotenv import load_dotenv +load_dotenv() +if os.getenv('WANDB_API_KEY') is None: + USE_WANDB = False +else: + USE_WANDB = True + wandb_key = os.getenv('WANDB_API_KEY') + +# You should update this to your particular problem to have better documentation of `model_type` +MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + +try: + nltk.data.find("tokenizers/punkt") +except (LookupError, OSError): + if is_offline_mode(): + raise LookupError( + "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" + ) + with FileLock(".lock") as lock: + nltk.download("punkt", quiet=True) + +def parse_args(): + parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help="The name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The configuration name of the dataset to use (via the datasets library).", + ) + parser.add_argument( + "--train_file", type=str, default=None, help="A csv or a json file containing the training data." + ) + parser.add_argument( + "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." + ) + + parser.add_argument( + "--max_source_length", + type=int, + default=1024, + help="The maximum total input sequence length after " + "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.", + ) + parser.add_argument( + "--source_prefix", + type=str, + default=None, + help="A prefix to add before every source text " "(useful for T5 models).", + ) + parser.add_argument( + "--preprocessing_num_workers", + type=int, + default=None, + help="The number of processes to use for the preprocessing.", + ) + + parser.add_argument( + "--max_target_length", + type=int, + default=64, + help="The maximum total sequence length for target text after " + "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded." + "during ``evaluate`` and ``predict``.", + ) + parser.add_argument( + "--val_max_target_length", + type=int, + default=None, + help="The maximum total sequence length for validation " + "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be " + "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` " + "param of ``model.generate``, which is used during ``evaluate`` and ``predict``.", + ) + parser.add_argument( + "--max_length", + type=int, + default=128, + help=( + "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," + " sequences shorter will be padded if `--pad_to_max_lengh` is passed." + ), + ) + parser.add_argument( + "--num_beams", + type=int, + default=None, + help="Number of beams to use for evaluation. This argument will be " + "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=True, + ) + parser.add_argument( + "--config_name", + type=str, + default=None, + help="Pretrained config name or path if not the same as model_name", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--text_column", + type=str, + default=None, + help="The name of the column in the datasets containing the full texts (for summarization).", + ) + parser.add_argument( + "--summary_column", + type=str, + default=None, + help="The name of the column in the datasets containing the summaries (for summarization).", + ) + parser.add_argument( + "--use_slow_tokenizer", + action="store_true", + help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", + ) + parser.add_argument( + "--per_device_train_batch_size", + type=int, + default=8, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--per_device_eval_batch_size", + type=int, + default=8, + help="Batch size (per device) for the evaluation dataloader.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") + parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--lr_scheduler_type", + type=SchedulerType, + default="linear", + help="The scheduler type to use.", + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], + ) + parser.add_argument( + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--model_type", + type=str, + default=None, + help="Model type to use if training from scratch.", + choices=MODEL_TYPES, + ) + + + parser.add_argument( + "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" + ) + + parser.add_argument( + "--pad_to_max_length", type=bool, default=True, help="do pading" + ) + + parser.add_argument( + "--ignore_pad_token_for_loss", type=bool, default=True, help="do pading" + ) + + parser.add_argument( + "--logging_steps", type=int, default=500, help="do pading" + ) + + parser.add_argument( + "--save_steps", type=int, default=5000, help="do pading" + ) + + parser.add_argument( + "--save_every_checkpoint", action="store_true" + ) + + parser.add_argument( + "--max_grad_norm", type=float, default=1.0, help="max_grad_norm" + ) + + parser.add_argument( + "--exp_name", + type=str, + help="Description to the experiment", + default='multiwoz', + ) + + parser.add_argument( + "--use_special_token", + action="store_true", + help="add special token or not" + ) + + parser.add_argument( + "--format_version", + type=str, default='v1', + help="format version" + ) + + parser.add_argument( + "--wandb_exp_name", + type=str, + default='multiwoz', + help="Description to the experiment worksheet name", + ) + + + args = parser.parse_args() + + # Sanity checks + if args.dataset_name is None and args.train_file is None and args.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if args.train_file is not None: + extension = args.train_file.split(".")[-1] + assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." + if args.validation_file is not None: + extension = args.validation_file.split(".")[-1] + assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." + + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + return args + + +def main(): + args = parse_args() + + if args.source_prefix is None and args.model_name_or_path in [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + ]: + logger.warning( + "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " + "`--source_prefix 'summarize: ' `" + ) + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. + accelerator = Accelerator() + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state) + + # Setup logging, we only want one process per machine to log things on the screen. + # accelerator.is_local_main_process is only True for one process per machine. + logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + if accelerator.is_local_main_process and USE_WANDB: + config = dict( + dataset_id = "", + infra = "", + ) + import wandb + wandb.init( + project=args.wandb_exp_name, + notes="Finetuning", + tags=["multiwoz"], + config=config, + entity= 'Convlab3') + + wandb.run.name = args.exp_name + + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called + # 'text' is found. You can easily tweak this behavior (see below). + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + else: + data_files = {} + if args.train_file is not None: + data_files["train"] = args.train_file + if args.validation_file is not None: + data_files["validation"] = args.validation_file + extension = args.train_file.split(".")[-1] + raw_datasets = load_dataset(extension, data_files=data_files) + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + if args.config_name: + config = AutoConfig.from_pretrained(args.config_name) + elif args.model_name_or_path: + config = AutoConfig.from_pretrained(args.model_name_or_path) + else: + config = CONFIG_MAPPING[args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer) + elif args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if args.model_name_or_path: + model = AutoModelForSeq2SeqLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForSeq2SeqLM.from_config(config) + + if 'blender' in args.model_name_or_path.lower(): + model.model.encoder.embed_positions.weight = torch.nn.Parameter(model.model.encoder.embed_positions.weight.repeat(4,1)) + tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + + if args.use_special_token: + special_tokens = [i.strip() for i in open('special_tokens.txt')] + tokenizer.add_tokens(special_tokens) + + model.resize_token_embeddings(len(tokenizer)) + if model.config.decoder_start_token_id is None: + raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") + + prefix = args.source_prefix if args.source_prefix is not None else "" + max_length = args.max_length + padding = "max_length" if args.pad_to_max_length else False + max_target_length = args.max_target_length + def preprocess_function(examples): + contextes = examples['Context'] + responses = examples['Response'] + kbs = examples['Knowledge'] + + responses_labels = [] + inputs = [] + + for context, response, kb in zip(contextes, responses, kbs): + if args.format_version == 'v1': + _input = ' EOS '.join(context.split(' EOS ')[-10:]) + _response = 'Belief: ' + kb + response + inputs.append(_input) + responses_labels.append(_response) + + model_inputs = tokenizer(inputs, max_length=args.max_length, padding=padding, truncation=True) + + # labels = model_inputs + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + labels = tokenizer(responses_labels, max_length=max_target_length, padding=padding, truncation=True) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and args.ignore_pad_token_for_loss: + labels["labels"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + + model_inputs["labels"] = labels["labels"] + return model_inputs + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder + # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower + # to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + + # del raw_datasets['train'] + # del raw_datasets['test'] + column_names = ['Context','Response','Knowledge','Dataset'] + # # column_names = ['text'] + # raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name) + lm_datasets = raw_datasets.map( + preprocess_function, + batched=True, + remove_columns=column_names, + num_proc=args.preprocessing_num_workers, + load_from_cache_file=False, + desc=f"Processing dataset", + ) + + train_dataset = lm_datasets["test"] + eval_dataset = lm_datasets["validation"] + test_dataset = lm_datasets["test"] + + # Log a few random samples from the training set: + for index in random.sample(range(len(train_dataset)), 1): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + + label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=8 if accelerator.use_fp16 else None, + ) + + def postprocess_text(preds, labels): + preds = [normalize_answer(pred.strip().replace('Agent :','')) for pred in preds] + labels = [normalize_answer(label.strip().replace('Agent :','')) for label in labels] + + # rougeLSum expects newline after each sentence + # preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] + # labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] + + return preds, labels + + train_dataloader = DataLoader( + train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size + ) + eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) + test_dataloader = DataLoader(test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) + + # Optimizer + # Split weights in two groups, one with weight decay and the other not. + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) + + # Prepare everything with our `accelerator`. + model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, test_dataloader + ) + + # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be + # shorter in multiprocess) + + # Scheduler and math around the number of training steps. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + else: + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + lr_scheduler = get_scheduler( + name=args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=args.num_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # Metric + + # Train! + total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + completed_steps = 0 + global_steps = 0 + tr_loss, logging_loss = 0.0, 0.0 + for epoch in range(args.num_train_epochs): + model.train() + + for step, batch in enumerate(train_dataloader): + + global_steps += 1 + outputs = model(**batch) + loss = outputs.loss + loss = loss / args.gradient_accumulation_steps + tr_loss += loss.item() + accelerator.backward(loss) + + if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + completed_steps += 1 + + if completed_steps >= args.max_train_steps: + break + + if step % args.logging_steps == 0: + logger.info(f" EVALERR: {(tr_loss - logging_loss)/float(args.logging_steps)}") + if accelerator.is_local_main_process and USE_WANDB: + wandb.log({'loss': tr_loss - logging_loss}) + logging_loss = tr_loss + progress_bar.update(args.logging_steps) + + if args.output_dir is not None and global_steps % args.save_steps == 0 and global_steps > 0: + print('hit store') + accelerator.wait_for_everyone() + if accelerator.is_local_main_process: + checkpoint_prefix = 'checkpoint' + output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_steps)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save) + + tokenizer.save_pretrained(output_dir) + torch.save(args, os.path.join(output_dir, 'training_args.bin')) + logger.info("Saving model checkpoint to %s", output_dir) + + model.eval() + if args.val_max_target_length is None: + args.val_max_target_length = args.max_target_length + + gen_kwargs = { + "max_length": args.val_max_target_length if args is not None else config.max_length, + "num_beams": args.num_beams, + } + + def chunks(lst, n): + for i in range(0, len(lst), n): + yield lst[i:i + n] + + metric = load_metric("./rouge_metric.py") + metric_bleu = load_metric("./bleu_metric.py") + decoded_preds_all = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + generated_tokens = accelerator.unwrap_model(model).generate( + batch["input_ids"], + attention_mask=batch["attention_mask"], + **gen_kwargs, + ) + + generated_tokens = accelerator.pad_across_processes( + generated_tokens, dim=1, pad_index=tokenizer.pad_token_id + ) + labels = batch["labels"] + if not args.pad_to_max_length: + # If we did not pad to max length, we need to pad the labels too + labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id) + + generated_tokens = accelerator.gather(generated_tokens).cpu().numpy() + labels = accelerator.gather(labels).cpu().numpy() + + if args.ignore_pad_token_for_loss: + # Replace -100 in the labels as we can't decode them. + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + if isinstance(generated_tokens, tuple): + generated_tokens = generated_tokens[0] + decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + metric.add_batch(predictions=decoded_preds, references=decoded_labels) + _decoded_preds = [i.split() for i in decoded_preds] + _decoded_labels = [[i.split()] for i in decoded_labels] + decoded_preds_all.extend(decoded_preds) + metric_bleu.add_batch(predictions=_decoded_preds, references=_decoded_labels) + + + result = metric.compute(use_stemmer=True) + # Extract a few results from ROUGE + result = {key: value.mid.fmeasure * 100 for key, value in result.items()} + + result = {k: round(v, 4) for k, v in result.items()} + + logger.info(result) + + result_bleu = metric_bleu.compute() + logger.info(result_bleu) + + accelerator.wait_for_everyone() + if accelerator.is_local_main_process and USE_WANDB: + wandb.log({'valid_bleu': result_bleu['bleu']}) + wandb.log({'valid_rouge': result['rougeL']}) + + + if args.output_dir is not None: + accelerator.wait_for_everyone() + if accelerator.is_local_main_process: + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + output_dir_file_name = os.path.join(args.output_dir, 'valid-step-{}'.format(completed_steps)) + print(output_dir_file_name) + json.dump(decoded_preds_all, open(output_dir_file_name,'w'), indent=2) + logger.info("Saving model outputs to %s", output_dir_file_name) + + metric = load_metric("rouge") + metric_bleu = load_metric("bleu") + + gen_kwargs = { + "max_length": args.val_max_target_length if args is not None else config.max_length, + "num_beams": args.num_beams, + } + decoded_preds_all = [] + for step, batch in enumerate(test_dataloader): + with torch.no_grad(): + generated_tokens = accelerator.unwrap_model(model).generate( + batch["input_ids"], + attention_mask=batch["attention_mask"], + **gen_kwargs, + ) + + generated_tokens = accelerator.pad_across_processes( + generated_tokens, dim=1, pad_index=tokenizer.pad_token_id + ) + labels = batch["labels"] + if not args.pad_to_max_length: + # If we did not pad to max length, we need to pad the labels too + labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id) + + generated_tokens = accelerator.gather(generated_tokens).cpu().numpy() + labels = accelerator.gather(labels).cpu().numpy() + + if args.ignore_pad_token_for_loss: + # Replace -100 in the labels as we can't decode them. + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + if isinstance(generated_tokens, tuple): + generated_tokens = generated_tokens[0] + decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) + metric.add_batch(predictions=decoded_preds, references=decoded_labels) + _decoded_preds = [i.split() for i in decoded_preds] + _decoded_labels = [[i.split()] for i in decoded_labels] + decoded_preds_all.extend(_decoded_preds) + metric_bleu.add_batch(predictions=_decoded_preds, references=_decoded_labels) + + result = metric.compute(use_stemmer=True) + # Extract a few results from ROUGE + result = {key: value.mid.fmeasure * 100 for key, value in result.items()} + + result = {k: round(v, 4) for k, v in result.items()} + + logger.info(result) + + result_bleu = metric_bleu.compute() + logger.info(result_bleu) + + accelerator.wait_for_everyone() + if accelerator.is_local_main_process and USE_WANDB: + wandb.log({'test_bleu': result_bleu['bleu']}) + wandb.log({'test_rouge': result['rougeL']}) + + import json + if args.output_dir is not None: + accelerator.wait_for_everyone() + if accelerator.is_local_main_process: + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + output_dir_file_name = os.path.join(args.output_dir, 'test-step-{}'.format(completed_steps)) + print(output_dir_file_name) + json.dump(decoded_preds_all, open(output_dir_file_name,'w'), indent=2) + logger.info("Saving model outputs to %s", output_dir_file_name) + + + if args.output_dir is not None and args.save_every_checkpoint: + accelerator.wait_for_everyone() + if accelerator.is_local_main_process: + checkpoint_prefix = 'checkpoint' + output_dir = os.path.join(args.output_dir, '{}-epoch-{}'.format(checkpoint_prefix, epoch)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + unwrapped_model = accelerator.unwrap_model(model) + unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save) + + tokenizer.save_pretrained(output_dir) + torch.save(args, os.path.join(output_dir, 'training_args.bin')) + logger.info("Saving model checkpoint to %s", output_dir) + + +if __name__ == "__main__": + main() diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py index aff31be6..32e98234 100644 --- a/convlab/util/unified_datasets_util.py +++ b/convlab/util/unified_datasets_util.py @@ -148,6 +148,7 @@ def load_unified_data( dialogue_acts=False, state=False, db_results=False, + delex_utterance=False, use_context=False, context_window_size=0, terminated=False, @@ -182,7 +183,7 @@ def load_unified_data( data_splits = dataset.keys() if data_split == 'all' else [data_split] assert speaker in ['user', 'system', 'all'] assert not use_context or context_window_size > 0 - info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results'])) + info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results', 'delex_utterance'])) info_list += ['utt_idx'] data_by_split = {} for data_split in data_splits: @@ -426,7 +427,12 @@ def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore for value in values.split('|'): if value.lower() not in ignore_values: placeholder = delex_func(domain, slot, value) - pattern = re.compile(r'\b({})\b'.format(value), flags=re.I) + #TODO: value = ? + value = '\?' if value == '?' else value + try: + pattern = re.compile(r'\b({})\b'.format(value), flags=re.I) + except Exception: + print(value) if delex_inplace(delex_utt, pattern): delex_vocab.add(placeholder) -- GitLab