Skip to content
Snippets Groups Projects
Unverified Commit fc978dcb authored by Carel van Niekerk's avatar Carel van Niekerk Committed by GitHub
Browse files

Merge branch 'master' into dsml_convlab

parents 04024645 c6170181
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,7 @@ import json
import os
import random
import sys
import itertools
import zipfile
import numpy
from numpy.lib.shape_base import _put_along_axis_dispatcher
......@@ -211,16 +212,18 @@ if __name__ == '__main__':
numpy.random.seed(seed)
torch.manual_seed(seed)
if len(sys.argv) != 4:
if len(sys.argv) < 4:
print("usage:")
print("\t python evaluate.py dataset model role")
print("\t dataset=MultiWOZ, CrossWOZ, or Camrest")
print("\t model=SCLSTM, SCLSTM_NoUNK, SCGPT or TemplateNLG")
print("\t role=usr/sys")
print("\t [Optional] model_file")
sys.exit()
dataset_name = sys.argv[1]
model_name = sys.argv[2]
role = sys.argv[3]
model_file = sys.argv[4] if len(sys.argv) >= 5 else None
if dataset_name == 'MultiWOZ':
if model_name == 'SCLSTM':
from convlab2.nlg.sclstm.multiwoz import SCLSTM
......@@ -242,17 +245,19 @@ if __name__ == '__main__':
model = TemplateNLG(is_user=False)
elif model_name == 'SCGPT':
from convlab2.nlg.scgpt.multiwoz import SCGPT
if model_file is not None:
print(f"load model at {model_file}")
if role == 'usr':
model = SCGPT(is_user=True)
model = SCGPT(model_file, is_user=True)
elif role == 'sys':
model = SCGPT(is_user=False, model_file='scgpt/trained_output/multiwoz/')
model = SCGPT(model_file, is_user=False)
else:
raise Exception("Available models: SCLSTM, SCGPT, TEMPLATE")
from convlab2.util.dataloader.module_dataloader import SingleTurnNLGDataloader
from convlab2.util.dataloader.dataset_dataloader import MultiWOZDataloader
dataloader = SingleTurnNLGDataloader(dataset_dataloader=MultiWOZDataloader())
data = dataloader.load_data(data_key='all', role=role)['test']
data = dataloader.load_data(data_key='all', role=role, session_id=True)['test']
dialog_acts = []
golden_utts = []
......@@ -262,7 +267,19 @@ if __name__ == '__main__':
sen_num = 0
# sys.stdout = open(sys.argv[2] + '-' + sys.argv[3] + '-' + 'evaluate_logs_neo.txt','w')
assert 'utterance' in data and 'dialog_act' in data and 'session_id' in data
assert len(data['utterance']) == len(data['dialog_act']) == len(data['session_id'])
# Turns during the same session should be contiguous, so we can call init_session at the first turn of a new session.
# This is necessary for SCGPT, but unnecessary for SCLSTM and TemplateNLG.
is_first_turn = []
for _, iterator in itertools.groupby(data['session_id']):
is_first_turn.append(True)
next(iterator)
is_first_turn.extend(False for _ in iterator)
for i in tqdm(range(len(data['utterance']))):
if is_first_turn[i]:
model.init_session()
dialog_acts.append(data['dialog_act'][i])
golden_utts.append(data['utterance'][i])
gen_utts.append(model.generate(data['dialog_act'][i]))
......
......@@ -21,9 +21,22 @@ tar -xvf scgpt.tar.gz
Then
``` python
python train.py --output_dir=trained_output --model_type=gpt2 --model_name_or_path=scgpt --do_train --do_eval --eval_data_file=multiwoz/data/test_sys.txt --overwrite_cache --use_tokenize --train_data_file=multiwoz/data/train_sys.txt --overwrite_output_dir
python train.py --output_dir=trained_output --model_type=gpt2 --model_name_or_path=scgpt --do_train --do_eval --eval_data_file=multiwoz/data/test_sys.txt --use_tokenize --train_data_file=multiwoz/data/train_sys.txt --overwrite_output_dir
```
some tricks (optional training argument):
* `--gradient_accumulation_steps xxx`
* `--fp16`, if it's set, you'd better set `--per_gpu_train_batch_size` to be multiple of 8
* `--max_seq xxx`, it should be larger than the length of the longest sequence. You can set `--max_seq 1024`. The script uses a dynamic sequence length at each training step.
* `--gradient_checkpointing`, it allows larger `per_gpu_train_batch_size`
* `--use_multi_tensor_adamw`, someone says it's a faster optimizer
distributed data parallel:
If multiple GPUs are available, you can run `python -m torch.distributed.launch --nproc_per_node CUDA_COUNT train.py ......`
`CUDA_COUNT` is the number of GPUs. `.....` are arguments of `train.py`.
## Use
```python
......
import warnings
from contextlib import nullcontext
from typing import TYPE_CHECKING
import torch.cuda.amp as amp
import transformers
from transformers import GPT2LMHeadModel
# reference: https://pytorch.org/docs/master/notes/amp_examples.html
class AmpGPT2LMHeadModel(GPT2LMHeadModel):
if TYPE_CHECKING:
# For IDE's code hinting
forward = GPT2LMHeadModel.forward
else:
def forward(self, *args, **kwargs):
with amp.autocast():
return super().forward(*args, **kwargs)
def try_enable_gradient_checkpointing(model: "transformers.modeling_utils.PreTrainedModel"):
if model.supports_gradient_checkpointing:
model.gradient_checkpointing_enable()
else:
warnings.warn(f"{type(model)} doesn't support gradient_checkpointing")
class AmpHelper:
"""
References:
https://pytorch.org/docs/master/notes/amp_examples.html
"""
def __init__(self, use_amp=True):
self.use_amp = use_amp
self.might_enable_autocast = amp.autocast() if use_amp else nullcontext()
self.scaler = amp.GradScaler()
def backward(self, loss):
if self.use_amp:
return self.scaler.scale(loss).backward()
else:
return loss.backward()
def step(self, optimizer):
if self.use_amp:
self.scaler.step(optimizer)
self.scaler.update()
else:
optimizer.step()
def might_unscale_(self, optimizer):
if self.use_amp:
# Unscales the gradients of optimizer's assigned params in-place
self.scaler.unscale_(optimizer)
\ No newline at end of file
......@@ -6,6 +6,7 @@ Created on Mon Sep 14 11:38:53 2020
import os
import json
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from convlab2.nlg.scgpt.utils import dict2dict, dict2seq
import zipfile
......@@ -14,6 +15,52 @@ def read_zipped_json(filepath, filename):
archive = zipfile.ZipFile(filepath, 'r')
return json.load(archive.open(filename))
def init_domain():
return {'Attraction':False,
'Hospital':False,
'Hotel':False,
'Police':False,
'Restaurant':False,
'Taxi':False,
'Train':False}
def write_file(name, data, role='usr'):
with open(f'{name}.txt', 'w', encoding='utf-8') as f:
for ID in data:
sess = data[ID]
sess_domains = init_domain()
for turn in sess:
if role == 'usr':
if not turn['usr_da']:
continue
turn['usr_da'] = eval(str(turn['usr_da']).replace('Bus','Train'))
da_seq = dict2seq(dict2dict(turn['usr_da'])).replace('&', 'and')
domains = set([key.split('-')[0] for key in turn['usr_da'].keys()])
elif role == 'sys':
if not turn['sys_da']:
continue
turn['sys_da'] = eval(str(turn['sys_da']).replace('Bus','Train'))
da_seq = dict2seq(dict2dict(turn['sys_da'])).replace('&', 'and')
domains = set([key.split('-')[0] for key in turn['sys_da'].keys()])
else:
raise NameError('Invalid Role: Select usr/sys.')
for domain in domains:
if domain not in ['general', 'Booking'] and not sess_domains[domain]:
da_seq = da_seq.replace(domain.lower(), domain.lower()+' *', 1)
sess_domains[domain] = True
if role == 'usr':
da_uttr = turn['usr'].replace(' bus ', ' train ').replace('&', 'and')
elif role == 'sys':
da_uttr = turn['sys'].replace(' bus ', ' train ').replace('&', 'and')
f.write(f'{da_seq} & {da_uttr}\n')
if __name__ == '__main__':
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--role', type=str, default='usr')
args = parser.parse_args()
cur_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(
cur_dir)))), 'data/multiwoz/')
......@@ -37,22 +84,22 @@ results_test = {}
for title, sess in data.items():
logs = sess['log']
turns = []
turn = {'turn':0, 'sys':'', 'sys_da':''}
turn = {'turn': 0, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''}
current_domain = None
for i, diag in enumerate(logs):
text = diag['text']
da = diag['dialog_act']
span = diag['span_info']
if i % 2 == 0:
turn['usr'] = text
if current_domain:
da = eval(str(da).replace('Booking', current_domain))
span = eval(str(span).replace('Booking', current_domain))
if i % 2 == 0:
turn['usr'] = text
turn['usr_da'] = da
turn['usr_span'] = span
turns.append(turn)
else:
turn = {'turn': i//2 +1}
turn = {'turn': i//2 + 1, 'sys': '', 'sys_da': '', 'usr': '', 'usr_da': ''}
turn['sys'] = text
turn['sys_da'] = da
turn['sys_span'] = span
......@@ -60,6 +107,9 @@ for title, sess in data.items():
domain = key.split('-')[0]
if domain not in ['general', 'Booking']:
current_domain = domain
else:
if args.role == 'sys':
turns.append(turn)
title = title
if title in val_list:
current = results_val
......@@ -73,41 +123,7 @@ results = eval(str(results).replace(" n't", " not"))
results_val = eval(str(results_val).replace(" n't", " not"))
results_test = eval(str(results_test).replace(" n't", " not"))
def init_domain():
return {'Attraction':False,
'Hospital':False,
'Hotel':False,
'Police':False,
'Restaurant':False,
'Taxi':False,
'Train':False}
def write_file(name, data):
with open(f'{name}.txt', 'w', encoding='utf-8') as f:
for ID in data:
sess = data[ID]
sess_domains = init_domain()
for turn in sess:
# TODO: set option to process usr/sys
if not turn['usr_da']:
continue
turn['usr_da'] = eval(str(turn['usr_da']).replace('Bus','Train'))
da_seq = dict2seq(dict2dict(turn['usr_da'])).replace('&', 'and')
domains = set([key.split('-')[0] for key in turn['usr_da'].keys()])
if not turn['sys_da']:
continue
turn['sys_da'] = eval(str(turn['sys_da']).replace('Bus','Train'))
da_seq = dict2seq(dict2dict(turn['sys_da'])).replace('&', 'and')
domains = set([key.split('-')[0] for key in turn['sys_da'].keys()])
for domain in domains:
if domain not in ['general', 'Booking'] and not sess_domains[domain]:
da_seq = da_seq.replace(domain.lower(), domain.lower()+' *', 1)
sess_domains[domain] = True
da_uttr = turn['usr'].replace(' bus ', ' train ').replace('&', 'and')
da_uttr = turn['sys'].replace(' bus ', ' train ').replace('&', 'and')
f.write(f'{da_seq} & {da_uttr}\n')
if not os.path.exists(os.path.join(cur_dir,'data')):
os.makedirs(os.path.join(cur_dir, 'data'))
write_file(os.path.join(cur_dir, 'data/train'), dict(results, **results_val))
write_file(os.path.join(cur_dir, 'data/test'), results_test)
write_file(os.path.join(cur_dir, f'data/train_{args.role}'), dict(results, **results_val), role=args.role)
write_file(os.path.join(cur_dir, f'data/test_{args.role}'), results_test, role=args.role)
......@@ -2,6 +2,7 @@ import torch
import numpy as np
import os
import zipfile
from copy import deepcopy
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from convlab2.nlg.scgpt.utils import tuple2seq
......@@ -10,23 +11,31 @@ from convlab2.nlg.nlg import NLG
from convlab2.util.file_util import cached_path
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "nlg-gpt-multiwoz.zip")
class SCGPT(NLG):
def __init__(self,
archive_file=DEFAULT_ARCHIVE_FILE,
use_cuda=True,
is_user=False,
model_file='https://convlab.blob.core.windows.net/convlab-2/nlg-gpt-multiwoz.zip'):
def __init__(self, model_file=None,
use_cuda=True, is_user=False):
# If no filename is mentioned then set to default
if not model_file:
if is_user:
model_file = 'https://convlab.blob.core.windows.net/convlab-2/nlg-gpt-multiwoz.zip'
else:
model_file = 'https://zenodo.org/record/5767426/files/neo_scgpt_system.zip'
# Load from file/url
model_dir = os.path.dirname(os.path.abspath(__file__))
if not os.path.isfile(archive_file):
archive_file = cached_path(model_file)
archive = zipfile.ZipFile(archive_file, 'r')
if not os.path.isfile(model_file):
model_file = cached_path(model_file)
if not os.path.isdir(model_file):
archive = zipfile.ZipFile(model_file, 'r')
archive.extractall(model_dir)
# Get model directory
model_file = archive.filelist[0].filename.replace('/', '')
self.model_name_or_path = os.path.join(model_dir, model_file)
else:
self.model_name_or_path = model_file
self.model_name_or_path = os.path.join(model_dir, 'multiwoz')
self.length = 50
self.num_samples = 5
self.temperature = 1.0
......@@ -63,8 +72,9 @@ class SCGPT(NLG):
'Restaurant':False,
'Taxi':False,
'Train':False,}
if not self.is_user:
self.sess_domains['Booking'] = False
self.cur_domain = None
# if not self.is_user:
# self.sess_domains['Booking'] = False
def generate(self, meta):
......@@ -72,10 +82,23 @@ class SCGPT(NLG):
if not meta:
return 'No user action'
meta = deepcopy(meta)
for list_ in meta:
domain = list_[1]
if domain not in ('general', 'Booking'):
self.cur_domain = domain
for i, list_ in enumerate(meta):
list_ = list(list_)
if list_[1] == 'Booking':
if self.cur_domain is not None:
list_[1] = self.cur_domain
meta[i] = list_
else:
print('`cur_domain` is None, but there is `Booking` in dialog action.')
raw_text = tuple2seq(meta)
domains = set([item[1] for item in meta])
for domain in domains:
if domain != 'general' and not self.sess_domains[domain]:
if domain not in ('general', 'Booking') and not self.sess_domains[domain]:
raw_text = raw_text.replace(domain.lower(), domain.lower()+ ' *', 1)
self.sess_domains[domain] = True
context_tokens = self.tokenizer.encode(raw_text, add_special_tokens=False)
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment