Skip to content
Snippets Groups Projects
Select Git revision
  • 006a02cd85a965edd338d326587b044542f65d03
  • master default protected
2 results

__init__.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    evaluate.py 7.02 KiB
    # -*- coding: utf-8 -*-
    
    import argparse
    import datetime
    import logging
    import os
    
    import numpy as np
    import torch
    from convlab.dialog_agent.agent import PipelineAgent
    from convlab.dialog_agent.session import BiSession
    from convlab.evaluator.multiwoz_eval import MultiWozEvaluator
    from convlab.policy.rule.multiwoz import RulePolicy
    from convlab.task.multiwoz.goal_generator import GoalGenerator
    from convlab.util.custom_util import set_seed, get_config, env_config, create_goals, data_goals
    from tqdm import tqdm
    
    
    def init_logging(log_dir_path, path_suffix=None):
        if not os.path.exists(log_dir_path):
            os.makedirs(log_dir_path)
        current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
        if path_suffix:
            log_file_path = os.path.join(
                log_dir_path, f"{current_time}_{path_suffix}.log")
        else:
            log_file_path = os.path.join(
                log_dir_path, "{}.log".format(current_time))
        print("LOG DIR PATH:", log_dir_path)
        stderr_handler = logging.StreamHandler()
        file_handler = logging.FileHandler(log_file_path)
        format_str = "%(message)s"
        logging.basicConfig(level=logging.DEBUG, handlers=[
                            stderr_handler, file_handler], format=format_str)
    
    
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    
    def evaluate(config_path, model_name, verbose=False, model_path="", goals_from_data=False, dialogues=500):
        seed = 0
        set_seed(seed)
    
        conf = get_config(config_path, [])
    
        if model_name == "PPO":
            from convlab.policy.ppo import PPO
            policy_sys = PPO(vectorizer=conf['vectorizer_sys_activated'])
        elif model_name == "RULE":
            policy_sys = RulePolicy()
        elif model_name == "PG":
            from convlab.policy.pg import PG
            policy_sys = PG(vectorizer=conf['vectorizer_sys_activated'])
        elif model_name == "MLE":
            from convlab.policy.mle import MLE
            policy_sys = MLE()
        elif model_name == "GDPL":
            from convlab.policy.gdpl import GDPL
            policy_sys = GDPL(vectorizer=conf['vectorizer_sys_activated'])
        elif model_name == "DDPT":
            from convlab.policy.vtrace_DPT import VTRACE
            policy_sys = VTRACE(is_train=False, vectorizer=conf['vectorizer_sys_activated'])
    
        try:
            if model_path:
                policy_sys.load(model_path)
            else:
                policy_sys.load(conf['model']['load_path'])
        except Exception as e:
            logging.info(f"Could not load a policy: {e}")
    
        env, sess = env_config(conf, policy_sys)
        action_dict = {}
    
        task_success = {'Complete': [], 'Success': [],
                        'Success strict': [], 'total_return': [], 'turns': []}
    
        goal_generator = GoalGenerator()
        if goals_from_data:
            logging.info("read goals from dataset...")
            goals = data_goals(dialogues, dataset="multiwoz21", dial_ids_order=0)
        else:
            logging.info("create goals from goal_generator...")
            goals = create_goals(goal_generator, num_goals=dialogues,
                                 single_domains=False, allowed_domains=None)
    
        for seed in tqdm(range(1000, 1000 + dialogues)):
            set_seed(seed)
            sess.init_session(goal=goals[seed-1000])
            sys_response = []
            actions = 0.0
            total_return = 0.0
            turns = 0
            task_succ = 0
            task_succ_strict = 0
            complete = 0
    
            if verbose:
                logging.info("NEW EPISODE!!!!" + "-" * 80)
                logging.info(f"\n Seed: {seed}")
                logging.info(f"GOAL: {sess.evaluator.goal}")
                logging.info("\n")
            for i in range(40):
                sys_response, user_response, session_over, reward = sess.next_turn(
                    sys_response)
    
                if verbose:
                    logging.info(f"USER RESPONSE: {user_response}")
                    logging.info(f"SYS RESPONSE: {sys_response}")
    
                actions += len(sys_response)
                length = len(sys_response)
                if length in action_dict:
                    if sys_response not in action_dict[length]:
                        action_dict[length].append(sys_response)
                else:
                    action_dict[length] = []
                    action_dict[length].append(sys_response)
    
                # logging.info(f"Actions in turn: {len(sys_response)}")
                turns += 1
                total_return += sess.evaluator.get_reward(session_over)
    
                if session_over:
                    task_succ = sess.evaluator.task_success()
                    task_succ = sess.evaluator.success
                    task_succ_strict = sess.evaluator.success_strict
                    if goals_from_data:
                        complete = sess.user_agent.policy.policy.goal.task_complete()
                    else:
                        complete = sess.evaluator.complete
                    break
    
            if verbose:
                logging.info(f"Complete: {complete}")
                logging.info(f"Success: {task_succ}")
                logging.info(f"Success strict: {task_succ_strict}")
                logging.info(f"Return: {total_return}")
                logging.info(f"Average actions: {actions / turns}")
    
            task_success['Complete'].append(complete)
            task_success['Success'].append(task_succ)
            task_success['Success strict'].append(task_succ_strict)
            task_success['total_return'].append(total_return)
            task_success['turns'].append(turns)
    
        for key in task_success:
            logging.info(
                f'{key} {len(task_success[key])} {np.average(task_success[key]) if len(task_success[key]) > 0 else 0}')
        logging.info(f"Average actions: {actions / turns}")
    
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        parser.add_argument("--model_name", type=str,
                            default="PPO", help="name of model")
        parser.add_argument("-C", "--config_path", type=str,
                            default='', help="config path defining the environment for simulation and system pipeline")
        parser.add_argument("--model_path", type=str,
                            default='', help="if this is set, tries to load the model weights from this path"
                                             ", otherwise from config")
        parser.add_argument("-N", "--num_dialogues", type=int,
                            default=500, help="# of evaluation dialogue")
        parser.add_argument("-V", "--verbose", action='store_true',
                            help="whether to output utterances")
        parser.add_argument("--log_path_suffix", type=str,
                            default="", help="suffix of path of log file")
        parser.add_argument("--log_dir_path", type=str,
                            default="log", help="path of log directory")
        parser.add_argument("-D", "--goals_from_data", action='store_true',
                            help="load goal from the dataset")
    
        args = parser.parse_args()
    
        init_logging(log_dir_path=args.log_dir_path,
                     path_suffix=args.log_path_suffix)
        evaluate(config_path=args.config_path,
                 model_name=args.model_name,
                 verbose=args.verbose,
                 model_path=args.model_path,
                 goals_from_data=args.goals_from_data,
                 dialogues=args.num_dialogues)