Skip to content
Snippets Groups Projects
Commit 8d845873 authored by Benjamin Ruppik's avatar Benjamin Ruppik
Browse files

Updated sbert embedding script with command line arguments

parent 3e185b76
Branches
Tags
No related merge requests found
# TDA code for Dialogue Term Extraction using Transfer Learning and Topological Data Analysis
This is the Topological Data Analysis portion of the code for the paper
Dialogue Term Extraction using Transfer Learning and Topological Data Analysis.
'Dialogue Term Extraction using Transfer Learning and Topological Data Analysis'.
The scripts in this folder should be executed in the `tda` working directory.
## Create embeddings
TODO
Precomputed sbert embeddings are contained in the `/data` folder
for the ambient fastText vocabulary, and the joint multiwoz and sgd vocabulary.
These embeddings are the basis for computing neighborhoods.
It is not necessary to recompute these embeddings,
for the neighborhood extraction and TDA features skip ahead to the next section.
The following command loads the precomputed embeddings
of the fastText vocabulary into an interactive python session:
```bash
python -i sbert_create_static_embeddings.py \
--embeddings_config_path ./sbert_static_embeddings_config_50_0.yaml \
--vocab_desc pretrained_cc_en \
--load_embeddings
```
To compute and save embeddings of the multiwoz and sgd vocabulary:
```bash
python sbert_create_static_embeddings.py \
--embeddings_config_path ./sbert_static_embeddings_config_50_0.yaml \
--vocab_desc multiwoz_and_sgd \
--save_embeddings
```
## Build neighbourhoods and extract persistence features
TODO
## 8. License
## License
This project is licensed under the Apache License, Version 2.0 (the "License");
you may not use the files except in compliance with the License.
You may obtain a copy of the License at
......
......@@ -49,6 +49,13 @@ parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--embeddings_config_path", default=None, type=str, required=True,
help="Dataset configuration file.")
parser.add_argument("--vocab_desc", choices=["multiwoz_and_sgd", "pretrained_cc_en"],
default=None, type=str, required=True,
help="String describing which embeddings should be created.")
parser.add_argument("--save_embeddings", default=False, action="store_true",
help="Flag to save embeddings to disk.")
parser.add_argument("--load_embeddings", default=False, action="store_true",
help="Flag to load embeddings from disk.")
# Optional parameters
# None
......@@ -59,22 +66,20 @@ args = parser.parse_args()
embeddings_config_path = args.embeddings_config_path
embeddings_config = yaml.safe_load(open(embeddings_config_path))
# # # # # # # # # # # # # # # # #
# MODIFY: Program parameters
DEVELOP_MODE = False
vocab_desc = args.vocab_desc
# the fields in the config file are named
# 'pretrained_cc_en_vocabulary_path' and 'multiwoz_and_sgd_vocabulary_path'
vocab_path = embeddings_config['data'][vocab_desc + '_vocabulary_path']
SAVE_EMBEDDINGS = True
LOAD_EMBEDDINGS = False
SAVE_EMBEDDINGS = args.save_embeddings
LOAD_EMBEDDINGS = args.load_embeddings
data_desc_list = ['paraphrase-MiniLM-L6-v2']
# # # # # # # # # # # # # # # # #
# Debug parameters
DEBUG_DATA_DESC = 'paraphrase-MiniLM-L6-v2'
VOCAB_PATH = embeddings_config['data']['multiwoz_and_sgd_joint_vocabulary_path']
VOCAB_DESC = 'multiwoz_and_sgd'
# VOCAB_PATH = embeddings_config['data']['pretrained_cc_en_vocabulary_path']
# VOCAB_DESC = 'pretrained_cc_en'
# Set up logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)
......@@ -108,12 +113,14 @@ models = {}
# Pooling: Mean Pooling
models['paraphrase-MiniLM-L6-v2'] = SentenceTransformer('paraphrase-MiniLM-L6-v2')
# more models could be added here as additional keys,
# the embeddings will be computed for each model
# Load data
logging.info('Loading data ...')
with open(VOCAB_PATH, "r") as file:
with open(vocab_path, "r") as file:
vocabulary = json.load(file)
# Testing tokenization of certain sentences
......@@ -157,7 +164,7 @@ if SAVE_EMBEDDINGS is True and LOAD_EMBEDDINGS is False:
# save or load embeddings
try:
embeddings_path = os.path.join(embeddings_config['embeddings']['embeddings_dict_path'],
f"{VOCAB_DESC}_vocab_embeddings_sbert.pkl")
f"{vocab_desc}_vocab_embeddings_sbert.pkl")
if SAVE_EMBEDDINGS is True:
logging.info("Saving embeddings ...")
logging.info(embeddings_path)
......@@ -179,12 +186,16 @@ test_vocab = ['.', ':',
'cheap', 'expensive',
'xg1tx6q8', '07591624763']
try:
# Print the embeddings
for word in test_vocab:
embedding = embeddings['paraphrase-MiniLM-L6-v2'][word]
print("Sentence:", word)
print("Embedding:", embedding)
print("")
except Exception as e:
logging.error(traceback.format_exc())
pass
embedding_vectors = {}
......
2022-06-15 12:27:38,561 - root - INFO - Loading config file: ./sbert_static_embeddings_config_50_0.yaml
2022-06-15 12:27:38,561 - root - INFO - {'data': {'data_folder_path': '../data', 'multiwoz_and_sgd_joint_vocabulary_path': '../data/multiwoz_and_sgd_joint_vocabulary.json', 'pretrained_cc_en_vocabulary_path': '../data/pretrained_cc_en_vocabulary.json'}, 'embeddings': {'embeddings_dict_path': '../data', 'embeddings_dataframes_path': '../data', 'context': 'word', 'pooling_method': 'mean', 'special_tokens': 'ignore'}, 'neighborhoods': {'nbhd_size': 50, 'nbhd_remove': 0, 'neighborhoods_path': '../data/neighborhoods', 'persistence_features_path': '../data', 'normalize': False}}
2022-06-15 12:27:38,561 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: paraphrase-MiniLM-L6-v2
2022-06-15 12:27:45,634 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device: cpu
2022-06-15 12:27:45,634 - root - INFO - Loading data ...
2022-06-15 12:27:45,636 - root - INFO - Creating embeddings...
2022-06-15 14:50:27,346 - root - INFO - Loading config file: ./sbert_static_embeddings_config_50_0.yaml
2022-06-15 14:50:27,347 - root - INFO - {'data': {'data_folder_path': '../data', 'multiwoz_and_sgd_vocabulary_path': '../data/multiwoz_and_sgd_joint_vocabulary.json', 'pretrained_cc_en_vocabulary_path': '../data/pretrained_cc_en_vocabulary.json'}, 'embeddings': {'embeddings_dict_path': '../data', 'embeddings_dataframes_path': '../data', 'context': 'word', 'pooling_method': 'mean', 'special_tokens': 'ignore'}, 'neighborhoods': {'nbhd_size': 50, 'nbhd_remove': 0, 'neighborhoods_path': '../data/neighborhoods', 'persistence_features_path': '../data', 'normalize': False}}
2022-06-15 14:50:27,347 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: paraphrase-MiniLM-L6-v2
2022-06-15 14:50:34,329 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device: cpu
2022-06-15 14:50:34,329 - root - INFO - Loading data ...
2022-06-15 14:50:34,469 - root - INFO - Creating embeddings...
2022-06-15 15:12:27,980 - root - INFO - Loading config file: ./sbert_static_embeddings_config_50_0.yaml
2022-06-15 15:12:27,980 - root - INFO - {'data': {'data_folder_path': '../data', 'multiwoz_and_sgd_vocabulary_path': '../data/multiwoz_and_sgd_joint_vocabulary.json', 'pretrained_cc_en_vocabulary_path': '../data/pretrained_cc_en_vocabulary.json'}, 'embeddings': {'embeddings_dict_path': '../data', 'embeddings_dataframes_path': '../data', 'context': 'word', 'pooling_method': 'mean', 'special_tokens': 'ignore'}, 'neighborhoods': {'nbhd_size': 50, 'nbhd_remove': 0, 'neighborhoods_path': '../data/neighborhoods', 'persistence_features_path': '../data', 'normalize': False}}
2022-06-15 15:12:27,980 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: paraphrase-MiniLM-L6-v2
2022-06-15 15:12:34,965 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device: cpu
2022-06-15 15:12:34,965 - root - INFO - Loading data ...
2022-06-15 15:12:35,103 - root - INFO - Loading embeddings ...
2022-06-15 15:12:43,604 - root - INFO - Loading embeddings done.
2022-06-15 15:13:47,810 - root - INFO - Loading config file: ./sbert_static_embeddings_config_50_0.yaml
2022-06-15 15:13:47,810 - root - INFO - {'data': {'data_folder_path': '../data', 'multiwoz_and_sgd_vocabulary_path': '../data/multiwoz_and_sgd_joint_vocabulary.json', 'pretrained_cc_en_vocabulary_path': '../data/pretrained_cc_en_vocabulary.json'}, 'embeddings': {'embeddings_dict_path': '../data', 'embeddings_dataframes_path': '../data', 'context': 'word', 'pooling_method': 'mean', 'special_tokens': 'ignore'}, 'neighborhoods': {'nbhd_size': 50, 'nbhd_remove': 0, 'neighborhoods_path': '../data/neighborhoods', 'persistence_features_path': '../data', 'normalize': False}}
2022-06-15 15:13:47,810 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: paraphrase-MiniLM-L6-v2
2022-06-15 15:13:54,786 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device: cpu
2022-06-15 15:13:54,786 - root - INFO - Loading data ...
2022-06-15 15:13:54,911 - root - INFO - Loading embeddings ...
2022-06-15 15:14:02,418 - root - INFO - Loading embeddings done.
2022-06-15 15:14:02,438 - root - ERROR - Traceback (most recent call last):
File "sbert_create_static_embeddings.py", line 190, in <module>
embedding = embeddings['paraphrase-MiniLM-L6-v2'][word]
KeyError: 'xg1tx6q8'
2022-06-15 15:15:33,524 - root - INFO - Loading config file: ./sbert_static_embeddings_config_50_0.yaml
2022-06-15 15:15:33,524 - root - INFO - {'data': {'data_folder_path': '../data', 'multiwoz_and_sgd_vocabulary_path': '../data/multiwoz_and_sgd_joint_vocabulary.json', 'pretrained_cc_en_vocabulary_path': '../data/pretrained_cc_en_vocabulary.json'}, 'embeddings': {'embeddings_dict_path': '../data', 'embeddings_dataframes_path': '../data', 'context': 'word', 'pooling_method': 'mean', 'special_tokens': 'ignore'}, 'neighborhoods': {'nbhd_size': 50, 'nbhd_remove': 0, 'neighborhoods_path': '../data/neighborhoods', 'persistence_features_path': '../data', 'normalize': False}}
2022-06-15 15:15:33,524 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: paraphrase-MiniLM-L6-v2
2022-06-15 15:15:40,440 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device: cpu
2022-06-15 15:15:40,440 - root - INFO - Loading data ...
2022-06-15 15:15:40,565 - root - INFO - Loading embeddings ...
2022-06-15 15:15:48,250 - root - INFO - Loading embeddings done.
2022-06-15 15:15:48,273 - root - ERROR - Traceback (most recent call last):
File "sbert_create_static_embeddings.py", line 190, in <module>
embedding = embeddings['paraphrase-MiniLM-L6-v2'][word]
KeyError: 'xg1tx6q8'
2022-06-15 15:56:59,026 - root - INFO - Loading config file: ./sbert_static_embeddings_config_50_0.yaml
2022-06-15 15:56:59,026 - root - INFO - {'data': {'data_folder_path': '../data', 'multiwoz_and_sgd_vocabulary_path': '../data/multiwoz_and_sgd_joint_vocabulary.json', 'pretrained_cc_en_vocabulary_path': '../data/pretrained_cc_en_vocabulary.json'}, 'embeddings': {'embeddings_dict_path': '../data', 'embeddings_dataframes_path': '../data', 'context': 'word', 'pooling_method': 'mean', 'special_tokens': 'ignore'}, 'neighborhoods': {'nbhd_size': 50, 'nbhd_remove': 0, 'neighborhoods_path': '../data/neighborhoods', 'persistence_features_path': '../data', 'normalize': False}}
2022-06-15 15:56:59,026 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: paraphrase-MiniLM-L6-v2
2022-06-15 15:57:05,865 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device: cpu
2022-06-15 15:57:05,865 - root - INFO - Loading data ...
2022-06-15 15:57:05,868 - root - INFO - Creating embeddings...
data:
data_folder_path: '../data'
multiwoz_and_sgd_joint_vocabulary_path: '../data/multiwoz_and_sgd_joint_vocabulary.json'
multiwoz_and_sgd_vocabulary_path: '../data/multiwoz_and_sgd_joint_vocabulary.json'
pretrained_cc_en_vocabulary_path: '../data/pretrained_cc_en_vocabulary.json'
embeddings:
embeddings_dict_path: '../data'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment