diff --git a/tutorials/Getting_Started.ipynb b/tutorials/Getting_Started.ipynb index c672d92353a4da6729814ead1d1c8494b6f005f4..60cac56ab33e09a439a7136a642045c753fd8888 100644 --- a/tutorials/Getting_Started.ipynb +++ b/tutorials/Getting_Started.ipynb @@ -11,9 +11,6 @@ "\n", "In this tutorial, you will know how to\n", "- use the models in **ConvLab-3** to build a dialog agent.\n", - "- build a simulator to chat with the agent and evaluate the performance.\n", - "- try different module combinations.\n", - "- use analysis tool to diagnose your system.\n", "\n", "Let's get started!" ] @@ -43,16 +40,6 @@ "! git clone --depth 1 https://github.com/ConvLab/ConvLab-3.git && cd ConvLab-3 && pip install -e ." ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# installing en_core_web_sm for spacy to resolve error in BERTNLU\n", - "!python -m spacy download en_core_web_sm" - ] - }, { "cell_type": "markdown", "metadata": { @@ -78,11 +65,11 @@ "outputs": [], "source": [ "# common import: convlab.$module.$model.$dataset\n", - "from convlab.nlu.jointBERT.multiwoz import BERTNLU\n", - "from convlab.nlu.milu.multiwoz import MILU\n", - "from convlab.dst.rule.multiwoz import RuleDST\n", - "from convlab.policy.rule.multiwoz import RulePolicy\n", - "from convlab.nlg.template.multiwoz import TemplateNLG\n", + "from convlab.base_models.t5.nlu import T5NLU\n", + "from convlab.base_models.t5.dst import T5DST\n", + "from convlab.base_models.t5.nlg import T5NLG\n", + "from convlab.policy.vector.vector_nodes import VectorNodes\n", + "from convlab.policy.vtrace_DPT import VTRACE\n", "from convlab.dialog_agent import PipelineAgent, BiSession\n", "from convlab.evaluator.multiwoz_eval import MultiWozEvaluator\n", "from pprint import pprint\n", @@ -98,7 +85,7 @@ "id": "N-18Q6YKGEzY" }, "source": [ - "Then, create the models and build an agent:" + "Then, create the models and build an agent on Multiwoz 2.1 dataset:" ] }, { @@ -112,14 +99,20 @@ "outputs": [], "source": [ "# go to README.md of each model for more information\n", - "# BERT nlu\n", - "sys_nlu = BERTNLU()\n", - "# simple rule DST\n", - "sys_dst = RuleDST()\n", - "# rule policy\n", - "sys_policy = RulePolicy()\n", - "# template NLG\n", - "sys_nlg = TemplateNLG(is_user=False)\n", + "sys_nlu = T5NLU(speaker='user', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlu-multiwoz21')\n", + "sys_dst = T5DST(dataset_name='multiwoz21', speaker='user', context_window_size=100, model_name_or_path='ConvLab/t5-small-dst-multiwoz21')\n", + "# Download pre-trained DDPT model\n", + "! wget https://huggingface.co/ConvLab/ddpt-policy-multiwoz21/resolve/main/supervised.pol.mdl --directory-prefix=\"convlab/policy/vtrace_DPT\"\n", + "vectorizer = VectorNodes(dataset_name='multiwoz21',\n", + " use_masking=True,\n", + " manually_add_entity_names=True,\n", + " seed=0,\n", + " filter_state=True)\n", + "sys_policy = VTRACE(is_train=False,\n", + " seed=0,\n", + " vectorizer=vectorizer,\n", + " load_path=\"convlab/policy/vtrace_DPT/supervised\")\n", + "sys_nlg = T5NLG(speaker='system', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlg-multiwoz21')\n", "# assemble\n", "sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')" ] @@ -137,14 +130,11 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "9LYnDLysH1nX" - }, + "metadata": {}, "outputs": [], "source": [ - "sys_agent.response(\"I want to find a moderate hotel\")" + "sys_agent.init_session()\n", + "sys_agent.response(\"I want to find a hotel in the expensive pricerange\")" ] }, { @@ -224,378 +214,6 @@ "source": [ "sys_agent.response(\"Book a table for 5 , this Sunday .\")" ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "n6uuuRonIHvW" - }, - "source": [ - "## Build a simulator to chat with the agent and evaluate\n", - "\n", - "In many one-to-one task-oriented dialog system, a simulator is essential to train an RL agent. In our framework, we doesn't distinguish user or system. All speakers are **agents**. The simulator is also an agent, with specific policy inside for accomplishing the user goal.\n", - "\n", - "We use `Agenda` policy for the simulator, this policy requires dialog act input, which means we should set DST argument of `PipelineAgent` to None. Then the `PipelineAgent` will pass dialog act to policy directly. Refer to `PipelineAgent` doc for more details." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "pAMAJZSF7D5w" - }, - "outputs": [], - "source": [ - "# MILU\n", - "user_nlu = MILU()\n", - "# not use dst\n", - "user_dst = None\n", - "# rule policy\n", - "user_policy = RulePolicy(character='usr')\n", - "# template NLG\n", - "user_nlg = TemplateNLG(is_user=True)\n", - "# assemble\n", - "user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "Rghl-V2AJhRY" - }, - "source": [ - "\n", - "Now we have a simulator and an agent. we will use an existed simple one-to-one conversation controller BiSession, you can also define your own Session class for your special need.\n", - "\n", - "We add `MultiWozEvaluator` to evaluate the performance. It uses the parsed dialog act input and policy output dialog act to calculate **inform f1**, **book rate**, and whether the task is **success**." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "saUoLUUGJqDa" - }, - "outputs": [], - "source": [ - "evaluator = MultiWozEvaluator()\n", - "sess = BiSession(sys_agent=sys_agent, user_agent=user_agent, kb_query=None, evaluator=evaluator)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "kevGJZhFJzTU" - }, - "source": [ - "Let's make this two agents chat! The key is `next_turn` method of `BiSession` class." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "FIV_qkE49LzE" - }, - "outputs": [], - "source": [ - "def set_seed(r_seed):\n", - " random.seed(r_seed)\n", - " np.random.seed(r_seed)\n", - " torch.manual_seed(r_seed)\n", - "\n", - "set_seed(20200131)\n", - "\n", - "sys_response = ''\n", - "sess.init_session()\n", - "print('init goal:')\n", - "pprint(sess.evaluator.goal)\n", - "print('-'*50)\n", - "for i in range(20):\n", - " sys_response, user_response, session_over, reward = sess.next_turn(sys_response)\n", - " print('user:', user_response)\n", - " print('sys:', sys_response)\n", - " print()\n", - " if session_over is True:\n", - " break\n", - "print('task success:', sess.evaluator.task_success())\n", - "print('book rate:', sess.evaluator.book_rate())\n", - "print('inform precision/recall/f1:', sess.evaluator.inform_F1())\n", - "print('-'*50)\n", - "print('final goal:')\n", - "pprint(sess.evaluator.goal)\n", - "print('='*100)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "CKOQs1l8LpTR" - }, - "source": [ - "## Try different module combinations\n", - "\n", - "The combination modes of pipeline agent modules are flexible. We support joint models such as TRADE, SUMBT for word-DST and MDRG, HDSA, LaRL for word-Policy, once the input and output are matched with previous and next module. We also support End2End models such as Sequicity.\n", - "\n", - "Available models:\n", - "\n", - "- NLU: BERTNLU, MILU, SVMNLU\n", - "- DST: RuleDST\n", - "- Word-DST: SUMBT, TRADE (set `sys_nlu` to `None`)\n", - "- Policy: RulePolicy, Imitation, REINFORCE, PPO, GDPL\n", - "- Word-Policy: MDRG, HDSA, LaRL (set `sys_nlg` to `None`)\n", - "- NLG: Template, SCLSTM\n", - "- End2End: Sequicity, DAMD, RNN_rollout (directly used as `sys_agent`)\n", - "- Simulator policy: Agenda, VHUS (for `user_policy`)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "G-9G0VYUNYOI" - }, - "outputs": [], - "source": [ - "# available NLU models\n", - "from convlab.nlu.svm.multiwoz import SVMNLU\n", - "from convlab.nlu.jointBERT.multiwoz import BERTNLU\n", - "from convlab.nlu.milu.multiwoz import MILU\n", - "# available DST models\n", - "from convlab.dst.rule.multiwoz import RuleDST\n", - "from convlab.dst.sumbt.multiwoz import SUMBT\n", - "from convlab.dst.trade.multiwoz import TRADE\n", - "# available Policy models\n", - "from convlab.policy.rule.multiwoz import RulePolicy\n", - "from convlab.policy.ppo.multiwoz import PPOPolicy\n", - "from convlab.policy.pg.multiwoz import PGPolicy\n", - "from convlab.policy.mle.multiwoz import MLEPolicy\n", - "from convlab.policy.gdpl.multiwoz import GDPLPolicy\n", - "from convlab.policy.vhus.multiwoz import UserPolicyVHUS\n", - "from convlab.policy.mdrg.multiwoz import MDRGWordPolicy\n", - "from convlab.policy.hdsa.multiwoz import HDSA\n", - "from convlab.policy.larl.multiwoz import LaRL\n", - "# available NLG models\n", - "from convlab.nlg.template.multiwoz import TemplateNLG\n", - "from convlab.nlg.sclstm.multiwoz import SCLSTM\n", - "# available E2E models\n", - "from convlab.e2e.sequicity.multiwoz import Sequicity\n", - "from convlab.e2e.damd.multiwoz import Damd" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "6TS2_Tp1Nzvq" - }, - "source": [ - "NLU+RuleDST or Word-DST:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "DZMk9wAlONrz" - }, - "outputs": [], - "source": [ - "# NLU+RuleDST:\n", - "sys_nlu = BERTNLU()\n", - "# sys_nlu = MILU()\n", - "# sys_nlu = SVMNLU()\n", - "sys_dst = RuleDST()\n", - "\n", - "# or Word-DST:\n", - "# sys_nlu = None\n", - "# sys_dst = SUMBT()\n", - "# sys_dst = TRADE()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "gUUYsDMJPJRl" - }, - "source": [ - "Policy+NLG or Word-Policy:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "PTJ95x9UPHt4" - }, - "outputs": [], - "source": [ - "# Policy+NLG:\n", - "sys_policy = RulePolicy()\n", - "# sys_policy = PPOPolicy()\n", - "# sys_policy = PGPolicy()\n", - "# sys_policy = MLEPolicy()\n", - "# sys_policy = GDPLPolicy()\n", - "sys_nlg = TemplateNLG(is_user=False)\n", - "# sys_nlg = SCLSTM(is_user=False)\n", - "\n", - "# or Word-Policy:\n", - "# sys_policy = LaRL()\n", - "# sys_policy = HDSA()\n", - "# sys_policy = MDRGWordPolicy()\n", - "# sys_nlg = None" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "s9lGIv0oPupn" - }, - "source": [ - "Assemble the Pipeline system agent:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "IvLx7HUkPyZ5" - }, - "outputs": [], - "source": [ - "sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, 'sys')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "hR4A8WbZP2lc" - }, - "source": [ - "Or Directly use an end-to-end model:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "8VdUmcxoP6ej" - }, - "outputs": [], - "source": [ - "# sys_agent = Sequicity()\n", - "# sys_agent = Damd()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "_v-eoBtnP9J9" - }, - "source": [ - "Config an user agent similarly:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "UkHpjvR5QezN" - }, - "outputs": [], - "source": [ - "user_nlu = BERTNLU()\n", - "# user_nlu = MILU()\n", - "# user_nlu = SVMNLU()\n", - "user_dst = None\n", - "user_policy = RulePolicy(character='usr')\n", - "# user_policy = UserPolicyVHUS(load_from_zip=True)\n", - "user_nlg = TemplateNLG(is_user=True)\n", - "# user_nlg = SCLSTM(is_user=True)\n", - "user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "VJTBxEDhSAqc" - }, - "source": [ - "## Use analysis tool to diagnose the system\n", - "We provide an analysis tool presents rich statistics and summarizes common mistakes from simulated dialogues, which facilitates error analysis and\n", - "system improvement. The analyzer will generate an HTML report which contains\n", - "rich statistics of simulated dialogues. For more information, please refer to `convlab/util/analysis_tool`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "Hu2q3lQiSMDy" - }, - "outputs": [], - "source": [ - "from convlab.util.analysis_tool.analyzer import Analyzer\n", - "\n", - "# if sys_nlu!=None, set use_nlu=True to collect more information\n", - "analyzer = Analyzer(user_agent=user_agent, dataset='multiwoz')\n", - "\n", - "set_seed(20200131)\n", - "analyzer.comprehensive_analyze(sys_agent=sys_agent, model_name='sys_agent', total_dialog=100)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "AB-mDm0plQWd" - }, - "source": [ - "To compare several models:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "GKe_DNJUlWzh" - }, - "outputs": [], - "source": [ - "set_seed(20200131)\n", - "analyzer.compare_models(agent_list=[sys_agent1, sys_agent2], model_name=['sys_agent1', 'sys_agent2'], total_dialog=100)" - ] } ], "metadata": { @@ -606,20 +224,28 @@ "toc_visible": true }, "kernelspec": { - "display_name": "Python 3.8.12 ('py38')", + "display_name": "Python 3.8.13 ('convlab')", "language": "python", "name": "python3" }, "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", "name": "python", - "version": "3.8.12" + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" }, "vscode": { "interpreter": { - "hash": "0f9333403d680bc010aa5ce5a2f27ba398c9e47e92ba3724506306aa234cd07d" + "hash": "4a33698a9a325011d7646f7f090905d1bb6057c0d4ab1946e074e5e84aab8508" } } }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file