Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
E
emoUS-public
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
general
dsml
emoUS-public
Commits
e0c4d01b
Commit
e0c4d01b
authored
2 years ago
by
zz-jacob
Browse files
Options
Downloads
Patches
Plain Diff
fix scgpt.py
parent
54c6842d
No related branches found
No related tags found
No related merge requests found
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
convlab/nlg/scgpt/evaluate.sh
+5
-5
5 additions, 5 deletions
convlab/nlg/scgpt/evaluate.sh
convlab/nlg/scgpt/scgpt.py
+9
-17
9 additions, 17 deletions
convlab/nlg/scgpt/scgpt.py
convlab/nlg/scgpt/train.sh
+8
-8
8 additions, 8 deletions
convlab/nlg/scgpt/train.sh
with
22 additions
and
30 deletions
convlab/nlg/scgpt/evaluate.sh
+
5
−
5
View file @
e0c4d01b
CUDA_VISIBLE_DEVICES
=
"
1
"
python
-m
torch.distributed.launch
--nproc_per_node
1
--master_port
205
1
main.py
\
--batch_size
64
\
CUDA_VISIBLE_DEVICES
=
"
0
"
python
-m
torch.distributed.launch
--nproc_per_node
1
--master_port
205
0
main.py
\
--batch_size
128
\
--base_model_name_path
gpt2-medium
\
--dataset
multiwoz21
\
--exp_name
gpt2_mwoz2
\
--model_path
saved_models/gpt2_mwoz/epoch_2/epoch_2_step1329.pt
\
\ No newline at end of file
--dataset
sgd
\
--exp_name
gpt2_sgd_test
\
--model_path
saved_models/exp_name/epoch_x/epoch_7_step10312.pt
\
\ No newline at end of file
This diff is collapsed.
Click to expand it.
convlab/nlg/scgpt/scgpt.py
+
9
−
17
View file @
e0c4d01b
...
...
@@ -2,27 +2,22 @@ import sys
sys
.
path
.
append
(
'
../../..
'
)
import
torch
from
transformers
import
GPT2Tokenizer
,
GPT2LMHeadModel
from
transformers
import
GPT2Tokenizer
,
GPT2LMHeadModel
,
GPT2Config
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
convlab.nlg.nlg
import
NLG
from
util
import
act2str
from
scgpt_special_tokens
import
*
special_tokens
=
[
START_OF_PRED
,
END_OF_PRED
,
SYS_SPEAK
,
USR_SPEAK
]
class
SCGPT
(
NLG
):
def
__init__
(
self
,
dataset_name
,
model_path
,
device
=
'
cpu
'
):
super
(
SCGPT
,
self
).
__init__
()
self
.
device
=
device
self
.
model
=
GPT2LMHeadModel
.
from_pretrained
(
'
gpt2
'
).
to
(
self
.
device
)
self
.
model
=
GPT2LMHeadModel
(
config
=
GPT2Config
.
from_pretrained
(
'
gpt2
'
)
)
.
to
(
self
.
device
)
self
.
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
'
gpt2
'
)
self
.
tokenizer
.
add_special_tokens
({
'
pad_token
'
:
PAD_TOKEN
,
'
eos_token
'
:
END_OF_PRED
,
'
additional_special_tokens
'
:
special_tokens
})
self
.
model
.
resize_token_embeddings
(
len
(
self
.
tokenizer
))
self
.
model
.
load_state_dict
(
torch
.
load
(
model_path
))
def
generate
(
self
,
action
):
action_str
=
act2str
(
action
)
output
=
self
.
_inference_batch
([
action_str
])[
0
]
...
...
@@ -30,17 +25,14 @@ class SCGPT(NLG):
def
_inference_batch
(
self
,
sents
):
with
torch
.
no_grad
():
sents
=
[
sent
+
'
'
+
START_OF_PRED
for
sent
in
sents
]
sent_ids
=
[
self
.
tokenizer
.
encode
(
sent
)
for
sent
in
sents
]
sents
=
[
sent
for
sent
in
sents
]
sent_ids
=
[
self
.
tokenizer
.
encode
(
sent
)
+
[
self
.
tokenizer
.
_convert_token_to_id_with_added_voc
(
'
&
'
)]
for
sent
in
sents
]
max_len
=
max
([
len
(
sent
)
for
sent
in
sent_ids
])
sent_ids
=
[
[
se
lf
.
tokenizer
.
pad_token_id
]
*
(
max_len
-
len
(
sent
))
+
sent
for
sent
in
sent_ids
]
sent_ids
=
[
se
nt
+
[
0
]
*
(
max_len
-
len
(
sent
))
for
sent
in
sent_ids
]
inputs
=
torch
.
LongTensor
(
sent_ids
).
to
(
self
.
device
)
model_to_run
=
self
.
model
.
module
if
type
(
self
.
model
)
is
DDP
else
self
.
model
outputs
=
model_to_run
.
generate
(
inputs
,
max_length
=
256
,
eos_token_id
=
self
.
tokenizer
.
pad_token_id
,
pad_token_id
=
self
.
tokenizer
.
pad_token_id
)
# greedy
# outputs = model_to_run.generate(inputs, num_beams=4, max_length=513, eos_token_id=gpt2_tokenizer.eos_token_id,
# pad_token_id=gpt2_tokenizer.pad_token_id) # beam search
output_strs
=
[
self
.
tokenizer
.
decode
(
item
)
for
item
in
outputs
]
outputs
=
model_to_run
.
generate
(
inputs
,
max_length
=
256
,
attention_mask
=
(
inputs
==
0
).
float
(),
eos_token_id
=
self
.
tokenizer
.
pad_token_id
)
# greedy
outputs
=
outputs
[:,
len
(
inputs
[
0
]):]
output_strs
=
[
self
.
tokenizer
.
decode
(
item
).
strip
()
for
item
in
outputs
]
return
output_strs
\ No newline at end of file
This diff is collapsed.
Click to expand it.
convlab/nlg/scgpt/train.sh
+
8
−
8
View file @
e0c4d01b
CUDA_VISIBLE_DEVICES
=
"
3
"
python
-m
torch.distributed.launch
--nproc_per_node
1
--master_port
204
3
main.py
\
CUDA_VISIBLE_DEVICES
=
"
2
"
python
-m
torch.distributed.launch
--nproc_per_node
1
--master_port
204
2
main.py
\
--batch_size
32
\
--accumulation_step
4
\
--epoch_num
2
0
\
--epoch_num
10
0
\
--lr
5e-5
\
--base_model_name_path
/root/autodl-tmp/ConvLab-3/convlab/nlg/scgpt/resource/scgpt
\
--val_step
100
0
\
--exp_name
sc
gpt_mwoz
\
--base_model_name_path
gpt2-medium
\
--val_step
100
\
--exp_name
gpt
2
_mwoz
001_direct
\
--do_train
\
--dataset
sgd
\
--train_ratio
1
.0
\
# --scgpt_model_ckpt_path saved_models/sgd_tm
_1e4
/epoch_
8
/epoch_
8
_step
41094
.pt
--dataset
multiwoz21
\
--train_ratio
0
.0
1
\
# --scgpt_model_ckpt_path saved_models/
gpt2_
sgd_tm/epoch_
2
/epoch_
2
_step
13698
.pt
# --base_model_name_path /root/autodl-tmp/ConvLab-3/convlab/nlg/scgpt/resource/scgpt \
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment