From d18cf12198f572697f874b2d2e498d9713b4425c Mon Sep 17 00:00:00 2001 From: Christian <christian.geishauser@hhu.de> Date: Mon, 23 Jan 2023 12:55:31 +0100 Subject: [PATCH] using chinese BERT if crosswoz is used --- convlab/policy/vector/vector_base.py | 2 +- .../vtrace_DPT/transformer_model/node_embedder.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index 566fd718..39d378d3 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -28,7 +28,7 @@ class VectorBase(Vector): self.ontology = load_ontology(dataset_name) try: # execute to make sure that the database exists or is downloaded otherwise - if dataset_name == "multiwoz21": + if dataset_name == "multiwoz21" or dataset_name == "crosswoz": load_database(dataset_name) # the following two lines are needed for pickling correctly during multi-processing exec(f'from data.unified_datasets.{dataset_name}.database import Database') diff --git a/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py b/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py index ab57df12..fdbf3b50 100644 --- a/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py +++ b/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py @@ -2,7 +2,7 @@ import os, json, logging import torch import torch.nn as nn -from transformers import RobertaTokenizer, RobertaModel +from transformers import RobertaTokenizer, RobertaModel, BertTokenizer, BertModel from convlab.policy.vtrace_DPT.transformer_model.noisy_linear import NoisyLinear from convlab.policy.vtrace_DPT.create_descriptions import create_description_dicts @@ -52,8 +52,13 @@ class NodeEmbedderRoberta(nn.Module): if os.path.exists(embedded_descriptions_path): self.embedded_descriptions = torch.load(embedded_descriptions_path).to(DEVICE) else: - self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base") - self.roberta_model = RobertaModel.from_pretrained("roberta-base").to(DEVICE) + if dataset_name == "crosswoz": + self.max_length = 40 + self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext") + self.roberta_model = BertModel.from_pretrained("hfl/chinese-roberta-wwm-ext").to(DEVICE) + else: + self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + self.roberta_model = RobertaModel.from_pretrained("roberta-base").to(DEVICE) if self.embedded_descriptions is None: if freeze_roberta: -- GitLab