From 76c1d32e6d75cb73de20f60c101b9a223a45dfbc Mon Sep 17 00:00:00 2001
From: zqwerty <zhuq96@hotmail.com>
Date: Fri, 15 Jul 2022 15:52:28 +0800
Subject: [PATCH] update get token loss method for multi-gpu inference

---
 convlab/base_models/gpt/create_data.py        |   2 +-
 .../gpt/keyword_extraction/get_token_loss.sh  |  35 +++
 convlab/base_models/gpt/run_clm.py            |  73 ++----
 convlab/base_models/gpt/trainer.py            | 243 ++++++++++++++++++
 4 files changed, 297 insertions(+), 56 deletions(-)
 create mode 100644 convlab/base_models/gpt/keyword_extraction/get_token_loss.sh
 create mode 100644 convlab/base_models/gpt/trainer.py

diff --git a/convlab/base_models/gpt/create_data.py b/convlab/base_models/gpt/create_data.py
index c9e6e9f0..e6c4d67b 100644
--- a/convlab/base_models/gpt/create_data.py
+++ b/convlab/base_models/gpt/create_data.py
@@ -35,5 +35,5 @@ if __name__ == '__main__':
     for dataset_name in tqdm(args.datasets, desc='datasets'):
         dataset = load_dataset(dataset_name)
         for task_name in tqdm(args.tasks, desc='tasks', leave=False):
-            data_dir = os.path.join('data', task_name, dataset_name, args.model_type)
+            data_dir = os.path.join('data', task_name, args.model_type, dataset_name)
             eval(f"create_{task_name}_data")(dataset, data_dir, args)
