From 1771c52b6d33717fafbc692202c0c68a9121aef4 Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Sun, 17 Jul 2022 12:06:26 +0800 Subject: [PATCH] rm keywords_th_ratio, add keywords_loss_th args for keyword extraction --- .../gpt/keyword_extraction/get_keywords.sh | 4 ++-- .../gpt/keyword_extraction/lmloss2keywords.py | 11 ++--------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/convlab/base_models/gpt/keyword_extraction/get_keywords.sh b/convlab/base_models/gpt/keyword_extraction/get_keywords.sh index 0533d3cb..d3051ba6 100644 --- a/convlab/base_models/gpt/keyword_extraction/get_keywords.sh +++ b/convlab/base_models/gpt/keyword_extraction/get_keywords.sh @@ -3,7 +3,7 @@ model_type="gpt" model_name_or_path="gpt2-large" keywords_num=100 keywords_ratio=0.3 -keywords_th_ratio=0 +keywords_loss_th=0 stopwords=True for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd reddit wikidialog do @@ -18,7 +18,7 @@ do --token_loss_file ${token_loss_file} \ --keywords_num ${keywords_num} \ --keywords_ratio ${keywords_ratio} \ - --keywords_th_ratio ${keywords_th_ratio} \ + --keywords_loss_th ${keywords_loss_th} \ --stopwords ${stopwords} \ --output_file ${output_file} done diff --git a/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py b/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py index cbc96798..bdd0f99e 100644 --- a/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py +++ b/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py @@ -75,13 +75,6 @@ def main(args): tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) sent_tokenizer = PunktSentenceTokenizer() - if args.keywords_th_ratio > 0: - losses = [loss for x in word_loss_list for word, loss in zip(x['words'], x['losses']) if not any([w.lower() in stop_words for w in word_tokenize(word)])] - loss_th = sorted(losses, reverse=True)[round(args.keywords_th_ratio*len(losses))] - print(f'loss th for top {args.keywords_th_ratio*100}%: {loss_th}') - else: - loss_th = 0 - def keywords_filter(words, losses): word_loss_pairs = list(zip(words, losses)) index2keyword = {} @@ -99,7 +92,7 @@ def main(args): if args.stopwords and any([w.lower() in stop_words for w in words]): # skip stopwords continue - if word_loss_pair[1] <= loss_th: + if word_loss_pair[1] <= args.keywords_loss_th: # skip if loss is too small continue # strip punctuation @@ -174,7 +167,7 @@ if __name__ == '__main__': parser.add_argument('--output_file', '-o', type=str, help='path to the output file') parser.add_argument('--keywords_num', '-n', type=int, default=100, help='how many words in an utterance serve as keywords') parser.add_argument('--keywords_ratio', '-r', type=float, default=1.0, help='how many words (in ratio) in an utterance serve as keywords') - parser.add_argument('--keywords_th_ratio', '-th', type=float, default=0., help='loss threshold for the keywords, ratio of all word losses') + parser.add_argument('--keywords_loss_th', '-th', type=float, default=0., help='loss threshold for the keywords') parser.add_argument('--stopwords', '-s', type=lambda x: bool(eval(x)), default=True, help='filter out stopwords') args = parser.parse_args() print(args) -- GitLab