diff --git a/convlab/nlg/evaluate_unified_datasets.py b/convlab/nlg/evaluate_unified_datasets.py
index 7a19a49267f1c526bbca29201893a41f948afc67..23c937ed732ae547b0ca48dd81e6cce6f71c3c62 100644
--- a/convlab/nlg/evaluate_unified_datasets.py
+++ b/convlab/nlg/evaluate_unified_datasets.py
@@ -24,7 +24,7 @@ class Logging:
             f.write('\n')
             f.close()
 
-def evaluate(predict_result, ontology):
+def evaluate(predict_result, ontology, filter_empty_acts=True):
     predict_result = json.load(open(predict_result))
     metrics = {}
 
@@ -33,8 +33,16 @@ def evaluate(predict_result, ontology):
     references = []
     candidates = []
     for i in range(len(predict_result)):
+        if filter_empty_acts:
+            acts = predict_result[i]['dialogue_acts']
+            acts_size = len(acts['binary']) + len(acts['categorical']) + len(acts['non-categorical'])
+            if acts_size == 0:
+                continue
         references.append(predict_result[i]['utterance'])
-        candidates.append(predict_result[i]['predictions']['utterance'])
+        if 'prediction' in predict_result[i]:
+            candidates.append(predict_result[i]['prediction'])
+        else:
+            candidates.append(predict_result[i]['predictions']['utterance'])
     # metrics['bleu'] = corpus_bleu(references, candidates)
     references = [" " if ref=="" else ref for ref in references]
     metrics['bleu'] = sacrebleu.corpus_bleu(candidates, [references], lowercase=True).score
@@ -55,7 +63,7 @@ def evaluate(predict_result, ontology):
     score_list = []
     for item in predict_result:
         da = item['dialogue_acts']
-        utterance = item['predictions']['utterance']
+        utterance = item['predictions']['utterance'] if 'predictions' in item else item['prediction']
         missing_count = 0
         redundant_count = 0
         all_count = 0
diff --git a/convlab/nlg/scgpt/evaluate.sh b/convlab/nlg/scgpt/evaluate.sh
old mode 100644
new mode 100755
index 9bf8f52ac6c351d71b76dd442ce0257001bf07b3..c2ef94d35c0fdc48ef5de405bfce0d5c8a3828a2
--- a/convlab/nlg/scgpt/evaluate.sh
+++ b/convlab/nlg/scgpt/evaluate.sh
@@ -1,4 +1,8 @@
-CUDA_VISIBLE_DEVICES="5" python -m torch.distributed.launch --nproc_per_node 1 --master_port 3046 main.py \
---dataset multiwoz21 \
---scgpt_model_ckpt_path /data/zhangzheng/scgpt \
---model_path /data/zhangzheng/ConvLab-3/convlab/nlg/scgpt/saved_model/epoch_4/epoch_4_step8875.pt
\ No newline at end of file
+CUDA_VISIBLE_DEVICES="0" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2052 main.py \
+--batch_size 1 \
+--base_model_name_path gpt2-medium \
+--dataset tm3 \
+--exp_name tm3_mst_test \
+--model_path saved_models/mwoz_sgd_tm_train/epoch_5/epoch_5_step19206.pt \
+# --model_path saved_models/gpt2_tm_direct/epoch_19/epoch_19_step65540.pt \
+# --model_path saved_models/gpt2_tm_direct/epoch_6/epoch_6_step22939.pt \
\ No newline at end of file
diff --git a/convlab/nlg/scgpt/main.py b/convlab/nlg/scgpt/main.py
index f3d48fbe7876c0a3bc31351dcb316327f3700da5..6c2fd505ed7da4ddac9971267bd6a6ba50bc402f 100644
--- a/convlab/nlg/scgpt/main.py
+++ b/convlab/nlg/scgpt/main.py
@@ -4,11 +4,13 @@ sys.path.append('../../..')
 import argparse
 import json
 from tqdm import tqdm
+import time
 import torch
+from functools import reduce
 import numpy as np
 import torch.nn as nn
 import torch.nn.functional as F
-from transformers import GPT2Tokenizer, GPT2LMHeadModel
+from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
 from torch.utils.data import DataLoader
 from torch.utils.tensorboard import SummaryWriter
 import os
@@ -29,11 +31,21 @@ code_test = False
 
 parser = argparse.ArgumentParser()
 parser.add_argument("--local_rank", default=-1, type=int)
+parser.add_argument("--lr", default=1e-5, type=float, help="learning rate")
+parser.add_argument("--batch_size", default=32, type=int)
+parser.add_argument("--train_ratio", default=1.0, type=float)
+parser.add_argument("--accumulation_step", default=4, type=int)
+parser.add_argument("--epoch_num", default=20, type=int)
+parser.add_argument("--val_step", default=100, type=int)
 parser.add_argument('--do_train', action="store_true", help="Whether to run training.")
 parser.add_argument('--dataset', default="multiwoz21", type=str, help="The name of the dataset to be used.")
 parser.add_argument('--model_path', default="", type=str, help="The path of model for testing.")
-parser.add_argument('--scgpt_model_ckpt_path', default="", type=str, help="The path of model for testing.")
+parser.add_argument('--base_model_name_path', default="gpt2", type=str, help="The path of base model.")
+parser.add_argument('--scgpt_model_ckpt_path', default=None, type=str, help="The path of model for testing.")
+parser.add_argument('--save_path', default="saved_models", type=str, help="Model save path.")
+parser.add_argument('--exp_name', default="default_name", type=str, help="Current experiment name.")
 parser.add_argument("--max_seq_len", default=128, type=int)
+parser.add_argument("--save_epoch_interval", default=1, type=int)
 FLAGS = parser.parse_args()
 local_rank = FLAGS.local_rank
 
@@ -41,25 +53,20 @@ torch.cuda.set_device(local_rank)
 dist.init_process_group(backend='nccl')
 
 # TensorBoard
-tb_dir = './runs'
+tb_dir = 'runs/' + FLAGS.exp_name
 if not os.path.exists(tb_dir):
     os.mkdir(tb_dir)
-tb_writer = SummaryWriter(tb_dir)
+tb_writer = SummaryWriter(tb_dir, flush_secs=5)
 
-special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK]
 ## load model
-if FLAGS.scgpt_model_ckpt_path == '':
-    tokenizer = GPT2Tokenizer.from_pretrained('./gpt2')
-    tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, 'additional_special_tokens': special_tokens})
-    model = GPT2LMHeadModel.from_pretrained('./gpt2').to(local_rank)
-    model.resize_token_embeddings(len(tokenizer))
+if FLAGS.scgpt_model_ckpt_path is None:
+    tokenizer = GPT2Tokenizer.from_pretrained(FLAGS.base_model_name_path)
+    model = GPT2LMHeadModel.from_pretrained(FLAGS.base_model_name_path).to(local_rank)
 else:
-    tokenizer = GPT2Tokenizer.from_pretrained(FLAGS.scgpt_model_ckpt_path)
-    tokenizer.add_special_tokens(
-        {'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED, 'additional_special_tokens': special_tokens})
-    model = GPT2LMHeadModel.from_pretrained(FLAGS.scgpt_model_ckpt_path).to(local_rank)
+    tokenizer = GPT2Tokenizer.from_pretrained(FLAGS.base_model_name_path)
+    model = GPT2LMHeadModel(config=GPT2Config.from_pretrained(FLAGS.base_model_name_path)).to(local_rank)
+    model.load_state_dict(torch.load(FLAGS.scgpt_model_ckpt_path))
     print('model load from ' + FLAGS.scgpt_model_ckpt_path)
-    model.resize_token_embeddings(len(tokenizer))
 
 nll_loss = nn.NLLLoss(reduce=False).to(local_rank)
 ce_loss = nn.CrossEntropyLoss(reduce=False).to(local_rank)
@@ -74,108 +81,105 @@ def cal_loss(input, target, seq_lens, seq_lens_input):
     input_mask = build_mask(torch.max(seq_lens).item()-1, seq_lens_input-1).to(local_rank)
     output_mask = torch.logical_xor(mask, input_mask)
     pad_mask = torch.logical_not(mask)
-    # masked_loss = loss * output_mask
-    masked_loss = loss * (output_mask + pad_mask)
+    masked_loss = loss * output_mask
+    # masked_loss = loss * (output_mask + pad_mask)
     mean_loss = torch.sum(masked_loss) / torch.sum(output_mask + pad_mask)
     return mean_loss
 
 
-def pad_collate(batch):
+def pad_collate(ori_batch):
     """
     Returns:
     batch: batch * max_len
     seq_lens: the length of len(da)+1+len(response)
     seq_lens_input: the length of len(da)
     """
-    START_OF_PRED_ID = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED)
-    pad_token_id = tokenizer.pad_token_id
-    batch = [item[0] + [START_OF_PRED_ID] + item[1] for item in batch]
+    START_OF_PRED_ID = tokenizer._convert_token_to_id_with_added_voc('&')
+    batch = [item[0] + [START_OF_PRED_ID] + item[1] + [tokenizer.eos_token_id] for item in ori_batch]
+    output_lens = [len(item[1])+1 for item in ori_batch]
     batch = [item[-FLAGS.max_seq_len:] for item in batch]
     max_len = max([len(item) for item in batch])
     # print('max_len', max_len)
     seq_lens = [len(item) for item in batch]
-    split_id = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED)
-    def get_x_len(tokens):
-        """Get the length of dialogue act tokens"""
-        split_idx = len(tokens)
-        try:
-            split_idx = tokens.index(split_id)+1
-        except:
-            pass
-        return split_idx
-    seq_lens_input = [get_x_len(item) for item in batch]
-    batch = [item + [pad_token_id]*(max_len-len(item)) for item in batch]
-    # print(batch)
-    # print(seq_lens)
-    # print(seq_lens_input)
+    seq_lens_input = []
+    for idx in range(len(batch)):
+        curr_ipt_len = seq_lens[idx] - output_lens[idx]
+        if curr_ipt_len < 0:
+            curr_ipt_len = 0
+        seq_lens_input.append(curr_ipt_len)
+    batch = [item + [0]*(max_len-len(item)) for item in batch]
     return torch.LongTensor(batch), torch.LongTensor(seq_lens), torch.LongTensor(seq_lens_input)
 
 ## Training Hyper-params
-EPOCH_NUM = 20
-BATCH_SIZE = 32   # real_batch_size = BATCH_SIZE * num_gpu
-VAL_STEP = 500
-WARM_STEPS = 250
-if code_test:
-    EPOCH_NUM = 2
-    BATCH_SIZE = 4
-    VAL_STEP = 2
-    WARM_STEPS = 3
-LR = 5e-5
-SAVE_PATH = f'./saved_model'
 def train(model, nlg_data, global_step=0):
-    train_dataset = SCGPTDataset(nlg_data['train'], tokenizer)
+    train_dataset = SCGPTDataset(filter_empty_nlg_data(nlg_data['train']), tokenizer)
     train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
-    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=2, sampler=train_sampler, collate_fn=pad_collate)
+    train_dataloader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, num_workers=2, sampler=train_sampler, collate_fn=pad_collate)
 
-    val_dataset = SCGPTDataset(nlg_data['validation'], tokenizer)
+    val_dataset = SCGPTDataset(filter_empty_nlg_data(nlg_data['validation']), tokenizer)
     val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
-    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=2, sampler=val_sampler, collate_fn=pad_collate)
+    val_dataloader = DataLoader(val_dataset, batch_size=FLAGS.batch_size, num_workers=2, sampler=val_sampler, collate_fn=pad_collate)
 
     model = DDP(model, device_ids=[local_rank], output_device=local_rank)
-    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
-    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARM_STEPS,
-                                                num_training_steps=len(train_dataloader) * EPOCH_NUM)
+    optimizer = torch.optim.AdamW(model.parameters(), lr=FLAGS.lr)
+    t_total = len(train_dataloader) * FLAGS.epoch_num // FLAGS.accumulation_step
+    warm_steps = int(0.1 * t_total)
+    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warm_steps,
+                                                num_training_steps=t_total)
     model.train()
-    for epoch in range(EPOCH_NUM):
+    for epoch in range(FLAGS.epoch_num):
         train_dataloader.sampler.set_epoch(epoch)
-        for batch_id, (inputs, seq_lens, seq_lens_input) in enumerate(tqdm(train_dataloader, desc=f'EPOCH:[{epoch+1}/{EPOCH_NUM}]')):
+        for batch_id, (inputs, seq_lens, seq_lens_input) in enumerate(tqdm(train_dataloader, desc=f'EPOCH:[{epoch+1}/{FLAGS.epoch_num}]')):
+            if (batch_id+1) % FLAGS.accumulation_step == 0:
+                global_step += 1
             inputs = inputs.to(local_rank)
             seq_lens = seq_lens.to(local_rank)
             seq_lens_input = seq_lens_input.to(local_rank)
-
-            outputs = model(inputs)
+            outputs = model(inputs, attention_mask=(inputs!=0).float())
             preds = outputs[0]
             loss = cal_loss(preds[:, :-1, :], inputs[:, 1:], seq_lens, seq_lens_input)
-
-            optimizer.zero_grad()
+            loss /= FLAGS.accumulation_step
+            loss /= dist.get_world_size() 
             loss.backward()
-            optimizer.step()
-            scheduler.step()
-            tb_writer.add_scalar(f'Train/loss', loss.item(), global_step)
-            tb_writer.add_scalar(f'Train/PPL', torch.exp(loss).item(), global_step)
-            tb_writer.add_scalar(f'Train/Learning Rate', scheduler.get_last_lr()[0], global_step)
-
-            global_step += 1
+            # update params
+            
+
+            if (batch_id+1) % FLAGS.accumulation_step == 0:
+                optimizer.step()
+                scheduler.step()
+                model.zero_grad()
+                # tensorboard
+                if dist.get_rank() == 0:
+                    tb_writer.add_scalar(f'Train/loss', loss.item(), global_step)
+                    tb_writer.add_scalar(f'Train/PPL', torch.exp(loss).item(), global_step)
+                    tb_writer.add_scalar(f'Train/Learning Rate', scheduler.get_last_lr()[0], global_step)
+                if global_step % FLAGS.val_step == 0:
+                    model.eval()
+                    val_loss = eval(model, val_dataloader)
+                    ppl = np.exp(val_loss)
+                    if dist.get_rank() == 0:
+                        tb_writer.add_scalar(f'Val/Loss', val_loss, global_step)
+                        tb_writer.add_scalar(f'Val/PPL', ppl, global_step)
+                    model.train()
+            
         # save the model when each epoch ends
         if dist.get_rank() == 0:
-
-            # vaidation
-            model.eval()
-            val_loss = eval(model, val_dataloader)
-            ppl = np.exp(val_loss)
-            tb_writer.add_scalar(f'Val/Loss', val_loss, global_step)
-            tb_writer.add_scalar(f'Val/PPL', ppl, global_step)
-            model.train()
-
-            # save model
-            save_dir = os.path.join(SAVE_PATH, f'epoch_{epoch}')
-            os.makedirs(save_dir, exist_ok=True)
-            torch.save(model.module.state_dict(), os.path.join(save_dir, f'epoch_{epoch}_step{global_step}.pt'))
-            tokenizer.save_pretrained(save_dir)
-            torch.save(optimizer.state_dict(), os.path.join(save_dir, 'optimizer.pt'))
-            torch.save(scheduler.state_dict(), os.path.join(save_dir, 'scheduler.pt'))
-            print(f'Save model checkpoint to [{save_dir}]')
-
+            if (epoch+1) % FLAGS.save_epoch_interval == 0:
+                # vaidation
+                model.eval()
+                val_loss = eval(model, val_dataloader)
+                ppl = np.exp(val_loss)
+                tb_writer.add_scalar(f'Val/Loss', val_loss, global_step)
+                tb_writer.add_scalar(f'Val/PPL', ppl, global_step)
+                model.train()
+                # save model
+                save_dir = os.path.join(FLAGS.save_path, FLAGS.exp_name, f'epoch_{epoch}')
+                os.makedirs(save_dir, exist_ok=True)
+                torch.save(model.module.state_dict(), os.path.join(save_dir, f'epoch_{epoch}_step{global_step}.pt'))
+                tokenizer.save_pretrained(save_dir)
+                torch.save(optimizer.state_dict(), os.path.join(save_dir, 'optimizer.pt'))
+                torch.save(scheduler.state_dict(), os.path.join(save_dir, 'scheduler.pt'))
+                print(f'Save model checkpoint to [{save_dir}]')
     tb_writer.flush()
 
 
@@ -198,17 +202,19 @@ def eval(model, loader, use_tqdm=False):
 def inference_batch(model, sents):
     """Inference model given a batch of sents."""
     with torch.no_grad():
-        sents = [sent + ' ' + START_OF_PRED for sent in sents]
+        sents = [sent + ' &' for sent in sents]
         sent_ids = [tokenizer.encode(sent) for sent in sents]
         max_len = max([len(sent) for sent in sent_ids])
-        sent_ids = [sent + [tokenizer.pad_token_id]*(max_len-len(sent)) for sent in sent_ids]
+        # ma_len = min(max_len, FLAGS.max_seq_len)
+        sent_ids = [[0]*(max_len-len(sent)) + sent for sent in sent_ids]
         inputs = torch.LongTensor(sent_ids).to(local_rank)
         model_to_run = model.module if type(model) is DDP else model
-        outputs = model_to_run.generate(inputs, max_length=FLAGS.max_seq_len, eos_token_id=tokenizer.pad_token_id,
-                                        pad_token_id=tokenizer.pad_token_id)  # greedy
+        outputs = model_to_run.generate(inputs, attention_mask=(inputs != 0).float(), max_length=FLAGS.max_seq_len, eos_token_id=tokenizer.eos_token_id)  # greedy
         # outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id,
         #                                 pad_token_id=gpt2_tokenizer.pad_token_id)  # beam search
-        output_strs = [tokenizer.decode(item) for item in outputs]
+        # output_strs = [tokenizer.decode(item) for item in outputs]
+        outputs = outputs[:, len(inputs[0]):]
+        output_strs = tokenizer.batch_decode(outputs)
         return output_strs
 
 
@@ -226,26 +232,42 @@ def inference_sents(model, sents):
     return outputs
 
 
+def inference_sents_by_batch(model, sents):
+    """Get the outputs of multiple sentences."""
+    start_idx = 0
+    ret = []
+    start = time.time()
+    while start_idx < len(sents):
+        end_idx = start_idx + FLAGS.batch_size
+        curr_sents = sents[start_idx:end_idx]
+        outputs = inference_batch(model, curr_sents)
+        ret += outputs
+        start_idx += FLAGS.batch_size
+        time_remain = (time.time()-start) / start_idx * (len(sents) - start_idx)
+        print('{}/{}, time remaining: {:.2f}'.format(start_idx, len(sents), time_remain))
+    return ret
+
+
 def test(model, nlg_data, ontology, model_path):
     """将sheel中的GPU个数设为1运行"""
     model.load_state_dict(torch.load(model_path))
     model.eval()
     print(f'model loaded from [{model_path}]')
     # Load test nlg data
-    test_data = nlg_data['test']
+    test_data = filter_empty_nlg_data(nlg_data['test'])
     dialog_acts = [act2str(item['dialogue_acts']).strip() for item in test_data]
     golden_responses = [item['utterance'].strip() for item in test_data]
     # dialog_acts = dialog_acts[:10]
     # golden_responses = golden_responses[:10]
-    outputs = inference_sents(model, dialog_acts)
+    outputs = inference_sents_by_batch(model, dialog_acts)
     def get_real_output(ipt):
-        if '[start_of_pred]' in ipt:
-            ipt = ipt[ipt.index('[start_of_pred]')+15:].strip()
-        if '[_pad_token_]' in ipt:
-            ipt = ipt[:ipt.index('[_pad_token_]')].strip()
+        if tokenizer.eos_token in ipt:
+            ipt = ipt[:ipt.index(tokenizer.eos_token)].strip()
         return ipt
     outputs = [get_real_output(item) for item in outputs]
-    output_file = './test_output.json'
+    if not os.path.exists('./test_outputs'):
+        os.makedirs('./test_outputs', exist_ok=True)
+    output_file = f'./test_outputs/{FLAGS.exp_name}.json'
     if dist.get_rank() == 0:
         with open(output_file, 'w+') as f:
             result = []
@@ -253,7 +275,9 @@ def test(model, nlg_data, ontology, model_path):
                 result.append({
                     'dialogue_acts': test_data[i]['dialogue_acts'],
                     'utterance': test_data[i]['utterance'],
-                    'prediction': outputs[i]
+                    'predictions': {
+                        'utterance': outputs[i]
+                    }
                 })
             json.dump(result, f, indent=2, ensure_ascii=False)
     evaluator = GentScorer()
@@ -307,12 +331,45 @@ def test(model, nlg_data, ontology, model_path):
     #     f.write(f'BLEU: {BLEU_Score}\nERR_Score: {ERR_Score}')
     #     f.close()
 
+def filter_empty_nlg_data(data):
+    ret = []
+    empty_number = 0
+    for item in data:
+        acts = item['dialogue_acts']
+        acts_size = len(acts['binary']) + len(acts['categorical']) + len(acts['non-categorical'])
+        if acts_size == 0:
+            empty_number += 1
+            continue
+        else:
+            ret.append(item)
+    print('empty count: ', empty_number)
+    return ret
+
 
 if __name__ == '__main__':
-    dataset = load_dataset(FLAGS.dataset)
-    ontology = load_ontology(FLAGS.dataset)
-    nlg_data = load_nlg_data(dataset)
-    if FLAGS.do_train:
-        train(model, nlg_data)
+    if '_' in FLAGS.dataset:
+        spans = FLAGS.dataset.split('_')
+        data_list = spans
+        datasets = [load_dataset(item) for item in data_list] 
+        nlg_datas = [load_nlg_data(item) for item in datasets]
+        ret = {}
+        def aggregrate(nlg_datas, split):
+            ret = []
+            for item in nlg_datas:
+                ret += item[split]
+            return ret
+        ret['train'] = aggregrate(nlg_datas, 'train')
+        ret['validation'] = aggregrate(nlg_datas, 'validation')
+        ret['test'] = aggregrate(nlg_datas, 'test')
+        if FLAGS.do_train:
+            train(model, ret)
+        else:
+            print('not supported')
     else:
-        test(model, nlg_data, ontology, FLAGS.model_path)
+        dataset = load_dataset(FLAGS.dataset, dial_ids_order=0, split2ratio={'train': FLAGS.train_ratio})
+        ontology = load_ontology(FLAGS.dataset)
+        nlg_data = load_nlg_data(dataset)
+        if FLAGS.do_train:
+            train(model, nlg_data)
+        else:
+            test(model, nlg_data, ontology, FLAGS.model_path)
diff --git a/convlab/nlg/scgpt/model.py b/convlab/nlg/scgpt/model.py
index 9a41cfa53e92938d781abc27bc7082c3948b7a0d..6d41319ec4d7254a27620dfe6d20fcdb684e8773 100644
--- a/convlab/nlg/scgpt/model.py
+++ b/convlab/nlg/scgpt/model.py
@@ -24,5 +24,29 @@ class SCGPTDataset(Dataset):
     def __len__(self):
         return len(self.data)
 
+    def __getitem__(self, idx):
+        return self.data[idx]
+
+
+class SGD_TMDataset(Dataset):
+    def __init__(self, data, tokenizer):
+        """
+        Args:
+            data: [[da_str, response], [da_str, response], ...]
+            tokenizer: GPT2 Tokenizer
+        """
+        self.data = []
+        length_list = []
+        for item in data:
+            da, response = item['dialogue_acts'], item['utterance']
+            da_tokens = tokenizer.encode(act2str(da))
+            response_tokens = tokenizer.encode(response)
+            length_list.append(len(da_tokens) + len(response_tokens) + 1)
+            self.data.append([da_tokens, response_tokens])
+        print(f'max: {np.max(length_list)}, min: {np.min(length_list)}, median: {np.quantile(length_list, 0.5)}, 0.99: {np.quantile(length_list, 0.99)}')
+
+    def __len__(self):
+        return len(self.data)
+
     def __getitem__(self, idx):
         return self.data[idx]
\ No newline at end of file
diff --git a/convlab/nlg/scgpt/multiwoz/__init__.py b/convlab/nlg/scgpt/multiwoz/__init__.py
deleted file mode 100644
index 88c7ca2e9735ded913e007644fc8b46fd78535f6..0000000000000000000000000000000000000000
--- a/convlab/nlg/scgpt/multiwoz/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Created on Sat Apr  4 21:43:42 2020
-
-@author: truthless
-"""
-
-from convlab.nlg.scgpt.multiwoz.scgpt import SCGPT
\ No newline at end of file
diff --git a/convlab/nlg/scgpt/multiwoz/preprocess.py b/convlab/nlg/scgpt/multiwoz/preprocess.py
deleted file mode 100644
index 3f5cf70f664895cd7f1c7d67f6c31903d157809a..0000000000000000000000000000000000000000
--- a/convlab/nlg/scgpt/multiwoz/preprocess.py
+++ /dev/null
@@ -1,129 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Created on Mon Sep 14 11:38:53 2020
-@author: truthless
-"""
-
-import os
-import json
-from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
-from convlab.nlg.scgpt.utils import dict2dict, dict2seq
-import zipfile
-
-def read_zipped_json(filepath, filename):
-    print("zip file path = ", filepath)
-    archive = zipfile.ZipFile(filepath, 'r')
-    return json.load(archive.open(filename))
-
-def init_domain():
-    return {'Attraction':False,
-            'Hospital':False,
-            'Hotel':False,
-            'Police':False,
-            'Restaurant':False,
-            'Taxi':False,
-            'Train':False}
-
-def write_file(name, data, role='usr'):
-    with open(f'{name}.txt', 'w', encoding='utf-8') as f:
-        for ID in data:
-            sess = data[ID]
-            sess_domains = init_domain()
-            for turn in sess:
-                if role == 'usr':
-                    if not turn['usr_da']:
-                        continue
-                    turn['usr_da'] = eval(str(turn['usr_da']).replace('Bus','Train'))
-                    da_seq = dict2seq(dict2dict(turn['usr_da'])).replace('&', 'and')
-                    domains = set([key.split('-')[0] for key in turn['usr_da'].keys()])
-                elif role == 'sys':
-                    if not turn['sys_da']:
-                        continue
-                    turn['sys_da'] = eval(str(turn['sys_da']).replace('Bus','Train'))
-                    da_seq = dict2seq(dict2dict(turn['sys_da'])).replace('&', 'and')
-                    domains = set([key.split('-')[0] for key in turn['sys_da'].keys()])
-                else:
-                    raise NameError('Invalid Role: Select usr/sys.')
-                for domain in domains:
-                    if domain not in ['general', 'Booking'] and not sess_domains[domain]:
-                        da_seq = da_seq.replace(domain.lower(), domain.lower()+' *', 1)
-                        sess_domains[domain] = True
-
-                if role == 'usr':
-                    da_uttr = turn['usr'].replace(' bus ', ' train ').replace('&', 'and')
-                elif role == 'sys':
-                    da_uttr = turn['sys'].replace(' bus ', ' train ').replace('&', 'and')
-                f.write(f'{da_seq} & {da_uttr}\n')
-
-
-if __name__ == '__main__':
-    parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--role', type=str, default='usr')
-    args = parser.parse_args()
-
-    cur_dir = os.path.dirname(os.path.abspath(__file__))
-    data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(
-            cur_dir)))), 'data/multiwoz/')
-
-    keys = ['train', 'val', 'test']
-    data = {}
-    for key in keys:
-        data_key = read_zipped_json(os.path.join(data_dir, key + '.json.zip'), key + '.json')
-        print('load {}, size {}'.format(key, len(data_key)))
-        data = dict(data, **data_key)
-
-    with open(os.path.join(data_dir, 'valListFile'), 'r') as f:
-        val_list = f.read().splitlines()
-    with open(os.path.join(data_dir, 'testListFile'), 'r') as f:
-        test_list = f.read().splitlines()
-
-    results = {}
-    results_val = {}
-    results_test = {}
-
-    for title, sess in data.items():
-        logs = sess['log']
-        turns = []
-        turn = {'turn': 0, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''}
-        current_domain = None
-        for i, diag in enumerate(logs):
-            text = diag['text']
-            da = diag['dialog_act']
-            span = diag['span_info']
-            if current_domain:
-                da = eval(str(da).replace('Booking', current_domain))
-                span = eval(str(span).replace('Booking', current_domain))
-            if i % 2 == 0:
-                turn['usr'] = text
-                turn['usr_da'] = da
-                turn['usr_span'] = span
-                turns.append(turn)
-            else:
-                turn = {'turn': i//2 + 1, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''}
-                turn['sys'] = text
-                turn['sys_da'] = da
-                turn['sys_span'] = span
-            for key in da:
-                domain = key.split('-')[0]
-                if domain not in ['general', 'Booking']:
-                    current_domain = domain
-        else:
-            if args.role == 'sys':
-                turns.append(turn)
-        title = title
-        if title in val_list:
-            current = results_val
-        elif title in test_list:
-            current = results_test
-        else:
-            current = results
-        current[title] = turns
-
-    results = eval(str(results).replace(" n't", " not"))
-    results_val = eval(str(results_val).replace(" n't", " not"))
-    results_test = eval(str(results_test).replace(" n't", " not"))
-
-    if not os.path.exists(os.path.join(cur_dir,'data')):
-        os.makedirs(os.path.join(cur_dir, 'data'))
-    write_file(os.path.join(cur_dir, f'data/train_{args.role}'), dict(results, **results_val), role=args.role)
-    write_file(os.path.join(cur_dir, f'data/test_{args.role}'), results_test, role=args.role)
diff --git a/convlab/nlg/scgpt/multiwoz/run.py b/convlab/nlg/scgpt/multiwoz/run.py
deleted file mode 100644
index e583fe72fb26cd4262a6c4aae7776aabee49293b..0000000000000000000000000000000000000000
--- a/convlab/nlg/scgpt/multiwoz/run.py
+++ /dev/null
@@ -1,171 +0,0 @@
-from __future__ import absolute_import, division, print_function, unicode_literals
-
-import argparse
-import logging
-from tqdm import trange
-
-import torch
-import torch.nn.functional as F
-import numpy as np
-
-import sys
-
-from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig
-
-from transformers import GPT2LMHeadModel, GPT2Tokenizer
-from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
-from transformers import XLNetLMHeadModel, XLNetTokenizer
-from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
-from transformers import CTRLLMHeadModel, CTRLTokenizer
-from transformers import XLMWithLMHeadModel, XLMTokenizer
-
-
-logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
-                    datefmt = '%m/%d/%Y %H:%M:%S',
-                    level = logging.INFO)
-logger = logging.getLogger(__name__)
-
-MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop
-
-ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())
-
-MODEL_CLASSES = {
-    'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
-    'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
-    'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
-    'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
-    'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
-    'xlm': (XLMWithLMHeadModel, XLMTokenizer),
-}
-
-# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
-# in https://github.com/rusiaaman/XLNet-gen#methodology
-# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
-PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
-(except for Alexei and Maria) are discovered.
-The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
-remainder of the story. 1883 Western Siberia,
-a young Grigori Rasputin is asked by his father and a group of men to perform magic.
-Rasputin has a vision and denounces one of the men as a horse thief. Although his
-father initially slaps him for making such an accusation, Rasputin watches as the
-man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
-the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
-with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
-
-
-def set_seed(args):
-    np.random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    if args.n_gpu > 0:
-        torch.cuda.manual_seed_all(args.seed)
-
-def main():
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--model_type", default=None, type=str, required=True,
-                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
-    parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
-                        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
-    parser.add_argument("--prompt", type=str, default="")
-    parser.add_argument("--padding_text", type=str, default="")
-    parser.add_argument("--length", type=int, default=40)
-    parser.add_argument("--num_samples", type=int, default=1)
-    parser.add_argument("--temperature", type=float, default=1.0,
-                        help="temperature of 0 implies greedy sampling")
-    parser.add_argument("--repetition_penalty", type=float, default=1.0,
-                        help="primarily useful for CTRL model; in that case, use 1.2")
-    parser.add_argument("--top_k", type=int, default=50)
-    parser.add_argument("--top_p", type=float, default=0.9)
-    parser.add_argument("--no_cuda", action='store_true',
-                        help="Avoid using CUDA when available")
-    parser.add_argument('--seed', type=int, default=42,
-                        help="random seed for initialization")
-    parser.add_argument('--stop_token', type=str, default=None,
-                        help="Token at which text generation is stopped")
-    parser.add_argument("--batch_size", default=1, type=int)
-    parser.add_argument('--input_file', type=str, default=None,
-                        help="file")
-    parser.add_argument('--output_file', type=str, default=None,
-                        help="file")
-
-    args = parser.parse_args()
-
-    args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
-    args.n_gpu = torch.cuda.device_count()
-
-    set_seed(args)
-
-    args.model_type = args.model_type.lower()
-    model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
-    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, pad_token='<PAD>', padding_side='left')
-    model = model_class.from_pretrained(args.model_name_or_path)
-    model.to(args.device)
-    model.eval()
-
-    if args.length < 0 and model.config.max_position_embeddings > 0:
-        args.length = model.config.max_position_embeddings
-    elif 0 < model.config.max_position_embeddings < args.length:
-        args.length = model.config.max_position_embeddings  # No generation bigger than model size 
-    elif args.length < 0:
-        args.length = MAX_LENGTH  # avoid infinite loop
-
-    logger.info(args)
-    if args.model_type in ["ctrl"]:
-        if args.temperature > 0.7:
-            logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
-
-    fin = open(args.input_file)
-    inputs = [i.strip() for i in fin]
-    output_tests = []
-    for idx in range(0, len(inputs), args.batch_size):
-        logger.info(f"PROGRESS: {int(idx/len(inputs)*100)}%")
-
-        # raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
-        raw_inputs = []
-        for i in range(idx, min(idx+args.batch_size, len(inputs))):
-            lines = inputs[i]
-            raw_text = lines.split(' & ')[0] + ' & '
-            if args.model_type in ["transfo-xl", "xlnet"]:
-                # Models with memory likes to have a long prompt for short inputs.
-                raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
-            raw_inputs.append(raw_text)
-        
-        encoding_inputs = tokenizer.batch_encode_plus(raw_inputs, pad_to_max_length=True, add_special_tokens=False)
-        context_tokens = torch.LongTensor(encoding_inputs['input_ids']).to(args.device)
-        max_length = len(context_tokens[0])
-        attention_mask = torch.LongTensor(encoding_inputs['attention_mask']).to(args.device)
-        position_ids = (attention_mask.cumsum(-1) - 1)
-        position_ids.masked_fill_(attention_mask==0, 0)
-
-        if args.model_type == "ctrl":
-            if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()):
-                logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
-        out_ids = model.generate(
-            input_ids=context_tokens,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            num_beams=args.num_samples,
-            num_return_sequences=args.num_samples,
-            max_length=args.length,
-            temperature=args.temperature,
-            do_sample=True,
-            top_k=args.top_k,
-            top_p=args.top_p,
-            repetition_penalty=args.repetition_penalty
-        )
-        out_ids = out_ids.reshape(len(raw_inputs), args.num_samples, -1)[:, :, max_length:].tolist()
-        for j, out in enumerate(out_ids):
-            examples = [inputs[j]]
-            for o in out:
-                text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
-                text = text[: text.find(args.stop_token) if args.stop_token else None]
-                examples.append(text)
-            output_tests.append(examples)
-        # break
-        # if args.prompt:
-            # break
-    import json
-    json.dump(output_tests, open(args.output_file,'w'), indent=2)
-    return text
-
-if __name__ == '__main__':
-    main()
diff --git a/convlab/nlg/scgpt/scgpt.py b/convlab/nlg/scgpt/scgpt.py
index 22763df86987755fbbdd7c5edc618e4673a90004..4977b4f3778229293bb60c2e98b7e715ebe227dd 100644
--- a/convlab/nlg/scgpt/scgpt.py
+++ b/convlab/nlg/scgpt/scgpt.py
@@ -2,27 +2,22 @@ import sys
 sys.path.append('../../..')
 
 import torch
