diff --git a/convlab/base_models/t5/trainer.py b/convlab/base_models/t5/trainer.py index ba0bce934b9f173aecc797119d4f28c0c974251b..dd8c2dfccec649359b0c03fe6b47dcf57bb01b22 100644 --- a/convlab/base_models/t5/trainer.py +++ b/convlab/base_models/t5/trainer.py @@ -16,9 +16,11 @@ # from dataclasses import dataclass, field # import torch # from torch import nn +from torch.utils.data import Dataset # from transformers.deepspeed import is_deepspeed_zero3_enabled # from transformers.utils import logging, cached_property, torch_required +from transformers.trainer_utils import PredictionOutput from transformers.training_args import ( os, torch, @@ -161,6 +163,103 @@ class ConvLabSeq2SeqTrainingArguments(Seq2SeqTrainingArguments): class ConvLabSeq2SeqTrainer(Seq2SeqTrainer): + # modifed from Seq2SeqTrainer of 4.26.1: https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/trainer_seq2seq.py + # add generation args in `prediction_step` + def evaluate( + self, + eval_dataset: Optional[Dataset] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + **gen_kwargs + ) -> Dict[str, float]: + """ + Run evaluation and returns metrics. + The calling script will be responsible for providing a method to compute metrics, as they are task-dependent + (pass it to the init `compute_metrics` argument). + You can also subclass and override this method to inject custom behavior. + Args: + eval_dataset (`Dataset`, *optional*): + Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns + not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` + method. + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is `"eval"` (default) + max_length (`int`, *optional*): + The maximum target length to use when predicting with the generate method. + num_beams (`int`, *optional*): + Number of beams for beam search that will be used when predicting with the generate method. 1 means no + beam search. + gen_kwargs: + Additional `generate` specific kwargs. + Returns: + A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The + dictionary also contains the epoch number which comes from the training state. + """ + + gen_kwargs = gen_kwargs.copy() + if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + gen_kwargs["max_length"] = self.args.generation_max_length + gen_kwargs["num_beams"] = ( + gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams + ) + self._gen_kwargs = gen_kwargs + + return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) + + def predict( + self, + test_dataset: Dataset, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "test", + **gen_kwargs + ) -> PredictionOutput: + """ + Run prediction and returns predictions and potential metrics. + Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method + will also return metrics, like in `evaluate()`. + Args: + test_dataset (`Dataset`): + Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. Has to implement the method `__len__` + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is `"eval"` (default) + max_length (`int`, *optional*): + The maximum target length to use when predicting with the generate method. + num_beams (`int`, *optional*): + Number of beams for beam search that will be used when predicting with the generate method. 1 means no + beam search. + gen_kwargs: + Additional `generate` specific kwargs. + <Tip> + If your predictions or labels have different sequence lengths (for instance because you're doing dynamic + padding in a token classification task) the predictions will be padded (on the right) to allow for + concatenation into one array. The padding index is -100. + </Tip> + Returns: *NamedTuple* A namedtuple with the following keys: + - predictions (`np.ndarray`): The predictions on `test_dataset`. + - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). + - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained + labels). + """ + + gen_kwargs = gen_kwargs.copy() + if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + gen_kwargs["max_length"] = self.args.generation_max_length + gen_kwargs["num_beams"] = ( + gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams + ) + self._gen_kwargs = gen_kwargs + + return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) + def prediction_step( self, model: nn.Module, @@ -194,16 +293,25 @@ class ConvLabSeq2SeqTrainer(Seq2SeqTrainer): inputs = self._prepare_inputs(inputs) # XXX: adapt synced_gpus for fairscale as well - gen_kwargs = { - "max_length": self._max_length if self._max_length is not None else self.model.config.max_length, - "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, - "synced_gpus": True if is_deepspeed_zero3_enabled() else False, + gen_kwargs = self._gen_kwargs.copy() + if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: + gen_kwargs["max_length"] = self.model.config.max_length + gen_kwargs["num_beams"] = ( + gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams + ) + default_synced_gpus = True if is_deepspeed_zero3_enabled() else False + gen_kwargs["synced_gpus"] = ( + gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus + ) + + # DONE: add generation arguments + gen_kwargs.update({ "do_sample": self.args.do_sample, "temperature": self.args.temperature, "top_k": self.args.top_k, "top_p": self.args.top_p, "num_return_sequences": self.args.num_return_sequences - } + }) if "attention_mask" in inputs: gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) @@ -223,13 +331,17 @@ class ConvLabSeq2SeqTrainer(Seq2SeqTrainer): **gen_kwargs, ) # in case the batch is shorter than max length, the output should be padded - if generated_tokens.shape[-1] < gen_kwargs["max_length"]: + if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]: generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) + elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < ( + gen_kwargs["max_new_tokens"] + 1 + ): + generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) with torch.no_grad(): - with self.autocast_smart_context_manager(): - outputs = model(**inputs) if has_labels: + with self.compute_loss_context_manager(): + outputs = model(**inputs) if self.label_smoother is not None: loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() else: @@ -242,8 +354,12 @@ class ConvLabSeq2SeqTrainer(Seq2SeqTrainer): if has_labels: labels = inputs["labels"] - if labels.shape[-1] < gen_kwargs["max_length"]: + if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]: labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) + elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < ( + gen_kwargs["max_new_tokens"] + 1 + ): + labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1)) else: labels = None diff --git a/convlab/deploy/README.md b/convlab/deploy/README.md index c02f097cbbac8c6258bea3fa4c7dacc147d91e75..d445429b38bacd729ff99f373db188ba21b92079 100755 --- a/convlab/deploy/README.md +++ b/convlab/deploy/README.md @@ -59,22 +59,17 @@ ```json "nlu": { - "svm-cam": { - "class_path": "convlab.nlu.svm.camrest.nlu.SVMNLU", - "data_set": "camrest", - "ini_params": {"mode": "usr"}, - "model_name": "svm-cam", - "max_core": 1, - "preload": true, - "enable": true - }, - "svm-mul": { - "class_path": "convlab.nlu.svm.multiwoz.nlu.SVMNLU", + "t5nlu-mul": { + "class_path": "convlab.base_models.t5.nlu.T5NLU", "data_set": "multiwoz", - "ini_params": {"mode": "usr"}, - "model_name": "svm-mul", + "ini_params": { + "speaker": "user", + "context_window_size": 0, + "model_name_or_path": "ConvLab/t5-small-nlu-multiwoz21" + }, + "model_name": "t5nlu-mul", "max_core": 1, - "preload": false, + "preload": true, "enable": true } } diff --git a/convlab/deploy/dep_config.json b/convlab/deploy/dep_config.json index b9cb0b676c872da16606c551fd4696aa446693d3..1cd52d4b5f502e42819e141f6e7a2c14ca56bc41 100755 --- a/convlab/deploy/dep_config.json +++ b/convlab/deploy/dep_config.json @@ -5,188 +5,65 @@ "session_time_out": 300 }, "nlu": { - "svm-cam": { - "class_path": "convlab.nlu.svm.camrest.nlu.SVMNLU", - "data_set": "camrest", - "ini_params": { - "mode": "usr" - }, - "model_name": "svm-cam", - "max_core": 1, - "preload": true, - "enable": true - }, - "svm-mul": { - "class_path": "convlab.nlu.svm.multiwoz.nlu.SVMNLU", - "data_set": "multiwoz", - "ini_params": { - "mode": "usr" - }, - "model_name": "svm-mul", - "max_core": 1, - "preload": false, - "enable": true - }, - "bert-cro": { - "class_path": "convlab.nlu.jointBERT.crosswoz.nlu.BERTNLU", - "data_set": "crosswoz", - "ini_params": { - "mode": "all", - "config_file": "crosswoz_all.json", - "model_file": "https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_crosswoz_all.zip" - }, - "model_name": "bert-cro", - "max_core": 1, - "preload": false, - "enable": true - }, - "bert-mul": { - "class_path": "convlab.nlu.jointBERT.multiwoz.nlu.BERTNLU", + "t5nlu-mul": { + "class_path": "convlab.base_models.t5.nlu.T5NLU", "data_set": "multiwoz", "ini_params": { - "mode": "all", - "config_file": "multiwoz_all.json", - "model_file": "https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_multiwoz_all.zip" + "speaker": "user", + "context_window_size": 0, + "model_name_or_path": "ConvLab/t5-small-nlu-multiwoz21" }, - "model_name": "bert-mul", + "model_name": "t5nlu-mul", "max_core": 1, - "preload": false, + "preload": true, "enable": true } }, "dst": { - "rule-cam": { - "class_path": "convlab.dst.rule.camrest.dst.RuleDST", - "data_set": "camrest", - "ini_params": {}, - "model_name": "rule-cam", - "max_core": 1, - "preload": true, - "enable": true - }, - "rule-mul": { - "class_path": "convlab.dst.rule.multiwoz.dst.RuleDST", + "t5dst-mul": { + "class_path": "convlab.base_models.t5.dst.T5DST", "data_set": "multiwoz", - "ini_params": {}, - "model_name": "rule-mul", - "max_core": 1, - "preload": true, - "enable": true - }, - "rule-cro": { - "class_path": "convlab.dst.rule.crosswoz.dst.RuleDST", - "data_set": "crosswoz", - "ini_params": {}, - "model_name": "rule-cro", - "max_core": 1, - "preload": true, - "enable": true - }, - "trade-mul": { - "class_path": "convlab.dst.trade.multiwoz.trade.MultiWOZTRADE", - "data_set": "multiwoz", - "ini_params": {}, - "model_name": "trade-mul", + "ini_params": { + "dataset_name": "multiwoz21", + "speaker": "user", + "context_window_size": 100, + "model_name_or_path": "ConvLab/t5-small-dst-multiwoz21" + }, + "model_name": "t5dst-mul", "max_core": 1, "preload": true, "enable": true } }, "policy": { - "mle-cam": { - "class_path": "convlab.policy.mle.camrest.mle.MLE", - "data_set": "camrest", - "ini_params": {}, - "model_name": "mle-cam", - "max_core": 1, - "preload": false, - "enable": true - }, - "mle-mul": { - "class_path": "convlab.policy.mle.multiwoz.mle.MLE", - "data_set": "multiwoz", - "ini_params": {}, - "model_name": "mle-mul", - "max_core": 1, - "preload": false, - "enable": true - }, - "rule-cam": { - "class_path": "convlab.policy.rule.camrest.rule_based_camrest_bot.RuleBasedCamrestBot", - "data_set": "camrest", - "ini_params": {}, - "model_name": "rule-cam", - "max_core": 1, - "preload": true, - "enable": true - }, - "rule-mul": { - "class_path": "convlab.policy.rule.multiwoz.rule_based_multiwoz_bot.RuleBasedMultiwozBot", + "ddpt-mul": { + "class_path": "convlab.policy.vtrace_DPT.VTRACE", "data_set": "multiwoz", - "ini_params": {}, - "model_name": "rule-mul", + "ini_params": { + "is_train": false, + "seed": 0, + "load_path": "from_pretrained", + "dataset_name": "multiwoz21" + }, + "model_name": "ddpt-mul", "max_core": 1, "preload": true, "enable": true - }, - "mle-cro": { - "class_path": "convlab.policy.mle.crosswoz.mle.MLE", - "data_set": "crosswoz", - "ini_params": {}, - "model_name": "mle-cro", - "max_core": 1, - "preload": false, - "enable": true } }, "nlg": { - "tmp-manual-cam": { - "class_path": "convlab.nlg.template.camrest.nlg.TemplateNLG", - "data_set": "camrest", - "ini_params": { - "is_user": false - }, - "model_name": "tmp-manual-cam", - "max_core": 1, - "preload": true, - "enable": true - }, - "tmp-auto_manual-cam": { - "class_path": "convlab.nlg.template.camrest.nlg.TemplateNLG", - "data_set": "camrest", - "ini_params": { - "is_user": false, - "mode": "auto_manual" - }, - "model_name": "tmp-auto_manual-cam", - "max_core": 1, - "preload": true, - "enable": true - }, - "tmp-auto_manual-mul": { - "class_path": "convlab.nlg.template.multiwoz.nlg.TemplateNLG", + "t5nlg-mul": { + "class_path": "convlab.base_models.t5.nlg.T5NLG", "data_set": "multiwoz", "ini_params": { - "is_user": false, - "mode": "auto_manual" - }, - "model_name": "tmp-auto_manual-mul", - "max_core": 1, - "preload": true, - "enable": true - - }, - "tmp-auto_manual-cro": { - "class_path": "convlab.nlg.template.crosswoz.nlg.TemplateNLG", - "data_set": "crosswoz", - "ini_params": { - "is_user": false, - "mode": "auto_manual" + "speaker": "system", + "context_window_size": 0, + "model_name_or_path": "ConvLab/t5-small-nlg-multiwoz21" }, - "model_name": "tmp-auto_manual-cro", + "model_name": "t5nlg-mul", "max_core": 1, "preload": true, "enable": true } } -} \ No newline at end of file +} diff --git a/convlab/policy/pg/pg.py b/convlab/policy/pg/pg.py index 060be694d06748bae2e737ba81853b3b0de2c29d..2230ac8d51886ac3630a89f58bc493b896d99b4a 100755 --- a/convlab/policy/pg/pg.py +++ b/convlab/policy/pg/pg.py @@ -13,6 +13,7 @@ from convlab.policy.vector.vector_binary import VectorBinary from convlab.util.file_util import cached_path import zipfile import sys +import urllib.request root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(root_dir) @@ -22,7 +23,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class PG(Policy): - def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None): + def __init__(self, is_train=False, seed=0, vectorizer=None, load_path="", **kwargs): with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: cfg = json.load(f) self.cfg = cfg @@ -32,19 +33,31 @@ class PG(Policy): self.optim_batchsz = cfg['batchsz'] self.gamma = cfg['gamma'] self.is_train = is_train - self.vector = vectorizer self.info_dict = {} set_seed(seed) + self.vector = vectorizer + dir_name = os.path.dirname(os.path.abspath(__file__)) + if self.vector is None: logging.info("No vectorizer was set, using default..") - from convlab.policy.vector.vector_binary import VectorBinary - self.vector = VectorBinary() - - if dataset == 'Multiwoz': - self.vector = vectorizer - self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE) + self.vector = VectorBinary(dataset_name=kwargs['dataset_name'], + use_masking=kwargs.get('use_masking', True), + manually_add_entity_names=kwargs.get('manually_add_entity_names', True), + seed=seed) + + self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE) + + try: + if load_path == "from_pretrained": + urllib.request.urlretrieve( + f"https://huggingface.co/ConvLab/mle-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl", + f"{dir_name}/{self.vector.dataset_name}_mle.pol.mdl") + load_path = f"{dir_name}/{self.vector.dataset_name}_mle" + self.load_policy(load_path) + except Exception as e: + print(f"Could not load the policy, Exception: {e}") # self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE) if is_train: @@ -208,4 +221,18 @@ class PG(Policy): cfg = json.load(f) model = cls() model.load_from_pretrained(archive_file, model_file, cfg['load']) - return model \ No newline at end of file + return model + + def load_policy(self, filename=""): + policy_mdl_candidates = [ + filename + '.pol.mdl', + filename + '_ppo.pol.mdl', + os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.pol.mdl'), + os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_ppo.pol.mdl') + ] + for policy_mdl in policy_mdl_candidates: + if os.path.exists(policy_mdl): + print(f"Loaded policy checkpoint from file: {policy_mdl}") + self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE)) + logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl)) + break \ No newline at end of file diff --git a/convlab/policy/ppo/ppo.py b/convlab/policy/ppo/ppo.py index 28fee71c70c640319b29becc77bddfe8311f2767..ca70f2b6f81e716aee544c5d2404b296369964bf 100755 --- a/convlab/policy/ppo/ppo.py +++ b/convlab/policy/ppo/ppo.py @@ -6,10 +6,13 @@ import numpy as np import logging import os import json +from convlab.policy.vector.vector_binary import VectorBinary from convlab.policy.policy import Policy from convlab.policy.rlmodule import MultiDiscretePolicy, Value from convlab.util.custom_util import model_downloader, set_seed import sys +import urllib.request + root_dir = os.path.dirname(os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) @@ -20,7 +23,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class PPO(Policy): - def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None): + def __init__(self, is_train=False, seed=0, vectorizer=None, load_path="", **kwargs): with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs' ,'ppo_config.json'), 'r') as f: cfg = json.load(f) @@ -39,18 +42,29 @@ class PPO(Policy): logging.info('PPO seed ' + str(seed)) set_seed(seed) + dir_name = os.path.dirname(os.path.abspath(__file__)) if self.vector is None: logging.info("No vectorizer was set, using default..") - from convlab.policy.vector.vector_binary import VectorBinary - self.vector = VectorBinary() - - # construct policy and value network - if dataset == 'Multiwoz': - self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], - self.vector.da_dim, seed).to(device=DEVICE) - logging.info(f"ACTION DIM OF PPO: {self.vector.da_dim}") - logging.info(f"STATE DIM OF PPO: {self.vector.state_dim}") + self.vector = VectorBinary(dataset_name=kwargs['dataset_name'], + use_masking=kwargs.get('use_masking', True), + manually_add_entity_names=kwargs.get('manually_add_entity_names', True), + seed=seed) + + self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], + self.vector.da_dim, seed).to(device=DEVICE) + logging.info(f"ACTION DIM OF PPO: {self.vector.da_dim}") + logging.info(f"STATE DIM OF PPO: {self.vector.state_dim}") + + try: + if load_path == "from_pretrained": + urllib.request.urlretrieve( + f"https://huggingface.co/ConvLab/mle-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl", + f"{dir_name}/{self.vector.dataset_name}_mle.pol.mdl") + load_path = f"{dir_name}/{self.vector.dataset_name}_mle" + self.load_policy(load_path) + except Exception as e: + print(f"Could not load the policy, Exception: {e}") self.value = Value(self.vector.state_dim, cfg['hv_dim']).to(device=DEVICE) @@ -263,6 +277,20 @@ class PPO(Policy): logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl)) break + def load_policy(self, filename=""): + policy_mdl_candidates = [ + filename + '.pol.mdl', + filename + '_ppo.pol.mdl', + os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.pol.mdl'), + os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_ppo.pol.mdl') + ] + for policy_mdl in policy_mdl_candidates: + if os.path.exists(policy_mdl): + print(f"Loaded policy checkpoint from file: {policy_mdl}") + self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE)) + logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl)) + break + # Load model from model_path(URL) def load_from_pretrained(self, model_path=""): diff --git a/convlab/policy/vtrace_DPT/vtrace.py b/convlab/policy/vtrace_DPT/vtrace.py index 2918f4dfb019cef3cfe9e0d3981b70c9502700c8..85d239e3299c73999042b9b58b34dc039fe1057f 100644 --- a/convlab/policy/vtrace_DPT/vtrace.py +++ b/convlab/policy/vtrace_DPT/vtrace.py @@ -8,6 +8,7 @@ import torch.nn as nn import urllib.request from torch import optim +from convlab.policy.vector.vector_nodes import VectorNodes from convlab.policy.vtrace_DPT.transformer_model.EncoderDecoder import EncoderDecoder from convlab.policy.vtrace_DPT.transformer_model.EncoderCritic import EncoderCritic from ... import Policy @@ -21,7 +22,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class VTRACE(nn.Module, Policy): - def __init__(self, is_train=True, seed=0, vectorizer=None, load_path=""): + def __init__(self, is_train=True, seed=0, vectorizer=None, load_path="", **kwargs): super(VTRACE, self).__init__() @@ -59,6 +60,13 @@ class VTRACE(nn.Module, Policy): self.last_action = None + if vectorizer is None: + vectorizer = VectorNodes(dataset_name=kwargs['dataset_name'], + use_masking=kwargs.get('use_masking', True), + manually_add_entity_names=kwargs.get('manually_add_entity_names', True), + seed=seed, + filter_state=kwargs.get('filter_state', True)) + self.vector = vectorizer self.cfg['dataset_name'] = self.vector.dataset_name self.policy = EncoderDecoder(**self.cfg, action_dict=self.vector.act2vec).to(device=DEVICE) @@ -67,9 +75,9 @@ class VTRACE(nn.Module, Policy): try: if load_path == "from_pretrained": urllib.request.urlretrieve( - "https://huggingface.co/ConvLab/ddpt-policy-multiwoz21/resolve/main/supervised.pol.mdl", - f"{dir_name}/ddpt.pol.mdl") - load_path = f"{dir_name}/ddpt" + f"https://huggingface.co/ConvLab/ddpt-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl", + f"{dir_name}/{self.vector.dataset_name}_ddpt.pol.mdl") + load_path = f"{dir_name}/{self.vector.dataset_name}_ddpt" self.load_policy(load_path) except Exception as e: print(f"Could not load the policy, Exception: {e}") @@ -104,7 +112,6 @@ class VTRACE(nn.Module, Policy): Returns: action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...}) """ - if not self.is_train: for param in self.policy.parameters(): param.requires_grad = False