Skip to content
Snippets Groups Projects
Unverified Commit 3618f9c9 authored by Christian Geishauser's avatar Christian Geishauser Committed by GitHub
Browse files

Merge pull request #61 from ConvLab/refactor_RL

Refactor RL code
parents 2cfe5c08 6ace023e
No related branches found
No related tags found
No related merge requests found
...@@ -11,7 +11,8 @@ from convlab2.dialog_agent.agent import PipelineAgent ...@@ -11,7 +11,8 @@ from convlab2.dialog_agent.agent import PipelineAgent
from convlab2.dialog_agent.session import BiSession from convlab2.dialog_agent.session import BiSession
from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
from convlab2.policy.rule.multiwoz import RulePolicy from convlab2.policy.rule.multiwoz import RulePolicy
from convlab2.util.custom_util import set_seed, get_config, env_config from convlab2.task.multiwoz.goal_generator import GoalGenerator
from convlab2.util.custom_util import set_seed, get_config, env_config, create_goals
def init_logging(log_dir_path, path_suffix=None): def init_logging(log_dir_path, path_suffix=None):
...@@ -66,9 +67,14 @@ def evaluate(config_path, model_name, verbose=False): ...@@ -66,9 +67,14 @@ def evaluate(config_path, model_name, verbose=False):
task_success = {'Complete': [], 'Success': [], task_success = {'Complete': [], 'Success': [],
'Success strict': [], 'total_return': [], 'turns': []} 'Success strict': [], 'total_return': [], 'turns': []}
for seed in range(1000, 1400):
dialogues = 500
goal_generator = GoalGenerator()
goals = create_goals(goal_generator, num_goals=dialogues, single_domains=False, allowed_domains=None)
for seed in range(1000, 1000 + dialogues):
set_seed(seed) set_seed(seed)
sess.init_session() sess.init_session(goal=goals[seed-1000])
sys_response = [] sys_response = []
actions = 0.0 actions = 0.0
total_return = 0.0 total_return = 0.0
......
...@@ -5,14 +5,9 @@ ...@@ -5,14 +5,9 @@
"pretrained_load_path": "", "pretrained_load_path": "",
"batchsz": 1000, "batchsz": 1000,
"seed": 0, "seed": 0,
"epoch": 200, "epoch": 50,
"eval_frequency": 5, "eval_frequency": 5,
"process_num": 4, "process_num": 4,
"use_masking": false,
"use_state_entropy": false,
"manually_add_entity_names": false,
"use_state_mutual_info": false,
"use_confidence_scores": false,
"sys_semantic_to_usr": false, "sys_semantic_to_usr": false,
"num_eval_dialogues": 500 "num_eval_dialogues": 500
}, },
......
...@@ -10,6 +10,7 @@ import logging ...@@ -10,6 +10,7 @@ import logging
import time import time
import numpy as np import numpy as np
import torch import torch
import random
from convlab2.policy.gdpl import GDPL from convlab2.policy.gdpl import GDPL
from convlab2.policy.gdpl import RewardEstimator from convlab2.policy.gdpl import RewardEstimator
...@@ -47,7 +48,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0): ...@@ -47,7 +48,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
:return: :return:
""" """
buff = Memory(seed=train_seed) buff = Memory()
# we need to sample batchsz of (state, action, next_state, reward, mask) # we need to sample batchsz of (state, action, next_state, reward, mask)
# each trajectory contains `trajectory_len` num of items, so we only need to sample # each trajectory contains `trajectory_len` num of items, so we only need to sample
# `batchsz//trajectory_len` num of trajectory totally # `batchsz//trajectory_len` num of trajectory totally
...@@ -58,6 +59,8 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0): ...@@ -58,6 +59,8 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
traj_len = 50 traj_len = 50
real_traj_len = 0 real_traj_len = 0
set_seed(train_seed)
while sampled_num < batchsz: while sampled_num < batchsz:
# for each trajectory, we reset the env and get initial state # for each trajectory, we reset the env and get initial state
s = env.reset() s = env.reset()
...@@ -121,6 +124,7 @@ def sample(env, policy, batchsz, process_num, seed): ...@@ -121,6 +124,7 @@ def sample(env, policy, batchsz, process_num, seed):
# batchsz will be splitted into each process, # batchsz will be splitted into each process,
# final batchsz maybe larger than batchsz parameters # final batchsz maybe larger than batchsz parameters
process_batchsz = np.ceil(batchsz / process_num).astype(np.int32) process_batchsz = np.ceil(batchsz / process_num).astype(np.int32)
train_seeds = random.sample(range(0, 1000), process_num)
# buffer to save all data # buffer to save all data
queue = mp.Queue() queue = mp.Queue()
...@@ -134,7 +138,7 @@ def sample(env, policy, batchsz, process_num, seed): ...@@ -134,7 +138,7 @@ def sample(env, policy, batchsz, process_num, seed):
evt = mp.Event() evt = mp.Event()
processes = [] processes = []
for i in range(process_num): for i in range(process_num):
process_args = (i, queue, evt, env, policy, process_batchsz, seed) process_args = (i, queue, evt, env, policy, process_batchsz, train_seeds[i])
processes.append(mp.Process(target=sampler, args=process_args)) processes.append(mp.Process(target=sampler, args=process_args))
for p in processes: for p in processes:
# set the process as daemon, and it will be killed once the main process is stoped. # set the process as daemon, and it will be killed once the main process is stoped.
......
{ {
"batchsz": 32, "batchsz": 32,
"gamma": 0.99, "gamma": 0.99,
"lr": 0.00001, "lr": 0.0000001,
"save_dir": "save", "save_dir": "save",
"log_dir": "log", "log_dir": "log",
"save_per_epoch": 5, "save_per_epoch": 5,
......
...@@ -5,14 +5,9 @@ ...@@ -5,14 +5,9 @@
"pretrained_load_path": "", "pretrained_load_path": "",
"batchsz": 1000, "batchsz": 1000,
"seed": 0, "seed": 0,
"epoch": 200, "epoch": 50,
"eval_frequency": 5, "eval_frequency": 5,
"process_num": 4, "process_num": 4,
"use_masking": false,
"use_state_entropy": false,
"manually_add_entity_names": false,
"use_state_mutual_info": false,
"use_confidence_scores": false,
"sys_semantic_to_usr": false, "sys_semantic_to_usr": false,
"num_eval_dialogues": 500 "num_eval_dialogues": 500
}, },
......
...@@ -10,6 +10,7 @@ import logging ...@@ -10,6 +10,7 @@ import logging
import time import time
import numpy as np import numpy as np
import torch import torch
import random
from convlab2.policy.pg import PG from convlab2.policy.pg import PG
from convlab2.policy.rlmodule import Memory from convlab2.policy.rlmodule import Memory
...@@ -46,7 +47,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0): ...@@ -46,7 +47,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
:return: :return:
""" """
buff = Memory(seed=train_seed) buff = Memory()
# we need to sample batchsz of (state, action, next_state, reward, mask) # we need to sample batchsz of (state, action, next_state, reward, mask)
# each trajectory contains `trajectory_len` num of items, so we only need to sample # each trajectory contains `trajectory_len` num of items, so we only need to sample
# `batchsz//trajectory_len` num of trajectory totally # `batchsz//trajectory_len` num of trajectory totally
...@@ -57,6 +58,8 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0): ...@@ -57,6 +58,8 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
traj_len = 50 traj_len = 50
real_traj_len = 0 real_traj_len = 0
set_seed(train_seed)
while sampled_num < batchsz: while sampled_num < batchsz:
# for each trajectory, we reset the env and get initial state # for each trajectory, we reset the env and get initial state
s = env.reset() s = env.reset()
...@@ -120,6 +123,7 @@ def sample(env, policy, batchsz, process_num, seed): ...@@ -120,6 +123,7 @@ def sample(env, policy, batchsz, process_num, seed):
# batchsz will be splitted into each process, # batchsz will be splitted into each process,
# final batchsz maybe larger than batchsz parameters # final batchsz maybe larger than batchsz parameters
process_batchsz = np.ceil(batchsz / process_num).astype(np.int32) process_batchsz = np.ceil(batchsz / process_num).astype(np.int32)
train_seeds = random.sample(range(0, 1000), process_num)
# buffer to save all data # buffer to save all data
queue = mp.Queue() queue = mp.Queue()
...@@ -133,7 +137,7 @@ def sample(env, policy, batchsz, process_num, seed): ...@@ -133,7 +137,7 @@ def sample(env, policy, batchsz, process_num, seed):
evt = mp.Event() evt = mp.Event()
processes = [] processes = []
for i in range(process_num): for i in range(process_num):
process_args = (i, queue, evt, env, policy, process_batchsz, seed) process_args = (i, queue, evt, env, policy, process_batchsz, train_seeds[i])
processes.append(mp.Process(target=sampler, args=process_args)) processes.append(mp.Process(target=sampler, args=process_args))
for p in processes: for p in processes:
# set the process as daemon, and it will be killed once the main process is stoped. # set the process as daemon, and it will be killed once the main process is stoped.
......
...@@ -5,14 +5,9 @@ ...@@ -5,14 +5,9 @@
"pretrained_load_path": "", "pretrained_load_path": "",
"batchsz": 1000, "batchsz": 1000,
"seed": 0, "seed": 0,
"epoch": 200, "epoch": 50,
"eval_frequency": 5, "eval_frequency": 5,
"process_num": 4, "process_num": 4,
"use_masking": false,
"use_state_entropy": false,
"manually_add_entity_names": false,
"use_state_mutual_info": false,
"use_confidence_scores": false,
"sys_semantic_to_usr": false, "sys_semantic_to_usr": false,
"num_eval_dialogues": 500 "num_eval_dialogues": 500
}, },
......
...@@ -10,6 +10,7 @@ import logging ...@@ -10,6 +10,7 @@ import logging
import time import time
import numpy as np import numpy as np
import torch import torch
import random
from convlab2.policy.ppo import PPO from convlab2.policy.ppo import PPO
from convlab2.policy.rlmodule import Memory from convlab2.policy.rlmodule import Memory
...@@ -46,7 +47,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0): ...@@ -46,7 +47,7 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
:return: :return:
""" """
buff = Memory(seed=train_seed) buff = Memory()
# we need to sample batchsz of (state, action, next_state, reward, mask) # we need to sample batchsz of (state, action, next_state, reward, mask)
# each trajectory contains `trajectory_len` num of items, so we only need to sample # each trajectory contains `trajectory_len` num of items, so we only need to sample
# `batchsz//trajectory_len` num of trajectory totally # `batchsz//trajectory_len` num of trajectory totally
...@@ -57,6 +58,8 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0): ...@@ -57,6 +58,8 @@ def sampler(pid, queue, evt, env, policy, batchsz, train_seed=0):
traj_len = 50 traj_len = 50
real_traj_len = 0 real_traj_len = 0
set_seed(train_seed)
while sampled_num < batchsz: while sampled_num < batchsz:
# for each trajectory, we reset the env and get initial state # for each trajectory, we reset the env and get initial state
s = env.reset() s = env.reset()
...@@ -120,6 +123,7 @@ def sample(env, policy, batchsz, process_num, seed): ...@@ -120,6 +123,7 @@ def sample(env, policy, batchsz, process_num, seed):
# batchsz will be splitted into each process, # batchsz will be splitted into each process,
# final batchsz maybe larger than batchsz parameters # final batchsz maybe larger than batchsz parameters
process_batchsz = np.ceil(batchsz / process_num).astype(np.int32) process_batchsz = np.ceil(batchsz / process_num).astype(np.int32)
train_seeds = random.sample(range(0, 1000), process_num)
# buffer to save all data # buffer to save all data
queue = mp.Queue() queue = mp.Queue()
...@@ -133,7 +137,7 @@ def sample(env, policy, batchsz, process_num, seed): ...@@ -133,7 +137,7 @@ def sample(env, policy, batchsz, process_num, seed):
evt = mp.Event() evt = mp.Event()
processes = [] processes = []
for i in range(process_num): for i in range(process_num):
process_args = (i, queue, evt, env, policy, process_batchsz, seed) process_args = (i, queue, evt, env, policy, process_batchsz, train_seeds[i])
processes.append(mp.Process(target=sampler, args=process_args)) processes.append(mp.Process(target=sampler, args=process_args))
for p in processes: for p in processes:
# set the process as daemon, and it will be killed once the main process is stoped. # set the process as daemon, and it will be killed once the main process is stoped.
......
...@@ -319,17 +319,8 @@ Transition = namedtuple('Transition', ('state', 'action', ...@@ -319,17 +319,8 @@ Transition = namedtuple('Transition', ('state', 'action',
class Memory(object): class Memory(object):
def __init__(self, seed=0): def __init__(self):
self.memory = [] self.memory = []
self.set_seed(seed)
def set_seed(self, seed):
np.random.seed(seed)
torch.random.manual_seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def push(self, *args): def push(self, *args):
"""Saves a transition.""" """Saves a transition."""
......
...@@ -135,8 +135,8 @@ class UserPolicyAgendaMultiWoz(Policy): ...@@ -135,8 +135,8 @@ class UserPolicyAgendaMultiWoz(Policy):
action = {} action = {}
while len(action) == 0: while len(action) == 0:
# A -> A' + user_action # A -> A' + user_action
# action = self.agenda.get_action(random.randint(2, self.max_initiative)) action = self.agenda.get_action(random.randint(1, self.max_initiative))
action = self.agenda.get_action(self.max_initiative) #action = self.agenda.get_action(self.max_initiative)
# transform to DA # transform to DA
action = self._transform_usract_out(action) action = self._transform_usract_out(action)
......
...@@ -18,6 +18,8 @@ from convlab2.dst.rule.multiwoz import RuleDST ...@@ -18,6 +18,8 @@ from convlab2.dst.rule.multiwoz import RuleDST
from convlab2.policy.rule.multiwoz import RulePolicy from convlab2.policy.rule.multiwoz import RulePolicy
from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator from convlab2.evaluator.multiwoz_eval import MultiWozEvaluator
from convlab2.util import load_dataset from convlab2.util import load_dataset
from convlab2.policy.rule.multiwoz.policy_agenda_multiwoz import Goal
import shutil import shutil
...@@ -435,6 +437,19 @@ def act_dict_to_flat_tuple(acts): ...@@ -435,6 +437,19 @@ def act_dict_to_flat_tuple(acts):
tuples.append([intent, domain, slot, value]) tuples.append([intent, domain, slot, value])
def create_goals(goal_generator, num_goals, single_domains=False, allowed_domains=None):
collected_goals = []
while len(collected_goals) != num_goals:
goal = Goal(goal_generator)
if single_domains and len(goal.domain_goals) > 1:
continue
if allowed_domains is not None and not set(goal.domain_goals).issubset(set(allowed_domains)):
continue
collected_goals.append(goal)
return collected_goals
def map_class(cls_path: str): def map_class(cls_path: str):
""" """
Map to class via package text path Map to class via package text path
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment