Skip to main content

synaptic_mongodb/
checkpointer.rs

1use async_trait::async_trait;
2use bson::{doc, DateTime as BsonDateTime};
3use futures::TryStreamExt;
4use mongodb::{Collection, Database, IndexModel};
5use synaptic_core::SynapticError;
6use synaptic_graph::{Checkpoint, CheckpointConfig, Checkpointer};
7
8/// MongoDB-backed graph checkpointer.
9///
10/// Stores graph state checkpoints in a MongoDB collection, suitable for
11/// distributed deployments where multiple processes share checkpointed state.
12///
13/// # Example
14///
15/// ```rust,no_run
16/// use synaptic_mongodb::MongoCheckpointer;
17///
18/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
19/// let client = mongodb::Client::with_uri_str("mongodb://localhost:27017").await?;
20/// let db = client.database("myapp");
21/// let checkpointer = MongoCheckpointer::new(&db, "graph_checkpoints").await?;
22/// # Ok(())
23/// # }
24/// ```
25pub struct MongoCheckpointer {
26    collection: Collection<bson::Document>,
27}
28
29impl MongoCheckpointer {
30    /// Create a new `MongoCheckpointer` backed by the given MongoDB database and collection.
31    ///
32    /// Creates a compound index on `(thread_id, checkpoint_id)` and a secondary
33    /// index on `(thread_id, seq)` for efficient ordered retrieval.
34    pub async fn new(db: &Database, collection_name: &str) -> Result<Self, SynapticError> {
35        let collection: Collection<bson::Document> = db.collection(collection_name);
36
37        // Unique index on (thread_id, checkpoint_id) — deduplicates puts
38        let unique_idx = IndexModel::builder()
39            .keys(doc! { "thread_id": 1, "checkpoint_id": 1 })
40            .options(
41                mongodb::options::IndexOptions::builder()
42                    .unique(true)
43                    .build(),
44            )
45            .build();
46
47        // Index on (thread_id, seq) for ordered listing and latest retrieval
48        let seq_idx = IndexModel::builder()
49            .keys(doc! { "thread_id": 1, "seq": 1 })
50            .build();
51
52        collection
53            .create_index(unique_idx)
54            .await
55            .map_err(|e| SynapticError::Store(format!("MongoDB create unique index: {e}")))?;
56
57        collection
58            .create_index(seq_idx)
59            .await
60            .map_err(|e| SynapticError::Store(format!("MongoDB create seq index: {e}")))?;
61
62        Ok(Self { collection })
63    }
64}
65
66#[async_trait]
67impl Checkpointer for MongoCheckpointer {
68    async fn put(
69        &self,
70        config: &CheckpointConfig,
71        checkpoint: &Checkpoint,
72    ) -> Result<(), SynapticError> {
73        // Serialize the checkpoint to JSON string for storage
74        let state_json = serde_json::to_string(checkpoint)
75            .map_err(|e| SynapticError::Store(format!("Serialize: {e}")))?;
76
77        // Determine next seq number for this thread
78        let count = self
79            .collection
80            .count_documents(doc! { "thread_id": &config.thread_id })
81            .await
82            .map_err(|e| SynapticError::Store(format!("MongoDB count: {e}")))?;
83
84        let document = doc! {
85            "thread_id": &config.thread_id,
86            "checkpoint_id": &checkpoint.id,
87            "seq": count as i64,
88            "state": &state_json,
89            "created_at": BsonDateTime::now(),
90        };
91
92        // Use upsert to be idempotent — same (thread_id, checkpoint_id) replaces
93        self.collection
94            .update_one(
95                doc! {
96                    "thread_id": &config.thread_id,
97                    "checkpoint_id": &checkpoint.id
98                },
99                doc! { "$setOnInsert": document },
100            )
101            .with_options(
102                mongodb::options::UpdateOptions::builder()
103                    .upsert(true)
104                    .build(),
105            )
106            .await
107            .map_err(|e| SynapticError::Store(format!("MongoDB upsert: {e}")))?;
108
109        Ok(())
110    }
111
112    async fn get(&self, config: &CheckpointConfig) -> Result<Option<Checkpoint>, SynapticError> {
113        let filter = if let Some(ref id) = config.checkpoint_id {
114            doc! { "thread_id": &config.thread_id, "checkpoint_id": id }
115        } else {
116            doc! { "thread_id": &config.thread_id }
117        };
118
119        let opts = mongodb::options::FindOneOptions::builder()
120            .sort(doc! { "seq": -1 })
121            .build();
122
123        let result = self
124            .collection
125            .find_one(filter)
126            .with_options(opts)
127            .await
128            .map_err(|e| SynapticError::Store(format!("MongoDB find_one: {e}")))?;
129
130        match result {
131            None => Ok(None),
132            Some(doc) => {
133                let state_str = doc
134                    .get_str("state")
135                    .map_err(|e| SynapticError::Store(format!("MongoDB get state field: {e}")))?;
136                let cp: Checkpoint = serde_json::from_str(state_str)
137                    .map_err(|e| SynapticError::Store(format!("Deserialize: {e}")))?;
138                Ok(Some(cp))
139            }
140        }
141    }
142
143    async fn list(&self, config: &CheckpointConfig) -> Result<Vec<Checkpoint>, SynapticError> {
144        let filter = doc! { "thread_id": &config.thread_id };
145        let opts = mongodb::options::FindOptions::builder()
146            .sort(doc! { "seq": 1 })
147            .build();
148
149        let mut cursor = self
150            .collection
151            .find(filter)
152            .with_options(opts)
153            .await
154            .map_err(|e| SynapticError::Store(format!("MongoDB find: {e}")))?;
155
156        let mut checkpoints = Vec::new();
157        while let Some(doc) = cursor
158            .try_next()
159            .await
160            .map_err(|e| SynapticError::Store(format!("MongoDB cursor: {e}")))?
161        {
162            let state_str = doc
163                .get_str("state")
164                .map_err(|e| SynapticError::Store(format!("MongoDB get state field: {e}")))?;
165            let cp: Checkpoint = serde_json::from_str(state_str)
166                .map_err(|e| SynapticError::Store(format!("Deserialize: {e}")))?;
167            checkpoints.push(cp);
168        }
169
170        Ok(checkpoints)
171    }
172}