diff --git a/convlab/dialog_agent/env.py b/convlab/dialog_agent/env.py index bee3e943db9d7363a672a4627b35ec23dde5c37d..c1f15dfa031c4c72091cf9418e008c25bd04d804 100755 --- a/convlab/dialog_agent/env.py +++ b/convlab/dialog_agent/env.py @@ -19,8 +19,8 @@ class Environment(): self.evaluator = evaluator self.use_semantic_acts = use_semantic_acts - def reset(self): - self.usr.init_session() + def reset(self, goal=None): + self.usr.init_session(goal=goal) self.sys_dst.init_session() if self.evaluator: self.evaluator.add_goal(self.usr.policy.get_goal()) diff --git a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py index 0ffc5b38e318826e3de8370484d19845801db4c1..a1b372f5531e4c350a62b5db1d408743b45b8a92 100755 --- a/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py +++ b/convlab/policy/rule/multiwoz/policy_agenda_multiwoz.py @@ -86,13 +86,13 @@ class UserPolicyAgendaMultiWoz(Policy): def reset_turn(self): self.__turn = 0 - def init_session(self, ini_goal=None): + def init_session(self, goal=None): """ Build new Goal and Agenda for next session """ self.reset_turn() - if not ini_goal: + if not goal: self.goal = Goal(self.goal_generator) else: - self.goal = ini_goal + self.goal = goal self.domain_goals = self.goal.domain_goals self.agenda = Agenda(self.goal) diff --git a/convlab/policy/vector/vector_base.py b/convlab/policy/vector/vector_base.py index a5d1b382262757e508f7e0e44b93543d07d846b5..62245a32c638cfe600957f908a826a0c10955162 100644 --- a/convlab/policy/vector/vector_base.py +++ b/convlab/policy/vector/vector_base.py @@ -2,8 +2,8 @@ import os import sys import numpy as np -import logging +from data.unified_datasets.multiwoz21.database import Database from copy import deepcopy from convlab.policy.vec import Vector from convlab.util.custom_util import flatten_acts @@ -26,8 +26,8 @@ class VectorBase(Vector): self.set_seed(seed) self.ontology = load_ontology(dataset_name) try: - self.db = load_database(dataset_name) - # self.db = Database() + #self.db = load_database(dataset_name) + self.db = Database() self.db_domains = self.db.domains except Exception as e: self.db = None