From a15bedeaac509334e5fa30b532f7765dfd79fb3f Mon Sep 17 00:00:00 2001 From: zqwerty <zhuq96@hotmail.com> Date: Wed, 13 Apr 2022 16:41:59 +0800 Subject: [PATCH] add keyword_extract for gpt --- convlab2/base_models/gpt/create_data.py | 4 +- .../gpt/keyword_extraction/get_keywords.sh | 20 +++ .../gpt/keyword_extraction/get_word_loss.sh | 65 +++++++++ .../gpt/keyword_extraction/lmloss2keywords.py | 123 ++++++++++++++++++ .../{run.sh => train_lm.sh} | 5 +- convlab2/base_models/gpt/run_clm.py | 94 ++++++++----- 6 files changed, 273 insertions(+), 38 deletions(-) create mode 100644 convlab2/base_models/gpt/keyword_extraction/get_keywords.sh create mode 100644 convlab2/base_models/gpt/keyword_extraction/get_word_loss.sh create mode 100644 convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py rename convlab2/base_models/gpt/keyword_extraction/{run.sh => train_lm.sh} (90%) diff --git a/convlab2/base_models/gpt/create_data.py b/convlab2/base_models/gpt/create_data.py index 94e88f3e..dd616b59 100644 --- a/convlab2/base_models/gpt/create_data.py +++ b/convlab2/base_models/gpt/create_data.py @@ -16,7 +16,7 @@ def create_lm_data(dataset, data_dir, args): if args.model_type == 'dialogpt': dialogue = ' <|endoftext|> '.join([turn['utterance'] for turn in sample['turns']]) + ' <|endoftext|>' else: - dialogue = ' '.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['turns']]) + dialogue = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['turns']]) data.append(json.dumps({'dialogue': dialogue}, ensure_ascii=False)+'\n') file_name = os.path.join(data_dir, f"{data_split}.json") @@ -35,5 +35,5 @@ if __name__ == '__main__': for dataset_name in tqdm(args.datasets, desc='datasets'): dataset = load_dataset(dataset_name) for task_name in tqdm(args.tasks, desc='tasks', leave=False): - data_dir = os.path.join('data', task_name, dataset_name) + data_dir = os.path.join('data', task_name, dataset_name, args.model_type) eval(f"create_{task_name}_data")(dataset, data_dir, args) diff --git a/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh b/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh new file mode 100644 index 00000000..6dd2680b --- /dev/null +++ b/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh @@ -0,0 +1,20 @@ +model_type=dialogpt +dataset_name=multiwoz21 +model_name=dialogpt-large +data_dir="data/lm/${dataset_name}/${model_type}" +word_loss_file="${data_dir}/${model_name}_${dataset_name}_word_loss.json" +keywords_num=5 +keywords_ratio=1 +keywords_th=0 +stopwords=True +output_file="${data_dir}/${dataset_name}_keywords_${model_name}_topk_${keywords_num}_ratio_${keywords_ratio}_th_${keywords_th}_stopwords_${stopwords}.json" + +python lmloss2keywords.py \ + --model_type ${model_type} \ + --word_loss_file ${word_loss_file} \ + --keywords_num ${keywords_num} \ + --keywords_ratio ${keywords_ratio} \ + --keywords_th ${keywords_th} \ + --stopwords ${stopwords} \ + --output_file ${output_file} + \ No newline at end of file diff --git a/convlab2/base_models/gpt/keyword_extraction/get_word_loss.sh b/convlab2/base_models/gpt/keyword_extraction/get_word_loss.sh new file mode 100644 index 00000000..2aad467c --- /dev/null +++ b/convlab2/base_models/gpt/keyword_extraction/get_word_loss.sh @@ -0,0 +1,65 @@ +set -e +n_gpus=1 +task_name="lm" +dataset_name="multiwoz21" +model_type="dialogpt" +data_dir="data/${task_name}/${dataset_name}/${model_type}" +output_dir="output/${task_name}/${dataset_name}/${model_type}" +cache_dir="../cache" +validation_file="${data_dir}/validation.json" +source_column="dialogue" +max_length=512 +model_name_or_path="microsoft/DialoGPT-large" +per_device_eval_batch_size=4 + +dump_eval_loss_to="${data_dir}/dialogpt-large_${dataset_name}_token_loss.json" +python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type dialogpt +python ../run_clm.py \ + --dump_eval_loss_to ${dump_eval_loss_to}\ + --model_name_or_path ${model_name_or_path} \ + --output_dir ${data_dir} \ + --validation_file ${validation_file} \ + --source_column ${source_column} \ + --max_length ${max_length} \ + --do_eval \ + --prediction_loss_only \ + --cache_dir ${cache_dir} \ + --preprocessing_num_workers 4 \ + --per_device_eval_batch_size ${per_device_eval_batch_size} +python lmloss2keywords.py --token_loss_file ${dump_eval_loss_to} --model_type ${model_type} + +dump_eval_loss_to="${data_dir}/dialogpt-large-mwoz_${dataset_name}_token_loss.json" +python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type dialogpt +python ../run_clm.py \ + --dump_eval_loss_to ${dump_eval_loss_to}\ + --model_name_or_path ${output_dir} \ + --output_dir ${data_dir} \ + --validation_file ${validation_file} \ + --source_column ${source_column} \ + --max_length ${max_length} \ + --do_eval \ + --prediction_loss_only \ + --cache_dir ${cache_dir} \ + --preprocessing_num_workers 4 \ + --per_device_eval_batch_size ${per_device_eval_batch_size} +python lmloss2keywords.py --token_loss_file ${dump_eval_loss_to} --model_type ${model_type} + +model_type="gpt" +data_dir="data/${task_name}/${dataset_name}/${model_type}" +validation_file="${data_dir}/validation.json" +model_name_or_path="gpt2-large" +dump_eval_loss_to="${data_dir}/gpt2-large_${dataset_name}_token_loss.json" +python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type gpt +python ../run_clm.py \ + --dump_eval_loss_to ${dump_eval_loss_to}\ + --model_name_or_path ${model_name_or_path} \ + --output_dir ${data_dir} \ + --validation_file ${validation_file} \ + --source_column ${source_column} \ + --max_length ${max_length} \ + --do_eval \ + --prediction_loss_only \ + --cache_dir ${cache_dir} \ + --preprocessing_num_workers 4 \ + --per_device_eval_batch_size ${per_device_eval_batch_size} +python lmloss2keywords.py --token_loss_file ${dump_eval_loss_to} --model_type ${model_type} diff --git a/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py b/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py new file mode 100644 index 00000000..ab9126ba --- /dev/null +++ b/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py @@ -0,0 +1,123 @@ +import json +import json_lines +from pprint import pprint +import os +from tqdm import tqdm +import numpy as np +from nltk.corpus import stopwords +from nltk.tokenize import word_tokenize + + + +def merge_tokens(tokens, losses, loss_merge_func=np.mean): + res = [] + i = 0 + while i < len(tokens): + token = tokens[i] + loss = losses[i] + if token in ['Ġ', 'Ċ']: + if token == 'Ċ' and i < len(tokens) - 1: + tokens[i+1] = 'Ġ'+tokens[i+1] + i += 1 + continue + if token in ['user', 'system'] and i < len(tokens)-1 and tokens[i+1] == ':': + if i > 0: + tokens[i+1] = '<|endoftext|>' + i += 1 + else: + i += 2 + continue + if token.startswith('Ġ'): + # Ġ means space + token = token.replace("Ġ", "") + res.append([token, loss]) + elif token == '<|endoftext|>': + res.append([token, loss]) + else: + assert 'Ġ' not in token + if len(res) > 0: + res[-1][0] += token + res[-1].append(loss) + else: + res.append([token, loss]) + i += 1 + if loss_merge_func: + for i in range(len(res)): + res[i] = [res[i][0], loss_merge_func(res[i][1:])] + return res + + +def convert_token_loss2word_loss(token_loss_file, loss_merge_func=np.mean): + word_loss_file = os.path.join(os.path.dirname(token_loss_file), token_loss_file.split('/')[-1].replace('token', 'word')) + fin = open(token_loss_file, 'rb') + fout = open(word_loss_file, 'w', encoding='utf-8') + lines = [] + + for item in tqdm(json_lines.reader(fin)): + tokens, losses = item['tokens'], item['losses'] + assert len(tokens) == len(losses) + word2losses = merge_tokens(tokens, losses, loss_merge_func) + lines.append({"words": [x[0] for x in word2losses], "losses": [x[1] for x in word2losses]}) + fout.write(json.dumps(lines[-1], ensure_ascii=False)+'\n') + + fin.close() + fout.close() + return lines + +def main(args): + if not args.word_loss_file: + word_loss_list = convert_token_loss2word_loss(args.token_loss_file) + else: + fin = open(args.word_loss_file, 'rb') + word_loss_list = [] + for item in json_lines.reader(fin): + words, losses = item['words'], item['losses'] + word_loss_list.append({"words": words, "losses": losses}) + fin.close() + + if not args.output_file: + return + + stop_words = set(stopwords.words('english')) + + dialogs = [] + for item in word_loss_list: + words = item['words'] + losses = item['losses'] + turns = [] + turn = {'words': [], 'losses': []} + for word, loss in zip(words, losses): + if word == '<|endoftext|>': + # switch turn + turn['utterance'] = ' '.join(turn['words']) + turn['keywords'] = list(zip(turn['words'], turn['losses'])) + if args.stopwords: + turn['keywords'] = [x for x in turn['keywords'] if not any([w.lower() in stop_words for w in word_tokenize(x[0])])] + turn['keywords'] = sorted(turn['keywords'], key=lambda x: x[1], reverse=True) + turn['keywords'] = [x for x in turn['keywords'] if x[1] > args.keywords_th][:min(round(args.keywords_ratio*len(turn['keywords'])), args.keywords_num)] + turn.pop('words') + turn.pop('losses') + turns.append(turn) + turn = {'words': [], 'losses': []} + else: + turn['words'].append(word) + turn['losses'].append(loss) + dialogs.append(turns) + json.dump(dialogs, open(args.output_file, "w", encoding='utf-8'), indent=2, ensure_ascii=False) + + + +if __name__ == '__main__': + from argparse import ArgumentParser + parser = ArgumentParser(description="calculate NLU metrics for unified datasets") + parser.add_argument('--model_type', '-m', type=str, help='gpt or dialogpt') + parser.add_argument('--token_loss_file', '-t', type=str, help='path to the token loss file that contains two columns: [tokens, losses]') + parser.add_argument('--word_loss_file', '-w', type=str, help='path to the token loss file that contains two columns: [tokens, losses]') + parser.add_argument('--output_file', '-o', type=str, help='path to the output file') + parser.add_argument('--keywords_num', '-n', type=int, default=100, help='how many words in an utterance serve as keywords') + parser.add_argument('--keywords_ratio', '-r', type=float, default=1.0, help='how many words (in ratio) in an utterance serve as keywords') + parser.add_argument('--keywords_th', '-th', type=float, default=0., help='loss threshold for the keywords') + parser.add_argument('--stopwords', '-s', type=lambda x: bool(eval(x)), default=True, help='filter out stopwords') + args = parser.parse_args() + print(args) + main(args) diff --git a/convlab2/base_models/gpt/keyword_extraction/run.sh b/convlab2/base_models/gpt/keyword_extraction/train_lm.sh similarity index 90% rename from convlab2/base_models/gpt/keyword_extraction/run.sh rename to convlab2/base_models/gpt/keyword_extraction/train_lm.sh index 85d87296..4ae47c32 100644 --- a/convlab2/base_models/gpt/keyword_extraction/run.sh +++ b/convlab2/base_models/gpt/keyword_extraction/train_lm.sh @@ -2,8 +2,9 @@ set -e n_gpus=1 task_name="lm" dataset_name="multiwoz21" -data_dir="data/${task_name}/${dataset_name}" -output_dir="output/${task_name}/${dataset_name}" +model_type="dialogpt" +data_dir="data/${task_name}/${dataset_name}/${model_type}" +output_dir="output/${task_name}/${dataset_name}/${model_type}" cache_dir="../cache" logging_dir="${output_dir}/runs" train_file="${data_dir}/train.json" diff --git a/convlab2/base_models/gpt/run_clm.py b/convlab2/base_models/gpt/run_clm.py index 95e020d7..9dff4a0a 100644 --- a/convlab2/base_models/gpt/run_clm.py +++ b/convlab2/base_models/gpt/run_clm.py @@ -30,7 +30,11 @@ from itertools import chain from typing import Optional import datasets -from datasets import load_dataset, load_metric +from datasets import load_dataset +from tqdm import tqdm +from torch.utils.data import DataLoader +import torch +import json import transformers from transformers import ( @@ -156,6 +160,9 @@ class DataTrainingArguments: "help": "An optional input evaluation data file to evaluate the metrics on (a text, jsonlines or csv file)." }, ) + dump_eval_loss_to: Optional[str] = field( + default=None, metadata={"help": "Where to dump the tokens' losses in the evaluation data, default not to"} + ) overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) @@ -484,23 +491,6 @@ def main(): pad_to_multiple_of=8 if training_args.fp16 else None, ) - def preprocess_logits_for_metrics(logits, labels): - if isinstance(logits, tuple): - # Depending on the model and config, logits may contain extra tensors, - # like past_key_values, but logits always come first - logits = logits[0] - return logits.argmax(dim=-1) - - metric = load_metric("accuracy") - - def compute_metrics(eval_preds): - preds, labels = eval_preds - # preds have the same shape as the labels, after the argmax(-1) has been calculated - # by preprocess_logits_for_metrics but we need to shift the labels - labels = labels[:, 1:].reshape(-1) - preds = preds[:, :-1].reshape(-1) - return metric.compute(predictions=preds, references=labels) - # Initialize our Trainer trainer = Trainer( model=model, @@ -509,11 +499,7 @@ def main(): eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, # Data collator will default to DataCollatorWithPadding, so we change it. - data_collator=data_collator, - compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, - preprocess_logits_for_metrics=preprocess_logits_for_metrics - if training_args.do_eval and not is_torch_tpu_available() - else None, + data_collator=data_collator ) # Training @@ -539,17 +525,57 @@ def main(): # Evaluation if training_args.do_eval: logger.info("*** Evaluate ***") - metrics = trainer.evaluate(metric_key_prefix="eval") - max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) - metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) - try: - perplexity = math.exp(metrics["eval_loss"]) - except OverflowError: - perplexity = float("inf") - metrics["eval_perplexity"] = perplexity - - trainer.log_metrics("eval", metrics) - trainer.save_metrics("eval", metrics) + if not data_args.dump_eval_loss_to: + metrics = trainer.evaluate(metric_key_prefix="eval") + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") + metrics["eval_perplexity"] = perplexity + logger.info(f"eval_perplexity: {perplexity}") + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + else: + if trainer.is_world_process_zero(): + output_prediction_file = data_args.dump_eval_loss_to + writer = open(output_prediction_file, "w", encoding='utf-8') + + eval_dataloader = DataLoader( + eval_dataset, collate_fn=lambda x: {k: v.to(model.device) for k, v in data_collator(x).items()}, batch_size=training_args.per_device_eval_batch_size + ) + model.eval() + losses = [] + loss_fct = torch.nn.CrossEntropyLoss(reduction='none') + for batch in tqdm(eval_dataloader): + with torch.no_grad(): + outputs = model(**batch) + + loss = outputs.loss + losses.append(loss.repeat(training_args.per_device_eval_batch_size)) + + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = batch['labels'][..., 1:].contiguous() + batch_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + batch_token_loss = batch_token_loss.view(shift_labels.size()).tolist() + labels = batch['labels'].tolist() + for i in range(len(labels)): + token_ids = [x for x in labels[i] if x != -100] + tokens = tokenizer.convert_ids_to_tokens(token_ids) + token_losses = [0] + batch_token_loss[i][:len(token_ids)-1] + writer.write(json.dumps({"tokens": tokens, "losses": token_losses}, ensure_ascii=False)+'\n') + + losses = torch.cat(losses) + losses = losses[: len(eval_dataset)] + try: + perplexity = math.exp(torch.mean(losses)) + except OverflowError: + perplexity = float("inf") + logger.info(f"perplexity: {perplexity}") + + writer.close() kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} if data_args.dataset_name is not None: -- GitLab