1use crate::config::LayerStorageConfig;
6use crate::error::{LayerStorageError, Result};
7use crate::snapshot::{calculate_directory_digest, create_snapshot, extract_snapshot};
8use crate::types::{ContainerLayerId, LayerSnapshot, PendingUpload, SyncState};
9use aws_sdk_s3::primitives::ByteStream;
10use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart as S3CompletedPart};
11use aws_sdk_s3::Client as S3Client;
12use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
13use sqlx::SqlitePool;
14use std::collections::HashMap;
15use std::path::Path;
16use std::str::FromStr;
17use std::sync::Arc;
18use tokio::fs::File;
19use tokio::io::AsyncReadExt;
20use tokio::sync::RwLock;
21use tracing::{debug, info, instrument, warn};
22
23pub struct LayerSyncManager {
25 config: LayerStorageConfig,
26 s3_client: S3Client,
27 pool: SqlitePool,
28 states: Arc<RwLock<HashMap<String, SyncState>>>,
30}
31
32impl LayerSyncManager {
33 pub async fn new(config: LayerStorageConfig) -> Result<Self> {
40 tokio::fs::create_dir_all(&config.staging_dir).await?;
42 if let Some(parent) = config.state_db_path.parent() {
43 tokio::fs::create_dir_all(parent).await?;
44 }
45
46 let mut aws_config_builder = aws_config::from_env();
48
49 if let Some(region) = &config.region {
50 aws_config_builder =
51 aws_config_builder.region(aws_sdk_s3::config::Region::new(region.clone()));
52 } else if config.endpoint_url.is_some() {
53 aws_config_builder =
56 aws_config_builder.region(aws_sdk_s3::config::Region::new("us-east-1"));
57 }
58
59 let aws_config = aws_config_builder.load().await;
60
61 let s3_config = if let Some(endpoint) = &config.endpoint_url {
62 aws_sdk_s3::config::Builder::from(&aws_config)
63 .endpoint_url(endpoint)
64 .force_path_style(true)
65 .build()
66 } else {
67 aws_sdk_s3::config::Builder::from(&aws_config).build()
68 };
69
70 let s3_client = S3Client::from_conf(s3_config);
71
72 let db_url = format!("sqlite:{}?mode=rwc", config.state_db_path.display());
74 let connect_options = SqliteConnectOptions::from_str(&db_url)
75 .map_err(|e| LayerStorageError::Database(e.to_string()))?
76 .create_if_missing(true);
77
78 let pool = SqlitePoolOptions::new()
79 .max_connections(5)
80 .connect_with(connect_options)
81 .await
82 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
83
84 sqlx::query("PRAGMA journal_mode=WAL")
86 .execute(&pool)
87 .await
88 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
89
90 sqlx::query(
92 r"
93 CREATE TABLE IF NOT EXISTS sync_state (
94 container_key TEXT PRIMARY KEY NOT NULL,
95 state_json TEXT NOT NULL,
96 updated_at TEXT DEFAULT CURRENT_TIMESTAMP
97 )
98 ",
99 )
100 .execute(&pool)
101 .await
102 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
103
104 let states = Arc::new(RwLock::new(Self::load_all_states(&pool).await?));
106
107 Ok(Self {
108 config,
109 s3_client,
110 pool,
111 states,
112 })
113 }
114
115 async fn load_all_states(pool: &SqlitePool) -> Result<HashMap<String, SyncState>> {
116 let rows: Vec<(String, String)> =
117 sqlx::query_as("SELECT container_key, state_json FROM sync_state")
118 .fetch_all(pool)
119 .await
120 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
121
122 let mut states = HashMap::new();
123 for (key, json) in rows {
124 let state: SyncState = serde_json::from_str(&json)?;
125 states.insert(key, state);
126 }
127
128 Ok(states)
129 }
130
131 async fn save_state(&self, state: &SyncState) -> Result<()> {
132 let key = state.container_id.to_key();
133 let value = serde_json::to_string(state)?;
134
135 sqlx::query(
136 r"
137 INSERT OR REPLACE INTO sync_state (container_key, state_json, updated_at)
138 VALUES (?, ?, CURRENT_TIMESTAMP)
139 ",
140 )
141 .bind(&key)
142 .bind(&value)
143 .execute(&self.pool)
144 .await
145 .map_err(|e| LayerStorageError::Database(e.to_string()))?;
146
147 Ok(())
148 }
149
150 #[instrument(skip(self))]
156 pub async fn register_container(&self, container_id: ContainerLayerId) -> Result<()> {
157 let key = container_id.to_key();
158 let mut states = self.states.write().await;
159
160 if let std::collections::hash_map::Entry::Vacant(e) = states.entry(key) {
161 let state = SyncState::new(container_id);
162 self.save_state(&state).await?;
163 e.insert(state);
164 info!("Registered new container for layer sync");
165 }
166
167 Ok(())
168 }
169
170 #[instrument(skip(self, upper_layer_path))]
177 pub async fn check_for_changes(
178 &self,
179 container_id: &ContainerLayerId,
180 upper_layer_path: impl AsRef<Path>,
181 ) -> Result<bool> {
182 let key = container_id.to_key();
183 let states = self.states.read().await;
184
185 let state = states
186 .get(&key)
187 .ok_or_else(|| LayerStorageError::NotFound(key.clone()))?;
188
189 let current_digest = calculate_directory_digest(upper_layer_path)?;
191
192 Ok(state.local_digest.as_ref() != Some(¤t_digest))
194 }
195
196 #[instrument(skip(self, upper_layer_path), fields(container = %container_id))]
202 pub async fn sync_layer(
203 &self,
204 container_id: &ContainerLayerId,
205 upper_layer_path: impl AsRef<Path>,
206 ) -> Result<Option<LayerSnapshot>> {
207 let upper_layer_path = upper_layer_path.as_ref();
208 let key = container_id.to_key();
209
210 {
212 let states = self.states.read().await;
213 if let Some(state) = states.get(&key) {
214 if let Some(pending) = &state.pending_upload {
215 info!("Found pending upload, attempting to resume");
216 return self.resume_upload(container_id, pending.clone()).await;
217 }
218 }
219 }
220
221 let current_digest = calculate_directory_digest(upper_layer_path)?;
223
224 {
226 let states = self.states.read().await;
227 if let Some(state) = states.get(&key) {
228 if state.remote_digest.as_ref() == Some(¤t_digest) {
229 debug!("Layer already synced, no changes");
230 return Ok(None);
231 }
232 }
233 }
234
235 let tarball_path = self
237 .config
238 .staging_dir
239 .join(format!("{current_digest}.tar.zst"));
240
241 let snapshot = tokio::task::spawn_blocking({
242 let source = upper_layer_path.to_path_buf();
243 let output = tarball_path.clone();
244 let level = self.config.compression_level;
245 move || create_snapshot(source, output, level)
246 })
247 .await
248 .map_err(|e| LayerStorageError::Io(std::io::Error::other(e)))??;
249
250 self.upload_snapshot(container_id, &tarball_path, &snapshot)
252 .await?;
253
254 {
256 let mut states = self.states.write().await;
257 if let Some(state) = states.get_mut(&key) {
258 state.local_digest = Some(snapshot.digest.clone());
259 state.remote_digest = Some(snapshot.digest.clone());
260 state.last_sync = Some(chrono::Utc::now());
261 state.pending_upload = None;
262 self.save_state(state).await?;
263 }
264 }
265
266 let _ = tokio::fs::remove_file(&tarball_path).await;
268
269 Ok(Some(snapshot))
270 }
271
272 #[allow(clippy::cast_possible_wrap)]
274 #[instrument(skip(self, tarball_path, snapshot))]
275 async fn upload_snapshot(
276 &self,
277 container_id: &ContainerLayerId,
278 tarball_path: &Path,
279 snapshot: &LayerSnapshot,
280 ) -> Result<()> {
281 let object_key = self.config.object_key(&snapshot.digest);
282 let file_size = tokio::fs::metadata(tarball_path).await?.len();
283 let part_size = self.config.part_size_bytes;
284 #[allow(clippy::cast_possible_truncation)]
285 let total_parts = file_size.div_ceil(part_size) as u32;
286
287 info!(
288 "Uploading {} ({} bytes) in {} parts",
289 object_key, file_size, total_parts
290 );
291
292 let create_response = self
294 .s3_client
295 .create_multipart_upload()
296 .bucket(&self.config.bucket)
297 .key(&object_key)
298 .content_type("application/zstd")
299 .send()
300 .await
301 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
302
303 let upload_id = create_response
304 .upload_id()
305 .ok_or_else(|| LayerStorageError::S3("No upload ID returned".to_string()))?
306 .to_string();
307
308 let pending = PendingUpload {
310 upload_id: upload_id.clone(),
311 object_key: object_key.clone(),
312 total_parts,
313 completed_parts: HashMap::new(),
314 part_size,
315 local_tarball_path: tarball_path.to_path_buf(),
316 started_at: chrono::Utc::now(),
317 digest: snapshot.digest.clone(),
318 };
319
320 {
321 let key = container_id.to_key();
322 let mut states = self.states.write().await;
323 if let Some(state) = states.get_mut(&key) {
324 state.pending_upload = Some(pending.clone());
325 self.save_state(state).await?;
326 }
327 }
328
329 let completed_parts = self
331 .upload_parts(
332 tarball_path,
333 &upload_id,
334 &object_key,
335 total_parts,
336 part_size,
337 )
338 .await?;
339
340 let completed_upload = CompletedMultipartUpload::builder()
342 .set_parts(Some(
343 completed_parts
344 .into_iter()
345 .map(|(num, etag)| {
346 S3CompletedPart::builder()
347 .part_number(num as i32)
348 .e_tag(etag)
349 .build()
350 })
351 .collect(),
352 ))
353 .build();
354
355 self.s3_client
356 .complete_multipart_upload()
357 .bucket(&self.config.bucket)
358 .key(&object_key)
359 .upload_id(&upload_id)
360 .multipart_upload(completed_upload)
361 .send()
362 .await
363 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
364
365 let metadata_key = self.config.metadata_key(&snapshot.digest);
367 let metadata_json = serde_json::to_vec(snapshot)?;
368
369 self.s3_client
370 .put_object()
371 .bucket(&self.config.bucket)
372 .key(&metadata_key)
373 .body(ByteStream::from(metadata_json))
374 .content_type("application/json")
375 .send()
376 .await
377 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
378
379 info!("Upload complete: {}", object_key);
380 Ok(())
381 }
382
383 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
385 async fn upload_parts(
386 &self,
387 tarball_path: &Path,
388 upload_id: &str,
389 object_key: &str,
390 total_parts: u32,
391 part_size: u64,
392 ) -> Result<Vec<(u32, String)>> {
393 let mut completed = Vec::new();
394
395 for part_number in 1..=total_parts {
396 let offset = (u64::from(part_number) - 1) * part_size;
397
398 let mut file = File::open(tarball_path).await?;
400 file.seek(std::io::SeekFrom::Start(offset)).await?;
401
402 let mut buffer = vec![0u8; part_size as usize];
403 let bytes_read = file.read(&mut buffer).await?;
404 buffer.truncate(bytes_read);
405
406 let response = self
408 .s3_client
409 .upload_part()
410 .bucket(&self.config.bucket)
411 .key(object_key)
412 .upload_id(upload_id)
413 .part_number(part_number as i32)
414 .body(ByteStream::from(buffer))
415 .send()
416 .await
417 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
418
419 let etag = response
420 .e_tag()
421 .ok_or_else(|| LayerStorageError::S3("No ETag returned for part".to_string()))?
422 .to_string();
423
424 debug!("Uploaded part {}/{}: {}", part_number, total_parts, etag);
425 completed.push((part_number, etag));
426 }
427
428 Ok(completed)
429 }
430
431 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
433 #[instrument(skip(self, pending))]
434 async fn resume_upload(
435 &self,
436 container_id: &ContainerLayerId,
437 pending: PendingUpload,
438 ) -> Result<Option<LayerSnapshot>> {
439 let missing = pending.missing_parts();
440
441 if missing.is_empty() {
442 info!("All parts uploaded, completing multipart upload");
444 } else {
445 info!("Resuming upload, {} parts remaining", missing.len());
446
447 if !pending.local_tarball_path.exists() {
449 warn!("Local tarball missing, aborting upload and starting fresh");
450 self.abort_upload(&pending).await?;
451
452 let key = container_id.to_key();
453 let mut states = self.states.write().await;
454 if let Some(state) = states.get_mut(&key) {
455 state.pending_upload = None;
456 self.save_state(state).await?;
457 }
458
459 return Err(LayerStorageError::UploadInterrupted(
460 "Local tarball missing".to_string(),
461 ));
462 }
463
464 for part_number in missing {
466 let offset = (u64::from(part_number) - 1) * pending.part_size;
467
468 let mut file = File::open(&pending.local_tarball_path).await?;
469 file.seek(std::io::SeekFrom::Start(offset)).await?;
470
471 let mut buffer = vec![0u8; pending.part_size as usize];
472 let bytes_read = file.read(&mut buffer).await?;
473 buffer.truncate(bytes_read);
474
475 let response = self
476 .s3_client
477 .upload_part()
478 .bucket(&self.config.bucket)
479 .key(&pending.object_key)
480 .upload_id(&pending.upload_id)
481 .part_number(part_number as i32)
482 .body(ByteStream::from(buffer))
483 .send()
484 .await
485 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
486
487 let etag = response
488 .e_tag()
489 .ok_or_else(|| LayerStorageError::S3("No ETag returned".to_string()))?
490 .to_string();
491
492 debug!("Uploaded part {}: {}", part_number, etag);
493 }
494 }
495
496 let parts_response = self
498 .s3_client
499 .list_parts()
500 .bucket(&self.config.bucket)
501 .key(&pending.object_key)
502 .upload_id(&pending.upload_id)
503 .send()
504 .await
505 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
506
507 let completed_parts: Vec<S3CompletedPart> = parts_response
508 .parts()
509 .iter()
510 .map(|p| {
511 S3CompletedPart::builder()
512 .part_number(p.part_number().unwrap_or(0))
513 .e_tag(p.e_tag().unwrap_or_default())
514 .build()
515 })
516 .collect();
517
518 let completed_upload = CompletedMultipartUpload::builder()
520 .set_parts(Some(completed_parts))
521 .build();
522
523 self.s3_client
524 .complete_multipart_upload()
525 .bucket(&self.config.bucket)
526 .key(&pending.object_key)
527 .upload_id(&pending.upload_id)
528 .multipart_upload(completed_upload)
529 .send()
530 .await
531 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
532
533 let key = container_id.to_key();
535 {
536 let mut states = self.states.write().await;
537 if let Some(state) = states.get_mut(&key) {
538 state.local_digest = Some(pending.digest.clone());
539 state.remote_digest = Some(pending.digest.clone());
540 state.last_sync = Some(chrono::Utc::now());
541 state.pending_upload = None;
542 self.save_state(state).await?;
543 }
544 }
545
546 let _ = tokio::fs::remove_file(&pending.local_tarball_path).await;
548
549 info!("Upload resumed and completed successfully");
550
551 self.get_snapshot_metadata(&pending.digest).await.map(Some)
553 }
554
555 async fn abort_upload(&self, pending: &PendingUpload) -> Result<()> {
557 self.s3_client
558 .abort_multipart_upload()
559 .bucket(&self.config.bucket)
560 .key(&pending.object_key)
561 .upload_id(&pending.upload_id)
562 .send()
563 .await
564 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
565
566 Ok(())
567 }
568
569 #[instrument(skip(self, target_path))]
576 pub async fn restore_layer(
577 &self,
578 container_id: &ContainerLayerId,
579 target_path: impl AsRef<Path>,
580 ) -> Result<LayerSnapshot> {
581 let target_path = target_path.as_ref();
582 let key = container_id.to_key();
583
584 let remote_digest = {
586 let states = self.states.read().await;
587 states
588 .get(&key)
589 .and_then(|s| s.remote_digest.clone())
590 .ok_or_else(|| LayerStorageError::NotFound(format!("No remote layer for {key}")))?
591 };
592
593 info!("Restoring layer {} from S3", remote_digest);
594
595 let tarball_path = self
597 .config
598 .staging_dir
599 .join(format!("{remote_digest}.tar.zst"));
600
601 let object_key = self.config.object_key(&remote_digest);
602 let response = self
603 .s3_client
604 .get_object()
605 .bucket(&self.config.bucket)
606 .key(&object_key)
607 .send()
608 .await
609 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
610
611 let mut file = tokio::fs::File::create(&tarball_path).await?;
613 let mut stream = response.body.into_async_read();
614 tokio::io::copy(&mut stream, &mut file).await?;
615
616 let snapshot = self.get_snapshot_metadata(&remote_digest).await?;
618
619 tokio::task::spawn_blocking({
621 let tarball = tarball_path.clone();
622 let target = target_path.to_path_buf();
623 let digest = remote_digest.clone();
624 move || extract_snapshot(tarball, target, Some(&digest))
625 })
626 .await
627 .map_err(|e| LayerStorageError::Io(std::io::Error::other(e)))??;
628
629 {
631 let mut states = self.states.write().await;
632 if let Some(state) = states.get_mut(&key) {
633 state.local_digest = Some(remote_digest);
634 self.save_state(state).await?;
635 }
636 }
637
638 let _ = tokio::fs::remove_file(&tarball_path).await;
640
641 info!("Layer restored successfully");
642 Ok(snapshot)
643 }
644
645 async fn get_snapshot_metadata(&self, digest: &str) -> Result<LayerSnapshot> {
647 let metadata_key = self.config.metadata_key(digest);
648
649 let response = self
650 .s3_client
651 .get_object()
652 .bucket(&self.config.bucket)
653 .key(&metadata_key)
654 .send()
655 .await
656 .map_err(|e| LayerStorageError::S3(e.to_string()))?;
657
658 let bytes = response
659 .body
660 .collect()
661 .await
662 .map_err(|e| LayerStorageError::S3(e.to_string()))?
663 .into_bytes();
664
665 serde_json::from_slice(&bytes).map_err(Into::into)
666 }
667
668 pub async fn list_containers(&self) -> Vec<ContainerLayerId> {
670 let states = self.states.read().await;
671 states.values().map(|s| s.container_id.clone()).collect()
672 }
673
674 pub async fn get_sync_state(&self, container_id: &ContainerLayerId) -> Option<SyncState> {
676 let states = self.states.read().await;
677 states.get(&container_id.to_key()).cloned()
678 }
679}
680
681use tokio::io::AsyncSeekExt;