1use async_trait::async_trait;
37use sqlx::{AnyPool, Row, any::AnyPoolOptions};
38use crate::sync_item::{SyncItem, ContentType};
39use super::traits::{ArchiveStore, BatchWriteResult, StorageError};
40use crate::resilience::retry::{retry, RetryConfig};
41use std::sync::Once;
42use std::time::Duration;
43
44static INSTALL_DRIVERS: Once = Once::new();
46
47fn install_drivers() {
48 INSTALL_DRIVERS.call_once(|| {
49 sqlx::any::install_default_drivers();
50 });
51}
52
53pub struct SqlStore {
54 pool: AnyPool,
55 is_sqlite: bool,
56}
57
58impl SqlStore {
59 pub async fn new(connection_string: &str) -> Result<Self, StorageError> {
61 install_drivers();
62
63 let is_sqlite = connection_string.starts_with("sqlite:");
64
65 let pool = retry("sql_connect", &RetryConfig::startup(), || async {
66 AnyPoolOptions::new()
67 .max_connections(20)
68 .acquire_timeout(Duration::from_secs(10))
69 .idle_timeout(Duration::from_secs(300))
70 .connect(connection_string)
71 .await
72 .map_err(|e| StorageError::Backend(e.to_string()))
73 })
74 .await?;
75
76 let store = Self { pool, is_sqlite };
77
78 if is_sqlite {
80 store.enable_wal_mode().await?;
81 }
82
83 store.init_schema().await?;
84 Ok(store)
85 }
86
87 pub fn pool(&self) -> AnyPool {
89 self.pool.clone()
90 }
91
92 async fn enable_wal_mode(&self) -> Result<(), StorageError> {
99 sqlx::query("PRAGMA journal_mode = WAL")
100 .execute(&self.pool)
101 .await
102 .map_err(|e| StorageError::Backend(format!("Failed to enable WAL mode: {}", e)))?;
103
104 sqlx::query("PRAGMA synchronous = NORMAL")
107 .execute(&self.pool)
108 .await
109 .map_err(|e| StorageError::Backend(format!("Failed to set synchronous mode: {}", e)))?;
110
111 Ok(())
112 }
113
114 async fn init_schema(&self) -> Result<(), StorageError> {
115 let sql = if self.is_sqlite {
125 r#"
126 CREATE TABLE IF NOT EXISTS sync_items (
127 id TEXT PRIMARY KEY,
128 version INTEGER NOT NULL DEFAULT 1,
129 timestamp INTEGER NOT NULL,
130 payload_hash TEXT,
131 payload TEXT,
132 payload_blob BLOB,
133 audit TEXT,
134 merkle_dirty INTEGER NOT NULL DEFAULT 1,
135 state TEXT NOT NULL DEFAULT 'default'
136 )
137 "#
138 } else {
139 r#"
142 CREATE TABLE IF NOT EXISTS sync_items (
143 id VARCHAR(255) PRIMARY KEY,
144 version BIGINT NOT NULL DEFAULT 1,
145 timestamp BIGINT NOT NULL,
146 payload_hash VARCHAR(64),
147 payload LONGTEXT,
148 payload_blob MEDIUMBLOB,
149 audit TEXT,
150 merkle_dirty TINYINT NOT NULL DEFAULT 1,
151 state VARCHAR(32) NOT NULL DEFAULT 'default',
152 INDEX idx_timestamp (timestamp),
153 INDEX idx_merkle_dirty (merkle_dirty),
154 INDEX idx_state (state)
155 )
156 "#
157 };
158
159 retry("sql_init_schema", &RetryConfig::startup(), || async {
160 sqlx::query(sql)
161 .execute(&self.pool)
162 .await
163 .map_err(|e| StorageError::Backend(e.to_string()))
164 })
165 .await?;
166
167 Ok(())
168 }
169
170 fn build_audit_json(item: &SyncItem) -> Option<String> {
172 let mut audit = serde_json::Map::new();
173
174 if let Some(ref batch_id) = item.batch_id {
175 audit.insert("batch".to_string(), serde_json::Value::String(batch_id.clone()));
176 }
177 if let Some(ref trace_parent) = item.trace_parent {
178 audit.insert("trace".to_string(), serde_json::Value::String(trace_parent.clone()));
179 }
180 if let Some(ref home) = item.home_instance_id {
181 audit.insert("home".to_string(), serde_json::Value::String(home.clone()));
182 }
183
184 if audit.is_empty() {
185 None
186 } else {
187 serde_json::to_string(&serde_json::Value::Object(audit)).ok()
188 }
189 }
190
191 fn parse_audit_json(audit_str: Option<String>) -> (Option<String>, Option<String>, Option<String>) {
193 match audit_str {
194 Some(s) => {
195 if let Ok(audit) = serde_json::from_str::<serde_json::Value>(&s) {
196 let batch_id = audit.get("batch").and_then(|v| v.as_str()).map(String::from);
197 let trace_parent = audit.get("trace").and_then(|v| v.as_str()).map(String::from);
198 let home_instance_id = audit.get("home").and_then(|v| v.as_str()).map(String::from);
199 (batch_id, trace_parent, home_instance_id)
200 } else {
201 (None, None, None)
202 }
203 }
204 None => (None, None, None),
205 }
206 }
207}
208
209#[async_trait]
210impl ArchiveStore for SqlStore {
211 async fn get(&self, id: &str) -> Result<Option<SyncItem>, StorageError> {
212 let id = id.to_string();
213
214 retry("sql_get", &RetryConfig::query(), || async {
215 let result = sqlx::query(
216 "SELECT version, timestamp, payload_hash, payload, payload_blob, audit, state FROM sync_items WHERE id = ?"
217 )
218 .bind(&id)
219 .fetch_optional(&self.pool)
220 .await
221 .map_err(|e| StorageError::Backend(e.to_string()))?;
222
223 match result {
224 Some(row) => {
225 let version: i64 = row.try_get("version").unwrap_or(1);
226 let timestamp: i64 = row.try_get("timestamp").unwrap_or(0);
227 let payload_hash: Option<String> = row.try_get("payload_hash").ok();
228
229 let payload_json: Option<String> = row.try_get::<String, _>("payload").ok()
231 .or_else(|| {
232 row.try_get::<Vec<u8>, _>("payload").ok()
233 .and_then(|bytes| String::from_utf8(bytes).ok())
234 });
235
236 let payload_blob: Option<Vec<u8>> = row.try_get("payload_blob").ok();
237
238 let audit_json: Option<String> = row.try_get::<String, _>("audit").ok()
240 .or_else(|| {
241 row.try_get::<Vec<u8>, _>("audit").ok()
242 .and_then(|bytes| String::from_utf8(bytes).ok())
243 });
244
245 let state: String = row.try_get::<String, _>("state").ok()
247 .or_else(|| {
248 row.try_get::<Vec<u8>, _>("state").ok()
249 .and_then(|bytes| String::from_utf8(bytes).ok())
250 })
251 .unwrap_or_else(|| "default".to_string());
252
253 let (content, content_type) = if let Some(ref json_str) = payload_json {
255 let content = json_str.as_bytes().to_vec();
257 (content, ContentType::Json)
258 } else if let Some(blob) = payload_blob {
259 (blob, ContentType::Binary)
261 } else {
262 return Err(StorageError::Backend("No payload in row".to_string()));
263 };
264
265 let (batch_id, trace_parent, home_instance_id) = Self::parse_audit_json(audit_json);
267
268 let item = SyncItem::reconstruct(
269 id.clone(),
270 version as u64,
271 timestamp,
272 content_type,
273 content,
274 batch_id,
275 trace_parent,
276 payload_hash.unwrap_or_default(),
277 home_instance_id,
278 state,
279 );
280 Ok(Some(item))
281 }
282 None => Ok(None),
283 }
284 })
285 .await
286 }
287
288 async fn put(&self, item: &SyncItem) -> Result<(), StorageError> {
289 let id = item.object_id.clone();
290 let version = item.version as i64;
291 let timestamp = item.updated_at;
292 let payload_hash = if item.merkle_root.is_empty() { None } else { Some(item.merkle_root.clone()) };
293 let audit_json = Self::build_audit_json(item);
294 let state = item.state.clone();
295
296 let (payload_json, payload_blob): (Option<String>, Option<Vec<u8>>) = match item.content_type {
298 ContentType::Json => {
299 let json_str = String::from_utf8_lossy(&item.content).to_string();
300 (Some(json_str), None)
301 }
302 ContentType::Binary => {
303 (None, Some(item.content.clone()))
304 }
305 };
306
307 let sql = if self.is_sqlite {
308 "INSERT INTO sync_items (id, version, timestamp, payload_hash, payload, payload_blob, audit, merkle_dirty, state)
309 VALUES (?, ?, ?, ?, ?, ?, ?, 1, ?)
310 ON CONFLICT(id) DO UPDATE SET
311 version = excluded.version,
312 timestamp = excluded.timestamp,
313 payload_hash = excluded.payload_hash,
314 payload = excluded.payload,
315 payload_blob = excluded.payload_blob,
316 audit = excluded.audit,
317 merkle_dirty = 1,
318 state = excluded.state"
319 } else {
320 "INSERT INTO sync_items (id, version, timestamp, payload_hash, payload, payload_blob, audit, merkle_dirty, state)
321 VALUES (?, ?, ?, ?, ?, ?, ?, 1, ?)
322 ON DUPLICATE KEY UPDATE
323 version = VALUES(version),
324 timestamp = VALUES(timestamp),
325 payload_hash = VALUES(payload_hash),
326 payload = VALUES(payload),
327 payload_blob = VALUES(payload_blob),
328 audit = VALUES(audit),
329 merkle_dirty = 1,
330 state = VALUES(state)"
331 };
332
333 retry("sql_put", &RetryConfig::query(), || async {
334 sqlx::query(sql)
335 .bind(&id)
336 .bind(version)
337 .bind(timestamp)
338 .bind(&payload_hash)
339 .bind(&payload_json)
340 .bind(&payload_blob)
341 .bind(&audit_json)
342 .bind(&state)
343 .execute(&self.pool)
344 .await
345 .map_err(|e| StorageError::Backend(e.to_string()))?;
346 Ok(())
347 })
348 .await
349 }
350
351 async fn delete(&self, id: &str) -> Result<(), StorageError> {
352 let id = id.to_string();
353 retry("sql_delete", &RetryConfig::query(), || async {
354 sqlx::query("DELETE FROM sync_items WHERE id = ?")
355 .bind(&id)
356 .execute(&self.pool)
357 .await
358 .map_err(|e| StorageError::Backend(e.to_string()))?;
359 Ok(())
360 })
361 .await
362 }
363
364 async fn exists(&self, id: &str) -> Result<bool, StorageError> {
365 let id = id.to_string();
366 retry("sql_exists", &RetryConfig::query(), || async {
367 let result = sqlx::query("SELECT 1 FROM sync_items WHERE id = ? LIMIT 1")
368 .bind(&id)
369 .fetch_optional(&self.pool)
370 .await
371 .map_err(|e| StorageError::Backend(e.to_string()))?;
372 Ok(result.is_some())
373 })
374 .await
375 }
376
377 async fn put_batch(&self, items: &mut [SyncItem]) -> Result<BatchWriteResult, StorageError> {
379 if items.is_empty() {
380 return Ok(BatchWriteResult {
381 batch_id: String::new(),
382 written: 0,
383 verified: true,
384 });
385 }
386
387 let batch_id = uuid::Uuid::new_v4().to_string();
389
390 for item in items.iter_mut() {
392 item.batch_id = Some(batch_id.clone());
393 }
394
395 const CHUNK_SIZE: usize = 500;
397 let mut total_written = 0usize;
398
399 for chunk in items.chunks(CHUNK_SIZE) {
400 let written = self.put_batch_chunk(chunk, &batch_id).await?;
401 total_written += written;
402 }
403
404 let verified_count = self.verify_batch(&batch_id).await?;
406 let verified = verified_count == items.len();
407
408 if !verified {
409 tracing::warn!(
410 batch_id = %batch_id,
411 expected = items.len(),
412 actual = verified_count,
413 "Batch verification mismatch"
414 );
415 }
416
417 Ok(BatchWriteResult {
418 batch_id,
419 written: total_written,
420 verified,
421 })
422 }
423
424 async fn scan_keys(&self, offset: u64, limit: usize) -> Result<Vec<String>, StorageError> {
425 let rows = sqlx::query("SELECT id FROM sync_items ORDER BY id LIMIT ? OFFSET ?")
426 .bind(limit as i64)
427 .bind(offset as i64)
428 .fetch_all(&self.pool)
429 .await
430 .map_err(|e| StorageError::Backend(e.to_string()))?;
431
432 let mut keys = Vec::with_capacity(rows.len());
433 for row in rows {
434 let id: String = row.try_get("id")
435 .map_err(|e| StorageError::Backend(e.to_string()))?;
436 keys.push(id);
437 }
438
439 Ok(keys)
440 }
441
442 async fn count_all(&self) -> Result<u64, StorageError> {
443 let result = sqlx::query("SELECT COUNT(*) as cnt FROM sync_items")
444 .fetch_one(&self.pool)
445 .await
446 .map_err(|e| StorageError::Backend(e.to_string()))?;
447
448 let count: i64 = result.try_get("cnt")
449 .map_err(|e| StorageError::Backend(e.to_string()))?;
450
451 Ok(count as u64)
452 }
453}
454
455impl SqlStore {
456 async fn put_batch_chunk(&self, chunk: &[SyncItem], _batch_id: &str) -> Result<usize, StorageError> {
459 let placeholders: Vec<String> = (0..chunk.len())
460 .map(|_| "(?, ?, ?, ?, ?, ?, ?, 1, ?)".to_string())
461 .collect();
462
463 let sql = if self.is_sqlite {
464 format!(
465 "INSERT INTO sync_items (id, version, timestamp, payload_hash, payload, payload_blob, audit, merkle_dirty, state) VALUES {} \
466 ON CONFLICT(id) DO UPDATE SET \
467 version = excluded.version, \
468 timestamp = excluded.timestamp, \
469 payload_hash = excluded.payload_hash, \
470 payload = excluded.payload, \
471 payload_blob = excluded.payload_blob, \
472 audit = excluded.audit, \
473 merkle_dirty = 1, \
474 state = excluded.state",
475 placeholders.join(", ")
476 )
477 } else {
478 format!(
479 "INSERT INTO sync_items (id, version, timestamp, payload_hash, payload, payload_blob, audit, merkle_dirty, state) VALUES {} \
480 ON DUPLICATE KEY UPDATE \
481 version = VALUES(version), \
482 timestamp = VALUES(timestamp), \
483 payload_hash = VALUES(payload_hash), \
484 payload = VALUES(payload), \
485 payload_blob = VALUES(payload_blob), \
486 audit = VALUES(audit), \
487 merkle_dirty = 1, \
488 state = VALUES(state)",
489 placeholders.join(", ")
490 )
491 };
492
493 #[derive(Clone)]
495 struct PreparedRow {
496 id: String,
497 version: i64,
498 timestamp: i64,
499 payload_hash: Option<String>,
500 payload_json: Option<String>,
501 payload_blob: Option<Vec<u8>>,
502 audit_json: Option<String>,
503 state: String,
504 }
505
506 let prepared: Vec<PreparedRow> = chunk.iter()
507 .map(|item| {
508 let (payload_json, payload_blob) = match item.content_type {
509 ContentType::Json => {
510 let json_str = String::from_utf8_lossy(&item.content).to_string();
511 (Some(json_str), None)
512 }
513 ContentType::Binary => {
514 (None, Some(item.content.clone()))
515 }
516 };
517
518 PreparedRow {
519 id: item.object_id.clone(),
520 version: item.version as i64,
521 timestamp: item.updated_at,
522 payload_hash: if item.merkle_root.is_empty() { None } else { Some(item.merkle_root.clone()) },
523 payload_json,
524 payload_blob,
525 audit_json: Self::build_audit_json(item),
526 state: item.state.clone(),
527 }
528 })
529 .collect();
530
531 retry("sql_put_batch", &RetryConfig::query(), || {
532 let sql = sql.clone();
533 let prepared = prepared.clone();
534 async move {
535 let mut query = sqlx::query(&sql);
536
537 for row in &prepared {
538 query = query
539 .bind(&row.id)
540 .bind(row.version)
541 .bind(row.timestamp)
542 .bind(&row.payload_hash)
543 .bind(&row.payload_json)
544 .bind(&row.payload_blob)
545 .bind(&row.audit_json)
546 .bind(&row.state);
547 }
548
549 query.execute(&self.pool)
550 .await
551 .map_err(|e| StorageError::Backend(e.to_string()))?;
552
553 Ok(())
554 }
555 })
556 .await?;
557
558 Ok(chunk.len())
559 }
560
561 async fn verify_batch(&self, batch_id: &str) -> Result<usize, StorageError> {
563 let batch_id = batch_id.to_string();
564
565 let sql = if self.is_sqlite {
567 "SELECT COUNT(*) as cnt FROM sync_items WHERE audit LIKE ?"
568 } else {
569 "SELECT COUNT(*) as cnt FROM sync_items WHERE JSON_EXTRACT(audit, '$.batch') = ?"
570 };
571
572 let bind_value = if self.is_sqlite {
573 format!("%\"batch\":\"{}%", batch_id)
574 } else {
575 batch_id.clone()
576 };
577
578 let result = sqlx::query(sql)
579 .bind(&bind_value)
580 .fetch_one(&self.pool)
581 .await
582 .map_err(|e| StorageError::Backend(e.to_string()))?;
583
584 let count: i64 = result.try_get("cnt")
585 .map_err(|e| StorageError::Backend(e.to_string()))?;
586
587 Ok(count as usize)
588 }
589
590 pub async fn scan_batch(&self, limit: usize) -> Result<Vec<SyncItem>, StorageError> {
592 let rows = sqlx::query(
593 "SELECT id, version, timestamp, payload_hash, payload, payload_blob, audit, state FROM sync_items ORDER BY timestamp ASC LIMIT ?"
594 )
595 .bind(limit as i64)
596 .fetch_all(&self.pool)
597 .await
598 .map_err(|e| StorageError::Backend(e.to_string()))?;
599
600 let mut items = Vec::with_capacity(rows.len());
601 for row in rows {
602 let id: String = row.try_get("id")
603 .map_err(|e| StorageError::Backend(e.to_string()))?;
604 let version: i64 = row.try_get("version").unwrap_or(1);
605 let timestamp: i64 = row.try_get("timestamp").unwrap_or(0);
606 let payload_hash: Option<String> = row.try_get("payload_hash").ok();
607
608 let payload_bytes: Option<Vec<u8>> = row.try_get("payload").ok();
610 let payload_json: Option<String> = payload_bytes.and_then(|b| String::from_utf8(b).ok());
611 let payload_blob: Option<Vec<u8>> = row.try_get("payload_blob").ok();
612 let audit_bytes: Option<Vec<u8>> = row.try_get("audit").ok();
613 let audit_json: Option<String> = audit_bytes.and_then(|b| String::from_utf8(b).ok());
614
615 let state_bytes: Option<Vec<u8>> = row.try_get("state").ok();
616 let state: String = state_bytes
617 .and_then(|bytes| String::from_utf8(bytes).ok())
618 .unwrap_or_else(|| "default".to_string());
619
620 let (content, content_type) = if let Some(ref json_str) = payload_json {
621 (json_str.as_bytes().to_vec(), ContentType::Json)
622 } else if let Some(blob) = payload_blob {
623 (blob, ContentType::Binary)
624 } else {
625 continue; };
627
628 let (batch_id, trace_parent, home_instance_id) = Self::parse_audit_json(audit_json);
629
630 let item = SyncItem::reconstruct(
631 id,
632 version as u64,
633 timestamp,
634 content_type,
635 content,
636 batch_id,
637 trace_parent,
638 payload_hash.unwrap_or_default(),
639 home_instance_id,
640 state,
641 );
642 items.push(item);
643 }
644
645 Ok(items)
646 }
647
648 pub async fn delete_batch(&self, ids: &[String]) -> Result<usize, StorageError> {
650 if ids.is_empty() {
651 return Ok(0);
652 }
653
654 let placeholders: Vec<&str> = ids.iter().map(|_| "?").collect();
655 let sql = format!(
656 "DELETE FROM sync_items WHERE id IN ({})",
657 placeholders.join(", ")
658 );
659
660 retry("sql_delete_batch", &RetryConfig::query(), || {
661 let sql = sql.clone();
662 let ids = ids.to_vec();
663 async move {
664 let mut query = sqlx::query(&sql);
665 for id in &ids {
666 query = query.bind(id);
667 }
668
669 let result = query.execute(&self.pool)
670 .await
671 .map_err(|e| StorageError::Backend(e.to_string()))?;
672
673 Ok(result.rows_affected() as usize)
674 }
675 })
676 .await
677 }
678
679 pub async fn get_dirty_merkle_ids(&self, limit: usize) -> Result<Vec<String>, StorageError> {
687 let rows = sqlx::query(
688 "SELECT id FROM sync_items WHERE merkle_dirty = 1 LIMIT ?"
689 )
690 .bind(limit as i64)
691 .fetch_all(&self.pool)
692 .await
693 .map_err(|e| StorageError::Backend(format!("Failed to get dirty merkle ids: {}", e)))?;
694
695 let mut ids = Vec::with_capacity(rows.len());
696 for row in rows {
697 let id: String = row.try_get("id")
698 .map_err(|e| StorageError::Backend(e.to_string()))?;
699 ids.push(id);
700 }
701
702 Ok(ids)
703 }
704
705 pub async fn count_dirty_merkle(&self) -> Result<u64, StorageError> {
707 let result = sqlx::query("SELECT COUNT(*) as cnt FROM sync_items WHERE merkle_dirty = 1")
708 .fetch_one(&self.pool)
709 .await
710 .map_err(|e| StorageError::Backend(e.to_string()))?;
711
712 let count: i64 = result.try_get("cnt")
713 .map_err(|e| StorageError::Backend(e.to_string()))?;
714
715 Ok(count as u64)
716 }
717
718 pub async fn mark_merkle_clean(&self, ids: &[String]) -> Result<usize, StorageError> {
720 if ids.is_empty() {
721 return Ok(0);
722 }
723
724 let placeholders: Vec<&str> = ids.iter().map(|_| "?").collect();
725 let sql = format!(
726 "UPDATE sync_items SET merkle_dirty = 0 WHERE id IN ({})",
727 placeholders.join(", ")
728 );
729
730 let mut query = sqlx::query(&sql);
731 for id in ids {
732 query = query.bind(id);
733 }
734
735 let result = query.execute(&self.pool)
736 .await
737 .map_err(|e| StorageError::Backend(e.to_string()))?;
738
739 Ok(result.rows_affected() as usize)
740 }
741
742 pub async fn has_dirty_merkle(&self) -> Result<bool, StorageError> {
744 let result = sqlx::query("SELECT 1 FROM sync_items WHERE merkle_dirty = 1 LIMIT 1")
745 .fetch_optional(&self.pool)
746 .await
747 .map_err(|e| StorageError::Backend(e.to_string()))?;
748
749 Ok(result.is_some())
750 }
751
752 pub async fn get_dirty_merkle_items(&self, limit: usize) -> Result<Vec<SyncItem>, StorageError> {
757 let rows = sqlx::query(
758 "SELECT id, version, timestamp, payload_hash, payload, payload_blob, audit, state
759 FROM sync_items WHERE merkle_dirty = 1 LIMIT ?"
760 )
761 .bind(limit as i64)
762 .fetch_all(&self.pool)
763 .await
764 .map_err(|e| StorageError::Backend(format!("Failed to get dirty merkle items: {}", e)))?;
765
766 let mut items = Vec::with_capacity(rows.len());
767 for row in rows {
768 let id: String = row.try_get("id")
769 .map_err(|e| StorageError::Backend(e.to_string()))?;
770 let version: i64 = row.try_get("version").unwrap_or(1);
771 let timestamp: i64 = row.try_get("timestamp").unwrap_or(0);
772 let payload_hash: Option<String> = row.try_get("payload_hash").ok();
773
774 let payload_bytes: Option<Vec<u8>> = row.try_get("payload").ok();
776 let payload_json: Option<String> = payload_bytes.and_then(|bytes| {
777 String::from_utf8(bytes).ok()
778 });
779
780 let payload_blob: Option<Vec<u8>> = row.try_get("payload_blob").ok();
781 let audit_bytes: Option<Vec<u8>> = row.try_get("audit").ok();
782 let audit_json: Option<String> = audit_bytes.and_then(|bytes| {
783 String::from_utf8(bytes).ok()
784 });
785
786 let state_bytes: Option<Vec<u8>> = row.try_get("state").ok();
788 let state: String = state_bytes
789 .and_then(|bytes| String::from_utf8(bytes).ok())
790 .unwrap_or_else(|| "default".to_string());
791
792 let (content, content_type) = if let Some(ref json_str) = payload_json {
794 (json_str.as_bytes().to_vec(), ContentType::Json)
795 } else if let Some(blob) = payload_blob {
796 (blob, ContentType::Binary)
797 } else {
798 continue; };
800
801 let (batch_id, trace_parent, home_instance_id) = Self::parse_audit_json(audit_json);
803
804 let item = SyncItem::reconstruct(
805 id,
806 version as u64,
807 timestamp,
808 content_type,
809 content,
810 batch_id,
811 trace_parent,
812 payload_hash.unwrap_or_default(),
813 home_instance_id,
814 state,
815 );
816 items.push(item);
817 }
818
819 Ok(items)
820 }
821
822 pub async fn get_by_state(&self, state: &str, limit: usize) -> Result<Vec<SyncItem>, StorageError> {
830 let rows = sqlx::query(
831 "SELECT id, version, timestamp, payload_hash, payload, payload_blob, audit, state
832 FROM sync_items WHERE state = ? LIMIT ?"
833 )
834 .bind(state)
835 .bind(limit as i64)
836 .fetch_all(&self.pool)
837 .await
838 .map_err(|e| StorageError::Backend(format!("Failed to get items by state: {}", e)))?;
839
840 let mut items = Vec::with_capacity(rows.len());
841 for row in rows {
842 let id: String = row.try_get("id")
843 .map_err(|e| StorageError::Backend(e.to_string()))?;
844 let version: i64 = row.try_get("version").unwrap_or(1);
845 let timestamp: i64 = row.try_get("timestamp").unwrap_or(0);
846 let payload_hash: Option<String> = row.try_get("payload_hash").ok();
847
848 let payload_json: Option<String> = row.try_get::<String, _>("payload").ok()
850 .or_else(|| {
851 row.try_get::<Vec<u8>, _>("payload").ok()
852 .and_then(|bytes| String::from_utf8(bytes).ok())
853 });
854
855 let payload_blob: Option<Vec<u8>> = row.try_get("payload_blob").ok();
856
857 let audit_json: Option<String> = row.try_get::<String, _>("audit").ok()
859 .or_else(|| {
860 row.try_get::<Vec<u8>, _>("audit").ok()
861 .and_then(|bytes| String::from_utf8(bytes).ok())
862 });
863
864 let state: String = row.try_get::<String, _>("state").ok()
866 .or_else(|| {
867 row.try_get::<Vec<u8>, _>("state").ok()
868 .and_then(|bytes| String::from_utf8(bytes).ok())
869 })
870 .unwrap_or_else(|| "default".to_string());
871
872 let (content, content_type) = if let Some(ref json_str) = payload_json {
873 (json_str.as_bytes().to_vec(), ContentType::Json)
874 } else if let Some(blob) = payload_blob {
875 (blob, ContentType::Binary)
876 } else {
877 continue;
878 };
879
880 let (batch_id, trace_parent, home_instance_id) = Self::parse_audit_json(audit_json);
881
882 let item = SyncItem::reconstruct(
883 id,
884 version as u64,
885 timestamp,
886 content_type,
887 content,
888 batch_id,
889 trace_parent,
890 payload_hash.unwrap_or_default(),
891 home_instance_id,
892 state,
893 );
894 items.push(item);
895 }
896
897 Ok(items)
898 }
899
900 pub async fn count_by_state(&self, state: &str) -> Result<u64, StorageError> {
902 let result = sqlx::query("SELECT COUNT(*) as cnt FROM sync_items WHERE state = ?")
903 .bind(state)
904 .fetch_one(&self.pool)
905 .await
906 .map_err(|e| StorageError::Backend(e.to_string()))?;
907
908 let count: i64 = result.try_get("cnt")
909 .map_err(|e| StorageError::Backend(e.to_string()))?;
910
911 Ok(count as u64)
912 }
913
914 pub async fn list_state_ids(&self, state: &str, limit: usize) -> Result<Vec<String>, StorageError> {
916 let rows = sqlx::query("SELECT id FROM sync_items WHERE state = ? LIMIT ?")
917 .bind(state)
918 .bind(limit as i64)
919 .fetch_all(&self.pool)
920 .await
921 .map_err(|e| StorageError::Backend(format!("Failed to list state IDs: {}", e)))?;
922
923 let mut ids = Vec::with_capacity(rows.len());
924 for row in rows {
925 let id: String = row.try_get("id")
926 .map_err(|e| StorageError::Backend(e.to_string()))?;
927 ids.push(id);
928 }
929
930 Ok(ids)
931 }
932
933 pub async fn set_state(&self, id: &str, new_state: &str) -> Result<bool, StorageError> {
935 let result = sqlx::query("UPDATE sync_items SET state = ? WHERE id = ?")
936 .bind(new_state)
937 .bind(id)
938 .execute(&self.pool)
939 .await
940 .map_err(|e| StorageError::Backend(e.to_string()))?;
941
942 Ok(result.rows_affected() > 0)
943 }
944
945 pub async fn delete_by_state(&self, state: &str) -> Result<u64, StorageError> {
949 let result = sqlx::query("DELETE FROM sync_items WHERE state = ?")
950 .bind(state)
951 .execute(&self.pool)
952 .await
953 .map_err(|e| StorageError::Backend(e.to_string()))?;
954
955 Ok(result.rows_affected())
956 }
957}
958
959#[cfg(test)]
960mod tests {
961 use super::*;
962 use std::path::PathBuf;
963 use serde_json::json;
964
965 fn temp_db_path(name: &str) -> PathBuf {
966 std::env::temp_dir().join(format!("sql_state_test_{}.db", name))
967 }
968
969 fn test_item(id: &str, state: &str) -> SyncItem {
970 SyncItem::from_json(id.to_string(), json!({"id": id}))
971 .with_state(state)
972 }
973
974 #[tokio::test]
975 async fn test_state_stored_and_retrieved() {
976 let db_path = temp_db_path("stored");
977 let _ = std::fs::remove_file(&db_path);
978
979 let url = format!("sqlite://{}?mode=rwc", db_path.display());
980 let store = SqlStore::new(&url).await.unwrap();
981
982 let item = test_item("item1", "delta");
984 store.put(&item).await.unwrap();
985
986 let retrieved = store.get("item1").await.unwrap().unwrap();
988 assert_eq!(retrieved.state, "delta");
989
990 let _ = std::fs::remove_file(&db_path);
991 }
992
993 #[tokio::test]
994 async fn test_state_default_value() {
995 let db_path = temp_db_path("default");
996 let _ = std::fs::remove_file(&db_path);
997
998 let url = format!("sqlite://{}?mode=rwc", db_path.display());
999 let store = SqlStore::new(&url).await.unwrap();
1000
1001 let item = SyncItem::from_json("item1".into(), json!({"test": true}));
1003 store.put(&item).await.unwrap();
1004
1005 let retrieved = store.get("item1").await.unwrap().unwrap();
1006 assert_eq!(retrieved.state, "default");
1007
1008 let _ = std::fs::remove_file(&db_path);
1009 }
1010
1011 #[tokio::test]
1012 async fn test_get_by_state() {
1013 let db_path = temp_db_path("get_by_state");
1014 let _ = std::fs::remove_file(&db_path);
1015
1016 let url = format!("sqlite://{}?mode=rwc", db_path.display());
1017 let store = SqlStore::new(&url).await.unwrap();
1018
1019 store.put(&test_item("delta1", "delta")).await.unwrap();
1021 store.put(&test_item("delta2", "delta")).await.unwrap();
1022 store.put(&test_item("base1", "base")).await.unwrap();
1023 store.put(&test_item("pending1", "pending")).await.unwrap();
1024
1025 let deltas = store.get_by_state("delta", 100).await.unwrap();
1027 assert_eq!(deltas.len(), 2);
1028 assert!(deltas.iter().all(|i| i.state == "delta"));
1029
1030 let bases = store.get_by_state("base", 100).await.unwrap();
1031 assert_eq!(bases.len(), 1);
1032 assert_eq!(bases[0].object_id, "base1");
1033
1034 let none = store.get_by_state("nonexistent", 100).await.unwrap();
1036 assert!(none.is_empty());
1037
1038 let _ = std::fs::remove_file(&db_path);
1039 }
1040
1041 #[tokio::test]
1042 async fn test_get_by_state_with_limit() {
1043 let db_path = temp_db_path("get_by_state_limit");
1044 let _ = std::fs::remove_file(&db_path);
1045
1046 let url = format!("sqlite://{}?mode=rwc", db_path.display());
1047 let store = SqlStore::new(&url).await.unwrap();
1048
1049 for i in 0..10 {
1051 store.put(&test_item(&format!("item{}", i), "batch")).await.unwrap();
1052 }
1053
1054 let limited = store.get_by_state("batch", 5).await.unwrap();
1056 assert_eq!(limited.len(), 5);
1057
1058 let _ = std::fs::remove_file(&db_path);
1059 }
1060
1061 #[tokio::test]
1062 async fn test_count_by_state() {
1063 let db_path = temp_db_path("count_by_state");
1064 let _ = std::fs::remove_file(&db_path);
1065
1066 let url = format!("sqlite://{}?mode=rwc", db_path.display());
1067 let store = SqlStore::new(&url).await.unwrap();
1068
1069 store.put(&test_item("a1", "alpha")).await.unwrap();
1071 store.put(&test_item("a2", "alpha")).await.unwrap();
1072 store.put(&test_item("a3", "alpha")).await.unwrap();
1073 store.put(&test_item("b1", "beta")).await.unwrap();
1074
1075 assert_eq!(store.count_by_state("alpha").await.unwrap(), 3);
1076 assert_eq!(store.count_by_state("beta").await.unwrap(), 1);
1077 assert_eq!(store.count_by_state("gamma").await.unwrap(), 0);
1078
1079 let _ = std::fs::remove_file(&db_path);
1080 }
1081
1082 #[tokio::test]
1083 async fn test_list_state_ids() {
1084 let db_path = temp_db_path("list_state_ids");
1085 let _ = std::fs::remove_file(&db_path);
1086
1087 let url = format!("sqlite://{}?mode=rwc", db_path.display());
1088 let store = SqlStore::new(&url).await.unwrap();
1089
1090 store.put(&test_item("id1", "pending")).await.unwrap();
1091 store.put(&test_item("id2", "pending")).await.unwrap();
1092 store.put(&test_item("id3", "done")).await.unwrap();
1093
1094 let pending_ids = store.list_state_ids("pending", 100).await.unwrap();
1095 assert_eq!(pending_ids.len(), 2);
1096 assert!(pending_ids.contains(&"id1".to_string()));
1097 assert!(pending_ids.contains(&"id2".to_string()));
1098
1099 let _ = std::fs::remove_file(&db_path);
1100 }
1101
1102 #[tokio::test]
1103 async fn test_set_state() {
1104 let db_path = temp_db_path("set_state");
1105 let _ = std::fs::remove_file(&db_path);
1106
1107 let url = format!("sqlite://{}?mode=rwc", db_path.display());
1108 let store = SqlStore::new(&url).await.unwrap();
1109
1110 store.put(&test_item("item1", "pending")).await.unwrap();
1111
1112 let before = store.get("item1").await.unwrap().unwrap();
1114 assert_eq!(before.state, "pending");
1115
1116 let updated = store.set_state("item1", "approved").await.unwrap();
1118 assert!(updated);
1119
1120 let after = store.get("item1").await.unwrap().unwrap();
1122 assert_eq!(after.state, "approved");
1123
1124 let not_found = store.set_state("nonexistent", "x").await.unwrap();
1126 assert!(!not_found);
1127
1128 let _ = std::fs::remove_file(&db_path);
1129 }
1130
1131 #[tokio::test]
1132 async fn test_delete_by_state() {
1133 let db_path = temp_db_path("delete_by_state");
1134 let _ = std::fs::remove_file(&db_path);
1135
1136 let url = format!("sqlite://{}?mode=rwc", db_path.display());
1137 let store = SqlStore::new(&url).await.unwrap();
1138
1139 store.put(&test_item("keep1", "keep")).await.unwrap();
1140 store.put(&test_item("keep2", "keep")).await.unwrap();
1141 store.put(&test_item("del1", "delete_me")).await.unwrap();
1142 store.put(&test_item("del2", "delete_me")).await.unwrap();
1143 store.put(&test_item("del3", "delete_me")).await.unwrap();
1144
1145 let deleted = store.delete_by_state("delete_me").await.unwrap();
1147 assert_eq!(deleted, 3);
1148
1149 assert!(store.get("del1").await.unwrap().is_none());
1151 assert!(store.get("del2").await.unwrap().is_none());
1152
1153 assert!(store.get("keep1").await.unwrap().is_some());
1155 assert!(store.get("keep2").await.unwrap().is_some());
1156
1157 let zero = store.delete_by_state("nonexistent").await.unwrap();
1159 assert_eq!(zero, 0);
1160
1161 let _ = std::fs::remove_file(&db_path);
1162 }
1163
1164 #[tokio::test]
1165 async fn test_multiple_puts_preserve_state() {
1166 let db_path = temp_db_path("multi_put_state");
1167 let _ = std::fs::remove_file(&db_path);
1168
1169 let url = format!("sqlite://{}?mode=rwc", db_path.display());
1170 let store = SqlStore::new(&url).await.unwrap();
1171
1172 store.put(&test_item("a", "state_a")).await.unwrap();
1174 store.put(&test_item("b", "state_b")).await.unwrap();
1175 store.put(&test_item("c", "state_c")).await.unwrap();
1176
1177 assert_eq!(store.get("a").await.unwrap().unwrap().state, "state_a");
1178 assert_eq!(store.get("b").await.unwrap().unwrap().state, "state_b");
1179 assert_eq!(store.get("c").await.unwrap().unwrap().state, "state_c");
1180
1181 let _ = std::fs::remove_file(&db_path);
1182 }
1183}