Skip to content
Snippets Groups Projects
Select Git revision
  • 9cf97f07003039966e418444febd0f4c8445b75a
  • master default protected
  • emoUS
  • add_default_vectorizer_and_pretrained_loading
  • clean_code
  • readme
  • issue127
  • generalized_action_dicts
  • ppo_num_dialogues
  • crossowoz_ddpt
  • issue_114
  • robust_masking_feature
  • scgpt_exp
  • e2e-soloist
  • convlab_exp
  • change_system_act_in_env
  • pre-training
  • nlg-scgpt
  • remapping_actions
  • soloist
20 results

nlu_metric.py

Blame
  • user avatar
    zqwerty authored
    9cf97f07
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    nlu_metric.py 5.26 KiB
    # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    """NLU Metric"""
    
    import datasets
    from convlab2.base_models.t5.nlu.serialization import deserialize_dialogue_acts
    
    
    # TODO: Add BibTeX citation
    _CITATION = """\
    """
    
    _DESCRIPTION = """\
    Metric to evaluate text-to-text models on the natural language understanding task.
    """
    
    _KWARGS_DESCRIPTION = """
    Calculates sequence exact match, dialog acts accuracy and f1
    Args:
        predictions: list of predictions to score. Each predictions
            should be a string.
        references: list of reference for each prediction. Each
            reference should be a string.
    Returns:
        seq_em: sequence exact match
        accuracy: dialog acts accuracy
        overall_f1: dialog acts overall f1
        binary_f1: binary dialog acts f1
        categorical_f1: categorical dialog acts f1
        non-categorical_f1: non-categorical dialog acts f1
    Examples:
    
        >>> nlu_metric = datasets.load_metric("nlu_metric.py")
        >>> predictions = ["[binary][thank][general][]", "[non-categorical][inform][taxi][leave at][17:15]"]
        >>> references = ["[binary][thank][general][]", "[non-categorical][inform][train][leave at][17:15]"]
        >>> results = nlu_metric.compute(predictions=predictions, references=references)
        >>> print(results)
        {'seq_em': 0.5, 'accuracy': 0.5, 
        'overall_f1': 0.5, 'overall_precision': 0.5, 'overall_recall': 0.5, 
        'binary_f1': 1.0, 'binary_precision': 1.0, 'binary_recall': 1.0, 
        'categorical_f1': 0.0, 'categorical_precision': 0.0, 'categorical_recall': 0.0, 
        'non-categorical_f1': 0.0, 'non-categorical_precision': 0.0, 'non-categorical_recall': 0.0}
    """
    
    
    @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
    class NLUMetrics(datasets.Metric):
        """Metric to evaluate text-to-text models on the natural language understanding task."""
    
        def _info(self):
            return datasets.MetricInfo(
                description=_DESCRIPTION,
                citation=_CITATION,
                inputs_description=_KWARGS_DESCRIPTION,
                # This defines the format of each prediction and reference
                features=datasets.Features({
                    'predictions': datasets.Value('string'),
                    'references': datasets.Value('string'),
                })
            )
    
        def _compute(self, predictions, references):
            """Returns the scores: sequence exact match, dialog acts accuracy and f1"""
            seq_em = []
            acc = []
            f1_metrics = {x: {'TP':0, 'FP':0, 'FN':0} for x in ['overall', 'binary', 'categorical', 'non-categorical']}
    
            for prediction, reference in zip(predictions, references):
                seq_em.append(prediction.strip()==reference.strip())
                pred_da = deserialize_dialogue_acts(prediction)
                gold_da = deserialize_dialogue_acts(reference)
                flag = True
                for da_type in ['binary', 'categorical', 'non-categorical']:
                    if da_type == 'binary':
                        predicts = sorted(list({(x['intent'], x['domain'], x['slot']) for x in pred_da[da_type]}))
                        labels = sorted(list({(x['intent'], x['domain'], x['slot']) for x in gold_da[da_type]}))
                    else:
                        predicts = sorted(list({(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in pred_da[da_type]}))
                        labels = sorted(list({(x['intent'], x['domain'], x['slot'], ''.join(x['value'].split()).lower()) for x in gold_da[da_type]}))
                    for ele in predicts:
                        if ele in labels:
                            f1_metrics['overall']['TP'] += 1
                            f1_metrics[da_type]['TP'] += 1
                        else:
                            f1_metrics['overall']['FP'] += 1
                            f1_metrics[da_type]['FP'] += 1
                    for ele in labels:
                        if ele not in predicts:
                            f1_metrics['overall']['FN'] += 1
                            f1_metrics[da_type]['FN'] += 1
                    flag &= (predicts==labels)
                acc.append(flag)
    
            for metric in list(f1_metrics.keys()):
                TP = f1_metrics[metric].pop('TP')
                FP = f1_metrics[metric].pop('FP')
                FN = f1_metrics[metric].pop('FN')
                precision = 1.0 * TP / (TP + FP) if TP + FP else 0.
                recall = 1.0 * TP / (TP + FN) if TP + FN else 0.
                f1 = 2.0 * precision * recall / (precision + recall) if precision + recall else 0.
                f1_metrics.pop(metric)
                f1_metrics[f'{metric}_f1'] = f1
                f1_metrics[f'{metric}_precision'] = precision
                f1_metrics[f'{metric}_recall'] = recall
    
            return {
                "seq_em": sum(seq_em)/len(seq_em),
                "accuracy": sum(acc)/len(acc),
                **f1_metrics
            }