From 27f2a2b319191e4c04397367aea47410edca6395 Mon Sep 17 00:00:00 2001
From: aaa123git <wandz19@mails.tsinghua.edu.cn>
Date: Tue, 14 Dec 2021 15:19:35 +0800
Subject: [PATCH] * append the last sys turn * fix Booking in scgpt

---
 convlab2/nlg/scgpt/multiwoz/preprocess.py |  3 +++
 convlab2/nlg/scgpt/multiwoz/scgpt.py      | 21 ++++++++++++++++++---
 2 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/convlab2/nlg/scgpt/multiwoz/preprocess.py b/convlab2/nlg/scgpt/multiwoz/preprocess.py
index bcd4d1f9..d7a47bd2 100644
--- a/convlab2/nlg/scgpt/multiwoz/preprocess.py
+++ b/convlab2/nlg/scgpt/multiwoz/preprocess.py
@@ -106,6 +106,9 @@ if __name__ == '__main__':
                 domain = key.split('-')[0]
                 if domain not in ['general', 'Booking']:
                     current_domain = domain
+        else:
+            if args.role == 'sys':
+                turns.append(turn)
         title = title
         if title in val_list:
             current = results_val
diff --git a/convlab2/nlg/scgpt/multiwoz/scgpt.py b/convlab2/nlg/scgpt/multiwoz/scgpt.py
index 5c933cad..78f16f6e 100644
--- a/convlab2/nlg/scgpt/multiwoz/scgpt.py
+++ b/convlab2/nlg/scgpt/multiwoz/scgpt.py
@@ -2,6 +2,7 @@ import torch
 import numpy as np
 import os
 import zipfile
+from copy import deepcopy
 
 from transformers import GPT2LMHeadModel, GPT2Tokenizer
 from convlab2.nlg.scgpt.utils import tuple2seq
@@ -71,8 +72,9 @@ class SCGPT(NLG):
             'Restaurant':False,
             'Taxi':False,
             'Train':False,}
-        if not self.is_user:
-            self.sess_domains['Booking'] = False
+        self.cur_domain = None
+        # if not self.is_user:
+        #     self.sess_domains['Booking'] = False
                 
     def generate(self, meta):
 
@@ -80,10 +82,23 @@ class SCGPT(NLG):
         if not meta:
             return 'No user action'
 
+        meta = deepcopy(meta)
+        for list_ in meta:
+            domain = list_[1]
+            if domain not in ('general', 'Booking'):
+                self.cur_domain = domain
+        for i, list_ in enumerate(meta):
+            list_ = list(list_)
+            if list_[1] == 'Booking':
+                if self.cur_domain is not None:
+                    list_[1] = self.cur_domain
+                    meta[i] = list_
+                else:
+                    print('`cur_domain` is None, but there is `Booking` in dialog action.')
         raw_text = tuple2seq(meta)
         domains = set([item[1] for item in meta])
         for domain in domains:
-            if domain != 'general' and not self.sess_domains[domain]:
+            if domain not in ('general', 'Booking') and not self.sess_domains[domain]:
                 raw_text = raw_text.replace(domain.lower(), domain.lower()+ ' *', 1)
                 self.sess_domains[domain] = True
         context_tokens = self.tokenizer.encode(raw_text, add_special_tokens=False)
-- 
GitLab