diff --git a/convlab2/nlg/scgpt/multiwoz/preprocess.py b/convlab2/nlg/scgpt/multiwoz/preprocess.py index bcd4d1f9fa918dec6f08cbf80c86cc348d9ab23e..d7a47bd2bf7a9dc1772bc861afc63beb6c3ccd8f 100644 --- a/convlab2/nlg/scgpt/multiwoz/preprocess.py +++ b/convlab2/nlg/scgpt/multiwoz/preprocess.py @@ -106,6 +106,9 @@ if __name__ == '__main__': domain = key.split('-')[0] if domain not in ['general', 'Booking']: current_domain = domain + else: + if args.role == 'sys': + turns.append(turn) title = title if title in val_list: current = results_val diff --git a/convlab2/nlg/scgpt/multiwoz/scgpt.py b/convlab2/nlg/scgpt/multiwoz/scgpt.py index 5c933cad0d3ef7aa71335cdf1cab65ea7b4cd795..78f16f6e0b8562c7118a2a6118f0eb5b3287c828 100644 --- a/convlab2/nlg/scgpt/multiwoz/scgpt.py +++ b/convlab2/nlg/scgpt/multiwoz/scgpt.py @@ -2,6 +2,7 @@ import torch import numpy as np import os import zipfile +from copy import deepcopy from transformers import GPT2LMHeadModel, GPT2Tokenizer from convlab2.nlg.scgpt.utils import tuple2seq @@ -71,8 +72,9 @@ class SCGPT(NLG): 'Restaurant':False, 'Taxi':False, 'Train':False,} - if not self.is_user: - self.sess_domains['Booking'] = False + self.cur_domain = None + # if not self.is_user: + # self.sess_domains['Booking'] = False def generate(self, meta): @@ -80,10 +82,23 @@ class SCGPT(NLG): if not meta: return 'No user action' + meta = deepcopy(meta) + for list_ in meta: + domain = list_[1] + if domain not in ('general', 'Booking'): + self.cur_domain = domain + for i, list_ in enumerate(meta): + list_ = list(list_) + if list_[1] == 'Booking': + if self.cur_domain is not None: + list_[1] = self.cur_domain + meta[i] = list_ + else: + print('`cur_domain` is None, but there is `Booking` in dialog action.') raw_text = tuple2seq(meta) domains = set([item[1] for item in meta]) for domain in domains: - if domain != 'general' and not self.sess_domains[domain]: + if domain not in ('general', 'Booking') and not self.sess_domains[domain]: raw_text = raw_text.replace(domain.lower(), domain.lower()+ ' *', 1) self.sess_domains[domain] = True context_tokens = self.tokenizer.encode(raw_text, add_special_tokens=False)