1use crate::config::LayerStorageConfig;
6use crate::error::{LayerStorageError, Result};
7use crate::snapshot::{calculate_directory_digest, create_snapshot, extract_snapshot};
8use crate::types::{ContainerLayerId, LayerSnapshot, PendingUpload, SyncState};
9use aws_sdk_s3::primitives::ByteStream;
10use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart as S3CompletedPart};
11use aws_sdk_s3::Client as S3Client;
12use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
13use sqlx::SqlitePool;
14use std::collections::HashMap;
15use std::path::Path;
16use std::str::FromStr;
17use std::sync::Arc;
18use tokio::fs::File;
19use tokio::io::AsyncReadExt;
20use tokio::sync::RwLock;
21use tracing::{debug, info, instrument, warn};
22
23pub struct LayerSyncManager {
25 config: LayerStorageConfig,
26 s3_client: S3Client,
27 pool: SqlitePool,
28 states: Arc<RwLock<HashMap<String, SyncState>>>,
30}
31
32impl LayerSyncManager {
33 pub async fn new(config: LayerStorageConfig) -> Result<Self> {
40 tokio::fs::create_dir_all(&config.staging_dir).await?;
42 if let Some(parent) = config.state_db_path.parent() {
43 tokio::fs::create_dir_all(parent).await?;
44 }
45
46 let mut aws_config_builder = aws_config::from_env();
48
49 if let Some(region) = &config.region {
50 aws_config_builder =
51 aws_config_builder.region(aws_sdk_s3::config::Region::new(region.clone()));
52 }
53
54 let aws_config = aws_config_builder.load().await;
55
56 let s3_config = if let Some(endpoint) = &config.endpoint_url {
57 aws_sdk_s3::config::Builder::from(&aws_config)
58 .endpoint_url(endpoint)
59 .force_path_style(true)
60 .build()
61 } else {
62 aws_sdk_s3::config::Builder::from(&aws_config).build()
63 };
64
65 let s3_client = S3Client::from_conf(s3_config);
66
67 let db_url = format!("sqlite:{}?mode=rwc", config.state_db_path.display());
69 let connect_options = SqliteConnectOptions::from_str(&db_url)
70 .map_err(|e| LayerStorageError::Database(e.to_string()))?
71 .create_if_missing(true);
72
73 let pool = SqlitePoolOptions::new()
74 .max_connections(5)
75 .connect_with(connect_options)
76 .await
77 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
78
79 sqlx::query("PRAGMA journal_mode=WAL")
81 .execute(&pool)
82 .await
83 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
84
85 sqlx::query(
87 r"
88 CREATE TABLE IF NOT EXISTS sync_state (
89 container_key TEXT PRIMARY KEY NOT NULL,
90 state_json TEXT NOT NULL,
91 updated_at TEXT DEFAULT CURRENT_TIMESTAMP
92 )
93 ",
94 )
95 .execute(&pool)
96 .await
97 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
98
99 let states = Arc::new(RwLock::new(Self::load_all_states(&pool).await?));
101
102 Ok(Self {
103 config,
104 s3_client,
105 pool,
106 states,
107 })
108 }
109
110 async fn load_all_states(pool: &SqlitePool) -> Result<HashMap<String, SyncState>> {
111 let rows: Vec<(String, String)> =
112 sqlx::query_as("SELECT container_key, state_json FROM sync_state")
113 .fetch_all(pool)
114 .await
115 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
116
117 let mut states = HashMap::new();
118 for (key, json) in rows {
119 let state: SyncState = serde_json::from_str(&json)?;
120 states.insert(key, state);
121 }
122
123 Ok(states)
124 }
125
126 async fn save_state(&self, state: &SyncState) -> Result<()> {
127 let key = state.container_id.to_key();
128 let value = serde_json::to_string(state)?;
129
130 sqlx::query(
131 r"
132 INSERT OR REPLACE INTO sync_state (container_key, state_json, updated_at)
133 VALUES (?, ?, CURRENT_TIMESTAMP)
134 ",
135 )
136 .bind(&key)
137 .bind(&value)
138 .execute(&self.pool)
139 .await
140 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
141
142 Ok(())
143 }
144
145 #[instrument(skip(self))]
151 pub async fn register_container(&self, container_id: ContainerLayerId) -> Result<()> {
152 let key = container_id.to_key();
153 let mut states = self.states.write().await;
154
155 if let std::collections::hash_map::Entry::Vacant(e) = states.entry(key) {
156 let state = SyncState::new(container_id);
157 self.save_state(&state).await?;
158 e.insert(state);
159 info!("Registered new container for layer sync");
160 }
161
162 Ok(())
163 }
164
165 #[instrument(skip(self, upper_layer_path))]
172 pub async fn check_for_changes(
173 &self,
174 container_id: &ContainerLayerId,
175 upper_layer_path: impl AsRef<Path>,
176 ) -> Result<bool> {
177 let key = container_id.to_key();
178 let states = self.states.read().await;
179
180 let state = states
181 .get(&key)
182 .ok_or_else(|| LayerStorageError::NotFound(key.clone()))?;
183
184 let current_digest = calculate_directory_digest(upper_layer_path)?;
186
187 Ok(state.local_digest.as_ref() != Some(¤t_digest))
189 }
190
191 #[instrument(skip(self, upper_layer_path), fields(container = %container_id))]
197 pub async fn sync_layer(
198 &self,
199 container_id: &ContainerLayerId,
200 upper_layer_path: impl AsRef<Path>,
201 ) -> Result<Option<LayerSnapshot>> {
202 let upper_layer_path = upper_layer_path.as_ref();
203 let key = container_id.to_key();
204
205 {
207 let states = self.states.read().await;
208 if let Some(state) = states.get(&key) {
209 if let Some(pending) = &state.pending_upload {
210 info!("Found pending upload, attempting to resume");
211 return self.resume_upload(container_id, pending.clone()).await;
212 }
213 }
214 }
215
216 let current_digest = calculate_directory_digest(upper_layer_path)?;
218
219 {
221 let states = self.states.read().await;
222 if let Some(state) = states.get(&key) {
223 if state.remote_digest.as_ref() == Some(¤t_digest) {
224 debug!("Layer already synced, no changes");
225 return Ok(None);
226 }
227 }
228 }
229
230 let tarball_path = self
232 .config
233 .staging_dir
234 .join(format!("{current_digest}.tar.zst"));
235
236 let snapshot = tokio::task::spawn_blocking({
237 let source = upper_layer_path.to_path_buf();
238 let output = tarball_path.clone();
239 let level = self.config.compression_level;
240 move || create_snapshot(source, output, level)
241 })
242 .await
243 .map_err(|e| LayerStorageError::Io(std::io::Error::other(e)))??;
244
245 self.upload_snapshot(container_id, &tarball_path, &snapshot)
247 .await?;
248
249 {
251 let mut states = self.states.write().await;
252 if let Some(state) = states.get_mut(&key) {
253 state.local_digest = Some(snapshot.digest.clone());
254 state.remote_digest = Some(snapshot.digest.clone());
255 state.last_sync = Some(chrono::Utc::now());
256 state.pending_upload = None;
257 self.save_state(state).await?;
258 }
259 }
260
261 let _ = tokio::fs::remove_file(&tarball_path).await;
263
264 Ok(Some(snapshot))
265 }
266
267 #[allow(clippy::cast_possible_wrap)]
269 #[instrument(skip(self, tarball_path, snapshot))]
270 async fn upload_snapshot(
271 &self,
272 container_id: &ContainerLayerId,
273 tarball_path: &Path,
274 snapshot: &LayerSnapshot,
275 ) -> Result<()> {
276 let object_key = self.config.object_key(&snapshot.digest);
277 let file_size = tokio::fs::metadata(tarball_path).await?.len();
278 let part_size = self.config.part_size_bytes;
279 #[allow(clippy::cast_possible_truncation)]
280 let total_parts = file_size.div_ceil(part_size) as u32;
281
282 info!(
283 "Uploading {} ({} bytes) in {} parts",
284 object_key, file_size, total_parts
285 );
286
287 let create_response = self
289 .s3_client
290 .create_multipart_upload()
291 .bucket(&self.config.bucket)
292 .key(&object_key)
293 .content_type("application/zstd")
294 .send()
295 .await
296 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
297
298 let upload_id = create_response
299 .upload_id()
300 .ok_or_else(|| LayerStorageError::S3("No upload ID returned".to_string()))?
301 .to_string();
302
303 let pending = PendingUpload {
305 upload_id: upload_id.clone(),
306 object_key: object_key.clone(),
307 total_parts,
308 completed_parts: HashMap::new(),
309 part_size,
310 local_tarball_path: tarball_path.to_path_buf(),
311 started_at: chrono::Utc::now(),
312 digest: snapshot.digest.clone(),
313 };
314
315 {
316 let key = container_id.to_key();
317 let mut states = self.states.write().await;
318 if let Some(state) = states.get_mut(&key) {
319 state.pending_upload = Some(pending.clone());
320 self.save_state(state).await?;
321 }
322 }
323
324 let completed_parts = self
326 .upload_parts(
327 tarball_path,
328 &upload_id,
329 &object_key,
330 total_parts,
331 part_size,
332 )
333 .await?;
334
335 let completed_upload = CompletedMultipartUpload::builder()
337 .set_parts(Some(
338 completed_parts
339 .into_iter()
340 .map(|(num, etag)| {
341 S3CompletedPart::builder()
342 .part_number(num as i32)
343 .e_tag(etag)
344 .build()
345 })
346 .collect(),
347 ))
348 .build();
349
350 self.s3_client
351 .complete_multipart_upload()
352 .bucket(&self.config.bucket)
353 .key(&object_key)
354 .upload_id(&upload_id)
355 .multipart_upload(completed_upload)
356 .send()
357 .await
358 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
359
360 let metadata_key = self.config.metadata_key(&snapshot.digest);
362 let metadata_json = serde_json::to_vec(snapshot)?;
363
364 self.s3_client
365 .put_object()
366 .bucket(&self.config.bucket)
367 .key(&metadata_key)
368 .body(ByteStream::from(metadata_json))
369 .content_type("application/json")
370 .send()
371 .await
372 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
373
374 info!("Upload complete: {}", object_key);
375 Ok(())
376 }
377
378 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
380 async fn upload_parts(
381 &self,
382 tarball_path: &Path,
383 upload_id: &str,
384 object_key: &str,
385 total_parts: u32,
386 part_size: u64,
387 ) -> Result<Vec<(u32, String)>> {
388 let mut completed = Vec::new();
389
390 for part_number in 1..=total_parts {
391 let offset = (u64::from(part_number) - 1) * part_size;
392
393 let mut file = File::open(tarball_path).await?;
395 file.seek(std::io::SeekFrom::Start(offset)).await?;
396
397 let mut buffer = vec![0u8; part_size as usize];
398 let bytes_read = file.read(&mut buffer).await?;
399 buffer.truncate(bytes_read);
400
401 let response = self
403 .s3_client
404 .upload_part()
405 .bucket(&self.config.bucket)
406 .key(object_key)
407 .upload_id(upload_id)
408 .part_number(part_number as i32)
409 .body(ByteStream::from(buffer))
410 .send()
411 .await
412 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
413
414 let etag = response
415 .e_tag()
416 .ok_or_else(|| LayerStorageError::S3("No ETag returned for part".to_string()))?
417 .to_string();
418
419 debug!("Uploaded part {}/{}: {}", part_number, total_parts, etag);
420 completed.push((part_number, etag));
421 }
422
423 Ok(completed)
424 }
425
426 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
428 #[instrument(skip(self, pending))]
429 async fn resume_upload(
430 &self,
431 container_id: &ContainerLayerId,
432 pending: PendingUpload,
433 ) -> Result<Option<LayerSnapshot>> {
434 let missing = pending.missing_parts();
435
436 if missing.is_empty() {
437 info!("All parts uploaded, completing multipart upload");
439 } else {
440 info!("Resuming upload, {} parts remaining", missing.len());
441
442 if !pending.local_tarball_path.exists() {
444 warn!("Local tarball missing, aborting upload and starting fresh");
445 self.abort_upload(&pending).await?;
446
447 let key = container_id.to_key();
448 let mut states = self.states.write().await;
449 if let Some(state) = states.get_mut(&key) {
450 state.pending_upload = None;
451 self.save_state(state).await?;
452 }
453
454 return Err(LayerStorageError::UploadInterrupted(
455 "Local tarball missing".to_string(),
456 ));
457 }
458
459 for part_number in missing {
461 let offset = (u64::from(part_number) - 1) * pending.part_size;
462
463 let mut file = File::open(&pending.local_tarball_path).await?;
464 file.seek(std::io::SeekFrom::Start(offset)).await?;
465
466 let mut buffer = vec![0u8; pending.part_size as usize];
467 let bytes_read = file.read(&mut buffer).await?;
468 buffer.truncate(bytes_read);
469
470 let response = self
471 .s3_client
472 .upload_part()
473 .bucket(&self.config.bucket)
474 .key(&pending.object_key)
475 .upload_id(&pending.upload_id)
476 .part_number(part_number as i32)
477 .body(ByteStream::from(buffer))
478 .send()
479 .await
480 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
481
482 let etag = response
483 .e_tag()
484 .ok_or_else(|| LayerStorageError::S3("No ETag returned".to_string()))?
485 .to_string();
486
487 debug!("Uploaded part {}: {}", part_number, etag);
488 }
489 }
490
491 let parts_response = self
493 .s3_client
494 .list_parts()
495 .bucket(&self.config.bucket)
496 .key(&pending.object_key)
497 .upload_id(&pending.upload_id)
498 .send()
499 .await
500 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
501
502 let completed_parts: Vec<S3CompletedPart> = parts_response
503 .parts()
504 .iter()
505 .map(|p| {
506 S3CompletedPart::builder()
507 .part_number(p.part_number().unwrap_or(0))
508 .e_tag(p.e_tag().unwrap_or_default())
509 .build()
510 })
511 .collect();
512
513 let completed_upload = CompletedMultipartUpload::builder()
515 .set_parts(Some(completed_parts))
516 .build();
517
518 self.s3_client
519 .complete_multipart_upload()
520 .bucket(&self.config.bucket)
521 .key(&pending.object_key)
522 .upload_id(&pending.upload_id)
523 .multipart_upload(completed_upload)
524 .send()
525 .await
526 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
527
528 let key = container_id.to_key();
530 {
531 let mut states = self.states.write().await;
532 if let Some(state) = states.get_mut(&key) {
533 state.local_digest = Some(pending.digest.clone());
534 state.remote_digest = Some(pending.digest.clone());
535 state.last_sync = Some(chrono::Utc::now());
536 state.pending_upload = None;
537 self.save_state(state).await?;
538 }
539 }
540
541 let _ = tokio::fs::remove_file(&pending.local_tarball_path).await;
543
544 info!("Upload resumed and completed successfully");
545
546 self.get_snapshot_metadata(&pending.digest).await.map(Some)
548 }
549
550 async fn abort_upload(&self, pending: &PendingUpload) -> Result<()> {
552 self.s3_client
553 .abort_multipart_upload()
554 .bucket(&self.config.bucket)
555 .key(&pending.object_key)
556 .upload_id(&pending.upload_id)
557 .send()
558 .await
559 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
560
561 Ok(())
562 }
563
564 #[instrument(skip(self, target_path))]
571 pub async fn restore_layer(
572 &self,
573 container_id: &ContainerLayerId,
574 target_path: impl AsRef<Path>,
575 ) -> Result<LayerSnapshot> {
576 let target_path = target_path.as_ref();
577 let key = container_id.to_key();
578
579 let remote_digest = {
581 let states = self.states.read().await;
582 states
583 .get(&key)
584 .and_then(|s| s.remote_digest.clone())
585 .ok_or_else(|| LayerStorageError::NotFound(format!("No remote layer for {key}")))?
586 };
587
588 info!("Restoring layer {} from S3", remote_digest);
589
590 let tarball_path = self
592 .config
593 .staging_dir
594 .join(format!("{remote_digest}.tar.zst"));
595
596 let object_key = self.config.object_key(&remote_digest);
597 let response = self
598 .s3_client
599 .get_object()
600 .bucket(&self.config.bucket)
601 .key(&object_key)
602 .send()
603 .await
604 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
605
606 let mut file = tokio::fs::File::create(&tarball_path).await?;
608 let mut stream = response.body.into_async_read();
609 tokio::io::copy(&mut stream, &mut file).await?;
610
611 let snapshot = self.get_snapshot_metadata(&remote_digest).await?;
613
614 tokio::task::spawn_blocking({
616 let tarball = tarball_path.clone();
617 let target = target_path.to_path_buf();
618 let digest = remote_digest.clone();
619 move || extract_snapshot(tarball, target, Some(&digest))
620 })
621 .await
622 .map_err(|e| LayerStorageError::Io(std::io::Error::other(e)))??;
623
624 {
626 let mut states = self.states.write().await;
627 if let Some(state) = states.get_mut(&key) {
628 state.local_digest = Some(remote_digest);
629 self.save_state(state).await?;
630 }
631 }
632
633 let _ = tokio::fs::remove_file(&tarball_path).await;
635
636 info!("Layer restored successfully");
637 Ok(snapshot)
638 }
639
640 async fn get_snapshot_metadata(&self, digest: &str) -> Result<LayerSnapshot> {
642 let metadata_key = self.config.metadata_key(digest);
643
644 let response = self
645 .s3_client
646 .get_object()
647 .bucket(&self.config.bucket)
648 .key(&metadata_key)
649 .send()
650 .await
651 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
652
653 let bytes = response
654 .body
655 .collect()
656 .await
657 .map_err(|e| LayerStorageError::S3(e.to_string()))?
658 .into_bytes();
659
660 serde_json::from_slice(&bytes).map_err(Into::into)
661 }
662
663 pub async fn list_containers(&self) -> Vec<ContainerLayerId> {
665 let states = self.states.read().await;
666 states.values().map(|s| s.container_id.clone()).collect()
667 }
668
669 pub async fn get_sync_state(&self, container_id: &ContainerLayerId) -> Option<SyncState> {
671 let states = self.states.read().await;
672 states.get(&container_id.to_key()).cloned()
673 }
674}
675
676use tokio::io::AsyncSeekExt;