Skip to content
Snippets Groups Projects
Commit a15bedea authored by zqwerty's avatar zqwerty
Browse files

add keyword_extract for gpt

parent 42f8dec8
No related branches found
No related tags found
No related merge requests found
...@@ -16,7 +16,7 @@ def create_lm_data(dataset, data_dir, args): ...@@ -16,7 +16,7 @@ def create_lm_data(dataset, data_dir, args):
if args.model_type == 'dialogpt': if args.model_type == 'dialogpt':
dialogue = ' <|endoftext|> '.join([turn['utterance'] for turn in sample['turns']]) + ' <|endoftext|>' dialogue = ' <|endoftext|> '.join([turn['utterance'] for turn in sample['turns']]) + ' <|endoftext|>'
else: else:
dialogue = ' '.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['turns']]) dialogue = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['turns']])
data.append(json.dumps({'dialogue': dialogue}, ensure_ascii=False)+'\n') data.append(json.dumps({'dialogue': dialogue}, ensure_ascii=False)+'\n')
file_name = os.path.join(data_dir, f"{data_split}.json") file_name = os.path.join(data_dir, f"{data_split}.json")
...@@ -35,5 +35,5 @@ if __name__ == '__main__': ...@@ -35,5 +35,5 @@ if __name__ == '__main__':
for dataset_name in tqdm(args.datasets, desc='datasets'): for dataset_name in tqdm(args.datasets, desc='datasets'):
dataset = load_dataset(dataset_name) dataset = load_dataset(dataset_name)
for task_name in tqdm(args.tasks, desc='tasks', leave=False): for task_name in tqdm(args.tasks, desc='tasks', leave=False):
data_dir = os.path.join('data', task_name, dataset_name) data_dir = os.path.join('data', task_name, dataset_name, args.model_type)
eval(f"create_{task_name}_data")(dataset, data_dir, args) eval(f"create_{task_name}_data")(dataset, data_dir, args)
model_type=dialogpt
dataset_name=multiwoz21
model_name=dialogpt-large
data_dir="data/lm/${dataset_name}/${model_type}"
word_loss_file="${data_dir}/${model_name}_${dataset_name}_word_loss.json"
keywords_num=5
keywords_ratio=1
keywords_th=0
stopwords=True
output_file="${data_dir}/${dataset_name}_keywords_${model_name}_topk_${keywords_num}_ratio_${keywords_ratio}_th_${keywords_th}_stopwords_${stopwords}.json"
python lmloss2keywords.py \
--model_type ${model_type} \
--word_loss_file ${word_loss_file} \
--keywords_num ${keywords_num} \
--keywords_ratio ${keywords_ratio} \
--keywords_th ${keywords_th} \
--stopwords ${stopwords} \
--output_file ${output_file}
\ No newline at end of file
set -e
n_gpus=1
task_name="lm"
dataset_name="multiwoz21"
model_type="dialogpt"
data_dir="data/${task_name}/${dataset_name}/${model_type}"
output_dir="output/${task_name}/${dataset_name}/${model_type}"
cache_dir="../cache"
validation_file="${data_dir}/validation.json"
source_column="dialogue"
max_length=512
model_name_or_path="microsoft/DialoGPT-large"
per_device_eval_batch_size=4
dump_eval_loss_to="${data_dir}/dialogpt-large_${dataset_name}_token_loss.json"
python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type dialogpt
python ../run_clm.py \
--dump_eval_loss_to ${dump_eval_loss_to}\
--model_name_or_path ${model_name_or_path} \
--output_dir ${data_dir} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--max_length ${max_length} \
--do_eval \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size}
python lmloss2keywords.py --token_loss_file ${dump_eval_loss_to} --model_type ${model_type}
dump_eval_loss_to="${data_dir}/dialogpt-large-mwoz_${dataset_name}_token_loss.json"
python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type dialogpt
python ../run_clm.py \
--dump_eval_loss_to ${dump_eval_loss_to}\
--model_name_or_path ${output_dir} \
--output_dir ${data_dir} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--max_length ${max_length} \
--do_eval \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size}
python lmloss2keywords.py --token_loss_file ${dump_eval_loss_to} --model_type ${model_type}
model_type="gpt"
data_dir="data/${task_name}/${dataset_name}/${model_type}"
validation_file="${data_dir}/validation.json"
model_name_or_path="gpt2-large"
dump_eval_loss_to="${data_dir}/gpt2-large_${dataset_name}_token_loss.json"
python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type gpt
python ../run_clm.py \
--dump_eval_loss_to ${dump_eval_loss_to}\
--model_name_or_path ${model_name_or_path} \
--output_dir ${data_dir} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--max_length ${max_length} \
--do_eval \
--prediction_loss_only \
--cache_dir ${cache_dir} \
--preprocessing_num_workers 4 \
--per_device_eval_batch_size ${per_device_eval_batch_size}
python lmloss2keywords.py --token_loss_file ${dump_eval_loss_to} --model_type ${model_type}
import json
import json_lines
from pprint import pprint
import os
from tqdm import tqdm
import numpy as np
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
def merge_tokens(tokens, losses, loss_merge_func=np.mean):
res = []
i = 0
while i < len(tokens):
token = tokens[i]
loss = losses[i]
if token in ['Ġ', 'Ċ']:
if token == 'Ċ' and i < len(tokens) - 1:
tokens[i+1] = 'Ġ'+tokens[i+1]
i += 1
continue
if token in ['user', 'system'] and i < len(tokens)-1 and tokens[i+1] == ':':
if i > 0:
tokens[i+1] = '<|endoftext|>'
i += 1
else:
i += 2
continue
if token.startswith('Ġ'):
# Ġ means space
token = token.replace("Ġ", "")
res.append([token, loss])
elif token == '<|endoftext|>':
res.append([token, loss])
else:
assert 'Ġ' not in token
if len(res) > 0:
res[-1][0] += token
res[-1].append(loss)
else:
res.append([token, loss])
i += 1
if loss_merge_func:
for i in range(len(res)):
res[i] = [res[i][0], loss_merge_func(res[i][1:])]
return res
def convert_token_loss2word_loss(token_loss_file, loss_merge_func=np.mean):
word_loss_file = os.path.join(os.path.dirname(token_loss_file), token_loss_file.split('/')[-1].replace('token', 'word'))
fin = open(token_loss_file, 'rb')
fout = open(word_loss_file, 'w', encoding='utf-8')
lines = []
for item in tqdm(json_lines.reader(fin)):
tokens, losses = item['tokens'], item['losses']
assert len(tokens) == len(losses)
word2losses = merge_tokens(tokens, losses, loss_merge_func)
lines.append({"words": [x[0] for x in word2losses], "losses": [x[1] for x in word2losses]})
fout.write(json.dumps(lines[-1], ensure_ascii=False)+'\n')
fin.close()
fout.close()
return lines
def main(args):
if not args.word_loss_file:
word_loss_list = convert_token_loss2word_loss(args.token_loss_file)
else:
fin = open(args.word_loss_file, 'rb')
word_loss_list = []
for item in json_lines.reader(fin):
words, losses = item['words'], item['losses']
word_loss_list.append({"words": words, "losses": losses})
fin.close()
if not args.output_file:
return
stop_words = set(stopwords.words('english'))
dialogs = []
for item in word_loss_list:
words = item['words']
losses = item['losses']
turns = []
turn = {'words': [], 'losses': []}
for word, loss in zip(words, losses):
if word == '<|endoftext|>':
# switch turn
turn['utterance'] = ' '.join(turn['words'])
turn['keywords'] = list(zip(turn['words'], turn['losses']))
if args.stopwords:
turn['keywords'] = [x for x in turn['keywords'] if not any([w.lower() in stop_words for w in word_tokenize(x[0])])]
turn['keywords'] = sorted(turn['keywords'], key=lambda x: x[1], reverse=True)
turn['keywords'] = [x for x in turn['keywords'] if x[1] > args.keywords_th][:min(round(args.keywords_ratio*len(turn['keywords'])), args.keywords_num)]
turn.pop('words')
turn.pop('losses')
turns.append(turn)
turn = {'words': [], 'losses': []}
else:
turn['words'].append(word)
turn['losses'].append(loss)
dialogs.append(turns)
json.dump(dialogs, open(args.output_file, "w", encoding='utf-8'), indent=2, ensure_ascii=False)
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
parser.add_argument('--model_type', '-m', type=str, help='gpt or dialogpt')
parser.add_argument('--token_loss_file', '-t', type=str, help='path to the token loss file that contains two columns: [tokens, losses]')
parser.add_argument('--word_loss_file', '-w', type=str, help='path to the token loss file that contains two columns: [tokens, losses]')
parser.add_argument('--output_file', '-o', type=str, help='path to the output file')
parser.add_argument('--keywords_num', '-n', type=int, default=100, help='how many words in an utterance serve as keywords')
parser.add_argument('--keywords_ratio', '-r', type=float, default=1.0, help='how many words (in ratio) in an utterance serve as keywords')
parser.add_argument('--keywords_th', '-th', type=float, default=0., help='loss threshold for the keywords')
parser.add_argument('--stopwords', '-s', type=lambda x: bool(eval(x)), default=True, help='filter out stopwords')
args = parser.parse_args()
print(args)
main(args)
...@@ -2,8 +2,9 @@ set -e ...@@ -2,8 +2,9 @@ set -e
n_gpus=1 n_gpus=1
task_name="lm" task_name="lm"
dataset_name="multiwoz21" dataset_name="multiwoz21"
data_dir="data/${task_name}/${dataset_name}" model_type="dialogpt"
output_dir="output/${task_name}/${dataset_name}" data_dir="data/${task_name}/${dataset_name}/${model_type}"
output_dir="output/${task_name}/${dataset_name}/${model_type}"
cache_dir="../cache" cache_dir="../cache"
logging_dir="${output_dir}/runs" logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json" train_file="${data_dir}/train.json"
......
...@@ -30,7 +30,11 @@ from itertools import chain ...@@ -30,7 +30,11 @@ from itertools import chain
from typing import Optional from typing import Optional
import datasets import datasets
from datasets import load_dataset, load_metric from datasets import load_dataset
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch
import json
import transformers import transformers
from transformers import ( from transformers import (
...@@ -156,6 +160,9 @@ class DataTrainingArguments: ...@@ -156,6 +160,9 @@ class DataTrainingArguments:
"help": "An optional input evaluation data file to evaluate the metrics on (a text, jsonlines or csv file)." "help": "An optional input evaluation data file to evaluate the metrics on (a text, jsonlines or csv file)."
}, },
) )
dump_eval_loss_to: Optional[str] = field(
default=None, metadata={"help": "Where to dump the tokens' losses in the evaluation data, default not to"}
)
overwrite_cache: bool = field( overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
) )
...@@ -484,23 +491,6 @@ def main(): ...@@ -484,23 +491,6 @@ def main():
pad_to_multiple_of=8 if training_args.fp16 else None, pad_to_multiple_of=8 if training_args.fp16 else None,
) )
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
return logits.argmax(dim=-1)
metric = load_metric("accuracy")
def compute_metrics(eval_preds):
preds, labels = eval_preds
# preds have the same shape as the labels, after the argmax(-1) has been calculated
# by preprocess_logits_for_metrics but we need to shift the labels
labels = labels[:, 1:].reshape(-1)
preds = preds[:, :-1].reshape(-1)
return metric.compute(predictions=preds, references=labels)
# Initialize our Trainer # Initialize our Trainer
trainer = Trainer( trainer = Trainer(
model=model, model=model,
...@@ -509,11 +499,7 @@ def main(): ...@@ -509,11 +499,7 @@ def main():
eval_dataset=eval_dataset if training_args.do_eval else None, eval_dataset=eval_dataset if training_args.do_eval else None,
tokenizer=tokenizer, tokenizer=tokenizer,
# Data collator will default to DataCollatorWithPadding, so we change it. # Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=data_collator, data_collator=data_collator
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
else None,
) )
# Training # Training
...@@ -539,6 +525,7 @@ def main(): ...@@ -539,6 +525,7 @@ def main():
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
if not data_args.dump_eval_loss_to:
metrics = trainer.evaluate(metric_key_prefix="eval") metrics = trainer.evaluate(metric_key_prefix="eval")
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
...@@ -547,9 +534,48 @@ def main(): ...@@ -547,9 +534,48 @@ def main():
except OverflowError: except OverflowError:
perplexity = float("inf") perplexity = float("inf")
metrics["eval_perplexity"] = perplexity metrics["eval_perplexity"] = perplexity
logger.info(f"eval_perplexity: {perplexity}")
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)
else:
if trainer.is_world_process_zero():
output_prediction_file = data_args.dump_eval_loss_to
writer = open(output_prediction_file, "w", encoding='utf-8')
eval_dataloader = DataLoader(
eval_dataset, collate_fn=lambda x: {k: v.to(model.device) for k, v in data_collator(x).items()}, batch_size=training_args.per_device_eval_batch_size
)
model.eval()
losses = []
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
for batch in tqdm(eval_dataloader):
with torch.no_grad():
outputs = model(**batch)
loss = outputs.loss
losses.append(loss.repeat(training_args.per_device_eval_batch_size))
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = batch['labels'][..., 1:].contiguous()
batch_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
batch_token_loss = batch_token_loss.view(shift_labels.size()).tolist()
labels = batch['labels'].tolist()
for i in range(len(labels)):
token_ids = [x for x in labels[i] if x != -100]
tokens = tokenizer.convert_ids_to_tokens(token_ids)
token_losses = [0] + batch_token_loss[i][:len(token_ids)-1]
writer.write(json.dumps({"tokens": tokens, "losses": token_losses}, ensure_ascii=False)+'\n')
losses = torch.cat(losses)
losses = losses[: len(eval_dataset)]
try:
perplexity = math.exp(torch.mean(losses))
except OverflowError:
perplexity = float("inf")
logger.info(f"perplexity: {perplexity}")
writer.close()
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
if data_args.dataset_name is not None: if data_args.dataset_name is not None:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment