From f196038ad6a2f8d9347f2bd4fade7295b9942bc9 Mon Sep 17 00:00:00 2001 From: Enmin Zhou Date: Sun, 14 Aug 2022 15:23:38 -0700 Subject: [PATCH] add constant, linear, cosine and cosine restart learning rate warmup scheduler --- .../nsp_transformer_model/optimizer_warmup.py | 22 +++++++++++++++---- .../nsp_transformer_model/train_model.py | 8 ++++++- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/droidlet/perception/semantic_parsing/nsp_transformer_model/optimizer_warmup.py b/droidlet/perception/semantic_parsing/nsp_transformer_model/optimizer_warmup.py index afcee721f8..d6cf3d7c78 100644 --- a/droidlet/perception/semantic_parsing/nsp_transformer_model/optimizer_warmup.py +++ b/droidlet/perception/semantic_parsing/nsp_transformer_model/optimizer_warmup.py @@ -52,7 +52,8 @@ def __init__(self, model, args): } else: raise NotImplementedError - + self.lr_scheduler_method = args.lr_scheduler_method + self.num_training_steps = args.dataset_size / args.batch_size + (args.dataset_size % args.batch_size > 0) self._step = 0 def _update_rate(self, stack): @@ -60,9 +61,22 @@ def _update_rate(self, stack): alpha = self._step / self.warmup_steps[stack] return self.lr[stack] * (self.warmup_factor * (1.0 - alpha) + alpha) else: - return self.lr[stack] * self.lr_ratio ** bisect_right( - self.lr_schedules[stack], self._step - ) + if self.lr_scheduler_method == 'constant': + return self.lr[stack] * 1.0 + elif self.lr_scheduler_method == 'linear': + return max(0.0, float(self.num_training_steps - self._step) / float(max(1, self.num_training_steps - self.warmup_steps[stack]))) + elif self.lr_scheduler_method == 'cosine': + progress = float(self._step - self.warmup_steps[stack]) / float(max(1, self.num_training_steps - self.warmup_steps[stack])) + return self.lr[stack] * max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + elif self.lr_scheduler_method == 'cosine_hard': + progress = float(self._step - self.warmup_steps[stack]) / float(max(1, self.num_training_steps - self.warmup_steps[stack])) + if progress >= 1.0: + return 0.0 + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) + else: + return self.lr[stack] * self.lr_ratio ** bisect_right( + self.lr_schedules[stack], self._step + ) def zero_grad(self): self.optimizer_decoder.zero_grad() diff --git a/droidlet/perception/semantic_parsing/nsp_transformer_model/train_model.py b/droidlet/perception/semantic_parsing/nsp_transformer_model/train_model.py index fa600e02cd..9736809819 100644 --- a/droidlet/perception/semantic_parsing/nsp_transformer_model/train_model.py +++ b/droidlet/perception/semantic_parsing/nsp_transformer_model/train_model.py @@ -515,6 +515,12 @@ def build_grammar(args): type=float, help="Factor for learning rate in warmup stage", ) + parser.add_argument( + "--lr_scheduler_method", + default="default", + type=str, + help="Method for learning rate warmup scheduler, e.g. linear, cosine, constant, cosine_hard" + ) parser.add_argument( "--node_label_smoothing", default=0.0, @@ -641,7 +647,7 @@ def build_grammar(args): word_noise=args.word_dropout, full_tree_voc=full_tree_voc, ) - + args.dataset_size = len(train_dataset) logging.info("====== Loading Validation Datasets ======") val_datasets = {} for dtype, _ in args.dtype_samples.items():