Skip to main content

volt_client_grpc/
sync_provider.rs

1//! Sync Provider for collaborative document editing with Y.js/yrs CRDT
2//!
3//! This module provides a SyncProvider that manages bidirectional synchronization
4//! of Y.js/yrs documents with a Volt server. It handles:
5//! - Initial document synchronization
6//! - Real-time updates from other clients
7//! - Chunked payload handling for large updates
8//! - State machine management (SYNCING → DONE → UPDATE → AWARENESS)
9//! - Subdocument synchronization events
10
11use std::sync::Arc;
12use tokio::sync::{mpsc, Mutex, RwLock};
13use tokio_stream::StreamExt;
14use yrs::updates::decoder::Decode;
15use yrs::updates::encoder::Encode;
16use yrs::{Doc, ReadTxn, StateVector, Subscription, Transact, Update, Uuid};
17
18use crate::error::{Result, VoltError};
19use crate::proto::volt::{self, SyncState};
20use crate::VoltClient;
21
22const MAX_CHUNK_SIZE: usize = 1024 * 1024; // 1MB chunks
23
24/// Origin marker for updates received from the server
25/// Used to distinguish server updates from local changes
26const SERVER_ORIGIN: &str = "volt-sync-provider";
27
28/// Events emitted by the SyncProvider
29#[derive(Debug, Clone)]
30pub enum SyncEvent {
31    /// Synchronization started
32    Syncing,
33    /// Initial synchronization complete
34    Synced,
35    /// Update received and applied
36    Updated,
37    /// Connection lost
38    Disconnected,
39    /// Error occurred
40    Error(String),
41    /// A subdocument was loaded and needs its own sync provider
42    SubdocLoaded { guid: Uuid, doc: Doc },
43    /// A subdocument was removed
44    SubdocRemoved { guid: Uuid },
45}
46
47/// Sync Provider state
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum ProviderState {
50    /// Not connected
51    Disconnected,
52    /// Synchronizing initial state
53    Syncing,
54    /// Synced and ready for updates
55    Synced,
56}
57
58/// SyncProvider manages document synchronization with Volt server
59pub struct SyncProvider {
60    /// The Y.js document (Doc is internally reference-counted)
61    doc: Doc,
62    /// Database ID
63    database_id: String,
64    /// Document ID
65    document_id: String,
66    /// Current provider state
67    state: Arc<RwLock<ProviderState>>,
68    /// Event sender
69    event_tx: mpsc::Sender<SyncEvent>,
70    /// Request sender to Volt server
71    request_tx: Option<mpsc::Sender<volt::SyncDocumentRequest>>,
72    /// Buffer for chunked payloads
73    chunk_buffer: Arc<Mutex<Vec<u8>>>,
74    /// Read-only mode
75    is_read_only: Arc<RwLock<bool>>,
76    /// Document update observer subscription (keeps observer alive)
77    #[allow(dead_code)]
78    update_subscription: Option<Subscription>,
79    /// Subdocs observer subscription (keeps observer alive)
80    #[allow(dead_code)]
81    subdocs_subscription: Option<Subscription>,
82    /// Channel for local updates that need to be sent to server
83    local_update_tx: Option<mpsc::Sender<Vec<u8>>>,
84}
85
86impl SyncProvider {
87    /// Create a new SyncProvider
88    ///
89    /// # Arguments
90    /// * `doc` - The Y.js document to sync
91    /// * `database_id` - ID of the database containing the document
92    /// * `document_id` - ID of the document to sync
93    ///
94    /// # Returns
95    /// A tuple of (SyncProvider, event_receiver)
96    pub fn new(
97        doc: Doc,
98        database_id: impl Into<String>,
99        document_id: impl Into<String>,
100    ) -> (Self, mpsc::Receiver<SyncEvent>) {
101        let (event_tx, event_rx) = mpsc::channel(100);
102
103        let provider = Self {
104            doc,
105            database_id: database_id.into(),
106            document_id: document_id.into(),
107            state: Arc::new(RwLock::new(ProviderState::Disconnected)),
108            event_tx,
109            request_tx: None,
110            chunk_buffer: Arc::new(Mutex::new(Vec::new())),
111            is_read_only: Arc::new(RwLock::new(false)),
112            update_subscription: None,
113            subdocs_subscription: None,
114            local_update_tx: None,
115        };
116
117        (provider, event_rx)
118    }
119
120    /// Start synchronization with the Volt server
121    ///
122    /// # Arguments
123    /// * `client` - The VoltClient to use for synchronization
124    /// * `read_only` - If true, only receive updates (no sending)
125    /// * `read_only_fallback` - If true, downgrade to read-only if write access denied
126    pub async fn start(
127        &mut self,
128        client: &VoltClient,
129        read_only: bool,
130        read_only_fallback: bool,
131    ) -> Result<()> {
132        // Set initial state
133        *self.state.write().await = ProviderState::Syncing;
134        self.event_tx
135            .send(SyncEvent::Syncing)
136            .await
137            .map_err(|_| VoltError::ConnectionError("Event channel closed".into()))?;
138
139        // Get the current state vector from our document (v1 encoding like JavaScript)
140        let state_vector = {
141            let txn = self.doc.transact();
142            txn.state_vector().encode_v1()
143        };
144
145        tracing::debug!("Local state vector: {} bytes", state_vector.len());
146
147        // Create the initial sync request
148        let sync_start = volt::SyncDocumentRequest {
149            payload: Some(volt::sync_document_request::Payload::SyncStart(
150                volt::SyncDocumentStart {
151                    database_id: self.database_id.clone(),
152                    document_id: self.document_id.clone(),
153                    state_vector,
154                    read_only,
155                    read_only_fallback,
156                },
157            )),
158            sync_state: SyncState::Syncing as i32,
159        };
160
161        tracing::debug!(
162            "Sending sync start request to database: {}, document: {}",
163            self.database_id,
164            self.document_id
165        );
166
167        // Create bidirectional stream with initial request
168        let (tx, mut rx) = client.sync_document(sync_start).await?;
169        self.request_tx = Some(tx.clone());
170
171        tracing::debug!("Sync stream established, waiting for responses...");
172
173        // Set up local update observer to send changes to server
174        // Create a channel for local updates (observer is sync, sending is async)
175        let (local_update_tx, mut local_update_rx) = mpsc::channel::<Vec<u8>>(100);
176        self.local_update_tx = Some(local_update_tx.clone());
177
178        // Set up the document observer
179        let update_state = self.state.clone();
180        let subscription = self
181            .doc
182            .observe_update_v2(move |txn, event| {
183                // Check if this update came from the server (has our origin marker)
184                if let Some(origin) = txn.origin() {
185                    if origin.as_ref() == SERVER_ORIGIN.as_bytes() {
186                        tracing::debug!("Skipping server-originated update");
187                        return;
188                    }
189                }
190
191                // Only send if we're synced (not during initial sync or disconnected)
192                // Note: We can't await here, so we use try_send
193                let state_guard = update_state.try_read();
194                if let Ok(state) = state_guard {
195                    if *state == ProviderState::Synced {
196                        let update = event.update.clone();
197                        tracing::debug!(
198                            "Local document updated, queueing {} bytes to send",
199                            update.len()
200                        );
201                        if let Err(e) = local_update_tx.try_send(update) {
202                            tracing::error!("Failed to queue local update: {}", e);
203                        }
204                    } else {
205                        tracing::debug!(
206                            "Ignoring local update, not synced yet (state: {:?})",
207                            *state
208                        );
209                    }
210                }
211            })
212            .map_err(|e| {
213                VoltError::ConnectionError(format!("Failed to set up document observer: {:?}", e))
214            })?;
215
216        self.update_subscription = Some(subscription);
217
218        // Set up subdocs observer to emit events when subdocs are loaded/removed
219        // The user is responsible for creating SyncProviders for subdocs
220        let subdoc_event_tx = self.event_tx.clone();
221        let subdocs_sub = self
222            .doc
223            .observe_subdocs(move |_txn, event| {
224                // Emit events for loaded subdocs
225                for subdoc in event.loaded() {
226                    let guid = subdoc.guid();
227                    let subdoc_clone = subdoc.clone();
228                    tracing::debug!("Subdoc loaded: {}", guid);
229                    if let Err(e) = subdoc_event_tx.try_send(SyncEvent::SubdocLoaded {
230                        guid,
231                        doc: subdoc_clone,
232                    }) {
233                        tracing::error!("Failed to send subdoc loaded event: {}", e);
234                    }
235                }
236
237                // Emit events for removed subdocs
238                for subdoc in event.removed() {
239                    let guid = subdoc.guid();
240                    tracing::debug!("Subdoc removed: {}", guid);
241                    if let Err(e) = subdoc_event_tx.try_send(SyncEvent::SubdocRemoved { guid }) {
242                        tracing::error!("Failed to send subdoc removed event: {}", e);
243                    }
244                }
245            })
246            .map_err(|e| {
247                VoltError::ConnectionError(format!("Failed to set up subdocs observer: {:?}", e))
248            })?;
249
250        self.subdocs_subscription = Some(subdocs_sub);
251
252        // Spawn task to send local updates to server
253        let update_request_tx = tx.clone();
254        let update_is_read_only = self.is_read_only.clone();
255        tokio::spawn(async move {
256            while let Some(update) = local_update_rx.recv().await {
257                // Check if read-only
258                if *update_is_read_only.read().await {
259                    tracing::debug!("Ignoring local update, document is read-only");
260                    continue;
261                }
262
263                // Send the update
264                tracing::debug!("Sending local update to server: {} bytes", update.len());
265                if let Err(e) = Self::send_update_chunks(&update_request_tx, &update).await {
266                    tracing::error!("Failed to send local update: {}", e);
267                }
268            }
269            tracing::debug!("Local update sender task finished");
270        });
271
272        // Spawn task to handle responses
273        let doc = self.doc.clone();
274        let event_tx = self.event_tx.clone();
275        let state = self.state.clone();
276        let chunk_buffer = self.chunk_buffer.clone();
277        let is_read_only = self.is_read_only.clone();
278        let request_tx = tx.clone();
279
280        tokio::spawn(async move {
281            tracing::debug!("Response handler task started");
282            while let Some(result) = rx.next().await {
283                tracing::debug!("Received response from server");
284                match result {
285                    Ok(response) => {
286                        tracing::debug!("Response sync_state: {}", response.sync_state);
287                        if let Err(e) = Self::handle_response(
288                            &doc,
289                            response,
290                            &event_tx,
291                            &state,
292                            &chunk_buffer,
293                            &is_read_only,
294                            &request_tx,
295                        )
296                        .await
297                        {
298                            let _ = event_tx.send(SyncEvent::Error(e.to_string())).await;
299                        }
300                    }
301                    Err(e) => {
302                        let _ = event_tx
303                            .send(SyncEvent::Error(format!("Stream error: {}", e)))
304                            .await;
305                        *state.write().await = ProviderState::Disconnected;
306                        let _ = event_tx.send(SyncEvent::Disconnected).await;
307                        break;
308                    }
309                }
310            }
311        });
312
313        Ok(())
314    }
315
316    /// Handle a response from the server
317    async fn handle_response(
318        doc: &Doc,
319        response: volt::SyncDocumentResponse,
320        event_tx: &mpsc::Sender<SyncEvent>,
321        state: &Arc<RwLock<ProviderState>>,
322        chunk_buffer: &Arc<Mutex<Vec<u8>>>,
323        is_read_only: &Arc<RwLock<bool>>,
324        request_tx: &mpsc::Sender<volt::SyncDocumentRequest>,
325    ) -> Result<()> {
326        // Check for errors
327        if let Some(status) = response.status {
328            if status.code != 0 {
329                return Err(VoltError::ServerError(format!(
330                    "Server error: {}",
331                    status.message
332                )));
333            }
334        }
335
336        // Handle read-only mode (only in initial response)
337        if response.is_read_only {
338            *is_read_only.write().await = true;
339        }
340
341        // Get sync state from response
342        let sync_state = SyncState::try_from(response.sync_state).unwrap_or(SyncState::Syncing);
343
344        match sync_state {
345            SyncState::Syncing => {
346                // Server response to our initial sync request
347                // Contains state_vector and any updates since last sync
348                tracing::debug!("Received SYNC_STATE_SYNCING response");
349
350                // Buffer the update chunk
351                if let Some(update) = &response.update {
352                    let mut buffer = chunk_buffer.lock().await;
353                    buffer.extend_from_slice(&update.chunk);
354
355                    if update.complete {
356                        // Apply the server's updates to our document (V2 encoding)
357                        // Use SERVER_ORIGIN so the observer knows to skip these updates
358                        if !buffer.is_empty() {
359                            tracing::debug!(
360                                "Applying {} bytes of updates from server",
361                                buffer.len()
362                            );
363                            match Update::decode_v2(&buffer) {
364                                Ok(update) => {
365                                    let mut txn = doc.transact_mut_with(SERVER_ORIGIN);
366                                    if let Err(e) = txn.apply_update(update) {
367                                        tracing::error!("Failed to apply update: {:?}", e);
368                                    }
369                                }
370                                Err(e) => {
371                                    tracing::error!("Failed to decode update: {:?}", e);
372                                }
373                            }
374                        }
375                        buffer.clear();
376
377                        // Now send our local updates based on server's state vector
378                        // The server's state_vector tells us what it already has
379                        let server_state_vector = if !response.state_vector.is_empty() {
380                            match StateVector::decode_v1(&response.state_vector) {
381                                Ok(sv) => sv,
382                                Err(e) => {
383                                    tracing::error!(
384                                        "Failed to decode server state vector: {:?}",
385                                        e
386                                    );
387                                    StateVector::default()
388                                }
389                            }
390                        } else {
391                            StateVector::default()
392                        };
393
394                        // Encode our state as an update relative to the server's state vector (V2)
395                        let local_update = {
396                            let txn = doc.transact();
397                            txn.encode_state_as_update_v2(&server_state_vector)
398                        };
399
400                        tracing::debug!(
401                            "Sending {} bytes of local updates to server",
402                            local_update.len()
403                        );
404
405                        // Send DONE message with our local updates
406                        let done_request = volt::SyncDocumentRequest {
407                            payload: Some(volt::sync_document_request::Payload::Update(
408                                volt::SyncDocumentUpdate {
409                                    chunk: local_update,
410                                    complete: true,
411                                },
412                            )),
413                            sync_state: SyncState::Done as i32,
414                        };
415
416                        request_tx.send(done_request).await.map_err(|_| {
417                            VoltError::ConnectionError("Failed to send DONE message".into())
418                        })?;
419                    }
420                } else {
421                    // No update in response, still need to send DONE with our state
422                    let server_state_vector = if !response.state_vector.is_empty() {
423                        match StateVector::decode_v1(&response.state_vector) {
424                            Ok(sv) => sv,
425                            Err(e) => {
426                                tracing::error!("Failed to decode server state vector: {:?}", e);
427                                StateVector::default()
428                            }
429                        }
430                    } else {
431                        StateVector::default()
432                    };
433
434                    let local_update = {
435                        let txn = doc.transact();
436                        txn.encode_state_as_update_v2(&server_state_vector)
437                    };
438
439                    tracing::debug!(
440                        "Sending {} bytes of local updates (no server update)",
441                        local_update.len()
442                    );
443
444                    let done_request = volt::SyncDocumentRequest {
445                        payload: Some(volt::sync_document_request::Payload::Update(
446                            volt::SyncDocumentUpdate {
447                                chunk: local_update,
448                                complete: true,
449                            },
450                        )),
451                        sync_state: SyncState::Done as i32,
452                    };
453
454                    request_tx.send(done_request).await.map_err(|_| {
455                        VoltError::ConnectionError("Failed to send DONE message".into())
456                    })?;
457                }
458            }
459            SyncState::Done => {
460                // Server acknowledges sync complete - we're now synced!
461                tracing::debug!("Received SYNC_STATE_DONE response");
462                *state.write().await = ProviderState::Synced;
463                let _ = event_tx.send(SyncEvent::Synced).await;
464            }
465            SyncState::Update => {
466                // Server sent an update - apply it to our document
467                // Use SERVER_ORIGIN so the observer knows to skip these updates
468                if let Some(update) = &response.update {
469                    let mut buffer = chunk_buffer.lock().await;
470                    buffer.extend_from_slice(&update.chunk);
471
472                    if update.complete {
473                        if !buffer.is_empty() {
474                            tracing::debug!(
475                                "Applying {} bytes of live update from server",
476                                buffer.len()
477                            );
478                            match Update::decode_v2(&buffer) {
479                                Ok(update) => {
480                                    let mut txn = doc.transact_mut_with(SERVER_ORIGIN);
481                                    if let Err(e) = txn.apply_update(update) {
482                                        tracing::error!("Failed to apply live update: {:?}", e);
483                                    }
484                                }
485                                Err(e) => {
486                                    tracing::error!("Failed to decode live update: {:?}", e);
487                                }
488                            }
489                        }
490                        let _ = event_tx.send(SyncEvent::Updated).await;
491                        buffer.clear();
492                    }
493                }
494            }
495            SyncState::Awareness => {
496                // Awareness update - ignore for now
497                tracing::debug!("Received SYNC_STATE_AWARENESS response");
498            }
499            _ => {
500                tracing::warn!("Unknown sync state: {:?}", sync_state);
501            }
502        }
503
504        Ok(())
505    }
506
507    /// Send an update to the server
508    ///
509    /// This should be called when the local document changes
510    pub async fn send_update(&self, update: &[u8]) -> Result<()> {
511        let tx = self
512            .request_tx
513            .as_ref()
514            .ok_or_else(|| VoltError::ConnectionError("Not connected".into()))?;
515
516        // Check if we're read-only
517        if *self.is_read_only.read().await {
518            return Err(VoltError::PermissionDenied("Document is read-only".into()));
519        }
520
521        Self::send_update_chunks(tx, update).await
522    }
523
524    /// Send update chunks to the server (static helper for spawned tasks)
525    async fn send_update_chunks(
526        tx: &mpsc::Sender<volt::SyncDocumentRequest>,
527        update: &[u8],
528    ) -> Result<()> {
529        // Split into chunks if needed
530        let chunks = Self::chunk_update_static(update);
531
532        for (i, chunk) in chunks.iter().enumerate() {
533            let is_last = i == chunks.len() - 1;
534
535            let request = volt::SyncDocumentRequest {
536                payload: Some(volt::sync_document_request::Payload::Update(
537                    volt::SyncDocumentUpdate {
538                        chunk: chunk.clone(),
539                        complete: is_last,
540                    },
541                )),
542                sync_state: SyncState::Update as i32,
543            };
544
545            tx.send(request)
546                .await
547                .map_err(|_| VoltError::ConnectionError("Failed to send update".into()))?;
548        }
549
550        Ok(())
551    }
552
553    /// Split an update into chunks (static version)
554    fn chunk_update_static(update: &[u8]) -> Vec<Vec<u8>> {
555        if update.len() <= MAX_CHUNK_SIZE {
556            return vec![update.to_vec()];
557        }
558
559        update
560            .chunks(MAX_CHUNK_SIZE)
561            .map(|chunk| chunk.to_vec())
562            .collect()
563    }
564
565    /// Get the current provider state
566    pub async fn state(&self) -> ProviderState {
567        *self.state.read().await
568    }
569
570    /// Check if the provider is read-only
571    pub async fn is_read_only(&self) -> bool {
572        *self.is_read_only.read().await
573    }
574
575    /// Stop synchronization
576    pub async fn stop(&mut self) {
577        *self.state.write().await = ProviderState::Disconnected;
578        self.request_tx = None;
579        self.local_update_tx = None;
580        self.update_subscription = None;
581        let _ = self.event_tx.send(SyncEvent::Disconnected).await;
582    }
583}