# -*- coding: utf-8 -*- """ Created on Wed Jul 17 14:27:34 2019 @author: truthless """ class Environment(): def __init__(self, sys_nlg, usr, sys_nlu, sys_dst, evaluator=None): self.sys_nlg = sys_nlg self.usr = usr self.sys_nlu = sys_nlu self.sys_dst = sys_dst self.evaluator = evaluator def reset(self): self.usr.init_session() self.sys_dst.init_session() if self.evaluator: self.evaluator.add_goal(self.usr.policy.get_goal()) return self.sys_dst.state def step(self, action): model_response = self.sys_nlg.generate(action) if self.sys_nlg else action observation = self.usr.response(model_response) if self.evaluator: self.evaluator.add_sys_da(self.usr.get_in_da()) self.evaluator.add_usr_da(self.usr.get_out_da()) dialog_act = self.sys_nlu.predict(observation) if self.sys_nlu else observation self.sys_dst.state['user_action'] = dialog_act state = self.sys_dst.update(dialog_act) if self.evaluator: if self.evaluator.task_success(): reward = 40 elif self.evaluator.cur_domain and self.evaluator.domain_success(self.evaluator.cur_domain): reward = 5 else: reward = -1 else: reward = self.usr.get_reward() terminated = self.usr.is_terminated() return state, reward, terminated