From a0f8ec10ab95300a39b00d49ee3f5ba6b465b281 Mon Sep 17 00:00:00 2001 From: Nicolas Krier <7557886+nicolaskrier@users.noreply.github.com> Date: Sat, 8 Nov 2025 20:52:31 +0100 Subject: [PATCH] Optimize MistralAiEmbeddingModel dimensions method - Calculate and cache values for unknown models only if necessary - Make known embedding dimensions a mutable map attribute - Verify the cache mechanism with MistralAiEmbeddingModelTests Signed-off-by: Nicolas Krier <7557886+nicolaskrier@users.noreply.github.com> --- .../ai/mistralai/MistralAiEmbeddingModel.java | 20 +++++++++----- .../MistralAiEmbeddingModelTests.java | 27 +++++++++++-------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java index 76b71e93030..8f3f09c591c 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java @@ -16,6 +16,7 @@ package org.springframework.ai.mistralai; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -56,16 +57,14 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel { private static final Logger logger = LoggerFactory.getLogger(MistralAiEmbeddingModel.class); + private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + /** * Known embedding dimensions for Mistral AI models. Maps model names to their * respective embedding vector dimensions. This allows the dimensions() method to * return the correct value without making an API call. */ - private static final Map KNOWN_EMBEDDING_DIMENSIONS = Map.of( - MistralAiApi.EmbeddingModel.EMBED.getValue(), 1024, MistralAiApi.EmbeddingModel.CODESTRAL_EMBED.getValue(), - 1536); - - private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + private final Map knownEmbeddingDimensions = createKnownEmbeddingDimensions(); private final MistralAiEmbeddingOptions defaultOptions; @@ -85,6 +84,14 @@ public class MistralAiEmbeddingModel extends AbstractEmbeddingModel { */ private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + private static Map createKnownEmbeddingDimensions() { + Map knownEmbeddingDimensions = new HashMap<>(); + knownEmbeddingDimensions.put(MistralAiApi.EmbeddingModel.EMBED.getValue(), 1024); + knownEmbeddingDimensions.put(MistralAiApi.EmbeddingModel.CODESTRAL_EMBED.getValue(), 1536); + + return knownEmbeddingDimensions; + } + public MistralAiEmbeddingModel(MistralAiApi mistralAiApi, MetadataMode metadataMode, MistralAiEmbeddingOptions options, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) { Assert.notNull(mistralAiApi, "mistralAiApi must not be null"); @@ -174,7 +181,8 @@ public float[] embed(Document document) { @Override public int dimensions() { - return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), super.dimensions()); + return this.knownEmbeddingDimensions.computeIfAbsent(this.defaultOptions.getModel(), + model -> super.dimensions()); } /** diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelTests.java index f771e59a89b..9be04dd190f 100644 --- a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelTests.java +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiEmbeddingModelTests.java @@ -16,6 +16,7 @@ package org.springframework.ai.mistralai; +import java.util.Arrays; import java.util.List; import org.junit.jupiter.api.Test; @@ -28,6 +29,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** @@ -77,7 +79,7 @@ void testDimensionsForCodestralEmbedModel() { void testDimensionsFallbackForUnknownModel() { MistralAiApi mockApi = createMockApiWithEmbeddingResponse(512); - // Use a model name that doesn't exist in KNOWN_EMBEDDING_DIMENSIONS + // Use a model name that doesn't exist in knownEmbeddingDimensions. MistralAiEmbeddingOptions options = MistralAiEmbeddingOptions.builder().withModel("unknown-model").build(); MistralAiEmbeddingModel model = MistralAiEmbeddingModel.builder() @@ -87,17 +89,23 @@ void testDimensionsFallbackForUnknownModel() { .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) .build(); - // Should fall back to super.dimensions() which detects dimensions from the API - // response + // For the first call, it should fall back to super.dimensions() which detects + // dimensions from the API response. assertThat(model.dimensions()).isEqualTo(512); + + // For the second call, it should use the cache mechanism. + assertThat(model.dimensions()).isEqualTo(512); + + // Verify that super.dimensions() has been called once. + verify(mockApi).embeddings(any()); } @Test void testAllEmbeddingModelsHaveDimensionMapping() { - // This test ensures that KNOWN_EMBEDDING_DIMENSIONS map stays in sync with the - // EmbeddingModel enum + // This test ensures that knownEmbeddingDimensions map stays in sync with the + // EmbeddingModel enum. // If a new model is added to the enum but not to the dimensions map, this test - // will help catch it + // will help catch it. for (MistralAiApi.EmbeddingModel embeddingModel : MistralAiApi.EmbeddingModel.values()) { MistralAiApi mockApi = createMockApiWithEmbeddingResponse(1024); @@ -138,16 +146,13 @@ private MistralAiApi createMockApiWithEmbeddingResponse(int dimensions) { // Create a mock embedding response with the specified dimensions float[] embedding = new float[dimensions]; - for (int i = 0; i < dimensions; i++) { - embedding[i] = 0.1f; - } + Arrays.fill(embedding, 0.1f); MistralAiApi.Embedding embeddingData = new MistralAiApi.Embedding(0, embedding, "embedding"); MistralAiApi.Usage usage = new MistralAiApi.Usage(10, 0, 10); - MistralAiApi.EmbeddingList embeddingList = new MistralAiApi.EmbeddingList("object", List.of(embeddingData), - "model", usage); + var embeddingList = new MistralAiApi.EmbeddingList<>("object", List.of(embeddingData), "model", usage); when(mockApi.embeddings(any())).thenReturn(ResponseEntity.ok(embeddingList));