Source code for langgraph.store.mongodb.base

import logging
from collections.abc import Iterable, Iterator, Sequence
from contextlib import contextmanager
from datetime import datetime, timezone
from typing import Any, Literal, Optional, Union

from bson import SON
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import run_in_executor
from langchain_mongodb.embeddings import AutoEmbeddings
from langgraph.store.base import (
    BaseStore,
    GetOp,
    IndexConfig,
    Item,
    ListNamespacesOp,
    NamespacePath,
    Op,
    PutOp,
    Result,
    SearchItem,
    SearchOp,
    TTLConfig,
)
from langgraph.store.base.embed import (
    AEmbeddingsFunc,
    EmbeddingsFunc,
    ensure_embeddings,
    get_text_at_path,
)
from pymongo import (
    DeleteOne,
    MongoClient,
    UpdateOne,
)
from pymongo.collection import Collection, ReturnDocument
from pymongo_search_utils import (
    append_client_metadata,
    autoembedding_vector_search_stage,
    create_vector_search_index,
    vector_search_stage,
)

from langgraph.store.mongodb.utils import DRIVER_METADATA

logger = logging.getLogger(__name__)


def _validate_filter(filter_dict: dict[str, Any]) -> None:
    for key, value in filter_dict.items():
        if not isinstance(key, str) or key.startswith("$"):
            raise ValueError(
                f"Invalid filter key '{key}': MongoDB operator keys are not allowed."
            )
        if isinstance(value, dict):
            _validate_filter(value)


[docs] class VectorIndexConfig(IndexConfig, total=False): """Configuration for a MongoDB Atlas Vector Index providing semantic search. Use the factory function, ~langgraph.store.mongodb.create_vector_index_config for convenient creation and sensible defaults. Unlike PostgreSQL, MongoDB does not require a separate package or vector store. Embeddings are stored and indexed within the collection that holds the data. For more: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-overview Vector Search can be Approximate Nearest Neighbor (ANN) or Exact (ENN). For ANN, Atlas uses HNSW (Hierarchical Navigable Small World). """ name: str """Name of the index attached to the Collection in the Atlas Database.""" relevance_score_fn: Literal["euclidean", "cosine", "dotProduct", None] """Similarity scoring function used to compare vectors.""" embedding_key: str """This key will contain the embedding vector for the value in fields[0]. MongoDB does not require a separate Vector Store. It is designed to have one vector per document. NOTE: If using AutoEmbeddings, the vectors are not explicitly stored in the Collection. Set dims to -1, relevance_score_fn to None. The embedding_key will not store vectors. Instead, it will be the texts to be embedded. """ filters: list[str] """List of fields to that can filtered during search. Fields must be included in the index if they are to be filtered upon. The `namespace` field will always be included. These will be added to the vector search index automatically if not present. Fields can refer to top-level or embedded values (e.g. metadata.available) NOTE: The `value` key at the front of every field is implicit. It need not be included, although it is contained in the collection and index. """
[docs] def create_vector_index_config( dims: int | None, embed: Union[Embeddings, EmbeddingsFunc, AEmbeddingsFunc, str], fields: Optional[list[str]] = None, name: str = "vector_index", relevance_score_fn: Literal["euclidean", "cosine", "dotProduct", None] = "cosine", embedding_key: str | None = "embedding", filters: Optional[list[str]] = None, ) -> VectorIndexConfig: """Factory function creates a VectorIndexConfig instance with sensible defaults. Args: dims: Dimensions of the embedding vectors. embed: Embedding model. fields: Field to extract text from for embedding generation (list of length 1). name: Arbitrary name to give to the index in Atlas. relevance_score_fn: Function used to establish similarity of vectors. embedding_key: Name of the field used in the collection to store vectors. filters: List of (possibly nested) fields to index allowing filtering. Returns: VectorIndexConfig to be passed to MongoDBStore constructor. """ MongoDBStore.ensure_index_filters(filters) if filters and "namespace_prefix" not in filters: filters.append("namespace_prefix") return VectorIndexConfig( dims=dims, embed=embed, fields=fields, name=name, relevance_score_fn=relevance_score_fn, embedding_key=embedding_key, filters=filters, )
[docs] class MongoDBStore(BaseStore): """MongoDB's persistent key-value stores for long-term memory. Stores enable persistence and memory that can be shared across threads, scoped to user IDs, assistant IDs, or other arbitrary namespaces. Supports semantic search capabilities through an optional `index` configuration. Only a single embedding is permitted per item, although embedded fields can be indexed via ["parent.child"]. """ supports_ttl: bool = True
[docs] def __init__( self, collection: Collection, ttl_config: Optional[TTLConfig] = None, index_config: Optional[VectorIndexConfig] = None, auto_index_timeout: int = 15, query_model: str | None = None, **kwargs: Any, ): """Construct store and its indexes. Semantic search is supported by an Atlas vector search index TTL (time to live) is supported by a TTL index of field: updated_at. Args: collection: Collection of Items backing the store. ttl_config: Optionally define a TTL and whether to update on reads(get/search). index_config: Optionally define a VectorIndexConfig for semantic search. auto_index_timeout: Optional timeout for creation of indexes. query_model: For semantic search, optionally provide a different model for search than indexing. Returns: Instance of MongoDBStore. """ self.collection = collection self.ttl_config = {} if ttl_config is None else ttl_config self.index_config = {} if index_config is None else index_config self.sep = kwargs.get("sep", "/") append_client_metadata( client=self.collection.database.client, driver_info=DRIVER_METADATA ) # Create indexes if not present # Unique compound index on (namespace_str, key) acts as the primary key. # namespace_str is the namespace tuple joined into a single string (e.g. # "users/alice/preferences"), which avoids the multikey-index collision # that occurs when indexing the namespace array directly. indexes = list(self.collection.list_indexes()) idx_keys = [idx["key"] for idx in indexes] # Always backfill namespace_str on legacy documents. Reads and deletes # now filter on namespace_str, so any document missing it is unreachable. # This also runs if the index already exists but an old client wrote new # documents without namespace_str after the migration. backfill_batch: list[UpdateOne] = [] for doc in self.collection.find( {"namespace_str": {"$exists": False}}, {"_id": 1, "namespace": 1}, ): namespace = doc.get("namespace") or () self._validate_namespace(namespace) backfill_batch.append( UpdateOne( {"_id": doc["_id"], "namespace_str": {"$exists": False}}, {"$set": {"namespace_str": self.sep.join(namespace)}}, ) ) if len(backfill_batch) >= 1000: self.collection.bulk_write(backfill_batch, ordered=False) backfill_batch = [] if backfill_batch: self.collection.bulk_write(backfill_batch, ordered=False) # Check for the unique index by both key pattern and unique option, so # a non-unique index with the same key pattern does not suppress creation. # If a conflicting non-unique index exists, drop it first — otherwise # create_index(..., unique=True) raises IndexOptionsConflict. ns_str_key = SON([("namespace_str", 1), ("key", 1)]) ns_str_indexes = [idx for idx in indexes if idx["key"] == ns_str_key] has_unique_ns_idx = any(idx.get("unique", False) for idx in ns_str_indexes) if not has_unique_ns_idx: conflicting = next( (idx for idx in ns_str_indexes if not idx.get("unique", False)), None ) if conflicting is not None: self.collection.drop_index(conflicting["name"]) self.collection.create_index(keys=["namespace_str", "key"], unique=True) # Drop the legacy multikey index only after the new unique index exists, # so there is never a window with no uniqueness enforcement. # Re-query to avoid acting on a stale snapshot (e.g. concurrent init). current_idx_keys = [idx["key"] for idx in self.collection.list_indexes()] if SON([("namespace", 1), ("key", 1)]) in current_idx_keys: self.collection.drop_index([("namespace", 1), ("key", 1)]) # Optionally, expire values using [TTL Index](https://www.mongodb.com/docs/manual/core/index-ttl/) if ( self.ttl_config and SON([("updated_at", 1)]) not in idx_keys and self.ttl_config["default_ttl"] ): self.collection.create_index( "updated_at", expireAfterSeconds=self.ttl_config["default_ttl"] ) # If details provided, prepare vector index for semantic queries if self.index_config: self.index_field = self._ensure_index_fields(self.index_config["fields"]) self.index_filters = self.__class__.ensure_index_filters( self.index_config["filters"] ) self.embeddings: Embeddings = ensure_embeddings( self.index_config.get("embed"), ) self._index_name = self.index_config.get("name", "vector_index") self._relevance_score_fn = self.index_config.get( "relevance_score_fn", "cosine" ) self._embedding_key = self.index_config.get("embedding_key", "embedding") auto_embedding_model = None self._is_autoembedding = False if isinstance(self.embeddings, AutoEmbeddings): self._is_autoembedding = True auto_embedding_model = self.embeddings.model self.query_model = ( self.embeddings.model if query_model is None else query_model ) # Create the vector index if it does not yet exist if not any( [ ix["name"] == self._index_name for ix in collection.list_search_indexes() ] ): create_vector_search_index( collection=collection, index_name=self._index_name, dimensions=self.index_config["dims"], path=self._embedding_key, similarity=self._relevance_score_fn, filters=self.index_filters, wait_until_complete=auto_index_timeout, auto_embedding_model=auto_embedding_model, )
[docs] @classmethod @contextmanager def from_conn_string( cls, conn_string: Optional[str] = None, db_name: str = "checkpointing_db", collection_name: str = "persistent-store", ttl_config: Optional[TTLConfig] = None, index_config: Optional[VectorIndexConfig] = None, **kwargs: Any, ) -> Iterator["MongoDBStore"]: """Context manager to create a persistent MongoDB key-value store. A unique compound index will be added to the collections backing the store on (namespace_str, key). On first initialization of an existing collection, legacy documents are backfilled with a namespace_str field and the old (namespace, key) index is replaced. On every initialization, documents missing namespace_str are backfilled; all other steps are no-ops if the collection is already up to date. If the `ttl` argument is provided, TTL functionality will be employed. This is done automatically via MongoDB's TTL Indexes, based on the `updated_at` field of the collection. The index will be created if it does not already exist. Args: conn_string: MongoDB connection string. See [class:~pymongo.MongoClient]. db_name: Database name. It will be created if it doesn't exist. collection_name: Collection name backing the store. Created if it doesn't exist. ttl_config: Defines a TTL (in seconds) and whether to update on reads(get/search). index_config: Defines a VectorIndexConfig for semantic search queries. Yields: A new MongoDBStore. """ client: Optional[MongoClient] = None try: client = MongoClient( conn_string, driver=DRIVER_METADATA, ) db = client[db_name] if collection_name not in db.list_collection_names( authorizedCollections=True ): db.create_collection(collection_name) collection = client[db_name][collection_name] yield MongoDBStore( collection=collection, ttl_config=ttl_config, index_config=index_config, **kwargs, ) finally: if client: client.close()
[docs] def get( self, namespace: tuple[str, ...], key: str, *, refresh_ttl: Optional[bool] = None, ) -> Optional[Item]: """Retrieve a single item. Args: namespace: Hierarchical path for the item. key: Unique identifier within the namespace. refresh_ttl: Whether to refresh TTLs for the returned item. If None (default), uses the store's default refresh_ttl setting. If no TTL is specified, this argument is ignored. Returns: The retrieved item or None if not found. """ self._validate_namespace(namespace) ns_str = self.sep.join(namespace) if refresh_ttl is False or ( self.ttl_config and not self.ttl_config["refresh_on_read"] ): res = self.collection.find_one( filter={"namespace_str": ns_str, "key": key}, ) else: res = self.collection.find_one_and_update( filter={"namespace_str": ns_str, "key": key}, update={"$set": {"updated_at": datetime.now(tz=timezone.utc)}}, return_document=ReturnDocument.AFTER, ) if res: return Item( value=res["value"], key=res["key"], namespace=tuple(res["namespace"]), created_at=res["created_at"], updated_at=res["updated_at"], )
[docs] def delete(self, namespace: tuple[str, ...], key: str) -> None: """Delete an item. Args: namespace: Hierarchical path for the item. key: Unique identifier within the namespace. """ self._validate_namespace(namespace) self.collection.delete_one( {"namespace_str": self.sep.join(namespace), "key": key} )
def _validate_namespace(self, namespace: Sequence[str]) -> None: """Raise ValueError if any namespace part is empty or contains the separator.""" for part in namespace: if not part: raise ValueError( f"Namespace parts must not be empty. Got namespace: {namespace}" ) if self.sep in part: raise ValueError( f"Namespace parts must not contain the separator {self.sep!r}. " f"Got namespace: {namespace}" ) @staticmethod def _match_prefix(prefix: NamespacePath) -> dict[str, Any]: """Helper for list_namespaces.""" if not prefix or prefix == "*": return {} if "*" not in prefix: return {"$eq": [{"$slice": ["$namespace", len(prefix)]}, list(prefix)]} matches = [] for i, p in enumerate(prefix): if p != "*": matches.append({"$eq": [{"$arrayElemAt": ["$namespace", i]}, p]}) return {"$and": matches} @staticmethod def _match_suffix(suffix: NamespacePath) -> dict[str, Any]: """Helper for list_namespaces.""" if not suffix or suffix == "*": return {} if "*" not in suffix: return {"$eq": [{"$slice": ["$namespace", -1 * len(suffix)]}, list(suffix)]} matches = [] for i, p in enumerate(suffix): if p != "*": matches.append( { "$eq": [ { "$arrayElemAt": [ "$namespace", {"$subtract": [{"$size": "$namespace"}, i]}, ] }, p, ] } ) return {"$and": matches}
[docs] def list_namespaces( self, *, prefix: Optional[NamespacePath] = None, suffix: Optional[NamespacePath] = None, max_depth: Optional[int] = None, limit: int = 100, offset: int = 0, ) -> list[tuple[str, ...]]: """List and filter namespaces in the store. Args: prefix: Filter namespaces that start with this path. suffix: Filter namespaces that end with this path. max_depth: Return namespaces up to this depth in the hierarchy. limit: Maximum number of namespaces to return (default 100). offset: Number of namespaces to skip for pagination. [Not implemented.] Returns: A list of namespace tuples that match the criteria. """ pipeline: list[dict[str, Any]] = [] expr = {} if prefix: precond = self._match_prefix(prefix) expr = {"$expr": precond} if suffix: sufcond = self._match_suffix(suffix) expr = {"$expr": sufcond} if prefix and suffix: expr = {"$expr": {"$and": [precond, sufcond]}} pipeline.append({"$match": expr}) if max_depth: pipeline.append( { "$project": { "namespace": {"$slice": ["$namespace", max_depth]}, "_id": 0, } } ) else: pipeline.append({"$project": {"namespace": 1, "_id": 0}}) if limit: pipeline.append({"$limit": limit}) # Deduplicate pipeline.extend( [ {"$group": {"_id": "$namespace"}}, {"$project": {"_id": 0, "namespace": "$_id"}}, ] ) if offset: raise NotImplementedError("offset is not implemented") results = self.collection.aggregate(pipeline) return [tuple(res["namespace"]) for res in results]
[docs] def batch(self, ops: Iterable[Op]) -> list[Result]: """Execute multiple operations synchronously in a single batch. Get, Search, and List operations are performed on state before batch. Put and Delete change state. They are deduplicated and applied in order, but only after the read operations have completed. Args: ops: An iterable of operations to execute. Returns: A list of results, where each result corresponds to an operation in the input. The order of results matches the order of input operations. The length of output may not match the input as PutOp returns None. """ results: list[Result] = [] dedupped_putops: dict[tuple[tuple[str, ...], str], PutOp] = {} writes: list[Union[DeleteOne, UpdateOne]] = [] for op in ops: if isinstance(op, PutOp): self._validate_namespace(op.namespace) dedupped_putops[(op.namespace, op.key)] = op results.append(None) elif isinstance(op, GetOp): results.append( self.get( namespace=op.namespace, key=op.key, refresh_ttl=op.refresh_ttl, ) ) elif isinstance(op, SearchOp): results.append( self.search( op.namespace_prefix, query=op.query, filter=op.filter, limit=op.limit, offset=op.offset, refresh_ttl=op.refresh_ttl, ) ) elif isinstance(op, ListNamespacesOp): prefix = None suffix = None if op.match_conditions: for cond in op.match_conditions: if cond.match_type == "prefix": prefix = cond.path elif cond.match_type == "suffix": suffix = cond.path else: raise ValueError( f"Match type {cond.match_type} must be prefix or suffix." ) results.append( self.list_namespaces( prefix=prefix, suffix=suffix, max_depth=op.max_depth, limit=op.limit, offset=op.offset, ) ) # Apply puts and deletes in bulk # Extract texts to embed for each op if self.index_config: texts = self._extract_texts(list(dedupped_putops.values())) if not self._is_autoembedding: vectors = self.embeddings.embed_documents(texts) v = 0 for op in dedupped_putops.values(): ns_str = self.sep.join(op.namespace) if op.value is None: # mark the item for deletion. writes.append( DeleteOne( filter={ "namespace_str": ns_str, "key": op.key, } ) ) else: # Add or Upsert the value to_set = { "namespace_str": ns_str, "value": op.value, "updated_at": datetime.now(tz=timezone.utc), } if self.index_config: embed = texts[v] if self._is_autoembedding else vectors[v] to_set[self._embedding_key] = embed to_set["namespace_prefix"] = self._denormalize_path(op.namespace) v += 1 writes.append( UpdateOne( filter={ "namespace_str": ns_str, "key": op.key, }, update={ "$set": to_set, "$setOnInsert": { "namespace": list(op.namespace), "created_at": datetime.now(tz=timezone.utc), }, }, upsert=True, ) ) if writes: self.collection.bulk_write(writes) return results
[docs] async def abatch(self, ops: Iterable[Op]) -> list[Result]: """Execute multiple operations asynchronously in a single batch. Args: ops: An iterable of operations to execute. Returns: A list of results, where each result corresponds to an operation in the input. The order of results matches the order of input operations. """ return await run_in_executor(None, self.batch, ops)
[docs] def search( self, namespace_prefix: tuple[str, ...], /, *, query: Optional[str] = None, filter: Optional[dict[str, Any]] = None, limit: int = 10, offset: int = 0, refresh_ttl: Optional[bool] = None, **kwargs: Any, ) -> list[SearchItem]: """Search for items within a namespace prefix. Values are stored in the collection as a document of name 'value'. One uses dot notation to access embedded fields. For example, `value.text`, `value.address.city` and for arrays `value.titles.3`. Note that when search will not update the `last_updated` field and thus not affect TTL, unlike `get`. Args: namespace_prefix: Hierarchical path prefix to search within. query: Optional query for natural language search. filter: Key-value pairs to filter results. limit: Maximum number of items to return. offset: Number of items to skip before returning results. refresh_ttl: TTL is not supported for search. Use get if needed. Returns: List of items matching the search criteria. ???+ example "Examples" Basic filtering: ```python # Search for documents with specific metadata results = store.search( ("docs",), filter={"value.type": "article", "value.status": "published"} ) ``` Natural language search (requires vector store implementation): ```python # Initialize store with embedding configuration store = YourStore( # e.g., InMemoryStore, AsyncPostgresStore index={ "dims": 1536, # embedding dimensions "embed": your_embedding_function, # function to create embeddings "fields": ["text"] # fields to embed. Defaults to ["$"] } ) """ if not isinstance(namespace_prefix, tuple): raise TypeError("namespace_prefix must be a non-empty tuple of strings") if offset: raise NotImplementedError("offset is not implemented in MongoDBStore") if filter: if any(f.startswith("value") for f in filter): raise ValueError("filters should be specified without `value`") _validate_filter(filter) if query is None: # Case 1. $match namespace and filter pipeline: list[dict[str, Any]] = [] match_cond: dict[str, Any] = {} if namespace_prefix: match_cond = {"$expr": self._match_prefix(namespace_prefix)} if filter: filter_cond = [{f"value.{k}": v} for k, v in filter.items()] match_cond = {"$and": [match_cond] + filter_cond} pipeline.append({"$match": match_cond}) if limit: pipeline.append({"$limit": limit}) elif query: # Case 2. $vectorSearch on query filtered on namespace and optional filters # Form filter condition for namespace_prefix filter_vec = {"namespace_prefix": self.sep.join(namespace_prefix)} if filter: # and add any specified filter_cond = [{f"value.{k}": v} for k, v in filter.items()] filter_vec = {"$and": [filter_vec] + filter_cond} if not self._is_autoembedding: query_vector = self.embeddings.embed_query(query) pipeline = [ vector_search_stage( query_vector=query_vector, search_field=self._embedding_key, index_name=self._index_name, top_k=limit, filter=filter_vec, ), {"$set": {"score": {"$meta": "vectorSearchScore"}}}, {"$project": {self._embedding_key: 0}}, ] else: # Case 2b. $vectorSearch uses autoEmbed index. pipeline = [ autoembedding_vector_search_stage( query=query, search_field=self._embedding_key, index_name=self._index_name, model=self.query_model, top_k=limit, filter=filter_vec, ), {"$set": {"score": {"$meta": "vectorSearchScore"}}}, ] results = self.collection.aggregate(pipeline) return [ SearchItem( namespace=tuple(res["namespace"]), key=res["key"], value=res["value"], created_at=res["created_at"], updated_at=res["updated_at"], score=res.get("score"), ) for res in results ]
def _denormalize_path(self, paths: Union[tuple[str, ...], list[str]]) -> list[str]: """Create list of path parents, for use in $vectorSearch filter. ???+ example "Example" ```python namespace = ('parent', 'child', 'pet') prefixes=store_mdb.denormalize_path(namespace) assert prefixes == ['parent', 'parent/child', 'parent/child/pet'] ``` """ return [self.sep.join(paths[:i]) for i in range(1, len(paths) + 1)] def _extract_texts(self, put_ops: Optional[list[PutOp]]) -> list[str]: """Extract text to embed according to index config.""" if put_ops and self.index_config and self.embeddings: to_embed = [] for op in put_ops: if op.value is not None and op.index is not False: if op.index is None: field = self.index_field else: field = self._ensure_index_fields(list(op.index)) texts = get_text_at_path(op.value, field) if texts: if len(texts) > 1: raise ValueError("Got multiple texts. Report as bug.") else: to_embed.append(texts[0]) return to_embed else: return [] @staticmethod def _ensure_index_fields(fields: Optional[list[str]]) -> str: """Ensure that requested fields to be indexed result in a single vector. We require that one document may only have one embedding vector. """ if fields and (len(fields) > 1 or "*" in fields[0]): raise ValueError("Only one field can be indexed for queries.") if isinstance(fields, list): return fields[0] else: return fields
[docs] @classmethod def ensure_index_filters(cls, filters: Optional[list[str]] = None) -> list[str]: """Prepare filters for Atlas indexing. We must ensure that `namespace_prefix` is included in the filter. We also must ensure that the implicit `value` field is added. """ filters = [] if filters is None else filters if not isinstance(filters, list): raise ValueError( "Index filters must be a list. Found: ", type(filters), ) filters = [ f"value.{field}" for field in filters if not field.startswith("values") and field != "namespace_prefix" ] if "namespace_prefix" not in filters: filters.append("namespace_prefix") return filters