diff --git a/convlab/dst/setsumbt/__init__.py b/convlab/dst/setsumbt/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..e20840c55ab89d9de54bb0018ce27af3f2e5bff4 100644 --- a/convlab/dst/setsumbt/__init__.py +++ b/convlab/dst/setsumbt/__init__.py @@ -0,0 +1 @@ +from convlab.dst.setsumbt.tracker import SetSUMBTTracker diff --git a/convlab/dst/setsumbt/datasets/metrics.py b/convlab/dst/setsumbt/datasets/metrics.py index 690d75baeaee4ef6ec5c1d963e8238db0044c57b..3dfdc9f187929f374cfb7ea13a7c45ca46a10195 100644 --- a/convlab/dst/setsumbt/datasets/metrics.py +++ b/convlab/dst/setsumbt/datasets/metrics.py @@ -22,9 +22,8 @@ import torch from transformers.utils import ModelOutput from matplotlib import pyplot as plt -from convlab.util import load_dataset -from convlab.util import load_dst_data -from convlab.dst.setsumbt.datasets.value_maps import VALUE_MAP, QUANTITIES +from convlab.util import load_dataset, load_dst_data +from convlab.dst.setsumbt.datasets.utils import clean_states class Metrics(ModelOutput): @@ -106,76 +105,11 @@ class JointGoalAccuracy: Returns: dict: The cleaned state. """ - clean_state = dict() - for domain, subset in state.items(): - clean_state[domain] = {} - for slot, value in subset.items(): - value = value.split('|') - - # Map values using value_map - for old, new in VALUE_MAP.items(): - value = [val.replace(old, new) for val in value] - value = '|'.join(value) - - # Map dontcare to "do not care" and empty to 'none' - value = value.replace('dontcare', 'do not care') - value = value if value else 'none' - - # Map quantity values to the integer quantity value - if 'people' in slot or 'duration' in slot or 'stay' in slot: - try: - if value not in ['do not care', 'none']: - value = int(value) - value = str(value) if value < 10 else QUANTITIES[-1] - except: - value = value - # Map time values to the most appropriate value in the standard time set - elif 'time' in slot or 'leave' in slot or 'arrive' in slot: - try: - if value not in ['do not care', 'none']: - # Strip after/before from time value - value = value.replace('after ', '').replace('before ', '') - # Extract hours and minutes from different possible formats - if ':' not in value and len(value) == 4: - h, m = value[:2], value[2:] - elif len(value) == 1: - h = int(value) - m = 0 - elif 'pm' in value: - h = int(value.replace('pm', '')) + 12 - m = 0 - elif 'am' in value: - h = int(value.replace('pm', '')) - m = 0 - elif ':' in value: - h, m = value.split(':') - elif ';' in value: - h, m = value.split(';') - # Map to closest 5 minutes - if int(m) % 5 != 0: - m = round(int(m) / 5) * 5 - h = int(h) - if m == 60: - m = 0 - h += 1 - if h >= 24: - h -= 24 - # Set in standard 24 hour format - h, m = int(h), int(m) - value = '%02i:%02i' % (h, m) - except: - value = value - # Map boolean slots to yes/no value - elif 'parking' in slot or 'internet' in slot: - if value not in ['do not care', 'none']: - if value == 'free': - value = 'yes' - elif True in [v in value.lower() for v in ['yes', 'no']]: - value = [v for v in ['yes', 'no'] if v in value][0] - - value = value if value != 'none' else '' - - clean_state[domain][slot] = value + + turns = [{'dialogue_acts': list(), + 'state': state}] + turns = clean_states(turns) + clean_state = turns[0]['state'] return clean_state diff --git a/convlab/dst/setsumbt/datasets/utils.py b/convlab/dst/setsumbt/datasets/utils.py index f227a569c5cfc782dda1fbedb61b5afbffa17cf5..39ce3fd37e58322d1b130eb1a4f4571c876f059d 100644 --- a/convlab/dst/setsumbt/datasets/utils.py +++ b/convlab/dst/setsumbt/datasets/utils.py @@ -41,15 +41,17 @@ def get_ontology_slots(dataset_name: str) -> dict: if domain_name not in ontology_slots: ontology_slots[domain_name] = dict() for slot, slot_info in ontology['domains'][domain]['slots'].items(): - if slot not in ontology_slots[domain_name]: - ontology_slots[domain_name][slot] = {'description': slot_info['description'], - 'possible_values': list(), - 'dataset_names': list()} + slot_name = slot.replace('.', '_') + if slot_name not in ontology_slots[domain_name]: + ontology_slots[domain_name][slot_name] = {'description': slot_info['description'], + 'possible_values': list(), + 'dataset_names': list()} if slot_info['is_categorical']: - ontology_slots[domain_name][slot]['possible_values'] += slot_info['possible_values'] + ontology_slots[domain_name][slot_name]['possible_values'] += slot_info['possible_values'] - ontology_slots[domain_name][slot]['possible_values'] = list(set(ontology_slots[domain_name][slot]['possible_values'])) - ontology_slots[domain_name][slot]['dataset_names'].append(dataset_name) + unique_vals = list(set(ontology_slots[domain_name][slot_name]['possible_values'])) + ontology_slots[domain_name][slot_name]['possible_values'] = unique_vals + ontology_slots[domain_name][slot_name]['dataset_names'].append(dataset_name) return ontology_slots @@ -81,11 +83,11 @@ def get_values_from_data(dataset: dict, data_split: str = "train") -> dict: if domain_name not in value_sets: value_sets[domain_name] = {} for slot, value in substate.items(): - if slot not in value_sets[domain_name]: - value_sets[domain_name][slot] = [] - if value and value not in value_sets[domain_name][slot]: - value_sets[domain_name][slot].append(value) - # pdb.set_trace() + slot_name = slot.replace('.', '_') + if slot_name not in value_sets[domain_name]: + value_sets[domain_name][slot_name] = [] + if value and value not in value_sets[domain_name][slot_name]: + value_sets[domain_name][slot_name].append(value) return clean_values(value_sets) @@ -209,7 +211,7 @@ def get_requestable_slots(datasets: list) -> dict: domain_name = DOMAINS_MAP.get(domain, domain.lower()) if domain_name not in slots: slots[domain_name] = [] - slots[domain_name].append(slot) + slots[domain_name].append(slot.replace('.', '_')) slots = {domain: list(set(slot_list)) for domain, slot_list in slots.items()} @@ -295,6 +297,7 @@ def clean_states(turns: list) -> list: for act in turn['dialogue_acts']: domain = act['domain'] act['domain'] = DOMAINS_MAP.get(domain, domain.lower()) + act['slot'] = act['slot'].replace('.', '_') clean_acts.append(act) for domain, subset in turn['state'].items(): domain_name = DOMAINS_MAP.get(domain, domain.lower()) @@ -363,7 +366,7 @@ def clean_states(turns: list) -> list: elif True in [v in value.lower() for v in ['yes', 'no']]: value = [v for v in ['yes', 'no'] if v in value][0] - clean_state[domain_name][slot] = value + clean_state[domain_name][slot.replace('.', '_')] = value turn['state'] = clean_state turn['dialogue_acts'] = clean_acts clean_turns.append(turn)