1use crate::config::LayerStorageConfig;
6use crate::error::{LayerStorageError, Result};
7use crate::snapshot::{calculate_directory_digest, create_snapshot, extract_snapshot};
8use crate::types::{ContainerLayerId, LayerSnapshot, PendingUpload, SyncState};
9use aws_sdk_s3::primitives::ByteStream;
10use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart as S3CompletedPart};
11use aws_sdk_s3::Client as S3Client;
12use redb::{Database, ReadableTable, TableDefinition};
13use std::collections::HashMap;
14use std::path::Path;
15use std::sync::Arc;
16use tokio::fs::File;
17use tokio::io::AsyncReadExt;
18use tokio::sync::RwLock;
19use tracing::{debug, info, instrument, warn};
20
21const SYNC_STATE_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("sync_state");
22const SNAPSHOT_META_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("snapshot_meta");
23
24pub struct LayerSyncManager {
26 config: LayerStorageConfig,
27 s3_client: S3Client,
28 db: Database,
29 states: Arc<RwLock<HashMap<String, SyncState>>>,
31}
32
33impl LayerSyncManager {
34 pub async fn new(config: LayerStorageConfig) -> Result<Self> {
36 tokio::fs::create_dir_all(&config.staging_dir).await?;
38 if let Some(parent) = config.state_db_path.parent() {
39 tokio::fs::create_dir_all(parent).await?;
40 }
41
42 let mut aws_config_builder = aws_config::from_env();
44
45 if let Some(region) = &config.region {
46 aws_config_builder =
47 aws_config_builder.region(aws_sdk_s3::config::Region::new(region.clone()));
48 }
49
50 let aws_config = aws_config_builder.load().await;
51
52 let s3_config = if let Some(endpoint) = &config.endpoint_url {
53 aws_sdk_s3::config::Builder::from(&aws_config)
54 .endpoint_url(endpoint)
55 .force_path_style(true)
56 .build()
57 } else {
58 aws_sdk_s3::config::Builder::from(&aws_config).build()
59 };
60
61 let s3_client = S3Client::from_conf(s3_config);
62
63 let db = Database::create(&config.state_db_path)
65 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
66
67 {
69 let write_txn = db
70 .begin_write()
71 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
72 {
73 let _ = write_txn.open_table(SYNC_STATE_TABLE);
74 let _ = write_txn.open_table(SNAPSHOT_META_TABLE);
75 }
76 write_txn
77 .commit()
78 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
79 }
80
81 let states = Arc::new(RwLock::new(Self::load_all_states(&db)?));
83
84 Ok(Self {
85 config,
86 s3_client,
87 db,
88 states,
89 })
90 }
91
92 fn load_all_states(db: &Database) -> Result<HashMap<String, SyncState>> {
93 let read_txn = db
94 .begin_read()
95 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
96 let table = read_txn
97 .open_table(SYNC_STATE_TABLE)
98 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
99
100 let mut states = HashMap::new();
101 let iter = table
102 .iter()
103 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
104
105 for entry in iter {
106 let (key, value) = entry.map_err(|e| LayerStorageError::Database(e.to_string()))?;
107 let state: SyncState = serde_json::from_slice(value.value())?;
108 states.insert(key.value().to_string(), state);
109 }
110
111 Ok(states)
112 }
113
114 fn save_state(&self, state: &SyncState) -> Result<()> {
115 let key = state.container_id.to_key();
116 let value = serde_json::to_vec(state)?;
117
118 let write_txn = self
119 .db
120 .begin_write()
121 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
122 {
123 let mut table = write_txn
124 .open_table(SYNC_STATE_TABLE)
125 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
126 table
127 .insert(key.as_str(), value.as_slice())
128 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
129 }
130 write_txn
131 .commit()
132 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
133
134 Ok(())
135 }
136
137 #[instrument(skip(self))]
139 pub async fn register_container(&self, container_id: ContainerLayerId) -> Result<()> {
140 let key = container_id.to_key();
141 let mut states = self.states.write().await;
142
143 if let std::collections::hash_map::Entry::Vacant(e) = states.entry(key) {
144 let state = SyncState::new(container_id);
145 self.save_state(&state)?;
146 e.insert(state);
147 info!("Registered new container for layer sync");
148 }
149
150 Ok(())
151 }
152
153 #[instrument(skip(self, upper_layer_path))]
155 pub async fn check_for_changes(
156 &self,
157 container_id: &ContainerLayerId,
158 upper_layer_path: impl AsRef<Path>,
159 ) -> Result<bool> {
160 let key = container_id.to_key();
161 let states = self.states.read().await;
162
163 let state = states
164 .get(&key)
165 .ok_or_else(|| LayerStorageError::NotFound(key.clone()))?;
166
167 let current_digest = calculate_directory_digest(upper_layer_path)?;
169
170 Ok(state.local_digest.as_ref() != Some(¤t_digest))
172 }
173
174 #[instrument(skip(self, upper_layer_path), fields(container = %container_id))]
176 pub async fn sync_layer(
177 &self,
178 container_id: &ContainerLayerId,
179 upper_layer_path: impl AsRef<Path>,
180 ) -> Result<Option<LayerSnapshot>> {
181 let upper_layer_path = upper_layer_path.as_ref();
182 let key = container_id.to_key();
183
184 {
186 let states = self.states.read().await;
187 if let Some(state) = states.get(&key) {
188 if let Some(pending) = &state.pending_upload {
189 info!("Found pending upload, attempting to resume");
190 return self.resume_upload(container_id, pending.clone()).await;
191 }
192 }
193 }
194
195 let current_digest = calculate_directory_digest(upper_layer_path)?;
197
198 {
200 let states = self.states.read().await;
201 if let Some(state) = states.get(&key) {
202 if state.remote_digest.as_ref() == Some(¤t_digest) {
203 debug!("Layer already synced, no changes");
204 return Ok(None);
205 }
206 }
207 }
208
209 let tarball_path = self
211 .config
212 .staging_dir
213 .join(format!("{}.tar.zst", current_digest));
214
215 let snapshot = tokio::task::spawn_blocking({
216 let source = upper_layer_path.to_path_buf();
217 let output = tarball_path.clone();
218 let level = self.config.compression_level;
219 move || create_snapshot(source, output, level)
220 })
221 .await
222 .map_err(|e| LayerStorageError::Io(std::io::Error::other(e)))??;
223
224 self.upload_snapshot(container_id, &tarball_path, &snapshot)
226 .await?;
227
228 {
230 let mut states = self.states.write().await;
231 if let Some(state) = states.get_mut(&key) {
232 state.local_digest = Some(snapshot.digest.clone());
233 state.remote_digest = Some(snapshot.digest.clone());
234 state.last_sync = Some(chrono::Utc::now());
235 state.pending_upload = None;
236 self.save_state(state)?;
237 }
238 }
239
240 let _ = tokio::fs::remove_file(&tarball_path).await;
242
243 Ok(Some(snapshot))
244 }
245
246 #[instrument(skip(self, tarball_path, snapshot))]
248 async fn upload_snapshot(
249 &self,
250 container_id: &ContainerLayerId,
251 tarball_path: &Path,
252 snapshot: &LayerSnapshot,
253 ) -> Result<()> {
254 let object_key = self.config.object_key(&snapshot.digest);
255 let file_size = tokio::fs::metadata(tarball_path).await?.len();
256 let part_size = self.config.part_size_bytes;
257 let total_parts = file_size.div_ceil(part_size) as u32;
258
259 info!(
260 "Uploading {} ({} bytes) in {} parts",
261 object_key, file_size, total_parts
262 );
263
264 let create_response = self
266 .s3_client
267 .create_multipart_upload()
268 .bucket(&self.config.bucket)
269 .key(&object_key)
270 .content_type("application/zstd")
271 .send()
272 .await
273 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
274
275 let upload_id = create_response
276 .upload_id()
277 .ok_or_else(|| LayerStorageError::S3("No upload ID returned".to_string()))?
278 .to_string();
279
280 let pending = PendingUpload {
282 upload_id: upload_id.clone(),
283 object_key: object_key.clone(),
284 total_parts,
285 completed_parts: HashMap::new(),
286 part_size,
287 local_tarball_path: tarball_path.to_path_buf(),
288 started_at: chrono::Utc::now(),
289 digest: snapshot.digest.clone(),
290 };
291
292 {
293 let key = container_id.to_key();
294 let mut states = self.states.write().await;
295 if let Some(state) = states.get_mut(&key) {
296 state.pending_upload = Some(pending.clone());
297 self.save_state(state)?;
298 }
299 }
300
301 let completed_parts = self
303 .upload_parts(
304 tarball_path,
305 &upload_id,
306 &object_key,
307 total_parts,
308 part_size,
309 )
310 .await?;
311
312 let completed_upload = CompletedMultipartUpload::builder()
314 .set_parts(Some(
315 completed_parts
316 .into_iter()
317 .map(|(num, etag)| {
318 S3CompletedPart::builder()
319 .part_number(num as i32)
320 .e_tag(etag)
321 .build()
322 })
323 .collect(),
324 ))
325 .build();
326
327 self.s3_client
328 .complete_multipart_upload()
329 .bucket(&self.config.bucket)
330 .key(&object_key)
331 .upload_id(&upload_id)
332 .multipart_upload(completed_upload)
333 .send()
334 .await
335 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
336
337 let metadata_key = self.config.metadata_key(&snapshot.digest);
339 let metadata_json = serde_json::to_vec(snapshot)?;
340
341 self.s3_client
342 .put_object()
343 .bucket(&self.config.bucket)
344 .key(&metadata_key)
345 .body(ByteStream::from(metadata_json))
346 .content_type("application/json")
347 .send()
348 .await
349 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
350
351 info!("Upload complete: {}", object_key);
352 Ok(())
353 }
354
355 async fn upload_parts(
357 &self,
358 tarball_path: &Path,
359 upload_id: &str,
360 object_key: &str,
361 total_parts: u32,
362 part_size: u64,
363 ) -> Result<Vec<(u32, String)>> {
364 let mut completed = Vec::new();
365
366 for part_number in 1..=total_parts {
367 let offset = (part_number as u64 - 1) * part_size;
368
369 let mut file = File::open(tarball_path).await?;
371 file.seek(std::io::SeekFrom::Start(offset)).await?;
372
373 let mut buffer = vec![0u8; part_size as usize];
374 let bytes_read = file.read(&mut buffer).await?;
375 buffer.truncate(bytes_read);
376
377 let response = self
379 .s3_client
380 .upload_part()
381 .bucket(&self.config.bucket)
382 .key(object_key)
383 .upload_id(upload_id)
384 .part_number(part_number as i32)
385 .body(ByteStream::from(buffer))
386 .send()
387 .await
388 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
389
390 let etag = response
391 .e_tag()
392 .ok_or_else(|| LayerStorageError::S3("No ETag returned for part".to_string()))?
393 .to_string();
394
395 debug!("Uploaded part {}/{}: {}", part_number, total_parts, etag);
396 completed.push((part_number, etag));
397 }
398
399 Ok(completed)
400 }
401
402 #[instrument(skip(self, pending))]
404 async fn resume_upload(
405 &self,
406 container_id: &ContainerLayerId,
407 pending: PendingUpload,
408 ) -> Result<Option<LayerSnapshot>> {
409 let missing = pending.missing_parts();
410
411 if missing.is_empty() {
412 info!("All parts uploaded, completing multipart upload");
414 } else {
415 info!("Resuming upload, {} parts remaining", missing.len());
416
417 if !pending.local_tarball_path.exists() {
419 warn!("Local tarball missing, aborting upload and starting fresh");
420 self.abort_upload(&pending).await?;
421
422 let key = container_id.to_key();
423 let mut states = self.states.write().await;
424 if let Some(state) = states.get_mut(&key) {
425 state.pending_upload = None;
426 self.save_state(state)?;
427 }
428
429 return Err(LayerStorageError::UploadInterrupted(
430 "Local tarball missing".to_string(),
431 ));
432 }
433
434 for part_number in missing {
436 let offset = (part_number as u64 - 1) * pending.part_size;
437
438 let mut file = File::open(&pending.local_tarball_path).await?;
439 file.seek(std::io::SeekFrom::Start(offset)).await?;
440
441 let mut buffer = vec![0u8; pending.part_size as usize];
442 let bytes_read = file.read(&mut buffer).await?;
443 buffer.truncate(bytes_read);
444
445 let response = self
446 .s3_client
447 .upload_part()
448 .bucket(&self.config.bucket)
449 .key(&pending.object_key)
450 .upload_id(&pending.upload_id)
451 .part_number(part_number as i32)
452 .body(ByteStream::from(buffer))
453 .send()
454 .await
455 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
456
457 let etag = response
458 .e_tag()
459 .ok_or_else(|| LayerStorageError::S3("No ETag returned".to_string()))?
460 .to_string();
461
462 debug!("Uploaded part {}: {}", part_number, etag);
463 }
464 }
465
466 let parts_response = self
468 .s3_client
469 .list_parts()
470 .bucket(&self.config.bucket)
471 .key(&pending.object_key)
472 .upload_id(&pending.upload_id)
473 .send()
474 .await
475 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
476
477 let completed_parts: Vec<S3CompletedPart> = parts_response
478 .parts()
479 .iter()
480 .map(|p| {
481 S3CompletedPart::builder()
482 .part_number(p.part_number().unwrap_or(0))
483 .e_tag(p.e_tag().unwrap_or_default())
484 .build()
485 })
486 .collect();
487
488 let completed_upload = CompletedMultipartUpload::builder()
490 .set_parts(Some(completed_parts))
491 .build();
492
493 self.s3_client
494 .complete_multipart_upload()
495 .bucket(&self.config.bucket)
496 .key(&pending.object_key)
497 .upload_id(&pending.upload_id)
498 .multipart_upload(completed_upload)
499 .send()
500 .await
501 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
502
503 let key = container_id.to_key();
505 {
506 let mut states = self.states.write().await;
507 if let Some(state) = states.get_mut(&key) {
508 state.local_digest = Some(pending.digest.clone());
509 state.remote_digest = Some(pending.digest.clone());
510 state.last_sync = Some(chrono::Utc::now());
511 state.pending_upload = None;
512 self.save_state(state)?;
513 }
514 }
515
516 let _ = tokio::fs::remove_file(&pending.local_tarball_path).await;
518
519 info!("Upload resumed and completed successfully");
520
521 self.get_snapshot_metadata(&pending.digest).await.map(Some)
523 }
524
525 async fn abort_upload(&self, pending: &PendingUpload) -> Result<()> {
527 self.s3_client
528 .abort_multipart_upload()
529 .bucket(&self.config.bucket)
530 .key(&pending.object_key)
531 .upload_id(&pending.upload_id)
532 .send()
533 .await
534 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
535
536 Ok(())
537 }
538
539 #[instrument(skip(self, target_path))]
541 pub async fn restore_layer(
542 &self,
543 container_id: &ContainerLayerId,
544 target_path: impl AsRef<Path>,
545 ) -> Result<LayerSnapshot> {
546 let target_path = target_path.as_ref();
547 let key = container_id.to_key();
548
549 let remote_digest = {
551 let states = self.states.read().await;
552 states
553 .get(&key)
554 .and_then(|s| s.remote_digest.clone())
555 .ok_or_else(|| {
556 LayerStorageError::NotFound(format!("No remote layer for {}", key))
557 })?
558 };
559
560 info!("Restoring layer {} from S3", remote_digest);
561
562 let tarball_path = self
564 .config
565 .staging_dir
566 .join(format!("{}.tar.zst", remote_digest));
567
568 let object_key = self.config.object_key(&remote_digest);
569 let response = self
570 .s3_client
571 .get_object()
572 .bucket(&self.config.bucket)
573 .key(&object_key)
574 .send()
575 .await
576 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
577
578 let mut file = tokio::fs::File::create(&tarball_path).await?;
580 let mut stream = response.body.into_async_read();
581 tokio::io::copy(&mut stream, &mut file).await?;
582
583 let snapshot = self.get_snapshot_metadata(&remote_digest).await?;
585
586 tokio::task::spawn_blocking({
588 let tarball = tarball_path.clone();
589 let target = target_path.to_path_buf();
590 let digest = remote_digest.clone();
591 move || extract_snapshot(tarball, target, Some(&digest))
592 })
593 .await
594 .map_err(|e| LayerStorageError::Io(std::io::Error::other(e)))??;
595
596 {
598 let mut states = self.states.write().await;
599 if let Some(state) = states.get_mut(&key) {
600 state.local_digest = Some(remote_digest);
601 self.save_state(state)?;
602 }
603 }
604
605 let _ = tokio::fs::remove_file(&tarball_path).await;
607
608 info!("Layer restored successfully");
609 Ok(snapshot)
610 }
611
612 async fn get_snapshot_metadata(&self, digest: &str) -> Result<LayerSnapshot> {
614 let metadata_key = self.config.metadata_key(digest);
615
616 let response = self
617 .s3_client
618 .get_object()
619 .bucket(&self.config.bucket)
620 .key(&metadata_key)
621 .send()
622 .await
623 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
624
625 let bytes = response
626 .body
627 .collect()
628 .await
629 .map_err(|e| LayerStorageError::S3(e.to_string()))?
630 .into_bytes();
631
632 serde_json::from_slice(&bytes).map_err(Into::into)
633 }
634
635 pub async fn list_containers(&self) -> Vec<ContainerLayerId> {
637 let states = self.states.read().await;
638 states.values().map(|s| s.container_id.clone()).collect()
639 }
640
641 pub async fn get_sync_state(&self, container_id: &ContainerLayerId) -> Option<SyncState> {
643 let states = self.states.read().await;
644 states.get(&container_id.to_key()).cloned()
645 }
646}
647
648use tokio::io::AsyncSeekExt;