diff --git a/convlab/e2e/soloist/READEME.md b/convlab/e2e/soloist/READEME.md
new file mode 100644
index 0000000000000000000000000000000000000000..12367222d48670054855f5592f369fad391f1be8
--- /dev/null
+++ b/convlab/e2e/soloist/READEME.md
@@ -0,0 +1,95 @@
+# SOLOIST
+
+On top of the pre-trained LMs, SOLOIST subsumes different components of task-oriented dialogs into a single model and emplies a pre-training then fine-tuning schema to build task bots.
+
+## Usage
+
+Follow the instruction under each dataset's directory to prepare data training and evaluation.
+
+#### Dataset Creation
+Create datasets of three settings. 
+```sh
+$ cd multiwoz
+$ python script/create_dataset.py joint
+$ python script/create_dataset.py transfer
+$ python script/create_dataset.py single
+```
+
+#### Train a model
+
+```sh
+$ python train.py --model_name_or_path t5-base --dataset_name e2e_dataloader.py --output_dir ./model --per_device_train_batch_size=2 --per_device_eval_batch_size=2 --max_target_length 128 --max_length 512 --num_train_epochs 50 --save_steps 10000 --preprocessing_num_workers 1 --num_beams 5 --learning_rate 5e-5 --dataset_config_name SINGLE --logging_steps 100
+```
+
+The model (`pytorch_model.bin`) will be saved under the `output_dir` of the config file. The script will save predictions for validation/test every epoch.
+
+#### Test a model
+
+The result will be saved under the `output_dir` of the config file. For evaluation, a 3rd party package is used. Please follow the instructions at https://github.com/Tomiinek/MultiWOZ_Evaluation
+
+
+## Performance on unified format datasets of different settings
+
+ Note that we use almost the same hyper-parameters for different settings, which may not be optimal.
+
+<table>
+<thead>
+  <tr>
+    <th></th>
+    <th colspan=2>MultiWOZ 2.1</th>
+    <th colspan=2>SGD</th>
+    <th colspan=2>Taskmaster-1</th>
+  </tr>
+</thead>
+<thead>
+  <tr>
+    <th>Model</th>
+    <th>Combined</th><th>BLEU</th>
+    <th>Slot F1</th><th>BLEU</th>
+    <th>Slot F1</th><th>BLEU</th>
+  </tr>
+</thead>
+<tbody>
+  <tr>
+    <td>SOLOIST w/o pre-training</td>
+    <td>67.0</td><td>16.8</td>
+    <td>56.9</td><td>11.2</td>
+    <td>8.5</td><td>28.0</td>    
+  </tr>
+  <tr>
+    <td>SOLOIST </td>
+    <td>71.4</td><td>17.1</td>
+    <td>69.7</td><td>23.1</td>
+    <td>9.2</td><td>29.2</td>
+
+  </tr>
+</tbody>
+</table>
+
+- Slot F1: F1 measure of the delexicalized slot predictions over the corpus.
+
+## References
+
+```
+@article{peng2021soloist,
+  title={Soloist: Buildingtask bots at scale with transfer learning and machine teaching},
+  author={Peng, Baolin and Li, Chunyuan and Li, Jinchao and Shayandeh, Shahin and Liden, Lars and Gao, Jianfeng},
+  journal={Transactions of the Association for Computational Linguistics},
+  volume={9},
+  pages={807--824},
+  year={2021},
+  publisher={MIT Press}
+}
+@article{nekvinda2021shades,
+  title={Shades of BLEU, flavours of success: The case of MultiWOZ},
+  author={Nekvinda, Tom{\'a}{\v{s}} and Du{\v{s}}ek, Ond{\v{r}}ej},
+  journal={arXiv preprint arXiv:2106.05555},
+  year={2021}
+}
+@article{peng2022godel,
+  title={GODEL: Large-Scale Pre-Training for Goal-Directed Dialog},
+  author={Peng, Baolin and Galley, Michel and He, Pengcheng and Brockett, Chris and Liden, Lars and Nouri, Elnaz and Yu, Zhou and Dolan, Bill and Gao, Jianfeng},
+  journal={arXiv preprint arXiv:2206.11309},
+  year={2022}
+}
+```
\ No newline at end of file
diff --git a/convlab/e2e/soloist/e2e_dataloader.py b/convlab/e2e/soloist/e2e_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac7be4d2e3d13f79e10f824505c8e1ab33ff4f35
--- /dev/null
+++ b/convlab/e2e/soloist/e2e_dataloader.py
@@ -0,0 +1,124 @@
+import datasets
+import jsonlines
+import random
+
+# coding=utf-8
+# Copyright 2020 HuggingFace Datasets Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Corpus for E2E Dialog Modeling"""
+
+
+import csv
+
+import datasets
+
+
+_DESCRIPTION = """\
+E2E Dialog Modeling
+"""
+
+_CITATION = """\
+E2E Dialog Modeling
+"""
+
+_DOWNLOAD_URL = ""
+_WEBPAGE = ""
+
+class UnifiedDialogConfig(datasets.BuilderConfig):
+    """BuilderConfig for SuperGLUE."""
+
+    def __init__(self, data_name, **kwargs):
+        """BuilderConfig for SuperGLUE.
+        Args:
+          features: `list[string]`, list of the features that will appear in the
+            feature dict. Should not include "label".
+          data_url: `string`, url to download the zip file from.
+          citation: `string`, citation for the data set.
+          url: `string`, url for information about the data set.
+          label_classes: `list[string]`, the list of classes for the label if the
+            label is present as a string. Non-string labels will be cast to either
+            'False' or 'True'.
+          **kwargs: keyword arguments forwarded to super.
+        """
+        # Version history:
+        # 1.0.2: Fixed non-nondeterminism in ReCoRD.
+        # 1.0.1: Change from the pre-release trial version of SuperGLUE (v1.9) to
+        #        the full release (v2.0).
+        # 1.0.0: S3 (new shuffling, sharding and slicing mechanism).
+        # 0.0.2: Initial version.
+        super(UnifiedDialogConfig, self).__init__(version=datasets.Version("1.0.2"), **kwargs)
+        self.data_name = data_name
+        
+
+
+class Summarization(datasets.GeneratorBasedBuilder):
+    """Summarization"""
+
+    BUILDER_CONFIGS = [
+         UnifiedDialogConfig(name='JOINT',data_name='joint'),
+         UnifiedDialogConfig(name='TRANSFER',data_name='transfer'),
+         UnifiedDialogConfig(name='SINGLE',data_name='single'),
+     ]
+    
+    
+    random.seed(2022)
+
+    def _info(self):
+        return datasets.DatasetInfo(
+            description=_DESCRIPTION,
+            features=datasets.Features(
+                {
+                    "Context": datasets.Value("string"),
+                    "Knowledge": datasets.Value("string"),
+                    "Response": datasets.Value("string"),
+                    "Dataset": datasets.Value("string"),
+                }
+            ),
+            homepage=_WEBPAGE,
+            citation=_CITATION,
+        )
+
+    def _split_generators(self, dl_manager):        
+        
+        data_name = self.config.data_name
+
+        if data_name == 'joint':
+            train_path = f'./multiwoz/data/joint_train.jsonl'
+            validation_path = f'./multiwoz/data/single_validation.jsonl'
+            test_path = f'./multiwoz/data/single_test.jsonl'
+        elif data_name == 'transfer':
+            train_path = f'./multiwoz/data/transfer_train.jsonl'
+            validation_path = f'./multiwoz/data/single_validation.jsonl'
+            test_path = f'./multiwoz/data/single_test.jsonl'
+        elif data_name == 'single':
+            train_path = f'./multiwoz/data/single_train.jsonl'
+            validation_path = f'./multiwoz/data/single_validation.jsonl'
+            test_path = f'./multiwoz/data/single_test.jsonl'
+        else:
+            raise('Please specific dataset config.')
+
+        return [
+            datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": train_path}),
+            datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": validation_path}),
+            datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": test_path}),
+        ]
+    def _generate_examples(self, filepath):
+        
+        with open(filepath, "r", encoding="utf-8") as reader:
+            key = 0
+            for item in jsonlines.Reader(reader):
+                yield key, item
+                key += 1
\ No newline at end of file
diff --git a/convlab/e2e/soloist/multiwoz/script/create_dataset.py b/convlab/e2e/soloist/multiwoz/script/create_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2426af9f6100c07caee4a67d89231cb70481bf8a
--- /dev/null
+++ b/convlab/e2e/soloist/multiwoz/script/create_dataset.py
@@ -0,0 +1,74 @@
+import jsonlines
+import copy
+import fire
+
+from convlab.util.unified_datasets_util import create_delex_data, load_dataset
+from convlab.util import load_e2e_data
+
+def state_to_string(state):
+    domain_str = []
+    for domain,svs in state.items():
+        svs_str = []
+        for s,v in svs.items():
+            if v != '':
+                svs_str.append(f'{s} is {v}')
+        svs_str = ' ; '.join(svs_str)
+        if svs_str != '':
+            domain_str.append(f'{domain} {svs_str}')
+    domain_str = ' | '.join(domain_str)
+    return domain_str
+
+def context_to_string(context):
+    response = ' EOS '.join(i['utterance'].strip() for i in context)
+    return response
+
+def delex_function(d,s,v):
+    s = s.replace(' ','')
+    str_ = f'[{d}_{s}]'
+    return str_
+
+def create_dataset(mode='joint'):
+    dataset_list = {
+        'joint': ['tm1','sgd','multiwoz21'],
+        'transfer': ['tm1','sgd'],
+        'single': ['multiwoz21']
+    }
+
+    examples = []
+    for _data in dataset_list[mode]:
+        
+        dataset = load_dataset(_data)
+        dataset, delex_vocab = create_delex_data(dataset, delex_func=delex_function)
+        e2e_data = load_e2e_data(dataset, delex_utterance = True)
+        
+        split_list = ['train','validation','test'] if mode == 'single' else ['train']
+        
+        for split in split_list:
+            data = e2e_data[split]
+            for i in data:
+                response = i['delex_utterance'].strip()
+                context = i['context']
+                context = context_to_string(context)
+                
+                example = {}
+                example['Context'] = context
+                try:
+                    knowledge = state_to_string(i['context'][-1]['state'])
+                except Exception:
+                    knowledge = ''
+                example['Knowledge'] = knowledge
+                example['Response'] = 'Agent: ' + response.strip()
+                example['Dataset'] = f'{_data}'
+                examples.append(copy.copy(example))
+            if mode == 'single':
+                with jsonlines.open(f'./data/{mode}_{split}.jsonl', "w") as writer:
+                    for item in examples:
+                        writer.write(item)
+                examples = []
+    if mode != 'single':  
+        with jsonlines.open(f'./data/{mode}_train.jsonl', "w") as writer:
+            for item in examples:
+                writer.write(item)
+            
+if __name__ == '__main__':
+    fire.Fire(create_dataset)
\ No newline at end of file
diff --git a/convlab/e2e/soloist/multiwoz/script/create_mwoz_e2e_json.py b/convlab/e2e/soloist/multiwoz/script/create_mwoz_e2e_json.py
deleted file mode 100644
index 5953c19f50e410c8c37002309904de264ae01602..0000000000000000000000000000000000000000
--- a/convlab/e2e/soloist/multiwoz/script/create_mwoz_e2e_json.py
+++ /dev/null
@@ -1,136 +0,0 @@
-import jsonlines
-import json,copy
-fidx = open('test.idx.txt','w')
-
-data = json.load(open('data/test.json'))
-examples = []
-for i in data:
-    name = i['file'].lower()   
-    history = [] 
-    for turn in i['info']:
-        history.append(turn['user_orig'])
-
-        bs = turn['BS']
-        bs_str = []
-        for domain, states in bs.items():
-            domain_str = []
-            for state in states:
-                domain_str.append(f'{state[0]} = {state[1]}')
-            domain_str = ' ; '.join(domain_str)
-            bs_str.append(domain + ' ' + domain_str)
-        bs_str = ' | '.join(bs_str)
-
-        db_str = 'kb '
-        db = turn['KB']
-        if db == 0:
-            db_str += 'zero'
-        elif db_str == 1:
-            db_str += 'one'
-        elif db_str == 2:
-            db_str += 'two'
-        else:
-            db_str += 'more than two'
-
-        act_seq = ' '.join(turn['act'].keys())
-        example = {}
-        example['Context'] = ' EOS '.join(history[:])
-        example['Knowledge'] = ''
-        example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip()
-
-        history.append(turn['sys'].strip())
-        examples.append(copy.copy(example))
-        fidx.write(name + '\n')
-
-writer =  jsonlines.open('multiwoz_test_e2e.jsonl', mode='w')
-for i in examples:    
-    writer.write(i)
-
-
-data = json.load(open('data/val.json'))
-examples = []
-for i in data:
-    name = i['file'].lower()   
-    history = [] 
-    for turn in i['info']:
-        history.append(turn['user_orig'])
-
-
-        bs = turn['BS']
-        bs_str = []
-        for domain, states in bs.items():
-            domain_str = []
-            for state in states:
-                domain_str.append(f'{state[0]} = {state[1]}')
-            domain_str = ' ; '.join(domain_str)
-            bs_str.append(domain + ' ' + domain_str)
-        bs_str = ' | '.join(bs_str)
-
-        db_str = 'kb '
-        db = turn['KB']
-        if db == 0:
-            db_str += 'zero'
-        elif db_str == 1:
-            db_str += 'one'
-        elif db_str == 2:
-            db_str += 'two'
-        else:
-            db_str += 'more than two'
-
-        act_seq = ' '.join(turn['act'].keys())
-        example = {}
-        example['Context'] = ' EOS '.join(history[:])
-        example['Knowledge'] = ''
-        example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip()
-
-        history.append(turn['sys'].strip())
-        examples.append(copy.copy(example))
-        # fidx.write(name + '\n')
-
-writer =  jsonlines.open('multiwoz_valid_e2e.jsonl', mode='w')
-for i in examples:    
-    writer.write(i)
-
-
-data = json.load(open('data/train.json'))
-examples = []
-for i in data:
-    name = i['file'].lower()   
-    history = [] 
-    for turn in i['info']:
-        history.append(turn['user_orig'])
-
-
-        bs = turn['BS']
-        bs_str = []
-        for domain, states in bs.items():
-            domain_str = []
-            for state in states:
-                domain_str.append(f'{state[0]} = {state[1]}')
-            domain_str = ' ; '.join(domain_str)
-            bs_str.append(domain + ' ' + domain_str)
-        bs_str = ' | '.join(bs_str)
-
-        db_str = 'kb '
-        db = turn['KB']
-        if db == 0:
-            db_str += 'zero'
-        elif db_str == 1:
-            db_str += 'one'
-        elif db_str == 2:
-            db_str += 'two'
-        else:
-            db_str += 'more than two'
-
-        act_seq = ' '.join(turn['act'].keys())
-        example = {}
-        example['Context'] = ' EOS '.join(history[:])
-        example['Knowledge'] = ''
-        example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip()
-
-        history.append(turn['sys'].strip())
-        examples.append(copy.copy(example))
-        # fidx.write(name + '\n')
-
-writer =  jsonlines.open('multiwoz_train_e2e.jsonl', mode='w')
-for i in examples:    
-    writer.write(i)
diff --git a/convlab/e2e/soloist/multiwoz/soloist_net.py b/convlab/e2e/soloist/multiwoz/soloist_net.py
deleted file mode 100644
index 45f98200bc7eaf5387f796c8d29d3bb0555ac9e0..0000000000000000000000000000000000000000
--- a/convlab/e2e/soloist/multiwoz/soloist_net.py
+++ /dev/null
@@ -1,277 +0,0 @@
-import argparse
-import logging
-import math
-import os
-import random
-
-import datasets
-import nltk
-import numpy as np
-import torch
-from datasets import load_dataset, load_metric
-from torch.utils.data.dataloader import DataLoader
-from tqdm.auto import tqdm
-
-import transformers
-from accelerate import Accelerator
-from filelock import FileLock
-from transformers import (
-    CONFIG_MAPPING,
-    MODEL_MAPPING,
-    AdamW,
-    AutoConfig,
-    AutoModelForSeq2SeqLM,
-    AutoTokenizer,
-    DataCollatorForSeq2Seq,
-    SchedulerType,
-    get_scheduler,
-    set_seed,
-)
-from transformers.file_utils import is_offline_mode
-from transformers.utils.versions import require_version
-
-import copy, operator
-from queue import PriorityQueue
-import numpy as np
-import torch
-import torch.nn.functional as F
-from torch import nn
-from torch.autograd import Variable
-from torch.distributions import Categorical
-from convlab.e2e.soloist.multiwoz.config import global_config as cfg
-
-logger = logging.getLogger(__name__)
-logging.basicConfig(
-        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
-        datefmt="%m/%d/%Y %H:%M:%S",
-        level=logging.INFO,
-    )
-
-def cuda_(var):
-    return var.cuda() if cfg.cuda and torch.cuda.is_available() else var
-
-
-def tensor(var):
-    return cuda_(torch.tensor(var))
-
-class SOLOIST:
-
-    def __init__(self) -> None:
-        
-        self.config = AutoConfig.from_pretrained(cfg.model_name_or_path)
-        self.model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_name_or_path,config=self.config)
-        self.tokenizer = AutoTokenizer.from_pretrained('t5-base')
-        print('model loaded!')
-
-        self.model = self.model.cuda() if torch.cuda.is_available() else self.model
-
-    def generate(self, inputs):
-
-        self.model.eval()
-        inputs = self.tokenizer([inputs])
-        input_ids = tensor(inputs['input_ids'])
-        # generated_tokens = self.model.generate(input_ids = input_ids, max_length = cfg.max_length, num_beams = cfg.num_beams)
-        generated_tokens = self.model.generate(input_ids = input_ids, max_length = cfg.max_length, top_p=cfg.top_p)
-        decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
-
-        return decoded_preds[0]
-
-    
-    def train_loop(self):
-
-        def preprocess_function(examples):
-            contextes = examples['Context']
-            responses = examples['Response']
-            belief = examples['Belief']
-            responses_labels = []
-            inputs = []
-
-            for context, response, kb in zip(contextes, responses, belief):
-                if cfg.no_kb:
-                    inputs.append(context + ' => ')
-                else:
-
-                    if cfg.format_version == 'e2e':
-                        context = ' EOS '.join(context.split(' EOS ')[-10:])
-                        _input = context
-                    
-                    if cfg.format_version == 'e2e+lm':
-                        context = ' EOS '.join(context.split(' EOS ')[-10:])
-                        inputs.append('[E2E] ' + context)
-                        responses_labels.append(response)
-                        inputs.append('[LM] ' + context )
-                        responses_labels.append(response.split(' EOS ')[1])
-                        continue
-                    
-                    if cfg.format_version == 'v2':
-                        _input = kb + context
-                    
-                    if cfg.format_version == 'v3':
-                        _input = ''
-                        context = context.split(' EOS ')
-                        for idx, turn in enumerate(context):
-                            if idx % 2 == 0:
-                                _input += 'user : ' + turn.strip()
-                            else:
-                                _input += ' system : ' + turn.strip()
-                        _input = _input + ' <|Knowledge|> ' + kb
-
-                    if cfg.format_version == 'v4':
-                        _input = ''
-                        context = context.split(' EOS ')
-                        for idx, turn in enumerate(context):
-                            if idx % 2 == 0:
-                                _input += 'user : ' + turn.strip()
-                            else:
-                                _input += ' system : ' + turn.strip()
-                        _input = kb + _input
-                    
-                    inputs.append(_input)
-                    responses_labels.append(response)
-            model_inputs = self.tokenizer(inputs, max_length=cfg.max_length, padding="max_length", truncation=True)
-
-            
-            with self.tokenizer.as_target_tokenizer():
-                labels = self.tokenizer(responses_labels, max_length=cfg.max_target_length, padding="max_length", truncation=True)
-
-            
-            if cfg.ignore_pad_token_for_loss:
-                labels["labels"] = [
-                    [(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
-                ]
-
-            model_inputs["labels"] = labels["labels"]
-            return model_inputs
-
-        raw_datasets = load_dataset(cfg.dataset_name)
-        column_names = ['Context','Response','Belief']
-        lm_datasets = raw_datasets.map(
-            preprocess_function,
-            batched=True,
-            remove_columns=column_names,
-            num_proc=cfg.preprocessing_num_workers,
-            load_from_cache_file=False,
-            desc=f"Processing dataset",
-        )
-
-        train_dataset = lm_datasets["test"]
-        # train_dataset = lm_datasets["validation"]
-        eval_dataset = lm_datasets["test"]
-        test_dataset = lm_datasets["test"]
-        for index in random.sample(range(len(train_dataset)), 1):
-            logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 
-
-        label_pad_token_id = -100 if cfg.ignore_pad_token_for_loss else self.tokenizer.pad_token_id
-
-        accelerator = Accelerator()
-        logger.info(accelerator.state)
-        data_collator = DataCollatorForSeq2Seq(
-            self.tokenizer,
-            model=self.model,
-            label_pad_token_id=label_pad_token_id,
-            pad_to_multiple_of=8 if accelerator.use_fp16 else None,
-        )
-
-
-        train_dataloader = DataLoader(
-        train_dataset, shuffle=True, collate_fn=data_collator, batch_size=cfg.per_device_train_batch_size
-        )
-        eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=cfg.per_device_eval_batch_size)
-        test_dataloader = DataLoader(test_dataset, collate_fn=data_collator, batch_size=cfg.per_device_eval_batch_size)
-
-        # Optimizer
-        # Split weights in two groups, one with weight decay and the other not.
-        no_decay = ["bias", "LayerNorm.weight"]
-        optimizer_grouped_parameters = [
-            {
-                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
-                "weight_decay": cfg.weight_decay,
-            },
-            {
-                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
-                "weight_decay": 0.0,
-            },
-        ]
-        optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.learning_rate)
-
-        # Prepare everything with our `accelerator`.
-        self.model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
-            self.model, optimizer, train_dataloader, eval_dataloader, test_dataloader
-        )
-
-        # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
-        # shorter in multiprocess)
-
-        # Scheduler and math around the number of training steps.
-        num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
-        if cfg.max_train_steps is None:
-            cfg.max_train_steps = cfg.num_train_epochs * num_update_steps_per_epoch
-        else:
-            cfg.num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch)
-
-        lr_scheduler = get_scheduler(
-            name=cfg.lr_scheduler_type,
-            optimizer=optimizer,
-            num_warmup_steps=cfg.num_warmup_steps,
-            num_training_steps=cfg.max_train_steps,
-        )
-
-        # Metric
-
-        # Train!
-        total_batch_size = cfg.per_device_train_batch_size * accelerator.num_processes * cfg.gradient_accumulation_steps
-
-        logger.info("***** Running training *****")
-        logger.info(f"  Num examples = {len(train_dataset)}")
-        logger.info(f"  Num Epochs = {cfg.num_train_epochs}")
-        logger.info(f"  Instantaneous batch size per device = {cfg.per_device_train_batch_size}")
-        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
-        logger.info(f"  Gradient Accumulation steps = {cfg.gradient_accumulation_steps}")
-        logger.info(f"  Total optimization steps = {cfg.max_train_steps}")
-        # Only show the progress bar once on each machine.
-        progress_bar = tqdm(range(cfg.max_train_steps), disable=not accelerator.is_local_main_process)
-        completed_steps = 0
-        global_steps = 0
-        tr_loss, logging_loss = 0.0, 0.0
-        for epoch in range(cfg.num_train_epochs):
-            self.model.train()
-            # for step, batch in enumerate(train_dataloader):
-            for step, batch in enumerate(train_dataloader):
-                global_steps += 1            
-                outputs = self.model(**batch)
-                loss = outputs.loss
-                loss = loss / cfg.gradient_accumulation_steps
-                tr_loss += loss.item()
-                accelerator.backward(loss)
-                
-                if step % cfg.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                    optimizer.step()
-                    lr_scheduler.step()
-                    optimizer.zero_grad()
-                    completed_steps += 1
-
-                if completed_steps >= cfg.max_train_steps:
-                    break
-
-                if step % cfg.logging_steps == 0:
-                    logger.info(f"  EVALERR:  {(tr_loss - logging_loss)/float(cfg.logging_steps)}")
-                    logging_loss = tr_loss
-                    progress_bar.update(cfg.logging_steps)
-
-                if cfg.output_dir is not None and global_steps % cfg.save_steps == 0 and global_steps > 0:
-                    
-                    accelerator.wait_for_everyone()
-                    if accelerator.is_local_main_process:               
-                        checkpoint_prefix = 'checkpoint'
-                        output_dir = os.path.join(cfg.output_dir, '{}-{}'.format(checkpoint_prefix, global_steps))
-                        if not os.path.exists(output_dir):
-                            os.makedirs(output_dir)
-                        unwrapped_model = accelerator.unwrap_model(self.model)
-                        unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
-
-                        self.tokenizer.save_pretrained(output_dir)
-                        torch.save(cfg, os.path.join(output_dir, 'training_args.bin'))
-                        logger.info("Saving model checkpoint to %s", output_dir)
-
-
-    
\ No newline at end of file
diff --git a/convlab/e2e/soloist/train.py b/convlab/e2e/soloist/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c2078954173bcf20fe2f21691b720eebf4d74c9
--- /dev/null
+++ b/convlab/e2e/soloist/train.py
@@ -0,0 +1,836 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Fine-tuning a 🤗 Transformers model on summarization.
+"""
+# You can also adapt this script on your own summarization task. Pointers for this are left as comments.
+
+import argparse
+import logging
+import math
+import os
+import random
+import json
+
+import datasets
+import nltk
+import numpy as np
+import torch
+from datasets import load_dataset, load_metric
+from torch.utils.data.dataloader import DataLoader
+from tqdm.auto import tqdm
+
+import transformers
+from accelerate import Accelerator
+from filelock import FileLock
+from transformers import (
+    CONFIG_MAPPING,
+    MODEL_MAPPING,
+    AdamW,
+    AutoConfig,
+    AutoModelForSeq2SeqLM,
+    AutoTokenizer,
+    DataCollatorForSeq2Seq,
+    SchedulerType,
+    get_scheduler,
+    set_seed,
+)
+from transformers.file_utils import is_offline_mode
+from transformers.utils.versions import require_version
+
+from nltk.tokenize import TweetTokenizer
+import re
+re_art = re.compile(r'\b(a|an|the)\b')
+re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
+
+
+def normalize_answer(s):
+    return s
+
+def clean_str(txt):
+	#print("in=[%s]" % txt)
+	txt = txt.lower()
+	txt = re.sub('^',' ', txt)
+	txt = re.sub('$',' ', txt)
+
+	# url and tag
+	words = []
+	for word in txt.split():
+		i = word.find('http') 
+		if i >= 0:
+			word = word[:i] + ' ' + '__url__'
+		words.append(word.strip())
+	txt = ' '.join(words)
+
+	# remove markdown URL
+	txt = re.sub(r'\[([^\]]*)\] \( *__url__ *\)', r'\1', txt)
+
+	# remove illegal char
+	txt = re.sub('__url__','URL',txt)
+	txt = re.sub(r"[^A-Za-z0-9():,.!?\"\']", " ", txt)
+	txt = re.sub('URL','__url__',txt)	
+
+	# contraction
+	add_space = ["'s", "'m", "'re", "n't", "'ll","'ve","'d","'em"]
+	tokenizer = TweetTokenizer(preserve_case=False)
+	txt = ' ' + ' '.join(tokenizer.tokenize(txt)) + ' '
+	txt = txt.replace(" won't ", " will n't ")
+	txt = txt.replace(" can't ", " can n't ")
+	for a in add_space:
+		txt = txt.replace(a+' ', ' '+a+' ')
+
+	txt = re.sub(r'^\s+', '', txt)
+	txt = re.sub(r'\s+$', '', txt)
+	txt = re.sub(r'\s+', ' ', txt) # remove extra spaces
+	
+	#print("out=[%s]" % txt)
+	return txt
+
+logger = logging.getLogger(__name__)
+require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
+
+
+import os
+from dotenv import load_dotenv
+load_dotenv()
+if os.getenv('WANDB_API_KEY') is None:
+    USE_WANDB = False
+else:
+    USE_WANDB = True
+    wandb_key = os.getenv('WANDB_API_KEY')
+
+# You should update this to your particular problem to have better documentation of `model_type`
+MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
+MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
+
+try:
+    nltk.data.find("tokenizers/punkt")
+except (LookupError, OSError):
+    if is_offline_mode():
+        raise LookupError(
+            "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
+        )
+    with FileLock(".lock") as lock:
+        nltk.download("punkt", quiet=True)
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
+    parser.add_argument(
+        "--dataset_name",
+        type=str,
+        default=None,
+        help="The name of the dataset to use (via the datasets library).",
+    )
+    parser.add_argument(
+        "--dataset_config_name",
+        type=str,
+        default=None,
+        help="The configuration name of the dataset to use (via the datasets library).",
+    )
+    parser.add_argument(
+        "--train_file", type=str, default=None, help="A csv or a json file containing the training data."
+    )
+    parser.add_argument(
+        "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
+    )
+    
+    parser.add_argument(
+        "--max_source_length",
+        type=int,
+        default=1024,
+        help="The maximum total input sequence length after "
+        "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
+    )
+    parser.add_argument(
+        "--source_prefix",
+        type=str,
+        default=None,
+        help="A prefix to add before every source text " "(useful for T5 models).",
+    )
+    parser.add_argument(
+        "--preprocessing_num_workers",
+        type=int,
+        default=None,
+        help="The number of processes to use for the preprocessing.",
+    )
+
+    parser.add_argument(
+        "--max_target_length",
+        type=int,
+        default=64,
+        help="The maximum total sequence length for target text after "
+        "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
+        "during ``evaluate`` and ``predict``.",
+    )
+    parser.add_argument(
+        "--val_max_target_length",
+        type=int,
+        default=None,
+        help="The maximum total sequence length for validation "
+        "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
+        "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
+        "param of ``model.generate``, which is used during ``evaluate`` and ``predict``.",
+    )
+    parser.add_argument(
+        "--max_length",
+        type=int,
+        default=128,
+        help=(
+            "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
+            " sequences shorter will be padded if `--pad_to_max_lengh` is passed."
+        ),
+    )
+    parser.add_argument(
+        "--num_beams",
+        type=int,
+        default=None,
+        help="Number of beams to use for evaluation. This argument will be "
+        "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
+    )
+    parser.add_argument(
+        "--model_name_or_path",
+        type=str,
+        help="Path to pretrained model or model identifier from huggingface.co/models.",
+        required=True,
+    )
+    parser.add_argument(
+        "--config_name",
+        type=str,
+        default=None,
+        help="Pretrained config name or path if not the same as model_name",
+    )
+    parser.add_argument(
+        "--tokenizer_name",
+        type=str,
+        default=None,
+        help="Pretrained tokenizer name or path if not the same as model_name",
+    )
+    parser.add_argument(
+        "--text_column",
+        type=str,
+        default=None,
+        help="The name of the column in the datasets containing the full texts (for summarization).",
+    )
+    parser.add_argument(
+        "--summary_column",
+        type=str,
+        default=None,
+        help="The name of the column in the datasets containing the summaries (for summarization).",
+    )
+    parser.add_argument(
+        "--use_slow_tokenizer",
+        action="store_true",
+        help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
+    )
+    parser.add_argument(
+        "--per_device_train_batch_size",
+        type=int,
+        default=8,
+        help="Batch size (per device) for the training dataloader.",
+    )
+    parser.add_argument(
+        "--per_device_eval_batch_size",
+        type=int,
+        default=8,
+        help="Batch size (per device) for the evaluation dataloader.",
+    )
+    parser.add_argument(
+        "--learning_rate",
+        type=float,
+        default=5e-5,
+        help="Initial learning rate (after the potential warmup period) to use.",
+    )
+    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
+    parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
+    parser.add_argument(
+        "--max_train_steps",
+        type=int,
+        default=None,
+        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+    )
+    parser.add_argument(
+        "--gradient_accumulation_steps",
+        type=int,
+        default=1,
+        help="Number of updates steps to accumulate before performing a backward/update pass.",
+    )
+    parser.add_argument(
+        "--lr_scheduler_type",
+        type=SchedulerType,
+        default="linear",
+        help="The scheduler type to use.",
+        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
+    )
+    parser.add_argument(
+        "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
+    )
+    parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
+    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+    parser.add_argument(
+        "--model_type",
+        type=str,
+        default=None,
+        help="Model type to use if training from scratch.",
+        choices=MODEL_TYPES,
+    )
+
+
+    parser.add_argument(
+        "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
+    )
+
+    parser.add_argument(
+        "--pad_to_max_length", type=bool, default=True, help="do pading"
+    )
+
+    parser.add_argument(
+        "--ignore_pad_token_for_loss", type=bool, default=True, help="do pading"
+    )
+
+    parser.add_argument(
+        "--logging_steps", type=int, default=500, help="do pading"
+    )
+
+    parser.add_argument(
+        "--save_steps", type=int, default=5000, help="do pading"
+    )
+
+    parser.add_argument(
+        "--save_every_checkpoint", action="store_true"
+    )
+
+    parser.add_argument(
+        "--max_grad_norm", type=float, default=1.0, help="max_grad_norm"
+    )
+
+    parser.add_argument(
+        "--exp_name",
+        type=str,
+        help="Description to the experiment",
+        default='multiwoz',
+    )
+
+    parser.add_argument(
+        "--use_special_token",
+        action="store_true",
+        help="add special token or not"
+    )
+
+    parser.add_argument(
+        "--format_version",
+        type=str, default='v1',
+        help="format version"
+    )
+
+    parser.add_argument(
+        "--wandb_exp_name",
+        type=str,
+        default='multiwoz',
+        help="Description to the experiment worksheet name",
+    )
+
+
+    args = parser.parse_args()
+
+    # Sanity checks
+    if args.dataset_name is None and args.train_file is None and args.validation_file is None:
+        raise ValueError("Need either a dataset name or a training/validation file.")
+    else:
+        if args.train_file is not None:
+            extension = args.train_file.split(".")[-1]
+            assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
+        if args.validation_file is not None:
+            extension = args.validation_file.split(".")[-1]
+            assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
+
+    if args.output_dir is not None:
+        os.makedirs(args.output_dir, exist_ok=True)
+
+    return args
+
+
+def main():
+    args = parse_args()
+
+    if args.source_prefix is None and args.model_name_or_path in [
+        "t5-small",
+        "t5-base",
+        "t5-large",
+        "t5-3b",
+        "t5-11b",
+    ]:
+        logger.warning(
+            "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
+            "`--source_prefix 'summarize: ' `"
+        )
+    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
+    accelerator = Accelerator()
+    # Make one log on every process with the configuration for debugging.
+    logging.basicConfig(
+        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO,
+    )
+    logger.info(accelerator.state)
+
+    # Setup logging, we only want one process per machine to log things on the screen.
+    # accelerator.is_local_main_process is only True for one process per machine.
+    logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+    if accelerator.is_local_main_process:
+        datasets.utils.logging.set_verbosity_warning()
+        transformers.utils.logging.set_verbosity_info()
+    else:
+        datasets.utils.logging.set_verbosity_error()
+        transformers.utils.logging.set_verbosity_error()
+
+    # If passed along, set the training seed now.
+    if args.seed is not None:
+        set_seed(args.seed)
+
+    if accelerator.is_local_main_process and USE_WANDB:
+        config = dict(
+        dataset_id = "",
+        infra = "",
+        )
+        import wandb
+        wandb.init(
+        project=args.wandb_exp_name,
+        notes="Finetuning",
+        tags=["multiwoz"],
+        config=config,
+        entity= 'Convlab3')
+
+        wandb.run.name = args.exp_name
+
+
+    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
+    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
+    # (the dataset will be downloaded automatically from the datasets Hub).
+    #
+    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
+    # 'text' is found. You can easily tweak this behavior (see below).
+    #
+    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
+    # download the dataset.
+    if args.dataset_name is not None:
+        # Downloading and loading a dataset from the hub.
+        raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
+    else:
+        data_files = {}
+        if args.train_file is not None:
+            data_files["train"] = args.train_file
+        if args.validation_file is not None:
+            data_files["validation"] = args.validation_file
+        extension = args.train_file.split(".")[-1]
+        raw_datasets = load_dataset(extension, data_files=data_files)
+    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
+    # https://huggingface.co/docs/datasets/loading_datasets.html.
+
+    # Load pretrained model and tokenizer
+    #
+    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
+    # download model & vocab.
+    if args.config_name:
+        config = AutoConfig.from_pretrained(args.config_name)
+    elif args.model_name_or_path:
+        config = AutoConfig.from_pretrained(args.model_name_or_path)
+    else:
+        config = CONFIG_MAPPING[args.model_type]()
+        logger.warning("You are instantiating a new config instance from scratch.")
+
+    if args.tokenizer_name:
+        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
+    elif args.model_name_or_path:
+        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
+    else:
+        raise ValueError(
+            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
+            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
+        )
+
+    if args.model_name_or_path:
+        model = AutoModelForSeq2SeqLM.from_pretrained(
+            args.model_name_or_path,
+            from_tf=bool(".ckpt" in args.model_name_or_path),
+            config=config,
+        )
+    else:
+        logger.info("Training new model from scratch")
+        model = AutoModelForSeq2SeqLM.from_config(config)
+
+    if 'blender' in args.model_name_or_path.lower():
+        model.model.encoder.embed_positions.weight = torch.nn.Parameter(model.model.encoder.embed_positions.weight.repeat(4,1))
+    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
+
+    if args.use_special_token:
+        special_tokens = [i.strip() for i in open('special_tokens.txt')]
+        tokenizer.add_tokens(special_tokens)
+
+    model.resize_token_embeddings(len(tokenizer))
+    if model.config.decoder_start_token_id is None:
+        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
+
+    prefix = args.source_prefix if args.source_prefix is not None else ""
+    max_length = args.max_length
+    padding = "max_length" if args.pad_to_max_length else False
+    max_target_length = args.max_target_length
+    def preprocess_function(examples):
+        contextes = examples['Context']
+        responses = examples['Response']
+        kbs = examples['Knowledge']
+
+        responses_labels = []
+        inputs = []
+
+        for context, response, kb in zip(contextes, responses, kbs):
+            if args.format_version == 'v1':
+                _input = ' EOS '.join(context.split(' EOS ')[-10:])
+                _response = 'Belief: ' + kb + response
+                inputs.append(_input)
+                responses_labels.append(_response)
+
+        model_inputs = tokenizer(inputs, max_length=args.max_length, padding=padding, truncation=True)
+
+        # labels = model_inputs
+        # Setup the tokenizer for targets
+        with tokenizer.as_target_tokenizer():
+            labels = tokenizer(responses_labels, max_length=max_target_length, padding=padding, truncation=True)
+
+        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
+        # padding in the loss.
+        if padding == "max_length" and args.ignore_pad_token_for_loss:
+            labels["labels"] = [
+                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
+            ]
+
+        model_inputs["labels"] = labels["labels"]
+        return model_inputs
+
+    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
+    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
+    # to preprocess.
+    #
+    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
+    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
+    
+    # del raw_datasets['train']
+    # del raw_datasets['test']
+    column_names = ['Context','Response','Knowledge','Dataset']
+    # # column_names = ['text']
+    # raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
+    lm_datasets = raw_datasets.map(
+        preprocess_function,
+        batched=True,
+        remove_columns=column_names,
+        num_proc=args.preprocessing_num_workers,
+        load_from_cache_file=False,
+        desc=f"Processing dataset",
+    )
+
+    train_dataset = lm_datasets["test"]
+    eval_dataset = lm_datasets["validation"]
+    test_dataset = lm_datasets["test"]
+
+    # Log a few random samples from the training set:
+    for index in random.sample(range(len(train_dataset)), 1):
+        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
+
+
+    label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
+    data_collator = DataCollatorForSeq2Seq(
+        tokenizer,
+        model=model,
+        label_pad_token_id=label_pad_token_id,
+        pad_to_multiple_of=8 if accelerator.use_fp16 else None,
+    )
+
+    def postprocess_text(preds, labels):
+        preds = [normalize_answer(pred.strip().replace('Agent :','')) for pred in preds]
+        labels = [normalize_answer(label.strip().replace('Agent :','')) for label in labels]
+
+        # rougeLSum expects newline after each sentence
+        # preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
+        # labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
+
+        return preds, labels
+
+    train_dataloader = DataLoader(
+        train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
+    )
+    eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
+    test_dataloader = DataLoader(test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
+
+    # Optimizer
+    # Split weights in two groups, one with weight decay and the other not.
+    no_decay = ["bias", "LayerNorm.weight"]
+    optimizer_grouped_parameters = [
+        {
+            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+            "weight_decay": args.weight_decay,
+        },
+        {
+            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+            "weight_decay": 0.0,
+        },
+    ]
+    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+
+    # Prepare everything with our `accelerator`.
+    model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
+        model, optimizer, train_dataloader, eval_dataloader, test_dataloader
+    )
+
+    # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
+    # shorter in multiprocess)
+
+    # Scheduler and math around the number of training steps.
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if args.max_train_steps is None:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+    else:
+        args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+    lr_scheduler = get_scheduler(
+        name=args.lr_scheduler_type,
+        optimizer=optimizer,
+        num_warmup_steps=args.num_warmup_steps,
+        num_training_steps=args.max_train_steps,
+    )
+
+    # Metric
+
+    # Train!
+    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+    logger.info("***** Running training *****")
+    logger.info(f"  Num examples = {len(train_dataset)}")
+    logger.info(f"  Num Epochs = {args.num_train_epochs}")
+    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
+    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+    logger.info(f"  Total optimization steps = {args.max_train_steps}")
+    # Only show the progress bar once on each machine.
+    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
+    completed_steps = 0
+    global_steps = 0
+    tr_loss, logging_loss = 0.0, 0.0
+    for epoch in range(args.num_train_epochs):
+        model.train()
+        
+        for step, batch in enumerate(train_dataloader):
+            
+            global_steps += 1            
+            outputs = model(**batch)
+            loss = outputs.loss
+            loss = loss / args.gradient_accumulation_steps
+            tr_loss += loss.item()
+            accelerator.backward(loss)            
+            
+            if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                optimizer.step()
+                lr_scheduler.step()
+                optimizer.zero_grad()
+                completed_steps += 1
+
+            if completed_steps >= args.max_train_steps:
+                break
+
+            if step % args.logging_steps == 0:
+                logger.info(f"  EVALERR:  {(tr_loss - logging_loss)/float(args.logging_steps)}")
+                if accelerator.is_local_main_process and USE_WANDB:
+                    wandb.log({'loss': tr_loss - logging_loss})
+                logging_loss = tr_loss
+                progress_bar.update(args.logging_steps)
+
+            if args.output_dir is not None and global_steps % args.save_steps == 0 and global_steps > 0:
+                print('hit store')
+                accelerator.wait_for_everyone()
+                if accelerator.is_local_main_process:               
+                    checkpoint_prefix = 'checkpoint'
+                    output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_steps))
+                    if not os.path.exists(output_dir):
+                        os.makedirs(output_dir)
+                    unwrapped_model = accelerator.unwrap_model(model)
+                    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
+
+                    tokenizer.save_pretrained(output_dir)
+                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
+                    logger.info("Saving model checkpoint to %s", output_dir)
+
+        model.eval()
+        if args.val_max_target_length is None:
+            args.val_max_target_length = args.max_target_length
+
+        gen_kwargs = {
+            "max_length": args.val_max_target_length if args is not None else config.max_length,
+            "num_beams": args.num_beams,
+        }
+
+        def chunks(lst, n):
+            for i in range(0, len(lst), n):
+                yield lst[i:i + n]
+
+        metric = load_metric("./rouge_metric.py")
+        metric_bleu = load_metric("./bleu_metric.py")
+        decoded_preds_all = []
+        for step, batch in enumerate(eval_dataloader):
+            with torch.no_grad():
+                generated_tokens = accelerator.unwrap_model(model).generate(
+                    batch["input_ids"],
+                    attention_mask=batch["attention_mask"],
+                    **gen_kwargs,
+                )
+
+                generated_tokens = accelerator.pad_across_processes(
+                    generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
+                )
+                labels = batch["labels"]
+                if not args.pad_to_max_length:
+                    # If we did not pad to max length, we need to pad the labels too
+                    labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
+
+                generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
+                labels = accelerator.gather(labels).cpu().numpy()
+
+                if args.ignore_pad_token_for_loss:
+                    # Replace -100 in the labels as we can't decode them.
+                    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
+                if isinstance(generated_tokens, tuple):
+                    generated_tokens = generated_tokens[0]
+                decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+                decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
+
+                metric.add_batch(predictions=decoded_preds, references=decoded_labels)
+                _decoded_preds = [i.split() for i in decoded_preds]
+                _decoded_labels = [[i.split()] for i in decoded_labels]
+                decoded_preds_all.extend(decoded_preds)
+                metric_bleu.add_batch(predictions=_decoded_preds, references=_decoded_labels)
+            
+                
+        result = metric.compute(use_stemmer=True)
+        # Extract a few results from ROUGE
+        result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
+
+        result = {k: round(v, 4) for k, v in result.items()}
+
+        logger.info(result)
+
+        result_bleu = metric_bleu.compute()
+        logger.info(result_bleu)
+
+        accelerator.wait_for_everyone()
+        if accelerator.is_local_main_process and USE_WANDB:
+            wandb.log({'valid_bleu': result_bleu['bleu']})
+            wandb.log({'valid_rouge': result['rougeL']})
+
+        
+        if args.output_dir is not None:
+            accelerator.wait_for_everyone()
+            if accelerator.is_local_main_process:               
+                if not os.path.exists(args.output_dir):
+                    os.makedirs(args.output_dir)
+                output_dir_file_name = os.path.join(args.output_dir, 'valid-step-{}'.format(completed_steps))
+                print(output_dir_file_name)
+                json.dump(decoded_preds_all, open(output_dir_file_name,'w'), indent=2)
+                logger.info("Saving model outputs to %s", output_dir_file_name)
+        
+        metric = load_metric("rouge")
+        metric_bleu = load_metric("bleu")
+
+        gen_kwargs = {
+            "max_length": args.val_max_target_length if args is not None else config.max_length,
+            "num_beams": args.num_beams,
+        }
+        decoded_preds_all = []
+        for step, batch in enumerate(test_dataloader):
+            with torch.no_grad():
+                generated_tokens = accelerator.unwrap_model(model).generate(
+                    batch["input_ids"],
+                    attention_mask=batch["attention_mask"],
+                    **gen_kwargs,
+                )
+
+                generated_tokens = accelerator.pad_across_processes(
+                    generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
+                )
+                labels = batch["labels"]
+                if not args.pad_to_max_length:
+                    # If we did not pad to max length, we need to pad the labels too
+                    labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
+
+                generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
+                labels = accelerator.gather(labels).cpu().numpy()
+
+                if args.ignore_pad_token_for_loss:
+                    # Replace -100 in the labels as we can't decode them.
+                    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
+                if isinstance(generated_tokens, tuple):
+                    generated_tokens = generated_tokens[0]
+                decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+                decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
+
+                decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
+                metric.add_batch(predictions=decoded_preds, references=decoded_labels)
+                _decoded_preds = [i.split() for i in decoded_preds]
+                _decoded_labels = [[i.split()] for i in decoded_labels]
+                decoded_preds_all.extend(_decoded_preds)
+                metric_bleu.add_batch(predictions=_decoded_preds, references=_decoded_labels)
+                
+        result = metric.compute(use_stemmer=True)
+        # Extract a few results from ROUGE
+        result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
+
+        result = {k: round(v, 4) for k, v in result.items()}
+
+        logger.info(result)
+
+        result_bleu = metric_bleu.compute()
+        logger.info(result_bleu)
+
+        accelerator.wait_for_everyone()
+        if accelerator.is_local_main_process and USE_WANDB:
+            wandb.log({'test_bleu': result_bleu['bleu']})
+            wandb.log({'test_rouge': result['rougeL']})
+
+        import json
+        if args.output_dir is not None:
+            accelerator.wait_for_everyone()
+            if accelerator.is_local_main_process:               
+                if not os.path.exists(args.output_dir):
+                    os.makedirs(args.output_dir)
+                output_dir_file_name = os.path.join(args.output_dir, 'test-step-{}'.format(completed_steps))
+                print(output_dir_file_name)
+                json.dump(decoded_preds_all, open(output_dir_file_name,'w'), indent=2)
+                logger.info("Saving model outputs to %s", output_dir_file_name)
+
+        
+        if args.output_dir is not None and args.save_every_checkpoint:
+            accelerator.wait_for_everyone()
+            if accelerator.is_local_main_process:               
+                checkpoint_prefix = 'checkpoint'
+                output_dir = os.path.join(args.output_dir, '{}-epoch-{}'.format(checkpoint_prefix, epoch))
+                if not os.path.exists(output_dir):
+                    os.makedirs(output_dir)
+                unwrapped_model = accelerator.unwrap_model(model)
+                unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
+
+                tokenizer.save_pretrained(output_dir)
+                torch.save(args, os.path.join(output_dir, 'training_args.bin'))
+                logger.info("Saving model checkpoint to %s", output_dir)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py
index aff31be6742be3908d5ff4ab65e141b3427471d9..32e98234fe0b652d5783237aec4fce16ce7da9c1 100644
--- a/convlab/util/unified_datasets_util.py
+++ b/convlab/util/unified_datasets_util.py
@@ -148,6 +148,7 @@ def load_unified_data(
         dialogue_acts=False, 
         state=False, 
         db_results=False,
+        delex_utterance=False,
         use_context=False, 
         context_window_size=0, 
         terminated=False, 
@@ -182,7 +183,7 @@ def load_unified_data(
     data_splits = dataset.keys() if data_split == 'all' else [data_split]
     assert speaker in ['user', 'system', 'all']
     assert not use_context or context_window_size > 0
-    info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results']))
+    info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results', 'delex_utterance']))
     info_list += ['utt_idx']
     data_by_split = {}
     for data_split in data_splits:
@@ -426,7 +427,12 @@ def create_delex_data(dataset, delex_func=lambda d,s,v: f'[({d})-({s})]', ignore
                             for value in values.split('|'):
                                 if value.lower() not in ignore_values:
                                     placeholder = delex_func(domain, slot, value)
-                                    pattern = re.compile(r'\b({})\b'.format(value), flags=re.I)
+                                    #TODO: value = ?
+                                    value = '\?' if value == '?' else value
+                                    try:
+                                        pattern = re.compile(r'\b({})\b'.format(value), flags=re.I)
+                                    except Exception:
+                                        print(value)
                                     if delex_inplace(delex_utt, pattern):
                                         delex_vocab.add(placeholder)