From a15bedeaac509334e5fa30b532f7765dfd79fb3f Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Wed, 13 Apr 2022 16:41:59 +0800
Subject: [PATCH] add keyword_extract for gpt

---
 convlab2/base_models/gpt/create_data.py       |   4 +-
 .../gpt/keyword_extraction/get_keywords.sh    |  20 +++
 .../gpt/keyword_extraction/get_word_loss.sh   |  65 +++++++++
 .../gpt/keyword_extraction/lmloss2keywords.py | 123 ++++++++++++++++++
 .../{run.sh => train_lm.sh}                   |   5 +-
 convlab2/base_models/gpt/run_clm.py           |  94 ++++++++-----
 6 files changed, 273 insertions(+), 38 deletions(-)
 create mode 100644 convlab2/base_models/gpt/keyword_extraction/get_keywords.sh
 create mode 100644 convlab2/base_models/gpt/keyword_extraction/get_word_loss.sh
 create mode 100644 convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
 rename convlab2/base_models/gpt/keyword_extraction/{run.sh => train_lm.sh} (90%)

diff --git a/convlab2/base_models/gpt/create_data.py b/convlab2/base_models/gpt/create_data.py
index 94e88f3e..dd616b59 100644
--- a/convlab2/base_models/gpt/create_data.py
+++ b/convlab2/base_models/gpt/create_data.py
@@ -16,7 +16,7 @@ def create_lm_data(dataset, data_dir, args):
             if args.model_type == 'dialogpt':
                 dialogue = ' <|endoftext|> '.join([turn['utterance'] for turn in sample['turns']]) + ' <|endoftext|>'
             else:
-                dialogue = ' '.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['turns']])
+                dialogue = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['turns']])
             data.append(json.dumps({'dialogue': dialogue}, ensure_ascii=False)+'\n')
 
         file_name = os.path.join(data_dir, f"{data_split}.json")
@@ -35,5 +35,5 @@ if __name__ == '__main__':
     for dataset_name in tqdm(args.datasets, desc='datasets'):
         dataset = load_dataset(dataset_name)
         for task_name in tqdm(args.tasks, desc='tasks', leave=False):
-            data_dir = os.path.join('data', task_name, dataset_name)
+            data_dir = os.path.join('data', task_name, dataset_name, args.model_type)
             eval(f"create_{task_name}_data")(dataset, data_dir, args)
diff --git a/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh b/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh
new file mode 100644
index 00000000..6dd2680b
--- /dev/null
+++ b/convlab2/base_models/gpt/keyword_extraction/get_keywords.sh
@@ -0,0 +1,20 @@
+model_type=dialogpt
+dataset_name=multiwoz21
+model_name=dialogpt-large
+data_dir="data/lm/${dataset_name}/${model_type}"
+word_loss_file="${data_dir}/${model_name}_${dataset_name}_word_loss.json"
+keywords_num=5
+keywords_ratio=1
+keywords_th=0
+stopwords=True
+output_file="${data_dir}/${dataset_name}_keywords_${model_name}_topk_${keywords_num}_ratio_${keywords_ratio}_th_${keywords_th}_stopwords_${stopwords}.json"
+
+python lmloss2keywords.py \
+    --model_type ${model_type} \
+    --word_loss_file ${word_loss_file} \
+    --keywords_num ${keywords_num} \
+    --keywords_ratio ${keywords_ratio} \
+    --keywords_th ${keywords_th} \
+    --stopwords ${stopwords} \
+    --output_file ${output_file}
+    
\ No newline at end of file
diff --git a/convlab2/base_models/gpt/keyword_extraction/get_word_loss.sh b/convlab2/base_models/gpt/keyword_extraction/get_word_loss.sh
new file mode 100644
index 00000000..2aad467c
--- /dev/null
+++ b/convlab2/base_models/gpt/keyword_extraction/get_word_loss.sh
@@ -0,0 +1,65 @@
+set -e
+n_gpus=1
+task_name="lm"
+dataset_name="multiwoz21"
+model_type="dialogpt"
+data_dir="data/${task_name}/${dataset_name}/${model_type}"
+output_dir="output/${task_name}/${dataset_name}/${model_type}"
+cache_dir="../cache"
+validation_file="${data_dir}/validation.json"
+source_column="dialogue"
+max_length=512
+model_name_or_path="microsoft/DialoGPT-large"
+per_device_eval_batch_size=4
+
+dump_eval_loss_to="${data_dir}/dialogpt-large_${dataset_name}_token_loss.json"
+python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type dialogpt
+python ../run_clm.py \
+    --dump_eval_loss_to ${dump_eval_loss_to}\
+    --model_name_or_path ${model_name_or_path} \
+    --output_dir ${data_dir} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --max_length ${max_length} \
+    --do_eval \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --preprocessing_num_workers 4 \
+    --per_device_eval_batch_size ${per_device_eval_batch_size}
+python lmloss2keywords.py --token_loss_file ${dump_eval_loss_to} --model_type ${model_type}
+
+dump_eval_loss_to="${data_dir}/dialogpt-large-mwoz_${dataset_name}_token_loss.json"
+python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type dialogpt
+python ../run_clm.py \
+    --dump_eval_loss_to ${dump_eval_loss_to}\
+    --model_name_or_path ${output_dir} \
+    --output_dir ${data_dir} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --max_length ${max_length} \
+    --do_eval \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --preprocessing_num_workers 4 \
+    --per_device_eval_batch_size ${per_device_eval_batch_size}
+python lmloss2keywords.py --token_loss_file ${dump_eval_loss_to} --model_type ${model_type}
+
+model_type="gpt"
+data_dir="data/${task_name}/${dataset_name}/${model_type}"
+validation_file="${data_dir}/validation.json"
+model_name_or_path="gpt2-large"
+dump_eval_loss_to="${data_dir}/gpt2-large_${dataset_name}_token_loss.json"
+python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type gpt
+python ../run_clm.py \
+    --dump_eval_loss_to ${dump_eval_loss_to}\
+    --model_name_or_path ${model_name_or_path} \
+    --output_dir ${data_dir} \
+    --validation_file ${validation_file} \
+    --source_column ${source_column} \
+    --max_length ${max_length} \
+    --do_eval \
+    --prediction_loss_only \
+    --cache_dir ${cache_dir} \
+    --preprocessing_num_workers 4 \
+    --per_device_eval_batch_size ${per_device_eval_batch_size}
+python lmloss2keywords.py --token_loss_file ${dump_eval_loss_to} --model_type ${model_type}
diff --git a/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py b/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
new file mode 100644
index 00000000..ab9126ba
--- /dev/null
+++ b/convlab2/base_models/gpt/keyword_extraction/lmloss2keywords.py
@@ -0,0 +1,123 @@
+import json
+import json_lines
+from pprint import pprint
+import os
+from tqdm import tqdm
+import numpy as np
+from nltk.corpus import stopwords
+from nltk.tokenize import word_tokenize
+
+
+
+def merge_tokens(tokens, losses, loss_merge_func=np.mean):
+    res = []
+    i = 0
+    while i < len(tokens):
+        token = tokens[i]
+        loss = losses[i]
+        if token in ['Ġ', 'Ċ']:
+            if token == 'Ċ' and i < len(tokens) - 1:
+                tokens[i+1] = 'Ġ'+tokens[i+1]
+            i += 1
+            continue
+        if token in ['user', 'system'] and i < len(tokens)-1 and tokens[i+1] == ':':
+            if i > 0:
+                tokens[i+1] = '<|endoftext|>'
+                i += 1
+            else:
+                i += 2
+            continue
+        if token.startswith('Ġ'):
+            # Ġ means space
+            token = token.replace("Ġ", "")
+            res.append([token, loss])
+        elif token == '<|endoftext|>':
+            res.append([token, loss])
+        else:
+            assert 'Ġ' not in token
+            if len(res) > 0:
+                res[-1][0] += token
+                res[-1].append(loss)
+            else:
+                res.append([token, loss])
+        i += 1
+    if loss_merge_func:
+        for i in range(len(res)):
+            res[i] = [res[i][0], loss_merge_func(res[i][1:])]
+    return res
+
+
+def convert_token_loss2word_loss(token_loss_file, loss_merge_func=np.mean):
+    word_loss_file = os.path.join(os.path.dirname(token_loss_file), token_loss_file.split('/')[-1].replace('token', 'word'))
+    fin = open(token_loss_file, 'rb')
+    fout = open(word_loss_file, 'w', encoding='utf-8')
+    lines = []
+
+    for item in tqdm(json_lines.reader(fin)):
+        tokens, losses = item['tokens'], item['losses']
+        assert len(tokens) == len(losses)
+        word2losses = merge_tokens(tokens, losses, loss_merge_func)
+        lines.append({"words": [x[0] for x in word2losses], "losses": [x[1] for x in word2losses]})
+        fout.write(json.dumps(lines[-1], ensure_ascii=False)+'\n')
+
+    fin.close()
+    fout.close()
+    return lines
+
+def main(args):
+    if not args.word_loss_file:
+        word_loss_list = convert_token_loss2word_loss(args.token_loss_file)
+    else:
+        fin = open(args.word_loss_file, 'rb')
+        word_loss_list = []
+        for item in json_lines.reader(fin):
+            words, losses = item['words'], item['losses']
+            word_loss_list.append({"words": words, "losses": losses})
+        fin.close()
+
+    if not args.output_file:
+        return
+
+    stop_words = set(stopwords.words('english'))
+
+    dialogs = []
+    for item in word_loss_list:
+        words = item['words']
+        losses = item['losses']
+        turns = []
+        turn = {'words': [], 'losses': []}
+        for word, loss in zip(words, losses):
+            if word == '<|endoftext|>':
+                # switch turn
+                turn['utterance'] = ' '.join(turn['words'])
+                turn['keywords'] = list(zip(turn['words'], turn['losses']))
+                if args.stopwords:
+                    turn['keywords'] = [x for x in turn['keywords'] if not any([w.lower() in stop_words for w in word_tokenize(x[0])])]
+                turn['keywords'] = sorted(turn['keywords'], key=lambda x: x[1], reverse=True)
+                turn['keywords'] = [x for x in turn['keywords'] if x[1] > args.keywords_th][:min(round(args.keywords_ratio*len(turn['keywords'])), args.keywords_num)]
+                turn.pop('words')
+                turn.pop('losses')
+                turns.append(turn)
+                turn = {'words': [], 'losses': []}
+            else:
+                turn['words'].append(word)
+                turn['losses'].append(loss)
+        dialogs.append(turns)
+    json.dump(dialogs, open(args.output_file, "w", encoding='utf-8'), indent=2, ensure_ascii=False)
+
+
+
+if __name__ == '__main__':
+    from argparse import ArgumentParser
+    parser = ArgumentParser(description="calculate NLU metrics for unified datasets")
+    parser.add_argument('--model_type', '-m', type=str, help='gpt or dialogpt')
+    parser.add_argument('--token_loss_file', '-t', type=str, help='path to the token loss file that contains two columns: [tokens, losses]')
+    parser.add_argument('--word_loss_file', '-w', type=str, help='path to the token loss file that contains two columns: [tokens, losses]')
+    parser.add_argument('--output_file', '-o', type=str, help='path to the output file')
+    parser.add_argument('--keywords_num', '-n', type=int, default=100, help='how many words in an utterance serve as keywords')
+    parser.add_argument('--keywords_ratio', '-r', type=float, default=1.0, help='how many words (in ratio) in an utterance serve as keywords')
+    parser.add_argument('--keywords_th', '-th', type=float, default=0., help='loss threshold for the keywords')
+    parser.add_argument('--stopwords', '-s', type=lambda x: bool(eval(x)), default=True, help='filter out stopwords')
+    args = parser.parse_args()
+    print(args)
+    main(args)
diff --git a/convlab2/base_models/gpt/keyword_extraction/run.sh b/convlab2/base_models/gpt/keyword_extraction/train_lm.sh
similarity index 90%
rename from convlab2/base_models/gpt/keyword_extraction/run.sh
rename to convlab2/base_models/gpt/keyword_extraction/train_lm.sh
index 85d87296..4ae47c32 100644
--- a/convlab2/base_models/gpt/keyword_extraction/run.sh
+++ b/convlab2/base_models/gpt/keyword_extraction/train_lm.sh
@@ -2,8 +2,9 @@ set -e
 n_gpus=1
 task_name="lm"
 dataset_name="multiwoz21"
-data_dir="data/${task_name}/${dataset_name}"
-output_dir="output/${task_name}/${dataset_name}"
+model_type="dialogpt"
+data_dir="data/${task_name}/${dataset_name}/${model_type}"
+output_dir="output/${task_name}/${dataset_name}/${model_type}"
 cache_dir="../cache"
 logging_dir="${output_dir}/runs"
 train_file="${data_dir}/train.json"
diff --git a/convlab2/base_models/gpt/run_clm.py b/convlab2/base_models/gpt/run_clm.py
index 95e020d7..9dff4a0a 100644
--- a/convlab2/base_models/gpt/run_clm.py
+++ b/convlab2/base_models/gpt/run_clm.py
@@ -30,7 +30,11 @@ from itertools import chain
 from typing import Optional
 
 import datasets
-from datasets import load_dataset, load_metric
+from datasets import load_dataset
+from tqdm import tqdm
+from torch.utils.data import DataLoader
+import torch
+import json
 
 import transformers
 from transformers import (
@@ -156,6 +160,9 @@ class DataTrainingArguments:
             "help": "An optional input evaluation data file to evaluate the metrics on (a text, jsonlines or csv file)."
         },
     )
+    dump_eval_loss_to: Optional[str] = field(
+        default=None, metadata={"help": "Where to dump the tokens' losses in the evaluation data, default not to"}
+    )
     overwrite_cache: bool = field(
         default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
     )
@@ -484,23 +491,6 @@ def main():
         pad_to_multiple_of=8 if training_args.fp16 else None,
     )
 
-    def preprocess_logits_for_metrics(logits, labels):
-        if isinstance(logits, tuple):
-            # Depending on the model and config, logits may contain extra tensors,
-            # like past_key_values, but logits always come first
-            logits = logits[0]
-        return logits.argmax(dim=-1)
-
-    metric = load_metric("accuracy")
-
-    def compute_metrics(eval_preds):
-        preds, labels = eval_preds
-        # preds have the same shape as the labels, after the argmax(-1) has been calculated
-        # by preprocess_logits_for_metrics but we need to shift the labels
-        labels = labels[:, 1:].reshape(-1)
-        preds = preds[:, :-1].reshape(-1)
-        return metric.compute(predictions=preds, references=labels)
-
     # Initialize our Trainer
     trainer = Trainer(
         model=model,
@@ -509,11 +499,7 @@ def main():
         eval_dataset=eval_dataset if training_args.do_eval else None,
         tokenizer=tokenizer,
         # Data collator will default to DataCollatorWithPadding, so we change it.
-        data_collator=data_collator,
-        compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
-        preprocess_logits_for_metrics=preprocess_logits_for_metrics
-        if training_args.do_eval and not is_torch_tpu_available()
-        else None,
+        data_collator=data_collator
     )
 
     # Training
@@ -539,17 +525,57 @@ def main():
     # Evaluation
     if training_args.do_eval:
         logger.info("*** Evaluate ***")
-        metrics = trainer.evaluate(metric_key_prefix="eval")
-        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
-        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
-        try:
-            perplexity = math.exp(metrics["eval_loss"])
-        except OverflowError:
-            perplexity = float("inf")
-        metrics["eval_perplexity"] = perplexity
-
-        trainer.log_metrics("eval", metrics)
-        trainer.save_metrics("eval", metrics)
+        if not data_args.dump_eval_loss_to:
+            metrics = trainer.evaluate(metric_key_prefix="eval")
+            max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
+            metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
+            try:
+                perplexity = math.exp(metrics["eval_loss"])
+            except OverflowError:
+                perplexity = float("inf")
+            metrics["eval_perplexity"] = perplexity
+            logger.info(f"eval_perplexity: {perplexity}")
+
+            trainer.log_metrics("eval", metrics)
+            trainer.save_metrics("eval", metrics)
+        else:
+            if trainer.is_world_process_zero():
+                output_prediction_file = data_args.dump_eval_loss_to
+                writer = open(output_prediction_file, "w", encoding='utf-8')
+
+                eval_dataloader = DataLoader(
+                    eval_dataset, collate_fn=lambda x: {k: v.to(model.device) for k, v in data_collator(x).items()}, batch_size=training_args.per_device_eval_batch_size
+                )
+                model.eval()
+                losses = []
+                loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
+                for batch in tqdm(eval_dataloader):
+                    with torch.no_grad():
+                        outputs = model(**batch)
+
+                    loss = outputs.loss
+                    losses.append(loss.repeat(training_args.per_device_eval_batch_size))
+                    
+                    shift_logits = outputs.logits[..., :-1, :].contiguous()
+                    shift_labels = batch['labels'][..., 1:].contiguous()
+                    batch_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+                    batch_token_loss = batch_token_loss.view(shift_labels.size()).tolist()
+                    labels = batch['labels'].tolist()
+                    for i in range(len(labels)):
+                        token_ids = [x for x in labels[i] if x != -100]
+                        tokens = tokenizer.convert_ids_to_tokens(token_ids)
+                        token_losses = [0] + batch_token_loss[i][:len(token_ids)-1]
+                        writer.write(json.dumps({"tokens": tokens, "losses": token_losses}, ensure_ascii=False)+'\n')
+
+                losses = torch.cat(losses)
+                losses = losses[: len(eval_dataset)]
+                try:
+                    perplexity = math.exp(torch.mean(losses))
+                except OverflowError:
+                    perplexity = float("inf")
+                logger.info(f"perplexity: {perplexity}")
+
+                writer.close()
 
     kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
     if data_args.dataset_name is not None:
-- 
GitLab