diff --git a/convlab2/nlg/scgpt/decode.py b/convlab2/nlg/scgpt/decode.py
deleted file mode 100644
index e95025afdde60d8beb41bac5d2fb038e39357f3d..0000000000000000000000000000000000000000
--- a/convlab2/nlg/scgpt/decode.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Created on Sat Apr  4 21:34:38 2020
-
-@author: truthless
-"""
-import numpy as np
-import torch
-
-def set_seed(seed, n_gpu):
-    np.random.seed(seed)
-    torch.manual_seed(seed)
-    if n_gpu > 0:
-        torch.cuda.manual_seed_all(seed)
-
-
-def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
-    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
-        Args:
-            logits: logits distribution shape (batch size x vocabulary size)
-            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
-            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
-                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
-        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
-    """
-    top_k = min(top_k, logits.size(-1))  # Safety check
-    if top_k > 0:
-        # Remove all tokens with a probability less than the last token of the top-k
-        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
-        logits[indices_to_remove] = filter_value
-
-    if top_p > 0.0:
-        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
-        cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
-
-        # Remove tokens with cumulative probability above the threshold
-        sorted_indices_to_remove = cumulative_probs > top_p
-        # Shift the indices to the right to keep also the first token above the threshold
-        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
-        sorted_indices_to_remove[..., 0] = 0
-
-        # scatter sorted tensors to original indexing
-        indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
-        logits[indices_to_remove] = filter_value
-    return logits
-
-
-def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
-                    is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
-    context = torch.tensor(context, dtype=torch.long, device=device)
-    context = context.unsqueeze(0).repeat(num_samples, 1)
-    generated = context
-    with torch.no_grad():
-        for _ in range(length):
-
-            inputs = {'input_ids': generated}
-            if is_xlnet: 
-                # XLNet is a direct (predict same token, not next token) and bi-directional model by default
-                # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
-                input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
-                perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
-                perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token
-                target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
-                target_mapping[0, 0, -1] = 1.0  # predict last token
-                inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
-
-            if is_xlm_mlm and xlm_mask_token:
-                # XLM MLM models are direct models (predict same token, not next token)
-                # => need one additional dummy token in the input (will be masked and guessed)
-                input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
-                inputs = {'input_ids': input_ids}
-
-            if xlm_lang is not None:
-                inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
-
-            outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
-            next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
-
-            # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
-            for i in range(num_samples):
-                for _ in set(generated[i].tolist()):
-                    next_token_logits[i, _] /= repetition_penalty
-                
-            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
-            if temperature == 0: # greedy sampling:
-                next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
-            else:
-                next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
-            generated = torch.cat((generated, next_token), dim=1)
-    return generated
diff --git a/convlab2/nlg/scgpt/main.py b/convlab2/nlg/scgpt/main.py
index 2d69ba2f5af38b29efaa7e71dcbfae851aa3a218..9f1ed5817d36d12c5f6a69b29c82c63dd585db44 100644
--- a/convlab2/nlg/scgpt/main.py
+++ b/convlab2/nlg/scgpt/main.py
@@ -19,22 +19,19 @@ from convlab2.nlg.scgpt.util import act2str
 from convlab2.nlg.scgpt.model import SCGPTDataset
 from evaluate import GentScorer
 
-# 分部式训练
 import torch.distributed as dist
 from torch.nn.parallel import DistributedDataParallel as DDP
 
 from util import build_mask
 from scgpt_special_tokens import *
-# from plot.attn_plot import plot_attn_encdec
 
-## Model Testing
 code_test = False
 
-## 多GPU并行参数设置
 parser = argparse.ArgumentParser()
 parser.add_argument("--local_rank", default=-1, type=int)
 parser.add_argument('--do_train', action="store_true", help="Whether to run training.")
-parser.add_argument('--dataset', default="multiwoz21", type=str, help="Whether to run training.")
+parser.add_argument('--dataset', default="multiwoz21", type=str, help="The name of the dataset to be used.")
+parser.add_argument('--model_path', default="", type=str, help="The path of model for testing.")
 parser.add_argument("--max_seq_len", default=256, type=int)
 FLAGS = parser.parse_args()
 local_rank = FLAGS.local_rank
@@ -52,16 +49,15 @@ tokenizer.add_special_tokens({'pad_token': PAD_TOKEN, 'eos_token': END_OF_PRED,
 model = GPT2LMHeadModel.from_pretrained('./gpt2').to(local_rank)
 model.resize_token_embeddings(len(tokenizer))
 
-## loss计算
 nll_loss = nn.NLLLoss(reduce=False).to(local_rank)
 ce_loss = nn.CrossEntropyLoss(reduce=False).to(local_rank)
 def cal_loss(input, target, seq_lens, seq_lens_input):
     """Only calculate loss on responses, not on dialog act"""
     global nll_loss
     """Input: [batch, length, vocab]; target: [batch, length]; seq_lens: [batch]"""
-    log_probs = F.log_softmax(input, dim=-1).transpose(1, 2)  # 类别维度要放在dim=1的位置,nn.NLLLoss的要求
+    log_probs = F.log_softmax(input, dim=-1).transpose(1, 2)
     loss = nll_loss(log_probs, target)
-    # loss = ce_loss(input, target)  # 等价
+    # loss = ce_loss(input, target)
     mask = build_mask(torch.max(seq_lens).item()-1, seq_lens-1).to(local_rank)
     input_mask = build_mask(torch.max(seq_lens).item()-1, seq_lens_input-1).to(local_rank)
     output_mask = torch.logical_xor(mask, input_mask)
@@ -82,7 +78,7 @@ def pad_collate(batch):
     START_OF_PRED_ID = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED)
     pad_token_id = tokenizer.pad_token_id
     batch = [item[0] + [START_OF_PRED_ID] + item[1] for item in batch]
-    batch = [item[-FLAGS.max_seq_len:] for item in batch]  # TF限制输入长度
+    batch = [item[-FLAGS.max_seq_len:] for item in batch]
     max_len = max([len(item) for item in batch])
     seq_lens = [len(item) for item in batch]
     split_id = tokenizer._convert_token_to_id_with_added_voc(START_OF_PRED)
@@ -286,4 +282,4 @@ if __name__ == '__main__':
     if FLAGS.do_train:
         train(model, nlg_data)
     else:
-        test(model, nlg_data, ontology, './saved_model/epoch_0/epoch_0_step2839.pt')
+        test(model, nlg_data, ontology, FLAGS.model_path)
diff --git a/convlab2/nlg/scgpt/modeling_utils.py b/convlab2/nlg/scgpt/modeling_utils.py
deleted file mode 100644
index a8b3f6ddfc6b7347c624446bf7869c67d3064cc1..0000000000000000000000000000000000000000
--- a/convlab2/nlg/scgpt/modeling_utils.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import warnings
-from contextlib import nullcontext
-from typing import TYPE_CHECKING
-import torch.cuda.amp as amp
-import transformers
-from transformers import GPT2LMHeadModel
-
-
-# reference: https://pytorch.org/docs/master/notes/amp_examples.html
-class AmpGPT2LMHeadModel(GPT2LMHeadModel):
-    if TYPE_CHECKING:
-        # For IDE's code hinting
-        forward = GPT2LMHeadModel.forward
-    else:
-        def forward(self, *args, **kwargs):
-            with amp.autocast():
-                return super().forward(*args, **kwargs)
-
-
-def try_enable_gradient_checkpointing(model: "transformers.modeling_utils.PreTrainedModel"):
-    if model.supports_gradient_checkpointing:
-        model.gradient_checkpointing_enable()
-    else:
-        warnings.warn(f"{type(model)} doesn't support gradient_checkpointing")
-
-
-class AmpHelper:
-    """
-    References:
-        https://pytorch.org/docs/master/notes/amp_examples.html
-    """
-    def __init__(self, use_amp=True):
-        self.use_amp = use_amp
-        self.might_enable_autocast = amp.autocast() if use_amp else nullcontext()
-        self.scaler = amp.GradScaler()
-
-    def backward(self, loss):
-        if self.use_amp:
-            return self.scaler.scale(loss).backward()
-        else:
-            return loss.backward()
-
-    def step(self, optimizer):
-        if self.use_amp:
-            self.scaler.step(optimizer)
-            self.scaler.update()
-        else:
-            optimizer.step()
-
-    def might_unscale_(self, optimizer):
-        if self.use_amp:
-            # Unscales the gradients of optimizer's assigned params in-place
-            self.scaler.unscale_(optimizer)
\ No newline at end of file
diff --git a/convlab2/nlg/scgpt/utils.py b/convlab2/nlg/scgpt/utils.py
deleted file mode 100644
index 7fefc166f096c307590fd5b0478c4db1cf551e7f..0000000000000000000000000000000000000000
--- a/convlab2/nlg/scgpt/utils.py
+++ /dev/null
@@ -1,98 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Created on Tue Mar 24 18:34:55 2020
-
-@author: truthless
-"""
-
-def tuple2dict(t):
-    '''
-    tuple: [(intent, domain, slot, value)]
-    dict: [domain: { intent: [slot, value] }]
-    '''
-    d = {}
-    for intent, domain, slot, value in t:
-        if domain not in d:
-            d[domain] = {}
-        if intent not in d[domain]:
-            d[domain][intent] = []
-        if slot == 'none' or slot is None:
-            continue
-        d[domain][intent].append([slot, value])
-    return d
-
-def dict2dict(D):
-    '''
-    dict: [domain-intent: [slot, value]]
-    dict: [domain: { intent: [slot, value] }]
-    '''
-    d = {}
-    for domint in D:
-        domain, intent = domint.split('-')
-        if domain not in d:
-            d[domain] = {}
-        if intent not in d[domain]:
-            d[domain][intent] = []
-        for slot, value in D[domint]:
-            if slot == 'none' or slot is None:
-                continue
-            d[domain][intent].append([slot, value])
-    return d
-
-def dict2seq(d):
-    '''
-    dict: [domain: { intent: [slot, value] }]
-    seq: [domain { intent ( slot = value ; ) @ } | ]
-    '''
-    s = ''
-    first_domain = True
-    first_intent = True
-    first_slot = True
-    for domain in d:
-        if not first_domain:
-            s += ' | '
-        s += domain
-        s += ' { '
-        first_domain = False
-        first_intent = True
-        for intent in d[domain]:
-            if not first_intent:
-                s += ' @ '
-            s += intent
-            s += ' ( '
-            first_intent = False
-            first_slot = True
-            for slot, value in d[domain][intent]:
-                if not first_slot:
-                    s += ' ; '
-                s += slot
-                if value:
-                    s += ' = '
-                    s += value
-                first_slot = False
-            s += ' )'
-        s += ' }'
-    return s.lower()
-
-def tuple2seq(t):
-    d = tuple2dict(t)
-    s = dict2seq(d)
-    return s
-    
-if __name__ == '__main__':
-    da_tuple = [('Inform', 'Booking', 'none', 'none'), ('Inform', 'Hotel', 'Price', 'cheap'), ('Inform', 'Hotel', 'Choice', '1'), ('Inform', 'Hotel', 'Parking', 'none')]
-    da_dict = tuple2dict(da_tuple)
-    print(da_dict)
-    da_seq = dict2seq(da_dict)
-    print(da_seq)
-
-    da_tuple = [('Request', 'Hotel', 'Address', '?'), ('Request', 'Hotel', 'Area', '?'), ('Inform', 'Attraction', 'Area', 'center'), ('Inform', 'Hotel', 'Price', 'cheap')]
-    da_dict = tuple2dict(da_tuple)
-    print(da_dict)
-    da_seq = dict2seq(da_dict)
-    print(da_seq)
-    
-    D = {'Hotel-Inform': [['Price', 'cheap'], ['Type', 'hotel']]}
-    da_dict = dict2dict(D)
-    print(da_dict)
-