-from transformers import GPT2Tokenizer, GPT2LMHeadModel
+from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
 from torch.nn.parallel import DistributedDataParallel as DDP
 
 from convlab.nlg.nlg import NLG
 from util import act2str
 from scgpt_special_tokens import *
 
-special_tokens = [START_OF_PRED, END_OF_PRED, SYS_SPEAK, USR_SPEAK]
 
 class SCGPT(NLG):
     def __init__(self, dataset_name, model_path, device='cpu'):
         super(SCGPT, self).__init__()
         self.device = device
-        self.model = GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
-        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
-        self.tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED,
-                                           'additional_special_tokens': special_tokens})
-        self.model.resize_token_embeddings(len(self.tokenizer))
+        self.model = GPT2LMHeadModel(config=GPT2Config.from_pretrained('gpt2-medium')).to(self.device)
+        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
         self.model.load_state_dict(torch.load(model_path))
 
-
     def generate(self, action):
         action_str = act2str(action)
         output = self._inference_batch([action_str])[0]
@@ -30,16 +25,19 @@ class SCGPT(NLG):
 
     def _inference_batch(self, sents):
         with torch.no_grad():
-            sents = [sent + ' ' + START_OF_PRED for sent in sents]
-            sent_ids = [self.tokenizer.encode(sent) for sent in sents]
+            sents = [sent for sent in sents]
+            sent_ids = [self.tokenizer.encode(sent) + [self.tokenizer._convert_token_to_id_with_added_voc('&')] for sent in sents]
             max_len = max([len(sent) for sent in sent_ids])
-            sent_ids = [sent + [self.tokenizer.pad_token_id] * (max_len - len(sent)) for sent in sent_ids]
+            sent_ids = [sent + [0] * (max_len - len(sent))  for sent in sent_ids]
             inputs = torch.LongTensor(sent_ids).to(self.device)
             model_to_run = self.model.module if type(self.model) is DDP else self.model
-            outputs = model_to_run.generate(inputs, max_length=256,
-                                            eos_token_id=self.tokenizer.pad_token_id,
-                                            pad_token_id=self.tokenizer.pad_token_id)  # greedy
-            # outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id,
-            #                                 pad_token_id=gpt2_tokenizer.pad_token_id)  # beam search
-            output_strs = [self.tokenizer.decode(item) for item in outputs]
+            outputs = model_to_run.generate(inputs, max_length=256, attention_mask=(inputs!=0).float(),
+                                            eos_token_id=self.tokenizer.pad_token_id)  # greedy
+            outputs = outputs[:, len(inputs[0]):]
+            def clean_sentence(sent):
+                sent = sent.strip()
+                if self.tokenizer.eos_token in sent:
+                    sent = sent[:sent.index(self.tokenizer.eos_token)]
+                return sent
+            output_strs = [clean_sentence(item) for item in outputs]
             return output_strs
\ No newline at end of file
diff --git a/convlab/nlg/scgpt/scgpt_special_tokens.py b/convlab/nlg/scgpt/scgpt_special_tokens.py
index 643820dd04e26bde83edddcd4581784577ad3853..4610be5ff74e6322d58a37a6eaf04b0dce7c7216 100644
--- a/convlab/nlg/scgpt/scgpt_special_tokens.py
+++ b/convlab/nlg/scgpt/scgpt_special_tokens.py
@@ -3,7 +3,7 @@ SYS_SPEAK = '[sys_speak]'
 USR_SPEAK = '[usr_speak]'
 START_OF_PRED = '[start_of_pred]'
 END_OF_PRED = '[end_of_pred]'
-PAD_TOKEN = '[_pad_token_]'
+PAD_TOKEN = '<|pad_token|>'
 START_OF_INTENT = '[start_of_intent]'
 END_OF_INTENT = '[end_of_intent]'
 START_OF_SLOT = ''
diff --git a/convlab/nlg/scgpt/train.sh b/convlab/nlg/scgpt/train.sh
old mode 100644
new mode 100755
index d36d1066abec984ca89c203435a5cf7111209c98..fbfed6a496c387cb8f5c090ac20aee33f79d0325
--- a/convlab/nlg/scgpt/train.sh
+++ b/convlab/nlg/scgpt/train.sh
@@ -1 +1,13 @@
-CUDA_VISIBLE_DEVICES="5" python -m torch.distributed.launch --nproc_per_node 1 main.py --do_train --dataset multiwoz21 --scgpt_model_ckpt_path /data/zhangzheng/scgpt
\ No newline at end of file
+CUDA_VISIBLE_DEVICES="0" python -m torch.distributed.launch --nproc_per_node 1 --master_port 2040 main.py \
+--batch_size 64 \
+--accumulation_step 2 \
+--epoch_num 20 \
+--lr 5e-5 \
+--base_model_name_path gpt2-medium \
+--val_step 500 \
+--exp_name mwoz_sgd_tm_train \
+--do_train \
+--dataset multiwoz21_sgd_tm1_tm2_tm3 \
+--train_ratio 1.0 \
+# --scgpt_model_ckpt_path saved_models/gpt2_sgd_tm/epoch_2/epoch_2_step13698.pt
+# --base_model_name_path /root/autodl-tmp/ConvLab-3/convlab/nlg/scgpt/resource/scgpt \