1use std::sync::Arc;
12use tokio::sync::{mpsc, Mutex, RwLock};
13use tokio_stream::StreamExt;
14use yrs::updates::decoder::Decode;
15use yrs::updates::encoder::Encode;
16use yrs::{Doc, ReadTxn, StateVector, Subscription, Transact, Update, Uuid};
17
18use crate::error::{Result, VoltError};
19use crate::proto::volt::{self, SyncState};
20use crate::VoltClient;
21
22const MAX_CHUNK_SIZE: usize = 1024 * 1024; const SERVER_ORIGIN: &str = "volt-sync-provider";
27
28#[derive(Debug, Clone)]
30pub enum SyncEvent {
31 Syncing,
33 Synced,
35 Updated,
37 Disconnected,
39 Error(String),
41 SubdocLoaded { guid: Uuid, doc: Doc },
43 SubdocRemoved { guid: Uuid },
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum ProviderState {
50 Disconnected,
52 Syncing,
54 Synced,
56}
57
58pub struct SyncProvider {
60 doc: Doc,
62 database_id: String,
64 document_id: String,
66 state: Arc<RwLock<ProviderState>>,
68 event_tx: mpsc::Sender<SyncEvent>,
70 request_tx: Option<mpsc::Sender<volt::SyncDocumentRequest>>,
72 chunk_buffer: Arc<Mutex<Vec<u8>>>,
74 is_read_only: Arc<RwLock<bool>>,
76 #[allow(dead_code)]
78 update_subscription: Option<Subscription>,
79 #[allow(dead_code)]
81 subdocs_subscription: Option<Subscription>,
82 local_update_tx: Option<mpsc::Sender<Vec<u8>>>,
84}
85
86impl SyncProvider {
87 pub fn new(
97 doc: Doc,
98 database_id: impl Into<String>,
99 document_id: impl Into<String>,
100 ) -> (Self, mpsc::Receiver<SyncEvent>) {
101 let (event_tx, event_rx) = mpsc::channel(100);
102
103 let provider = Self {
104 doc,
105 database_id: database_id.into(),
106 document_id: document_id.into(),
107 state: Arc::new(RwLock::new(ProviderState::Disconnected)),
108 event_tx,
109 request_tx: None,
110 chunk_buffer: Arc::new(Mutex::new(Vec::new())),
111 is_read_only: Arc::new(RwLock::new(false)),
112 update_subscription: None,
113 subdocs_subscription: None,
114 local_update_tx: None,
115 };
116
117 (provider, event_rx)
118 }
119
120 pub async fn start(
127 &mut self,
128 client: &VoltClient,
129 read_only: bool,
130 read_only_fallback: bool,
131 ) -> Result<()> {
132 *self.state.write().await = ProviderState::Syncing;
134 self.event_tx
135 .send(SyncEvent::Syncing)
136 .await
137 .map_err(|_| VoltError::ConnectionError("Event channel closed".into()))?;
138
139 let state_vector = {
141 let txn = self.doc.transact();
142 txn.state_vector().encode_v1()
143 };
144
145 tracing::debug!("Local state vector: {} bytes", state_vector.len());
146
147 let sync_start = volt::SyncDocumentRequest {
149 payload: Some(volt::sync_document_request::Payload::SyncStart(
150 volt::SyncDocumentStart {
151 database_id: self.database_id.clone(),
152 document_id: self.document_id.clone(),
153 state_vector,
154 read_only,
155 read_only_fallback,
156 },
157 )),
158 sync_state: SyncState::Syncing as i32,
159 };
160
161 tracing::debug!(
162 "Sending sync start request to database: {}, document: {}",
163 self.database_id,
164 self.document_id
165 );
166
167 let (tx, mut rx) = client.sync_document(sync_start).await?;
169 self.request_tx = Some(tx.clone());
170
171 tracing::debug!("Sync stream established, waiting for responses...");
172
173 let (local_update_tx, mut local_update_rx) = mpsc::channel::<Vec<u8>>(100);
176 self.local_update_tx = Some(local_update_tx.clone());
177
178 let update_state = self.state.clone();
180 let subscription = self
181 .doc
182 .observe_update_v2(move |txn, event| {
183 if let Some(origin) = txn.origin() {
185 if origin.as_ref() == SERVER_ORIGIN.as_bytes() {
186 tracing::debug!("Skipping server-originated update");
187 return;
188 }
189 }
190
191 let state_guard = update_state.try_read();
194 if let Ok(state) = state_guard {
195 if *state == ProviderState::Synced {
196 let update = event.update.clone();
197 tracing::debug!(
198 "Local document updated, queueing {} bytes to send",
199 update.len()
200 );
201 if let Err(e) = local_update_tx.try_send(update) {
202 tracing::error!("Failed to queue local update: {}", e);
203 }
204 } else {
205 tracing::debug!(
206 "Ignoring local update, not synced yet (state: {:?})",
207 *state
208 );
209 }
210 }
211 })
212 .map_err(|e| {
213 VoltError::ConnectionError(format!("Failed to set up document observer: {:?}", e))
214 })?;
215
216 self.update_subscription = Some(subscription);
217
218 let subdoc_event_tx = self.event_tx.clone();
221 let subdocs_sub = self
222 .doc
223 .observe_subdocs(move |_txn, event| {
224 for subdoc in event.loaded() {
226 let guid = subdoc.guid();
227 let subdoc_clone = subdoc.clone();
228 tracing::debug!("Subdoc loaded: {}", guid);
229 if let Err(e) = subdoc_event_tx.try_send(SyncEvent::SubdocLoaded {
230 guid,
231 doc: subdoc_clone,
232 }) {
233 tracing::error!("Failed to send subdoc loaded event: {}", e);
234 }
235 }
236
237 for subdoc in event.removed() {
239 let guid = subdoc.guid();
240 tracing::debug!("Subdoc removed: {}", guid);
241 if let Err(e) = subdoc_event_tx.try_send(SyncEvent::SubdocRemoved { guid }) {
242 tracing::error!("Failed to send subdoc removed event: {}", e);
243 }
244 }
245 })
246 .map_err(|e| {
247 VoltError::ConnectionError(format!("Failed to set up subdocs observer: {:?}", e))
248 })?;
249
250 self.subdocs_subscription = Some(subdocs_sub);
251
252 let update_request_tx = tx.clone();
254 let update_is_read_only = self.is_read_only.clone();
255 tokio::spawn(async move {
256 while let Some(update) = local_update_rx.recv().await {
257 if *update_is_read_only.read().await {
259 tracing::debug!("Ignoring local update, document is read-only");
260 continue;
261 }
262
263 tracing::debug!("Sending local update to server: {} bytes", update.len());
265 if let Err(e) = Self::send_update_chunks(&update_request_tx, &update).await {
266 tracing::error!("Failed to send local update: {}", e);
267 }
268 }
269 tracing::debug!("Local update sender task finished");
270 });
271
272 let doc = self.doc.clone();
274 let event_tx = self.event_tx.clone();
275 let state = self.state.clone();
276 let chunk_buffer = self.chunk_buffer.clone();
277 let is_read_only = self.is_read_only.clone();
278 let request_tx = tx.clone();
279
280 tokio::spawn(async move {
281 tracing::debug!("Response handler task started");
282 while let Some(result) = rx.next().await {
283 tracing::debug!("Received response from server");
284 match result {
285 Ok(response) => {
286 tracing::debug!("Response sync_state: {}", response.sync_state);
287 if let Err(e) = Self::handle_response(
288 &doc,
289 response,
290 &event_tx,
291 &state,
292 &chunk_buffer,
293 &is_read_only,
294 &request_tx,
295 )
296 .await
297 {
298 let _ = event_tx.send(SyncEvent::Error(e.to_string())).await;
299 }
300 }
301 Err(e) => {
302 let _ = event_tx
303 .send(SyncEvent::Error(format!("Stream error: {}", e)))
304 .await;
305 *state.write().await = ProviderState::Disconnected;
306 let _ = event_tx.send(SyncEvent::Disconnected).await;
307 break;
308 }
309 }
310 }
311 });
312
313 Ok(())
314 }
315
316 async fn handle_response(
318 doc: &Doc,
319 response: volt::SyncDocumentResponse,
320 event_tx: &mpsc::Sender<SyncEvent>,
321 state: &Arc<RwLock<ProviderState>>,
322 chunk_buffer: &Arc<Mutex<Vec<u8>>>,
323 is_read_only: &Arc<RwLock<bool>>,
324 request_tx: &mpsc::Sender<volt::SyncDocumentRequest>,
325 ) -> Result<()> {
326 if let Some(status) = response.status {
328 if status.code != 0 {
329 return Err(VoltError::ServerError(format!(
330 "Server error: {}",
331 status.message
332 )));
333 }
334 }
335
336 if response.is_read_only {
338 *is_read_only.write().await = true;
339 }
340
341 let sync_state = SyncState::try_from(response.sync_state).unwrap_or(SyncState::Syncing);
343
344 match sync_state {
345 SyncState::Syncing => {
346 tracing::debug!("Received SYNC_STATE_SYNCING response");
349
350 if let Some(update) = &response.update {
352 let mut buffer = chunk_buffer.lock().await;
353 buffer.extend_from_slice(&update.chunk);
354
355 if update.complete {
356 if !buffer.is_empty() {
359 tracing::debug!(
360 "Applying {} bytes of updates from server",
361 buffer.len()
362 );
363 match Update::decode_v2(&buffer) {
364 Ok(update) => {
365 let mut txn = doc.transact_mut_with(SERVER_ORIGIN);
366 if let Err(e) = txn.apply_update(update) {
367 tracing::error!("Failed to apply update: {:?}", e);
368 }
369 }
370 Err(e) => {
371 tracing::error!("Failed to decode update: {:?}", e);
372 }
373 }
374 }
375 buffer.clear();
376
377 let server_state_vector = if !response.state_vector.is_empty() {
380 match StateVector::decode_v1(&response.state_vector) {
381 Ok(sv) => sv,
382 Err(e) => {
383 tracing::error!(
384 "Failed to decode server state vector: {:?}",
385 e
386 );
387 StateVector::default()
388 }
389 }
390 } else {
391 StateVector::default()
392 };
393
394 let local_update = {
396 let txn = doc.transact();
397 txn.encode_state_as_update_v2(&server_state_vector)
398 };
399
400 tracing::debug!(
401 "Sending {} bytes of local updates to server",
402 local_update.len()
403 );
404
405 let done_request = volt::SyncDocumentRequest {
407 payload: Some(volt::sync_document_request::Payload::Update(
408 volt::SyncDocumentUpdate {
409 chunk: local_update,
410 complete: true,
411 },
412 )),
413 sync_state: SyncState::Done as i32,
414 };
415
416 request_tx.send(done_request).await.map_err(|_| {
417 VoltError::ConnectionError("Failed to send DONE message".into())
418 })?;
419 }
420 } else {
421 let server_state_vector = if !response.state_vector.is_empty() {
423 match StateVector::decode_v1(&response.state_vector) {
424 Ok(sv) => sv,
425 Err(e) => {
426 tracing::error!("Failed to decode server state vector: {:?}", e);
427 StateVector::default()
428 }
429 }
430 } else {
431 StateVector::default()
432 };
433
434 let local_update = {
435 let txn = doc.transact();
436 txn.encode_state_as_update_v2(&server_state_vector)
437 };
438
439 tracing::debug!(
440 "Sending {} bytes of local updates (no server update)",
441 local_update.len()
442 );
443
444 let done_request = volt::SyncDocumentRequest {
445 payload: Some(volt::sync_document_request::Payload::Update(
446 volt::SyncDocumentUpdate {
447 chunk: local_update,
448 complete: true,
449 },
450 )),
451 sync_state: SyncState::Done as i32,
452 };
453
454 request_tx.send(done_request).await.map_err(|_| {
455 VoltError::ConnectionError("Failed to send DONE message".into())
456 })?;
457 }
458 }
459 SyncState::Done => {
460 tracing::debug!("Received SYNC_STATE_DONE response");
462 *state.write().await = ProviderState::Synced;
463 let _ = event_tx.send(SyncEvent::Synced).await;
464 }
465 SyncState::Update => {
466 if let Some(update) = &response.update {
469 let mut buffer = chunk_buffer.lock().await;
470 buffer.extend_from_slice(&update.chunk);
471
472 if update.complete {
473 if !buffer.is_empty() {
474 tracing::debug!(
475 "Applying {} bytes of live update from server",
476 buffer.len()
477 );
478 match Update::decode_v2(&buffer) {
479 Ok(update) => {
480 let mut txn = doc.transact_mut_with(SERVER_ORIGIN);
481 if let Err(e) = txn.apply_update(update) {
482 tracing::error!("Failed to apply live update: {:?}", e);
483 }
484 }
485 Err(e) => {
486 tracing::error!("Failed to decode live update: {:?}", e);
487 }
488 }
489 }
490 let _ = event_tx.send(SyncEvent::Updated).await;
491 buffer.clear();
492 }
493 }
494 }
495 SyncState::Awareness => {
496 tracing::debug!("Received SYNC_STATE_AWARENESS response");
498 }
499 _ => {
500 tracing::warn!("Unknown sync state: {:?}", sync_state);
501 }
502 }
503
504 Ok(())
505 }
506
507 pub async fn send_update(&self, update: &[u8]) -> Result<()> {
511 let tx = self
512 .request_tx
513 .as_ref()
514 .ok_or_else(|| VoltError::ConnectionError("Not connected".into()))?;
515
516 if *self.is_read_only.read().await {
518 return Err(VoltError::PermissionDenied("Document is read-only".into()));
519 }
520
521 Self::send_update_chunks(tx, update).await
522 }
523
524 async fn send_update_chunks(
526 tx: &mpsc::Sender<volt::SyncDocumentRequest>,
527 update: &[u8],
528 ) -> Result<()> {
529 let chunks = Self::chunk_update_static(update);
531
532 for (i, chunk) in chunks.iter().enumerate() {
533 let is_last = i == chunks.len() - 1;
534
535 let request = volt::SyncDocumentRequest {
536 payload: Some(volt::sync_document_request::Payload::Update(
537 volt::SyncDocumentUpdate {
538 chunk: chunk.clone(),
539 complete: is_last,
540 },
541 )),
542 sync_state: SyncState::Update as i32,
543 };
544
545 tx.send(request)
546 .await
547 .map_err(|_| VoltError::ConnectionError("Failed to send update".into()))?;
548 }
549
550 Ok(())
551 }
552
553 fn chunk_update_static(update: &[u8]) -> Vec<Vec<u8>> {
555 if update.len() <= MAX_CHUNK_SIZE {
556 return vec![update.to_vec()];
557 }
558
559 update
560 .chunks(MAX_CHUNK_SIZE)
561 .map(|chunk| chunk.to_vec())
562 .collect()
563 }
564
565 pub async fn state(&self) -> ProviderState {
567 *self.state.read().await
568 }
569
570 pub async fn is_read_only(&self) -> bool {
572 *self.is_read_only.read().await
573 }
574
575 pub async fn stop(&mut self) {
577 *self.state.write().await = ProviderState::Disconnected;
578 self.request_tx = None;
579 self.local_update_tx = None;
580 self.update_subscription = None;
581 let _ = self.event_tx.send(SyncEvent::Disconnected).await;
582 }
583}