import asyncio
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import contextmanager
from datetime import datetime
from typing import (
Any,
Optional,
)
from langchain_core.runnables import RunnableConfig, run_in_executor
from pymongo import ASCENDING, MongoClient, UpdateOne
from pymongo.database import Database as MongoDatabase
from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
BaseCheckpointSaver,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
get_checkpoint_id,
)
from .utils import DRIVER_METADATA, dumps_metadata, loads_metadata
[docs]
class MongoDBSaver(BaseCheckpointSaver):
"""A checkpointer that stores StateGraph checkpoints in a MongoDB database.
A compound index as shown below will be added to each of the collections
backing the saver (checkpoints, pending writes). If the collections pre-exist,
and have indexes already, nothing will be done during initialization::
keys=[("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)],
unique=True,
Args:
client (MongoClient): The MongoDB connection.
db_name (Optional[str]): Database name
checkpoint_collection_name (Optional[str]): Name of Collection of Checkpoints
writes_collection_name (Optional[str]): Name of Collection of intermediate writes.
ttl (Optional[int]): Time to live in seconds. See https://www.mongodb.com/docs/manual/core/index-ttl/.
Examples:
>>> from langgraph.checkpoint.mongodb import MongoDBSaver
>>> from langgraph.graph import StateGraph
>>> from pymongo import MongoClient
>>>
>>> builder = StateGraph(int)
>>> builder.add_node("add_one", lambda x: x + 1)
>>> builder.set_entry_point("add_one")
>>> builder.set_finish_point("add_one")
>>> client = MongoClient("mongodb://localhost:27017")
>>> memory = MongoDBSaver(client)
>>> graph = builder.compile(checkpointer=memory)
>>> config = {"configurable": {"thread_id": "1"}}
>>> graph.get_state(config)
>>> result = graph.invoke(3, config)
>>> graph.get_state(config)
StateSnapshot(values=4, next=(), config={'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef8b22d-df71-6ddc-8001-7c821b5c45fd'}}, metadata={'source': 'loop', 'writes': {'add_one': 4}, 'step': 1, 'parents': {}}, created_at='2024-10-15T18:25:34.088329+00:00', parent_config={'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef8b22d-df6f-6eec-8000-20f621dcf3b7'}}, tasks=())
Adding sharding support:
>>> from langgraph.checkpoint.mongodb import MongoDBSaver
>>> from pymongo import MongoClient
>>> memory = MongoDBSaver(client)
>>> client = MongoClient("mongodb://localhost:27017")
>>> client.admin.command('enableSharding', memory.db_name)
>>> shard_key = {'your_shard_key': 1} # Specify your shard key
>>> client.admin.command('shardCollection', f'{memory.db_name}.{memory.checkpoint_collection_name}', key=shard_key)
>>> client.admin.command('shardCollection', f'{memory.db_name}.{memory.writes_collection_name}', key=shard_key)
"""
client: MongoClient
db: MongoDatabase
[docs]
def __init__(
self,
client: MongoClient,
db_name: str = "checkpointing_db",
checkpoint_collection_name: str = "checkpoints",
writes_collection_name: str = "checkpoint_writes",
ttl: Optional[int] = None,
**kwargs: Any,
) -> None:
super().__init__()
self.client = client
self.db = self.client[db_name]
self.checkpoint_collection = self.db[checkpoint_collection_name]
self.writes_collection = self.db[writes_collection_name]
self.ttl = ttl
# Create indexes if not present
if len(self.checkpoint_collection.list_indexes().to_list()) < 2:
self.checkpoint_collection.create_index(
keys=[("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)],
unique=True,
)
if self.ttl:
self.checkpoint_collection.create_index(
keys=[("created_at", ASCENDING)],
expireAfterSeconds=self.ttl,
)
if len(self.writes_collection.list_indexes().to_list()) < 2:
self.writes_collection.create_index(
keys=[
("thread_id", 1),
("checkpoint_ns", 1),
("checkpoint_id", -1),
("task_id", 1),
("idx", 1),
],
unique=True,
)
if self.ttl:
self.writes_collection.create_index(
keys=[("created_at", ASCENDING)],
expireAfterSeconds=self.ttl,
)
[docs]
@classmethod
@contextmanager
def from_conn_string(
cls,
conn_string: Optional[str] = None,
db_name: str = "checkpointing_db",
checkpoint_collection_name: str = "checkpoints",
writes_collection_name: str = "checkpoint_writes",
ttl: Optional[int] = None,
**kwargs: Any,
) -> Iterator["MongoDBSaver"]:
"""Context manager to create a MongoDB checkpoint saver.
A compound index as shown below will be added to each of the collections
backing the saver (checkpoints, pending writes). If the collections pre-exist,
and have indexes already, nothing will be done during initialization::
keys=[("thread_id", 1), ("checkpoint_ns", 1), ("checkpoint_id", -1)],
unique=True
Args:
conn_string: MongoDB connection string. See [class:~pymongo.MongoClient].
db_name: Database name. It will be created if it doesn't exist.
checkpoint_collection_name: Checkpoint Collection name. Created if it doesn't exist.
writes_collection_name: Collection name of intermediate writes. Created if it doesn't exist.
ttl (Optional[int]): Time to live in seconds.
Yields: A new MongoDBSaver.
"""
client: Optional[MongoClient] = None
try:
client = MongoClient(
conn_string,
driver=DRIVER_METADATA,
)
yield MongoDBSaver(
client,
db_name,
checkpoint_collection_name,
writes_collection_name,
ttl,
**kwargs,
)
finally:
if client:
client.close()
[docs]
def close(self) -> None:
"""Close the resources used by the MongoDBSaver."""
self.client.close()
[docs]
def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Get a checkpoint tuple from the database.
This method retrieves a checkpoint tuple from the MongoDB database based on the
provided config. If the config contains a "checkpoint_id" key, the checkpoint with
the matching thread ID and checkpoint ID is retrieved. Otherwise, the latest checkpoint
for the given thread ID is retrieved.
Args:
config (RunnableConfig): The config to use for retrieving the checkpoint.
Returns:
Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.
Examples:
Basic:
>>> config = {"configurable": {"thread_id": "1"}}
>>> checkpoint_tuple = memory.get_tuple(config)
>>> print(checkpoint_tuple)
CheckpointTuple(...)
With checkpoint ID:
>>> config = {
... "configurable": {
... "thread_id": "1",
... "checkpoint_ns": "",
... "checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875",
... }
... }
>>> checkpoint_tuple = memory.get_tuple(config)
>>> print(checkpoint_tuple)
CheckpointTuple(...)
"""
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
if checkpoint_id := get_checkpoint_id(config):
query = {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
else:
query = {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}
result = self.checkpoint_collection.find(
query, sort=[("checkpoint_id", -1)], limit=1
)
for doc in result:
config_values = {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": doc["checkpoint_id"],
}
checkpoint = self.serde.loads_typed((doc["type"], doc["checkpoint"]))
serialized_writes = self.writes_collection.find(config_values)
pending_writes = [
(
doc["task_id"],
doc["channel"],
self.serde.loads_typed((doc["type"], doc["value"])),
)
for doc in serialized_writes
]
return CheckpointTuple(
{"configurable": config_values},
checkpoint,
loads_metadata(doc["metadata"]),
(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": doc["parent_checkpoint_id"],
}
}
if doc.get("parent_checkpoint_id")
else None
),
pending_writes,
)
[docs]
def list(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> Iterator[CheckpointTuple]:
"""List checkpoints from the database.
This method retrieves a list of checkpoint tuples from the MongoDB database based
on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first).
Args:
config (RunnableConfig): The config to use for listing the checkpoints.
filter (Optional[dict[str, Any]]): Additional filtering criteria for metadata. Defaults to None.
before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None.
limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None.
Yields:
Iterator[CheckpointTuple]: An iterator of checkpoint tuples.
Examples:
>>> from langgraph.checkpoint.mongodb import MongoDBSaver
>>> with MongoDBSaver.from_conn_string("mongodb://localhost:27017") as memory:
... # Run a graph, then list the checkpoints
>>> config = {"configurable": {"thread_id": "1"}}
>>> checkpoints = list(memory.list(config, limit=2))
>>> print(checkpoints)
[CheckpointTuple(...), CheckpointTuple(...)]
"""
query = {}
if config is not None:
if "thread_id" in config["configurable"]:
query["thread_id"] = config["configurable"]["thread_id"]
if "checkpoint_ns" in config["configurable"]:
query["checkpoint_ns"] = config["configurable"]["checkpoint_ns"]
if filter:
for key, value in filter.items():
query[f"metadata.{key}"] = dumps_metadata(value)
if before is not None:
query["checkpoint_id"] = {"$lt": before["configurable"]["checkpoint_id"]}
result = self.checkpoint_collection.find(
query, limit=0 if limit is None else limit, sort=[("checkpoint_id", -1)]
)
for doc in result:
config_values = {
"thread_id": doc["thread_id"],
"checkpoint_ns": doc["checkpoint_ns"],
"checkpoint_id": doc["checkpoint_id"],
}
serialized_writes = self.writes_collection.find(config_values)
pending_writes = [
(
wrt["task_id"],
wrt["channel"],
self.serde.loads_typed((wrt["type"], wrt["value"])),
)
for wrt in serialized_writes
]
yield CheckpointTuple(
config={
"configurable": {
"thread_id": doc["thread_id"],
"checkpoint_ns": doc["checkpoint_ns"],
"checkpoint_id": doc["checkpoint_id"],
}
},
checkpoint=self.serde.loads_typed((doc["type"], doc["checkpoint"])),
metadata=loads_metadata(doc["metadata"]),
parent_config=(
{
"configurable": {
"thread_id": doc["thread_id"],
"checkpoint_ns": doc["checkpoint_ns"],
"checkpoint_id": doc["parent_checkpoint_id"],
}
}
if doc.get("parent_checkpoint_id")
else None
),
pending_writes=pending_writes,
)
[docs]
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Save a checkpoint to the database.
This method saves a checkpoint to the MongoDB database. The checkpoint is associated
with the provided config and its parent config (if any).
Args:
config (RunnableConfig): The config to associate with the checkpoint.
checkpoint (Checkpoint): The checkpoint to save.
metadata (CheckpointMetadata): Additional metadata to save with the checkpoint.
new_versions (ChannelVersions): New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
Examples:
>>> from langgraph.checkpoint.mongodb import MongoDBSaver
>>> with MongoDBSaver.from_conn_string("mongodb://localhost:27017") as memory:
>>> config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}}
>>> checkpoint = {"ts": "2024-05-04T06:32:42.235444+00:00", "id": "1ef4f797-8335-6428-8001-8a1503f9b875", "data": {"key": "value"}}
>>> saved_config = memory.put(config, checkpoint, {"source": "input", "step": 1, "writes": {"key": "value"}}, {})
>>> print(saved_config)
{'configurable': {'thread_id': '1', 'checkpoint_ns': '', 'checkpoint_id': '1ef4f797-8335-6428-8001-8a1503f9b875'}}
"""
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"]["checkpoint_ns"]
checkpoint_id = checkpoint["id"]
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
metadata = metadata.copy()
metadata.update(config.get("metadata", {}))
doc = {
"parent_checkpoint_id": config["configurable"].get("checkpoint_id"),
"type": type_,
"checkpoint": serialized_checkpoint,
"metadata": dumps_metadata(metadata),
}
upsert_query = {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
if self.ttl:
doc["created_at"] = datetime.now()
self.checkpoint_collection.update_one(upsert_query, {"$set": doc}, upsert=True)
return {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
}
}
[docs]
def put_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Store intermediate writes linked to a checkpoint.
This method saves intermediate writes associated with a checkpoint to the MongoDB database.
Args:
config (RunnableConfig): Configuration of the related checkpoint.
writes (Sequence[tuple[str, Any]]): List of writes to store, each as (channel, value) pair.
task_id (str): Identifier for the task creating the writes.
task_path (str): Path of the task creating the writes.
"""
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"]["checkpoint_ns"]
checkpoint_id = config["configurable"]["checkpoint_id"]
set_method = ( # Allow replacement on existing writes only if there were errors.
"$set" if all(w[0] in WRITES_IDX_MAP for w in writes) else "$setOnInsert"
)
operations = []
now = datetime.now()
for idx, (channel, value) in enumerate(writes):
upsert_query = {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": checkpoint_id,
"task_id": task_id,
"task_path": task_path,
"idx": WRITES_IDX_MAP.get(channel, idx),
}
type_, serialized_value = self.serde.dumps_typed(value)
update_doc: dict[str, Any] = {
"channel": channel,
"type": type_,
"value": serialized_value,
}
if self.ttl:
update_doc["created_at"] = now
operations.append(
UpdateOne(
filter=upsert_query,
update={set_method: update_doc},
upsert=True,
)
)
self.writes_collection.bulk_write(operations)
[docs]
def delete_thread(
self,
thread_id: str,
) -> None:
"""Delete all checkpoints and writes associated with a specific thread ID.
Args:
thread_id (str): The thread ID whose checkpoints should be deleted.
"""
# Delete all checkpoints associated with the thread ID
self.checkpoint_collection.delete_many({"thread_id": thread_id})
# Delete all writes associated with the thread ID
self.writes_collection.delete_many({"thread_id": thread_id})
[docs]
async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
"""Asynchronously fetch a checkpoint tuple using the given configuration.
Asynchronously wraps the blocking `self.get_tuple` method.
Args:
config: Configuration specifying which checkpoint to retrieve.
Returns:
Optional[CheckpointTuple]: The requested checkpoint tuple, or None if not found.
"""
return await run_in_executor(None, self.get_tuple, config)
[docs]
async def alist(
self,
config: Optional[RunnableConfig],
*,
filter: Optional[dict[str, Any]] = None,
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
) -> AsyncIterator[CheckpointTuple]:
"""Asynchronously list checkpoints that match the given criteria.
Asynchronously wraps the blocking `self.list` generator.
Runs `self.list(...)` in a background thread and yields its items
asynchronously from an asyncio.Queue. This allows integration of
synchronous iterators into async code.
Args:
config: Configuration object passed to `self.list`.
filter: Optional filter dictionary.
before: Optional parameter to limit results before a given checkpoint.
limit: Optional maximum number of results to yield.
Yields:
AsyncIterator[CheckpointTuple]: An iterator of checkpoint tuples.
"""
loop = asyncio.get_running_loop()
queue: asyncio.Queue[CheckpointTuple] = asyncio.Queue()
sentinel = object()
def run() -> None:
try:
for item in self.list(
config, filter=filter, before=before, limit=limit
):
loop.call_soon_threadsafe(queue.put_nowait, item)
finally:
loop.call_soon_threadsafe(queue.put_nowait, sentinel) # type: ignore
await run_in_executor(None, run)
while True:
item = await queue.get()
if item is sentinel:
break
yield item
[docs]
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
"""Asynchronously store a checkpoint with its configuration and metadata.
Asynchronously wraps the blocking `self.put` method.
Args:
config: Configuration for the checkpoint.
checkpoint: The checkpoint to store.
metadata: Additional metadata for the checkpoint.
new_versions: New channel versions as of this write.
Returns:
RunnableConfig: Updated configuration after storing the checkpoint.
"""
return await run_in_executor(
None, self.put, config, checkpoint, metadata, new_versions
)
[docs]
async def aput_writes(
self,
config: RunnableConfig,
writes: Sequence[tuple[str, Any]],
task_id: str,
task_path: str = "",
) -> None:
"""Asynchronously store intermediate writes linked to a checkpoint.
Asynchronously wraps the blocking `self.put_writes` method.
Args:
config: Configuration of the related checkpoint.
writes: List of writes to store.
task_id: Identifier for the task creating the writes.
task_path: Path of the task creating the writes.
"""
return await run_in_executor(
None, self.put_writes, config, writes, task_id, task_path
)
[docs]
async def adelete_thread(
self,
thread_id: str,
) -> None:
"""Delete all checkpoints and writes associated with a specific thread ID.
Asynchronously wraps the blocking `self.delete_thread` method.
Args:
thread_id: The thread ID whose checkpoints should be deleted.
"""
return await run_in_executor(None, self.delete_thread, thread_id)