From 7e0eac70f85123e798ec6638d6292f6216d3139e Mon Sep 17 00:00:00 2001
From: zz-jacob <zhangz.goal@gmail.com>
Date: Mon, 21 Mar 2022 10:16:59 +0800
Subject: [PATCH] add training and inference code for sc-gpt

---
 convlab2/nlg/scgpt/main.py                 | 239 ++++++++
 convlab2/nlg/scgpt/model.py                |  24 +
 convlab2/nlg/scgpt/scgpt_special_tokens.py |  14 +
 convlab2/nlg/scgpt/train.py                | 682 ---------------------
 convlab2/nlg/scgpt/train.sh                |   1 +
 convlab2/nlg/scgpt/util.py                 |  94 +++
 6 files changed, 372 insertions(+), 682 deletions(-)
 create mode 100644 convlab2/nlg/scgpt/main.py
 create mode 100644 convlab2/nlg/scgpt/model.py
 create mode 100644 convlab2/nlg/scgpt/scgpt_special_tokens.py
 delete mode 100644 convlab2/nlg/scgpt/train.py
 create mode 100644 convlab2/nlg/scgpt/train.sh
 create mode 100644 convlab2/nlg/scgpt/util.py

diff --git a/convlab2/nlg/scgpt/main.py b/convlab2/nlg/scgpt/main.py
new file mode 100644
index 00000000..a15f7612
--- /dev/null
+++ b/convlab2/nlg/scgpt/main.py
@@ -0,0 +1,239 @@
+import sys
+sys.path.append('../../..')
+
+import argparse
+from tqdm import tqdm
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import GPT2Tokenizer, GPT2LMHeadModel
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset
+from torch.utils.tensorboard import SummaryWriter
+import os
+from transformers import get_linear_schedule_with_warmup
+
+from convlab2.util.unified_datasets_util import load_dataset, load_nlg_data
+from convlab2.nlg.scgpt.util import act2str
+from convlab2.nlg.scgpt.model import SCGPTDataset
+
+# 分部式训练
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+from util import build_mask
+from scgpt_special_tokens import *
+# from plot.attn_plot import plot_attn_encdec
+
+## Model Testing
+code_test = False
+
+## 多GPU并行参数设置
+parser = argparse.ArgumentParser()
+parser.add_argument("--local_rank", default=-1, type=int)
+parser.add_argument('--do_train', action="store_true", help="Whether to run training.")
+parser.add_argument('--dataset', default="multiwoz21", type=str, help="Whether to run training.")
+FLAGS = parser.parse_args()
+local_rank = FLAGS.local_rank
+
+torch.cuda.set_device(local_rank)
+dist.init_process_group(backend='nccl')
+
+# TensorBoard
+tb_writer = SummaryWriter()
+
+special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK]
+## load model
+tokenizer = GPT2Tokenizer.from_pretrained('./gpt2')
+tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, 'additional_special_tokens': special_tokens})
+model = GPT2LMHeadModel.from_pretrained('./gpt2').to(local_rank)
+model.resize_token_embeddings(len(tokenizer))
+
+## loss计算
+nll_loss = nn.NLLLoss(reduce=False).to(local_rank)
+ce_loss = nn.CrossEntropyLoss(reduce=False).to(local_rank)
+def cal_loss(input, target, seq_lens, seq_lens_input):
+    """Only calculate loss on responses, not on dialog act"""
+    global nll_loss
+    """Input: [batch, length, vocab]; target: [batch, length]; seq_lens: [batch]"""
+    log_probs = F.log_softmax(input, dim=-1).transpose(1, 2)  # 类别维度要放在dim=1的位置,nn.NLLLoss的要求
+    loss = nll_loss(log_probs, target)
+    # loss = ce_loss(input, target)  # 等价
+    mask = build_mask(torch.max(seq_lens).item()-1, seq_lens-1).to(local_rank)
+    input_mask = build_mask(torch.max(seq_lens).item()-1, seq_lens_input-1).to(local_rank)
+    output_mask = torch.logical_xor(mask, input_mask)
+    pad_mask = torch.logical_not(mask)
+    # masked_loss = loss * output_mask
+    masked_loss = loss * (output_mask + pad_mask)
+    mean_loss = torch.sum(masked_loss) / torch.sum(output_mask + pad_mask)
+    return mean_loss
+
+
+def pad_collate(batch):
+    """
+    Returns:
+    batch: batch * max_len
+    seq_lens: the length of len(da)+1+len(response)
+    seq_lens_input: the length of len(da)
+    """
+    START_OF_PRED_ID = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED)
+    pad_token_id = tokenizer.pad_token_id
+    batch = [item[0] + [START_OF_PRED_ID] + item[1] for item in batch]
+    batch = [item[-512:] for item in batch]  # TF限制输入长度
+    max_len = max([len(item) for item in batch])
+    seq_lens = [len(item) for item in batch]
+    split_id = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED)
+    def get_x_len(tokens):
+        """Get the length of dialogue act tokens"""
+        split_idx = len(tokens)
+        try:
+            split_idx = tokens.index(split_id)+1
+        except:
+            pass
+        return split_idx
+    seq_lens_input = [get_x_len(item) for item in batch]
+    batch = [item + [pad_token_id]*(max_len-len(item)) for item in batch]
+    return torch.LongTensor(batch), torch.LongTensor(seq_lens), torch.LongTensor(seq_lens_input)
+
+## Training Hyper-params
+EPOCH_NUM = 20
+BATCH_SIZE = 10   # real_batch_size = BATCH_SIZE * num_gpu
+VAL_STEP = 30
+WARM_STEPS = 250
+if code_test:
+    EPOCH_NUM = 2
+    BATCH_SIZE = 4
+    VAL_STEP = 2
+    WARM_STEPS = 3
+LR = 5e-5
+TASK_TYPE = 'nlu'  # nlu or dst
+SAVE_PATH = f'./saved_model'
+def train(model, nlg_data, global_step=0):
+    train_dataset = SCGPTDataset(nlg_data['train'], tokenizer)
+    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=2, sampler=train_sampler, collate_fn=pad_collate)
+
+    val_dataset = SCGPTDataset(nlg_data['validation'], tokenizer)
+    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
+    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=2, sampler=val_sampler, collate_fn=pad_collate)
+
+    model = DDP(model, device_ids=[local_rank], output_device=local_rank)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
+    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARM_STEPS,
+                                                num_training_steps=len(train_dataloader) * EPOCH_NUM)
+    model.train()
+    for epoch in range(EPOCH_NUM):
+        train_dataloader.sampler.set_epoch(epoch)
+        for batch_id, (inputs, seq_lens, seq_lens_input) in enumerate(tqdm(train_dataloader, desc=f'EPOCH:[{epoch+1}/{EPOCH_NUM}]')):
+            inputs = inputs.to(local_rank)
+            seq_lens = seq_lens.to(local_rank)
+            seq_lens_input = seq_lens_input.to(local_rank)
+
+            outputs = model(inputs)
+            preds = outputs[0]
+            loss = cal_loss(preds[:, :-1, :], inputs[:, 1:], seq_lens, seq_lens_input)
+
+            optimizer.zero_grad()
+            loss.backward()
+            optimizer.step()
+            scheduler.step()
+            tb_writer.add_scalar(f'Train/loss', loss.item(), global_step)
+            tb_writer.add_scalar(f'Train/PPL', torch.exp(loss).item(), global_step)
+            tb_writer.add_scalar(f'Train/Learning Rate', scheduler.get_last_lr()[0], global_step)
+
+            if batch_id % VAL_STEP == 0:
+                model.eval()
+                val_loss = eval(model, val_dataloader)
+                ppl = np.exp(val_loss)
+                tb_writer.add_scalar(f'Val/Loss', val_loss, global_step)
+                tb_writer.add_scalar(f'Val/PPL', ppl, global_step)
+                model.train()
+            global_step += 1
+        # save the model when each epoch ends
+        if dist.get_rank() == 0:
+            save_dir = os.path.join(SAVE_PATH, f'epoch_{epoch}')
+            os.makedirs(save_dir, exist_ok=True)
+            torch.save(model.module.state_dict(), os.path.join(save_dir, f'epoch_{epoch}_step{global_step}.pt'))
+            tokenizer.save_pretrained(save_dir)
+            torch.save(optimizer.state_dict(), os.path.join(save_dir, 'optimizer.pt'))
+            torch.save(scheduler.state_dict(), os.path.join(save_dir, 'scheduler.pt'))
+            print(f'Save model checkpoint to [{save_dir}]')
+    tb_writer.flush()
+
+
+def eval(model, loader, use_tqdm=False):
+    with torch.no_grad():
+        loss_list = []
+        iter = tqdm(loader, desc='Val') if use_tqdm else loader
+        for inputs, seq_lens, seq_lens_input in iter:
+            inputs = inputs.to(local_rank)
+            seq_lens = seq_lens.to(local_rank)
+            seq_lens_input = seq_lens_input.to(local_rank)
+            outputs = model(inputs)
+            preds = outputs[0]
+            loss = cal_loss(preds[:, :-1, :], inputs[:, 1:], seq_lens, seq_lens_input)
+            loss_list.append(loss.item())
+        mean_loss = np.mean(loss_list)
+    return mean_loss
+
+
+def inference_batch(model, sents):
+    """Inference model given a batch of sents."""
+    with torch.no_grad():
+        sents = [sent + ' ' + START_OF_PRED for sent in sents]
+        sent_ids = [tokenizer.encode(sent) for sent in sents]
+        max_len = max([len(sent) for sent in sent_ids])
+        sent_ids = [sent + [tokenizer.pad_token_id]*(max_len-len(sent)) for sent in sent_ids]
+        inputs = torch.LongTensor(sent_ids).to(local_rank)
+        model_to_run = model.module if type(model) is DDP else model
+        outputs = model_to_run.generate(inputs, max_length=513, eos_token_id=tokenizer.eos_token_id,
+                                        pad_token_id=tokenizer.pad_token_id)  # greedy
+        # outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id,
+        #                                 pad_token_id=gpt2_tokenizer.pad_token_id)  # beam search
+        output_strs = [tokenizer.decode(item) for item in outputs]
+        return output_strs
+
+
+def inference_sent(model, sent):
+    """Inference model given one single sentence."""
+    return inference_batch(model, [sent])[0]
+
+
+def inference_sents(model, sents):
+    """Get the outputs of multiple sentences."""
+    outputs = []
+    for sent in tqdm(sents, desc='Inference Sentences'):
+        output = inference_sent(model, sent)
+        outputs.append(output)
+    return outputs
+
+
+def test(model, nlg_data, model_path):
+    """将sheel中的GPU个数设为1运行"""
+    model.load_state_dict(torch.load(model_path))
+    print(f'model loaded from [{model_path}]')
+    # sample_file = os.path.join(f'../../../data/dstc2/sample50_{TASK_TYPE}_input_data.txt')
+    # Load test nlg data
+    test_data = nlg_data['test']
+    dialog_acts = [act2str(item['dialogue_acts']) for item in test_data]
+    golden_responses = [item['utterance'] for item in test_data]
+    outputs = inference_sents(model, dialog_acts, use_tqdm=True)
+    if dist.get_rank() == 0:
+        output_file = './test_output.txt'
+        with open(output_file, 'w+') as f:
+            for i in range(len(dialog_acts)):
+                f.write(f'{dialog_acts[i]}\n{golden_responses[i]}\n{outputs[i]}\n\n')
+            f.close()
+
+
+if __name__ == '__main__':
+    dataset = load_dataset(FLAGS.dataset)
+    nlg_data = load_nlg_data(dataset)
+    if FLAGS.do_train:
+        train(model, nlg_data)
+    else:
+        test(model, nlg_data, 'saved_model/{TASK_TYPE}/19_save/19_step5840.pt')
+        # test_samples(f'saved_model/{TASK_TYPE}/19_save/19_step5840.pt')
+    # elif FLAGS.show_attn:
+    #     show_attention(f'saved_model/{TASK_TYPE}/19_save/19_step5840.pt')
\ No newline at end of file
diff --git a/convlab2/nlg/scgpt/model.py b/convlab2/nlg/scgpt/model.py
new file mode 100644
index 00000000..82df1464
--- /dev/null
+++ b/convlab2/nlg/scgpt/model.py
@@ -0,0 +1,24 @@
+from torch.utils.data import Dataset
+from util import act2str
+from scgpt_special_tokens import *
+import torch
+
+class SCGPTDataset(Dataset):
+    def __init__(self, data, tokenizer):
+        """
+        Args:
+            data: [[da_str, response], [da_str, response], ...]
+            tokenizer: GPT2 Tokenizer
+        """
+        self.data = []
+        for item in data:
+            da, response = item['dialogue_acts'], item['utterance']
+            da_tokens = tokenizer.encode(act2str(da))
+            response_tokens = tokenizer.encode(response)
+            self.data.append([da_tokens, response_tokens])
+
+    def __len__(self):
+        return len(self.data)
+
+    def __getitem__(self, idx):
+        return self.data[idx]
\ No newline at end of file
diff --git a/convlab2/nlg/scgpt/scgpt_special_tokens.py b/convlab2/nlg/scgpt/scgpt_special_tokens.py
new file mode 100644
index 00000000..643820dd
--- /dev/null
+++ b/convlab2/nlg/scgpt/scgpt_special_tokens.py
@@ -0,0 +1,14 @@
+# separator
+SYS_SPEAK = '[sys_speak]'
+USR_SPEAK = '[usr_speak]'
+START_OF_PRED = '[start_of_pred]'
+END_OF_PRED = '[end_of_pred]'
+PAD_TOKEN = '[_pad_token_]'
+START_OF_INTENT = '[start_of_intent]'
+END_OF_INTENT = '[end_of_intent]'
+START_OF_SLOT = ''
+
+SPECIAL_TOKENS = [val for name, val in globals().items() if
+                  str.isupper(name) and isinstance(val, str) and val and val[0] == '[' and val[-1] == ']']
+
+assert all(token.islower() for token in SPECIAL_TOKENS)
\ No newline at end of file
diff --git a/convlab2/nlg/scgpt/train.py b/convlab2/nlg/scgpt/train.py
deleted file mode 100644
index 0878f313..00000000
--- a/convlab2/nlg/scgpt/train.py
+++ /dev/null
@@ -1,682 +0,0 @@
-from __future__ import absolute_import, division, print_function
-
-import argparse
-import glob
-import logging
-import os
-import pickle
-import random
-import re
-import shutil
-
-import numpy as np
-import torch
-from tqdm import tqdm, trange
-from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler
-from torch.utils.data.distributed import DistributedSampler
-
-try:
-    from torch.utils.tensorboard import SummaryWriter
-except ImportError:
-    from tensorboardX import SummaryWriter
-
-from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
-                          BertConfig, BertForMaskedLM, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
-                          OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, GPT2TokenizerFast,
-                          RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
-                          DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer, BertTokenizer)
-from convlab2.nlg.scgpt.modeling_utils import AmpGPT2LMHeadModel, try_enable_gradient_checkpointing, AmpHelper
-
-logger = logging.getLogger(__name__)
-
-MODEL_CLASSES = {
-    'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast),
-    'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
-    'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
-    'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
-    'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
-}
-
-
-def closest_multiple_of_8(n):
-    """
-    Returns:
-        a closest number, which is a multiple of 8 and >= n
-    """
-    return ((n + 7) >> 3) << 3
-
-
-class TextDataset(Dataset):
-    def __init__(self, tokenizer, args, file_path='train', block_size=512, max_seq=80):
-        assert os.path.isfile(file_path)
-        directory, filename = os.path.split(file_path)
-        cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(
-            block_size) + '_seqlen_' + str(max_seq) + '_' + filename)
-
-        if os.path.exists(cached_features_file) and not args.overwrite_cache:
-            logger.info("Loading features from cached file %s", cached_features_file)
-            with open(cached_features_file, 'rb') as handle:
-                self.examples = pickle.load(handle)
-        else:
-            logger.info("Creating features from dataset file at %s", directory)
-
-            self.examples = []
-
-            with open(file_path, encoding="utf-8") as f:
-                if args.text_chunk:
-                    text = f.read()
-                    tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
-                else:
-                    for line in f:
-                        tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line.strip() + ' eos'))
-                        self.examples.append(tokenized_text)
-
-            if args.text_chunk:
-                for i in range(0, len(tokenized_text) - block_size + 1, block_size):  # Truncate in block of block_size
-                    self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i:i + block_size]))
-
-            # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
-            # If your dataset is small, first you should look for a bigger one :-) and second you
-            # can change this behavior by adding (model specific) padding.
-
-            logger.info("Saving features into cached file %s", cached_features_file)
-            with open(cached_features_file, 'wb') as handle:
-                pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
-
-    def __len__(self):
-        return len(self.examples)
-
-    def __getitem__(self, item):
-        return torch.tensor(self.examples[item])
-
-
-class TextSeqDataset(Dataset):
-    def __init__(self, tokenizer, args, file_path='train', block_size=512, max_seq=80, separator=' & '):
-        max_seq = closest_multiple_of_8(max_seq)
-        assert os.path.isfile(file_path)
-        directory, filename = os.path.split(file_path)
-        cached_features_file = os.path.join(directory, args.output_dir.replace(os.sep, '_') + '_cached_lm_' + str(
-            block_size) + '_seqlen_' + str(max_seq) + '_' + filename)
-
-        if os.path.exists(cached_features_file) and not args.overwrite_cache:
-            logger.info("Loading features from cached file %s", cached_features_file)
-            with open(cached_features_file, 'rb') as handle:
-                self.examples, self.masks, self.labels,  self.seq_lengths = pickle.load(handle)
-        else:
-            logger.info("Creating features from dataset file at %s", directory)
-            self.examples = []
-            self.labels = []
-            self.masks = []
-            self.seq_lengths = []
-            with open(file_path, encoding="utf-8") as f:
-                for line in tqdm(f):
-                    line = line.strip()
-                    raw_str = line.lower()  # do we need lowercase?
-                    code_str = line.lower().split(separator)[0] + separator
-                    code_str = code_str.strip()
-                    if len(raw_str.split()) > max_seq -1:
-                        raw_str = ' '.join(raw_str.split()[:max_seq -1])
-                    raw_str += ' ' + tokenizer.eos_token
-                    if args.use_tokenize:
-                        tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_str))
-                        code_str_len =  len(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(code_str)))
-                    else:
-                        tokenized_text = tokenizer.convert_tokens_to_ids(raw_str.split())
-                        code_str_len =  len(tokenizer.convert_tokens_to_ids(code_str.split()))
-
-                    label = [-1] *  max_seq
-                    label[:len(tokenized_text)] = tokenized_text
-                    mask = [1] *  max_seq
-
-                    if len(tokenized_text) < max_seq:
-                        self.seq_lengths.append(len(tokenized_text))
-                        mask[-(max_seq - len(tokenized_text)):] = [0] * (max_seq - len(tokenized_text))
-                        # label[code_str_len:len(tokenized_text)] = tokenized_text[code_str_len:]
-                        tokenized_text = tokenized_text + [tokenizer.eos_token_id] * (max_seq - len(tokenized_text))
-                    else:
-                        self.seq_lengths.append(max_seq)
-                        tokenized_text = tokenized_text[:max_seq]
-                        # label[code_str_len:] = tokenized_text[code_str_len:]
-
-                    self.examples.append(tokenized_text)
-                    self.masks.append(mask)
-                    self.labels.append(label)
-
-            # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
-            # If your dataset is small, first you should look for a bigger one :-) and second you
-            # can change this behavior by adding (model specific) padding.
-            if args.with_code_loss:
-                self.labels = self.examples
-            logger.info("Saving features into cached file %s", cached_features_file)
-            with open(cached_features_file, 'wb') as handle:
-                pickle.dump((self.examples, self.masks, self.labels, self.seq_lengths), handle,
-                            protocol=pickle.HIGHEST_PROTOCOL)
-
-    def __len__(self):
-        return len(self.examples)
-
-    def __getitem__(self, item):
-        return torch.tensor(self.examples[item]), torch.tensor(self.masks[item]), torch.tensor(
-            self.labels[item]), torch.tensor(self.seq_lengths[item])
-
-
-def load_and_cache_examples(args, tokenizer, evaluate=False):
-    dataset = TextSeqDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file,
-                             block_size=args.block_size, max_seq=args.max_seq)
-    return dataset
-
-
-def set_seed(args):
-    random.seed(args.seed)
-    np.random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    if args.n_gpu > 0:
-        torch.cuda.manual_seed_all(args.seed)
-
-
-def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False):
-    if not args.save_total_limit:
-        return
-    if args.save_total_limit <= 0:
-        return
-
-    # Check if we should delete older checkpoint(s)
-    glob_checkpoints = glob.glob(os.path.join(args.output_dir, '{}-*'.format(checkpoint_prefix)))
-    if len(glob_checkpoints) <= args.save_total_limit:
-        return
-
-    ordering_and_checkpoint_path = []
-    for path in glob_checkpoints:
-        if use_mtime:
-            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
-        else:
-            regex_match = re.match('.*{}-([0-9]+)'.format(checkpoint_prefix), path)
-            if regex_match and regex_match.groups():
-                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
-
-    checkpoints_sorted = sorted(ordering_and_checkpoint_path)
-    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
-    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
-    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
-    for checkpoint in checkpoints_to_be_deleted:
-        logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
-        shutil.rmtree(checkpoint)
-
-
-def mask_tokens(inputs, tokenizer, args):
-    """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
-    labels = inputs.clone()
-    # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
-    probability_matrix = torch.full(labels.shape, args.mlm_probability)
-    special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in
-                           labels.tolist()]
-    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
-    masked_indices = torch.bernoulli(probability_matrix).bool()
-    labels[~masked_indices] = -1  # We only compute loss on masked tokens
-
-    # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
-    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
-    inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
-
-    # 10% of the time, we replace masked input tokens with random word
-    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
-    random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
-    inputs[indices_random] = random_words[indices_random]
-
-    # The rest of the time (10% of the time) we keep the masked input tokens unchanged
-    return inputs, labels
-
-
-def preprocess_batch(inputs, masks, labels, seq_lengths):
-    """
-    The real sequence length of a batch may be shorter than max_seq of the whole dataset.
-    Remove some padding tokens to accelerate the training process.
-    And make sure that the sequence length is multiple of 8.
-
-    References:
-        https://huggingface.co/transformers/performance.html#fp16
-    """
-    # The gain for FP16 training is that in each of those cases, the training with the flag --fp16 is twice as fast,
-    # which does require every tensor to have every dimension be a multiple of 8
-    # (examples pad the tensors to a sequence length that is a multiple of 8).
-    max_seq_len = seq_lengths.max()
-    max_seq_len = closest_multiple_of_8(max_seq_len)
-    return inputs[:, :max_seq_len], masks[:, :max_seq_len], labels[:, :max_seq_len]
-
-
-def train(args, train_dataset, model, tokenizer):
-    """ Train the model """
-    if args.local_rank in [-1, 0]:
-        tb_writer = SummaryWriter()
-
-    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
-    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
-    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
-
-    if args.max_steps > 0:
-        t_total = args.max_steps
-        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
-    else:
-        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
-
-    # Prepare optimizer and schedule (linear warmup and decay)
-    no_decay = ['bias', 'LayerNorm.weight']
-    optimizer_grouped_parameters = [
-        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
-         'weight_decay': args.weight_decay},
-        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
-    ]
-    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
-    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
-                                                num_training_steps=t_total)
-    # https://pytorch.org/docs/master/notes/amp_examples.html
-    amp_helper = AmpHelper(use_amp=args.fp16)
-    if args.n_gpu > 1:
-        model = torch.nn.DataParallel(model)
-
-    # Distributed training
-    if args.local_rank != -1:
-        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
-                                                          output_device=args.local_rank,
-                                                          find_unused_parameters=False)
-
-    # Train!
-    logger.info("***** Running training *****")
-    logger.info("  Num examples = %d", len(train_dataset))
-    logger.info("  Num Epochs = %d", args.num_train_epochs)
-    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
-    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d",
-                args.train_batch_size * args.gradient_accumulation_steps * (
-                    torch.distributed.get_world_size() if args.local_rank != -1 else 1))
-    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
-    logger.info("  Total optimization steps = %d", t_total)
-
-    global_step = 0
-    tr_loss, logging_loss = 0.0, 0.0
-    model.zero_grad()
-    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
-    set_seed(args)  # Added here for reproducibility (even between python 2 and 3)
-    for e in train_iterator:
-
-        # epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
-        for step, batch in enumerate(train_dataloader):
-            # inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
-            logger.info(f"  PROGRESS: {float(global_step) / t_total * 100}%")
-            inputs, masks, labels, seq_lengths = batch
-            inputs, masks, labels = preprocess_batch(inputs, masks, labels, seq_lengths)  # cut seq
-            # import pdb
-            # pdb.set_trace()
-            inputs = inputs.to(args.device)
-            # masks = masks.to(args.device)
-            labels = labels.to(args.device)
-
-            model.train()
-            try:
-                with amp_helper.might_enable_autocast:
-                    outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
-                    loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
-
-                    if args.n_gpu > 1:
-                        loss = loss.mean()  # mean() to average on multi-gpu parallel training
-                    if args.gradient_accumulation_steps > 1:
-                        loss = loss / args.gradient_accumulation_steps
-
-                amp_helper.backward(loss)
-            except RuntimeError as e:
-                if 'CUDA out of memory' in str(e):
-                    # if out of memory, we must choose smaller batch_size
-                    print(f'inputs.shape = {inputs.shape}, labels.shape = {labels.shape}')
-                raise
-
-            tr_loss += loss.item()
-            if (step + 1) % args.gradient_accumulation_steps == 0:
-                amp_helper.might_unscale_(optimizer)
-                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
-                # optimizer.step()
-                amp_helper.step(optimizer)
-                scheduler.step()  # Update learning rate schedule
-                model.zero_grad()
-                global_step += 1
-
-                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
-                    # Log metrics
-                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
-                        results = evaluate(args, model, tokenizer)
-                        for key, value in results.items():
-                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
-                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
-                    tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
-                    logger.info(f"  EVALERR:  {(tr_loss - logging_loss) / float(args.logging_steps)}")
-                    logging_loss = tr_loss
-
-                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
-                    checkpoint_prefix = 'checkpoint'
-                    # Save model checkpoint
-                    output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step))
-                    if not os.path.exists(output_dir):
-                        os.makedirs(output_dir)
-                    model_to_save = model.module if hasattr(model,
-                                                            'module') else model  # Take care of distributed/parallel training
-                    model_to_save.save_pretrained(output_dir)
-                    tokenizer.save_pretrained(output_dir)
-                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
-                    logger.info("Saving model checkpoint to %s", output_dir)
-
-                    _rotate_checkpoints(args, checkpoint_prefix)
-
-            if global_step > args.max_steps > 0:
-                train_iterator.close()
-                break
-
-    if args.local_rank in [-1, 0]:
-        tb_writer.close()
-
-    return global_step, tr_loss / global_step
-
-
-def evaluate(args, model, tokenizer, prefix=""):
-    # Loop to handle MNLI double evaluation (matched, mis-matched)
-    eval_output_dir = args.output_dir
-
-    eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
-
-    if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
-        os.makedirs(eval_output_dir)
-
-    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
-    # Note that DistributedSampler samples randomly
-    eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
-    eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
-
-    # multi-gpu evaluate
-    if args.n_gpu > 1 and not (isinstance(model, torch.nn.DataParallel) or
-                               isinstance(model, torch.nn.parallel.DistributedDataParallel)):
-        # if args.evaluate_during_training, DataParallel is already used
-        model = torch.nn.DataParallel(model)
-
-    # Eval!
-    logger.info("***** Running evaluation {} *****".format(prefix))
-    logger.info("  Num examples = %d", len(eval_dataset))
-    logger.info("  Batch size = %d", args.eval_batch_size)
-    eval_loss = 0.0
-    nb_eval_steps = 0
-    model.eval()
-
-    for batch in tqdm(eval_dataloader, desc="Evaluating"):
-        # inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
-
-        inputs, masks, labels, seq_lengths = batch
-        inputs, masks, labels = preprocess_batch(inputs, masks, labels, seq_lengths)  # cut seq
-        # import pdb
-        # pdb.set_trace()
-        inputs = inputs.to(args.device)
-        masks = masks.to(args.device)
-        labels = labels.to(args.device)
-        # inputs = inputs.to(args.device)
-        # labels = labels.to(args.device)
-
-        with torch.no_grad():
-            outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
-            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)
-            eval_loss += loss.mean().item()
-        nb_eval_steps += 1
-
-    eval_loss = eval_loss / nb_eval_steps
-    perplexity = float(np.exp(eval_loss))
-
-    result = {
-        "perplexity": perplexity
-    }
-
-    output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
-    with open(output_eval_file, "w") as writer:
-        logger.info("***** Eval results {} *****".format(prefix))
-        for key in sorted(result.keys()):
-            logger.info("  %s = %s", key, str(result[key]))
-            writer.write("%s = %s\n" % (key, str(result[key])))
-
-    return result
-
-
-def main():
-    global AdamW
-    parser = argparse.ArgumentParser()
-
-    ## Required parameters
-    parser.add_argument("--train_data_file", default=None, type=str, required=True,
-                        help="The input training data file (a text file).")
-    parser.add_argument("--output_dir", default=None, type=str, required=True,
-                        help="The output directory where the model predictions and checkpoints will be written.")
-
-    ## Other parameters
-    parser.add_argument("--eval_data_file", default=None, type=str,
-                        help="An optional input evaluation data file to evaluate the perplexity on (a text file).")
-
-    parser.add_argument("--model_type", default="gpt2", type=str,
-                        help="The model architecture to be fine-tuned.")
-    parser.add_argument("--model_name_or_path", default="gpt2", type=str,
-                        help="The model checkpoint for weights initialization.")
-
-    parser.add_argument("--mlm", action='store_true',
-                        help="Train with masked-language modeling loss instead of language modeling.")
-    parser.add_argument("--mlm_probability", type=float, default=0.15,
-                        help="Ratio of tokens to mask for masked language modeling loss")
-
-    parser.add_argument("--config_name", default="", type=str,
-                        help="Optional pretrained config name or path if not the same as model_name_or_path")
-    parser.add_argument("--tokenizer_name", default="", type=str,
-                        help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
-    parser.add_argument("--cache_dir", default="", type=str,
-                        help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
-    parser.add_argument("--block_size", default=80, type=int,
-                        help="Optional input sequence length after tokenization."
-                             "The training dataset will be truncated in block of this size for training."
-                             "Default to the model max input length for single sentence inputs (take into account special tokens).")
-    parser.add_argument("--do_train", action='store_true',
-                        help="Whether to run training.")
-    parser.add_argument("--do_eval", action='store_true',
-                        help="Whether to run eval on the dev set.")
-    parser.add_argument("--evaluate_during_training", action='store_true',
-                        help="Run evaluation during training at each logging step.")
-    parser.add_argument("--do_lower_case", action='store_true',
-                        help="Set this flag if you are using an uncased model.")
-
-    parser.add_argument("--per_gpu_train_batch_size", default=1, type=int,
-                        help="Batch size per GPU/CPU for training.")
-    parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int,
-                        help="Batch size per GPU/CPU for evaluation.")
-    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
-                        help="Number of updates steps to accumulate before performing a backward/update pass.")
-    parser.add_argument("--learning_rate", default=1e-5, type=float,
-                        help="The initial learning rate for Adam.")
-    parser.add_argument("--weight_decay", default=0.0, type=float,
-                        help="Weight deay if we apply some.")
-    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
-                        help="Epsilon for Adam optimizer.")
-    parser.add_argument("--max_grad_norm", default=1.0, type=float,
-                        help="Max gradient norm.")
-    parser.add_argument("--num_train_epochs", default=5.0, type=float,
-                        help="Total number of training epochs to perform.")
-    parser.add_argument("--max_steps", default=-1, type=int,
-                        help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
-    parser.add_argument("--warmup_steps", default=0, type=int,
-                        help="Linear warmup over warmup_steps.")
-
-    parser.add_argument('--logging_steps', type=int, default=100,
-                        help="Log every X updates steps.")
-    parser.add_argument('--save_steps', type=int, default=5000,
-                        help="Save checkpoint every X updates steps.")
-    parser.add_argument('--save_total_limit', type=int, default=None,
-                        help='Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default')
-    parser.add_argument("--eval_all_checkpoints", action='store_true',
-                        help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
-    parser.add_argument("--no_cuda", action='store_true',
-                        help="Avoid using CUDA when available")
-    parser.add_argument('--overwrite_output_dir', action='store_true',
-                        help="Overwrite the content of the output directory")
-    parser.add_argument('--overwrite_cache', action='store_true',
-                        help="Overwrite the cached training and evaluation sets")
-    parser.add_argument('--seed', type=int, default=42,
-                        help="random seed for initialization")
-
-    parser.add_argument('--fp16', action='store_true',
-                        help="Whether to use 16-bit (mixed) precision (through torch.cuda.amp) instead of 32-bit")
-    parser.add_argument("--local_rank", type=int, default=-1,
-                        help="For distributed training: local_rank")
-    parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
-    parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
-    parser.add_argument('--text_chunk', action='store_true', help="")
-    parser.add_argument('--use_reverse', action='store_true', help="")
-    parser.add_argument('--with_code_loss', type=bool, default=True, help="")
-    parser.add_argument('--use_tokenize', action='store_true', help="")
-
-    parser.add_argument("--max_seq", default=80, type=int,
-                        help="")
-    parser.add_argument('--gradient_checkpointing', action='store_true', help='enable gradient checkpointing')
-    parser.add_argument('--use_multi_tensor_adamw', action='store_true',
-                        help='use torch.optim._multi_tensor.AdamW instead of transformers.AdamW')
-
-    args = parser.parse_args()
-    if args.use_multi_tensor_adamw:
-        try:
-            # overwrite the previous imported AdamW
-            # https://huggingface.co/transformers/performance.html#faster-optimizer
-            from torch.optim._multi_tensor import AdamW
-        except ImportError as e:
-            print(e)
-
-    if args.model_type in ["bert", "roberta", "distilbert"] and not args.mlm:
-        raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
-                         "flag (masked language modeling).")
-    if args.eval_data_file is None and args.do_eval:
-        raise ValueError(
-            "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
-            "or remove the --do_eval argument.")
-
-    if os.path.exists(args.output_dir) and os.listdir(
-            args.output_dir) and args.do_train and not args.overwrite_output_dir:
-        raise ValueError(
-            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
-                args.output_dir))
-
-    # Setup distant debugging if needed
-    if args.server_ip and args.server_port:
-        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
-        import ptvsd
-        print("Waiting for debugger attach")
-        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
-        ptvsd.wait_for_attach()
-
-    # Setup logging before `torch.distributed.init_process_group` is called
-    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
-                        datefmt='%m/%d/%Y %H:%M:%S',
-                        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
-
-    # Setup CUDA, GPU & distributed training
-    if args.local_rank == -1 or args.no_cuda:
-        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
-        args.n_gpu = torch.cuda.device_count()
-    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
-        torch.cuda.set_device(args.local_rank)
-        device = torch.device("cuda", args.local_rank)
-        torch.distributed.init_process_group(backend='nccl')
-        args.n_gpu = 1
-    args.device = device
-    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
-                   args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
-
-    # Set seed
-    set_seed(args)
-
-    # Load pretrained model and tokenizer
-    if args.local_rank not in [-1, 0]:
-        torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training download model & vocab
-
-    if args.fp16:
-        MODEL_CLASSES['gpt2'] = (GPT2Config, AmpGPT2LMHeadModel, GPT2TokenizerFast)
-    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
-    config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
-                                          cache_dir=args.cache_dir if args.cache_dir else None)
-    tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
-                                                # tokenizer = BertTokenizer(vocab_file='../GPT2-chitchat/vocabulary/vocab_small.txt', eos_token='<T>',
-                                                do_lower_case=args.do_lower_case,
-                                                cache_dir=args.cache_dir if args.cache_dir else None)
-
-    if args.block_size <= 0:
-        args.block_size = tokenizer.max_len_single_sentence  # Our input block size will be the max possible for the model
-    args.block_size = min(args.block_size, tokenizer.max_len_single_sentence)
-    model = model_class.from_pretrained(args.model_name_or_path,
-                                        from_tf=bool('.ckpt' in args.model_name_or_path),
-                                        config=config,
-                                        cache_dir=args.cache_dir if args.cache_dir else None)
-    if model.config.vocab_size != len(tokenizer):
-        logger.info('resize token embeddings, since there may be added tokens.')
-        model.resize_token_embeddings(len(tokenizer))
-    model.to(args.device)
-    if args.gradient_checkpointing:
-        # https://huggingface.co/transformers/performance.html#gradient-checkpointing
-        try_enable_gradient_checkpointing(model)
-
-    if args.local_rank == 0:
-        torch.distributed.barrier()  # End of barrier to make sure only the first process in distributed training download model & vocab
-
-    logger.info("Training/evaluation parameters %s", args)
-
-    # Training
-    if args.do_train:
-        if args.local_rank not in [-1, 0]:
-            torch.distributed.barrier()  # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache
-
-        train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)
-
-        if args.local_rank == 0:
-            torch.distributed.barrier()
-
-        global_step, tr_loss = train(args, train_dataset, model, tokenizer)
-        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
-
-    # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
-    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
-        # Create output directory if needed
-        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
-            os.makedirs(args.output_dir)
-
-        logger.info("Saving model checkpoint to %s", args.output_dir)
-        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
-        # They can then be reloaded using `from_pretrained()`
-        model_to_save = model.module if hasattr(model,
-                                                'module') else model  # Take care of distributed/parallel training
-        model_to_save.save_pretrained(args.output_dir)
-        tokenizer.save_pretrained(args.output_dir)
-
-        # Good practice: save your training arguments together with the trained model
-        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
-
-        # Load a trained model and vocabulary that you have fine-tuned
-        model = model_class.from_pretrained(args.output_dir)
-        tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
-        model.to(args.device)
-
-    # Evaluation
-    results = {}
-    if args.do_eval and args.local_rank in [-1, 0]:
-        checkpoints = [args.output_dir]
-        if args.eval_all_checkpoints:
-            checkpoints = list(
-                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
-            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
-        logger.info("Evaluate the following checkpoints: %s", checkpoints)
-        for checkpoint in checkpoints:
-            global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
-            prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
-
-            model = model_class.from_pretrained(checkpoint)
-            model.to(args.device)
-            result = evaluate(args, model, tokenizer, prefix=prefix)
-            result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
-            results.update(result)
-    return results
-
-
-if __name__ == "__main__":
-    main()
diff --git a/convlab2/nlg/scgpt/train.sh b/convlab2/nlg/scgpt/train.sh
new file mode 100644
index 00000000..5a869e3d
--- /dev/null
+++ b/convlab2/nlg/scgpt/train.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES="0,1,2,3" python -m torch.distributed.launch --nproc_per_node 4 main.py --do_train --dataset multiwoz21
\ No newline at end of file
diff --git a/convlab2/nlg/scgpt/util.py b/convlab2/nlg/scgpt/util.py
new file mode 100644
index 00000000..b0efa589
--- /dev/null
+++ b/convlab2/nlg/scgpt/util.py
@@ -0,0 +1,94 @@
+import torch
+
+def act2str(act):
+    """Convert unified dataset dialog act dict to string.
+    act:
+        {'categorical': [{'intent': 'inform', 'domain': 'restaurant', 'slot': 'area', 'value': 'north'}],
+        'non-categorical': [{'intent': 'inform', 'domain': 'hotel', 'slot': 'area', 'value': 'north'}],
+        'binary': [{'intent': 'request', 'domain': 'hotel', 'slot': 'area'}]}
+    return:
+        restaurant { inform ( area = north ) } | hotel { inform ( area = north ) @ request ( area ) }
+    """
+    old_format_dict = convert2old_format(act)
+    return dict2seq(old_format_dict)
+
+
+def build_mask(max_len, seq_lens, use_float=False):
+    """
+    make one-hot masks given seq_lens list.
+    e.g., input: max_len=4, seq_lens=[2,3], return: [[1,1,0,0], [1,1,1,0]]
+    Args:
+        max_len (int): maximum sequence length
+        seq_lens (torch.Tensor): (batch)
+    Returns:
+        mask (torch.Tensor): (batch, max_len)
+    """
+    a = torch.arange(max_len)[None, :]
+    b = seq_lens[:, None].cpu()
+    mask = a < b
+    if use_float:
+        mask = mask.float()
+    return mask
+
+
+def convert2old_format(act):
+    """
+    dict: {'categorical': [{'intent': 'request', 'domain': 'hotel', 'slot': 'area'}], 'non-categorical': [...], 'binary': [,,,]}
+    """
+    new_act = {}
+    for key in act:
+        for item_dic in act[key]:
+            domain = item_dic['domain']
+            if domain not in new_act:
+                new_act[domain] = {}
+            intent = item_dic['intent']
+            if intent not in new_act[domain]:
+                new_act[domain][intent] = []
+            slot = item_dic['slot']
+            if 'value' in item_dic:
+                value = item_dic['value']
+            else:
+                value = None
+            new_act[domain][intent].append([slot, value])
+    return new_act
+
+
+def dict2seq(d):
+    '''
+    dict: [domain: { intent: [slot, value] }]
+    seq: [domain { intent ( slot = value ; ) @ } | ]
+    '''
+    s = ''
+    first_domain = True
+    first_intent = True
+    first_slot = True
+    for domain in d:
+        if not first_domain:
+            s += ' | '
+        s += domain
+        s += ' { '
+        first_domain = False
+        first_intent = True
+        for intent in d[domain]:
+            if not first_intent:
+                s += ' @ '
+            s += intent
+            s += ' ( '
+            first_intent = False
+            first_slot = True
+            for slot, value in d[domain][intent]:
+                if not first_slot:
+                    s += ' ; '
+                s += slot
+                if value:
+                    s += ' = '
+                    s += value
+                first_slot = False
+            s += ' )'
+        s += ' }'
+    return s.lower()
+
+
+if __name__ == '__main__':
+    ipt = {'categorical': [{'intent': 'inform', 'domain': 'restaurant', 'slot': 'area', 'value': 'north'}], 'non-categorical': [{'intent': 'inform', 'domain': 'hotel', 'slot': 'area', 'value': 'north'}], 'binary': [{'intent': 'request', 'domain': 'hotel', 'slot': 'area'}]}
+    print(act2str(ipt))
\ No newline at end of file
-- 
GitLab