Skip to content
Snippets Groups Projects
Commit 09efd18a authored by Christian's avatar Christian
Browse files

first version where SGD data set is running, now have to test whether it learns properly

parent 2540630d
No related branches found
No related tags found
No related merge requests found
Showing
with 47 additions and 19 deletions
......@@ -58,7 +58,7 @@ def evaluate(config_path, model_name, verbose=False, model_path="", goals_from_d
from convlab.policy.gdpl import GDPL
policy_sys = GDPL(vectorizer=conf['vectorizer_sys_activated'])
elif model_name == "DDPT":
from convlab2.policy.vtrace_DPT import VTRACE
from convlab.policy.vtrace_DPT import VTRACE
policy_sys = VTRACE(is_train=False, vectorizer=conf['vectorizer_sys_activated'])
try:
......
......@@ -100,6 +100,8 @@ class VectorNodes(VectorBase):
feature_type = 'user goal'
for domain in self.belief_domains:
# the if case is needed because SGD only saves the dialogue state info for active domains
if domain in state['belief_state']:
for slot, value in state['belief_state'][domain].items():
description = f"user goal-{domain}-{slot}".lower()
value = 1.0 if (value and value != "not mentioned") else 0.0
......
......@@ -11,6 +11,7 @@ def create_description_dicts(name='multiwoz21'):
ontology = load_ontology(name)
default_state = ontology['state']
domains = list(ontology['domains'].keys())
try:
db = load_database(name)
db_domains = db.domains
......@@ -31,22 +32,27 @@ def create_description_dicts(name='multiwoz21'):
for domain in default_state:
for slot in default_state[domain]:
domain = domain.lower()
description_dict_semantic[f"user goal-{domain}-{slot.lower()}"] = f"user goal {domain} {slot}"
if db_domains:
for domain in db_domains:
domain = domain.lower()
description_dict_semantic[f"db-{domain}-entities"] = f"data base {domain} number of entities"
description_dict_semantic[f"general-{domain}-booked"] = f"general {domain} booked"
for domain in domains:
domain = domain.lower()
description_dict_semantic[f"general-{domain}"] = f"domain {domain}"
for act in da_voc:
domain, intent, slot, value = act.split("-")
domain = domain.lower()
description_dict_semantic["system-"+act.lower()] = f"last system act {domain} {intent} {slot} {value}"
for act in da_voc_opp:
domain, intent, slot, value = [item.lower() for item in act.split("-")]
domain = domain.lower()
description_dict_semantic["user-"+act.lower()] = f"user act {domain} {intent} {slot} {value}"
root_dir = os.path.dirname(os.path.abspath(__file__))
......
Source diff could not be displayed: it is too large. Options to address this: view the blob.
......@@ -43,6 +43,8 @@ class PolicyDataVectorizer:
self.data[split] = []
raw_data = data_split[split]
num = 0
for data_point in raw_data:
state = default_state()
......@@ -52,6 +54,7 @@ class PolicyDataVectorizer:
if len(data_point['context']) > 1 else {}
state['system_action'] = flatten_acts(last_system_act)
state['terminated'] = data_point['terminated']
if 'booked' in data_point:
state['booked'] = data_point['booked']
dialogue_act = flatten_acts(data_point['dialogue_acts'])
......@@ -59,6 +62,9 @@ class PolicyDataVectorizer:
vectorized_action = self.vector.action_vectorize(dialogue_act)
self.data[split].append({"state": self.vector.kg_info, "action": vectorized_action, "mask": mask,
"terminated": state['terminated']})
num += 1
if num > 500:
break
with open(os.path.join(processed_dir, '{}.pkl'.format(split)), 'wb') as f:
pickle.dump(self.data[split], f)
......
......@@ -61,6 +61,9 @@ class MLE_Trainer:
loss_a = self.policy_loop(data)
a_loss += loss_a.item()
loss_a.backward()
for p in self.policy.parameters():
if p.grad is not None:
p.grad[p.grad != p.grad] = 0.0
self.policy_optim.step()
self.policy.eval()
......@@ -162,12 +165,17 @@ if __name__ == '__main__':
with open(os.path.join(root_directory, 'config.json'), 'r') as f:
cfg = json.load(f)
cfg['dataset_name'] = args.dataset_name
logger, tb_writer, current_time, save_path, config_save_path, dir_path, log_save_path = \
init_logging(os.path.dirname(os.path.abspath(__file__)), "info")
save_config(vars(args), cfg, config_save_path)
set_seed(args.seed)
logging.info(f"Seed used: {args.seed}")
logging.info(f"Batch size: {cfg['batchsz']}")
logging.info(f"Epochs: {cfg['epoch']}")
logging.info(f"Learning rate: {cfg['supervised_lr']}")
vector = VectorNodes(dataset_name=args.dataset_name, use_masking=False, filter_state=True)
manager = PolicyDataVectorizer(dataset_name=args.dataset_name, vector=vector)
......
......@@ -343,11 +343,11 @@ class EncoderDecoder(nn.Module):
# Map the actions to action embeddings that are fed as input to decoder model
# pad input and remove "eos" token
padded_decoder_input = torch.stack(
[torch.cat([act[:-1], torch.zeros(max_length - len(act))], dim=-1) for act in action_targets], dim=0) \
[torch.cat([act[:-1], torch.zeros(max_length - len(act)).to(DEVICE)], dim=-1) for act in action_targets], dim=0) \
.to(DEVICE).long()
padded_action_targets = torch.stack(
[torch.cat([act, torch.zeros(max_length - len(act))], dim=-1) for act in action_targets], dim=0) \
[torch.cat([act, torch.zeros(max_length - len(act)).to(DEVICE)], dim=-1) for act in action_targets], dim=0) \
.to(DEVICE)
decoder_input = self.action_embedder.action_embeddings[padded_decoder_input]
......
......@@ -40,6 +40,7 @@ class ActionEmbedder(nn.Module):
self.action_embeddings = nn.Parameter(self.action_embeddings)
else:
logging.info("We use Roberta to embed actions.")
self.dataset_name = node_embedder.dataset_name
self.create_action_embeddings_roberta(node_embedder)
self.action_embeddings.requires_grad = False
embedding_dim = 768
......@@ -102,15 +103,15 @@ class ActionEmbedder(nn.Module):
return action_mask.to(DEVICE)
def get_action_mask(self, domain="", intent="", start=False):
def get_action_mask(self, domain=None, intent="", start=False):
action_mask = torch.ones(len(self.small_action_dict))
# This is for predicting end of sequence token <eos>
if not start and not domain:
if not start and domain is None:
action_mask[self.small_action_dict['eos']] = 0
if not domain:
if domain is None:
#TODO: I allow all domains now for checking supervised training
for domain in self.domain_dict:
if domain not in self.forbidden_domains:
......@@ -136,6 +137,8 @@ class ActionEmbedder(nn.Module):
if valid:
action_mask[self.small_action_dict[slot_value]] = 0
assert not torch.equal(action_mask, torch.ones(len(self.small_action_dict)))
return action_mask.to(DEVICE)
def get_current_domain_mask(self, current_domains, current=True):
......@@ -154,7 +157,7 @@ class ActionEmbedder(nn.Module):
def is_valid(self, part_action):
for act in self.action_dict:
if part_action in act:
if act.startswith(part_action):
return True
return False
......@@ -202,8 +205,10 @@ class ActionEmbedder(nn.Module):
action_embeddings.append("pad") #add the PAD token
small_action_dict['pad'] = len(small_action_dict)
action_embeddings_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'action_embeddings.pt')
small_action_dict_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'small_action_dict.json')
action_embeddings_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
f'action_embeddings_{self.dataset_name}.pt')
small_action_dict_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
f'small_action_dict_{self.dataset_name}.json')
if os.path.exists(action_embeddings_path):
self.action_embeddings = torch.load(action_embeddings_path).to(DEVICE)
......
File deleted
File added
File deleted
File added
......@@ -37,7 +37,7 @@ class NodeEmbedderRoberta(nn.Module):
if roberta_path:
embedded_descriptions_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'embedded_descriptions.pt')
f'embedded_descriptions_{self.dataset_name}.pt')
if os.path.exists(embedded_descriptions_path):
self.embedded_descriptions = torch.load(embedded_descriptions_path).to(DEVICE)
else:
......@@ -47,7 +47,7 @@ class NodeEmbedderRoberta(nn.Module):
else:
embedded_descriptions_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'embedded_descriptions_base.pt')
f'embedded_descriptions_base_{self.dataset_name}.pt')
if os.path.exists(embedded_descriptions_path):
self.embedded_descriptions = torch.load(embedded_descriptions_path).to(DEVICE)
else:
......
{"": 0, "alarm_1": 1, "banks_1": 2, "banks_2": 3, "buses_1": 4, "buses_2": 5, "buses_3": 6, "calendar_1": 7, "events_1": 8, "events_2": 9, "events_3": 10, "flights_1": 11, "flights_2": 12, "flights_3": 13, "flights_4": 14, "homes_1": 15, "homes_2": 16, "hotels_1": 17, "hotels_2": 18, "hotels_3": 19, "hotels_4": 20, "media_1": 21, "media_2": 22, "media_3": 23, "messaging_1": 24, "movies_1": 25, "movies_2": 26, "movies_3": 27, "music_1": 28, "music_2": 29, "music_3": 30, "payment_1": 31, "rentalcars_1": 32, "rentalcars_2": 33, "rentalcars_3": 34, "restaurants_1": 35, "restaurants_2": 36, "ridesharing_1": 37, "ridesharing_2": 38, "services_1": 39, "services_2": 40, "services_3": 41, "services_4": 42, "trains_1": 43, "travel_1": 44, "weather_1": 45, "eos": 46, "goodbye": 47, "req_more": 48, "confirm": 49, "inform_count": 50, "notify_success": 51, "offer": 52, "offer_intent": 53, "request": 54, "inform": 55, "notify_failure": 56, "none-none": 57, "new_alarm_name-1": 58, "new_alarm_time-1": 59, "count-1": 60, "alarm_name-1": 61, "alarm_time-1": 62, "addalarm-1": 63, "new_alarm_time-?": 64, "account_type-1": 65, "amount-1": 66, "recipient_account_name-1": 67, "recipient_account_type-1": 68, "balance-1": 69, "transfermoney-1": 70, "account_type-?": 71, "amount-?": 72, "recipient_account_name-?": 73, "recipient_name-1": 74, "transfer_amount-1": 75, "transfer_time-1": 76, "account_balance-1": 77, "recipient_name-?": 78, "transfer_amount-?": 79, "from_location-1": 80, "leaving_date-1": 81, "leaving_time-1": 82, "to_location-1": 83, "travelers-1": 84, "from_station-1": 85, "to_station-1": 86, "transfers-1": 87, "fare-1": 88, "buybusticket-1": 89, "from_location-?": 90, "leaving_date-?": 91, "leaving_time-?": 92, "to_location-?": 93, "travelers-?": 94, "departure_date-1": 95, "departure_time-1": 96, "destination-1": 97, "fare_type-1": 98, "group_size-1": 99, "origin-1": 100, "destination_station_name-1": 101, "origin_station_name-1": 102, "price-1": 103, "departure_date-?": 104, "departure_time-?": 105, "destination-?": 106, "group_size-?": 107, "origin-?": 108, "additional_luggage-1": 109, "from_city-1": 110, "num_passengers-1": 111, "to_city-1": 112, "category-1": 113, "from_city-?": 114, "num_passengers-?": 115, "to_city-?": 116, "event_date-1": 117, "event_location-1": 118, "event_name-1": 119, "event_time-1": 120, "available_end_time-1": 121, "available_start_time-1": 122, "addevent-1": 123, "event_date-?": 124, "event_location-?": 125, "event_name-?": 126, "event_time-?": 127, "city_of_event-1": 128, "date-1": 129, "number_of_seats-1": 130, "address_of_location-1": 131, "subcategory-1": 132, "time-1": 133, "buyeventtickets-1": 134, "category-?": 135, "city_of_event-?": 136, "date-?": 137, "number_of_seats-?": 138, "city-1": 139, "number_of_tickets-1": 140, "venue-1": 141, "venue_address-1": 142, "city-?": 143, "event_type-?": 144, "number_of_tickets-?": 145, "price_per_ticket-1": 146, "airlines-1": 147, "destination_city-1": 148, "inbound_departure_time-1": 149, "origin_city-1": 150, "outbound_departure_time-1": 151, "passengers-1": 152, "return_date-1": 153, "seating_class-1": 154, "destination_airport-1": 155, "inbound_arrival_time-1": 156, "number_stops-1": 157, "origin_airport-1": 158, "outbound_arrival_time-1": 159, "refundable-1": 160, "reserveonewayflight-1": 161, "reserveroundtripflights-1": 162, "airlines-?": 163, "destination_city-?": 164, "inbound_departure_time-?": 165, "origin_city-?": 166, "outbound_departure_time-?": 167, "return_date-?": 168, "is_redeye-1": 169, "arrives_next_day-1": 170, "destination_airport_name-1": 171, "origin_airport_name-1": 172, "is_nonstop-1": 173, "destination_airport-?": 174, "origin_airport-?": 175, "property_name-1": 176, "visit_date-1": 177, "furnished-1": 178, "pets_allowed-1": 179, "phone_number-1": 180, "address-1": 181, "number_of_baths-1": 182, "number_of_beds-1": 183, "rent-1": 184, "schedulevisit-1": 185, "area-?": 186, "number_of_beds-?": 187, "visit_date-?": 188, "has_garage-1": 189, "in_unit_laundry-1": 190, "intent-?": 191, "number_of_baths-?": 192, "check_in_date-1": 193, "hotel_name-1": 194, "number_of_days-1": 195, "number_of_rooms-1": 196, "has_wifi-1": 197, "price_per_night-1": 198, "street_address-1": 199, "star_rating-1": 200, "reservehotel-1": 201, "check_in_date-?": 202, "hotel_name-?": 203, "number_of_days-?": 204, "check_out_date-1": 205, "number_of_adults-1": 206, "where_to-1": 207, "has_laundry_service-1": 208, "total_price-1": 209, "rating-1": 210, "bookhouse-1": 211, "check_out_date-?": 212, "number_of_adults-?": 213, "where_to-?": 214, "location-1": 215, "pets_welcome-1": 216, "average_rating-1": 217, "location-?": 218, "place_name-1": 219, "stay_length-1": 220, "smoking_allowed-1": 221, "stay_length-?": 222, "subtitles-1": 223, "title-1": 224, "directed_by-1": 225, "genre-1": 226, "title-2": 227, "title-3": 228, "playmovie-1": 229, "genre-?": 230, "title-?": 231, "movie_name-1": 232, "subtitle_language-1": 233, "movie_name-2": 234, "movie_name-3": 235, "rentmovie-1": 236, "starring-1": 237, "contact_name-1": 238, "contact_name-?": 239, "show_date-1": 240, "show_time-1": 241, "show_type-1": 242, "theater_name-1": 243, "buymovietickets-1": 244, "movie_name-?": 245, "show_date-?": 246, "show_time-?": 247, "show_type-?": 248, "aggregate_rating-1": 249, "cast-1": 250, "movie_title-1": 251, "percent_rating-1": 252, "playback_device-1": 253, "song_name-1": 254, "album-1": 255, "year-1": 256, "artist-1": 257, "playsong-1": 258, "song_name-?": 259, "playmedia-1": 260, "device-1": 261, "track-1": 262, "payment_method-1": 263, "private_visibility-1": 264, "receiver-1": 265, "payment_method-?": 266, "receiver-?": 267, "dropoff_date-1": 268, "pickup_date-1": 269, "pickup_location-1": 270, "pickup_time-1": 271, "type-1": 272, "car_name-1": 273, "reservecar-1": 274, "dropoff_date-?": 275, "pickup_city-?": 276, "pickup_date-?": 277, "pickup_location-?": 278, "pickup_time-?": 279, "type-?": 280, "car_type-1": 281, "car_type-?": 282, "add_insurance-1": 283, "end_date-1": 284, "start_date-1": 285, "price_per_day-1": 286, "add_insurance-?": 287, "end_date-?": 288, "start_date-?": 289, "party_size-1": 290, "restaurant_name-1": 291, "cuisine-1": 292, "has_live_music-1": 293, "price_range-1": 294, "serves_alcohol-1": 295, "reserverestaurant-1": 296, "cuisine-?": 297, "restaurant_name-?": 298, "time-?": 299, "has_seating_outdoors-1": 300, "has_vegetarian_options-1": 301, "number_of_riders-1": 302, "shared_ride-1": 303, "approximate_ride_duration-1": 304, "ride_fare-1": 305, "number_of_riders-?": 306, "shared_ride-?": 307, "ride_type-1": 308, "wait_time-1": 309, "ride_type-?": 310, "appointment_date-1": 311, "appointment_time-1": 312, "stylist_name-1": 313, "is_unisex-1": 314, "bookappointment-1": 315, "appointment_date-?": 316, "appointment_time-?": 317, "dentist_name-1": 318, "offers_cosmetic_services-1": 319, "doctor_name-1": 320, "therapist_name-1": 321, "class-1": 322, "date_of_journey-1": 323, "from-1": 324, "journey_start_time-1": 325, "to-1": 326, "trip_protection-1": 327, "total-1": 328, "gettraintickets-1": 329, "date_of_journey-?": 330, "from-?": 331, "to-?": 332, "trip_protection-?": 333, "free_entry-1": 334, "good_for_kids-1": 335, "attraction_name-1": 336, "humidity-1": 337, "wind-1": 338, "precipitation-1": 339, "temperature-1": 340, "pad": 341}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment