From 1771c52b6d33717fafbc692202c0c68a9121aef4 Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Sun, 17 Jul 2022 12:06:26 +0800
Subject: [PATCH] rm keywords_th_ratio, add keywords_loss_th args for keyword
 extraction

---
 .../gpt/keyword_extraction/get_keywords.sh            |  4 ++--
 .../gpt/keyword_extraction/lmloss2keywords.py         | 11 ++---------
 2 files changed, 4 insertions(+), 11 deletions(-)

diff --git a/convlab/base_models/gpt/keyword_extraction/get_keywords.sh b/convlab/base_models/gpt/keyword_extraction/get_keywords.sh
index 0533d3cb..d3051ba6 100644
--- a/convlab/base_models/gpt/keyword_extraction/get_keywords.sh
+++ b/convlab/base_models/gpt/keyword_extraction/get_keywords.sh
@@ -3,7 +3,7 @@ model_type="gpt"
 model_name_or_path="gpt2-large"
 keywords_num=100
 keywords_ratio=0.3
-keywords_th_ratio=0
+keywords_loss_th=0
 stopwords=True
 for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd reddit wikidialog
 do
@@ -18,7 +18,7 @@ do
             --token_loss_file ${token_loss_file} \
             --keywords_num ${keywords_num} \
             --keywords_ratio ${keywords_ratio} \
-            --keywords_th_ratio ${keywords_th_ratio} \
+            --keywords_loss_th ${keywords_loss_th} \
             --stopwords ${stopwords} \
             --output_file ${output_file}
     done
diff --git a/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py b/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py
index cbc96798..bdd0f99e 100644
--- a/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py
+++ b/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py
@@ -75,13 +75,6 @@ def main(args):
     tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
     sent_tokenizer = PunktSentenceTokenizer()
 
-    if args.keywords_th_ratio > 0:
-        losses = [loss for x in word_loss_list for word, loss in zip(x['words'], x['losses']) if not any([w.lower() in stop_words for w in word_tokenize(word)])]
-        loss_th = sorted(losses, reverse=True)[round(args.keywords_th_ratio*len(losses))]
-        print(f'loss th for top {args.keywords_th_ratio*100}%: {loss_th}')
-    else:
-        loss_th = 0
-
     def keywords_filter(words, losses):
         word_loss_pairs = list(zip(words, losses))
         index2keyword = {}
@@ -99,7 +92,7 @@ def main(args):
             if args.stopwords and any([w.lower() in stop_words for w in words]):
                 # skip stopwords
                 continue
-            if word_loss_pair[1] <= loss_th:
+            if word_loss_pair[1] <= args.keywords_loss_th:
                 # skip if loss is too small
                 continue
             # strip punctuation
@@ -174,7 +167,7 @@ if __name__ == '__main__':
     parser.add_argument('--output_file', '-o', type=str, help='path to the output file')
     parser.add_argument('--keywords_num', '-n', type=int, default=100, help='how many words in an utterance serve as keywords')
     parser.add_argument('--keywords_ratio', '-r', type=float, default=1.0, help='how many words (in ratio) in an utterance serve as keywords')
-    parser.add_argument('--keywords_th_ratio', '-th', type=float, default=0., help='loss threshold for the keywords, ratio of all word losses')
+    parser.add_argument('--keywords_loss_th', '-th', type=float, default=0., help='loss threshold for the keywords')
     parser.add_argument('--stopwords', '-s', type=lambda x: bool(eval(x)), default=True, help='filter out stopwords')
     args = parser.parse_args()
     print(args)
-- 
GitLab