Skip to main content

zlayer_storage/
sync.rs

1//! Sync manager for coordinating local/remote layer state
2//!
3//! Handles crash-tolerant uploads with resume capability using S3 multipart uploads.
4
5use crate::config::LayerStorageConfig;
6use crate::error::{LayerStorageError, Result};
7use crate::snapshot::{calculate_directory_digest, create_snapshot, extract_snapshot};
8use crate::types::{ContainerLayerId, LayerSnapshot, PendingUpload, SyncState};
9use aws_sdk_s3::primitives::ByteStream;
10use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart as S3CompletedPart};
11use aws_sdk_s3::Client as S3Client;
12use redb::{Database, ReadableTable, TableDefinition};
13use std::collections::HashMap;
14use std::path::Path;
15use std::sync::Arc;
16use tokio::fs::File;
17use tokio::io::AsyncReadExt;
18use tokio::sync::RwLock;
19use tracing::{debug, info, instrument, warn};
20
21const SYNC_STATE_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("sync_state");
22const SNAPSHOT_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("snapshot_meta");
23
24/// Manages layer synchronization between local storage and S3
25pub struct LayerSyncManager {
26    config: LayerStorageConfig,
27    s3_client: S3Client,
28    db: Database,
29    /// In-memory cache of sync states
30    states: Arc<RwLock<HashMap<String, SyncState>>>,
31}
32
33impl LayerSyncManager {
34    /// Create a new sync manager
35    pub async fn new(config: LayerStorageConfig) -> Result<Self> {
36        // Ensure directories exist
37        tokio::fs::create_dir_all(&config.staging_dir).await?;
38        if let Some(parent) = config.state_db_path.parent() {
39            tokio::fs::create_dir_all(parent).await?;
40        }
41
42        // Initialize AWS SDK
43        let mut aws_config_builder = aws_config::from_env();
44
45        if let Some(region) = &config.region {
46            aws_config_builder =
47                aws_config_builder.region(aws_sdk_s3::config::Region::new(region.clone()));
48        }
49
50        let aws_config = aws_config_builder.load().await;
51
52        let s3_config = if let Some(endpoint) = &config.endpoint_url {
53            aws_sdk_s3::config::Builder::from(&aws_config)
54                .endpoint_url(endpoint)
55                .force_path_style(true)
56                .build()
57        } else {
58            aws_sdk_s3::config::Builder::from(&aws_config).build()
59        };
60
61        let s3_client = S3Client::from_conf(s3_config);
62
63        // Open/create database
64        let db = Database::create(&config.state_db_path)
65            .map_err(|e| LayerStorageError::Database(e.to_string()))?;
66
67        // Initialize tables
68        {
69            let write_txn = db
70                .begin_write()
71                .map_err(|e| LayerStorageError::Database(e.to_string()))?;
72            {
73                let _ = write_txn.open_table(SYNC_STATE_TABLE);
74                let _ = write_txn.open_table(SNAPSHOT_META_TABLE);
75            }
76            write_txn
77                .commit()
78                .map_err(|e| LayerStorageError::Database(e.to_string()))?;
79        }
80
81        // Load existing states into memory
82        let states = Arc::new(RwLock::new(Self::load_all_states(&db)?));
83
84        Ok(Self {
85            config,
86            s3_client,
87            db,
88            states,
89        })
90    }
91
92    fn load_all_states(db: &Database) -> Result<HashMap<String, SyncState>> {
93        let read_txn = db
94            .begin_read()
95            .map_err(|e| LayerStorageError::Database(e.to_string()))?;
96        let table = read_txn
97            .open_table(SYNC_STATE_TABLE)
98            .map_err(|e| LayerStorageError::Database(e.to_string()))?;
99
100        let mut states = HashMap::new();
101        let iter = table
102            .iter()
103            .map_err(|e| LayerStorageError::Database(e.to_string()))?;
104
105        for entry in iter {
106            let (key, value) = entry.map_err(|e| LayerStorageError::Database(e.to_string()))?;
107            let state: SyncState = serde_json::from_slice(value.value())?;
108            states.insert(key.value().to_string(), state);
109        }
110
111        Ok(states)
112    }
113
114    fn save_state(&self, state: &SyncState) -> Result<()> {
115        let key = state.container_id.to_key();
116        let value = serde_json::to_vec(state)?;
117
118        let write_txn = self
119            .db
120            .begin_write()
121            .map_err(|e| LayerStorageError::Database(e.to_string()))?;
122        {
123            let mut table = write_txn
124                .open_table(SYNC_STATE_TABLE)
125                .map_err(|e| LayerStorageError::Database(e.to_string()))?;
126            table
127                .insert(key.as_str(), value.as_slice())
128                .map_err(|e| LayerStorageError::Database(e.to_string()))?;
129        }
130        write_txn
131            .commit()
132            .map_err(|e| LayerStorageError::Database(e.to_string()))?;
133
134        Ok(())
135    }
136
137    /// Register a container for layer sync tracking
138    #[instrument(skip(self))]
139    pub async fn register_container(&self, container_id: ContainerLayerId) -> Result<()> {
140        let key = container_id.to_key();
141        let mut states = self.states.write().await;
142
143        if let std::collections::hash_map::Entry::Vacant(e) = states.entry(key) {
144            let state = SyncState::new(container_id);
145            self.save_state(&state)?;
146            e.insert(state);
147            info!("Registered new container for layer sync");
148        }
149
150        Ok(())
151    }
152
153    /// Check if a container's layer has changed and needs sync
154    #[instrument(skip(self, upper_layer_path))]
155    pub async fn check_for_changes(
156        &self,
157        container_id: &ContainerLayerId,
158        upper_layer_path: impl AsRef<Path>,
159    ) -> Result<bool> {
160        let key = container_id.to_key();
161        let states = self.states.read().await;
162
163        let state = states
164            .get(&key)
165            .ok_or_else(|| LayerStorageError::NotFound(key.clone()))?;
166
167        // Calculate current digest
168        let current_digest = calculate_directory_digest(upper_layer_path)?;
169
170        // Compare with stored digest
171        Ok(state.local_digest.as_ref() != Some(&current_digest))
172    }
173
174    /// Create a snapshot and upload it to S3
175    #[instrument(skip(self, upper_layer_path), fields(container = %container_id))]
176    pub async fn sync_layer(
177        &self,
178        container_id: &ContainerLayerId,
179        upper_layer_path: impl AsRef<Path>,
180    ) -> Result<Option<LayerSnapshot>> {
181        let upper_layer_path = upper_layer_path.as_ref();
182        let key = container_id.to_key();
183
184        // Check for pending upload to resume
185        {
186            let states = self.states.read().await;
187            if let Some(state) = states.get(&key) {
188                if let Some(pending) = &state.pending_upload {
189                    info!("Found pending upload, attempting to resume");
190                    return self.resume_upload(container_id, pending.clone()).await;
191                }
192            }
193        }
194
195        // Calculate current digest
196        let current_digest = calculate_directory_digest(upper_layer_path)?;
197
198        // Check if sync needed
199        {
200            let states = self.states.read().await;
201            if let Some(state) = states.get(&key) {
202                if state.remote_digest.as_ref() == Some(&current_digest) {
203                    debug!("Layer already synced, no changes");
204                    return Ok(None);
205                }
206            }
207        }
208
209        // Create snapshot
210        let tarball_path = self
211            .config
212            .staging_dir
213            .join(format!("{}.tar.zst", current_digest));
214
215        let snapshot = tokio::task::spawn_blocking({
216            let source = upper_layer_path.to_path_buf();
217            let output = tarball_path.clone();
218            let level = self.config.compression_level;
219            move || create_snapshot(source, output, level)
220        })
221        .await
222        .map_err(|e| LayerStorageError::Io(std::io::Error::other(e)))??;
223
224        // Upload to S3
225        self.upload_snapshot(container_id, &tarball_path, &snapshot)
226            .await?;
227
228        // Update state
229        {
230            let mut states = self.states.write().await;
231            if let Some(state) = states.get_mut(&key) {
232                state.local_digest = Some(snapshot.digest.clone());
233                state.remote_digest = Some(snapshot.digest.clone());
234                state.last_sync = Some(chrono::Utc::now());
235                state.pending_upload = None;
236                self.save_state(state)?;
237            }
238        }
239
240        // Clean up staging file
241        let _ = tokio::fs::remove_file(&tarball_path).await;
242
243        Ok(Some(snapshot))
244    }
245
246    /// Upload a snapshot to S3 using multipart upload
247    #[instrument(skip(self, tarball_path, snapshot))]
248    async fn upload_snapshot(
249        &self,
250        container_id: &ContainerLayerId,
251        tarball_path: &Path,
252        snapshot: &LayerSnapshot,
253    ) -> Result<()> {
254        let object_key = self.config.object_key(&snapshot.digest);
255        let file_size = tokio::fs::metadata(tarball_path).await?.len();
256        let part_size = self.config.part_size_bytes;
257        let total_parts = file_size.div_ceil(part_size) as u32;
258
259        info!(
260            "Uploading {} ({} bytes) in {} parts",
261            object_key, file_size, total_parts
262        );
263
264        // Initiate multipart upload
265        let create_response = self
266            .s3_client
267            .create_multipart_upload()
268            .bucket(&self.config.bucket)
269            .key(&object_key)
270            .content_type("application/zstd")
271            .send()
272            .await
273            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
274
275        let upload_id = create_response
276            .upload_id()
277            .ok_or_else(|| LayerStorageError::S3("No upload ID returned".to_string()))?
278            .to_string();
279
280        // Record pending upload for crash recovery
281        let pending = PendingUpload {
282            upload_id: upload_id.clone(),
283            object_key: object_key.clone(),
284            total_parts,
285            completed_parts: HashMap::new(),
286            part_size,
287            local_tarball_path: tarball_path.to_path_buf(),
288            started_at: chrono::Utc::now(),
289            digest: snapshot.digest.clone(),
290        };
291
292        {
293            let key = container_id.to_key();
294            let mut states = self.states.write().await;
295            if let Some(state) = states.get_mut(&key) {
296                state.pending_upload = Some(pending.clone());
297                self.save_state(state)?;
298            }
299        }
300
301        // Upload parts
302        let completed_parts = self
303            .upload_parts(
304                tarball_path,
305                &upload_id,
306                &object_key,
307                total_parts,
308                part_size,
309            )
310            .await?;
311
312        // Complete multipart upload
313        let completed_upload = CompletedMultipartUpload::builder()
314            .set_parts(Some(
315                completed_parts
316                    .into_iter()
317                    .map(|(num, etag)| {
318                        S3CompletedPart::builder()
319                            .part_number(num as i32)
320                            .e_tag(etag)
321                            .build()
322                    })
323                    .collect(),
324            ))
325            .build();
326
327        self.s3_client
328            .complete_multipart_upload()
329            .bucket(&self.config.bucket)
330            .key(&object_key)
331            .upload_id(&upload_id)
332            .multipart_upload(completed_upload)
333            .send()
334            .await
335            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
336
337        // Upload metadata
338        let metadata_key = self.config.metadata_key(&snapshot.digest);
339        let metadata_json = serde_json::to_vec(snapshot)?;
340
341        self.s3_client
342            .put_object()
343            .bucket(&self.config.bucket)
344            .key(&metadata_key)
345            .body(ByteStream::from(metadata_json))
346            .content_type("application/json")
347            .send()
348            .await
349            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
350
351        info!("Upload complete: {}", object_key);
352        Ok(())
353    }
354
355    /// Upload individual parts with progress tracking
356    async fn upload_parts(
357        &self,
358        tarball_path: &Path,
359        upload_id: &str,
360        object_key: &str,
361        total_parts: u32,
362        part_size: u64,
363    ) -> Result<Vec<(u32, String)>> {
364        let mut completed = Vec::new();
365
366        for part_number in 1..=total_parts {
367            let offset = (part_number as u64 - 1) * part_size;
368
369            // Read part data
370            let mut file = File::open(tarball_path).await?;
371            file.seek(std::io::SeekFrom::Start(offset)).await?;
372
373            let mut buffer = vec![0u8; part_size as usize];
374            let bytes_read = file.read(&mut buffer).await?;
375            buffer.truncate(bytes_read);
376
377            // Upload part
378            let response = self
379                .s3_client
380                .upload_part()
381                .bucket(&self.config.bucket)
382                .key(object_key)
383                .upload_id(upload_id)
384                .part_number(part_number as i32)
385                .body(ByteStream::from(buffer))
386                .send()
387                .await
388                .map_err(|e| LayerStorageError::S3(e.to_string()))?;
389
390            let etag = response
391                .e_tag()
392                .ok_or_else(|| LayerStorageError::S3("No ETag returned for part".to_string()))?
393                .to_string();
394
395            debug!("Uploaded part {}/{}: {}", part_number, total_parts, etag);
396            completed.push((part_number, etag));
397        }
398
399        Ok(completed)
400    }
401
402    /// Resume an interrupted upload
403    #[instrument(skip(self, pending))]
404    async fn resume_upload(
405        &self,
406        container_id: &ContainerLayerId,
407        pending: PendingUpload,
408    ) -> Result<Option<LayerSnapshot>> {
409        let missing = pending.missing_parts();
410
411        if missing.is_empty() {
412            // All parts uploaded, just need to complete
413            info!("All parts uploaded, completing multipart upload");
414        } else {
415            info!("Resuming upload, {} parts remaining", missing.len());
416
417            // Verify local file still exists
418            if !pending.local_tarball_path.exists() {
419                warn!("Local tarball missing, aborting upload and starting fresh");
420                self.abort_upload(&pending).await?;
421
422                let key = container_id.to_key();
423                let mut states = self.states.write().await;
424                if let Some(state) = states.get_mut(&key) {
425                    state.pending_upload = None;
426                    self.save_state(state)?;
427                }
428
429                return Err(LayerStorageError::UploadInterrupted(
430                    "Local tarball missing".to_string(),
431                ));
432            }
433
434            // Upload missing parts
435            for part_number in missing {
436                let offset = (part_number as u64 - 1) * pending.part_size;
437
438                let mut file = File::open(&pending.local_tarball_path).await?;
439                file.seek(std::io::SeekFrom::Start(offset)).await?;
440
441                let mut buffer = vec![0u8; pending.part_size as usize];
442                let bytes_read = file.read(&mut buffer).await?;
443                buffer.truncate(bytes_read);
444
445                let response = self
446                    .s3_client
447                    .upload_part()
448                    .bucket(&self.config.bucket)
449                    .key(&pending.object_key)
450                    .upload_id(&pending.upload_id)
451                    .part_number(part_number as i32)
452                    .body(ByteStream::from(buffer))
453                    .send()
454                    .await
455                    .map_err(|e| LayerStorageError::S3(e.to_string()))?;
456
457                let etag = response
458                    .e_tag()
459                    .ok_or_else(|| LayerStorageError::S3("No ETag returned".to_string()))?
460                    .to_string();
461
462                debug!("Uploaded part {}: {}", part_number, etag);
463            }
464        }
465
466        // List parts to get all ETags
467        let parts_response = self
468            .s3_client
469            .list_parts()
470            .bucket(&self.config.bucket)
471            .key(&pending.object_key)
472            .upload_id(&pending.upload_id)
473            .send()
474            .await
475            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
476
477        let completed_parts: Vec<S3CompletedPart> = parts_response
478            .parts()
479            .iter()
480            .map(|p| {
481                S3CompletedPart::builder()
482                    .part_number(p.part_number().unwrap_or(0))
483                    .e_tag(p.e_tag().unwrap_or_default())
484                    .build()
485            })
486            .collect();
487
488        // Complete multipart upload
489        let completed_upload = CompletedMultipartUpload::builder()
490            .set_parts(Some(completed_parts))
491            .build();
492
493        self.s3_client
494            .complete_multipart_upload()
495            .bucket(&self.config.bucket)
496            .key(&pending.object_key)
497            .upload_id(&pending.upload_id)
498            .multipart_upload(completed_upload)
499            .send()
500            .await
501            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
502
503        // Update state
504        let key = container_id.to_key();
505        {
506            let mut states = self.states.write().await;
507            if let Some(state) = states.get_mut(&key) {
508                state.local_digest = Some(pending.digest.clone());
509                state.remote_digest = Some(pending.digest.clone());
510                state.last_sync = Some(chrono::Utc::now());
511                state.pending_upload = None;
512                self.save_state(state)?;
513            }
514        }
515
516        // Clean up staging file
517        let _ = tokio::fs::remove_file(&pending.local_tarball_path).await;
518
519        info!("Upload resumed and completed successfully");
520
521        // Return snapshot metadata (fetch from S3)
522        self.get_snapshot_metadata(&pending.digest).await.map(Some)
523    }
524
525    /// Abort a multipart upload
526    async fn abort_upload(&self, pending: &PendingUpload) -> Result<()> {
527        self.s3_client
528            .abort_multipart_upload()
529            .bucket(&self.config.bucket)
530            .key(&pending.object_key)
531            .upload_id(&pending.upload_id)
532            .send()
533            .await
534            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
535
536        Ok(())
537    }
538
539    /// Download and restore a layer from S3
540    #[instrument(skip(self, target_path))]
541    pub async fn restore_layer(
542        &self,
543        container_id: &ContainerLayerId,
544        target_path: impl AsRef<Path>,
545    ) -> Result<LayerSnapshot> {
546        let target_path = target_path.as_ref();
547        let key = container_id.to_key();
548
549        // Get remote digest
550        let remote_digest = {
551            let states = self.states.read().await;
552            states
553                .get(&key)
554                .and_then(|s| s.remote_digest.clone())
555                .ok_or_else(|| {
556                    LayerStorageError::NotFound(format!("No remote layer for {}", key))
557                })?
558        };
559
560        info!("Restoring layer {} from S3", remote_digest);
561
562        // Download tarball
563        let tarball_path = self
564            .config
565            .staging_dir
566            .join(format!("{}.tar.zst", remote_digest));
567
568        let object_key = self.config.object_key(&remote_digest);
569        let response = self
570            .s3_client
571            .get_object()
572            .bucket(&self.config.bucket)
573            .key(&object_key)
574            .send()
575            .await
576            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
577
578        // Stream to file
579        let mut file = tokio::fs::File::create(&tarball_path).await?;
580        let mut stream = response.body.into_async_read();
581        tokio::io::copy(&mut stream, &mut file).await?;
582
583        // Get snapshot metadata
584        let snapshot = self.get_snapshot_metadata(&remote_digest).await?;
585
586        // Extract
587        tokio::task::spawn_blocking({
588            let tarball = tarball_path.clone();
589            let target = target_path.to_path_buf();
590            let digest = remote_digest.clone();
591            move || extract_snapshot(tarball, target, Some(&digest))
592        })
593        .await
594        .map_err(|e| LayerStorageError::Io(std::io::Error::other(e)))??;
595
596        // Update local digest
597        {
598            let mut states = self.states.write().await;
599            if let Some(state) = states.get_mut(&key) {
600                state.local_digest = Some(remote_digest);
601                self.save_state(state)?;
602            }
603        }
604
605        // Clean up
606        let _ = tokio::fs::remove_file(&tarball_path).await;
607
608        info!("Layer restored successfully");
609        Ok(snapshot)
610    }
611
612    /// Get snapshot metadata from S3
613    async fn get_snapshot_metadata(&self, digest: &str) -> Result<LayerSnapshot> {
614        let metadata_key = self.config.metadata_key(digest);
615
616        let response = self
617            .s3_client
618            .get_object()
619            .bucket(&self.config.bucket)
620            .key(&metadata_key)
621            .send()
622            .await
623            .map_err(|e| LayerStorageError::S3(e.to_string()))?;
624
625        let bytes = response
626            .body
627            .collect()
628            .await
629            .map_err(|e| LayerStorageError::S3(e.to_string()))?
630            .into_bytes();
631
632        serde_json::from_slice(&bytes).map_err(Into::into)
633    }
634
635    /// List all containers with sync state
636    pub async fn list_containers(&self) -> Vec<ContainerLayerId> {
637        let states = self.states.read().await;
638        states.values().map(|s| s.container_id.clone()).collect()
639    }
640
641    /// Get sync state for a container
642    pub async fn get_sync_state(&self, container_id: &ContainerLayerId) -> Option<SyncState> {
643        let states = self.states.read().await;
644        states.get(&container_id.to_key()).cloned()
645    }
646}
647
648// Need to use tokio seek
649use tokio::io::AsyncSeekExt;