@@ -223,6 +223,12 @@ def beam_search(
223223 encoder_output_key = "last_hidden_state" if self .is_huggingface_model else "encoder_output"
224224 encoder_output = model_kwargs ["encoder_outputs" ][encoder_output_key ]
225225
226+ num_sequences = input_ids .shape [0 ]
227+
228+ # Pre-allocate everything
229+ token_idxs = torch .full ((num_sequences , num_beams , 1 ), eos_idx ).to (dtype = torch .long , device = device )
230+ beam_idxs = torch .zeros ((num_sequences , num_beams , 1 )).to (dtype = torch .long , device = device )
231+
226232 def update_func (emissions , N , T , prev_step_token_idxs , prev_step_hyp_idxs , prev_step_model_states , timestep ):
227233 # `emissions` and `N` are unused in this current implementation
228234
@@ -231,16 +237,8 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
231237 # For first timestep, create previous step token_idxs and model_states
232238 if timestep == 0 :
233239 prev_step_token_idxs = [- 1 ]
234- prev_step_model_states = [
235- create_emitting_model_state (
236- Seq2SeqModelState (timestep = 0 , sequence = input_ids [i ].unsqueeze (0 ), lm_scores = None )
237- )
238- ]
239240
240241 encoder_output_for_curr_seq = encoder_output [i , :, :].unsqueeze (0 ) if self .is_encoder_decoder else None
241- prev_model_state_sequences = [
242- get_obj_from_emitting_model_state (state ).sequence for state in prev_step_model_states
243- ]
244242 out_probs , model_states = [], []
245243
246244 start = 0
@@ -256,66 +254,32 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
256254 if end > curr_beam_size :
257255 end = curr_beam_size
258256
259- num_samples = end - start
260-
261257 if prev_step_token_idxs != [- 1 ]:
262- state_sequences = torch .cat (prev_model_state_sequences [start :end ], dim = 0 )
263- token_indices = (
264- torch .Tensor (prev_step_token_idxs [start :end ])
265- .to (dtype = torch .long , device = device )
266- .reshape (num_samples , 1 )
267- )
268-
269- state_and_tokens = torch .cat (
270- [state_sequences , token_indices ], dim = - 1
271- ) # [batch_size x (timestep + 1)]
272- assert state_and_tokens .shape == (
273- num_samples ,
274- timestep + 1 ,
275- ), f"state_and_tokens has shape { state_and_tokens .shape } = expected { (num_samples , timestep + 1 )} "
258+ token_indices = torch .Tensor (prev_step_token_idxs [start :end ]).to (dtype = torch .long , device = device )
259+ token_idxs [i , : len (token_indices ), 0 ] = token_indices
260+ curr_token_idxs = token_idxs [i , :, 0 ].reshape (num_beams , 1 )
276261 else :
277- assert len (prev_model_state_sequences ) == 1
278- state_and_tokens = token_indices = prev_model_state_sequences [0 ].expand (
279- num_beams , - 1
280- ) # TODO: Make this more robust
281-
282- # Cleanup -- combine this with the above
283- if self .is_encoder_decoder :
284- # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
285- # This is a view-only operation and doesn't copy
286- model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
287- num_samples if timestep > 0 else num_beams , - 1 , - 1
288- )
262+ if self .is_encoder_decoder :
263+ # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
264+ # This is a view-only operation and doesn't copy
265+ model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
266+ num_beams , - 1 , - 1
267+ )
268+ curr_token_idxs = torch .zeros ((num_beams , 1 )).to (dtype = torch .long , device = device )
269+
289270
290271 # Preprocess inputs for generation
291272 model_inputs = self .model .prepare_inputs_for_generation (
292- token_indices , ** model_kwargs
273+ curr_token_idxs , ** model_kwargs
293274 ) # This should technically work with state_and_tokens, but the prepare function has to splice if past (like HF does)
294275 if self .is_huggingface_model :
295276 model_inputs .update (self ._huggingface_model_input_values )
296277 if len (prev_step_hyp_idxs ) > 1 and model_kwargs ["past" ] is not None :
297- beam_idxs = torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 )
298-
299- # We could store this in model_kwargs
300- num_hyps_in_prev_step = model_kwargs ["past" ][0 ][0 ].shape [0 ]
301-
302- num_finished_hyps_in_step = num_hyps_in_prev_step - len (prev_step_hyp_idxs )
303- if num_finished_hyps_in_step > 0 :
304- beam_idxs = F .pad (beam_idxs , (0 , num_finished_hyps_in_step ), "constant" , 0 )
305-
306- beam_idxs = torch .clamp (beam_idxs , max = len (prev_step_hyp_idxs ) - 1 )
307-
308- reordered_cached = self .model ._reorder_cache (model_kwargs ["past" ], beam_idxs )
309-
310- if num_finished_hyps_in_step > 0 :
311- sliced_cache = ()
312- for states in reordered_cached :
313- sliced_state = ()
314- for state in states :
315- sliced_state = sliced_state + (state [: len (prev_step_hyp_idxs )],)
316- sliced_cache = sliced_cache + (sliced_state ,)
317- reordered_cached = sliced_cache
278+ beam_indices = torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 )
279+ beam_idxs [i , : len (prev_step_hyp_idxs ), 0 ] = beam_indices
280+ curr_beam_idxs = beam_idxs [i , :, 0 ]
318281
282+ reordered_cached = self .model ._reorder_cache (model_kwargs ["past" ], curr_beam_idxs )
319283 model_inputs ["past_key_values" ] = reordered_cached
320284
321285 # Forward pass
@@ -329,18 +293,21 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
329293 if self .is_huggingface_model :
330294 self ._update_model_kwargs_for_generation (outputs , model_kwargs )
331295
296+ # Reset
297+ token_idxs [i , :, 0 ] = eos_idx
298+ beam_idxs [i , :, 0 ] = 0
299+
332300 # Keep track of probabilities over vocab for this pairing
333- # TODO: fix how we track the number here?
334- for i in range (lm_scores .shape [0 ]):
301+ for i in range (num_beams ):
335302 sample_lm_scores = lm_scores [i , - 1 ]
336303 out_probs .append (sample_lm_scores .tolist ())
337304 # Keep track of sequence and decoder hidden states
338305 model_states .append (
339306 create_emitting_model_state (
340307 Seq2SeqModelState (
341308 timestep = timestep ,
342- sequence = state_and_tokens [ i ]. unsqueeze ( 0 ) ,
343- lm_scores = sample_lm_scores ,
309+ sequence = [] ,
310+ lm_scores = 0 ,
344311 )
345312 )
346313 )
@@ -386,10 +353,6 @@ def is_not_neg_one(elem: int) -> bool:
386353 if not self .is_encoder_decoder :
387354 final_tokens = input_ids [timestep ].tolist () + final_tokens
388355
389- # Makeshift padding so that we can stack the tensors
390- while len (final_tokens ) < max_len :
391- final_tokens += [0 ]
392-
393356 # Convert from list to tensors
394357 final_tokens_as_tensors = torch .Tensor (final_tokens ).to (torch .long )
395358
0 commit comments