Skip to content
Snippets Groups Projects
Commit b9804957 authored by Christian's avatar Christian
Browse files

bugfix, i need to load database manually in vectorbase, otherwise distributed...

bugfix, i need to load database manually in vectorbase, otherwise distributed training fails due to not pickle serializable. Also made environment such that it can expect a goal passed to it when resetting
parent aa8d3ff9
Branches
No related tags found
No related merge requests found
...@@ -19,8 +19,8 @@ class Environment(): ...@@ -19,8 +19,8 @@ class Environment():
self.evaluator = evaluator self.evaluator = evaluator
self.use_semantic_acts = use_semantic_acts self.use_semantic_acts = use_semantic_acts
def reset(self): def reset(self, goal=None):
self.usr.init_session() self.usr.init_session(goal=goal)
self.sys_dst.init_session() self.sys_dst.init_session()
if self.evaluator: if self.evaluator:
self.evaluator.add_goal(self.usr.policy.get_goal()) self.evaluator.add_goal(self.usr.policy.get_goal())
......
...@@ -86,13 +86,13 @@ class UserPolicyAgendaMultiWoz(Policy): ...@@ -86,13 +86,13 @@ class UserPolicyAgendaMultiWoz(Policy):
def reset_turn(self): def reset_turn(self):
self.__turn = 0 self.__turn = 0
def init_session(self, ini_goal=None): def init_session(self, goal=None):
""" Build new Goal and Agenda for next session """ """ Build new Goal and Agenda for next session """
self.reset_turn() self.reset_turn()
if not ini_goal: if not goal:
self.goal = Goal(self.goal_generator) self.goal = Goal(self.goal_generator)
else: else:
self.goal = ini_goal self.goal = goal
self.domain_goals = self.goal.domain_goals self.domain_goals = self.goal.domain_goals
self.agenda = Agenda(self.goal) self.agenda = Agenda(self.goal)
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
import os import os
import sys import sys
import numpy as np import numpy as np
import logging
from data.unified_datasets.multiwoz21.database import Database
from copy import deepcopy from copy import deepcopy
from convlab.policy.vec import Vector from convlab.policy.vec import Vector
from convlab.util.custom_util import flatten_acts from convlab.util.custom_util import flatten_acts
...@@ -26,8 +26,8 @@ class VectorBase(Vector): ...@@ -26,8 +26,8 @@ class VectorBase(Vector):
self.set_seed(seed) self.set_seed(seed)
self.ontology = load_ontology(dataset_name) self.ontology = load_ontology(dataset_name)
try: try:
self.db = load_database(dataset_name) #self.db = load_database(dataset_name)
# self.db = Database() self.db = Database()
self.db_domains = self.db.domains self.db_domains = self.db.domains
except Exception as e: except Exception as e:
self.db = None self.db = None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment