From b980495717469a353c842217136eba29f4ecba53 Mon Sep 17 00:00:00 2001 From: Christian <christian.geishauser@hhu.de> Date: Mon, 4 Jul 2022 09:53:49 +0200 Subject: [PATCH] bugfix, i need to load database manually in vectorbase, otherwise distributed training fails due to not pickle serializable. Also made environment such that it can expect a goal passed to it when resetting --- convlab/dialog_agent/env.py | 4 ++-- convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py | 6 +++--- convlab/policy/vector/vector_base.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/convlab/dialog_agent/env.py b/convlab/dialog_agent/env.py index bee3e943..c1f15dfa 100755 --- a/convlab/dialog_agent/env.py +++ b/convlab/dialog_agent/env.py @@ -19,8 +19,8 @@ class Environment(): self.evaluator = evaluator self.use_semantic_acts = use_semantic_acts - def reset(self): - self.usr.init_session() + def reset(self, goal=None): + self.usr.init_session(goal=goal) self.sys_dst.init_session() if self.evaluator: self.evaluator.add_goal(self.usr.policy.get_goal()) diff --git a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py index 0ffc5b38..a1b372f5 100755 --- a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -86,13 +86,13 @@ class UserPolicyAgendaMultiWoz(Policy): def reset_turn(self): self.__turn = 0 - def init_session(self, ini_goal=None): + def init_session(self, goal=None): """ Build new Goal and Agenda for next session """ self.reset_turn() - if not ini_goal: + if not goal: self.goal = Goal(self.goal_generator) else: - self.goal = ini_goal + self.goal = goal self.domain_goals = self.goal.domain_goals self.agenda = Agenda(self.goal) diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index a5d1b382..62245a32 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -2,8 +2,8 @@ import os import sys import numpy as np -import logging +from data.unified_datasets.multiwoz21.database import Database from copy import deepcopy from convlab.policy.vec import Vector from convlab.util.custom_util import flatten_acts @@ -26,8 +26,8 @@ class VectorBase(Vector): self.set_seed(seed) self.ontology = load_ontology(dataset_name) try: - self.db = load_database(dataset_name) - # self.db = Database() + #self.db = load_database(dataset_name) + self.db = Database() self.db_domains = self.db.domains except Exception as e: self.db = None -- GitLab