From 4f9d5759e5de5c5698321d161564f3f0fc237839 Mon Sep 17 00:00:00 2001 From: Ryuichi Takanobu <truthless11@gmail.com> Date: Thu, 15 Oct 2020 21:24:04 +0800 Subject: [PATCH] remove unnecessary mapping (#147) --- convlab2/policy/vector/vector_multiwoz.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/convlab2/policy/vector/vector_multiwoz.py b/convlab2/policy/vector/vector_multiwoz.py index 64841ef..7ef0c8a 100755 --- a/convlab2/policy/vector/vector_multiwoz.py +++ b/convlab2/policy/vector/vector_multiwoz.py @@ -5,21 +5,8 @@ import numpy as np from convlab2.policy.vec import Vector from convlab2.util.multiwoz.lexicalize import delexicalize_da, flat_da, deflat_da, lexicalize_da from convlab2.util.multiwoz.state import default_state -from convlab2.util.multiwoz.multiwoz_slot_trans import REF_USR_DA from convlab2.util.multiwoz.dbquery import Database -mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'name': 'name', 'phone': 'phone', - 'post': 'postcode', 'price': 'pricerange'}, - 'hotel': {'addr': 'address', 'area': 'area', 'internet': 'internet', 'parking': 'parking', 'name': 'name', - 'phone': 'phone', 'post': 'postcode', 'price': 'pricerange', 'stars': 'stars', 'type': 'type'}, - 'attraction': {'addr': 'address', 'area': 'area', 'fee': 'entrance fee', 'name': 'name', 'phone': 'phone', - 'post': 'postcode', 'type': 'type'}, - 'train': {'id': 'trainID', 'arrive': 'arriveBy', 'day': 'day', 'depart': 'departure', 'dest': 'destination', - 'time': 'duration', 'leave': 'leaveAt', 'ticket': 'price'}, - 'taxi': {'car': 'car type', 'phone': 'phone'}, - 'hospital': {'post': 'postcode', 'phone': 'phone', 'addr': 'address', 'department': 'department'}, - 'police': {'post': 'postcode', 'phone': 'phone', 'addr': 'address'}} - DEFAULT_INTENT_FILEPATH = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), 'data/multiwoz/trackable_intent.json' @@ -93,10 +80,7 @@ class MultiWozVector(Vector): def pointer(self, turn): pointer_vector = np.zeros(6 * len(self.db_domains)) for domain in self.db_domains: - constraint = [] - for k, v in turn[domain.lower()]['semi'].items(): - if k in mapping[domain.lower()]: - constraint.append((mapping[domain.lower()][k], v)) + constraint = turn[domain.lower()]['semi'].items() entities = self.db.query(domain.lower(), constraint) pointer_vector = self.one_hot_vector(len(entities), domain, pointer_vector) @@ -205,10 +189,7 @@ class MultiWozVector(Vector): entities list: list of entities of the specified domain """ - constraint = [] - for k, v in self.state[domain.lower()]['semi'].items(): - if k in mapping[domain.lower()]: - constraint.append((mapping[domain.lower()][k], v)) + constraint = self.state[domain.lower()]['semi'].items() return self.db.query(domain.lower(), constraint) def action_devectorize(self, action_vec): -- GitLab