Skip to content
Snippets Groups Projects
Commit d18cf121 authored by Christian's avatar Christian
Browse files

using chinese BERT if crosswoz is used

parent d4ffd925
No related branches found
No related tags found
No related merge requests found
...@@ -28,7 +28,7 @@ class VectorBase(Vector): ...@@ -28,7 +28,7 @@ class VectorBase(Vector):
self.ontology = load_ontology(dataset_name) self.ontology = load_ontology(dataset_name)
try: try:
# execute to make sure that the database exists or is downloaded otherwise # execute to make sure that the database exists or is downloaded otherwise
if dataset_name == "multiwoz21": if dataset_name == "multiwoz21" or dataset_name == "crosswoz":
load_database(dataset_name) load_database(dataset_name)
# the following two lines are needed for pickling correctly during multi-processing # the following two lines are needed for pickling correctly during multi-processing
exec(f'from data.unified_datasets.{dataset_name}.database import Database') exec(f'from data.unified_datasets.{dataset_name}.database import Database')
......
...@@ -2,7 +2,7 @@ import os, json, logging ...@@ -2,7 +2,7 @@ import os, json, logging
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel from transformers import RobertaTokenizer, RobertaModel, BertTokenizer, BertModel
from convlab.policy.vtrace_DPT.transformer_model.noisy_linear import NoisyLinear from convlab.policy.vtrace_DPT.transformer_model.noisy_linear import NoisyLinear
from convlab.policy.vtrace_DPT.create_descriptions import create_description_dicts from convlab.policy.vtrace_DPT.create_descriptions import create_description_dicts
...@@ -51,6 +51,11 @@ class NodeEmbedderRoberta(nn.Module): ...@@ -51,6 +51,11 @@ class NodeEmbedderRoberta(nn.Module):
f'embedded_descriptions_base_{self.dataset_name}.pt') f'embedded_descriptions_base_{self.dataset_name}.pt')
if os.path.exists(embedded_descriptions_path): if os.path.exists(embedded_descriptions_path):
self.embedded_descriptions = torch.load(embedded_descriptions_path).to(DEVICE) self.embedded_descriptions = torch.load(embedded_descriptions_path).to(DEVICE)
else:
if dataset_name == "crosswoz":
self.max_length = 40
self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
self.roberta_model = BertModel.from_pretrained("hfl/chinese-roberta-wwm-ext").to(DEVICE)
else: else:
self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base") self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
self.roberta_model = RobertaModel.from_pretrained("roberta-base").to(DEVICE) self.roberta_model = RobertaModel.from_pretrained("roberta-base").to(DEVICE)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment