diff --git a/convlab2/policy/vector/vector_multiwoz.py b/convlab2/policy/vector/vector_multiwoz.py index 64841ef5e7a1fd54a5f60b3cce11c646475b8e51..7ef0c8a2e3a32f181ea98f10834b886c014dc72a 100755 --- a/convlab2/policy/vector/vector_multiwoz.py +++ b/convlab2/policy/vector/vector_multiwoz.py @@ -5,21 +5,8 @@ import numpy as np from convlab2.policy.vec import Vector from convlab2.util.multiwoz.lexicalize import delexicalize_da, flat_da, deflat_da, lexicalize_da from convlab2.util.multiwoz.state import default_state -from convlab2.util.multiwoz.multiwoz_slot_trans import REF_USR_DA from convlab2.util.multiwoz.dbquery import Database -mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'name': 'name', 'phone': 'phone', - 'post': 'postcode', 'price': 'pricerange'}, - 'hotel': {'addr': 'address', 'area': 'area', 'internet': 'internet', 'parking': 'parking', 'name': 'name', - 'phone': 'phone', 'post': 'postcode', 'price': 'pricerange', 'stars': 'stars', 'type': 'type'}, - 'attraction': {'addr': 'address', 'area': 'area', 'fee': 'entrance fee', 'name': 'name', 'phone': 'phone', - 'post': 'postcode', 'type': 'type'}, - 'train': {'id': 'trainID', 'arrive': 'arriveBy', 'day': 'day', 'depart': 'departure', 'dest': 'destination', - 'time': 'duration', 'leave': 'leaveAt', 'ticket': 'price'}, - 'taxi': {'car': 'car type', 'phone': 'phone'}, - 'hospital': {'post': 'postcode', 'phone': 'phone', 'addr': 'address', 'department': 'department'}, - 'police': {'post': 'postcode', 'phone': 'phone', 'addr': 'address'}} - DEFAULT_INTENT_FILEPATH = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), 'data/multiwoz/trackable_intent.json' @@ -93,10 +80,7 @@ class MultiWozVector(Vector): def pointer(self, turn): pointer_vector = np.zeros(6 * len(self.db_domains)) for domain in self.db_domains: - constraint = [] - for k, v in turn[domain.lower()]['semi'].items(): - if k in mapping[domain.lower()]: - constraint.append((mapping[domain.lower()][k], v)) + constraint = turn[domain.lower()]['semi'].items() entities = self.db.query(domain.lower(), constraint) pointer_vector = self.one_hot_vector(len(entities), domain, pointer_vector) @@ -205,10 +189,7 @@ class MultiWozVector(Vector): entities list: list of entities of the specified domain """ - constraint = [] - for k, v in self.state[domain.lower()]['semi'].items(): - if k in mapping[domain.lower()]: - constraint.append((mapping[domain.lower()][k], v)) + constraint = self.state[domain.lower()]['semi'].items() return self.db.query(domain.lower(), constraint) def action_devectorize(self, action_vec):