Skip to content
Snippets Groups Projects
Commit 5d22541b authored by Carel van Niekerk's avatar Carel van Niekerk :computer:
Browse files

Refactor uncertainty vectoriser

parent 2d5519ce
Branches
No related tags found
No related merge requests found
import os
from convlab.nlu import NLU
from convlab.dst import DST
......
from convlab.dst.dst import DST
from convlab.dst.setsumbt import SetSUMBTTracker
......@@ -223,10 +223,9 @@ class SetSUMBTTracker(DST):
user_acts = _output[2]
for domain in new_domains:
user_acts.append({'intent': 'inform', 'domain': domain, 'slot': '', 'value': ''})
user_acts.append(['inform', domain, 'none', 'none'])
new_belief_state = copy.deepcopy(prev_state['belief_state'])
# user_acts = []
for domain, substate in _output[0].items():
for slot, value in substate.items():
value = '' if value == 'none' else value
......@@ -247,7 +246,7 @@ class SetSUMBTTracker(DST):
new_belief_state[domain][slot] = value
if prev_state['belief_state'][domain][slot] != value:
user_acts.append({'intent': 'inform', 'domain': domain, 'slot': slot, 'value': value})
user_acts.append(['inform', domain, slot, value])
else:
bug = f'Unknown slot name <{slot}> with value <{value}> of domain <{domain}>'
logging.debug(bug)
......@@ -345,10 +344,7 @@ class SetSUMBTTracker(DST):
# Construct request action prediction
request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5]
request_acts = [slot.split('-', 1) for slot in request_acts]
request_acts = [{'intent': 'request',
'domain': domain,
'slot': slot,
'value': '?'} for domain, slot in request_acts]
request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts]
# Construct active domain set
active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()}
......@@ -356,7 +352,7 @@ class SetSUMBTTracker(DST):
# Construct general domain action
general_acts = general_act_probs[0, 0, :].argmax(-1).item()
general_acts = [[], ['bye'], ['thank']][general_acts]
general_acts = [{'intent': act, 'domain': 'general', 'slot': '', 'value': ''} for act in general_acts]
general_acts = [[act, 'general', 'none', 'none'] for act in general_acts]
user_acts = request_acts + general_acts
......@@ -417,22 +413,30 @@ class SetSUMBTTracker(DST):
return features
if __name__ == "__main__":
tracker = SetSUMBTTracker(model_path='/gpfs/project/niekerk/src/SetSUMBT/models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-labelsmoothing-Seed0-10-08-22-12-42',
return_turn_pooled_representation=True, return_confidence_scores=True,
confidence_threshold = 'auto', return_belief_state_entropy=True,
return_belief_state_mutual_info=True, store_full_belief_state=True)
tracker.init_session()
state = tracker.update('hey. I need a cheap restaurant.')
tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.'])
tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?'])
state = tracker.update('If you have something Asian that would be great.')
tracker.state['history'].append(['usr', 'If you have something Asian that would be great.'])
tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.'])
tracker.state['system_action'] = [{'intent': 'inform', 'domain': 'restaurant', 'slot': 'food', 'value': 'chinese'},
{'intent': 'inform', 'domain': 'restaurant', 'slot': 'name',
'value': 'the golden wok'}]
state = tracker.update('Great. Where are they located?')
tracker.state['history'].append(['usr', 'Great. Where are they located?'])
print(tracker.state)
print(tracker.full_belief_state)
# if __name__ == "__main__":
# from convlab.policy.vector.vector_uncertainty import VectorUncertainty
# # from convlab.policy.vector.vector_binary import VectorBinary
# tracker = SetSUMBTTracker(model_path='/gpfs/project/niekerk/src/SetSUMBT/models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-labelsmoothing-Seed0-10-08-22-12-42',
# return_confidence_scores=True, confidence_threshold='auto',
# return_belief_state_entropy=True)
# vector = VectorUncertainty(use_state_total_uncertainty=True, confidence_thresholds=tracker.confidence_thresholds,
# use_masking=True)
# # vector = VectorBinary()
# tracker.init_session()
#
# state = tracker.update('hey. I need a cheap restaurant.')
# tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.'])
# tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?'])
# state = tracker.update('If you have something Asian that would be great.')
# tracker.state['history'].append(['usr', 'If you have something Asian that would be great.'])
# tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.'])
# tracker.state['system_action'] = [['inform', 'restaurant', 'food', 'chinese'],
# ['inform', 'restaurant', 'name', 'the golden wok']]
# state = tracker.update('Great. Where are they located?')
# tracker.state['history'].append(['usr', 'Great. Where are they located?'])
# state = tracker.state
# state['terminated'] = False
# state['booked'] = {}
#
# print(state)
# print(vector.state_vectorize(state))
# -*- coding: utf-8 -*-
import sys
import os
import numpy as np
import logging
from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da
from convlab.util.multiwoz.state import default_state
from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
from .vector_binary import VectorBinary as VectorBase
DEFAULT_INTENT_FILEPATH = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(
os.path.dirname(os.path.abspath(__file__))))),
'data/multiwoz/trackable_intent.json'
)
SLOT_MAP = {'taxi_types': 'car type'}
class MultiWozVector(VectorBase):
def __init__(self, voc_file=None, voc_opp_file=None, character='sys',
intent_file=DEFAULT_INTENT_FILEPATH,
use_confidence_scores=False,
use_entropy=False,
use_mutual_info=False,
use_masking=False,
manually_add_entity_names=False,
seed=0,
shrink=False):
from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da
from convlab.policy.vector.vector_binary import VectorBinary
class VectorUncertainty(VectorBinary):
"""Vectorise state and state uncertainty predictions"""
def __init__(self,
dataset_name: str = 'multiwoz21',
character: str = 'sys',
use_masking: bool = False,
manually_add_entity_names: bool = True,
seed: str = 0,
use_confidence_scores: bool = True,
confidence_thresholds: dict = None,
use_state_total_uncertainty: bool = False,
use_state_knowledge_uncertainty: bool = False):
"""
Args:
dataset_name: Name of environment dataset
character: Character of the agent (sys/usr)
use_masking: If true certain actions are masked during devectorisation
manually_add_entity_names: If true inform entity name actions are manually added
seed: Seed
use_confidence_scores: If true confidence scores are used in state vectorisation
confidence_thresholds: If true confidence thresholds are used in database querying
use_state_total_uncertainty: If true state entropy is added to the state vector
use_state_knowledge_uncertainty: If true state mutual information is added to the state vector
"""
self.use_confidence_scores = use_confidence_scores
self.use_entropy = use_entropy
self.use_mutual_info = use_mutual_info
self.thresholds = None
self.use_state_total_uncertainty = use_state_total_uncertainty
self.use_state_knowledge_uncertainty = use_state_knowledge_uncertainty
if confidence_thresholds is not None:
self.setup_uncertain_query(confidence_thresholds)
super().__init__(voc_file, voc_opp_file, character, intent_file, use_masking, manually_add_entity_names, seed)
super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed)
def get_state_dim(self):
self.belief_state_dim = 0
for domain in self.belief_domains:
for slot in default_state()['belief_state'][domain.lower()]['semi']:
for domain in self.ontology['state']:
for slot in self.ontology['state'][domain]:
# Dim 1 - indicator/confidence score
# Dim 2 - Entropy (Total uncertainty) / Mutual information (knowledge unc)
slot_dim = 1 if not self.use_entropy else 2
slot_dim += 1 if self.use_mutual_info else 0
slot_dim = 1 if not self.use_state_total_uncertainty else 2
slot_dim += 1 if self.use_state_knowledge_uncertainty else 0
self.belief_state_dim += slot_dim
self.state_dim = self.da_opp_dim + self.da_dim + self.belief_state_dim + \
len(self.db_domains) + 6 * len(self.db_domains) + 1
# Add thresholds for db_queries
def setup_uncertain_query(self, confidence_thresholds):
self.use_confidence_scores = True
self.confidence_thresholds = confidence_thresholds
logging.info('DB Search uncertainty activated.')
def dbquery_domain(self, domain):
"""
query entities of specified domain
......@@ -61,178 +72,95 @@ class MultiWozVector(VectorBase):
list of entities of the specified domain
"""
# Get all user constraints
constraint = self.state[domain.lower()]['semi']
constraint = {k: i for k, i in constraint.items() if i and i not in ['dontcare', "do n't care", "do not care"]}
constraints = {slot: value for slot, value in self.state[domain].items()
if slot and value not in ['dontcare',
"do n't care", "do not care"]} if domain in self.state else dict()
# Remove constraints for which the uncertainty is high
if self.confidence_scores is not None and self.use_confidence_scores and self.thresholds != None:
if self.confidence_scores is not None and self.use_confidence_scores and self.confidence_thresholds is not None:
# Collect threshold values for each domain-slot pair
thres = self.thresholds.get(domain.lower(), {})
thres = {k: thres.get(k, 0.05) for k in constraint}
threshold = self.confidence_thresholds.get(domain, dict())
threshold = {slot: threshold.get(slot, 0.05) for slot in constraints}
# Get confidence scores for each constraint
probs = self.confidence_scores.get(domain.lower(), {})
probs = {k: probs.get(k, {}).get('inform', 1.0)
for k in constraint}
probs = self.confidence_scores.get(domain, dict())
probs = {slot: probs.get(slot, {}).get('inform', 1.0) for slot in constraints}
# Filter out constraints for which confidence is lower than threshold
constraint = {k: i for k, i in constraint.items()
if probs[k] >= thres[k]}
constraints = {slot: value for slot, value in constraints.items() if probs[slot] >= threshold[slot]}
return self.db.query(domain.lower(), constraint.items())
return self.db.query(domain, constraints.items(), topk=10)
# Add thresholds for db_queries
def setup_uncertain_query(self, thresholds):
self.use_confidence_scores = True
self.thresholds = thresholds
logging.info('DB Search uncertainty activated.')
def vectorize_user_act_confidence_scores(self, state, opp_action):
def vectorize_user_act(self, state):
"""Return confidence scores for the user actions"""
self.confidence_scores = state['belief_state_probs'] if 'belief_state_probs' in state else None
action = state['user_action'] if self.character == 'sys' else state['system_action']
opp_action = delexicalize_da(action, self.requestable)
opp_action = flat_da(opp_action)
opp_act_vec = np.zeros(self.da_opp_dim)
for da in self.opp2vec:
for da in opp_action:
if da in self.opp2vec:
if 'belief_state_probs' in state and self.use_confidence_scores:
domain, intent, slot, value = da.split('-')
if domain.lower() in state['belief_state_probs']:
# Map slot name to match user actions
slot = REF_SYS_DA[domain].get(
slot, slot) if domain in REF_SYS_DA else slot
if domain in state['belief_state_probs']:
slot = slot if slot else 'none'
slot = SLOT_MAP.get(slot, slot)
domain = domain.lower()
if slot in state['belief_state_probs'][domain]:
prob = state['belief_state_probs'][domain][slot]
elif slot.lower() in state['belief_state_probs'][domain]:
prob = state['belief_state_probs'][domain][slot.lower()]
else:
prob = {}
prob = dict()
intent = intent.lower()
if intent in prob:
prob = float(prob[intent])
elif da in opp_action:
prob = 1.0
else:
prob = 0.0
elif da in opp_action:
prob = 1.0
else:
prob = 0.0
opp_act_vec[self.opp2vec[da]] = prob
return opp_act_vec
def state_vectorize(self, state):
"""vectorize a state
Args:
state (dict):
Dialog state
action (tuple):
Dialog act
Returns:
state_vec (np.array):
Dialog state vector
"""
self.state = state['belief_state']
self.confidence_scores = state['belief_state_probs'] if 'belief_state_probs' in state else None
domain_active_dict = {}
for domain in self.belief_domains:
domain_active_dict[domain] = False
# when character is sys, to help query database when da is booking-book
# update current domain according to user action
if self.character == 'sys':
action = state['user_action']
for intent, domain, slot, value in action:
domain_active_dict[domain] = True
action = state['user_action'] if self.character == 'sys' else state['system_action']
opp_action = delexicalize_da(action, self.requestable)
opp_action = flat_da(opp_action)
if 'belief_state_probs' in state and self.use_confidence_scores:
opp_act_vec = self.vectorize_user_act_confidence_scores(
state, opp_action)
prob = 1.0
else:
opp_act_vec = np.zeros(self.da_opp_dim)
for da in opp_action:
if da in self.opp2vec:
prob = 1.0
opp_act_vec[self.opp2vec[da]] = prob
action = state['system_action'] if self.character == 'sys' else state['user_action']
action = delexicalize_da(action, self.requestable)
action = flat_da(action)
last_act_vec = np.zeros(self.da_dim)
for da in action:
if da in self.act2vec:
last_act_vec[self.act2vec[da]] = 1.
return opp_act_vec
def vectorize_belief_state(self, state, domain_active_dict):
belief_state = np.zeros(self.belief_state_dim)
i = 0
for domain in self.belief_domains:
if self.use_confidence_scores and 'belief_state_probs' in state:
for slot in state['belief_state'][domain.lower()]['semi']:
if slot in state['belief_state_probs'][domain.lower()]:
prob = state['belief_state_probs'][domain.lower()
][slot]
for slot in state['belief_state'][domain]:
prob = None
if slot in state['belief_state_probs'][domain]:
prob = state['belief_state_probs'][domain][slot]
prob = prob['inform'] if 'inform' in prob else None
if prob:
belief_state[i] = float(prob)
i += 1
else:
for slot, value in state['belief_state'][domain.lower()]['semi'].items():
for slot, value in state['belief_state'][domain].items():
if value and value != 'not mentioned':
belief_state[i] = 1.
i += 1
if 'active_domains' in state:
domain_active = state['active_domains'][domain.lower()]
domain_active = state['active_domains'][domain]
domain_active_dict[domain] = domain_active
else:
if [slot for slot, value in state['belief_state'][domain.lower()]['semi'].items() if value]:
if [slot for slot, value in state['belief_state'][domain].items() if value]:
domain_active_dict[domain] = True
# Add knowledge and/or total uncertainty to the belief state
if self.use_entropy and 'entropy' in state:
if self.use_state_total_uncertainty and 'entropy' in state:
for domain in self.belief_domains:
for slot in state['belief_state'][domain.lower()]['semi']:
if slot in state['entropy'][domain.lower()]:
belief_state[i] = float(
state['entropy'][domain.lower()][slot])
for slot in state['belief_state'][domain]:
if slot in state['entropy'][domain]:
belief_state[i] = float(state['entropy'][domain][slot])
i += 1
if self.use_mutual_info and 'mutual_information' in state:
if self.use_state_knowledge_uncertainty and 'mutual_information' in state:
for domain in self.belief_domains:
for slot in state['belief_state'][domain.lower()]['semi']:
if slot in state['mutual_information'][domain.lower()]:
belief_state[i] = float(
state['mutual_information'][domain.lower()][slot])
for slot in state['belief_state'][domain]['semi']:
if slot in state['mutual_information'][domain]:
belief_state[i] = float(state['mutual_information'][domain][slot])
i += 1
book = np.zeros(len(self.db_domains))
for i, domain in enumerate(self.db_domains):
if state['belief_state'][domain.lower()]['book']['booked']:
book[i] = 1.
degree, number_entities_dict = self.pointer()
final = 1. if state['terminated'] else 0.
state_vec = np.r_[opp_act_vec, last_act_vec,
belief_state, book, degree, final]
assert len(state_vec) == self.state_dim
if self.use_mask is not None:
# None covers the case for policies that don't use masking at all, so do not expect an output "state_vec, mask"
if self.use_mask:
domain_mask = self.compute_domain_mask(domain_active_dict)
entity_mask = self.compute_entity_mask(number_entities_dict)
general_mask = self.compute_general_mask()
mask = domain_mask + entity_mask + general_mask
for i in range(self.da_dim):
mask[i] = -int(bool(mask[i])) * sys.maxsize
else:
mask = np.zeros(self.da_dim)
return state_vec, mask
else:
return state_vec
return belief_state, domain_active_dict
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment