Skip to content

Commit 21b54cd

Browse files
committed
CrateDB: Vector Store -- make it work using CrateDB's vector_similarity
Before, the adapter used CrateDB's built-in `_score` field for ranking. Now, it uses the dedicated `vector_similarity()` function to compute the similarity between two vectors.
1 parent 0ca8ad5 commit 21b54cd

File tree

2 files changed

+29
-26
lines changed
  • libs/community

2 files changed

+29
-26
lines changed

libs/community/langchain_community/vectorstores/cratedb/base.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, floa
260260
page_content=result.EmbeddingStore.document,
261261
metadata=result.EmbeddingStore.cmetadata,
262262
),
263-
result._score if self.embedding_function is not None else None,
263+
result.similarity if self.embedding_function is not None else None,
264264
)
265265
for result in results
266266
]
@@ -324,15 +324,22 @@ def _query_collection_multi(
324324
results: List[Any] = (
325325
session.query( # type: ignore[attr-defined]
326326
self.EmbeddingStore,
327-
# FIXME: Using `_score` is definitively the wrong choice.
328-
# - https://github.com/crate-workbench/langchain/issues/19
329-
# - https://github.com/crate/crate/issues/15835
330327
# TODO: Original pgvector code uses `self.distance_strategy`.
331328
# CrateDB currently only supports EUCLIDEAN.
332329
# self.distance_strategy(embedding).label("distance")
333-
sqlalchemy.literal_column(
334-
f"{self.EmbeddingStore.__tablename__}._score"
335-
).label("_score"),
330+
sqlalchemy.func.vector_similarity(
331+
self.EmbeddingStore.embedding,
332+
# TODO: Just reference the `embedding` symbol here, don't
333+
# serialize its value prematurely.
334+
# https://github.com/crate/crate/issues/16912
335+
#
336+
# Until that got fixed, marshal the arguments to
337+
# `vector_similarity()` manually, in order to work around
338+
# this edge case bug. We don't need to use JSON marshalling,
339+
# because Python's string representation of a list is just
340+
# right.
341+
sqlalchemy.text(str(embedding)),
342+
).label("similarity"),
336343
)
337344
.filter(filter_by)
338345
# CrateDB applies `KNN_MATCH` within the `WHERE` clause.
@@ -341,7 +348,7 @@ def _query_collection_multi(
341348
self.EmbeddingStore.embedding, embedding, k
342349
)
343350
)
344-
.order_by(sqlalchemy.desc("_score"))
351+
.order_by(sqlalchemy.desc("similarity"))
345352
.join(
346353
self.CollectionStore,
347354
self.EmbeddingStore.collection_id == self.CollectionStore.uuid,
@@ -450,7 +457,7 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]:
450457
)
451458

452459
@staticmethod
453-
def _euclidean_relevance_score_fn(score: float) -> float:
460+
def _euclidean_relevance_score_fn(similarity: float) -> float:
454461
"""Return a similarity score on a scale [0, 1]."""
455462
# The 'correct' relevance function
456463
# may differ depending on a few things, including:
@@ -465,4 +472,4 @@ def _euclidean_relevance_score_fn(score: float) -> float:
465472

466473
# Original:
467474
# return 1.0 - distance / math.sqrt(2)
468-
return score / math.sqrt(2)
475+
return similarity / math.sqrt(2)

libs/community/tests/integration_tests/vectorstores/test_cratedb.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def test_cratedb_with_metadatas_with_scores() -> None:
232232
pre_delete_collection=True,
233233
)
234234
output = docsearch.similarity_search_with_score("foo", k=1)
235-
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 2.0)]
235+
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)]
236236

237237

238238
def test_cratedb_with_filter_match() -> None:
@@ -250,9 +250,7 @@ def test_cratedb_with_filter_match() -> None:
250250
# TODO: Original:
251251
# assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] # noqa: E501
252252
output = docsearch.similarity_search_with_score("foo", k=1, filter={"page": "0"})
253-
assert output == [
254-
(Document(page_content="foo", metadata={"page": "0"}), pytest.approx(2.2, 0.3))
255-
]
253+
assert output == [(Document(page_content="foo", metadata={"page": "0"}), 1.0)]
256254

257255

258256
def test_cratedb_with_filter_distant_match() -> None:
@@ -269,9 +267,7 @@ def test_cratedb_with_filter_distant_match() -> None:
269267
)
270268
output = docsearch.similarity_search_with_score("foo", k=2, filter={"page": "2"})
271269
# Original score value: 0.0013003906671379406
272-
assert output == [
273-
(Document(page_content="baz", metadata={"page": "2"}), pytest.approx(1.5, 0.2))
274-
]
270+
assert output == [(Document(page_content="baz", metadata={"page": "2"}), 0.2)]
275271

276272

277273
def test_cratedb_with_filter_no_match() -> None:
@@ -425,8 +421,8 @@ def test_cratedb_with_filter_in_set() -> None:
425421
)
426422
# Original score values: 0.0, 0.0013003906671379406
427423
assert output == [
428-
(Document(page_content="foo", metadata={"page": "0"}), pytest.approx(3.0, 0.1)),
429-
(Document(page_content="baz", metadata={"page": "2"}), pytest.approx(2.2, 0.1)),
424+
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
425+
(Document(page_content="baz", metadata={"page": "2"}), 0.2),
430426
]
431427

432428

@@ -474,9 +470,9 @@ def test_cratedb_relevance_score() -> None:
474470
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
475471
# Original score values: 1.0, 0.9996744261675065, 0.9986996093328621
476472
assert output == [
477-
(Document(page_content="foo", metadata={"page": "0"}), pytest.approx(1.4, 0.1)),
478-
(Document(page_content="bar", metadata={"page": "1"}), pytest.approx(1.1, 0.1)),
479-
(Document(page_content="baz", metadata={"page": "2"}), pytest.approx(0.8, 0.1)),
473+
(Document(page_content="foo", metadata={"page": "0"}), 0.7071067811865475),
474+
(Document(page_content="bar", metadata={"page": "1"}), 0.35355339059327373),
475+
(Document(page_content="baz", metadata={"page": "2"}), 0.1414213562373095),
480476
]
481477

482478

@@ -495,9 +491,9 @@ def test_cratedb_retriever_search_threshold() -> None:
495491

496492
retriever = docsearch.as_retriever(
497493
search_type="similarity_score_threshold",
498-
search_kwargs={"k": 3, "score_threshold": 0.999},
494+
search_kwargs={"k": 3, "score_threshold": 0.35}, # Original value: 0.999
499495
)
500-
output = retriever.get_relevant_documents("summer")
496+
output = retriever.invoke("summer")
501497
assert output == [
502498
Document(page_content="foo", metadata={"page": "0"}),
503499
Document(page_content="bar", metadata={"page": "1"}),
@@ -522,7 +518,7 @@ def test_cratedb_retriever_search_threshold_custom_normalization_fn() -> None:
522518
search_type="similarity_score_threshold",
523519
search_kwargs={"k": 3, "score_threshold": 0.5},
524520
)
525-
output = retriever.get_relevant_documents("foo")
521+
output = retriever.invoke("foo")
526522
assert output == []
527523

528524

@@ -551,7 +547,7 @@ def test_cratedb_max_marginal_relevance_search_with_score() -> None:
551547
pre_delete_collection=True,
552548
)
553549
output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3)
554-
assert output == [(Document(page_content="foo"), 2.0)]
550+
assert output == [(Document(page_content="foo"), 1.0)]
555551

556552

557553
def test_cratedb_multicollection_search_success() -> None:

0 commit comments

Comments
 (0)