diff --git a/convlab2/nlu/evaluate_unified_datasets.py b/convlab2/nlu/evaluate_unified_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb244e34918c6aeaed468d6a67683c8e4e5306b4
--- /dev/null
+++ b/convlab2/nlu/evaluate_unified_datasets.py
@@ -0,0 +1,51 @@
+import json
+from pprint import pprint
+
+
+def evaluate(predict_result):
+    predict_result = json.load(open(predict_result))
+
+    metrics = {x: {'TP':0, 'FP':0, 'FN':0} for x in ['overall', 'binary', 'categorical', 'non-categorical']}
+
+    for sample in predict_result:
+        for da_type in ['binary', 'categorical', 'non-categorical']:
+            if da_type == 'binary':
+                predicts = [(x['intent'], x['domain'], x['slot']) for x in sample['predictions']['dialogue_acts'][da_type]]
+                labels = [(x['intent'], x['domain'], x['slot']) for x in sample['dialogue_acts'][da_type]]
+            else:
+                predicts = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in sample['predictions']['dialogue_acts'][da_type]]
+                labels = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in sample['dialogue_acts'][da_type]]
+            for ele in predicts:
+                if ele in labels:
+                    metrics['overall']['TP'] += 1
+                    metrics[da_type]['TP'] += 1
+                else:
+                    metrics['overall']['FP'] += 1
+                    metrics[da_type]['FP'] += 1
+            for ele in labels:
+                if ele not in predicts:
+                    metrics['overall']['FN'] += 1
+                    metrics[da_type]['FN'] += 1
+    
+    for metric in metrics:
+        TP = metrics[metric].pop('TP')
+        FP = metrics[metric].pop('FP')
+        FN = metrics[metric].pop('FN')
+        precision = 1.0 * TP / (TP + FP) if TP + FP else 0.
+        recall = 1.0 * TP / (TP + FN) if TP + FN else 0.
+        f1 = 2.0 * precision * recall / (precision + recall) if precision + recall else 0.
+        metrics[metric]['precision'] = precision
+        metrics[metric]['recall'] = recall
+        metrics[metric]['f1'] = f1
+
+    return metrics
+
+
+if __name__ == '__main__':
+    from argparse import ArgumentParser
+    parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
+    parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the prediction file that in the unified data format')
+    args = parser.parse_args()
+    print(args)
+    metrics = evaluate(args.predict_result)
+    pprint(metrics)
diff --git a/convlab2/nlu/jointBERT/README.md b/convlab2/nlu/jointBERT/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..c9756d3c1ebdd42e975bb86d32a532b066a29048
--- /dev/null
+++ b/convlab2/nlu/jointBERT/README.md
@@ -0,0 +1,57 @@
+# BERTNLU
+
+On top of the pre-trained BERT, BERTNLU use an MLP for slot tagging and another MLP for intent classification. All parameters are fine-tuned to learn these two tasks jointly.
+
+Dialog acts are split into two groups, depending on whether the values are in the utterances:
+
+- For dialogue acts whose values are in the utterances, we use **slot tagging** to extract the values. For example, `"Find me a cheap hotel"`, its dialog act is `{intent=Inform, domain=hotel, slot=price, value=cheap}`, and the corresponding BIO tag sequence is `["O", "O", "O", "B-inform-hotel-price", "O"]`. An MLP classifier takes a token's representation from BERT and outputs its tag.
+- For dialogue acts whose values may not be presented in the utterances, we treat them as **intents** of the utterances. Another MLP takes embeddings of `[CLS]` of a utterance as input and does the binary classification for each intent independently. Since some intents are rare, we set the weight of positive samples as $\lg(\frac{\# \ negative\ samples}{\# \ positive\ samples})$ empirically for each intent.
+
+The model can also incorporate context information by setting the `context=true` in the config file. The context utterances will be concatenated (separated by `[SEP]`) and fed into BERT. Then the `[CLS]` embedding serves as context representaion and is concatenated to all token representations in the target utterance right before the slot and intent classifiers.
+
+
+## Usage
+
+Follow the instruction under each dataset's directory to prepare data and model config file for training and evaluation.
+
+#### Train a model
+
+```sh
+$ python train.py --config_path path_to_a_config_file
+```
+
+The model (`pytorch_model.bin`) will be saved under the `output_dir` of the config file.
+
+#### Test a model
+
+```sh
+$ python test.py --config_path path_to_a_config_file
+```
+
+The result (`output.json`) will be saved under the `output_dir` of the config file. Also, it will be zipped as `zipped_model_path` in the config file.
+
+
+## References
+
+```
+@inproceedings{devlin2019bert,
+  title={BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding},
+  author={Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina},
+  booktitle={Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)},
+  pages={4171--4186},
+  year={2019}
+}
+
+@inproceedings{zhu-etal-2020-convlab,
+    title = "{C}onv{L}ab-2: An Open-Source Toolkit for Building, Evaluating, and Diagnosing Dialogue Systems",
+    author = "Zhu, Qi and Zhang, Zheng and Fang, Yan and Li, Xiang and Takanobu, Ryuichi and Li, Jinchao and Peng, Baolin and Gao, Jianfeng and Zhu, Xiaoyan and Huang, Minlie",
+    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations",
+    month = jul,
+    year = "2020",
+    address = "Online",
+    publisher = "Association for Computational Linguistics",
+    url = "https://aclanthology.org/2020.acl-demos.19",
+    doi = "10.18653/v1/2020.acl-demos.19",
+    pages = "142--149"
+}
+```
\ No newline at end of file
diff --git a/convlab2/nlu/jointBERT/dataloader.py b/convlab2/nlu/jointBERT/dataloader.py
index 38fc24ea0fdc288410a65146716996432ebd896c..d1fcbc7a4864211a9956cedacc7c3479c195733a 100755
--- a/convlab2/nlu/jointBERT/dataloader.py
+++ b/convlab2/nlu/jointBERT/dataloader.py
@@ -39,13 +39,13 @@ class Dataloader:
         for d in self.data[data_key]:
             max_sen_len = max(max_sen_len, len(d[0]))
             sen_len.append(len(d[0]))
-            # d = (tokens, tags, intents, da2triples(turn["dialog_act"], context(list of str))
+            # d = (tokens, tags, intents, original dialog acts, context(list of str))
             if cut_sen_len > 0:
                 d[0] = d[0][:cut_sen_len]
                 d[1] = d[1][:cut_sen_len]
                 d[4] = [' '.join(s.split()[:cut_sen_len]) for s in d[4]]
 
-            d[4] = self.tokenizer.encode('[CLS] ' + ' [SEP] '.join(d[4]))
+            d[4] = self.tokenizer.encode(' [SEP] '.join(d[4]))
             max_context_len = max(max_context_len, len(d[4]))
             context_len.append(len(d[4]))
 
diff --git a/convlab2/nlu/jointBERT/test.py b/convlab2/nlu/jointBERT/test.py
index 7856e5ecc0c1be7471f9497339a7f9fbc2f3f9ec..2e1e1b51940c5d899a833fc9fba3a7f6aa257e7b 100755
--- a/convlab2/nlu/jointBERT/test.py
+++ b/convlab2/nlu/jointBERT/test.py
@@ -29,7 +29,11 @@ if __name__ == '__main__':
 
     set_seed(config['seed'])
 
-    if 'multiwoz' in data_dir:
+    if 'unified_datasets' in data_dir:
+        dataset_name = config['dataset_name']
+        print('-' * 20 + f'dataset:unified_datasets:{dataset_name}' + '-' * 20)
+        from convlab2.nlu.jointBERT.unified_datasets.postprocess import is_slot_da, calculateF1, recover_intent
+    elif 'multiwoz' in data_dir:
         print('-'*20 + 'dataset:multiwoz' + '-'*20)
         from convlab2.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent
     elif 'camrest' in data_dir:
@@ -90,14 +94,25 @@ if __name__ == '__main__':
                 'predict': predicts,
                 'golden': labels
             })
-            predict_golden['slot'].append({
-                'predict': [x for x in predicts if is_slot_da(x)],
-                'golden': [x for x in labels if is_slot_da(x)]
-            })
-            predict_golden['intent'].append({
-                'predict': [x for x in predicts if not is_slot_da(x)],
-                'golden': [x for x in labels if not is_slot_da(x)]
-            })
+            if isinstance(predicts, dict):
+                predict_golden['slot'].append({
+                    'predict': {k:v for k, v in predicts.items() if is_slot_da(k)},
+                    'golden': {k:v for k, v in labels.items() if is_slot_da(k)}
+                })
+                predict_golden['intent'].append({
+                    'predict': {k:v for k, v in predicts.items() if not is_slot_da(k)},
+                    'golden': {k:v for k, v in labels.items() if not is_slot_da(k)}
+                })
+            else:
+                assert isinstance(predicts, list)
+                predict_golden['slot'].append({
+                    'predict': [x for x in predicts if is_slot_da(x)],
+                    'golden': [x for x in labels if is_slot_da(x)]
+                })
+                predict_golden['intent'].append({
+                    'predict': [x for x in predicts if not is_slot_da(x)],
+                    'golden': [x for x in labels if not is_slot_da(x)]
+                })
         print('[%d|%d] samples' % (len(predict_golden['overall']), len(dataloader.data[data_key])))
 
     total = len(dataloader.data[data_key])
diff --git a/convlab2/nlu/jointBERT/train.py b/convlab2/nlu/jointBERT/train.py
index a6267b9403dd805ea7537763f1063cbcc03965d1..fad50eda9c3b6676d9ce5b9e00dcc961e14ae4e7 100755
--- a/convlab2/nlu/jointBERT/train.py
+++ b/convlab2/nlu/jointBERT/train.py
@@ -32,7 +32,11 @@ if __name__ == '__main__':
 
     set_seed(config['seed'])
 
-    if 'multiwoz' in data_dir:
+    if 'unified_datasets' in data_dir:
+        dataset_name = config['dataset_name']
+        print('-' * 20 + f'dataset:unified_datasets:{dataset_name}' + '-' * 20)
+        from convlab2.nlu.jointBERT.unified_datasets.postprocess import is_slot_da, calculateF1, recover_intent
+    elif 'multiwoz' in data_dir:
         print('-'*20 + 'dataset:multiwoz' + '-'*20)
         from convlab2.nlu.jointBERT.multiwoz.postprocess import is_slot_da, calculateF1, recover_intent
     elif 'camrest' in data_dir:
@@ -149,14 +153,25 @@ if __name__ == '__main__':
                         'predict': predicts,
                         'golden': labels
                     })
-                    predict_golden['slot'].append({
-                        'predict': [x for x in predicts if is_slot_da(x)],
-                        'golden': [x for x in labels if is_slot_da(x)]
-                    })
-                    predict_golden['intent'].append({
-                        'predict': [x for x in predicts if not is_slot_da(x)],
-                        'golden': [x for x in labels if not is_slot_da(x)]
-                    })
+                    if isinstance(predicts, dict):
+                        predict_golden['slot'].append({
+                            'predict': {k:v for k, v in predicts.items() if is_slot_da(k)},
+                            'golden': {k:v for k, v in labels.items() if is_slot_da(k)}
+                        })
+                        predict_golden['intent'].append({
+                            'predict': {k:v for k, v in predicts.items() if not is_slot_da(k)},
+                            'golden': {k:v for k, v in labels.items() if not is_slot_da(k)}
+                        })
+                    else:
+                        assert isinstance(predicts, list)
+                        predict_golden['slot'].append({
+                            'predict': [x for x in predicts if is_slot_da(x)],
+                            'golden': [x for x in labels if is_slot_da(x)]
+                        })
+                        predict_golden['intent'].append({
+                            'predict': [x for x in predicts if not is_slot_da(x)],
+                            'golden': [x for x in labels if not is_slot_da(x)]
+                        })
 
             for j in range(10):
                 writer.add_text('val_sample_{}'.format(j),
diff --git a/convlab2/nlu/jointBERT/unified_datasets/README.md b/convlab2/nlu/jointBERT/unified_datasets/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..9e9031148ff58ae6354bb8071220395b140eb599
--- /dev/null
+++ b/convlab2/nlu/jointBERT/unified_datasets/README.md
@@ -0,0 +1,46 @@
+# BERTNLU on datasets in unified format
+
+We support training BERTNLU on datasets that are in our unified format.
+
+- For **non-categorical** dialogue acts whose values are in the utterances, we use **slot tagging** to extract the values.
+- For **categorical** and **binary** dialogue acts whose values may not be presented in the utterances, we treat them as **intents** of the utterances.
+
+## Usage
+
+#### Preprocess data
+
+```sh
+$ python preprocess.py --dataset dataset_name --speaker {user,system,all} --context_window_size CONTEXT_WINDOW_SIZE --save_dir save_directory
+```
+
+Note that the dataset will be loaded by `convlab2.util.load_dataset(dataset_name)`. If you want to use custom datasets, make sure they follow the unified format and can be loaded using this function.
+output processed data on `${save_dir}/${dataset_name}/${speaker}/context_window_size_${context_window_size}` dir.
+
+#### Train a model
+
+Prepare a config file and run the training script in the parent directory:
+
+```sh
+$ python train.py --config_path path_to_a_config_file
+```
+
+The model (`pytorch_model.bin`) will be saved under the `output_dir` of the config file. Also, it will be zipped as `zipped_model_path` in the config file.
+
+#### Test a model
+
+Run the inference script in the parent directory:
+
+```sh
+$ python test.py --config_path path_to_a_config_file
+```
+
+The result (`output.json`) will be saved under the `output_dir` of the config file.
+
+To generate `predictions.json` that merges test data and model predictions under the same directory of the `output.json`:
+```sh
+$ python merge_predict_res.py -d dataset_name -s {user,system,all} -c CONTEXT_WINDOW_SIZE -p path_to_output.json
+```
+
+#### Predict
+
+See `nlu.py` for usage.
diff --git a/convlab2/nlu/jointBERT/unified_datasets/__init__.py b/convlab2/nlu/jointBERT/unified_datasets/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..74b08065ca97132195ab56c65cb7fe87ec7b4fca
--- /dev/null
+++ b/convlab2/nlu/jointBERT/unified_datasets/__init__.py
@@ -0,0 +1 @@
+from convlab2.nlu.jointBERT.unified_datasets.nlu import BERTNLU
\ No newline at end of file
diff --git a/convlab2/nlu/jointBERT/unified_datasets/configs/multiwoz21_user.json b/convlab2/nlu/jointBERT/unified_datasets/configs/multiwoz21_user.json
new file mode 100755
index 0000000000000000000000000000000000000000..d6be45577a3662065d36d112ea15de938705e224
--- /dev/null
+++ b/convlab2/nlu/jointBERT/unified_datasets/configs/multiwoz21_user.json
@@ -0,0 +1,27 @@
+{
+  "dataset_name": "multiwoz21",
+  "data_dir": "unified_datasets/data/multiwoz21/user/context_window_size_0",
+  "output_dir": "unified_datasets/output/multiwoz21/user/context_window_size_0",
+  "zipped_model_path": "unified_datasets/output/multiwoz21/user/context_window_size_0/bertnlu_unified_multiwoz_user_context0.zip",
+  "log_dir": "unified_datasets/output/multiwoz21/user/context_window_size_0/log",
+  "DEVICE": "cuda:0",
+  "seed": 2019,
+  "cut_sen_len": 40,
+  "use_bert_tokenizer": true,
+  "context_window_size": 0,
+  "model": {
+    "finetune": true,
+    "context": false,
+    "context_grad": false,
+    "pretrained_weights": "bert-base-uncased",
+    "check_step": 1000,
+    "max_step": 10000,
+    "batch_size": 128,
+    "learning_rate": 1e-4,
+    "adam_epsilon": 1e-8,
+    "warmup_steps": 0,
+    "weight_decay": 0.0,
+    "dropout": 0.1,
+    "hidden_units": 768
+  }
+}
\ No newline at end of file
diff --git a/convlab2/nlu/jointBERT/unified_datasets/configs/multiwoz21_user_context3.json b/convlab2/nlu/jointBERT/unified_datasets/configs/multiwoz21_user_context3.json
new file mode 100755
index 0000000000000000000000000000000000000000..d46f4db6096028e2582bea546b847be028faf184
--- /dev/null
+++ b/convlab2/nlu/jointBERT/unified_datasets/configs/multiwoz21_user_context3.json
@@ -0,0 +1,27 @@
+{
+  "dataset_name": "multiwoz21",
+  "data_dir": "unified_datasets/data/multiwoz21/user/context_window_size_3",
+  "output_dir": "unified_datasets/output/multiwoz21/user/context_window_size_3",
+  "zipped_model_path": "unified_datasets/output/multiwoz21/user/context_window_size_3/bertnlu_unified_multiwoz_user_context3.zip",
+  "log_dir": "unified_datasets/output/multiwoz21/user/context_window_size_3/log",
+  "DEVICE": "cuda:0",
+  "seed": 2019,
+  "cut_sen_len": 40,
+  "use_bert_tokenizer": true,
+  "context_window_size": 3,
+  "model": {
+    "finetune": true,
+    "context": true,
+    "context_grad": true,
+    "pretrained_weights": "bert-base-uncased",
+    "check_step": 1000,
+    "max_step": 10000,
+    "batch_size": 128,
+    "learning_rate": 1e-4,
+    "adam_epsilon": 1e-8,
+    "warmup_steps": 0,
+    "weight_decay": 0.0,
+    "dropout": 0.1,
+    "hidden_units": 1536
+  }
+}
\ No newline at end of file
diff --git a/convlab2/nlu/jointBERT/unified_datasets/merge_predict_res.py b/convlab2/nlu/jointBERT/unified_datasets/merge_predict_res.py
new file mode 100755
index 0000000000000000000000000000000000000000..a6be2242357d9eb9b0f5a317c69df63a5013efd3
--- /dev/null
+++ b/convlab2/nlu/jointBERT/unified_datasets/merge_predict_res.py
@@ -0,0 +1,33 @@
+import json
+import os
+from convlab2.util import load_dataset, load_nlu_data
+
+
+def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
+    assert os.path.exists(predict_result)
+    dataset = load_dataset(dataset_name)
+    data = load_nlu_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
+    
+    if save_dir is None:
+        save_dir = os.path.dirname(predict_result)
+    else:
+        os.makedirs(save_dir, exist_ok=True)
+    predict_result = json.load(open(predict_result))
+
+    for sample, prediction in zip(data, predict_result):
+        sample['predictions'] = {'dialogue_acts': prediction['predict']}
+
+    json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+
+
+if __name__ == '__main__':
+    from argparse import ArgumentParser
+    parser = ArgumentParser(description="merge predict results with original data for unified NLU evaluation")
+    parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset')
+    parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances')
+    parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result')
+    parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
+    parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated by ../test.py')
+    args = parser.parse_args()
+    print(args)
+    merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result)
diff --git a/convlab2/nlu/jointBERT/unified_datasets/nlu.py b/convlab2/nlu/jointBERT/unified_datasets/nlu.py
new file mode 100755
index 0000000000000000000000000000000000000000..063ea036e4999626ca02452b1c1dc9f38ddb913f
--- /dev/null
+++ b/convlab2/nlu/jointBERT/unified_datasets/nlu.py
@@ -0,0 +1,106 @@
+import logging
+import os
+import json
+import torch
+from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
+import transformers
+from convlab2.nlu.nlu import NLU
+from convlab2.nlu.jointBERT.dataloader import Dataloader
+from convlab2.nlu.jointBERT.jointBERT import JointBERT
+from convlab2.nlu.jointBERT.unified_datasets.postprocess import recover_intent
+from convlab2.util.custom_util import model_downloader
+
+
+class BERTNLU(NLU):
+    def __init__(self, mode, config_file, model_file=None):
+        assert mode == 'user' or mode == 'sys' or mode == 'all'
+        self.mode = mode
+        config_file = os.path.join(os.path.dirname(
+            os.path.abspath(__file__)), 'configs/{}'.format(config_file))
+        config = json.load(open(config_file))
+        # print(config['DEVICE'])
+        # DEVICE = config['DEVICE']
+        DEVICE = 'cpu' if not torch.cuda.is_available() else config['DEVICE']
+        root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+        data_dir = os.path.join(root_dir, config['data_dir'])
+        output_dir = os.path.join(root_dir, config['output_dir'])
+
+        assert os.path.exists(os.path.join(data_dir, 'intent_vocab.json')), print('Please run preprocess first')
+
+        intent_vocab = json.load(open(os.path.join(data_dir, 'intent_vocab.json')))
+        tag_vocab = json.load(open(os.path.join(data_dir, 'tag_vocab.json')))
+        dataloader = Dataloader(intent_vocab=intent_vocab, tag_vocab=tag_vocab,
+                                pretrained_weights=config['model']['pretrained_weights'])
+
+        logging.info('intent num:' +  str(len(intent_vocab)))
+        logging.info('tag num:' + str(len(tag_vocab)))
+
+        if not os.path.exists(output_dir):
+            model_downloader(root_dir, model_file)
+        model = JointBERT(config['model'], DEVICE, dataloader.tag_dim, dataloader.intent_dim)
+
+        state_dict = torch.load(os.path.join(output_dir, 'pytorch_model.bin'), DEVICE)
+        if int(transformers.__version__.split('.')[0]) >= 3 and 'bert.embeddings.position_ids' not in state_dict:
+            state_dict['bert.embeddings.position_ids'] = torch.tensor(range(512)).reshape(1, -1).to(DEVICE)
+
+        model.load_state_dict(state_dict)
+        model.to(DEVICE)
+        model.eval()
+
+        self.model = model
+        self.use_context = config['model']['context']
+        self.context_window_size = config['context_window_size']
+        self.dataloader = dataloader
+        self.sent_tokenizer = PunktSentenceTokenizer()
+        self.word_tokenizer = TreebankWordTokenizer()
+        logging.info("BERTNLU loaded")
+
+    def predict(self, utterance, context=list()):
+        sentences = self.sent_tokenizer.tokenize(utterance)
+        ori_word_seq = [token for sent in sentences for token in self.word_tokenizer.tokenize(sent)]
+        ori_tag_seq = [str(('O',))] * len(ori_word_seq)
+        if self.use_context:
+            if len(context) > 0 and type(context[0]) is list and len(context[0]) > 1:
+                context = [item[1] for item in context]
+            context_seq = self.dataloader.tokenizer.encode(' [SEP] '.join(context[-self.context_window_size:]))
+            context_seq = context_seq[:510]
+        else:
+            context_seq = self.dataloader.tokenizer.encode('')
+        intents = []
+        da = {}
+
+        word_seq, tag_seq, new2ori = self.dataloader.bert_tokenize(ori_word_seq, ori_tag_seq)
+        word_seq = word_seq[:510]
+        tag_seq = tag_seq[:510]
+        batch_data = [[ori_word_seq, ori_tag_seq, intents, da, context_seq,
+                       new2ori, word_seq, self.dataloader.seq_tag2id(tag_seq), self.dataloader.seq_intent2id(intents)]]
+
+        pad_batch = self.dataloader.pad_batch(batch_data)
+        pad_batch = tuple(t.to(self.model.device) for t in pad_batch)
+        word_seq_tensor, tag_seq_tensor, intent_tensor, word_mask_tensor, tag_mask_tensor, context_seq_tensor, context_mask_tensor = pad_batch
+        slot_logits, intent_logits = self.model.forward(word_seq_tensor, word_mask_tensor,
+                                                        context_seq_tensor=context_seq_tensor,
+                                                        context_mask_tensor=context_mask_tensor)
+        das = recover_intent(self.dataloader, intent_logits[0], slot_logits[0], tag_mask_tensor[0],
+                             batch_data[0][0], batch_data[0][-4])
+        dialog_act = []
+        for da_type in das:
+            for da in das[da_type]:
+                dialog_act.append([da['intent'], da['domain'], da['slot'], da.get('value','')])
+        return dialog_act
+
+
+if __name__ == '__main__':
+    texts = [
+        "I would like a taxi from Saint John's college to Pizza Hut Fen Ditton.",
+        "I want to leave after 17:15.",
+        "Thank you for all the help! I appreciate it.",
+        "Please find a restaurant called Nusha.",
+        "What is the train id, please? ",
+        "I don't care about the price and it doesn't need to have free parking."
+    ]
+    nlu = BERTNLU(mode='user', config_file='multiwoz21_user.json')
+    for text in texts:
+        print(text)
+        print(nlu.predict(text))
+        print()
diff --git a/convlab2/nlu/jointBERT/unified_datasets/postprocess.py b/convlab2/nlu/jointBERT/unified_datasets/postprocess.py
new file mode 100755
index 0000000000000000000000000000000000000000..982b4a92a6df10832ac313acd5984af63567dd1d
--- /dev/null
+++ b/convlab2/nlu/jointBERT/unified_datasets/postprocess.py
@@ -0,0 +1,111 @@
+import re
+import torch
+
+
+def is_slot_da(da_type):
+    return da_type == 'non-categorical'
+
+
+def calculateF1(predict_golden):
+    # F1 of all three types of dialogue acts
+    TP, FP, FN = 0, 0, 0
+    for item in predict_golden:
+        for da_type in ['non-categorical', 'categorical', 'binary']:
+            if da_type not in item['predict']:
+                assert da_type not in item['golden']
+                continue
+            if da_type == 'binary':
+                predicts = [(x['intent'], x['domain'], x['slot']) for x in item['predict'][da_type]]
+                labels = [(x['intent'], x['domain'], x['slot']) for x in item['golden'][da_type]]
+            else:
+                predicts = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in item['predict'][da_type]]
+                labels = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in item['golden'][da_type]]
+            
+            for ele in predicts:
+                if ele in labels:
+                    TP += 1
+                else:
+                    FP += 1
+            for ele in labels:
+                if ele not in predicts:
+                    FN += 1
+    # print(TP, FP, FN)
+    precision = 1.0 * TP / (TP + FP) if TP + FP else 0.
+    recall = 1.0 * TP / (TP + FN) if TP + FN else 0.
+    F1 = 2.0 * precision * recall / (precision + recall) if precision + recall else 0.
+    return precision, recall, F1
+
+
+def tag2triples(word_seq, tag_seq):
+    word_seq = word_seq[:len(tag_seq)]
+    assert len(word_seq)==len(tag_seq)
+    triples = []
+    i = 0
+    while i < len(tag_seq):
+        tag = eval(tag_seq[i])
+        if tag[-1] == 'B':
+            intent, domain, slot = tag[0], tag[1], tag[2]
+            value = word_seq[i]
+            j = i + 1
+            while j < len(tag_seq):
+                next_tag = eval(tag_seq[j])
+                if next_tag[-1] == 'I' and next_tag[:-1] == tag[:-1]:
+                    value += ' ' + word_seq[j]
+                    i += 1
+                    j += 1
+                else:
+                    break
+            triples.append([intent, domain, slot, value])
+        i += 1
+    return triples
+
+
+def recover_intent(dataloader, intent_logits, tag_logits, tag_mask_tensor, ori_word_seq, new2ori):
+    # tag_logits = [sequence_length, tag_dim]
+    # intent_logits = [intent_dim]
+    # tag_mask_tensor = [sequence_length]
+    # new2ori = {(new_idx:old_idx),...} (after removing [CLS] and [SEP]
+    max_seq_len = tag_logits.size(0)
+    dialogue_acts = {
+        "categorical": [],
+        "non-categorical": [],
+        "binary": []
+    }
+    # for categorical & binary dialogue acts
+    for j in range(dataloader.intent_dim):
+        if intent_logits[j] > 0:
+            intent = eval(dataloader.id2intent[j])
+            if len(intent) == 3:
+                dialogue_acts['binary'].append({
+                    'intent': intent[0],
+                    'domain': intent[1],
+                    'slot': intent[2]
+                })
+            else:
+                assert len(intent) == 4
+                dialogue_acts['categorical'].append({
+                    'intent': intent[0],
+                    'domain': intent[1],
+                    'slot': intent[2],
+                    'value': intent[3]
+                })
+    # for non-categorical dialogues acts
+    tags = []
+    for j in range(1, max_seq_len-1):
+        if tag_mask_tensor[j] == 1:
+            value, tag_id = torch.max(tag_logits[j], dim=-1)
+            tags.append(dataloader.id2tag[tag_id.item()])
+    recover_tags = []
+    for i, tag in enumerate(tags):
+        if new2ori[i] >= len(recover_tags):
+            recover_tags.append(tag)
+    ori_word_seq = ori_word_seq[:len(recover_tags)]
+    tag_intent = tag2triples(ori_word_seq, recover_tags)
+    for intent in tag_intent:
+        dialogue_acts['non-categorical'].append({
+            'intent': intent[0],
+            'domain': intent[1],
+            'slot': intent[2],
+            'value': intent[3]
+        })
+    return dialogue_acts
diff --git a/convlab2/nlu/jointBERT/unified_datasets/preprocess.py b/convlab2/nlu/jointBERT/unified_datasets/preprocess.py
new file mode 100755
index 0000000000000000000000000000000000000000..ca942b38f039abc449dcc9c80ba1ab352aac2483
--- /dev/null
+++ b/convlab2/nlu/jointBERT/unified_datasets/preprocess.py
@@ -0,0 +1,93 @@
+import json
+import os
+from collections import Counter
+from convlab2.util import load_dataset, load_nlu_data
+from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
+from tqdm import tqdm
+
+
+def preprocess(dataset_name, speaker, save_dir, context_window_size):
+    dataset = load_dataset(dataset_name)
+    data_by_split = load_nlu_data(dataset, speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)
+    data_dir = os.path.join(save_dir, dataset_name, speaker, f'context_window_size_{context_window_size}')
+    os.makedirs(data_dir, exist_ok=True)
+
+    sent_tokenizer = PunktSentenceTokenizer()
+    word_tokenizer = TreebankWordTokenizer()
+    
+    processed_data = {}
+    all_tags = set([str(('O',))])
+    all_intents = Counter()
+    for data_split, data in data_by_split.items():
+        if data_split == 'validation':
+            data_split = 'val'
+        processed_data[data_split] = []
+        for sample in tqdm(data, desc=f'{data_split} samples'):
+
+            utterance = sample['utterance']
+
+            sentences = sent_tokenizer.tokenize(utterance)
+            sent_spans = sent_tokenizer.span_tokenize(utterance)
+            tokens = [token for sent in sentences for token in word_tokenizer.tokenize(sent)]
+            token_spans = [(sent_span[0]+token_span[0], sent_span[0]+token_span[1]) for sent, sent_span in zip(sentences, sent_spans) for token_span in word_tokenizer.span_tokenize(sent)]
+            tags = [str(('O',))] * len(tokens)
+            for da in sample['dialogue_acts']['non-categorical']:
+                if 'start' not in da:
+                    # skip da that doesn't have span annotation
+                    continue
+                char_start = da['start']
+                char_end = da['end']
+                word_start, word_end = -1, -1
+                for i, token_span in enumerate(token_spans):
+                    if char_start == token_span[0]:
+                        word_start = i
+                    if char_end == token_span[1]:
+                        word_end = i + 1
+                if word_start == -1 and word_end == -1:
+                    # char span does not match word, maybe there is an error in the annotation, skip
+                    print('char span does not match word, skipping')
+                    print('\t', 'utteance:', utterance)
+                    print('\t', 'value:', utterance[char_start: char_end])
+                    print('\t', 'da:', da, '\n')
+                    continue
+                intent, domain, slot = da['intent'], da['domain'], da['slot']
+                all_tags.add(str((intent, domain, slot, 'B')))
+                all_tags.add(str((intent, domain, slot, 'I')))
+                tags[word_start] = str((intent, domain, slot, 'B'))
+                for i in range(word_start+1, word_end):
+                    tags[i] = str((intent, domain, slot, 'I'))
+
+            intents = []
+            for da in sample['dialogue_acts']['categorical']:
+                intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da['value'].strip().lower()
+                intent = str((intent, domain, slot, value))
+                intents.append(intent)
+                all_intents[intent] += 1
+            for da in sample['dialogue_acts']['binary']:
+                intent, domain, slot = da['intent'], da['domain'], da['slot']
+                intent = str((intent, domain, slot))
+                intents.append(intent)
+                all_intents[intent] += 1
+            context = []
+            if context_window_size > 0:
+                context = [s['utterance'] for s in sample['context']]
+            processed_data[data_split].append([tokens, tags, intents, sample['dialogue_acts'], context])
+        json.dump(processed_data[data_split], open(os.path.join(data_dir, '{}_data.json'.format(data_split)), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+
+    # filter out intents that occur only once to get intent vocabulary. however, these intents are still in the data
+    all_intents = {x: count for x, count in all_intents.items() if count > 1}
+    print('sentence label num:', len(all_intents))
+    print('tag num:', len(all_tags))
+    json.dump(sorted(all_intents), open(os.path.join(data_dir, 'intent_vocab.json'), 'w'), indent=2)
+    json.dump(sorted(all_tags), open(os.path.join(data_dir, 'tag_vocab.json'), 'w'), indent=2)
+
+if __name__ == '__main__':
+    from argparse import ArgumentParser
+    parser = ArgumentParser(description="create nlu data for bertnlu training")
+    parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset')
+    parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances')
+    parser.add_argument('--save_dir', metavar='save_directory', type=str, default='data', help='directory to save the data, save_dir/$dataset_name/$speaker')
+    parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
+    args = parser.parse_args()
+    print(args)
+    preprocess(args.dataset, args.speaker, args.save_dir, args.context_window_size)
diff --git a/convlab2/nlu/milu/README.md b/convlab2/nlu/milu/README.md
index b1c6bf5a130215e89c4b6145a73c720fca86be18..bbd54d671900f26373ef369c3c6f049dcf74ac23 100755
--- a/convlab2/nlu/milu/README.md
+++ b/convlab2/nlu/milu/README.md
@@ -5,16 +5,44 @@ MILU is a joint neural model that allows you to simultaneously predict multiple
 ## Example usage
 We based our implementation on the [AllenNLP library](https://github.com/allenai/allennlp). For an introduction to this library, you should check [these tutorials](https://allennlp.org/tutorials).
 
+To use this model, you need to additionally install `overrides==4.1.2, allennlp==0.9.0` and use `python>=3.6,<=3.8`.
+
+### On MultiWOZ dataset
+
 ```bash
-$ PYTHONPATH=../../.. python train.py multiwoz/configs/[base|context3].jsonnet -s serialization_dir
-$ PYTHONPATH=../../.. python evaluate.py serialization_dir/model.tar.gz {test_file} --cuda-device {CUDA_DEVICE}
+$ python train.py multiwoz/configs/[base|context3].jsonnet -s serialization_dir
+$ python evaluate.py serialization_dir/model.tar.gz {test_file} --cuda-device {CUDA_DEVICE}
 ```
 
 If you want to perform end-to-end evaluation, you can include the trained model by adding the model path (serialization_dir/model.tar.gz) to your ConvLab spec file.
 
-## Data
+#### Data
 We use the multiwoz data (data/multiwoz/[train|val|test].json.zip).
 
+### MILU on datasets in unified format
+We support training MILU on datasets that are in our unified format.
+
+- For **non-categorical** dialogue acts whose values are in the utterances, we use **slot tagging** to extract the values.
+- For **categorical** and **binary** dialogue acts whose values may not be presented in the utterances, we treat them as **intents** of the utterances.
+
+Takes MultiWOZ 2.1 (unified format) as an example,
+```bash
+$ python train.py unified_datasets/configs/multiwoz21_user_context3.jsonnet -s serialization_dir
+$ python evaluate.py serialization_dir/model.tar.gz test --cuda-device {CUDA_DEVICE} --output_file output/multiwoz21_user/output.json
+
+# to generate output/multiwoz21_user/predictions.json that merges test data and model predictions.
+$ python unified_datasets/merge_predict_res.py -d multiwoz21 -s user -p output/multiwoz21_user/output.json
+```
+Note that the config file is different from the above. You should set:
+- `"use_unified_datasets": true` in `dataset_reader` and `model`
+- `"dataset_name": "multiwoz21"` in `dataset_reader`
+- `"train_data_path": "train"`
+- `"validation_data_path": "validation"`
+- `"test_data_path": "test"`
+
+## Predict
+See `nlu.py` under `multiwoz` and `unified_datasets` directories.
+
 ## References
 ```
 @inproceedings{lee2019convlab,
diff --git a/convlab2/nlu/milu/dai_f1_measure.py b/convlab2/nlu/milu/dai_f1_measure.py
index f82d9acd627192d3239b24bc3456614aac85a5e7..4bb7ec868259dd4c09351d1edc348bb50caa8542 100755
--- a/convlab2/nlu/milu/dai_f1_measure.py
+++ b/convlab2/nlu/milu/dai_f1_measure.py
@@ -9,7 +9,7 @@ from allennlp.training.metrics.metric import Metric
 class DialogActItemF1Measure(Metric):
     """
     """
-    def __init__(self) -> None:
+    def __init__(self, use_unified_datasets) -> None:
         """
         Parameters
         ----------
@@ -18,6 +18,7 @@ class DialogActItemF1Measure(Metric):
         self._true_positives = 0 
         self._false_positives = 0 
         self._false_negatives = 0 
+        self.use_unified_datasets = use_unified_datasets
 
 
     def __call__(self,
@@ -32,17 +33,36 @@ class DialogActItemF1Measure(Metric):
             A tensor of integer class label of shape (batch_size, sequence_length). It must be the same
             shape as the ``predictions`` tensor without the ``num_classes`` dimension.
         """
-        for prediction, gold_label in zip(predictions, gold_labels): 
-            for dat in prediction:
-                for sv in prediction[dat]:
-                    if dat not in gold_label or sv not in gold_label[dat]:
-                        self._false_positives += 1
+        if self.use_unified_datasets:
+            for prediction, gold_label in zip(predictions, gold_labels): 
+                for da_type in ['non-categorical', 'categorical', 'binary']:
+                    if da_type == 'binary':
+                        predicts = [(x['intent'], x['domain'], x['slot']) for x in prediction[da_type]]
+                        labels = [(x['intent'], x['domain'], x['slot']) for x in gold_label[da_type]]
                     else:
-                        self._true_positives += 1
-            for dat in gold_label:
-                for sv in gold_label[dat]:
-                    if dat not in prediction or sv not in prediction[dat]:
-                        self._false_negatives += 1
+                        predicts = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in prediction[da_type]]
+                        labels = [(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in gold_label[da_type]]
+                    
+                    for ele in predicts:
+                        if ele in labels:
+                            self._true_positives += 1
+                        else:
+                            self._false_positives += 1
+                    for ele in labels:
+                        if ele not in predicts:
+                            self._false_negatives += 1
+        else:
+            for prediction, gold_label in zip(predictions, gold_labels): 
+                for dat in prediction:
+                    for sv in prediction[dat]:
+                        if dat not in gold_label or sv not in gold_label[dat]:
+                            self._false_positives += 1
+                        else:
+                            self._true_positives += 1
+                for dat in gold_label:
+                    for sv in gold_label[dat]:
+                        if dat not in prediction or sv not in prediction[dat]:
+                            self._false_negatives += 1
 
 
     def get_metric(self, reset: bool = False):
diff --git a/convlab2/nlu/milu/dataset_reader.py b/convlab2/nlu/milu/dataset_reader.py
index 5e00af04e7fe6c13ddbb21d60f22c51d5cbbc106..35f71903ab4a269b0f9e5d3cd208d78e48278349 100755
--- a/convlab2/nlu/milu/dataset_reader.py
+++ b/convlab2/nlu/milu/dataset_reader.py
@@ -13,6 +13,8 @@ from allennlp.data.fields import TextField, SequenceLabelField, MultiLabelField,
 from allennlp.data.instance import Instance
 from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
 from allennlp.data.tokenizers import Token
+from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
+from convlab2.util import load_dataset, load_nlu_data
 from overrides import overrides
 
 from convlab2.util.file_util import cached_path
@@ -45,6 +47,8 @@ class MILUDatasetReader(DatasetReader):
     def __init__(self,
                  context_size: int = 0,
                  agent: str = None,
+                 use_unified_datasets: bool = False,
+                 dataset_name: str = None,
                  random_context_size: bool = True,
                  token_delimiter: str = None,
                  token_indexers: Dict[str, TokenIndexer] = None,
@@ -52,81 +56,150 @@ class MILUDatasetReader(DatasetReader):
         super().__init__(lazy)
         self._context_size = context_size
         self._agent = agent 
+        self.use_unified_datasets = use_unified_datasets
+        if self.use_unified_datasets:
+            self._dataset_name = dataset_name
+            self._dataset = load_dataset(self._dataset_name)
         self._random_context_size = random_context_size
         self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
         self._token_delimiter = token_delimiter
+        self._sent_tokenizer = PunktSentenceTokenizer()
+        self._word_tokenizer = TreebankWordTokenizer()
 
     @overrides
     def _read(self, file_path):
-        # if `file_path` is a URL, redirect to the cache
-        file_path = cached_path(file_path)
-
-        if file_path.endswith("zip"):
-            archive = zipfile.ZipFile(file_path, "r")
-            data_file = archive.open(os.path.basename(file_path)[:-4])
-        else:
-            data_file = open(file_path, "r")
-
-        logger.info("Reading instances from lines in file at: %s", file_path)
-
-        dialogs = json.load(data_file)
-
-        for dial_name in dialogs:
-            dialog = dialogs[dial_name]["log"]
-            context_tokens_list = []
-            for i, turn in enumerate(dialog):
-                if self._agent and self._agent == "user" and i % 2 == 1: 
-                    context_tokens_list.append(turn["text"].lower().split()+ ["SENT_END"])
-                    continue
-                if self._agent and self._agent == "system" and i % 2 == 0: 
-                    context_tokens_list.append(turn["text"].lower().split()+ ["SENT_END"])
-                    continue
-
-                tokens = turn["text"].split()
-
-                dialog_act = {}
-                for dacts in turn["span_info"]:
-                    if dacts[0] not in dialog_act:
-                        dialog_act[dacts[0]] = []
-                    dialog_act[dacts[0]].append([dacts[1], " ".join(tokens[dacts[3]: dacts[4]+1])])
-
-                spans = turn["span_info"]
-                tags = []
-                for j in range(len(tokens)):
-                    for span in spans:
-                        if j == span[3]:
-                            tags.append("B-"+span[0]+"+"+span[1])
-                            break
-                        if j > span[3] and j <= span[4]:
-                            tags.append("I-"+span[0]+"+"+span[1])
-                            break
-                    else:
-                        tags.append("O")
+        if self.use_unified_datasets:
+            data_split = file_path
+            logger.info("Reading instances from unified dataset %s[%s]", self._dataset_name, data_split)
+
+            data = load_nlu_data(self._dataset, data_split=data_split, speaker=self._agent, use_context=self._context_size>0, context_window_size=self._context_size)[data_split]
+
+            for sample in data:
+                utterance = sample['utterance']
+                sentences = self._sent_tokenizer.tokenize(utterance)
+                sent_spans = self._sent_tokenizer.span_tokenize(utterance)
+                tokens = [token for sent in sentences for token in self._word_tokenizer.tokenize(sent)]
+                token_spans = [(sent_span[0]+token_span[0], sent_span[0]+token_span[1]) for sent, sent_span in zip(sentences, sent_spans) for token_span in self._word_tokenizer.span_tokenize(sent)]
+                tags = ['O'] * len(tokens)
+
+                for da in sample['dialogue_acts']['non-categorical']:
+                    if 'start' not in da:
+                        # skip da that doesn't have span annotation
+                        continue
+                    char_start = da['start']
+                    char_end = da['end']
+                    word_start, word_end = -1, -1
+                    for i, token_span in enumerate(token_spans):
+                        if char_start == token_span[0]:
+                            word_start = i
+                        if char_end == token_span[1]:
+                            word_end = i + 1
+                    if word_start == -1 and word_end == -1:
+                        # char span does not match word, maybe there is an error in the annotation, skip
+                        print('char span does not match word, skipping')
+                        print('\t', 'utteance:', utterance)
+                        print('\t', 'value:', utterance[char_start: char_end])
+                        print('\t', 'da:', da, '\n')
+                        continue
+                    intent, domain, slot = da['intent'], da['domain'], da['slot']
+                    tags[word_start] = f"B-{intent}+{domain}+{slot}"
+                    for i in range(word_start+1, word_end):
+                        tags[i] = f"I-{intent}+{domain}+{slot}"
 
                 intents = []
-                for dacts in turn["dialog_act"]:
-                    for dact in turn["dialog_act"][dacts]:
-                        if dacts not in dialog_act or dact[0] not in [sv[0] for sv in dialog_act[dacts]]:
-                            if dact[1] in ["none", "?", "yes", "no", "dontcare", "do nt care", "do n't care"]:
-                                intents.append(dacts+"+"+dact[0]+"*"+dact[1])
-
-                for dacts in turn["dialog_act"]:
-                    for dact in turn["dialog_act"][dacts]:
-                        if dacts not in dialog_act:
-                            dialog_act[dacts] = turn["dialog_act"][dacts]
-                            break
-                        elif dact[0] not in [sv[0] for sv in dialog_act[dacts]]:
-                            dialog_act[dacts].append(dact)
+                for da in sample['dialogue_acts']['categorical']:
+                    intent, domain, slot, value = da['intent'], da['domain'], da['slot'], da['value'].strip().lower()
+                    intent = str((intent, domain, slot, value))
+                    intents.append(intent)
+                for da in sample['dialogue_acts']['binary']:
+                    intent, domain, slot = da['intent'], da['domain'], da['slot']
+                    intent = str((intent, domain, slot))
+                    intents.append(intent)
+
+                wrapped_tokens = [Token(token) for token in tokens]
 
+                wrapped_context_tokens = []
                 num_context = random.randint(0, self._context_size) if self._random_context_size else self._context_size
-                if len(context_tokens_list) > 0 and num_context > 0:
-                    wrapped_context_tokens = [Token(token) for context_tokens in context_tokens_list[-num_context:] for token in context_tokens]
+                if num_context > 0 and len(sample['context']) > 0:
+                    for utt in sample['context']:
+                        for sent in self._sent_tokenizer.tokenize(utt['utterance']):
+                            for token in self._word_tokenizer.tokenize(sent):
+                                wrapped_context_tokens.append(Token(token))
+                        wrapped_context_tokens.append(Token("SENT_END"))
                 else:
                     wrapped_context_tokens = [Token("SENT_END")]
-                wrapped_tokens = [Token(token) for token in tokens]
-                context_tokens_list.append(tokens + ["SENT_END"])
 
-                yield self.text_to_instance(wrapped_context_tokens, wrapped_tokens, tags, intents, dialog_act)
+                yield self.text_to_instance(wrapped_context_tokens, wrapped_tokens, tags, intents, sample['dialogue_acts'])
+        else:
+            # if `file_path` is a URL, redirect to the cache
+            file_path = cached_path(file_path)
+
+            if file_path.endswith("zip"):
+                archive = zipfile.ZipFile(file_path, "r")
+                data_file = archive.open(os.path.basename(file_path)[:-4])
+            else:
+                data_file = open(file_path, "r")
+
+            logger.info("Reading instances from lines in file at: %s", file_path)
+
+            dialogs = json.load(data_file)
+
+            for dial_name in dialogs:
+                dialog = dialogs[dial_name]["log"]
+                context_tokens_list = []
+                for i, turn in enumerate(dialog):
+                    if self._agent and self._agent == "user" and i % 2 == 1: 
+                        context_tokens_list.append(turn["text"].lower().split()+ ["SENT_END"])
+                        continue
+                    if self._agent and self._agent == "system" and i % 2 == 0: 
+                        context_tokens_list.append(turn["text"].lower().split()+ ["SENT_END"])
+                        continue
+
+                    tokens = turn["text"].split()
+
+                    dialog_act = {}
+                    for dacts in turn["span_info"]:
+                        if dacts[0] not in dialog_act:
+                            dialog_act[dacts[0]] = []
+                        dialog_act[dacts[0]].append([dacts[1], " ".join(tokens[dacts[3]: dacts[4]+1])])
+
+                    spans = turn["span_info"]
+                    tags = []
+                    for j in range(len(tokens)):
+                        for span in spans:
+                            if j == span[3]:
+                                tags.append("B-"+span[0]+"+"+span[1])
+                                break
+                            if j > span[3] and j <= span[4]:
+                                tags.append("I-"+span[0]+"+"+span[1])
+                                break
+                        else:
+                            tags.append("O")
+
+                    intents = []
+                    for dacts in turn["dialog_act"]:
+                        for dact in turn["dialog_act"][dacts]:
+                            if dacts not in dialog_act or dact[0] not in [sv[0] for sv in dialog_act[dacts]]:
+                                if dact[1] in ["none", "?", "yes", "no", "dontcare", "do nt care", "do n't care"]:
+                                    intents.append(dacts+"+"+dact[0]+"*"+dact[1])
+
+                    for dacts in turn["dialog_act"]:
+                        for dact in turn["dialog_act"][dacts]:
+                            if dacts not in dialog_act:
+                                dialog_act[dacts] = turn["dialog_act"][dacts]
+                                break
+                            elif dact[0] not in [sv[0] for sv in dialog_act[dacts]]:
+                                dialog_act[dacts].append(dact)
+
+                    num_context = random.randint(0, self._context_size) if self._random_context_size else self._context_size
+                    if len(context_tokens_list) > 0 and num_context > 0:
+                        wrapped_context_tokens = [Token(token) for context_tokens in context_tokens_list[-num_context:] for token in context_tokens]
+                    else:
+                        wrapped_context_tokens = [Token("SENT_END")]
+                    wrapped_tokens = [Token(token) for token in tokens]
+                    context_tokens_list.append(tokens + ["SENT_END"])
+
+                    yield self.text_to_instance(wrapped_context_tokens, wrapped_tokens, tags, intents, dialog_act)
 
 
     def text_to_instance(self, context_tokens: List[Token], tokens: List[Token], tags: List[str] = None,
diff --git a/convlab2/nlu/milu/evaluate.py b/convlab2/nlu/milu/evaluate.py
index 55bb5c8b416c6816ccf0c5114c8ca23a8461fe08..1bef413ccbd123267c91dfbed8faa971fe41dd01 100755
--- a/convlab2/nlu/milu/evaluate.py
+++ b/convlab2/nlu/milu/evaluate.py
@@ -16,7 +16,7 @@ from allennlp.common.util import prepare_environment
 from allennlp.data.dataset_readers.dataset_reader import DatasetReader
 from allennlp.data.iterators import DataIterator
 from allennlp.models.archival import load_archive
-from allennlp.training.util import evaluate
+from convlab2.nlu.milu.util import evaluate
 
 from convlab2.nlu.milu import dataset_reader, model
 
@@ -28,7 +28,7 @@ argparser.add_argument('archive_file', type=str, help='path to an archived train
 
 argparser.add_argument('input_file', type=str, help='path to the file containing the evaluation data')
 
-argparser.add_argument('--output-file', type=str, help='path to output file')
+argparser.add_argument('--output_file', type=str, help='path to output file')
 
 argparser.add_argument('--weights-file',
                         type=str,
@@ -105,7 +105,7 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]:
     iterator = DataIterator.from_params(iterator_params)
     iterator.index_with(model.vocab)
 
-    metrics = evaluate(model, instances, iterator, args.cuda_device, args.batch_weight_key)
+    metrics, predict_results = evaluate(model, instances, iterator, args.cuda_device, args.batch_weight_key)
 
     logger.info("Finished evaluating.")
     logger.info("Metrics:")
@@ -114,8 +114,8 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]:
 
     output_file = args.output_file
     if output_file:
-        with open(output_file, "w") as file:
-            json.dump(metrics, file, indent=4)
+        with open(output_file, "w", encoding='utf-8') as file:
+            json.dump(predict_results, file, indent=2, ensure_ascii=False)
     return metrics
 
 
diff --git a/convlab2/nlu/milu/model.py b/convlab2/nlu/milu/model.py
index e4831926bd08c1b2df6b924f3f65f65d097e8b87..8d5031c7891e3abb010ff0b72ff4a42489dc39a6 100755
--- a/convlab2/nlu/milu/model.py
+++ b/convlab2/nlu/milu/model.py
@@ -57,6 +57,7 @@ class MILU(Model):
                  feedforward: Optional[FeedForward] = None,
                  label_encoding: Optional[str] = None,
                  include_start_end_transitions: bool = True,
+                 use_unified_datasets: bool = False,
                  crf_decoding: bool = False,
                  constrain_crf_decoding: bool = None,
                  focal_loss_gamma: float = None,
@@ -83,6 +84,7 @@ class MILU(Model):
         self.tag_encoder = intent_encoder
         self._feedforward = feedforward
         self._verbose_metrics = verbose_metrics
+        self.use_unified_datasets = use_unified_datasets
         self.rl = False 
  
         if attention:
@@ -164,7 +166,7 @@ class MILU(Model):
             self._f1_metric = SpanBasedF1Measure(vocab,
                                                  tag_namespace=sequence_label_namespace,
                                                  label_encoding=label_encoding)
-        self._dai_f1_metric = DialogActItemF1Measure()
+        self._dai_f1_metric = DialogActItemF1Measure(self.use_unified_datasets)
 
         check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(),
                                "text field embedding dim", "encoder input dim")
@@ -355,29 +357,64 @@ class MILU(Model):
         for i, tags in enumerate(output_dict["tags"]): 
             seq_len = len(output_dict["words"][i])
             spans = bio_tags_to_spans(tags[:seq_len])
-            dialog_act = {}
-            for span in spans:
-                domain_act = span[0].split("+")[0]
-                slot = span[0].split("+")[1]
-                value = " ".join(output_dict["words"][i][span[1][0]:span[1][1]+1])
-                if domain_act not in dialog_act:
-                    dialog_act[domain_act] = [[slot, value]]
-                else:
-                    dialog_act[domain_act].append([slot, value])
-            for intent in output_dict["intents"][i]:
-                if "+" in intent: 
-                    if "*" in intent: 
-                        intent, value = intent.split("*", 1) 
+            if self.use_unified_datasets:
+                dialog_act = {
+                    'categorical': [],
+                    'non-categorical': [],
+                    'binary': []
+                }
+                for span in spans:
+                    intent, domain, slot = span[0].split("+")
+                    value = " ".join(output_dict["words"][i][span[1][0]:span[1][1]+1])
+                    dialog_act['non-categorical'].append({
+                        'intent': intent,
+                        'domain': domain,
+                        'slot': slot,
+                        'value': value
+                    })
+                
+                for intent in output_dict["intents"][i]:
+                    intent = eval(intent)
+                    if len(intent) == 3:
+                        dialog_act['binary'].append({
+                            'intent': intent[0],
+                            'domain': intent[1],
+                            'slot': intent[2]
+                        })
                     else:
-                        value = "?"
-                    domain_act = intent.split("+")[0] 
+                        assert len(intent) == 4
+                        dialog_act['categorical'].append({
+                            'intent': intent[0],
+                            'domain': intent[1],
+                            'slot': intent[2],
+                            'value': intent[3]
+                        })
+                output_dict["dialog_act"].append(dialog_act)
+
+            else:
+                dialog_act = {}
+                for span in spans:
+                    domain_act = span[0].split("+")[0]
+                    slot = span[0].split("+")[1]
+                    value = " ".join(output_dict["words"][i][span[1][0]:span[1][1]+1])
                     if domain_act not in dialog_act:
-                        dialog_act[domain_act] = [[intent.split("+")[1], value]]
+                        dialog_act[domain_act] = [[slot, value]]
+                    else:
+                        dialog_act[domain_act].append([slot, value])
+                for intent in output_dict["intents"][i]:
+                    if "+" in intent: 
+                        if "*" in intent: 
+                            intent, value = intent.split("*", 1) 
+                        else:
+                            value = "?"
+                        domain_act = intent.split("+")[0] 
+                        if domain_act not in dialog_act:
+                            dialog_act[domain_act] = [[intent.split("+")[1], value]]
+                        else:
+                            dialog_act[domain_act].append([intent.split("+")[1], value])
                     else:
-                        dialog_act[domain_act].append([intent.split("+")[1], value])
-                else:
-                    dialog_act[intent] = [["none", "none"]]
-            output_dict["dialog_act"].append(dialog_act)
+                        dialog_act[intent] = [["none", "none"]]
+                output_dict["dialog_act"].append(dialog_act)
 
         return output_dict
 
diff --git a/convlab2/nlu/milu/unified_datasets/__init__.py b/convlab2/nlu/milu/unified_datasets/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..902c9a23af7fe1f0c8423394085a0f4ea0bb55f9
--- /dev/null
+++ b/convlab2/nlu/milu/unified_datasets/__init__.py
@@ -0,0 +1 @@
+from convlab2.nlu.milu.unified_datasets.nlu import MILU
diff --git a/convlab2/nlu/milu/unified_datasets/configs/multiwoz21_user.jsonnet b/convlab2/nlu/milu/unified_datasets/configs/multiwoz21_user.jsonnet
new file mode 100755
index 0000000000000000000000000000000000000000..858a57a19ecbe2acd59f02465d79c1d852341f60
--- /dev/null
+++ b/convlab2/nlu/milu/unified_datasets/configs/multiwoz21_user.jsonnet
@@ -0,0 +1,104 @@
+{
+  "dataset_reader": {
+    "type": "milu",
+    "token_indexers": {
+      "tokens": {
+        "type": "single_id",
+        "lowercase_tokens": true
+      },
+      "token_characters": {
+        "type": "characters",
+        "min_padding_length": 3
+      },
+    },
+    "context_size": 0,
+    "agent": "user",
+    "use_unified_datasets": true,
+    "dataset_name": "multiwoz21",
+    "random_context_size": false
+  },
+  "train_data_path": "train",
+  "validation_data_path": "validation",
+  "test_data_path": "test",
+  "model": {
+    "type": "milu",
+    "label_encoding": "BIO",
+    "use_unified_datasets": true,
+    "dropout": 0.3,
+    "include_start_end_transitions": false,
+    "text_field_embedder": {
+      "token_embedders": {
+        "tokens": {
+            "type": "embedding",
+            "embedding_dim": 50,
+            "pretrained_file": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.6B.50d.txt.gz",
+            "trainable": true
+        },
+        "token_characters": {
+            "type": "character_encoding",
+            "embedding": {
+            "embedding_dim": 16
+            },
+            "encoder": {
+            "type": "cnn",
+            "embedding_dim": 16,
+            "num_filters": 128,
+            "ngram_filter_sizes": [3],
+            "conv_layer_activation": "relu"
+            }
+        }
+      }
+    },
+    "encoder": {
+      "type": "lstm",
+      "input_size": 178,
+      "hidden_size": 200,
+      "num_layers": 1,
+      "dropout": 0.5,
+      "bidirectional": true
+    },
+    "intent_encoder": {
+      "type": "lstm",
+      "input_size": 400,
+      "hidden_size": 200,
+      "num_layers": 1,
+      "dropout": 0.5,
+      "bidirectional": true
+    },
+    "attention": {
+      "type": "bilinear",
+      "vector_dim": 400,
+      "matrix_dim": 400
+    },    
+    "context_for_intent": true,
+    "context_for_tag": false,
+    "attention_for_intent": false,
+    "attention_for_tag": false,
+    "regularizer": [
+      [
+        "scalar_parameters",
+        {
+          "type": "l2",
+          "alpha": 0.1
+        }
+      ]
+    ]
+  },
+  "iterator": {
+    "type": "basic",
+    "batch_size": 64
+  },
+  "trainer": {
+    "optimizer": {
+        "type": "adam",
+        "lr": 0.001
+    },
+    "validation_metric": "+f1-measure",
+    "num_serialized_models_to_keep": 3,
+    "num_epochs": 40,
+    "grad_norm": 5.0,
+    "patience": 75,
+    "cuda_device": 4
+  },
+  "evaluate_on_test": true
+}
diff --git a/convlab2/nlu/milu/unified_datasets/configs/multiwoz21_user_context3.jsonnet b/convlab2/nlu/milu/unified_datasets/configs/multiwoz21_user_context3.jsonnet
new file mode 100755
index 0000000000000000000000000000000000000000..1ce1d0d0b4135a2ee03d1ba7dda084f5c22da0c0
--- /dev/null
+++ b/convlab2/nlu/milu/unified_datasets/configs/multiwoz21_user_context3.jsonnet
@@ -0,0 +1,104 @@
+{
+  "dataset_reader": {
+    "type": "milu",
+    "token_indexers": {
+      "tokens": {
+        "type": "single_id",
+        "lowercase_tokens": true
+      },
+      "token_characters": {
+        "type": "characters",
+        "min_padding_length": 3
+      },
+    },
+    "context_size": 3,
+    "agent": "user",
+    "use_unified_datasets": true,
+    "dataset_name": "multiwoz21",
+    "random_context_size": false
+  },
+  "train_data_path": "train",
+  "validation_data_path": "validation",
+  "test_data_path": "test",
+  "model": {
+    "type": "milu",
+    "label_encoding": "BIO",
+    "use_unified_datasets": true,
+    "dropout": 0.3,
+    "include_start_end_transitions": false,
+    "text_field_embedder": {
+      "token_embedders": {
+        "tokens": {
+            "type": "embedding",
+            "embedding_dim": 50,
+            "pretrained_file": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.6B.50d.txt.gz",
+            "trainable": true
+        },
+        "token_characters": {
+            "type": "character_encoding",
+            "embedding": {
+            "embedding_dim": 16
+            },
+            "encoder": {
+            "type": "cnn",
+            "embedding_dim": 16,
+            "num_filters": 128,
+            "ngram_filter_sizes": [3],
+            "conv_layer_activation": "relu"
+            }
+        }
+      }
+    },
+    "encoder": {
+      "type": "lstm",
+      "input_size": 178,
+      "hidden_size": 200,
+      "num_layers": 1,
+      "dropout": 0.5,
+      "bidirectional": true
+    },
+    "intent_encoder": {
+      "type": "lstm",
+      "input_size": 400,
+      "hidden_size": 200,
+      "num_layers": 1,
+      "dropout": 0.5,
+      "bidirectional": true
+    },
+    "attention": {
+      "type": "bilinear",
+      "vector_dim": 400,
+      "matrix_dim": 400
+    },    
+    "context_for_intent": true,
+    "context_for_tag": false,
+    "attention_for_intent": false,
+    "attention_for_tag": false,
+    "regularizer": [
+      [
+        "scalar_parameters",
+        {
+          "type": "l2",
+          "alpha": 0.1
+        }
+      ]
+    ]
+  },
+  "iterator": {
+    "type": "basic",
+    "batch_size": 64
+  },
+  "trainer": {
+    "optimizer": {
+        "type": "adam",
+        "lr": 0.001
+    },
+    "validation_metric": "+f1-measure",
+    "num_serialized_models_to_keep": 3,
+    "num_epochs": 40,
+    "grad_norm": 5.0,
+    "patience": 75,
+    "cuda_device": 0
+  },
+  "evaluate_on_test": true
+}
diff --git a/convlab2/nlu/milu/unified_datasets/merge_predict_res.py b/convlab2/nlu/milu/unified_datasets/merge_predict_res.py
new file mode 100755
index 0000000000000000000000000000000000000000..f785f832931635f1fec8875eaf5bbbe487f6d6a9
--- /dev/null
+++ b/convlab2/nlu/milu/unified_datasets/merge_predict_res.py
@@ -0,0 +1,33 @@
+import json
+import os
+from convlab2.util import load_dataset, load_nlu_data
+
+
+def merge(dataset_name, speaker, save_dir, context_window_size, predict_result):
+    assert os.path.exists(predict_result)
+    dataset = load_dataset(dataset_name)
+    data = load_nlu_data(dataset, data_split='test', speaker=speaker, use_context=context_window_size>0, context_window_size=context_window_size)['test']
+    
+    if save_dir is None:
+        save_dir = os.path.dirname(predict_result)
+    else:
+        os.makedirs(save_dir, exist_ok=True)
+    predict_result = json.load(open(predict_result))
+
+    for sample, prediction in zip(data, predict_result):
+        sample['predictions'] = {'dialogue_acts': prediction}
+
+    json.dump(data, open(os.path.join(save_dir, 'predictions.json'), 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
+
+
+if __name__ == '__main__':
+    from argparse import ArgumentParser
+    parser = ArgumentParser(description="merge predict results with original data for unified NLU evaluation")
+    parser.add_argument('--dataset', '-d', metavar='dataset_name', type=str, help='name of the unified dataset')
+    parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s) of utterances')
+    parser.add_argument('--save_dir', type=str, help='merged data will be saved as $save_dir/predictions.json. default: on the same directory as predict_result')
+    parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered')
+    parser.add_argument('--predict_result', '-p', type=str, required=True, help='path to the output file generated by ../test.py')
+    args = parser.parse_args()
+    print(args)
+    merge(args.dataset, args.speaker, args.save_dir, args.context_window_size, args.predict_result)
diff --git a/convlab2/nlu/milu/unified_datasets/nlu.py b/convlab2/nlu/milu/unified_datasets/nlu.py
new file mode 100755
index 0000000000000000000000000000000000000000..8d5e96852c6369a4db61702c241e2eb4e4cee1a5
--- /dev/null
+++ b/convlab2/nlu/milu/unified_datasets/nlu.py
@@ -0,0 +1,104 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT license.
+
+"""
+"""
+
+import os
+from pprint import pprint
+import torch
+from allennlp.common.checks import check_for_gpu
+from allennlp.data import DatasetReader
+from allennlp.models.archival import load_archive
+from allennlp.data.tokenizers import Token
+
+from convlab2.util.file_util import cached_path
+from convlab2.nlu.milu import dataset_reader, model
+from convlab2.nlu.nlu import NLU
+from nltk.tokenize import TreebankWordTokenizer, PunktSentenceTokenizer
+
+DEFAULT_CUDA_DEVICE = -1
+DEFAULT_DIRECTORY = "models"
+DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "milu_multiwoz_all_context.tar.gz")
+
+class MILU(NLU):
+    """Multi-intent language understanding model."""
+
+    def __init__(self,
+                archive_file,
+                cuda_device,
+                model_file,
+                context_size):
+        """ Constructor for NLU class. """
+
+        self.context_size = context_size
+        cuda_device = 0 if torch.cuda.is_available() else DEFAULT_CUDA_DEVICE
+        check_for_gpu(cuda_device)
+
+        if not os.path.isfile(archive_file):
+            if not model_file:
+                raise Exception("No model for MILU is specified!")
+
+            archive_file = cached_path(model_file)
+
+        archive = load_archive(archive_file,
+                            cuda_device=cuda_device)
+        self.sent_tokenizer = PunktSentenceTokenizer()
+        self.word_tokenizer = TreebankWordTokenizer()
+
+        dataset_reader_params = archive.config["dataset_reader"]
+        self.dataset_reader = DatasetReader.from_params(dataset_reader_params)
+        self.model = archive.model
+        self.model.eval()
+
+
+    def predict(self, utterance, context=list()):
+        """
+        Predict the dialog act of a natural language utterance and apply error model.
+        Args:
+            utterance (str): A natural language utterance.
+        Returns:
+            output (dict): The dialog act of utterance.
+        """
+        if len(utterance) == 0:
+            return []
+
+        if self.context_size > 0 and len(context) > 0:
+            context_tokens = []
+            for utt in context[-self.context_size:]:
+                for sent in self.sent_tokenizer.tokenize(utt):
+                    for token in self.word_tokenizer.tokenize(sent):
+                        context_tokens.append(Token(token))
+                context_tokens.append(Token("SENT_END"))
+        else:
+            context_tokens = [Token("SENT_END")]
+        sentences = self.sent_tokenizer.tokenize(utterance)
+        tokens = [Token(token) for sent in sentences for token in self.word_tokenizer.tokenize(sent)]
+        instance = self.dataset_reader.text_to_instance(context_tokens, tokens)
+        outputs = self.model.forward_on_instance(instance)
+
+        tuples = []
+        for da_type in outputs['dialog_act']:
+            for da in outputs['dialog_act'][da_type]:
+                tuples.append([da['intent'], da['domain'], da['slot'], da.get('value','')])
+        return tuples
+
+
+if __name__ == "__main__":
+    nlu = MILU(archive_file='../output/multiwoz21_user/model.tar.gz', cuda_device=3, model_file=None, context_size=3)
+    test_utterances = [
+        "What type of accommodations are they. No , i just need their address . Can you tell me if the hotel has internet available ?",
+        "What type of accommodations are they.",
+        "No , i just need their address .",
+        "Can you tell me if the hotel has internet available ?",
+        "yes. it should be moderately priced.",
+        "i want to book a table for 6 at 18:45 on thursday",
+        "i will be departing out of stevenage.",
+        "What is the name of attraction ?",
+        "Can I get the name of restaurant?",
+        "Can I get the address and phone number of the restaurant?",
+        "do you have a specific area you want to stay in?"
+    ]
+    for utt in test_utterances:
+        print(utt)
+        pprint(nlu.predict(utt))
diff --git a/convlab2/nlu/milu/util.py b/convlab2/nlu/milu/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f40f43aa2b0a708489f456fbbcaa4658e45aeae
--- /dev/null
+++ b/convlab2/nlu/milu/util.py
@@ -0,0 +1,440 @@
+"""
+Helper functions for Trainers
+"""
+from typing import Any, Union, Dict, Iterable, List, Optional, Tuple
+import datetime
+import json
+import logging
+import pathlib
+import os
+import shutil
+
+import torch
+from torch.nn.parallel import replicate, parallel_apply
+from torch.nn.parallel.scatter_gather import gather
+
+from allennlp.common.checks import ConfigurationError, check_for_gpu
+from allennlp.common.params import Params
+from allennlp.common.tqdm import Tqdm
+from allennlp.data.dataset_readers import DatasetReader
+from allennlp.data import Instance
+from allennlp.data.iterators import DataIterator
+from allennlp.data.iterators.data_iterator import TensorDict
+from allennlp.models.model import Model
+from allennlp.models.archival import CONFIG_NAME
+from allennlp.nn import util as nn_util
+
+logger = logging.getLogger(__name__)
+
+# We want to warn people that tqdm ignores metrics that start with underscores
+# exactly once. This variable keeps track of whether we have.
+class HasBeenWarned:
+    tqdm_ignores_underscores = False
+
+def sparse_clip_norm(parameters, max_norm, norm_type=2) -> float:
+    """Clips gradient norm of an iterable of parameters.
+
+    The norm is computed over all gradients together, as if they were
+    concatenated into a single vector. Gradients are modified in-place.
+    Supports sparse gradients.
+
+    Parameters
+    ----------
+    parameters : ``(Iterable[torch.Tensor])``
+        An iterable of Tensors that will have gradients normalized.
+    max_norm : ``float``
+        The max norm of the gradients.
+    norm_type : ``float``
+        The type of the used p-norm. Can be ``'inf'`` for infinity norm.
+
+    Returns
+    -------
+    Total norm of the parameters (viewed as a single vector).
+    """
+    # pylint: disable=invalid-name,protected-access
+    parameters = list(filter(lambda p: p.grad is not None, parameters))
+    max_norm = float(max_norm)
+    norm_type = float(norm_type)
+    if norm_type == float('inf'):
+        total_norm = max(p.grad.data.abs().max() for p in parameters)
+    else:
+        total_norm = 0
+        for p in parameters:
+            if p.grad.is_sparse:
+                # need to coalesce the repeated indices before finding norm
+                grad = p.grad.data.coalesce()
+                param_norm = grad._values().norm(norm_type)
+            else:
+                param_norm = p.grad.data.norm(norm_type)
+            total_norm += param_norm ** norm_type
+        total_norm = total_norm ** (1. / norm_type)
+    clip_coef = max_norm / (total_norm + 1e-6)
+    if clip_coef < 1:
+        for p in parameters:
+            if p.grad.is_sparse:
+                p.grad.data._values().mul_(clip_coef)
+            else:
+                p.grad.data.mul_(clip_coef)
+    return total_norm
+
+
+def move_optimizer_to_cuda(optimizer):
+    """
+    Move the optimizer state to GPU, if necessary.
+    After calling, any parameter specific state in the optimizer
+    will be located on the same device as the parameter.
+    """
+    for param_group in optimizer.param_groups:
+        for param in param_group['params']:
+            if param.is_cuda:
+                param_state = optimizer.state[param]
+                for k in param_state.keys():
+                    if isinstance(param_state[k], torch.Tensor):
+                        param_state[k] = param_state[k].cuda(device=param.get_device())
+
+
+def get_batch_size(batch: Union[Dict, torch.Tensor]) -> int:
+    """
+    Returns the size of the batch dimension. Assumes a well-formed batch,
+    returns 0 otherwise.
+    """
+    if isinstance(batch, torch.Tensor):
+        return batch.size(0) # type: ignore
+    elif isinstance(batch, Dict):
+        return get_batch_size(next(iter(batch.values())))
+    else:
+        return 0
+
+
+def time_to_str(timestamp: int) -> str:
+    """
+    Convert seconds past Epoch to human readable string.
+    """
+    datetimestamp = datetime.datetime.fromtimestamp(timestamp)
+    return '{:04d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}'.format(
+            datetimestamp.year, datetimestamp.month, datetimestamp.day,
+            datetimestamp.hour, datetimestamp.minute, datetimestamp.second
+    )
+
+
+def str_to_time(time_str: str) -> datetime.datetime:
+    """
+    Convert human readable string to datetime.datetime.
+    """
+    pieces: Any = [int(piece) for piece in time_str.split('-')]
+    return datetime.datetime(*pieces)
+
+
+def datasets_from_params(params: Params,
+                         cache_directory: str = None,
+                         cache_prefix: str = None) -> Dict[str, Iterable[Instance]]:
+    """
+    Load all the datasets specified by the config.
+
+    Parameters
+    ----------
+    params : ``Params``
+    cache_directory : ``str``, optional
+        If given, we will instruct the ``DatasetReaders`` that we construct to cache their
+        instances in this location (or read their instances from caches in this location, if a
+        suitable cache already exists).  This is essentially a `base` directory for the cache, as
+        we will additionally add the ``cache_prefix`` to this directory, giving an actual cache
+        location of ``cache_directory + cache_prefix``.
+    cache_prefix : ``str``, optional
+        This works in conjunction with the ``cache_directory``.  The idea is that the
+        ``cache_directory`` contains caches for all different parameter settings, while the
+        ``cache_prefix`` captures a specific set of parameters that led to a particular cache file.
+        That is, if you change the tokenization settings inside your ``DatasetReader``, you don't
+        want to read cached data that used the old settings.  In order to avoid this, we compute a
+        hash of the parameters used to construct each ``DatasetReader`` and use that as a "prefix"
+        to the cache files inside the base ``cache_directory``.  So, a given ``input_file`` would
+        be cached essentially as ``cache_directory + cache_prefix + input_file``, where you specify
+        a ``cache_directory``, the ``cache_prefix`` is based on the dataset reader parameters, and
+        the ``input_file`` is whatever path you provided to ``DatasetReader.read()``.  In order to
+        allow you to give recognizable names to these prefixes if you want them, you can manually
+        specify the ``cache_prefix``.  Note that in some rare cases this can be dangerous, as we'll
+        use the `same` prefix for both train and validation dataset readers.
+    """
+    dataset_reader_params = params.pop('dataset_reader')
+    validation_dataset_reader_params = params.pop('validation_dataset_reader', None)
+    train_cache_dir, validation_cache_dir = _set_up_cache_files(dataset_reader_params,
+                                                                validation_dataset_reader_params,
+                                                                cache_directory,
+                                                                cache_prefix)
+
+    dataset_reader = DatasetReader.from_params(dataset_reader_params)
+
+    validation_and_test_dataset_reader: DatasetReader = dataset_reader
+    if validation_dataset_reader_params is not None:
+        logger.info("Using a separate dataset reader to load validation and test data.")
+        validation_and_test_dataset_reader = DatasetReader.from_params(validation_dataset_reader_params)
+
+    if train_cache_dir:
+        dataset_reader.cache_data(train_cache_dir)
+        validation_and_test_dataset_reader.cache_data(validation_cache_dir)
+
+    train_data_path = params.pop('train_data_path')
+    logger.info("Reading training data from %s", train_data_path)
+    train_data = dataset_reader.read(train_data_path)
+
+    datasets: Dict[str, Iterable[Instance]] = {"train": train_data}
+
+    validation_data_path = params.pop('validation_data_path', None)
+    if validation_data_path is not None:
+        logger.info("Reading validation data from %s", validation_data_path)
+        validation_data = validation_and_test_dataset_reader.read(validation_data_path)
+        datasets["validation"] = validation_data
+
+    test_data_path = params.pop("test_data_path", None)
+    if test_data_path is not None:
+        logger.info("Reading test data from %s", test_data_path)
+        test_data = validation_and_test_dataset_reader.read(test_data_path)
+        datasets["test"] = test_data
+
+    return datasets
+
+
+def _set_up_cache_files(train_params: Params,
+                        validation_params: Params = None,
+                        cache_directory: str = None,
+                        cache_prefix: str = None) -> Tuple[str, str]:
+    if not cache_directory:
+        return None, None
+
+    # We need to compute the parameter hash before the parameters get destroyed when they're
+    # passed to `DatasetReader.from_params`.
+    if not cache_prefix:
+        cache_prefix = _dataset_reader_param_hash(train_params)
+        if validation_params:
+            validation_cache_prefix = _dataset_reader_param_hash(validation_params)
+        else:
+            validation_cache_prefix = cache_prefix
+    else:
+        validation_cache_prefix = cache_prefix
+
+    train_cache_dir = pathlib.Path(cache_directory) / cache_prefix
+    validation_cache_dir = pathlib.Path(cache_directory) / validation_cache_prefix
+
+    # For easy human inspection of what parameters were used to create the cache.  This will
+    # overwrite old files, but they should be identical.  This could bite someone who gave
+    # their own prefix instead of letting us compute it, and then _re-used_ that name with
+    # different parameters, without clearing the cache first.  But correctly handling that case
+    # is more work than it's worth.
+    os.makedirs(train_cache_dir, exist_ok=True)
+    with open(train_cache_dir / 'params.json', 'w') as param_file:
+        json.dump(train_params.as_dict(quiet=True), param_file)
+    os.makedirs(validation_cache_dir, exist_ok=True)
+    with open(validation_cache_dir / 'params.json', 'w') as param_file:
+        if validation_params:
+            json.dump(validation_params.as_dict(quiet=True), param_file)
+        else:
+            json.dump(train_params.as_dict(quiet=True), param_file)
+    return str(train_cache_dir), str(validation_cache_dir)
+
+
+def _dataset_reader_param_hash(params: Params) -> str:
+    copied_params = params.duplicate()
+    # Laziness doesn't affect how the data is computed, so it shouldn't affect the hash.
+    copied_params.pop('lazy', default=None)
+    return copied_params.get_hash()
+
+
+def create_serialization_dir(
+        params: Params,
+        serialization_dir: str,
+        recover: bool,
+        force: bool) -> None:
+    """
+    This function creates the serialization directory if it doesn't exist.  If it already exists
+    and is non-empty, then it verifies that we're recovering from a training with an identical configuration.
+
+    Parameters
+    ----------
+    params: ``Params``
+        A parameter object specifying an AllenNLP Experiment.
+    serialization_dir: ``str``
+        The directory in which to save results and logs.
+    recover: ``bool``
+        If ``True``, we will try to recover from an existing serialization directory, and crash if
+        the directory doesn't exist, or doesn't match the configuration we're given.
+    force: ``bool``
+        If ``True``, we will overwrite the serialization directory if it already exists.
+    """
+    if recover and force:
+        raise ConfigurationError("Illegal arguments: both force and recover are true.")
+
+    if os.path.exists(serialization_dir) and force:
+        shutil.rmtree(serialization_dir)
+
+    if os.path.exists(serialization_dir) and os.listdir(serialization_dir):
+        if not recover:
+            raise ConfigurationError(f"Serialization directory ({serialization_dir}) already exists and is "
+                                     f"not empty. Specify --recover to recover training from existing output.")
+
+        logger.info(f"Recovering from prior training at {serialization_dir}.")
+
+        recovered_config_file = os.path.join(serialization_dir, CONFIG_NAME)
+        if not os.path.exists(recovered_config_file):
+            raise ConfigurationError("The serialization directory already exists but doesn't "
+                                     "contain a config.json. You probably gave the wrong directory.")
+        loaded_params = Params.from_file(recovered_config_file)
+
+        # Check whether any of the training configuration differs from the configuration we are
+        # resuming.  If so, warn the user that training may fail.
+        fail = False
+        flat_params = params.as_flat_dict()
+        flat_loaded = loaded_params.as_flat_dict()
+        for key in flat_params.keys() - flat_loaded.keys():
+            logger.error(f"Key '{key}' found in training configuration but not in the serialization "
+                         f"directory we're recovering from.")
+            fail = True
+        for key in flat_loaded.keys() - flat_params.keys():
+            logger.error(f"Key '{key}' found in the serialization directory we're recovering from "
+                         f"but not in the training config.")
+            fail = True
+        for key in flat_params.keys():
+            if flat_params.get(key, None) != flat_loaded.get(key, None):
+                logger.error(f"Value for '{key}' in training configuration does not match that the value in "
+                             f"the serialization directory we're recovering from: "
+                             f"{flat_params[key]} != {flat_loaded[key]}")
+                fail = True
+        if fail:
+            raise ConfigurationError("Training configuration does not match the configuration we're "
+                                     "recovering from.")
+    else:
+        if recover:
+            raise ConfigurationError(f"--recover specified but serialization_dir ({serialization_dir}) "
+                                     "does not exist.  There is nothing to recover from.")
+        os.makedirs(serialization_dir, exist_ok=True)
+
+def data_parallel(batch_group: List[TensorDict],
+                  model: Model,
+                  cuda_devices: List) -> Dict[str, torch.Tensor]:
+    """
+    Performs a forward pass using multiple GPUs.  This is a simplification
+    of torch.nn.parallel.data_parallel to support the allennlp model
+    interface.
+    """
+    assert len(batch_group) <= len(cuda_devices)
+
+    moved = [nn_util.move_to_device(batch, device)
+             for batch, device in zip(batch_group, cuda_devices)]
+
+    used_device_ids = cuda_devices[:len(moved)]
+    # Counterintuitively, it appears replicate expects the source device id to be the first element
+    # in the device id list. See torch.cuda.comm.broadcast_coalesced, which is called indirectly.
+    replicas = replicate(model, used_device_ids)
+
+    # We pass all our arguments as kwargs. Create a list of empty tuples of the
+    # correct shape to serve as (non-existent) positional arguments.
+    inputs = [()] * len(batch_group)
+    outputs = parallel_apply(replicas, inputs, moved, used_device_ids)
+
+    # Only the 'loss' is needed.
+    # a (num_gpu, ) tensor with loss on each GPU
+    losses = gather([output['loss'].unsqueeze(0) for output in outputs], used_device_ids[0], 0)
+    return {'loss': losses.mean()}
+
+def enable_gradient_clipping(model: Model, grad_clipping: Optional[float]) -> None:
+    if grad_clipping is not None:
+        for parameter in model.parameters():
+            if parameter.requires_grad:
+                parameter.register_hook(lambda grad: nn_util.clamp_tensor(grad,
+                                                                          minimum=-grad_clipping,
+                                                                          maximum=grad_clipping))
+
+def rescale_gradients(model: Model, grad_norm: Optional[float] = None) -> Optional[float]:
+    """
+    Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled.
+    """
+    if grad_norm:
+        parameters_to_clip = [p for p in model.parameters()
+                              if p.grad is not None]
+        return sparse_clip_norm(parameters_to_clip, grad_norm)
+    return None
+
+def get_metrics(model: Model, total_loss: float, num_batches: int, reset: bool = False) -> Dict[str, float]:
+    """
+    Gets the metrics but sets ``"loss"`` to
+    the total loss divided by the ``num_batches`` so that
+    the ``"loss"`` metric is "average loss per batch".
+    """
+    metrics = model.get_metrics(reset=reset)
+    metrics["loss"] = float(total_loss / num_batches) if num_batches > 0 else 0.0
+    return metrics
+
+
+def evaluate(model: Model,
+             instances: Iterable[Instance],
+             data_iterator: DataIterator,
+             cuda_device: int,
+             batch_weight_key: str) -> Dict[str, Any]:
+    check_for_gpu(cuda_device)
+    predict_results = []
+    with torch.no_grad():
+        model.eval()
+
+        iterator = data_iterator(instances,
+                                 num_epochs=1,
+                                 shuffle=False)
+        logger.info("Iterating over dataset")
+        generator_tqdm = Tqdm.tqdm(iterator, total=data_iterator.get_num_batches(instances))
+
+        # Number of batches in instances.
+        batch_count = 0
+        # Number of batches where the model produces a loss.
+        loss_count = 0
+        # Cumulative weighted loss
+        total_loss = 0.0
+        # Cumulative weight across all batches.
+        total_weight = 0.0
+
+        for batch in generator_tqdm:
+            batch_count += 1
+            batch = nn_util.move_to_device(batch, cuda_device)
+            output_dict = model(**batch)
+            loss = output_dict.get("loss")
+            predict_results += output_dict["dialog_act"]
+
+            metrics = model.get_metrics()
+
+            if loss is not None:
+                loss_count += 1
+                if batch_weight_key:
+                    weight = output_dict[batch_weight_key].item()
+                else:
+                    weight = 1.0
+
+                total_weight += weight
+                total_loss += loss.item() * weight
+                # Report the average loss so far.
+                metrics["loss"] = total_loss / total_weight
+
+            if (not HasBeenWarned.tqdm_ignores_underscores and
+                        any(metric_name.startswith("_") for metric_name in metrics)):
+                logger.warning("Metrics with names beginning with \"_\" will "
+                               "not be logged to the tqdm progress bar.")
+                HasBeenWarned.tqdm_ignores_underscores = True
+            description = ', '.join(["%s: %.2f" % (name, value) for name, value
+                                     in metrics.items() if not name.startswith("_")]) + " ||"
+            generator_tqdm.set_description(description, refresh=False)
+
+        final_metrics = model.get_metrics(reset=True)
+        if loss_count > 0:
+            # Sanity check
+            if loss_count != batch_count:
+                raise RuntimeError("The model you are trying to evaluate only sometimes " +
+                                   "produced a loss!")
+            final_metrics["loss"] = total_loss / total_weight
+
+        return final_metrics, predict_results
+
+def description_from_metrics(metrics: Dict[str, float]) -> str:
+    if (not HasBeenWarned.tqdm_ignores_underscores and
+                any(metric_name.startswith("_") for metric_name in metrics)):
+        logger.warning("Metrics with names beginning with \"_\" will "
+                       "not be logged to the tqdm progress bar.")
+        HasBeenWarned.tqdm_ignores_underscores = True
+    return ', '.join(["%s: %.4f" % (name, value)
+                      for name, value in
+                      metrics.items() if not name.startswith("_")]) + " ||"
diff --git a/setup.py b/setup.py
index 2e426a70fc9a9e6c74bbd1f5a5ae0e4f7cffc8d4..900b92f3912b72831bb027da851da556a9d8ad3d 100755
--- a/setup.py
+++ b/setup.py
@@ -34,6 +34,7 @@ setup(
                 'Topic :: Scientific/Engineering :: Artificial Intelligence',
     ],
     install_requires=[
+        'matplotlib',
         'tabulate',
         'python-Levenshtein',
         'requests',
@@ -45,7 +46,6 @@ setup(
         'datasets>=1.8',
         'seqeval',
         'spacy',
-        'allennlp',
         'simplejson',
         'unidecode',
         'jieba',