Skip to content
Snippets Groups Projects
Commit 27f2a2b3 authored by aaa123git's avatar aaa123git
Browse files

* append the last sys turn

* fix Booking in scgpt
parent 726a6326
No related branches found
No related tags found
No related merge requests found
......@@ -106,6 +106,9 @@ if __name__ == '__main__':
domain = key.split('-')[0]
if domain not in ['general', 'Booking']:
current_domain = domain
else:
if args.role == 'sys':
turns.append(turn)
title = title
if title in val_list:
current = results_val
......
......@@ -2,6 +2,7 @@ import torch
import numpy as np
import os
import zipfile
from copy import deepcopy
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from convlab2.nlg.scgpt.utils import tuple2seq
......@@ -71,8 +72,9 @@ class SCGPT(NLG):
'Restaurant':False,
'Taxi':False,
'Train':False,}
if not self.is_user:
self.sess_domains['Booking'] = False
self.cur_domain = None
# if not self.is_user:
# self.sess_domains['Booking'] = False
def generate(self, meta):
......@@ -80,10 +82,23 @@ class SCGPT(NLG):
if not meta:
return 'No user action'
meta = deepcopy(meta)
for list_ in meta:
domain = list_[1]
if domain not in ('general', 'Booking'):
self.cur_domain = domain
for i, list_ in enumerate(meta):
list_ = list(list_)
if list_[1] == 'Booking':
if self.cur_domain is not None:
list_[1] = self.cur_domain
meta[i] = list_
else:
print('`cur_domain` is None, but there is `Booking` in dialog action.')
raw_text = tuple2seq(meta)
domains = set([item[1] for item in meta])
for domain in domains:
if domain != 'general' and not self.sess_domains[domain]:
if domain not in ('general', 'Booking') and not self.sess_domains[domain]:
raw_text = raw_text.replace(domain.lower(), domain.lower()+ ' *', 1)
self.sess_domains[domain] = True
context_tokens = self.tokenizer.encode(raw_text, add_special_tokens=False)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment