Skip to content
Snippets Groups Projects
Commit 46939fb3 authored by Christian's avatar Christian
Browse files

added model links to readme

parent a719c81e
No related branches found
No related tags found
No related merge requests found
......@@ -43,7 +43,7 @@ The necessary step before starting a training is to set up the environment and p
```
{
"model": {
"load_path": "", # specify a loading path to load a pre-trained model
"load_path": "", # specify a loading path to load a pre-trained model, omit the ending .pol.mdl
"use_pretrained_initialisation": false, # will download a provided ConvLab-3 model
"pretrained_load_path": "",
"seed": 0, # the seed for the experiment
......
......@@ -14,6 +14,9 @@ The dataset name can be "multiwoz21" or "sgd" for instance. The first time you r
Other hyperparameters such as learning rate or number of epochs can be set in the config.json file.
We provide a model trained on multiwoz21 on hugging-face: https://huggingface.co/ConvLab/mle-policy-multiwoz21
## Evaluation
Evaluation on the validation data set takes place during training.
\ No newline at end of file
# -*- coding: utf-8 -*-
import sys
import numpy as np
import logging
from convlab.util.multiwoz.lexicalize import delexicalize_da, flat_da
from .vector_base import VectorBase
......@@ -8,9 +10,11 @@ from .vector_base import VectorBase
class VectorNodes(VectorBase):
def __init__(self, dataset_name='multiwoz21', character='sys', use_masking=False, manually_add_entity_names=True,
seed=0):
seed=0, filter_state=True):
super().__init__(dataset_name, character, use_masking, manually_add_entity_names, seed)
self.filter_state = filter_state
logging.info(f"We filter state by active domains: {self.filter_state}")
def get_state_dim(self):
self.belief_state_dim = 0
......@@ -56,9 +60,16 @@ class VectorNodes(VectorBase):
self.get_user_act_feature(state)
self.get_sys_act_feature(state)
domain_active_dict = self.get_user_goal_feature(state, domain_active_dict)
number_entities_dict = self.get_db_features()
self.get_general_features(state, domain_active_dict)
if self.db is not None:
number_entities_dict = self.get_db_features()
else:
number_entities_dict = None
if self.filter_state:
self.kg_info = self.filter_inactive_domains(domain_active_dict)
if self.use_mask:
mask = self.get_mask(domain_active_dict, number_entities_dict)
for i in range(self.da_dim):
......@@ -89,6 +100,8 @@ class VectorNodes(VectorBase):
feature_type = 'user goal'
for domain in self.belief_domains:
# the if case is needed because SGD only saves the dialogue state info for active domains
if domain in state['belief_state']:
for slot, value in state['belief_state'][domain].items():
description = f"user goal-{domain}-{slot}".lower()
value = 1.0 if (value and value != "not mentioned") else 0.0
......@@ -128,6 +141,7 @@ class VectorNodes(VectorBase):
def get_general_features(self, state, domain_active_dict):
feature_type = 'general'
if 'booked' in state:
for i, domain in enumerate(self.db_domains):
if domain in state['booked']:
description = f"general-{domain}-booked".lower()
......@@ -140,3 +154,17 @@ class VectorNodes(VectorBase):
value = 1.0 if domain_active_dict[domain] else 0
description = f"general-{domain}".lower()
self.add_graph_node(domain, feature_type, description, value)
def filter_inactive_domains(self, domain_active_dict):
kg_filtered = []
for node in self.kg_info:
domain = node['domain']
if domain in domain_active_dict:
if domain_active_dict[domain]:
kg_filtered.append(node)
else:
kg_filtered.append(node)
return kg_filtered
......@@ -20,7 +20,11 @@ You can specify the dataset that you would like to use, e.g. "multiwoz21" or "sg
You can specify hyperparamters such as epoch, supervised_lr and data_percentage (how much of the data you want to use) in the config.json file.
We provide several supervised trained models on hugging-face to reproduce the results:
- pre-trained on SGD: https://huggingface.co/ConvLab/ddpt-policy-sgd
- pre-trained on 1% multiwoz21: https://huggingface.co/ConvLab/ddpt-policy-0.01multiwoz21
- pre-trained on SGD and afterwards on 1% multiwoz21: https://huggingface.co/ConvLab/ddpt-policy-sgd_0.01multiwoz21
## RL training
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment