Skip to content
Snippets Groups Projects
Select Git revision
  • c796811b2abe6b5c901b3886118bed3f09251731
  • master default protected
  • towards_1.8.0
  • updateTLC
  • 1.1.0-stups
  • 1.0.2-stups
  • 1.0.1-stups
  • 1.0.0-stups
8 results

pom.xml

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    lexicalize.py 3.94 KiB
    from copy import deepcopy
    
    from convlab.util.multiwoz.multiwoz_slot_trans import REF_SYS_DA
    from convlab.util import relative_import_module_from_unified_datasets
    
    reverse_da_slot_name_map = relative_import_module_from_unified_datasets('multiwoz21', 'preprocess.py', 'reverse_da_slot_name_map')
    
    def delexicalize_da(da, requestable):
        delexicalized_da = []
        counter = {}
        for intent, domain, slot, value in da:
            if slot == "":
                slot = 'none'
            if intent in requestable:
                v = '?'
            else:
                if slot == 'none':
                    v = 'none'
                else:
                    k = '_'.join([intent, domain, slot])
                    counter.setdefault(k, 0)
                    counter[k] += 1
                    v = str(counter[k])
            delexicalized_da.append([domain, intent, slot, v])
        return delexicalized_da
    
    
    def flat_da(delexicalized_da):
        flaten = ['_'.join(x) for x in delexicalized_da]
        return flaten
    
    
    def deflat_da(meta):
        meta = deepcopy(meta)
        dialog_act = {}
        for da in meta:
            d, i, s, v = da.split('_')
            k = '_'.join((d, i))
            if k not in dialog_act:
                dialog_act[k] = []
            dialog_act[k].append([s, v])
        return dialog_act
    
    
    def lexicalize_da(meta, entities, state, requestable):
        meta = deepcopy(meta)
        for k, v in meta.items():
            domain, intent = k.split('_')
            if domain in ['general']:
                continue
            elif intent in requestable:
                for pair in v:
                    pair[1] = '?'
            else:
                if intent == "book":
                    # this means we booked something. We retrieve reference number here
                    for pair in v:
                        n = int(pair[1]) - 1 if pair[1] != 'none' else 0
                        if len(entities[domain]) > n:
                            if 'Ref' in entities[domain][n]:
                                pair[1] = entities[domain][n]['Ref']
                    continue
    
                for pair in v:
                    if pair[1] == 'none':
                        continue
                    elif pair[0].lower() == 'choice':
                        pair[1] = str(len(entities[domain]))
                    else:
                        # try to retrieve value from the database entity, otherwise from the belief state
                        if domain != 'taxi':
                            slot_reverse = reverse_da_slot_name_map.get(pair[0], pair[0])
                        else:
                            slot_reverse = reverse_da_slot_name_map['taxi'].get(pair[0], pair[0])
                        try:
                            slot_old = REF_SYS_DA[domain.capitalize()].get(slot_reverse, pair[0].lower())
                        except:
                            slot_old = ""
                        slot = pair[0]
                        n = int(pair[1]) - 1
                        if len(entities[domain]) > n:
                            if slot in entities[domain][n]:
                                pair[1] = entities[domain][n][slot]
                            elif "".join(slot.split(" ")) in entities[domain][n]:
                                pair[1] = entities[domain][n]["".join(slot.split(" "))]
                            elif slot.capitalize() in entities[domain][n]:
                                pair[1] = entities[domain][n][slot.capitalize()]
                            elif slot_old in entities[domain][n]:
                                pair[1] = entities[domain][n][slot_old]
                            elif slot in state[domain]:
                                pair[1] = state[domain][slot]
                            else:
                                pair[1] = 'not available'
                        elif slot in state[domain]:
                            pair[1] = state[domain][slot] if state[domain][slot] else 'none'
                        else:
                            pair[1] = 'none'
    
        tuples = []
        for domain_intent, svs in meta.items():
            for slot, value in svs:
                domain, intent = domain_intent.split('_')
                tuples.append([intent, domain, slot, value])
        return tuples