From d18cf12198f572697f874b2d2e498d9713b4425c Mon Sep 17 00:00:00 2001
From: Christian <christian.geishauser@hhu.de>
Date: Mon, 23 Jan 2023 12:55:31 +0100
Subject: [PATCH] using chinese BERT if crosswoz is used

---
 convlab/policy/vector/vector_base.py                  |  2 +-
 .../vtrace_DPT/transformer_model/node_embedder.py     | 11 ++++++++---
 2 files changed, 9 insertions(+), 4 deletions(-)

diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py
index 566fd718..39d378d3 100644
--- a/convlab/policy/vector/vector_base.py
+++ b/convlab/policy/vector/vector_base.py
@@ -28,7 +28,7 @@ class VectorBase(Vector):
         self.ontology = load_ontology(dataset_name)
         try:
             # execute to make sure that the database exists or is downloaded otherwise
-            if dataset_name == "multiwoz21":
+            if dataset_name == "multiwoz21" or dataset_name == "crosswoz":
                 load_database(dataset_name)
             # the following two lines are needed for pickling correctly during multi-processing
             exec(f'from data.unified_datasets.{dataset_name}.database import Database')
diff --git a/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py b/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py
index ab57df12..fdbf3b50 100644
--- a/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py
+++ b/convlab/policy/vtrace_DPT/transformer_model/node_embedder.py
@@ -2,7 +2,7 @@ import os, json, logging
 import torch
 import torch.nn as nn
 
-from transformers import RobertaTokenizer, RobertaModel
+from transformers import RobertaTokenizer, RobertaModel, BertTokenizer, BertModel
 from convlab.policy.vtrace_DPT.transformer_model.noisy_linear import NoisyLinear
 from convlab.policy.vtrace_DPT.create_descriptions import create_description_dicts
 
@@ -52,8 +52,13 @@ class NodeEmbedderRoberta(nn.Module):
             if os.path.exists(embedded_descriptions_path):
                 self.embedded_descriptions = torch.load(embedded_descriptions_path).to(DEVICE)
             else:
-                self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
-                self.roberta_model = RobertaModel.from_pretrained("roberta-base").to(DEVICE)
+                if dataset_name == "crosswoz":
+                    self.max_length = 40
+                    self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
+                    self.roberta_model = BertModel.from_pretrained("hfl/chinese-roberta-wwm-ext").to(DEVICE)
+                else:
+                    self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
+                    self.roberta_model = RobertaModel.from_pretrained("roberta-base").to(DEVICE)
 
         if self.embedded_descriptions is None:
             if freeze_roberta:
-- 
GitLab