diff --git a/convlab2/nlu/jointBERT/README.md b/convlab2/nlu/jointBERT/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..c9756d3c1ebdd42e975bb86d32a532b066a29048
--- /dev/null
+++ b/convlab2/nlu/jointBERT/README.md
@@ -0,0 +1,57 @@
+# BERTNLU
+
+On top of the pre-trained BERT, BERTNLU use an MLP for slot tagging and another MLP for intent classification. All parameters are fine-tuned to learn these two tasks jointly.
+
+Dialog acts are split into two groups, depending on whether the values are in the utterances:
+
+- For dialogue acts whose values are in the utterances, we use **slot tagging** to extract the values. For example, `"Find me a cheap hotel"`, its dialog act is `{intent=Inform, domain=hotel, slot=price, value=cheap}`, and the corresponding BIO tag sequence is `["O", "O", "O", "B-inform-hotel-price", "O"]`. An MLP classifier takes a token's representation from BERT and outputs its tag.
+- For dialogue acts whose values may not be presented in the utterances, we treat them as **intents** of the utterances. Another MLP takes embeddings of `[CLS]` of a utterance as input and does the binary classification for each intent independently. Since some intents are rare, we set the weight of positive samples as $\lg(\frac{\# \ negative\ samples}{\# \ positive\ samples})$ empirically for each intent.
+
+The model can also incorporate context information by setting the `context=true` in the config file. The context utterances will be concatenated (separated by `[SEP]`) and fed into BERT. Then the `[CLS]` embedding serves as context representaion and is concatenated to all token representations in the target utterance right before the slot and intent classifiers.
+
+
+## Usage
+
+Follow the instruction under each dataset's directory to prepare data and model config file for training and evaluation.
+
+#### Train a model
+
+```sh
+$ python train.py --config_path path_to_a_config_file
+```
+
+The model (`pytorch_model.bin`) will be saved under the `output_dir` of the config file.
+
+#### Test a model
+
+```sh
+$ python test.py --config_path path_to_a_config_file
+```
+
+The result (`output.json`) will be saved under the `output_dir` of the config file. Also, it will be zipped as `zipped_model_path` in the config file.
+
+
+## References
+
+```
+@inproceedings{devlin2019bert,
+  title={BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding},
+  author={Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina},
+  booktitle={Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)},
+  pages={4171--4186},
+  year={2019}
+}
+
+@inproceedings{zhu-etal-2020-convlab,
+    title = "{C}onv{L}ab-2: An Open-Source Toolkit for Building, Evaluating, and Diagnosing Dialogue Systems",
+    author = "Zhu, Qi and Zhang, Zheng and Fang, Yan and Li, Xiang and Takanobu, Ryuichi and Li, Jinchao and Peng, Baolin and Gao, Jianfeng and Zhu, Xiaoyan and Huang, Minlie",
+    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations",
+    month = jul,
+    year = "2020",
+    address = "Online",
+    publisher = "Association for Computational Linguistics",
+    url = "https://aclanthology.org/2020.acl-demos.19",
+    doi = "10.18653/v1/2020.acl-demos.19",
+    pages = "142--149"
+}
+```
\ No newline at end of file
diff --git a/convlab2/nlu/jointBERT/dataloader.py b/convlab2/nlu/jointBERT/dataloader.py
index 38fc24ea0fdc288410a65146716996432ebd896c..d1fcbc7a4864211a9956cedacc7c3479c195733a 100755
--- a/convlab2/nlu/jointBERT/dataloader.py
+++ b/convlab2/nlu/jointBERT/dataloader.py
@@ -39,13 +39,13 @@ class Dataloader:
         for d in self.data[data_key]:
             max_sen_len = max(max_sen_len, len(d[0]))
             sen_len.append(len(d[0]))
-            # d = (tokens, tags, intents, da2triples(turn["dialog_act"], context(list of str))
+            # d = (tokens, tags, intents, original dialog acts, context(list of str))
             if cut_sen_len > 0:
                 d[0] = d[0][:cut_sen_len]
                 d[1] = d[1][:cut_sen_len]
                 d[4] = [' '.join(s.split()[:cut_sen_len]) for s in d[4]]
 
-            d[4] = self.tokenizer.encode('[CLS] ' + ' [SEP] '.join(d[4]))
+            d[4] = self.tokenizer.encode(' [SEP] '.join(d[4]))
             max_context_len = max(max_context_len, len(d[4]))
             context_len.append(len(d[4]))
 
diff --git a/convlab2/nlu/jointBERT/test.py b/convlab2/nlu/jointBERT/test.py
index 7856e5ecc0c1be7471f9497339a7f9fbc2f3f9ec..2e1e1b51940c5d899a833fc9fba3a7f6aa257e7b 100755
--- a/convlab2/nlu/jointBERT/test.py
+++ b/convlab2/nlu/jointBERT/test.py
@@ -29,7 +29,11 @@ if __name__ == '__main__':
 
     set_seed(config['seed'])
 
-    if 'multiwoz' in data_dir:
+    if 'unified_datasets' in data_dir:
+        dataset_name = config['dataset_name']
+        print('-' * 20 + f'dataset:unified_datasets:{dataset_name}' + '-' * 20)
+        from convlab2.nlu.jointBERT.unified_datasets.postprocess import is_slot_da, calculateF1, recover_intent
+    elif 'multiwoz' in data_dir:
         print('-'*20 + 'dataset:multiwoz' + '-'*20)
         from convlab2.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent
     elif 'camrest' in data_dir:
@@ -90,14 +94,25 @@ if __name__ == '__main__':
                 'predict': predicts,
                 'golden': labels
             })
-            predict_golden['slot'].append({
-                'predict': [x for x in predicts if is_slot_da(x)],
-                'golden': [x for x in labels if is_slot_da(x)]
-            })
-            predict_golden['intent'].append({
-                'predict': [x for x in predicts if not is_slot_da(x)],
-                'golden': [x for x in labels if not is_slot_da(x)]
-            })
+            if isinstance(predicts, dict):
+                predict_golden['slot'].append({
+                    'predict': {k:v for k, v in predicts.items() if is_slot_da(k)},
+                    'golden': {k:v for k, v in labels.items() if is_slot_da(k)}
+                })
+                predict_golden['intent'].append({
+                    'predict': {k:v for k, v in predicts.items() if not is_slot_da(k)},
+                    'golden': {k:v for k, v in labels.items() if not is_slot_da(k)}
+                })
+            else:
+                assert isinstance(predicts, list)
+                predict_golden['slot'].append({
+                    'predict': [x for x in predicts if is_slot_da(x)],
+                    'golden': [x for x in labels if is_slot_da(x)]
+                })
+                predict_golden['intent'].append({
+                    'predict': [x for x in predicts if not is_slot_da(x)],
+                    'golden': [x for x in labels if not is_slot_da(x)]
+                })
         print('[%d|%d] samples' % (len(predict_golden['overall']), len(dataloader.data[data_key])))
 
     total = len(dataloader.data[data_key])
diff --git a/convlab2/nlu/jointBERT/train.py b/convlab2/nlu/jointBERT/train.py
index a6267b9403dd805ea7537763f1063cbcc03965d1..fad50eda9c3b6676d9ce5b9e00dcc961e14ae4e7 100755
--- a/convlab2/nlu/jointBERT/train.py
+++ b/convlab2/nlu/jointBERT/train.py
@@ -32,7 +32,11 @@ if __name__ == '__main__':
 
     set_seed(config['seed'])
 
-    if 'multiwoz' in data_dir:
+    if 'unified_datasets' in data_dir:
+        dataset_name = config['dataset_name']
+        print('-' * 20 + f'dataset:unified_datasets:{dataset_name}' + '-' * 20)
+        from convlab2.nlu.jointBERT.unified_datasets.postprocess import is_slot_da, calculateF1, recover_intent
+    elif 'multiwoz' in data_dir:
         print('-'*20 + 'dataset:multiwoz' + '-'*20)
         from convlab2.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent
     elif 'camrest' in data_dir:
@@ -149,14 +153,25 @@ if __name__ == '__main__':
                         'predict': predicts,
                         'golden': labels
                     })
-                    predict_golden['slot'].append({
-                        'predict': [x for x in predicts if is_slot_da(x)],
-                        'golden': [x for x in labels if is_slot_da(x)]
-                    })
-                    predict_golden['intent'].append({
-                        'predict': [x for x in predicts if not is_slot_da(x)],
-                        'golden': [x for x in labels if not is_slot_da(x)]
-                    })
+                    if isinstance(predicts, dict):
+                        predict_golden['slot'].append({
+                            'predict': {k:v for k, v in predicts.items() if is_slot_da(k)},
+                            'golden': {k:v for k, v in labels.items() if is_slot_da(k)}
+                        })
+                        predict_golden['intent'].append({
+                            'predict': {k:v for k, v in predicts.items() if not is_slot_da(k)},
+                            'golden': {k:v for k, v in labels.items() if not is_slot_da(k)}
+                        })
+                    else:
+                        assert isinstance(predicts, list)
+                        predict_golden['slot'].append({
+                            'predict': [x for x in predicts if is_slot_da(x)],
+                            'golden': [x for x in labels if is_slot_da(x)]
+                        })
+                        predict_golden['intent'].append({
+                            'predict': [x for x in predicts if not is_slot_da(x)],
+                            'golden': [x for x in labels if not is_slot_da(x)]
+                        })
 
             for j in range(10):
                 writer.add_text('val_sample_{}'.format(j),
diff --git a/convlab2/nlu/jointBERT/unified_datasets/README.md b/convlab2/nlu/jointBERT/unified_datasets/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..6ced41a7f793da71dcf07edfdc58c29cde4996d5
--- /dev/null
+++ b/convlab2/nlu/jointBERT/unified_datasets/README.md
@@ -0,0 +1,41 @@
+# BERTNLU on datasets in unified format
+
+We support training BERTNLU on datasets that are in our unified format.
+
+- For **non-categorical** dialogue acts whose values are in the utterances, we use **slot tagging** to extract the values.
+- For **categorical** and **binary** dialogue acts whose values may not be presented in the utterances, we treat them as **intents** of the utterances.
+
+## Usage
+
+#### Preprocess data
+
+```sh
+$ python preprocess.py --dataset dataset_name --speaker {user,system,all} --context_window_size CONTEXT_WINDOW_SIZE --save_dir save_directory
+```
+
+Note that the dataset will be loaded by `convlab2.util.load_dataset(dataset_name)`. If you want to use custom datasets, make sure they follow the unified format and can be loaded using this function.
+output processed data on `${save_dir}/${dataset_name}/${speaker}/context_window_size_${context_window_size}` dir.
+
+#### Train a model
+
+Prepare a config file and run the training script in the parent directory:
+
+```sh
+$ python train.py --config_path path_to_a_config_file
+```
+
+The model (`pytorch_model.bin`) will be saved under the `output_dir` of the config file. Also, it will be zipped as `zipped_model_path` in the config file.
+
+#### Test a model
+
+Run the inference script in the parent directory:
+
+```sh
+$ python test.py --config_path path_to_a_config_file
+```
+
+The result (`output.json`) will be saved under the `output_dir` of the config file.
+
+#### Predict
+
+See `nlu.py` for usage.
diff --git a/convlab2/nlu/jointBERT/unified_datasets/__init__.py b/convlab2/nlu/jointBERT/unified_datasets/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..74b08065ca97132195ab56c65cb7fe87ec7b4fca
--- /dev/null
+++ b/convlab2/nlu/jointBERT/unified_datasets/__init__.py
@@ -0,0 +1 @@
+from convlab2.nlu.jointBERT.unified_datasets.nlu import BERTNLU
\ No newline at end of file
diff --git a/convlab2/nlu/jointBERT/unified_datasets/merge_predict_res.py b/convlab2/nlu/jointBERT/unified_datasets/merge_predict_res.py
new file mode 100755
index 0000000000000000000000000000000000000000..6de31fbea9825b54b2e29bcf51e035a571de1c6b
--- /dev/null
+++ b/convlab2/nlu/jointBERT/unified_datasets/merge_predict_res.py
@@ -0,0 +1,33 @@
+import json
+import os
+from convlab2.util import load_dataset, load_nlu_data
+
+
+def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
+    assert os.path.exists(predict_result)
+    dataset = load_dataset(dataset_name)
+    data = load_nlu_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
+    
+    if save_dir is None:
+        save_dir = os.path.dirname(predict_result)
+    else:
+        os.makedirs(save_dir, exist_ok=True)
+    predict_result = json.load(open(predict_result))
+
+    for sample, prediction in zip(data, predict_result):
+        sample['dialogue_acts_prediction'] = prediction['predict']
+
+    json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+
+
+if __name__ == '__main__':
+    from argparse import ArgumentParser
+    parser = ArgumentParser(description="merge predict results with original data for unified NLU evaluation")
+    parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset')
+    parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances')
+    parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result')
+    parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
+    parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated by ../test.py')
+    args = parser.parse_args()
+    print(args)
+    merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result)
diff --git a/convlab2/nlu/jointBERT/unified_datasets/nlu.py b/convlab2/nlu/jointBERT/unified_datasets/nlu.py
new file mode 100755
index 0000000000000000000000000000000000000000..063ea036e4999626ca02452b1c1dc9f38ddb913f
--- /dev/null
+++ b/convlab2/nlu/jointBERT/unified_datasets/nlu.py
@@ -0,0 +1,106 @@
+import logging
+import os
+import json
+import torch
+from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
+import transformers
+from convlab2.nlu.nlu import NLU
+from convlab2.nlu.jointBERT.dataloader import Dataloader
+from convlab2.nlu.jointBERT.jointBERT import JointBERT
+from convlab2.nlu.jointBERT.unified_datasets.postprocess import recover_intent
+from convlab2.util.custom_util import model_downloader
+
+
+class BERTNLU(NLU):
+    def __init__(self, mode, config_file, model_file=None):
+        assert mode == 'user' or mode == 'sys' or mode == 'all'
+        self.mode = mode
+        config_file = os.path.join(os.path.dirname(
+            os.path.abspath(__file__)), 'configs/{}'.format(config_file))
+        config = json.load(open(config_file))
+        # print(config['DEVICE'])
+        # DEVICE = config['DEVICE']
+        DEVICE = 'cpu' if not torch.cuda.is_available() else config['DEVICE']
+        root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+        data_dir = os.path.join(root_dir, config['data_dir'])
+        output_dir = os.path.join(root_dir, config['output_dir'])
+
+        assert os.path.exists(os.path.join(data_dir, 'intent_vocab.json')), print('Please run preprocess first')
+
+        intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json')))
+        tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json')))
+        dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab,
+                                pretrained_weights=config['model']['pretrained_weights'])
+
+        logging.info('intent num:' +  str(len(intent_vocab)))
+        logging.info('tag num:' + str(len(tag_vocab)))
+
+        if not os.path.exists(output_dir):
+            model_downloader(root_dir, model_file)
+        model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim)
+
+        state_dict = torch.load(os.path.join(output_dir, 'pytorch_model.bin'), DEVICE)
+        if int(transformers.__version__.split('.')[0]) >= 3 and 'bert.embeddings.position_ids' not in state_dict:
+            state_dict['bert.embeddings.position_ids'] = torch.tensor(range(512)).reshape(1, -1).to(DEVICE)
+
+        model.load_state_dict(state_dict)
+        model.to(DEVICE)
+        model.eval()
+
+        self.model = model
+        self.use_context = config['model']['context']
+        self.context_window_size = config['context_window_size']
+        self.dataloader = dataloader
+        self.sent_tokenizer = PunktSentenceTokenizer()
+        self.word_tokenizer = TreebankWordTokenizer()
+        logging.info("BERTNLU loaded")
+
+    def predict(self, utterance, context=list()):
+        sentences = self.sent_tokenizer.tokenize(utterance)
+        ori_word_seq = [token for sent in sentences for token in self.word_tokenizer.tokenize(sent)]
+        ori_tag_seq = [str(('O',))] * len(ori_word_seq)
+        if self.use_context:
+            if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1:
+                context = [item[1] for item in context]
+            context_seq = self.dataloader.tokenizer.encode(' [SEP] '.join(context[-self.context_window_size:]))
+            context_seq = context_seq[:510]
+        else:
+            context_seq = self.dataloader.tokenizer.encode('')
+        intents = []
+        da = {}
+
+        word_seq, tag_seq, new2ori = self.dataloader.bert_tokenize(ori_word_seq, ori_tag_seq)
+        word_seq = word_seq[:510]
+        tag_seq = tag_seq[:510]
+        batch_data = [[ori_word_seq, ori_tag_seq, intents, da, context_seq,
+                       new2ori, word_seq, self.dataloader.seq_tag2id(tag_seq), self.dataloader.seq_intent2id(intents)]]
+
+        pad_batch = self.dataloader.pad_batch(batch_data)
+        pad_batch = tuple(t.to(self.model.device) for t in pad_batch)
+        word_seq_tensor, tag_seq_tensor, intent_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = pad_batch
+        slot_logits, intent_logits = self.model.forward(word_seq_tensor, word_mask_tensor,
+                                                        context_seq_tensor=context_seq_tensor,
+                                                        context_mask_tensor=context_mask_tensor)
+        das = recover_intent(self.dataloader, intent_logits[0], slot_logits[0], tag_mask_tensor[0],
+                             batch_data[0][0], batch_data[0][-4])
+        dialog_act = []
+        for da_type in das:
+            for da in das[da_type]:
+                dialog_act.append([da['intent'], da['domain'], da['slot'], da.get('value','')])
+        return dialog_act
+
+
+if __name__ == '__main__':
+    texts = [
+        "I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
+        "I want to leave after 17:15.",
+        "Thank you for all the help! I appreciate it.",
+        "Please find a restaurant called Nusha.",
+        "What is the train id, please? ",
+        "I don't care about the price and it doesn't need to have free parking."
+    ]
+    nlu = BERTNLU(mode='user', config_file='multiwoz21_user.json')
+    for text in texts:
+        print(text)
+        print(nlu.predict(text))
+        print()
diff --git a/convlab2/nlu/jointBERT/unified_datasets/postprocess.py b/convlab2/nlu/jointBERT/unified_datasets/postprocess.py
new file mode 100755
index 0000000000000000000000000000000000000000..982b4a92a6df10832ac313acd5984af63567dd1d
--- /dev/null
+++ b/convlab2/nlu/jointBERT/unified_datasets/postprocess.py
@@ -0,0 +1,111 @@
+import re
+import torch
+
+
+def is_slot_da(da_type):
+    return da_type == 'non-categorical'
+
+
+def calculateF1(predict_golden):
+    # F1 of all three types of dialogue acts
+    TP, FP, FN = 0, 0, 0
+    for item in predict_golden:
+        for da_type in ['non-categorical', 'categorical', 'binary']:
+            if da_type not in item['predict']:
+                assert da_type not in item['golden']
+                continue
+            if da_type == 'binary':
+                predicts = [(x['intent'], x['domain'], x['slot']) for x in item['predict'][da_type]]
+                labels = [(x['intent'], x['domain'], x['slot']) for x in item['golden'][da_type]]
+            else:
+                predicts = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in item['predict'][da_type]]
+                labels = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in item['golden'][da_type]]
+            
+            for ele in predicts:
+                if ele in labels:
+                    TP += 1
+                else:
+                    FP += 1
+            for ele in labels:
+                if ele not in predicts:
+                    FN += 1
+    # print(TP, FP, FN)
+    precision = 1.0 * TP / (TP + FP) if TP + FP else 0.
+    recall = 1.0 * TP / (TP + FN) if TP + FN else 0.
+    F1 = 2.0 * precision * recall / (precision + recall) if precision + recall else 0.
+    return precision, recall, F1
+
+
+def tag2triples(word_seq, tag_seq):
+    word_seq = word_seq[:len(tag_seq)]
+    assert len(word_seq)==len(tag_seq)
+    triples = []
+    i = 0
+    while i < len(tag_seq):
+        tag = eval(tag_seq[i])
+        if tag[-1] == 'B':
+            intent, domain, slot = tag[0], tag[1], tag[2]
+            value = word_seq[i]
+            j = i + 1
+            while j < len(tag_seq):
+                next_tag = eval(tag_seq[j])
+                if next_tag[-1] == 'I' and next_tag[:-1] == tag[:-1]:
+                    value += ' ' + word_seq[j]
+                    i += 1
+                    j += 1
+                else:
+                    break
+            triples.append([intent, domain, slot, value])
+        i += 1
+    return triples
+
+
+def recover_intent(dataloader, intent_logits, tag_logits, tag_mask_tensor, ori_word_seq, new2ori):
+    # tag_logits = [sequence_length, tag_dim]
+    # intent_logits = [intent_dim]
+    # tag_mask_tensor = [sequence_length]
+    # new2ori = {(new_idx:old_idx),...} (after removing [CLS] and [SEP]
+    max_seq_len = tag_logits.size(0)
+    dialogue_acts = {
+        "categorical": [],
+        "non-categorical": [],
+        "binary": []
+    }
+    # for categorical & binary dialogue acts
+    for j in range(dataloader.intent_dim):
+        if intent_logits[j] > 0:
+            intent = eval(dataloader.id2intent[j])
+            if len(intent) == 3:
+                dialogue_acts['binary'].append({
+                    'intent': intent[0],
+                    'domain': intent[1],
+                    'slot': intent[2]
+                })
+            else:
+                assert len(intent) == 4
+                dialogue_acts['categorical'].append({
+                    'intent': intent[0],
+                    'domain': intent[1],
+                    'slot': intent[2],
+                    'value': intent[3]
+                })
+    # for non-categorical dialogues acts
+    tags = []
+    for j in range(1, max_seq_len-1):
+        if tag_mask_tensor[j] == 1:
+            value, tag_id = torch.max(tag_logits[j], dim=-1)
+            tags.append(dataloader.id2tag[tag_id.item()])
+    recover_tags = []
+    for i, tag in enumerate(tags):
+        if new2ori[i] >= len(recover_tags):
+            recover_tags.append(tag)
+    ori_word_seq = ori_word_seq[:len(recover_tags)]
+    tag_intent = tag2triples(ori_word_seq, recover_tags)
+    for intent in tag_intent:
+        dialogue_acts['non-categorical'].append({
+            'intent': intent[0],
+            'domain': intent[1],
+            'slot': intent[2],
+            'value': intent[3]
+        })
+    return dialogue_acts
diff --git a/convlab2/nlu/jointBERT/unified_datasets/preprocess.py b/convlab2/nlu/jointBERT/unified_datasets/preprocess.py
new file mode 100755
index 0000000000000000000000000000000000000000..055e0724ab4c8d5d8f467a6d11354426974da3d8
--- /dev/null
+++ b/convlab2/nlu/jointBERT/unified_datasets/preprocess.py
@@ -0,0 +1,93 @@
+import json
+import os
+from collections import Counter
+from convlab2.util import load_dataset, load_ontology, load_nlu_data
+from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
+from tqdm import tqdm
+
+
+def preprocess(dataset_name, speaker, save_dir, context_window_size):
+    dataset = load_dataset(dataset_name)
+    data_by_split = load_nlu_data(dataset, speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)
+    data_dir = os.path.join(save_dir, dataset_name, speaker, f'context_window_size_{context_window_size}')
+    os.makedirs(data_dir, exist_ok=True)
+
+    sent_tokenizer = PunktSentenceTokenizer()
+    word_tokenizer = TreebankWordTokenizer()
+    
+    processed_data = {}
+    all_tags = set([str(('O',))])
+    all_intents = Counter()
+    for data_split, data in data_by_split.items():
+        if data_split == 'validation':
+            data_split = 'val'
+        processed_data[data_split] = []
+        for sample in tqdm(data, desc=f'{data_split} samples'):
+
+            utterance = sample['utterance']
+
+            sentences = sent_tokenizer.tokenize(utterance)
+            sent_spans = sent_tokenizer.span_tokenize(utterance)
+            tokens = [token for sent in sentences for token in word_tokenizer.tokenize(sent)]
+            token_spans = [(sent_span[0]+token_span[0], sent_span[0]+token_span[1]) for sent, sent_span in zip(sentences, sent_spans) for token_span in word_tokenizer.span_tokenize(sent)]
+            tags = [str(('O',))] * len(tokens)
+            for da in sample['dialogue_acts']['non-categorical']:
+                if 'start' not in da:
+                    # skip da that doesn't have span annotation
+                    continue
+                char_start = da['start']
+                char_end = da['end']
+                word_start, word_end = -1, -1
+                for i, token_span in enumerate(token_spans):
+                    if char_start == token_span[0]:
+                        word_start = i
+                    if char_end == token_span[1]:
+                        word_end = i + 1
+                if word_start == -1 and word_end == -1:
+                    # char span does not match word, maybe there is an error in the annotation, skip
+                    print('char span does not match word, skipping')
+                    print('\t', 'utteance:', utterance)
+                    print('\t', 'value:', utterance[char_start: char_end])
+                    print('\t', 'da:', da, '\n')
+                    continue
+                intent, domain, slot = da['intent'], da['domain'], da['slot']
+                all_tags.add(str((intent, domain, slot, 'B')))
+                all_tags.add(str((intent, domain, slot, 'I')))
+                tags[word_start] = str((intent, domain, slot, 'B'))
+                for i in range(word_start+1, word_end):
+                    tags[i] = str((intent, domain, slot, 'I'))
+
+            intents = []
+            for da in sample['dialogue_acts']['categorical']:
+                intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da['value'].strip().lower()
+                intent = str((intent, domain, slot, value))
+                intents.append(intent)
+                all_intents[intent] += 1
+            for da in sample['dialogue_acts']['binary']:
+                intent, domain, slot = da['intent'], da['domain'], da['slot']
+                intent = str((intent, domain, slot))
+                intents.append(intent)
+                all_intents[intent] += 1
+            context = []
+            if context_window_size > 0:
+                context = [s['utterance'] for s in sample['context']]
+            processed_data[data_split].append([tokens, tags, intents, sample['dialogue_acts'], context])
+        json.dump(processed_data[data_split], open(os.path.join(data_dir, '{}_data.json'.format(data_split)), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+
+    # filter out intents that occur only once to get intent vocabulary. however, these intents are still in the data
+    all_intents = {x: count for x, count in all_intents.items() if count > 1}
+    print('sentence label num:', len(all_intents))
+    print('tag num:', len(all_tags))
+    json.dump(sorted(all_intents), open(os.path.join(data_dir, 'intent_vocab.json'), 'w'), indent=2)
+    json.dump(sorted(all_tags), open(os.path.join(data_dir, 'tag_vocab.json'), 'w'), indent=2)
+
+if __name__ == '__main__':
+    from argparse import ArgumentParser
+    parser = ArgumentParser(description="create nlu data for bertnlu training")
+    parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset')
+    parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances')
+    parser.add_argument('--save_dir', metavar='save_directory', type=str, default='data', help='directory to save the data, save_dir/$dataset_name/$speaker')
+    parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
+    args = parser.parse_args()
+    print(args)
+    preprocess(args.dataset, args.speaker, args.save_dir, args.context_window_size)
diff --git a/setup.py b/setup.py
index 2e426a70fc9a9e6c74bbd1f5a5ae0e4f7cffc8d4..62cbf1bd880e55e5dc13b26322dcf67ce6b1b73b 100755
--- a/setup.py
+++ b/setup.py
@@ -34,6 +34,7 @@ setup(
                 'Topic :: Scientific/Engineering :: Artificial Intelligence',
     ],
     install_requires=[
+        'matplotlib',
         'tabulate',
         'python-Levenshtein',
         'requests',