Skip to main content

zlayer_storage/
sync.rs

1//! Sync manager for coordinating local/remote layer state
2//!
3//! Handles crash-tolerant uploads with resume capability using S3 multipart uploads.
4
5use crate::config::LayerStorageConfig;
6use crate::error::{LayerStorageError, Result};
7use crate::snapshot::{calculate_directory_digest, create_snapshot, extract_snapshot};
8use crate::types::{ContainerLayerId, LayerSnapshot, PendingUpload, SyncState};
9use aws_sdk_s3::primitives::ByteStream;
10use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart as S3CompletedPart};
11use aws_sdk_s3::Client as S3Client;
12use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
13use sqlx::SqlitePool;
14use std::collections::HashMap;
15use std::path::Path;
16use std::str::FromStr;
17use std::sync::Arc;
18use tokio::fs::File;
19use tokio::io::AsyncReadExt;
20use tokio::sync::RwLock;
21use tracing::{debug, info, instrument, warn};
22
23/// Manages layer synchronization between local storage and S3
24pub struct LayerSyncManager {
25    config: LayerStorageConfig,
26    s3_client: S3Client,
27    pool: SqlitePool,
28    /// In-memory cache of sync states
29    states: Arc<RwLock<HashMap<String, SyncState>>>,
30}
31
32impl LayerSyncManager {
33    /// Create a new sync manager
34    ///
35    /// # Errors
36    ///
37    /// Returns an error if directories cannot be created, the database cannot
38    /// be opened, or the AWS SDK fails to initialize.
39    pub async fn new(config: LayerStorageConfig) -> Result<Self> {
40        // Ensure directories exist
41        tokio::fs::create_dir_all(&config.staging_dir).await?;
42        if let Some(parent) = config.state_db_path.parent() {
43            tokio::fs::create_dir_all(parent).await?;
44        }
45
46        // Initialize AWS SDK
47        let mut aws_config_builder = aws_config::from_env();
48
49        if let Some(region) = &config.region {
50            aws_config_builder =
51                aws_config_builder.region(aws_sdk_s3::config::Region::new(region.clone()));
52        } else if config.endpoint_url.is_some() {
53            // Custom / non-AWS endpoint: pin a static region so the SDK never consults the
54            // IMDS region provider (which blocks ~5s+ off-AWS). See replicator::mod for context.
55            aws_config_builder =
56                aws_config_builder.region(aws_sdk_s3::config::Region::new("us-east-1"));
57        }
58
59        let aws_config = aws_config_builder.load().await;
60
61        let s3_config = if let Some(endpoint) = &config.endpoint_url {
62            aws_sdk_s3::config::Builder::from(&aws_config)
63                .endpoint_url(endpoint)
64                .force_path_style(true)
65                .build()
66        } else {
67            aws_sdk_s3::config::Builder::from(&aws_config).build()
68        };
69
70        let s3_client = S3Client::from_conf(s3_config);
71
72        // Open/create SQLite database
73        let db_url = format!("sqlite:{}?mode=rwc", config.state_db_path.display());
74        let connect_options = SqliteConnectOptions::from_str(&db_url)
75            .map_err(|e| LayerStorageError::Database(e.to_string()))?
76            .create_if_missing(true);
77
78        let pool = SqlitePoolOptions::new()
79            .max_connections(5)
80            .connect_with(connect_options)
81            .await
82            .map_err(|e| LayerStorageError::Database(e.to_string()))?;
83
84        // Enable WAL mode for better concurrent access
85        sqlx::query("PRAGMA journal_mode=WAL")
86            .execute(&pool)
87            .await
88            .map_err(|e| LayerStorageError::Database(e.to_string()))?;
89
90        // Create sync_state table
91        sqlx::query(
92            r"
93            CREATE TABLE IF NOT EXISTS sync_state (
94                container_key TEXT PRIMARY KEY NOT NULL,
95                state_json TEXT NOT NULL,
96                updated_at TEXT DEFAULT CURRENT_TIMESTAMP
97            )
98            ",
99        )
100        .execute(&pool)
101        .await
102        .map_err(|e| LayerStorageError::Database(e.to_string()))?;
103
104        // Load existing states into memory
105        let states = Arc::new(RwLock::new(Self::load_all_states(&pool).await?));
106
107        Ok(Self {
108            config,
109            s3_client,
110            pool,
111            states,
112        })
113    }
114
115    async fn load_all_states(pool: &SqlitePool) -> Result<HashMap<String, SyncState>> {
116        let rows: Vec<(String, String)> =
117            sqlx::query_as("SELECT container_key, state_json FROM sync_state")
118                .fetch_all(pool)
119                .await
120                .map_err(|e| LayerStorageError::Database(e.to_string()))?;
121
122        let mut states = HashMap::new();
123        for (key, json) in rows {
124            let state: SyncState = serde_json::from_str(&json)?;
125            states.insert(key, state);
126        }
127
128        Ok(states)
129    }
130
131    async fn save_state(&self, state: &SyncState) -> Result<()> {
132        let key = state.container_id.to_key();
133        let value = serde_json::to_string(state)?;
134
135        sqlx::query(
136            r"
137            INSERT OR REPLACE INTO sync_state (container_key, state_json, updated_at)
138            VALUES (?, ?, CURRENT_TIMESTAMP)
139            ",
140        )
141        .bind(&key)
142        .bind(&value)
143        .execute(&self.pool)
144        .await
145        .map_err(|e| LayerStorageError::Database(e.to_string()))?;
146
147        Ok(())
148    }
149
150    /// Register a container for layer sync tracking
151    ///
152    /// # Errors
153    ///
154    /// Returns an error if persisting the sync state to the database fails.
155    #[instrument(skip(self))]
156    pub async fn register_container(&self, container_id: ContainerLayerId) -> Result<()> {
157        let key = container_id.to_key();
158        let mut states = self.states.write().await;
159
160        if let std::collections::hash_map::Entry::Vacant(e) = states.entry(key) {
161            let state = SyncState::new(container_id);
162            self.save_state(&state).await?;
163            e.insert(state);
164            info!("Registered new container for layer sync");
165        }
166
167        Ok(())
168    }
169
170    /// Check if a container's layer has changed and needs sync
171    ///
172    /// # Errors
173    ///
174    /// Returns an error if the container is not registered or if calculating the
175    /// directory digest fails.
176    #[instrument(skip(self, upper_layer_path))]
177    pub async fn check_for_changes(
178        &self,
179        container_id: &ContainerLayerId,
180        upper_layer_path: impl AsRef<Path>,
181    ) -> Result<bool> {
182        let key = container_id.to_key();
183        let states = self.states.read().await;
184
185        let state = states
186            .get(&key)
187            .ok_or_else(|| LayerStorageError::NotFound(key.clone()))?;
188
189        // Calculate current digest
190        let current_digest = calculate_directory_digest(upper_layer_path)?;
191
192        // Compare with stored digest
193        Ok(state.local_digest.as_ref() != Some(&current_digest))
194    }
195
196    /// Create a snapshot and upload it to S3
197    ///
198    /// # Errors
199    ///
200    /// Returns an error if snapshot creation, S3 upload, or state persistence fails.
201    #[instrument(skip(self, upper_layer_path), fields(container = %container_id))]
202    pub async fn sync_layer(
203        &self,
204        container_id: &ContainerLayerId,
205        upper_layer_path: impl AsRef<Path>,
206    ) -> Result<Option<LayerSnapshot>> {
207        let upper_layer_path = upper_layer_path.as_ref();
208        let key = container_id.to_key();
209
210        // Check for pending upload to resume
211        {
212            let states = self.states.read().await;
213            if let Some(state) = states.get(&key) {
214                if let Some(pending) = &state.pending_upload {
215                    info!("Found pending upload, attempting to resume");
216                    return self.resume_upload(container_id, pending.clone()).await;
217                }
218            }
219        }
220
221        // Calculate current digest
222        let current_digest = calculate_directory_digest(upper_layer_path)?;
223
224        // Check if sync needed
225        {
226            let states = self.states.read().await;
227            if let Some(state) = states.get(&key) {
228                if state.remote_digest.as_ref() == Some(&current_digest) {
229                    debug!("Layer already synced, no changes");
230                    return Ok(None);
231                }
232            }
233        }
234
235        // Create snapshot
236        let tarball_path = self
237            .config
238            .staging_dir
239            .join(format!("{current_digest}.tar.zst"));
240
241        let snapshot = tokio::task::spawn_blocking({
242            let source = upper_layer_path.to_path_buf();
243            let output = tarball_path.clone();
244            let level = self.config.compression_level;
245            move || create_snapshot(source, output, level)
246        })
247        .await
248        .map_err(|e| LayerStorageError::Io(std::io::Error::other(e)))??;
249
250        // Upload to S3
251        self.upload_snapshot(container_id, &tarball_path, &snapshot)
252            .await?;
253
254        // Update state
255        {
256            let mut states = self.states.write().await;
257            if let Some(state) = states.get_mut(&key) {
258                state.local_digest = Some(snapshot.digest.clone());
259                state.remote_digest = Some(snapshot.digest.clone());
260                state.last_sync = Some(chrono::Utc::now());
261                state.pending_upload = None;
262                self.save_state(state).await?;
263            }
264        }
265
266        // Clean up staging file
267        let _ = tokio::fs::remove_file(&tarball_path).await;
268
269        Ok(Some(snapshot))
270    }
271
272    /// Upload a snapshot to S3 using multipart upload
273    #[allow(clippy::cast_possible_wrap)]
274    #[instrument(skip(self, tarball_path, snapshot))]
275    async fn upload_snapshot(
276        &self,
277        container_id: &ContainerLayerId,
278        tarball_path: &Path,
279        snapshot: &LayerSnapshot,
280    ) -> Result<()> {
281        let object_key = self.config.object_key(&snapshot.digest);
282        let file_size = tokio::fs::metadata(tarball_path).await?.len();
283        let part_size = self.config.part_size_bytes;
284        #[allow(clippy::cast_possible_truncation)]
285        let total_parts = file_size.div_ceil(part_size) as u32;
286
287        info!(
288            "Uploading {} ({} bytes) in {} parts",
289            object_key, file_size, total_parts
290        );
291
292        // Initiate multipart upload
293        let create_response = self
294            .s3_client
295            .create_multipart_upload()
296            .bucket(&self.config.bucket)
297            .key(&object_key)
298            .content_type("application/zstd")
299            .send()
300            .await
301            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
302
303        let upload_id = create_response
304            .upload_id()
305            .ok_or_else(|| LayerStorageError::S3("No upload ID returned".to_string()))?
306            .to_string();
307
308        // Record pending upload for crash recovery
309        let pending = PendingUpload {
310            upload_id: upload_id.clone(),
311            object_key: object_key.clone(),
312            total_parts,
313            completed_parts: HashMap::new(),
314            part_size,
315            local_tarball_path: tarball_path.to_path_buf(),
316            started_at: chrono::Utc::now(),
317            digest: snapshot.digest.clone(),
318        };
319
320        {
321            let key = container_id.to_key();
322            let mut states = self.states.write().await;
323            if let Some(state) = states.get_mut(&key) {
324                state.pending_upload = Some(pending.clone());
325                self.save_state(state).await?;
326            }
327        }
328
329        // Upload parts
330        let completed_parts = self
331            .upload_parts(
332                tarball_path,
333                &upload_id,
334                &object_key,
335                total_parts,
336                part_size,
337            )
338            .await?;
339
340        // Complete multipart upload
341        let completed_upload = CompletedMultipartUpload::builder()
342            .set_parts(Some(
343                completed_parts
344                    .into_iter()
345                    .map(|(num, etag)| {
346                        S3CompletedPart::builder()
347                            .part_number(num as i32)
348                            .e_tag(etag)
349                            .build()
350                    })
351                    .collect(),
352            ))
353            .build();
354
355        self.s3_client
356            .complete_multipart_upload()
357            .bucket(&self.config.bucket)
358            .key(&object_key)
359            .upload_id(&upload_id)
360            .multipart_upload(completed_upload)
361            .send()
362            .await
363            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
364
365        // Upload metadata
366        let metadata_key = self.config.metadata_key(&snapshot.digest);
367        let metadata_json = serde_json::to_vec(snapshot)?;
368
369        self.s3_client
370            .put_object()
371            .bucket(&self.config.bucket)
372            .key(&metadata_key)
373            .body(ByteStream::from(metadata_json))
374            .content_type("application/json")
375            .send()
376            .await
377            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
378
379        info!("Upload complete: {}", object_key);
380        Ok(())
381    }
382
383    /// Upload individual parts with progress tracking
384    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
385    async fn upload_parts(
386        &self,
387        tarball_path: &Path,
388        upload_id: &str,
389        object_key: &str,
390        total_parts: u32,
391        part_size: u64,
392    ) -> Result<Vec<(u32, String)>> {
393        let mut completed = Vec::new();
394
395        for part_number in 1..=total_parts {
396            let offset = (u64::from(part_number) - 1) * part_size;
397
398            // Read part data
399            let mut file = File::open(tarball_path).await?;
400            file.seek(std::io::SeekFrom::Start(offset)).await?;
401
402            let mut buffer = vec![0u8; part_size as usize];
403            let bytes_read = file.read(&mut buffer).await?;
404            buffer.truncate(bytes_read);
405
406            // Upload part
407            let response = self
408                .s3_client
409                .upload_part()
410                .bucket(&self.config.bucket)
411                .key(object_key)
412                .upload_id(upload_id)
413                .part_number(part_number as i32)
414                .body(ByteStream::from(buffer))
415                .send()
416                .await
417                .map_err(|e| LayerStorageError::S3(e.to_string()))?;
418
419            let etag = response
420                .e_tag()
421                .ok_or_else(|| LayerStorageError::S3("No ETag returned for part".to_string()))?
422                .to_string();
423
424            debug!("Uploaded part {}/{}: {}", part_number, total_parts, etag);
425            completed.push((part_number, etag));
426        }
427
428        Ok(completed)
429    }
430
431    /// Resume an interrupted upload
432    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
433    #[instrument(skip(self, pending))]
434    async fn resume_upload(
435        &self,
436        container_id: &ContainerLayerId,
437        pending: PendingUpload,
438    ) -> Result<Option<LayerSnapshot>> {
439        let missing = pending.missing_parts();
440
441        if missing.is_empty() {
442            // All parts uploaded, just need to complete
443            info!("All parts uploaded, completing multipart upload");
444        } else {
445            info!("Resuming upload, {} parts remaining", missing.len());
446
447            // Verify local file still exists
448            if !pending.local_tarball_path.exists() {
449                warn!("Local tarball missing, aborting upload and starting fresh");
450                self.abort_upload(&pending).await?;
451
452                let key = container_id.to_key();
453                let mut states = self.states.write().await;
454                if let Some(state) = states.get_mut(&key) {
455                    state.pending_upload = None;
456                    self.save_state(state).await?;
457                }
458
459                return Err(LayerStorageError::UploadInterrupted(
460                    "Local tarball missing".to_string(),
461                ));
462            }
463
464            // Upload missing parts
465            for part_number in missing {
466                let offset = (u64::from(part_number) - 1) * pending.part_size;
467
468                let mut file = File::open(&pending.local_tarball_path).await?;
469                file.seek(std::io::SeekFrom::Start(offset)).await?;
470
471                let mut buffer = vec![0u8; pending.part_size as usize];
472                let bytes_read = file.read(&mut buffer).await?;
473                buffer.truncate(bytes_read);
474
475                let response = self
476                    .s3_client
477                    .upload_part()
478                    .bucket(&self.config.bucket)
479                    .key(&pending.object_key)
480                    .upload_id(&pending.upload_id)
481                    .part_number(part_number as i32)
482                    .body(ByteStream::from(buffer))
483                    .send()
484                    .await
485                    .map_err(|e| LayerStorageError::S3(e.to_string()))?;
486
487                let etag = response
488                    .e_tag()
489                    .ok_or_else(|| LayerStorageError::S3("No ETag returned".to_string()))?
490                    .to_string();
491
492                debug!("Uploaded part {}: {}", part_number, etag);
493            }
494        }
495
496        // List parts to get all ETags
497        let parts_response = self
498            .s3_client
499            .list_parts()
500            .bucket(&self.config.bucket)
501            .key(&pending.object_key)
502            .upload_id(&pending.upload_id)
503            .send()
504            .await
505            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
506
507        let completed_parts: Vec<S3CompletedPart> = parts_response
508            .parts()
509            .iter()
510            .map(|p| {
511                S3CompletedPart::builder()
512                    .part_number(p.part_number().unwrap_or(0))
513                    .e_tag(p.e_tag().unwrap_or_default())
514                    .build()
515            })
516            .collect();
517
518        // Complete multipart upload
519        let completed_upload = CompletedMultipartUpload::builder()
520            .set_parts(Some(completed_parts))
521            .build();
522
523        self.s3_client
524            .complete_multipart_upload()
525            .bucket(&self.config.bucket)
526            .key(&pending.object_key)
527            .upload_id(&pending.upload_id)
528            .multipart_upload(completed_upload)
529            .send()
530            .await
531            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
532
533        // Update state
534        let key = container_id.to_key();
535        {
536            let mut states = self.states.write().await;
537            if let Some(state) = states.get_mut(&key) {
538                state.local_digest = Some(pending.digest.clone());
539                state.remote_digest = Some(pending.digest.clone());
540                state.last_sync = Some(chrono::Utc::now());
541                state.pending_upload = None;
542                self.save_state(state).await?;
543            }
544        }
545
546        // Clean up staging file
547        let _ = tokio::fs::remove_file(&pending.local_tarball_path).await;
548
549        info!("Upload resumed and completed successfully");
550
551        // Return snapshot metadata (fetch from S3)
552        self.get_snapshot_metadata(&pending.digest).await.map(Some)
553    }
554
555    /// Abort a multipart upload
556    async fn abort_upload(&self, pending: &PendingUpload) -> Result<()> {
557        self.s3_client
558            .abort_multipart_upload()
559            .bucket(&self.config.bucket)
560            .key(&pending.object_key)
561            .upload_id(&pending.upload_id)
562            .send()
563            .await
564            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
565
566        Ok(())
567    }
568
569    /// Download and restore a layer from S3
570    ///
571    /// # Errors
572    ///
573    /// Returns an error if the remote layer is not found, download fails,
574    /// digest verification fails, or extraction fails.
575    #[instrument(skip(self, target_path))]
576    pub async fn restore_layer(
577        &self,
578        container_id: &ContainerLayerId,
579        target_path: impl AsRef<Path>,
580    ) -> Result<LayerSnapshot> {
581        let target_path = target_path.as_ref();
582        let key = container_id.to_key();
583
584        // Get remote digest
585        let remote_digest = {
586            let states = self.states.read().await;
587            states
588                .get(&key)
589                .and_then(|s| s.remote_digest.clone())
590                .ok_or_else(|| LayerStorageError::NotFound(format!("No remote layer for {key}")))?
591        };
592
593        info!("Restoring layer {} from S3", remote_digest);
594
595        // Download tarball
596        let tarball_path = self
597            .config
598            .staging_dir
599            .join(format!("{remote_digest}.tar.zst"));
600
601        let object_key = self.config.object_key(&remote_digest);
602        let response = self
603            .s3_client
604            .get_object()
605            .bucket(&self.config.bucket)
606            .key(&object_key)
607            .send()
608            .await
609            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
610
611        // Stream to file
612        let mut file = tokio::fs::File::create(&tarball_path).await?;
613        let mut stream = response.body.into_async_read();
614        tokio::io::copy(&mut stream, &mut file).await?;
615
616        // Get snapshot metadata
617        let snapshot = self.get_snapshot_metadata(&remote_digest).await?;
618
619        // Extract
620        tokio::task::spawn_blocking({
621            let tarball = tarball_path.clone();
622            let target = target_path.to_path_buf();
623            let digest = remote_digest.clone();
624            move || extract_snapshot(tarball, target, Some(&digest))
625        })
626        .await
627        .map_err(|e| LayerStorageError::Io(std::io::Error::other(e)))??;
628
629        // Update local digest
630        {
631            let mut states = self.states.write().await;
632            if let Some(state) = states.get_mut(&key) {
633                state.local_digest = Some(remote_digest);
634                self.save_state(state).await?;
635            }
636        }
637
638        // Clean up
639        let _ = tokio::fs::remove_file(&tarball_path).await;
640
641        info!("Layer restored successfully");
642        Ok(snapshot)
643    }
644
645    /// Get snapshot metadata from S3
646    async fn get_snapshot_metadata(&self, digest: &str) -> Result<LayerSnapshot> {
647        let metadata_key = self.config.metadata_key(digest);
648
649        let response = self
650            .s3_client
651            .get_object()
652            .bucket(&self.config.bucket)
653            .key(&metadata_key)
654            .send()
655            .await
656            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
657
658        let bytes = response
659            .body
660            .collect()
661            .await
662            .map_err(|e| LayerStorageError::S3(e.to_string()))?
663            .into_bytes();
664
665        serde_json::from_slice(&bytes).map_err(Into::into)
666    }
667
668    /// List all containers with sync state
669    pub async fn list_containers(&self) -> Vec<ContainerLayerId> {
670        let states = self.states.read().await;
671        states.values().map(|s| s.container_id.clone()).collect()
672    }
673
674    /// Get sync state for a container
675    pub async fn get_sync_state(&self, container_id: &ContainerLayerId) -> Option<SyncState> {
676        let states = self.states.read().await;
677        states.get(&container_id.to_key()).cloned()
678    }
679}
680
681// Need to use tokio seek
682use tokio::io::AsyncSeekExt;