Skip to content
Snippets Groups Projects
Commit 27f2a2b3 authored by aaa123git's avatar aaa123git
Browse files

* append the last sys turn

* fix Booking in scgpt
parent 726a6326
No related branches found
No related tags found
No related merge requests found
...@@ -106,6 +106,9 @@ if __name__ == '__main__': ...@@ -106,6 +106,9 @@ if __name__ == '__main__':
domain = key.split('-')[0] domain = key.split('-')[0]
if domain not in ['general', 'Booking']: if domain not in ['general', 'Booking']:
current_domain = domain current_domain = domain
else:
if args.role == 'sys':
turns.append(turn)
title = title title = title
if title in val_list: if title in val_list:
current = results_val current = results_val
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
import numpy as np import numpy as np
import os import os
import zipfile import zipfile
from copy import deepcopy
from transformers import GPT2LMHeadModel, GPT2Tokenizer from transformers import GPT2LMHeadModel, GPT2Tokenizer
from convlab2.nlg.scgpt.utils import tuple2seq from convlab2.nlg.scgpt.utils import tuple2seq
...@@ -71,8 +72,9 @@ class SCGPT(NLG): ...@@ -71,8 +72,9 @@ class SCGPT(NLG):
'Restaurant':False, 'Restaurant':False,
'Taxi':False, 'Taxi':False,
'Train':False,} 'Train':False,}
if not self.is_user: self.cur_domain = None
self.sess_domains['Booking'] = False # if not self.is_user:
# self.sess_domains['Booking'] = False
def generate(self, meta): def generate(self, meta):
...@@ -80,10 +82,23 @@ class SCGPT(NLG): ...@@ -80,10 +82,23 @@ class SCGPT(NLG):
if not meta: if not meta:
return 'No user action' return 'No user action'
meta = deepcopy(meta)
for list_ in meta:
domain = list_[1]
if domain not in ('general', 'Booking'):
self.cur_domain = domain
for i, list_ in enumerate(meta):
list_ = list(list_)
if list_[1] == 'Booking':
if self.cur_domain is not None:
list_[1] = self.cur_domain
meta[i] = list_
else:
print('`cur_domain` is None, but there is `Booking` in dialog action.')
raw_text = tuple2seq(meta) raw_text = tuple2seq(meta)
domains = set([item[1] for item in meta]) domains = set([item[1] for item in meta])
for domain in domains: for domain in domains:
if domain != 'general' and not self.sess_domains[domain]: if domain not in ('general', 'Booking') and not self.sess_domains[domain]:
raw_text = raw_text.replace(domain.lower(), domain.lower()+ ' *', 1) raw_text = raw_text.replace(domain.lower(), domain.lower()+ ' *', 1)
self.sess_domains[domain] = True self.sess_domains[domain] = True
context_tokens = self.tokenizer.encode(raw_text, add_special_tokens=False) context_tokens = self.tokenizer.encode(raw_text, add_special_tokens=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment