Skip to content
Snippets Groups Projects
Commit dedad8e7 authored by Carel van Niekerk's avatar Carel van Niekerk :computer:
Browse files

Fix bug in ontology enxtraction

parent d7ba8e0f
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,6 @@ def evaluate(predict_result):
metrics = {'TP':0, 'FP':0, 'FN':0}
acc = []
for sample in predict_result:
pred_state = sample['predictions']['state']
gold_state = sample['state']
......
......@@ -16,6 +16,7 @@
"""Convlab3 Unified dataset data processing utilities"""
import numpy
import pdb
from convlab.util import load_ontology, load_dst_data, load_nlu_data
from convlab.dst.setsumbt.dataset.value_maps import VALUE_MAP, DOMAINS_MAP, QUANTITIES, TIME
......@@ -68,7 +69,9 @@ def get_values_from_data(dataset: dict, data_split: str = "train") -> dict:
data = load_dst_data(dataset, data_split='all', speaker='user')
# Remove test data from the data when building training/validation ontology
if data_split in ['train', 'validation']:
if data_split == 'train':
data = {key: itm for key, itm in data.items() if key == 'train'}
elif data_split == 'validation':
data = {key: itm for key, itm in data.items() if key in ['train', 'validation']}
value_sets = {}
......@@ -76,13 +79,14 @@ def get_values_from_data(dataset: dict, data_split: str = "train") -> dict:
for turn in dataset:
for domain, substate in turn['state'].items():
domain_name = DOMAINS_MAP.get(domain, domain.lower())
if domain not in value_sets:
if domain_name not in value_sets:
value_sets[domain_name] = {}
for slot, value in substate.items():
if slot not in value_sets[domain_name]:
value_sets[domain_name][slot] = []
if value and value not in value_sets[domain_name][slot]:
value_sets[domain_name][slot].append(value)
# pdb.set_trace()
return clean_values(value_sets)
......@@ -165,6 +169,9 @@ def ontology_add_values(ontology_slots: dict, value_sets: dict, data_split: str
if data_split in ['train', 'validation']:
if domain not in value_sets:
continue
possible_values = [v for slot, vals in value_sets[domain].items() for v in vals]
if len(possible_values) == 0:
continue
ontology[domain] = {}
for slot in sorted(ontology_slots[domain]):
if not ontology_slots[domain][slot]['possible_values']:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment