Skip to content
Snippets Groups Projects
Commit e19624c3 authored by Carel van Niekerk's avatar Carel van Niekerk :computer:
Browse files

Refactoring

parent c4517f41
No related branches found
No related tags found
No related merge requests found
......@@ -19,16 +19,15 @@ from convlab.util import load_dataset, load_ontology, load_dst_data, load_nlu_da
from convlab.dst.setsumbt.dataset.value_maps import *
# Load slots, descriptions and categorical slot values from dataset ontology
def get_ontology_slots(dataset_name: str) -> dict:
'''
"""
Function to extract slots, slot descriptions and categorical slot values from the dataset ontology.
Args:
dataset_name (str): Dataset name
Returns:
ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values
'''
"""
dataset_names = dataset_name.split('+') if '+' in dataset_name else [dataset_name]
ontology_slots = dict()
for dataset_name in dataset_names:
......@@ -50,16 +49,15 @@ def get_ontology_slots(dataset_name: str) -> dict:
return ontology_slots
# Load possible slot values from the oracle states in the dataset
def get_values_from_data(dataset: dict) -> dict:
'''
"""
Function to extract slots, slot descriptions and categorical slot values from the dataset ontology.
Args:
dataset (dict): Dataset dictionary obtained using the load_dataset function
Returns:
value_sets (dict): Dictionary containing possible values obtained from dataset
'''
"""
data = load_dst_data(dataset, data_split='all', speaker='user')
value_sets = {}
for set_type, dataset in data.items():
......@@ -77,7 +75,15 @@ def get_values_from_data(dataset: dict) -> dict:
return clean_values(value_sets)
def combine_value_sets(value_sets):
def combine_value_sets(value_sets: list) -> dict:
"""
Function to combine value sets extracted from different datasets
Args:
value_sets (list): List of value sets extracted using the get_values_from_data function
Returns:
value_set (dict): Dictionary containing possible values obtained from datasets
"""
value_set = value_sets[0]
for _value_set in value_sets[1:]:
for domain, domain_info in _value_set.items():
......@@ -92,17 +98,16 @@ def combine_value_sets(value_sets):
return value_set
# Clean the possible values for the ontology
def clean_values(value_sets: dict, value_map: dict = VALUE_MAP) -> dict:
'''
"""
Function to clean up the possible value sets extracted from the states in the dataset
Args:
value_sets (dict): Dictionary containing possible values obtained from dataset
value_map (dict): MultiWOZ specific label map to avoid duplication and typos in values
value_map (dict): Label map to avoid duplication and typos in values
Returns:
clean_vals (dict): Cleaned Dictionary containing possible values obtained from dataset
'''
"""
clean_vals = {}
for domain, subset in value_sets.items():
clean_vals[domain] = {}
......@@ -130,9 +135,8 @@ def clean_values(value_sets: dict, value_map: dict=VALUE_MAP) -> dict:
return clean_vals
# Add value sets obtained from the dataset to the ontology
def ontology_add_values(ontology_slots: dict, value_sets: dict) -> dict:
'''
"""
Add value sets obtained from the dataset to the ontology
Args:
ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values
......@@ -140,7 +144,7 @@ def ontology_add_values(ontology_slots: dict, value_sets: dict) -> dict:
Returns:
ontology_slots (dict): Ontology dictionary containing slots, slot descriptions and possible value sets
'''
"""
ontology = {}
for domain in sorted(ontology_slots):
ontology[domain] = {}
......@@ -158,16 +162,15 @@ def ontology_add_values(ontology_slots: dict, value_sets: dict) -> dict:
return ontology
# Get set of requestable slots from the dataset action labels
def get_requestable_slots(datasets: list) -> dict:
'''
"""
Function to get set of requestable slots from the dataset action labels.
Args:
dataset (dict): Dataset dictionary obtained using the load_dataset function
Returns:
slots (dict): Dictionary containing requestable domain-slot pairs
'''
"""
datasets = [load_nlu_data(dataset, data_split='all', speaker='user') for dataset in datasets]
slots = {}
......@@ -189,9 +192,8 @@ def get_requestable_slots(datasets: list) -> dict:
return slots
# Add requestable slots obtained from the dataset to the ontology
def ontology_add_requestable_slots(ontology_slots: dict, requestable_slots: dict) -> dict:
'''
"""
Add requestable slots obtained from the dataset to the ontology
Args:
ontology_slots (dict): Ontology dictionary containing slots, descriptions and categorical slot values
......@@ -200,7 +202,7 @@ def ontology_add_requestable_slots(ontology_slots: dict, requestable_slots: dict
Returns:
ontology_slots (dict): Ontology dictionary containing slots, slot descriptions and
possible value sets including requests
'''
"""
for domain in ontology_slots:
for slot in ontology_slots[domain]:
if domain in requestable_slots:
......@@ -210,26 +212,28 @@ def ontology_add_requestable_slots(ontology_slots: dict, requestable_slots: dict
return ontology_slots
# Extract the required information from the data provided by unified loader
def extract_turns(dialogue: list) -> list:
'''
"""
Extract the required information from the data provided by unified loader
Args:
dialogue (list): List of turns within a dialogue
Returns:
turns (list): List of turns within a dialogue
'''
"""
turns = []
turn_info = {}
for turn in dialogue:
if turn['speaker'] == 'system':
turn_info['system_utterance'] = turn['utterance']
if turn['utt_idx'] == 1:
# System utterance in the first turn is always empty as conversation is initiated by the user
if turn['utt_idx'] == 1:
turn_info['system_utterance'] = ''
if turn['speaker'] == 'user':
turn_info['user_utterance'] = turn['utterance']
# Inform acts not required by model
turn_info['dialogue_acts'] = [act for act in turn['dialogue_acts']['categorical']
if act['intent'] not in ['inform']]
......@@ -237,6 +241,7 @@ def extract_turns(dialogue: list) -> list:
if act['intent'] not in ['inform']]
turn_info['dialogue_acts'] += [act for act in turn['dialogue_acts']['binary']
if act['intent'] not in ['inform']]
turn_info['state'] = turn['state']
if 'system_utterance' in turn_info and 'user_utterance' in turn_info:
......@@ -246,16 +251,15 @@ def extract_turns(dialogue: list) -> list:
return turns
# Clean the state within each turn of a dialogue (cleaning values and mapping to options used in ontology)
def clean_states(turns: list) -> list:
'''
"""
Clean the state within each turn of a dialogue (cleaning values and mapping to options used in ontology)
Args:
turns (list): List of turns within a dialogue
Returns:
clean_turns (list): List of turns within a dialogue
'''
"""
clean_turns = []
for turn in turns:
clean_state = {}
......@@ -339,16 +343,15 @@ def clean_states(turns: list) -> list:
return clean_turns
# Get active domains at each turn in a dialogue
def get_active_domains(turns: list) -> list:
'''
"""
Get active domains at each turn in a dialogue
Args:
turns (list): List of turns within a dialogue
Returns:
turns (list): List of turns within a dialogue
'''
"""
for turn_id in range(len(turns)):
# At first turn all domains with not none values in the state are active
if turn_id == 0:
......@@ -371,22 +374,22 @@ def get_active_domains(turns: list) -> list:
if val != 'none':
domains.append(domain_name)
# Add all domains activated by a user action
domains += [act['domain'] for act in turns[turn_id]['dialogue_acts'] if act['domain'] in turns[turn_id]['state']]
domains += [act['domain'] for act in turns[turn_id]['dialogue_acts']
if act['domain'] in turns[turn_id]['state']]
turns[turn_id]['active_domains'] = list(set(domains))
return turns
# Extract all dialogues from dataset
def extract_dialogues(data: list) -> list:
'''
"""
Extract all dialogues from dataset
Args:
data (list): List of all dialogues in a subset of the data
Returns:
dialogues (list): List of all extracted dialogues
'''
"""
dialogues = []
for dial in data:
turns = extract_turns(dial['turns'])
......
# -*- coding: utf-8 -*-
# Copyright 2022 DSML Group, Heinrich Heine University, Düsseldorf
# Authors: Carel van Niekerk (niekerk@hhu.de)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convlab3 Unified dataset value maps"""
# MultiWOZ specific label map to avoid duplication and typos in values
VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'and', 'b and b': 'bed and breakfast',
'cityroomz': 'city roomz', ' ': ' ', 'acorn house': 'acorn guest house', 'marriot': 'marriott',
......@@ -7,7 +25,7 @@ VALUE_MAP = {'guesthouse': 'guest house', 'belfry': 'belfray', '-': ' ', '&': 'a
'north americanindian': 'north american', 'caribbeanindian': 'indian', 'sheeps': "sheep's"}
# Domain map for SGD Data
# Domain map for SGD and TM Data
DOMAINS_MAP = {'Alarm_1': 'alarm', 'Banks_1': 'banks', 'Banks_2': 'banks', 'Buses_1': 'bus', 'Buses_2': 'bus',
'Buses_3': 'bus', 'Calendar_1': 'calendar', 'Events_1': 'events', 'Events_2': 'events',
'Events_3': 'events', 'Flights_1': 'flights', 'Flights_2': 'flights', 'Flights_3': 'flights',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment