Skip to content
Snippets Groups Projects
Unverified Commit fc978dcb authored by Carel van Niekerk's avatar Carel van Niekerk Committed by GitHub
Browse files

Merge branch 'master' into dsml_convlab

parents 04024645 c6170181
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,7 @@ import json
import os
import random
import sys
import itertools
import zipfile
import numpy
from numpy.lib.shape_base import _put_along_axis_dispatcher
......@@ -211,16 +212,18 @@ if __name__ == '__main__':
numpy.random.seed(seed)
torch.manual_seed(seed)
if len(sys.argv) != 4:
if len(sys.argv) < 4:
print("usage:")
print("\t python evaluate.py dataset model role")
print("\t dataset=MultiWOZ, CrossWOZ, or Camrest")
print("\t model=SCLSTM, SCLSTM_NoUNK, SCGPT or TemplateNLG")
print("\t role=usr/sys")
print("\t [Optional] model_file")
sys.exit()
dataset_name = sys.argv[1]
model_name = sys.argv[2]
role = sys.argv[3]
model_file = sys.argv[4] if len(sys.argv) >= 5 else None
if dataset_name == 'MultiWOZ':
if model_name == 'SCLSTM':
from convlab2.nlg.sclstm.multiwoz import SCLSTM
......@@ -242,17 +245,19 @@ if __name__ == '__main__':
model = TemplateNLG(is_user=False)
elif model_name == 'SCGPT':
from convlab2.nlg.scgpt.multiwoz import SCGPT
if model_file is not None:
print(f"load model at {model_file}")
if role == 'usr':
model = SCGPT(is_user=True)
model = SCGPT(model_file, is_user=True)
elif role == 'sys':
model = SCGPT(is_user=False, model_file='scgpt/trained_output/multiwoz/')
model = SCGPT(model_file, is_user=False)
else:
raise Exception("Available models: SCLSTM, SCGPT, TEMPLATE")
from convlab2.util.dataloader.module_dataloader import SingleTurnNLGDataloader
from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
dataloader = SingleTurnNLGDataloader(dataset_dataloader=MultiWOZDataloader())
data = dataloader.load_data(data_key='all', role=role)['test']
data = dataloader.load_data(data_key='all', role=role, session_id=True)['test']
dialog_acts = []
golden_utts = []
......@@ -262,7 +267,19 @@ if __name__ == '__main__':
sen_num = 0
# sys.stdout = open(sys.argv[2] + '-' + sys.argv[3] + '-' + 'evaluate_logs_neo.txt','w')
assert 'utterance' in data and 'dialog_act' in data and 'session_id' in data
assert len(data['utterance']) == len(data['dialog_act']) == len(data['session_id'])
# Turns during the same session should be contiguous, so we can call init_session at the first turn of a new session.
# This is necessary for SCGPT, but unnecessary for SCLSTM and TemplateNLG.
is_first_turn = []
for _, iterator in itertools.groupby(data['session_id']):
is_first_turn.append(True)
next(iterator)
is_first_turn.extend(False for _ in iterator)
for i in tqdm(range(len(data['utterance']))):
if is_first_turn[i]:
model.init_session()
dialog_acts.append(data['dialog_act'][i])
golden_utts.append(data['utterance'][i])
gen_utts.append(model.generate(data['dialog_act'][i]))
......
......@@ -21,9 +21,22 @@ tar -xvf scgpt.tar.gz
Then
``` python
python train.py --output_dir=trained_output --model_type=gpt2 --model_name_or_path=scgpt --do_train --do_eval --eval_data_file=multiwoz/data/test_sys.txt --overwrite_cache --use_tokenize --train_data_file=multiwoz/data/train_sys.txt --overwrite_output_dir
python train.py --output_dir=trained_output --model_type=gpt2 --model_name_or_path=scgpt --do_train --do_eval --eval_data_file=multiwoz/data/test_sys.txt --use_tokenize --train_data_file=multiwoz/data/train_sys.txt --overwrite_output_dir
```
some tricks (optional training argument):
* `--gradient_accumulation_steps xxx`
* `--fp16`, if it's set, you'd better set `--per_gpu_train_batch_size` to be multiple of 8
* `--max_seq xxx`, it should be larger than the length of the longest sequence. You can set `--max_seq 1024`. The script uses a dynamic sequence length at each training step.
* `--gradient_checkpointing`, it allows larger `per_gpu_train_batch_size`
* `--use_multi_tensor_adamw`, someone says it's a faster optimizer
distributed data parallel:
If multiple GPUs are available, you can run `python -m torch.distributed.launch --nproc_per_node CUDA_COUNT train.py ......`
`CUDA_COUNT` is the number of GPUs. `.....` are arguments of `train.py`.
## Use
```python
......
import warnings
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch.cuda.amp as amp
import transformers
from transformers import GPT2LMHeadModel
# reference: https://pytorch.org/docs/master/notes/amp_examples.html
class AmpGPT2LMHeadModel(GPT2LMHeadModel):
if TYPE_CHECKING:
# For IDE's code hinting
forward = GPT2LMHeadModel.forward
else:
def forward(self, *args, **kwargs):
with amp.autocast():
return super().forward(*args, **kwargs)
def try_enable_gradient_checkpointing(model: "transformers.modeling_utils.PreTrainedModel"):
if model.supports_gradient_checkpointing:
model.gradient_checkpointing_enable()
else:
warnings.warn(f"{type(model)} doesn't support gradient_checkpointing")
class AmpHelper:
"""
References:
https://pytorch.org/docs/master/notes/amp_examples.html
"""
def __init__(self, use_amp=True):
self.use_amp = use_amp
self.might_enable_autocast = amp.autocast() if use_amp else nullcontext()
self.scaler = amp.GradScaler()
def backward(self, loss):
if self.use_amp:
return self.scaler.scale(loss).backward()
else:
return loss.backward()
def step(self, optimizer):
if self.use_amp:
self.scaler.step(optimizer)
self.scaler.update()
else:
optimizer.step()
def might_unscale_(self, optimizer):
if self.use_amp:
# Unscales the gradients of optimizer's assigned params in-place
self.scaler.unscale_(optimizer)
\ No newline at end of file
......@@ -6,6 +6,7 @@ Created on Mon Sep 14 11:38:53 2020
import os
import json
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from convlab2.nlg.scgpt.utils import dict2dict, dict2seq
import zipfile
......@@ -14,6 +15,52 @@ def read_zipped_json(filepath, filename):
archive = zipfile.ZipFile(filepath, 'r')
return json.load(archive.open(filename))
def init_domain():
return {'Attraction':False,
'Hospital':False,
'Hotel':False,
'Police':False,
'Restaurant':False,
'Taxi':False,
'Train':False}
def write_file(name, data, role='usr'):
with open(f'{name}.txt', 'w', encoding='utf-8') as f:
for ID in data:
sess = data[ID]
sess_domains = init_domain()
for turn in sess:
if role == 'usr':
if not turn['usr_da']:
continue
turn['usr_da'] = eval(str(turn['usr_da']).replace('Bus','Train'))
da_seq = dict2seq(dict2dict(turn['usr_da'])).replace('&', 'and')
domains = set([key.split('-')[0] for key in turn['usr_da'].keys()])
elif role == 'sys':
if not turn['sys_da']:
continue
turn['sys_da'] = eval(str(turn['sys_da']).replace('Bus','Train'))
da_seq = dict2seq(dict2dict(turn['sys_da'])).replace('&', 'and')
domains = set([key.split('-')[0] for key in turn['sys_da'].keys()])
else:
raise NameError('Invalid Role: Select usr/sys.')
for domain in domains:
if domain not in ['general', 'Booking'] and not sess_domains[domain]:
da_seq = da_seq.replace(domain.lower(), domain.lower()+' *', 1)
sess_domains[domain] = True
if role == 'usr':
da_uttr = turn['usr'].replace(' bus ', ' train ').replace('&', 'and')
elif role == 'sys':
da_uttr = turn['sys'].replace(' bus ', ' train ').replace('&', 'and')
f.write(f'{da_seq} & {da_uttr}\n')
if __name__ == '__main__':
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--role', type=str, default='usr')
args = parser.parse_args()
cur_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(
cur_dir)))), 'data/multiwoz/')
......@@ -37,22 +84,22 @@ results_test = {}
for title, sess in data.items():
logs = sess['log']
turns = []
turn = {'turn':0, 'sys':'', 'sys_da':''}
turn = {'turn': 0, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''}
current_domain = None
for i, diag in enumerate(logs):
text = diag['text']
da = diag['dialog_act']
span = diag['span_info']
if i % 2 == 0:
turn['usr'] = text
if current_domain:
da = eval(str(da).replace('Booking', current_domain))
span = eval(str(span).replace('Booking', current_domain))
if i % 2 == 0:
turn['usr'] = text
turn['usr_da'] = da
turn['usr_span'] = span
turns.append(turn)
else:
turn = {'turn': i//2 +1}
turn = {'turn': i//2 + 1, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''}
turn['sys'] = text
turn['sys_da'] = da
turn['sys_span'] = span
......@@ -60,6 +107,9 @@ for title, sess in data.items():
domain = key.split('-')[0]
if domain not in ['general', 'Booking']:
current_domain = domain
else:
if args.role == 'sys':
turns.append(turn)
title = title
if title in val_list:
current = results_val
......@@ -73,41 +123,7 @@ results = eval(str(results).replace(" n't", " not"))
results_val = eval(str(results_val).replace(" n't", " not"))
results_test = eval(str(results_test).replace(" n't", " not"))
def init_domain():
return {'Attraction':False,
'Hospital':False,
'Hotel':False,
'Police':False,
'Restaurant':False,
'Taxi':False,
'Train':False}
def write_file(name, data):
with open(f'{name}.txt', 'w', encoding='utf-8') as f:
for ID in data:
sess = data[ID]
sess_domains = init_domain()
for turn in sess:
# TODO: set option to process usr/sys
if not turn['usr_da']:
continue
turn['usr_da'] = eval(str(turn['usr_da']).replace('Bus','Train'))
da_seq = dict2seq(dict2dict(turn['usr_da'])).replace('&', 'and')
domains = set([key.split('-')[0] for key in turn['usr_da'].keys()])
if not turn['sys_da']:
continue
turn['sys_da'] = eval(str(turn['sys_da']).replace('Bus','Train'))
da_seq = dict2seq(dict2dict(turn['sys_da'])).replace('&', 'and')
domains = set([key.split('-')[0] for key in turn['sys_da'].keys()])
for domain in domains:
if domain not in ['general', 'Booking'] and not sess_domains[domain]:
da_seq = da_seq.replace(domain.lower(), domain.lower()+' *', 1)
sess_domains[domain] = True
da_uttr = turn['usr'].replace(' bus ', ' train ').replace('&', 'and')
da_uttr = turn['sys'].replace(' bus ', ' train ').replace('&', 'and')
f.write(f'{da_seq} & {da_uttr}\n')
if not os.path.exists(os.path.join(cur_dir,'data')):
os.makedirs(os.path.join(cur_dir, 'data'))
write_file(os.path.join(cur_dir, 'data/train'), dict(results, **results_val))
write_file(os.path.join(cur_dir, 'data/test'), results_test)
write_file(os.path.join(cur_dir, f'data/train_{args.role}'), dict(results, **results_val), role=args.role)
write_file(os.path.join(cur_dir, f'data/test_{args.role}'), results_test, role=args.role)
......@@ -2,6 +2,7 @@ import torch
import numpy as np
import os
import zipfile
from copy import deepcopy
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from convlab2.nlg.scgpt.utils import tuple2seq
......@@ -10,23 +11,31 @@ from convlab2.nlg.nlg import NLG
from convlab2.util.file_util import cached_path
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "nlg-gpt-multiwoz.zip")
class SCGPT(NLG):
def __init__(self,
archive_file=DEFAULT_ARCHIVE_FILE,
use_cuda=True,
is_user=False,
model_file='https://convlab.blob.core.windows.net/convlab-2/nlg-gpt-multiwoz.zip'):
def __init__(self, model_file=None,
use_cuda=True, is_user=False):
# If no filename is mentioned then set to default
if not model_file:
if is_user:
model_file = 'https://convlab.blob.core.windows.net/convlab-2/nlg-gpt-multiwoz.zip'
else:
model_file = 'https://zenodo.org/record/5767426/files/neo_scgpt_system.zip'
# Load from file/url
model_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.isfile(archive_file):
archive_file = cached_path(model_file)
archive = zipfile.ZipFile(archive_file, 'r')
if not os.path.isfile(model_file):
model_file = cached_path(model_file)
if not os.path.isdir(model_file):
archive = zipfile.ZipFile(model_file, 'r')
archive.extractall(model_dir)
# Get model directory
model_file = archive.filelist[0].filename.replace('/', '')
self.model_name_or_path = os.path.join(model_dir, model_file)
else:
self.model_name_or_path = model_file
self.model_name_or_path = os.path.join(model_dir, 'multiwoz')
self.length = 50
self.num_samples = 5
self.temperature = 1.0
......@@ -63,8 +72,9 @@ class SCGPT(NLG):
'Restaurant':False,
'Taxi':False,
'Train':False,}
if not self.is_user:
self.sess_domains['Booking'] = False
self.cur_domain = None
# if not self.is_user:
# self.sess_domains['Booking'] = False
def generate(self, meta):
......@@ -72,10 +82,23 @@ class SCGPT(NLG):
if not meta:
return 'No user action'
meta = deepcopy(meta)
for list_ in meta:
domain = list_[1]
if domain not in ('general', 'Booking'):
self.cur_domain = domain
for i, list_ in enumerate(meta):
list_ = list(list_)
if list_[1] == 'Booking':
if self.cur_domain is not None:
list_[1] = self.cur_domain
meta[i] = list_
else:
print('`cur_domain` is None, but there is `Booking` in dialog action.')
raw_text = tuple2seq(meta)
domains = set([item[1] for item in meta])
for domain in domains:
if domain != 'general' and not self.sess_domains[domain]:
if domain not in ('general', 'Booking') and not self.sess_domains[domain]:
raw_text = raw_text.replace(domain.lower(), domain.lower()+ ' *', 1)
self.sess_domains[domain] = True
context_tokens = self.tokenizer.encode(raw_text, add_special_tokens=False)
......
......@@ -9,33 +9,28 @@ import random
import re
import shutil
import sys
import numpy as np
import torch
from tqdm import tqdm, trange
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
from torch.utils.data.distributed import DistributedSampler
try:
from torch.utils.tensorboard import SummaryWriter
except:
except ImportError:
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
BertConfig, BertForMaskedLM, BertTokenizer,
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
BertConfig, BertForMaskedLM, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2TokenizerFast,
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer, BertTokenizer)
from convlab2.nlg.scgpt.modeling_utils import AmpGPT2LMHeadModel, try_enable_gradient_checkpointing, AmpHelper
logger = logging.getLogger(__name__)
MODEL_CLASSES = {
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast),
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
......@@ -43,11 +38,20 @@ MODEL_CLASSES = {
}
def closest_multiple_of_8(n):
"""
Returns:
a closest number, which is a multiple of 8 and >= n
"""
return ((n + 7) >> 3) << 3
class TextDataset(Dataset):
def __init__(self, tokenizer, args, file_path='train', block_size=512, max_seq=80):
assert os.path.isfile(file_path)
directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(block_size) + '_seqlen_' + str(max_seq) + '_' + filename)
cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(
block_size) + '_seqlen_' + str(max_seq) + '_' + filename)
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file)
......@@ -71,9 +75,8 @@ class TextDataset(Dataset):
for i in range(0, len(tokenized_text) - block_size + 1, block_size): # Truncate in block of block_size
self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i:i + block_size]))
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
# If your dataset is small, first you should loook for a bigger one :-) and second you
# If your dataset is small, first you should look for a bigger one :-) and second you
# can change this behavior by adding (model specific) padding.
logger.info("Saving features into cached file %s", cached_features_file)
......@@ -86,26 +89,30 @@ class TextDataset(Dataset):
def __getitem__(self, item):
return torch.tensor(self.examples[item])
class TextSeqDataset(Dataset):
def __init__(self, tokenizer, args, file_path='train', block_size=512, max_seq=80, seperator=' & '):
def __init__(self, tokenizer, args, file_path='train', block_size=512, max_seq=80, separator=' & '):
max_seq = closest_multiple_of_8(max_seq)
assert os.path.isfile(file_path)
directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(directory, args.output_dir.replace(os.sep, '_') + '_cached_lm_' + str(block_size) + '_seqlen_' + str(max_seq) + '_' + filename)
cached_features_file = os.path.join(directory, args.output_dir.replace(os.sep, '_') + '_cached_lm_' + str(
block_size) + '_seqlen_' + str(max_seq) + '_' + filename)
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file)
with open(cached_features_file, 'rb') as handle:
self.examples = pickle.load(handle)
self.examples, self.masks, self.labels, self.seq_lengths = pickle.load(handle)
else:
logger.info("Creating features from dataset file at %s", directory)
self.examples = []
self.labels = []
self.masks = []
self.seq_lengths = []
with open(file_path, encoding="utf-8") as f:
for line in f:
for line in tqdm(f):
line = line.strip()
raw_str = line.lower()
code_str = line.lower().split(seperator)[0] + seperator
raw_str = line.lower() # do we need lowercase?
code_str = line.lower().split(separator)[0] + separator
code_str = code_str.strip()
if len(raw_str.split()) > max_seq -1:
raw_str = ' '.join(raw_str.split()[:max_seq -1])
......@@ -121,12 +128,13 @@ class TextSeqDataset(Dataset):
label[:len(tokenized_text)] = tokenized_text
mask = [1] * max_seq
if len(tokenized_text) < max_seq:
self.seq_lengths.append(len(tokenized_text))
mask[-(max_seq - len(tokenized_text)):] = [0] * (max_seq - len(tokenized_text))
# label[code_str_len:len(tokenized_text)] = tokenized_text[code_str_len:]
tokenized_text = tokenized_text + [0] * (max_seq - len(tokenized_text))
tokenized_text = tokenized_text + [tokenizer.eos_token_id] * (max_seq - len(tokenized_text))
else:
self.seq_lengths.append(max_seq)
tokenized_text = tokenized_text[:max_seq]
# label[code_str_len:] = tokenized_text[code_str_len:]
......@@ -135,23 +143,26 @@ class TextSeqDataset(Dataset):
self.labels.append(label)
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
# If your dataset is small, first you should loook for a bigger one :-) and second you
# If your dataset is small, first you should look for a bigger one :-) and second you
# can change this behavior by adding (model specific) padding.
if args.with_code_loss:
self.labels = self.examples
logger.info("Saving features into cached file %s", cached_features_file)
with open(cached_features_file, 'wb') as handle:
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump((self.examples, self.masks, self.labels, self.seq_lengths), handle,
protocol=pickle.HIGHEST_PROTOCOL)
def __len__(self):
return len(self.examples)
def __getitem__(self, item):
return torch.tensor(self.examples[item]), torch.tensor(self.masks[item]), torch.tensor(self.labels[item])
return torch.tensor(self.examples[item]), torch.tensor(self.masks[item]), torch.tensor(
self.labels[item]), torch.tensor(self.seq_lengths[item])
def load_and_cache_examples(args, tokenizer, evaluate=False):
dataset = TextSeqDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size, max_seq=args.max_seq)
dataset = TextSeqDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file,
block_size=args.block_size, max_seq=args.max_seq)
return dataset
......@@ -197,7 +208,8 @@ def mask_tokens(inputs, tokenizer, args):
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, args.mlm_probability)
special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in
labels.tolist()]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -1 # We only compute loss on masked tokens
......@@ -215,6 +227,23 @@ def mask_tokens(inputs, tokenizer, args):
return inputs, labels
def preprocess_batch(inputs, masks, labels, seq_lengths):
"""
The real sequence length of a batch may be shorter than max_seq of the whole dataset.
Remove some padding tokens to accelerate the training process.
And make sure that the sequence length is multiple of 8.
References:
https://huggingface.co/transformers/performance.html#fp16
"""
# The gain for FP16 training is that in each of those cases, the training with the flag --fp16 is twice as fast,
# which does require every tensor to have every dimension be a multiple of 8
# (examples pad the tensors to a sequence length that is a multiple of 8).
max_seq_len = seq_lengths.max()
max_seq_len = closest_multiple_of_8(max_seq_len)
return inputs[:, :max_seq_len], masks[:, :max_seq_len], labels[:, :max_seq_len]
def train(args, train_dataset, model, tokenizer):
""" Train the model """
if args.local_rank in [-1, 0]:
......@@ -233,27 +262,23 @@ def train(args, train_dataset, model, tokenizer):
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
model.resize_token_embeddings(len(tokenizer))
# multi-gpu training (should be after apex fp16 initialization)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
num_training_steps=t_total)
# https://pytorch.org/docs/master/notes/amp_examples.html
amp_helper = AmpHelper(use_amp=args.fp16)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
# Distributed training
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True)
find_unused_parameters=False)
# Train!
logger.info("***** Running training *****")
......@@ -261,7 +286,8 @@ def train(args, train_dataset, model, tokenizer):
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
args.train_batch_size * args.gradient_accumulation_steps * (
torch.distributed.get_world_size() if args.local_rank != -1 else 1))
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
......@@ -276,7 +302,8 @@ def train(args, train_dataset, model, tokenizer):
for step, batch in enumerate(train_dataloader):
# inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
logger.info(f" PROGRESS: {float(global_step) / t_total * 100}%")
inputs, masks, labels = batch
inputs, masks, labels, seq_lengths = batch
inputs, masks, labels = preprocess_batch(inputs, masks, labels, seq_lengths) # cut seq
# import pdb
# pdb.set_trace()
inputs = inputs.to(args.device)
......@@ -284,6 +311,8 @@ def train(args, train_dataset, model, tokenizer):
labels = labels.to(args.device)
model.train()
try:
with amp_helper.might_enable_autocast:
outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
......@@ -292,19 +321,19 @@ def train(args, train_dataset, model, tokenizer):
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
amp_helper.backward(loss)
except RuntimeError as e:
if 'CUDA out of memory' in str(e):
# if out of memory, we must choose smaller batch_size
print(f'inputs.shape = {inputs.shape}, labels.shape = {labels.shape}')
raise
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
amp_helper.might_unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
# optimizer.step()
amp_helper.step(optimizer)
scheduler.step() # Update learning rate schedule
model.zero_grad()
global_step += 1
......@@ -326,7 +355,8 @@ def train(args, train_dataset, model, tokenizer):
output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save = model.module if hasattr(model,
'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
......@@ -334,10 +364,7 @@ def train(args, train_dataset, model, tokenizer):
_rotate_checkpoints(args, checkpoint_prefix)
# if args.max_steps > 0 and global_step > args.max_steps:
# epoch_iterator.close()
# break
if args.max_steps > 0 and global_step > args.max_steps:
if global_step > args.max_steps > 0:
train_iterator.close()
break
......@@ -362,7 +389,9 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate
if args.n_gpu > 1:
if args.n_gpu > 1 and not (isinstance(model, torch.nn.DataParallel) or
isinstance(model, torch.nn.parallel.DistributedDataParallel)):
# if args.evaluate_during_training, DataParallel is already used
model = torch.nn.DataParallel(model)
# Eval!
......@@ -376,7 +405,8 @@ def evaluate(args, model, tokenizer, prefix=""):
for batch in tqdm(eval_dataloader, desc="Evaluating"):
# inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
inputs, masks, labels = batch
inputs, masks, labels, seq_lengths = batch
inputs, masks, labels = preprocess_batch(inputs, masks, labels, seq_lengths) # cut seq
# import pdb
# pdb.set_trace()
inputs = inputs.to(args.device)
......@@ -387,12 +417,12 @@ def evaluate(args, model, tokenizer, prefix=""):
with torch.no_grad():
outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
lm_loss = outputs[0]
eval_loss += lm_loss.mean().item()
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
eval_loss += loss.mean().item()
nb_eval_steps += 1
eval_loss = eval_loss / nb_eval_steps
perplexity = torch.exp(torch.tensor(eval_loss))
perplexity = float(np.exp(eval_loss))
result = {
"perplexity": perplexity
......@@ -409,6 +439,7 @@ def evaluate(args, model, tokenizer, prefix=""):
def main():
global AdamW
parser = argparse.ArgumentParser()
## Required parameters
......@@ -489,10 +520,7 @@ def main():
help="random seed for initialization")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
help="Whether to use 16-bit (mixed) precision (through torch.cuda.amp) instead of 32-bit")
parser.add_argument("--local_rank", type=int, default=-1,
help="For distributed training: local_rank")
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
......@@ -504,18 +532,32 @@ def main():
parser.add_argument("--max_seq", default=80, type=int,
help="")
parser.add_argument('--gradient_checkpointing', action='store_true', help='enable gradient checkpointing')
parser.add_argument('--use_multi_tensor_adamw', action='store_true',
help='use torch.optim._multi_tensor.AdamW instead of transformers.AdamW')
args = parser.parse_args()
if args.use_multi_tensor_adamw:
try:
# overwrite the previous imported AdamW
# https://huggingface.co/transformers/performance.html#faster-optimizer
from torch.optim._multi_tensor import AdamW
except ImportError as e:
print(e)
if args.model_type in ["bert", "roberta", "distilbert"] and not args.mlm:
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
"flag (masked language modeling).")
if args.eval_data_file is None and args.do_eval:
raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
raise ValueError(
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
"or remove the --do_eval argument.")
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
if os.path.exists(args.output_dir) and os.listdir(
args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args.output_dir))
# Setup distant debugging if needed
if args.server_ip and args.server_port:
......@@ -525,6 +567,11 @@ def main():
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
# Setup logging before `torch.distributed.init_process_group` is called
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
# Setup CUDA, GPU & distributed training
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
......@@ -535,11 +582,6 @@ def main():
torch.distributed.init_process_group(backend='nccl')
args.n_gpu = 1
args.device = device
# Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
......@@ -550,6 +592,8 @@ def main():
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
if args.fp16:
MODEL_CLASSES['gpt2'] = (GPT2Config, AmpGPT2LMHeadModel, GPT2TokenizerFast)
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
cache_dir=args.cache_dir if args.cache_dir else None)
......@@ -565,7 +609,13 @@ def main():
from_tf=bool('.ckpt' in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir if args.cache_dir else None)
if model.config.vocab_size != len(tokenizer):
logger.info('resize token embeddings, since there may be added tokens.')
model.resize_token_embeddings(len(tokenizer))
model.to(args.device)
if args.gradient_checkpointing:
# https://huggingface.co/transformers/performance.html#gradient-checkpointing
try_enable_gradient_checkpointing(model)
if args.local_rank == 0:
torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab
......@@ -585,7 +635,6 @@ def main():
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Create output directory if needed
......@@ -595,7 +644,8 @@ def main():
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save = model.module if hasattr(model,
'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
......@@ -607,13 +657,13 @@ def main():
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model.to(args.device)
# Evaluation
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
checkpoints = list(
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
......@@ -625,7 +675,6 @@ def main():
result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result)
return results
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment