From dafb50a27ee7fc369f8445c6e3343381763e789c Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Sun, 17 Jul 2022 11:46:23 +0800 Subject: [PATCH] update lmloss2keywords to reduce memory cost --- .../gpt/keyword_extraction/get_keywords.sh | 34 ++++++++++--------- .../gpt/keyword_extraction/get_token_loss.sh | 2 +- .../gpt/keyword_extraction/lmloss2keywords.py | 31 ++++++++--------- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/convlab/base_models/gpt/keyword_extraction/get_keywords.sh b/convlab/base_models/gpt/keyword_extraction/get_keywords.sh index c060c9a0..8f5d2809 100644 --- a/convlab/base_models/gpt/keyword_extraction/get_keywords.sh +++ b/convlab/base_models/gpt/keyword_extraction/get_keywords.sh @@ -1,23 +1,25 @@ task_name="lm" -dataset_name=$1 model_type="gpt" -data_dir="data/${task_name}/${dataset_name}/${model_type}" -model_name_or_path="gpt2-large" +model_name_or_path="/data/zhuqi/pre-trained-models/gpt2-large" keywords_num=100 -keywords_ratio=0.4 +keywords_ratio=0.3 keywords_th_ratio=0 stopwords=True -for data_split in validation test train +for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd reddit wikidialog do - word_loss_file="${data_dir}/${model_name_or_path}_${dataset_name}_${data_split}_word_loss.json" - output_file="${data_dir}/${dataset_name}_${data_split}_keywords_${model_name_or_path}_topk_${keywords_num}_ratio_${keywords_ratio}_th_${keywords_th_ratio}_stopwords_${stopwords}.json" - - python lmloss2keywords.py \ - --model_type ${model_type} \ - --word_loss_file ${word_loss_file} \ - --keywords_num ${keywords_num} \ - --keywords_ratio ${keywords_ratio} \ - --keywords_th_ratio ${keywords_th_ratio} \ - --stopwords ${stopwords} \ - --output_file ${output_file} + data_dir="data/${task_name}/${model_type}/${dataset_name}" + for data_split in validation train + do + token_loss_file="${data_dir}/token_loss_${data_split}.json" + output_file="${data_dir}/keywords_${data_split}.json" + python lmloss2keywords.py \ + --model_type ${model_type} \ + --model_name_or_path ${model_name_or_path} \ + --token_loss_file ${token_loss_file} \ + --keywords_num ${keywords_num} \ + --keywords_ratio ${keywords_ratio} \ + --keywords_th_ratio ${keywords_th_ratio} \ + --stopwords ${stopwords} \ + --output_file ${output_file} + done done \ No newline at end of file diff --git a/convlab/base_models/gpt/keyword_extraction/get_token_loss.sh b/convlab/base_models/gpt/keyword_extraction/get_token_loss.sh index 4c73a3b0..7c2b57da 100644 --- a/convlab/base_models/gpt/keyword_extraction/get_token_loss.sh +++ b/convlab/base_models/gpt/keyword_extraction/get_token_loss.sh @@ -5,7 +5,7 @@ model_type="gpt" cache_dir="../cache" source_column="dialogue" max_length=512 -model_name_or_path="/data/zhuqi/pre-trained-models/gpt2-large" +model_name_or_path="gpt2-large" per_device_eval_batch_size=16 for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd reddit wikidialog diff --git a/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py b/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py index 01581d2d..cbc96798 100644 --- a/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py +++ b/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py @@ -51,35 +51,28 @@ def convert_token_loss2word_loss(token_loss_file): word_loss_file = os.path.join(os.path.dirname(token_loss_file), token_loss_file.split('/')[-1].replace('token', 'word')) fin = open(token_loss_file, 'rb') fout = open(word_loss_file, 'w', encoding='utf-8') - lines = [] for item in tqdm(json_lines.reader(fin)): tokens, losses = item['tokens'], item['losses'] assert len(tokens) == len(losses) word2losses = merge_tokens(tokens, losses) - lines.append({"words": [x[0] for x in word2losses], "losses": [x[1] for x in word2losses]}) - fout.write(json.dumps(lines[-1], ensure_ascii=False)+'\n') + fout.write(json.dumps({"words": [x[0] for x in word2losses], "losses": [x[1] for x in word2losses]}, ensure_ascii=False)+'\n') fin.close() fout.close() - return lines + return word_loss_file def main(args): if not args.word_loss_file: - word_loss_list = convert_token_loss2word_loss(args.token_loss_file) + word_loss_file = convert_token_loss2word_loss(args.token_loss_file) else: - fin = open(args.word_loss_file, 'rb') - word_loss_list = [] - for item in json_lines.reader(fin): - words, losses = item['words'], item['losses'] - word_loss_list.append({"words": words, "losses": losses}) - fin.close() + word_loss_file = args.word_loss_file if not args.output_file: return stop_words = set(stopwords.words('english')) - tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large') + tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path) sent_tokenizer = PunktSentenceTokenizer() if args.keywords_th_ratio > 0: @@ -138,8 +131,10 @@ def main(args): return keywords, keywords_turn_sent2idx - dialogs = [] - for item in tqdm(word_loss_list): + fin = open(word_loss_file, 'rb') + fout = open(args.output_file, 'w', encoding='utf-8') + + for item in json_lines.reader(fin): words = [tokenizer.convert_tokens_to_string(tokens) for tokens in item['words']] losses = [np.mean(loss) for loss in item['losses']] dialog_keywords, keywords_turn_sent2idx = keywords_filter(words, losses) @@ -163,15 +158,17 @@ def main(args): turns.append(turn) turn = {'words': [], 'losses': []} - dialogs.append(turns) - json.dump(dialogs, open(args.output_file, "w", encoding='utf-8'), indent=2, ensure_ascii=False) - + fout.write(json.dumps(turns, ensure_ascii=False)+'\n') + + fin.close() + fout.close() if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser(description="extract keywords according to lm loss") parser.add_argument('--model_type', '-m', type=str, help='gpt or dialogpt') + parser.add_argument('--model_name_or_path', type=str, help='model name or path') parser.add_argument('--token_loss_file', '-t', type=str, help='path to the token loss file that contains two columns: [tokens, losses]') parser.add_argument('--word_loss_file', '-w', type=str, help='path to the token loss file that contains two columns: [tokens, losses]') parser.add_argument('--output_file', '-o', type=str, help='path to the output file') -- GitLab