diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index 62245a32c638cfe600957f908a826a0c10955162..8b7d8ff0ddafed41efc91b249003ae55c525bc93 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -3,7 +3,6 @@ import os import sys import numpy as np -from data.unified_datasets.multiwoz21.database import Database from copy import deepcopy from convlab.policy.vec import Vector from convlab.util.custom_util import flatten_acts @@ -26,8 +25,11 @@ class VectorBase(Vector): self.set_seed(seed) self.ontology = load_ontology(dataset_name) try: - #self.db = load_database(dataset_name) - self.db = Database() + # execute to make sure that the database exists or is downloaded otherwise + load_database(dataset_name) + # the following two lines are needed for pickling correctly during multi-processing + exec(f'from data.unified_datasets.{dataset_name}.database import Database') + self.db = eval('Database()') self.db_domains = self.db.domains except Exception as e: self.db = None