Skip to main content

zlayer_storage/replicator/
s3_backend.rs

1//! S3 upload/download for `SQLite` backups
2//!
3//! Handles uploading snapshots and WAL segments to S3, and downloading for restore.
4
5use super::cache::CacheEntry;
6use crate::error::{LayerStorageError, Result};
7use aws_sdk_s3::primitives::ByteStream;
8use aws_sdk_s3::Client as S3Client;
9use serde::{Deserialize, Serialize};
10use std::io::{Read, Write};
11use tracing::{debug, info};
12
13/// Replication metadata stored in S3
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ReplicationMetadata {
16    /// Latest snapshot key
17    pub latest_snapshot: Option<String>,
18    /// Timestamp of latest snapshot
19    pub latest_snapshot_time: Option<chrono::DateTime<chrono::Utc>>,
20    /// Highest WAL sequence uploaded
21    pub latest_wal_sequence: Option<u64>,
22    /// Total number of snapshots stored
23    pub snapshot_count: u64,
24    /// Database identifier (for validation)
25    pub db_identifier: Option<String>,
26    /// Last modified timestamp
27    pub last_modified: chrono::DateTime<chrono::Utc>,
28}
29
30impl Default for ReplicationMetadata {
31    fn default() -> Self {
32        Self {
33            latest_snapshot: None,
34            latest_snapshot_time: None,
35            latest_wal_sequence: None,
36            snapshot_count: 0,
37            db_identifier: None,
38            last_modified: chrono::Utc::now(),
39        }
40    }
41}
42
43/// S3 backend for `SQLite` replication
44pub struct S3Backend {
45    client: S3Client,
46    bucket: String,
47    prefix: String,
48    compression_level: i32,
49}
50
51impl S3Backend {
52    /// Create a new S3 backend
53    ///
54    /// # Arguments
55    ///
56    /// * `client` - Pre-configured S3 client
57    /// * `bucket` - S3 bucket name
58    /// * `prefix` - Key prefix for all objects
59    /// * `compression_level` - Zstd compression level (1-22)
60    pub fn new(client: S3Client, bucket: String, prefix: String, compression_level: i32) -> Self {
61        Self {
62            client,
63            bucket,
64            prefix,
65            compression_level,
66        }
67    }
68
69    /// Build the S3 key for a snapshot
70    fn snapshot_key(&self, timestamp: &chrono::DateTime<chrono::Utc>) -> String {
71        format!(
72            "{}snapshots/{}.sqlite.zst",
73            self.prefix,
74            timestamp.format("%Y%m%d_%H%M%S")
75        )
76    }
77
78    /// Build the S3 key for a WAL segment
79    fn wal_key(&self, sequence: u64) -> String {
80        format!("{}wal/{:020}.wal.zst", self.prefix, sequence)
81    }
82
83    /// Build the S3 key for metadata
84    fn metadata_key(&self) -> String {
85        format!("{}metadata.json", self.prefix)
86    }
87
88    /// Upload a database snapshot
89    pub async fn upload_snapshot(&self, data: &[u8]) -> Result<()> {
90        let timestamp = chrono::Utc::now();
91        let key = self.snapshot_key(&timestamp);
92
93        info!(
94            "Uploading snapshot to s3://{}/{} ({} bytes)",
95            self.bucket,
96            key,
97            data.len()
98        );
99
100        // Compress the data
101        let compressed = self.compress(data)?;
102
103        #[allow(clippy::cast_precision_loss)]
104        let reduction_pct = (1.0 - (compressed.len() as f64 / data.len() as f64)) * 100.0;
105        info!(
106            "Compressed {} bytes to {} bytes ({:.1}% reduction)",
107            data.len(),
108            compressed.len(),
109            reduction_pct,
110        );
111
112        // Upload to S3
113        self.client
114            .put_object()
115            .bucket(&self.bucket)
116            .key(&key)
117            .body(ByteStream::from(compressed))
118            .content_type("application/zstd")
119            .send()
120            .await
121            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
122
123        debug!("Snapshot uploaded: {}", key);
124        Ok(())
125    }
126
127    /// Upload a WAL segment
128    pub async fn upload_wal_segment(&self, entry: &CacheEntry) -> Result<()> {
129        let key = self.wal_key(entry.sequence);
130
131        debug!(
132            "Uploading WAL segment {} to s3://{}/{} ({} bytes)",
133            entry.sequence,
134            self.bucket,
135            key,
136            entry.data.len()
137        );
138
139        // Compress the data
140        let compressed = self.compress(&entry.data)?;
141        let compressed_len = compressed.len();
142
143        // Upload to S3
144        self.client
145            .put_object()
146            .bucket(&self.bucket)
147            .key(&key)
148            .body(ByteStream::from(compressed))
149            .content_type("application/zstd")
150            .send()
151            .await
152            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
153
154        debug!(
155            "WAL segment {} uploaded ({} bytes compressed)",
156            entry.sequence, compressed_len
157        );
158        Ok(())
159    }
160
161    /// Download the latest snapshot
162    pub async fn download_latest_snapshot(&self) -> Result<Option<Vec<u8>>> {
163        // Get metadata to find latest snapshot
164        let metadata = self.get_metadata().await?;
165
166        let snapshot_key = if let Some(key) = &metadata.latest_snapshot {
167            key.clone()
168        } else {
169            // List snapshots to find the latest
170            let snapshots = self.list_snapshots().await?;
171            if snapshots.is_empty() {
172                return Ok(None);
173            }
174            snapshots.last().unwrap().clone()
175        };
176
177        info!("Downloading snapshot: {}", snapshot_key);
178
179        let response = self
180            .client
181            .get_object()
182            .bucket(&self.bucket)
183            .key(&snapshot_key)
184            .send()
185            .await
186            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
187
188        let compressed_bytes = response
189            .body
190            .collect()
191            .await
192            .map_err(|e| LayerStorageError::S3(e.to_string()))?
193            .into_bytes();
194
195        // Decompress
196        let decompressed = self.decompress(&compressed_bytes)?;
197
198        info!(
199            "Downloaded snapshot: {} bytes (compressed: {} bytes)",
200            decompressed.len(),
201            compressed_bytes.len()
202        );
203
204        Ok(Some(decompressed))
205    }
206
207    /// Download WAL segments since a given sequence
208    pub async fn download_wal_segments_since(&self, sequence: u64) -> Result<Vec<CacheEntry>> {
209        let prefix = format!("{}wal/", self.prefix);
210
211        let mut segments = Vec::new();
212        let mut continuation_token: Option<String> = None;
213
214        loop {
215            let mut request = self
216                .client
217                .list_objects_v2()
218                .bucket(&self.bucket)
219                .prefix(&prefix);
220
221            if let Some(token) = &continuation_token {
222                request = request.continuation_token(token);
223            }
224
225            let response = request
226                .send()
227                .await
228                .map_err(|e| LayerStorageError::S3(e.to_string()))?;
229
230            for object in response.contents() {
231                if let Some(key) = object.key() {
232                    // Parse sequence from key
233                    if let Some(seq) = Self::parse_wal_sequence(key) {
234                        if seq > sequence {
235                            // Download this segment
236                            let entry = self.download_wal_segment(key, seq).await?;
237                            segments.push(entry);
238                        }
239                    }
240                }
241            }
242
243            if response.is_truncated().unwrap_or(false) {
244                continuation_token = response.next_continuation_token().map(String::from);
245            } else {
246                break;
247            }
248        }
249
250        // Sort by sequence
251        segments.sort_by_key(|e| e.sequence);
252
253        info!(
254            "Downloaded {} WAL segments since sequence {}",
255            segments.len(),
256            sequence
257        );
258
259        Ok(segments)
260    }
261
262    /// Download a single WAL segment
263    async fn download_wal_segment(&self, key: &str, sequence: u64) -> Result<CacheEntry> {
264        let response = self
265            .client
266            .get_object()
267            .bucket(&self.bucket)
268            .key(key)
269            .send()
270            .await
271            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
272
273        let compressed_bytes = response
274            .body
275            .collect()
276            .await
277            .map_err(|e| LayerStorageError::S3(e.to_string()))?
278            .into_bytes();
279
280        let data = self.decompress(&compressed_bytes)?;
281
282        Ok(CacheEntry {
283            sequence,
284            data,
285            cached_at: chrono::Utc::now(),
286            attempts: 0,
287        })
288    }
289
290    /// Parse sequence number from WAL key
291    fn parse_wal_sequence(key: &str) -> Option<u64> {
292        // Key format: {prefix}wal/{sequence:020}.wal.zst
293        let filename = key.rsplit('/').next()?;
294        let sequence_str = filename.strip_suffix(".wal.zst")?;
295        sequence_str.parse().ok()
296    }
297
298    /// List all snapshot keys
299    pub async fn list_snapshots(&self) -> Result<Vec<String>> {
300        let prefix = format!("{}snapshots/", self.prefix);
301
302        let mut keys = Vec::new();
303        let mut continuation_token: Option<String> = None;
304
305        loop {
306            let mut request = self
307                .client
308                .list_objects_v2()
309                .bucket(&self.bucket)
310                .prefix(&prefix);
311
312            if let Some(token) = &continuation_token {
313                request = request.continuation_token(token);
314            }
315
316            let response = request
317                .send()
318                .await
319                .map_err(|e| LayerStorageError::S3(e.to_string()))?;
320
321            for object in response.contents() {
322                if let Some(key) = object.key() {
323                    if key.ends_with(".sqlite.zst") {
324                        keys.push(key.to_string());
325                    }
326                }
327            }
328
329            if response.is_truncated().unwrap_or(false) {
330                continuation_token = response.next_continuation_token().map(String::from);
331            } else {
332                break;
333            }
334        }
335
336        // Sort by timestamp (embedded in key)
337        keys.sort();
338
339        Ok(keys)
340    }
341
342    /// Get replication metadata from S3
343    pub async fn get_metadata(&self) -> Result<ReplicationMetadata> {
344        let key = self.metadata_key();
345
346        match self
347            .client
348            .get_object()
349            .bucket(&self.bucket)
350            .key(&key)
351            .send()
352            .await
353        {
354            Ok(response) => {
355                let bytes = response
356                    .body
357                    .collect()
358                    .await
359                    .map_err(|e| LayerStorageError::S3(e.to_string()))?
360                    .into_bytes();
361
362                let metadata: ReplicationMetadata = serde_json::from_slice(&bytes)?;
363                Ok(metadata)
364            }
365            Err(e) => {
366                // Check if it's a not-found error
367                if e.to_string().contains("NoSuchKey") || e.to_string().contains("404") {
368                    Ok(ReplicationMetadata::default())
369                } else {
370                    Err(LayerStorageError::S3(e.to_string()))
371                }
372            }
373        }
374    }
375
376    /// Update replication metadata in S3
377    pub async fn update_metadata(&self, wal_sequence: Option<u64>) -> Result<()> {
378        let key = self.metadata_key();
379
380        // Get current metadata
381        let mut metadata = self.get_metadata().await.unwrap_or_default();
382
383        // Get latest snapshot
384        let snapshots = self.list_snapshots().await?;
385        if let Some(latest) = snapshots.last() {
386            metadata.latest_snapshot = Some(latest.clone());
387            metadata.latest_snapshot_time = Some(chrono::Utc::now());
388        }
389        metadata.snapshot_count = snapshots.len() as u64;
390
391        // Update WAL sequence if provided
392        if let Some(seq) = wal_sequence {
393            metadata.latest_wal_sequence = Some(seq);
394        }
395
396        metadata.last_modified = chrono::Utc::now();
397
398        // Upload metadata
399        let json = serde_json::to_vec_pretty(&metadata)?;
400        self.client
401            .put_object()
402            .bucket(&self.bucket)
403            .key(&key)
404            .body(ByteStream::from(json))
405            .content_type("application/json")
406            .send()
407            .await
408            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
409
410        debug!("Metadata updated");
411        Ok(())
412    }
413
414    /// Compress data using zstd
415    fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
416        let mut encoder = zstd::stream::Encoder::new(Vec::new(), self.compression_level)?;
417        encoder.write_all(data)?;
418        Ok(encoder.finish()?)
419    }
420
421    /// Decompress zstd data
422    #[allow(clippy::unused_self)]
423    fn decompress(&self, data: &[u8]) -> Result<Vec<u8>> {
424        let mut decoder = zstd::stream::Decoder::new(data)?;
425        let mut decompressed = Vec::new();
426        decoder.read_to_end(&mut decompressed)?;
427        Ok(decompressed)
428    }
429
430    /// Delete old snapshots, keeping only the most recent N
431    #[allow(dead_code)]
432    pub async fn cleanup_old_snapshots(&self, keep_count: usize) -> Result<usize> {
433        let snapshots = self.list_snapshots().await?;
434
435        if snapshots.len() <= keep_count {
436            return Ok(0);
437        }
438
439        let to_delete = &snapshots[..snapshots.len() - keep_count];
440        let mut deleted = 0;
441
442        for key in to_delete {
443            match self
444                .client
445                .delete_object()
446                .bucket(&self.bucket)
447                .key(key)
448                .send()
449                .await
450            {
451                Ok(_) => {
452                    debug!("Deleted old snapshot: {}", key);
453                    deleted += 1;
454                }
455                Err(e) => {
456                    debug!("Failed to delete snapshot {}: {}", key, e);
457                }
458            }
459        }
460
461        info!("Cleaned up {} old snapshots", deleted);
462        Ok(deleted)
463    }
464
465    /// Delete WAL segments older than a given sequence
466    #[allow(dead_code)]
467    pub async fn cleanup_old_wal(&self, before_sequence: u64) -> Result<usize> {
468        let prefix = format!("{}wal/", self.prefix);
469
470        let mut deleted = 0;
471        let mut continuation_token: Option<String> = None;
472
473        loop {
474            let mut request = self
475                .client
476                .list_objects_v2()
477                .bucket(&self.bucket)
478                .prefix(&prefix);
479
480            if let Some(token) = &continuation_token {
481                request = request.continuation_token(token);
482            }
483
484            let response = request
485                .send()
486                .await
487                .map_err(|e| LayerStorageError::S3(e.to_string()))?;
488
489            for object in response.contents() {
490                if let Some(key) = object.key() {
491                    if let Some(seq) = Self::parse_wal_sequence(key) {
492                        if seq < before_sequence
493                            && self
494                                .client
495                                .delete_object()
496                                .bucket(&self.bucket)
497                                .key(key)
498                                .send()
499                                .await
500                                .is_ok()
501                        {
502                            deleted += 1;
503                        }
504                    }
505                }
506            }
507
508            if response.is_truncated().unwrap_or(false) {
509                continuation_token = response.next_continuation_token().map(String::from);
510            } else {
511                break;
512            }
513        }
514
515        info!(
516            "Cleaned up {} WAL segments older than sequence {}",
517            deleted, before_sequence
518        );
519        Ok(deleted)
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526
527    #[test]
528    fn test_parse_wal_sequence() {
529        assert_eq!(
530            S3Backend::parse_wal_sequence("prefix/wal/00000000000000000001.wal.zst"),
531            Some(1)
532        );
533        assert_eq!(
534            S3Backend::parse_wal_sequence("prefix/wal/00000000000000000100.wal.zst"),
535            Some(100)
536        );
537        assert_eq!(S3Backend::parse_wal_sequence("invalid"), None);
538        assert_eq!(
539            S3Backend::parse_wal_sequence("prefix/wal/abc.wal.zst"),
540            None
541        );
542    }
543
544    #[test]
545    fn test_metadata_default() {
546        let metadata = ReplicationMetadata::default();
547        assert!(metadata.latest_snapshot.is_none());
548        assert!(metadata.latest_wal_sequence.is_none());
549        assert_eq!(metadata.snapshot_count, 0);
550    }
551}