from __future__ import annotations
from typing import Any, List, Optional
import pymongo
from langchain_classic.retrievers.parent_document_retriever import (
ParentDocumentRetriever,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import run_in_executor
from langchain_text_splitters import TextSplitter
from pydantic import Field
from pymongo import MongoClient
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.docstores import MongoDBDocStore
from langchain_mongodb.index import create_fulltext_search_index
from langchain_mongodb.pipelines import (
autoembedding_vector_search_stage,
vector_search_stage,
)
from langchain_mongodb.utils import (
DRIVER_METADATA,
make_serializable,
prepare_query_for_vector_search,
)
[docs]
class MongoDBAtlasParentDocumentRetriever(ParentDocumentRetriever):
"""MongoDB Atlas's ParentDocumentRetriever
“Parent Document Retrieval” is a common approach to enhance the performance of
retrieval methods in RAG by providing the LLM with a broader context to consider.
In essence, we divide the original documents into relatively small chunks,
embed each one, and store them in a vector database.
Using such small chunks (a sentence or a couple of sentences)
helps the embedding models to better reflect their meaning.
If two high scoring chunks are contained in the same document,
the query response will include the parent document just once.
One can control the number of chunks found in the vector_search_stage by setting
search_kwargs == {'top_k': n}. The number of query responses will be <= top_k.
In this implementation, we can store both parent and child documents in a single
collection while only having to compute and index embedding vectors for the chunks!
This is achieved by backing both the
vectorstore, :class:`~langchain_mongodb.vectorstores.MongoDBAtlasVectorSearch`
and the docstore :class:`~langchain_mongodb.docstores.MongoDBDocStore`
by the same MongoDB Collection.
For more details, see superclasses
:class:`~langchain.retrievers.parent_document_retriever.ParentDocumentRetriever`
and :class:`~langchain.retrievers.MultiVectorRetriever`.
Examples:
>>> from langchain_mongodb.retrievers.parent_document import (
>>> ParentDocumentRetriever
>>> )
>>> from langchain_text_splitters import RecursiveCharacterTextSplitter
>>> from langchain_openai import OpenAIEmbeddings
>>>
>>> retriever = ParentDocumentRetriever.from_connection_string(
>>> "mongodb+srv://<user>:<clustername>.mongodb.net",
>>> OpenAIEmbeddings(model="text-embedding-3-large"),
>>> RecursiveCharacterTextSplitter(chunk_size=400),
>>> "example_database"
>>> )
retriever.add_documents([Document(..., technical_report_pages)
>>> resp = retriever.invoke("Langchain MongDB Partnership Ecosystem")
>>> print(resp)
[Document(...), ...]
"""
vectorstore: MongoDBAtlasVectorSearch
"""Vectorstore API to add, embed, and search through child documents"""
docstore: MongoDBDocStore
"""Provides an API around the Collection to add the parent documents"""
id_key: str = "doc_id"
"""Key stored in metadata pointing to parent document"""
search_kwargs: dict = Field(default_factory=dict)
"""Kwargs to be passed to vector_search_stage. e.g. {'top_k': 5}. """
def __init__(self, *args: Any, **kwargs: Any) -> None:
auto_create_index = kwargs.pop("auto_create_index", True)
auto_index_timeout = kwargs.pop("auto_index_timeout", 15)
search_index_name = kwargs.pop("search_index_name", "search_index")
search_field = kwargs.pop("search_field", None)
super().__init__(*args, **kwargs)
if auto_create_index and not any(
ix["name"] == search_index_name
for ix in self.vectorstore._collection.list_search_indexes()
):
create_fulltext_search_index(
collection=self.vectorstore._collection,
index_name=search_index_name,
field=search_field or self.vectorstore._text_key,
wait_until_complete=auto_index_timeout,
)
def _get_relevant_documents(
self,
query: str,
*,
run_manager: Optional[CallbackManagerForRetrieverRun] = None,
) -> List[Document]:
# Prepare query for vector search (handles auto embeddings check)
query_input, is_autoembedding = prepare_query_for_vector_search(
query, self.vectorstore._embedding
)
# Build the vector search stage based on embedding type
if is_autoembedding:
assert isinstance(query_input, str)
auto_embedding = self.vectorstore._embedding
vector_stage = autoembedding_vector_search_stage(
query=query_input,
search_field=self.vectorstore._text_key,
index_name=self.vectorstore._index_name,
model=auto_embedding.model, # type: ignore[attr-defined]
**self.search_kwargs, # See MongoDBAtlasVectorSearch
)
else:
assert self.vectorstore._embedding_key is not None
assert isinstance(query_input, list)
vector_stage = vector_search_stage(
query_vector=query_input,
search_field=self.vectorstore._embedding_key,
index_name=self.vectorstore._index_name,
**self.search_kwargs, # See MongoDBAtlasVectorSearch
)
pipeline = [
vector_stage,
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
{"$project": {"embedding": 0}},
{ # Find corresponding parent doc
"$lookup": {
"from": self.vectorstore.collection.name,
"localField": self.id_key,
"foreignField": "_id",
"as": "parent_context",
"pipeline": [
# Discard sub-documents
{"$match": {f"metadata.{self.id_key}": {"$exists": False}}},
],
}
}, # Remove duplicate parent docs and reformat
{"$unwind": {"path": "$parent_context"}},
{
"$group": {
"_id": "$parent_context._id",
"uniqueDocument": {"$first": "$parent_context"},
}
},
{"$replaceRoot": {"newRoot": "$uniqueDocument"}},
]
# Execute
cursor = self.vectorstore._collection.aggregate(pipeline) # type: ignore[arg-type]
docs = []
# Format into Documents
for res in cursor:
text = res.pop(self.vectorstore._text_key)
make_serializable(res)
docs.append(Document(page_content=text, metadata=res))
return docs
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
"""Asynchronous version of get_relevant_documents"""
return await run_in_executor(
None,
self._get_relevant_documents,
query,
run_manager=run_manager.get_sync(),
)
[docs]
@classmethod
def from_connection_string(
cls,
connection_string: str,
embedding_model: Embeddings,
child_splitter: TextSplitter,
database_name: str,
collection_name: str = "document_with_chunks",
id_key: str = "doc_id",
auto_create_index: bool = True,
auto_index_timeout: int = 15,
search_index_name: str = "text_index",
search_field: Optional[str] = None,
**kwargs: Any,
) -> MongoDBAtlasParentDocumentRetriever:
"""Construct Retriever using one Collection for VectorStore and one for DocStore
See parent classes
:class:`~langchain.retrievers.parent_document_retriever.ParentDocumentRetriever`
and :class:`~langchain.retrievers.MultiVectorRetriever` for further details.
Args:
connection_string: A valid MongoDB Atlas connection URI.
embedding_model: The text embedding model to use for the vector store.
child_splitter: Splits documents into chunks.
If parent_splitter is given, the documents will have already been split.
database_name: Name of database to connect to. Created if it does not exist.
collection_name: Name of collection to use.
It includes parent documents, sub-documents and their embeddings.
id_key: Key used to identify parent documents.
auto_create_index: Whether to automatically create the full-text search index if it does not exist. Defaults to True.
auto_index_timeout: How long to wait for the automatic index creation to complete, in seconds
search_index_name: Name of the full-text search index to create when auto_create_index is True. Defaults to "text_index".
search_field: Field to index for full-text search. Defaults to the vectorstore text key.
**kwargs: Additional keyword arguments. See parent classes for more.
Returns: A new MongoDBAtlasParentDocumentRetriever
"""
client: MongoClient = MongoClient(
connection_string,
driver=DRIVER_METADATA,
)
collection = client[database_name][collection_name]
vectorstore = MongoDBAtlasVectorSearch(
collection=collection, embedding=embedding_model, **kwargs
)
docstore = MongoDBDocStore(collection=collection)
docstore.collection.create_index([(id_key, pymongo.ASCENDING)])
return cls(
vectorstore=vectorstore,
docstore=docstore,
child_splitter=child_splitter,
id_key=id_key,
auto_create_index=auto_create_index,
auto_index_timeout=auto_index_timeout,
search_index_name=search_index_name,
search_field=search_field,
**kwargs,
)
[docs]
def close(self) -> None:
"""Close the resources used by the MongoDBAtlasParentDocumentRetriever."""
self.vectorstore.close()
self.docstore.close()