diff --git a/convlab/base_models/gpt/keyword_extraction/get_token_loss.sh b/convlab/base_models/gpt/keyword_extraction/get_token_loss.sh
new file mode 100644
index 00000000..4c73a3b0
--- /dev/null
+++ b/convlab/base_models/gpt/keyword_extraction/get_token_loss.sh
@@ -0,0 +1,35 @@
+n_gpus=4
+master_port=23456
+task_name="lm"
+model_type="gpt"
+cache_dir="../cache"
+source_column="dialogue"
+max_length=512
+model_name_or_path="/data/zhuqi/pre-trained-models/gpt2-large"
+per_device_eval_batch_size=16
+
+for dataset_name in dailydialog metalwoz tm1 tm2 tm3 sgd reddit wikidialog
+do
+    data_dir="data/${task_name}/${model_type}/${dataset_name}"
+    output_dir="output/${task_name}/${model_type}/${dataset_name}"
+
+    python ../create_data.py --tasks ${task_name} --datasets ${dataset_name} --model_type ${model_type}
+    for data_split in validation train
+    do
+        validation_file="${data_dir}/${data_split}.json"
+        dump_eval_loss_to="${data_dir}/token_loss_${data_split}.json"
+        rm ${dump_eval_loss_to}
+        python -m torch.distributed.launch --master_port ${master_port} \
+            --nproc_per_node ${n_gpus} ../run_clm.py \
+            --dump_eval_loss_to ${dump_eval_loss_to}\
+            --model_name_or_path ${model_name_or_path} \
+            --output_dir ${data_dir} \
+            --validation_file ${validation_file} \
+            --source_column ${source_column} \
+            --max_length ${max_length} \
+            --do_eval \
+            --cache_dir ${cache_dir} \
+            --preprocessing_num_workers 4 \
+            --per_device_eval_batch_size ${per_device_eval_batch_size}
+    done
+done
diff --git a/convlab/base_models/gpt/run_clm.py b/convlab/base_models/gpt/run_clm.py
index 9dff4a0a..ace68609 100644
--- a/convlab/base_models/gpt/run_clm.py
+++ b/convlab/base_models/gpt/run_clm.py
@@ -44,7 +44,6 @@ from transformers import (
     AutoModelForCausalLM,
     AutoTokenizer,
     HfArgumentParser,
-    Trainer,
     TrainingArguments,
     DataCollatorForTokenClassification,
     is_torch_tpu_available,
@@ -53,6 +52,7 @@ from transformers import (
 from transformers.trainer_utils import get_last_checkpoint
 from transformers.utils import check_min_version
 from transformers.utils.versions import require_version
+from convlab.base_models.gpt.trainer import DumpTokenLossTrainer
 
 
 # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
@@ -491,15 +491,17 @@ def main():
         pad_to_multiple_of=8 if training_args.fp16 else None,
     )
 
+    training_args.dump_eval_loss_to = data_args.dump_eval_loss_to
+    
     # Initialize our Trainer
-    trainer = Trainer(
+    trainer = DumpTokenLossTrainer(
         model=model,
         args=training_args,
         train_dataset=train_dataset if training_args.do_train else None,
         eval_dataset=eval_dataset if training_args.do_eval else None,
         tokenizer=tokenizer,
         # Data collator will default to DataCollatorWithPadding, so we change it.
-        data_collator=data_collator
+        data_collator=data_collator,
     )
 
     # Training
@@ -525,58 +527,19 @@ def main():
     # Evaluation
     if training_args.do_eval:
         logger.info("*** Evaluate ***")
-        if not data_args.dump_eval_loss_to:
-            metrics = trainer.evaluate(metric_key_prefix="eval")
-            max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
-            metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
-            try:
-                perplexity = math.exp(metrics["eval_loss"])
-            except OverflowError:
-                perplexity = float("inf")
-            metrics["eval_perplexity"] = perplexity
-            logger.info(f"eval_perplexity: {perplexity}")
-
-            trainer.log_metrics("eval", metrics)
-            trainer.save_metrics("eval", metrics)
-        else:
-            if trainer.is_world_process_zero():
-                output_prediction_file = data_args.dump_eval_loss_to
-                writer = open(output_prediction_file, "w", encoding='utf-8')
-
-                eval_dataloader = DataLoader(
-                    eval_dataset, collate_fn=lambda x: {k: v.to(model.device) for k, v in data_collator(x).items()}, batch_size=training_args.per_device_eval_batch_size
-                )
-                model.eval()
-                losses = []
-                loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
-                for batch in tqdm(eval_dataloader):
-                    with torch.no_grad():
-                        outputs = model(**batch)
-
-                    loss = outputs.loss
-                    losses.append(loss.repeat(training_args.per_device_eval_batch_size))
-                    
-                    shift_logits = outputs.logits[..., :-1, :].contiguous()
-                    shift_labels = batch['labels'][..., 1:].contiguous()
-                    batch_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
-                    batch_token_loss = batch_token_loss.view(shift_labels.size()).tolist()
-                    labels = batch['labels'].tolist()
-                    for i in range(len(labels)):
-                        token_ids = [x for x in labels[i] if x != -100]
-                        tokens = tokenizer.convert_ids_to_tokens(token_ids)
-                        token_losses = [0] + batch_token_loss[i][:len(token_ids)-1]
-                        writer.write(json.dumps({"tokens": tokens, "losses": token_losses}, ensure_ascii=False)+'\n')
-
-                losses = torch.cat(losses)
-                losses = losses[: len(eval_dataset)]
-                try:
-                    perplexity = math.exp(torch.mean(losses))
-                except OverflowError:
-                    perplexity = float("inf")
-                logger.info(f"perplexity: {perplexity}")
-
-                writer.close()
-
+        metrics = trainer.evaluate(metric_key_prefix="eval")
+        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
+        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
+        try:
+            perplexity = math.exp(metrics["eval_loss"])
+        except OverflowError:
+            perplexity = float("inf")
+        metrics["eval_perplexity"] = perplexity
+        logger.info(f"eval_perplexity: {perplexity}")
+
+        trainer.log_metrics("eval", metrics)
+        trainer.save_metrics("eval", metrics)
+        
     kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
     if data_args.dataset_name is not None:
         kwargs["dataset_tags"] = data_args.dataset_name
diff --git a/convlab/base_models/gpt/trainer.py b/convlab/base_models/gpt/trainer.py
new file mode 100644
index 00000000..5a8ed11c
--- /dev/null
+++ b/convlab/base_models/gpt/trainer.py
@@ -0,0 +1,243 @@
+from transformers import Trainer
+from transformers.trainer_utils import EvalLoopOutput, has_length
+from transformers.deepspeed import deepspeed_init
+from transformers.utils import logging
+from transformers.trainer_pt_utils import find_batch_size, nested_concat, nested_numpify, IterableDatasetShard, nested_truncate
+from transformers.trainer_utils import EvalPrediction, denumpify_detensorize
+import torch
+from torch.utils.data import DataLoader
+import numpy as np
+from typing import List, Optional
+import json
+
+
+logger = logging.get_logger(__name__)
+
+class DumpTokenLossTrainer(Trainer):
+    def evaluation_loop(
+        self,
+        dataloader: DataLoader,
+        description: str,
+        prediction_loss_only: Optional[bool] = None,
+        ignore_keys: Optional[List[str]] = None,
+        metric_key_prefix: str = "eval",
+    ) -> EvalLoopOutput:
+        """
+        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
+        Works both with or without labels.
+        """
+        args = self.args
+
+        prediction_loss_only = args.prediction_loss_only
+
+        # if eval is called w/o train init deepspeed here
+        if args.deepspeed and not self.deepspeed:
+
+            # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
+            # from the checkpoint eventually
+            deepspeed_engine, _, _ = deepspeed_init(
+                self, num_training_steps=0, resume_from_checkpoint=None, inference=True
+            )
+            self.model = deepspeed_engine.module
+            self.model_wrapped = deepspeed_engine
+            self.deepspeed = deepspeed_engine
+
+        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
+
+        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
+        # while ``train`` is running, cast it to the right dtype first and then put on device
+        if not self.is_in_train:
+            if args.fp16_full_eval:
+                model = model.to(dtype=torch.float16, device=args.device)
+            elif args.bf16_full_eval:
+                model = model.to(dtype=torch.bfloat16, device=args.device)
+
+        batch_size = self.args.eval_batch_size
+
+        logger.info(f"***** Running {description} *****")
+        if has_length(dataloader):
+            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
+        else:
+            logger.info("  Num examples: Unknown")
+        logger.info(f"  Batch size = {batch_size}")
+
+        model.eval()
+
+        self.callback_handler.eval_dataloader = dataloader
+        # Do this before wrapping.
+        eval_dataset = getattr(dataloader, "dataset", None)
+
+        if args.past_index >= 0:
+            self._past = None
+
+        # Initialize containers
+        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
+        losses_host = None
+        preds_host = None
+        labels_host = None
+        inputs_host = None
+
+        # losses/preds/labels on CPU (final containers)
+        all_losses = None
+        all_preds = None
+        all_labels = None
+        all_inputs = None
+        # Will be useful when we have an iterable dataset so don't know its length.
+
+        if args.dump_eval_loss_to:
+            writer = open(args.dump_eval_loss_to, "a", encoding='utf-8')
+            loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
+            num_sample_to_write = len(eval_dataset)
+
+        observed_num_examples = 0
+        # Main evaluation loop
+        for step, inputs in enumerate(dataloader):
+            # Update the observed num examples
+            observed_batch_size = find_batch_size(inputs)
+            if observed_batch_size is not None:
+                observed_num_examples += observed_batch_size
+                # For batch samplers, batch_size is not known by the dataloader in advance.
+                if batch_size is None:
+                    batch_size = observed_batch_size
+
+            # Prediction step
+            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
+            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
+
+            # Update containers on host
+            if loss is not None:
+                losses = self._nested_gather(loss.repeat(batch_size))
+                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
+            if labels is not None:
+                labels = self._pad_across_processes(labels)
+                labels = self._nested_gather(labels)
+                # labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
+            if inputs_decode is not None:
+                inputs_decode = self._pad_across_processes(inputs_decode)
+                inputs_decode = self._nested_gather(inputs_decode)
+                inputs_host = (
+                    inputs_decode
+                    if inputs_host is None
+                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
+                )
+            if logits is not None:
+                logits = self._pad_across_processes(logits)
+                logits = self._nested_gather(logits)
+                if self.preprocess_logits_for_metrics is not None:
+                    logits = self.preprocess_logits_for_metrics(logits, labels)
+                # preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
+
+            if args.dump_eval_loss_to:
+                if self.is_world_process_zero() and num_sample_to_write > 0:
+                    assert logits is not None and labels is not None, print('prediction_loss_only', prediction_loss_only)
+                    shift_logits = logits[..., :-1, :].contiguous()
+                    shift_labels = labels[..., 1:].contiguous()
+                    batch_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+                    batch_token_loss = batch_token_loss.view(shift_labels.size()).tolist()
+                    labels = labels.tolist()
+                    for i in range(len(labels)):
+                        if num_sample_to_write > 0:
+                            num_sample_to_write -= 1
+                        else:
+                            break
+                        token_ids = [x for x in labels[i] if x != -100]
+                        tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
+                        token_losses = [0] + batch_token_loss[i][:len(token_ids)-1]
+                        writer.write(json.dumps({"tokens": tokens, "losses": token_losses}, ensure_ascii=False)+'\n')
+
+            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
+
+            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
+            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
+                if losses_host is not None:
+                    losses = nested_numpify(losses_host)
+                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
+                if preds_host is not None:
+                    logits = nested_numpify(preds_host)
+                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
+                if inputs_host is not None:
+                    inputs_decode = nested_numpify(inputs_host)
+                    all_inputs = (
+                        inputs_decode
+                        if all_inputs is None
+                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)
+                    )
+                if labels_host is not None:
+                    labels = nested_numpify(labels_host)
+                    all_labels = (
+                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
+                    )
+
+                # Set back to None to begin a new accumulation
+                losses_host, preds_host, inputs_host, labels_host = None, None, None, None
+
+        if args.dump_eval_loss_to:
+            writer.close()
+        
+        if args.past_index and hasattr(self, "_past"):
+            # Clean the state at the end of the evaluation loop
+            delattr(self, "_past")
+
+        # Gather all remaining tensors and put them back on the CPU
+        if losses_host is not None:
+            losses = nested_numpify(losses_host)
+            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
+        if preds_host is not None:
+            logits = nested_numpify(preds_host)
+            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
+        if inputs_host is not None:
+            inputs_decode = nested_numpify(inputs_host)
+            all_inputs = (
+                inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
+            )
+        if labels_host is not None:
+            labels = nested_numpify(labels_host)
+            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
+
+        # Number of samples
+        if has_length(eval_dataset):
+            num_samples = len(eval_dataset)
+        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
+        # methods. Therefore we need to make sure it also has the attribute.
+        elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
+            num_samples = eval_dataset.num_examples
+        else:
+            if has_length(dataloader):
+                num_samples = self.num_examples(dataloader)
+            else:  # both len(dataloader.dataset) and len(dataloader) fail
+                num_samples = observed_num_examples
+
+        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
+        # samplers has been rounded to a multiple of batch_size, so we truncate.
+        if all_losses is not None:
+            all_losses = all_losses[:num_samples]
+        if all_preds is not None:
+            all_preds = nested_truncate(all_preds, num_samples)
+        if all_labels is not None:
+            all_labels = nested_truncate(all_labels, num_samples)
+        if all_inputs is not None:
+            all_inputs = nested_truncate(all_inputs, num_samples)
+
+        # Metrics!
+        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
+            if args.include_inputs_for_metrics:
+                metrics = self.compute_metrics(
+                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
+                )
+            else:
+                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
+        else:
+            metrics = {}
+
+        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
+        metrics = denumpify_detensorize(metrics)
+
+        if all_losses is not None:
+            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
+
+        # Prefix all keys with metric_key_prefix + '_'
+        for key in list(metrics.keys()):
+            if not key.startswith(f"{metric_key_prefix}_"):
+                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
+
+        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
-- 
GitLab