diff --git a/README.md b/README.md
index 4e4a7abff996d08aa72359151c12490d03725e12..4f20e80a64732f56211872f43bf31653ce761de5 100755
--- a/README.md
+++ b/README.md
@@ -65,15 +65,10 @@ docker exec -it CONTAINER_ID bash
 
 ## Tutorials
 
-| Section                                                      | Description |
-| ------------------------------------------------------------ | ----------- |
-| [Getting Started](https://github.com/thu-coai/ConvLab-2/blob/master/tutorials/Getting_Started.ipynb) (Have a try on [Colab](https://colab.research.google.com/github/thu-coai/ConvLab-2/blob/master/tutorials/Getting_Started.ipynb)!) |             |
-| [Unified Data Format](https://github.com/ConvLab/ConvLab-3/tree/master/data/unified_datasets) |             |
-| [Utility functions for unified datasets](https://github.com/ConvLab/ConvLab-3/blob/master/convlab/util/unified_datasets_util.py) |             |
-| [RL Toolkit](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/policy) |             |
-| [How to add a new dataset](https://github.com/thu-coai/ConvLab-2/blob/master/tutorials/Add_New_Model.md) |             |
-| How to add a new model                                       |             |
-| [Interactive Tool](https://github.com/ConvLab/ConvLab-3/blob/master/deploy) [[demo video]](https://youtu.be/00VWzbcx26E) |             |
+- [Introduction to Unified Data Format](https://github.com/ConvLab/ConvLab-3/tree/master/data/unified_datasets)
+- [Utility functions for unified datasets](https://github.com/ConvLab/ConvLab-3/blob/master/convlab/util/unified_datasets_util.py)
+- [RL Toolkit](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/policy)
+- [Interactive Tool](https://github.com/ConvLab/ConvLab-3/blob/master/deploy) [[demo video]](https://youtu.be/00VWzbcx26E)
 
 ## Unified Datasets
 
@@ -112,10 +107,6 @@ We list newly integrated models in ConvLab-3 that support unified data format an
 
 Trained models are available on [Hugging Face Hub](https://huggingface.co/ConvLab).
 
-## Code structure
-
-
-
 ## Contributing
 
 We welcome contributions from community. Please see issues to find what we need.
@@ -131,15 +122,6 @@ We would like to thank all contributors of ConvLab:
 
 Yan Fang, Zhuoer Feng, Jianfeng Gao, Qihan Guo, Kaili Huang, Minlie Huang, Sungjin Lee, Bing Li, Jinchao Li, Xiang Li, Xiujun Li, Jiexi Liu, Lingxiao Luo, Wenchang Ma, Mehrad Moradshahi, Baolin Peng, Runze Liang, Ryuichi Takanobu, Dazhen Wan, Hongru Wang, Jiaxin Wen, Yaoqin Zhang, Zheng Zhang, Qi Zhu, Xiaoyan Zhu, Carel van Niekerk, Christian Geishauser, Hsien-chin Lin, Nurul Lubis, Xiaochen Zhu, Michael Heck, Shutong Feng, Milica Gašić.
 
-
-## Citing
-
-If you use ConvLab-3 in your research, please cite:
-
-```
-
-```
-
 ## License
 
 Apache License 2.0
diff --git a/convlab/e2e/soloist/READEME.md b/convlab/e2e/soloist/READEME.md
new file mode 100644
index 0000000000000000000000000000000000000000..12367222d48670054855f5592f369fad391f1be8
--- /dev/null
+++ b/convlab/e2e/soloist/READEME.md
@@ -0,0 +1,95 @@
+# SOLOIST
+
+On top of the pre-trained LMs, SOLOIST subsumes different components of task-oriented dialogs into a single model and emplies a pre-training then fine-tuning schema to build task bots.
+
+## Usage
+
+Follow the instruction under each dataset's directory to prepare data training and evaluation.
+
+#### Dataset Creation
+Create datasets of three settings. 
+```sh
+$ cd multiwoz
+$ python script/create_dataset.py joint
+$ python script/create_dataset.py transfer
+$ python script/create_dataset.py single
+```
+
+#### Train a model
+
+```sh
+$ python train.py --model_name_or_path t5-base --dataset_name e2e_dataloader.py --output_dir ./model --per_device_train_batch_size=2 --per_device_eval_batch_size=2 --max_target_length 128 --max_length 512 --num_train_epochs 50 --save_steps 10000 --preprocessing_num_workers 1 --num_beams 5 --learning_rate 5e-5 --dataset_config_name SINGLE --logging_steps 100
+```
+
+The model (`pytorch_model.bin`) will be saved under the `output_dir` of the config file. The script will save predictions for validation/test every epoch.
+
+#### Test a model
+
+The result will be saved under the `output_dir` of the config file. For evaluation, a 3rd party package is used. Please follow the instructions at https://github.com/Tomiinek/MultiWOZ_Evaluation
+
+
+## Performance on unified format datasets of different settings
+
+ Note that we use almost the same hyper-parameters for different settings, which may not be optimal.
+
+<table>
+<thead>
+  <tr>
+    <th></th>
+    <th colspan=2>MultiWOZ 2.1</th>
+    <th colspan=2>SGD</th>
+    <th colspan=2>Taskmaster-1</th>
+  </tr>
+</thead>
+<thead>
+  <tr>
+    <th>Model</th>
+    <th>Combined</th><th>BLEU</th>
+    <th>Slot F1</th><th>BLEU</th>
+    <th>Slot F1</th><th>BLEU</th>
+  </tr>
+</thead>
+<tbody>
+  <tr>
+    <td>SOLOIST w/o pre-training</td>
+    <td>67.0</td><td>16.8</td>
+    <td>56.9</td><td>11.2</td>
+    <td>8.5</td><td>28.0</td>    
+  </tr>
+  <tr>
+    <td>SOLOIST </td>
+    <td>71.4</td><td>17.1</td>
+    <td>69.7</td><td>23.1</td>
+    <td>9.2</td><td>29.2</td>
+
+  </tr>
+</tbody>
+</table>
+
+- Slot F1: F1 measure of the delexicalized slot predictions over the corpus.
+
+## References
+
+```
+@article{peng2021soloist,
+  title={Soloist: Buildingtask bots at scale with transfer learning and machine teaching},
+  author={Peng, Baolin and Li, Chunyuan and Li, Jinchao and Shayandeh, Shahin and Liden, Lars and Gao, Jianfeng},
+  journal={Transactions of the Association for Computational Linguistics},
+  volume={9},
+  pages={807--824},
+  year={2021},
+  publisher={MIT Press}
+}
+@article{nekvinda2021shades,
+  title={Shades of BLEU, flavours of success: The case of MultiWOZ},
+  author={Nekvinda, Tom{\'a}{\v{s}} and Du{\v{s}}ek, Ond{\v{r}}ej},
+  journal={arXiv preprint arXiv:2106.05555},
+  year={2021}
+}
+@article{peng2022godel,
+  title={GODEL: Large-Scale Pre-Training for Goal-Directed Dialog},
+  author={Peng, Baolin and Galley, Michel and He, Pengcheng and Brockett, Chris and Liden, Lars and Nouri, Elnaz and Yu, Zhou and Dolan, Bill and Gao, Jianfeng},
+  journal={arXiv preprint arXiv:2206.11309},
+  year={2022}
+}
+```
\ No newline at end of file
diff --git a/convlab/e2e/soloist/e2e_dataloader.py b/convlab/e2e/soloist/e2e_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac7be4d2e3d13f79e10f824505c8e1ab33ff4f35
--- /dev/null
+++ b/convlab/e2e/soloist/e2e_dataloader.py
@@ -0,0 +1,124 @@
+import datasets
+import jsonlines
+import random
+
+# coding=utf-8
+# Copyright 2020 HuggingFace Datasets Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Corpus for E2E Dialog Modeling"""
+
+
+import csv
+
+import datasets
+
+
+_DESCRIPTION = """\
+E2E Dialog Modeling
+"""
+
+_CITATION = """\
+E2E Dialog Modeling
+"""
+
+_DOWNLOAD_URL = ""
+_WEBPAGE = ""
+
+class UnifiedDialogConfig(datasets.BuilderConfig):
+    """BuilderConfig for SuperGLUE."""
+
+    def __init__(self, data_name, **kwargs):
+        """BuilderConfig for SuperGLUE.
+        Args:
+          features: `list[string]`, list of the features that will appear in the
+            feature dict. Should not include "label".
+          data_url: `string`, url to download the zip file from.
+          citation: `string`, citation for the data set.
+          url: `string`, url for information about the data set.
+          label_classes: `list[string]`, the list of classes for the label if the
+            label is present as a string. Non-string labels will be cast to either
+            'False' or 'True'.
+          **kwargs: keyword arguments forwarded to super.
+        """
+        # Version history:
+        # 1.0.2: Fixed non-nondeterminism in ReCoRD.
+        # 1.0.1: Change from the pre-release trial version of SuperGLUE (v1.9) to
+        #        the full release (v2.0).
+        # 1.0.0: S3 (new shuffling, sharding and slicing mechanism).
+        # 0.0.2: Initial version.
+        super(UnifiedDialogConfig, self).__init__(version=datasets.Version("1.0.2"), **kwargs)
+        self.data_name = data_name
+        
+
+
+class Summarization(datasets.GeneratorBasedBuilder):
+    """Summarization"""
+
+    BUILDER_CONFIGS = [
+         UnifiedDialogConfig(name='JOINT',data_name='joint'),
+         UnifiedDialogConfig(name='TRANSFER',data_name='transfer'),
+         UnifiedDialogConfig(name='SINGLE',data_name='single'),
+     ]
+    
+    
+    random.seed(2022)
+
+    def _info(self):
+        return datasets.DatasetInfo(
+            description=_DESCRIPTION,
+            features=datasets.Features(
+                {
+                    "Context": datasets.Value("string"),
+                    "Knowledge": datasets.Value("string"),
+                    "Response": datasets.Value("string"),
+                    "Dataset": datasets.Value("string"),
+                }
+            ),
+            homepage=_WEBPAGE,
+            citation=_CITATION,
+        )
+
+    def _split_generators(self, dl_manager):        
+        
+        data_name = self.config.data_name
+
+        if data_name == 'joint':
+            train_path = f'./multiwoz/data/joint_train.jsonl'
+            validation_path = f'./multiwoz/data/single_validation.jsonl'
+            test_path = f'./multiwoz/data/single_test.jsonl'
+        elif data_name == 'transfer':
+            train_path = f'./multiwoz/data/transfer_train.jsonl'
+            validation_path = f'./multiwoz/data/single_validation.jsonl'
+            test_path = f'./multiwoz/data/single_test.jsonl'
+        elif data_name == 'single':
+            train_path = f'./multiwoz/data/single_train.jsonl'
+            validation_path = f'./multiwoz/data/single_validation.jsonl'
+            test_path = f'./multiwoz/data/single_test.jsonl'
+        else:
+            raise('Please specific dataset config.')
+
+        return [
+            datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": train_path}),
+            datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": validation_path}),
+            datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": test_path}),
+        ]
+    def _generate_examples(self, filepath):
+        
+        with open(filepath, "r", encoding="utf-8") as reader:
+            key = 0
+            for item in jsonlines.Reader(reader):
+                yield key, item
+                key += 1
\ No newline at end of file
diff --git a/convlab/e2e/soloist/multiwoz/script/create_dataset.py b/convlab/e2e/soloist/multiwoz/script/create_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2426af9f6100c07caee4a67d89231cb70481bf8a
--- /dev/null
+++ b/convlab/e2e/soloist/multiwoz/script/create_dataset.py
@@ -0,0 +1,74 @@
+import jsonlines
+import copy
+import fire
+
+from convlab.util.unified_datasets_util import create_delex_data, load_dataset
+from convlab.util import load_e2e_data
+
+def state_to_string(state):
+    domain_str = []
+    for domain,svs in state.items():
+        svs_str = []
+        for s,v in svs.items():
+            if v != '':
+                svs_str.append(f'{s} is {v}')
+        svs_str = ' ; '.join(svs_str)
+        if svs_str != '':
+            domain_str.append(f'{domain} {svs_str}')
+    domain_str = ' | '.join(domain_str)
+    return domain_str
+
+def context_to_string(context):
+    response = ' EOS '.join(i['utterance'].strip() for i in context)
+    return response
+
+def delex_function(d,s,v):
+    s = s.replace(' ','')
+    str_ = f'[{d}_{s}]'
+    return str_
+
+def create_dataset(mode='joint'):
+    dataset_list = {
+        'joint': ['tm1','sgd','multiwoz21'],
+        'transfer': ['tm1','sgd'],
+        'single': ['multiwoz21']
+    }
+
+    examples = []
+    for _data in dataset_list[mode]:
+        
+        dataset = load_dataset(_data)
+        dataset, delex_vocab = create_delex_data(dataset, delex_func=delex_function)
+        e2e_data = load_e2e_data(dataset, delex_utterance = True)
+        
+        split_list = ['train','validation','test'] if mode == 'single' else ['train']
+        
+        for split in split_list:
+            data = e2e_data[split]
+            for i in data:
+                response = i['delex_utterance'].strip()
+                context = i['context']
+                context = context_to_string(context)
+                
+                example = {}
+                example['Context'] = context
+                try:
+                    knowledge = state_to_string(i['context'][-1]['state'])
+                except Exception:
+                    knowledge = ''
+                example['Knowledge'] = knowledge
+                example['Response'] = 'Agent: ' + response.strip()
+                example['Dataset'] = f'{_data}'
+                examples.append(copy.copy(example))
+            if mode == 'single':
+                with jsonlines.open(f'./data/{mode}_{split}.jsonl', "w") as writer:
+                    for item in examples:
+                        writer.write(item)
+                examples = []
+    if mode != 'single':  
+        with jsonlines.open(f'./data/{mode}_train.jsonl', "w") as writer:
+            for item in examples:
+                writer.write(item)
+            
+if __name__ == '__main__':
+    fire.Fire(create_dataset)
\ No newline at end of file
diff --git a/convlab/e2e/soloist/multiwoz/script/create_mwoz_e2e_json.py b/convlab/e2e/soloist/multiwoz/script/create_mwoz_e2e_json.py
deleted file mode 100644
index 5953c19f50e410c8c37002309904de264ae01602..0000000000000000000000000000000000000000
--- a/convlab/e2e/soloist/multiwoz/script/create_mwoz_e2e_json.py
+++ /dev/null
@@ -1,136 +0,0 @@
-import jsonlines
-import json,copy
-fidx = open('test.idx.txt','w')
-
-data = json.load(open('data/test.json'))
-examples = []
-for i in data:
-    name = i['file'].lower()   
-    history = [] 
-    for turn in i['info']:
-        history.append(turn['user_orig'])
-
-        bs = turn['BS']
-        bs_str = []
-        for domain, states in bs.items():
-            domain_str = []
-            for state in states:
-                domain_str.append(f'{state[0]} = {state[1]}')
-            domain_str = ' ; '.join(domain_str)
-            bs_str.append(domain + ' ' + domain_str)
-        bs_str = ' | '.join(bs_str)
-
-        db_str = 'kb '
-        db = turn['KB']
-        if db == 0:
-            db_str += 'zero'
-        elif db_str == 1:
-            db_str += 'one'
-        elif db_str == 2:
-            db_str += 'two'
-        else:
-            db_str += 'more than two'
-
-        act_seq = ' '.join(turn['act'].keys())
-        example = {}
-        example['Context'] = ' EOS '.join(history[:])
-        example['Knowledge'] = ''
-        example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip()
-
-        history.append(turn['sys'].strip())
-        examples.append(copy.copy(example))
-        fidx.write(name + '\n')
-
-writer =  jsonlines.open('multiwoz_test_e2e.jsonl', mode='w')
-for i in examples:    
-    writer.write(i)
-
-
-data = json.load(open('data/val.json'))
-examples = []
-for i in data:
-    name = i['file'].lower()   
-    history = [] 
-    for turn in i['info']:
-        history.append(turn['user_orig'])
-
-
-        bs = turn['BS']
-        bs_str = []
-        for domain, states in bs.items():
-            domain_str = []
-            for state in states:
-                domain_str.append(f'{state[0]} = {state[1]}')
-            domain_str = ' ; '.join(domain_str)
-            bs_str.append(domain + ' ' + domain_str)
-        bs_str = ' | '.join(bs_str)
-
-        db_str = 'kb '
-        db = turn['KB']
-        if db == 0:
-            db_str += 'zero'
-        elif db_str == 1:
-            db_str += 'one'
-        elif db_str == 2:
-            db_str += 'two'
-        else:
-            db_str += 'more than two'
-
-        act_seq = ' '.join(turn['act'].keys())
-        example = {}
-        example['Context'] = ' EOS '.join(history[:])
-        example['Knowledge'] = ''
-        example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip()
-
-        history.append(turn['sys'].strip())
-        examples.append(copy.copy(example))
-        # fidx.write(name + '\n')
-
-writer =  jsonlines.open('multiwoz_valid_e2e.jsonl', mode='w')
-for i in examples:    
-    writer.write(i)
-
-
-data = json.load(open('data/train.json'))
-examples = []
-for i in data:
-    name = i['file'].lower()   
-    history = [] 
-    for turn in i['info']:
-        history.append(turn['user_orig'])
-
-
-        bs = turn['BS']
-        bs_str = []
-        for domain, states in bs.items():
-            domain_str = []
-            for state in states:
-                domain_str.append(f'{state[0]} = {state[1]}')
-            domain_str = ' ; '.join(domain_str)
-            bs_str.append(domain + ' ' + domain_str)
-        bs_str = ' | '.join(bs_str)
-
-        db_str = 'kb '
-        db = turn['KB']
-        if db == 0:
-            db_str += 'zero'
-        elif db_str == 1:
-            db_str += 'one'
-        elif db_str == 2:
-            db_str += 'two'
-        else:
-            db_str += 'more than two'
-
-        act_seq = ' '.join(turn['act'].keys())
-        example = {}
-        example['Context'] = ' EOS '.join(history[:])
-        example['Knowledge'] = ''
-        example['Response'] = 'belief : ' + bs_str + ' EOS ' + turn['sys'].strip()
-
-        history.append(turn['sys'].strip())
-        examples.append(copy.copy(example))
-        # fidx.write(name + '\n')
-
-writer =  jsonlines.open('multiwoz_train_e2e.jsonl', mode='w')
-for i in examples:    
-    writer.write(i)
diff --git a/convlab/e2e/soloist/multiwoz/soloist_net.py b/convlab/e2e/soloist/multiwoz/soloist_net.py
deleted file mode 100644
index 45f98200bc7eaf5387f796c8d29d3bb0555ac9e0..0000000000000000000000000000000000000000
--- a/convlab/e2e/soloist/multiwoz/soloist_net.py
+++ /dev/null
@@ -1,277 +0,0 @@
-import argparse
-import logging
-import math
-import os
-import random
-
-import datasets
-import nltk
-import numpy as np
-import torch
-from datasets import load_dataset, load_metric
-from torch.utils.data.dataloader import DataLoader
-from tqdm.auto import tqdm
-
-import transformers
-from accelerate import Accelerator
-from filelock import FileLock
-from transformers import (
-    CONFIG_MAPPING,
-    MODEL_MAPPING,
-    AdamW,
-    AutoConfig,
-    AutoModelForSeq2SeqLM,
-    AutoTokenizer,
-    DataCollatorForSeq2Seq,
-    SchedulerType,
-    get_scheduler,
-    set_seed,
-)
-from transformers.file_utils import is_offline_mode
-from transformers.utils.versions import require_version
-
-import copy, operator
-from queue import PriorityQueue
-import numpy as np
-import torch
-import torch.nn.functional as F
-from torch import nn
-from torch.autograd import Variable
-from torch.distributions import Categorical
-from convlab.e2e.soloist.multiwoz.config import global_config as cfg
-
-logger = logging.getLogger(__name__)
-logging.basicConfig(
-        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
-        datefmt="%m/%d/%Y %H:%M:%S",
-        level=logging.INFO,
-    )
-
-def cuda_(var):
-    return var.cuda() if cfg.cuda and torch.cuda.is_available() else var
-
-
-def tensor(var):
-    return cuda_(torch.tensor(var))
-
-class SOLOIST:
-
-    def __init__(self) -> None:
-        
-        self.config = AutoConfig.from_pretrained(cfg.model_name_or_path)
-        self.model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_name_or_path,config=self.config)
-        self.tokenizer = AutoTokenizer.from_pretrained('t5-base')
-        print('model loaded!')
-
-        self.model = self.model.cuda() if torch.cuda.is_available() else self.model
-
-    def generate(self, inputs):
-
-        self.model.eval()
-        inputs = self.tokenizer([inputs])
-        input_ids = tensor(inputs['input_ids'])
-        # generated_tokens = self.model.generate(input_ids = input_ids, max_length = cfg.max_length, num_beams = cfg.num_beams)
-        generated_tokens = self.model.generate(input_ids = input_ids, max_length = cfg.max_length, top_p=cfg.top_p)
-        decoded_preds = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
-
-        return decoded_preds[0]
-
-    
-    def train_loop(self):
-
-        def preprocess_function(examples):
-            contextes = examples['Context']
-            responses = examples['Response']
-            belief = examples['Belief']
-            responses_labels = []
-            inputs = []
-
-            for context, response, kb in zip(contextes, responses, belief):
-                if cfg.no_kb:
-                    inputs.append(context + ' => ')
-                else:
-
-                    if cfg.format_version == 'e2e':
-                        context = ' EOS '.join(context.split(' EOS ')[-10:])
-                        _input = context
-                    
-                    if cfg.format_version == 'e2e+lm':
-                        context = ' EOS '.join(context.split(' EOS ')[-10:])
-                        inputs.append('[E2E] ' + context)
-                        responses_labels.append(response)
-                        inputs.append('[LM] ' + context )
-                        responses_labels.append(response.split(' EOS ')[1])
-                        continue
-                    
-                    if cfg.format_version == 'v2':
-                        _input = kb + context
-                    
-                    if cfg.format_version == 'v3':
-                        _input = ''
-                        context = context.split(' EOS ')
-                        for idx, turn in enumerate(context):
-                            if idx % 2 == 0:
-                                _input += 'user : ' + turn.strip()
-                            else:
-                                _input += ' system : ' + turn.strip()
-                        _input = _input + ' <|Knowledge|> ' + kb
-
-                    if cfg.format_version == 'v4':
-                        _input = ''
-                        context = context.split(' EOS ')
-                        for idx, turn in enumerate(context):
-                            if idx % 2 == 0:
-                                _input += 'user : ' + turn.strip()
-                            else:
-                                _input += ' system : ' + turn.strip()
-                        _input = kb + _input
-                    
-                    inputs.append(_input)
-                    responses_labels.append(response)
-            model_inputs = self.tokenizer(inputs, max_length=cfg.max_length, padding="max_length", truncation=True)
-
-            
-            with self.tokenizer.as_target_tokenizer():
-                labels = self.tokenizer(responses_labels, max_length=cfg.max_target_length, padding="max_length", truncation=True)
-
-            
-            if cfg.ignore_pad_token_for_loss:
-                labels["labels"] = [
-                    [(l if l != self.tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
-                ]
-
-            model_inputs["labels"] = labels["labels"]
-            return model_inputs
-
-        raw_datasets = load_dataset(cfg.dataset_name)
-        column_names = ['Context','Response','Belief']
-        lm_datasets = raw_datasets.map(
-            preprocess_function,
-            batched=True,
-            remove_columns=column_names,
-            num_proc=cfg.preprocessing_num_workers,
-            load_from_cache_file=False,
-            desc=f"Processing dataset",
-        )
-
-        train_dataset = lm_datasets["test"]
-        # train_dataset = lm_datasets["validation"]
-        eval_dataset = lm_datasets["test"]
-        test_dataset = lm_datasets["test"]
-        for index in random.sample(range(len(train_dataset)), 1):
-            logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 
-
-        label_pad_token_id = -100 if cfg.ignore_pad_token_for_loss else self.tokenizer.pad_token_id
-
-        accelerator = Accelerator()
-        logger.info(accelerator.state)
-        data_collator = DataCollatorForSeq2Seq(
-            self.tokenizer,
-            model=self.model,
-            label_pad_token_id=label_pad_token_id,
-            pad_to_multiple_of=8 if accelerator.use_fp16 else None,
-        )
-
-
-        train_dataloader = DataLoader(
-        train_dataset, shuffle=True, collate_fn=data_collator, batch_size=cfg.per_device_train_batch_size
-        )
-        eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=cfg.per_device_eval_batch_size)
-        test_dataloader = DataLoader(test_dataset, collate_fn=data_collator, batch_size=cfg.per_device_eval_batch_size)
-
-        # Optimizer
-        # Split weights in two groups, one with weight decay and the other not.
-        no_decay = ["bias", "LayerNorm.weight"]
-        optimizer_grouped_parameters = [
-            {
-                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
-                "weight_decay": cfg.weight_decay,
-            },
-            {
-                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
-                "weight_decay": 0.0,
-            },
-        ]
-        optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.learning_rate)
-
-        # Prepare everything with our `accelerator`.
-        self.model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
-            self.model, optimizer, train_dataloader, eval_dataloader, test_dataloader
-        )
-
-        # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
-        # shorter in multiprocess)
-
-        # Scheduler and math around the number of training steps.
-        num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
-        if cfg.max_train_steps is None:
-            cfg.max_train_steps = cfg.num_train_epochs * num_update_steps_per_epoch
-        else:
-            cfg.num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch)
-
-        lr_scheduler = get_scheduler(
-            name=cfg.lr_scheduler_type,
-            optimizer=optimizer,
-            num_warmup_steps=cfg.num_warmup_steps,
-            num_training_steps=cfg.max_train_steps,
-        )
-
-        # Metric
-
-        # Train!
-        total_batch_size = cfg.per_device_train_batch_size * accelerator.num_processes * cfg.gradient_accumulation_steps
-
-        logger.info("***** Running training *****")
-        logger.info(f"  Num examples = {len(train_dataset)}")
-        logger.info(f"  Num Epochs = {cfg.num_train_epochs}")
-        logger.info(f"  Instantaneous batch size per device = {cfg.per_device_train_batch_size}")
-        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
-        logger.info(f"  Gradient Accumulation steps = {cfg.gradient_accumulation_steps}")
-        logger.info(f"  Total optimization steps = {cfg.max_train_steps}")
-        # Only show the progress bar once on each machine.
-        progress_bar = tqdm(range(cfg.max_train_steps), disable=not accelerator.is_local_main_process)
-        completed_steps = 0
-        global_steps = 0
-        tr_loss, logging_loss = 0.0, 0.0
-        for epoch in range(cfg.num_train_epochs):
-            self.model.train()
-            # for step, batch in enumerate(train_dataloader):
-            for step, batch in enumerate(train_dataloader):
-                global_steps += 1            
-                outputs = self.model(**batch)
-                loss = outputs.loss
-                loss = loss / cfg.gradient_accumulation_steps
-                tr_loss += loss.item()
-                accelerator.backward(loss)
-                
-                if step % cfg.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
-                    optimizer.step()
-                    lr_scheduler.step()
-                    optimizer.zero_grad()
-                    completed_steps += 1
-
-                if completed_steps >= cfg.max_train_steps:
-                    break
-
-                if step % cfg.logging_steps == 0:
-                    logger.info(f"  EVALERR:  {(tr_loss - logging_loss)/float(cfg.logging_steps)}")
-                    logging_loss = tr_loss
-                    progress_bar.update(cfg.logging_steps)
-
-                if cfg.output_dir is not None and global_steps % cfg.save_steps == 0 and global_steps > 0:
-                    
-                    accelerator.wait_for_everyone()
-                    if accelerator.is_local_main_process:               
-                        checkpoint_prefix = 'checkpoint'
-                        output_dir = os.path.join(cfg.output_dir, '{}-{}'.format(checkpoint_prefix, global_steps))
-                        if not os.path.exists(output_dir):
-                            os.makedirs(output_dir)
-                        unwrapped_model = accelerator.unwrap_model(self.model)
-                        unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
-
-                        self.tokenizer.save_pretrained(output_dir)
-                        torch.save(cfg, os.path.join(output_dir, 'training_args.bin'))
-                        logger.info("Saving model checkpoint to %s", output_dir)
-
-
-    
\ No newline at end of file
diff --git a/convlab/e2e/soloist/train.py b/convlab/e2e/soloist/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c2078954173bcf20fe2f21691b720eebf4d74c9
--- /dev/null
+++ b/convlab/e2e/soloist/train.py
@@ -0,0 +1,836 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Fine-tuning a 🤗 Transformers model on summarization.
+"""
+# You can also adapt this script on your own summarization task. Pointers for this are left as comments.
+
+import argparse
+import logging
+import math
+import os
+import random
+import json
+
+import datasets
+import nltk
+import numpy as np
+import torch
+from datasets import load_dataset, load_metric
+from torch.utils.data.dataloader import DataLoader
+from tqdm.auto import tqdm
+
+import transformers
+from accelerate import Accelerator
+from filelock import FileLock
+from transformers import (
+    CONFIG_MAPPING,
+    MODEL_MAPPING,
+    AdamW,
+    AutoConfig,
+    AutoModelForSeq2SeqLM,
+    AutoTokenizer,
+    DataCollatorForSeq2Seq,
+    SchedulerType,
+    get_scheduler,
+    set_seed,
+)
+from transformers.file_utils import is_offline_mode
+from transformers.utils.versions import require_version
+
+from nltk.tokenize import TweetTokenizer
+import re
+re_art = re.compile(r'\b(a|an|the)\b')
+re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
+
+
+def normalize_answer(s):
+    return s
+
+def clean_str(txt):
+	#print("in=[%s]" % txt)
+	txt = txt.lower()
+	txt = re.sub('^',' ', txt)
+	txt = re.sub('$',' ', txt)
+
+	# url and tag
+	words = []
+	for word in txt.split():
+		i = word.find('http') 
+		if i >= 0:
+			word = word[:i] + ' ' + '__url__'
+		words.append(word.strip())
+	txt = ' '.join(words)
+
+	# remove markdown URL
+	txt = re.sub(r'\[([^\]]*)\] \( *__url__ *\)', r'\1', txt)
+
+	# remove illegal char
+	txt = re.sub('__url__','URL',txt)
+	txt = re.sub(r"[^A-Za-z0-9():,.!?\"\']", " ", txt)
+	txt = re.sub('URL','__url__',txt)	
+
+	# contraction
+	add_space = ["'s", "'m", "'re", "n't", "'ll","'ve","'d","'em"]
+	tokenizer = TweetTokenizer(preserve_case=False)
+	txt = ' ' + ' '.join(tokenizer.tokenize(txt)) + ' '
+	txt = txt.replace(" won't ", " will n't ")
+	txt = txt.replace(" can't ", " can n't ")
+	for a in add_space:
+		txt = txt.replace(a+' ', ' '+a+' ')
+
+	txt = re.sub(r'^\s+', '', txt)
+	txt = re.sub(r'\s+$', '', txt)
+	txt = re.sub(r'\s+', ' ', txt) # remove extra spaces
+	
+	#print("out=[%s]" % txt)
+	return txt
+
+logger = logging.getLogger(__name__)
+require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
+
+
+import os
+from dotenv import load_dotenv
+load_dotenv()
+if os.getenv('WANDB_API_KEY') is None:
+    USE_WANDB = False
+else:
+    USE_WANDB = True
+    wandb_key = os.getenv('WANDB_API_KEY')
+
+# You should update this to your particular problem to have better documentation of `model_type`
+MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
+MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
+
+try:
+    nltk.data.find("tokenizers/punkt")
+except (LookupError, OSError):
+    if is_offline_mode():
+        raise LookupError(
+            "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
+        )
+    with FileLock(".lock") as lock:
+        nltk.download("punkt", quiet=True)
+
+def parse_args():
+    parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
+    parser.add_argument(
+        "--dataset_name",
+        type=str,
+        default=None,
+        help="The name of the dataset to use (via the datasets library).",
+    )
+    parser.add_argument(
+        "--dataset_config_name",
+        type=str,
+        default=None,
+        help="The configuration name of the dataset to use (via the datasets library).",
+    )
+    parser.add_argument(
+        "--train_file", type=str, default=None, help="A csv or a json file containing the training data."
+    )
+    parser.add_argument(
+        "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
+    )
+    
+    parser.add_argument(
+        "--max_source_length",
+        type=int,
+        default=1024,
+        help="The maximum total input sequence length after "
+        "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
+    )
+    parser.add_argument(
+        "--source_prefix",
+        type=str,
+        default=None,
+        help="A prefix to add before every source text " "(useful for T5 models).",
+    )
+    parser.add_argument(
+        "--preprocessing_num_workers",
+        type=int,
+        default=None,
+        help="The number of processes to use for the preprocessing.",
+    )
+
+    parser.add_argument(
+        "--max_target_length",
+        type=int,
+        default=64,
+        help="The maximum total sequence length for target text after "
+        "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
+        "during ``evaluate`` and ``predict``.",
+    )
+    parser.add_argument(
+        "--val_max_target_length",
+        type=int,
+        default=None,
+        help="The maximum total sequence length for validation "
+        "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
+        "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
+        "param of ``model.generate``, which is used during ``evaluate`` and ``predict``.",
+    )
+    parser.add_argument(
+        "--max_length",
+        type=int,
+        default=128,
+        help=(
+            "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
+            " sequences shorter will be padded if `--pad_to_max_lengh` is passed."
+        ),
+    )
+    parser.add_argument(
+        "--num_beams",
+        type=int,
+        default=None,
+        help="Number of beams to use for evaluation. This argument will be "
+        "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
+    )
+    parser.add_argument(
+        "--model_name_or_path",
+        type=str,
+        help="Path to pretrained model or model identifier from huggingface.co/models.",
+        required=True,
+    )
+    parser.add_argument(
+        "--config_name",
+        type=str,
+        default=None,
+        help="Pretrained config name or path if not the same as model_name",
+    )
+    parser.add_argument(
+        "--tokenizer_name",
+        type=str,
+        default=None,
+        help="Pretrained tokenizer name or path if not the same as model_name",
+    )
+    parser.add_argument(
+        "--text_column",
+        type=str,
+        default=None,
+        help="The name of the column in the datasets containing the full texts (for summarization).",
+    )
+    parser.add_argument(
+        "--summary_column",
+        type=str,
+        default=None,
+        help="The name of the column in the datasets containing the summaries (for summarization).",
+    )
+    parser.add_argument(
+        "--use_slow_tokenizer",
+        action="store_true",
+        help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
+    )
+    parser.add_argument(
+        "--per_device_train_batch_size",
+        type=int,
+        default=8,
+        help="Batch size (per device) for the training dataloader.",
+    )
+    parser.add_argument(
+        "--per_device_eval_batch_size",
+        type=int,
+        default=8,
+        help="Batch size (per device) for the evaluation dataloader.",
+    )
+    parser.add_argument(
+        "--learning_rate",
+        type=float,
+        default=5e-5,
+        help="Initial learning rate (after the potential warmup period) to use.",
+    )
+    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
+    parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
+    parser.add_argument(
+        "--max_train_steps",
+        type=int,
+        default=None,
+        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+    )
+    parser.add_argument(
+        "--gradient_accumulation_steps",
+        type=int,
+        default=1,
+        help="Number of updates steps to accumulate before performing a backward/update pass.",
+    )
+    parser.add_argument(
+        "--lr_scheduler_type",
+        type=SchedulerType,
+        default="linear",
+        help="The scheduler type to use.",
+        choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
+    )
+    parser.add_argument(
+        "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
+    )
+    parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
+    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+    parser.add_argument(
+        "--model_type",
+        type=str,
+        default=None,
+        help="Model type to use if training from scratch.",
+        choices=MODEL_TYPES,
+    )
+
+
+    parser.add_argument(
+        "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
+    )
+
+    parser.add_argument(
+        "--pad_to_max_length", type=bool, default=True, help="do pading"
+    )
+
+    parser.add_argument(
+        "--ignore_pad_token_for_loss", type=bool, default=True, help="do pading"
+    )
+
+    parser.add_argument(
+        "--logging_steps", type=int, default=500, help="do pading"
+    )
+
+    parser.add_argument(
+        "--save_steps", type=int, default=5000, help="do pading"
+    )
+
+    parser.add_argument(
+        "--save_every_checkpoint", action="store_true"
+    )
+
+    parser.add_argument(
+        "--max_grad_norm", type=float, default=1.0, help="max_grad_norm"
+    )
+
+    parser.add_argument(
+        "--exp_name",
+        type=str,
+        help="Description to the experiment",
+        default='multiwoz',
+    )
+
+    parser.add_argument(
+        "--use_special_token",
+        action="store_true",
+        help="add special token or not"
+    )
+
+    parser.add_argument(
+        "--format_version",
+        type=str, default='v1',
+        help="format version"
+    )
+
+    parser.add_argument(
+        "--wandb_exp_name",
+        type=str,
+        default='multiwoz',
+        help="Description to the experiment worksheet name",
+    )
+
+
+    args = parser.parse_args()
+
+    # Sanity checks
+    if args.dataset_name is None and args.train_file is None and args.validation_file is None:
+        raise ValueError("Need either a dataset name or a training/validation file.")
+    else:
+        if args.train_file is not None:
+            extension = args.train_file.split(".")[-1]
+            assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
+        if args.validation_file is not None:
+            extension = args.validation_file.split(".")[-1]
+            assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
+
+    if args.output_dir is not None:
+        os.makedirs(args.output_dir, exist_ok=True)
+
+    return args
+
+
+def main():
+    args = parse_args()
+
+    if args.source_prefix is None and args.model_name_or_path in [
+        "t5-small",
+        "t5-base",
+        "t5-large",
+        "t5-3b",
+        "t5-11b",
+    ]:
+        logger.warning(
+            "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
+            "`--source_prefix 'summarize: ' `"
+        )
+    # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
+    accelerator = Accelerator()
+    # Make one log on every process with the configuration for debugging.
+    logging.basicConfig(
+        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+        datefmt="%m/%d/%Y %H:%M:%S",
+        level=logging.INFO,
+    )
+    logger.info(accelerator.state)
+
+    # Setup logging, we only want one process per machine to log things on the screen.
+    # accelerator.is_local_main_process is only True for one process per machine.
+    logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR)
+    if accelerator.is_local_main_process:
+        datasets.utils.logging.set_verbosity_warning()
+        transformers.utils.logging.set_verbosity_info()
+    else:
+        datasets.utils.logging.set_verbosity_error()
+        transformers.utils.logging.set_verbosity_error()
+
+    # If passed along, set the training seed now.
+    if args.seed is not None:
+        set_seed(args.seed)
+
+    if accelerator.is_local_main_process and USE_WANDB:
+        config = dict(
+        dataset_id = "",
+        infra = "",
+        )
+        import wandb
+        wandb.init(
+        project=args.wandb_exp_name,
+        notes="Finetuning",
+        tags=["multiwoz"],
+        config=config,
+        entity= 'Convlab3')
+
+        wandb.run.name = args.exp_name
+
+
+    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
+    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
+    # (the dataset will be downloaded automatically from the datasets Hub).
+    #
+    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
+    # 'text' is found. You can easily tweak this behavior (see below).
+    #
+    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
+    # download the dataset.
+    if args.dataset_name is not None:
+        # Downloading and loading a dataset from the hub.
+        raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
+    else:
+        data_files = {}
+        if args.train_file is not None:
+            data_files["train"] = args.train_file
+        if args.validation_file is not None:
+            data_files["validation"] = args.validation_file
+        extension = args.train_file.split(".")[-1]
+        raw_datasets = load_dataset(extension, data_files=data_files)
+    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
+    # https://huggingface.co/docs/datasets/loading_datasets.html.
+
+    # Load pretrained model and tokenizer
+    #
+    # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
+    # download model & vocab.
+    if args.config_name:
+        config = AutoConfig.from_pretrained(args.config_name)
+    elif args.model_name_or_path:
+        config = AutoConfig.from_pretrained(args.model_name_or_path)
+    else:
+        config = CONFIG_MAPPING[args.model_type]()
+        logger.warning("You are instantiating a new config instance from scratch.")
+
+    if args.tokenizer_name:
+        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=not args.use_slow_tokenizer)
+    elif args.model_name_or_path:
+        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=not args.use_slow_tokenizer)
+    else:
+        raise ValueError(
+            "You are instantiating a new tokenizer from scratch. This is not supported by this script."
+            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
+        )
+
+    if args.model_name_or_path:
+        model = AutoModelForSeq2SeqLM.from_pretrained(
+            args.model_name_or_path,
+            from_tf=bool(".ckpt" in args.model_name_or_path),
+            config=config,
+        )
+    else:
+        logger.info("Training new model from scratch")
+        model = AutoModelForSeq2SeqLM.from_config(config)
+
+    if 'blender' in args.model_name_or_path.lower():
+        model.model.encoder.embed_positions.weight = torch.nn.Parameter(model.model.encoder.embed_positions.weight.repeat(4,1))
+    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
+
+    if args.use_special_token:
+        special_tokens = [i.strip() for i in open('special_tokens.txt')]
+        tokenizer.add_tokens(special_tokens)
+
+    model.resize_token_embeddings(len(tokenizer))
+    if model.config.decoder_start_token_id is None:
+        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
+
+    prefix = args.source_prefix if args.source_prefix is not None else ""
+    max_length = args.max_length
+    padding = "max_length" if args.pad_to_max_length else False
+    max_target_length = args.max_target_length
+    def preprocess_function(examples):
+        contextes = examples['Context']
+        responses = examples['Response']
+        kbs = examples['Knowledge']
+
+        responses_labels = []
+        inputs = []
+
+        for context, response, kb in zip(contextes, responses, kbs):
+            if args.format_version == 'v1':
+                _input = ' EOS '.join(context.split(' EOS ')[-10:])
+                _response = 'Belief: ' + kb + response
+                inputs.append(_input)
+                responses_labels.append(_response)
+
+        model_inputs = tokenizer(inputs, max_length=args.max_length, padding=padding, truncation=True)
+
+        # labels = model_inputs
+        # Setup the tokenizer for targets
+        with tokenizer.as_target_tokenizer():
+            labels = tokenizer(responses_labels, max_length=max_target_length, padding=padding, truncation=True)
+
+        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
+        # padding in the loss.
+        if padding == "max_length" and args.ignore_pad_token_for_loss:
+            labels["labels"] = [
+                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
+            ]
+
+        model_inputs["labels"] = labels["labels"]
+        return model_inputs
+
+    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
+    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
+    # to preprocess.
+    #
+    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
+    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
+    
+    # del raw_datasets['train']
+    # del raw_datasets['test']
+    column_names = ['Context','Response','Knowledge','Dataset']
+    # # column_names = ['text']
+    # raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
+    lm_datasets = raw_datasets.map(
+        preprocess_function,
+        batched=True,
+        remove_columns=column_names,
+        num_proc=args.preprocessing_num_workers,
+        load_from_cache_file=False,
+        desc=f"Processing dataset",
+    )
+
+    train_dataset = lm_datasets["test"]
+    eval_dataset = lm_datasets["validation"]
+    test_dataset = lm_datasets["test"]
+
+    # Log a few random samples from the training set:
+    for index in random.sample(range(len(train_dataset)), 1):
+        logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
+
+
+    label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
+    data_collator = DataCollatorForSeq2Seq(
+        tokenizer,
+        model=model,
+        label_pad_token_id=label_pad_token_id,
+        pad_to_multiple_of=8 if accelerator.use_fp16 else None,
+    )
+
+    def postprocess_text(preds, labels):
+        preds = [normalize_answer(pred.strip().replace('Agent :','')) for pred in preds]
+        labels = [normalize_answer(label.strip().replace('Agent :','')) for label in labels]
+
+        # rougeLSum expects newline after each sentence
+        # preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
+        # labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
+
+        return preds, labels
+
+    train_dataloader = DataLoader(
+        train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
+    )
+    eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
+    test_dataloader = DataLoader(test_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
+
+    # Optimizer
+    # Split weights in two groups, one with weight decay and the other not.
+    no_decay = ["bias", "LayerNorm.weight"]
+    optimizer_grouped_parameters = [
+        {
+            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+            "weight_decay": args.weight_decay,
+        },
+        {
+            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+            "weight_decay": 0.0,
+        },
+    ]
+    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
+
+    # Prepare everything with our `accelerator`.
+    model, optimizer, train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
+        model, optimizer, train_dataloader, eval_dataloader, test_dataloader
+    )
+
+    # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be
+    # shorter in multiprocess)
+
+    # Scheduler and math around the number of training steps.
+    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+    if args.max_train_steps is None:
+        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+    else:
+        args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+    lr_scheduler = get_scheduler(
+        name=args.lr_scheduler_type,
+        optimizer=optimizer,
+        num_warmup_steps=args.num_warmup_steps,
+        num_training_steps=args.max_train_steps,
+    )
+
+    # Metric
+
+    # Train!
+    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+    logger.info("***** Running training *****")
+    logger.info(f"  Num examples = {len(train_dataset)}")
+    logger.info(f"  Num Epochs = {args.num_train_epochs}")
+    logger.info(f"  Instantaneous batch size per device = {args.per_device_train_batch_size}")
+    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+    logger.info(f"  Total optimization steps = {args.max_train_steps}")
+    # Only show the progress bar once on each machine.
+    progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
+    completed_steps = 0
+    global_steps = 0
+    tr_loss, logging_loss = 0.0, 0.0
+    for epoch in range(args.num_train_epochs):
+        model.train()
+        
+        for step, batch in enumerate(train_dataloader):
+            
+            global_steps += 1            
+            outputs = model(**batch)
+            loss = outputs.loss
+            loss = loss / args.gradient_accumulation_steps
+            tr_loss += loss.item()
+            accelerator.backward(loss)            
+            
+            if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                optimizer.step()
+                lr_scheduler.step()
+                optimizer.zero_grad()
+                completed_steps += 1
+
+            if completed_steps >= args.max_train_steps:
+                break
+
+            if step % args.logging_steps == 0:
+                logger.info(f"  EVALERR:  {(tr_loss - logging_loss)/float(args.logging_steps)}")
+                if accelerator.is_local_main_process and USE_WANDB:
+                    wandb.log({'loss': tr_loss - logging_loss})
+                logging_loss = tr_loss
+                progress_bar.update(args.logging_steps)
+
+            if args.output_dir is not None and global_steps % args.save_steps == 0 and global_steps > 0:
+                print('hit store')
+                accelerator.wait_for_everyone()
+                if accelerator.is_local_main_process:               
+                    checkpoint_prefix = 'checkpoint'
+                    output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_steps))
+                    if not os.path.exists(output_dir):
+                        os.makedirs(output_dir)
+                    unwrapped_model = accelerator.unwrap_model(model)
+                    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
+
+                    tokenizer.save_pretrained(output_dir)
+                    torch.save(args, os.path.join(output_dir, 'training_args.bin'))
+                    logger.info("Saving model checkpoint to %s", output_dir)
+
+        model.eval()
+        if args.val_max_target_length is None:
+            args.val_max_target_length = args.max_target_length
+
+        gen_kwargs = {
+            "max_length": args.val_max_target_length if args is not None else config.max_length,
+            "num_beams": args.num_beams,
+        }
+
+        def chunks(lst, n):
+            for i in range(0, len(lst), n):
+                yield lst[i:i + n]
+
+        metric = load_metric("./rouge_metric.py")
+        metric_bleu = load_metric("./bleu_metric.py")
+        decoded_preds_all = []
+        for step, batch in enumerate(eval_dataloader):
+            with torch.no_grad():
+                generated_tokens = accelerator.unwrap_model(model).generate(
+                    batch["input_ids"],
+                    attention_mask=batch["attention_mask"],
+                    **gen_kwargs,
+                )
+
+                generated_tokens = accelerator.pad_across_processes(
+                    generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
+                )
+                labels = batch["labels"]
+                if not args.pad_to_max_length:
+                    # If we did not pad to max length, we need to pad the labels too
+                    labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
+
+                generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
+                labels = accelerator.gather(labels).cpu().numpy()
+
+                if args.ignore_pad_token_for_loss:
+                    # Replace -100 in the labels as we can't decode them.
+                    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
+                if isinstance(generated_tokens, tuple):
+                    generated_tokens = generated_tokens[0]
+                decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+                decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
+
+                metric.add_batch(predictions=decoded_preds, references=decoded_labels)
+                _decoded_preds = [i.split() for i in decoded_preds]
+                _decoded_labels = [[i.split()] for i in decoded_labels]
+                decoded_preds_all.extend(decoded_preds)
+                metric_bleu.add_batch(predictions=_decoded_preds, references=_decoded_labels)
+            
+                
+        result = metric.compute(use_stemmer=True)
+        # Extract a few results from ROUGE
+        result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
+
+        result = {k: round(v, 4) for k, v in result.items()}
+
+        logger.info(result)
+
+        result_bleu = metric_bleu.compute()
+        logger.info(result_bleu)
+
+        accelerator.wait_for_everyone()
+        if accelerator.is_local_main_process and USE_WANDB:
+            wandb.log({'valid_bleu': result_bleu['bleu']})
+            wandb.log({'valid_rouge': result['rougeL']})
+
+        
+        if args.output_dir is not None:
+            accelerator.wait_for_everyone()
+            if accelerator.is_local_main_process:               
+                if not os.path.exists(args.output_dir):
+                    os.makedirs(args.output_dir)
+                output_dir_file_name = os.path.join(args.output_dir, 'valid-step-{}'.format(completed_steps))
+                print(output_dir_file_name)
+                json.dump(decoded_preds_all, open(output_dir_file_name,'w'), indent=2)
+                logger.info("Saving model outputs to %s", output_dir_file_name)
+        
+        metric = load_metric("rouge")
+        metric_bleu = load_metric("bleu")
+
+        gen_kwargs = {
+            "max_length": args.val_max_target_length if args is not None else config.max_length,
+            "num_beams": args.num_beams,
+        }
+        decoded_preds_all = []
+        for step, batch in enumerate(test_dataloader):
+            with torch.no_grad():
+                generated_tokens = accelerator.unwrap_model(model).generate(
+                    batch["input_ids"],
+                    attention_mask=batch["attention_mask"],
+                    **gen_kwargs,
+                )
+
+                generated_tokens = accelerator.pad_across_processes(
+                    generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
+                )
+                labels = batch["labels"]
+                if not args.pad_to_max_length:
+                    # If we did not pad to max length, we need to pad the labels too
+                    labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
+
+                generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
+                labels = accelerator.gather(labels).cpu().numpy()
+
+                if args.ignore_pad_token_for_loss:
+                    # Replace -100 in the labels as we can't decode them.
+                    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
+                if isinstance(generated_tokens, tuple):
+                    generated_tokens = generated_tokens[0]
+                decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
+                decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
+
+                decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
+                metric.add_batch(predictions=decoded_preds, references=decoded_labels)
+                _decoded_preds = [i.split() for i in decoded_preds]
+                _decoded_labels = [[i.split()] for i in decoded_labels]
+                decoded_preds_all.extend(_decoded_preds)
+                metric_bleu.add_batch(predictions=_decoded_preds, references=_decoded_labels)
+                
+        result = metric.compute(use_stemmer=True)
+        # Extract a few results from ROUGE
+        result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
+
+        result = {k: round(v, 4) for k, v in result.items()}
+
+        logger.info(result)
+
+        result_bleu = metric_bleu.compute()
+        logger.info(result_bleu)
+
+        accelerator.wait_for_everyone()
+        if accelerator.is_local_main_process and USE_WANDB:
+            wandb.log({'test_bleu': result_bleu['bleu']})
+            wandb.log({'test_rouge': result['rougeL']})
+
+        import json
+        if args.output_dir is not None:
+            accelerator.wait_for_everyone()
+            if accelerator.is_local_main_process:               
+                if not os.path.exists(args.output_dir):
+                    os.makedirs(args.output_dir)
+                output_dir_file_name = os.path.join(args.output_dir, 'test-step-{}'.format(completed_steps))
+                print(output_dir_file_name)
+                json.dump(decoded_preds_all, open(output_dir_file_name,'w'), indent=2)
+                logger.info("Saving model outputs to %s", output_dir_file_name)
+
+        
+        if args.output_dir is not None and args.save_every_checkpoint:
+            accelerator.wait_for_everyone()
+            if accelerator.is_local_main_process:               
+                checkpoint_prefix = 'checkpoint'
+                output_dir = os.path.join(args.output_dir, '{}-epoch-{}'.format(checkpoint_prefix, epoch))
+                if not os.path.exists(output_dir):
+                    os.makedirs(output_dir)
+                unwrapped_model = accelerator.unwrap_model(model)
+                unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
+
+                tokenizer.save_pretrained(output_dir)
+                torch.save(args, os.path.join(output_dir, 'training_args.bin'))
+                logger.info("Saving model checkpoint to %s", output_dir)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/convlab/util/unified_datasets_util.py b/convlab/util/unified_datasets_util.py
index 726079d1c2b04c304bfd7055d48b1b4ae4905856..2c16bc5cc6d99bc6a00005fc3652ee556321c8df 100644
--- a/convlab/util/unified_datasets_util.py
+++ b/convlab/util/unified_datasets_util.py
@@ -159,20 +159,21 @@ def load_database(dataset_name: str):
 
 
 def load_unified_data(
-    dataset,
-    data_split='all',
-    speaker='all',
-    utterance=False,
-    dialogue_acts=False,
-    state=False,
-    db_results=False,
-    use_context=False,
-    context_window_size=0,
-    terminated=False,
-    goal=False,
-    active_domains=False,
-    split_to_turn=True
-):
+        dataset, 
+        data_split='all', 
+        speaker='all', 
+        utterance=False, 
+        dialogue_acts=False, 
+        state=False, 
+        db_results=False,
+        delex_utterance=False,
+        use_context=False, 
+        context_window_size=0, 
+        terminated=False, 
+        goal=False, 
+        active_domains=False,
+        split_to_turn=True
+    ):
     """
     > This function takes in a dataset, and returns a dictionary of data splits, where each data split
     is a list of samples
@@ -200,8 +201,7 @@ def load_unified_data(
     data_splits = dataset.keys() if data_split == 'all' else [data_split]
     assert speaker in ['user', 'system', 'all']
     assert not use_context or context_window_size > 0
-    info_list = list(
-        filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results']))
+    info_list = list(filter(eval, ['utterance', 'dialogue_acts', 'state', 'db_results', 'delex_utterance']))
     info_list += ['utt_idx']
     data_by_split = {}
     for data_split in data_splits:
@@ -452,10 +452,13 @@ def create_delex_data(dataset, delex_func=lambda d, s, v: f'[({d})-({s})]', igno
                             # has value
                             for value in values.split('|'):
                                 if value.lower() not in ignore_values:
-                                    placeholder = delex_func(
-                                        domain, slot, value)
-                                    pattern = re.compile(
-                                        r'\b({})\b'.format(value), flags=re.I)
+                                    placeholder = delex_func(domain, slot, value)
+                                    #TODO: value = ?
+                                    value = '\?' if value == '?' else value
+                                    try:
+                                        pattern = re.compile(r'\b({})\b'.format(value), flags=re.I)
+                                    except Exception:
+                                        print(value)
                                     if delex_inplace(delex_utt, pattern):
                                         delex_vocab.add(placeholder)