Skip to content
Snippets Groups Projects
Select Git revision
  • 8013a6fac7d8b91bc72fcafc743ab8c8c99169ba
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

trainer.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    trainer.py 11.49 KiB
    from transformers import Trainer
    from transformers.trainer_utils import EvalLoopOutput, has_length
    from transformers.deepspeed import deepspeed_init
    from transformers.utils import logging
    from transformers.trainer_pt_utils import find_batch_size, nested_concat, nested_numpify, IterableDatasetShard, nested_truncate
    from transformers.trainer_utils import EvalPrediction, denumpify_detensorize
    import torch
    from torch.utils.data import DataLoader
    import numpy as np
    from typing import List, Optional
    import json
    
    
    logger = logging.get_logger(__name__)
    
    class DumpTokenLossTrainer(Trainer):
        def evaluation_loop(
            self,
            dataloader: DataLoader,
            description: str,
            prediction_loss_only: Optional[bool] = None,
            ignore_keys: Optional[List[str]] = None,
            metric_key_prefix: str = "eval",
        ) -> EvalLoopOutput:
            """
            Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
            Works both with or without labels.
            """
            args = self.args
    
            prediction_loss_only = args.prediction_loss_only
    
            # if eval is called w/o train init deepspeed here
            if args.deepspeed and not self.deepspeed:
    
                # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
                # from the checkpoint eventually
                deepspeed_engine, _, _ = deepspeed_init(
                    self, num_training_steps=0, resume_from_checkpoint=None, inference=True
                )
                self.model = deepspeed_engine.module
                self.model_wrapped = deepspeed_engine
                self.deepspeed = deepspeed_engine
    
            model = self._wrap_model(self.model, training=False, dataloader=dataloader)
    
            # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
            # while ``train`` is running, cast it to the right dtype first and then put on device
            if not self.is_in_train:
                if args.fp16_full_eval:
                    model = model.to(dtype=torch.float16, device=args.device)
                elif args.bf16_full_eval:
                    model = model.to(dtype=torch.bfloat16, device=args.device)
    
            batch_size = self.args.eval_batch_size
    
            logger.info(f"***** Running {description} *****")
            if has_length(dataloader):
                logger.info(f"  Num examples = {self.num_examples(dataloader)}")
            else:
                logger.info("  Num examples: Unknown")
            logger.info(f"  Batch size = {batch_size}")
    
            model.eval()
    
            self.callback_handler.eval_dataloader = dataloader
            # Do this before wrapping.
            eval_dataset = getattr(dataloader, "dataset", None)
    
            if args.past_index >= 0:
                self._past = None
    
            # Initialize containers
            # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
            losses_host = None
            preds_host = None
            labels_host = None
            inputs_host = None
    
            # losses/preds/labels on CPU (final containers)
            all_losses = None
            all_preds = None
            all_labels = None
            all_inputs = None
            # Will be useful when we have an iterable dataset so don't know its length.
    
            if args.dump_eval_loss_to:
                writer = open(args.dump_eval_loss_to, "a", encoding='utf-8')
                loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
                num_sample_to_write = len(eval_dataset)
    
            observed_num_examples = 0
            # Main evaluation loop
            for step, inputs in enumerate(dataloader):
                # Update the observed num examples
                observed_batch_size = find_batch_size(inputs)
                if observed_batch_size is not None:
                    observed_num_examples += observed_batch_size
                    # For batch samplers, batch_size is not known by the dataloader in advance.
                    if batch_size is None:
                        batch_size = observed_batch_size
    
                # Prediction step
                loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
                inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
    
                # Update containers on host
                if loss is not None:
                    losses = self._nested_gather(loss.repeat(batch_size))
                    losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
                if labels is not None:
                    labels = self._pad_across_processes(labels)
                    labels = self._nested_gather(labels)
                    # labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
                if inputs_decode is not None:
                    inputs_decode = self._pad_across_processes(inputs_decode)
                    inputs_decode = self._nested_gather(inputs_decode)
                    inputs_host = (
                        inputs_decode
                        if inputs_host is None
                        else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                    )
                if logits is not None:
                    logits = self._pad_across_processes(logits)
                    logits = self._nested_gather(logits)
                    if self.preprocess_logits_for_metrics is not None:
                        logits = self.preprocess_logits_for_metrics(logits, labels)
                    # preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
    
                if args.dump_eval_loss_to:
                    if self.is_world_process_zero() and num_sample_to_write > 0:
                        assert logits is not None and labels is not None, print('prediction_loss_only', prediction_loss_only)
                        shift_logits = logits[..., :-1, :].contiguous()
                        shift_labels = labels[..., 1:].contiguous()
                        batch_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                        batch_token_loss = batch_token_loss.view(shift_labels.size()).tolist()
                        labels = labels.tolist()
                        for i in range(len(labels)):
                            if num_sample_to_write > 0:
                                num_sample_to_write -= 1
                            else:
                                break
                            token_ids = [x for x in labels[i] if x != -100]
                            tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
                            token_losses = [0] + batch_token_loss[i][:len(token_ids)-1]
                            writer.write(json.dumps({"tokens": tokens, "losses": token_losses}, ensure_ascii=False)+'\n')
    
                self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
    
                # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
                if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
                    if losses_host is not None:
                        losses = nested_numpify(losses_host)
                        all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                    if preds_host is not None:
                        logits = nested_numpify(preds_host)
                        all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
                    if inputs_host is not None:
                        inputs_decode = nested_numpify(inputs_host)
                        all_inputs = (
                            inputs_decode
                            if all_inputs is None
                            else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                        )
                    if labels_host is not None:
                        labels = nested_numpify(labels_host)
                        all_labels = (
                            labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                        )
    
                    # Set back to None to begin a new accumulation
                    losses_host, preds_host, inputs_host, labels_host = None, None, None, None
    
            if args.dump_eval_loss_to:
                writer.close()
            
            if args.past_index and hasattr(self, "_past"):
                # Clean the state at the end of the evaluation loop
                delattr(self, "_past")
    
            # Gather all remaining tensors and put them back on the CPU
            if losses_host is not None:
                losses = nested_numpify(losses_host)
                all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
            if preds_host is not None:
                logits = nested_numpify(preds_host)
                all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
            if inputs_host is not None:
                inputs_decode = nested_numpify(inputs_host)
                all_inputs = (
                    inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                )
            if labels_host is not None:
                labels = nested_numpify(labels_host)
                all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
    
            # Number of samples
            if has_length(eval_dataset):
                num_samples = len(eval_dataset)
            # The instance check is weird and does not actually check for the type, but whether the dataset has the right
            # methods. Therefore we need to make sure it also has the attribute.
            elif isinstance(eval_dataset, IterableDatasetShard) and hasattr(eval_dataset, "num_examples"):
                num_samples = eval_dataset.num_examples
            else:
                if has_length(dataloader):
                    num_samples = self.num_examples(dataloader)
                else:  # both len(dataloader.dataset) and len(dataloader) fail
                    num_samples = observed_num_examples
    
            # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
            # samplers has been rounded to a multiple of batch_size, so we truncate.
            if all_losses is not None:
                all_losses = all_losses[:num_samples]
            if all_preds is not None:
                all_preds = nested_truncate(all_preds, num_samples)
            if all_labels is not None:
                all_labels = nested_truncate(all_labels, num_samples)
            if all_inputs is not None:
                all_inputs = nested_truncate(all_inputs, num_samples)
    
            # Metrics!
            if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
                if args.include_inputs_for_metrics:
                    metrics = self.compute_metrics(
                        EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                    )
                else:
                    metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
            else:
                metrics = {}
    
            # To be JSON-serializable, we need to remove numpy types or zero-d tensors
            metrics = denumpify_detensorize(metrics)
    
            if all_losses is not None:
                metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
    
            # Prefix all keys with metric_key_prefix + '_'
            for key in list(metrics.keys()):
                if not key.startswith(f"{metric_key_prefix}_"):
                    metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
    
            return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)