diff --git a/run_dst.py b/run_dst.py
index c52cd64d2bcf9a6e4c3d5bdbde678aae92b0f4a9..40ec53dc70130416bb332ef64af14b757cd575a9 100644
--- a/run_dst.py
+++ b/run_dst.py
@@ -107,19 +107,18 @@ def train(args, train_dataset, features, model, tokenizer, processor, continue_f
     ]
     optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
     scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total)
-    if args.fp16:
-        try:
-            from apex import amp
-        except ImportError:
-            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
-        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
-
-    # multi-gpu training (should be after apex fp16 initialization)
+    scaler = torch.cuda.amp.GradScaler()
+    if 'cuda' in args.device.type:
+        autocast = torch.cuda.amp.autocast(enabled=args.fp16)
+    else:
+        autocast = torch.cpu.amp.autocast(enabled=args.fp16)
+
+    # multi-gpu training
     model_single_gpu = model
     if args.n_gpu > 1:
         model = torch.nn.DataParallel(model_single_gpu)
 
-    # Distributed training (should be after apex fp16 initialization)
+    # Distributed training
     if args.local_rank != -1:
         model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                           output_device=args.local_rank,
@@ -170,7 +169,8 @@ def train(args, train_dataset, features, model, tokenizer, processor, continue_f
                       'refer_id':        batch[6],
                       'diag_state':      batch[7],
                       'class_label_id':  batch[8]}
-            outputs = model(**inputs)
+            with autocast:
+                outputs = model(**inputs)
             loss = outputs[0]  # model outputs are always tuple in pytorch-transformers (see doc)
 
             if args.n_gpu > 1:
@@ -178,17 +178,13 @@ def train(args, train_dataset, features, model, tokenizer, processor, continue_f
             if args.gradient_accumulation_steps > 1:
                 loss = loss / args.gradient_accumulation_steps
 
-            if args.fp16:
-                with amp.scale_loss(loss, optimizer) as scaled_loss:
-                    scaled_loss.backward()
-                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
-            else:
-                loss.backward()
-                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
-
             tr_loss += loss.item()
             if (step + 1) % args.gradient_accumulation_steps == 0:
-                optimizer.step()
+                scaler.scale(loss).backward()
+                scaler.unscale_(optimizer)
+                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
+                scaler.step(optimizer)
+                scaler.update()
                 scheduler.step()  # Update learning rate schedule
                 model.zero_grad()
                 global_step += 1
@@ -627,10 +623,7 @@ def main():
     parser.add_argument("--local_rank", type=int, default=-1,
                         help="local_rank for distributed training on gpus")
     parser.add_argument('--fp16', action='store_true',
-                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
-    parser.add_argument('--fp16_opt_level', type=str, default='O1',
-                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
-                             "See details at https://nvidia.github.io/apex/amp.html")
+                        help="Whether to use 16-bit (mixed) precision instead of 32-bit")
     parser.add_argument('--local_files_only', action='store_true',
                         help="Whether to only load local model files (useful when working offline).")
 
diff --git a/run_dst_mtl.py b/run_dst_mtl.py
index 125d7dbbc53b6aeef5c5d16dbb9ba771b0f2cd6b..6063a6a6c56ac9dfeef884665a3668d1e0256ce7 100644
--- a/run_dst_mtl.py
+++ b/run_dst_mtl.py
@@ -97,19 +97,18 @@ def train_mtl(args, train_dataset, aux_dataset, aux_task_def, features, model, t
     ]
     optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
     scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=t_total)
-    if args.fp16:
-        try:
-            from apex import amp
-        except ImportError:
-            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
-        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
-
-    # multi-gpu training (should be after apex fp16 initialization)
+    scaler = torch.cuda.amp.GradScaler()
+    if 'cuda' in args.device.type:
+        autocast = torch.cuda.amp.autocast(enabled=args.fp16)
+    else:
+        autocast = torch.cpu.amp.autocast(enabled=args.fp16)
+
+    # multi-gpu training
     model_single_gpu = model
     if args.n_gpu > 1:
         model = torch.nn.DataParallel(model_single_gpu)
 
-    # Distributed training (should be after apex fp16 initialization)
+    # Distributed training
     if args.local_rank != -1:
         model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                           output_device=args.local_rank,
@@ -191,24 +190,21 @@ def train_mtl(args, train_dataset, aux_dataset, aux_task_def, features, model, t
                               'class_label_id':  aux_batch[5],
                               'aux_task_def':    aux_task_def}
                 model.train()
-                aux_outputs = model(**aux_inputs)
+                with autocast:
+                    aux_outputs = model(**aux_inputs)
                 aux_loss = aux_outputs[0]
 
                 if args.n_gpu > 1:
                     aux_loss = aux_loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
 
-                if args.fp16:
-                    with amp.scale_loss(aux_loss, optimizer) as scaled_loss:
-                        scaled_loss.backward()
-                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
-                else:
-                    aux_loss.backward()
-                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
-
                 tr_aux_loss += aux_loss.item()
                 aux_logged_steps += 1
 
-                optimizer.step()
+                scaler.scale(loss).backward()
+                scaler.unscale_(optimizer)
+                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
+                scaler.step(optimizer)
+                scaler.update()
 
                 model.zero_grad()
                 if args.mtl_print_loss_diff:
@@ -226,7 +222,8 @@ def train_mtl(args, train_dataset, aux_dataset, aux_task_def, features, model, t
             # Normal training
             model.train()
 
-            outputs = model(**inputs)
+            with autocast:
+                outputs = model(**inputs)
             loss = outputs[0]  # model outputs are always tuple in pytorch-transformers (see doc)
 
             if args.n_gpu > 1:
@@ -234,17 +231,13 @@ def train_mtl(args, train_dataset, aux_dataset, aux_task_def, features, model, t
             if args.gradient_accumulation_steps > 1:
                 loss = loss / args.gradient_accumulation_steps
 
-            if args.fp16:
-                with amp.scale_loss(loss, optimizer) as scaled_loss:
-                    scaled_loss.backward()
-                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
-            else:
-                loss.backward()
-                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
-
             tr_loss += loss.item()
             if (step + 1) % args.gradient_accumulation_steps == 0:
-                optimizer.step()
+                scaler.scale(loss).backward()
+                scaler.unscale_(optimizer)
+                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
+                scaler.step(optimizer)
+                scaler.update()
                 scheduler.step()  # Update learning rate schedule
                 model.zero_grad()
                 global_step += 1
@@ -448,10 +441,7 @@ def main():
     parser.add_argument("--local_rank", type=int, default=-1,
                         help="local_rank for distributed training on gpus")
     parser.add_argument('--fp16', action='store_true',
-                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
-    parser.add_argument('--fp16_opt_level', type=str, default='O1',
-                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
-                             "See details at https://nvidia.github.io/apex/amp.html")
+                        help="Whether to use 16-bit (mixed) precision instead of 32-bit")
     parser.add_argument('--local_files_only', action='store_true',
                         help="Whether to only load local model files (useful when working offline).")