diff --git a/.gitignore b/.gitignore index b8268b152da605f2187579dec5ec4f664edceefd..cd356208b42f43779d592bd7dd043a4bcbf306b6 100644 --- a/.gitignore +++ b/.gitignore @@ -33,12 +33,21 @@ convlab2/nlg/sclstm/**/sclstm.log convlab2/nlg/sclstm/**/sclstm_usr.pt convlab2/nlg/sclstm/**/sclstm_usr.res convlab2/nlg/sclstm/**/sclstm_usr.log -convlab2/nlu/jointBERT/**/output/ convlab2/dst/sumbt/multiwoz/output/ convlab2/nlg/sclstm/**/generated_sens_sys.json convlab2/nlg/template/**/generated_sens_sys.json convlab2/nlu/jointBERT/crosswoz/**/data convlab2/nlu/jointBERT/multiwoz/**/data +convlab2/nlu/jointBERT/**/output/ +convlab2/nlu/jointBERT_new/crosswoz/**/data +convlab2/nlu/jointBERT_new/multiwoz/**/data +convlab2/nlu/jointBERT_new/crosswoz/**/log +convlab2/nlu/jointBERT_new/multiwoz/**/log +convlab2/nlu/jointBERT_new/**/output/ +convlab2/nlu/milu/09* +convlab2/nlu/jointBERT/multiwoz/configs/multiwoz_new_usr_context.json +convlab2/nlu/milu/multiwoz/configs/system_without_context.jsonnet +convlab2/nlu/milu/multiwoz/configs/user_without_context.jsonnet # test script *_test.py diff --git a/convlab2/human_eval/worlds.py b/convlab2/human_eval/worlds.py index 70d86a1c4db42d4b960d4133e9e96c461b8c180a..7aaac4ce7e6c27dc550c7a44f116a996868e40d5 100755 --- a/convlab2/human_eval/worlds.py +++ b/convlab2/human_eval/worlds.py @@ -86,6 +86,7 @@ APPROPRIATENESS_MSG = 'Now please evaluate the \ be asked to give a reason for the score you choose.</b></span>' APPROPRIATENESS_REASON_MSG = 'Please give a <b>reason for the appropriateness \ score</b> you gave above. Please try to give concrete examples.' +REQT_MSG = 'Please type the values you obtained: ' import requests @@ -201,10 +202,10 @@ class MultiWozEvalWorld(MTurkTaskWorld): def __init__(self, opt, agent, num_extra_trial=2, max_turn=50, - max_resp_time=120, + max_resp_time=300, model_agent_opt=None, world_tag='', - agent_timeout_shutdown=120): + agent_timeout_shutdown=300): self.opt = opt self.agent = agent self.turn_idx = 1 @@ -261,6 +262,8 @@ class MultiWozEvalWorld(MTurkTaskWorld): self.goal_text += '</ul>' print(self.goal_text) + print(self.goal) + self.final_goal = deepcopy(self.goal) self.state = deepcopy(self.goal) def _track_state(self, inp): @@ -403,6 +406,23 @@ class MultiWozEvalWorld(MTurkTaskWorld): if 'text' in acts[idx] and \ acts[idx]['text'] != '': self.fail_reason = acts[idx]['text'] + else: + # reqt message + for domain in self.goal: + if 'reqt' in self.goal[domain]: + self.final_goal[domain]['reqt'] = dict() + for slot in self.goal[domain]['reqt']: + control_msg['text'] = REQT_MSG + '<b>' + domain + '-' + slot + '</b>' + agent.observe(validate(control_msg)) + acts[idx] = agent.act(timeout=self.max_resp_time) + while acts[idx]['text'] == '': + control_msg['text'] = 'Please try again.' + agent.observe(validate(control_msg)) + acts[idx] = agent.act(timeout=self.max_resp_time) + if 'text' in acts[idx] and \ + acts[idx]['text'] != '': + self.final_goal[domain]['reqt'][slot] = acts[idx]['text'] + # print(self.final_goal) # Language Understanding Check control_msg['text'] = UNDERSTANDING_MSG @@ -416,7 +436,7 @@ class MultiWozEvalWorld(MTurkTaskWorld): acts[idx]['text'] in self.ratings: self.understanding_score = int(acts[idx]['text']) - # Language Understanding reason + # Language Understanding reason control_msg['text'] = UNDERSTANDING_REASON_MSG agent.observe(validate(control_msg)) acts[idx] = agent.act(timeout=self.max_resp_time) @@ -440,7 +460,7 @@ class MultiWozEvalWorld(MTurkTaskWorld): acts[idx]['text'] in self.ratings: self.appropriateness_score = int(acts[idx]['text']) - # Response Appropriateness reason + # Response Appropriateness reason control_msg['text'] = APPROPRIATENESS_REASON_MSG agent.observe(validate(control_msg)) acts[idx] = agent.act(timeout=self.max_resp_time) @@ -553,6 +573,7 @@ class MultiWozEvalWorld(MTurkTaskWorld): ) ) result = {'goal': self.goal, + 'final_goal': self.final_goal, 'goal_text': self.goal_text, 'dialog': self.dialog, 'workers': self.agent.worker_id, diff --git a/convlab2/nlg/template/multiwoz/manual_system_template_nlg.json b/convlab2/nlg/template/multiwoz/manual_system_template_nlg.json index ee1c7aa898e0f9507f5545af4d20ff27e515a162..5d3bb11c79c2f1ee476e303c155e0282fc4ff8a0 100755 --- a/convlab2/nlg/template/multiwoz/manual_system_template_nlg.json +++ b/convlab2/nlg/template/multiwoz/manual_system_template_nlg.json @@ -58,8 +58,8 @@ "It is in the #ATTRACTION-INFORM-AREA# ." ], "Phone": [ - "The phone number is #ATTRACTION-INFORM-PHONE# .", - "Here is the phone number , #ATTRACTION-INFORM-PHONE# ." + "The attraction phone number is #ATTRACTION-INFORM-PHONE# .", + "Here is the attraction phone number , #ATTRACTION-INFORM-PHONE# ." ], "Type": [ "It is listed as a #ATTRACTION-INFORM-TYPE# attraction .", @@ -165,8 +165,8 @@ "Their postcode is #ATTRACTION-RECOMMEND-POST# ." ], "Phone": [ - "The phone number is #ATTRACTION-RECOMMEND-PHONE# .", - "Here is the phone number , #ATTRACTION-RECOMMEND-PHONE# ." + "The attraction phone number is #ATTRACTION-RECOMMEND-PHONE# .", + "Here is the attraction phone number , #ATTRACTION-RECOMMEND-PHONE# ." ], "Area": [ "That one is located in the #ATTRACTION-RECOMMEND-AREA# .", @@ -454,8 +454,8 @@ "the parking is free ." ], "Phone": [ - "Their phone number is #HOTEL-INFORM-PHONE# .", - "The phone number is #HOTEL-INFORM-PHONE# ." + "The hotel phone number is #HOTEL-INFORM-PHONE# .", + "The phone number of the hotel is #HOTEL-INFORM-PHONE# ." ], "Choice": [ "i have #HOTEL-INFORM-CHOICE# options for you", @@ -629,8 +629,8 @@ "would you like a recommendation ?" ], "Phone": [ - "Their phone number is #HOTEL-RECOMMEND-PHONE# .", - "The phone number is #HOTEL-RECOMMEND-PHONE# ." + "The hotel phone number is #HOTEL-RECOMMEND-PHONE# .", + "The phone number of the hotel is #HOTEL-RECOMMEND-PHONE# ." ], "Choice": [ "i have #HOTEL-RECOMMEND-CHOICE# options for you", @@ -758,15 +758,10 @@ "The post code is #RESTAURANT-INFORM-POST# ." ], "Phone": [ - "The number there is #RESTAURANT-INFORM-PHONE# .", - "their phone number is #RESTAURANT-INFORM-PHONE#", - "The phone number is #RESTAURANT-INFORM-PHONE# .", - "#RESTAURANT-INFORM-PHONE# is the phone number", - "Their phone number is #RESTAURANT-INFORM-PHONE#", - "Their number is #RESTAURANT-INFORM-PHONE# .", - "It is #RESTAURANT-INFORM-PHONE# .", - "Their phone number is #RESTAURANT-INFORM-PHONE# .", - "The phone number is #RESTAURANT-INFORM-PHONE# ." + "The number of the restaurant is #RESTAURANT-INFORM-PHONE# .", + "The restaurant's phone number is #RESTAURANT-INFORM-PHONE# .", + "The phone number of the restaurant is #RESTAURANT-INFORM-PHONE# .", + "#RESTAURANT-INFORM-PHONE# is the restaurant phone number" ], "Area": [ "it is in the #RESTAURANT-INFORM-AREA# area .", @@ -876,15 +871,10 @@ "The post code is #RESTAURANT-RECOMMEND-POST# ." ], "Phone": [ - "The number there is #RESTAURANT-RECOMMEND-PHONE# .", - "their phone number is #RESTAURANT-RECOMMEND-PHONE#", - "The phone number is #RESTAURANT-RECOMMEND-PHONE# .", - "#RESTAURANT-RECOMMEND-PHONE# is the phone number", - "Their phone number is #RESTAURANT-RECOMMEND-PHONE#", - "Their number is #RESTAURANT-RECOMMEND-PHONE# .", - "It is #RESTAURANT-RECOMMEND-PHONE# .", - "Their phone number is #RESTAURANT-RECOMMEND-PHONE# .", - "The phone number is #RESTAURANT-RECOMMEND-PHONE# ." + "The number of the restaurant is #RESTAURANT-RECOMMEND-PHONE# .", + "The restaurant's phone number is #RESTAURANT-RECOMMEND-PHONE#", + "The phone number of the restaurant is #RESTAURANT-RECOMMEND-PHONE# .", + "#RESTAURANT-RECOMMEND-PHONE# is the restaurant phone number" ], "none": [ "Is there anything else I can help you with ?", diff --git a/convlab2/nlu/jointBERT/dataloader.py b/convlab2/nlu/jointBERT/dataloader.py index fba4ebf125d71bc6a40d5f3a438faf63a7b0db3d..38fc24ea0fdc288410a65146716996432ebd896c 100755 --- a/convlab2/nlu/jointBERT/dataloader.py +++ b/convlab2/nlu/jointBERT/dataloader.py @@ -57,6 +57,7 @@ class Dataloader: new2ori = None d.append(new2ori) d.append(word_seq) + d.append(self.seq_tag2id(tag_seq)) d.append(self.seq_intent2id(d[2])) # d = (tokens, tags, intents, da2triples(turn["dialog_act"]), context(token id), new2ori, new_word_seq, tag2id_seq, intent2id_seq) @@ -95,7 +96,7 @@ class Dataloader: return split_tokens, new_tag_seq, new2ori def seq_tag2id(self, tags): - return [self.tag2id[x] for x in tags if x in self.tag2id] + return [self.tag2id[x] if x in self.tag2id else self.tag2id['O'] for x in tags] def seq_id2tag(self, ids): return [self.id2tag[x] for x in ids] diff --git a/convlab2/nlu/milu/dataset_reader.py b/convlab2/nlu/milu/dataset_reader.py index 3a8cf77818d2a8c8431337593eca532d55ec6cc8..5e00af04e7fe6c13ddbb21d60f22c51d5cbbc106 100755 --- a/convlab2/nlu/milu/dataset_reader.py +++ b/convlab2/nlu/milu/dataset_reader.py @@ -75,9 +75,11 @@ class MILUDatasetReader(DatasetReader): dialog = dialogs[dial_name]["log"] context_tokens_list = [] for i, turn in enumerate(dialog): - if self._agent and self._agent == "user" and i % 2 != 1: + if self._agent and self._agent == "user" and i % 2 == 1: + context_tokens_list.append(turn["text"].lower().split()+ ["SENT_END"]) continue - if self._agent and self._agent == "system" and i % 2 != 0: + if self._agent and self._agent == "system" and i % 2 == 0: + context_tokens_list.append(turn["text"].lower().split()+ ["SENT_END"]) continue tokens = turn["text"].split() diff --git a/convlab2/nlu/milu/multiwoz/nlu.py b/convlab2/nlu/milu/multiwoz/nlu.py index 5417c6d958954bf1005d399895bf7e2972379861..002a7dc86f6f0bcce81d9410193d069f442195ef 100755 --- a/convlab2/nlu/milu/multiwoz/nlu.py +++ b/convlab2/nlu/milu/multiwoz/nlu.py @@ -28,7 +28,7 @@ class MILU(NLU): def __init__(self, archive_file=DEFAULT_ARCHIVE_FILE, cuda_device=DEFAULT_CUDA_DEVICE, - model_file="https://convlab.blob.core.windows.net/convlab-2/milu_multiwoz_all_context.tar.gz", + model_file="https://convlab.blob.core.windows.net/convlab-2/new_milu(20200922)_multiwoz_all_context.tar.gz", context_size=3): """ Constructor for NLU class. """ diff --git a/convlab2/nlu/milu/train.py b/convlab2/nlu/milu/train.py index 99db49f91c3ae90f5c97287835f483c28fe1a832..9507a3a8ba920622ce84b26fa24513cdd2f2bb53 100755 --- a/convlab2/nlu/milu/train.py +++ b/convlab2/nlu/milu/train.py @@ -16,7 +16,8 @@ from allennlp.common.checks import check_for_gpu from allennlp.common.util import prepare_environment, prepare_global_logging, cleanup_global_logging, dump_metrics from allennlp.models.archival import archive_model, CONFIG_NAME from allennlp.models.model import Model, _DEFAULT_WEIGHTS -from allennlp.training.trainer import Trainer, TrainerPieces +from allennlp.training.trainer import Trainer +from allennlp.training.trainer_pieces import TrainerPieces from allennlp.training.trainer_base import TrainerBase from allennlp.training.util import create_serialization_dir, evaluate