Skip to content
Snippets Groups Projects
Select Git revision
  • 33080c63ad0f7841f225baf73726a4ec1b67471d
  • master default protected
  • btypes-fixes
  • freetypes
  • embedded-codegen
  • rust-support
  • z3solver
7 results

build.gradle

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    tracker.py 20.67 KiB
    import os
    import json
    import copy
    import logging
    
    import torch
    import transformers
    from transformers import BertModel, BertConfig, BertTokenizer, RobertaModel, RobertaConfig, RobertaTokenizer
    
    from convlab.dst.setsumbt.modeling import RobertaSetSUMBT, BertSetSUMBT
    from convlab.dst.setsumbt.modeling.training import set_ontology_embeddings
    from convlab.dst.dst import DST
    from convlab.util.custom_util import model_downloader
    
    USE_CUDA = torch.cuda.is_available()
    transformers.logging.set_verbosity_error()
    
    
    class SetSUMBTTracker(DST):
        """SetSUMBT Tracker object for Convlab dialogue system"""
    
        def __init__(self,
                     model_path: str = "",
                     model_type: str = "roberta",
                     return_turn_pooled_representation: bool = False,
                     return_confidence_scores: bool = False,
                     confidence_threshold='auto',
                     return_belief_state_entropy: bool = False,
                     return_belief_state_mutual_info: bool = False,
                     store_full_belief_state: bool = True):
            """
            Args:
                model_path: Model path or download URL
                model_type: Transformer type (roberta/bert)
                return_turn_pooled_representation: If true a turn level pooled representation is returned
                return_confidence_scores: If true act confidence scores are included in the state
                confidence_threshold: Confidence threshold value for constraints or option auto
                return_belief_state_entropy: If true belief state distribution entropies are included in the state
                return_belief_state_mutual_info: If true belief state distribution mutual infos are included in the state
                store_full_belief_state: If true full belief state is stored within tracker object
            """
            super(SetSUMBTTracker, self).__init__()
    
            self.model_type = model_type
            self.model_path = model_path
            self.return_turn_pooled_representation = return_turn_pooled_representation
            self.return_confidence_scores = return_confidence_scores
            self.confidence_threshold = confidence_threshold
            self.return_belief_state_entropy = return_belief_state_entropy
            self.return_belief_state_mutual_info = return_belief_state_mutual_info
            self.store_full_belief_state = store_full_belief_state
            if self.store_full_belief_state:
                self.full_belief_state = {}
            self.info_dict = {}
    
            # Download model if needed
            if not os.path.exists(self.model_path):
                # Get path /.../convlab/dst/setsumbt/multiwoz/models
                download_path = os.path.dirname(os.path.abspath(__file__))
                download_path = os.path.join(download_path, 'models')
                if not os.path.exists(download_path):
                    os.mkdir(download_path)
                model_downloader(download_path, self.model_path)
                # Downloadable model path format http://.../model_name.zip
                self.model_path = self.model_path.split('/')[-1].replace('.zip', '')
                self.model_path = os.path.join(download_path, self.model_path)
    
            # Select model type based on the encoder
            if model_type == "roberta":
                self.config = RobertaConfig.from_pretrained(self.model_path)
                self.tokenizer = RobertaTokenizer
                self.model = RobertaSetSUMBT
            elif model_type == "bert":
                self.config = BertConfig.from_pretrained(self.model_path)
                self.tokenizer = BertTokenizer
                self.model = BertSetSUMBT
            else:
                logging.debug("Name Error: Not Implemented")
    
            self.device = torch.device('cuda') if USE_CUDA else torch.device('cpu')
    
            self.load_weights()
    
        def load_weights(self):
            """Load model weights and model ontology"""
            logging.info('Loading SetSUMBT pretrained model.')
            self.tokenizer = self.tokenizer.from_pretrained(self.config.tokenizer_name)
            logging.info(f'Model tokenizer loaded from {self.config.tokenizer_name}.')
            self.model = self.model.from_pretrained(self.model_path, config=self.config)
            logging.info(f'Model loaded from {self.model_path}.')
    
            # Transfer model to compute device and setup eval environment
            self.model = self.model.to(self.device)
            self.model.eval()
            logging.info(f'Model transferred to device: {self.device}')
    
            logging.info('Loading model ontology')
            f = open(os.path.join(self.model_path, 'database', 'test.json'), 'r')
            self.ontology = json.load(f)
            f.close()
    
            db = torch.load(os.path.join(self.model_path, 'database', 'test.db'))
            set_ontology_embeddings(self.model, db)
    
            if self.return_confidence_scores:
                logging.info('Model returns user action and belief state confidence scores.')
                self.get_thresholds(self.confidence_threshold)
                logging.info('Uncertain Querying set up and thresholds set up at:')
                logging.info(self.confidence_thresholds)
            if self.return_belief_state_entropy:
                logging.info('Model returns belief state distribution entropy scores (Total uncertainty).')
            if self.return_belief_state_mutual_info:
                logging.info('Model returns belief state distribution mutual information scores (Knowledge uncertainty).')
            logging.info('Ontology loaded successfully.')
    
        def get_thresholds(self, threshold='auto') -> dict:
            """
            Setup dictionary of domain specific confidence thresholds
    
            Args:
                threshold: Threshold value or option auto
    
            Returns:
                confidence_thresholds: Domain specific confidence thresholds
            """
            self.confidence_thresholds = dict()
            for domain, substate in self.ontology.items():
                for slot, slot_info in substate.items():
                    # Auto thresholds are set based on the number of value candidates per slot
                    if domain not in self.confidence_thresholds:
                        self.confidence_thresholds[domain] = dict()
                    if threshold == 'auto':
                        thres = 1.0 / (float(len(slot_info['possible_values'])) - 2.1)
                        self.confidence_thresholds[domain][slot] = max(0.05, thres)
                    else:
                        self.confidence_thresholds[domain][slot] = max(0.05, threshold)
    
            return self.confidence_thresholds
    
        def init_session(self):
            self.state = dict()
            self.state['belief_state'] = dict()
            self.state['booked'] = dict()
            for domain, substate in self.ontology.items():
                self.state['belief_state'][domain] = dict()
                for slot, slot_info in substate.items():
                    if slot_info['possible_values'] and slot_info['possible_values'] != ['?']:
                        self.state['belief_state'][domain][slot] = ''
                self.state['booked'][domain] = list()
            self.state['history'] = []
            self.state['system_action'] = []
            self.state['user_action'] = []
            self.state['terminated'] = False
            self.active_domains = {}
            self.hidden_states = None
            self.info_dict = {}
    
        def update(self, user_act: str = '') -> dict:
            """
            Update user actions and dialogue and belief states.
    
            Args:
                user_act:
    
            Returns:
    
            """
            prev_state = self.state
            _output = self.predict(self.get_features(user_act))
    
            # Format state entropy
            if _output[5] is not None:
                state_entropy = dict()
                for slot, e in _output[5].items():
                    domain, slot = slot.split('-', 1)
                    if domain not in state_entropy:
                        state_entropy[domain] = dict()
                    state_entropy[domain][slot] = e
            else:
                state_entropy = None
    
            # Format state mutual information
            if _output[6] is not None:
                state_mutual_info = dict()
                for slot, mi in _output[6].items():
                    domain, slot = slot.split('-', 1)
                    if domain not in state_mutual_info:
                        state_mutual_info[domain] = dict()
                    state_mutual_info[domain][slot] = mi[0, 0]
            else:
                state_mutual_info = None
    
            # Format all confidence scores
            belief_state_confidence = None
            if _output[4] is not None:
                belief_state_confidence = dict()
                belief_state_conf, request_probs, active_domain_probs, general_act_probs = _output[4]
                for slot, p in belief_state_conf.items():
                    domain, slot = slot.split('-', 1)
                    if domain not in belief_state_confidence:
                        belief_state_confidence[domain] = dict()
                    if slot not in belief_state_confidence[domain]:
                        belief_state_confidence[domain][slot] = dict()
                    belief_state_confidence[domain][slot]['inform'] = p
    
                for slot, p in request_probs.items():
                    domain, slot = slot.split('-', 1)
                    if domain not in belief_state_confidence:
                        belief_state_confidence[domain] = dict()
                    if slot not in belief_state_confidence[domain]:
                        belief_state_confidence[domain][slot] = dict()
                    belief_state_confidence[domain][slot]['request'] = p
    
                for domain, p in active_domain_probs.items():
                    if domain not in belief_state_confidence:
                        belief_state_confidence[domain] = dict()
                    belief_state_confidence[domain]['none'] = {'inform': p}
    
                if 'general' not in belief_state_confidence:
                    belief_state_confidence['general'] = dict()
                belief_state_confidence['general']['none'] = general_act_probs
    
            # Get new domain activation actions
            new_domains = [d for d, active in _output[1].items() if active]
            new_domains = [d for d in new_domains if not self.active_domains.get(d, False)]
            self.active_domains = _output[1]
    
            user_acts = _output[2]
            for domain in new_domains:
                user_acts.append(['inform', domain, 'none', 'none'])
    
            new_belief_state = copy.deepcopy(prev_state['belief_state'])
            for domain, substate in _output[0].items():
                for slot, value in substate.items():
                    value = '' if value == 'none' else value
                    value = 'dontcare' if value == 'do not care' else value
                    value = 'guesthouse' if value == 'guest house' else value
    
                    if domain not in new_belief_state:
                        if domain == 'bus':
                            continue
                        else:
                            logging.debug('Error: domain <{}> not in belief state'.format(domain))
    
                    # Uncertainty clipping of state
                    if belief_state_confidence is not None:
                        threshold = self.confidence_thresholds[domain][slot]
                        if belief_state_confidence[domain][slot].get('inform', 1.0) < threshold:
                            value = ''
    
                    new_belief_state[domain][slot] = value
                    if prev_state['belief_state'][domain][slot] != value:
                        user_acts.append(['inform', domain, slot, value])
                    else:
                        bug = f'Unknown slot name <{slot}> with value <{value}> of domain <{domain}>'
                        logging.debug(bug)
    
            new_state = copy.deepcopy(dict(prev_state))
            new_state['belief_state'] = new_belief_state
            new_state['active_domains'] = self.active_domains
            if belief_state_confidence is not None:
                new_state['belief_state_probs'] = belief_state_confidence
            if state_entropy is not None:
                new_state['entropy'] = state_entropy
            if state_mutual_info is not None:
                new_state['mutual_information'] = state_mutual_info
    
            user_acts = [act for act in user_acts if act not in new_state['system_action']]
            new_state['user_action'] = user_acts
    
            if _output[3] is not None:
                new_state['turn_pooled_representation'] = _output[3]
    
            self.state = new_state
            self.info_dict = copy.deepcopy(dict(new_state))
    
            return self.state
    
        def predict(self, features: dict) -> tuple:
            """
            Model forward pass and prediction post processing.
    
            Args:
                features: Dictionary of model input features
    
            Returns:
                out: Model predictions and uncertainty features
            """
            state_mutual_info = None
            with torch.no_grad():
                turn_pooled_representation = None
                if self.return_turn_pooled_representation:
                    _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
                                          attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
                                          get_turn_pooled_representation=True)
                    belief_state = _outputs[0]
                    request_probs = _outputs[1]
                    active_domain_probs = _outputs[2]
                    general_act_probs = _outputs[3]
                    self.hidden_states = _outputs[4]
                    turn_pooled_representation = _outputs[5]
                elif self.return_belief_state_mutual_info:
                    _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
                                          attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
                                          get_turn_pooled_representation=True, calculate_state_mutual_info=True)
                    belief_state = _outputs[0]
                    request_probs = _outputs[1]
                    active_domain_probs = _outputs[2]
                    general_act_probs = _outputs[3]
                    self.hidden_states = _outputs[4]
                    state_mutual_info = _outputs[5]
                else:
                    _outputs = self.model(input_ids=features['input_ids'], token_type_ids=features['token_type_ids'],
                                          attention_mask=features['attention_mask'], hidden_state=self.hidden_states,
                                          get_turn_pooled_representation=False)
                    belief_state, request_probs, active_domain_probs, general_act_probs, self.hidden_states = _outputs
    
            # Convert belief state into dialog state
            dialogue_state = dict()
            for slot, probs in belief_state.items():
                dom, slot = slot.split('-', 1)
                if dom not in dialogue_state:
                    dialogue_state[dom] = dict()
                val = self.ontology[dom][slot]['possible_values'][probs[0, 0, :].argmax().item()]
                if val != 'none':
                    dialogue_state[dom][slot] = val
    
            if self.store_full_belief_state:
                self.info_dict['belief_state_distributions'] = belief_state
                if state_mutual_info is not None:
                    self.info_dict['belief_state_knowledge_uncertainty'] = state_mutual_info
    
            # Obtain model output probabilities
            if self.return_confidence_scores:
                state_entropy = None
                if self.return_belief_state_entropy:
                    state_entropy = {slot: probs[0, 0, :] for slot, probs in belief_state.items()}
                    state_entropy = {slot: self.relative_entropy(p).item() for slot, p in state_entropy.items()}
    
                # Confidence score is the max probability across all not "none" values candidates.
                belief_state_conf = {slot: probs[0, 0, 1:].max().item() for slot, probs in belief_state.items()}
                _request_probs = {slot: p[0, 0].item() for slot, p in request_probs.items()}
                _active_domain_probs = {domain: p[0, 0].item() for domain, p in active_domain_probs.items()}
                _general_act_probs = {'bye': general_act_probs[0, 0, 1].item(), 'thank': general_act_probs[0, 0, 2].item()}
                confidence_scores = (belief_state_conf, _request_probs, _active_domain_probs, _general_act_probs)
            else:
                confidence_scores = None
                state_entropy = None
    
            # Construct request action prediction
            if request_probs is not None:
                request_acts = [slot for slot, p in request_probs.items() if p[0, 0].item() > 0.5]
                request_acts = [slot.split('-', 1) for slot in request_acts]
                request_acts = [['request', domain, slot, '?'] for domain, slot in request_acts]
            else:
                request_acts = list()
    
            # Construct active domain set
            if active_domain_probs is not None:
                active_domains = {domain: p[0, 0].item() > 0.5 for domain, p in active_domain_probs.items()}
            else:
                active_domains = dict()
    
            # Construct general domain action
            if general_act_probs is not None:
                general_acts = general_act_probs[0, 0, :].argmax(-1).item()
                general_acts = [[], ['bye'], ['thank']][general_acts]
                general_acts = [[act, 'general', 'none', 'none'] for act in general_acts]
            else:
                general_acts = list()
    
            user_acts = request_acts + general_acts
    
            out = (dialogue_state, active_domains, user_acts, turn_pooled_representation, confidence_scores)
            out += (state_entropy, state_mutual_info)
            return out
    
        def relative_entropy(self, probs: torch.Tensor) -> torch.Tensor:
            """
            Compute relative entrop for a probability distribution
    
            Args:
                probs: Probability distributions
    
            Returns:
                entropy: Relative entropy
            """
            entropy = probs * torch.log(probs + 1e-8)
            entropy = -entropy.sum()
            # Maximum entropy of a K dimentional distribution is ln(K)
            entropy /= torch.log(torch.tensor(probs.size(-1)).float())
    
            return entropy
    
        def get_features(self, user_act: str) -> dict:
            """
            Tokenize utterances and construct model input features
    
            Args:
                user_act: User action string
    
            Returns:
                features: Model input features
            """
            # Extract system utterance from dialog history
            context = self.state['history']
            if context:
                if context[-1][0] != 'sys':
                    system_act = ''
                else:
                    system_act = context[-1][-1]
            else:
                system_act = ''
    
            # Tokenize dialog
            features = self.tokenizer.encode_plus(user_act, system_act, add_special_tokens=True,
                                                  max_length=self.config.max_turn_len, padding='max_length',
                                                  truncation='longest_first')
    
            input_ids = torch.tensor(features['input_ids']).reshape(
                1, 1, -1).to(self.device) if 'input_ids' in features else None
            token_type_ids = torch.tensor(features['token_type_ids']).reshape(
                1, 1, -1).to(self.device) if 'token_type_ids' in features else None
            attention_mask = torch.tensor(features['attention_mask']).reshape(
                1, 1, -1).to(self.device) if 'attention_mask' in features else None
            features = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': attention_mask}
    
            return features
    
    
    # if __name__ == "__main__":
    #     from convlab.policy.vector.vector_uncertainty import VectorUncertainty
    #     # from convlab.policy.vector.vector_binary import VectorBinary
    #     tracker = SetSUMBTTracker(model_path='/gpfs/project/niekerk/src/SetSUMBT/models/SetSUMBT+ActPrediction-multiwoz21-roberta-gru-cosine-labelsmoothing-Seed0-10-08-22-12-42',
    #                               return_confidence_scores=True, confidence_threshold='auto',
    #                               return_belief_state_entropy=True)
    #     vector = VectorUncertainty(use_state_total_uncertainty=True, confidence_thresholds=tracker.confidence_thresholds,
    #                                use_masking=True)
    #     # vector = VectorBinary()
    #     tracker.init_session()
    #
    #     state = tracker.update('hey. I need a cheap restaurant.')
    #     tracker.state['history'].append(['usr', 'hey. I need a cheap restaurant.'])
    #     tracker.state['history'].append(['sys', 'There are many cheap places, which food do you like?'])
    #     state = tracker.update('If you have something Asian that would be great.')
    #     tracker.state['history'].append(['usr', 'If you have something Asian that would be great.'])
    #     tracker.state['history'].append(['sys', 'The Golden Wok is a nice cheap chinese restaurant.'])
    #     tracker.state['system_action'] = [['inform', 'restaurant', 'food', 'chinese'],
    #                                       ['inform', 'restaurant', 'name', 'the golden wok']]
    #     state = tracker.update('Great. Where are they located?')
    #     tracker.state['history'].append(['usr', 'Great. Where are they located?'])
    #     state = tracker.state
    #     state['terminated'] = False
    #     state['booked'] = {}
    #
    #     print(state)
    #     print(vector.state_vectorize(state))