diff --git a/convlab2/base_models/t5/create_data.py b/convlab2/base_models/t5/create_data.py index d4b99d9a7ebfc294433f06e795aff2620f0c1fb5..71fea81e73969f74c5e962445d9143cf38e722d0 100644 --- a/convlab2/base_models/t5/create_data.py +++ b/convlab2/base_models/t5/create_data.py @@ -15,9 +15,10 @@ def create_rg_data(dataset, data_dir, args): for data_split in data_splits: data = [] for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): - context = ' '.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]) - response = f"{sample['speaker']}: {sample['utterance']}" - data.append(json.dumps({'context': context, 'response': response}, ensure_ascii=False)+'\n') + if len(sample['context']) == 0: + continue + context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[f'{sample["speaker"]}: ']) + data.append(json.dumps({'context': context, 'response': sample['utterance']}, ensure_ascii=False)+'\n') file_name = os.path.join(data_dir, f"{data_split}.json") with open(file_name, "w", encoding='utf-8') as f: @@ -34,7 +35,7 @@ def create_nlu_data(dataset, data_dir, args): for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): response = f"{sample['speaker']}: {sample['utterance']}" if args.context_window_size>0: - context = ' '.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[response]) + context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[response]) else: context = response dialogue_acts_seq = serialize_dialogue_acts(sample['dialogue_acts']) @@ -56,7 +57,7 @@ def create_dst_data(dataset, data_dir, args): for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): response = f"{sample['speaker']}: {sample['utterance']}" if args.context_window_size>0: - context = ' '.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[response]) + context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[response]) else: context = response state_seq = serialize_dialogue_state(sample['state']) @@ -67,6 +68,28 @@ def create_dst_data(dataset, data_dir, args): with open(file_name, "w", encoding='utf-8') as f: f.writelines(data) +def create_nlg_data(dataset, data_dir, args): + data_by_split = load_nlu_data(dataset, speaker=args.speaker, use_context=args.context_window_size>0, context_window_size=args.context_window_size) + data_dir = os.path.join(data_dir, args.speaker, f'context_{args.context_window_size}') + os.makedirs(data_dir, exist_ok=True) + + data_splits = data_by_split.keys() + for data_split in data_splits: + data = [] + for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): + dialogue_acts_seq = serialize_dialogue_acts(sample['dialogue_acts']) + if args.context_window_size>0: + context = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['context']]+[f'{sample["speaker"]}: ']) + context = f'{dialogue_acts_seq}\n\n{context}' + else: + context = f'{dialogue_acts_seq}\n\n{sample["speaker"]}: ' + assert equal_da_seq(sample['dialogue_acts'], dialogue_acts_seq), print(sample['dialogue_acts'], dialogue_acts_seq, deserialize_dialogue_acts(dialogue_acts_seq)) + data.append(json.dumps({'context+da': context, 'response': sample['utterance']}, ensure_ascii=False)+'\n') + + file_name = os.path.join(data_dir, f"{data_split}.json") + with open(file_name, "w", encoding='utf-8') as f: + f.writelines(data) + def create_goal2dialogue_data(dataset, data_dir, args): data_by_split = dataset os.makedirs(data_dir, exist_ok=True) @@ -76,7 +99,7 @@ def create_goal2dialogue_data(dataset, data_dir, args): data = [] for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False): goal = re.sub(r'<.*?>', '', sample['goal']['description']) - dialogue = ' '.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['turns']]) + dialogue = '\n'.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['turns']]) data.append(json.dumps({'goal': goal, 'dialogue': dialogue}, ensure_ascii=False)+'\n') file_name = os.path.join(data_dir, f"{data_split}.json") @@ -87,7 +110,7 @@ def create_goal2dialogue_data(dataset, data_dir, args): if __name__ == '__main__': from argparse import ArgumentParser parser = ArgumentParser(description="create data for seq2seq training") - parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['rg', 'nlu', 'dst', 'goal2dialogue'], help='names of tasks') + parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['rg', 'nlu', 'dst', 'nlg', 'goal2dialogue'], help='names of tasks') parser.add_argument('--datasets', '-d', metavar='dataset_name', nargs='*', help='names of unified datasets') parser.add_argument('--speaker', '-s', type=str, choices=['user', 'system', 'all'], help='speaker(s)') parser.add_argument('--context_window_size', '-c', type=int, default=0, help='how many contextual utterances are considered') diff --git a/convlab2/base_models/t5/run_seq2seq.py b/convlab2/base_models/t5/run_seq2seq.py index dace9713d540b7fe2aa1c552132cc4c54d698989..2f0f5481243c2f78eac4d352786482508f70e617 100644 --- a/convlab2/base_models/t5/run_seq2seq.py +++ b/convlab2/base_models/t5/run_seq2seq.py @@ -212,8 +212,8 @@ class DataTrainingArguments: "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." }, ) - source_prefix_filepath: Optional[str] = field( - default=None, metadata={"help": "A file whose first line is the prefix to add before every source text (useful for T5 models)."} + source_prefix: Optional[str] = field( + default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."} ) def __post_init__(self): @@ -271,7 +271,7 @@ def main(): ) logger.info(f"Training/evaluation parameters {training_args}") - if data_args.source_prefix_filepath is None and model_args.model_name_or_path in [ + if data_args.source_prefix is None and model_args.model_name_or_path in [ "t5-small", "t5-base", "t5-large", @@ -280,7 +280,7 @@ def main(): ]: logger.warning( "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " - "`--source_prefix_filepath 'path_to_prefix_file' ` whose first line is the source prefix" + "`--source_prefix 'summarize: ' `" ) # Detecting last checkpoint. @@ -386,10 +386,7 @@ def main(): "resize the model's position encodings by passing `--resize_position_embeddings`." ) - if data_args.source_prefix_filepath is not None: - prefix = open(data_args.source_prefix_filepath, 'r', encoding='utf-8').readline().strip('\n') - else: - prefix = "" + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" logger.info(f'source prefix: "{prefix}"') diff --git a/convlab2/util/unified_datasets_util.py b/convlab2/util/unified_datasets_util.py index 46ba13d79c2bde1693454caad7c95596e7e63d81..014f5306a4f41637d9dab92d1fcb00a4d41a3463 100644 --- a/convlab2/util/unified_datasets_util.py +++ b/convlab2/util/unified_datasets_util.py @@ -124,7 +124,7 @@ def load_unified_data( sample['domains'] = dialogue['domains'] if terminated: sample['terminated'] = turn['utt_idx'] == len(dialogue['turns']) - 1 - if speaker == 'system': + if speaker == 'system' and 'booked' in turn: sample['booked'] = turn['booked'] data_by_split[data_split].append(sample) if not split_to_turn: