From 27f2a2b319191e4c04397367aea47410edca6395 Mon Sep 17 00:00:00 2001 From: aaa123git <wandz19@mails.tsinghua.edu.cn> Date: Tue, 14 Dec 2021 15:19:35 +0800 Subject: [PATCH] * append the last sys turn * fix Booking in scgpt --- convlab2/nlg/scgpt/multiwoz/preprocess.py | 3 +++ convlab2/nlg/scgpt/multiwoz/scgpt.py | 21 ++++++++++++++++++--- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/convlab2/nlg/scgpt/multiwoz/preprocess.py b/convlab2/nlg/scgpt/multiwoz/preprocess.py index bcd4d1f9..d7a47bd2 100644 --- a/convlab2/nlg/scgpt/multiwoz/preprocess.py +++ b/convlab2/nlg/scgpt/multiwoz/preprocess.py @@ -106,6 +106,9 @@ if __name__ == '__main__': domain = key.split('-')[0] if domain not in ['general', 'Booking']: current_domain = domain + else: + if args.role == 'sys': + turns.append(turn) title = title if title in val_list: current = results_val diff --git a/convlab2/nlg/scgpt/multiwoz/scgpt.py b/convlab2/nlg/scgpt/multiwoz/scgpt.py index 5c933cad..78f16f6e 100644 --- a/convlab2/nlg/scgpt/multiwoz/scgpt.py +++ b/convlab2/nlg/scgpt/multiwoz/scgpt.py @@ -2,6 +2,7 @@ import torch import numpy as np import os import zipfile +from copy import deepcopy from transformers import GPT2LMHeadModel, GPT2Tokenizer from convlab2.nlg.scgpt.utils import tuple2seq @@ -71,8 +72,9 @@ class SCGPT(NLG): 'Restaurant':False, 'Taxi':False, 'Train':False,} - if not self.is_user: - self.sess_domains['Booking'] = False + self.cur_domain = None + # if not self.is_user: + # self.sess_domains['Booking'] = False def generate(self, meta): @@ -80,10 +82,23 @@ class SCGPT(NLG): if not meta: return 'No user action' + meta = deepcopy(meta) + for list_ in meta: + domain = list_[1] + if domain not in ('general', 'Booking'): + self.cur_domain = domain + for i, list_ in enumerate(meta): + list_ = list(list_) + if list_[1] == 'Booking': + if self.cur_domain is not None: + list_[1] = self.cur_domain + meta[i] = list_ + else: + print('`cur_domain` is None, but there is `Booking` in dialog action.') raw_text = tuple2seq(meta) domains = set([item[1] for item in meta]) for domain in domains: - if domain != 'general' and not self.sess_domains[domain]: + if domain not in ('general', 'Booking') and not self.sess_domains[domain]: raw_text = raw_text.replace(domain.lower(), domain.lower()+ ' *', 1) self.sess_domains[domain] = True context_tokens = self.tokenizer.encode(raw_text, add_special_tokens=False) -- GitLab