66
77use Codewithkyrian \Transformers \Exceptions \MissingModelInputException ;
88use Codewithkyrian \Transformers \Exceptions \ModelExecutionException ;
9- use Codewithkyrian \Transformers \Models \Pretrained \PreTrainedModel ;
9+ use Codewithkyrian \Transformers \Models \Pretrained \PretrainedModel ;
1010use Codewithkyrian \Transformers \Utils \GenerationConfig ;
1111use Codewithkyrian \Transformers \Utils \Tensor ;
1212
@@ -30,7 +30,7 @@ public function canGenerate(): bool
3030 };
3131 }
3232
33- public function runBeam (PreTrainedModel $ model , array &$ beam ): array
33+ public function runBeam (PretrainedModel $ model , array &$ beam ): array
3434 {
3535 return match ($ this ) {
3636 self ::DecoderOnly => $ this ->decoderRunBeam ($ model , $ beam ),
@@ -40,7 +40,7 @@ public function runBeam(PreTrainedModel $model, array &$beam): array
4040 }
4141
4242 public function startBeams (
43- PreTrainedModel $ model ,
43+ PretrainedModel $ model ,
4444 Tensor $ inputTokenIds ,
4545 GenerationConfig $ generationConfig ,
4646 int $ numOutputTokens ,
@@ -63,7 +63,7 @@ public function updateBeam(array &$beam, int $newTokenId): void
6363 };
6464 }
6565
66- public function forward (PreTrainedModel $ model , array $ modelInputs ): array
66+ public function forward (PretrainedModel $ model , array $ modelInputs ): array
6767 {
6868 return match ($ this ) {
6969 self ::EncoderOnly => $ this ->encoderForward ($ model , $ modelInputs ),
@@ -77,7 +77,7 @@ public function forward(PreTrainedModel $model, array $modelInputs): array
7777
7878 //<editor-fold desc="Encoder methods">
7979
80- protected function encoderForward (PreTrainedModel $ model , array $ modelInputs ): array
80+ protected function encoderForward (PretrainedModel $ model , array $ modelInputs ): array
8181 {
8282 $ encoderFeeds = [];
8383
@@ -102,11 +102,11 @@ protected function encoderForward(PreTrainedModel $model, array $modelInputs): a
102102
103103 /**
104104 * Runs a single step of the text generation process for a given beam.
105- * @param PreTrainedModel $model The text generation model object.
105+ * @param PretrainedModel $model The text generation model object.
106106 * @param array $beam The beam to run the generation process for.
107107 * @return array The output of the generation process for the given beam.
108108 */
109- protected function decoderRunBeam (PreTrainedModel $ model , array &$ beam ): array
109+ protected function decoderRunBeam (PretrainedModel $ model , array &$ beam ): array
110110 {
111111 $ attnMaskLength = count ($ beam ['output_token_ids ' ]);
112112 $ attnMaskData = array_fill (0 , $ attnMaskLength , 1 );
@@ -128,15 +128,15 @@ protected function decoderRunBeam(PreTrainedModel $model, array &$beam): array
128128 }
129129
130130 /** Starts the generation of text by initializing the beams for the given input token IDs.
131- * @param PreTrainedModel $model The text generation model object.
131+ * @param PretrainedModel $model The text generation model object.
132132 * @param Tensor $inputTokenIds A tensor of input token IDs to generate text from.
133133 * @param GenerationConfig $generationConfig The generation config.
134134 * @param int $numOutputTokens The maximum number of tokens to generate for each beam.
135135 * @param Tensor|null $inputsAttentionMask The attention mask tensor for the input token IDs.
136136 * @return array An array of beams initialized with the given inputs and parameters.
137137 */
138138 protected function decoderStartBeams (
139- PreTrainedModel $ model ,
139+ PretrainedModel $ model ,
140140 Tensor $ inputTokenIds ,
141141 GenerationConfig $ generationConfig ,
142142 int $ numOutputTokens ,
@@ -195,12 +195,12 @@ protected function decoderUpdatebeam(array &$beam, int $newTokenId): void
195195
196196 /**
197197 * Forward pass for the decoder model.
198- * @param PreTrainedModel $model The model to use for the forward pass.
198+ * @param PretrainedModel $model The model to use for the forward pass.
199199 * @param array $modelInputs The inputs to the model.
200200 * @return array The output of the forward pass.
201201 * @throws MissingModelInputException|ModelExecutionException
202202 */
203- protected function decoderForward (PreTrainedModel $ model , array $ modelInputs ): array
203+ protected function decoderForward (PretrainedModel $ model , array $ modelInputs ): array
204204 {
205205 ['input_ids ' => $ inputIds , 'past_key_values ' => $ pastKeyValues , 'attention_mask ' => $ attentionMask ]
206206 = $ modelInputs ;
@@ -234,7 +234,7 @@ protected function decoderForward(PreTrainedModel $model, array $modelInputs): a
234234
235235 //<editor-fold desc="Seq2Seq methods">
236236
237- protected function seq2seqRunBeam (PreTrainedModel $ model , array &$ beam ): array
237+ protected function seq2seqRunBeam (PretrainedModel $ model , array &$ beam ): array
238238 {
239239 $ inputName = $ model ->mainInputName ;
240240
@@ -270,14 +270,14 @@ protected function seq2seqRunBeam(PreTrainedModel $model, array &$beam): array
270270 }
271271
272272 /** Start the beam search process for the seq2seq model.
273- * @param PreTrainedModel $model The model to use for the beam search.
273+ * @param PretrainedModel $model The model to use for the beam search.
274274 * @param Tensor $inputTokenIds Array of input token ids for each input sequence.
275275 * @param GenerationConfig $generationConfig The generation configuration.
276276 * @param int $numOutputTokens The maximum number of output tokens for the model.
277277 * @return array Array of beam search objects.
278278 */
279279 protected function seq2seqStartBeams (
280- PreTrainedModel $ model ,
280+ PretrainedModel $ model ,
281281 Tensor $ inputTokenIds ,
282282 GenerationConfig $ generationConfig ,
283283 int $ numOutputTokens ,
@@ -330,7 +330,7 @@ protected function seq2seqUpdatebeam(array &$beam, int $newTokenId): void
330330 $ beam ['output_token_ids ' ][] = $ newTokenId ;
331331 }
332332
333- protected function seq2seqForward (PreTrainedModel $ model , array $ modelInputs ): array
333+ protected function seq2seqForward (PretrainedModel $ model , array $ modelInputs ): array
334334 {
335335
336336 ['encoder_outputs ' => $ encoderOutputs , 'past_key_values ' => $ pastKeyValues ] = $ modelInputs ;
0 commit comments