From dafb50a27ee7fc369f8445c6e3343381763e789c Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Sun, 17 Jul 2022 11:46:23 +0800
Subject: [PATCH] update lmloss2keywords to reduce memory cost

---
 .../gpt/keyword_extraction/get_keywords.sh    | 34 ++++++++++---------
 .../gpt/keyword_extraction/get_token_loss.sh  |  2 +-
 .../gpt/keyword_extraction/lmloss2keywords.py | 31 ++++++++---------
 3 files changed, 33 insertions(+), 34 deletions(-)

diff --git a/convlab/base_models/gpt/keyword_extraction/get_keywords.sh b/convlab/base_models/gpt/keyword_extraction/get_keywords.sh
index c060c9a0..8f5d2809 100644
--- a/convlab/base_models/gpt/keyword_extraction/get_keywords.sh
+++ b/convlab/base_models/gpt/keyword_extraction/get_keywords.sh
@@ -1,23 +1,25 @@
 task_name="lm"
-dataset_name=$1
 model_type="gpt"
-data_dir="data/${task_name}/${dataset_name}/${model_type}"
-model_name_or_path="gpt2-large"
+model_name_or_path="/data/zhuqi/pre-trained-models/gpt2-large"
 keywords_num=100
-keywords_ratio=0.4
+keywords_ratio=0.3
 keywords_th_ratio=0
 stopwords=True
-for data_split in validation test train
+for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd reddit wikidialog
 do
-    word_loss_file="${data_dir}/${model_name_or_path}_${dataset_name}_${data_split}_word_loss.json"
-    output_file="${data_dir}/${dataset_name}_${data_split}_keywords_${model_name_or_path}_topk_${keywords_num}_ratio_${keywords_ratio}_th_${keywords_th_ratio}_stopwords_${stopwords}.json"
-
-    python lmloss2keywords.py \
-        --model_type ${model_type} \
-        --word_loss_file ${word_loss_file} \
-        --keywords_num ${keywords_num} \
-        --keywords_ratio ${keywords_ratio} \
-        --keywords_th_ratio ${keywords_th_ratio} \
-        --stopwords ${stopwords} \
-        --output_file ${output_file}
+    data_dir="data/${task_name}/${model_type}/${dataset_name}"
+    for data_split in validation train
+    do
+        token_loss_file="${data_dir}/token_loss_${data_split}.json"
+        output_file="${data_dir}/keywords_${data_split}.json"
+        python lmloss2keywords.py \
+            --model_type ${model_type} \
+            --model_name_or_path ${model_name_or_path} \
+            --token_loss_file ${token_loss_file} \
+            --keywords_num ${keywords_num} \
+            --keywords_ratio ${keywords_ratio} \
+            --keywords_th_ratio ${keywords_th_ratio} \
+            --stopwords ${stopwords} \
+            --output_file ${output_file}
+    done
 done
\ No newline at end of file
diff --git a/convlab/base_models/gpt/keyword_extraction/get_token_loss.sh b/convlab/base_models/gpt/keyword_extraction/get_token_loss.sh
index 4c73a3b0..7c2b57da 100644
--- a/convlab/base_models/gpt/keyword_extraction/get_token_loss.sh
+++ b/convlab/base_models/gpt/keyword_extraction/get_token_loss.sh
@@ -5,7 +5,7 @@ model_type="gpt"
 cache_dir="../cache"
 source_column="dialogue"
 max_length=512
-model_name_or_path="/data/zhuqi/pre-trained-models/gpt2-large"
+model_name_or_path="gpt2-large"
 per_device_eval_batch_size=16
 
 for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd reddit wikidialog
diff --git a/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py b/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py
index 01581d2d..cbc96798 100644
--- a/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py
+++ b/convlab/base_models/gpt/keyword_extraction/lmloss2keywords.py
@@ -51,35 +51,28 @@ def convert_token_loss2word_loss(token_loss_file):
     word_loss_file = os.path.join(os.path.dirname(token_loss_file), token_loss_file.split('/')[-1].replace('token', 'word'))
     fin = open(token_loss_file, 'rb')
     fout = open(word_loss_file, 'w', encoding='utf-8')
-    lines = []
 
     for item in tqdm(json_lines.reader(fin)):
         tokens, losses = item['tokens'], item['losses']
         assert len(tokens) == len(losses)
         word2losses = merge_tokens(tokens, losses)
-        lines.append({"words": [x[0] for x in word2losses], "losses": [x[1] for x in word2losses]})
-        fout.write(json.dumps(lines[-1], ensure_ascii=False)+'\n')
+        fout.write(json.dumps({"words": [x[0] for x in word2losses], "losses": [x[1] for x in word2losses]}, ensure_ascii=False)+'\n')
 
     fin.close()
     fout.close()
-    return lines
+    return word_loss_file
 
 def main(args):
     if not args.word_loss_file:
-        word_loss_list = convert_token_loss2word_loss(args.token_loss_file)
+        word_loss_file = convert_token_loss2word_loss(args.token_loss_file)
     else:
-        fin = open(args.word_loss_file, 'rb')
-        word_loss_list = []
-        for item in json_lines.reader(fin):
-            words, losses = item['words'], item['losses']
-            word_loss_list.append({"words": words, "losses": losses})
-        fin.close()
+        word_loss_file = args.word_loss_file
 
     if not args.output_file:
         return
 
     stop_words = set(stopwords.words('english'))
-    tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
+    tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
     sent_tokenizer = PunktSentenceTokenizer()
 
     if args.keywords_th_ratio > 0:
@@ -138,8 +131,10 @@ def main(args):
 
         return keywords, keywords_turn_sent2idx
 
-    dialogs = []
-    for item in tqdm(word_loss_list):
+    fin = open(word_loss_file, 'rb')
+    fout = open(args.output_file, 'w', encoding='utf-8')
+
+    for item in json_lines.reader(fin):
         words = [tokenizer.convert_tokens_to_string(tokens) for tokens in item['words']]
         losses = [np.mean(loss) for loss in item['losses']]
         dialog_keywords, keywords_turn_sent2idx = keywords_filter(words, losses)
@@ -163,15 +158,17 @@ def main(args):
                 turns.append(turn)
                 turn = {'words': [], 'losses': []}
                 
-        dialogs.append(turns)
-    json.dump(dialogs, open(args.output_file, "w", encoding='utf-8'), indent=2, ensure_ascii=False)
-
+        fout.write(json.dumps(turns, ensure_ascii=False)+'\n')
+    
+    fin.close()
+    fout.close()
 
 
 if __name__ == '__main__':
     from argparse import ArgumentParser
     parser = ArgumentParser(description="extract keywords according to lm loss")
     parser.add_argument('--model_type', '-m', type=str, help='gpt or dialogpt')
+    parser.add_argument('--model_name_or_path', type=str, help='model name or path')
     parser.add_argument('--token_loss_file', '-t', type=str, help='path to the token loss file that contains two columns: [tokens, losses]')
     parser.add_argument('--word_loss_file', '-w', type=str, help='path to the token loss file that contains two columns: [tokens, losses]')
     parser.add_argument('--output_file', '-o', type=str, help='path to the output file')
-- 
GitLab