Skip to content
Snippets Groups Projects
Unverified Commit 4f9d5759 authored by Ryuichi Takanobu's avatar Ryuichi Takanobu Committed by GitHub
Browse files

remove unnecessary mapping (#147)

parent 222280e5
Branches
No related tags found
No related merge requests found
......@@ -5,21 +5,8 @@ import numpy as np
from convlab2.policy.vec import Vector
from convlab2.util.multiwoz.lexicalize import delexicalize_da, flat_da, deflat_da, lexicalize_da
from convlab2.util.multiwoz.state import default_state
from convlab2.util.multiwoz.multiwoz_slot_trans import REF_USR_DA
from convlab2.util.multiwoz.dbquery import Database
mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'name': 'name', 'phone': 'phone',
'post': 'postcode', 'price': 'pricerange'},
'hotel': {'addr': 'address', 'area': 'area', 'internet': 'internet', 'parking': 'parking', 'name': 'name',
'phone': 'phone', 'post': 'postcode', 'price': 'pricerange', 'stars': 'stars', 'type': 'type'},
'attraction': {'addr': 'address', 'area': 'area', 'fee': 'entrance fee', 'name': 'name', 'phone': 'phone',
'post': 'postcode', 'type': 'type'},
'train': {'id': 'trainID', 'arrive': 'arriveBy', 'day': 'day', 'depart': 'departure', 'dest': 'destination',
'time': 'duration', 'leave': 'leaveAt', 'ticket': 'price'},
'taxi': {'car': 'car type', 'phone': 'phone'},
'hospital': {'post': 'postcode', 'phone': 'phone', 'addr': 'address', 'department': 'department'},
'police': {'post': 'postcode', 'phone': 'phone', 'addr': 'address'}}
DEFAULT_INTENT_FILEPATH = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))),
'data/multiwoz/trackable_intent.json'
......@@ -93,10 +80,7 @@ class MultiWozVector(Vector):
def pointer(self, turn):
pointer_vector = np.zeros(6 * len(self.db_domains))
for domain in self.db_domains:
constraint = []
for k, v in turn[domain.lower()]['semi'].items():
if k in mapping[domain.lower()]:
constraint.append((mapping[domain.lower()][k], v))
constraint = turn[domain.lower()]['semi'].items()
entities = self.db.query(domain.lower(), constraint)
pointer_vector = self.one_hot_vector(len(entities), domain, pointer_vector)
......@@ -205,10 +189,7 @@ class MultiWozVector(Vector):
entities list:
list of entities of the specified domain
"""
constraint = []
for k, v in self.state[domain.lower()]['semi'].items():
if k in mapping[domain.lower()]:
constraint.append((mapping[domain.lower()][k], v))
constraint = self.state[domain.lower()]['semi'].items()
return self.db.query(domain.lower(), constraint)
def action_devectorize(self, action_vec):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment