Skip to content
Snippets Groups Projects
Select Git revision
  • f04aaf315ba00a59c68eb967d780f2d7dae338be
  • master default protected
  • dev
  • sybilNLO
  • gprBug
  • maximumtotalflux
  • easyConstraint
  • switchbug
  • thuong
  • momafix
  • rmReactBug
11 results

validreactId_Exch.R

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    create_data.py 1.74 KiB
    import os
    import json
    from tqdm import tqdm
    import re
    from convlab2.util import load_dataset
    
    
    def create_lm_data(dataset, data_dir, args):
        data_by_split = dataset
        os.makedirs(data_dir, exist_ok=True)
    
        data_splits = data_by_split.keys()
        for data_split in data_splits:
            data = []
            for sample in tqdm(data_by_split[data_split], desc=f'{data_split} sample', leave=False):
                if args.model_type == 'dialogpt':
                    dialogue = ' <|endoftext|> '.join([turn['utterance'] for turn in sample['turns']]) + ' <|endoftext|>'
                else:
                    dialogue = ' '.join([f"{turn['speaker']}: {turn['utterance']}" for turn in sample['turns']])
                data.append(json.dumps({'dialogue': dialogue}, ensure_ascii=False)+'\n')
    
            file_name = os.path.join(data_dir, f"{data_split}.json")
            with open(file_name, "w", encoding='utf-8') as f:
                f.writelines(data)
    
    
    if __name__ == '__main__':
        from argparse import ArgumentParser
        parser = ArgumentParser(description="create data for seq2seq training")
        parser.add_argument('--tasks', '-t', metavar='task_name', nargs='*', choices=['lm'], help='names of tasks')
        parser.add_argument('--datasets', '-d', metavar='dataset_name', nargs='*', help='names of unified datasets')
        parser.add_argument('--model_type', '-m', metavar='model_type', help='type of the language model: gpt, dialogpt, ..')
        args = parser.parse_args()
        print(args)
        for dataset_name in tqdm(args.datasets, desc='datasets'):
            dataset = load_dataset(dataset_name)
            for task_name in tqdm(args.tasks, desc='tasks', leave=False):
                data_dir = os.path.join('data', task_name, dataset_name)
                eval(f"create_{task_name}_data")(dataset, data_dir, args)