1use std::io::{Read, Write};
2use std::sync::Arc;
3
4use tracing::{debug, warn};
5
6pub trait NbdBackend: Send + Sync {
8 fn size(&self) -> u64;
9 fn read(&self, offset: u64, buf: &mut [u8]) -> std::io::Result<usize>;
10 fn write(&self, offset: u64, buf: &[u8]) -> std::io::Result<usize>;
11 fn flush(&self) -> std::io::Result<()>;
12}
13
14impl NbdBackend for crate::backend::FlatFileBackend {
15 fn size(&self) -> u64 { self.size() }
16 fn read(&self, offset: u64, buf: &mut [u8]) -> std::io::Result<usize> { self.read(offset, buf) }
17 fn write(&self, offset: u64, buf: &[u8]) -> std::io::Result<usize> { self.write(offset, buf) }
18 fn flush(&self) -> std::io::Result<()> { self.flush() }
19}
20
21impl NbdBackend for crate::cas::CasBackend {
22 fn size(&self) -> u64 { self.size() }
23 fn read(&self, offset: u64, buf: &mut [u8]) -> std::io::Result<usize> { self.read(offset, buf) }
24 fn write(&self, offset: u64, buf: &[u8]) -> std::io::Result<usize> { self.write(offset, buf) }
25 fn flush(&self) -> std::io::Result<()> { self.flush() }
26}
27
28const NBDMAGIC: u64 = 0x4e42444d41474943;
30const IHAVEOPT: u64 = 0x49484156454F5054;
31const REPLY_MAGIC: u64 = 0x3e889045565a9;
32
33const NBD_FLAG_FIXED_NEWSTYLE: u16 = 1 << 0;
35const NBD_FLAG_NO_ZEROES: u16 = 1 << 1;
36
37const NBD_FLAG_C_NO_ZEROES: u32 = 1 << 1;
39
40const NBD_FLAG_HAS_FLAGS: u16 = 1 << 0;
42const NBD_FLAG_SEND_FLUSH: u16 = 1 << 2;
43
44const NBD_OPT_EXPORT_NAME: u32 = 1;
46const NBD_OPT_ABORT: u32 = 2;
47const NBD_OPT_INFO: u32 = 6;
48const NBD_OPT_GO: u32 = 7;
49
50const NBD_REP_ACK: u32 = 1;
52const NBD_REP_INFO: u32 = 3;
53const NBD_REP_ERR_UNSUP: u32 = (1 << 31) | 1;
54
55const NBD_INFO_EXPORT: u16 = 0;
57
58const NBD_CMD_READ: u16 = 0;
60const NBD_CMD_WRITE: u16 = 1;
61const NBD_CMD_DISC: u16 = 2;
62const NBD_CMD_FLUSH: u16 = 3;
63
64const NBD_SIMPLE_REPLY_MAGIC: u32 = 0x67446698;
66
67const NBD_OK: u32 = 0;
69const NBD_EIO: u32 = 5;
70const NBD_EINVAL: u32 = 22;
71
72pub fn handle_client(
74 mut stream: std::os::unix::net::UnixStream,
75 backend: Arc<dyn NbdBackend>,
76) -> anyhow::Result<()> {
77 handshake(&mut stream, backend.as_ref())?;
78 transmission(&mut stream, backend.as_ref())?;
79 Ok(())
80}
81
82fn handshake(
83 stream: &mut std::os::unix::net::UnixStream,
84 backend: &dyn NbdBackend,
85) -> anyhow::Result<()> {
86 stream.write_all(&NBDMAGIC.to_be_bytes())?;
88 stream.write_all(&IHAVEOPT.to_be_bytes())?;
89 let server_flags = NBD_FLAG_FIXED_NEWSTYLE | NBD_FLAG_NO_ZEROES;
90 stream.write_all(&server_flags.to_be_bytes())?;
91 stream.flush()?;
92
93 let mut buf = [0u8; 4];
95 stream.read_exact(&mut buf)?;
96 let client_flags = u32::from_be_bytes(buf);
97 let no_zeroes = (client_flags & NBD_FLAG_C_NO_ZEROES) != 0;
98
99 loop {
101 let mut opt_header = [0u8; 16];
103 stream.read_exact(&mut opt_header)?;
104 let magic = u64::from_be_bytes(opt_header[0..8].try_into().unwrap());
105 if magic != IHAVEOPT {
106 anyhow::bail!("bad option magic: {:#x}", magic);
107 }
108 let option = u32::from_be_bytes(opt_header[8..12].try_into().unwrap());
109 let data_len = u32::from_be_bytes(opt_header[12..16].try_into().unwrap());
110
111 let mut opt_data = vec![0u8; data_len as usize];
113 if data_len > 0 {
114 stream.read_exact(&mut opt_data)?;
115 }
116
117 match option {
118 NBD_OPT_EXPORT_NAME => {
119 let trans_flags = NBD_FLAG_HAS_FLAGS | NBD_FLAG_SEND_FLUSH;
121 stream.write_all(&backend.size().to_be_bytes())?;
122 stream.write_all(&trans_flags.to_be_bytes())?;
123 if !no_zeroes {
124 stream.write_all(&[0u8; 124])?;
125 }
126 stream.flush()?;
127 debug!("NBD handshake complete (EXPORT_NAME), size={}", backend.size());
128 return Ok(());
129 }
130 NBD_OPT_INFO | NBD_OPT_GO => {
131 send_export_info(stream, option, backend)?;
132 if option == NBD_OPT_GO {
133 debug!("NBD handshake complete (GO), size={}", backend.size());
134 return Ok(());
135 }
136 debug!("NBD INFO reply sent, size={}", backend.size());
137 }
138 NBD_OPT_ABORT => {
139 send_option_reply(stream, option, NBD_REP_ACK, &[])?;
140 stream.flush()?;
141 anyhow::bail!("client aborted");
142 }
143 _ => {
144 debug!("unsupported NBD option: {}", option);
145 send_option_reply(stream, option, NBD_REP_ERR_UNSUP, &[])?;
146 stream.flush()?;
147 }
148 }
149 }
150}
151
152fn send_export_info(
153 stream: &mut std::os::unix::net::UnixStream,
154 option: u32,
155 backend: &dyn NbdBackend,
156) -> std::io::Result<()> {
157 let trans_flags = NBD_FLAG_HAS_FLAGS | NBD_FLAG_SEND_FLUSH;
158 let mut info = [0u8; 12];
159 info[0..2].copy_from_slice(&NBD_INFO_EXPORT.to_be_bytes());
160 info[2..10].copy_from_slice(&backend.size().to_be_bytes());
161 info[10..12].copy_from_slice(&trans_flags.to_be_bytes());
162 send_option_reply(stream, option, NBD_REP_INFO, &info)?;
163 send_option_reply(stream, option, NBD_REP_ACK, &[])?;
164 stream.flush()
165}
166
167fn send_option_reply(
168 stream: &mut std::os::unix::net::UnixStream,
169 option: u32,
170 reply_type: u32,
171 data: &[u8],
172) -> std::io::Result<()> {
173 stream.write_all(&REPLY_MAGIC.to_be_bytes())?;
174 stream.write_all(&option.to_be_bytes())?;
175 stream.write_all(&reply_type.to_be_bytes())?;
176 stream.write_all(&(data.len() as u32).to_be_bytes())?;
177 if !data.is_empty() {
178 stream.write_all(data)?;
179 }
180 Ok(())
181}
182
183fn transmission(
184 stream: &mut std::os::unix::net::UnixStream,
185 backend: &dyn NbdBackend,
186) -> anyhow::Result<()> {
187 let mut req_header = [0u8; 28];
188
189 loop {
190 if let Err(e) = stream.read_exact(&mut req_header) {
191 debug!("NBD session ended while reading command: {}", e);
192 return Ok(());
193 }
194
195 let magic = u32::from_be_bytes(req_header[0..4].try_into().unwrap());
196 if magic != 0x25609513 {
197 anyhow::bail!("bad request magic: {:#x}", magic);
198 }
199
200 let cmd_type = u16::from_be_bytes(req_header[6..8].try_into().unwrap());
202 let handle = &req_header[8..16];
203 let offset = u64::from_be_bytes(req_header[16..24].try_into().unwrap());
204 let length = u32::from_be_bytes(req_header[24..28].try_into().unwrap());
205
206 match cmd_type {
207 NBD_CMD_READ => {
208 let mut buf = vec![0u8; length as usize];
209 let error = match backend.read(offset, &mut buf) {
210 Ok(n) => {
211 if n < length as usize {
212 buf[n..].fill(0);
213 }
214 NBD_OK
215 }
216 Err(e) => {
217 warn!("NBD read error at offset {}: {}", offset, e);
218 NBD_EIO
219 }
220 };
221 send_reply(stream, error, handle, if error == NBD_OK { Some(&buf) } else { None })?;
222 }
223 NBD_CMD_WRITE => {
224 let mut data = vec![0u8; length as usize];
225 stream.read_exact(&mut data)?;
226 let error = match backend.write(offset, &data) {
227 Ok(_) => NBD_OK,
228 Err(e) => {
229 warn!("NBD write error at offset {}: {}", offset, e);
230 NBD_EIO
231 }
232 };
233 send_reply(stream, error, handle, None)?;
234 }
235 NBD_CMD_FLUSH => {
236 let error = match backend.flush() {
237 Ok(()) => NBD_OK,
238 Err(e) => {
239 warn!("NBD flush error: {}", e);
240 NBD_EIO
241 }
242 };
243 send_reply(stream, error, handle, None)?;
244 }
245 NBD_CMD_DISC => {
246 debug!("NBD client sent disconnect");
247 return Ok(());
248 }
249 _ => {
250 warn!("unsupported NBD command: {}", cmd_type);
251 send_reply(stream, NBD_EINVAL, handle, None)?;
252 }
253 }
254 }
255}
256
257fn send_reply(
258 stream: &mut std::os::unix::net::UnixStream,
259 error: u32,
260 handle: &[u8],
261 data: Option<&[u8]>,
262) -> std::io::Result<()> {
263 stream.write_all(&NBD_SIMPLE_REPLY_MAGIC.to_be_bytes())?;
264 stream.write_all(&error.to_be_bytes())?;
265 stream.write_all(handle)?;
266 if let Some(data) = data {
267 stream.write_all(data)?;
268 }
269 stream.flush()?;
270 Ok(())
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use std::io::Write;
277 use crate::FlatFileBackend;
278
279 fn create_test_backend() -> (tempfile::NamedTempFile, Arc<FlatFileBackend>) {
280 let mut tmp = tempfile::NamedTempFile::new().unwrap();
281 let data = vec![0xABu8; 1024 * 1024]; tmp.write_all(&data).unwrap();
283 tmp.flush().unwrap();
284 let backend = Arc::new(FlatFileBackend::open(tmp.path().to_str().unwrap()).unwrap());
285 (tmp, backend)
286 }
287
288 #[test]
289 fn test_backend_read_write() {
290 let (_tmp, backend) = create_test_backend();
291 let mut buf = [0u8; 4];
292 backend.read(0, &mut buf).unwrap();
293 assert_eq!(buf, [0xAB; 4]);
294
295 backend.write(0, &[1, 2, 3, 4]).unwrap();
296 backend.read(0, &mut buf).unwrap();
297 assert_eq!(buf, [1, 2, 3, 4]);
298
299 let mut buf2 = [0u8; 4];
301 backend.read(4, &mut buf2).unwrap();
302 assert_eq!(buf2, [0xAB; 4]);
303 }
304}