Skip to main content

zlayer_storage/
sync.rs

1//! Sync manager for coordinating local/remote layer state
2//!
3//! Handles crash-tolerant uploads with resume capability using S3 multipart uploads.
4
5use crate::config::LayerStorageConfig;
6use crate::error::{LayerStorageError, Result};
7use crate::snapshot::{calculate_directory_digest, create_snapshot, extract_snapshot};
8use crate::types::{ContainerLayerId, LayerSnapshot, PendingUpload, SyncState};
9use aws_sdk_s3::primitives::ByteStream;
10use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart as S3CompletedPart};
11use aws_sdk_s3::Client as S3Client;
12use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
13use sqlx::SqlitePool;
14use std::collections::HashMap;
15use std::path::Path;
16use std::str::FromStr;
17use std::sync::Arc;
18use tokio::fs::File;
19use tokio::io::AsyncReadExt;
20use tokio::sync::RwLock;
21use tracing::{debug, info, instrument, warn};
22
23/// Manages layer synchronization between local storage and S3
24pub struct LayerSyncManager {
25    config: LayerStorageConfig,
26    s3_client: S3Client,
27    pool: SqlitePool,
28    /// In-memory cache of sync states
29    states: Arc<RwLock<HashMap<String, SyncState>>>,
30}
31
32impl LayerSyncManager {
33    /// Create a new sync manager
34    ///
35    /// # Errors
36    ///
37    /// Returns an error if directories cannot be created, the database cannot
38    /// be opened, or the AWS SDK fails to initialize.
39    pub async fn new(config: LayerStorageConfig) -> Result<Self> {
40        // Ensure directories exist
41        tokio::fs::create_dir_all(&config.staging_dir).await?;
42        if let Some(parent) = config.state_db_path.parent() {
43            tokio::fs::create_dir_all(parent).await?;
44        }
45
46        // Initialize AWS SDK
47        let mut aws_config_builder = aws_config::from_env();
48
49        if let Some(region) = &config.region {
50            aws_config_builder =
51                aws_config_builder.region(aws_sdk_s3::config::Region::new(region.clone()));
52        }
53
54        let aws_config = aws_config_builder.load().await;
55
56        let s3_config = if let Some(endpoint) = &config.endpoint_url {
57            aws_sdk_s3::config::Builder::from(&aws_config)
58                .endpoint_url(endpoint)
59                .force_path_style(true)
60                .build()
61        } else {
62            aws_sdk_s3::config::Builder::from(&aws_config).build()
63        };
64
65        let s3_client = S3Client::from_conf(s3_config);
66
67        // Open/create SQLite database
68        let db_url = format!("sqlite:{}?mode=rwc", config.state_db_path.display());
69        let connect_options = SqliteConnectOptions::from_str(&db_url)
70            .map_err(|e| LayerStorageError::Database(e.to_string()))?
71            .create_if_missing(true);
72
73        let pool = SqlitePoolOptions::new()
74            .max_connections(5)
75            .connect_with(connect_options)
76            .await
77            .map_err(|e| LayerStorageError::Database(e.to_string()))?;
78
79        // Enable WAL mode for better concurrent access
80        sqlx::query("PRAGMA journal_mode=WAL")
81            .execute(&pool)
82            .await
83            .map_err(|e| LayerStorageError::Database(e.to_string()))?;
84
85        // Create sync_state table
86        sqlx::query(
87            r"
88            CREATE TABLE IF NOT EXISTS sync_state (
89                container_key TEXT PRIMARY KEY NOT NULL,
90                state_json TEXT NOT NULL,
91                updated_at TEXT DEFAULT CURRENT_TIMESTAMP
92            )
93            ",
94        )
95        .execute(&pool)
96        .await
97        .map_err(|e| LayerStorageError::Database(e.to_string()))?;
98
99        // Load existing states into memory
100        let states = Arc::new(RwLock::new(Self::load_all_states(&pool).await?));
101
102        Ok(Self {
103            config,
104            s3_client,
105            pool,
106            states,
107        })
108    }
109
110    async fn load_all_states(pool: &SqlitePool) -> Result<HashMap<String, SyncState>> {
111        let rows: Vec<(String, String)> =
112            sqlx::query_as("SELECT container_key, state_json FROM sync_state")
113                .fetch_all(pool)
114                .await
115                .map_err(|e| LayerStorageError::Database(e.to_string()))?;
116
117        let mut states = HashMap::new();
118        for (key, json) in rows {
119            let state: SyncState = serde_json::from_str(&json)?;
120            states.insert(key, state);
121        }
122
123        Ok(states)
124    }
125
126    async fn save_state(&self, state: &SyncState) -> Result<()> {
127        let key = state.container_id.to_key();
128        let value = serde_json::to_string(state)?;
129
130        sqlx::query(
131            r"
132            INSERT OR REPLACE INTO sync_state (container_key, state_json, updated_at)
133            VALUES (?, ?, CURRENT_TIMESTAMP)
134            ",
135        )
136        .bind(&key)
137        .bind(&value)
138        .execute(&self.pool)
139        .await
140        .map_err(|e| LayerStorageError::Database(e.to_string()))?;
141
142        Ok(())
143    }
144
145    /// Register a container for layer sync tracking
146    ///
147    /// # Errors
148    ///
149    /// Returns an error if persisting the sync state to the database fails.
150    #[instrument(skip(self))]
151    pub async fn register_container(&self, container_id: ContainerLayerId) -> Result<()> {
152        let key = container_id.to_key();
153        let mut states = self.states.write().await;
154
155        if let std::collections::hash_map::Entry::Vacant(e) = states.entry(key) {
156            let state = SyncState::new(container_id);
157            self.save_state(&state).await?;
158            e.insert(state);
159            info!("Registered new container for layer sync");
160        }
161
162        Ok(())
163    }
164
165    /// Check if a container's layer has changed and needs sync
166    ///
167    /// # Errors
168    ///
169    /// Returns an error if the container is not registered or if calculating the
170    /// directory digest fails.
171    #[instrument(skip(self, upper_layer_path))]
172    pub async fn check_for_changes(
173        &self,
174        container_id: &ContainerLayerId,
175        upper_layer_path: impl AsRef<Path>,
176    ) -> Result<bool> {
177        let key = container_id.to_key();
178        let states = self.states.read().await;
179
180        let state = states
181            .get(&key)
182            .ok_or_else(|| LayerStorageError::NotFound(key.clone()))?;
183
184        // Calculate current digest
185        let current_digest = calculate_directory_digest(upper_layer_path)?;
186
187        // Compare with stored digest
188        Ok(state.local_digest.as_ref() != Some(&current_digest))
189    }
190
191    /// Create a snapshot and upload it to S3
192    ///
193    /// # Errors
194    ///
195    /// Returns an error if snapshot creation, S3 upload, or state persistence fails.
196    #[instrument(skip(self, upper_layer_path), fields(container = %container_id))]
197    pub async fn sync_layer(
198        &self,
199        container_id: &ContainerLayerId,
200        upper_layer_path: impl AsRef<Path>,
201    ) -> Result<Option<LayerSnapshot>> {
202        let upper_layer_path = upper_layer_path.as_ref();
203        let key = container_id.to_key();
204
205        // Check for pending upload to resume
206        {
207            let states = self.states.read().await;
208            if let Some(state) = states.get(&key) {
209                if let Some(pending) = &state.pending_upload {
210                    info!("Found pending upload, attempting to resume");
211                    return self.resume_upload(container_id, pending.clone()).await;
212                }
213            }
214        }
215
216        // Calculate current digest
217        let current_digest = calculate_directory_digest(upper_layer_path)?;
218
219        // Check if sync needed
220        {
221            let states = self.states.read().await;
222            if let Some(state) = states.get(&key) {
223                if state.remote_digest.as_ref() == Some(&current_digest) {
224                    debug!("Layer already synced, no changes");
225                    return Ok(None);
226                }
227            }
228        }
229
230        // Create snapshot
231        let tarball_path = self
232            .config
233            .staging_dir
234            .join(format!("{current_digest}.tar.zst"));
235
236        let snapshot = tokio::task::spawn_blocking({
237            let source = upper_layer_path.to_path_buf();
238            let output = tarball_path.clone();
239            let level = self.config.compression_level;
240            move || create_snapshot(source, output, level)
241        })
242        .await
243        .map_err(|e| LayerStorageError::Io(std::io::Error::other(e)))??;
244
245        // Upload to S3
246        self.upload_snapshot(container_id, &tarball_path, &snapshot)
247            .await?;
248
249        // Update state
250        {
251            let mut states = self.states.write().await;
252            if let Some(state) = states.get_mut(&key) {
253                state.local_digest = Some(snapshot.digest.clone());
254                state.remote_digest = Some(snapshot.digest.clone());
255                state.last_sync = Some(chrono::Utc::now());
256                state.pending_upload = None;
257                self.save_state(state).await?;
258            }
259        }
260
261        // Clean up staging file
262        let _ = tokio::fs::remove_file(&tarball_path).await;
263
264        Ok(Some(snapshot))
265    }
266
267    /// Upload a snapshot to S3 using multipart upload
268    #[allow(clippy::cast_possible_wrap)]
269    #[instrument(skip(self, tarball_path, snapshot))]
270    async fn upload_snapshot(
271        &self,
272        container_id: &ContainerLayerId,
273        tarball_path: &Path,
274        snapshot: &LayerSnapshot,
275    ) -> Result<()> {
276        let object_key = self.config.object_key(&snapshot.digest);
277        let file_size = tokio::fs::metadata(tarball_path).await?.len();
278        let part_size = self.config.part_size_bytes;
279        #[allow(clippy::cast_possible_truncation)]
280        let total_parts = file_size.div_ceil(part_size) as u32;
281
282        info!(
283            "Uploading {} ({} bytes) in {} parts",
284            object_key, file_size, total_parts
285        );
286
287        // Initiate multipart upload
288        let create_response = self
289            .s3_client
290            .create_multipart_upload()
291            .bucket(&self.config.bucket)
292            .key(&object_key)
293            .content_type("application/zstd")
294            .send()
295            .await
296            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
297
298        let upload_id = create_response
299            .upload_id()
300            .ok_or_else(|| LayerStorageError::S3("No upload ID returned".to_string()))?
301            .to_string();
302
303        // Record pending upload for crash recovery
304        let pending = PendingUpload {
305            upload_id: upload_id.clone(),
306            object_key: object_key.clone(),
307            total_parts,
308            completed_parts: HashMap::new(),
309            part_size,
310            local_tarball_path: tarball_path.to_path_buf(),
311            started_at: chrono::Utc::now(),
312            digest: snapshot.digest.clone(),
313        };
314
315        {
316            let key = container_id.to_key();
317            let mut states = self.states.write().await;
318            if let Some(state) = states.get_mut(&key) {
319                state.pending_upload = Some(pending.clone());
320                self.save_state(state).await?;
321            }
322        }
323
324        // Upload parts
325        let completed_parts = self
326            .upload_parts(
327                tarball_path,
328                &upload_id,
329                &object_key,
330                total_parts,
331                part_size,
332            )
333            .await?;
334
335        // Complete multipart upload
336        let completed_upload = CompletedMultipartUpload::builder()
337            .set_parts(Some(
338                completed_parts
339                    .into_iter()
340                    .map(|(num, etag)| {
341                        S3CompletedPart::builder()
342                            .part_number(num as i32)
343                            .e_tag(etag)
344                            .build()
345                    })
346                    .collect(),
347            ))
348            .build();
349
350        self.s3_client
351            .complete_multipart_upload()
352            .bucket(&self.config.bucket)
353            .key(&object_key)
354            .upload_id(&upload_id)
355            .multipart_upload(completed_upload)
356            .send()
357            .await
358            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
359
360        // Upload metadata
361        let metadata_key = self.config.metadata_key(&snapshot.digest);
362        let metadata_json = serde_json::to_vec(snapshot)?;
363
364        self.s3_client
365            .put_object()
366            .bucket(&self.config.bucket)
367            .key(&metadata_key)
368            .body(ByteStream::from(metadata_json))
369            .content_type("application/json")
370            .send()
371            .await
372            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
373
374        info!("Upload complete: {}", object_key);
375        Ok(())
376    }
377
378    /// Upload individual parts with progress tracking
379    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
380    async fn upload_parts(
381        &self,
382        tarball_path: &Path,
383        upload_id: &str,
384        object_key: &str,
385        total_parts: u32,
386        part_size: u64,
387    ) -> Result<Vec<(u32, String)>> {
388        let mut completed = Vec::new();
389
390        for part_number in 1..=total_parts {
391            let offset = (u64::from(part_number) - 1) * part_size;
392
393            // Read part data
394            let mut file = File::open(tarball_path).await?;
395            file.seek(std::io::SeekFrom::Start(offset)).await?;
396
397            let mut buffer = vec![0u8; part_size as usize];
398            let bytes_read = file.read(&mut buffer).await?;
399            buffer.truncate(bytes_read);
400
401            // Upload part
402            let response = self
403                .s3_client
404                .upload_part()
405                .bucket(&self.config.bucket)
406                .key(object_key)
407                .upload_id(upload_id)
408                .part_number(part_number as i32)
409                .body(ByteStream::from(buffer))
410                .send()
411                .await
412                .map_err(|e| LayerStorageError::S3(e.to_string()))?;
413
414            let etag = response
415                .e_tag()
416                .ok_or_else(|| LayerStorageError::S3("No ETag returned for part".to_string()))?
417                .to_string();
418
419            debug!("Uploaded part {}/{}: {}", part_number, total_parts, etag);
420            completed.push((part_number, etag));
421        }
422
423        Ok(completed)
424    }
425
426    /// Resume an interrupted upload
427    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
428    #[instrument(skip(self, pending))]
429    async fn resume_upload(
430        &self,
431        container_id: &ContainerLayerId,
432        pending: PendingUpload,
433    ) -> Result<Option<LayerSnapshot>> {
434        let missing = pending.missing_parts();
435
436        if missing.is_empty() {
437            // All parts uploaded, just need to complete
438            info!("All parts uploaded, completing multipart upload");
439        } else {
440            info!("Resuming upload, {} parts remaining", missing.len());
441
442            // Verify local file still exists
443            if !pending.local_tarball_path.exists() {
444                warn!("Local tarball missing, aborting upload and starting fresh");
445                self.abort_upload(&pending).await?;
446
447                let key = container_id.to_key();
448                let mut states = self.states.write().await;
449                if let Some(state) = states.get_mut(&key) {
450                    state.pending_upload = None;
451                    self.save_state(state).await?;
452                }
453
454                return Err(LayerStorageError::UploadInterrupted(
455                    "Local tarball missing".to_string(),
456                ));
457            }
458
459            // Upload missing parts
460            for part_number in missing {
461                let offset = (u64::from(part_number) - 1) * pending.part_size;
462
463                let mut file = File::open(&pending.local_tarball_path).await?;
464                file.seek(std::io::SeekFrom::Start(offset)).await?;
465
466                let mut buffer = vec![0u8; pending.part_size as usize];
467                let bytes_read = file.read(&mut buffer).await?;
468                buffer.truncate(bytes_read);
469
470                let response = self
471                    .s3_client
472                    .upload_part()
473                    .bucket(&self.config.bucket)
474                    .key(&pending.object_key)
475                    .upload_id(&pending.upload_id)
476                    .part_number(part_number as i32)
477                    .body(ByteStream::from(buffer))
478                    .send()
479                    .await
480                    .map_err(|e| LayerStorageError::S3(e.to_string()))?;
481
482                let etag = response
483                    .e_tag()
484                    .ok_or_else(|| LayerStorageError::S3("No ETag returned".to_string()))?
485                    .to_string();
486
487                debug!("Uploaded part {}: {}", part_number, etag);
488            }
489        }
490
491        // List parts to get all ETags
492        let parts_response = self
493            .s3_client
494            .list_parts()
495            .bucket(&self.config.bucket)
496            .key(&pending.object_key)
497            .upload_id(&pending.upload_id)
498            .send()
499            .await
500            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
501
502        let completed_parts: Vec<S3CompletedPart> = parts_response
503            .parts()
504            .iter()
505            .map(|p| {
506                S3CompletedPart::builder()
507                    .part_number(p.part_number().unwrap_or(0))
508                    .e_tag(p.e_tag().unwrap_or_default())
509                    .build()
510            })
511            .collect();
512
513        // Complete multipart upload
514        let completed_upload = CompletedMultipartUpload::builder()
515            .set_parts(Some(completed_parts))
516            .build();
517
518        self.s3_client
519            .complete_multipart_upload()
520            .bucket(&self.config.bucket)
521            .key(&pending.object_key)
522            .upload_id(&pending.upload_id)
523            .multipart_upload(completed_upload)
524            .send()
525            .await
526            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
527
528        // Update state
529        let key = container_id.to_key();
530        {
531            let mut states = self.states.write().await;
532            if let Some(state) = states.get_mut(&key) {
533                state.local_digest = Some(pending.digest.clone());
534                state.remote_digest = Some(pending.digest.clone());
535                state.last_sync = Some(chrono::Utc::now());
536                state.pending_upload = None;
537                self.save_state(state).await?;
538            }
539        }
540
541        // Clean up staging file
542        let _ = tokio::fs::remove_file(&pending.local_tarball_path).await;
543
544        info!("Upload resumed and completed successfully");
545
546        // Return snapshot metadata (fetch from S3)
547        self.get_snapshot_metadata(&pending.digest).await.map(Some)
548    }
549
550    /// Abort a multipart upload
551    async fn abort_upload(&self, pending: &PendingUpload) -> Result<()> {
552        self.s3_client
553            .abort_multipart_upload()
554            .bucket(&self.config.bucket)
555            .key(&pending.object_key)
556            .upload_id(&pending.upload_id)
557            .send()
558            .await
559            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
560
561        Ok(())
562    }
563
564    /// Download and restore a layer from S3
565    ///
566    /// # Errors
567    ///
568    /// Returns an error if the remote layer is not found, download fails,
569    /// digest verification fails, or extraction fails.
570    #[instrument(skip(self, target_path))]
571    pub async fn restore_layer(
572        &self,
573        container_id: &ContainerLayerId,
574        target_path: impl AsRef<Path>,
575    ) -> Result<LayerSnapshot> {
576        let target_path = target_path.as_ref();
577        let key = container_id.to_key();
578
579        // Get remote digest
580        let remote_digest = {
581            let states = self.states.read().await;
582            states
583                .get(&key)
584                .and_then(|s| s.remote_digest.clone())
585                .ok_or_else(|| LayerStorageError::NotFound(format!("No remote layer for {key}")))?
586        };
587
588        info!("Restoring layer {} from S3", remote_digest);
589
590        // Download tarball
591        let tarball_path = self
592            .config
593            .staging_dir
594            .join(format!("{remote_digest}.tar.zst"));
595
596        let object_key = self.config.object_key(&remote_digest);
597        let response = self
598            .s3_client
599            .get_object()
600            .bucket(&self.config.bucket)
601            .key(&object_key)
602            .send()
603            .await
604            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
605
606        // Stream to file
607        let mut file = tokio::fs::File::create(&tarball_path).await?;
608        let mut stream = response.body.into_async_read();
609        tokio::io::copy(&mut stream, &mut file).await?;
610
611        // Get snapshot metadata
612        let snapshot = self.get_snapshot_metadata(&remote_digest).await?;
613
614        // Extract
615        tokio::task::spawn_blocking({
616            let tarball = tarball_path.clone();
617            let target = target_path.to_path_buf();
618            let digest = remote_digest.clone();
619            move || extract_snapshot(tarball, target, Some(&digest))
620        })
621        .await
622        .map_err(|e| LayerStorageError::Io(std::io::Error::other(e)))??;
623
624        // Update local digest
625        {
626            let mut states = self.states.write().await;
627            if let Some(state) = states.get_mut(&key) {
628                state.local_digest = Some(remote_digest);
629                self.save_state(state).await?;
630            }
631        }
632
633        // Clean up
634        let _ = tokio::fs::remove_file(&tarball_path).await;
635
636        info!("Layer restored successfully");
637        Ok(snapshot)
638    }
639
640    /// Get snapshot metadata from S3
641    async fn get_snapshot_metadata(&self, digest: &str) -> Result<LayerSnapshot> {
642        let metadata_key = self.config.metadata_key(digest);
643
644        let response = self
645            .s3_client
646            .get_object()
647            .bucket(&self.config.bucket)
648            .key(&metadata_key)
649            .send()
650            .await
651            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
652
653        let bytes = response
654            .body
655            .collect()
656            .await
657            .map_err(|e| LayerStorageError::S3(e.to_string()))?
658            .into_bytes();
659
660        serde_json::from_slice(&bytes).map_err(Into::into)
661    }
662
663    /// List all containers with sync state
664    pub async fn list_containers(&self) -> Vec<ContainerLayerId> {
665        let states = self.states.read().await;
666        states.values().map(|s| s.container_id.clone()).collect()
667    }
668
669    /// Get sync state for a container
670    pub async fn get_sync_state(&self, container_id: &ContainerLayerId) -> Option<SyncState> {
671        let states = self.states.read().await;
672        states.get(&container_id.to_key()).cloned()
673    }
674}
675
676// Need to use tokio seek
677use tokio::io::AsyncSeekExt;