diff --git a/convlab2/nlg/evaluate.py b/convlab2/nlg/evaluate.py index e2cffc4553060c9c6c5e812c0ccd317981274fd6..1a4747b7f19a47f2c069e4ba286c9ad16763043b 100755 --- a/convlab2/nlg/evaluate.py +++ b/convlab2/nlg/evaluate.py @@ -8,6 +8,7 @@ import json import os import random import sys +import itertools import zipfile import numpy from numpy.lib.shape_base import _put_along_axis_dispatcher @@ -211,16 +212,18 @@ if __name__ == '__main__': numpy.random.seed(seed) torch.manual_seed(seed) - if len(sys.argv) != 4: + if len(sys.argv) < 4: print("usage:") print("\t python evaluate.py dataset model role") print("\t dataset=MultiWOZ, CrossWOZ, or Camrest") print("\t model=SCLSTM, SCLSTM_NoUNK, SCGPT or TemplateNLG") print("\t role=usr/sys") + print("\t [Optional] model_file") sys.exit() dataset_name = sys.argv[1] model_name = sys.argv[2] role = sys.argv[3] + model_file = sys.argv[4] if len(sys.argv) >= 5 else None if dataset_name == 'MultiWOZ': if model_name == 'SCLSTM': from convlab2.nlg.sclstm.multiwoz import SCLSTM @@ -242,17 +245,19 @@ if __name__ == '__main__': model = TemplateNLG(is_user=False) elif model_name == 'SCGPT': from convlab2.nlg.scgpt.multiwoz import SCGPT + if model_file is not None: + print(f"load model at {model_file}") if role == 'usr': - model = SCGPT(is_user=True) + model = SCGPT(model_file, is_user=True) elif role == 'sys': - model = SCGPT(is_user=False, model_file='scgpt/trained_output/multiwoz/') + model = SCGPT(model_file, is_user=False) else: raise Exception("Available models: SCLSTM, SCGPT, TEMPLATE") from convlab2.util.dataloader.module_dataloader import SingleTurnNLGDataloader from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader dataloader = SingleTurnNLGDataloader(dataset_dataloader=MultiWOZDataloader()) - data = dataloader.load_data(data_key='all', role=role)['test'] + data = dataloader.load_data(data_key='all', role=role, session_id=True)['test'] dialog_acts = [] golden_utts = [] @@ -262,7 +267,19 @@ if __name__ == '__main__': sen_num = 0 # sys.stdout = open(sys.argv[2] + '-' + sys.argv[3] + '-' + 'evaluate_logs_neo.txt','w') + assert 'utterance' in data and 'dialog_act' in data and 'session_id' in data + assert len(data['utterance']) == len(data['dialog_act']) == len(data['session_id']) + + # Turns during the same session should be contiguous, so we can call init_session at the first turn of a new session. + # This is necessary for SCGPT, but unnecessary for SCLSTM and TemplateNLG. + is_first_turn = [] + for _, iterator in itertools.groupby(data['session_id']): + is_first_turn.append(True) + next(iterator) + is_first_turn.extend(False for _ in iterator) for i in tqdm(range(len(data['utterance']))): + if is_first_turn[i]: + model.init_session() dialog_acts.append(data['dialog_act'][i]) golden_utts.append(data['utterance'][i]) gen_utts.append(model.generate(data['dialog_act'][i])) diff --git a/convlab2/nlg/scgpt/README.md b/convlab2/nlg/scgpt/README.md index b8630eeb2bcccbf454539883f512a42a4bebd4f3..5eed2c0fc167cd9ee79d66e3252be25060bb294d 100644 --- a/convlab2/nlg/scgpt/README.md +++ b/convlab2/nlg/scgpt/README.md @@ -21,9 +21,22 @@ tar -xvf scgpt.tar.gz Then ``` python -python train.py --output_dir=trained_output --model_type=gpt2 --model_name_or_path=scgpt --do_train --do_eval --eval_data_file=multiwoz/data/test_sys.txt --overwrite_cache --use_tokenize --train_data_file=multiwoz/data/train_sys.txt --overwrite_output_dir +python train.py --output_dir=trained_output --model_type=gpt2 --model_name_or_path=scgpt --do_train --do_eval --eval_data_file=multiwoz/data/test_sys.txt --use_tokenize --train_data_file=multiwoz/data/train_sys.txt --overwrite_output_dir ``` +some tricks (optional training argument): +* `--gradient_accumulation_steps xxx` +* `--fp16`, if it's set, you'd better set `--per_gpu_train_batch_size` to be multiple of 8 +* `--max_seq xxx`, it should be larger than the length of the longest sequence. You can set `--max_seq 1024`. The script uses a dynamic sequence length at each training step. +* `--gradient_checkpointing`, it allows larger `per_gpu_train_batch_size` +* `--use_multi_tensor_adamw`, someone says it's a faster optimizer + +distributed data parallel: + + If multiple GPUs are available, you can run `python -m torch.distributed.launch --nproc_per_node CUDA_COUNT train.py ......` + + `CUDA_COUNT` is the number of GPUs. `.....` are arguments of `train.py`. + ## Use ```python diff --git a/convlab2/nlg/scgpt/modeling_utils.py b/convlab2/nlg/scgpt/modeling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a8b3f6ddfc6b7347c624446bf7869c67d3064cc1 --- /dev/null +++ b/convlab2/nlg/scgpt/modeling_utils.py @@ -0,0 +1,53 @@ +import warnings +from contextlib import nullcontext +from typing import TYPE_CHECKING +import torch.cuda.amp as amp +import transformers +from transformers import GPT2LMHeadModel + + +# reference: https://pytorch.org/docs/master/notes/amp_examples.html +class AmpGPT2LMHeadModel(GPT2LMHeadModel): + if TYPE_CHECKING: + # For IDE's code hinting + forward = GPT2LMHeadModel.forward + else: + def forward(self, *args, **kwargs): + with amp.autocast(): + return super().forward(*args, **kwargs) + + +def try_enable_gradient_checkpointing(model: "transformers.modeling_utils.PreTrainedModel"): + if model.supports_gradient_checkpointing: + model.gradient_checkpointing_enable() + else: + warnings.warn(f"{type(model)} doesn't support gradient_checkpointing") + + +class AmpHelper: + """ + References: + https://pytorch.org/docs/master/notes/amp_examples.html + """ + def __init__(self, use_amp=True): + self.use_amp = use_amp + self.might_enable_autocast = amp.autocast() if use_amp else nullcontext() + self.scaler = amp.GradScaler() + + def backward(self, loss): + if self.use_amp: + return self.scaler.scale(loss).backward() + else: + return loss.backward() + + def step(self, optimizer): + if self.use_amp: + self.scaler.step(optimizer) + self.scaler.update() + else: + optimizer.step() + + def might_unscale_(self, optimizer): + if self.use_amp: + # Unscales the gradients of optimizer's assigned params in-place + self.scaler.unscale_(optimizer) \ No newline at end of file diff --git a/convlab2/nlg/scgpt/multiwoz/preprocess.py b/convlab2/nlg/scgpt/multiwoz/preprocess.py index 10e588886f4e316fbade8eba04059b09f86030ff..ae7b08566842435247b09625b5410bc741c58db8 100644 --- a/convlab2/nlg/scgpt/multiwoz/preprocess.py +++ b/convlab2/nlg/scgpt/multiwoz/preprocess.py @@ -6,6 +6,7 @@ Created on Mon Sep 14 11:38:53 2020 import os import json +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from convlab2.nlg.scgpt.utils import dict2dict, dict2seq import zipfile @@ -14,65 +15,6 @@ def read_zipped_json(filepath, filename): archive = zipfile.ZipFile(filepath, 'r') return json.load(archive.open(filename)) -cur_dir = os.path.dirname(os.path.abspath(__file__)) -data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname( - cur_dir)))), 'data/multiwoz/') - -keys = ['train', 'val', 'test'] -data = {} -for key in keys: - data_key = read_zipped_json(os.path.join(data_dir, key + '.json.zip'), key + '.json') - print('load {}, size {}'.format(key, len(data_key))) - data = dict(data, **data_key) - -with open(os.path.join(data_dir, 'valListFile'), 'r') as f: - val_list = f.read().splitlines() -with open(os.path.join(data_dir, 'testListFile'), 'r') as f: - test_list = f.read().splitlines() - -results = {} -results_val = {} -results_test = {} - -for title, sess in data.items(): - logs = sess['log'] - turns = [] - turn = {'turn':0, 'sys':'', 'sys_da':''} - current_domain = None - for i, diag in enumerate(logs): - text = diag['text'] - da = diag['dialog_act'] - span = diag['span_info'] - if i % 2 == 0: - turn['usr'] = text - if current_domain: - da = eval(str(da).replace('Booking', current_domain)) - span = eval(str(span).replace('Booking', current_domain)) - turn['usr_da'] = da - turn['usr_span'] = span - turns.append(turn) - else: - turn = {'turn': i//2 +1} - turn['sys'] = text - turn['sys_da'] = da - turn['sys_span'] = span - for key in da: - domain = key.split('-')[0] - if domain not in ['general', 'Booking']: - current_domain = domain - title = title - if title in val_list: - current = results_val - elif title in test_list: - current = results_test - else: - current = results - current[title] = turns - -results = eval(str(results).replace(" n't", " not")) -results_val = eval(str(results_val).replace(" n't", " not")) -results_test = eval(str(results_test).replace(" n't", " not")) - def init_domain(): return {'Attraction':False, 'Hospital':False, @@ -82,32 +24,106 @@ def init_domain(): 'Taxi':False, 'Train':False} -def write_file(name, data): +def write_file(name, data, role='usr'): with open(f'{name}.txt', 'w', encoding='utf-8') as f: for ID in data: sess = data[ID] sess_domains = init_domain() for turn in sess: - # TODO: set option to process usr/sys - if not turn['usr_da']: - continue - turn['usr_da'] = eval(str(turn['usr_da']).replace('Bus','Train')) - da_seq = dict2seq(dict2dict(turn['usr_da'])).replace('&', 'and') - domains = set([key.split('-')[0] for key in turn['usr_da'].keys()]) - if not turn['sys_da']: - continue - turn['sys_da'] = eval(str(turn['sys_da']).replace('Bus','Train')) - da_seq = dict2seq(dict2dict(turn['sys_da'])).replace('&', 'and') - domains = set([key.split('-')[0] for key in turn['sys_da'].keys()]) + if role == 'usr': + if not turn['usr_da']: + continue + turn['usr_da'] = eval(str(turn['usr_da']).replace('Bus','Train')) + da_seq = dict2seq(dict2dict(turn['usr_da'])).replace('&', 'and') + domains = set([key.split('-')[0] for key in turn['usr_da'].keys()]) + elif role == 'sys': + if not turn['sys_da']: + continue + turn['sys_da'] = eval(str(turn['sys_da']).replace('Bus','Train')) + da_seq = dict2seq(dict2dict(turn['sys_da'])).replace('&', 'and') + domains = set([key.split('-')[0] for key in turn['sys_da'].keys()]) + else: + raise NameError('Invalid Role: Select usr/sys.') for domain in domains: if domain not in ['general', 'Booking'] and not sess_domains[domain]: da_seq = da_seq.replace(domain.lower(), domain.lower()+' *', 1) sess_domains[domain] = True - da_uttr = turn['usr'].replace(' bus ', ' train ').replace('&', 'and') - da_uttr = turn['sys'].replace(' bus ', ' train ').replace('&', 'and') + + if role == 'usr': + da_uttr = turn['usr'].replace(' bus ', ' train ').replace('&', 'and') + elif role == 'sys': + da_uttr = turn['sys'].replace(' bus ', ' train ').replace('&', 'and') f.write(f'{da_seq} & {da_uttr}\n') -if not os.path.exists(os.path.join(cur_dir,'data')): - os.makedirs(os.path.join(cur_dir, 'data')) -write_file(os.path.join(cur_dir, 'data/train'), dict(results, **results_val)) -write_file(os.path.join(cur_dir, 'data/test'), results_test) + +if __name__ == '__main__': + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument('--role', type=str, default='usr') + args = parser.parse_args() + + cur_dir = os.path.dirname(os.path.abspath(__file__)) + data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname( + cur_dir)))), 'data/multiwoz/') + + keys = ['train', 'val', 'test'] + data = {} + for key in keys: + data_key = read_zipped_json(os.path.join(data_dir, key + '.json.zip'), key + '.json') + print('load {}, size {}'.format(key, len(data_key))) + data = dict(data, **data_key) + + with open(os.path.join(data_dir, 'valListFile'), 'r') as f: + val_list = f.read().splitlines() + with open(os.path.join(data_dir, 'testListFile'), 'r') as f: + test_list = f.read().splitlines() + + results = {} + results_val = {} + results_test = {} + + for title, sess in data.items(): + logs = sess['log'] + turns = [] + turn = {'turn': 0, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''} + current_domain = None + for i, diag in enumerate(logs): + text = diag['text'] + da = diag['dialog_act'] + span = diag['span_info'] + if current_domain: + da = eval(str(da).replace('Booking', current_domain)) + span = eval(str(span).replace('Booking', current_domain)) + if i % 2 == 0: + turn['usr'] = text + turn['usr_da'] = da + turn['usr_span'] = span + turns.append(turn) + else: + turn = {'turn': i//2 + 1, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''} + turn['sys'] = text + turn['sys_da'] = da + turn['sys_span'] = span + for key in da: + domain = key.split('-')[0] + if domain not in ['general', 'Booking']: + current_domain = domain + else: + if args.role == 'sys': + turns.append(turn) + title = title + if title in val_list: + current = results_val + elif title in test_list: + current = results_test + else: + current = results + current[title] = turns + + results = eval(str(results).replace(" n't", " not")) + results_val = eval(str(results_val).replace(" n't", " not")) + results_test = eval(str(results_test).replace(" n't", " not")) + + if not os.path.exists(os.path.join(cur_dir,'data')): + os.makedirs(os.path.join(cur_dir, 'data')) + write_file(os.path.join(cur_dir, f'data/train_{args.role}'), dict(results, **results_val), role=args.role) + write_file(os.path.join(cur_dir, f'data/test_{args.role}'), results_test, role=args.role) diff --git a/convlab2/nlg/scgpt/multiwoz/scgpt.py b/convlab2/nlg/scgpt/multiwoz/scgpt.py index b4b957c0d1b17a3805f388641dcd14bfe0b32fa2..78f16f6e0b8562c7118a2a6118f0eb5b3287c828 100644 --- a/convlab2/nlg/scgpt/multiwoz/scgpt.py +++ b/convlab2/nlg/scgpt/multiwoz/scgpt.py @@ -2,6 +2,7 @@ import torch import numpy as np import os import zipfile +from copy import deepcopy from transformers import GPT2LMHeadModel, GPT2Tokenizer from convlab2.nlg.scgpt.utils import tuple2seq @@ -10,23 +11,31 @@ from convlab2.nlg.nlg import NLG from convlab2.util.file_util import cached_path MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop -DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") -DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "nlg-gpt-multiwoz.zip") class SCGPT(NLG): - def __init__(self, - archive_file=DEFAULT_ARCHIVE_FILE, - use_cuda=True, - is_user=False, - model_file='https://convlab.blob.core.windows.net/convlab-2/nlg-gpt-multiwoz.zip'): + def __init__(self, model_file=None, + use_cuda=True, is_user=False): + # If no filename is mentioned then set to default + if not model_file: + if is_user: + model_file = 'https://convlab.blob.core.windows.net/convlab-2/nlg-gpt-multiwoz.zip' + else: + model_file = 'https://zenodo.org/record/5767426/files/neo_scgpt_system.zip' + + # Load from file/url model_dir = os.path.dirname(os.path.abspath(__file__)) - if not os.path.isfile(archive_file): - archive_file = cached_path(model_file) - archive = zipfile.ZipFile(archive_file, 'r') + if not os.path.isfile(model_file): + model_file = cached_path(model_file) + if not os.path.isdir(model_file): + archive = zipfile.ZipFile(model_file, 'r') archive.extractall(model_dir) - - self.model_name_or_path = os.path.join(model_dir, 'multiwoz') + # Get model directory + model_file = archive.filelist[0].filename.replace('/', '') + self.model_name_or_path = os.path.join(model_dir, model_file) + else: + self.model_name_or_path = model_file + self.length = 50 self.num_samples = 5 self.temperature = 1.0 @@ -63,8 +72,9 @@ class SCGPT(NLG): 'Restaurant':False, 'Taxi':False, 'Train':False,} - if not self.is_user: - self.sess_domains['Booking'] = False + self.cur_domain = None + # if not self.is_user: + # self.sess_domains['Booking'] = False def generate(self, meta): @@ -72,10 +82,23 @@ class SCGPT(NLG): if not meta: return 'No user action' + meta = deepcopy(meta) + for list_ in meta: + domain = list_[1] + if domain not in ('general', 'Booking'): + self.cur_domain = domain + for i, list_ in enumerate(meta): + list_ = list(list_) + if list_[1] == 'Booking': + if self.cur_domain is not None: + list_[1] = self.cur_domain + meta[i] = list_ + else: + print('`cur_domain` is None, but there is `Booking` in dialog action.') raw_text = tuple2seq(meta) domains = set([item[1] for item in meta]) for domain in domains: - if domain != 'general' and not self.sess_domains[domain]: + if domain not in ('general', 'Booking') and not self.sess_domains[domain]: raw_text = raw_text.replace(domain.lower(), domain.lower()+ ' *', 1) self.sess_domains[domain] = True context_tokens = self.tokenizer.encode(raw_text, add_special_tokens=False) @@ -97,4 +120,4 @@ class SCGPT(NLG): text = text.split('& ')[-1] text = text[: text.find(self.stop_token) if self.stop_token else None] - return text \ No newline at end of file + return text diff --git a/convlab2/nlg/scgpt/train.py b/convlab2/nlg/scgpt/train.py index 775688bbd63e116da42d5f02ecb78930c823a229..0878f31353735ede8b2036ec1f46ef56ce129bed 100644 --- a/convlab2/nlg/scgpt/train.py +++ b/convlab2/nlg/scgpt/train.py @@ -9,33 +9,28 @@ import random import re import shutil -import sys - import numpy as np import torch +from tqdm import tqdm, trange from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler from torch.utils.data.distributed import DistributedSampler try: from torch.utils.tensorboard import SummaryWriter -except: +except ImportError: from tensorboardX import SummaryWriter -from tqdm import tqdm, trange - from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, - BertConfig, BertForMaskedLM, BertTokenizer, - GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, - OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, - RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, - DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer, BertTokenizer) - + BertConfig, BertForMaskedLM, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, + OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2TokenizerFast, + RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, + DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer, BertTokenizer) +from convlab2.nlg.scgpt.modeling_utils import AmpGPT2LMHeadModel, try_enable_gradient_checkpointing, AmpHelper logger = logging.getLogger(__name__) - MODEL_CLASSES = { - 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), + 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast), 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 'bert': (BertConfig, BertForMaskedLM, BertTokenizer), 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer), @@ -43,11 +38,20 @@ MODEL_CLASSES = { } +def closest_multiple_of_8(n): + """ + Returns: + a closest number, which is a multiple of 8 and >= n + """ + return ((n + 7) >> 3) << 3 + + class TextDataset(Dataset): def __init__(self, tokenizer, args, file_path='train', block_size=512, max_seq=80): assert os.path.isfile(file_path) directory, filename = os.path.split(file_path) - cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(block_size) + '_seqlen_' + str(max_seq) + '_' + filename) + cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str( + block_size) + '_seqlen_' + str(max_seq) + '_' + filename) if os.path.exists(cached_features_file) and not args.overwrite_cache: logger.info("Loading features from cached file %s", cached_features_file) @@ -68,12 +72,11 @@ class TextDataset(Dataset): self.examples.append(tokenized_text) if args.text_chunk: - for i in range(0, len(tokenized_text)-block_size+1, block_size): # Truncate in block of block_size - self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i:i+block_size])) + for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size + self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i:i + block_size])) - # Note that we are loosing the last truncated example here for the sake of simplicity (no padding) - # If your dataset is small, first you should loook for a bigger one :-) and second you + # If your dataset is small, first you should look for a bigger one :-) and second you # can change this behavior by adding (model specific) padding. logger.info("Saving features into cached file %s", cached_features_file) @@ -86,26 +89,30 @@ class TextDataset(Dataset): def __getitem__(self, item): return torch.tensor(self.examples[item]) + class TextSeqDataset(Dataset): - def __init__(self, tokenizer, args, file_path='train', block_size=512, max_seq=80, seperator=' & '): + def __init__(self, tokenizer, args, file_path='train', block_size=512, max_seq=80, separator=' & '): + max_seq = closest_multiple_of_8(max_seq) assert os.path.isfile(file_path) directory, filename = os.path.split(file_path) - cached_features_file = os.path.join(directory, args.output_dir.replace(os.sep, '_') + '_cached_lm_' + str(block_size) + '_seqlen_' + str(max_seq) + '_' + filename) + cached_features_file = os.path.join(directory, args.output_dir.replace(os.sep, '_') + '_cached_lm_' + str( + block_size) + '_seqlen_' + str(max_seq) + '_' + filename) if os.path.exists(cached_features_file) and not args.overwrite_cache: logger.info("Loading features from cached file %s", cached_features_file) with open(cached_features_file, 'rb') as handle: - self.examples = pickle.load(handle) + self.examples, self.masks, self.labels, self.seq_lengths = pickle.load(handle) else: logger.info("Creating features from dataset file at %s", directory) self.examples = [] self.labels = [] self.masks = [] + self.seq_lengths = [] with open(file_path, encoding="utf-8") as f: - for line in f: - line = line.strip() - raw_str = line.lower() - code_str = line.lower().split(seperator)[0] + seperator + for line in tqdm(f): + line = line.strip() + raw_str = line.lower() # do we need lowercase? + code_str = line.lower().split(separator)[0] + separator code_str = code_str.strip() if len(raw_str.split()) > max_seq -1: raw_str = ' '.join(raw_str.split()[:max_seq -1]) @@ -118,40 +125,44 @@ class TextSeqDataset(Dataset): code_str_len = len(tokenizer.convert_tokens_to_ids(code_str.split())) label = [-1] * max_seq - label[:len(tokenized_text)] = tokenized_text + label[:len(tokenized_text)] = tokenized_text mask = [1] * max_seq - if len(tokenized_text) < max_seq: + self.seq_lengths.append(len(tokenized_text)) mask[-(max_seq - len(tokenized_text)):] = [0] * (max_seq - len(tokenized_text)) # label[code_str_len:len(tokenized_text)] = tokenized_text[code_str_len:] - tokenized_text = tokenized_text + [0] * (max_seq - len(tokenized_text)) + tokenized_text = tokenized_text + [tokenizer.eos_token_id] * (max_seq - len(tokenized_text)) else: + self.seq_lengths.append(max_seq) tokenized_text = tokenized_text[:max_seq] - # label[code_str_len:] = tokenized_text[code_str_len:] - + # label[code_str_len:] = tokenized_text[code_str_len:] + self.examples.append(tokenized_text) self.masks.append(mask) self.labels.append(label) # Note that we are loosing the last truncated example here for the sake of simplicity (no padding) - # If your dataset is small, first you should loook for a bigger one :-) and second you + # If your dataset is small, first you should look for a bigger one :-) and second you # can change this behavior by adding (model specific) padding. if args.with_code_loss: self.labels = self.examples logger.info("Saving features into cached file %s", cached_features_file) with open(cached_features_file, 'wb') as handle: - pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump((self.examples, self.masks, self.labels, self.seq_lengths), handle, + protocol=pickle.HIGHEST_PROTOCOL) def __len__(self): return len(self.examples) def __getitem__(self, item): - return torch.tensor(self.examples[item]), torch.tensor(self.masks[item]), torch.tensor(self.labels[item]) + return torch.tensor(self.examples[item]), torch.tensor(self.masks[item]), torch.tensor( + self.labels[item]), torch.tensor(self.seq_lengths[item]) def load_and_cache_examples(args, tokenizer, evaluate=False): - dataset = TextSeqDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size, max_seq=args.max_seq) + dataset = TextSeqDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, + block_size=args.block_size, max_seq=args.max_seq) return dataset @@ -197,7 +208,8 @@ def mask_tokens(inputs, tokenizer, args): labels = inputs.clone() # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) probability_matrix = torch.full(labels.shape, args.mlm_probability) - special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()] + special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in + labels.tolist()] probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) masked_indices = torch.bernoulli(probability_matrix).bool() labels[~masked_indices] = -1 # We only compute loss on masked tokens @@ -215,6 +227,23 @@ def mask_tokens(inputs, tokenizer, args): return inputs, labels +def preprocess_batch(inputs, masks, labels, seq_lengths): + """ + The real sequence length of a batch may be shorter than max_seq of the whole dataset. + Remove some padding tokens to accelerate the training process. + And make sure that the sequence length is multiple of 8. + + References: + https://huggingface.co/transformers/performance.html#fp16 + """ + # The gain for FP16 training is that in each of those cases, the training with the flag --fp16 is twice as fast, + # which does require every tensor to have every dimension be a multiple of 8 + # (examples pad the tensors to a sequence length that is a multiple of 8). + max_seq_len = seq_lengths.max() + max_seq_len = closest_multiple_of_8(max_seq_len) + return inputs[:, :max_seq_len], masks[:, :max_seq_len], labels[:, :max_seq_len] + + def train(args, train_dataset, model, tokenizer): """ Train the model """ if args.local_rank in [-1, 0]: @@ -233,27 +262,23 @@ def train(args, train_dataset, model, tokenizer): # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ - {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + 'weight_decay': args.weight_decay}, {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} - ] + ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) - if args.fp16: - try: - from apex import amp - except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") - model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) - model.resize_token_embeddings(len(tokenizer)) - # multi-gpu training (should be after apex fp16 initialization) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, + num_training_steps=t_total) + # https://pytorch.org/docs/master/notes/amp_examples.html + amp_helper = AmpHelper(use_amp=args.fp16) if args.n_gpu > 1: model = torch.nn.DataParallel(model) - # Distributed training (should be after apex fp16 initialization) + # Distributed training if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, - find_unused_parameters=True) + find_unused_parameters=False) # Train! logger.info("***** Running training *****") @@ -261,7 +286,8 @@ def train(args, train_dataset, model, tokenizer): logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", - args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) + args.train_batch_size * args.gradient_accumulation_steps * ( + torch.distributed.get_world_size() if args.local_rank != -1 else 1)) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) @@ -271,12 +297,13 @@ def train(args, train_dataset, model, tokenizer): train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) set_seed(args) # Added here for reproducibility (even between python 2 and 3) for e in train_iterator: - + # epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) for step, batch in enumerate(train_dataloader): # inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) - logger.info(f" PROGRESS: {float(global_step)/t_total*100}%") - inputs, masks, labels = batch + logger.info(f" PROGRESS: {float(global_step) / t_total * 100}%") + inputs, masks, labels, seq_lengths = batch + inputs, masks, labels = preprocess_batch(inputs, masks, labels, seq_lengths) # cut seq # import pdb # pdb.set_trace() inputs = inputs.to(args.device) @@ -284,27 +311,29 @@ def train(args, train_dataset, model, tokenizer): labels = labels.to(args.device) model.train() - outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels) - loss = outputs[0] # model outputs are always tuple in transformers (see doc) - - if args.n_gpu > 1: - loss = loss.mean() # mean() to average on multi-gpu parallel training - if args.gradient_accumulation_steps > 1: - loss = loss / args.gradient_accumulation_steps - - if args.fp16: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() + try: + with amp_helper.might_enable_autocast: + outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels) + loss = outputs[0] # model outputs are always tuple in transformers (see doc) + + if args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + amp_helper.backward(loss) + except RuntimeError as e: + if 'CUDA out of memory' in str(e): + # if out of memory, we must choose smaller batch_size + print(f'inputs.shape = {inputs.shape}, labels.shape = {labels.shape}') + raise tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: - if args.fp16: - torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) - else: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) - optimizer.step() + amp_helper.might_unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + # optimizer.step() + amp_helper.step(optimizer) scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 @@ -317,7 +346,7 @@ def train(args, train_dataset, model, tokenizer): tb_writer.add_scalar('eval_{}'.format(key), value, global_step) tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) - logger.info(f" EVALERR: {(tr_loss - logging_loss)/float(args.logging_steps)}") + logger.info(f" EVALERR: {(tr_loss - logging_loss) / float(args.logging_steps)}") logging_loss = tr_loss if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: @@ -326,7 +355,8 @@ def train(args, train_dataset, model, tokenizer): output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) - model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + model_to_save = model.module if hasattr(model, + 'module') else model # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, 'training_args.bin')) @@ -334,12 +364,9 @@ def train(args, train_dataset, model, tokenizer): _rotate_checkpoints(args, checkpoint_prefix) - # if args.max_steps > 0 and global_step > args.max_steps: - # epoch_iterator.close() - # break - if args.max_steps > 0 and global_step > args.max_steps: - train_iterator.close() - break + if global_step > args.max_steps > 0: + train_iterator.close() + break if args.local_rank in [-1, 0]: tb_writer.close() @@ -362,7 +389,9 @@ def evaluate(args, model, tokenizer, prefix=""): eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) # multi-gpu evaluate - if args.n_gpu > 1: + if args.n_gpu > 1 and not (isinstance(model, torch.nn.DataParallel) or + isinstance(model, torch.nn.parallel.DistributedDataParallel)): + # if args.evaluate_during_training, DataParallel is already used model = torch.nn.DataParallel(model) # Eval! @@ -376,9 +405,10 @@ def evaluate(args, model, tokenizer, prefix=""): for batch in tqdm(eval_dataloader, desc="Evaluating"): # inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) - inputs, masks, labels = batch - # import pdb - # pdb.set_trace() + inputs, masks, labels, seq_lengths = batch + inputs, masks, labels = preprocess_batch(inputs, masks, labels, seq_lengths) # cut seq + # import pdb + # pdb.set_trace() inputs = inputs.to(args.device) masks = masks.to(args.device) labels = labels.to(args.device) @@ -387,12 +417,12 @@ def evaluate(args, model, tokenizer, prefix=""): with torch.no_grad(): outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels) - lm_loss = outputs[0] - eval_loss += lm_loss.mean().item() + loss = outputs[0] # model outputs are always tuple in transformers (see doc) + eval_loss += loss.mean().item() nb_eval_steps += 1 eval_loss = eval_loss / nb_eval_steps - perplexity = torch.exp(torch.tensor(eval_loss)) + perplexity = float(np.exp(eval_loss)) result = { "perplexity": perplexity @@ -409,6 +439,7 @@ def evaluate(args, model, tokenizer, prefix=""): def main(): + global AdamW parser = argparse.ArgumentParser() ## Required parameters @@ -489,10 +520,7 @@ def main(): help="random seed for initialization") parser.add_argument('--fp16', action='store_true', - help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") - parser.add_argument('--fp16_opt_level', type=str, default='O1', - help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." - "See details at https://nvidia.github.io/apex/amp.html") + help="Whether to use 16-bit (mixed) precision (through torch.cuda.amp) instead of 32-bit") parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") @@ -504,18 +532,32 @@ def main(): parser.add_argument("--max_seq", default=80, type=int, help="") + parser.add_argument('--gradient_checkpointing', action='store_true', help='enable gradient checkpointing') + parser.add_argument('--use_multi_tensor_adamw', action='store_true', + help='use torch.optim._multi_tensor.AdamW instead of transformers.AdamW') args = parser.parse_args() + if args.use_multi_tensor_adamw: + try: + # overwrite the previous imported AdamW + # https://huggingface.co/transformers/performance.html#faster-optimizer + from torch.optim._multi_tensor import AdamW + except ImportError as e: + print(e) if args.model_type in ["bert", "roberta", "distilbert"] and not args.mlm: raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm " "flag (masked language modeling).") if args.eval_data_file is None and args.do_eval: - raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " - "or remove the --do_eval argument.") + raise ValueError( + "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " + "or remove the --do_eval argument.") - if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: - raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) + if os.path.exists(args.output_dir) and os.listdir( + args.output_dir) and args.do_train and not args.overwrite_output_dir: + raise ValueError( + "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( + args.output_dir)) # Setup distant debugging if needed if args.server_ip and args.server_port: @@ -525,6 +567,11 @@ def main(): ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.wait_for_attach() + # Setup logging before `torch.distributed.init_process_group` is called + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + # Setup CUDA, GPU & distributed training if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") @@ -535,13 +582,8 @@ def main(): torch.distributed.init_process_group(backend='nccl') args.n_gpu = 1 args.device = device - - # Setup logging - logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt = '%m/%d/%Y %H:%M:%S', - level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", - args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) + args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) # Set seed set_seed(args) @@ -550,14 +592,16 @@ def main(): if args.local_rank not in [-1, 0]: torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab + if args.fp16: + MODEL_CLASSES['gpt2'] = (GPT2Config, AmpGPT2LMHeadModel, GPT2TokenizerFast) config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None) tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, - #tokenizer = BertTokenizer(vocab_file='../GPT2-chitchat/vocabulary/vocab_small.txt', eos_token='<T>', + # tokenizer = BertTokenizer(vocab_file='../GPT2-chitchat/vocabulary/vocab_small.txt', eos_token='<T>', do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None) - + if args.block_size <= 0: args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model args.block_size = min(args.block_size, tokenizer.max_len_single_sentence) @@ -565,7 +609,13 @@ def main(): from_tf=bool('.ckpt' in args.model_name_or_path), config=config, cache_dir=args.cache_dir if args.cache_dir else None) + if model.config.vocab_size != len(tokenizer): + logger.info('resize token embeddings, since there may be added tokens.') + model.resize_token_embeddings(len(tokenizer)) model.to(args.device) + if args.gradient_checkpointing: + # https://huggingface.co/transformers/performance.html#gradient-checkpointing + try_enable_gradient_checkpointing(model) if args.local_rank == 0: torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab @@ -585,7 +635,6 @@ def main(): global_step, tr_loss = train(args, train_dataset, model, tokenizer) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) - # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained() if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): # Create output directory if needed @@ -595,7 +644,8 @@ def main(): logger.info("Saving model checkpoint to %s", args.output_dir) # Save a trained model, configuration and tokenizer using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` - model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + model_to_save = model.module if hasattr(model, + 'module') else model # Take care of distributed/parallel training model_to_save.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) @@ -607,25 +657,24 @@ def main(): tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) model.to(args.device) - # Evaluation results = {} if args.do_eval and args.local_rank in [-1, 0]: checkpoints = [args.output_dir] if args.eval_all_checkpoints: - checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) + checkpoints = list( + os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging logger.info("Evaluate the following checkpoints: %s", checkpoints) for checkpoint in checkpoints: global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else "" - + model = model_class.from_pretrained(checkpoint) model.to(args.device) result = evaluate(args, model, tokenizer, prefix=prefix) result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) results.update(result) - return results