Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
E
emoUS-public
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
general
dsml
emoUS-public
Commits
d47f79e5
Commit
d47f79e5
authored
3 years ago
by
zqwerty
Browse files
Options
Downloads
Patches
Plain Diff
support load dataset from hf dataset
parent
f4695da9
No related branches found
No related tags found
No related merge requests found
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
convlab/util/unified_datasets_util.py
+30
-5
30 additions, 5 deletions
convlab/util/unified_datasets_util.py
data/unified_datasets/multiwoz21/preprocess.py
+1
-1
1 addition, 1 deletion
data/unified_datasets/multiwoz21/preprocess.py
with
31 additions
and
6 deletions
convlab/util/unified_datasets_util.py
+
30
−
5
View file @
d47f79e5
...
@@ -7,6 +7,8 @@ import re
...
@@ -7,6 +7,8 @@ import re
import
importlib
import
importlib
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
pprint
import
pprint
from
pprint
import
pprint
from
convlab.util.file_util
import
cached_path
import
shutil
class
BaseDatabase
(
ABC
):
class
BaseDatabase
(
ABC
):
...
@@ -18,6 +20,23 @@ class BaseDatabase(ABC):
...
@@ -18,6 +20,23 @@ class BaseDatabase(ABC):
def
query
(
self
,
domain
:
str
,
state
:
dict
,
topk
:
int
,
**
kwargs
)
->
list
:
def
query
(
self
,
domain
:
str
,
state
:
dict
,
topk
:
int
,
**
kwargs
)
->
list
:
"""
return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.
"""
"""
return a list of topk entities (dict containing slot-value pairs) for a given domain based on the dialogue state.
"""
def
load_from_hf_datasets
(
dataset_name
,
filename
,
data_dir
):
"""
It downloads the file from the Hugging Face if it doesn
'
t exist in the data directory
:param dataset_name: The name of the dataset
:param filename: the name of the file you want to download
:param data_dir: the directory where the data will be downloaded to
:return: The data path
"""
data_path
=
os
.
path
.
join
(
data_dir
,
filename
)
if
not
os
.
path
.
exists
(
data_path
):
if
not
os
.
path
.
exists
(
data_dir
):
os
.
makedirs
(
data_dir
,
exist_ok
=
True
)
data_url
=
f
'
https://huggingface.co/datasets/ConvLab/
{
dataset_name
}
/resolve/main/
{
filename
}
'
cache_path
=
cached_path
(
data_url
)
shutil
.
move
(
cache_path
,
data_path
)
return
data_path
def
load_dataset
(
dataset_name
:
str
,
dial_ids_order
=
None
)
->
Dict
:
def
load_dataset
(
dataset_name
:
str
,
dial_ids_order
=
None
)
->
Dict
:
"""
load unified dataset from `data/unified_datasets/$dataset_name`
"""
load unified dataset from `data/unified_datasets/$dataset_name`
...
@@ -30,12 +49,15 @@ def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict:
...
@@ -30,12 +49,15 @@ def load_dataset(dataset_name:str, dial_ids_order=None) -> Dict:
dataset (dict): keys are data splits and the values are lists of dialogues
dataset (dict): keys are data splits and the values are lists of dialogues
"""
"""
data_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
abspath
(
__file__
),
f
'
../../../data/unified_datasets/
{
dataset_name
}
'
))
data_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
abspath
(
__file__
),
f
'
../../../data/unified_datasets/
{
dataset_name
}
'
))
archive
=
ZipFile
(
os
.
path
.
join
(
data_dir
,
'
data.zip
'
))
data_path
=
load_from_hf_datasets
(
dataset_name
,
'
data.zip
'
,
data_dir
)
archive
=
ZipFile
(
data_path
)
with
archive
.
open
(
'
data/dialogues.json
'
)
as
f
:
with
archive
.
open
(
'
data/dialogues.json
'
)
as
f
:
dialogues
=
json
.
loads
(
f
.
read
())
dialogues
=
json
.
loads
(
f
.
read
())
dataset
=
{}
dataset
=
{}
if
dial_ids_order
is
not
None
:
if
dial_ids_order
is
not
None
:
dial_ids
=
json
.
load
(
open
(
os
.
path
.
join
(
data_dir
,
'
shuffled_dial_ids.json
'
)))[
dial_ids_order
]
data_path
=
load_from_hf_datasets
(
dataset_name
,
'
shuffled_dial_ids.json
'
,
data_dir
)
dial_ids
=
json
.
load
(
open
(
data_path
))[
dial_ids_order
]
for
data_split
in
dial_ids
:
for
data_split
in
dial_ids
:
dataset
[
data_split
]
=
[
dialogues
[
i
]
for
i
in
dial_ids
[
data_split
]]
dataset
[
data_split
]
=
[
dialogues
[
i
]
for
i
in
dial_ids
[
data_split
]]
else
:
else
:
...
@@ -56,7 +78,9 @@ def load_ontology(dataset_name:str) -> Dict:
...
@@ -56,7 +78,9 @@ def load_ontology(dataset_name:str) -> Dict:
ontology (dict): dataset ontology
ontology (dict): dataset ontology
"""
"""
data_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
abspath
(
__file__
),
f
'
../../../data/unified_datasets/
{
dataset_name
}
'
))
data_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
abspath
(
__file__
),
f
'
../../../data/unified_datasets/
{
dataset_name
}
'
))
archive
=
ZipFile
(
os
.
path
.
join
(
data_dir
,
'
data.zip
'
))
data_path
=
load_from_hf_datasets
(
dataset_name
,
'
data.zip
'
,
data_dir
)
archive
=
ZipFile
(
data_path
)
with
archive
.
open
(
'
data/ontology.json
'
)
as
f
:
with
archive
.
open
(
'
data/ontology.json
'
)
as
f
:
ontology
=
json
.
loads
(
f
.
read
())
ontology
=
json
.
loads
(
f
.
read
())
return
ontology
return
ontology
...
@@ -70,8 +94,9 @@ def load_database(dataset_name:str):
...
@@ -70,8 +94,9 @@ def load_database(dataset_name:str):
Returns:
Returns:
database: an instance of BaseDatabase
database: an instance of BaseDatabase
"""
"""
data_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
abspath
(
__file__
),
f
'
../../../data/unified_datasets/
{
dataset_name
}
/database.py
'
))
data_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
abspath
(
__file__
),
f
'
../../../data/unified_datasets/
{
dataset_name
}
'
))
module_spec
=
importlib
.
util
.
spec_from_file_location
(
'
database
'
,
data_dir
)
data_path
=
load_from_hf_datasets
(
dataset_name
,
'
database.py
'
,
data_dir
)
module_spec
=
importlib
.
util
.
spec_from_file_location
(
'
database
'
,
data_path
)
module
=
importlib
.
util
.
module_from_spec
(
module_spec
)
module
=
importlib
.
util
.
module_from_spec
(
module_spec
)
module_spec
.
loader
.
exec_module
(
module
)
module_spec
.
loader
.
exec_module
(
module
)
Database
=
module
.
Database
Database
=
module
.
Database
...
...
This diff is collapsed.
Click to expand it.
data/unified_datasets/multiwoz21/preprocess.py
+
1
−
1
View file @
d47f79e5
...
@@ -8,7 +8,7 @@ from tqdm import tqdm
...
@@ -8,7 +8,7 @@ from tqdm import tqdm
from
collections
import
Counter
from
collections
import
Counter
from
pprint
import
pprint
from
pprint
import
pprint
from
nltk.tokenize
import
TreebankWordTokenizer
,
PunktSentenceTokenizer
from
nltk.tokenize
import
TreebankWordTokenizer
,
PunktSentenceTokenizer
from
data.unified_datasets.multiwoz21
.booking_remapper
import
BookingActRemapper
from
.booking_remapper
import
BookingActRemapper
ontology
=
{
ontology
=
{
"
domains
"
:
{
# descriptions are adapted from multiwoz22, but is_categorical may be different
"
domains
"
:
{
# descriptions are adapted from multiwoz22, but is_categorical may be different
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment