Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,25 @@ public void FromDataJsonIncludesDimensionsWhenProvided()
// Assert
Assert.Contains($"{DimensionalityJsonPropertyName}:{Dimensions}", json);
}

[Fact]
public void FromDataShouldIncludeTaskTypeWhenProvided()
{
// Arrange
var input = new[] { "This is a retrieval document." };
var modelId = "embedding-001";
var dimensions = 1024;
var taskType = "RETRIEVAL_DOCUMENT";

// Act
var request = GoogleAIEmbeddingRequest.FromData(input, modelId, dimensions, taskType);

// Serialize to JSON (this is what would be sent in the HTTP request)
var json = System.Text.Json.JsonSerializer.Serialize(request);

// Assert
Assert.Contains("\"taskType\":\"RETRIEVAL_DOCUMENT\"", json);
Assert.Contains("\"model\":\"models/embedding-001\"", json);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;

namespace Microsoft.SemanticKernel.Connectors.Google.Core;
Expand Down Expand Up @@ -54,15 +55,25 @@ public GoogleAIEmbeddingClient(
/// Generates embeddings for the given data asynchronously.
/// </summary>
/// <param name="data">The list of strings to generate embeddings for.</param>
/// <param name="options">The embedding generation options.</param>
/// <param name="cancellationToken">The cancellation token to cancel the operation.</param>
/// <returns>Result contains a list of read-only memories of floats representing the generated embeddings.</returns>
public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(
IList<string> data,
EmbeddingGenerationOptions? options = null,
CancellationToken cancellationToken = default)
{
Verify.NotNullOrEmpty(data);

var geminiRequest = this.GetEmbeddingRequest(data);
// var geminiRequest = this.GetEmbeddingRequest(data);
string? taskType = null;
if (options?.AdditionalProperties?.TryGetValue("task_type", out var taskTypeValue) == true)
{
taskType = taskTypeValue?.ToString();
}

var geminiRequest = this.GetEmbeddingRequest(data, taskType);

using var httpRequestMessage = await this.CreateHttpRequestAsync(geminiRequest, this._embeddingEndpoint).ConfigureAwait(false);

string body = await this.SendRequestAndGetStringBodyAsync(httpRequestMessage, cancellationToken)
Expand All @@ -71,8 +82,8 @@ public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(
return DeserializeAndProcessEmbeddingsResponse(body);
}

private GoogleAIEmbeddingRequest GetEmbeddingRequest(IEnumerable<string> data)
=> GoogleAIEmbeddingRequest.FromData(data, this._embeddingModelId, this._dimensions);
private GoogleAIEmbeddingRequest GetEmbeddingRequest(IEnumerable<string> data, string? taskType = null)
=> GoogleAIEmbeddingRequest.FromData(data, this._embeddingModelId, this._dimensions, taskType);

private static List<ReadOnlyMemory<float>> DeserializeAndProcessEmbeddingsResponse(string body)
=> ProcessEmbeddingsResponse(DeserializeResponse<GoogleAIEmbeddingResponse>(body));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ internal sealed class GoogleAIEmbeddingRequest
[JsonPropertyName("requests")]
public IList<RequestEmbeddingContent> Requests { get; set; } = null!;

public static GoogleAIEmbeddingRequest FromData(IEnumerable<string> data, string modelId, int? dimensions = null) => new()
public static GoogleAIEmbeddingRequest FromData(IEnumerable<string> data, string modelId, int? dimensions = null, string? taskType = null) => new()
{
Requests = data.Select(text => new RequestEmbeddingContent
{
Expand All @@ -26,7 +26,8 @@ internal sealed class GoogleAIEmbeddingRequest
}
]
},
Dimensions = dimensions
Dimensions = dimensions,
TaskType = taskType
}).ToList()
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.Connectors.Google.Core;
using Microsoft.SemanticKernel.Embeddings;
Expand Down Expand Up @@ -68,6 +69,25 @@ public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(
Kernel? kernel = null,
CancellationToken cancellationToken = default)
{
return this._embeddingClient.GenerateEmbeddingsAsync(data, cancellationToken);
return this._embeddingClient.GenerateEmbeddingsAsync(data, null, cancellationToken);
}

/// <summary>
/// Generates embeddings for the specified input text, allowing additional configuration
/// via <see cref="EmbeddingGenerationOptions"/> (e.g., specifying the Google task type).
/// </summary>
/// <param name="data">The input text collection to generate embeddings for.</param>
/// <param name="options">Embedding generation options (e.g., task_type).</param>
/// <param name="kernel">Optional Kernel instance.</param>
/// <param name="cancellationToken">Token for cancelling the request.</param>
/// <returns>A list of generated embeddings.</returns>
public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(
IList<string> data,
EmbeddingGenerationOptions? options,
Kernel? kernel = null,
CancellationToken cancellationToken = default)
{
return this._embeddingClient.GenerateEmbeddingsAsync(data, options, cancellationToken);
}

}