Skip to content
Snippets Groups Projects
Commit f88cdbf8 authored by Carel van Niekerk's avatar Carel van Niekerk
Browse files

Bug fixes for SetSUMBT

parent 307750cd
Branches
No related tags found
No related merge requests found
from convlab.dst.setsumbt.tracker import SetSUMBTTracker
...@@ -22,9 +22,8 @@ import torch ...@@ -22,9 +22,8 @@ import torch
from transformers.utils import ModelOutput from transformers.utils import ModelOutput
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from convlab.util import load_dataset from convlab.util import load_dataset, load_dst_data
from convlab.util import load_dst_data from convlab.dst.setsumbt.datasets.utils import clean_states
from convlab.dst.setsumbt.datasets.value_maps import VALUE_MAP, QUANTITIES
class Metrics(ModelOutput): class Metrics(ModelOutput):
...@@ -106,76 +105,11 @@ class JointGoalAccuracy: ...@@ -106,76 +105,11 @@ class JointGoalAccuracy:
Returns: Returns:
dict: The cleaned state. dict: The cleaned state.
""" """
clean_state = dict()
for domain, subset in state.items(): turns = [{'dialogue_acts': list(),
clean_state[domain] = {} 'state': state}]
for slot, value in subset.items(): turns = clean_states(turns)
value = value.split('|') clean_state = turns[0]['state']
# Map values using value_map
for old, new in VALUE_MAP.items():
value = [val.replace(old, new) for val in value]
value = '|'.join(value)
# Map dontcare to "do not care" and empty to 'none'
value = value.replace('dontcare', 'do not care')
value = value if value else 'none'
# Map quantity values to the integer quantity value
if 'people' in slot or 'duration' in slot or 'stay' in slot:
try:
if value not in ['do not care', 'none']:
value = int(value)
value = str(value) if value < 10 else QUANTITIES[-1]
except:
value = value
# Map time values to the most appropriate value in the standard time set
elif 'time' in slot or 'leave' in slot or 'arrive' in slot:
try:
if value not in ['do not care', 'none']:
# Strip after/before from time value
value = value.replace('after ', '').replace('before ', '')
# Extract hours and minutes from different possible formats
if ':' not in value and len(value) == 4:
h, m = value[:2], value[2:]
elif len(value) == 1:
h = int(value)
m = 0
elif 'pm' in value:
h = int(value.replace('pm', '')) + 12
m = 0
elif 'am' in value:
h = int(value.replace('pm', ''))
m = 0
elif ':' in value:
h, m = value.split(':')
elif ';' in value:
h, m = value.split(';')
# Map to closest 5 minutes
if int(m) % 5 != 0:
m = round(int(m) / 5) * 5
h = int(h)
if m == 60:
m = 0
h += 1
if h >= 24:
h -= 24
# Set in standard 24 hour format
h, m = int(h), int(m)
value = '%02i:%02i' % (h, m)
except:
value = value
# Map boolean slots to yes/no value
elif 'parking' in slot or 'internet' in slot:
if value not in ['do not care', 'none']:
if value == 'free':
value = 'yes'
elif True in [v in value.lower() for v in ['yes', 'no']]:
value = [v for v in ['yes', 'no'] if v in value][0]
value = value if value != 'none' else ''
clean_state[domain][slot] = value
return clean_state return clean_state
......
...@@ -41,15 +41,17 @@ def get_ontology_slots(dataset_name: str) -> dict: ...@@ -41,15 +41,17 @@ def get_ontology_slots(dataset_name: str) -> dict:
if domain_name not in ontology_slots: if domain_name not in ontology_slots:
ontology_slots[domain_name] = dict() ontology_slots[domain_name] = dict()
for slot, slot_info in ontology['domains'][domain]['slots'].items(): for slot, slot_info in ontology['domains'][domain]['slots'].items():
if slot not in ontology_slots[domain_name]: slot_name = slot.replace('.', '_')
ontology_slots[domain_name][slot] = {'description': slot_info['description'], if slot_name not in ontology_slots[domain_name]:
ontology_slots[domain_name][slot_name] = {'description': slot_info['description'],
'possible_values': list(), 'possible_values': list(),
'dataset_names': list()} 'dataset_names': list()}
if slot_info['is_categorical']: if slot_info['is_categorical']:
ontology_slots[domain_name][slot]['possible_values'] += slot_info['possible_values'] ontology_slots[domain_name][slot_name]['possible_values'] += slot_info['possible_values']
ontology_slots[domain_name][slot]['possible_values'] = list(set(ontology_slots[domain_name][slot]['possible_values'])) unique_vals = list(set(ontology_slots[domain_name][slot_name]['possible_values']))
ontology_slots[domain_name][slot]['dataset_names'].append(dataset_name) ontology_slots[domain_name][slot_name]['possible_values'] = unique_vals
ontology_slots[domain_name][slot_name]['dataset_names'].append(dataset_name)
return ontology_slots return ontology_slots
...@@ -81,11 +83,11 @@ def get_values_from_data(dataset: dict, data_split: str = "train") -> dict: ...@@ -81,11 +83,11 @@ def get_values_from_data(dataset: dict, data_split: str = "train") -> dict:
if domain_name not in value_sets: if domain_name not in value_sets:
value_sets[domain_name] = {} value_sets[domain_name] = {}
for slot, value in substate.items(): for slot, value in substate.items():
if slot not in value_sets[domain_name]: slot_name = slot.replace('.', '_')
value_sets[domain_name][slot] = [] if slot_name not in value_sets[domain_name]:
if value and value not in value_sets[domain_name][slot]: value_sets[domain_name][slot_name] = []
value_sets[domain_name][slot].append(value) if value and value not in value_sets[domain_name][slot_name]:
# pdb.set_trace() value_sets[domain_name][slot_name].append(value)
return clean_values(value_sets) return clean_values(value_sets)
...@@ -209,7 +211,7 @@ def get_requestable_slots(datasets: list) -> dict: ...@@ -209,7 +211,7 @@ def get_requestable_slots(datasets: list) -> dict:
domain_name = DOMAINS_MAP.get(domain, domain.lower()) domain_name = DOMAINS_MAP.get(domain, domain.lower())
if domain_name not in slots: if domain_name not in slots:
slots[domain_name] = [] slots[domain_name] = []
slots[domain_name].append(slot) slots[domain_name].append(slot.replace('.', '_'))
slots = {domain: list(set(slot_list)) for domain, slot_list in slots.items()} slots = {domain: list(set(slot_list)) for domain, slot_list in slots.items()}
...@@ -295,6 +297,7 @@ def clean_states(turns: list) -> list: ...@@ -295,6 +297,7 @@ def clean_states(turns: list) -> list:
for act in turn['dialogue_acts']: for act in turn['dialogue_acts']:
domain = act['domain'] domain = act['domain']
act['domain'] = DOMAINS_MAP.get(domain, domain.lower()) act['domain'] = DOMAINS_MAP.get(domain, domain.lower())
act['slot'] = act['slot'].replace('.', '_')
clean_acts.append(act) clean_acts.append(act)
for domain, subset in turn['state'].items(): for domain, subset in turn['state'].items():
domain_name = DOMAINS_MAP.get(domain, domain.lower()) domain_name = DOMAINS_MAP.get(domain, domain.lower())
...@@ -363,7 +366,7 @@ def clean_states(turns: list) -> list: ...@@ -363,7 +366,7 @@ def clean_states(turns: list) -> list:
elif True in [v in value.lower() for v in ['yes', 'no']]: elif True in [v in value.lower() for v in ['yes', 'no']]:
value = [v for v in ['yes', 'no'] if v in value][0] value = [v for v in ['yes', 'no'] if v in value][0]
clean_state[domain_name][slot] = value clean_state[domain_name][slot.replace('.', '_')] = value
turn['state'] = clean_state turn['state'] = clean_state
turn['dialogue_acts'] = clean_acts turn['dialogue_acts'] = clean_acts
clean_turns.append(turn) clean_turns.append(turn)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment