Skip to content
Snippets Groups Projects
Commit 1771c52b authored by zqwerty's avatar zqwerty
Browse files

rm keywords_th_ratio, add keywords_loss_th args for keyword extraction

parent 7570883d
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,7 @@ model_type="gpt" ...@@ -3,7 +3,7 @@ model_type="gpt"
model_name_or_path="gpt2-large" model_name_or_path="gpt2-large"
keywords_num=100 keywords_num=100
keywords_ratio=0.3 keywords_ratio=0.3
keywords_th_ratio=0 keywords_loss_th=0
stopwords=True stopwords=True
for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd reddit wikidialog for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd reddit wikidialog
do do
...@@ -18,7 +18,7 @@ do ...@@ -18,7 +18,7 @@ do
--token_loss_file ${token_loss_file} \ --token_loss_file ${token_loss_file} \
--keywords_num ${keywords_num} \ --keywords_num ${keywords_num} \
--keywords_ratio ${keywords_ratio} \ --keywords_ratio ${keywords_ratio} \
--keywords_th_ratio ${keywords_th_ratio} \ --keywords_loss_th ${keywords_loss_th} \
--stopwords ${stopwords} \ --stopwords ${stopwords} \
--output_file ${output_file} --output_file ${output_file}
done done
......
...@@ -75,13 +75,6 @@ def main(args): ...@@ -75,13 +75,6 @@ def main(args):
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
sent_tokenizer = PunktSentenceTokenizer() sent_tokenizer = PunktSentenceTokenizer()
if args.keywords_th_ratio > 0:
losses = [loss for x in word_loss_list for word, loss in zip(x['words'], x['losses']) if not any([w.lower() in stop_words for w in word_tokenize(word)])]
loss_th = sorted(losses, reverse=True)[round(args.keywords_th_ratio*len(losses))]
print(f'loss th for top {args.keywords_th_ratio*100}%: {loss_th}')
else:
loss_th = 0
def keywords_filter(words, losses): def keywords_filter(words, losses):
word_loss_pairs = list(zip(words, losses)) word_loss_pairs = list(zip(words, losses))
index2keyword = {} index2keyword = {}
...@@ -99,7 +92,7 @@ def main(args): ...@@ -99,7 +92,7 @@ def main(args):
if args.stopwords and any([w.lower() in stop_words for w in words]): if args.stopwords and any([w.lower() in stop_words for w in words]):
# skip stopwords # skip stopwords
continue continue
if word_loss_pair[1] <= loss_th: if word_loss_pair[1] <= args.keywords_loss_th:
# skip if loss is too small # skip if loss is too small
continue continue
# strip punctuation # strip punctuation
...@@ -174,7 +167,7 @@ if __name__ == '__main__': ...@@ -174,7 +167,7 @@ if __name__ == '__main__':
parser.add_argument('--output_file', '-o', type=str, help='path to the output file') parser.add_argument('--output_file', '-o', type=str, help='path to the output file')
parser.add_argument('--keywords_num', '-n', type=int, default=100, help='how many words in an utterance serve as keywords') parser.add_argument('--keywords_num', '-n', type=int, default=100, help='how many words in an utterance serve as keywords')
parser.add_argument('--keywords_ratio', '-r', type=float, default=1.0, help='how many words (in ratio) in an utterance serve as keywords') parser.add_argument('--keywords_ratio', '-r', type=float, default=1.0, help='how many words (in ratio) in an utterance serve as keywords')
parser.add_argument('--keywords_th_ratio', '-th', type=float, default=0., help='loss threshold for the keywords, ratio of all word losses') parser.add_argument('--keywords_loss_th', '-th', type=float, default=0., help='loss threshold for the keywords')
parser.add_argument('--stopwords', '-s', type=lambda x: bool(eval(x)), default=True, help='filter out stopwords') parser.add_argument('--stopwords', '-s', type=lambda x: bool(eval(x)), default=True, help='filter out stopwords')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment