diff --git a/convlab/base_models/t5/README.md b/convlab/base_models/t5/README.md index 754a1db40474b485c78d5a7857ffef70b4658c24..0c420f6ab76a9c7e78c9853c56419061b0b9056a 100644 --- a/convlab/base_models/t5/README.md +++ b/convlab/base_models/t5/README.md @@ -43,21 +43,28 @@ Trained models and their performance are available in [Hugging Face Hub](https:/ | ------------------------------------------------------------ | ------------- | ---------------------------- | | [t5-small-goal2dialogue-multiwoz21](https://huggingface.co/ConvLab/t5-small-goal2dialogue-multiwoz21) | Goal2Dialogue | MultiWOZ 2.1 | | [t5-small-nlu-multiwoz21](https://huggingface.co/ConvLab/t5-small-nlu-multiwoz21) | NLU | MultiWOZ 2.1 | +| [t5-small-nlu-all-multiwoz21](https://huggingface.co/ConvLab/t5-small-nlu-all-multiwoz21) | NLU | MultiWOZ 2.1 all utterances | | [t5-small-nlu-sgd](https://huggingface.co/ConvLab/t5-small-nlu-sgd) | NLU | SGD | | [t5-small-nlu-tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-nlu-tm1_tm2_tm3) | NLU | TM1+TM2+TM3 | | [t5-small-nlu-multiwoz21_sgd_tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-nlu-multiwoz21_sgd_tm1_tm2_tm3) | NLU | MultiWOZ 2.1+SGD+TM1+TM2+TM3 | +| [mt5-small-nlu-all-crosswoz](https://huggingface.co/ConvLab/mt5-small-nlu-all-crosswoz) | NLU | CrossWOZ all utterances | | [t5-small-nlu-multiwoz21-context3](https://huggingface.co/ConvLab/t5-small-nlu-multiwoz21-context3) | NLU (context=3) | MultiWOZ 2.1 | +| [t5-small-nlu-all-multiwoz21-context3](https://huggingface.co/ConvLab/t5-small-nlu-all-multiwoz21-context3) | NLU (context=3) | MultiWOZ 2.1 all utterances | | [t5-small-nlu-tm1-context3](https://huggingface.co/ConvLab/t5-small-nlu-tm1-context3) | NLU (context=3) | TM1 | | [t5-small-nlu-tm2-context3](https://huggingface.co/ConvLab/t5-small-nlu-tm2-context3) | NLU (context=3) | TM2 | | [t5-small-nlu-tm3-context3](https://huggingface.co/ConvLab/t5-small-nlu-tm3-context3) | NLU (context=3) | TM3 | | [t5-small-dst-multiwoz21](https://huggingface.co/ConvLab/t5-small-dst-multiwoz21) | DST | MultiWOZ 2.1 | | [t5-small-dst-sgd](https://huggingface.co/ConvLab/t5-small-dst-sgd) | DST | SGD | | [t5-small-dst-tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-dst-tm1_tm2_tm3) | DST | TM1+TM2+TM3 | +| [mt5-small-dst-crosswoz](https://huggingface.co/ConvLab/mt5-small-dst-crosswoz) | DST | CrossWOZ | | [t5-small-dst-multiwoz21_sgd_tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-dst-multiwoz21_sgd_tm1_tm2_tm3) | DST | MultiWOZ 2.1+SGD+TM1+TM2+TM3 | | [t5-small-nlg-multiwoz21](https://huggingface.co/ConvLab/t5-small-nlg-multiwoz21) | NLG | MultiWOZ 2.1 | +| [t5-small-nlg-user-multiwoz21](https://huggingface.co/ConvLab/t5-small-nlg-user-multiwoz21) | NLG | MultiWOZ 2.1 user utterances | +| [t5-small-nlg-all-multiwoz21](https://huggingface.co/ConvLab/t5-small-nlg-all-multiwoz21) | NLG | MultiWOZ 2.1 all utterances | | [t5-small-nlg-sgd](https://huggingface.co/ConvLab/t5-small-nlg-sgd) | NLG | SGD | | [t5-small-nlg-tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-nlg-tm1_tm2_tm3) | NLG | TM1+TM2+TM3 | | [t5-small-nlg-multiwoz21_sgd_tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-nlg-multiwoz21_sgd_tm1_tm2_tm3) | NLG | MultiWOZ 2.1+SGD+TM1+TM2+TM3 | +| [mt5-small-nlg-all-crosswoz](https://huggingface.co/ConvLab/mt5-small-nlg-all-crosswoz) | NLG | CrossWOZ all utterances | ## Interface diff --git a/convlab/dst/setsumbt/tracker.py b/convlab/dst/setsumbt/tracker.py index eca7f1749369f9569d6b923312a93cd317e0701c..e40332048188b7f5a1f43e397896a8b6b201553d 100644 --- a/convlab/dst/setsumbt/tracker.py +++ b/convlab/dst/setsumbt/tracker.py @@ -346,17 +346,26 @@ class SetSUMBTTracker(DST): state_entropy = None # Construct request action prediction - request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5] - request_acts = [slot.split('-', 1) for slot in request_acts] - request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts] + if request_probs is not None: + request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5] + request_acts = [slot.split('-', 1) for slot in request_acts] + request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts] + else: + request_acts = list() # Construct active domain set - active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()} + if active_domain_probs is not None: + active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()} + else: + active_domains = dict() # Construct general domain action - general_acts = general_act_probs[0, 0, :].argmax(-1).item() - general_acts = [[], ['bye'], ['thank']][general_acts] - general_acts = [[act, 'general', 'none', 'none'] for act in general_acts] + if general_act_probs is not None: + general_acts = general_act_probs[0, 0, :].argmax(-1).item() + general_acts = [[], ['bye'], ['thank']][general_acts] + general_acts = [[act, 'general', 'none', 'none'] for act in general_acts] + else: + general_acts = list() user_acts = request_acts + general_acts diff --git a/convlab/evaluator/multiwoz_eval.py b/convlab/evaluator/multiwoz_eval.py index cb6c8feb73aea6e4481a51fa0eb2466a6a07d1c6..f300914eb9382de570fbc08fae55e048f0f07ffa 100755 --- a/convlab/evaluator/multiwoz_eval.py +++ b/convlab/evaluator/multiwoz_eval.py @@ -6,10 +6,10 @@ import numpy as np import pdb from copy import deepcopy -from data.unified_datasets.multiwoz21.preprocess import reverse_da, reverse_da_slot_name_map +# from data.unified_datasets.multiwoz21.preprocess import reverse_da, reverse_da_slot_name_map from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA from convlab.evaluator.evaluator import Evaluator -from data.unified_datasets.multiwoz21.preprocess import reverse_da_slot_name_map +# from data.unified_datasets.multiwoz21.preprocess import reverse_da_slot_name_map from convlab.policy.rule.multiwoz.policy_agenda_multiwoz import unified_format, act_dict_to_flat_tuple from convlab.util.multiwoz.dbquery import Database from convlab.util import relative_import_module_from_unified_datasets @@ -28,6 +28,7 @@ REF_SYS_DA_M['taxi']['phone'] = 'phone' REF_SYS_DA_M['taxi']['car'] = 'car type' reverse_da = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da') +reverse_da_slot_name_map = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da_slot_name_map') requestable = \ diff --git a/convlab/nlu/jointBERT/README.md b/convlab/nlu/jointBERT/README.md index 647007548f18ad5c3adf97565acb7897ab025c5f..80f18a837e1056a7326b9b49d486f9c2e676e55b 100755 --- a/convlab/nlu/jointBERT/README.md +++ b/convlab/nlu/jointBERT/README.md @@ -40,6 +40,7 @@ To illustrate that it is easy to use the model for any dataset that in our unifi <tr> <th></th> <th colspan=2>MultiWOZ 2.1</th> + <th colspan=2>MultiWOZ 2.1 all utterances</th> <th colspan=2>Taskmaster-1</th> <th colspan=2>Taskmaster-2</th> <th colspan=2>Taskmaster-3</th> @@ -52,12 +53,14 @@ To illustrate that it is easy to use the model for any dataset that in our unifi <th>Acc</th><th>F1</th> <th>Acc</th><th>F1</th> <th>Acc</th><th>F1</th> + <th>Acc</th><th>F1</th> </tr> </thead> <tbody> <tr> <td>BERTNLU</td> <td>74.5</td><td>85.9</td> + <td>59.5</td><td>80.0</td> <td>72.8</td><td>50.6</td> <td>79.2</td><td>70.6</td> <td>86.1</td><td>81.9</td> @@ -65,6 +68,7 @@ To illustrate that it is easy to use the model for any dataset that in our unifi <tr> <td>BERTNLU (context=3)</td> <td>80.6</td><td>90.3</td> + <td>58.1</td><td>79.6</td> <td>74.2</td><td>52.7</td> <td>80.9</td><td>73.3</td> <td>87.8</td><td>83.8</td> diff --git a/convlab/nlu/jointBERT/unified_datasets/configs/multiwoz21_all.json b/convlab/nlu/jointBERT/unified_datasets/configs/multiwoz21_all.json new file mode 100644 index 0000000000000000000000000000000000000000..b996324cf5d6195290b7aa59cd7745ad6851abb8 --- /dev/null +++ b/convlab/nlu/jointBERT/unified_datasets/configs/multiwoz21_all.json @@ -0,0 +1,27 @@ +{ + "dataset_name": "multiwoz21", + "data_dir": "unified_datasets/data/multiwoz21/all/context_window_size_0", + "output_dir": "unified_datasets/output/multiwoz21/all/context_window_size_0", + "zipped_model_path": "unified_datasets/output/multiwoz21/all/context_window_size_0/bertnlu_unified_multiwoz21_all_context0.zip", + "log_dir": "unified_datasets/output/multiwoz21/all/context_window_size_0/log", + "DEVICE": "cuda:0", + "seed": 2019, + "cut_sen_len": 40, + "use_bert_tokenizer": true, + "context_window_size": 0, + "model": { + "finetune": true, + "context": false, + "context_grad": false, + "pretrained_weights": "bert-base-uncased", + "check_step": 1000, + "max_step": 10000, + "batch_size": 128, + "learning_rate": 1e-4, + "adam_epsilon": 1e-8, + "warmup_steps": 0, + "weight_decay": 0.0, + "dropout": 0.1, + "hidden_units": 768 + } + } \ No newline at end of file diff --git a/convlab/nlu/jointBERT/unified_datasets/configs/multiwoz21_all_context3.json b/convlab/nlu/jointBERT/unified_datasets/configs/multiwoz21_all_context3.json new file mode 100644 index 0000000000000000000000000000000000000000..b23ef9d63900caf39b9dea548754e310f7577a93 --- /dev/null +++ b/convlab/nlu/jointBERT/unified_datasets/configs/multiwoz21_all_context3.json @@ -0,0 +1,27 @@ +{ + "dataset_name": "multiwoz21", + "data_dir": "unified_datasets/data/multiwoz21/all/context_window_size_3", + "output_dir": "unified_datasets/output/multiwoz21/all/context_window_size_3", + "zipped_model_path": "unified_datasets/output/multiwoz21/all/context_window_size_3/bertnlu_unified_multiwoz21_all_context3.zip", + "log_dir": "unified_datasets/output/multiwoz21/all/context_window_size_3/log", + "DEVICE": "cuda:0", + "seed": 2019, + "cut_sen_len": 40, + "use_bert_tokenizer": true, + "context_window_size": 3, + "model": { + "finetune": true, + "context": true, + "context_grad": true, + "pretrained_weights": "bert-base-uncased", + "check_step": 1000, + "max_step": 10000, + "batch_size": 128, + "learning_rate": 1e-4, + "adam_epsilon": 1e-8, + "warmup_steps": 0, + "weight_decay": 0.0, + "dropout": 0.1, + "hidden_units": 1536 + } + } \ No newline at end of file diff --git a/convlab/policy/lava/multiwoz/lava.py b/convlab/policy/lava/multiwoz/lava.py index 76d177396c4d9a9d4c7e20f18ce652f6f8852ad5..ec07d36eda6f467706a6e6cb0c30a2ba3b23d37c 100755 --- a/convlab/policy/lava/multiwoz/lava.py +++ b/convlab/policy/lava/multiwoz/lava.py @@ -11,7 +11,9 @@ from convlab.policy import Policy from convlab.util.file_util import cached_path from convlab.util.multiwoz.state import default_state # from convlab.util.multiwoz.dbquery import Database -from data.unified_datasets.multiwoz21.database import Database +# from data.unified_datasets.multiwoz21.database import Database +from convlab.util import load_database +Database = load_database('multiwoz21') from copy import deepcopy import json import os diff --git a/examples/agent_examples/test_SUMBT-LaRL.py b/examples/agent_examples/test_SUMBT-LaRL.py index 3f2e9a28912114ba931493b99c78e6300b80f6b4..91004b206d78fbb907d4dc09ab1fe230b51a400c 100755 --- a/examples/agent_examples/test_SUMBT-LaRL.py +++ b/examples/agent_examples/test_SUMBT-LaRL.py @@ -1,11 +1,11 @@ # available NLU models # from convlab.nlu.svm.multiwoz import SVMNLU -from convlab.nlu.jointBERT.multiwoz import BERTNLU +from convlab.nlu.jointBERT.unified_datasets import BERTNLU # from convlab.nlu.milu.multiwoz import MILU # available DST models # from convlab.dst.rule.multiwoz import RuleDST # from convlab.dst.mdbt.multiwoz import MDBT -from convlab.dst.sumbt.multiwoz import SUMBT +from convlab.dst.setsumbt import SetSUMBTTracker # from convlab.dst.trade.multiwoz import TRADE # from convlab.dst.comer.multiwoz import COMER # available Policy models @@ -44,7 +44,7 @@ def test_end2end(): # BERT nlu sys_nlu = None # simple rule DST - sys_dst = SUMBT() + sys_dst = SetSUMBTTracker(model_type='bert', model_path="path/to/sumbt/checkpoint") # rule policy sys_policy = LaRL() # template NLG @@ -53,7 +53,7 @@ def test_end2end(): sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys') # BERT nlu trained on sys utterance - user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json', + user_nlu = BERTNLU(mode='sys', config_file='multiwoz21_sys_context.json', model_file='https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_multiwoz_sys_context.zip') # not use dst user_dst = None diff --git a/tutorials/Getting_Started.ipynb b/tutorials/Getting_Started.ipynb index 5d32d4e1f810ef30a30cd74f77d20c6b6ba3fb4c..9eb956a6d846bf55b40b9f3cede9ddca72a07a4e 100644 --- a/tutorials/Getting_Started.ipynb +++ b/tutorials/Getting_Started.ipynb @@ -11,9 +11,6 @@ "\n", "In this tutorial, you will know how to\n", "- use the models in **ConvLab-3** to build a dialog agent.\n", - "- build a simulator to chat with the agent and evaluate the performance.\n", - "- try different module combinations.\n", - "- use analysis tool to diagnose your system.\n", "\n", "Let's get started!" ] @@ -43,16 +40,6 @@ "! git clone --depth 1 https://github.com/ConvLab/ConvLab-3.git && cd ConvLab-3 && pip install -e ." ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# installing en_core_web_sm for spacy to resolve error in BERTNLU\n", - "!python -m spacy download en_core_web_sm" - ] - }, { "cell_type": "markdown", "metadata": { @@ -78,11 +65,11 @@ "outputs": [], "source": [ "# common import: convlab.$module.$model.$dataset\n", - "from convlab.nlu.jointBERT.multiwoz import BERTNLU\n", - "from convlab.nlu.milu.multiwoz import MILU\n", - "from convlab.dst.rule.multiwoz import RuleDST\n", - "from convlab.policy.rule.multiwoz import RulePolicy\n", - "from convlab.nlg.template.multiwoz import TemplateNLG\n", + "from convlab.base_models.t5.nlu import T5NLU\n", + "from convlab.base_models.t5.dst import T5DST\n", + "from convlab.base_models.t5.nlg import T5NLG\n", + "from convlab.policy.vector.vector_nodes import VectorNodes\n", + "from convlab.policy.vtrace_DPT import VTRACE\n", "from convlab.dialog_agent import PipelineAgent, BiSession\n", "from convlab.evaluator.multiwoz_eval import MultiWozEvaluator\n", "from pprint import pprint\n", @@ -98,7 +85,7 @@ "id": "N-18Q6YKGEzY" }, "source": [ - "Then, create the models and build an agent:" + "Then, create the models and build an agent on Multiwoz 2.1 dataset:" ] }, { @@ -112,14 +99,20 @@ "outputs": [], "source": [ "# go to README.md of each model for more information\n", - "# BERT nlu\n", - "sys_nlu = BERTNLU()\n", - "# simple rule DST\n", - "sys_dst = RuleDST()\n", - "# rule policy\n", - "sys_policy = RulePolicy()\n", - "# template NLG\n", - "sys_nlg = TemplateNLG(is_user=False)\n", + "sys_nlu = T5NLU(speaker='user', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlu-multiwoz21')\n", + "sys_dst = T5DST(dataset_name='multiwoz21', speaker='user', context_window_size=100, model_name_or_path='ConvLab/t5-small-dst-multiwoz21')\n", + "# Download pre-trained DDPT model\n", + "! wget https://huggingface.co/ConvLab/ddpt-policy-multiwoz21/resolve/main/supervised.pol.mdl --directory-prefix=\"convlab/policy/vtrace_DPT\"\n", + "vectorizer = VectorNodes(dataset_name='multiwoz21',\n", + " use_masking=True,\n", + " manually_add_entity_names=True,\n", + " seed=0,\n", + " filter_state=True)\n", + "sys_policy = VTRACE(is_train=False,\n", + " seed=0,\n", + " vectorizer=vectorizer,\n", + " load_path=\"convlab/policy/vtrace_DPT/supervised\")\n", + "sys_nlg = T5NLG(speaker='system', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlg-multiwoz21')\n", "# assemble\n", "sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')" ] @@ -137,14 +130,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "9LYnDLysH1nX" - }, + "metadata": {}, "outputs": [], "source": [ - "sys_agent.response(\"I want to find a moderate hotel\")" + "sys_agent.init_session()\n", + "sys_agent.response(\"I want to find a hotel in the expensive pricerange\")" ] }, { @@ -224,378 +214,6 @@ "source": [ "sys_agent.response(\"Book a table for 5 , this Sunday .\")" ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "n6uuuRonIHvW" - }, - "source": [ - "## Build a simulator to chat with the agent and evaluate\n", - "\n", - "In many one-to-one task-oriented dialog system, a simulator is essential to train an RL agent. In our framework, we doesn't distinguish user or system. All speakers are **agents**. The simulator is also an agent, with specific policy inside for accomplishing the user goal.\n", - "\n", - "We use `Agenda` policy for the simulator, this policy requires dialog act input, which means we should set DST argument of `PipelineAgent` to None. Then the `PipelineAgent` will pass dialog act to policy directly. Refer to `PipelineAgent` doc for more details." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "pAMAJZSF7D5w" - }, - "outputs": [], - "source": [ - "# MILU\n", - "user_nlu = MILU()\n", - "# not use dst\n", - "user_dst = None\n", - "# rule policy\n", - "user_policy = RulePolicy(character='usr')\n", - "# template NLG\n", - "user_nlg = TemplateNLG(is_user=True)\n", - "# assemble\n", - "user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Rghl-V2AJhRY" - }, - "source": [ - "\n", - "Now we have a simulator and an agent. we will use an existed simple one-to-one conversation controller BiSession, you can also define your own Session class for your special need.\n", - "\n", - "We add `MultiWozEvaluator` to evaluate the performance. It uses the parsed dialog act input and policy output dialog act to calculate **inform f1**, **book rate**, and whether the task is **success**." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "saUoLUUGJqDa" - }, - "outputs": [], - "source": [ - "evaluator = MultiWozEvaluator()\n", - "sess = BiSession(sys_agent=sys_agent, user_agent=user_agent, kb_query=None, evaluator=evaluator)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "kevGJZhFJzTU" - }, - "source": [ - "Let's make this two agents chat! The key is `next_turn` method of `BiSession` class." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "FIV_qkE49LzE" - }, - "outputs": [], - "source": [ - "def set_seed(r_seed):\n", - " random.seed(r_seed)\n", - " np.random.seed(r_seed)\n", - " torch.manual_seed(r_seed)\n", - "\n", - "set_seed(20200131)\n", - "\n", - "sys_response = ''\n", - "sess.init_session()\n", - "print('init goal:')\n", - "pprint(sess.evaluator.goal)\n", - "print('-'*50)\n", - "for i in range(20):\n", - " sys_response, user_response, session_over, reward = sess.next_turn(sys_response)\n", - " print('user:', user_response)\n", - " print('sys:', sys_response)\n", - " print()\n", - " if session_over is True:\n", - " break\n", - "print('task success:', sess.evaluator.task_success())\n", - "print('book rate:', sess.evaluator.book_rate())\n", - "print('inform precision/recall/f1:', sess.evaluator.inform_F1())\n", - "print('-'*50)\n", - "print('final goal:')\n", - "pprint(sess.evaluator.goal)\n", - "print('='*100)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "CKOQs1l8LpTR" - }, - "source": [ - "## Try different module combinations\n", - "\n", - "The combination modes of pipeline agent modules are flexible. We support joint models such as TRADE, SUMBT for word-DST and MDRG, HDSA, LaRL for word-Policy, once the input and output are matched with previous and next module. We also support End2End models such as Sequicity.\n", - "\n", - "Available models:\n", - "\n", - "- NLU: BERTNLU, MILU, SVMNLU\n", - "- DST: RuleDST\n", - "- Word-DST: SUMBT, TRADE (set `sys_nlu` to `None`)\n", - "- Policy: RulePolicy, Imitation, REINFORCE, PPO, GDPL\n", - "- Word-Policy: MDRG, HDSA, LaRL (set `sys_nlg` to `None`)\n", - "- NLG: Template, SCLSTM\n", - "- End2End: Sequicity, DAMD, RNN_rollout (directly used as `sys_agent`)\n", - "- Simulator policy: Agenda, VHUS (for `user_policy`)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "G-9G0VYUNYOI" - }, - "outputs": [], - "source": [ - "# available NLU models\n", - "from convlab.nlu.svm.multiwoz import SVMNLU\n", - "from convlab.nlu.jointBERT.multiwoz import BERTNLU\n", - "from convlab.nlu.milu.multiwoz import MILU\n", - "# available DST models\n", - "from convlab.dst.rule.multiwoz import RuleDST\n", - "from convlab.dst.sumbt.multiwoz import SUMBT\n", - "from convlab.dst.trade.multiwoz import TRADE\n", - "# available Policy models\n", - "from convlab.policy.rule.multiwoz import RulePolicy\n", - "from convlab.policy.ppo.multiwoz import PPOPolicy\n", - "from convlab.policy.pg.multiwoz import PGPolicy\n", - "from convlab.policy.mle.multiwoz import MLEPolicy\n", - "from convlab.policy.gdpl.multiwoz import GDPLPolicy\n", - "from convlab.policy.vhus.multiwoz import UserPolicyVHUS\n", - "from convlab.policy.mdrg.multiwoz import MDRGWordPolicy\n", - "from convlab.policy.hdsa.multiwoz import HDSA\n", - "from convlab.policy.larl.multiwoz import LaRL\n", - "# available NLG models\n", - "from convlab.nlg.template.multiwoz import TemplateNLG\n", - "from convlab.nlg.sclstm.multiwoz import SCLSTM\n", - "# available E2E models\n", - "from convlab.e2e.sequicity.multiwoz import Sequicity\n", - "from convlab.e2e.damd.multiwoz import Damd" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "6TS2_Tp1Nzvq" - }, - "source": [ - "NLU+RuleDST or Word-DST:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "DZMk9wAlONrz" - }, - "outputs": [], - "source": [ - "# NLU+RuleDST:\n", - "sys_nlu = BERTNLU()\n", - "# sys_nlu = MILU()\n", - "# sys_nlu = SVMNLU()\n", - "sys_dst = RuleDST()\n", - "\n", - "# or Word-DST:\n", - "# sys_nlu = None\n", - "# sys_dst = SUMBT()\n", - "# sys_dst = TRADE()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "gUUYsDMJPJRl" - }, - "source": [ - "Policy+NLG or Word-Policy:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "PTJ95x9UPHt4" - }, - "outputs": [], - "source": [ - "# Policy+NLG:\n", - "sys_policy = RulePolicy()\n", - "# sys_policy = PPOPolicy()\n", - "# sys_policy = PGPolicy()\n", - "# sys_policy = MLEPolicy()\n", - "# sys_policy = GDPLPolicy()\n", - "sys_nlg = TemplateNLG(is_user=False)\n", - "# sys_nlg = SCLSTM(is_user=False)\n", - "\n", - "# or Word-Policy:\n", - "# sys_policy = LaRL()\n", - "# sys_policy = HDSA()\n", - "# sys_policy = MDRGWordPolicy()\n", - "# sys_nlg = None" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "s9lGIv0oPupn" - }, - "source": [ - "Assemble the Pipeline system agent:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "IvLx7HUkPyZ5" - }, - "outputs": [], - "source": [ - "sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, 'sys')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "hR4A8WbZP2lc" - }, - "source": [ - "Or Directly use an end-to-end model:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "8VdUmcxoP6ej" - }, - "outputs": [], - "source": [ - "# sys_agent = Sequicity()\n", - "# sys_agent = Damd()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "_v-eoBtnP9J9" - }, - "source": [ - "Config an user agent similarly:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "UkHpjvR5QezN" - }, - "outputs": [], - "source": [ - "user_nlu = BERTNLU()\n", - "# user_nlu = MILU()\n", - "# user_nlu = SVMNLU()\n", - "user_dst = None\n", - "user_policy = RulePolicy(character='usr')\n", - "# user_policy = UserPolicyVHUS(load_from_zip=True)\n", - "user_nlg = TemplateNLG(is_user=True)\n", - "# user_nlg = SCLSTM(is_user=True)\n", - "user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "VJTBxEDhSAqc" - }, - "source": [ - "## Use analysis tool to diagnose the system\n", - "We provide an analysis tool presents rich statistics and summarizes common mistakes from simulated dialogues, which facilitates error analysis and\n", - "system improvement. The analyzer will generate an HTML report which contains\n", - "rich statistics of simulated dialogues. For more information, please refer to `convlab/util/analysis_tool`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "Hu2q3lQiSMDy" - }, - "outputs": [], - "source": [ - "from convlab.util.analysis_tool.analyzer import Analyzer\n", - "\n", - "# if sys_nlu!=None, set use_nlu=True to collect more information\n", - "analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')\n", - "\n", - "set_seed(20200131)\n", - "analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='sys_agent', total_dialog=100)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "AB-mDm0plQWd" - }, - "source": [ - "To compare several models:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "GKe_DNJUlWzh" - }, - "outputs": [], - "source": [ - "set_seed(20200131)\n", - "analyzer.compare_models(agent_list=[sys_agent1, sys_agent2], model_name=['sys_agent1', 'sys_agent2'], total_dialog=100)" - ] } ], "metadata": { @@ -606,17 +224,25 @@ "toc_visible": true }, "kernelspec": { - "display_name": "Python 3.6.9 64-bit", + "display_name": "Python 3.8.13 ('convlab')", "language": "python", "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "version": "3.6.9" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" }, "vscode": { "interpreter": { - "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + "hash": "4a33698a9a325011d7646f7f090905d1bb6057c0d4ab1946e074e5e84aab8508" } } },