Skip to content
Snippets Groups Projects
Commit 4c90401d authored by truthless11's avatar truthless11
Browse files

remove unnecessary mapping

parent 8603ecbc
No related branches found
No related tags found
No related merge requests found
...@@ -5,21 +5,8 @@ import numpy as np ...@@ -5,21 +5,8 @@ import numpy as np
from convlab2.policy.vec import Vector from convlab2.policy.vec import Vector
from convlab2.util.multiwoz.lexicalize import delexicalize_da, flat_da, deflat_da, lexicalize_da from convlab2.util.multiwoz.lexicalize import delexicalize_da, flat_da, deflat_da, lexicalize_da
from convlab2.util.multiwoz.state import default_state from convlab2.util.multiwoz.state import default_state
from convlab2.util.multiwoz.multiwoz_slot_trans import REF_USR_DA
from convlab2.util.multiwoz.dbquery import Database from convlab2.util.multiwoz.dbquery import Database
mapping = {'restaurant': {'addr': 'address', 'area': 'area', 'food': 'food', 'name': 'name', 'phone': 'phone',
'post': 'postcode', 'price': 'pricerange'},
'hotel': {'addr': 'address', 'area': 'area', 'internet': 'internet', 'parking': 'parking', 'name': 'name',
'phone': 'phone', 'post': 'postcode', 'price': 'pricerange', 'stars': 'stars', 'type': 'type'},
'attraction': {'addr': 'address', 'area': 'area', 'fee': 'entrance fee', 'name': 'name', 'phone': 'phone',
'post': 'postcode', 'type': 'type'},
'train': {'id': 'trainID', 'arrive': 'arriveBy', 'day': 'day', 'depart': 'departure', 'dest': 'destination',
'time': 'duration', 'leave': 'leaveAt', 'ticket': 'price'},
'taxi': {'car': 'car type', 'phone': 'phone'},
'hospital': {'post': 'postcode', 'phone': 'phone', 'addr': 'address', 'department': 'department'},
'police': {'post': 'postcode', 'phone': 'phone', 'addr': 'address'}}
DEFAULT_INTENT_FILEPATH = os.path.join( DEFAULT_INTENT_FILEPATH = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))), os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))),
'data/multiwoz/trackable_intent.json' 'data/multiwoz/trackable_intent.json'
...@@ -93,10 +80,7 @@ class MultiWozVector(Vector): ...@@ -93,10 +80,7 @@ class MultiWozVector(Vector):
def pointer(self, turn): def pointer(self, turn):
pointer_vector = np.zeros(6 * len(self.db_domains)) pointer_vector = np.zeros(6 * len(self.db_domains))
for domain in self.db_domains: for domain in self.db_domains:
constraint = [] constraint = turn[domain.lower()]['semi'].items()
for k, v in turn[domain.lower()]['semi'].items():
if k in mapping[domain.lower()]:
constraint.append((mapping[domain.lower()][k], v))
entities = self.db.query(domain.lower(), constraint) entities = self.db.query(domain.lower(), constraint)
pointer_vector = self.one_hot_vector(len(entities), domain, pointer_vector) pointer_vector = self.one_hot_vector(len(entities), domain, pointer_vector)
...@@ -205,10 +189,7 @@ class MultiWozVector(Vector): ...@@ -205,10 +189,7 @@ class MultiWozVector(Vector):
entities list: entities list:
list of entities of the specified domain list of entities of the specified domain
""" """
constraint = [] constraint = self.state[domain.lower()]['semi'].items()
for k, v in self.state[domain.lower()]['semi'].items():
if k in mapping[domain.lower()]:
constraint.append((mapping[domain.lower()][k], v))
return self.db.query(domain.lower(), constraint) return self.db.query(domain.lower(), constraint)
def action_devectorize(self, action_vec): def action_devectorize(self, action_vec):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment