diff --git a/Dockerfile b/Dockerfile index 49bcaad8bb0420937c43c06f3ab1427093b33faf..b4cbcdf97a280df63fc83b173193635a25f28bfc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -43,7 +43,9 @@ RUN pip install quadprog RUN pip install pyyaml RUN pip install fuzzywuzzy RUN pip install python-Levenshtein - +RUN pip install gtts +RUN pip install DeepSpeech +RUN pip install pydub RUN [ "python", "-c", "import nltk; nltk.download('stopwords')" ] diff --git a/convlab2/laug/README.md b/convlab2/laug/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4c94d4821b9acdf5e9fd55c1f4cb32745a7bbad1 --- /dev/null +++ b/convlab2/laug/README.md @@ -0,0 +1,17 @@ +# LAUG +**LAUG**[[repo]](https://github.com/thu-coai/LAUG/) is an open-source toolkit for Language understanding AUGmentation. It is an automatic method to approximate the natural perturbations to existing data. Augmented data could be used to conduct black-box robustness testing or enhancing training. [[paper]](https://arxiv.org/abs/2012.15262) + +Here are the 4 augmentation methods described in our paper. +- Word Perturbation, at `Word_Perturbation/` dir. +- Text Paraphrasing, at `Text_Paraphrasing/`dir. +- Speech Recognition, at `Speech_Recognition/`dir. +- Speech Disfluency, at `Speech_Disfluency/`dir. + +Please see our paper and README.md in each augmentation method for detailed information. + +See `demo.py` for the usage of these augmentation methods. +> python demo.py + + +Noting that our augmentation methods contains several neural models, pre-trained parameters need to be downloaded before use. Parameters pre-trained by us are available at [Link](http://115.182.62.174:9876/). For parameters which released by others, please follow the instructions of each method. + diff --git a/convlab2/laug/Speech_Disfluency/LSTMCRF.py b/convlab2/laug/Speech_Disfluency/LSTMCRF.py new file mode 100644 index 0000000000000000000000000000000000000000..0efcc3343312e819ecd08394542aaea12248743a --- /dev/null +++ b/convlab2/laug/Speech_Disfluency/LSTMCRF.py @@ -0,0 +1,188 @@ +# -*- coding: utf-8 -*- + +# Arranged from pytorch official tutorials + +import torch +import torch.autograd as autograd +import torch.nn as nn +import torch.optim as optim +import json +torch.manual_seed(1) + +##################################################################### +# Helper functions to make the code more readable. + + +def argmax(vec): + # return the argmax as a python int + _, idx = torch.max(vec, 1) + return idx.item() + + + + +# Compute log sum exp in a numerically stable way for the forward algorithm +def log_sum_exp(vec): + max_score = vec[0, argmax(vec)] + max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1]) + return max_score + \ + torch.log(torch.sum(torch.exp(vec - max_score_broadcast))) + +##################################################################### +# Create model + + +class BiLSTM_CRF(nn.Module): + + def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim,emb_weights): + super(BiLSTM_CRF, self).__init__() + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + self.vocab_size = vocab_size + self.tag_to_ix = tag_to_ix + self.tagset_size = len(tag_to_ix) + + #self.word_embeds = nn.Embedding(vocab_size, embedding_dim) + self.word_embeds=nn.Embedding.from_pretrained(emb_weights) + self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, + num_layers=1, bidirectional=True) + + # Maps the output of the LSTM into tag space. + self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size) + + # Matrix of transition parameters. Entry i,j is the score of + # transitioning *to* i *from* j. + self.transitions = nn.Parameter( + torch.randn(self.tagset_size, self.tagset_size)) + + # These two statements enforce the constraint that we never transfer + # to the start tag and we never transfer from the stop tag + self.transitions.data[tag_to_ix[START_TAG], :] = -10000 + self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000 + + self.hidden = self.init_hidden() + + def init_hidden(self): + return (torch.randn(2, 1, self.hidden_dim // 2), + torch.randn(2, 1, self.hidden_dim // 2)) + + def _forward_alg(self, feats): + # Do the forward algorithm to compute the partition function + init_alphas = torch.full((1, self.tagset_size), -10000.) + # START_TAG has all of the score. + init_alphas[0][self.tag_to_ix[START_TAG]] = 0. + + # Wrap in a variable so that we will get automatic backprop + forward_var = init_alphas + + # Iterate through the sentence + for feat in feats: + alphas_t = [] # The forward tensors at this timestep + for next_tag in range(self.tagset_size): + # broadcast the emission score: it is the same regardless of + # the previous tag + emit_score = feat[next_tag].view( + 1, -1).expand(1, self.tagset_size) + # the ith entry of trans_score is the score of transitioning to + # next_tag from i + trans_score = self.transitions[next_tag].view(1, -1) + # The ith entry of next_tag_var is the value for the + # edge (i -> next_tag) before we do log-sum-exp + next_tag_var = forward_var + trans_score + emit_score + # The forward variable for this tag is log-sum-exp of all the + # scores. + alphas_t.append(log_sum_exp(next_tag_var).view(1)) + forward_var = torch.cat(alphas_t).view(1, -1) + terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]] + alpha = log_sum_exp(terminal_var) + return alpha + + def _get_lstm_features(self, sentence): + self.hidden = self.init_hidden() + embeds = self.word_embeds(sentence).view(len(sentence), 1, -1) + lstm_out, self.hidden = self.lstm(embeds, self.hidden) + lstm_out = lstm_out.view(len(sentence), self.hidden_dim) + lstm_feats = self.hidden2tag(lstm_out) + return lstm_feats + + def _score_sentence(self, feats, tags): + # Gives the score of a provided tag sequence + score = torch.zeros(1) + tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags]) + for i, feat in enumerate(feats): + score = score + \ + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]] + score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]] + return score + + def _viterbi_decode(self, feats): + backpointers = [] + + # Initialize the viterbi variables in log space + init_vvars = torch.full((1, self.tagset_size), -10000.) + init_vvars[0][self.tag_to_ix[START_TAG]] = 0 + + # forward_var at step i holds the viterbi variables for step i-1 + forward_var = init_vvars + for feat in feats: + bptrs_t = [] # holds the backpointers for this step + viterbivars_t = [] # holds the viterbi variables for this step + + for next_tag in range(self.tagset_size): + # next_tag_var[i] holds the viterbi variable for tag i at the + # previous step, plus the score of transitioning + # from tag i to next_tag. + # We don't include the emission scores here because the max + # does not depend on them (we add them in below) + next_tag_var = forward_var + self.transitions[next_tag] + best_tag_id = argmax(next_tag_var) + bptrs_t.append(best_tag_id) + viterbivars_t.append(next_tag_var[0][best_tag_id].view(1)) + # Now add in the emission scores, and assign forward_var to the set + # of viterbi variables we just computed + forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1) + backpointers.append(bptrs_t) + + # Transition to STOP_TAG + terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]] + best_tag_id = argmax(terminal_var) + path_score = terminal_var[0][best_tag_id] + + # Follow the back pointers to decode the best path. + best_path = [best_tag_id] + for bptrs_t in reversed(backpointers): + best_tag_id = bptrs_t[best_tag_id] + best_path.append(best_tag_id) + # Pop off the start tag (we dont want to return that to the caller) + start = best_path.pop() + assert start == self.tag_to_ix[START_TAG] # Sanity check + best_path.reverse() + return path_score, best_path + + def neg_log_likelihood(self, sentence, tags): + feats = self._get_lstm_features(sentence) + forward_score = self._forward_alg(feats) + gold_score = self._score_sentence(feats, tags) + return forward_score - gold_score + + def forward(self, sentence): # dont confuse this with _forward_alg above. + # Get the emission scores from the BiLSTM + lstm_feats = self._get_lstm_features(sentence) + + # Find the best path, given the features. + score, tag_seq = self._viterbi_decode(lstm_feats) + return score, tag_seq + +##################################################################### +# Run training + +START_TAG = "<START>" +STOP_TAG = "<STOP>" +EMBEDDING_DIM = 100 +HIDDEN_DIM = 100 + + + + + + diff --git a/convlab2/laug/Speech_Disfluency/README.md b/convlab2/laug/Speech_Disfluency/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0f7fcd2c5be8f86e1773d52e1a17bb489e7d9372 --- /dev/null +++ b/convlab2/laug/Speech_Disfluency/README.md @@ -0,0 +1,14 @@ +## Speech_Disfluency + +The interruption points are predictedby a Bi-LSTM+CRF model. + +The fillerwords, restart terms, and edit terms and their occurrence frequency are all sampled from their distribution in SwitchBoard. + + +## Bi-LSTM+CRF model + +Bi-LSTM+CRF model is trained on SwitchBoard data. + +Please download the pre-trained parameters and disfluency resources at [Link](http://115.182.62.174:9876/). + +The model requires glove.6B.100d wordvector, please modify line22 in inference.py. diff --git a/convlab2/laug/Speech_Disfluency/Speech_Disfluency.py b/convlab2/laug/Speech_Disfluency/Speech_Disfluency.py new file mode 100644 index 0000000000000000000000000000000000000000..988e74af48e77a797c6edc4f86ec6bbe8a582f66 --- /dev/null +++ b/convlab2/laug/Speech_Disfluency/Speech_Disfluency.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +import json +import random +from fuzzywuzzy import fuzz +from convlab2.laug.Speech_Disfluency.inference import IP_model +import os + +current_path=os.path.dirname(os.path.abspath(__file__)) +def random_01(possibility): + x=random.random() + if x>=possibility: + return 0 + else: + return 1 + +def random_pick_from_list(random_list): + return random_list[int(len(random_list)*random.random())] + +def process_distribution_dict(distribution_dict): + processed_distribution=[] + sum=0 + for key in distribution_dict: + sum+=distribution_dict[key] + processed_distribution.append((key,sum)) + return processed_distribution + +def random_pick_from_distribution(distribution_dict): + processed_distribution=process_distribution_dict(distribution_dict) + x=random.random()*processed_distribution[-1][1] + for item in processed_distribution: + if x>item[1]: + continue + else: + picked_item=item[0] + break + return picked_item + +def preprocess(sentence): + word_list=sentence.lower().strip().split() + return word_list + +class Speech_Disfluency: + def __init__(self,dataset='multiwoz',edit_frequency=0.3): + self.resources=json.load(open(os.path.join(current_path,'resources/resources_'+dataset+'.json'),'r')) + self.edit_frequency=edit_frequency + + + def protect_slots(self,word_list,spans,IP_tags): + sentence=' '.join(word_list)+' ' + for span in spans: + value=span[2] + start=sentence.count(' ',0,sentence.find(' '+value+' ')) + lenth=len(value.split()) + for i in range(start+1,start+lenth): + IP_tags[i]=0 + IP_tags[start]=1 + if IP_tags[start]==2: + IP_tags[start]=1 + return IP_tags + + + def add_repairs(self,word_list,spans): + sentence=' '+' '.join(word_list)+' ' + if len(spans)==0: + return word_list + else: + edit_possibility=self.edit_frequency/len(spans) + for span in spans: + if random_01(edit_possibility)==0: + continue + value=span[2] + start=sentence.count(' ',0,sentence.find(' '+value+' '))-1 + + max_ratio,max_entity=0,'' + for e in self.resources["knowledge_base"]["entity"]: + ratio=fuzz.ratio(e,value) + if ratio>max_ratio: + max_ratio=ratio + max_entity=e + if max_entity!='' and max_ratio>60: + candidate=[] + if max_entity in self.resources["knowledge_base"]["entity"]: + candidate=self.resources["knowledge_base"]["category"][random_pick_from_list(self.resources["knowledge_base"]["entity"][max_entity])][0:] + if span in candidate: + candidate.remove(span) + if len(candidate)!=0: + word_list[start]=random_pick_from_list(candidate)+' '+random_pick_from_list(self.resources["edit_terms"])+' '+word_list[start] + return word_list + + def add_repeats(self,word_list,IP_tags): + for i in range(len(IP_tags)): + if IP_tags[i]==2: + word_list[i]=word_list[i]+random_pick_from_list([' ',' , '])+word_list[i] + return word_list + + + def add_fillers(self,word_list,IP_tags): + for i in range(len(IP_tags)): + if IP_tags[i]==1: + word_list[i]=random_pick_from_distribution(self.resources["filler_terms"])+' '+word_list[i] + return word_list + + def add_restart(self,word_list): + word_list[0]=random_pick_from_distribution(self.resources["restart_terms"])+' '+word_list[0] + return word_list + + + def find_spans(self,disfluent_sentence,spans): + checked=1 + sentence=' '+disfluent_sentence+' ' + for i in range(len(spans)): + value=spans[i][2] + start=sentence.count(' ',0,sentence.find(' '+value+' ')) + lenth=len(value.split()) + spans[i][3]=start + spans[i][4]=start+lenth-1 + if ' '.join(sentence.split()[spans[i][3]:spans[i][4]+1])!=spans[i][2]: + checked=0 + return spans,checked + + def aug(self,sentence,spans): + word_list=preprocess(sentence) + IP_tags=IP_model(word_list) + IP_tags=self.protect_slots(word_list,spans,IP_tags) + word_list=self.add_repairs(word_list,spans) + word_list=self.add_repeats(word_list,IP_tags) + word_list=self.add_fillers(word_list,IP_tags) + word_list=self.add_restart(word_list) + disfluent_sentence=' '.join(word_list) + new_spans,checked=self.find_spans(disfluent_sentence,spans) + return disfluent_sentence,new_spans + # input sentence and span_info ; output the disfluent sentence and new_span_info + +if __name__=="__main__": + text = "I want a train to Cambridge" + span_info = [["Train-Inform","Dest","Cambridge",5,5]] + SR = Speech_Disfluency() + new_text,new_span_info = SR.aug(text,span_info) + print(new_text) + print(new_span_info) diff --git a/convlab2/laug/Speech_Disfluency/__init__.py b/convlab2/laug/Speech_Disfluency/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/convlab2/laug/Speech_Disfluency/__init__.py @@ -0,0 +1 @@ + diff --git a/convlab2/laug/Speech_Disfluency/inference.py b/convlab2/laug/Speech_Disfluency/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..278d7b38c3ca8712630cae7a658c912dbf539b32 --- /dev/null +++ b/convlab2/laug/Speech_Disfluency/inference.py @@ -0,0 +1,55 @@ +from .LSTMCRF import BiLSTM_CRF +import json +import numpy as np +import torch +import os +START_TAG = "<START>" +STOP_TAG = "<STOP>" +EMBEDDING_DIM = 100 +HIDDEN_DIM = 100 + +# Make up some training data +def prepare_sequence(seq, to_ix): + idxs=[] + for w in seq: + if w in to_ix: + idxs.append(to_ix[w]) + else: + idxs.append(0) + return torch.tensor(idxs, dtype=torch.long) + +# Put your dir to glove here +glove_file='[dir_to]/glove.6B.100d.txt' + +word_to_ix={} +max=20000 +ifs=open(glove_file, 'r') +word_to_ix['<unk>'] = 0 +weights=[] +weights.append(torch.from_numpy(np.array([0.]*100))) +for i,line in enumerate(ifs.readlines()): + if i>=max: + break + line_list = line.split() + word = line_list[0] + embed = line_list[1:] + embed = torch.from_numpy(np.array([float(num) for num in embed])) + word_to_ix[word] = i+1 + weights.append(embed) + +weights = torch.stack(weights, 0).float() + +tag_to_ix = {"O": 0, "F": 1, "R": 2, START_TAG: 3, STOP_TAG: 4} + +model = BiLSTM_CRF(len(word_to_ix), tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM,weights) +model_path=os.path.dirname(os.path.abspath(__file__)) +model.load_state_dict(torch.load(os.path.join(model_path,'model/LSTMCRF.bin'))) + +def IP_model(word_list): + with torch.no_grad(): + precheck_sent = prepare_sequence(word_list, word_to_ix) + return model(precheck_sent)[1] + +if __name__=="__main__": + sent="okay , i like to do weight training and cycling ." + print(IP_model(sent.split())) diff --git a/convlab2/laug/Speech_Disfluency/train.py b/convlab2/laug/Speech_Disfluency/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e304c7bef497afaf0654db0e309d4323b36d7b6e --- /dev/null +++ b/convlab2/laug/Speech_Disfluency/train.py @@ -0,0 +1,91 @@ +import torch +import torch.autograd as autograd +import torch.nn as nn +import torch.optim as optim +import json +from LSTMCRF2 import BiLSTM_CRF +import numpy as np +from progressbar import progressbar + + + + +def prepare_sequence(seq, to_ix): + idxs=[] + for w in seq: + if w in to_ix: + idxs.append(to_ix[w]) + else: + idxs.append(0) + return torch.tensor(idxs, dtype=torch.long) + +START_TAG = "<START>" +STOP_TAG = "<STOP>" +EMBEDDING_DIM = 100 +HIDDEN_DIM = 100 + +# Make up some training data + + + +data=json.load(open('SWBD/data.json','r')) +training_data=[] +for d in data: + training_data.append((d['text'],d['tags'])) +print(len(training_data)) +glove_file='' + +word_to_ix={} +max=20000 +ifs=open(glove_file, 'r') +word_to_ix['<unk>'] = 0 +weights=[] +weights.append(torch.from_numpy(np.array([0.]*100))) +for i,line in enumerate(ifs.readlines()): + if i>=max: + break + line_list = line.split() + word = line_list[0] + embed = line_list[1:] + embed = torch.from_numpy(np.array([float(num) for num in embed])) + word_to_ix[word] = i+1 + weights.append(embed) + +weights = torch.stack(weights, 0).float() + + +tag_to_ix = {"O": 0, "F": 1, "R": 2, START_TAG: 3, STOP_TAG: 4} + +model = BiLSTM_CRF( len(word_to_ix), tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM,weights) +model + +optimizer = optim.Adam(model.parameters(), lr=0.001) + + +with torch.no_grad(): + precheck_sent = prepare_sequence(training_data[0][0], word_to_ix) + precheck_tags = torch.tensor([tag_to_ix[t] for t in training_data[0][1]], dtype=torch.long) + print(model(precheck_sent)) + +ep=0 +for epoch in range(30): + n,losses=0,0. + ep+=1 + for sentence, tags in progressbar(training_data): + model.zero_grad() + sentence_in = prepare_sequence(sentence, word_to_ix) + targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long) + loss = model.neg_log_likelihood(sentence_in, targets) + losses+=loss + n+=1 + loss.backward() + optimizer.step() + torch.save(model.state_dict(), 'model/LSTMCRF_'+str(ep)+'.bin') + print('loss:'+str(losses/n)) + with torch.no_grad(): + precheck_sent = prepare_sequence("okay , i like to do , weight training and cycling .".split(), word_to_ix) + print(model(precheck_sent)) + precheck_sent = prepare_sequence(training_data[1][0], word_to_ix) + print(model(precheck_sent)) + precheck_sent = prepare_sequence('i want to go to cambridge .'.split(), word_to_ix) + print(model(precheck_sent)) \ No newline at end of file diff --git a/convlab2/laug/Speech_Recognition/ASR.py b/convlab2/laug/Speech_Recognition/ASR.py new file mode 100644 index 0000000000000000000000000000000000000000..0b9beddb75906ec82533de301c1f809428016cd9 --- /dev/null +++ b/convlab2/laug/Speech_Recognition/ASR.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from __future__ import absolute_import, division, print_function +import os +import argparse +import numpy as np +import shlex +import subprocess +import sys +import wave +import json + +from deepspeech import Model, version +from timeit import default_timer as timer + +try: + from shhlex import quote +except ImportError: + from pipes import quote + + +def convert_samplerate(audio_path, desired_sample_rate): + sox_cmd = 'sox {} --type raw --bits 16 --channels 1 --rate {} --encoding signed-integer --endian little --compression 0.0 --no-dither - '.format(quote(audio_path), desired_sample_rate) + try: + output = subprocess.check_output(shlex.split(sox_cmd), stderr=subprocess.PIPE) + except subprocess.CalledProcessError as e: + raise RuntimeError('SoX returned non-zero status: {}'.format(e.stderr)) + except OSError as e: + raise OSError(e.errno, 'SoX not found, use {}hz files or install it: {}'.format(desired_sample_rate, e.strerror)) + + return desired_sample_rate, np.frombuffer(output, np.int16) + + +def metadata_to_string(metadata): + return ''.join(token.text for token in metadata.tokens) + + +def words_from_candidate_transcript(metadata): + word = "" + word_list = [] + word_start_time = 0 + # Loop through each character + for i, token in enumerate(metadata.tokens): + # Append character to word if it's not a space + if token.text != " ": + if len(word) == 0: + # Log the start time of the new word + word_start_time = token.start_time + + word = word + token.text + # Word boundary is either a space or the last character in the array + if token.text == " " or i == len(metadata.tokens) - 1: + word_duration = token.start_time - word_start_time + + if word_duration < 0: + word_duration = 0 + + each_word = dict() + each_word["word"] = word + each_word["start_time"] = round(word_start_time, 4) + each_word["duration"] = round(word_duration, 4) + + word_list.append(each_word) + # Reset + word = "" + word_start_time = 0 + + return word_list + + +def metadata_json_output(metadata): + json_result = dict() + json_result["transcripts"] = [{ + "confidence": transcript.confidence, + "words": words_from_candidate_transcript(transcript), + } for transcript in metadata.transcripts] + return json.dumps(json_result, indent=2) + + + +class VersionAction(argparse.Action): + def __init__(self, *args, **kwargs): + super(VersionAction, self).__init__(nargs=0, *args, **kwargs) + + def __call__(self, *args, **kwargs): + print('DeepSpeech ', version()) + exit(0) + + +class wav2text(): + def __init__(self,): + + print('Loading model from file {}'.format(args.model), file=sys.stderr) + model_load_start = timer() + # sphinx-doc: python_ref_model_start + model_path=os.path.dirname(os.path.abspath(__file__)) + + ds = Model(os.path.join(model_path,args.model)) + # sphinx-doc: python_ref_model_stop + model_load_end = timer() - model_load_start + print('Loaded model in {:.3}s.'.format(model_load_end), file=sys.stderr) + + if args.beam_width: + ds.setBeamWidth(args.beam_width) + + self.desired_sample_rate = ds.sampleRate() + + + + if args.scorer: + print('Loading scorer from files {}'.format(args.scorer), file=sys.stderr) + scorer_load_start = timer() + ds.enableExternalScorer(os.path.join(model_path,args.scorer)) + scorer_load_end = timer() - scorer_load_start + print('Loaded scorer in {:.3}s.'.format(scorer_load_end), file=sys.stderr) + + if args.lm_alpha and args.lm_beta: + ds.setScorerAlphaBeta(args.lm_alpha, args.lm_beta) + + if args.hot_words: + print('Adding hot-words', file=sys.stderr) + for word_boost in args.hot_words.split(','): + word,boost = word_boost.split(':') + ds.addHotWord(word,float(boost)) + self.ds=ds + + def run(self,audio): + fin = wave.open(audio, 'rb') + fs_orig = fin.getframerate() + if fs_orig != self.desired_sample_rate: + print('Warning: original sample rate ({}) is different than {}hz. Resampling might produce erratic speech recognition.'.format(fs_orig, desired_sample_rate), file=sys.stderr) + fs_new, audio = convert_samplerate(args.audio, desired_sample_rate) + else: + audio = np.frombuffer(fin.readframes(fin.getnframes()), np.int16) + + audio_length = fin.getnframes() * (1/fs_orig) + fin.close() + + inference_start = timer() + # sphinx-doc: python_ref_inference_start + text=self.ds.stt(audio) + #print(text) + # sphinx-doc: python_ref_inference_stop + inference_end = timer() - inference_start + #print('Inference took %0.3fs for %0.3fs audio file.' % (inference_end, audio_length), file=sys.stderr) + return text + + + +parser = argparse.ArgumentParser(description='Running DeepSpeech inference.') +parser.add_argument('--model', required=False,default='deepspeech-0.9.3-models.pbmm', + help='Path to the model (protocol buffer binary file)') +parser.add_argument('--scorer', required=False,default='deepspeech-0.9.3-models.scorer', + help='Path to the external scorer file') +parser.add_argument('--audio', required=False, + help='Path to the audio file to run (WAV format)') +parser.add_argument('--beam_width', type=int, + help='Beam width for the CTC decoder') +parser.add_argument('--lm_alpha', type=float, + help='Language model weight (lm_alpha). If not specified, use default from the scorer package.') +parser.add_argument('--lm_beta', type=float, + help='Word insertion bonus (lm_beta). If not specified, use default from the scorer package.') +parser.add_argument('--version', action=VersionAction, + help='Print version and exits') +parser.add_argument('--extended', required=False, action='store_true', + help='Output string from extended metadata') +parser.add_argument('--json', required=False, action='store_true', + help='Output json from metadata with timestamp of each word') +parser.add_argument('--candidate_transcripts', type=int, default=3, + help='Number of candidate transcripts to include in JSON output') +parser.add_argument('--hot_words', type=str, + help='Hot-words and their boosts.') +args = parser.parse_args() + + diff --git a/convlab2/laug/Speech_Recognition/README.md b/convlab2/laug/Speech_Recognition/README.md new file mode 100644 index 0000000000000000000000000000000000000000..77a465277d38786af90ac0142f098c6092dca769 --- /dev/null +++ b/convlab2/laug/Speech_Recognition/README.md @@ -0,0 +1,18 @@ +# Speech Recognition + +A TTS+ASR pipeline to simulate speech characteristics and recognition error. + +## TTS + +We use gTTS as the TTS moudle. +Pleas install ffmpeg before use: +```bash +conda install ffmpeg +``` + +## ASR + +We use DeepSpeech as the ASR moudle. Noting that we use DeepSpeech2 to conduct our experiments in our paper, but in this released toolkit we choose DeepSpeech instead for higher efficiency. + +Please download [released models](https://github.com/mozilla/DeepSpeech/releases/tag/v0.9.3) before use. +Please download deepspeech-0.9.3-models.pbmm and deepspeech-0.9.3-models.scorer place them under `Speech Recognition/` dir. diff --git a/convlab2/laug/Speech_Recognition/Speech_Recognition.py b/convlab2/laug/Speech_Recognition/Speech_Recognition.py new file mode 100644 index 0000000000000000000000000000000000000000..15bf06e48a8aef506f6daeb2fdba975b95c0a98b --- /dev/null +++ b/convlab2/laug/Speech_Recognition/Speech_Recognition.py @@ -0,0 +1,37 @@ +#coding: UTF-8 +from convlab2.laug.Speech_Recognition.ASR import wav2text +from convlab2.laug.Speech_Recognition.TTS import text2wav +from convlab2.laug.Speech_Recognition.multiwoz.span_detection import span_detect + +import os +import time + +class Speech_Recognition: + def __init__(self,dataset='multiwoz',temp_file='temp',tld='com'): + + self.wav2text = wav2text() + self.temp_file = temp_file + self.tld = tld + def aug(self,text,span_info): + ok=0 + while ok==0: + try: + text2wav(text,tld=self.tld,filename=self.temp_file) + except ValueError: + ok=0 + print("gTTS error occur!") + else: + ok=1 + new_text = self.wav2text.run(self.temp_file+".wav") + new_span_info=[] + for span in span_info: + new_span_info.append(span_detect(text,new_text,span)) + return new_text,new_span_info + +if __name__=="__main__": + text = "I want a train to Cambridge" + span_info = [["Train-Inform","Dest","Cambridge",5,5]] + SR = Speech_Recognition() + new_text,new_span_info = SR.aug(text,span_info) + print(new_text) + print(new_span_info) diff --git a/convlab2/laug/Speech_Recognition/TTS.py b/convlab2/laug/Speech_Recognition/TTS.py new file mode 100644 index 0000000000000000000000000000000000000000..b570076124f2743023b7174aa4fb22221b3579f4 --- /dev/null +++ b/convlab2/laug/Speech_Recognition/TTS.py @@ -0,0 +1,10 @@ +#coding: UTF-8 +from gtts import gTTS +from pydub.audio_segment import AudioSegment +import os + + +def text2wav(text,language='en',filename='temp',tld='cn'): + gTTS(text=text, tld=tld,lang=language).save(filename+".mp3") + AudioSegment.from_mp3(filename+".mp3").set_frame_rate(16000).export(filename+".wav", format="wav") + diff --git a/convlab2/laug/Speech_Recognition/__init__.py b/convlab2/laug/Speech_Recognition/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c68785e9d0b64e1fe46403c4316a9fe1ea36eeb --- /dev/null +++ b/convlab2/laug/Speech_Recognition/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- \ No newline at end of file diff --git a/convlab2/laug/Speech_Recognition/multiwoz/detection_utils.py b/convlab2/laug/Speech_Recognition/multiwoz/detection_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9e5151dcb617315df07abbd47885397daea54fa8 --- /dev/null +++ b/convlab2/laug/Speech_Recognition/multiwoz/detection_utils.py @@ -0,0 +1,78 @@ +import locale; + +NUMBER_CONSTANT = {0:"zero ", 1:"one", 2:"two", 3:"three", 4:"four", 5:"five", 6:"six", 7:"seven", + 8:"eight", 9:"nine", 10:"ten", 11:"eleven", 12:"twelve", 13:"thirteen", + 14:"fourteen", 15:"fifteen", 16:"sixteen", 17:"seventeen", 18:"eighteen", 19:"nineteen" }; +IN_HUNDRED_CONSTANT = {2:"twenty", 3:"thirty", 4:"forty", 5:"fifty", 6:"sixty", 7:"seventy", 8:"eighty", 9:"ninety"} +BASE_CONSTANT = {0:" ", 1:"hundred", 2:"thousand", 3:"million", 4:"billion"}; + +#supported number range is 1-n billion; +def translateNumberToEnglish(number): + if str(number).isnumeric(): + if str(number)[0] == '0' and len(str(number)) > 1: + return translateNumberToEnglish(int(number[1:])); + if int(number) < 20: + return NUMBER_CONSTANT[int(number)]; + elif int(number) < 100: + if str(number)[1] == '0': + return IN_HUNDRED_CONSTANT[int(str(number)[0])]; + else: + return IN_HUNDRED_CONSTANT[int(str(number)[0])] + " " + NUMBER_CONSTANT[int(str(number)[1])]; + else: + #locale.setlocale(locale.LC_ALL, "English_United States.1252"); + #strNumber = locale.format("%d" , number, grouping=True); + strNumber=str(number) + numberArray = str(strNumber).split(","); + stringResult = ""; + groupCount = len(numberArray) + 1; + for groupNumber in numberArray: + if groupCount > 1 and groupNumber[0:] != "000": + stringResult += str(getUnderThreeNumberString(str(groupNumber))) + " "; + else: + break; + groupCount -= 1; + if groupCount > 1: + stringResult += BASE_CONSTANT[groupCount] + " "; + endPoint = len(stringResult) - len(" hundred,"); + #return stringResult[0:endPoint]; + return stringResult; + + else: + print("please input a number!"); + +#between 0-999 +def getUnderThreeNumberString(number): + if str(number).isnumeric() and len(number) < 4: + if len(number) < 3: + return translateNumberToEnglish(int(number)); + elif len(number) == 3 and number[0:] == "000": + return " "; + elif len(number) == 3 and number[1:] == "00": + return NUMBER_CONSTANT[int(number[0])] + " " + BASE_CONSTANT[1]; + else: + return NUMBER_CONSTANT[int(number[0])] + " " + BASE_CONSTANT[1] + " and " + translateNumberToEnglish((number[1:])); + +def translateTimeToEnglish(t): + t=t.split(':') + if t[1]!='00': + return translateNumberToEnglish(t[0])+' '+translateNumberToEnglish(t[1]) + else: + return translateNumberToEnglish(t[0])+' '+'o\'clock' + +def span_typer(s): + if s.isnumeric(): + return "number" + if s.find(':')>=0: + s=s.split(':') + if len(s)==2: + if s[0].isnumeric() and s[1].isnumeric(): + return "time" + return "none" + +def replacer(s): + s=s.replace(' n\'t','n\'t') + s=s.replace(' \'ll','\'ll') + s=s.replace('centre','center') + s=s.replace('-star',' star') + s=s.replace('guesthouse','guest house') + return s \ No newline at end of file diff --git a/convlab2/laug/Speech_Recognition/multiwoz/paraphrase_span_detection.py b/convlab2/laug/Speech_Recognition/multiwoz/paraphrase_span_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..af76bc58e771409b58fcb41f5ec665d6744c053e --- /dev/null +++ b/convlab2/laug/Speech_Recognition/multiwoz/paraphrase_span_detection.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 11 17:49:53 2020 + +@author: truthless +""" + +import spacy +from fuzzywuzzy import fuzz + +digit2word = { + '0': 'zero', '1': 'one', '2': 'two', '3': 'three', '4': 'four', '5': 'five', + '6': 'six', '7': 'seven', '8': 'eight', '9': 'nine', '10': 'ten', '11': 'eleven', + '12': 'twelve' +} +word2digit = {v:k for k,v in digit2word.items()} + +#nlp = spacy.load('en_core_web_sm') +threshold = 55 + +def digit_normalize(utt_list): + for i, text in enumerate(utt_list): + if text in word2digit: + utt_list[i] = word2digit[text] + return utt_list + +def phrase_idx_utt(value_list, utt_list): + utt_list = digit_normalize(utt_list) + candidates = [] + l = len(value_list) + for i in [l, l-1, l+1]: + if i == 0: + continue + for j in range(len(utt_list)-i+1): + score = fuzz.ratio(' '.join(utt_list[j:j+i]), ' '.join(value_list)) + if score > threshold: + candidates.append((score, j, j+i-1)) + return sorted(candidates, key=lambda x:x[0], reverse=True)[0][1:] if candidates else None + +def preprocess(utt, da): + ''' + utt: str + da: dict {'domain-intent': [slot, value]} + ''' + with nlp.disable_pipes('tagger', 'parser'): + tokens = [token.text for token in nlp(utt)] + labels = dict() + for key, pair in da.items(): + tags = ["O"] * len(tokens) + slots = [] + labels[key] = {'tags':tags, 'slots':slots} + for slot, value in pair: + intent = key.split('-')[1].lower() + if intent in ["request"]: + slots.append(slot) + elif intent in ['inform']: + value_tokens = [token.text for token in nlp(value)] + span = phrase_idx_utt(value_tokens, tokens) + if span is not None: + if slot.lower() in ['name', 'dest', 'depart']: + tokens[span[0]:span[1]+1] = value_tokens + tags[span[0]:span[1]+1] = ["O"] * len(value_tokens) + tags[span[0]] = "B-" + slot + for i in range(span[0]+1, span[0]+len(value_tokens)): + tags[i] = "I-" + slot + else: + #tags[span[0]] = "B-" + da[1] + '-' + da[0] + "+" + da[2] + tags[span[0]] = "B-" + slot + for i in range(span[0]+1, span[1]+1): + #tags[i] = "I-" + da[1] + '-' + da[0] + "+" + da[2] + tags[i] = "I-" + slot + return tokens, labels diff --git a/convlab2/laug/Speech_Recognition/multiwoz/span_detection.py b/convlab2/laug/Speech_Recognition/multiwoz/span_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..c765e96ae758ca129418965606797f287cc1ad2c --- /dev/null +++ b/convlab2/laug/Speech_Recognition/multiwoz/span_detection.py @@ -0,0 +1,85 @@ +from .detection_utils import translateNumberToEnglish,translateTimeToEnglish,span_typer,replacer +import json +from .paraphrase_span_detection import phrase_idx_utt + +def span_detect(original_text,new_text,span_list): +#input:original_text,new_text,one span_info [slot,slot,span,start,end] +#output:is_span_found? , is_span_changed? , new span_info [slot,slot,new span,new start,new end] + span=span_list[2].lower() + span=replacer(span) + span_type=span_typer(span) + new_words=new_text.split() + if span_type=="time": + span2=translateTimeToEnglish(span) + if span_type=="number": + span2=translateNumberToEnglish(span) + if span_type=="none": + span2=span + span_changed,span_found=0,0 + if new_text.find(span)>=0: + span_changed,span_found=0,1 + span_start=new_text.count(' ',0,new_text.find(span)) + span_end=span_start+len(span.split())-1 + new_span_list=[span_list[0],span_list[1],' '.join(new_words[span_start:span_end+1]),span_start,span_end] + elif new_text.find(span2)>=0: + span_changed,span_found=1,1 + span=span2 + span_start=new_text.count(' ',0,new_text.find(span)) + span_end=span_start+len(span.split())-1 + new_span_list=[span_list[0],span_list[1],' '.join(new_words[span_start:span_end+1]),span_start,span_end] + else: + span=span2 + span_words=span.split() + + result=phrase_idx_utt(span_words,new_words) + if result is not None: + max_start,max_end=result + span_changed,span_found=1,1 + new_span_list=[span_list[0],span_list[1],' '.join(new_words[max_start:max_end+1]),max_start,max_end] + else: + origin_split=original_text.split() + new_split=new_words + ok=0 + origin_start=span_list[3]-1 + if origin_start>=0: + if origin_start-1>=0 and origin_split[origin_start] in ['.',',','?']: + origin_start-=1 + start_word=origin_split[origin_start] + for start in range(len(new_split)): + if new_split[start]==start_word: + break + start+=1 + else: + start=0 + if span_list[4]+1<len(origin_split) and start<len(new_split): + end_word=origin_split[span_list[4]+1] + if end_word not in ['.',',','?']: + if span_list[4]+1<len(origin_split): + end_word=origin_split[span_list[4]+1] + for end in range(start,len(new_split)): + if new_split[end]==end_word: + ok=1 + break + end-=1 + + else: + if span_list[4]+2<len(origin_split): + end_word=origin_split[span_list[4]+2] + for end in range(start,len(new_split)): + if new_split[end]==end_word: + ok=1 + break + end-=1 + else: + ok=1 + end=len(new_split)-1 + else: + ok=1 + end=len(new_split)-1 + if start<=end and ok==1: + span_changed,span_found=1,1 + new_span_list=[span_list[0],span_list[1],' '.join(new_words[start:end+1]),start,end] + + if span_found==0: + new_span_list=[span_list[0],span_list[1],span_list[2],0,0] + return new_span_list diff --git a/convlab2/laug/Text_Paraphrasing/README.md b/convlab2/laug/Text_Paraphrasing/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0e5682b30a58440ecfba5bf28197651237c255e0 --- /dev/null +++ b/convlab2/laug/Text_Paraphrasing/README.md @@ -0,0 +1,3 @@ +# Text Paraphrasing + +We applied SC-GPT to paraphrase the sentences. Code of SC-GPT is under `LAUG/nlg/` dir. diff --git a/convlab2/laug/Text_Paraphrasing/Text_Paraphrasing.py b/convlab2/laug/Text_Paraphrasing/Text_Paraphrasing.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd15c8d873d5c6577bba0517ac8d9302a4296dd --- /dev/null +++ b/convlab2/laug/Text_Paraphrasing/Text_Paraphrasing.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +from convlab2.nlg.scgpt.multiwoz.scgpt import SCGPT +from convlab2.laug.Text_Paraphrasing.utils import span2tuple,paraphrase_span_detection +class Text_Paraphrasing: + def __init__(self,dataset='multiwoz'): + if dataset=='multiwoz': + self.model=SCGPT() + if dataset=='frames': + self.model=SCGPT(model_file='https://convlab.blob.core.windows.net/convlab-2/nlg-gpt-frames.zip') + self.model.init_session() + def aug(self,text,span_info): + t=span2tuple(span_info) + new_text = self.model.generate(t) + new_span_info = paraphrase_span_detection(new_text,span_info) + return new_text, new_span_info + + +if __name__=="__main__": + text = "I want a train to Cambridge" + span_info = [["Train-Infrom","Dest","Cambridge",5,5]] + TP = Text_Paraphrasing() + new_text,new_span_info = TP.aug(text,span_info) + print(new_text) + print(new_span_info) diff --git a/convlab2/laug/Text_Paraphrasing/__init__.py b/convlab2/laug/Text_Paraphrasing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c68785e9d0b64e1fe46403c4316a9fe1ea36eeb --- /dev/null +++ b/convlab2/laug/Text_Paraphrasing/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- \ No newline at end of file diff --git a/convlab2/laug/Text_Paraphrasing/utils.py b/convlab2/laug/Text_Paraphrasing/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c3d0446a0cec3bf9bc06e8c5e5b64e4f49cbff --- /dev/null +++ b/convlab2/laug/Text_Paraphrasing/utils.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +from convlab2.util.multiwoz.paraphrase_span_detection import phrase_idx_utt + +def paraphrase_span_detection(new_text,span_info): + new_words=new_text.split() + new_span_info=[] + for span in span_info: + span_words=span[2].split() + result=phrase_idx_utt(span_words,new_words) + if result is not None: + max_start,max_end=result + new_span_info.append([span[0],span[1],' '.join(new_words[max_start:max_end+1]),max_start,max_end]) + return new_span_info + + +def span2tuple(span_info): + t=[] + for span in span_info: + t.append((span[0].split('-')[1],span[0].split('-')[0],span[1],span[2])) + return t \ No newline at end of file diff --git a/convlab2/laug/Word_Perturbation/README.md b/convlab2/laug/Word_Perturbation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bc7f2dd12f15b35266ff58e10476a4a2e3210ae3 --- /dev/null +++ b/convlab2/laug/Word_Perturbation/README.md @@ -0,0 +1,28 @@ +## EDA +Query database and randomly replace some dialog actions. + +Randomly replace, delete, swap or insert some words in text. + +## Requirements +It requires `nltk`. +```shell script +pip install nltk +``` + +And you must download `wordnet` first. + +```python +import nltk +nltk.download('wordnet') +``` + +Database of sgd and frames is available at [Link](http://115.182.62.174:9876/). Please place ```db/``` folder under ```Word_Perturbation/``` dir. + + +## Run +```shell script +python run.py --multiwoz MULTIWOZ_FILEPATH --output AUGMENTED_MULTIWOZ_FILEPATH +``` + +Run ```python run.py --help``` for more information about arguments. + diff --git a/convlab2/laug/Word_Perturbation/Word_Perturbation.py b/convlab2/laug/Word_Perturbation/Word_Perturbation.py new file mode 100644 index 0000000000000000000000000000000000000000..3a5947cfa7febb7d23802eee237358966ffbc05e --- /dev/null +++ b/convlab2/laug/Word_Perturbation/Word_Perturbation.py @@ -0,0 +1,24 @@ +#coding: UTF-8 +from convlab2.laug.Word_Perturbation.multiwoz.multiwoz_eda import MultiwozEDA +from convlab2.laug.Word_Perturbation.multiwoz.aug_with_sgd_db import multiwoz_eda_config +from convlab2.laug.Word_Perturbation.frames.aug_with_sgd_db import frames_eda_config +class Word_Perturbation: + def __init__(self,dataset='multiwoz'): + self.dataset=dataset + if dataset=='multiwoz': + multiwoz_config=multiwoz_eda_config() + self.EDA=MultiwozEDA(multiwoz_config.multiwoz,multiwoz_config.db_loader) + elif dataset=='frames': + frames_config=frames_eda_config() + self.EDA=MultiwozEDA(frames_config.frames,frames_config.db_loader) + def aug(self,text,span_info): + (new_text,new_span_info,_),_=self.EDA.augment_sentence_only(text, span_info, {}) + return new_text,new_span_info + +if __name__=="__main__": + text = "I want a train to Cambridge" + span_info = [["Train-Infrom","Dest","Cambridge",5,5]] + WP = Word_Perturbation() + new_text,new_span_info = WP.aug(text,span_info) + print(new_text) + print(new_span_info) diff --git a/convlab2/laug/Word_Perturbation/__init__.py b/convlab2/laug/Word_Perturbation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/convlab2/laug/Word_Perturbation/__init__.py @@ -0,0 +1 @@ + diff --git a/convlab2/laug/Word_Perturbation/frames/__init__.py b/convlab2/laug/Word_Perturbation/frames/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c68785e9d0b64e1fe46403c4316a9fe1ea36eeb --- /dev/null +++ b/convlab2/laug/Word_Perturbation/frames/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- \ No newline at end of file diff --git a/convlab2/laug/Word_Perturbation/frames/aug_with_sgd_db.py b/convlab2/laug/Word_Perturbation/frames/aug_with_sgd_db.py new file mode 100644 index 0000000000000000000000000000000000000000..bc33ec3c4d98f8897eceaaa33bb965434648dfc2 --- /dev/null +++ b/convlab2/laug/Word_Perturbation/frames/aug_with_sgd_db.py @@ -0,0 +1,99 @@ +from convlab2.laug.Word_Perturbation.multiwoz.multiwoz_eda import MultiwozEDA +from convlab2.laug.Word_Perturbation.multiwoz.db.slot_value_replace import MultiSourceDBLoader, MultiSourceDBLoaderArgs +from convlab2.laug.Word_Perturbation.multiwoz.util import load_json, dump_json +from convlab2 import DATA_ROOT,get_root_path + +def read_zipped_json(filepath, filename): + print("zip file path = ", filepath) + archive = zipfile.ZipFile(filepath, 'r') + return json.load(archive.open(filename)) + + +class frames_eda_config: + def __init__(self,): + self.frames=read_zipped_json(os.path.join(DATA_ROOT, 'frames/Ori','train.json.zip'),'train.json') + + frames_frames_domain_slot_map = { + # ('frame', 'category'): ('hotel', 'category'), + ('frame', 'dst_city'): ('hotel', 'location'), + # ('frame', 'gst_rating'): ('hotel', 'gst_rating'), + ('frame', 'name'): ('hotel', 'name'), + + ('frame', 'or_city'): ('trip', 'or_city'), + # ('frame', 'seat'): ('trip', 'seat'), + } + + frames_sgd_domain_slot_map = { + ('frame', 'dst_city'): ('hotels', 'dst_city'), + ('frame', 'name'): ('hotels', 'hotel_name'), + + ('frame', 'or_city'): ('travel', 'location'), + } + frames_db_dir=os.path.join(get_root_path(),"convlab2/laug/Word_Perturbation/db/frames-db/") + sgd_db_dir=os.path.join(get_root_path(),"convlab2/laug/Word_Perturbation/db/sgd-db/") + + loader_args = [ + MultiSourceDBLoaderArgs(frames_db_dir, frames_frames_domain_slot_map), + MultiSourceDBLoaderArgs(sgd_db_dir, frames_sgd_domain_slot_map) + ] + self.db_loader = MultiSourceDBLoader(loader_args) + + + + +def main(frames_filepath, output_filepath, + frames_db_dir, + sgd_db_dir, + alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=2, + p_slot_value_replacement=0.25): + frames = load_json(frames_filepath) + + frames_frames_domain_slot_map = { + # ('frame', 'category'): ('hotel', 'category'), + ('frame', 'dst_city'): ('hotel', 'location'), + # ('frame', 'gst_rating'): ('hotel', 'gst_rating'), + ('frame', 'name'): ('hotel', 'name'), + + ('frame', 'or_city'): ('trip', 'or_city'), + # ('frame', 'seat'): ('trip', 'seat'), + } + + frames_sgd_domain_slot_map = { + ('frame', 'dst_city'): ('hotels', 'dst_city'), + ('frame', 'name'): ('hotels', 'hotel_name'), + + ('frame', 'or_city'): ('travel', 'location'), + } + loader_args = [ + MultiSourceDBLoaderArgs(frames_db_dir, frames_frames_domain_slot_map), + MultiSourceDBLoaderArgs(sgd_db_dir, frames_sgd_domain_slot_map) + ] + db_loader = MultiSourceDBLoader(loader_args) + + eda = MultiwozEDA(frames, db_loader, + inform_intents=('inform', 'switch_frame', 'confirm'), + slot_value_replacement_probability=p_slot_value_replacement, + alpha_sr=alpha_sr, alpha_ri=alpha_ri, alpha_rs=alpha_rs, p_rd=p_rd, num_aug=num_aug) + result = eda.augment_multiwoz_dataset('usr') + + dump_json(result, output_filepath, indent=4) + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--frames_filepath", default='multiwoz.json') + parser.add_argument('--output_filepath', '--output', '-o', default='augmented_multiwoz.json') + parser.add_argument('--alpha_sr', type=float, default=0.1, help='probability of replacement') + parser.add_argument('--alpha_ri', type=float, default=0.1, help='probability of insertion') + parser.add_argument('--alpha_rs', type=float, default=0.1, help='probability of swap') + parser.add_argument('--p_rd', type=float, default=0.1, help="probability of deletion") + parser.add_argument('--num_aug', type=int, default=2, + help="generate `num_aug` candidates with EDA and randomly choose one dialog as augmented dialog.") + parser.add_argument('--p_slot_value_replacement', '-p_svr', type=float, default=0.25, + help='probability to replace a slot value.') + parser.add_argument('--sgd_db_dir', '--sgd', help='dir of sgd db.') + parser.add_argument('--frames_db_dir', help='dir of frames db') + opts = parser.parse_args() + main(**vars(opts)) diff --git a/convlab2/laug/Word_Perturbation/multiwoz/__init__.py b/convlab2/laug/Word_Perturbation/multiwoz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c68785e9d0b64e1fe46403c4316a9fe1ea36eeb --- /dev/null +++ b/convlab2/laug/Word_Perturbation/multiwoz/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- \ No newline at end of file diff --git a/convlab2/laug/Word_Perturbation/multiwoz/aug_with_sgd_db.py b/convlab2/laug/Word_Perturbation/multiwoz/aug_with_sgd_db.py new file mode 100644 index 0000000000000000000000000000000000000000..b008a95ffbbd5f39c59ad9a771934ec40e622b1c --- /dev/null +++ b/convlab2/laug/Word_Perturbation/multiwoz/aug_with_sgd_db.py @@ -0,0 +1,121 @@ +import os +from convlab2.laug.Word_Perturbation.multiwoz.multiwoz_eda import MultiwozEDA +from convlab2.laug.Word_Perturbation.multiwoz.db.slot_value_replace import MultiSourceDBLoader, MultiSourceDBLoaderArgs +from convlab2.laug.Word_Perturbation.multiwoz.util import load_json, dump_json +from convlab2 import DATA_ROOT,get_root_path +import json +import zipfile +def read_zipped_json(filepath, filename): + print("zip file path = ", filepath) + archive = zipfile.ZipFile(filepath, 'r') + return json.load(archive.open(filename)) + + +class multiwoz_eda_config: + def __init__(self,): + self.multiwoz=read_zipped_json(os.path.join(DATA_ROOT, 'multiwoz','train.json.zip'),'train.json') + multiwoz_db_dir = os.path.join(DATA_ROOT, 'multiwoz', 'db') + multiwoz_multiwoz_domain_slot_map = { + ('attraction', 'area'): ('attraction', 'Area'), + ('attraction', 'type'): ('attraction', 'Type'), + ('attraction', 'name'): ('attraction', 'Name'), + ('attraction', 'address'): ('attraction', 'Addr'), + ('hospital', 'department'): ('hospital', 'Department'), + ('hospital', 'address'): ('hospital', 'Addr'), + ('hotel', 'type'): ('hotel', 'Type'), + ('hotel', 'area'): ('hotel', 'Area'), + ('hotel', 'name'): ('hotel', 'Name'), + ('hotel', 'address'): ('hotel', 'Addr'), + ('restaurant', 'food'): ('restaurant', 'Food'), + ('restaurant', 'area'): ('restaurant', 'Area'), + ('restaurant', 'name'): ('restaurant', 'Name'), + ('restaurant', 'address'): ('restaurant', 'Addr'), + ('train', 'destination'): ('train', 'Dest'), + ('train', 'departure'): ('train', 'Depart') + } + + multiwoz_sgd_domain_slot_map = { + ('train', 'dest'): ('train', 'to'), + ('train', 'depart'): ('train', 'from'), + ('hotel', 'name'): ('hotels', 'hotel_name'), + ('hotel', 'addr'): ('hotels', 'address'), + ('attraction', 'name'): ('travel', 'attraction_name'), + ('restaurant', 'name'): ('restaurants', 'restaurant_name'), + ('restaurant', 'addr'): ('restaurants', 'street_address') + } + loader_args = [MultiSourceDBLoaderArgs(multiwoz_db_dir, multiwoz_multiwoz_domain_slot_map)] + sgd_db_dir=os.path.join(get_root_path(),"convlab2/laug/Word_Perturbation/db/sgd-db/") + loader_args.append(MultiSourceDBLoaderArgs( + sgd_db_dir, + multiwoz_sgd_domain_slot_map + )) + self.db_loader = MultiSourceDBLoader(loader_args) + +def main(multiwoz_filepath, output_filepath, + sgd_db_dir=None, + alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=2, + p_slot_value_replacement=0.25): + multiwoz = load_json(multiwoz_filepath) + + multiwoz_db_dir = os.path.join(DATA_ROOT, 'multiwoz', 'db') + multiwoz_multiwoz_domain_slot_map = { + ('attraction', 'area'): ('attraction', 'Area'), + ('attraction', 'type'): ('attraction', 'Type'), + ('attraction', 'name'): ('attraction', 'Name'), + ('attraction', 'address'): ('attraction', 'Addr'), + ('hospital', 'department'): ('hospital', 'Department'), + ('hospital', 'address'): ('hospital', 'Addr'), + ('hotel', 'type'): ('hotel', 'Type'), + ('hotel', 'area'): ('hotel', 'Area'), + ('hotel', 'name'): ('hotel', 'Name'), + ('hotel', 'address'): ('hotel', 'Addr'), + ('restaurant', 'food'): ('restaurant', 'Food'), + ('restaurant', 'area'): ('restaurant', 'Area'), + ('restaurant', 'name'): ('restaurant', 'Name'), + ('restaurant', 'address'): ('restaurant', 'Addr'), + ('train', 'destination'): ('train', 'Dest'), + ('train', 'departure'): ('train', 'Depart') + } + + multiwoz_sgd_domain_slot_map = { + ('train', 'dest'): ('train', 'to'), + ('train', 'depart'): ('train', 'from'), + ('hotel', 'name'): ('hotels', 'hotel_name'), + ('hotel', 'addr'): ('hotels', 'address'), + ('attraction', 'name'): ('travel', 'attraction_name'), + ('restaurant', 'name'): ('restaurants', 'restaurant_name'), + ('restaurant', 'addr'): ('restaurants', 'street_address') + } + loader_args = [MultiSourceDBLoaderArgs(multiwoz_db_dir, multiwoz_multiwoz_domain_slot_map)] + assert sgd_db_dir is not None + loader_args.append(MultiSourceDBLoaderArgs( + sgd_db_dir, + multiwoz_sgd_domain_slot_map + )) + db_loader = MultiSourceDBLoader(loader_args) + + eda = MultiwozEDA(multiwoz, db_loader, + slot_value_replacement_probability=p_slot_value_replacement, + alpha_sr=alpha_sr, alpha_ri=alpha_ri, alpha_rs=alpha_rs, p_rd=p_rd, num_aug=num_aug) + result = eda.augment_multiwoz_dataset('usr') + + dump_json(result, output_filepath, indent=4) + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--multiwoz_filepath", '--multiwoz', default='multiwoz.json') + parser.add_argument('--output_filepath', '--output', '-o', default='augmented_multiwoz.json') + parser.add_argument('--alpha_sr', type=float, default=0.1, help='probability of replacement') + parser.add_argument('--alpha_ri', type=float, default=0.1, help='probability of insertion') + parser.add_argument('--alpha_rs', type=float, default=0.1, help='probability of swap') + parser.add_argument('--p_rd', type=float, default=0.1, help="probability of deletion") + parser.add_argument('--num_aug', type=int, default=2, + help="generate `num_aug` candidates with EDA and randomly choose one dialog as augmented dialog.") + parser.add_argument('--p_slot_value_replacement', '-p_svr', type=float, default=0.25, + help='probability to replace a slot value.') + parser.add_argument('--sgd_db_dir', '--sgd', help='dir of sgd db.') + opts = parser.parse_args() + main(**vars(opts)) diff --git a/convlab2/laug/Word_Perturbation/multiwoz/db/__init__.py b/convlab2/laug/Word_Perturbation/multiwoz/db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c68785e9d0b64e1fe46403c4316a9fe1ea36eeb --- /dev/null +++ b/convlab2/laug/Word_Perturbation/multiwoz/db/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- \ No newline at end of file diff --git a/convlab2/laug/Word_Perturbation/multiwoz/db/db.py b/convlab2/laug/Word_Perturbation/multiwoz/db/db.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ad50843b00f145fc99afdd35dfcf7abcd82784 --- /dev/null +++ b/convlab2/laug/Word_Perturbation/multiwoz/db/db.py @@ -0,0 +1,31 @@ +from typing import Union, Callable, List, Optional +from ..util import choice + + +class BaseDB: + ... + + +class DB(list, BaseDB): + """ + DB is a list of dicts. + """ + + def query(self, conditions: Union[dict, Callable[[dict], bool], None]) -> List[dict]: + if conditions is None: + return self + assert callable(conditions) or isinstance(conditions, dict) + if isinstance(conditions, dict): + fn = lambda item: all(item[k] == conditions[k] for k in conditions if k in item) + else: + fn = conditions + return [item for item in self if fn(item)] + + def sample(self, conditions=None) -> Optional[dict]: + list_ = self.query(conditions) + if list_: + try: + return choice(list_) + except (IndexError, ValueError): + pass + return None diff --git a/convlab2/laug/Word_Perturbation/multiwoz/db/db_loader.py b/convlab2/laug/Word_Perturbation/multiwoz/db/db_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..73d755aa1a69ef47e5863dbab5a6570f375238d5 --- /dev/null +++ b/convlab2/laug/Word_Perturbation/multiwoz/db/db_loader.py @@ -0,0 +1,52 @@ +import re +import os +from functools import lru_cache +from typing import Optional +from abc import ABC +from .db import DB, BaseDB +from ..util import load_json + + +def list_db_filename(db_dir): + filenames = os.listdir(db_dir) + db_filenames = {} + for filename in filenames: + match = re.match(r'^(\w+)_db\.json$', filename) + if match is not None and os.path.isfile(os.path.join(db_dir, filename)): + domain = match.group(1) + db_filenames[domain] = filename + return db_filenames + + +def list_db_filepath(db_dir): + return {domain: os.path.join(db_dir, filename) for domain, filename in list_db_filename(db_dir).items()} + + +class BaseDBLoader(ABC): + def load_db(self, domain: str, slot: Optional[str] = None) -> Optional[BaseDB]: + """given a domain and a slot, load corresponding db.""" + raise NotImplementedError + + +class DBLoader(BaseDBLoader): + def __init__(self, db_dir): + assert db_dir and os.path.isdir(db_dir) + self.db_dir = db_dir + self.db_files = list_db_filepath(db_dir) # domain --> filepath + self.db_cache = {} # domain --> List[dict] + + @lru_cache(maxsize=25) + def _get_db_file(self, domain): + if domain in self.db_files: + return self.db_files[domain] + for dom, filename in self.db_files.items(): + if domain.lower() in dom.lower(): + return filename + + def load_db(self, domain: str, slot: Optional[str] = None) -> Optional[DB]: + filepath = self._get_db_file(domain) + if filepath is None: + return None + if domain not in self.db_cache: + self.db_cache[domain] = DB(load_json(filepath)) + return self.db_cache[domain] diff --git a/convlab2/laug/Word_Perturbation/multiwoz/db/slot_value_replace.py b/convlab2/laug/Word_Perturbation/multiwoz/db/slot_value_replace.py new file mode 100644 index 0000000000000000000000000000000000000000..32f0a259a530003a286584b2cf535ad07193835a --- /dev/null +++ b/convlab2/laug/Word_Perturbation/multiwoz/db/slot_value_replace.py @@ -0,0 +1,291 @@ +import random +import re +from typing import List, Dict, Optional, Union, Iterable +from copy import deepcopy +from collections import defaultdict, namedtuple + +from .db_loader import BaseDBLoader, DBLoader +from .db import BaseDB, DB, choice +from ..tokenize_util import tokenize +from ..util import p_str + +MultiSourceDBLoaderArgs = namedtuple('MultiSourceDBLoaderArgs', 'db_dir domain_slot_map') + + +class MultiSourceDBLoader(BaseDBLoader): + @staticmethod + def _parse_init_args(args) -> List[MultiSourceDBLoaderArgs]: + assert isinstance(args, (list, tuple)) + if isinstance(args, MultiSourceDBLoaderArgs): + return [args] + + def toMultiSourceDBLoaderArgs(arg): + if isinstance(arg, MultiSourceDBLoaderArgs): + return arg + assert isinstance(arg, (list, tuple)) + assert len(arg) == len(MultiSourceDBLoaderArgs._fields) + return MultiSourceDBLoaderArgs(*arg) + + args = [toMultiSourceDBLoaderArgs(arg) for arg in args] + return args + + def __init__(self, args: Union[List[MultiSourceDBLoaderArgs], List[tuple], MultiSourceDBLoaderArgs]): + self.loaders_and_maps = [] + args = self._parse_init_args(args) + for db_dir, domain_slot_map in args: + loader = DBLoader(db_dir) + self.loaders_and_maps.append((loader, domain_slot_map)) + + def load_db(self, domain, slot: Optional[str] = None) -> Optional["MultiSourceDB"]: + dbs = [] + for loader, domain_slot_map in self.loaders_and_maps: + if slot is not None: + if (domain.lower(), slot.lower()) in domain_slot_map: + db_domain, db_slot = domain_slot_map[(domain.lower(), slot.lower())] + db = loader.load_db(db_domain, db_slot) + if db is not None: + dbs.append((db, db_domain, domain_slot_map)) + else: + domain_to_db = {} + for domain_slot_tuple, db_domain_slot_tuple in domain_slot_map.items(): + if domain.lower() == domain_slot_tuple[0].lower(): + db_domain = db_domain_slot_tuple[0] + if db_domain not in domain_to_db: + db = loader.load_db(db_domain) + if db is not None: + domain_to_db[db_domain] = db + dbs.extend((db, db_domain, domain_slot_map) for db_domain, db in domain_to_db.items()) + + if not dbs: + return None + return MultiSourceDB(dbs) + + +MultiSourceDBArgs = namedtuple('MultiSourceDBArgs', 'db db_domain domain_slot_map') + + +class MultiSourceDB(BaseDB): + @staticmethod + def _parse_init_args(args) -> List[MultiSourceDBArgs]: + if isinstance(args, MultiSourceDBArgs): + return [args] + assert isinstance(args, (list, tuple)) + + def toMultiSourceDBArgs(arg): + if isinstance(arg, MultiSourceDBArgs): + return arg + assert isinstance(arg, (list, tuple)) + assert len(arg) == len(MultiSourceDBArgs._fields) + return MultiSourceDBArgs(*arg) + + args = [toMultiSourceDBArgs(arg) for arg in args] + return args + + def __init__(self, args: Union[MultiSourceDBArgs, List[MultiSourceDBArgs], List[tuple]]): + self.args = self._parse_init_args(args) + + def find_different_values(self, domain, slot, excluding_values=()) -> Iterable: + """find different values, which belong to the same domain and slot.""" + for db, db_domain, domain_slot_map in self.args: + k = (domain.lower(), slot.lower()) + if k not in domain_slot_map: + continue + if domain_slot_map[k][0] != db_domain: + continue + db_domain, db_slot = domain_slot_map[k] + r = db.query( + lambda item: db_slot in item and item[db_slot] not in excluding_values) + yield from (dict_[db_slot] for dict_ in r) + + def sample_value(self, domain, slot, excluding_values=()): + values = self.find_different_values(domain, slot, excluding_values=excluding_values) + try: + return choice(values) + except ValueError: + return None + + +def _get_word2indexes(words, to_lower=False): + word2indexes = defaultdict(list) + for i, word in enumerate(words): + if to_lower: + word2indexes[word.lower()].append(i) + else: + word2indexes[word].append(i) + return word2indexes + + +def _get_positions(words: List[str], word_to_indexes: Dict[str, List[int]], value: List[str]): + first_word = value[0] + N = len(value) + for first_index in word_to_indexes.get(first_word, ()): + if words[first_index: first_index + N] == value: + return [first_index, first_index + N - 1] + + +def fix_text(text): + # strip; split punctuation and word + text = re.sub(r"(?:^|(?<=\s))([\w$']+)([,.?*/!;<=>\]\"]+)(?:$|(?=[A-Z\s]))", r'\1 \2', + text) # split word and punctuation + return text + + +def fix_turn(turn: dict): + # fix error in a turn + # turn = { + # 'text': ..., + # 'span_info': ..., + # 'dialog_act': ... + # } + text = turn['text'] + words = tokenize(text) + word2indexes = _get_word2indexes(words, to_lower=False) + span_info = turn['span_info'] + dialog_act = turn['dialog_act'] + for i, item in enumerate(span_info): + domain_intent, slot, value, *positions = item + assert len(positions) == 2 + domain, intent = domain_intent.split('-') + if ' '.join(words[positions[0]: positions[1] + 1]) != value: + positions = None + if positions is None: + positions = _get_positions(words, word2indexes, tokenize(value)) + if positions is None: + slot_value_list = dialog_act[domain_intent] + for i in range(len(slot_value_list)): + if slot_value_list[i][0] == slot: + value = slot_value_list[i][1] + break + positions = _get_positions(words, word2indexes, tokenize(value)) + if positions is None: + raise ValueError(f"turn: {p_str(turn)}\nitem: {p_str(item)}\nwords: {p_str(words)}\n" + f"word2indexes {p_str(word2indexes)}\nvalue: {tokenize(value)}") + value = ' '.join(words[positions[0]:1 + positions[1]]) + item[2] = value + if item[-2:] != positions: + item[-2:] = positions + if domain_intent not in dialog_act: + continue + slot_value_list = dialog_act[domain_intent] + for i in range(len(slot_value_list)): + if slot_value_list[i][0] == slot: + slot_value_list[i][1] = value + span_info.sort(key=lambda item: item[-2:]) + + +def assert_correct_turn(turn: dict): + text = turn['text'] + words = tokenize(text) + span_info = turn['span_info'] + dialog_act = turn['dialog_act'] + new_dialog_act = {} + for item in span_info: + domain_intent, slot, value, begin, end = item + assert words[begin: 1 + end] == tokenize(value), f"turn: {p_str(turn)}\nitem: {item}" + new_dialog_act.setdefault(domain_intent, []) + new_dialog_act[domain_intent].append([slot, value]) + for domain_intent, new_slot_value_list in new_dialog_act.items(): + assert domain_intent in dialog_act + new_slot_value_set = {tuple(slot_value) for slot_value in new_slot_value_list} + slot_value_list = dialog_act[domain_intent] + slot_value_set = {tuple(slot_value) for slot_value in slot_value_list} + assert new_slot_value_set <= slot_value_set, p_str([turn, new_dialog_act]) + diff = slot_value_set - new_slot_value_set + assert all(slot == 'none' or value == '?' for slot, value in diff), f"Error, {p_str(turn)}\n{p_str(diff)}" + + +def replace_slot_values_in_turn(turn: dict, db_loader: MultiSourceDBLoader, + p=0.25, + inform_intents=('inform',)): + orig_turn = turn + turn = deepcopy(orig_turn) + try: + fix_turn(turn) + assert_correct_turn(turn) + except: + return orig_turn + text = turn['text'] + words = tokenize(text) + span_info = turn['span_info'] + span_info.sort(key=lambda item: item[-2:]) + dialog_act = turn['dialog_act'] + if any(span_info[i][-2] <= span_info[i - 1][-1] for i in range(1, len(span_info))): + return turn + + new_turn = deepcopy(turn) + new_words = words.copy() + new_span_info = new_turn['span_info'] + new_dialog_act = new_turn['dialog_act'] + updated_span = [] + + for i, item in enumerate(span_info): + domain_intent = item[0] + domain, intent = domain_intent.split('-') + slot = item[1] + value = item[2] + if intent.lower() not in inform_intents: + continue + if updated_span: + j = updated_span[-1] + last_item = span_info[j] + if item[-2] <= last_item[-1]: + continue + db = db_loader.load_db(domain, slot) + if db is None: + continue + new_value = db.sample_value(domain, slot, excluding_values=(value, 'none', '?')) + if new_value is None: + continue + if random.random() > p: + continue + new_value = fix_text(new_value) + new_span_info[i][2] = new_value + new_slot_value_list = new_dialog_act[domain_intent] + for j in range(len(new_slot_value_list)): + if new_slot_value_list[j][0] == slot: + new_slot_value_list[j][1] = new_value + updated_span.append(i) + # print(f'replace {item[2]} with {new_value}') + + # update new_words and span in new_span_info + if updated_span: + offset = 0 + for i in range(len(span_info)): + begin, end = span_info[i][-2:] + new_value = new_span_info[i][2] + new_value = tokenize(new_value) + num_words = len(new_value) + new_words[offset + begin: offset + end + 1] = new_value + new_span_info[i][-2:] = [begin + offset, begin + offset + num_words - 1] + offset += num_words - (end - begin + 1) + new_turn['text'] = ' '.join(new_words) + assert_correct_turn(new_turn) + return new_turn + + +def replace_slot_values(sample, db_loader: MultiSourceDBLoader, + p=0.25, + inform_intents=('inform',), + mode='usr'): + """ + replace slot values in a sample + + Args: + sample: a dialogue + db_loader: it can loads a db + p: probability to replace if conditions are satisfied + inform_intents: only inform intents may be replaced. + mode: 'usr' or 'user': only replace on user turns; + 'sys': on;y replace on sys turns; + 'all': replace on all turns + """ + new_sample = deepcopy(sample) + for turn_index, turn in enumerate(sample['log']): + is_user = turn_index % 2 == 0 + if is_user and mode not in ('usr', 'user', 'all'): + continue + if not is_user and mode not in ('sys', 'system', 'all'): + continue + new_turn = replace_slot_values_in_turn(turn, db_loader, p=p, inform_intents=inform_intents) + new_sample['log'][turn_index] = new_turn + return new_sample diff --git a/convlab2/laug/Word_Perturbation/multiwoz/multiwoz_eda.py b/convlab2/laug/Word_Perturbation/multiwoz/multiwoz_eda.py new file mode 100644 index 0000000000000000000000000000000000000000..d2ee027b38df9fe63070db3b3393c7a8cae5111c --- /dev/null +++ b/convlab2/laug/Word_Perturbation/multiwoz/multiwoz_eda.py @@ -0,0 +1,103 @@ +import tqdm + +from .task_oriented_eda import eda +from .types import MultiwozSampleType, SentenceType, MultiwozDatasetType +from .util import AugmentationRecorder, iter_dialogues, Helper, choice, is_span_info_consistent_with_text, p_str +from .tokenize_util import tokenize, convert_tokens_to_string, convert_sentence_to_tokens +from .db.slot_value_replace import replace_slot_values_in_turn, MultiSourceDBLoader, assert_correct_turn + + +class MultiwozEDA: + def __init__(self, multiwoz: MultiwozDatasetType, + db_loader: MultiSourceDBLoader, + inform_intents=('inform',), + slot_value_replacement_probability=0.25, + alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=2): + # attributes for slot value replacement + self.db_loader = db_loader + self.inform_intents = inform_intents + self.slot_value_replacement_probability = slot_value_replacement_probability + + # attributes for EDA. + self.eda_config = dict(alpha_sr=alpha_sr, alpha_ri=alpha_ri, alpha_rs=alpha_rs, p_rd=p_rd, num_aug=num_aug) + self.multiwoz = multiwoz + self.helper = Helper(multiwoz) + + def _get_excluding_indexes(self, words, span_info, dialog_act): + return self.helper._get_excluding_indexes(words, span_info, dialog_act) + + def _augment_sentence_only(self, sentence: SentenceType, span_info, dialog_act): + """don't change DA (span indexes may change)""" + words = convert_sentence_to_tokens(sentence) + excluding_indexes = self._get_excluding_indexes(words, span_info, dialog_act) + + for new_words, index_map in eda(words, **self.eda_config, excluding_indexes=excluding_indexes): + new_span_info = [] + for x in span_info: + new_span_info.append([*x[:3], index_map[x[3]], index_map[x[4]]]) + yield convert_tokens_to_string(new_words), new_span_info, dialog_act + + def augment_sentence_only(self, sentence: SentenceType, span_info, dialog_act): + return list(self._augment_sentence_only(sentence, span_info, dialog_act)) + + def _augment_sample(self, sample: MultiwozSampleType, mode='usr') -> AugmentationRecorder: + recorder = AugmentationRecorder(sample) + + for turn_index, turn in iter_dialogues(sample, mode=mode): + if not is_span_info_consistent_with_text(turn['text'], turn['span_info']): + continue + try: + assert_correct_turn(turn) + except: + continue + from copy import deepcopy + orig_turn = deepcopy(turn) + new_turn = replace_slot_values_in_turn( + turn, + self.db_loader, + p=self.slot_value_replacement_probability, + inform_intents=self.inform_intents + ) + augmented = new_turn != turn + turn = new_turn + + try: + text = turn['text'] + span_info = turn['span_info'] + dialog_act = turn['dialog_act'] + tokens = tokenize(text) + augmented_sentence, augmented_span_info, augmented_dialog_act = choice( + self._augment_sentence_only(tokens, span_info, dialog_act) + ) + except (ValueError, IndexError): + pass + else: + assert is_span_info_consistent_with_text(augmented_sentence, augmented_span_info), p_str( + [orig_turn, turn]) + augmented = True + turn = { + 'text': augmented_sentence, + 'span_info': augmented_span_info, + 'dialog_act': augmented_dialog_act, + **{k: v for k, v in turn.items() if k not in ('text', 'span_info', 'dialog_act')} + } + + if augmented: + recorder.add_augmented_dialog(turn_index, turn) + return recorder + + def augment_sample(self, sample: MultiwozSampleType, mode='usr') -> MultiwozSampleType: + return self._augment_sample(sample, mode=mode).get_augmented_sample() + + __call__ = augment_sample + + def augment_multiwoz_dataset(self, mode='usr', progress_bar=True): + assert mode in ('usr', 'user', 'sys', 'all') + res = {} + if progress_bar: + items = tqdm.tqdm(self.multiwoz.items(), total=len(self.multiwoz)) + else: + items = self.multiwoz.items() + for sample_id, sample in items: + res[sample_id] = self.augment_sample(sample, mode=mode) + return res diff --git a/convlab2/laug/Word_Perturbation/multiwoz/run.py b/convlab2/laug/Word_Perturbation/multiwoz/run.py new file mode 100644 index 0000000000000000000000000000000000000000..47a09416b892c3a5ad4c05c0ee34c87a4be86c1e --- /dev/null +++ b/convlab2/laug/Word_Perturbation/multiwoz/run.py @@ -0,0 +1,55 @@ +import os, json +from convlab2.laug.Word_Perturbation.multiwoz.multiwoz_eda import MultiwozEDA +from convlab2.laug.Word_Perturbation.multiwoz.db.slot_value_replace import MultiSourceDBLoader, MultiSourceDBLoaderArgs +from convlab2.laug.Word_Perturbation.multiwoz.util import load_json +from convlab2 import DATA_ROOT + + +def main(multiwoz_filepath, output_filepath, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=2, + p_slot_value_replacement=0.25): + multiwoz = load_json(multiwoz_filepath) + + db_dir = os.path.join(DATA_ROOT, 'multiwoz', 'db') + multiwoz_multiwoz_domain_slot_map = { + ('attraction', 'area'): ('attraction', 'Area'), + ('attraction', 'type'): ('attraction', 'Type'), + ('attraction', 'name'): ('attraction', 'Name'), + ('attraction', 'address'): ('attraction', 'Addr'), + ('hospital', 'department'): ('hospital', 'Department'), + ('hospital', 'address'): ('hospital', 'Addr'), + ('hotel', 'type'): ('hotel', 'Type'), + ('hotel', 'area'): ('hotel', 'Area'), + ('hotel', 'name'): ('hotel', 'Name'), + ('hotel', 'address'): ('hotel', 'Addr'), + ('restaurant', 'food'): ('restaurant', 'Food'), + ('restaurant', 'area'): ('restaurant', 'Area'), + ('restaurant', 'name'): ('restaurant', 'Name'), + ('restaurant', 'address'): ('restaurant', 'Addr'), + ('train', 'destination'): ('train', 'Dest'), + ('train', 'departure'): ('train', 'Depart') + } + loader_args = MultiSourceDBLoaderArgs(db_dir, multiwoz_multiwoz_domain_slot_map) + db_loader = MultiSourceDBLoader(loader_args) + + eda = MultiwozEDA(multiwoz, db_loader, + slot_value_replacement_probability=p_slot_value_replacement, + alpha_sr=alpha_sr, alpha_ri=alpha_ri, alpha_rs=alpha_rs, p_rd=p_rd, num_aug=num_aug) + result = eda.augment_multiwoz_dataset('usr') + + os.makedirs(os.path.dirname(os.path.abspath(output_filepath)), exist_ok=True) + with open(output_filepath, 'w', encoding='utf-8') as out: + json.dump(result, out, indent=4) + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--multiwoz_filepath", '--multiwoz', default='multiwoz.json') + parser.add_argument('--output_filepath', '--output', '-o', default='augmented_multiwoz.json') + parser.add_argument('--alpha_sr', type=float, default=0.1, help='probability of replacement') + parser.add_argument('--alpha_ri', type=float, default=0.1, help='probability of insertion') + parser.add_argument('--alpha_rs', type=float, default=0.1, help='probability of swap') + parser.add_argument('--p_rd', type=float, default=0.1, help="probability of deletion") + parser.add_argument('--num_aug', type=int, default=2, help="generate `num_aug` candidates with EDA and randomly choose one dialog as augmented dialog.") + parser.add_argument('--p_slot_value_replacement', '-p_svr', type=float, default=0.25, help='probability to replace a slot value.') + opts = parser.parse_args() + main(**vars(opts)) diff --git a/convlab2/laug/Word_Perturbation/multiwoz/task_oriented_eda.py b/convlab2/laug/Word_Perturbation/multiwoz/task_oriented_eda.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa2d308c2cbdbb81b8db8945d02baef4844c488 --- /dev/null +++ b/convlab2/laug/Word_Perturbation/multiwoz/task_oriented_eda.py @@ -0,0 +1,344 @@ + +import random +import string +import re +from functools import lru_cache +from typing import List, Optional, Tuple, Sequence +from collections import defaultdict +from random import shuffle + +random.seed(1) + +# stop words list +stop_words = ['i', 'me', 'my', 'myself', 'we', 'our', + 'ours', 'ourselves', 'you', 'your', 'yours', + 'yourself', 'yourselves', 'he', 'him', 'his', + 'himself', 'she', 'her', 'hers', 'herself', + 'it', 'its', 'itself', 'they', 'them', 'their', + 'theirs', 'themselves', 'what', 'which', 'who', + 'whom', 'this', 'that', 'these', 'those', 'am', + 'is', 'are', 'was', 'were', 'be', 'been', 'being', + 'have', 'has', 'had', 'having', 'do', 'does', 'did', + 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', + 'because', 'as', 'until', 'while', 'of', 'at', + 'by', 'for', 'with', 'about', 'against', 'between', + 'into', 'through', 'during', 'before', 'after', + 'above', 'below', 'to', 'from', 'up', 'down', 'in', + 'out', 'on', 'off', 'over', 'under', 'again', + 'further', 'then', 'once', 'here', 'there', 'when', + 'where', 'why', 'how', 'all', 'any', 'both', 'each', + 'few', 'more', 'most', 'other', 'some', 'such', 'no', + 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', + 'very', 's', 't', 'can', 'will', 'just', 'don', + 'should', 'now', ''] +stop_words = set(stop_words) + +ascii_lowercase_and_space = string.ascii_lowercase + ' ' + + + +def get_only_chars(line): + line = line.lower() + line = re.sub(r"[’']", '', line) + line = re.sub(r'[\t\n\-]', " ", line) # replace hyphens with spaces + line = re.sub(r'[^a-z ]', ' ', line) + line = re.sub(' +', ' ', line) + return line.lstrip(' ') + + +######################################################################## +# Synonym replacement +# Replace n words in the sentence with synonyms from wordnet +######################################################################## +from nltk.corpus import wordnet + + +def random_replacement(words, n, excluding_indexes: Optional[Sequence[int]]=None): + """ + randomly replace n words with synonyms + + Args: + words: input words + n: num of replaced words + excluding_indexes: these words won't be replaced + + Returns: + new_words (List[str]) + index_map (Dict[int, int]) map an index in words to an index in new_words + """ + + new_words = words.copy() + indexes = list(range(len(new_words))) + forbidden = [False for _ in range(len(new_words))] + if excluding_indexes is not None: + for i in excluding_indexes: + forbidden[i] = True + + word2index = defaultdict(list) + for i, word in enumerate(words): + if word not in stop_words and not forbidden[i]: + word2index[word].append(i) + random_words = list(word2index) + random.shuffle(random_words) + + num_replaced = 0 + changes = [] + for random_word in random_words: + synonyms = get_synonyms(random_word) + if len(synonyms) >= 1: + synonym = random.choice(synonyms) + synonym_tokens = [token for token in synonym.split() if token.strip()] + if len(synonym_tokens) == 1: + for i in word2index[random_word]: + new_words[i] = synonym_tokens[0] + indexes[i] = None + else: + # if synonym has more than 1 words and simply insert synonym, index map will be wrong. + for i in word2index[random_word]: + changes.append((i, synonym_tokens)) + num_replaced += 1 + if num_replaced >= n: # only replace up to n words + break + + if changes: + changes.sort(key=lambda x: x[0]) + offset = 0 + for i, synonym_tokens in changes: + i += offset + new_words[i:i+1] = synonym_tokens + indexes[i:i+1] = [None for _ in range(len(synonym_tokens))] + offset += len(synonym_tokens) - 1 + return new_words, {v: i for i, v in enumerate(indexes) if v is not None} + + +def replacement(words, index: int): + # returns: new_words, start, end, synonym_tokens + # new_words[start: end+1] == synonym_tokens + new_words = words.copy() + word = words[index] + synonyms = get_synonyms(word) + if len(synonyms) > 0: + synonym = random.choice(synonyms) + synonym_tokens = [token for token in synonym.split() if token.strip()] + if len(synonym_tokens) == 1: + new_words[index] = synonym_tokens[0] + return new_words, index, index, synonym_tokens + else: + new_words[index: index+1] = synonym_tokens + return new_words, index, index + len(synonym_tokens) - 1, synonym_tokens + else: + return None + + +@lru_cache(maxsize=1000) +def get_synonyms(word): + synonyms = set() + for syn in wordnet.synsets(word): + for l in syn.lemmas(): + synonym = l.name().replace("_", " ").replace("-", " ").lower() + synonym = "".join(char for char in synonym if char in ascii_lowercase_and_space).strip() + if synonym: + synonyms.add(synonym) + if word in synonyms: + synonyms.remove(word) + return list(synonyms) + + +######################################################################## +# Random deletion +# Randomly delete words from the sentence with probability p +######################################################################## + +def random_deletion(words, p, excluding_indexes: Optional[Sequence[int]]=None): + """ + remove each word with probability p. + + Args: + words: input words + p: delete probability + excluding_indexes: these words won't be removed. + + Returns: + + """ + # obviously, if there's only one word, don't delete it + if len(words) == 1: + return words, {0: 0} + + # randomly delete words with probability p + new_words = [] + indexes = [] + forbidden = [False for _ in range(len(words))] + if excluding_indexes is not None: + for i in excluding_indexes: + forbidden[i] = True + for i, word in enumerate(words): + if forbidden[i]: + remained = True + else: + remained = random.uniform(0, 1) > p + if remained: + new_words.append(word) + indexes.append(i) + + # if you end up deleting all words, just return a random word + if len(new_words) == 0: + rand_int = random.randint(0, len(words) - 1) + return [words[rand_int]], {rand_int: 0} + + return new_words, {v: i for i, v in enumerate(indexes)} + + +######################################################################## +# Random swap +# Randomly swap two words in the sentence n times +######################################################################## + +def random_swap(words, n, excluding_indexes: Optional[Sequence[int]]=None): + """ + randomly swap n pairs of words + + Args: + words: input words + n: num of pairs + excluding_indexes: these words won't be swapped + + Returns: + + """ + new_words = words.copy() + indexes = list(range(len(words))) + if excluding_indexes is not None: + allow_indexes = set(range(len(words))) - set(excluding_indexes) + allow_indexes = list(allow_indexes) + else: + allow_indexes = indexes.copy() + + for _ in range(n): + new_words = swap_word(new_words, indexes, allow_indexes) + return new_words, {v: i for i, v in enumerate(indexes)} + + +def swap_word(new_words, indexes, allow_indexes): + if len(allow_indexes) <= 1: + return new_words + for _ in range(4): + i = random.choice(allow_indexes) + j = random.choice(allow_indexes) + if i != j: + new_words[i], new_words[j] = new_words[j], new_words[i] + indexes[i], indexes[j] = indexes[j], indexes[i] + break + return new_words + + +######################################################################## +# Random insertion +# Randomly insert n words into the sentence +######################################################################## + +def random_insertion(words, n, excluding_indexes: Optional[Sequence[int]]=None): + """ + randomly insert n words. + """ + new_words = words.copy() + indexes = list(range(len(new_words))) + forbidden = [False for _ in range(len(new_words))] + if excluding_indexes is not None: + for i in excluding_indexes: + forbidden[i] = True + + for _ in range(n): + add_word(new_words, indexes, forbidden) + return new_words, {v: i for i, v in enumerate(indexes) if v is not None} + + +def add_word(new_words, indexes, forbidden): + if sum(forbidden) == len(new_words): + return + synonyms = [] + counter = 0 + + while len(synonyms) < 1: + counter += 1 + if counter >= 15: + return + + idx = random.randint(0, len(new_words) - 1) + old_idx = indexes[idx] + if old_idx is None or forbidden[old_idx]: + continue + random_word = new_words[idx] + synonyms = get_synonyms(random_word) + + random_synonym = synonyms[0] + for _ in range(5): + idx = random.randint(0, len(new_words) - 1) + old_idx = indexes[idx] + if old_idx is None or not forbidden[old_idx]: + random_synonym_tokens = [token for token in random_synonym.split() if token.strip()] + # new_words.insert(idx, random_synonym) + # indexes.insert(idx, None) + new_words[idx:idx] = random_synonym_tokens + indexes[idx:idx] = [None for _ in range(len(random_synonym_tokens))] + return + + +######################################################################## +# main data augmentation function +######################################################################## + +def eda(words, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9, excluding_indexes: Optional[Sequence[int]]=None) -> List[Tuple[list, dict]]: + # sentence = get_only_chars(sentence) + # words = sentence.split(' ') + words = [word for word in words if word is not ''] + num_words = len(words) + + augmented_sentences: List[Tuple[list, dict]] = [] + num_new_per_technique = int(num_aug / 4) + 1 + n_sr = max(1, int(alpha_sr * num_words)) + n_ri = max(1, int(alpha_ri * num_words)) + n_rs = max(1, int(alpha_rs * num_words)) + + seen = set() + seen.add(tuple(words)) + + # sr + for _ in range(num_new_per_technique): + a_words, index_map = random_replacement(words, n_sr, excluding_indexes) + if tuple(a_words) not in seen: + seen.add(tuple(a_words)) + augmented_sentences.append((a_words, index_map)) + + # ri + for _ in range(num_new_per_technique): + a_words, index_map = random_insertion(words, n_ri, excluding_indexes) + if tuple(a_words) not in seen: + seen.add(tuple(a_words)) + augmented_sentences.append((a_words, index_map)) + + # rs + for _ in range(num_new_per_technique): + a_words, index_map = random_swap(words, n_rs, excluding_indexes) + if tuple(a_words) not in seen: + seen.add(tuple(a_words)) + augmented_sentences.append((a_words, index_map)) + + # rd + for _ in range(num_new_per_technique): + a_words, index_map = random_deletion(words, p_rd, excluding_indexes) + if tuple(a_words) not in seen: + seen.add(tuple(a_words)) + augmented_sentences.append((a_words, index_map)) + + # augmented_sentences = [get_only_chars(sentence) for sentence in augmented_sentences] + shuffle(augmented_sentences) + + # trim so that we have the desired number of augmented sentences + if num_aug >= 1: + augmented_sentences = augmented_sentences[:num_aug] + else: + keep_prob = num_aug + augmented_sentences = [s for s in augmented_sentences if random.uniform(0, 1) < keep_prob] + + return augmented_sentences diff --git a/convlab2/laug/Word_Perturbation/multiwoz/tokenize_util.py b/convlab2/laug/Word_Perturbation/multiwoz/tokenize_util.py new file mode 100644 index 0000000000000000000000000000000000000000..b80d5828c3bdd758fa076f7e546a1c1e454fb387 --- /dev/null +++ b/convlab2/laug/Word_Perturbation/multiwoz/tokenize_util.py @@ -0,0 +1,15 @@ +from .types import TokenListType, SentenceType + + +def tokenize(sentence: str) -> TokenListType: + return [token for token in sentence.split() if token.strip()] + +def convert_sentence_to_tokens(sentence: SentenceType) -> TokenListType: + if isinstance(sentence, str): + return tokenize(sentence) + else: + assert isinstance(sentence, list) + return sentence + +def convert_tokens_to_string(tokens: TokenListType) -> str: + return ' '.join(tokens) diff --git a/convlab2/laug/Word_Perturbation/multiwoz/types.py b/convlab2/laug/Word_Perturbation/multiwoz/types.py new file mode 100644 index 0000000000000000000000000000000000000000..089094f829108021e6902fb60d1b08164578dd2b --- /dev/null +++ b/convlab2/laug/Word_Perturbation/multiwoz/types.py @@ -0,0 +1,10 @@ +from typing import TypeVar, NewType, Union, List, Dict + +SampleType = TypeVar("SampleType") + +StringType = str +WordType = TokenType = NewType("TokenType", str) +TokenListType = WordListType = List[TokenType] +SentenceType = Union[StringType, TokenListType] +MultiwozSampleType = Dict[str, Union[None, list, dict]] +MultiwozDatasetType = Dict[str, MultiwozSampleType] diff --git a/convlab2/laug/Word_Perturbation/multiwoz/util.py b/convlab2/laug/Word_Perturbation/multiwoz/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a3cb5329160fc8a3ed617da948286446f9a0b853 --- /dev/null +++ b/convlab2/laug/Word_Perturbation/multiwoz/util.py @@ -0,0 +1,305 @@ +import re +import json +import os +import random +import string +from copy import deepcopy +from typing import List, Iterable +from collections import defaultdict, Counter +from functools import reduce, lru_cache +from io import StringIO +from pprint import pprint + +from .types import MultiwozSampleType, MultiwozDatasetType, SentenceType +from .tokenize_util import tokenize, convert_sentence_to_tokens + + +## load json and dump json ############ +def load_json(filepath): + with open(filepath, 'r', encoding='utf-8') as f: + return json.load(f) + + +def dump_json(obj, filepath, **json_kwargs): + os.makedirs(os.path.dirname(os.path.abspath(filepath)), exist_ok=True) + with open(filepath, 'w', encoding='utf-8') as f: + json_kwargs.setdefault('indent', 4) + json.dump(obj, f, **json_kwargs) + + +## better str ################################## +def p_str(obj): + sio = StringIO() + pprint(obj, sio) + return sio.getvalue() + + +punctuation_pattern = re.compile(r'^[{}]+$'.format('\\'.join(string.punctuation))) +RePatternType = type(re.compile('')) + + +def is_punctuation(word): + return punctuation_pattern.match(word) is not None + + +## Helper class ################### +def _get_all_slots(multiwoz: MultiwozDatasetType): + slots = defaultdict(dict) # Dict[str, Dict[str, Set[str]]]; Dict[Sample_ID, Dict[intent, Set[Slot]]] + for sample in multiwoz.values(): + logs = sample['log'] + if isinstance(logs, dict): + logs = logs.values() + for turn in logs: + dialog_act = turn['dialog_act'] + for domain_intent, slot_value_list in dialog_act.items(): + domain, intent = domain_intent.lower().split('-') + for slot, value in slot_value_list: + slots[domain].setdefault(intent, set()).add(slot.lower() if isinstance(slot, str) else slot) + for domain in slots: + s = reduce((lambda s1, s2: s1 | s2), slots[domain].values(), set()) + slots[domain]['all'] = s + for domain in slots: + unique_slots = slots[domain]['all'].copy() + for other_domain in slots: + if other_domain != domain: + unique_slots -= slots[other_domain]['all'] + slots[domain]['unique'] = unique_slots + return slots + + +class Patterns: + time_pattern = re.compile(r"^\d{1,2}:\d{1,2}$") + integer_pattern = re.compile(r"^[+-]?(\d+|\d{1,3}(,\d{3})*)$") + ref_pattern = re.compile(r"(?=.*[A-Z].*[0-9].*|.*[0-9].*[A-Z].*)[A-Z0-9]{8}") + + +class Helper: + def __init__(self, multiwoz): + self.multiwoz = multiwoz + self.slots = _get_all_slots(multiwoz) + + def get_unique_slots(self, domain): + return self.slots[domain]['unique'] + + @staticmethod + @lru_cache(1000) + def split_str(s): + return s.split() + + _words_about_slot = { + 'attraction': { + 'area': 'area', 'type': 'type', 'name': 'name', + 'fee': ['entrance fee'], 'addr': 'address', + 'post': ['postcode', 'post code'], 'phone': 'phone' + }, + 'hospital': { + 'department': 'department', 'addr': 'address', 'post': ['postcode', 'post code'], + 'phone': 'phone' + }, + 'hotel': { + 'type': 'type', 'parking': 'parking', 'price': ['pricerange', 'price range'], + 'internet': ['internet', 'wifi'], 'area': 'area', 'stars': 'stars', + 'name': 'name', 'stay': ['stay', Patterns.integer_pattern], 'day': 'day', + 'people': ['people', Patterns.integer_pattern], + 'addr': 'address', 'post': ['postcode', 'post code'], 'phone': 'phone' + }, + 'police': { + 'addr': 'address', 'post': ['postcode', 'post code'], 'phone': 'phone', 'name': 'name' + }, + 'restaurant': { + 'food': 'food', 'price': ['pricerange', 'price range'], 'area': 'area', + 'name': 'name', 'time': 'time', 'day': 'day', 'people': 'people', + 'phone': 'phone', 'post': ['postcode', 'post code'], 'addr': 'address' + }, + 'taxi': { + 'leave': ['leaveat', 'leave at', Patterns.time_pattern], "dest": ['destination', "cambridge", 'belfry'], + 'depart': 'departure', + 'arrive': ['arriveby', 'arrive by', Patterns.time_pattern], + 'car': 'car type', 'phone': 'phone' + }, + 'train': { + 'dest': ['destination', "cambridge"], 'day': 'day', + 'arrive': ['arriveby', 'arrive by', Patterns.time_pattern], + 'depart': 'departure', 'leave': ['leaveat', 'leave at', Patterns.time_pattern], 'people': 'people', + 'time': 'duration', 'id': 'trainid' + } + } + + def relevant_words_of_slot(self, domain, slot): + if domain not in self._words_about_slot or slot not in self._words_about_slot[domain]: + return [slot] + if isinstance(self._words_about_slot[domain][slot], str): + res = [self._words_about_slot[domain][slot]] + if slot != res[-1]: + res.append(slot) + return res + else: + return self._words_about_slot[domain][slot] + [slot] + + _words_about_domain = { + 'police': ['assistance'], + 'attraction': ['attractions', 'trip', 'gallery', 'museum', 'theatre', 'visit', 'entertainment', 'cinema', + 'park', 'cambridge', 'college', 'architecture'], + 'hotel': ['place to stay', 'hotels', 'guesthouse'], + 'restaurant': ['place to eat', 'place to dine', 'food', 'gastropub', 'restaurants'], + 'booking': ['book'], + } + + def relevant_words_of_domain(self, domain): + if domain not in self._words_about_domain: + return [domain] + return self._words_about_domain[domain] + [domain] + + @staticmethod + def contain_word(sentence: str, word): + if isinstance(word, str): + if not word.startswith(r'\b'): + word = r'\b' + word + if not word.endswith(r'\b'): + word += r'\b' + word = re.compile(word) + else: + assert isinstance(word, RePatternType) + return word.search(sentence) is not None + + def _get_excluding_indexes(self, words, span_info, dialog_act): + """exclude some words, so that the label keeps the same after augmented.""" + excluding_indexes = set() + domains = set() + slots = set() + for domain_intent, slot, value, start, end in span_info: + excluding_indexes.update(range(start, end + 1)) + domain = domain_intent.split('-')[0].lower() + domains.add(domain) + slots.add((domain, slot.lower())) + for domain_intent, slot_value_list in dialog_act.items(): + domain = domain_intent.split('-')[0].lower() + domains.add(domain) + for slot, value in slot_value_list: + slots.add((domain, slot.lower() if isinstance(slot, str) else slot)) + + word2index = {v.lower(): i for i, v in enumerate(words)} + for domain in domains: + for word in self.relevant_words_of_domain(domain): + if isinstance(word, str): + ws = tokenize(word) + if len(ws) == 1: + if ws[0] in word2index: + excluding_indexes.add(word2index[word]) + else: + n = len(ws) + N = len(words) + for i in range(N): + if i + n <= N and all(ws[j] == words[i + j] for j in range(n)): + excluding_indexes.update(range(i, i + n)) + + if isinstance(word, RePatternType): + for i in range(len(words)): + if word.match(words[i]): + excluding_indexes.add(i) + for domain, slot in slots: + for word in self.relevant_words_of_slot(domain, slot): + if isinstance(word, str): + ws = tokenize(word) + if len(ws) == 1: + if ws[0] in word2index: + excluding_indexes.add(word2index[word]) + else: + n = len(ws) + N = len(words) + for i in range(N): + if i + n <= N and all(ws[j] == words[i + j] for j in range(N)): + excluding_indexes.update(range(i, i + n)) + + if isinstance(word, RePatternType): + for i in range(len(words)): + if word.match(words[i]): + excluding_indexes.add(i) + + for i, word in enumerate(words): + if is_punctuation(word): + excluding_indexes.add(i) + elif word == 'reference' and i + 1 < len(words) and words[i + 1] == 'number': + # exclude "reference number" + excluding_indexes.update((i, i + 1)) + return excluding_indexes + + +## iter dialogues ############ +## the data format of the augmented multiwoz may be different from the original multiwoz +def _iter_dialogues(sample: MultiwozSampleType): + dialogues = sample['log'] + if isinstance(dialogues, list): + for i, dialog in enumerate(dialogues): + turn = dialog.get('turn', i) + yield turn, dialog + elif isinstance(dialogues, dict): + # assume key is `turn` + yield from dialogues.items() + else: + raise RuntimeError("unknown format.") + + +def iter_dialogues(sample: MultiwozSampleType, mode='usr'): + assert mode in ('usr', 'user', 'all', 'sys') + for turn, dialog in _iter_dialogues(sample): + if mode in ("usr", 'user') and turn % 2 == 1: + continue + if mode == 'sys' and turn % 2 == 0: + continue + yield turn, dialog + + +## random choice #################### +_EmptySequence = object() + + +def choice(seq: Iterable): + if hasattr(seq, '__len__') and hasattr(seq, '__getitem__'): + return random.choice(seq) + + r = _EmptySequence + for i, x in enumerate(seq, 1): + if random.random() * i <= 1: + r = x + if r is _EmptySequence: + raise ValueError("empty sequence") + return r + + +## record augmented text and span info, then returns an augmented sample +class AugmentationRecorder: + def __init__(self, original_sample: MultiwozSampleType): + self.original_sample = original_sample + self.augmented_turns = [] + + def add_augmented_dialog(self, turn_index, turn): + self.augmented_turns.append((turn_index, turn)) + + def get_augmented_sample(self) -> MultiwozSampleType: + sample = deepcopy(self.original_sample) + turns = sample['log'] + counter = Counter() + for turn_index, turn in self.augmented_turns: + # if there is more than one augmented text + # random choose one + counter[turn_index] += 1 + if random.random() * counter[turn_index] <= 1: + turns[turn_index] = {'turn_index': turn_index, 'augmented': True, **turn} + return sample + + +## check whether span info is consistent with text +def _equal_words(words1, words2, ignore_case): + if not ignore_case: + return words1 == words2 + else: + return len(words1) == len(words2) and all(w1.lower() == w2.lower() for w1, w2 in zip(words1, words2)) + + +def is_span_info_consistent_with_text(sentence: SentenceType, span_info: List[list], ignore_case=True) -> bool: + """check whether the span info is consistent with text.""" + words = convert_sentence_to_tokens(sentence) + return all( + _equal_words(words[start:end + 1], tokenize(span), ignore_case) for domain_intent, slot, span, start, end in + span_info) and len({tuple(x[-2:]) for x in span_info}) == len(span_info) diff --git a/convlab2/laug/__init__.py b/convlab2/laug/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a1664a8e01837a207228f0c21ec88d35338d4fb0 --- /dev/null +++ b/convlab2/laug/__init__.py @@ -0,0 +1,4 @@ +from convlab2.laug.Word_Perturbation.Word_Perturbation import Word_Perturbation +from convlab2.laug.Text_Paraphrasing.Text_Paraphrasing import Text_Paraphrasing +from convlab2.laug.Speech_Recognition.Speech_Recognition import Speech_Recognition +from convlab2.laug.Speech_Disfluency.Speech_Disfluency import Speech_Disfluency \ No newline at end of file diff --git a/convlab2/laug/demo.py b/convlab2/laug/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..59b2ba4a3f2e015f6120e342389204ec621b7def --- /dev/null +++ b/convlab2/laug/demo.py @@ -0,0 +1,29 @@ +from convlab2.laug import Word_Perturbation +from convlab2.laug import Text_Paraphrasing +from convlab2.laug import Speech_Recognition +from convlab2.laug import Speech_Disfluency + +if __name__=="__main__": + text = "I want a train to Cambridge" + span_info = [["Train-Infrom","Dest","Cambridge",5,5]] + WP = Word_Perturbation('multiwoz') + TP = Text_Paraphrasing('multiwoz') + SR = Speech_Recognition('multiwoz') + SD = Speech_Disfluency('multiwoz') + WP_text,WP_span_info = WP.aug(text,span_info) + print('Word Perturbation:') + print(WP_text) + print(WP_span_info) + TP_text,TP_span_info = TP.aug(text,span_info) + print('Text Paraphrasing:') + print(TP_text) + print(TP_span_info) + SR_text,SR_span_info = SR.aug(text,span_info) + print('Speech Recognition:') + print(SR_text) + print(SR_span_info) + SD_text,SD_span_info = SD.aug(text,span_info) + print('Speech Disfluency:') + print(SD_text) + print(SD_span_info) + diff --git a/convlab2/nlg/scgpt/README.md b/convlab2/nlg/scgpt/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8a6a47c9c85ca0633dc8001db09b401a53996d13 --- /dev/null +++ b/convlab2/nlg/scgpt/README.md @@ -0,0 +1,38 @@ +# GPT + +The code derives from [HuggingFace/Transformers](https://github.com/huggingface/transformers). + +## Preprocess + +```python +cd $dataset$ +python preprocess.py +``` + +## Train + +Fetch and unzip the checkpoint + +``` +wget https://bapengstorage.blob.core.windows.net/fileshare/scgpt.tar.gz +tar -xvf scgpt.tar.gz +``` + +Then + +``` python +python train.py --output_dir=$output_dir$ --model_type=scgpt --model_name_or_path=gpt2 --do_train --do_eval --eval_data_file=$test_file$ --overwrite_cache --use_tokenize --train_data_file=$train_file$ --overwrite_output_dir +``` + +## Use + +```python +python run.py --model_type=gpt2 --model_name_or_path=$save_dir$ --num_samples 5 --input_file=$test_file$ --output_file=$output_file$ --length 100 --stop_token '<|endoftext|>' --batch_size 16 +``` + +## Data Format + +``` +dialog act seq & user utterance +``` + diff --git a/convlab2/nlg/scgpt/__init__.py b/convlab2/nlg/scgpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c68785e9d0b64e1fe46403c4316a9fe1ea36eeb --- /dev/null +++ b/convlab2/nlg/scgpt/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- \ No newline at end of file diff --git a/convlab2/nlg/scgpt/decode.py b/convlab2/nlg/scgpt/decode.py new file mode 100644 index 0000000000000000000000000000000000000000..e95025afdde60d8beb41bac5d2fb038e39357f3d --- /dev/null +++ b/convlab2/nlg/scgpt/decode.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Apr 4 21:34:38 2020 + +@author: truthless +""" +import numpy as np +import torch + +def set_seed(seed, n_gpu): + np.random.seed(seed) + torch.manual_seed(seed) + if n_gpu > 0: + torch.cuda.manual_seed_all(seed) + + +def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): + """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size x vocabulary size) + top_k > 0: keep only top k tokens with highest probability (top-k filtering). + top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + top_k = min(top_k, logits.size(-1)) # Safety check + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p > 0.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits + + +def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, + is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'): + context = torch.tensor(context, dtype=torch.long, device=device) + context = context.unsqueeze(0).repeat(num_samples, 1) + generated = context + with torch.no_grad(): + for _ in range(length): + + inputs = {'input_ids': generated} + if is_xlnet: + # XLNet is a direct (predict same token, not next token) and bi-directional model by default + # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring) + input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1) + perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device) + perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token + target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device) + target_mapping[0, 0, -1] = 1.0 # predict last token + inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping} + + if is_xlm_mlm and xlm_mask_token: + # XLM MLM models are direct models (predict same token, not next token) + # => need one additional dummy token in the input (will be masked and guessed) + input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1) + inputs = {'input_ids': input_ids} + + if xlm_lang is not None: + inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1) + + outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states) + next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.) + + # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858) + for i in range(num_samples): + for _ in set(generated[i].tolist()): + next_token_logits[i, _] /= repetition_penalty + + filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) + if temperature == 0: # greedy sampling: + next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1) + else: + next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1) + generated = torch.cat((generated, next_token), dim=1) + return generated diff --git a/convlab2/nlg/scgpt/multiwoz/__init__.py b/convlab2/nlg/scgpt/multiwoz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..47c0e157e9f35dd3f0df26eebb4be33c183238b2 --- /dev/null +++ b/convlab2/nlg/scgpt/multiwoz/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +""" +Created on Sat Apr 4 21:43:42 2020 + +@author: truthless +""" + +from convlab2.nlg.scgpt.multiwoz.scgpt import SCGPT \ No newline at end of file diff --git a/convlab2/nlg/scgpt/multiwoz/preprocess.py b/convlab2/nlg/scgpt/multiwoz/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..3dcda2eb246fc919cf4843fe0c8075d0a9071138 --- /dev/null +++ b/convlab2/nlg/scgpt/multiwoz/preprocess.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +""" +Created on Mon Sep 14 11:38:53 2020 + +@author: truthless +""" + +import os +import json +from convlab2.nlg.scgpt.utils import dict2dict, dict2seq + +cur_dir = os.path.dirname(os.path.abspath(__file__)) +data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname( + cur_dir)))), 'data/multiwoz/') + +with open(os.path.join(data_dir, '0807_final.json'),'r', encoding='utf8') as f: + data = json.load(f) + +with open(os.path.join(data_dir, 'valListFile'), 'r') as f: + val_list = f.read().splitlines() +with open(os.path.join(data_dir, 'testListFile'), 'r') as f: + test_list = f.read().splitlines() + +results = {} +results_val = {} +results_test = {} + +for title, sess in data.items(): + logs = sess['log'] + turns = [] + turn = {'turn':0, 'sys':'', 'sys_da':''} + current_domain = None + for i, diag in enumerate(logs): + text = diag['text'] + da = diag['dialog_act'] + span = diag['span_info'] + if i % 2 == 0: + turn['usr'] = text + if current_domain: + da = eval(str(da).replace('Booking', current_domain)) + span = eval(str(span).replace('Booking', current_domain)) + turn['usr_da'] = da + turn['usr_span'] = span + turns.append(turn) + else: + turn = {'turn': i//2 +1} + turn['sys'] = text + turn['sys_da'] = da + turn['sys_span'] = span + for key in da: + domain = key.split('-')[0] + if domain not in ['general', 'Booking']: + current_domain = domain + title = title + if title in val_list: + current = results_val + elif title in test_list: + current = results_test + else: + current = results + current[title] = turns + +results = eval(str(results).replace(" n't", " not")) +results_val = eval(str(results_val).replace(" n't", " not")) +results_test = eval(str(results_test).replace(" n't", " not")) + +def init_domain(): + return {'Attraction':False, + 'Hospital':False, + 'Hotel':False, + 'Police':False, + 'Restaurant':False, + 'Taxi':False, + 'Train':False} + +def write_file(name, data): + with open(f'{name}.txt', 'w', encoding='utf-8') as f: + for ID in data: + sess = data[ID] + sess_domains = init_domain() + for turn in sess: + if not turn['usr_da']: + continue + turn['usr_da'] = eval(str(turn['usr_da']).replace('Bus','Train')) + da_seq = dict2seq(dict2dict(turn['usr_da'])).replace('&', 'and') + domains = set([key.split('-')[0] for key in turn['usr_da'].keys()]) + for domain in domains: + if domain not in ['general', 'Booking'] and not sess_domains[domain]: + da_seq = da_seq.replace(domain.lower(), domain.lower()+' *', 1) + sess_domains[domain] = True + da_uttr = turn['usr'].replace(' bus ', ' train ').replace('&', 'and') + f.write(f'{da_seq} & {da_uttr}\n') + +if not os.path.exists(os.path.join(cur_dir,'data')): + os.makedirs(os.path.join(cur_dir, 'data')) +write_file(os.path.join(cur_dir, 'data/train'), dict(results, **results_val)) +write_file(os.path.join(cur_dir, 'data/test'), results_test) diff --git a/convlab2/nlg/scgpt/multiwoz/run.py b/convlab2/nlg/scgpt/multiwoz/run.py new file mode 100644 index 0000000000000000000000000000000000000000..e583fe72fb26cd4262a6c4aae7776aabee49293b --- /dev/null +++ b/convlab2/nlg/scgpt/multiwoz/run.py @@ -0,0 +1,171 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import logging +from tqdm import trange + +import torch +import torch.nn.functional as F +import numpy as np + +import sys + +from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig + +from transformers import GPT2LMHeadModel, GPT2Tokenizer +from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer +from transformers import XLNetLMHeadModel, XLNetTokenizer +from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer +from transformers import CTRLLMHeadModel, CTRLTokenizer +from transformers import XLMWithLMHeadModel, XLMTokenizer + + +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO) +logger = logging.getLogger(__name__) + +MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop + +ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ()) + +MODEL_CLASSES = { + 'gpt2': (GPT2LMHeadModel, GPT2Tokenizer), + 'ctrl': (CTRLLMHeadModel, CTRLTokenizer), + 'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), + 'xlnet': (XLNetLMHeadModel, XLNetTokenizer), + 'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer), + 'xlm': (XLMWithLMHeadModel, XLMTokenizer), +} + +# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia +# in https://github.com/rusiaaman/XLNet-gen#methodology +# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e +PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family +(except for Alexei and Maria) are discovered. +The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the +remainder of the story. 1883 Western Siberia, +a young Grigori Rasputin is asked by his father and a group of men to perform magic. +Rasputin has a vision and denounces one of the men as a horse thief. Although his +father initially slaps him for making such an accusation, Rasputin watches as the +man is chased outside and beaten. Twenty years later, Rasputin sees a vision of +the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, +with people, even a bishop, begging for his blessing. <eod> </s> <eos>""" + + +def set_seed(args): + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_type", default=None, type=str, required=True, + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) + parser.add_argument("--model_name_or_path", default=None, type=str, required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) + parser.add_argument("--prompt", type=str, default="") + parser.add_argument("--padding_text", type=str, default="") + parser.add_argument("--length", type=int, default=40) + parser.add_argument("--num_samples", type=int, default=1) + parser.add_argument("--temperature", type=float, default=1.0, + help="temperature of 0 implies greedy sampling") + parser.add_argument("--repetition_penalty", type=float, default=1.0, + help="primarily useful for CTRL model; in that case, use 1.2") + parser.add_argument("--top_k", type=int, default=50) + parser.add_argument("--top_p", type=float, default=0.9) + parser.add_argument("--no_cuda", action='store_true', + help="Avoid using CUDA when available") + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + parser.add_argument('--stop_token', type=str, default=None, + help="Token at which text generation is stopped") + parser.add_argument("--batch_size", default=1, type=int) + parser.add_argument('--input_file', type=str, default=None, + help="file") + parser.add_argument('--output_file', type=str, default=None, + help="file") + + args = parser.parse_args() + + args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = torch.cuda.device_count() + + set_seed(args) + + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, pad_token='<PAD>', padding_side='left') + model = model_class.from_pretrained(args.model_name_or_path) + model.to(args.device) + model.eval() + + if args.length < 0 and model.config.max_position_embeddings > 0: + args.length = model.config.max_position_embeddings + elif 0 < model.config.max_position_embeddings < args.length: + args.length = model.config.max_position_embeddings # No generation bigger than model size + elif args.length < 0: + args.length = MAX_LENGTH # avoid infinite loop + + logger.info(args) + if args.model_type in ["ctrl"]: + if args.temperature > 0.7: + logger.info('CTRL typically works better with lower temperatures (and lower top_k).') + + fin = open(args.input_file) + inputs = [i.strip() for i in fin] + output_tests = [] + for idx in range(0, len(inputs), args.batch_size): + logger.info(f"PROGRESS: {int(idx/len(inputs)*100)}%") + + # raw_text = args.prompt if args.prompt else input("Model prompt >>> ") + raw_inputs = [] + for i in range(idx, min(idx+args.batch_size, len(inputs))): + lines = inputs[i] + raw_text = lines.split(' & ')[0] + ' & ' + if args.model_type in ["transfo-xl", "xlnet"]: + # Models with memory likes to have a long prompt for short inputs. + raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text + raw_inputs.append(raw_text) + + encoding_inputs = tokenizer.batch_encode_plus(raw_inputs, pad_to_max_length=True, add_special_tokens=False) + context_tokens = torch.LongTensor(encoding_inputs['input_ids']).to(args.device) + max_length = len(context_tokens[0]) + attention_mask = torch.LongTensor(encoding_inputs['attention_mask']).to(args.device) + position_ids = (attention_mask.cumsum(-1) - 1) + position_ids.masked_fill_(attention_mask==0, 0) + + if args.model_type == "ctrl": + if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()): + logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") + out_ids = model.generate( + input_ids=context_tokens, + attention_mask=attention_mask, + position_ids=position_ids, + num_beams=args.num_samples, + num_return_sequences=args.num_samples, + max_length=args.length, + temperature=args.temperature, + do_sample=True, + top_k=args.top_k, + top_p=args.top_p, + repetition_penalty=args.repetition_penalty + ) + out_ids = out_ids.reshape(len(raw_inputs), args.num_samples, -1)[:, :, max_length:].tolist() + for j, out in enumerate(out_ids): + examples = [inputs[j]] + for o in out: + text = tokenizer.decode(o, clean_up_tokenization_spaces=True) + text = text[: text.find(args.stop_token) if args.stop_token else None] + examples.append(text) + output_tests.append(examples) + # break + # if args.prompt: + # break + import json + json.dump(output_tests, open(args.output_file,'w'), indent=2) + return text + +if __name__ == '__main__': + main() diff --git a/convlab2/nlg/scgpt/multiwoz/scgpt.py b/convlab2/nlg/scgpt/multiwoz/scgpt.py new file mode 100644 index 0000000000000000000000000000000000000000..18f599a18bdd463c72e0c308764988b877165949 --- /dev/null +++ b/convlab2/nlg/scgpt/multiwoz/scgpt.py @@ -0,0 +1,91 @@ +import torch +import numpy as np +import os +import zipfile + +from transformers import GPT2LMHeadModel, GPT2Tokenizer +from convlab2.nlg.scgpt.utils import tuple2seq +from convlab2.nlg.scgpt.decode import set_seed, sample_sequence +from convlab2.nlg.nlg import NLG +from convlab2.util.file_util import cached_path + +MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop +DEFAULT_DIRECTORY = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models") +DEFAULT_ARCHIVE_FILE = os.path.join(DEFAULT_DIRECTORY, "nlg-gpt-multiwoz.zip") + +class SCGPT(NLG): + + def __init__(self, + archive_file=DEFAULT_ARCHIVE_FILE, + use_cuda=True, + is_user=False, + model_file='https://convlab.blob.core.windows.net/convlab-2/nlg-gpt-multiwoz.zip'): + model_dir = os.path.dirname(os.path.abspath(__file__)) + if not os.path.isfile(archive_file): + archive_file = cached_path(model_file) + archive = zipfile.ZipFile(archive_file, 'r') + archive.extractall(model_dir) + + self.model_name_or_path = os.path.join(model_dir, 'multiwoz') + self.length = 50 + self.num_samples = 5 + self.temperature = 1.0 + self.repetition_penalty = 1.0 + self.top_k = 50 + self.top_p = 0.9 + self.seed = 42 + self.stop_token = '<|endoftext|>' + + self.device = torch.device("cuda" if torch.cuda.is_available() and use_cuda else "cpu") + set_seed(self.seed, torch.cuda.device_count()) + + model_class, tokenizer_class = GPT2LMHeadModel, GPT2Tokenizer + self.tokenizer = tokenizer_class.from_pretrained(self.model_name_or_path) + self.model = model_class.from_pretrained(self.model_name_or_path) + self.model.to(self.device) + self.model.eval() + + if self.length < 0 and self.model.config.max_position_embeddings > 0: + self.length = self.model.config.max_position_embeddings + elif 0 < self.model.config.max_position_embeddings < self.length: + self.length = self.model.config.max_position_embeddings # No generation bigger than model size + elif self.length < 0: + self.length = self.MAX_LENGTH # avoid infinite loop + + def init_session(self): + self.sess_domains = {'Attraction':False, + 'Hospital':False, + 'Hotel':False, + 'Police':False, + 'Restaurant':False, + 'Taxi':False, + 'Train':False} + + def generate(self, meta): + + raw_text = tuple2seq(meta) + domains = set([item[1] for item in meta]) + for domain in domains: + if domain != 'general' and not self.sess_domains[domain]: + raw_text = raw_text.replace(domain.lower(), domain.lower()+ ' *', 1) + self.sess_domains[domain] = True + context_tokens = self.tokenizer.encode(raw_text, add_special_tokens=False) + out = sample_sequence( + model=self.model, + context=context_tokens, + num_samples=self.num_samples, + length=self.length, + temperature=self.temperature, + top_k=self.top_k, + top_p=self.top_p, + repetition_penalty=self.repetition_penalty, + device=self.device, + ) + out = out[:, len(context_tokens):].tolist() + index = np.random.choice([0,1,2,3],p=[0.4,0.3,0.2,0.1]) + o = out[index] + text = self.tokenizer.decode(o, clean_up_tokenization_spaces=True) + text = text.split('& ')[-1] + text = text[: text.find(self.stop_token) if self.stop_token else None] + + return text \ No newline at end of file diff --git a/convlab2/nlg/scgpt/train.py b/convlab2/nlg/scgpt/train.py new file mode 100644 index 0000000000000000000000000000000000000000..775688bbd63e116da42d5f02ecb78930c823a229 --- /dev/null +++ b/convlab2/nlg/scgpt/train.py @@ -0,0 +1,633 @@ +from __future__ import absolute_import, division, print_function + +import argparse +import glob +import logging +import os +import pickle +import random +import re +import shutil + +import sys + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler +from torch.utils.data.distributed import DistributedSampler + +try: + from torch.utils.tensorboard import SummaryWriter +except: + from tensorboardX import SummaryWriter + +from tqdm import tqdm, trange + +from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, + BertConfig, BertForMaskedLM, BertTokenizer, + GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, + OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, + RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, + DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer, BertTokenizer) + + +logger = logging.getLogger(__name__) + + +MODEL_CLASSES = { + 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), + 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), + 'bert': (BertConfig, BertForMaskedLM, BertTokenizer), + 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer), + 'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer) +} + + +class TextDataset(Dataset): + def __init__(self, tokenizer, args, file_path='train', block_size=512, max_seq=80): + assert os.path.isfile(file_path) + directory, filename = os.path.split(file_path) + cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(block_size) + '_seqlen_' + str(max_seq) + '_' + filename) + + if os.path.exists(cached_features_file) and not args.overwrite_cache: + logger.info("Loading features from cached file %s", cached_features_file) + with open(cached_features_file, 'rb') as handle: + self.examples = pickle.load(handle) + else: + logger.info("Creating features from dataset file at %s", directory) + + self.examples = [] + + with open(file_path, encoding="utf-8") as f: + if args.text_chunk: + text = f.read() + tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) + else: + for line in f: + tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line.strip() + ' eos')) + self.examples.append(tokenized_text) + + if args.text_chunk: + for i in range(0, len(tokenized_text)-block_size+1, block_size): # Truncate in block of block_size + self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i:i+block_size])) + + + # Note that we are loosing the last truncated example here for the sake of simplicity (no padding) + # If your dataset is small, first you should loook for a bigger one :-) and second you + # can change this behavior by adding (model specific) padding. + + logger.info("Saving features into cached file %s", cached_features_file) + with open(cached_features_file, 'wb') as handle: + pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) + + def __len__(self): + return len(self.examples) + + def __getitem__(self, item): + return torch.tensor(self.examples[item]) + +class TextSeqDataset(Dataset): + def __init__(self, tokenizer, args, file_path='train', block_size=512, max_seq=80, seperator=' & '): + assert os.path.isfile(file_path) + directory, filename = os.path.split(file_path) + cached_features_file = os.path.join(directory, args.output_dir.replace(os.sep, '_') + '_cached_lm_' + str(block_size) + '_seqlen_' + str(max_seq) + '_' + filename) + + if os.path.exists(cached_features_file) and not args.overwrite_cache: + logger.info("Loading features from cached file %s", cached_features_file) + with open(cached_features_file, 'rb') as handle: + self.examples = pickle.load(handle) + else: + logger.info("Creating features from dataset file at %s", directory) + self.examples = [] + self.labels = [] + self.masks = [] + with open(file_path, encoding="utf-8") as f: + for line in f: + line = line.strip() + raw_str = line.lower() + code_str = line.lower().split(seperator)[0] + seperator + code_str = code_str.strip() + if len(raw_str.split()) > max_seq -1: + raw_str = ' '.join(raw_str.split()[:max_seq -1]) + raw_str += ' ' + tokenizer.eos_token + if args.use_tokenize: + tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_str)) + code_str_len = len(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(code_str))) + else: + tokenized_text = tokenizer.convert_tokens_to_ids(raw_str.split()) + code_str_len = len(tokenizer.convert_tokens_to_ids(code_str.split())) + + label = [-1] * max_seq + label[:len(tokenized_text)] = tokenized_text + mask = [1] * max_seq + + + if len(tokenized_text) < max_seq: + mask[-(max_seq - len(tokenized_text)):] = [0] * (max_seq - len(tokenized_text)) + # label[code_str_len:len(tokenized_text)] = tokenized_text[code_str_len:] + tokenized_text = tokenized_text + [0] * (max_seq - len(tokenized_text)) + else: + tokenized_text = tokenized_text[:max_seq] + # label[code_str_len:] = tokenized_text[code_str_len:] + + self.examples.append(tokenized_text) + self.masks.append(mask) + self.labels.append(label) + + # Note that we are loosing the last truncated example here for the sake of simplicity (no padding) + # If your dataset is small, first you should loook for a bigger one :-) and second you + # can change this behavior by adding (model specific) padding. + if args.with_code_loss: + self.labels = self.examples + logger.info("Saving features into cached file %s", cached_features_file) + with open(cached_features_file, 'wb') as handle: + pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) + + def __len__(self): + return len(self.examples) + + def __getitem__(self, item): + return torch.tensor(self.examples[item]), torch.tensor(self.masks[item]), torch.tensor(self.labels[item]) + + +def load_and_cache_examples(args, tokenizer, evaluate=False): + dataset = TextSeqDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size, max_seq=args.max_seq) + return dataset + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def _rotate_checkpoints(args, checkpoint_prefix, use_mtime=False): + if not args.save_total_limit: + return + if args.save_total_limit <= 0: + return + + # Check if we should delete older checkpoint(s) + glob_checkpoints = glob.glob(os.path.join(args.output_dir, '{}-*'.format(checkpoint_prefix))) + if len(glob_checkpoints) <= args.save_total_limit: + return + + ordering_and_checkpoint_path = [] + for path in glob_checkpoints: + if use_mtime: + ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) + else: + regex_match = re.match('.*{}-([0-9]+)'.format(checkpoint_prefix), path) + if regex_match and regex_match.groups(): + ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) + + checkpoints_sorted = sorted(ordering_and_checkpoint_path) + checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] + number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit) + checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] + for checkpoint in checkpoints_to_be_deleted: + logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint)) + shutil.rmtree(checkpoint) + + +def mask_tokens(inputs, tokenizer, args): + """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ + labels = inputs.clone() + # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) + probability_matrix = torch.full(labels.shape, args.mlm_probability) + special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()] + probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) + masked_indices = torch.bernoulli(probability_matrix).bool() + labels[~masked_indices] = -1 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices + inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) + + # 10% of the time, we replace masked input tokens with random word + indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced + random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long) + inputs[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input tokens unchanged + return inputs, labels + + +def train(args, train_dataset, model, tokenizer): + """ Train the model """ + if args.local_rank in [-1, 0]: + tb_writer = SummaryWriter() + + args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) + train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) + train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) + + if args.max_steps > 0: + t_total = args.max_steps + args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 + else: + t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs + + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ['bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) + if args.fp16: + try: + from apex import amp + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) + model.resize_token_embeddings(len(tokenizer)) + # multi-gpu training (should be after apex fp16 initialization) + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + + # Distributed training (should be after apex fp16 initialization) + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True) + + # Train! + logger.info("***** Running training *****") + logger.info(" Num examples = %d", len(train_dataset)) + logger.info(" Num Epochs = %d", args.num_train_epochs) + logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) + logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", t_total) + + global_step = 0 + tr_loss, logging_loss = 0.0, 0.0 + model.zero_grad() + train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) + set_seed(args) # Added here for reproducibility (even between python 2 and 3) + for e in train_iterator: + + # epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) + for step, batch in enumerate(train_dataloader): + # inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) + logger.info(f" PROGRESS: {float(global_step)/t_total*100}%") + inputs, masks, labels = batch + # import pdb + # pdb.set_trace() + inputs = inputs.to(args.device) + # masks = masks.to(args.device) + labels = labels.to(args.device) + + model.train() + outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels) + loss = outputs[0] # model outputs are always tuple in transformers (see doc) + + if args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + if args.fp16: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + tr_loss += loss.item() + if (step + 1) % args.gradient_accumulation_steps == 0: + if args.fp16: + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + optimizer.step() + scheduler.step() # Update learning rate schedule + model.zero_grad() + global_step += 1 + + if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: + # Log metrics + if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well + results = evaluate(args, model, tokenizer) + for key, value in results.items(): + tb_writer.add_scalar('eval_{}'.format(key), value, global_step) + tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) + tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) + logger.info(f" EVALERR: {(tr_loss - logging_loss)/float(args.logging_steps)}") + logging_loss = tr_loss + + if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: + checkpoint_prefix = 'checkpoint' + # Save model checkpoint + output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + torch.save(args, os.path.join(output_dir, 'training_args.bin')) + logger.info("Saving model checkpoint to %s", output_dir) + + _rotate_checkpoints(args, checkpoint_prefix) + + # if args.max_steps > 0 and global_step > args.max_steps: + # epoch_iterator.close() + # break + if args.max_steps > 0 and global_step > args.max_steps: + train_iterator.close() + break + + if args.local_rank in [-1, 0]: + tb_writer.close() + + return global_step, tr_loss / global_step + + +def evaluate(args, model, tokenizer, prefix=""): + # Loop to handle MNLI double evaluation (matched, mis-matched) + eval_output_dir = args.output_dir + + eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True) + + if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: + os.makedirs(eval_output_dir) + + args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) + # Note that DistributedSampler samples randomly + eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) + eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + + # multi-gpu evaluate + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + + # Eval! + logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info(" Num examples = %d", len(eval_dataset)) + logger.info(" Batch size = %d", args.eval_batch_size) + eval_loss = 0.0 + nb_eval_steps = 0 + model.eval() + + for batch in tqdm(eval_dataloader, desc="Evaluating"): + # inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) + + inputs, masks, labels = batch + # import pdb + # pdb.set_trace() + inputs = inputs.to(args.device) + masks = masks.to(args.device) + labels = labels.to(args.device) + # inputs = inputs.to(args.device) + # labels = labels.to(args.device) + + with torch.no_grad(): + outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels) + lm_loss = outputs[0] + eval_loss += lm_loss.mean().item() + nb_eval_steps += 1 + + eval_loss = eval_loss / nb_eval_steps + perplexity = torch.exp(torch.tensor(eval_loss)) + + result = { + "perplexity": perplexity + } + + output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results {} *****".format(prefix)) + for key in sorted(result.keys()): + logger.info(" %s = %s", key, str(result[key])) + writer.write("%s = %s\n" % (key, str(result[key]))) + + return result + + +def main(): + parser = argparse.ArgumentParser() + + ## Required parameters + parser.add_argument("--train_data_file", default=None, type=str, required=True, + help="The input training data file (a text file).") + parser.add_argument("--output_dir", default=None, type=str, required=True, + help="The output directory where the model predictions and checkpoints will be written.") + + ## Other parameters + parser.add_argument("--eval_data_file", default=None, type=str, + help="An optional input evaluation data file to evaluate the perplexity on (a text file).") + + parser.add_argument("--model_type", default="gpt2", type=str, + help="The model architecture to be fine-tuned.") + parser.add_argument("--model_name_or_path", default="gpt2", type=str, + help="The model checkpoint for weights initialization.") + + parser.add_argument("--mlm", action='store_true', + help="Train with masked-language modeling loss instead of language modeling.") + parser.add_argument("--mlm_probability", type=float, default=0.15, + help="Ratio of tokens to mask for masked language modeling loss") + + parser.add_argument("--config_name", default="", type=str, + help="Optional pretrained config name or path if not the same as model_name_or_path") + parser.add_argument("--tokenizer_name", default="", type=str, + help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") + parser.add_argument("--cache_dir", default="", type=str, + help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)") + parser.add_argument("--block_size", default=80, type=int, + help="Optional input sequence length after tokenization." + "The training dataset will be truncated in block of this size for training." + "Default to the model max input length for single sentence inputs (take into account special tokens).") + parser.add_argument("--do_train", action='store_true', + help="Whether to run training.") + parser.add_argument("--do_eval", action='store_true', + help="Whether to run eval on the dev set.") + parser.add_argument("--evaluate_during_training", action='store_true', + help="Run evaluation during training at each logging step.") + parser.add_argument("--do_lower_case", action='store_true', + help="Set this flag if you are using an uncased model.") + + parser.add_argument("--per_gpu_train_batch_size", default=1, type=int, + help="Batch size per GPU/CPU for training.") + parser.add_argument("--per_gpu_eval_batch_size", default=1, type=int, + help="Batch size per GPU/CPU for evaluation.") + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument("--learning_rate", default=1e-5, type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--weight_decay", default=0.0, type=float, + help="Weight deay if we apply some.") + parser.add_argument("--adam_epsilon", default=1e-8, type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, + help="Max gradient norm.") + parser.add_argument("--num_train_epochs", default=5.0, type=float, + help="Total number of training epochs to perform.") + parser.add_argument("--max_steps", default=-1, type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.") + parser.add_argument("--warmup_steps", default=0, type=int, + help="Linear warmup over warmup_steps.") + + parser.add_argument('--logging_steps', type=int, default=100, + help="Log every X updates steps.") + parser.add_argument('--save_steps', type=int, default=5000, + help="Save checkpoint every X updates steps.") + parser.add_argument('--save_total_limit', type=int, default=None, + help='Limit the total amount of checkpoints, delete the older checkpoints in the output_dir, does not delete by default') + parser.add_argument("--eval_all_checkpoints", action='store_true', + help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number") + parser.add_argument("--no_cuda", action='store_true', + help="Avoid using CUDA when available") + parser.add_argument('--overwrite_output_dir', action='store_true', + help="Overwrite the content of the output directory") + parser.add_argument('--overwrite_cache', action='store_true', + help="Overwrite the cached training and evaluation sets") + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") + parser.add_argument('--fp16_opt_level', type=str, default='O1', + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html") + parser.add_argument("--local_rank", type=int, default=-1, + help="For distributed training: local_rank") + parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") + parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") + parser.add_argument('--text_chunk', action='store_true', help="") + parser.add_argument('--use_reverse', action='store_true', help="") + parser.add_argument('--with_code_loss', type=bool, default=True, help="") + parser.add_argument('--use_tokenize', action='store_true', help="") + + parser.add_argument("--max_seq", default=80, type=int, + help="") + + args = parser.parse_args() + + if args.model_type in ["bert", "roberta", "distilbert"] and not args.mlm: + raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm " + "flag (masked language modeling).") + if args.eval_data_file is None and args.do_eval: + raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " + "or remove the --do_eval argument.") + + if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: + raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) + + # Setup distant debugging if needed + if args.server_ip and args.server_port: + # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script + import ptvsd + print("Waiting for debugger attach") + ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) + ptvsd.wait_for_attach() + + # Setup CUDA, GPU & distributed training + if args.local_rank == -1 or args.no_cuda: + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend='nccl') + args.n_gpu = 1 + args.device = device + + # Setup logging + logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) + + # Set seed + set_seed(args) + + # Load pretrained model and tokenizer + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab + + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, + cache_dir=args.cache_dir if args.cache_dir else None) + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, + #tokenizer = BertTokenizer(vocab_file='../GPT2-chitchat/vocabulary/vocab_small.txt', eos_token='<T>', + do_lower_case=args.do_lower_case, + cache_dir=args.cache_dir if args.cache_dir else None) + + if args.block_size <= 0: + args.block_size = tokenizer.max_len_single_sentence # Our input block size will be the max possible for the model + args.block_size = min(args.block_size, tokenizer.max_len_single_sentence) + model = model_class.from_pretrained(args.model_name_or_path, + from_tf=bool('.ckpt' in args.model_name_or_path), + config=config, + cache_dir=args.cache_dir if args.cache_dir else None) + model.to(args.device) + + if args.local_rank == 0: + torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab + + logger.info("Training/evaluation parameters %s", args) + + # Training + if args.do_train: + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache + + train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False) + + if args.local_rank == 0: + torch.distributed.barrier() + + global_step, tr_loss = train(args, train_dataset, model, tokenizer) + logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) + + + # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained() + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + # Create output directory if needed + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + logger.info("Saving model checkpoint to %s", args.output_dir) + # Save a trained model, configuration and tokenizer using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + model_to_save.save_pretrained(args.output_dir) + tokenizer.save_pretrained(args.output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) + + # Load a trained model and vocabulary that you have fine-tuned + model = model_class.from_pretrained(args.output_dir) + tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) + model.to(args.device) + + + # Evaluation + results = {} + if args.do_eval and args.local_rank in [-1, 0]: + checkpoints = [args.output_dir] + if args.eval_all_checkpoints: + checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) + logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging + logger.info("Evaluate the following checkpoints: %s", checkpoints) + for checkpoint in checkpoints: + global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" + prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else "" + + model = model_class.from_pretrained(checkpoint) + model.to(args.device) + result = evaluate(args, model, tokenizer, prefix=prefix) + result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) + results.update(result) + + return results + + +if __name__ == "__main__": + main() diff --git a/convlab2/nlg/scgpt/utils.py b/convlab2/nlg/scgpt/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7fefc166f096c307590fd5b0478c4db1cf551e7f --- /dev/null +++ b/convlab2/nlg/scgpt/utils.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Mar 24 18:34:55 2020 + +@author: truthless +""" + +def tuple2dict(t): + ''' + tuple: [(intent, domain, slot, value)] + dict: [domain: { intent: [slot, value] }] + ''' + d = {} + for intent, domain, slot, value in t: + if domain not in d: + d[domain] = {} + if intent not in d[domain]: + d[domain][intent] = [] + if slot == 'none' or slot is None: + continue + d[domain][intent].append([slot, value]) + return d + +def dict2dict(D): + ''' + dict: [domain-intent: [slot, value]] + dict: [domain: { intent: [slot, value] }] + ''' + d = {} + for domint in D: + domain, intent = domint.split('-') + if domain not in d: + d[domain] = {} + if intent not in d[domain]: + d[domain][intent] = [] + for slot, value in D[domint]: + if slot == 'none' or slot is None: + continue + d[domain][intent].append([slot, value]) + return d + +def dict2seq(d): + ''' + dict: [domain: { intent: [slot, value] }] + seq: [domain { intent ( slot = value ; ) @ } | ] + ''' + s = '' + first_domain = True + first_intent = True + first_slot = True + for domain in d: + if not first_domain: + s += ' | ' + s += domain + s += ' { ' + first_domain = False + first_intent = True + for intent in d[domain]: + if not first_intent: + s += ' @ ' + s += intent + s += ' ( ' + first_intent = False + first_slot = True + for slot, value in d[domain][intent]: + if not first_slot: + s += ' ; ' + s += slot + if value: + s += ' = ' + s += value + first_slot = False + s += ' )' + s += ' }' + return s.lower() + +def tuple2seq(t): + d = tuple2dict(t) + s = dict2seq(d) + return s + +if __name__ == '__main__': + da_tuple = [('Inform', 'Booking', 'none', 'none'), ('Inform', 'Hotel', 'Price', 'cheap'), ('Inform', 'Hotel', 'Choice', '1'), ('Inform', 'Hotel', 'Parking', 'none')] + da_dict = tuple2dict(da_tuple) + print(da_dict) + da_seq = dict2seq(da_dict) + print(da_seq) + + da_tuple = [('Request', 'Hotel', 'Address', '?'), ('Request', 'Hotel', 'Area', '?'), ('Inform', 'Attraction', 'Area', 'center'), ('Inform', 'Hotel', 'Price', 'cheap')] + da_dict = tuple2dict(da_tuple) + print(da_dict) + da_seq = dict2seq(da_dict) + print(da_seq) + + D = {'Hotel-Inform': [['Price', 'cheap'], ['Type', 'hotel']]} + da_dict = dict2dict(D) + print(da_dict) + diff --git a/convlab2/util/multiwoz/paraphrase_span_detection.py b/convlab2/util/multiwoz/paraphrase_span_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..38b9bde8df86c040eec3a429faccabd69676fb14 --- /dev/null +++ b/convlab2/util/multiwoz/paraphrase_span_detection.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Aug 11 17:49:53 2020 + +@author: truthless +""" + +import spacy +from fuzzywuzzy import fuzz + +digit2word = { + '0': 'zero', '1': 'one', '2': 'two', '3': 'three', '4': 'four', '5': 'five', + '6': 'six', '7': 'seven', '8': 'eight', '9': 'nine', '10': 'ten', '11': 'eleven', + '12': 'twelve' +} +word2digit = {v:k for k,v in digit2word.items()} + +nlp = spacy.load('en_core_web_sm') +threshold = 60 + +def digit_normalize(utt_list): + for i, text in enumerate(utt_list): + if text in word2digit: + utt_list[i] = word2digit[text] + return utt_list + +def phrase_idx_utt(value_list, utt_list, constraint=[]): + value_list = digit_normalize(value_list) + utt_list = digit_normalize(utt_list) + candidates = [] + l = len(value_list) + for i in [l, l-1, l+1]: + if i == 0: + continue + for j in range(len(utt_list)-i+1): + if constraint and j <= constraint[0] and constraint[0] < j+i: + if j == constraint[0]: + constraint.append(constraint.pop(0)) + continue + score = fuzz.ratio(' '.join(utt_list[j:j+i]), ' '.join(value_list)) + if score > threshold: + candidates.append((score, j, j+i-1)) + return sorted(candidates, key=lambda x:x[0], reverse=True)[0][1:] if candidates else None + +def preprocess(utt, da): + ''' + utt: str + da: dict {'domain-intent': [slot, value]} + ''' + with nlp.disable_pipes('tagger', 'parser'): + ''' + get tokens, recover the paraphrased entity + ''' + tokens = [token.text for token in nlp(utt)] + for key, pair in da.items(): + constraint = [] + intent = key.split('-')[1].lower() + if intent not in ['inform']: + continue + for slot, value in pair: + if slot.lower() in ['name', 'dest', 'depart']: + value_tokens = [token.text for token in nlp(value)] + span = phrase_idx_utt(value_tokens, tokens, constraint) + if span is not None: + for i in range(span[0], span[1]+1): + constraint.append(i) + tokens[span[0]:span[1]+1] = value_tokens + + ''' + get labels, tag or slot or none + ''' + labels = dict() + for key, pair in da.items(): + constraint = [] + intent = key.split('-')[1].lower() + if intent in ["request"]: + labels[key] = [] + elif intent in ['inform']: + labels[key] = ["O"] * len(tokens) + else: + labels[key] = None + for slot, value in pair: + if intent in ["request"]: + labels[key].append(slot) + elif intent in ['inform']: + value_tokens = [token.text for token in nlp(value)] + span = phrase_idx_utt(value_tokens, tokens, constraint) + if span is not None: + for i in range(span[0], span[1]+1): + constraint.append(i) + #tags[span[0]] = "B-" + da[1] + '-' + da[0] + "+" + da[2] + labels[key][span[0]] = "B-" + slot + for i in range(span[0]+1, span[1]+1): + #tags[i] = "I-" + da[1] + '-' + da[0] + "+" + da[2] + labels[key][i] = "I-" + slot + return tokens, labels diff --git a/setup.py b/setup.py index 4744b317df5a7511952b00c8e923613bf607b318..321a111f8971eef619c14642320e145d108dbcea 100755 --- a/setup.py +++ b/setup.py @@ -59,7 +59,10 @@ setup( 'pyyaml', 'fuzzywuzzy', 'python-Levenshtein', - 'json_lines' + 'json_lines', + 'gtts', + 'DeepSpeech', + 'pydub' ], extras_require={ 'develop': [