Skip to content
Snippets Groups Projects
Commit 6d5d3479 authored by Hsien-Chin Lin's avatar Hsien-Chin Lin
Browse files

ablation study

parent 4e28a9f0
No related branches found
No related tags found
No related merge requests found
......@@ -15,11 +15,12 @@ DEBUG = False
class UserActionPolicy(GenTUSUserActionPolicy):
def __init__(self, model_checkpoint, mode="semantic", only_action=True, max_turn=40, **kwargs):
def __init__(self, model_checkpoint, mode="language", only_action=False, max_turn=40, **kwargs):
self.use_sentiment = kwargs.get("use_sentiment", False)
print("use_sentiment", self.use_sentiment)
self.add_persona = kwargs.get("add_persona", False)
self.emotion_mid = kwargs.get("emotion_mid", False)
super().__init__(model_checkpoint, mode, only_action, max_turn, **kwargs)
print("sentiment", self.use_sentiment)
weight = kwargs.get("weight", None)
self.kg = KnowledgeGraph(
tokenizer=self.tokenizer,
......@@ -52,22 +53,17 @@ class UserActionPolicy(GenTUSUserActionPolicy):
else:
history = self.usr_acts[-1*self.max_history:]
# TODO add user info? impolite? -> check self.use_sentiment
if self.use_sentiment:
# TODO how to get event and user politeness?
input_dict = {"system": sys_act,
"goal": self.goal.get_goal_list(),
"history": history,
"turn": str(int(self.time_step/2))}
if self.add_persona:
for user, info in self.user_info.items():
input_dict[user] = info
inputs = json.dumps(input_dict)
else:
inputs = json.dumps({"system": sys_act,
"goal": self.goal.get_goal_list(),
"history": history,
"turn": str(int(self.time_step/2))})
with torch.no_grad():
if emotion == "all":
raw_output = self.generate_from_emotion(
......@@ -91,16 +87,12 @@ class UserActionPolicy(GenTUSUserActionPolicy):
raw_output = self._generate_action(
raw_inputs=inputs, mode=mode, allow_general_intent=allow_general_intent)
output = self._parse_output(raw_output)
print(output)
self.semantic_action = self._remove_illegal_action(output["action"])
if not self.only_action:
self.utterance = output["text"]
self.emotion = output["emotion"]
if self.use_sentiment:
self.sentiment = output["sentiment"]
# print("---> sentiment", self.sentiment)
# print("---> emotion", self.emotion)
# print("---> self.utterance", self.utterance)
if self.is_finish():
self.emotion, self.semantic_action, self.utterance = self._good_bye()
......@@ -113,12 +105,7 @@ class UserActionPolicy(GenTUSUserActionPolicy):
del inputs
if self.mode == "language":
# print("in", sys_act)
# print("out", self.utterance)
return self.utterance
else:
return self.semantic_action
def _parse_output(self, in_str):
in_str = str(in_str)
......@@ -135,25 +122,24 @@ class UserActionPolicy(GenTUSUserActionPolicy):
print("-"*20)
return action
def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True, emotion_mode="normal"):
self.kg.parse_input(raw_inputs)
model_input = self.vector.encode(raw_inputs, self.max_in_len)
# start token
self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
pos = self._update_seq([0], 0)
pos = self._update_seq(self.token_map.get_id('start_json'), pos)
if self.use_sentiment:
def _update_sentiment(self, pos, model_input, mode):
pos = self._update_seq(
self.token_map.get_id('start_sentiment'), pos)
sentiment = self._get_sentiment(
model_input, self.seq[:1, :pos], mode)
pos = self._update_seq(sentiment["token_id"], pos)
else:
return sentiment, pos
def _update_emotion(self, pos, model_input, mode, emotion_mode, sentiment=None):
pos = self._update_seq(
self.token_map.get_id('start_emotion'), pos)
emotion = self._get_emotion(
model_input, self.seq[:1, :pos], mode, emotion_mode)
model_input, self.seq[:1, :pos], mode, emotion_mode, sentiment)
pos = self._update_seq(emotion["token_id"], pos)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
pos = self._update_seq(self.token_map.get_id('start_act'), pos)
return pos
# get semantic actions
def _update_semantic_act(self, pos, model_input, mode, allow_general_intent):
mode = "max"
for act_len in range(self.max_action_len):
pos = self._get_semantic_action(
model_input, pos, mode, allow_general_intent)
......@@ -164,16 +150,82 @@ class UserActionPolicy(GenTUSUserActionPolicy):
if terminate:
break
return pos
if self.only_action:
return self.vector.decode(self.seq[0, :pos])
def _sent_act_emo(self, pos, model_input, mode, emotion_mode, allow_general_intent):
# sent
sentiment, pos = self._update_sentiment(pos, model_input, mode)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
# act
pos = self._update_seq(self.token_map.get_id('start_act'), pos)
pos = self._update_semantic_act(
pos, model_input, mode, allow_general_intent)
# emo
pos = self._update_emotion(
pos, model_input, mode, emotion_mode, sentiment["token_name"])
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
if self.use_sentiment:
pos = self._update_seq(self.token_map.get_id('start_emotion'), pos)
emotion = self._get_emotion(
model_input, self.seq[:1, :pos], mode, emotion_mode, sentiment["token_name"])
pos = self._update_seq(emotion["token_id"], pos)
return pos
def _sent_emo_act(self, pos, model_input, mode, emotion_mode, allow_general_intent):
# sent
sentiment, pos = self._update_sentiment(pos, model_input, mode)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
# emo
pos = self._update_emotion(
pos, model_input, mode, emotion_mode, sentiment["token_name"])
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
# act
pos = self._update_seq(self.token_map.get_id('start_act'), pos)
pos = self._update_semantic_act(
pos, model_input, mode, allow_general_intent)
return pos
def _emo_act(self, pos, model_input, mode, emotion_mode, allow_general_intent):
# emo
pos = self._update_emotion(
pos, model_input, mode, emotion_mode)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
# act
pos = self._update_seq(self.token_map.get_id('start_act'), pos)
pos = self._update_semantic_act(
pos, model_input, mode, allow_general_intent)
return pos
def _act_emo(self, pos, model_input, mode, emotion_mode, allow_general_intent):
# act
pos = self._update_seq(self.token_map.get_id('start_act'), pos)
pos = self._update_semantic_act(
pos, model_input, mode, allow_general_intent)
# emo
pos = self._update_emotion(
pos, model_input, mode, emotion_mode)
pos = self._update_seq(self.token_map.get_id('sep_token'), pos)
return pos
def _generate_action(self, raw_inputs, mode="max", allow_general_intent=True, emotion_mode="normal"):
self.kg.parse_input(raw_inputs)
model_input = self.vector.encode(raw_inputs, self.max_in_len)
# start token
self.seq = torch.zeros(1, self.max_out_len, device=self.device).long()
pos = self._update_seq([0], 0)
pos = self._update_seq(self.token_map.get_id('start_json'), pos)
if self.use_sentiment and self.emotion_mid:
pos = self._sent_act_emo(
pos, model_input, mode, emotion_mode, allow_general_intent)
elif self.use_sentiment and not self.emotion_mid:
pos = self._sent_emo_act(
pos, model_input, mode, emotion_mode, allow_general_intent)
elif not self.use_sentiment and self.emotion_mid:
pos = self._act_emo(
pos, model_input, mode, emotion_mode, allow_general_intent)
else:
pos = self._emo_act(
pos, model_input, mode, emotion_mode, allow_general_intent)
pos = self._update_seq(self.token_map.get_id("start_text"), pos)
text = self._get_text(model_input, pos)
......@@ -332,8 +384,6 @@ class UserActionPolicy(GenTUSUserActionPolicy):
class UserPolicy(Policy):
def __init__(self,
model_checkpoint,
mode="semantic",
only_action=True,
sample=False,
action_penalty=False,
**kwargs):
......@@ -342,7 +392,8 @@ class UserPolicy(Policy):
os.mkdir(os.path.dirname(model_checkpoint))
model_downloader(os.path.dirname(model_checkpoint),
"https://zenodo.org/record/7372442/files/multiwoz21-exp.zip")
only_action = False
mode = "language"
self.policy = UserActionPolicy(
model_checkpoint,
mode=mode,
......@@ -385,15 +436,15 @@ if __name__ == "__main__":
# from convlab.nlu.jointBERT.multiwoz import BERTNLU
from convlab.util.custom_util import set_seed
set_seed(20220220)
use_sentiment, emotion_mid = True, True
set_seed(0)
# Test semantic level behaviour
model_checkpoint = 'convlab/policy/emoTUS/unify/experiments/emowoz+dialmage_0_1/23-01-11-15-17'
usr_policy = UserPolicy(
model_checkpoint,
mode="language",
only_action=False,
use_sentiment=True,
sample=True)
sample=True,
use_sentiment=use_sentiment,
emotion_mid=emotion_mid)
# usr_policy.policy.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
usr_nlu = None # BERTNLU()
usr = PipelineAgent(usr_nlu, None, usr_policy, None, name='user')
......
......@@ -28,35 +28,34 @@ def arg_parser():
default="")
parser.add_argument("--generated-file", type=str, help="the generated results",
default="")
parser.add_argument("--only-action", action="store_true")
parser.add_argument("--dataset", default="multiwoz")
parser.add_argument("--do-semantic", action="store_true",
help="do semantic evaluation")
parser.add_argument("--do-nlg", action="store_true",
help="do nlg generation")
parser.add_argument("--do-golden-nlg", action="store_true",
help="do golden nlg generation")
parser.add_argument("--no-neutral", action="store_true",
help="skip neutral emotion")
parser.add_argument("--use-sentiment", action="store_true")
parser.add_argument("--emotion-mid", action="store_true")
parser.add_argument("--weight", type=float, default=None)
return parser.parse_args()
class Evaluator:
def __init__(self, model_checkpoint, dataset, model_weight=None, only_action=False, use_sentiment=False, weight=None):
def __init__(self, model_checkpoint, dataset, model_weight=None, **kwargs):
self.dataset = dataset
self.model_checkpoint = model_checkpoint
self.model_weight = model_weight
self.time = f"{datetime.now().strftime('%y-%m-%d-%H-%M')}"
self.use_sentiment = use_sentiment
self.use_sentiment = kwargs.get("use_sentiment", False)
self.add_persona = kwargs.get("add_persona", False)
self.emotion_mid = kwargs.get("emotion_mid", False)
weight = kwargs.get("weight", None)
self.usr = UserActionPolicy(
model_checkpoint,
only_action=only_action,
dataset=self.dataset,
use_sentiment=use_sentiment,
use_sentiment=self.use_sentiment,
add_persona=self.add_persona,
emotion_mid=self.emotion_mid,
weight=weight)
self.usr.load(os.path.join(model_checkpoint, "pytorch_model.bin"))
self.r = {"input": [],
......@@ -66,7 +65,8 @@ class Evaluator:
"gen_acts": [],
"gen_utts": [],
"gen_emotion": []}
if use_sentiment:
if self.use_sentiment:
self.r["golden_sentiment"] = []
self.r["gen_sentiment"] = []
......@@ -81,17 +81,13 @@ class Evaluator:
for x in self.r:
self.r[x].append(temp[x])
def generate_results(self, f_eval, golden=False, no_neutral=False):
def generate_results(self, f_eval, golden=False):
emotion_mode = "normal"
if no_neutral:
emotion_mode = "no_neutral"
in_file = json.load(open(f_eval))
for dialog in tqdm(in_file['dialog']):
for dialog in tqdm(in_file['dialog'][:2]):
inputs = dialog["in"]
labels = self.usr._parse_output(dialog["out"])
if no_neutral and labels["emotion"].lower() == "neutral":
continue
if golden:
usr_act = labels["action"]
......@@ -138,10 +134,10 @@ class Evaluator:
result.append(temp)
return result
def nlg_evaluation(self, input_file=None, generated_file=None, golden=False, no_neutral=False):
def nlg_evaluation(self, input_file=None, generated_file=None, golden=False):
if input_file:
print("Force generation")
self.generate_results(input_file, golden, no_neutral)
self.generate_results(input_file, golden)
elif generated_file:
self.read_generated_result(generated_file)
......@@ -240,7 +236,7 @@ class Evaluator:
for metric in scores:
result[metric] = sum(scores[metric])/len(scores[metric])
print(f"{metric}: {result[metric]}")
# TODO no neutral
emo_score = emotion_score(
golden_emotions,
gen_emotions,
......@@ -338,23 +334,19 @@ def main():
eval = Evaluator(args.model_checkpoint,
args.dataset,
args.model_weight,
args.only_action,
args.use_sentiment,
use_sentiment=args.use_sentiment,
emotion_mid=args.emotion_mid,
weight=args.weight)
print("model checkpoint", args.model_checkpoint)
print("generated_file", args.generated_file)
print("input_file", args.input_file)
with torch.no_grad():
if args.do_semantic:
eval.evaluation(args.input_file)
if args.do_nlg:
if args.generated_file:
generated_file = args.generated_file
else:
nlg_result = eval.nlg_evaluation(input_file=args.input_file,
generated_file=args.generated_file,
golden=args.do_golden_nlg,
no_neutral=args.no_neutral)
golden=args.do_golden_nlg)
generated_file = nlg_result
eval.evaluation(args.input_file,
......
......@@ -2,28 +2,26 @@ import json
class tokenMap:
def __init__(self, tokenizer, use_sentiment=False):
def __init__(self, tokenizer, **kwargs):
self.tokenizer = tokenizer
self.token_name = {}
self.hash_map = {}
self.debug = False
self.use_sentiment = use_sentiment
self.default()
def default(self, only_action=False):
self.format_tokens = {
'start_json': '{"emotion": "', # 49643, 10845, 7862, 646
'start_act': 'action": [["', # 49329
'sep_token': '", "', # 1297('",'), 22
'sep_act': '"], ["', # 49177
'end_act': '"]], "', # 42248, 7479, 22
'start_text': 'text": "', # 29015, 7862, 22
'end_json': '}', # 24303
'end_json_2': '"}' # 48805
'start_json': '{"',
'start_sentiment': 'sentiment": "',
'start_emotion': 'emotion": "',
'start_act': 'action": [["',
'sep_token': '", "',
'sep_act': '"], ["',
'end_act': '"]], "',
'start_text': 'text": "',
'end_json': '}',
'end_json_2': '"}'
}
if self.use_sentiment:
self.format_tokens['start_json'] = '{"sentiment": "'
self.format_tokens['start_emotion'] = 'emotion": "'
if only_action:
self.format_tokens['end_act'] = '"]]}'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment