Skip to content
Snippets Groups Projects
Commit 8013a6fa authored by zqwerty's avatar zqwerty
Browse files

add T5 readme

parent 7c2911ff
No related branches found
No related tags found
No related merge requests found
Showing
with 88 additions and 1141 deletions
...@@ -46,6 +46,7 @@ To use ConvLab-3 as an off-the-shelf tool, you can install via: ...@@ -46,6 +46,7 @@ To use ConvLab-3 as an off-the-shelf tool, you can install via:
```bash ```bash
pip install convlab pip install convlab
``` ```
Note that the `data` directory will not be included due to the package size limitation.
### Using Docker ### Using Docker
...@@ -99,7 +100,7 @@ We list newly integrated models in ConvLab-3 that support unified data format an ...@@ -99,7 +100,7 @@ We list newly integrated models in ConvLab-3 that support unified data format an
| Task | Models | Input | Output | | Task | Models | Input | Output |
| ------------------------------ | ------------------------------------------------------------ | --------------- | ---------------- | | ------------------------------ | ------------------------------------------------------------ | --------------- | ---------------- |
| Response Generation | [T5](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/base_models/t5) | Context | Response | | Response Generation | [T5](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/base_models/t5) | Context | Response |
| Goal-to-Dialog | [T5](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/base_models/t5) | Goal | Dialog | | Goal-to-Dialogue | [T5](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/base_models/t5) | Goal | Dialog |
| Natural Language Understanding | [T5](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/base_models/t5), [BERTNLU](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/nlu/jointBERT), [MILU](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/nlu/milu) | Context | DA-U | | Natural Language Understanding | [T5](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/base_models/t5), [BERTNLU](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/nlu/jointBERT), [MILU](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/nlu/milu) | Context | DA-U |
| Dialog State Tracking | [T5](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/base_models/t5), SUMBT, SetSUMBT, TripPy | Context | State | | Dialog State Tracking | [T5](https://github.com/ConvLab/ConvLab-3/tree/master/convlab/base_models/t5), SUMBT, SetSUMBT, TripPy | Context | State |
| RL Policy | DDPT, PPO, PG | State, DA-U, DB | DA-S | | RL Policy | DDPT, PPO, PG | State, DA-U, DB | DA-S |
......
# T5 models
By converting NLP tasks into a text-to-text format, we can use one single model to solve various tasks. Here we use T5 as backbone model and provide a unified training script `run_seq2seq.py` for many tasks. **See `*.sh` under each task directory for usage.**
## Create Data
Currently we support natural language understanding (**NLU**), dialog state tracking (**DST**), natural language generation (**NLG**), response generation (**RG**), and generating a dialog from a user goal (**Goal2Dialogue**). We provide serialization and deserialization methods for dialog acts and state in the unified data format (user goals are already natural language instruction). An example of serialized dialog acts and state:
```
User: I am looking for a cheap restaurant.
System: Is there a particular area of town you prefer?
User: In the centre of town.
User dialog acts: [inform][restaurant]([area][centre])
State: [restaurant]([area][centre],[price range][cheap])
System dialog acts: [recommend][restaurant]([name][Zizzi Cambridge])
System: I would recommend Zizzi Cambridge.
```
Dialogue acts are in the form of `[intent][domain]([slot][value],...);...`. State is in the form of `[domain]([slot][value],...);...`. Multiple items will be concatenated by a semicolon `;`.
To create data for a specific task, run `create_data.py` with corresponding arguments. For example, create data for single turn NLU on MultiWOZ 2.1:
```bash
python create_data.py --tasks nlu --datasets multiwoz21 --speaker user
```
Note that the script only supported **datasets in the unified format**.
## Training
To train the model, specify the arguments like data path, learning rate, epochs, etc., and then run `run_seq2seq.py`. See `nlu/run_nlu.sh` for an example.
## Evaluation
The standard evaluation scripts of NLU, DST, and NLG task are located under `../../$task/evaluate_unified_datasets.py` directories. See `nlu/run_nlu.sh` for an example.
## Trained Models
Trained models and their performance are available in [Hugging Face Hub](https://huggingface.co/ConvLab). You can try some example with hosted inference API.
| Name | Task | Training Dataset |
| ------------------------------------------------------------ | ------------- | ---------------------------- |
| [t5-small-goal2dialogue-multiwoz21](https://huggingface.co/ConvLab/t5-small-goal2dialogue-multiwoz21) | Goal2Dialogue | MultiWOZ 2.1 |
| [t5-small-nlu-multiwoz21](https://huggingface.co/ConvLab/t5-small-nlu-multiwoz21) | NLU | MultiWOZ 2.1 |
| [t5-small-nlu-sgd](https://huggingface.co/ConvLab/t5-small-nlu-sgd) | NLU | SGD |
| [t5-small-nlu-tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-nlu-tm1_tm2_tm3) | NLU | TM1+TM2+TM3 |
| [t5-small-nlu-multiwoz21_sgd_tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-nlu-multiwoz21_sgd_tm1_tm2_tm3) | NLU | MultiWOZ 2.1+SGD+TM1+TM2+TM3 |
| [t5-small-dst-multiwoz21](https://huggingface.co/ConvLab/t5-small-dst-multiwoz21) | DST | MultiWOZ 2.1 |
| [t5-small-dst-sgd](https://huggingface.co/ConvLab/t5-small-dst-sgd) | DST | SGD |
| [t5-small-dst-tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-dst-tm1_tm2_tm3) | DST | TM1+TM2+TM3 |
| [t5-small-dst-multiwoz21_sgd_tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-dst-multiwoz21_sgd_tm1_tm2_tm3) | DST | MultiWOZ 2.1+SGD+TM1+TM2+TM3 |
| [t5-small-nlg-multiwoz21](https://huggingface.co/ConvLab/t5-small-nlg-multiwoz21) | NLG | MultiWOZ 2.1 |
| [t5-small-nlg-sgd](https://huggingface.co/ConvLab/t5-small-nlg-sgd) | NLG | SGD |
| [t5-small-nlg-tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-nlg-tm1_tm2_tm3) | NLG | TM1+TM2+TM3 |
| [t5-small-nlg-multiwoz21_sgd_tm1_tm2_tm3](https://huggingface.co/ConvLab/t5-small-nlg-multiwoz21_sgd_tm1_tm2_tm3) | NLG | MultiWOZ 2.1+SGD+TM1+TM2+TM3 |
## Interface
To use trained models in a dialog system, import them through:
```python
from convlab.base_models.t5.nlu import T5NLU
from convlab.base_models.t5.dst import T5DST
from convlab.base_models.t5.nlg import T5NLG
# example instantiation
# model_name_or_path could be model in hugging face hub or local path
nlu = T5NLU(speaker='user', context_window_size=0, model_name_or_path='ConvLab/t5-small-nlu-multiwoz21')
```
See `nlu/nlu.py`, `dst/dst.py`, `nlg/nlg.py` for example usage.
## Support a New Task
To support a new task, you can first serialize model input and output like `create_data.py`, and then train the model with `run_seq2seq.py`. Finally, write a evaluation script for the task or pass the `metric_name_or_path` for an existing metric to `run_seq2seq.py`.
## Author
Qi Zhu(zhuq96 at gmail dot com)
\ No newline at end of file
from convlab.base_models.t5.dst.dst import T5DST
\ No newline at end of file
...@@ -8,17 +8,13 @@ from convlab.util.custom_util import model_downloader ...@@ -8,17 +8,13 @@ from convlab.util.custom_util import model_downloader
class T5DST(DST): class T5DST(DST):
def __init__(self, speaker, context_window_size, model_name_or_path, model_file=None, device='cuda'): def __init__(self, speaker, context_window_size, model_name_or_path, device='cuda'):
assert speaker in ['user', 'system'] assert speaker in ['user', 'system']
assert context_window_size > 0 assert context_window_size > 0
self.speaker = speaker self.speaker = speaker
self.opponent = 'system' if speaker == 'user' else 'user' self.opponent = 'system' if speaker == 'user' else 'user'
self.context_window_size = context_window_size self.context_window_size = context_window_size
model_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.exists(model_name_or_path):
model_downloader(model_dir, model_file)
self.config = AutoConfig.from_pretrained(model_name_or_path) self.config = AutoConfig.from_pretrained(model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config)
......
import os
import json
from tqdm import tqdm
from convlab.util import load_dataset, load_unified_data, load_nlu_data
def create_nlg_data(dataset, data_dir, args):
data_by_split = load_nlu_data(dataset, speaker='system', use_context=True, context_window_size=3)
os.makedirs(data_dir, exist_ok=True)
data_splits = data_by_split.keys()
for data_split in data_splits:
data = []
for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False):
context = [(turn['speaker'], turn['utterance']) for turn in sample['context']]
response = sample['utterance']
if len(context) > 0 and len(response) > 0:
knowledge = sample['dialogue_acts']
data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n')
if 'test' in data_split:
file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json")
else:
file_name = os.path.join(data_dir, f"{data_split}.json")
with open(file_name, "w", encoding='utf-8') as f:
f.writelines(data)
data_by_split[data_split] = data
return data_by_split
def create_kvret_data(dataset, data_dir, args):
data_by_split = load_unified_data(dataset, speaker='system', utterance=True, db_results=True, use_context=True, context_window_size=100)
os.makedirs(data_dir, exist_ok=True)
domain2entity_col = {'schedule': 'event' ,'navigate': 'poi', 'weather': 'location'}
data_splits = data_by_split.keys()
for data_split in data_splits:
data = []
for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False):
context = [(turn['speaker'], turn['utterance']) for turn in sample['context']]
response = sample['utterance']
if len(context) > 0 and len(response) > 0:
knowledge = sample['db_results']
for domain, db_items in knowledge.items():
entity_col = domain2entity_col[domain]
for db_item in db_items:
db_item['entity'] = db_item.pop(entity_col)
data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n')
if 'test' in data_split:
file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json")
else:
file_name = os.path.join(data_dir, f"{data_split}.json")
with open(file_name, "w", encoding='utf-8') as f:
f.writelines(data)
data_by_split[data_split] = data
return data_by_split
def create_personachat_data(dataset, data_dir, args):
data_by_split = dataset
os.makedirs(data_dir, exist_ok=True)
data_splits = data_by_split.keys()
for data_split in data_splits:
data = []
for dial in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False):
knowledge = dial['persona']['system']
context = []
for turn in dial['turns']:
response = turn['utterance']
if turn['speaker'] == 'system' and len(context) > 0 and len(response) > 0:
data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n')
context.append((turn['speaker'], turn['utterance']))
if 'test' in data_split:
file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json")
else:
file_name = os.path.join(data_dir, f"{data_split}.json")
with open(file_name, "w", encoding='utf-8') as f:
f.writelines(data)
data_by_split[data_split] = data
return data_by_split
def create_wow_data(dataset, data_dir, args):
data_by_split = dataset
os.makedirs(data_dir, exist_ok=True)
data_by_split['test'] = data_by_split['test_seen'] + data_by_split['test_unseen']
data_by_split.pop('test_seen')
data_by_split.pop('test_unseen')
data_splits = data_by_split.keys()
for data_split in data_splits:
data = []
for dial in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False):
context = []
for turn in dial['turns']:
response = turn['utterance']
if turn['speaker'] == 'system' and len(context) > 0 and len(response) > 0:
knowledge = turn['checked_passage']
if knowledge is None:
knowledge = []
elif isinstance(knowledge, str):
knowledge = [knowledge]
data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n')
context.append((turn['speaker'], turn['utterance']))
if 'test' in data_split:
file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json")
else:
file_name = os.path.join(data_dir, f"{data_split}.json")
with open(file_name, "w", encoding='utf-8') as f:
f.writelines(data)
data_by_split[data_split] = data
return data_by_split
def create_opendialkg_data(dataset, data_dir, args):
data_by_split = dataset
os.makedirs(data_dir, exist_ok=True)
data_splits = data_by_split.keys()
for data_split in data_splits:
data = []
for dial in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False):
context = []
for turn in dial['turns']:
response = turn['utterance']
if turn['speaker'] == 'system' and 'kg_path' in turn and len(context) > 0 and len(response) > 0:
knowledge = turn['kg_path']['triples']
data.append(json.dumps({'context': context, 'knowledge': knowledge, 'response': response}, ensure_ascii=False)+'\n')
context.append((turn['speaker'], turn['utterance']))
if 'test' in data_split:
file_name = os.path.join(os.path.dirname(data_dir), f"{data_split}.json")
else:
file_name = os.path.join(data_dir, f"{data_split}.json")
with open(file_name, "w", encoding='utf-8') as f:
f.writelines(data)
data_by_split[data_split] = data
return data_by_split
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="create data for seq2seq training")
parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['nlg', 'kvret', 'opendialkg', 'personachat', 'wow'], help='names of tasks')
parser.add_argument('--datasets', '-d', metavar='dataset_name', nargs='*', help='names of unified datasets')
parser.add_argument('--shot', '-s', type=float, default=None, help='how many data is used for training and evaluation, ratio if < 1 else absolute number')
parser.add_argument('--dial_ids_order', '-o', type=int, default=None, help='which data order is used for experiments')
args = parser.parse_args()
print(args)
for dataset_name in tqdm(args.datasets, desc='datasets'):
dataset = load_dataset(dataset_name, dial_ids_order=args.dial_ids_order)
if args.shot:
if args.shot < 1:
dataset['train'] = dataset['train'][:round(len(dataset['train'])*args.shot)]
dataset['validation'] = dataset['validation'][:round(len(dataset['validation'])*args.shot)]
else:
args.shot = int(args.shot)
dataset['train'] = dataset['train'][:args.shot]
dataset['validation'] = dataset['validation'][:args.shot]
for task_name in tqdm(args.tasks, desc='tasks', leave=False):
data_dir = os.path.join('data', task_name, (dataset_name if not args.shot else f'{dataset_name}_{args.shot}shot_order{args.dial_ids_order}'))
data_by_split = eval(f"create_{task_name}_data")(dataset, data_dir, args)
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data processing for vanilla generator"""
import json
import datasets
from convlab.base_models.t5.key2gen.features import FEATURES
from copy import deepcopy
class GodelDataset(datasets.GeneratorBasedBuilder):
"""Dataset for vanilla generator (e.g., t5)"""
VERSION = datasets.Version("1.18.0")
BUILDER_CONFIGS = [
datasets.BuilderConfig(name="nlg", version=VERSION, description="DA grounded generation task"),
datasets.BuilderConfig(name="kvret", version=VERSION, description="KB grounded generation task"),
datasets.BuilderConfig(name="opendialkg", version=VERSION, description="KG grounded generation task"),
datasets.BuilderConfig(name="wow", version=VERSION, description="Passage grounded generation task"),
datasets.BuilderConfig(name="personachat", version=VERSION, description="Persona grounded generation task"),
]
def _info(self):
return datasets.DatasetInfo(
description=f"Vanilla Dataset for {self.config.description}",
features=datasets.Features(deepcopy(FEATURES[self.config.name]))
)
def _split_generators(self, dl_manager):
generators = []
if "train" in self.config.data_files:
generators.append(datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": self.config.data_files["train"][0],
"split": "train",
},
))
if "validation" in self.config.data_files:
generators.append(datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={
"filepath": self.config.data_files["validation"][0],
"split": "validation",
},
))
if "test" in self.config.data_files:
generators.append(datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepath": self.config.data_files["test"][0],
"split": "test",
},
))
return generators
def _generate_examples(self, filepath, split):
with open(filepath, encoding="utf-8") as f:
for key, row in enumerate(f):
item = json.loads(row)
if self.config.name == "nlg":
knowledge = item["knowledge"]
triples = []
for da_type in knowledge:
for da in knowledge[da_type]:
intent, domain, slot, value = da["intent"], da["domain"], da["slot"], da.get("value", "")
if 'start' in da:
da.pop('start')
da.pop('end')
intent_domain = f"{intent}-{domain}"
triples.append([intent_domain])
if len(slot) > 0:
triples[-1].append(slot)
if len(value) > 0:
triples[-1].append(value)
knowledge_seq = "| {} |".format(" | ".join([" : ".join(da_keywords) for da_keywords in triples]))
elif self.config.name == "kvret":
knowledge = {"schedule": [], "weather": [], "navigate": []}
triples = []
for domain, db_items in item["knowledge"].items():
knowledge[domain] = db_items
for db_item in db_items:
entity = db_item["entity"]
for db_key, db_value in db_item.items():
if db_key == "entity":
continue
triples.append([entity, db_key, db_value])
knowledge_seq = "| {} |".format(" | ".join([" : ".join(triple) for triple in triples]))
elif self.config.name == "opendialkg":
knowledge = item["knowledge"]
knowledge_seq = "| {} |".format(" | ".join([" : ".join(triple) for triple in item["knowledge"]]))
elif self.config.name in ["wow", "personachat"]:
knowledge = item["knowledge"]
try:
knowledge_seq = "| {} |".format(" | ".join(item["knowledge"]))
except:
print([knowledge])
raise
context = " EOS ".join([turn[1] for turn in item["context"]])
context_knowledge = context + ' <|Knowledge|> \n\n' + knowledge_seq + ' => '
yield key, {
"context+knowledge": context_knowledge,
"response": item["response"],
"knowledge": knowledge,
}
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data processing for vanilla generator"""
import json
import datasets
from convlab.base_models.t5.key2gen.features import FEATURES
from copy import deepcopy
class VanillaDataset(datasets.GeneratorBasedBuilder):
"""Dataset for vanilla generator (e.g., t5)"""
VERSION = datasets.Version("1.18.0")
BUILDER_CONFIGS = [
datasets.BuilderConfig(name="nlg", version=VERSION, description="DA grounded generation task"),
datasets.BuilderConfig(name="kvret", version=VERSION, description="KB grounded generation task"),
datasets.BuilderConfig(name="opendialkg", version=VERSION, description="KG grounded generation task"),
datasets.BuilderConfig(name="wow", version=VERSION, description="Passage grounded generation task"),
datasets.BuilderConfig(name="personachat", version=VERSION, description="Persona grounded generation task"),
]
def _info(self):
return datasets.DatasetInfo(
description=f"Vanilla Dataset for {self.config.description}",
features=datasets.Features(deepcopy(FEATURES[self.config.name]))
)
def _split_generators(self, dl_manager):
generators = []
if "train" in self.config.data_files:
generators.append(datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"filepath": self.config.data_files["train"][0],
"split": "train",
},
))
if "validation" in self.config.data_files:
generators.append(datasets.SplitGenerator(
name=datasets.Split.VALIDATION,
gen_kwargs={
"filepath": self.config.data_files["validation"][0],
"split": "validation",
},
))
if "test" in self.config.data_files:
generators.append(datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepath": self.config.data_files["test"][0],
"split": "test",
},
))
return generators
def _generate_examples(self, filepath, split):
with open(filepath, encoding="utf-8") as f:
for key, row in enumerate(f):
item = json.loads(row)
if self.config.name == "nlg":
knowledge = item["knowledge"]
triples = []
for da_type in knowledge:
for da in knowledge[da_type]:
intent, domain, slot, value = da["intent"], da["domain"], da["slot"], da.get("value", "")
if 'start' in da:
da.pop('start')
da.pop('end')
intent_domain = f"{intent}-{domain}"
triples.append([intent_domain])
if len(slot) > 0:
triples[-1].append(slot)
if len(value) > 0:
triples[-1].append(value)
knowledge_seq = "| {} |".format(" | ".join([" : ".join(da_keywords) for da_keywords in triples]))
elif self.config.name == "kvret":
knowledge = {"schedule": [], "weather": [], "navigate": []}
triples = []
for domain, db_items in item["knowledge"].items():
knowledge[domain] = db_items
for db_item in db_items:
entity = db_item["entity"]
for db_key, db_value in db_item.items():
if db_key == "entity":
continue
triples.append([entity, db_key, db_value])
knowledge_seq = "| {} |".format(" | ".join([" : ".join(triple) for triple in triples]))
elif self.config.name == "opendialkg":
knowledge = item["knowledge"]
knowledge_seq = "| {} |".format(" | ".join([" : ".join(triple) for triple in item["knowledge"]]))
elif self.config.name in ["wow", "personachat"]:
knowledge = item["knowledge"]
try:
knowledge_seq = "| {} |".format(" | ".join(item["knowledge"]))
except:
print([knowledge])
raise
context = "\n".join([f"{turn[0]}: {turn[1]}" for turn in item["context"]]+["system: "])
if self.config.name in ["kvret", "wow", "personachat"]:
context_knowledge = f"generate a response: all knowledge: \n\n{knowledge_seq} context:\n\n{context}"
else:
context_knowledge = f"generate a response: grounded knowledge: \n\n{knowledge_seq} context:\n\n{context}"
yield key, {
"context+knowledge": context_knowledge,
"response": item["response"],
"knowledge": knowledge,
}
%% Cell type:code id: tags:
``` python
import json
import re
```
%% Cell type:code id: tags:
``` python
def read_jsonline(path):
return [json.loads(line) for line in open(path)]
```
%% Cell type:code id: tags:
``` python
origin = read_jsonline('output/wow/wow/test_unseen.json')
```
%% Cell type:code id: tags:
``` python
key2gen = read_jsonline('output/wow/key2gen_wow/test_unseen.json')
```
%% Cell type:code id: tags:
``` python
with open('tmp_wow.txt', 'w') as f:
for d1, d2 in zip(origin, key2gen):
print(re.split('context:|grounded knowledge:', d1['context+knowledge'])[1].strip(), file=f)
print(re.split('context:|grounded knowledge:', d2['context+knowledge'])[1].strip(), file=f)
print(d1['context+knowledge'].split('context:')[1].replace('\n\n', '\n'), file=f)
print(file=f)
print('target', d1['response'], file=f)
print('origin', d1['predictions'], file=f)
print('key2gen', d2['predictions'], file=f)
print('='*100, file=f)
```
%% Cell type:code id: tags:
``` python
for ratio in [0.1, 0.01]:
for order in [0, 1, 2]:
origin = read_jsonline(f'output/personachat/key2gen_personachat_{ratio}_order{order}/generated_predictions.json')
score = metric.compute(predictions=[d['predictions'] for d in origin], references=[d['response'] for d in origin])
print(ratio, order)
print(score)
```
%% Cell type:code id: tags:
``` python
for ratio in [0.01]:
for order in [1]:
origin = read_jsonline(f'output/personachat/personachat/generated_predictions.json')
score = metric.compute(predictions=[d['predictions'] for d in origin], references=[d['response'] for d in origin])
print(ratio, order)
print(score)
```
%% Output
0.01 1
{'bleu-1': 24.322560358946276, 'bleu-2': 13.03630111937752, 'bleu-3': 7.43647978674912, 'bleu-4': 4.450365738541082, 'unigram f1': 0.20101056184593705, 'unigram f1 (non-stop words)': 0.09881569367818614, 'rouge1': 21.359332522961864, 'rouge2': 6.532120354812852, 'rougeL': 19.76437990594138}
%% Cell type:code id: tags:
``` python
from datasets import load_metric
```
%% Cell type:code id: tags:
``` python
metric = load_metric('metric.py')
```
%% Cell type:code id: tags:
``` python
metric.compute(predictions=[d['predictions'] for d in key2gen], references=[d['response'] for d in key2gen])
```
%% Output
{'bleu-1': 47.9848465486215,
'bleu-2': 37.18000679532912,
'bleu-3': 29.346646172092814,
'bleu-4': 23.410526740211363,
'unigram f1': 0.4999850046010773,
'unigram f1 (non-stop words)': 0.5150265227462978,
'rouge1': 50.536642578692195,
'rouge2': 33.10681789367832,
'rougeL': 46.84702913163778,
'meteor': 0.4641962079490068}
%% Cell type:code id: tags:
``` python
metric.compute(predictions=[d['predictions'] for d in origin], references=[d['response'] for d in origin])
```
%% Output
{'bleu-1': 37.570099942714585,
'bleu-2': 26.77393964962893,
'bleu-3': 21.115954644820572,
'bleu-4': 17.513316671216046,
'unigram f1': 0.3656930567072274,
'unigram f1 (non-stop words)': 0.36456219281235724,
'rouge1': 39.1982724920493,
'rouge2': 20.825159884632743,
'rougeL': 34.98278542180112,
'meteor': 0.3405671227693821,
'distinct-1': 0.07838670580160921,
'distinct-2': 0.29689084413659694}
%% Cell type:code id: tags:
``` python
metric.compute(predictions=[d['predictions'] for d in key2gen], references=[d['response'] for d in key2gen])
```
%% Output
{'bleu-1': 47.9848465486215,
'bleu-2': 37.18000679532912,
'bleu-3': 29.346646172092814,
'bleu-4': 23.410526740211363,
'unigram f1': 0.4999850046010773,
'unigram f1 (non-stop words)': 0.5150265227462978,
'rouge1': AggregateScore(low=Score(precision=0.5301926525013549, recall=0.4821419251082986, fmeasure=0.48565655175230005), mid=Score(precision=0.5513392693168799, recall=0.50235850981064, fmeasure=0.5053664257869219), high=Score(precision=0.5760132731228504, recall=0.5268580272115051, fmeasure=0.5279111393835526)),
'rouge2': AggregateScore(low=Score(precision=0.34772127155901306, recall=0.30411953889228, fmeasure=0.31029658993105447), mid=Score(precision=0.3696898381097765, recall=0.32612705034192035, fmeasure=0.3310681789367832), high=Score(precision=0.3947745596965405, recall=0.34880792116864995, fmeasure=0.35356317521641434)),
'rougeL': AggregateScore(low=Score(precision=0.4874189522136045, recall=0.4413343070361347, fmeasure=0.4464463084888409), mid=Score(precision=0.5108530997712726, recall=0.4642203560120527, fmeasure=0.46847029131637785), high=Score(precision=0.5350154077389535, recall=0.4855131911095939, fmeasure=0.4899950876629784)),
'rougeLsum': AggregateScore(low=Score(precision=0.4871840444049138, recall=0.44081531444183386, fmeasure=0.44514075751478493), mid=Score(precision=0.5105975305923949, recall=0.4639265647317744, fmeasure=0.46779186414456864), high=Score(precision=0.5348015149575474, recall=0.48693312722760357, fmeasure=0.4918651382986408))}
%% Cell type:code id: tags:
``` python
```
from tabulate import tabulate
import os
import json
from tqdm import tqdm
from datasets import load_metric
import numpy as np
import csv
def evaluate(filename, metric):
"""
It reads the predictions, references, and knowledge from a file, and then computes the metric
:param filename: the path to the file containing the predictions
:param metric: the metric to use for evaluation
:return: The result of the evaluation.
"""
predictions, references, knowledge = [], [], []
with open(filename, 'r') as f:
for line in f:
item = json.loads(line)
predictions.append(item['predictions'])
references.append(item['response'])
knowledge.append(item['knowledge'])
result = metric.compute(predictions=predictions, references=references, knowledge=knowledge)
return result
def avg_result(results):
"""
It takes a list of dictionaries, and returns a dictionary with the same keys, but the values are the
mean and standard deviation of the values in the input dictionaries
:param results: a list of dictionaries, each dictionary is the result of a single run of the model
:return: The average and standard deviation of the results.
"""
ret = {}
for k in results[0]:
m = round(np.mean([result[k] for result in results]), 2)
v = round(np.std([result[k] for result in results], ddof=1), 2) if len(results) > 1 else None
ret[k] = f"{m}({v})"
return ret
if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser(description="create data for seq2seq training")
parser.add_argument("--output_dirs", type=str, nargs='*', required=True)
parser.add_argument('--tasks', '-t', type=str, nargs='*', choices=['nlg', 'kvret', 'opendialkg', 'personachat', 'wow'], help='names of tasks')
parser.add_argument('--shots', '-s', type=int, nargs='*', help='how many data is used for training and evaluation, ratio if < 1 else absolute number')
parser.add_argument('--dial_ids_orders', '-o', type=int, nargs='*', help='which data order is used for experiments')
args = parser.parse_args()
print(args)
table = []
fieldnames = []
for task_name in tqdm(args.tasks, desc='tasks'):
metric = load_metric("metric.py", task_name)
dataset_name = task_name if task_name != "nlg" else "multiwoz21"
for shot in tqdm(args.shots, desc='shots', leave=False):
for output_dir in tqdm(args.output_dirs, desc='models', leave=False):
model_name = output_dir.split('/')[-1]
results = []
for dial_ids_order in tqdm(args.dial_ids_orders, desc='dial_ids_orders', leave=False):
result_dir = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen")
result_file = os.path.join(result_dir, "result.json")
if not os.path.exists(result_file):
filename = os.path.join(output_dir, task_name, f"{dataset_name}_{shot}shot_order{dial_ids_order}/gen/generated_predictions.json")
result = evaluate(filename, metric)
json.dump(result, open(result_file, 'w', encoding='utf-8'), indent=2, ensure_ascii=False)
else:
result = json.load(open(result_file))
results.append(result)
res = {
"dataset": f"{task_name}-{shot}shot",
"model": f"{model_name}",
**avg_result(results)
}
table.append(res)
for k in res:
if k not in fieldnames:
fieldnames.append(k)
res = tabulate(table, headers='keys', tablefmt='github')
with open(f'eval_results.txt', 'w', encoding='utf-8') as f:
print(res, file=f)
with open('eval_results.csv', 'w', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for res in table:
writer.writerow(res)
import datasets
FEATURES = {
"nlg": {
"context+knowledge": datasets.Value("string"),
"response": datasets.Value("string"),
"knowledge": {
"categorical": datasets.Sequence({
"intent": datasets.Value("string"),
"domain": datasets.Value("string"),
"slot": datasets.Value("string"),
"value": datasets.Value("string"),
}),
"non-categorical": datasets.Sequence({
"intent": datasets.Value("string"),
"domain": datasets.Value("string"),
"slot": datasets.Value("string"),
"value": datasets.Value("string"),
}),
"binary": datasets.Sequence({
"intent": datasets.Value("string"),
"domain": datasets.Value("string"),
"slot": datasets.Value("string"),
})
}},
"kvret": {
"context+knowledge": datasets.Value("string"),
"response": datasets.Value("string"),
"knowledge": {
"schedule": datasets.Sequence({
"entity": datasets.Value("string"),
"time": datasets.Value("string"),
"date": datasets.Value("string"),
"party": datasets.Value("string"),
"room": datasets.Value("string"),
"agenda": datasets.Value("string")
}),
"weather": datasets.Sequence({
"entity": datasets.Value("string"),
"today": datasets.Value("string"),
"monday": datasets.Value("string"),
"tuesday": datasets.Value("string"),
"wednesday": datasets.Value("string"),
"thursday": datasets.Value("string"),
"friday": datasets.Value("string"),
"saturday": datasets.Value("string"),
"sunday": datasets.Value("string"),
}),
"navigate": datasets.Sequence({
"entity": datasets.Value("string"),
"traffic_info": datasets.Value("string"),
"poi_type": datasets.Value("string"),
"address": datasets.Value("string"),
"distance": datasets.Value("string")
})
}},
"opendialkg": {
"context+knowledge": datasets.Value("string"),
"response": datasets.Value("string"),
"knowledge": datasets.Sequence(datasets.Sequence(datasets.Value("string"))),
},
"wow": {
"context+knowledge": datasets.Value("string"),
"response": datasets.Value("string"),
"knowledge": datasets.Sequence(datasets.Value("string")),
},
"personachat": {
"context+knowledge": datasets.Value("string"),
"response": datasets.Value("string"),
"knowledge": datasets.Sequence(datasets.Value("string")),
}
}
\ No newline at end of file
set -e
dataset_path=$1
model_name=$2
model_name_or_path=$3
dataset_name=$4
if [ "${dataset_name}" == "multiwoz21" ]
then
task_name="nlg"
else
task_name=${dataset_name}
fi
master_port=$5
n_gpus=2
cache_dir="../cache"
metric_name_or_path="metric.py"
source_column="context+knowledge"
target_column="response"
truncation_side="left"
max_source_length=512
max_target_length=512
per_device_train_batch_size=64
per_device_eval_batch_size=64
gradient_accumulation_steps=1
num_workers=16
lr=1e-3
num_train_epochs=100
for shot in 50 100 200
do
for dial_ids_order in 0 1 2 3 4
do
python create_data.py -t ${task_name} -d ${dataset_name} -o ${dial_ids_order} -s ${shot}
data_dir="data/${task_name}/${dataset_name}_${shot}shot_order${dial_ids_order}"
output_dir="output/${model_name}/${task_name}/${dataset_name}_${shot}shot_order${dial_ids_order}"
logging_dir="${output_dir}/runs"
train_file="${data_dir}/train.json"
validation_file="${data_dir}/validation.json"
# training
python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \
--task_name ${task_name} \
--dataset_name ${dataset_path} \
--dataset_config_name ${task_name} \
--train_file ${train_file} \
--validation_file ${validation_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${model_name_or_path} \
--do_train \
--do_eval \
--save_strategy epoch \
--evaluation_strategy epoch \
--save_total_limit 1 \
--prediction_loss_only \
--load_best_model_at_end \
--overwrite_output_dir \
--cache_dir ${cache_dir} \
--output_dir ${output_dir} \
--logging_dir ${logging_dir} \
--preprocessing_num_workers ${num_workers} \
--dataloader_num_workers ${num_workers} \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--optim adafactor \
--lr_scheduler_type constant \
--gradient_checkpointing
# inference
test_file="data/${task_name}/test.json"
gen_output_dir="${output_dir}/gen"
python -m torch.distributed.launch --master_port ${master_port} \
--nproc_per_node ${n_gpus} ../run_seq2seq.py \
--task_name ${task_name} \
--dataset_name ${dataset_path} \
--dataset_config_name ${task_name} \
--metric_name_or_path ${metric_name_or_path} \
--metric_config_name ${task_name} \
--test_file ${test_file} \
--source_column ${source_column} \
--target_column ${target_column} \
--max_source_length ${max_source_length} \
--max_target_length ${max_target_length} \
--truncation_side ${truncation_side} \
--model_name_or_path ${output_dir} \
--do_predict \
--predict_with_generate \
--cache_dir ${cache_dir} \
--output_dir ${gen_output_dir} \
--logging_dir ${logging_dir} \
--overwrite_output_dir \
--preprocessing_num_workers ${num_workers} \
--dataloader_num_workers ${num_workers} \
--per_device_train_batch_size ${per_device_train_batch_size} \
--per_device_eval_batch_size ${per_device_eval_batch_size} \
--gradient_accumulation_steps ${gradient_accumulation_steps} \
--learning_rate ${lr} \
--num_train_epochs ${num_train_epochs} \
--optim adafactor \
--lr_scheduler_type constant \
--gradient_checkpointing
done
done
# evaluation
python evaluate.py --output_dirs output/${model_name} -t ${task_name} -s 50 100 200 -o 0 1 2 3 4
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Grounded Dialog Generation Metric"""
from weakref import ref
import datasets
from sacrebleu.metrics import BLEU
from sacrebleu.utils import sum_of_lists
import re
from collections import Counter
import numpy as np
from nltk.corpus import stopwords
from rouge_score import rouge_scorer, scoring
from nltk.translate import meteor_score
from datasets.config import importlib_metadata, version
from convlab.base_models.t5.key2gen.features import FEATURES
from convlab.util import load_ontology
from copy import deepcopy
NLTK_VERSION = version.parse(importlib_metadata.version("nltk"))
if NLTK_VERSION >= version.Version("3.6.5"):
from nltk import word_tokenize
# Uncomment to download nltk_data for the first time running.
# import nltk
# nltk.download("wordnet")
# if NLTK_VERSION >= version.Version("3.6.5"):
# nltk.download("punkt")
# if NLTK_VERSION >= version.Version("3.6.6"):
# nltk.download("omw-1.4")
_CITATION = """
"""
_DESCRIPTION = """\
Metric to evaluate text generation models on the grounded dialog generation task.
"""
# TODO
_KWARGS_DESCRIPTION = """
Args:
predictions: list of predictions to score. Each predictions
should be a string.
references: list of reference for each prediction. Each
reference should be a string.
knowledge: task-specific grounded knowledge
Returns:
bleu-1/2/3/4: corpus-bleu score, from sacrebleu
rouge-1/2/L: ROUGE-F1, from rouge_score
meteor: METEOR, from nltk
unigram f1: unigram overlap, from parlai
distinct-1/2: from parlai
other knowledge utility score: task-specific knowledge utility metrics
"""
re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
stop_words = set(stopwords.words("english"))
def utt2words(s):
"""Lower text and remove punctuation, articles and extra whitespace.
from parlai https://github.com/facebookresearch/ParlAI/blob/9daae69320c07104493486e022c0e46a7871b253/parlai/core/metrics.py#L810"""
s = s.lower()
s = re_punc.sub(' ', s)
s = re_art.sub(' ', s)
return s.split()
def get_bleu(predictions, references):
"""bleu-1/2/3/4 from sacrebleu"""
references = [" " if ref=="" else ref for ref in references]
metrics = {}
bleu = BLEU(lowercase=True, force=False, tokenize=BLEU.TOKENIZER_DEFAULT, smooth_method="exp", smooth_value=None, effective_order=False)
stats = sum_of_lists(bleu._extract_corpus_statistics(predictions, [references]))
for n in range(1,5):
metrics[f"bleu-{n}"] = bleu.compute_bleu(
correct=stats[2: 2 + bleu.max_ngram_order],
total=stats[2 + bleu.max_ngram_order:],
sys_len=int(stats[0]), ref_len=int(stats[1]),
smooth_method=bleu.smooth_method, smooth_value=bleu.smooth_value,
effective_order=bleu.effective_order,
max_ngram_order=n).score
return metrics
def get_unigram_f1(predictions, references):
"""unigram f1 between prediction and reference, from parlai"""
metrics = {}
metrics["unigram f1"] = []
metrics["unigram f1 (non-stop words)"] = []
for prediction, reference in zip(predictions, references):
pred_items = utt2words(prediction)
gold_items = utt2words(reference)
for remove_stopwords in [False, True]:
if remove_stopwords:
pred_items = [w for w in pred_items if w not in stop_words]
gold_items = [w for w in gold_items if w not in stop_words]
common = Counter(pred_items) & Counter(gold_items)
num_same = sum(common.values())
if num_same == 0:
f1 = 0
else:
precision = 1.0 * num_same / len(pred_items)
recall = 1.0 * num_same / len(gold_items)
f1 = (2 * precision * recall) / (precision + recall)
if not remove_stopwords:
metrics["unigram f1"].append(f1)
else:
metrics["unigram f1 (non-stop words)"].append(f1)
metrics["unigram f1"] = np.mean(metrics["unigram f1"]) * 100
metrics["unigram f1 (non-stop words)"] = np.mean(metrics["unigram f1 (non-stop words)"]) * 100
return metrics
def get_rouge(predictions, references):
"""rouge-1/2/L from rouge-score"""
rouge_types=["rouge1", "rouge2", "rougeL"]
scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True)
aggregator = scoring.BootstrapAggregator()
for prediction, reference in zip(predictions, references):
score = scorer.score(reference, prediction)
aggregator.add_scores(score)
return {key: 100 * (value.mid.fmeasure if key == "rougeL" else value.mid.recall) for key, value in aggregator.aggregate().items()}
def get_meteor(predictions, references):
"""meteor from nltk"""
alpha=0.9
beta=3
gamma=0.5
if NLTK_VERSION >= version.Version("3.6.5"):
scores = [
meteor_score.single_meteor_score(
word_tokenize(ref), word_tokenize(pred), alpha=alpha, beta=beta, gamma=gamma
)
for ref, pred in zip(references, predictions)
]
else:
scores = [
meteor_score.single_meteor_score(ref, pred, alpha=alpha, beta=beta, gamma=gamma)
for ref, pred in zip(references, predictions)
]
return {"meteor": np.mean(scores) * 100}
def get_distinct(predictions):
"""distinct-1/2
from parlai https://github.com/facebookresearch/ParlAI/blob/9daae69320c07104493486e022c0e46a7871b253/parlai/core/metrics.py#L781"""
def _ngram(seq, n):
for i in range(len(seq) - n + 1):
yield tuple(seq[i : i + n])
metrics = {}
for k in [1, 2]:
inter_cnt = Counter()
for prediction in predictions:
ngram = Counter(_ngram(utt2words(prediction), k))
inter_cnt += ngram
metrics[f"distinct-{k}"] = max(len(inter_cnt), 1e-12) / max(sum(inter_cnt.values()), 1e-5) * 100
return metrics
def get_nlg_slot_err(predictions, knowledge):
"""slot error rate: (missing_count + redundant_count) / all_count for value in dialog acts"""
val2ds_dict = {}
ontology = load_ontology("multiwoz21")
for domain_name in ontology["domains"]:
domain = ontology["domains"][domain_name]
for slot_name in domain["slots"]:
slot = domain["slots"][slot_name]
if "possible_values" not in slot:
continue
possible_vals = slot["possible_values"]
if len(possible_vals) > 0:
for val in possible_vals:
val2ds_dict[val] = f"{domain_name}-{slot_name}"
score_list = []
for utterance, da in zip(predictions, knowledge):
missing_count = 0
redundant_count = 0
all_count = 0
all_values = set()
## missing values
# print(da)
# print(utterance)
for key in ['categorical', 'non-categorical']:
for value in da[key]['value']:
if len(value) > 0:
# print(value)
all_values.add(value)
if value.strip().lower() not in utterance.lower():
missing_count += 1
# print(f"\tmissing: {value}")
all_count += 1
if all_count == 0:
continue
## redundant values
for val in val2ds_dict:
if f" {val.strip().lower()} " in f" {utterance.strip().lower()} " and val.strip().lower() not in all_values:
wlist = val2ds_dict[val].split("-")
domain, slot = wlist[0], wlist[1]
if f" {slot.strip().lower()}" in f" {utterance.strip().lower()} ":
redundant_count += 1
# print(f"redundant: {val}/{val2ds_dict[val]}")
item_score = float(missing_count + redundant_count) / all_count
# print(f"\tredundant: {redundant_count} | missing_count: {missing_count} |all_count: {all_count}")
# print('-'*100)
score_list.append(item_score)
return {"err": np.mean(score_list) * 100}
def load_entities():
"""modified (load from unified ontology) from UnifiedSKG
https://github.com/HKUNLP/UnifiedSKG/blob/49a2ff950bb312b980c22ad72b11520db72ab6a3/metrics/kvret/evaluator.py#L8"""
ontology = load_ontology("kvret")
all_entities = set()
for domain in ontology["domains"]:
for slot in ontology["domains"][domain]["slots"]:
all_entities |= set(ontology["domains"][domain]["slots"][slot]["possible_values"])
missed_entities = ["yoga", "tennis", "swimming", "football", " lab ", "doctor", "optometrist", "dentist", "1st",
"2nd", "3rd", "4th", "5th", "6th", "7th", "8th", "9th", "10th",
"11th", "12th", "13th", "14th", "15th", "16th", "17th", "18th", "19th", "20th", "Jill",
"Jack"]
all_entities |= set(missed_entities)
all_entities.remove("HR")
all_entities.add(" HR ")
all_entities = sorted(list(all_entities), key=lambda i: len(i), reverse=True)
return all_entities
def check_sub_str(str_list: list, sub_str: str):
"""
It takes a list of strings and a substring as input, and returns True if the substring is found
in any of the strings in the list, and False otherwise
"""
for str_item in str_list:
if sub_str in str_item or sub_str.lower() in str_item.lower():
return True
return False
def extract_entities_from_utterance(utterance, sorted_entities):
"""modified (remove underscore) from UnifiedSKG
https://github.com/HKUNLP/UnifiedSKG/blob/49a2ff950bb312b980c22ad72b11520db72ab6a3/metrics/kvret/response_entity_hit.py#L45"""
utterance = " {} ".format(utterance) # for entity matching
for h in range(0, 13): # for formulating am & pm
utterance = utterance.replace("{} am".format(h), "{}am".format(h))
utterance = utterance.replace("{} pm".format(h), "{}pm".format(h))
for entity_item_a in [20, 30, 40, 50, 60, 70, 80, 90, 100]:
for entity_item_b in [20, 30, 40, 50, 60, 70, 80, 90, 100]:
utterance = utterance.replace("{}-{}f".format(str(entity_item_a), str(entity_item_b)), "{}f-{}f".format(str(entity_item_a), str(entity_item_b)))
entities_in_this_utterance = []
for entity in sorted_entities:
# len(entity) decreases
if (entity in utterance) or (entity.lower() in utterance.lower()):
if not check_sub_str(entities_in_this_utterance, entity):
# in case of "week & weekend", "week & next_week" etc
entities_in_this_utterance.append(entity)
return entities_in_this_utterance
def f1_score(y_pred, y_true, average="micro"):
"""micro/marco-F1 score, modified from UnifiedSKG
https://github.com/HKUNLP/UnifiedSKG/blob/49a2ff950bb312b980c22ad72b11520db72ab6a3/metrics/kvret/response_entity_hit.py#L76"""
assert len(y_pred) == len(y_true)
def _compute_F1(precision, recall):
return 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0
def _compute_prf(gold, pred):
TP, FP, FN = 0, 0, 0
if len(gold) != 0:
count = 1
for g in gold:
if g in pred:
TP += 1
else:
FN += 1
for p in set(pred):
if p not in gold:
FP += 1
precision = TP / float(TP + FP) if (TP + FP) != 0 else 0
recall = TP / float(TP + FN) if (TP + FN) != 0 else 0
F1 = _compute_F1(precision, recall)
else:
precision, recall, F1, count = 0, 0, 0, 0
return TP, FP, FN, F1, count
F1_pred, F1_count, TP_all, FP_all, FN_all = 0, 0, 0, 0, 0
for y_true_item, y_pred_item in zip(y_true, y_pred):
single_tp, single_fp, single_fn, single_f1, count = _compute_prf(y_true_item, y_pred_item)
F1_pred += single_f1
F1_count += count
TP_all += single_tp
FP_all += single_fp
FN_all += single_fn
if average == "macro":
F1_macro_score = F1_pred / float(F1_count) if F1_count != 0 else 0
return F1_macro_score * 100
elif average == "micro":
P_score = TP_all / float(TP_all + FP_all) if (TP_all + FP_all) != 0 else 0
R_score = TP_all / float(TP_all + FN_all) if (TP_all + FN_all) != 0 else 0
F1_micro_score = _compute_F1(P_score, R_score)
return F1_micro_score * 100
else:
raise ValueError("Options other than micro/macro are not supported.")
def get_kvret_entity_f1(predictions, references, knowledge):
"""entity f1 for kvret, modified from
https://github.com/HKUNLP/UnifiedSKG/blob/49a2ff950bb312b980c22ad72b11520db72ab6a3/metrics/kvret/response_entity_hit.py#L178"""
global_entities = load_entities()
F1_scores = {}
entities_from_predictions_and_references = {
d: {"predictions_entities": [], "references_entities": []} for d in ["all", "schedule", "weather", "navigate"]
}
for prediction, reference, kb in zip(predictions, references, knowledge):
prediction_entities = extract_entities_from_utterance(utterance=prediction, sorted_entities=global_entities)
reference_entities = extract_entities_from_utterance(utterance=reference, sorted_entities=global_entities)
entities_from_predictions_and_references["all"]["predictions_entities"].append(prediction_entities)
entities_from_predictions_and_references["all"]["references_entities"].append(reference_entities)
domain = "schedule"
for d in kb:
if len(kb[d]["entity"]) > 0:
domain = d
break
entities_from_predictions_and_references[domain]["predictions_entities"].append(prediction_entities)
entities_from_predictions_and_references[domain]["references_entities"].append(reference_entities)
for category in entities_from_predictions_and_references.keys():
predictions_entities = entities_from_predictions_and_references[category]["predictions_entities"]
references_entities = entities_from_predictions_and_references[category]["references_entities"]
F1_scores["{} micro entity F1".format(category)] = f1_score(y_pred=predictions_entities, y_true=references_entities, average="micro")
F1_scores["{} macro entity F1".format(category)] = f1_score(y_pred=predictions_entities, y_true=references_entities, average="macro")
return {**F1_scores}
def get_opendialkg_entity_f1(predictions, references, knowledge):
predictions_entities, references_entities = [], []
for prediction, reference, kg_path in zip(predictions, references, knowledge):
kg_entities = set()
for kg_triple in kg_path:
# add head and tail entities
kg_entities.add(kg_triple[0])
kg_entities.add(kg_triple[-1])
kg_entities = sorted(list(kg_entities), key=lambda i: len(i), reverse=True)
for utterance, entities in zip([prediction, reference], [predictions_entities, references_entities]):
entities_in_this_utterance = []
for entity in kg_entities:
if (entity in utterance) or (entity.lower() in utterance.lower()):
if not check_sub_str(entities_in_this_utterance, entity):
# in case of "week & weekend", "week & next_week" etc
entities_in_this_utterance.append(entity)
entities.append(entities_in_this_utterance)
return {
"micro entity f1": f1_score(y_pred=predictions_entities, y_true=references_entities, average="micro"),
"macro entity f1": f1_score(y_pred=predictions_entities, y_true=references_entities, average="macro")
}
def get_knowledge_sentences_f1(predictions, knowledge):
knowledge_reference = [' '.join(k_sens) for k_sens in knowledge]
f1_score = get_unigram_f1(predictions, knowledge_reference)
return {f"knowledge {k}": v for k, v in f1_score.items()}
@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class GroundedDialogGenerationMetrics(datasets.Metric):
"""Metric to evaluate text generation models on the grounded dialog generation task."""
def _info(self):
return datasets.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features({
"predictions": datasets.Value("string"),
"references": datasets.Value("string"),
"knowledge": deepcopy(FEATURES[self.config_name]["knowledge"])
})
)
def compute(self, predictions, references, knowledge=None):
"""Returns the scores: bleu"""
metrics = {}
# bleu
metrics.update(get_bleu(predictions, references))
# unigram f1
metrics.update(get_unigram_f1(predictions, references))
# rouge-1/2/L-fmeasure
metrics.update(get_rouge(predictions, references))
# meteor
metrics.update(get_meteor(predictions, references))
# inter-distinct-1/2
metrics.update(get_distinct(predictions))
if knowledge is not None:
if self.config_name == "nlg":
metrics.update(get_nlg_slot_err(predictions, knowledge))
elif self.config_name == "kvret":
metrics.update(get_kvret_entity_f1(predictions, references, knowledge))
elif self.config_name == "opendialkg":
metrics.update(get_opendialkg_entity_f1(predictions, references, knowledge))
elif self.config_name in ["wow", "personachat"]:
metrics.update(get_knowledge_sentences_f1(predictions, knowledge))
return metrics
from convlab.base_models.t5.nlg.nlg import T5NLG
\ No newline at end of file
...@@ -8,17 +8,13 @@ from convlab.util.custom_util import model_downloader ...@@ -8,17 +8,13 @@ from convlab.util.custom_util import model_downloader
class T5NLG(NLG): class T5NLG(NLG):
def __init__(self, speaker, context_window_size, model_name_or_path, model_file=None, device='cuda'): def __init__(self, speaker, context_window_size, model_name_or_path, device='cuda'):
assert speaker in ['user', 'system'] assert speaker in ['user', 'system']
self.speaker = speaker self.speaker = speaker
self.opponent = 'system' if speaker == 'user' else 'user' self.opponent = 'system' if speaker == 'user' else 'user'
self.context_window_size = context_window_size self.context_window_size = context_window_size
self.use_context = context_window_size > 0 self.use_context = context_window_size > 0
model_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.exists(model_name_or_path):
model_downloader(model_dir, model_file)
self.config = AutoConfig.from_pretrained(model_name_or_path) self.config = AutoConfig.from_pretrained(model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config)
......
from convlab.base_models.t5.nlu.nlu import T5NLU
\ No newline at end of file
...@@ -8,17 +8,13 @@ from convlab.util.custom_util import model_downloader ...@@ -8,17 +8,13 @@ from convlab.util.custom_util import model_downloader
class T5NLU(NLU): class T5NLU(NLU):
def __init__(self, speaker, context_window_size, model_name_or_path, model_file=None, device='cuda'): def __init__(self, speaker, context_window_size, model_name_or_path, device='cuda'):
assert speaker in ['user', 'system'] assert speaker in ['user', 'system']
self.speaker = speaker self.speaker = speaker
self.opponent = 'system' if speaker == 'user' else 'user' self.opponent = 'system' if speaker == 'user' else 'user'
self.context_window_size = context_window_size self.context_window_size = context_window_size
self.use_context = context_window_size > 0 self.use_context = context_window_size > 0
model_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.exists(model_name_or_path):
model_downloader(model_dir, model_file)
self.config = AutoConfig.from_pretrained(model_name_or_path) self.config = AutoConfig.from_pretrained(model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment