Skip to content
Snippets Groups Projects
Commit f0004533 authored by Christian's avatar Christian
Browse files

Merge branch 'master' of https://github.com/ConvLab/ConvLab-3 into github_master

parents 26a6256d 0bf91d22
Branches
No related tags found
No related merge requests found
...@@ -16,9 +16,11 @@ ...@@ -16,9 +16,11 @@
# from dataclasses import dataclass, field # from dataclasses import dataclass, field
# import torch # import torch
# from torch import nn # from torch import nn
from torch.utils.data import Dataset
# from transformers.deepspeed import is_deepspeed_zero3_enabled # from transformers.deepspeed import is_deepspeed_zero3_enabled
# from transformers.utils import logging, cached_property, torch_required # from transformers.utils import logging, cached_property, torch_required
from transformers.trainer_utils import PredictionOutput
from transformers.training_args import ( from transformers.training_args import (
os, os,
torch, torch,
...@@ -161,6 +163,103 @@ class ConvLabSeq2SeqTrainingArguments(Seq2SeqTrainingArguments): ...@@ -161,6 +163,103 @@ class ConvLabSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
class ConvLabSeq2SeqTrainer(Seq2SeqTrainer): class ConvLabSeq2SeqTrainer(Seq2SeqTrainer):
# modifed from Seq2SeqTrainer of 4.26.1: https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/trainer_seq2seq.py
# add generation args in `prediction_step`
def evaluate(
self,
eval_dataset: Optional[Dataset] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
**gen_kwargs
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
(pass it to the init `compute_metrics` argument).
You can also subclass and override this method to inject custom behavior.
Args:
eval_dataset (`Dataset`, *optional*):
Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
method.
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is `"eval"` (default)
max_length (`int`, *optional*):
The maximum target length to use when predicting with the generate method.
num_beams (`int`, *optional*):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
gen_kwargs:
Additional `generate` specific kwargs.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
dictionary also contains the epoch number which comes from the training state.
"""
gen_kwargs = gen_kwargs.copy()
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.args.generation_max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
self._gen_kwargs = gen_kwargs
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def predict(
self,
test_dataset: Dataset,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "test",
**gen_kwargs
) -> PredictionOutput:
"""
Run prediction and returns predictions and potential metrics.
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
will also return metrics, like in `evaluate()`.
Args:
test_dataset (`Dataset`):
Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. Has to implement the method `__len__`
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
"eval_bleu" if the prefix is `"eval"` (default)
max_length (`int`, *optional*):
The maximum target length to use when predicting with the generate method.
num_beams (`int`, *optional*):
Number of beams for beam search that will be used when predicting with the generate method. 1 means no
beam search.
gen_kwargs:
Additional `generate` specific kwargs.
<Tip>
If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
padding in a token classification task) the predictions will be padded (on the right) to allow for
concatenation into one array. The padding index is -100.
</Tip>
Returns: *NamedTuple* A namedtuple with the following keys:
- predictions (`np.ndarray`): The predictions on `test_dataset`.
- label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
- metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
labels).
"""
gen_kwargs = gen_kwargs.copy()
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.args.generation_max_length
gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
)
self._gen_kwargs = gen_kwargs
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
def prediction_step( def prediction_step(
self, self,
model: nn.Module, model: nn.Module,
...@@ -194,16 +293,25 @@ class ConvLabSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -194,16 +293,25 @@ class ConvLabSeq2SeqTrainer(Seq2SeqTrainer):
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
# XXX: adapt synced_gpus for fairscale as well # XXX: adapt synced_gpus for fairscale as well
gen_kwargs = { gen_kwargs = self._gen_kwargs.copy()
"max_length": self._max_length if self._max_length is not None else self.model.config.max_length, if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
"num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams, gen_kwargs["max_length"] = self.model.config.max_length
"synced_gpus": True if is_deepspeed_zero3_enabled() else False, gen_kwargs["num_beams"] = (
gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
)
default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
gen_kwargs["synced_gpus"] = (
gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
)
# DONE: add generation arguments
gen_kwargs.update({
"do_sample": self.args.do_sample, "do_sample": self.args.do_sample,
"temperature": self.args.temperature, "temperature": self.args.temperature,
"top_k": self.args.top_k, "top_k": self.args.top_k,
"top_p": self.args.top_p, "top_p": self.args.top_p,
"num_return_sequences": self.args.num_return_sequences "num_return_sequences": self.args.num_return_sequences
} })
if "attention_mask" in inputs: if "attention_mask" in inputs:
gen_kwargs["attention_mask"] = inputs.get("attention_mask", None) gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
...@@ -223,13 +331,17 @@ class ConvLabSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -223,13 +331,17 @@ class ConvLabSeq2SeqTrainer(Seq2SeqTrainer):
**gen_kwargs, **gen_kwargs,
) )
# in case the batch is shorter than max length, the output should be padded # in case the batch is shorter than max length, the output should be padded
if generated_tokens.shape[-1] < gen_kwargs["max_length"]: if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"]) generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
gen_kwargs["max_new_tokens"] + 1
):
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
with torch.no_grad(): with torch.no_grad():
with self.autocast_smart_context_manager():
outputs = model(**inputs)
if has_labels: if has_labels:
with self.compute_loss_context_manager():
outputs = model(**inputs)
if self.label_smoother is not None: if self.label_smoother is not None:
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
else: else:
...@@ -242,8 +354,12 @@ class ConvLabSeq2SeqTrainer(Seq2SeqTrainer): ...@@ -242,8 +354,12 @@ class ConvLabSeq2SeqTrainer(Seq2SeqTrainer):
if has_labels: if has_labels:
labels = inputs["labels"] labels = inputs["labels"]
if labels.shape[-1] < gen_kwargs["max_length"]: if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"]) labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
gen_kwargs["max_new_tokens"] + 1
):
labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
else: else:
labels = None labels = None
......
...@@ -59,22 +59,17 @@ ...@@ -59,22 +59,17 @@
```json ```json
"nlu": "nlu":
{ {
"svm-cam": { "t5nlu-mul": {
"class_path": "convlab.nlu.svm.camrest.nlu.SVMNLU", "class_path": "convlab.base_models.t5.nlu.T5NLU",
"data_set": "camrest",
"ini_params": {"mode": "usr"},
"model_name": "svm-cam",
"max_core": 1,
"preload": true,
"enable": true
},
"svm-mul": {
"class_path": "convlab.nlu.svm.multiwoz.nlu.SVMNLU",
"data_set": "multiwoz", "data_set": "multiwoz",
"ini_params": {"mode": "usr"}, "ini_params": {
"model_name": "svm-mul", "speaker": "user",
"context_window_size": 0,
"model_name_or_path": "ConvLab/t5-small-nlu-multiwoz21"
},
"model_name": "t5nlu-mul",
"max_core": 1, "max_core": 1,
"preload": false, "preload": true,
"enable": true "enable": true
} }
} }
......
...@@ -5,185 +5,62 @@ ...@@ -5,185 +5,62 @@
"session_time_out": 300 "session_time_out": 300
}, },
"nlu": { "nlu": {
"svm-cam": { "t5nlu-mul": {
"class_path": "convlab.nlu.svm.camrest.nlu.SVMNLU", "class_path": "convlab.base_models.t5.nlu.T5NLU",
"data_set": "camrest",
"ini_params": {
"mode": "usr"
},
"model_name": "svm-cam",
"max_core": 1,
"preload": true,
"enable": true
},
"svm-mul": {
"class_path": "convlab.nlu.svm.multiwoz.nlu.SVMNLU",
"data_set": "multiwoz",
"ini_params": {
"mode": "usr"
},
"model_name": "svm-mul",
"max_core": 1,
"preload": false,
"enable": true
},
"bert-cro": {
"class_path": "convlab.nlu.jointBERT.crosswoz.nlu.BERTNLU",
"data_set": "crosswoz",
"ini_params": {
"mode": "all",
"config_file": "crosswoz_all.json",
"model_file": "https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_crosswoz_all.zip"
},
"model_name": "bert-cro",
"max_core": 1,
"preload": false,
"enable": true
},
"bert-mul": {
"class_path": "convlab.nlu.jointBERT.multiwoz.nlu.BERTNLU",
"data_set": "multiwoz", "data_set": "multiwoz",
"ini_params": { "ini_params": {
"mode": "all", "speaker": "user",
"config_file": "multiwoz_all.json", "context_window_size": 0,
"model_file": "https://huggingface.co/ConvLab/ConvLab-2_models/resolve/main/bert_multiwoz_all.zip" "model_name_or_path": "ConvLab/t5-small-nlu-multiwoz21"
}, },
"model_name": "bert-mul", "model_name": "t5nlu-mul",
"max_core": 1, "max_core": 1,
"preload": false, "preload": true,
"enable": true "enable": true
} }
}, },
"dst": { "dst": {
"rule-cam": { "t5dst-mul": {
"class_path": "convlab.dst.rule.camrest.dst.RuleDST", "class_path": "convlab.base_models.t5.dst.T5DST",
"data_set": "camrest",
"ini_params": {},
"model_name": "rule-cam",
"max_core": 1,
"preload": true,
"enable": true
},
"rule-mul": {
"class_path": "convlab.dst.rule.multiwoz.dst.RuleDST",
"data_set": "multiwoz", "data_set": "multiwoz",
"ini_params": {}, "ini_params": {
"model_name": "rule-mul", "dataset_name": "multiwoz21",
"max_core": 1, "speaker": "user",
"preload": true, "context_window_size": 100,
"enable": true "model_name_or_path": "ConvLab/t5-small-dst-multiwoz21"
},
"rule-cro": {
"class_path": "convlab.dst.rule.crosswoz.dst.RuleDST",
"data_set": "crosswoz",
"ini_params": {},
"model_name": "rule-cro",
"max_core": 1,
"preload": true,
"enable": true
}, },
"trade-mul": { "model_name": "t5dst-mul",
"class_path": "convlab.dst.trade.multiwoz.trade.MultiWOZTRADE",
"data_set": "multiwoz",
"ini_params": {},
"model_name": "trade-mul",
"max_core": 1, "max_core": 1,
"preload": true, "preload": true,
"enable": true "enable": true
} }
}, },
"policy": { "policy": {
"mle-cam": { "ddpt-mul": {
"class_path": "convlab.policy.mle.camrest.mle.MLE", "class_path": "convlab.policy.vtrace_DPT.VTRACE",
"data_set": "camrest",
"ini_params": {},
"model_name": "mle-cam",
"max_core": 1,
"preload": false,
"enable": true
},
"mle-mul": {
"class_path": "convlab.policy.mle.multiwoz.mle.MLE",
"data_set": "multiwoz", "data_set": "multiwoz",
"ini_params": {}, "ini_params": {
"model_name": "mle-mul", "is_train": false,
"max_core": 1, "seed": 0,
"preload": false, "load_path": "from_pretrained",
"enable": true "dataset_name": "multiwoz21"
},
"rule-cam": {
"class_path": "convlab.policy.rule.camrest.rule_based_camrest_bot.RuleBasedCamrestBot",
"data_set": "camrest",
"ini_params": {},
"model_name": "rule-cam",
"max_core": 1,
"preload": true,
"enable": true
}, },
"rule-mul": { "model_name": "ddpt-mul",
"class_path": "convlab.policy.rule.multiwoz.rule_based_multiwoz_bot.RuleBasedMultiwozBot",
"data_set": "multiwoz",
"ini_params": {},
"model_name": "rule-mul",
"max_core": 1, "max_core": 1,
"preload": true, "preload": true,
"enable": true "enable": true
},
"mle-cro": {
"class_path": "convlab.policy.mle.crosswoz.mle.MLE",
"data_set": "crosswoz",
"ini_params": {},
"model_name": "mle-cro",
"max_core": 1,
"preload": false,
"enable": true
} }
}, },
"nlg": { "nlg": {
"tmp-manual-cam": { "t5nlg-mul": {
"class_path": "convlab.nlg.template.camrest.nlg.TemplateNLG", "class_path": "convlab.base_models.t5.nlg.T5NLG",
"data_set": "camrest",
"ini_params": {
"is_user": false
},
"model_name": "tmp-manual-cam",
"max_core": 1,
"preload": true,
"enable": true
},
"tmp-auto_manual-cam": {
"class_path": "convlab.nlg.template.camrest.nlg.TemplateNLG",
"data_set": "camrest",
"ini_params": {
"is_user": false,
"mode": "auto_manual"
},
"model_name": "tmp-auto_manual-cam",
"max_core": 1,
"preload": true,
"enable": true
},
"tmp-auto_manual-mul": {
"class_path": "convlab.nlg.template.multiwoz.nlg.TemplateNLG",
"data_set": "multiwoz", "data_set": "multiwoz",
"ini_params": { "ini_params": {
"is_user": false, "speaker": "system",
"mode": "auto_manual" "context_window_size": 0,
}, "model_name_or_path": "ConvLab/t5-small-nlg-multiwoz21"
"model_name": "tmp-auto_manual-mul",
"max_core": 1,
"preload": true,
"enable": true
},
"tmp-auto_manual-cro": {
"class_path": "convlab.nlg.template.crosswoz.nlg.TemplateNLG",
"data_set": "crosswoz",
"ini_params": {
"is_user": false,
"mode": "auto_manual"
}, },
"model_name": "tmp-auto_manual-cro", "model_name": "t5nlg-mul",
"max_core": 1, "max_core": 1,
"preload": true, "preload": true,
"enable": true "enable": true
......
...@@ -13,6 +13,7 @@ from convlab.policy.vector.vector_binary import VectorBinary ...@@ -13,6 +13,7 @@ from convlab.policy.vector.vector_binary import VectorBinary
from convlab.util.file_util import cached_path from convlab.util.file_util import cached_path
import zipfile import zipfile
import sys import sys
import urllib.request
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(root_dir) sys.path.append(root_dir)
...@@ -22,7 +23,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") ...@@ -22,7 +23,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class PG(Policy): class PG(Policy):
def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None): def __init__(self, is_train=False, seed=0, vectorizer=None, load_path="", **kwargs):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f: with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.json'), 'r') as f:
cfg = json.load(f) cfg = json.load(f)
self.cfg = cfg self.cfg = cfg
...@@ -32,20 +33,32 @@ class PG(Policy): ...@@ -32,20 +33,32 @@ class PG(Policy):
self.optim_batchsz = cfg['batchsz'] self.optim_batchsz = cfg['batchsz']
self.gamma = cfg['gamma'] self.gamma = cfg['gamma']
self.is_train = is_train self.is_train = is_train
self.vector = vectorizer
self.info_dict = {} self.info_dict = {}
set_seed(seed) set_seed(seed)
self.vector = vectorizer
dir_name = os.path.dirname(os.path.abspath(__file__))
if self.vector is None: if self.vector is None:
logging.info("No vectorizer was set, using default..") logging.info("No vectorizer was set, using default..")
from convlab.policy.vector.vector_binary import VectorBinary self.vector = VectorBinary(dataset_name=kwargs['dataset_name'],
self.vector = VectorBinary() use_masking=kwargs.get('use_masking', True),
manually_add_entity_names=kwargs.get('manually_add_entity_names', True),
seed=seed)
if dataset == 'Multiwoz':
self.vector = vectorizer
self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE) self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)
try:
if load_path == "from_pretrained":
urllib.request.urlretrieve(
f"https://huggingface.co/ConvLab/mle-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl",
f"{dir_name}/{self.vector.dataset_name}_mle.pol.mdl")
load_path = f"{dir_name}/{self.vector.dataset_name}_mle"
self.load_policy(load_path)
except Exception as e:
print(f"Could not load the policy, Exception: {e}")
# self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE) # self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.vector.da_dim).to(device=DEVICE)
if is_train: if is_train:
self.policy_optim = optim.RMSprop(self.policy.parameters(), lr=cfg['lr']) self.policy_optim = optim.RMSprop(self.policy.parameters(), lr=cfg['lr'])
...@@ -209,3 +222,17 @@ class PG(Policy): ...@@ -209,3 +222,17 @@ class PG(Policy):
model = cls() model = cls()
model.load_from_pretrained(archive_file, model_file, cfg['load']) model.load_from_pretrained(archive_file, model_file, cfg['load'])
return model return model
def load_policy(self, filename=""):
policy_mdl_candidates = [
filename + '.pol.mdl',
filename + '_ppo.pol.mdl',
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.pol.mdl'),
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_ppo.pol.mdl')
]
for policy_mdl in policy_mdl_candidates:
if os.path.exists(policy_mdl):
print(f"Loaded policy checkpoint from file: {policy_mdl}")
self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))
break
\ No newline at end of file
...@@ -6,10 +6,13 @@ import numpy as np ...@@ -6,10 +6,13 @@ import numpy as np
import logging import logging
import os import os
import json import json
from convlab.policy.vector.vector_binary import VectorBinary
from convlab.policy.policy import Policy from convlab.policy.policy import Policy
from convlab.policy.rlmodule import MultiDiscretePolicy, Value from convlab.policy.rlmodule import MultiDiscretePolicy, Value
from convlab.util.custom_util import model_downloader, set_seed from convlab.util.custom_util import model_downloader, set_seed
import sys import sys
import urllib.request
root_dir = os.path.dirname(os.path.dirname( root_dir = os.path.dirname(os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
...@@ -20,7 +23,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") ...@@ -20,7 +23,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class PPO(Policy): class PPO(Policy):
def __init__(self, is_train=False, dataset='Multiwoz', seed=0, vectorizer=None): def __init__(self, is_train=False, seed=0, vectorizer=None, load_path="", **kwargs):
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs' ,'ppo_config.json'), 'r') as f: with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs' ,'ppo_config.json'), 'r') as f:
cfg = json.load(f) cfg = json.load(f)
...@@ -39,19 +42,30 @@ class PPO(Policy): ...@@ -39,19 +42,30 @@ class PPO(Policy):
logging.info('PPO seed ' + str(seed)) logging.info('PPO seed ' + str(seed))
set_seed(seed) set_seed(seed)
dir_name = os.path.dirname(os.path.abspath(__file__))
if self.vector is None: if self.vector is None:
logging.info("No vectorizer was set, using default..") logging.info("No vectorizer was set, using default..")
from convlab.policy.vector.vector_binary import VectorBinary self.vector = VectorBinary(dataset_name=kwargs['dataset_name'],
self.vector = VectorBinary() use_masking=kwargs.get('use_masking', True),
manually_add_entity_names=kwargs.get('manually_add_entity_names', True),
seed=seed)
# construct policy and value network
if dataset == 'Multiwoz':
self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'], self.policy = MultiDiscretePolicy(self.vector.state_dim, cfg['h_dim'],
self.vector.da_dim, seed).to(device=DEVICE) self.vector.da_dim, seed).to(device=DEVICE)
logging.info(f"ACTION DIM OF PPO: {self.vector.da_dim}") logging.info(f"ACTION DIM OF PPO: {self.vector.da_dim}")
logging.info(f"STATE DIM OF PPO: {self.vector.state_dim}") logging.info(f"STATE DIM OF PPO: {self.vector.state_dim}")
try:
if load_path == "from_pretrained":
urllib.request.urlretrieve(
f"https://huggingface.co/ConvLab/mle-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl",
f"{dir_name}/{self.vector.dataset_name}_mle.pol.mdl")
load_path = f"{dir_name}/{self.vector.dataset_name}_mle"
self.load_policy(load_path)
except Exception as e:
print(f"Could not load the policy, Exception: {e}")
self.value = Value(self.vector.state_dim, self.value = Value(self.vector.state_dim,
cfg['hv_dim']).to(device=DEVICE) cfg['hv_dim']).to(device=DEVICE)
if is_train: if is_train:
...@@ -263,6 +277,20 @@ class PPO(Policy): ...@@ -263,6 +277,20 @@ class PPO(Policy):
logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl)) logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))
break break
def load_policy(self, filename=""):
policy_mdl_candidates = [
filename + '.pol.mdl',
filename + '_ppo.pol.mdl',
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '.pol.mdl'),
os.path.join(os.path.dirname(os.path.abspath(__file__)), filename + '_ppo.pol.mdl')
]
for policy_mdl in policy_mdl_candidates:
if os.path.exists(policy_mdl):
print(f"Loaded policy checkpoint from file: {policy_mdl}")
self.policy.load_state_dict(torch.load(policy_mdl, map_location=DEVICE))
logging.info('<<dialog policy>> loaded checkpoint from file: {}'.format(policy_mdl))
break
# Load model from model_path(URL) # Load model from model_path(URL)
def load_from_pretrained(self, model_path=""): def load_from_pretrained(self, model_path=""):
......
...@@ -8,6 +8,7 @@ import torch.nn as nn ...@@ -8,6 +8,7 @@ import torch.nn as nn
import urllib.request import urllib.request
from torch import optim from torch import optim
from convlab.policy.vector.vector_nodes import VectorNodes
from convlab.policy.vtrace_DPT.transformer_model.EncoderDecoder import EncoderDecoder from convlab.policy.vtrace_DPT.transformer_model.EncoderDecoder import EncoderDecoder
from convlab.policy.vtrace_DPT.transformer_model.EncoderCritic import EncoderCritic from convlab.policy.vtrace_DPT.transformer_model.EncoderCritic import EncoderCritic
from ... import Policy from ... import Policy
...@@ -21,7 +22,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") ...@@ -21,7 +22,7 @@ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class VTRACE(nn.Module, Policy): class VTRACE(nn.Module, Policy):
def __init__(self, is_train=True, seed=0, vectorizer=None, load_path=""): def __init__(self, is_train=True, seed=0, vectorizer=None, load_path="", **kwargs):
super(VTRACE, self).__init__() super(VTRACE, self).__init__()
...@@ -59,6 +60,13 @@ class VTRACE(nn.Module, Policy): ...@@ -59,6 +60,13 @@ class VTRACE(nn.Module, Policy):
self.last_action = None self.last_action = None
if vectorizer is None:
vectorizer = VectorNodes(dataset_name=kwargs['dataset_name'],
use_masking=kwargs.get('use_masking', True),
manually_add_entity_names=kwargs.get('manually_add_entity_names', True),
seed=seed,
filter_state=kwargs.get('filter_state', True))
self.vector = vectorizer self.vector = vectorizer
self.cfg['dataset_name'] = self.vector.dataset_name self.cfg['dataset_name'] = self.vector.dataset_name
self.policy = EncoderDecoder(**self.cfg, action_dict=self.vector.act2vec).to(device=DEVICE) self.policy = EncoderDecoder(**self.cfg, action_dict=self.vector.act2vec).to(device=DEVICE)
...@@ -67,9 +75,9 @@ class VTRACE(nn.Module, Policy): ...@@ -67,9 +75,9 @@ class VTRACE(nn.Module, Policy):
try: try:
if load_path == "from_pretrained": if load_path == "from_pretrained":
urllib.request.urlretrieve( urllib.request.urlretrieve(
"https://huggingface.co/ConvLab/ddpt-policy-multiwoz21/resolve/main/supervised.pol.mdl", f"https://huggingface.co/ConvLab/ddpt-policy-{self.vector.dataset_name}/resolve/main/supervised.pol.mdl",
f"{dir_name}/ddpt.pol.mdl") f"{dir_name}/{self.vector.dataset_name}_ddpt.pol.mdl")
load_path = f"{dir_name}/ddpt" load_path = f"{dir_name}/{self.vector.dataset_name}_ddpt"
self.load_policy(load_path) self.load_policy(load_path)
except Exception as e: except Exception as e:
print(f"Could not load the policy, Exception: {e}") print(f"Could not load the policy, Exception: {e}")
...@@ -104,7 +112,6 @@ class VTRACE(nn.Module, Policy): ...@@ -104,7 +112,6 @@ class VTRACE(nn.Module, Policy):
Returns: Returns:
action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...}) action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
""" """
if not self.is_train: if not self.is_train:
for param in self.policy.parameters(): for param in self.policy.parameters():
param.requires_grad = False param.requires_grad = False
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment