diff --git a/convlab/deploy/README.md b/convlab/deploy/README.md index c02f097cbbac8c6258bea3fa4c7dacc147d91e75..d445429b38bacd729ff99f373db188ba21b92079 100755 --- a/convlab/deploy/README.md +++ b/convlab/deploy/README.md @@ -59,22 +59,17 @@ ```json "nlu": { - "svm-cam": { - "class_path": "convlab.nlu.svm.camrest.nlu.SVMNLU", - "data_set": "camrest", - "ini_params": {"mode": "usr"}, - "model_name": "svm-cam", - "max_core": 1, - "preload": true, - "enable": true - }, - "svm-mul": { - "class_path": "convlab.nlu.svm.multiwoz.nlu.SVMNLU", + "t5nlu-mul": { + "class_path": "convlab.base_models.t5.nlu.T5NLU", "data_set": "multiwoz", - "ini_params": {"mode": "usr"}, - "model_name": "svm-mul", + "ini_params": { + "speaker": "user", + "context_window_size": 0, + "model_name_or_path": "ConvLab/t5-small-nlu-multiwoz21" + }, + "model_name": "t5nlu-mul", "max_core": 1, - "preload": false, + "preload": true, "enable": true } } diff --git a/convlab/deploy/dep_config.json b/convlab/deploy/dep_config.json index b9cb0b676c872da16606c551fd4696aa446693d3..a6fb20b913ec1732a3f53a1ac34d6971901b73ee 100755 --- a/convlab/deploy/dep_config.json +++ b/convlab/deploy/dep_config.json @@ -5,185 +5,62 @@ "session_time_out": 300 }, "nlu": { - "svm-cam": { - "class_path": "convlab.nlu.svm.camrest.nlu.SVMNLU", - "data_set": "camrest", - "ini_params": { - "mode": "usr" - }, - "model_name": "svm-cam", - "max_core": 1, - "preload": true, - "enable": true - }, - "svm-mul": { - "class_path": "convlab.nlu.svm.multiwoz.nlu.SVMNLU", - "data_set": "multiwoz", - "ini_params": { - "mode": "usr" - }, - "model_name": "svm-mul", - "max_core": 1, - "preload": false, - "enable": true - }, - "bert-cro": { - "class_path": "convlab.nlu.jointBERT.crosswoz.nlu.BERTNLU", - "data_set": "crosswoz", - "ini_params": { - "mode": "all", - "config_file": "crosswoz_all.json", - "model_file": "https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_crosswoz_all.zip" - }, - "model_name": "bert-cro", - "max_core": 1, - "preload": false, - "enable": true - }, - "bert-mul": { - "class_path": "convlab.nlu.jointBERT.multiwoz.nlu.BERTNLU", + "t5nlu-mul": { + "class_path": "convlab.base_models.t5.nlu.T5NLU", "data_set": "multiwoz", "ini_params": { - "mode": "all", - "config_file": "multiwoz_all.json", - "model_file": "https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_multiwoz_all.zip" + "speaker": "user", + "context_window_size": 0, + "model_name_or_path": "ConvLab/t5-small-nlu-multiwoz21" }, - "model_name": "bert-mul", + "model_name": "t5nlu-mul", "max_core": 1, - "preload": false, + "preload": true, "enable": true } }, "dst": { - "rule-cam": { - "class_path": "convlab.dst.rule.camrest.dst.RuleDST", - "data_set": "camrest", - "ini_params": {}, - "model_name": "rule-cam", - "max_core": 1, - "preload": true, - "enable": true - }, - "rule-mul": { - "class_path": "convlab.dst.rule.multiwoz.dst.RuleDST", + "t5dst-mul": { + "class_path": "convlab.base_models.t5.dst.T5DST", "data_set": "multiwoz", - "ini_params": {}, - "model_name": "rule-mul", - "max_core": 1, - "preload": true, - "enable": true - }, - "rule-cro": { - "class_path": "convlab.dst.rule.crosswoz.dst.RuleDST", - "data_set": "crosswoz", - "ini_params": {}, - "model_name": "rule-cro", - "max_core": 1, - "preload": true, - "enable": true - }, - "trade-mul": { - "class_path": "convlab.dst.trade.multiwoz.trade.MultiWOZTRADE", - "data_set": "multiwoz", - "ini_params": {}, - "model_name": "trade-mul", + "ini_params": { + "dataset_name": "multiwoz21", + "speaker": "user", + "context_window_size": 100, + "model_name_or_path": "ConvLab/t5-small-dst-multiwoz21" + }, + "model_name": "t5dst-mul", "max_core": 1, "preload": true, "enable": true } }, "policy": { - "mle-cam": { - "class_path": "convlab.policy.mle.camrest.mle.MLE", - "data_set": "camrest", - "ini_params": {}, - "model_name": "mle-cam", - "max_core": 1, - "preload": false, - "enable": true - }, - "mle-mul": { - "class_path": "convlab.policy.mle.multiwoz.mle.MLE", - "data_set": "multiwoz", - "ini_params": {}, - "model_name": "mle-mul", - "max_core": 1, - "preload": false, - "enable": true - }, - "rule-cam": { - "class_path": "convlab.policy.rule.camrest.rule_based_camrest_bot.RuleBasedCamrestBot", - "data_set": "camrest", - "ini_params": {}, - "model_name": "rule-cam", - "max_core": 1, - "preload": true, - "enable": true - }, - "rule-mul": { - "class_path": "convlab.policy.rule.multiwoz.rule_based_multiwoz_bot.RuleBasedMultiwozBot", + "ddpt-mul": { + "class_path": "convlab.policy.vtrace_DPT.VTRACE", "data_set": "multiwoz", - "ini_params": {}, - "model_name": "rule-mul", + "ini_params": { + "is_train": false, + "seed": 0, + "load_path": "supervised", + "dataset_name": "multiwoz21" + }, + "model_name": "ddpt-mul", "max_core": 1, "preload": true, "enable": true - }, - "mle-cro": { - "class_path": "convlab.policy.mle.crosswoz.mle.MLE", - "data_set": "crosswoz", - "ini_params": {}, - "model_name": "mle-cro", - "max_core": 1, - "preload": false, - "enable": true } }, "nlg": { - "tmp-manual-cam": { - "class_path": "convlab.nlg.template.camrest.nlg.TemplateNLG", - "data_set": "camrest", - "ini_params": { - "is_user": false - }, - "model_name": "tmp-manual-cam", - "max_core": 1, - "preload": true, - "enable": true - }, - "tmp-auto_manual-cam": { - "class_path": "convlab.nlg.template.camrest.nlg.TemplateNLG", - "data_set": "camrest", - "ini_params": { - "is_user": false, - "mode": "auto_manual" - }, - "model_name": "tmp-auto_manual-cam", - "max_core": 1, - "preload": true, - "enable": true - }, - "tmp-auto_manual-mul": { - "class_path": "convlab.nlg.template.multiwoz.nlg.TemplateNLG", + "t5nlg-mul": { + "class_path": "convlab.base_models.t5.nlg.T5NLG", "data_set": "multiwoz", "ini_params": { - "is_user": false, - "mode": "auto_manual" - }, - "model_name": "tmp-auto_manual-mul", - "max_core": 1, - "preload": true, - "enable": true - - }, - "tmp-auto_manual-cro": { - "class_path": "convlab.nlg.template.crosswoz.nlg.TemplateNLG", - "data_set": "crosswoz", - "ini_params": { - "is_user": false, - "mode": "auto_manual" + "speaker": "system", + "context_window_size": 0, + "model_name_or_path": "ConvLab/t5-small-nlg-multiwoz21" }, - "model_name": "tmp-auto_manual-cro", + "model_name": "t5nlg-mul", "max_core": 1, "preload": true, "enable": true diff --git a/convlab/policy/vtrace_DPT/vtrace.py b/convlab/policy/vtrace_DPT/vtrace.py index 2918f4dfb019cef3cfe9e0d3981b70c9502700c8..5b031c4c297d075d3de66b1f776101a3d8a2e614 100644 --- a/convlab/policy/vtrace_DPT/vtrace.py +++ b/convlab/policy/vtrace_DPT/vtrace.py @@ -8,6 +8,7 @@ import torch.nn as nn import urllib.request from torch import optim +from convlab.policy.vector.vector_nodes import VectorNodes from convlab.policy.vtrace_DPT.transformer_model.EncoderDecoder import EncoderDecoder from convlab.policy.vtrace_DPT.transformer_model.EncoderCritic import EncoderCritic from ... import Policy @@ -21,7 +22,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class VTRACE(nn.Module, Policy): - def __init__(self, is_train=True, seed=0, vectorizer=None, load_path=""): + def __init__(self, is_train=True, seed=0, vectorizer=None, load_path="", **kwargs): super(VTRACE, self).__init__() @@ -59,6 +60,13 @@ class VTRACE(nn.Module, Policy): self.last_action = None + if vectorizer is None: + vectorizer = VectorNodes(dataset_name=kwargs['dataset_name'], + use_masking=kwargs.get('use_masking', True), + manually_add_entity_names=kwargs.get('manually_add_entity_names', True), + seed=seed, + filter_state=kwargs.get('filter_state', True)) + self.vector = vectorizer self.cfg['dataset_name'] = self.vector.dataset_name self.policy = EncoderDecoder(**self.cfg, action_dict=self.vector.act2vec).to(device=DEVICE) @@ -104,7 +112,6 @@ class VTRACE(nn.Module, Policy): Returns: action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...}) """ - if not self.is_train: for param in self.policy.parameters(): param.requires_grad = False