Skip to content
Snippets Groups Projects
Commit dafb50a2 authored by zqwerty's avatar zqwerty
Browse files

update lmloss2keywords to reduce memory cost

parent ed0dc474
No related branches found
No related tags found
No related merge requests found
task_name="lm"
dataset_name=$1
model_type="gpt"
data_dir="data/${task_name}/${dataset_name}/${model_type}"
model_name_or_path="gpt2-large"
model_name_or_path="/data/zhuqi/pre-trained-models/gpt2-large"
keywords_num=100
keywords_ratio=0.4
keywords_ratio=0.3
keywords_th_ratio=0
stopwords=True
for data_split in validation test train
for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd reddit wikidialog
do
word_loss_file="${data_dir}/${model_name_or_path}_${dataset_name}_${data_split}_word_loss.json"
output_file="${data_dir}/${dataset_name}_${data_split}_keywords_${model_name_or_path}_topk_${keywords_num}_ratio_${keywords_ratio}_th_${keywords_th_ratio}_stopwords_${stopwords}.json"
data_dir="data/${task_name}/${model_type}/${dataset_name}"
for data_split in validation train
do
token_loss_file="${data_dir}/token_loss_${data_split}.json"
output_file="${data_dir}/keywords_${data_split}.json"
python lmloss2keywords.py \
--model_type ${model_type} \
--word_loss_file ${word_loss_file} \
--model_name_or_path ${model_name_or_path} \
--token_loss_file ${token_loss_file} \
--keywords_num ${keywords_num} \
--keywords_ratio ${keywords_ratio} \
--keywords_th_ratio ${keywords_th_ratio} \
--stopwords ${stopwords} \
--output_file ${output_file}
done
done
\ No newline at end of file
......@@ -5,7 +5,7 @@ model_type="gpt"
cache_dir="../cache"
source_column="dialogue"
max_length=512
model_name_or_path="/data/zhuqi/pre-trained-models/gpt2-large"
model_name_or_path="gpt2-large"
per_device_eval_batch_size=16
for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd reddit wikidialog
......
......@@ -51,35 +51,28 @@ def convert_token_loss2word_loss(token_loss_file):
word_loss_file = os.path.join(os.path.dirname(token_loss_file), token_loss_file.split('/')[-1].replace('token', 'word'))
fin = open(token_loss_file, 'rb')
fout = open(word_loss_file, 'w', encoding='utf-8')
lines = []
for item in tqdm(json_lines.reader(fin)):
tokens, losses = item['tokens'], item['losses']
assert len(tokens) == len(losses)
word2losses = merge_tokens(tokens, losses)
lines.append({"words": [x[0] for x in word2losses], "losses": [x[1] for x in word2losses]})
fout.write(json.dumps(lines[-1], ensure_ascii=False)+'\n')
fout.write(json.dumps({"words": [x[0] for x in word2losses], "losses": [x[1] for x in word2losses]}, ensure_ascii=False)+'\n')
fin.close()
fout.close()
return lines
return word_loss_file
def main(args):
if not args.word_loss_file:
word_loss_list = convert_token_loss2word_loss(args.token_loss_file)
word_loss_file = convert_token_loss2word_loss(args.token_loss_file)
else:
fin = open(args.word_loss_file, 'rb')
word_loss_list = []
for item in json_lines.reader(fin):
words, losses = item['words'], item['losses']
word_loss_list.append({"words": words, "losses": losses})
fin.close()
word_loss_file = args.word_loss_file
if not args.output_file:
return
stop_words = set(stopwords.words('english'))
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
sent_tokenizer = PunktSentenceTokenizer()
if args.keywords_th_ratio > 0:
......@@ -138,8 +131,10 @@ def main(args):
return keywords, keywords_turn_sent2idx
dialogs = []
for item in tqdm(word_loss_list):
fin = open(word_loss_file, 'rb')
fout = open(args.output_file, 'w', encoding='utf-8')
for item in json_lines.reader(fin):
words = [tokenizer.convert_tokens_to_string(tokens) for tokens in item['words']]
losses = [np.mean(loss) for loss in item['losses']]
dialog_keywords, keywords_turn_sent2idx = keywords_filter(words, losses)
......@@ -163,15 +158,17 @@ def main(args):
turns.append(turn)
turn = {'words': [], 'losses': []}
dialogs.append(turns)
json.dump(dialogs, open(args.output_file, "w", encoding='utf-8'), indent=2, ensure_ascii=False)
fout.write(json.dumps(turns, ensure_ascii=False)+'\n')
fin.close()
fout.close()
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="extract keywords according to lm loss")
parser.add_argument('--model_type', '-m', type=str, help='gpt or dialogpt')
parser.add_argument('--model_name_or_path', type=str, help='model name or path')
parser.add_argument('--token_loss_file', '-t', type=str, help='path to the token loss file that contains two columns: [tokens, losses]')
parser.add_argument('--word_loss_file', '-w', type=str, help='path to the token loss file that contains two columns: [tokens, losses]')
parser.add_argument('--output_file', '-o', type=str, help='path to the output file')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment