From b980495717469a353c842217136eba29f4ecba53 Mon Sep 17 00:00:00 2001
From: Christian <christian.geishauser@hhu.de>
Date: Mon, 4 Jul 2022 09:53:49 +0200
Subject: [PATCH] bugfix, i need to load database manually in vectorbase,
 otherwise distributed training fails due to not pickle serializable. Also
 made environment such that it can expect a goal passed to it when resetting

---
 convlab/dialog_agent/env.py                            | 4 ++--
 convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py | 6 +++---
 convlab/policy/vector/vector_base.py                   | 6 +++---
 3 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/convlab/dialog_agent/env.py b/convlab/dialog_agent/env.py
index bee3e943..c1f15dfa 100755
--- a/convlab/dialog_agent/env.py
+++ b/convlab/dialog_agent/env.py
@@ -19,8 +19,8 @@ class Environment():
         self.evaluator = evaluator
         self.use_semantic_acts = use_semantic_acts
 
-    def reset(self):
-        self.usr.init_session()
+    def reset(self, goal=None):
+        self.usr.init_session(goal=goal)
         self.sys_dst.init_session()
         if self.evaluator:
             self.evaluator.add_goal(self.usr.policy.get_goal())
diff --git a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py
index 0ffc5b38..a1b372f5 100755
--- a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py
+++ b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py
@@ -86,13 +86,13 @@ class UserPolicyAgendaMultiWoz(Policy):
     def reset_turn(self):
         self.__turn = 0
 
-    def init_session(self, ini_goal=None):
+    def init_session(self, goal=None):
         """ Build new Goal and Agenda for next session """
         self.reset_turn()
-        if not ini_goal:
+        if not goal:
             self.goal = Goal(self.goal_generator)
         else:
-            self.goal = ini_goal
+            self.goal = goal
         self.domain_goals = self.goal.domain_goals
         self.agenda = Agenda(self.goal)
 
diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py
index a5d1b382..62245a32 100644
--- a/convlab/policy/vector/vector_base.py
+++ b/convlab/policy/vector/vector_base.py
@@ -2,8 +2,8 @@
 import os
 import sys
 import numpy as np
-import logging
 
+from data.unified_datasets.multiwoz21.database import Database
 from copy import deepcopy
 from convlab.policy.vec import Vector
 from convlab.util.custom_util import flatten_acts
@@ -26,8 +26,8 @@ class VectorBase(Vector):
         self.set_seed(seed)
         self.ontology = load_ontology(dataset_name)
         try:
-            self.db = load_database(dataset_name)
-            # self.db = Database()
+            #self.db = load_database(dataset_name)
+            self.db = Database()
             self.db_domains = self.db.domains
         except Exception as e:
             self.db = None
-- 
GitLab