diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index 566fd718c9726065680fb28ef81443a0af10e7cf..39d378d3cd8d77d140cb8947d00c314b78f72b66 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -28,7 +28,7 @@ class VectorBase(Vector): self.ontology = load_ontology(dataset_name) try: # execute to make sure that the database exists or is downloaded otherwise - if dataset_name == "multiwoz21": + if dataset_name == "multiwoz21" or dataset_name == "crosswoz": load_database(dataset_name) # the following two lines are needed for pickling correctly during multi-processing exec(f'from data.unified_datasets.{dataset_name}.database import Database') diff --git a/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py b/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py index ab57df122369e9457c78e04fe2e64eab389d5abd..fdbf3b50890ee5231d0b42de5bd4f0c528a49961 100644 --- a/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py +++ b/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py @@ -2,7 +2,7 @@ import os, json, logging import torch import torch.nn as nn -from transformers import RobertaTokenizer, RobertaModel +from transformers import RobertaTokenizer, RobertaModel, BertTokenizer, BertModel from convlab.policy.vtrace_DPT.transformer_model.noisy_linear import NoisyLinear from convlab.policy.vtrace_DPT.create_descriptions import create_description_dicts @@ -52,8 +52,13 @@ class NodeEmbedderRoberta(nn.Module): if os.path.exists(embedded_descriptions_path): self.embedded_descriptions = torch.load(embedded_descriptions_path).to(DEVICE) else: - self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base") - self.roberta_model = RobertaModel.from_pretrained("roberta-base").to(DEVICE) + if dataset_name == "crosswoz": + self.max_length = 40 + self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext") + self.roberta_model = BertModel.from_pretrained("hfl/chinese-roberta-wwm-ext").to(DEVICE) + else: + self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + self.roberta_model = RobertaModel.from_pretrained("roberta-base").to(DEVICE) if self.embedded_descriptions is None: if freeze_roberta: