Skip to content
Snippets Groups Projects
Unverified Commit 1895814d authored by Jinchao Li's avatar Jinchao Li Committed by GitHub
Browse files

add input reqt vals in human eval (#128)

parent d70c0fc6
No related branches found
No related tags found
No related merge requests found
...@@ -86,6 +86,7 @@ APPROPRIATENESS_MSG = 'Now please evaluate the \ ...@@ -86,6 +86,7 @@ APPROPRIATENESS_MSG = 'Now please evaluate the \
be asked to give a reason for the score you choose.</b></span>' be asked to give a reason for the score you choose.</b></span>'
APPROPRIATENESS_REASON_MSG = 'Please give a <b>reason for the appropriateness \ APPROPRIATENESS_REASON_MSG = 'Please give a <b>reason for the appropriateness \
score</b> you gave above. Please try to give concrete examples.' score</b> you gave above. Please try to give concrete examples.'
REQT_MSG = 'Please type the values you obtained: '
import requests import requests
...@@ -201,10 +202,10 @@ class MultiWozEvalWorld(MTurkTaskWorld): ...@@ -201,10 +202,10 @@ class MultiWozEvalWorld(MTurkTaskWorld):
def __init__(self, opt, agent, def __init__(self, opt, agent,
num_extra_trial=2, num_extra_trial=2,
max_turn=50, max_turn=50,
max_resp_time=120, max_resp_time=300,
model_agent_opt=None, model_agent_opt=None,
world_tag='', world_tag='',
agent_timeout_shutdown=120): agent_timeout_shutdown=300):
self.opt = opt self.opt = opt
self.agent = agent self.agent = agent
self.turn_idx = 1 self.turn_idx = 1
...@@ -261,6 +262,8 @@ class MultiWozEvalWorld(MTurkTaskWorld): ...@@ -261,6 +262,8 @@ class MultiWozEvalWorld(MTurkTaskWorld):
self.goal_text += '</ul>' self.goal_text += '</ul>'
print(self.goal_text) print(self.goal_text)
print(self.goal)
self.final_goal = deepcopy(self.goal)
self.state = deepcopy(self.goal) self.state = deepcopy(self.goal)
def _track_state(self, inp): def _track_state(self, inp):
...@@ -403,6 +406,23 @@ class MultiWozEvalWorld(MTurkTaskWorld): ...@@ -403,6 +406,23 @@ class MultiWozEvalWorld(MTurkTaskWorld):
if 'text' in acts[idx] and \ if 'text' in acts[idx] and \
acts[idx]['text'] != '': acts[idx]['text'] != '':
self.fail_reason = acts[idx]['text'] self.fail_reason = acts[idx]['text']
else:
# reqt message
for domain in self.goal:
if 'reqt' in self.goal[domain]:
self.final_goal[domain]['reqt'] = dict()
for slot in self.goal[domain]['reqt']:
control_msg['text'] = REQT_MSG + '<b>' + domain + '-' + slot + '</b>'
agent.observe(validate(control_msg))
acts[idx] = agent.act(timeout=self.max_resp_time)
while acts[idx]['text'] == '':
control_msg['text'] = 'Please try again.'
agent.observe(validate(control_msg))
acts[idx] = agent.act(timeout=self.max_resp_time)
if 'text' in acts[idx] and \
acts[idx]['text'] != '':
self.final_goal[domain]['reqt'][slot] = acts[idx]['text']
# print(self.final_goal)
# Language Understanding Check # Language Understanding Check
control_msg['text'] = UNDERSTANDING_MSG control_msg['text'] = UNDERSTANDING_MSG
...@@ -553,6 +573,7 @@ class MultiWozEvalWorld(MTurkTaskWorld): ...@@ -553,6 +573,7 @@ class MultiWozEvalWorld(MTurkTaskWorld):
) )
) )
result = {'goal': self.goal, result = {'goal': self.goal,
'final_goal': self.final_goal,
'goal_text': self.goal_text, 'goal_text': self.goal_text,
'dialog': self.dialog, 'dialog': self.dialog,
'workers': self.agent.worker_id, 'workers': self.agent.worker_id,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment