1mod backend;
2pub mod cas;
3mod nbd;
4
5pub use backend::FlatFileBackend;
6pub use cas::{CasBackend, ChunkIndex, ChunkStore, LocalChunkStore};
7pub use nbd::NbdBackend;
8
9use std::os::unix::net::UnixListener;
10use std::path::Path;
11use std::sync::atomic::{AtomicI32, Ordering};
12use std::sync::Arc;
13
14use anyhow::{Context, Result};
15use tracing::{debug, info, warn};
16
17pub struct NbdHandle {
19 socket_path: String,
20 shutdown: Option<std::sync::mpsc::Sender<()>>,
21 thread: Option<std::thread::JoinHandle<()>>,
22 cas_backend: Option<Arc<CasBackend>>,
24 active_fd: Arc<AtomicI32>,
27}
28
29impl NbdHandle {
30 pub fn uri(&self) -> String {
32 format!("nbd+unix:///export?socket={}", self.socket_path)
33 }
34
35 pub fn save_checkpoint(&self, index_path: &str) -> Result<()> {
37 let backend = self.cas_backend.as_ref()
38 .ok_or_else(|| anyhow::anyhow!("save_checkpoint requires CAS backend"))?;
39 backend.save_index(index_path)
40 }
41}
42
43impl Drop for NbdHandle {
44 fn drop(&mut self) {
45 if let Some(ref backend) = self.cas_backend {
47 let _ = backend.flush();
48 }
49 if let Some(tx) = self.shutdown.take() {
51 let _ = tx.send(());
52 }
53 let fd = self.active_fd.load(Ordering::Acquire);
55 if fd >= 0 {
56 unsafe { libc::shutdown(fd, libc::SHUT_RDWR); }
57 }
58 let _ = std::os::unix::net::UnixStream::connect(&self.socket_path);
60 if let Some(thread) = self.thread.take() {
61 let _ = thread.join();
62 }
63 let _ = std::fs::remove_file(&self.socket_path);
64 }
65}
66
67fn start_nbd_with_backend(
68 backend: Arc<dyn NbdBackend>,
69 socket_path: &str,
70 cas_backend: Option<Arc<CasBackend>>,
71) -> Result<NbdHandle> {
72 let _ = std::fs::remove_file(socket_path);
73 let listener = UnixListener::bind(socket_path)
74 .with_context(|| format!("failed to bind NBD socket: {}", socket_path))?;
75 let (shutdown_tx, shutdown_rx) = std::sync::mpsc::channel::<()>();
77 let socket_path_owned = socket_path.to_string();
78 let active_fd = Arc::new(AtomicI32::new(-1));
79 let active_fd_thread = active_fd.clone();
80
81 let thread = std::thread::Builder::new()
82 .name("shuru-nbd".into())
83 .spawn(move || {
84 info!("NBD server listening on {}", socket_path_owned);
85 loop {
86 match listener.accept() {
87 Ok((stream, _)) => {
88 if shutdown_rx.try_recv().is_ok() {
89 debug!("NBD server shutting down");
90 break;
91 }
92 use std::os::unix::io::AsRawFd;
95 let fd = stream.as_raw_fd();
96 active_fd_thread.store(fd, Ordering::Release);
97 info!("NBD client connected");
98 if let Err(e) = nbd::handle_client(stream, backend.clone()) {
99 warn!("NBD client session ended: {}", e);
100 }
101 active_fd_thread.store(-1, Ordering::Release);
102 debug!("NBD client disconnected, waiting for reconnect...");
103 }
104 Err(e) => {
105 if shutdown_rx.try_recv().is_ok() {
106 break;
107 }
108 warn!("NBD accept error: {}", e);
109 }
110 }
111 }
112 info!("NBD server stopped");
113 })?;
114
115 Ok(NbdHandle {
116 socket_path: socket_path.to_string(),
117 shutdown: Some(shutdown_tx),
118 thread: Some(thread),
119 cas_backend,
120 active_fd,
121 })
122}
123
124pub fn start_cas_nbd_server(
127 rootfs_path: &str,
128 cas_dir: &str,
129 index_path: &str,
130 socket_path: &str,
131 disk_size: u64,
132) -> Result<NbdHandle> {
133 let store: Box<dyn ChunkStore> = Box::new(LocalChunkStore::open(cas_dir)?);
134
135 let (index, fallback, source_idx) = if Path::new(index_path).exists() {
136 info!("loading CAS index from {}", index_path);
137 let idx = ChunkIndex::load(index_path)?;
138 let fb = idx.fallback_path.as_ref().and_then(|p| {
140 FlatFileBackend::open(p).ok()
141 });
142 (idx, fb, Some(index_path.to_string()))
143 } else {
144 let fb = FlatFileBackend::open(rootfs_path)
146 .with_context(|| format!("failed to open rootfs for lazy ingestion: {}", rootfs_path))?;
147 let disk_size = fb.size();
148 info!("CAS: lazy mode, {} MB rootfs", disk_size / (1024 * 1024));
149 (ChunkIndex::new(disk_size), Some(fb), None)
150 };
151
152 let mut backend = if let Some(fb) = fallback {
153 CasBackend::with_fallback(store, index, fb)
154 } else {
155 CasBackend::new(store, index)
156 };
157 backend.source_index_path = source_idx;
158 if disk_size > 0 {
159 backend.set_disk_size(disk_size);
160 }
161 let cas = Arc::new(backend);
162 start_nbd_with_backend(cas.clone(), socket_path, Some(cas))
163}