diff --git a/README.md b/README.md index 150ada1a54348439f5c8cf93476c98eb2f1cb620..a468a073a114d22f6343fe2022988dd87e144006 100755 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ To use ConvLab-3 as an off-the-shelf tool, you can install via: ```bash pip install convlab ``` +Note that the `data` directory will not be included due to the package size limitation. ### Using Docker @@ -99,7 +100,7 @@ We list newly integrated models in ConvLab-3 that support unified data format an | Task | Models | Input | Output | | ------------------------------ | ------------------------------------------------------------ | --------------- | ---------------- | | Response Generation | [T5](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/base_models/t5) | Context | Response | -| Goal-to-Dialog | [T5](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/base_models/t5) | Goal | Dialog | +| Goal-to-Dialogue | [T5](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/base_models/t5) | Goal | Dialog | | Natural Language Understanding | [T5](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/base_models/t5), [BERTNLU](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/nlu/jointBERT), [MILU](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/nlu/milu) | Context | DA-U | | Dialog State Tracking | [T5](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/base_models/t5), SUMBT, SetSUMBT, TripPy | Context | State | | RL Policy | DDPT, PPO, PG | State, DA-U, DB | DA-S | diff --git a/convlab/base_models/t5/README.md b/convlab/base_models/t5/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c0894e6c5cd38b5738ef85f7d900ecdf304d3fa6 --- /dev/null +++ b/convlab/base_models/t5/README.md @@ -0,0 +1,80 @@ +# T5 models + +By converting NLP tasks into a text-to-text format, we can use one single model to solve various tasks. Here we use T5 as backbone model and provide a unified training script `run_seq2seq.py` for many tasks. **See `*.sh` under each task directory for usage.** + +## Create Data +Currently we support natural language understanding (**NLU**), dialog state tracking (**DST**), natural language generation (**NLG**), response generation (**RG**), and generating a dialog from a user goal (**Goal2Dialogue**). We provide serialization and deserialization methods for dialog acts and state in the unified data format (user goals are already natural language instruction). An example of serialized dialog acts and state: + +``` +User: I am looking for a cheap restaurant. +System: Is there a particular area of town you prefer? +User: In the centre of town. + +User dialog acts: [inform][restaurant]([area][centre]) +State: [restaurant]([area][centre],[price range][cheap]) +System dialog acts: [recommend][restaurant]([name][Zizzi Cambridge]) + +System: I would recommend Zizzi Cambridge. +``` + +Dialogue acts are in the form of `[intent][domain]([slot][value],...);...`. State is in the form of `[domain]([slot][value],...);...`. Multiple items will be concatenated by a semicolon `;`. + +To create data for a specific task, run `create_data.py` with corresponding arguments. For example, create data for single turn NLU on MultiWOZ 2.1: + +```bash +python create_data.py --tasks nlu --datasets multiwoz21 --speaker user +``` + +Note that the script only supported **datasets in the unified format**. + +## Training + +To train the model, specify the arguments like data path, learning rate, epochs, etc., and then run `run_seq2seq.py`. See `nlu/run_nlu.sh` for an example. + +## Evaluation + +The standard evaluation scripts of NLU, DST, and NLG task are located under `../../$task/evaluate_unified_datasets.py` directories. See `nlu/run_nlu.sh` for an example. + +## Trained Models + +Trained models and their performance are available in [Hugging Face Hub](https://huggingface.co/ConvLab). You can try some example with hosted inference API. + +| Name | Task | Training Dataset | +| ------------------------------------------------------------ | ------------- | ---------------------------- | +| [t5-small-goal2dialogue-multiwoz21](https://huggingface.co/ConvLab/t5-small-goal2dialogue-multiwoz21) | Goal2Dialogue | MultiWOZ 2.1 | +| [t5-small-nlu-multiwoz21](https://huggingface.co/ConvLab/t5-small-nlu-multiwoz21) | NLU | MultiWOZ 2.1 | +| [t5-small-nlu-sgd](https://huggingface.co/ConvLab/t5-small-nlu-sgd) | NLU | SGD | +| [t5-small-nlu-tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-nlu-tm1_tm2_tm3) | NLU | TM1+TM2+TM3 | +| [t5-small-nlu-multiwoz21_sgd_tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-nlu-multiwoz21_sgd_tm1_tm2_tm3) | NLU | MultiWOZ 2.1+SGD+TM1+TM2+TM3 | +| [t5-small-dst-multiwoz21](https://huggingface.co/ConvLab/t5-small-dst-multiwoz21) | DST | MultiWOZ 2.1 | +| [t5-small-dst-sgd](https://huggingface.co/ConvLab/t5-small-dst-sgd) | DST | SGD | +| [t5-small-dst-tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-dst-tm1_tm2_tm3) | DST | TM1+TM2+TM3 | +| [t5-small-dst-multiwoz21_sgd_tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-dst-multiwoz21_sgd_tm1_tm2_tm3) | DST | MultiWOZ 2.1+SGD+TM1+TM2+TM3 | +| [t5-small-nlg-multiwoz21](https://huggingface.co/ConvLab/t5-small-nlg-multiwoz21) | NLG | MultiWOZ 2.1 | +| [t5-small-nlg-sgd](https://huggingface.co/ConvLab/t5-small-nlg-sgd) | NLG | SGD | +| [t5-small-nlg-tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-nlg-tm1_tm2_tm3) | NLG | TM1+TM2+TM3 | +| [t5-small-nlg-multiwoz21_sgd_tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-nlg-multiwoz21_sgd_tm1_tm2_tm3) | NLG | MultiWOZ 2.1+SGD+TM1+TM2+TM3 | + +## Interface + +To use trained models in a dialog system, import them through: + +```python +from convlab.base_models.t5.nlu import T5NLU +from convlab.base_models.t5.dst import T5DST +from convlab.base_models.t5.nlg import T5NLG + +# example instantiation +# model_name_or_path could be model in hugging face hub or local path +nlu = T5NLU(speaker='user', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlu-multiwoz21') +``` + +See `nlu/nlu.py`, `dst/dst.py`, `nlg/nlg.py` for example usage. + +## Support a New Task + +To support a new task, you can first serialize model input and output like `create_data.py`, and then train the model with `run_seq2seq.py`. Finally, write a evaluation script for the task or pass the `metric_name_or_path` for an existing metric to `run_seq2seq.py`. + +## Author + +Qi Zhu(zhuq96 at gmail dot com) \ No newline at end of file diff --git a/convlab/base_models/t5/dst/__init__.py b/convlab/base_models/t5/dst/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4e6903b1093e645eed5ae6ba6b9b88d6150fe39 --- /dev/null +++ b/convlab/base_models/t5/dst/__init__.py @@ -0,0 +1 @@ +from convlab.base_models.t5.dst.dst import T5DST \ No newline at end of file diff --git a/convlab/base_models/t5/dst/dst.py b/convlab/base_models/t5/dst/dst.py index c34395c02fd188568b9f4c4bc1956240fbbc88a9..3c5ec2525091cfdcb423460ed1b01871087deb21 100755 --- a/convlab/base_models/t5/dst/dst.py +++ b/convlab/base_models/t5/dst/dst.py @@ -8,16 +8,12 @@ from convlab.util.custom_util import model_downloader class T5DST(DST): - def __init__(self, speaker, context_window_size, model_name_or_path, model_file=None, device='cuda'): + def __init__(self, speaker, context_window_size, model_name_or_path, device='cuda'): assert speaker in ['user', 'system'] assert context_window_size > 0 self.speaker = speaker self.opponent = 'system' if speaker == 'user' else 'user' self.context_window_size = context_window_size - - model_dir = os.path.dirname(os.path.abspath(__file__)) - if not os.path.exists(model_name_or_path): - model_downloader(model_dir, model_file) self.config = AutoConfig.from_pretrained(model_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) diff --git a/convlab/base_models/t5/key2gen/create_data.py b/convlab/base_models/t5/key2gen/create_data.py deleted file mode 100644 index 138808f2a13d4b3fad71b57a2aa7977917f8143c..0000000000000000000000000000000000000000 --- a/convlab/base_models/t5/key2gen/create_data.py +++ /dev/null @@ -1,162 +0,0 @@ -import os -import json -from tqdm import tqdm -from convlab.util import load_dataset, load_unified_data, load_nlu_data - -def create_nlg_data(dataset, data_dir, args): - data_by_split = load_nlu_data(dataset, speaker='system', use_context=True, context_window_size=3) - os.makedirs(data_dir, exist_ok=True) - - data_splits = data_by_split.keys() - for data_split in data_splits: - data = [] - for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): - context = [(turn['speaker'], turn['utterance']) for turn in sample['context']] - response = sample['utterance'] - if len(context) > 0 and len(response) > 0: - knowledge = sample['dialogue_acts'] - data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n') - - if 'test' in data_split: - file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json") - else: - file_name = os.path.join(data_dir, f"{data_split}.json") - with open(file_name, "w", encoding='utf-8') as f: - f.writelines(data) - data_by_split[data_split] = data - return data_by_split - -def create_kvret_data(dataset, data_dir, args): - data_by_split = load_unified_data(dataset, speaker='system', utterance=True, db_results=True, use_context=True, context_window_size=100) - os.makedirs(data_dir, exist_ok=True) - - domain2entity_col = {'schedule': 'event' ,'navigate': 'poi', 'weather': 'location'} - data_splits = data_by_split.keys() - for data_split in data_splits: - data = [] - for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): - context = [(turn['speaker'], turn['utterance']) for turn in sample['context']] - response = sample['utterance'] - if len(context) > 0 and len(response) > 0: - knowledge = sample['db_results'] - for domain, db_items in knowledge.items(): - entity_col = domain2entity_col[domain] - for db_item in db_items: - db_item['entity'] = db_item.pop(entity_col) - - data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n') - - if 'test' in data_split: - file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json") - else: - file_name = os.path.join(data_dir, f"{data_split}.json") - with open(file_name, "w", encoding='utf-8') as f: - f.writelines(data) - data_by_split[data_split] = data - return data_by_split - -def create_personachat_data(dataset, data_dir, args): - data_by_split = dataset - os.makedirs(data_dir, exist_ok=True) - - data_splits = data_by_split.keys() - for data_split in data_splits: - data = [] - for dial in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): - knowledge = dial['persona']['system'] - context = [] - for turn in dial['turns']: - response = turn['utterance'] - if turn['speaker'] == 'system' and len(context) > 0 and len(response) > 0: - data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n') - context.append((turn['speaker'], turn['utterance'])) - - if 'test' in data_split: - file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json") - else: - file_name = os.path.join(data_dir, f"{data_split}.json") - with open(file_name, "w", encoding='utf-8') as f: - f.writelines(data) - data_by_split[data_split] = data - return data_by_split - -def create_wow_data(dataset, data_dir, args): - data_by_split = dataset - os.makedirs(data_dir, exist_ok=True) - data_by_split['test'] = data_by_split['test_seen'] + data_by_split['test_unseen'] - data_by_split.pop('test_seen') - data_by_split.pop('test_unseen') - - data_splits = data_by_split.keys() - for data_split in data_splits: - data = [] - for dial in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): - context = [] - for turn in dial['turns']: - response = turn['utterance'] - if turn['speaker'] == 'system' and len(context) > 0 and len(response) > 0: - knowledge = turn['checked_passage'] - if knowledge is None: - knowledge = [] - elif isinstance(knowledge, str): - knowledge = [knowledge] - data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n') - context.append((turn['speaker'], turn['utterance'])) - - if 'test' in data_split: - file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json") - else: - file_name = os.path.join(data_dir, f"{data_split}.json") - with open(file_name, "w", encoding='utf-8') as f: - f.writelines(data) - data_by_split[data_split] = data - return data_by_split - -def create_opendialkg_data(dataset, data_dir, args): - data_by_split = dataset - os.makedirs(data_dir, exist_ok=True) - - data_splits = data_by_split.keys() - for data_split in data_splits: - data = [] - for dial in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): - context = [] - for turn in dial['turns']: - response = turn['utterance'] - if turn['speaker'] == 'system' and 'kg_path' in turn and len(context) > 0 and len(response) > 0: - knowledge = turn['kg_path']['triples'] - data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n') - context.append((turn['speaker'], turn['utterance'])) - - if 'test' in data_split: - file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json") - else: - file_name = os.path.join(data_dir, f"{data_split}.json") - with open(file_name, "w", encoding='utf-8') as f: - f.writelines(data) - data_by_split[data_split] = data - return data_by_split - - -if __name__ == '__main__': - from argparse import ArgumentParser - parser = ArgumentParser(description="create data for seq2seq training") - parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['nlg', 'kvret', 'opendialkg', 'personachat', 'wow'], help='names of tasks') - parser.add_argument('--datasets', '-d', metavar='dataset_name', nargs='*', help='names of unified datasets') - parser.add_argument('--shot', '-s', type=float, default=None, help='how many data is used for training and evaluation, ratio if < 1 else absolute number') - parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments') - args = parser.parse_args() - print(args) - for dataset_name in tqdm(args.datasets, desc='datasets'): - dataset = load_dataset(dataset_name, dial_ids_order=args.dial_ids_order) - if args.shot: - if args.shot < 1: - dataset['train'] = dataset['train'][:round(len(dataset['train'])*args.shot)] - dataset['validation'] = dataset['validation'][:round(len(dataset['validation'])*args.shot)] - else: - args.shot = int(args.shot) - dataset['train'] = dataset['train'][:args.shot] - dataset['validation'] = dataset['validation'][:args.shot] - for task_name in tqdm(args.tasks, desc='tasks', leave=False): - data_dir = os.path.join('data', task_name, (dataset_name if not args.shot else f'{dataset_name}_{args.shot}shot_order{args.dial_ids_order}')) - data_by_split = eval(f"create_{task_name}_data")(dataset, data_dir, args) diff --git a/convlab/base_models/t5/key2gen/dataset_godel.py b/convlab/base_models/t5/key2gen/dataset_godel.py deleted file mode 100644 index caf7b8ab7b1fb10b8de03c01dac9a147f5540af1..0000000000000000000000000000000000000000 --- a/convlab/base_models/t5/key2gen/dataset_godel.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Data processing for vanilla generator""" - -import json -import datasets -from convlab.base_models.t5.key2gen.features import FEATURES -from copy import deepcopy - - -class GodelDataset(datasets.GeneratorBasedBuilder): - """Dataset for vanilla generator (e.g., t5)""" - - VERSION = datasets.Version("1.18.0") - - BUILDER_CONFIGS = [ - datasets.BuilderConfig(name="nlg", version=VERSION, description="DA grounded generation task"), - datasets.BuilderConfig(name="kvret", version=VERSION, description="KB grounded generation task"), - datasets.BuilderConfig(name="opendialkg", version=VERSION, description="KG grounded generation task"), - datasets.BuilderConfig(name="wow", version=VERSION, description="Passage grounded generation task"), - datasets.BuilderConfig(name="personachat", version=VERSION, description="Persona grounded generation task"), - ] - - def _info(self): - return datasets.DatasetInfo( - description=f"Vanilla Dataset for {self.config.description}", - features=datasets.Features(deepcopy(FEATURES[self.config.name])) - ) - - def _split_generators(self, dl_manager): - generators = [] - if "train" in self.config.data_files: - generators.append(datasets.SplitGenerator( - name=datasets.Split.TRAIN, - gen_kwargs={ - "filepath": self.config.data_files["train"][0], - "split": "train", - }, - )) - if "validation" in self.config.data_files: - generators.append(datasets.SplitGenerator( - name=datasets.Split.VALIDATION, - gen_kwargs={ - "filepath": self.config.data_files["validation"][0], - "split": "validation", - }, - )) - if "test" in self.config.data_files: - generators.append(datasets.SplitGenerator( - name=datasets.Split.TEST, - gen_kwargs={ - "filepath": self.config.data_files["test"][0], - "split": "test", - }, - )) - - return generators - - def _generate_examples(self, filepath, split): - with open(filepath, encoding="utf-8") as f: - for key, row in enumerate(f): - item = json.loads(row) - if self.config.name == "nlg": - knowledge = item["knowledge"] - triples = [] - for da_type in knowledge: - for da in knowledge[da_type]: - intent, domain, slot, value = da["intent"], da["domain"], da["slot"], da.get("value", "") - if 'start' in da: - da.pop('start') - da.pop('end') - intent_domain = f"{intent}-{domain}" - triples.append([intent_domain]) - if len(slot) > 0: - triples[-1].append(slot) - if len(value) > 0: - triples[-1].append(value) - knowledge_seq = "| {} |".format(" | ".join([" : ".join(da_keywords) for da_keywords in triples])) - - elif self.config.name == "kvret": - knowledge = {"schedule": [], "weather": [], "navigate": []} - triples = [] - for domain, db_items in item["knowledge"].items(): - knowledge[domain] = db_items - for db_item in db_items: - entity = db_item["entity"] - for db_key, db_value in db_item.items(): - if db_key == "entity": - continue - triples.append([entity, db_key, db_value]) - knowledge_seq = "| {} |".format(" | ".join([" : ".join(triple) for triple in triples])) - - elif self.config.name == "opendialkg": - knowledge = item["knowledge"] - knowledge_seq = "| {} |".format(" | ".join([" : ".join(triple) for triple in item["knowledge"]])) - - elif self.config.name in ["wow", "personachat"]: - knowledge = item["knowledge"] - try: - knowledge_seq = "| {} |".format(" | ".join(item["knowledge"])) - except: - print([knowledge]) - raise - - context = " EOS ".join([turn[1] for turn in item["context"]]) - context_knowledge = context + ' <|Knowledge|> \n\n' + knowledge_seq + ' => ' - - yield key, { - "context+knowledge": context_knowledge, - "response": item["response"], - "knowledge": knowledge, - } diff --git a/convlab/base_models/t5/key2gen/dataset_vanilla.py b/convlab/base_models/t5/key2gen/dataset_vanilla.py deleted file mode 100644 index 15a8c7b4ac8cfbf1057e090f675a3fc7a4051f2c..0000000000000000000000000000000000000000 --- a/convlab/base_models/t5/key2gen/dataset_vanilla.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Data processing for vanilla generator""" - -import json -import datasets -from convlab.base_models.t5.key2gen.features import FEATURES -from copy import deepcopy - - -class VanillaDataset(datasets.GeneratorBasedBuilder): - """Dataset for vanilla generator (e.g., t5)""" - - VERSION = datasets.Version("1.18.0") - - BUILDER_CONFIGS = [ - datasets.BuilderConfig(name="nlg", version=VERSION, description="DA grounded generation task"), - datasets.BuilderConfig(name="kvret", version=VERSION, description="KB grounded generation task"), - datasets.BuilderConfig(name="opendialkg", version=VERSION, description="KG grounded generation task"), - datasets.BuilderConfig(name="wow", version=VERSION, description="Passage grounded generation task"), - datasets.BuilderConfig(name="personachat", version=VERSION, description="Persona grounded generation task"), - ] - - def _info(self): - return datasets.DatasetInfo( - description=f"Vanilla Dataset for {self.config.description}", - features=datasets.Features(deepcopy(FEATURES[self.config.name])) - ) - - def _split_generators(self, dl_manager): - generators = [] - if "train" in self.config.data_files: - generators.append(datasets.SplitGenerator( - name=datasets.Split.TRAIN, - gen_kwargs={ - "filepath": self.config.data_files["train"][0], - "split": "train", - }, - )) - if "validation" in self.config.data_files: - generators.append(datasets.SplitGenerator( - name=datasets.Split.VALIDATION, - gen_kwargs={ - "filepath": self.config.data_files["validation"][0], - "split": "validation", - }, - )) - if "test" in self.config.data_files: - generators.append(datasets.SplitGenerator( - name=datasets.Split.TEST, - gen_kwargs={ - "filepath": self.config.data_files["test"][0], - "split": "test", - }, - )) - - return generators - - def _generate_examples(self, filepath, split): - with open(filepath, encoding="utf-8") as f: - for key, row in enumerate(f): - item = json.loads(row) - if self.config.name == "nlg": - knowledge = item["knowledge"] - triples = [] - for da_type in knowledge: - for da in knowledge[da_type]: - intent, domain, slot, value = da["intent"], da["domain"], da["slot"], da.get("value", "") - if 'start' in da: - da.pop('start') - da.pop('end') - intent_domain = f"{intent}-{domain}" - triples.append([intent_domain]) - if len(slot) > 0: - triples[-1].append(slot) - if len(value) > 0: - triples[-1].append(value) - knowledge_seq = "| {} |".format(" | ".join([" : ".join(da_keywords) for da_keywords in triples])) - - elif self.config.name == "kvret": - knowledge = {"schedule": [], "weather": [], "navigate": []} - triples = [] - for domain, db_items in item["knowledge"].items(): - knowledge[domain] = db_items - for db_item in db_items: - entity = db_item["entity"] - for db_key, db_value in db_item.items(): - if db_key == "entity": - continue - triples.append([entity, db_key, db_value]) - knowledge_seq = "| {} |".format(" | ".join([" : ".join(triple) for triple in triples])) - - elif self.config.name == "opendialkg": - knowledge = item["knowledge"] - knowledge_seq = "| {} |".format(" | ".join([" : ".join(triple) for triple in item["knowledge"]])) - - elif self.config.name in ["wow", "personachat"]: - knowledge = item["knowledge"] - try: - knowledge_seq = "| {} |".format(" | ".join(item["knowledge"])) - except: - print([knowledge]) - raise - - context = "\n".join([f"{turn[0]}: {turn[1]}" for turn in item["context"]]+["system: "]) - if self.config.name in ["kvret", "wow", "personachat"]: - context_knowledge = f"generate a response: all knowledge: \n\n{knowledge_seq} context:\n\n{context}" - else: - context_knowledge = f"generate a response: grounded knowledge: \n\n{knowledge_seq} context:\n\n{context}" - - yield key, { - "context+knowledge": context_knowledge, - "response": item["response"], - "knowledge": knowledge, - } diff --git a/convlab/base_models/t5/key2gen/eval.ipynb b/convlab/base_models/t5/key2gen/eval.ipynb deleted file mode 100644 index 51fcc5e0da1321ef740084d0a8b0241b5721a2fc..0000000000000000000000000000000000000000 --- a/convlab/base_models/t5/key2gen/eval.ipynb +++ /dev/null @@ -1 +0,0 @@ -{"cells":[{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[],"source":["import json\n","import re"]},{"cell_type":"code","execution_count":2,"metadata":{},"outputs":[],"source":["def read_jsonline(path):\n"," return [json.loads(line) for line in open(path)]"]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[],"source":["origin = read_jsonline('output/wow/wow/test_unseen.json')"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[],"source":["key2gen = read_jsonline('output/wow/key2gen_wow/test_unseen.json')"]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[],"source":["with open('tmp_wow.txt', 'w') as f:\n"," for d1, d2 in zip(origin, key2gen):\n"," print(re.split('context:|grounded knowledge:', d1['context+knowledge'])[1].strip(), file=f)\n"," print(re.split('context:|grounded knowledge:', d2['context+knowledge'])[1].strip(), file=f)\n"," print(d1['context+knowledge'].split('context:')[1].replace('\\n\\n', '\\n'), file=f)\n"," print(file=f)\n"," print('target', d1['response'], file=f)\n"," print('origin', d1['predictions'], file=f)\n"," print('key2gen', d2['predictions'], file=f)\n"," print('='*100, file=f)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["for ratio in [0.1, 0.01]:\n"," for order in [0, 1, 2]:\n"," origin = read_jsonline(f'output/personachat/key2gen_personachat_{ratio}_order{order}/generated_predictions.json')\n"," score = metric.compute(predictions=[d['predictions'] for d in origin], references=[d['response'] for d in origin])\n"," print(ratio, order)\n"," print(score)\n"," "]},{"cell_type":"code","execution_count":51,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["0.01 1\n","{'bleu-1': 24.322560358946276, 'bleu-2': 13.03630111937752, 'bleu-3': 7.43647978674912, 'bleu-4': 4.450365738541082, 'unigram f1': 0.20101056184593705, 'unigram f1 (non-stop words)': 0.09881569367818614, 'rouge1': 21.359332522961864, 'rouge2': 6.532120354812852, 'rougeL': 19.76437990594138}\n"]}],"source":["for ratio in [0.01]:\n"," for order in [1]:\n"," origin = read_jsonline(f'output/personachat/personachat/generated_predictions.json')\n"," score = metric.compute(predictions=[d['predictions'] for d in origin], references=[d['response'] for d in origin])\n"," print(ratio, order)\n"," print(score)\n"," "]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[],"source":["from datasets import load_metric"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["metric = load_metric('metric.py')"]},{"cell_type":"code","execution_count":58,"metadata":{},"outputs":[{"data":{"text/plain":["{'bleu-1': 47.9848465486215,\n"," 'bleu-2': 37.18000679532912,\n"," 'bleu-3': 29.346646172092814,\n"," 'bleu-4': 23.410526740211363,\n"," 'unigram f1': 0.4999850046010773,\n"," 'unigram f1 (non-stop words)': 0.5150265227462978,\n"," 'rouge1': 50.536642578692195,\n"," 'rouge2': 33.10681789367832,\n"," 'rougeL': 46.84702913163778,\n"," 'meteor': 0.4641962079490068}"]},"execution_count":58,"metadata":{},"output_type":"execute_result"}],"source":["metric.compute(predictions=[d['predictions'] for d in key2gen], references=[d['response'] for d in key2gen])"]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"data":{"text/plain":["{'bleu-1': 37.570099942714585,\n"," 'bleu-2': 26.77393964962893,\n"," 'bleu-3': 21.115954644820572,\n"," 'bleu-4': 17.513316671216046,\n"," 'unigram f1': 0.3656930567072274,\n"," 'unigram f1 (non-stop words)': 0.36456219281235724,\n"," 'rouge1': 39.1982724920493,\n"," 'rouge2': 20.825159884632743,\n"," 'rougeL': 34.98278542180112,\n"," 'meteor': 0.3405671227693821,\n"," 'distinct-1': 0.07838670580160921,\n"," 'distinct-2': 0.29689084413659694}"]},"execution_count":8,"metadata":{},"output_type":"execute_result"}],"source":["metric.compute(predictions=[d['predictions'] for d in origin], references=[d['response'] for d in origin])"]},{"cell_type":"code","execution_count":34,"metadata":{},"outputs":[{"data":{"text/plain":["{'bleu-1': 47.9848465486215,\n"," 'bleu-2': 37.18000679532912,\n"," 'bleu-3': 29.346646172092814,\n"," 'bleu-4': 23.410526740211363,\n"," 'unigram f1': 0.4999850046010773,\n"," 'unigram f1 (non-stop words)': 0.5150265227462978,\n"," 'rouge1': AggregateScore(low=Score(precision=0.5301926525013549, recall=0.4821419251082986, fmeasure=0.48565655175230005), mid=Score(precision=0.5513392693168799, recall=0.50235850981064, fmeasure=0.5053664257869219), high=Score(precision=0.5760132731228504, recall=0.5268580272115051, fmeasure=0.5279111393835526)),\n"," 'rouge2': AggregateScore(low=Score(precision=0.34772127155901306, recall=0.30411953889228, fmeasure=0.31029658993105447), mid=Score(precision=0.3696898381097765, recall=0.32612705034192035, fmeasure=0.3310681789367832), high=Score(precision=0.3947745596965405, recall=0.34880792116864995, fmeasure=0.35356317521641434)),\n"," 'rougeL': AggregateScore(low=Score(precision=0.4874189522136045, recall=0.4413343070361347, fmeasure=0.4464463084888409), mid=Score(precision=0.5108530997712726, recall=0.4642203560120527, fmeasure=0.46847029131637785), high=Score(precision=0.5350154077389535, recall=0.4855131911095939, fmeasure=0.4899950876629784)),\n"," 'rougeLsum': AggregateScore(low=Score(precision=0.4871840444049138, recall=0.44081531444183386, fmeasure=0.44514075751478493), mid=Score(precision=0.5105975305923949, recall=0.4639265647317744, fmeasure=0.46779186414456864), high=Score(precision=0.5348015149575474, recall=0.48693312722760357, fmeasure=0.4918651382986408))}"]},"execution_count":34,"metadata":{},"output_type":"execute_result"}],"source":["metric.compute(predictions=[d['predictions'] for d in key2gen], references=[d['response'] for d in key2gen])"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"interpreter":{"hash":"0f9333403d680bc010aa5ce5a2f27ba398c9e47e92ba3724506306aa234cd07d"},"kernelspec":{"display_name":"Python 3.8.12 ('py38')","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.12"},"orig_nbformat":4},"nbformat":4,"nbformat_minor":2} diff --git a/convlab/base_models/t5/key2gen/evaluate.py b/convlab/base_models/t5/key2gen/evaluate.py deleted file mode 100644 index 769fdfcf3d1c899aad1b5389dad2c8d9465c05c6..0000000000000000000000000000000000000000 --- a/convlab/base_models/t5/key2gen/evaluate.py +++ /dev/null @@ -1,91 +0,0 @@ -from tabulate import tabulate -import os -import json -from tqdm import tqdm -from datasets import load_metric -import numpy as np -import csv - -def evaluate(filename, metric): - """ - It reads the predictions, references, and knowledge from a file, and then computes the metric - - :param filename: the path to the file containing the predictions - :param metric: the metric to use for evaluation - :return: The result of the evaluation. - """ - predictions, references, knowledge = [], [], [] - with open(filename, 'r') as f: - for line in f: - item = json.loads(line) - predictions.append(item['predictions']) - references.append(item['response']) - knowledge.append(item['knowledge']) - result = metric.compute(predictions=predictions, references=references, knowledge=knowledge) - return result - - -def avg_result(results): - """ - It takes a list of dictionaries, and returns a dictionary with the same keys, but the values are the - mean and standard deviation of the values in the input dictionaries - - :param results: a list of dictionaries, each dictionary is the result of a single run of the model - :return: The average and standard deviation of the results. - """ - ret = {} - for k in results[0]: - m = round(np.mean([result[k] for result in results]), 2) - v = round(np.std([result[k] for result in results], ddof=1), 2) if len(results) > 1 else None - ret[k] = f"{m}({v})" - return ret - - -if __name__ == '__main__': - from argparse import ArgumentParser - parser = ArgumentParser(description="create data for seq2seq training") - parser.add_argument("--output_dirs", type=str, nargs='*', required=True) - parser.add_argument('--tasks', '-t', type=str, nargs='*', choices=['nlg', 'kvret', 'opendialkg', 'personachat', 'wow'], help='names of tasks') - parser.add_argument('--shots', '-s', type=int, nargs='*', help='how many data is used for training and evaluation, ratio if < 1 else absolute number') - parser.add_argument('--dial_ids_orders', '-o', type=int, nargs='*', help='which data order is used for experiments') - args = parser.parse_args() - print(args) - - table = [] - fieldnames = [] - for task_name in tqdm(args.tasks, desc='tasks'): - metric = load_metric("metric.py", task_name) - dataset_name = task_name if task_name != "nlg" else "multiwoz21" - for shot in tqdm(args.shots, desc='shots', leave=False): - for output_dir in tqdm(args.output_dirs, desc='models', leave=False): - model_name = output_dir.split('/')[-1] - results = [] - for dial_ids_order in tqdm(args.dial_ids_orders, desc='dial_ids_orders', leave=False): - result_dir = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen") - result_file = os.path.join(result_dir, "result.json") - if not os.path.exists(result_file): - filename = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen/generated_predictions.json") - result = evaluate(filename, metric) - json.dump(result, open(result_file, 'w', encoding='utf-8'), indent=2, ensure_ascii=False) - else: - result = json.load(open(result_file)) - results.append(result) - res = { - "dataset": f"{task_name}-{shot}shot", - "model": f"{model_name}", - **avg_result(results) - } - table.append(res) - for k in res: - if k not in fieldnames: - fieldnames.append(k) - - res = tabulate(table, headers='keys', tablefmt='github') - with open(f'eval_results.txt', 'w', encoding='utf-8') as f: - print(res, file=f) - with open('eval_results.csv', 'w', newline='') as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - - writer.writeheader() - for res in table: - writer.writerow(res) diff --git a/convlab/base_models/t5/key2gen/features.py b/convlab/base_models/t5/key2gen/features.py deleted file mode 100644 index 0ac768b5cbe61d46e430580b025182e515db93ef..0000000000000000000000000000000000000000 --- a/convlab/base_models/t5/key2gen/features.py +++ /dev/null @@ -1,72 +0,0 @@ -import datasets - -FEATURES = { - "nlg": { - "context+knowledge": datasets.Value("string"), - "response": datasets.Value("string"), - "knowledge": { - "categorical": datasets.Sequence({ - "intent": datasets.Value("string"), - "domain": datasets.Value("string"), - "slot": datasets.Value("string"), - "value": datasets.Value("string"), - }), - "non-categorical": datasets.Sequence({ - "intent": datasets.Value("string"), - "domain": datasets.Value("string"), - "slot": datasets.Value("string"), - "value": datasets.Value("string"), - }), - "binary": datasets.Sequence({ - "intent": datasets.Value("string"), - "domain": datasets.Value("string"), - "slot": datasets.Value("string"), - }) - }}, - "kvret": { - "context+knowledge": datasets.Value("string"), - "response": datasets.Value("string"), - "knowledge": { - "schedule": datasets.Sequence({ - "entity": datasets.Value("string"), - "time": datasets.Value("string"), - "date": datasets.Value("string"), - "party": datasets.Value("string"), - "room": datasets.Value("string"), - "agenda": datasets.Value("string") - }), - "weather": datasets.Sequence({ - "entity": datasets.Value("string"), - "today": datasets.Value("string"), - "monday": datasets.Value("string"), - "tuesday": datasets.Value("string"), - "wednesday": datasets.Value("string"), - "thursday": datasets.Value("string"), - "friday": datasets.Value("string"), - "saturday": datasets.Value("string"), - "sunday": datasets.Value("string"), - }), - "navigate": datasets.Sequence({ - "entity": datasets.Value("string"), - "traffic_info": datasets.Value("string"), - "poi_type": datasets.Value("string"), - "address": datasets.Value("string"), - "distance": datasets.Value("string") - }) - }}, - "opendialkg": { - "context+knowledge": datasets.Value("string"), - "response": datasets.Value("string"), - "knowledge": datasets.Sequence(datasets.Sequence(datasets.Value("string"))), - }, - "wow": { - "context+knowledge": datasets.Value("string"), - "response": datasets.Value("string"), - "knowledge": datasets.Sequence(datasets.Value("string")), - }, - "personachat": { - "context+knowledge": datasets.Value("string"), - "response": datasets.Value("string"), - "knowledge": datasets.Sequence(datasets.Value("string")), - } -} \ No newline at end of file diff --git a/convlab/base_models/t5/key2gen/finetune.sh b/convlab/base_models/t5/key2gen/finetune.sh deleted file mode 100644 index 8b2eb8d208966ed8f8056f01ece1b1a373033014..0000000000000000000000000000000000000000 --- a/convlab/base_models/t5/key2gen/finetune.sh +++ /dev/null @@ -1,116 +0,0 @@ -set -e -dataset_path=$1 -model_name=$2 -model_name_or_path=$3 -dataset_name=$4 -if [ "${dataset_name}" == "multiwoz21" ] -then - task_name="nlg" -else - task_name=${dataset_name} -fi -master_port=$5 - -n_gpus=2 -cache_dir="../cache" -metric_name_or_path="metric.py" -source_column="context+knowledge" -target_column="response" -truncation_side="left" -max_source_length=512 -max_target_length=512 -per_device_train_batch_size=64 -per_device_eval_batch_size=64 -gradient_accumulation_steps=1 -num_workers=16 -lr=1e-3 -num_train_epochs=100 - -for shot in 50 100 200 -do - for dial_ids_order in 0 1 2 3 4 - do - python create_data.py -t ${task_name} -d ${dataset_name} -o ${dial_ids_order} -s ${shot} - - data_dir="data/${task_name}/${dataset_name}_${shot}shot_order${dial_ids_order}" - output_dir="output/${model_name}/${task_name}/${dataset_name}_${shot}shot_order${dial_ids_order}" - logging_dir="${output_dir}/runs" - train_file="${data_dir}/train.json" - validation_file="${data_dir}/validation.json" - - # training - python -m torch.distributed.launch --master_port ${master_port} \ - --nproc_per_node ${n_gpus} ../run_seq2seq.py \ - --task_name ${task_name} \ - --dataset_name ${dataset_path} \ - --dataset_config_name ${task_name} \ - --train_file ${train_file} \ - --validation_file ${validation_file} \ - --source_column ${source_column} \ - --target_column ${target_column} \ - --max_source_length ${max_source_length} \ - --max_target_length ${max_target_length} \ - --truncation_side ${truncation_side} \ - --model_name_or_path ${model_name_or_path} \ - --do_train \ - --do_eval \ - --save_strategy epoch \ - --evaluation_strategy epoch \ - --save_total_limit 1 \ - --prediction_loss_only \ - --load_best_model_at_end \ - --overwrite_output_dir \ - --cache_dir ${cache_dir} \ - --output_dir ${output_dir} \ - --logging_dir ${logging_dir} \ - --preprocessing_num_workers ${num_workers} \ - --dataloader_num_workers ${num_workers} \ - --per_device_train_batch_size ${per_device_train_batch_size} \ - --per_device_eval_batch_size ${per_device_eval_batch_size} \ - --gradient_accumulation_steps ${gradient_accumulation_steps} \ - --learning_rate ${lr} \ - --num_train_epochs ${num_train_epochs} \ - --optim adafactor \ - --lr_scheduler_type constant \ - --gradient_checkpointing - - # inference - test_file="data/${task_name}/test.json" - gen_output_dir="${output_dir}/gen" - - python -m torch.distributed.launch --master_port ${master_port} \ - --nproc_per_node ${n_gpus} ../run_seq2seq.py \ - --task_name ${task_name} \ - --dataset_name ${dataset_path} \ - --dataset_config_name ${task_name} \ - --metric_name_or_path ${metric_name_or_path} \ - --metric_config_name ${task_name} \ - --test_file ${test_file} \ - --source_column ${source_column} \ - --target_column ${target_column} \ - --max_source_length ${max_source_length} \ - --max_target_length ${max_target_length} \ - --truncation_side ${truncation_side} \ - --model_name_or_path ${output_dir} \ - --do_predict \ - --predict_with_generate \ - --cache_dir ${cache_dir} \ - --output_dir ${gen_output_dir} \ - --logging_dir ${logging_dir} \ - --overwrite_output_dir \ - --preprocessing_num_workers ${num_workers} \ - --dataloader_num_workers ${num_workers} \ - --per_device_train_batch_size ${per_device_train_batch_size} \ - --per_device_eval_batch_size ${per_device_eval_batch_size} \ - --gradient_accumulation_steps ${gradient_accumulation_steps} \ - --learning_rate ${lr} \ - --num_train_epochs ${num_train_epochs} \ - --optim adafactor \ - --lr_scheduler_type constant \ - --gradient_checkpointing - - done -done - -# evaluation -python evaluate.py --output_dirs output/${model_name} -t ${task_name} -s 50 100 200 -o 0 1 2 3 4 diff --git a/convlab/base_models/t5/key2gen/metric.py b/convlab/base_models/t5/key2gen/metric.py deleted file mode 100644 index 808934b65268ab2ae4180b9bbe64457fb5ca1b68..0000000000000000000000000000000000000000 --- a/convlab/base_models/t5/key2gen/metric.py +++ /dev/null @@ -1,434 +0,0 @@ -# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Grounded Dialog Generation Metric""" - -from weakref import ref -import datasets -from sacrebleu.metrics import BLEU -from sacrebleu.utils import sum_of_lists -import re -from collections import Counter -import numpy as np -from nltk.corpus import stopwords -from rouge_score import rouge_scorer, scoring -from nltk.translate import meteor_score -from datasets.config import importlib_metadata, version -from convlab.base_models.t5.key2gen.features import FEATURES -from convlab.util import load_ontology -from copy import deepcopy - - -NLTK_VERSION = version.parse(importlib_metadata.version("nltk")) -if NLTK_VERSION >= version.Version("3.6.5"): - from nltk import word_tokenize - -# Uncomment to download nltk_data for the first time running. -# import nltk -# nltk.download("wordnet") -# if NLTK_VERSION >= version.Version("3.6.5"): -# nltk.download("punkt") -# if NLTK_VERSION >= version.Version("3.6.6"): -# nltk.download("omw-1.4") - - -_CITATION = """ -""" - -_DESCRIPTION = """\ -Metric to evaluate text generation models on the grounded dialog generation task. -""" - -# TODO -_KWARGS_DESCRIPTION = """ -Args: - predictions: list of predictions to score. Each predictions - should be a string. - references: list of reference for each prediction. Each - reference should be a string. - knowledge: task-specific grounded knowledge - -Returns: - bleu-1/2/3/4: corpus-bleu score, from sacrebleu - rouge-1/2/L: ROUGE-F1, from rouge_score - meteor: METEOR, from nltk - unigram f1: unigram overlap, from parlai - distinct-1/2: from parlai - other knowledge utility score: task-specific knowledge utility metrics -""" - -re_art = re.compile(r'\b(a|an|the)\b') -re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']') -stop_words = set(stopwords.words("english")) -def utt2words(s): - """Lower text and remove punctuation, articles and extra whitespace. - from parlai https://github.com/facebookresearch/ParlAI/blob/9daae69320c07104493486e022c0e46a7871b253/parlai/core/metrics.py#L810""" - s = s.lower() - s = re_punc.sub(' ', s) - s = re_art.sub(' ', s) - return s.split() - - -def get_bleu(predictions, references): - """bleu-1/2/3/4 from sacrebleu""" - references = [" " if ref=="" else ref for ref in references] - metrics = {} - bleu = BLEU(lowercase=True, force=False, tokenize=BLEU.TOKENIZER_DEFAULT, smooth_method="exp", smooth_value=None, effective_order=False) - stats = sum_of_lists(bleu._extract_corpus_statistics(predictions, [references])) - for n in range(1,5): - metrics[f"bleu-{n}"] = bleu.compute_bleu( - correct=stats[2: 2 + bleu.max_ngram_order], - total=stats[2 + bleu.max_ngram_order:], - sys_len=int(stats[0]), ref_len=int(stats[1]), - smooth_method=bleu.smooth_method, smooth_value=bleu.smooth_value, - effective_order=bleu.effective_order, - max_ngram_order=n).score - return metrics - - -def get_unigram_f1(predictions, references): - """unigram f1 between prediction and reference, from parlai""" - metrics = {} - metrics["unigram f1"] = [] - metrics["unigram f1 (non-stop words)"] = [] - for prediction, reference in zip(predictions, references): - pred_items = utt2words(prediction) - gold_items = utt2words(reference) - for remove_stopwords in [False, True]: - if remove_stopwords: - pred_items = [w for w in pred_items if w not in stop_words] - gold_items = [w for w in gold_items if w not in stop_words] - common = Counter(pred_items) & Counter(gold_items) - num_same = sum(common.values()) - if num_same == 0: - f1 = 0 - else: - precision = 1.0 * num_same / len(pred_items) - recall = 1.0 * num_same / len(gold_items) - f1 = (2 * precision * recall) / (precision + recall) - if not remove_stopwords: - metrics["unigram f1"].append(f1) - else: - metrics["unigram f1 (non-stop words)"].append(f1) - metrics["unigram f1"] = np.mean(metrics["unigram f1"]) * 100 - metrics["unigram f1 (non-stop words)"] = np.mean(metrics["unigram f1 (non-stop words)"]) * 100 - return metrics - - -def get_rouge(predictions, references): - """rouge-1/2/L from rouge-score""" - rouge_types=["rouge1", "rouge2", "rougeL"] - scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True) - aggregator = scoring.BootstrapAggregator() - - for prediction, reference in zip(predictions, references): - score = scorer.score(reference, prediction) - aggregator.add_scores(score) - - return {key: 100 * (value.mid.fmeasure if key == "rougeL" else value.mid.recall) for key, value in aggregator.aggregate().items()} - - -def get_meteor(predictions, references): - """meteor from nltk""" - alpha=0.9 - beta=3 - gamma=0.5 - if NLTK_VERSION >= version.Version("3.6.5"): - scores = [ - meteor_score.single_meteor_score( - word_tokenize(ref), word_tokenize(pred), alpha=alpha, beta=beta, gamma=gamma - ) - for ref, pred in zip(references, predictions) - ] - else: - scores = [ - meteor_score.single_meteor_score(ref, pred, alpha=alpha, beta=beta, gamma=gamma) - for ref, pred in zip(references, predictions) - ] - return {"meteor": np.mean(scores) * 100} - - -def get_distinct(predictions): - """distinct-1/2 - from parlai https://github.com/facebookresearch/ParlAI/blob/9daae69320c07104493486e022c0e46a7871b253/parlai/core/metrics.py#L781""" - def _ngram(seq, n): - for i in range(len(seq) - n + 1): - yield tuple(seq[i : i + n]) - - metrics = {} - for k in [1, 2]: - inter_cnt = Counter() - for prediction in predictions: - ngram = Counter(_ngram(utt2words(prediction), k)) - inter_cnt += ngram - metrics[f"distinct-{k}"] = max(len(inter_cnt), 1e-12) / max(sum(inter_cnt.values()), 1e-5) * 100 - return metrics - - -def get_nlg_slot_err(predictions, knowledge): - """slot error rate: (missing_count + redundant_count) / all_count for value in dialog acts""" - val2ds_dict = {} - ontology = load_ontology("multiwoz21") - for domain_name in ontology["domains"]: - domain = ontology["domains"][domain_name] - for slot_name in domain["slots"]: - slot = domain["slots"][slot_name] - if "possible_values" not in slot: - continue - possible_vals = slot["possible_values"] - if len(possible_vals) > 0: - for val in possible_vals: - val2ds_dict[val] = f"{domain_name}-{slot_name}" - score_list = [] - for utterance, da in zip(predictions, knowledge): - missing_count = 0 - redundant_count = 0 - all_count = 0 - all_values = set() - ## missing values - # print(da) - # print(utterance) - for key in ['categorical', 'non-categorical']: - for value in da[key]['value']: - if len(value) > 0: - # print(value) - all_values.add(value) - if value.strip().lower() not in utterance.lower(): - missing_count += 1 - # print(f"\tmissing: {value}") - all_count += 1 - if all_count == 0: - continue - ## redundant values - for val in val2ds_dict: - if f" {val.strip().lower()} " in f" {utterance.strip().lower()} " and val.strip().lower() not in all_values: - wlist = val2ds_dict[val].split("-") - domain, slot = wlist[0], wlist[1] - if f" {slot.strip().lower()}" in f" {utterance.strip().lower()} ": - redundant_count += 1 - # print(f"redundant: {val}/{val2ds_dict[val]}") - item_score = float(missing_count + redundant_count) / all_count - # print(f"\tredundant: {redundant_count} | missing_count: {missing_count} |all_count: {all_count}") - # print('-'*100) - score_list.append(item_score) - return {"err": np.mean(score_list) * 100} - - -def load_entities(): - """modified (load from unified ontology) from UnifiedSKG - https://github.com/HKUNLP/UnifiedSKG/blob/49a2ff950bb312b980c22ad72b11520db72ab6a3/metrics/kvret/evaluator.py#L8""" - - ontology = load_ontology("kvret") - all_entities = set() - for domain in ontology["domains"]: - for slot in ontology["domains"][domain]["slots"]: - all_entities |= set(ontology["domains"][domain]["slots"][slot]["possible_values"]) - missed_entities = ["yoga", "tennis", "swimming", "football", " lab ", "doctor", "optometrist", "dentist", "1st", - "2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th", - "11th", "12th", "13th", "14th", "15th", "16th", "17th", "18th", "19th", "20th", "Jill", - "Jack"] - all_entities |= set(missed_entities) - all_entities.remove("HR") - all_entities.add(" HR ") - all_entities = sorted(list(all_entities), key=lambda i: len(i), reverse=True) - return all_entities - - -def check_sub_str(str_list: list, sub_str: str): - """ - It takes a list of strings and a substring as input, and returns True if the substring is found - in any of the strings in the list, and False otherwise - """ - for str_item in str_list: - if sub_str in str_item or sub_str.lower() in str_item.lower(): - return True - return False - - -def extract_entities_from_utterance(utterance, sorted_entities): - """modified (remove underscore) from UnifiedSKG - https://github.com/HKUNLP/UnifiedSKG/blob/49a2ff950bb312b980c22ad72b11520db72ab6a3/metrics/kvret/response_entity_hit.py#L45""" - - utterance = " {} ".format(utterance) # for entity matching - for h in range(0, 13): # for formulating am & pm - utterance = utterance.replace("{} am".format(h), "{}am".format(h)) - utterance = utterance.replace("{} pm".format(h), "{}pm".format(h)) - for entity_item_a in [20, 30, 40, 50, 60, 70, 80, 90, 100]: - for entity_item_b in [20, 30, 40, 50, 60, 70, 80, 90, 100]: - utterance = utterance.replace("{}-{}f".format(str(entity_item_a), str(entity_item_b)), "{}f-{}f".format(str(entity_item_a), str(entity_item_b))) - entities_in_this_utterance = [] - for entity in sorted_entities: - # len(entity) decreases - if (entity in utterance) or (entity.lower() in utterance.lower()): - if not check_sub_str(entities_in_this_utterance, entity): - # in case of "week & weekend", "week & next_week" etc - entities_in_this_utterance.append(entity) - return entities_in_this_utterance - - -def f1_score(y_pred, y_true, average="micro"): - """micro/marco-F1 score, modified from UnifiedSKG - https://github.com/HKUNLP/UnifiedSKG/blob/49a2ff950bb312b980c22ad72b11520db72ab6a3/metrics/kvret/response_entity_hit.py#L76""" - - assert len(y_pred) == len(y_true) - - def _compute_F1(precision, recall): - return 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0 - - def _compute_prf(gold, pred): - TP, FP, FN = 0, 0, 0 - if len(gold) != 0: - count = 1 - for g in gold: - if g in pred: - TP += 1 - else: - FN += 1 - for p in set(pred): - if p not in gold: - FP += 1 - precision = TP / float(TP + FP) if (TP + FP) != 0 else 0 - recall = TP / float(TP + FN) if (TP + FN) != 0 else 0 - F1 = _compute_F1(precision, recall) - else: - precision, recall, F1, count = 0, 0, 0, 0 - return TP, FP, FN, F1, count - - F1_pred, F1_count, TP_all, FP_all, FN_all = 0, 0, 0, 0, 0 - - for y_true_item, y_pred_item in zip(y_true, y_pred): - single_tp, single_fp, single_fn, single_f1, count = _compute_prf(y_true_item, y_pred_item) - F1_pred += single_f1 - F1_count += count - TP_all += single_tp - FP_all += single_fp - FN_all += single_fn - - if average == "macro": - F1_macro_score = F1_pred / float(F1_count) if F1_count != 0 else 0 - return F1_macro_score * 100 - elif average == "micro": - P_score = TP_all / float(TP_all + FP_all) if (TP_all + FP_all) != 0 else 0 - R_score = TP_all / float(TP_all + FN_all) if (TP_all + FN_all) != 0 else 0 - F1_micro_score = _compute_F1(P_score, R_score) - return F1_micro_score * 100 - else: - raise ValueError("Options other than micro/macro are not supported.") - - -def get_kvret_entity_f1(predictions, references, knowledge): - """entity f1 for kvret, modified from - https://github.com/HKUNLP/UnifiedSKG/blob/49a2ff950bb312b980c22ad72b11520db72ab6a3/metrics/kvret/response_entity_hit.py#L178""" - - global_entities = load_entities() - F1_scores = {} - entities_from_predictions_and_references = { - d: {"predictions_entities": [], "references_entities": []} for d in ["all", "schedule", "weather", "navigate"] - } - for prediction, reference, kb in zip(predictions, references, knowledge): - prediction_entities = extract_entities_from_utterance(utterance=prediction, sorted_entities=global_entities) - reference_entities = extract_entities_from_utterance(utterance=reference, sorted_entities=global_entities) - entities_from_predictions_and_references["all"]["predictions_entities"].append(prediction_entities) - entities_from_predictions_and_references["all"]["references_entities"].append(reference_entities) - domain = "schedule" - for d in kb: - if len(kb[d]["entity"]) > 0: - domain = d - break - entities_from_predictions_and_references[domain]["predictions_entities"].append(prediction_entities) - entities_from_predictions_and_references[domain]["references_entities"].append(reference_entities) - - for category in entities_from_predictions_and_references.keys(): - predictions_entities = entities_from_predictions_and_references[category]["predictions_entities"] - references_entities = entities_from_predictions_and_references[category]["references_entities"] - F1_scores["{} micro entity F1".format(category)] = f1_score(y_pred=predictions_entities, y_true=references_entities, average="micro") - F1_scores["{} macro entity F1".format(category)] = f1_score(y_pred=predictions_entities, y_true=references_entities, average="macro") - - return {**F1_scores} - - -def get_opendialkg_entity_f1(predictions, references, knowledge): - predictions_entities, references_entities = [], [] - for prediction, reference, kg_path in zip(predictions, references, knowledge): - kg_entities = set() - for kg_triple in kg_path: - # add head and tail entities - kg_entities.add(kg_triple[0]) - kg_entities.add(kg_triple[-1]) - kg_entities = sorted(list(kg_entities), key=lambda i: len(i), reverse=True) - - for utterance, entities in zip([prediction, reference], [predictions_entities, references_entities]): - entities_in_this_utterance = [] - for entity in kg_entities: - if (entity in utterance) or (entity.lower() in utterance.lower()): - if not check_sub_str(entities_in_this_utterance, entity): - # in case of "week & weekend", "week & next_week" etc - entities_in_this_utterance.append(entity) - entities.append(entities_in_this_utterance) - - return { - "micro entity f1": f1_score(y_pred=predictions_entities, y_true=references_entities, average="micro"), - "macro entity f1": f1_score(y_pred=predictions_entities, y_true=references_entities, average="macro") - } - -def get_knowledge_sentences_f1(predictions, knowledge): - knowledge_reference = [' '.join(k_sens) for k_sens in knowledge] - f1_score = get_unigram_f1(predictions, knowledge_reference) - return {f"knowledge {k}": v for k, v in f1_score.items()} - - -@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) -class GroundedDialogGenerationMetrics(datasets.Metric): - """Metric to evaluate text generation models on the grounded dialog generation task.""" - def _info(self): - return datasets.MetricInfo( - description=_DESCRIPTION, - citation=_CITATION, - inputs_description=_KWARGS_DESCRIPTION, - features=datasets.Features({ - "predictions": datasets.Value("string"), - "references": datasets.Value("string"), - "knowledge": deepcopy(FEATURES[self.config_name]["knowledge"]) - }) - ) - - def compute(self, predictions, references, knowledge=None): - """Returns the scores: bleu""" - metrics = {} - - # bleu - metrics.update(get_bleu(predictions, references)) - - # unigram f1 - metrics.update(get_unigram_f1(predictions, references)) - - # rouge-1/2/L-fmeasure - metrics.update(get_rouge(predictions, references)) - - # meteor - metrics.update(get_meteor(predictions, references)) - - # inter-distinct-1/2 - metrics.update(get_distinct(predictions)) - - if knowledge is not None: - if self.config_name == "nlg": - metrics.update(get_nlg_slot_err(predictions, knowledge)) - elif self.config_name == "kvret": - metrics.update(get_kvret_entity_f1(predictions, references, knowledge)) - elif self.config_name == "opendialkg": - metrics.update(get_opendialkg_entity_f1(predictions, references, knowledge)) - elif self.config_name in ["wow", "personachat"]: - metrics.update(get_knowledge_sentences_f1(predictions, knowledge)) - - return metrics diff --git a/convlab/base_models/t5/nlg/__init__.py b/convlab/base_models/t5/nlg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..32275a5755248f5cfb67192b0a0f530cee0a276e --- /dev/null +++ b/convlab/base_models/t5/nlg/__init__.py @@ -0,0 +1 @@ +from convlab.base_models.t5.nlg.nlg import T5NLG \ No newline at end of file diff --git a/convlab/base_models/t5/nlg/nlg.py b/convlab/base_models/t5/nlg/nlg.py index 2781fded74c3b02a9c46a6c289d6f0e5fb850f2b..214dc01eed75cfbb85a740e8f5fee8a759d813b0 100755 --- a/convlab/base_models/t5/nlg/nlg.py +++ b/convlab/base_models/t5/nlg/nlg.py @@ -8,17 +8,13 @@ from convlab.util.custom_util import model_downloader class T5NLG(NLG): - def __init__(self, speaker, context_window_size, model_name_or_path, model_file=None, device='cuda'): + def __init__(self, speaker, context_window_size, model_name_or_path, device='cuda'): assert speaker in ['user', 'system'] self.speaker = speaker self.opponent = 'system' if speaker == 'user' else 'user' self.context_window_size = context_window_size self.use_context = context_window_size > 0 - model_dir = os.path.dirname(os.path.abspath(__file__)) - if not os.path.exists(model_name_or_path): - model_downloader(model_dir, model_file) - self.config = AutoConfig.from_pretrained(model_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config) diff --git a/convlab/base_models/t5/nlu/__init__.py b/convlab/base_models/t5/nlu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed4bbd0306fe68dfbbaa69f7e93cdb2b68c23e6d --- /dev/null +++ b/convlab/base_models/t5/nlu/__init__.py @@ -0,0 +1 @@ +from convlab.base_models.t5.nlu.nlu import T5NLU \ No newline at end of file diff --git a/convlab/base_models/t5/nlu/nlu.py b/convlab/base_models/t5/nlu/nlu.py index 2862cea7aa74c8c365a047dc74aa41dc79ead405..a5a6e6a23ec184b15fc073d88fa1a6b3fece34d8 100755 --- a/convlab/base_models/t5/nlu/nlu.py +++ b/convlab/base_models/t5/nlu/nlu.py @@ -8,16 +8,12 @@ from convlab.util.custom_util import model_downloader class T5NLU(NLU): - def __init__(self, speaker, context_window_size, model_name_or_path, model_file=None, device='cuda'): + def __init__(self, speaker, context_window_size, model_name_or_path, device='cuda'): assert speaker in ['user', 'system'] self.speaker = speaker self.opponent = 'system' if speaker == 'user' else 'user' self.context_window_size = context_window_size self.use_context = context_window_size > 0 - - model_dir = os.path.dirname(os.path.abspath(__file__)) - if not os.path.exists(model_name_or_path): - model_downloader(model_dir, model_file) self.config = AutoConfig.from_pretrained(model_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)