Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
T
TUS_public
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
general
dsml
TUS_public
Commits
68d96eda
Commit
68d96eda
authored
4 years ago
by
function2
Browse files
Options
Downloads
Plain Diff
Merge branch 'master' of github.com:thu-coai/ConvLab-2
parents
dcf8c504
e368deeb
No related branches found
No related tags found
No related merge requests found
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
convlab2/policy/dqn/dqn.py
+26
-7
26 additions, 7 deletions
convlab2/policy/dqn/dqn.py
convlab2/policy/dqn/train.py
+76
-2
76 additions, 2 deletions
convlab2/policy/dqn/train.py
convlab2/policy/hdsa/multiwoz/transformer/Beam.py
+1
-1
1 addition, 1 deletion
convlab2/policy/hdsa/multiwoz/transformer/Beam.py
with
103 additions
and
10 deletions
convlab2/policy/dqn/dqn.py
+
26
−
7
View file @
68d96eda
...
@@ -11,6 +11,7 @@ from convlab2.policy.policy import Policy
...
@@ -11,6 +11,7 @@ from convlab2.policy.policy import Policy
from
convlab2.policy.rlmodule
import
EpsilonGreedyPolicy
,
MemoryReplay
from
convlab2.policy.rlmodule
import
EpsilonGreedyPolicy
,
MemoryReplay
from
convlab2.util.train_util
import
init_logging_handler
from
convlab2.util.train_util
import
init_logging_handler
from
convlab2.policy.vector.vector_multiwoz
import
MultiWozVector
from
convlab2.policy.vector.vector_multiwoz
import
MultiWozVector
from
convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot
import
RuleBasedMultiwozBot
from
convlab2.util.file_util
import
cached_path
from
convlab2.util.file_util
import
cached_path
import
zipfile
import
zipfile
import
sys
import
sys
...
@@ -32,6 +33,8 @@ class DQN(Policy):
...
@@ -32,6 +33,8 @@ class DQN(Policy):
self
.
training_iter
=
cfg
[
'
training_iter
'
]
self
.
training_iter
=
cfg
[
'
training_iter
'
]
self
.
training_batch_iter
=
cfg
[
'
training_batch_iter
'
]
self
.
training_batch_iter
=
cfg
[
'
training_batch_iter
'
]
self
.
batch_size
=
cfg
[
'
batch_size
'
]
self
.
batch_size
=
cfg
[
'
batch_size
'
]
self
.
epsilon
=
cfg
[
'
epsilon_spec
'
][
'
start
'
]
self
.
rule_bot
=
RuleBasedMultiwozBot
()
self
.
gamma
=
cfg
[
'
gamma
'
]
self
.
gamma
=
cfg
[
'
gamma
'
]
self
.
is_train
=
is_train
self
.
is_train
=
is_train
if
is_train
:
if
is_train
:
...
@@ -58,9 +61,10 @@ class DQN(Policy):
...
@@ -58,9 +61,10 @@ class DQN(Policy):
self
.
loss_fn
=
nn
.
MSELoss
()
self
.
loss_fn
=
nn
.
MSELoss
()
def
update_memory
(
self
,
sample
):
def
update_memory
(
self
,
sample
):
self
.
memory
.
reset
()
self
.
memory
.
append
(
sample
)
self
.
memory
.
append
(
sample
)
def
predict
(
self
,
state
):
def
predict
(
self
,
state
,
warm_up
=
False
):
"""
"""
Predict an system action given state.
Predict an system action given state.
Args:
Args:
...
@@ -68,14 +72,29 @@ class DQN(Policy):
...
@@ -68,14 +72,29 @@ class DQN(Policy):
Returns:
Returns:
action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
"""
"""
if
warm_up
:
action
=
self
.
rule_action
(
state
)
state
[
'
system_action
'
]
=
action
else
:
s_vec
=
torch
.
Tensor
(
self
.
vector
.
state_vectorize
(
state
))
s_vec
=
torch
.
Tensor
(
self
.
vector
.
state_vectorize
(
state
))
a
=
self
.
net
.
select_action
(
s_vec
.
to
(
device
=
DEVICE
))
a
=
self
.
net
.
select_action
(
s_vec
.
to
(
device
=
DEVICE
),
is_train
=
self
.
is_train
)
action
=
self
.
vector
.
action_devectorize
(
a
.
numpy
())
action
=
self
.
vector
.
action_devectorize
(
a
.
numpy
())
state
[
'
system_action
'
]
=
action
state
[
'
system_action
'
]
=
action
return
action
return
action
def
rule_action
(
self
,
state
):
if
self
.
epsilon
>
np
.
random
.
rand
():
a
=
torch
.
randint
(
self
.
vector
.
da_dim
,
(
1
,
))
# transforms action index to a vector action (one-hot encoding)
a_vec
=
torch
.
zeros
(
self
.
vector
.
da_dim
)
a_vec
[
a
]
=
1.
action
=
self
.
vector
.
action_devectorize
(
a_vec
.
numpy
())
else
:
# rule-based warm up
action
=
self
.
rule_bot
.
predict
(
state
)
return
action
def
init_session
(
self
):
def
init_session
(
self
):
"""
"""
Restore after one session
Restore after one session
...
...
This diff is collapsed.
Click to expand it.
convlab2/policy/dqn/train.py
+
76
−
2
View file @
68d96eda
...
@@ -90,8 +90,71 @@ def sampler(pid, queue, evt, env, policy, batchsz):
...
@@ -90,8 +90,71 @@ def sampler(pid, queue, evt, env, policy, batchsz):
queue
.
put
([
pid
,
buff
])
queue
.
put
([
pid
,
buff
])
evt
.
wait
()
evt
.
wait
()
def
warmupsampler
(
pid
,
queue
,
evt
,
env
,
policy
,
batchsz
):
"""
This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
processes.
:param pid: process id
:param queue: multiprocessing.Queue, to collect sampled data
:param evt: multiprocessing.Event, to keep the process alive
:param env: environment instance
:param policy: policy network, to generate action from current policy
:param batchsz: total sampled items
:return:
"""
buff
=
Memory
()
# we need to sample batchsz of (state, action, next_state, reward, mask)
# each trajectory contains `trajectory_len` num of items, so we only need to sample
# `batchsz//trajectory_len` num of trajectory totally
# the final sampled number may be larger than batchsz.
sampled_num
=
0
sampled_traj_num
=
0
traj_len
=
50
real_traj_len
=
0
while
sampled_num
<
batchsz
:
# for each trajectory, we reset the env and get initial state
s
=
env
.
reset
()
for
t
in
range
(
traj_len
):
# [s_dim] => [a_dim]
s_vec
=
torch
.
Tensor
(
policy
.
vector
.
state_vectorize
(
s
))
a
=
policy
.
predict
(
s
,
warm_up
=
True
)
# interact with env
next_s
,
r
,
done
=
env
.
step
(
a
)
# a flag indicates ending or not
mask
=
0
if
done
else
1
# get reward compared to demostrations
next_s_vec
=
torch
.
Tensor
(
policy
.
vector
.
state_vectorize
(
next_s
))
# save to queue
buff
.
push
(
s_vec
.
numpy
(),
policy
.
vector
.
action_vectorize
(
a
),
r
,
next_s_vec
.
numpy
(),
mask
)
# update per step
s
=
next_s
real_traj_len
=
t
if
done
:
break
def
sample
(
env
,
policy
,
batchsz
,
process_num
):
# this is end of one trajectory
sampled_num
+=
real_traj_len
sampled_traj_num
+=
1
# t indicates the valid trajectory length
# this is end of sampling all batchsz of items.
# when sampling is over, push all buff data into queue
queue
.
put
([
pid
,
buff
])
evt
.
wait
()
def
sample
(
env
,
policy
,
batchsz
,
process_num
,
warm_up
=
False
):
"""
"""
Given batchsz number of task, the batchsz will be splited equally to each processes
Given batchsz number of task, the batchsz will be splited equally to each processes
and when processes return, it merge all data and return
and when processes return, it merge all data and return
...
@@ -119,6 +182,9 @@ def sample(env, policy, batchsz, process_num):
...
@@ -119,6 +182,9 @@ def sample(env, policy, batchsz, process_num):
processes
=
[]
processes
=
[]
for
i
in
range
(
process_num
):
for
i
in
range
(
process_num
):
process_args
=
(
i
,
queue
,
evt
,
env
,
policy
,
process_batchsz
)
process_args
=
(
i
,
queue
,
evt
,
env
,
policy
,
process_batchsz
)
if
warm_up
:
processes
.
append
(
mp
.
Process
(
target
=
warmupsampler
,
args
=
process_args
))
else
:
processes
.
append
(
mp
.
Process
(
target
=
sampler
,
args
=
process_args
))
processes
.
append
(
mp
.
Process
(
target
=
sampler
,
args
=
process_args
))
for
p
in
processes
:
for
p
in
processes
:
# set the process as daemon, and it will be killed once the main process is stoped.
# set the process as daemon, and it will be killed once the main process is stoped.
...
@@ -146,6 +212,13 @@ def update(env, policy, batchsz, epoch, process_num):
...
@@ -146,6 +212,13 @@ def update(env, policy, batchsz, epoch, process_num):
policy
.
update
(
epoch
)
policy
.
update
(
epoch
)
def
warm_start
(
env
,
policy
,
batchsz
,
epoch
,
process_num
):
# sample data asynchronously
buff
=
sample
(
env
,
policy
,
batchsz
,
process_num
,
warm_up
=
True
)
policy
.
update_memory
(
buff
)
policy
.
update
(
epoch
)
if
__name__
==
'
__main__
'
:
if
__name__
==
'
__main__
'
:
parser
=
ArgumentParser
()
parser
=
ArgumentParser
()
parser
.
add_argument
(
"
--load_path
"
,
type
=
str
,
default
=
""
,
help
=
"
path of model to load
"
)
parser
.
add_argument
(
"
--load_path
"
,
type
=
str
,
default
=
""
,
help
=
"
path of model to load
"
)
...
@@ -170,6 +243,7 @@ if __name__ == '__main__':
...
@@ -170,6 +243,7 @@ if __name__ == '__main__':
evaluator
=
MultiWozEvaluator
()
evaluator
=
MultiWozEvaluator
()
env
=
Environment
(
None
,
simulator
,
None
,
dst_sys
,
evaluator
)
env
=
Environment
(
None
,
simulator
,
None
,
dst_sys
,
evaluator
)
warm_start
(
env
,
policy_sys
,
args
.
batchsz
,
0
,
args
.
process_num
)
for
i
in
range
(
args
.
epoch
):
for
i
in
range
(
args
.
epoch
):
update
(
env
,
policy_sys
,
args
.
batchsz
,
i
,
args
.
process_num
)
update
(
env
,
policy_sys
,
args
.
batchsz
,
i
,
args
.
process_num
)
This diff is collapsed.
Click to expand it.
convlab2/policy/hdsa/multiwoz/transformer/Beam.py
+
1
−
1
View file @
68d96eda
...
@@ -66,7 +66,7 @@ class Beam(object):
...
@@ -66,7 +66,7 @@ class Beam(object):
# bestScoresId is flattened as a (beam x word) array,
# bestScoresId is flattened as a (beam x word) array,
# so we need to calculate which word and beam each score came from
# so we need to calculate which word and beam each score came from
prev_k
=
best_scores_id
/
num_words
prev_k
=
best_scores_id
/
/
num_words
self
.
prev_ks
.
append
(
prev_k
)
self
.
prev_ks
.
append
(
prev_k
)
self
.
next_ys
.
append
(
best_scores_id
-
prev_k
*
num_words
)
self
.
next_ys
.
append
(
best_scores_id
-
prev_k
*
num_words
)
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment