diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/VectorConfig.java b/src/main/java/io/weaviate/client6/v1/api/collections/VectorConfig.java index ae13f70c..dffa7450 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/VectorConfig.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/VectorConfig.java @@ -285,7 +285,7 @@ public void write(JsonWriter out, VectorConfig value) throws IOException { vectorizer.add(value._kind().jsonValue(), config); vectorIndex.getAsJsonObject().add("vectorizer", vectorizer); - if (value.quantization() != null) { + if (value.quantization() != null && !config.getAsJsonObject().get("quantization").isJsonNull()) { vectorIndex.getAsJsonObject() .get("vectorIndexConfig").getAsJsonObject() .add(value.quantization()._kind().jsonValue(), config.getAsJsonObject().remove("quantization")); diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/VectorIndex.java b/src/main/java/io/weaviate/client6/v1/api/collections/VectorIndex.java index 51096f63..12d2e2a0 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/VectorIndex.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/VectorIndex.java @@ -13,15 +13,17 @@ import com.google.gson.stream.JsonReader; import com.google.gson.stream.JsonWriter; +import io.weaviate.client6.v1.api.collections.vectorindex.Dynamic; import io.weaviate.client6.v1.api.collections.vectorindex.Flat; import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; +import io.weaviate.client6.v1.internal.TaggedUnion; import io.weaviate.client6.v1.internal.json.JsonEnum; -public interface VectorIndex { +public interface VectorIndex extends TaggedUnion { static final String DEFAULT_VECTOR_NAME = "default"; static final VectorIndex DEFAULT_VECTOR_INDEX = Hnsw.of(); - public enum Kind implements JsonEnum { + enum Kind implements JsonEnum { HNSW("hnsw"), FLAT("flat"), DYNAMIC("dynamic"); @@ -43,17 +45,37 @@ public static Kind valueOfJson(String jsonValue) { } } - VectorIndex.Kind _kind(); + /** Is this vector index of type HNSW? */ + default Hnsw isHnsw() { + return _as(VectorIndex.Kind.HNSW); + } + + /** Get as {@link Hnsw} instance. */ + default Hnsw asHnsw() { + return _as(VectorIndex.Kind.HNSW); + } + + /** Is this vector index of type FLAT? */ + default Flat isFlat() { + return _as(VectorIndex.Kind.FLAT); + } + + /** Get as {@link Flat} instance. */ + default Flat asFlat() { + return _as(VectorIndex.Kind.FLAT); + } - /** Returns the on-the-wire name of the vector index type. */ - default String type() { - return _kind().jsonValue(); + /** Is this vector index of type DYNAMIC? */ + default Dynamic isDynamic() { + return _as(VectorIndex.Kind.DYNAMIC); } - /** Get the concrete vector index configuration object. */ - Object config(); + /** Get as {@link Dynamic} instance. */ + default Dynamic asDynamic() { + return _as(VectorIndex.Kind.DYNAMIC); + } - public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { + static enum CustomTypeAdapterFactory implements TypeAdapterFactory { INSTANCE; private static final EnumMap> readAdapters = new EnumMap<>( @@ -66,6 +88,7 @@ private final void addAdapter(Gson gson, VectorIndex.Kind kind, Class> fn) { + return fn.apply(new Builder()).build(); + } + + public Dynamic(Builder builder) { + this( + builder.hnsw, + builder.flat, + builder.threshold); + } + + public static class Builder implements ObjectBuilder { + + private Hnsw hnsw; + private Flat flat; + private Long threshold; + + public Builder hnsw(Hnsw hnsw) { + this.hnsw = hnsw; + return this; + } + + public Builder flat(Flat flat) { + this.flat = flat; + return this; + } + + public Builder threshold(long threshold) { + this.threshold = threshold; + return this; + } + + @Override + public Dynamic build() { + return new Dynamic(this); + } + } + + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { + INSTANCE; + + @SuppressWarnings("unchecked") + @Override + public TypeAdapter create(Gson gson, TypeToken type) { + var rawType = type.getRawType(); + if (!Dynamic.class.isAssignableFrom(rawType)) { + return null; + } + + final var hnswAdapter = gson.getDelegateAdapter(VectorIndex.CustomTypeAdapterFactory.INSTANCE, + TypeToken.get(Hnsw.class)); + final var flatAdapter = gson.getDelegateAdapter(VectorIndex.CustomTypeAdapterFactory.INSTANCE, + TypeToken.get(Flat.class)); + + return (TypeAdapter) new TypeAdapter() { + + @Override + public void write(JsonWriter out, Dynamic value) throws IOException { + + var dynamic = new JsonObject(); + + dynamic.addProperty("threshold", value.threshold); + dynamic.add("hnsw", hnswAdapter.toJsonTree(value.hnsw)); + dynamic.add("flat", flatAdapter.toJsonTree(value.flat)); + + Streams.write(dynamic, out); + } + + @Override + public Dynamic read(JsonReader in) throws IOException { + var jsonObject = JsonParser.parseReader(in).getAsJsonObject(); + + var hnsw = hnswAdapter.fromJsonTree(jsonObject.get("hnsw")); + var flat = flatAdapter.fromJsonTree(jsonObject.get("flat")); + var threshold = jsonObject.get("threshold").getAsLong(); + return new Dynamic(hnsw, flat, threshold); + } + }.nullSafe(); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Flat.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Flat.java index 92069553..dfa1da18 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Flat.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Flat.java @@ -16,7 +16,7 @@ public VectorIndex.Kind _kind() { } @Override - public Object config() { + public Object _self() { return this; } diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Hnsw.java b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Hnsw.java index 92e68424..a6e28ee7 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Hnsw.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/vectorindex/Hnsw.java @@ -29,7 +29,7 @@ public VectorIndex.Kind _kind() { } @Override - public Object config() { + public Object _self() { return this; } diff --git a/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java b/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java index 4bee47c6..d1dc1db5 100644 --- a/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java +++ b/src/main/java/io/weaviate/client6/v1/internal/json/JSON.java @@ -28,8 +28,17 @@ public final class JSON { io.weaviate.client6.v1.api.collections.Vectors.CustomTypeAdapterFactory.INSTANCE); gsonBuilder.registerTypeAdapterFactory( io.weaviate.client6.v1.api.collections.VectorConfig.CustomTypeAdapterFactory.INSTANCE); + + // These 2 adapters need to be registered in this exact order: Dynamic + // (narrower), VectorIndex (broader). + // When searching for an adapter, Gson will pick the first adapter factory that + // can process the class, and it's important that Dynamic.class is processed by + // this factory. + gsonBuilder.registerTypeAdapterFactory( + io.weaviate.client6.v1.api.collections.vectorindex.Dynamic.CustomTypeAdapterFactory.INSTANCE); gsonBuilder.registerTypeAdapterFactory( io.weaviate.client6.v1.api.collections.VectorIndex.CustomTypeAdapterFactory.INSTANCE); + gsonBuilder.registerTypeAdapterFactory( io.weaviate.client6.v1.api.collections.Reranker.CustomTypeAdapterFactory.INSTANCE); gsonBuilder.registerTypeAdapterFactory( diff --git a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java index f6f15e0d..f6e8d322 100644 --- a/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java +++ b/src/test/java/io/weaviate/client6/v1/internal/json/JSONTest.java @@ -32,6 +32,7 @@ import io.weaviate.client6.v1.api.collections.quantizers.PQ; import io.weaviate.client6.v1.api.collections.rerankers.CohereReranker; import io.weaviate.client6.v1.api.collections.vectorindex.Distance; +import io.weaviate.client6.v1.api.collections.vectorindex.Dynamic; import io.weaviate.client6.v1.api.collections.vectorindex.Flat; import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw; import io.weaviate.client6.v1.api.collections.vectorindex.MultiVector; @@ -166,6 +167,28 @@ public static Object[][] testCases() { } """, }, + { + VectorConfig.class, + SelfProvidedVectorizer.of(none -> none + .vectorIndex(Dynamic.of(idx -> idx + .hnsw(Hnsw.of(hnsw -> hnsw + .ef(1) + .efConstruction(2))) + .flat(Flat.of(flat -> flat + .vectorCacheMaxObjects(100))) + .threshold(5)))), + """ + { + "vectorIndexType": "dynamic", + "vectorizer": {"none": {}}, + "vectorIndexConfig": { + "flat": {"vectorCacheMaxObjects": 100}, + "hnsw": {"ef": 1, "efConstruction": 2}, + "threshold": 5 + } + } + """, + }, { VectorConfig.class, SelfProvidedVectorizer.of(none -> none