diff --git a/tda/README.md b/tda/README.md index b6db66cf442b4f7a87a601ca3bd5f3eac1b7c552..f4020eb0db8e7cda367d00702a9276e061068b2c 100644 --- a/tda/README.md +++ b/tda/README.md @@ -1,17 +1,39 @@ # TDA code for Dialogue Term Extraction using Transfer Learning and Topological Data Analysis This is the Topological Data Analysis portion of the code for the paper -Dialogue Term Extraction using Transfer Learning and Topological Data Analysis. +'Dialogue Term Extraction using Transfer Learning and Topological Data Analysis'. + +The scripts in this folder should be executed in the `tda` working directory. ## Create embeddings -TODO +Precomputed sbert embeddings are contained in the `/data` folder +for the ambient fastText vocabulary, and the joint multiwoz and sgd vocabulary. +These embeddings are the basis for computing neighborhoods. +It is not necessary to recompute these embeddings, +for the neighborhood extraction and TDA features skip ahead to the next section. + +The following command loads the precomputed embeddings +of the fastText vocabulary into an interactive python session: +```bash +python -i sbert_create_static_embeddings.py \ + --embeddings_config_path ./sbert_static_embeddings_config_50_0.yaml \ + --vocab_desc pretrained_cc_en \ + --load_embeddings +``` +To compute and save embeddings of the multiwoz and sgd vocabulary: +```bash +python sbert_create_static_embeddings.py \ + --embeddings_config_path ./sbert_static_embeddings_config_50_0.yaml \ + --vocab_desc multiwoz_and_sgd \ + --save_embeddings +``` ## Build neighbourhoods and extract persistence features TODO -## 8. License +## License This project is licensed under the Apache License, Version 2.0 (the "License"); you may not use the files except in compliance with the License. You may obtain a copy of the License at diff --git a/tda/sbert_create_static_embeddings.py b/tda/sbert_create_static_embeddings.py index 2cb2873844f3d769dbcaac46b87a3e9314c5d055..1995e0c128c04961d3f2e292387f8c554224f1f5 100644 --- a/tda/sbert_create_static_embeddings.py +++ b/tda/sbert_create_static_embeddings.py @@ -49,6 +49,13 @@ parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--embeddings_config_path", default=None, type=str, required=True, help="Dataset configuration file.") +parser.add_argument("--vocab_desc", choices=["multiwoz_and_sgd", "pretrained_cc_en"], + default=None, type=str, required=True, + help="String describing which embeddings should be created.") +parser.add_argument("--save_embeddings", default=False, action="store_true", + help="Flag to save embeddings to disk.") +parser.add_argument("--load_embeddings", default=False, action="store_true", + help="Flag to load embeddings from disk.") # Optional parameters # None @@ -59,22 +66,20 @@ args = parser.parse_args() embeddings_config_path = args.embeddings_config_path embeddings_config = yaml.safe_load(open(embeddings_config_path)) -# # # # # # # # # # # # # # # # # -# MODIFY: Program parameters -DEVELOP_MODE = False +vocab_desc = args.vocab_desc +# the fields in the config file are named +# 'pretrained_cc_en_vocabulary_path' and 'multiwoz_and_sgd_vocabulary_path' +vocab_path = embeddings_config['data'][vocab_desc + '_vocabulary_path'] -SAVE_EMBEDDINGS = True -LOAD_EMBEDDINGS = False +SAVE_EMBEDDINGS = args.save_embeddings +LOAD_EMBEDDINGS = args.load_embeddings data_desc_list = ['paraphrase-MiniLM-L6-v2'] +# # # # # # # # # # # # # # # # # +# Debug parameters DEBUG_DATA_DESC = 'paraphrase-MiniLM-L6-v2' -VOCAB_PATH = embeddings_config['data']['multiwoz_and_sgd_joint_vocabulary_path'] -VOCAB_DESC = 'multiwoz_and_sgd' -# VOCAB_PATH = embeddings_config['data']['pretrained_cc_en_vocabulary_path'] -# VOCAB_DESC = 'pretrained_cc_en' - # Set up logging logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -108,12 +113,14 @@ models = {} # Pooling: Mean Pooling models['paraphrase-MiniLM-L6-v2'] = SentenceTransformer('paraphrase-MiniLM-L6-v2') +# more models could be added here as additional keys, +# the embeddings will be computed for each model # Load data logging.info('Loading data ...') -with open(VOCAB_PATH, "r") as file: +with open(vocab_path, "r") as file: vocabulary = json.load(file) # Testing tokenization of certain sentences @@ -157,7 +164,7 @@ if SAVE_EMBEDDINGS is True and LOAD_EMBEDDINGS is False: # save or load embeddings try: embeddings_path = os.path.join(embeddings_config['embeddings']['embeddings_dict_path'], - f"{VOCAB_DESC}_vocab_embeddings_sbert.pkl") + f"{vocab_desc}_vocab_embeddings_sbert.pkl") if SAVE_EMBEDDINGS is True: logging.info("Saving embeddings ...") logging.info(embeddings_path) @@ -179,12 +186,16 @@ test_vocab = ['.', ':', 'cheap', 'expensive', 'xg1tx6q8', '07591624763'] -# Print the embeddings -for word in test_vocab: - embedding = embeddings['paraphrase-MiniLM-L6-v2'][word] - print("Sentence:", word) - print("Embedding:", embedding) - print("") +try: + # Print the embeddings + for word in test_vocab: + embedding = embeddings['paraphrase-MiniLM-L6-v2'][word] + print("Sentence:", word) + print("Embedding:", embedding) + print("") +except Exception as e: + logging.error(traceback.format_exc()) + pass embedding_vectors = {} diff --git a/tda/sbert_create_static_embeddings_50_0.log b/tda/sbert_create_static_embeddings_50_0.log new file mode 100644 index 0000000000000000000000000000000000000000..e1cf5e2acc11b083f484cb0be170779f243b5720 --- /dev/null +++ b/tda/sbert_create_static_embeddings_50_0.log @@ -0,0 +1,49 @@ +2022-06-15 12:27:38,561 - root - INFO - Loading config file: ./sbert_static_embeddings_config_50_0.yaml +2022-06-15 12:27:38,561 - root - INFO - {'data': {'data_folder_path': '../data', 'multiwoz_and_sgd_joint_vocabulary_path': '../data/multiwoz_and_sgd_joint_vocabulary.json', 'pretrained_cc_en_vocabulary_path': '../data/pretrained_cc_en_vocabulary.json'}, 'embeddings': {'embeddings_dict_path': '../data', 'embeddings_dataframes_path': '../data', 'context': 'word', 'pooling_method': 'mean', 'special_tokens': 'ignore'}, 'neighborhoods': {'nbhd_size': 50, 'nbhd_remove': 0, 'neighborhoods_path': '../data/neighborhoods', 'persistence_features_path': '../data', 'normalize': False}} +2022-06-15 12:27:38,561 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: paraphrase-MiniLM-L6-v2 +2022-06-15 12:27:45,634 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device: cpu +2022-06-15 12:27:45,634 - root - INFO - Loading data ... +2022-06-15 12:27:45,636 - root - INFO - Creating embeddings... +2022-06-15 14:50:27,346 - root - INFO - Loading config file: ./sbert_static_embeddings_config_50_0.yaml +2022-06-15 14:50:27,347 - root - INFO - {'data': {'data_folder_path': '../data', 'multiwoz_and_sgd_vocabulary_path': '../data/multiwoz_and_sgd_joint_vocabulary.json', 'pretrained_cc_en_vocabulary_path': '../data/pretrained_cc_en_vocabulary.json'}, 'embeddings': {'embeddings_dict_path': '../data', 'embeddings_dataframes_path': '../data', 'context': 'word', 'pooling_method': 'mean', 'special_tokens': 'ignore'}, 'neighborhoods': {'nbhd_size': 50, 'nbhd_remove': 0, 'neighborhoods_path': '../data/neighborhoods', 'persistence_features_path': '../data', 'normalize': False}} +2022-06-15 14:50:27,347 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: paraphrase-MiniLM-L6-v2 +2022-06-15 14:50:34,329 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device: cpu +2022-06-15 14:50:34,329 - root - INFO - Loading data ... +2022-06-15 14:50:34,469 - root - INFO - Creating embeddings... +2022-06-15 15:12:27,980 - root - INFO - Loading config file: ./sbert_static_embeddings_config_50_0.yaml +2022-06-15 15:12:27,980 - root - INFO - {'data': {'data_folder_path': '../data', 'multiwoz_and_sgd_vocabulary_path': '../data/multiwoz_and_sgd_joint_vocabulary.json', 'pretrained_cc_en_vocabulary_path': '../data/pretrained_cc_en_vocabulary.json'}, 'embeddings': {'embeddings_dict_path': '../data', 'embeddings_dataframes_path': '../data', 'context': 'word', 'pooling_method': 'mean', 'special_tokens': 'ignore'}, 'neighborhoods': {'nbhd_size': 50, 'nbhd_remove': 0, 'neighborhoods_path': '../data/neighborhoods', 'persistence_features_path': '../data', 'normalize': False}} +2022-06-15 15:12:27,980 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: paraphrase-MiniLM-L6-v2 +2022-06-15 15:12:34,965 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device: cpu +2022-06-15 15:12:34,965 - root - INFO - Loading data ... +2022-06-15 15:12:35,103 - root - INFO - Loading embeddings ... +2022-06-15 15:12:43,604 - root - INFO - Loading embeddings done. +2022-06-15 15:13:47,810 - root - INFO - Loading config file: ./sbert_static_embeddings_config_50_0.yaml +2022-06-15 15:13:47,810 - root - INFO - {'data': {'data_folder_path': '../data', 'multiwoz_and_sgd_vocabulary_path': '../data/multiwoz_and_sgd_joint_vocabulary.json', 'pretrained_cc_en_vocabulary_path': '../data/pretrained_cc_en_vocabulary.json'}, 'embeddings': {'embeddings_dict_path': '../data', 'embeddings_dataframes_path': '../data', 'context': 'word', 'pooling_method': 'mean', 'special_tokens': 'ignore'}, 'neighborhoods': {'nbhd_size': 50, 'nbhd_remove': 0, 'neighborhoods_path': '../data/neighborhoods', 'persistence_features_path': '../data', 'normalize': False}} +2022-06-15 15:13:47,810 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: paraphrase-MiniLM-L6-v2 +2022-06-15 15:13:54,786 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device: cpu +2022-06-15 15:13:54,786 - root - INFO - Loading data ... +2022-06-15 15:13:54,911 - root - INFO - Loading embeddings ... +2022-06-15 15:14:02,418 - root - INFO - Loading embeddings done. +2022-06-15 15:14:02,438 - root - ERROR - Traceback (most recent call last): + File "sbert_create_static_embeddings.py", line 190, in <module> + embedding = embeddings['paraphrase-MiniLM-L6-v2'][word] +KeyError: 'xg1tx6q8' + +2022-06-15 15:15:33,524 - root - INFO - Loading config file: ./sbert_static_embeddings_config_50_0.yaml +2022-06-15 15:15:33,524 - root - INFO - {'data': {'data_folder_path': '../data', 'multiwoz_and_sgd_vocabulary_path': '../data/multiwoz_and_sgd_joint_vocabulary.json', 'pretrained_cc_en_vocabulary_path': '../data/pretrained_cc_en_vocabulary.json'}, 'embeddings': {'embeddings_dict_path': '../data', 'embeddings_dataframes_path': '../data', 'context': 'word', 'pooling_method': 'mean', 'special_tokens': 'ignore'}, 'neighborhoods': {'nbhd_size': 50, 'nbhd_remove': 0, 'neighborhoods_path': '../data/neighborhoods', 'persistence_features_path': '../data', 'normalize': False}} +2022-06-15 15:15:33,524 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: paraphrase-MiniLM-L6-v2 +2022-06-15 15:15:40,440 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device: cpu +2022-06-15 15:15:40,440 - root - INFO - Loading data ... +2022-06-15 15:15:40,565 - root - INFO - Loading embeddings ... +2022-06-15 15:15:48,250 - root - INFO - Loading embeddings done. +2022-06-15 15:15:48,273 - root - ERROR - Traceback (most recent call last): + File "sbert_create_static_embeddings.py", line 190, in <module> + embedding = embeddings['paraphrase-MiniLM-L6-v2'][word] +KeyError: 'xg1tx6q8' + +2022-06-15 15:56:59,026 - root - INFO - Loading config file: ./sbert_static_embeddings_config_50_0.yaml +2022-06-15 15:56:59,026 - root - INFO - {'data': {'data_folder_path': '../data', 'multiwoz_and_sgd_vocabulary_path': '../data/multiwoz_and_sgd_joint_vocabulary.json', 'pretrained_cc_en_vocabulary_path': '../data/pretrained_cc_en_vocabulary.json'}, 'embeddings': {'embeddings_dict_path': '../data', 'embeddings_dataframes_path': '../data', 'context': 'word', 'pooling_method': 'mean', 'special_tokens': 'ignore'}, 'neighborhoods': {'nbhd_size': 50, 'nbhd_remove': 0, 'neighborhoods_path': '../data/neighborhoods', 'persistence_features_path': '../data', 'normalize': False}} +2022-06-15 15:56:59,026 - sentence_transformers.SentenceTransformer - INFO - Load pretrained SentenceTransformer: paraphrase-MiniLM-L6-v2 +2022-06-15 15:57:05,865 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device: cpu +2022-06-15 15:57:05,865 - root - INFO - Loading data ... +2022-06-15 15:57:05,868 - root - INFO - Creating embeddings... diff --git a/tda/sbert_static_embeddings_config_50_0.yaml b/tda/sbert_static_embeddings_config_50_0.yaml index ba3965618a9967aa77f95b76e651630be7d82e21..612999d5cc38a6d03b71783d1c6a0a02a334c32c 100644 --- a/tda/sbert_static_embeddings_config_50_0.yaml +++ b/tda/sbert_static_embeddings_config_50_0.yaml @@ -1,6 +1,6 @@ data: data_folder_path: '../data' - multiwoz_and_sgd_joint_vocabulary_path: '../data/multiwoz_and_sgd_joint_vocabulary.json' + multiwoz_and_sgd_vocabulary_path: '../data/multiwoz_and_sgd_joint_vocabulary.json' pretrained_cc_en_vocabulary_path: '../data/pretrained_cc_en_vocabulary.json' embeddings: embeddings_dict_path: '../data'