diff --git a/convlab2/nlg/evaluate.py b/convlab2/nlg/evaluate.py
index e2cffc4553060c9c6c5e812c0ccd317981274fd6..1a4747b7f19a47f2c069e4ba286c9ad16763043b 100755
--- a/convlab2/nlg/evaluate.py
+++ b/convlab2/nlg/evaluate.py
@@ -8,6 +8,7 @@ import json
 import os
 import random
 import sys
+import itertools
 import zipfile
 import numpy
 from numpy.lib.shape_base import _put_along_axis_dispatcher
@@ -211,16 +212,18 @@ if __name__ == '__main__':
     numpy.random.seed(seed)
     torch.manual_seed(seed)
 
-    if len(sys.argv) != 4:
+    if len(sys.argv) < 4:
         print("usage:")
         print("\t python evaluate.py dataset model role")
         print("\t dataset=MultiWOZ, CrossWOZ, or Camrest")
         print("\t model=SCLSTM, SCLSTM_NoUNK, SCGPT or TemplateNLG")
         print("\t role=usr/sys")
+        print("\t [Optional] model_file")
         sys.exit()
     dataset_name = sys.argv[1]
     model_name = sys.argv[2]
     role = sys.argv[3]
+    model_file = sys.argv[4] if len(sys.argv) >= 5 else None
     if dataset_name == 'MultiWOZ':
         if model_name == 'SCLSTM':
             from convlab2.nlg.sclstm.multiwoz import SCLSTM
@@ -242,17 +245,19 @@ if __name__ == '__main__':
                 model = TemplateNLG(is_user=False)
         elif model_name == 'SCGPT':
             from convlab2.nlg.scgpt.multiwoz import SCGPT
+            if model_file is not None:
+                print(f"load model at {model_file}")
             if role == 'usr':
-                model = SCGPT(is_user=True)
+                model = SCGPT(model_file, is_user=True)
             elif role == 'sys':
-                model  = SCGPT(is_user=False, model_file='scgpt/trained_output/multiwoz/')
+                model  = SCGPT(model_file, is_user=False)
         else:
             raise Exception("Available models: SCLSTM, SCGPT, TEMPLATE")
 
         from convlab2.util.dataloader.module_dataloader import SingleTurnNLGDataloader
         from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
         dataloader = SingleTurnNLGDataloader(dataset_dataloader=MultiWOZDataloader())
-        data = dataloader.load_data(data_key='all', role=role)['test']
+        data = dataloader.load_data(data_key='all', role=role, session_id=True)['test']
 
         dialog_acts = []
         golden_utts = []
@@ -262,7 +267,19 @@ if __name__ == '__main__':
         sen_num = 0
 
         # sys.stdout = open(sys.argv[2] + '-' + sys.argv[3] + '-' + 'evaluate_logs_neo.txt','w')
+        assert 'utterance' in data and 'dialog_act' in data and 'session_id' in data
+        assert len(data['utterance']) == len(data['dialog_act']) == len(data['session_id'])
+
+        # Turns during the same session should be contiguous, so we can call init_session at the first turn of a new session.
+        # This is necessary for SCGPT, but unnecessary for SCLSTM and TemplateNLG.
+        is_first_turn = []
+        for _, iterator in itertools.groupby(data['session_id']):
+            is_first_turn.append(True)
+            next(iterator)
+            is_first_turn.extend(False for _ in iterator)
         for i in tqdm(range(len(data['utterance']))):
+            if is_first_turn[i]:
+                model.init_session()
             dialog_acts.append(data['dialog_act'][i])
             golden_utts.append(data['utterance'][i])
             gen_utts.append(model.generate(data['dialog_act'][i]))
diff --git a/convlab2/nlg/scgpt/README.md b/convlab2/nlg/scgpt/README.md
index b8630eeb2bcccbf454539883f512a42a4bebd4f3..5eed2c0fc167cd9ee79d66e3252be25060bb294d 100644
--- a/convlab2/nlg/scgpt/README.md
+++ b/convlab2/nlg/scgpt/README.md
@@ -21,9 +21,22 @@ tar -xvf scgpt.tar.gz
 Then
 
 ``` python
-python train.py --output_dir=trained_output --model_type=gpt2 --model_name_or_path=scgpt --do_train --do_eval --eval_data_file=multiwoz/data/test_sys.txt --overwrite_cache --use_tokenize --train_data_file=multiwoz/data/train_sys.txt --overwrite_output_dir
+python train.py --output_dir=trained_output --model_type=gpt2 --model_name_or_path=scgpt --do_train --do_eval --eval_data_file=multiwoz/data/test_sys.txt --use_tokenize --train_data_file=multiwoz/data/train_sys.txt --overwrite_output_dir
 ```
 
+some tricks (optional training argument):
+* `--gradient_accumulation_steps xxx` 
+* `--fp16`, if it's set, you'd better set `--per_gpu_train_batch_size` to be multiple of 8
+* `--max_seq xxx`, it should be larger than the length of the longest sequence. You can set `--max_seq 1024`. The script uses a dynamic sequence length at each training step.
+* `--gradient_checkpointing`, it allows larger `per_gpu_train_batch_size`
+* `--use_multi_tensor_adamw`, someone says it's a faster optimizer
+
+distributed data parallel:
+
+  If multiple GPUs are available, you can run `python -m torch.distributed.launch --nproc_per_node CUDA_COUNT train.py ......` 
+
+  `CUDA_COUNT` is the number of GPUs. `.....` are arguments of `train.py`.
+
 ## Use
 
 ```python
diff --git a/convlab2/nlg/scgpt/modeling_utils.py b/convlab2/nlg/scgpt/modeling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8b3f6ddfc6b7347c624446bf7869c67d3064cc1
--- /dev/null
+++ b/convlab2/nlg/scgpt/modeling_utils.py
@@ -0,0 +1,53 @@
+import warnings
+from contextlib import nullcontext
+from typing import TYPE_CHECKING
+import torch.cuda.amp as amp
+import transformers
+from transformers import GPT2LMHeadModel
+
+
+# reference: https://pytorch.org/docs/master/notes/amp_examples.html
+class AmpGPT2LMHeadModel(GPT2LMHeadModel):
+    if TYPE_CHECKING:
+        # For IDE's code hinting
+        forward = GPT2LMHeadModel.forward
+    else:
+        def forward(self, *args, **kwargs):
+            with amp.autocast():
+                return super().forward(*args, **kwargs)
+
+
+def try_enable_gradient_checkpointing(model: "transformers.modeling_utils.PreTrainedModel"):
+    if model.supports_gradient_checkpointing:
+        model.gradient_checkpointing_enable()
+    else:
+        warnings.warn(f"{type(model)} doesn't support gradient_checkpointing")
+
+
+class AmpHelper:
+    """
+    References:
+        https://pytorch.org/docs/master/notes/amp_examples.html
+    """
+    def __init__(self, use_amp=True):
+        self.use_amp = use_amp
+        self.might_enable_autocast = amp.autocast() if use_amp else nullcontext()
+        self.scaler = amp.GradScaler()
+
+    def backward(self, loss):
+        if self.use_amp:
+            return self.scaler.scale(loss).backward()
+        else:
+            return loss.backward()
+
+    def step(self, optimizer):
+        if self.use_amp:
+            self.scaler.step(optimizer)
+            self.scaler.update()
+        else:
+            optimizer.step()
+
+    def might_unscale_(self, optimizer):
+        if self.use_amp:
+            # Unscales the gradients of optimizer's assigned params in-place
+            self.scaler.unscale_(optimizer)
\ No newline at end of file
diff --git a/convlab2/nlg/scgpt/multiwoz/preprocess.py b/convlab2/nlg/scgpt/multiwoz/preprocess.py
index 10e588886f4e316fbade8eba04059b09f86030ff..ae7b08566842435247b09625b5410bc741c58db8 100644
--- a/convlab2/nlg/scgpt/multiwoz/preprocess.py
+++ b/convlab2/nlg/scgpt/multiwoz/preprocess.py
@@ -6,6 +6,7 @@ Created on Mon Sep 14 11:38:53 2020
 
 import os
 import json
+from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
 from convlab2.nlg.scgpt.utils import dict2dict, dict2seq
 import zipfile
 
@@ -14,65 +15,6 @@ def read_zipped_json(filepath, filename):
     archive = zipfile.ZipFile(filepath, 'r')
     return json.load(archive.open(filename))
 
-cur_dir = os.path.dirname(os.path.abspath(__file__)) 
-data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(
-        cur_dir)))), 'data/multiwoz/')
-
-keys = ['train', 'val', 'test']
-data = {}
-for key in keys:
-    data_key = read_zipped_json(os.path.join(data_dir, key + '.json.zip'), key + '.json')
-    print('load {}, size {}'.format(key, len(data_key)))
-    data = dict(data, **data_key)
-
-with open(os.path.join(data_dir, 'valListFile'), 'r') as f:
-    val_list = f.read().splitlines()
-with open(os.path.join(data_dir, 'testListFile'), 'r') as f:
-    test_list = f.read().splitlines()
-    
-results = {}
-results_val = {}
-results_test = {}
-
-for title, sess in data.items():
-    logs = sess['log']
-    turns = []
-    turn = {'turn':0, 'sys':'', 'sys_da':''}
-    current_domain = None
-    for i, diag in enumerate(logs):
-        text = diag['text']
-        da = diag['dialog_act']
-        span = diag['span_info']
-        if i % 2 == 0:
-            turn['usr'] = text
-            if current_domain:
-                da = eval(str(da).replace('Booking', current_domain))
-                span = eval(str(span).replace('Booking', current_domain))
-            turn['usr_da'] = da
-            turn['usr_span'] = span
-            turns.append(turn)
-        else:
-            turn = {'turn': i//2 +1}
-            turn['sys'] = text
-            turn['sys_da'] = da
-            turn['sys_span'] = span
-        for key in da:
-            domain = key.split('-')[0]
-            if domain not in ['general', 'Booking']:
-                current_domain = domain
-    title = title
-    if title in val_list:
-        current = results_val
-    elif title in test_list:
-        current = results_test
-    else:
-        current = results
-    current[title] = turns
-    
-results = eval(str(results).replace(" n't", " not"))
-results_val = eval(str(results_val).replace(" n't", " not"))
-results_test = eval(str(results_test).replace(" n't", " not"))
-
 def init_domain():
     return {'Attraction':False,
             'Hospital':False,
@@ -82,32 +24,106 @@ def init_domain():
             'Taxi':False,
             'Train':False}
 
-def write_file(name, data):
+def write_file(name, data, role='usr'):
     with open(f'{name}.txt', 'w', encoding='utf-8') as f:
         for ID in data:
             sess = data[ID]
             sess_domains = init_domain()
             for turn in sess:
-                # TODO: set option to process usr/sys
-                if not turn['usr_da']:
-                    continue
-                turn['usr_da'] = eval(str(turn['usr_da']).replace('Bus','Train'))
-                da_seq = dict2seq(dict2dict(turn['usr_da'])).replace('&', 'and')
-                domains = set([key.split('-')[0] for key in turn['usr_da'].keys()])
-                if not turn['sys_da']:
-                    continue
-                turn['sys_da'] = eval(str(turn['sys_da']).replace('Bus','Train'))
-                da_seq = dict2seq(dict2dict(turn['sys_da'])).replace('&', 'and')
-                domains = set([key.split('-')[0] for key in turn['sys_da'].keys()])
+                if role == 'usr':
+                    if not turn['usr_da']:
+                        continue
+                    turn['usr_da'] = eval(str(turn['usr_da']).replace('Bus','Train'))
+                    da_seq = dict2seq(dict2dict(turn['usr_da'])).replace('&', 'and')
+                    domains = set([key.split('-')[0] for key in turn['usr_da'].keys()])
+                elif role == 'sys':
+                    if not turn['sys_da']:
+                        continue
+                    turn['sys_da'] = eval(str(turn['sys_da']).replace('Bus','Train'))
+                    da_seq = dict2seq(dict2dict(turn['sys_da'])).replace('&', 'and')
+                    domains = set([key.split('-')[0] for key in turn['sys_da'].keys()])
+                else:
+                    raise NameError('Invalid Role: Select usr/sys.')
                 for domain in domains:
                     if domain not in ['general', 'Booking'] and not sess_domains[domain]:
                         da_seq = da_seq.replace(domain.lower(), domain.lower()+' *', 1)
                         sess_domains[domain] = True
-                da_uttr = turn['usr'].replace(' bus ', ' train ').replace('&', 'and')
-                da_uttr = turn['sys'].replace(' bus ', ' train ').replace('&', 'and')
+
+                if role == 'usr':
+                    da_uttr = turn['usr'].replace(' bus ', ' train ').replace('&', 'and')
+                elif role == 'sys':
+                    da_uttr = turn['sys'].replace(' bus ', ' train ').replace('&', 'and')
                 f.write(f'{da_seq} & {da_uttr}\n')
 
-if not os.path.exists(os.path.join(cur_dir,'data')):
-    os.makedirs(os.path.join(cur_dir, 'data'))
-write_file(os.path.join(cur_dir, 'data/train'), dict(results, **results_val))
-write_file(os.path.join(cur_dir, 'data/test'), results_test)
+
+if __name__ == '__main__':
+    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--role', type=str, default='usr')
+    args = parser.parse_args()
+
+    cur_dir = os.path.dirname(os.path.abspath(__file__))
+    data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(
+            cur_dir)))), 'data/multiwoz/')
+
+    keys = ['train', 'val', 'test']
+    data = {}
+    for key in keys:
+        data_key = read_zipped_json(os.path.join(data_dir, key + '.json.zip'), key + '.json')
+        print('load {}, size {}'.format(key, len(data_key)))
+        data = dict(data, **data_key)
+
+    with open(os.path.join(data_dir, 'valListFile'), 'r') as f:
+        val_list = f.read().splitlines()
+    with open(os.path.join(data_dir, 'testListFile'), 'r') as f:
+        test_list = f.read().splitlines()
+
+    results = {}
+    results_val = {}
+    results_test = {}
+
+    for title, sess in data.items():
+        logs = sess['log']
+        turns = []
+        turn = {'turn': 0, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''}
+        current_domain = None
+        for i, diag in enumerate(logs):
+            text = diag['text']
+            da = diag['dialog_act']
+            span = diag['span_info']
+            if current_domain:
+                da = eval(str(da).replace('Booking', current_domain))
+                span = eval(str(span).replace('Booking', current_domain))
+            if i % 2 == 0:
+                turn['usr'] = text
+                turn['usr_da'] = da
+                turn['usr_span'] = span
+                turns.append(turn)
+            else:
+                turn = {'turn': i//2 + 1, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''}
+                turn['sys'] = text
+                turn['sys_da'] = da
+                turn['sys_span'] = span
+            for key in da:
+                domain = key.split('-')[0]
+                if domain not in ['general', 'Booking']:
+                    current_domain = domain
+        else:
+            if args.role == 'sys':
+                turns.append(turn)
+        title = title
+        if title in val_list:
+            current = results_val
+        elif title in test_list:
+            current = results_test
+        else:
+            current = results
+        current[title] = turns
+
+    results = eval(str(results).replace(" n't", " not"))
+    results_val = eval(str(results_val).replace(" n't", " not"))
+    results_test = eval(str(results_test).replace(" n't", " not"))
+
+    if not os.path.exists(os.path.join(cur_dir,'data')):
+        os.makedirs(os.path.join(cur_dir, 'data'))
+    write_file(os.path.join(cur_dir, f'data/train_{args.role}'), dict(results, **results_val), role=args.role)
+    write_file(os.path.join(cur_dir, f'data/test_{args.role}'), results_test, role=args.role)
diff --git a/convlab2/nlg/scgpt/multiwoz/scgpt.py b/convlab2/nlg/scgpt/multiwoz/scgpt.py
index b4b957c0d1b17a3805f388641dcd14bfe0b32fa2..78f16f6e0b8562c7118a2a6118f0eb5b3287c828 100644
--- a/convlab2/nlg/scgpt/multiwoz/scgpt.py
+++ b/convlab2/nlg/scgpt/multiwoz/scgpt.py
@@ -2,6 +2,7 @@ import torch
 import numpy as np
 import os
 import zipfile
+from copy import deepcopy
 
 from transformers import GPT2LMHeadModel, GPT2Tokenizer
 from convlab2.nlg.scgpt.utils import tuple2seq
@@ -10,23 +11,31 @@ from convlab2.nlg.nlg import NLG
 from convlab2.util.file_util import cached_path
 
 MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop
-DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
-DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "nlg-gpt-multiwoz.zip")
 
 class SCGPT(NLG):
     
-    def __init__(self,
-                 archive_file=DEFAULT_ARCHIVE_FILE,
-                 use_cuda=True,
-                 is_user=False,
-                 model_file='https://convlab.blob.core.windows.net/convlab-2/nlg-gpt-multiwoz.zip'):
+    def __init__(self, model_file=None,
+                 use_cuda=True, is_user=False):
+        # If no filename is mentioned then set to default
+        if not model_file:
+            if is_user:
+                model_file = 'https://convlab.blob.core.windows.net/convlab-2/nlg-gpt-multiwoz.zip'
+            else:
+                model_file = 'https://zenodo.org/record/5767426/files/neo_scgpt_system.zip'
+
+        # Load from file/url
         model_dir = os.path.dirname(os.path.abspath(__file__))
-        if not os.path.isfile(archive_file):
-            archive_file = cached_path(model_file)
-            archive = zipfile.ZipFile(archive_file, 'r')
+        if not os.path.isfile(model_file):
+            model_file = cached_path(model_file)
+        if not os.path.isdir(model_file):
+            archive = zipfile.ZipFile(model_file, 'r')
             archive.extractall(model_dir)
-        
-        self.model_name_or_path = os.path.join(model_dir, 'multiwoz')
+            # Get model directory
+            model_file = archive.filelist[0].filename.replace('/', '')
+            self.model_name_or_path = os.path.join(model_dir, model_file)
+        else:
+            self.model_name_or_path = model_file
+            
         self.length = 50
         self.num_samples = 5
         self.temperature = 1.0
@@ -63,8 +72,9 @@ class SCGPT(NLG):
             'Restaurant':False,
             'Taxi':False,
             'Train':False,}
-        if not self.is_user:
-            self.sess_domains['Booking'] = False
+        self.cur_domain = None
+        # if not self.is_user:
+        #     self.sess_domains['Booking'] = False
                 
     def generate(self, meta):
 
@@ -72,10 +82,23 @@ class SCGPT(NLG):
         if not meta:
             return 'No user action'
 
+        meta = deepcopy(meta)
+        for list_ in meta:
+            domain = list_[1]
+            if domain not in ('general', 'Booking'):
+                self.cur_domain = domain
+        for i, list_ in enumerate(meta):
+            list_ = list(list_)
+            if list_[1] == 'Booking':
+                if self.cur_domain is not None:
+                    list_[1] = self.cur_domain
+                    meta[i] = list_
+                else:
+                    print('`cur_domain` is None, but there is `Booking` in dialog action.')
         raw_text = tuple2seq(meta)
         domains = set([item[1] for item in meta])
         for domain in domains:
-            if domain != 'general' and not self.sess_domains[domain]:
+            if domain not in ('general', 'Booking') and not self.sess_domains[domain]:
                 raw_text = raw_text.replace(domain.lower(), domain.lower()+ ' *', 1)
                 self.sess_domains[domain] = True
         context_tokens = self.tokenizer.encode(raw_text, add_special_tokens=False)
@@ -97,4 +120,4 @@ class SCGPT(NLG):
         text = text.split('& ')[-1]
         text = text[: text.find(self.stop_token) if self.stop_token else None]
     
-        return text
\ No newline at end of file
+        return text
diff --git a/convlab2/nlg/scgpt/train.py b/convlab2/nlg/scgpt/train.py
index 775688bbd63e116da42d5f02ecb78930c823a229..0878f31353735ede8b2036ec1f46ef56ce129bed 100644
--- a/convlab2/nlg/scgpt/train.py
+++ b/convlab2/nlg/scgpt/train.py
@@ -9,33 +9,28 @@ import random
 import re
 import shutil
 
-import sys
-
 import numpy as np
 import torch
+from tqdm import tqdm, trange
 from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
 from torch.utils.data.distributed import DistributedSampler
 
 try:
     from torch.utils.tensorboard import SummaryWriter
-except:
+except ImportError:
     from tensorboardX import SummaryWriter
 
-from tqdm import tqdm, trange
-
 from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
-                                  BertConfig, BertForMaskedLM, BertTokenizer,
-                                  GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
-                                  OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
-                                  RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
-                                  DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer, BertTokenizer)
-
+                          BertConfig, BertForMaskedLM, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
+                          OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2TokenizerFast,
+                          RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
+                          DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer, BertTokenizer)
+from convlab2.nlg.scgpt.modeling_utils import AmpGPT2LMHeadModel, try_enable_gradient_checkpointing, AmpHelper
 
 logger = logging.getLogger(__name__)
 
-
 MODEL_CLASSES = {
-    'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
+    'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast),
     'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
     'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
     'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
@@ -43,11 +38,20 @@ MODEL_CLASSES = {
 }
 
 
+def closest_multiple_of_8(n):
+    """
+    Returns:
+        a closest number, which is a multiple of 8 and >= n
+    """
+    return ((n + 7) >> 3) << 3
+
+
 class TextDataset(Dataset):
     def __init__(self, tokenizer, args, file_path='train', block_size=512, max_seq=80):
         assert os.path.isfile(file_path)
         directory, filename = os.path.split(file_path)
-        cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(block_size) + '_seqlen_' + str(max_seq) + '_' + filename)
+        cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(
+            block_size) + '_seqlen_' + str(max_seq) + '_' + filename)
 
         if os.path.exists(cached_features_file) and not args.overwrite_cache:
             logger.info("Loading features from cached file %s", cached_features_file)
@@ -68,12 +72,11 @@ class TextDataset(Dataset):
                         self.examples.append(tokenized_text)
 
             if args.text_chunk:
-                for i in range(0, len(tokenized_text)-block_size+1, block_size): # Truncate in block of block_size
-                    self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i:i+block_size]))
+                for i in range(0, len(tokenized_text) - block_size + 1, block_size):  # Truncate in block of block_size
+                    self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i:i + block_size]))
 
-            
             # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
-            # If your dataset is small, first you should loook for a bigger one :-) and second you
+            # If your dataset is small, first you should look for a bigger one :-) and second you
             # can change this behavior by adding (model specific) padding.
 
             logger.info("Saving features into cached file %s", cached_features_file)
@@ -86,26 +89,30 @@ class TextDataset(Dataset):
     def __getitem__(self, item):
         return torch.tensor(self.examples[item])
 
+
 class TextSeqDataset(Dataset):
-    def __init__(self, tokenizer, args, file_path='train', block_size=512, max_seq=80, seperator=' & '):
+    def __init__(self, tokenizer, args, file_path='train', block_size=512, max_seq=80, separator=' & '):
+        max_seq = closest_multiple_of_8(max_seq)
         assert os.path.isfile(file_path)
         directory, filename = os.path.split(file_path)
-        cached_features_file = os.path.join(directory, args.output_dir.replace(os.sep, '_') + '_cached_lm_' + str(block_size) + '_seqlen_' + str(max_seq) + '_' + filename)
+        cached_features_file = os.path.join(directory, args.output_dir.replace(os.sep, '_') + '_cached_lm_' + str(
+            block_size) + '_seqlen_' + str(max_seq) + '_' + filename)
 
         if os.path.exists(cached_features_file) and not args.overwrite_cache:
             logger.info("Loading features from cached file %s", cached_features_file)
             with open(cached_features_file, 'rb') as handle:
-                self.examples = pickle.load(handle)
+                self.examples, self.masks, self.labels,  self.seq_lengths = pickle.load(handle)
         else:
             logger.info("Creating features from dataset file at %s", directory)
             self.examples = []
             self.labels = []
             self.masks = []
+            self.seq_lengths = []
             with open(file_path, encoding="utf-8") as f:
-                for line in f:
-                    line = line.strip()      
-                    raw_str = line.lower()
-                    code_str = line.lower().split(seperator)[0] + seperator
+                for line in tqdm(f):
+                    line = line.strip()
+                    raw_str = line.lower()  # do we need lowercase?
+                    code_str = line.lower().split(separator)[0] + separator
                     code_str = code_str.strip()
                     if len(raw_str.split()) > max_seq -1:
                         raw_str = ' '.join(raw_str.split()[:max_seq -1])
@@ -118,40 +125,44 @@ class TextSeqDataset(Dataset):
                         code_str_len =  len(tokenizer.convert_tokens_to_ids(code_str.split()))
 
                     label = [-1] *  max_seq
-                    label[:len(tokenized_text)] = tokenized_text 
+                    label[:len(tokenized_text)] = tokenized_text
                     mask = [1] *  max_seq
 
-
                     if len(tokenized_text) < max_seq:
+                        self.seq_lengths.append(len(tokenized_text))
                         mask[-(max_seq - len(tokenized_text)):] = [0] * (max_seq - len(tokenized_text))
                         # label[code_str_len:len(tokenized_text)] = tokenized_text[code_str_len:]
-                        tokenized_text = tokenized_text + [0] * (max_seq - len(tokenized_text)) 
+                        tokenized_text = tokenized_text + [tokenizer.eos_token_id] * (max_seq - len(tokenized_text))
                     else:
+                        self.seq_lengths.append(max_seq)
                         tokenized_text = tokenized_text[:max_seq]
-                        # label[code_str_len:] = tokenized_text[code_str_len:] 
-                    
+                        # label[code_str_len:] = tokenized_text[code_str_len:]
+
                     self.examples.append(tokenized_text)
                     self.masks.append(mask)
                     self.labels.append(label)
 
             # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
-            # If your dataset is small, first you should loook for a bigger one :-) and second you
+            # If your dataset is small, first you should look for a bigger one :-) and second you
             # can change this behavior by adding (model specific) padding.
             if args.with_code_loss:
                 self.labels = self.examples
             logger.info("Saving features into cached file %s", cached_features_file)
             with open(cached_features_file, 'wb') as handle:
-                pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
+                pickle.dump((self.examples, self.masks, self.labels, self.seq_lengths), handle,
+                            protocol=pickle.HIGHEST_PROTOCOL)
 
     def __len__(self):
         return len(self.examples)
 
     def __getitem__(self, item):
-        return torch.tensor(self.examples[item]), torch.tensor(self.masks[item]), torch.tensor(self.labels[item])
+        return torch.tensor(self.examples[item]), torch.tensor(self.masks[item]), torch.tensor(
+            self.labels[item]), torch.tensor(self.seq_lengths[item])
 
 
 def load_and_cache_examples(args, tokenizer, evaluate=False):
-    dataset = TextSeqDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size, max_seq=args.max_seq)
+    dataset = TextSeqDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file,
+                             block_size=args.block_size, max_seq=args.max_seq)
     return dataset
 
 
@@ -197,7 +208,8 @@ def mask_tokens(inputs, tokenizer, args):
     labels = inputs.clone()
     # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
     probability_matrix = torch.full(labels.shape, args.mlm_probability)
-    special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
+    special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in
+                           labels.tolist()]
     probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
     masked_indices = torch.bernoulli(probability_matrix).bool()
     labels[~masked_indices] = -1  # We only compute loss on masked tokens
@@ -215,6 +227,23 @@ def mask_tokens(inputs, tokenizer, args):
     return inputs, labels
 
 
+def preprocess_batch(inputs, masks, labels, seq_lengths):
+    """
+    The real sequence length of a batch may be shorter than max_seq of the whole dataset.
+    Remove some padding tokens to accelerate the training process.
+    And make sure that the sequence length is multiple of 8.
+
+    References:
+        https://huggingface.co/transformers/performance.html#fp16
+    """
+    # The gain for FP16 training is that in each of those cases, the training with the flag --fp16 is twice as fast,
+    # which does require every tensor to have every dimension be a multiple of 8
+    # (examples pad the tensors to a sequence length that is a multiple of 8).
+    max_seq_len = seq_lengths.max()
+    max_seq_len = closest_multiple_of_8(max_seq_len)
+    return inputs[:, :max_seq_len], masks[:, :max_seq_len], labels[:, :max_seq_len]
+
+
 def train(args, train_dataset, model, tokenizer):
     """ Train the model """
     if args.local_rank in [-1, 0]:
@@ -233,27 +262,23 @@ def train(args, train_dataset, model, tokenizer):
     # Prepare optimizer and schedule (linear warmup and decay)
     no_decay = ['bias', 'LayerNorm.weight']
     optimizer_grouped_parameters = [
-        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
+        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+         'weight_decay': args.weight_decay},
         {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
-        ]
+    ]
     optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
-    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
-    if args.fp16:
-        try:
-            from apex import amp
-        except ImportError:
-            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
-        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
-    model.resize_token_embeddings(len(tokenizer))
-    # multi-gpu training (should be after apex fp16 initialization)
+    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
+                                                num_training_steps=t_total)
+    # https://pytorch.org/docs/master/notes/amp_examples.html
+    amp_helper = AmpHelper(use_amp=args.fp16)
     if args.n_gpu > 1:
         model = torch.nn.DataParallel(model)
 
-    # Distributed training (should be after apex fp16 initialization)
+    # Distributed training
     if args.local_rank != -1:
         model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                           output_device=args.local_rank,
-                                                          find_unused_parameters=True)
+                                                          find_unused_parameters=False)
 
     # Train!
     logger.info("***** Running training *****")
@@ -261,7 +286,8 @@ def train(args, train_dataset, model, tokenizer):
     logger.info("  Num Epochs = %d", args.num_train_epochs)
     logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
     logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
-                   args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
+                args.train_batch_size * args.gradient_accumulation_steps * (
+                    torch.distributed.get_world_size() if args.local_rank != -1 else 1))
     logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
     logger.info("  Total optimization steps = %d", t_total)
 
@@ -271,12 +297,13 @@ def train(args, train_dataset, model, tokenizer):
     train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
     set_seed(args)  # Added here for reproducibility (even between python 2 and 3)
     for e in train_iterator:
-        
+
         # epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
         for step, batch in enumerate(train_dataloader):
             # inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
-            logger.info(f"  PROGRESS: {float(global_step)/t_total*100}%")
-            inputs, masks, labels = batch
+            logger.info(f"  PROGRESS: {float(global_step) / t_total * 100}%")
+            inputs, masks, labels, seq_lengths = batch
+            inputs, masks, labels = preprocess_batch(inputs, masks, labels, seq_lengths)  # cut seq
             # import pdb
             # pdb.set_trace()
             inputs = inputs.to(args.device)
@@ -284,27 +311,29 @@ def train(args, train_dataset, model, tokenizer):
             labels = labels.to(args.device)
 
             model.train()
-            outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
-            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
-
-            if args.n_gpu > 1:
-                loss = loss.mean()  # mean() to average on multi-gpu parallel training
-            if args.gradient_accumulation_steps > 1:
-                loss = loss / args.gradient_accumulation_steps
-
-            if args.fp16:
-                with amp.scale_loss(loss, optimizer) as scaled_loss:
-                    scaled_loss.backward()
-            else:
-                loss.backward()
+            try:
+                with amp_helper.might_enable_autocast:
+                    outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
+                    loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
+
+                    if args.n_gpu > 1:
+                        loss = loss.mean()  # mean() to average on multi-gpu parallel training
+                    if args.gradient_accumulation_steps > 1:
+                        loss = loss / args.gradient_accumulation_steps
+
+                amp_helper.backward(loss)
+            except RuntimeError as e:
+                if 'CUDA out of memory' in str(e):
+                    # if out of memory, we must choose smaller batch_size
+                    print(f'inputs.shape = {inputs.shape}, labels.shape = {labels.shape}')
+                raise
 
             tr_loss += loss.item()
             if (step + 1) % args.gradient_accumulation_steps == 0:
-                if args.fp16:
-                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
-                else:
-                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
-                optimizer.step()
+                amp_helper.might_unscale_(optimizer)
+                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
+                # optimizer.step()
+                amp_helper.step(optimizer)
                 scheduler.step()  # Update learning rate schedule
                 model.zero_grad()
                 global_step += 1
@@ -317,7 +346,7 @@ def train(args, train_dataset, model, tokenizer):
                             tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                     tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                     tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
-                    logger.info(f"  EVALERR:  {(tr_loss - logging_loss)/float(args.logging_steps)}")
+                    logger.info(f"  EVALERR:  {(tr_loss - logging_loss) / float(args.logging_steps)}")
                     logging_loss = tr_loss
 
                 if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
@@ -326,7 +355,8 @@ def train(args, train_dataset, model, tokenizer):
                     output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step))
                     if not os.path.exists(output_dir):
                         os.makedirs(output_dir)
-                    model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
+                    model_to_save = model.module if hasattr(model,
+                                                            'module') else model  # Take care of distributed/parallel training
                     model_to_save.save_pretrained(output_dir)
                     tokenizer.save_pretrained(output_dir)
                     torch.save(args, os.path.join(output_dir, 'training_args.bin'))
@@ -334,12 +364,9 @@ def train(args, train_dataset, model, tokenizer):
 
                     _rotate_checkpoints(args, checkpoint_prefix)
 
-            # if args.max_steps > 0 and global_step > args.max_steps:
-                # epoch_iterator.close()
-                # break
-        if args.max_steps > 0 and global_step > args.max_steps:
-            train_iterator.close()
-            break
+            if global_step > args.max_steps > 0:
+                train_iterator.close()
+                break
 
     if args.local_rank in [-1, 0]:
         tb_writer.close()
@@ -362,7 +389,9 @@ def evaluate(args, model, tokenizer, prefix=""):
     eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
 
     # multi-gpu evaluate
-    if args.n_gpu > 1:
+    if args.n_gpu > 1 and not (isinstance(model, torch.nn.DataParallel) or
+                               isinstance(model, torch.nn.parallel.DistributedDataParallel)):
+        # if args.evaluate_during_training, DataParallel is already used
         model = torch.nn.DataParallel(model)
 
     # Eval!
@@ -376,9 +405,10 @@ def evaluate(args, model, tokenizer, prefix=""):
     for batch in tqdm(eval_dataloader, desc="Evaluating"):
         # inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
 
-        inputs, masks, labels = batch
-            # import pdb
-            # pdb.set_trace()
+        inputs, masks, labels, seq_lengths = batch
+        inputs, masks, labels = preprocess_batch(inputs, masks, labels, seq_lengths)  # cut seq
+        # import pdb
+        # pdb.set_trace()
         inputs = inputs.to(args.device)
         masks = masks.to(args.device)
         labels = labels.to(args.device)
@@ -387,12 +417,12 @@ def evaluate(args, model, tokenizer, prefix=""):
 
         with torch.no_grad():
             outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
-            lm_loss = outputs[0]
-            eval_loss += lm_loss.mean().item()
+            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
+            eval_loss += loss.mean().item()
         nb_eval_steps += 1
 
     eval_loss = eval_loss / nb_eval_steps
-    perplexity = torch.exp(torch.tensor(eval_loss))
+    perplexity = float(np.exp(eval_loss))
 
     result = {
         "perplexity": perplexity
@@ -409,6 +439,7 @@ def evaluate(args, model, tokenizer, prefix=""):
 
 
 def main():
+    global AdamW
     parser = argparse.ArgumentParser()
 
     ## Required parameters
@@ -489,10 +520,7 @@ def main():
                         help="random seed for initialization")
 
     parser.add_argument('--fp16', action='store_true',
-                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
-    parser.add_argument('--fp16_opt_level', type=str, default='O1',
-                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
-                             "See details at https://nvidia.github.io/apex/amp.html")
+                        help="Whether to use 16-bit (mixed) precision (through torch.cuda.amp) instead of 32-bit")
     parser.add_argument("--local_rank", type=int, default=-1,
                         help="For distributed training: local_rank")
     parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
@@ -504,18 +532,32 @@ def main():
 
     parser.add_argument("--max_seq", default=80, type=int,
                         help="")
+    parser.add_argument('--gradient_checkpointing', action='store_true', help='enable gradient checkpointing')
+    parser.add_argument('--use_multi_tensor_adamw', action='store_true',
+                        help='use torch.optim._multi_tensor.AdamW instead of transformers.AdamW')
 
     args = parser.parse_args()
+    if args.use_multi_tensor_adamw:
+        try:
+            # overwrite the previous imported AdamW
+            # https://huggingface.co/transformers/performance.html#faster-optimizer
+            from torch.optim._multi_tensor import AdamW
+        except ImportError as e:
+            print(e)
 
     if args.model_type in ["bert", "roberta", "distilbert"] and not args.mlm:
         raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
                          "flag (masked language modeling).")
     if args.eval_data_file is None and args.do_eval:
-        raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
-                         "or remove the --do_eval argument.")
+        raise ValueError(
+            "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
+            "or remove the --do_eval argument.")
 
-    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
-        raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
+    if os.path.exists(args.output_dir) and os.listdir(
+            args.output_dir) and args.do_train and not args.overwrite_output_dir:
+        raise ValueError(
+            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
+                args.output_dir))
 
     # Setup distant debugging if needed
     if args.server_ip and args.server_port:
@@ -525,6 +567,11 @@ def main():
         ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
         ptvsd.wait_for_attach()
 
+    # Setup logging before `torch.distributed.init_process_group` is called
+    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
+                        datefmt='%m/%d/%Y %H:%M:%S',
+                        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
+
     # Setup CUDA, GPU & distributed training
     if args.local_rank == -1 or args.no_cuda:
         device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
@@ -535,13 +582,8 @@ def main():
         torch.distributed.init_process_group(backend='nccl')
         args.n_gpu = 1
     args.device = device
-
-    # Setup logging
-    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
-                        datefmt = '%m/%d/%Y %H:%M:%S',
-                        level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
     logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
-                    args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
+                   args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
 
     # Set seed
     set_seed(args)
@@ -550,14 +592,16 @@ def main():
     if args.local_rank not in [-1, 0]:
         torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training download model & vocab
 
+    if args.fp16:
+        MODEL_CLASSES['gpt2'] = (GPT2Config, AmpGPT2LMHeadModel, GPT2TokenizerFast)
     config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
     config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
                                           cache_dir=args.cache_dir if args.cache_dir else None)
     tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
-    #tokenizer = BertTokenizer(vocab_file='../GPT2-chitchat/vocabulary/vocab_small.txt', eos_token='<T>',
+                                                # tokenizer = BertTokenizer(vocab_file='../GPT2-chitchat/vocabulary/vocab_small.txt', eos_token='<T>',
                                                 do_lower_case=args.do_lower_case,
                                                 cache_dir=args.cache_dir if args.cache_dir else None)
-    
+
     if args.block_size <= 0:
         args.block_size = tokenizer.max_len_single_sentence  # Our input block size will be the max possible for the model
     args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
@@ -565,7 +609,13 @@ def main():
                                         from_tf=bool('.ckpt' in args.model_name_or_path),
                                         config=config,
                                         cache_dir=args.cache_dir if args.cache_dir else None)
+    if model.config.vocab_size != len(tokenizer):
+        logger.info('resize token embeddings, since there may be added tokens.')
+        model.resize_token_embeddings(len(tokenizer))
     model.to(args.device)
+    if args.gradient_checkpointing:
+        # https://huggingface.co/transformers/performance.html#gradient-checkpointing
+        try_enable_gradient_checkpointing(model)
 
     if args.local_rank == 0:
         torch.distributed.barrier()  # End of barrier to make sure only the first process in distributed training download model & vocab
@@ -585,7 +635,6 @@ def main():
         global_step, tr_loss = train(args, train_dataset, model, tokenizer)
         logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
 
-
     # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
     if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
         # Create output directory if needed
@@ -595,7 +644,8 @@ def main():
         logger.info("Saving model checkpoint to %s", args.output_dir)
         # Save a trained model, configuration and tokenizer using `save_pretrained()`.
         # They can then be reloaded using `from_pretrained()`
-        model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
+        model_to_save = model.module if hasattr(model,
+                                                'module') else model  # Take care of distributed/parallel training
         model_to_save.save_pretrained(args.output_dir)
         tokenizer.save_pretrained(args.output_dir)
 
@@ -607,25 +657,24 @@ def main():
         tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
         model.to(args.device)
 
-
     # Evaluation
     results = {}
     if args.do_eval and args.local_rank in [-1, 0]:
         checkpoints = [args.output_dir]
         if args.eval_all_checkpoints:
-            checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
+            checkpoints = list(
+                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
             logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
         logger.info("Evaluate the following checkpoints: %s", checkpoints)
         for checkpoint in checkpoints:
             global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
             prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
-            
+
             model = model_class.from_pretrained(checkpoint)
             model.to(args.device)
             result = evaluate(args, model, tokenizer, prefix=prefix)
             result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
             results.update(result)
-
     return results