@@ -576,7 +576,14 @@ def forward(self,
576576 ) -> ALL_NET_OUTPUT :
577577
578578 if isinstance (past_targets , dict ):
579- past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
579+ (
580+ past_targets ,
581+ past_features ,
582+ future_features ,
583+ past_observed_targets ,
584+ future_targets ,
585+ decoder_observed_values
586+ ) = self ._unwrap_past_targets (past_targets )
580587
581588 x_past , x_future , x_static , loc , scale , static_context_initial_hidden , _ = self .pre_processing (
582589 past_targets = past_targets ,
@@ -610,13 +617,12 @@ def forward(self,
610617 def _unwrap_past_targets (
611618 self ,
612619 past_targets : dict
613- ) -> Tuple [
614- torch .Tensor ,
615- Optional [torch .Tensor ],
616- Optional [torch .Tensor ],
617- Optional [torch .Tensor ],
618- Optional [torch .BoolTensor ],
619- Optional [torch .Tensor ]]:
620+ ) -> Tuple [torch .Tensor ,
621+ Optional [torch .Tensor ],
622+ Optional [torch .Tensor ],
623+ Optional [torch .Tensor ],
624+ Optional [torch .BoolTensor ],
625+ Optional [torch .Tensor ]]:
620626 """
621627 Time series forecasting network requires multiple inputs for the forward pass which is different to how pytorch
622628 networks usually work. SWA's update_bn in line #452 of trainer choice, does not unwrap the dictionary of the
@@ -637,7 +643,14 @@ def _unwrap_past_targets(
637643 future_features = past_targets_copy .pop ('future_features' , None )
638644 past_observed_targets = past_targets_copy .pop ('past_observed_targets' , None )
639645 decoder_observed_values = past_targets_copy .pop ('decoder_observed_values' , None )
640- return past_targets ,past_features ,future_features ,past_observed_targets
646+ return (
647+ past_targets ,
648+ past_features ,
649+ future_features ,
650+ past_observed_targets ,
651+ future_targets ,
652+ decoder_observed_values
653+ )
641654
642655 def pred_from_net_output (self , net_output : ALL_NET_OUTPUT ) -> torch .Tensor :
643656 if self .output_type == 'regression' :
@@ -730,9 +743,16 @@ def forward(self,
730743 future_features : Optional [torch .Tensor ] = None ,
731744 past_observed_targets : Optional [torch .BoolTensor ] = None ,
732745 decoder_observed_values : Optional [torch .Tensor ] = None , ) -> ALL_NET_OUTPUT :
733-
746+
734747 if isinstance (past_targets , dict ):
735- past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
748+ (
749+ past_targets ,
750+ past_features ,
751+ future_features ,
752+ past_observed_targets ,
753+ future_targets ,
754+ decoder_observed_values
755+ ) = self ._unwrap_past_targets (past_targets )
736756
737757 x_past , _ , x_static , loc , scale , static_context_initial_hidden , past_targets = self .pre_processing (
738758 past_targets = past_targets ,
@@ -1025,7 +1045,14 @@ def forward(self,
10251045 decoder_observed_values : Optional [torch .Tensor ] = None , ) -> ALL_NET_OUTPUT :
10261046
10271047 if isinstance (past_targets , dict ):
1028- past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
1048+ (
1049+ past_targets ,
1050+ past_features ,
1051+ future_features ,
1052+ past_observed_targets ,
1053+ future_targets ,
1054+ decoder_observed_values
1055+ ) = self ._unwrap_past_targets (past_targets )
10291056
10301057 encode_length = min (self .window_size , past_targets .shape [1 ])
10311058
@@ -1295,7 +1322,14 @@ def forward(self, # type: ignore[override]
12951322 Tuple [torch .Tensor , torch .Tensor ]]:
12961323
12971324 if isinstance (past_targets , dict ):
1298- past_targets , past_features , future_features , past_observed_targets = self ._unwrap_past_targets (past_targets )
1325+ (
1326+ past_targets ,
1327+ past_features ,
1328+ future_features ,
1329+ past_observed_targets ,
1330+ future_targets ,
1331+ decoder_observed_values
1332+ ) = self ._unwrap_past_targets (past_targets )
12991333
13001334 # Unlike other networks, NBEATS network is required to predict both past and future targets.
13011335 # Thereby, we return two tensors for backcast and forecast
0 commit comments