Skip to main content

shuru_store/
nbd.rs

1use std::io::{Read, Write};
2use std::sync::Arc;
3
4use tracing::{debug, warn};
5
6/// Trait for NBD storage backends.
7pub trait NbdBackend: Send + Sync {
8    fn size(&self) -> u64;
9    fn read(&self, offset: u64, buf: &mut [u8]) -> std::io::Result<usize>;
10    fn write(&self, offset: u64, buf: &[u8]) -> std::io::Result<usize>;
11    fn flush(&self) -> std::io::Result<()>;
12}
13
14impl NbdBackend for crate::backend::FlatFileBackend {
15    fn size(&self) -> u64 { self.size() }
16    fn read(&self, offset: u64, buf: &mut [u8]) -> std::io::Result<usize> { self.read(offset, buf) }
17    fn write(&self, offset: u64, buf: &[u8]) -> std::io::Result<usize> { self.write(offset, buf) }
18    fn flush(&self) -> std::io::Result<()> { self.flush() }
19}
20
21impl NbdBackend for crate::cas::CasBackend {
22    fn size(&self) -> u64 { self.size() }
23    fn read(&self, offset: u64, buf: &mut [u8]) -> std::io::Result<usize> { self.read(offset, buf) }
24    fn write(&self, offset: u64, buf: &[u8]) -> std::io::Result<usize> { self.write(offset, buf) }
25    fn flush(&self) -> std::io::Result<()> { self.flush() }
26}
27
28// NBD magic values
29const NBDMAGIC: u64 = 0x4e42444d41474943;
30const IHAVEOPT: u64 = 0x49484156454F5054;
31const REPLY_MAGIC: u64 = 0x3e889045565a9;
32
33// Handshake flags
34const NBD_FLAG_FIXED_NEWSTYLE: u16 = 1 << 0;
35const NBD_FLAG_NO_ZEROES: u16 = 1 << 1;
36
37// Client flags
38const NBD_FLAG_C_NO_ZEROES: u32 = 1 << 1;
39
40// Transmission flags
41const NBD_FLAG_HAS_FLAGS: u16 = 1 << 0;
42const NBD_FLAG_SEND_FLUSH: u16 = 1 << 2;
43
44// Option types
45const NBD_OPT_EXPORT_NAME: u32 = 1;
46const NBD_OPT_ABORT: u32 = 2;
47const NBD_OPT_INFO: u32 = 6;
48const NBD_OPT_GO: u32 = 7;
49
50// Option reply types
51const NBD_REP_ACK: u32 = 1;
52const NBD_REP_INFO: u32 = 3;
53const NBD_REP_ERR_UNSUP: u32 = (1 << 31) | 1;
54
55// Info types
56const NBD_INFO_EXPORT: u16 = 0;
57
58// Command types
59const NBD_CMD_READ: u16 = 0;
60const NBD_CMD_WRITE: u16 = 1;
61const NBD_CMD_DISC: u16 = 2;
62const NBD_CMD_FLUSH: u16 = 3;
63
64// Reply magic
65const NBD_SIMPLE_REPLY_MAGIC: u32 = 0x67446698;
66
67// Errors
68const NBD_OK: u32 = 0;
69const NBD_EIO: u32 = 5;
70const NBD_EINVAL: u32 = 22;
71
72/// Handle one NBD client connection (blocking I/O on the stream).
73pub fn handle_client(
74    mut stream: std::os::unix::net::UnixStream,
75    backend: Arc<dyn NbdBackend>,
76) -> anyhow::Result<()> {
77    handshake(&mut stream, backend.as_ref())?;
78    transmission(&mut stream, backend.as_ref())?;
79    Ok(())
80}
81
82fn handshake(
83    stream: &mut std::os::unix::net::UnixStream,
84    backend: &dyn NbdBackend,
85) -> anyhow::Result<()> {
86    // Server sends: NBDMAGIC + IHAVEOPT + handshake flags
87    stream.write_all(&NBDMAGIC.to_be_bytes())?;
88    stream.write_all(&IHAVEOPT.to_be_bytes())?;
89    let server_flags = NBD_FLAG_FIXED_NEWSTYLE | NBD_FLAG_NO_ZEROES;
90    stream.write_all(&server_flags.to_be_bytes())?;
91    stream.flush()?;
92
93    // Client sends: client flags
94    let mut buf = [0u8; 4];
95    stream.read_exact(&mut buf)?;
96    let client_flags = u32::from_be_bytes(buf);
97    let no_zeroes = (client_flags & NBD_FLAG_C_NO_ZEROES) != 0;
98
99    // Option haggling loop
100    loop {
101        // Client sends: IHAVEOPT + option + length
102        let mut opt_header = [0u8; 16];
103        stream.read_exact(&mut opt_header)?;
104        let magic = u64::from_be_bytes(opt_header[0..8].try_into().unwrap());
105        if magic != IHAVEOPT {
106            anyhow::bail!("bad option magic: {:#x}", magic);
107        }
108        let option = u32::from_be_bytes(opt_header[8..12].try_into().unwrap());
109        let data_len = u32::from_be_bytes(opt_header[12..16].try_into().unwrap());
110
111        // Read option data (export name, info requests, etc.)
112        let mut opt_data = vec![0u8; data_len as usize];
113        if data_len > 0 {
114            stream.read_exact(&mut opt_data)?;
115        }
116
117        match option {
118            NBD_OPT_EXPORT_NAME => {
119                // Legacy negotiation: send export info directly, no reply header
120                let trans_flags = NBD_FLAG_HAS_FLAGS | NBD_FLAG_SEND_FLUSH;
121                stream.write_all(&backend.size().to_be_bytes())?;
122                stream.write_all(&trans_flags.to_be_bytes())?;
123                if !no_zeroes {
124                    stream.write_all(&[0u8; 124])?;
125                }
126                stream.flush()?;
127                debug!("NBD handshake complete (EXPORT_NAME), size={}", backend.size());
128                return Ok(());
129            }
130            NBD_OPT_INFO | NBD_OPT_GO => {
131                send_export_info(stream, option, backend)?;
132                if option == NBD_OPT_GO {
133                    debug!("NBD handshake complete (GO), size={}", backend.size());
134                    return Ok(());
135                }
136                debug!("NBD INFO reply sent, size={}", backend.size());
137            }
138            NBD_OPT_ABORT => {
139                send_option_reply(stream, option, NBD_REP_ACK, &[])?;
140                stream.flush()?;
141                anyhow::bail!("client aborted");
142            }
143            _ => {
144                debug!("unsupported NBD option: {}", option);
145                send_option_reply(stream, option, NBD_REP_ERR_UNSUP, &[])?;
146                stream.flush()?;
147            }
148        }
149    }
150}
151
152fn send_export_info(
153    stream: &mut std::os::unix::net::UnixStream,
154    option: u32,
155    backend: &dyn NbdBackend,
156) -> std::io::Result<()> {
157    let trans_flags = NBD_FLAG_HAS_FLAGS | NBD_FLAG_SEND_FLUSH;
158    let mut info = [0u8; 12];
159    info[0..2].copy_from_slice(&NBD_INFO_EXPORT.to_be_bytes());
160    info[2..10].copy_from_slice(&backend.size().to_be_bytes());
161    info[10..12].copy_from_slice(&trans_flags.to_be_bytes());
162    send_option_reply(stream, option, NBD_REP_INFO, &info)?;
163    send_option_reply(stream, option, NBD_REP_ACK, &[])?;
164    stream.flush()
165}
166
167fn send_option_reply(
168    stream: &mut std::os::unix::net::UnixStream,
169    option: u32,
170    reply_type: u32,
171    data: &[u8],
172) -> std::io::Result<()> {
173    stream.write_all(&REPLY_MAGIC.to_be_bytes())?;
174    stream.write_all(&option.to_be_bytes())?;
175    stream.write_all(&reply_type.to_be_bytes())?;
176    stream.write_all(&(data.len() as u32).to_be_bytes())?;
177    if !data.is_empty() {
178        stream.write_all(data)?;
179    }
180    Ok(())
181}
182
183fn transmission(
184    stream: &mut std::os::unix::net::UnixStream,
185    backend: &dyn NbdBackend,
186) -> anyhow::Result<()> {
187    let mut req_header = [0u8; 28];
188
189    loop {
190        if let Err(e) = stream.read_exact(&mut req_header) {
191            debug!("NBD session ended while reading command: {}", e);
192            return Ok(());
193        }
194
195        let magic = u32::from_be_bytes(req_header[0..4].try_into().unwrap());
196        if magic != 0x25609513 {
197            anyhow::bail!("bad request magic: {:#x}", magic);
198        }
199
200        // flags at [4..6] (ignored for now)
201        let cmd_type = u16::from_be_bytes(req_header[6..8].try_into().unwrap());
202        let handle = &req_header[8..16];
203        let offset = u64::from_be_bytes(req_header[16..24].try_into().unwrap());
204        let length = u32::from_be_bytes(req_header[24..28].try_into().unwrap());
205
206        match cmd_type {
207            NBD_CMD_READ => {
208                let mut buf = vec![0u8; length as usize];
209                let error = match backend.read(offset, &mut buf) {
210                    Ok(n) => {
211                        if n < length as usize {
212                            buf[n..].fill(0);
213                        }
214                        NBD_OK
215                    }
216                    Err(e) => {
217                        warn!("NBD read error at offset {}: {}", offset, e);
218                        NBD_EIO
219                    }
220                };
221                send_reply(stream, error, handle, if error == NBD_OK { Some(&buf) } else { None })?;
222            }
223            NBD_CMD_WRITE => {
224                let mut data = vec![0u8; length as usize];
225                stream.read_exact(&mut data)?;
226                let error = match backend.write(offset, &data) {
227                    Ok(_) => NBD_OK,
228                    Err(e) => {
229                        warn!("NBD write error at offset {}: {}", offset, e);
230                        NBD_EIO
231                    }
232                };
233                send_reply(stream, error, handle, None)?;
234            }
235            NBD_CMD_FLUSH => {
236                let error = match backend.flush() {
237                    Ok(()) => NBD_OK,
238                    Err(e) => {
239                        warn!("NBD flush error: {}", e);
240                        NBD_EIO
241                    }
242                };
243                send_reply(stream, error, handle, None)?;
244            }
245            NBD_CMD_DISC => {
246                debug!("NBD client sent disconnect");
247                return Ok(());
248            }
249            _ => {
250                warn!("unsupported NBD command: {}", cmd_type);
251                send_reply(stream, NBD_EINVAL, handle, None)?;
252            }
253        }
254    }
255}
256
257fn send_reply(
258    stream: &mut std::os::unix::net::UnixStream,
259    error: u32,
260    handle: &[u8],
261    data: Option<&[u8]>,
262) -> std::io::Result<()> {
263    stream.write_all(&NBD_SIMPLE_REPLY_MAGIC.to_be_bytes())?;
264    stream.write_all(&error.to_be_bytes())?;
265    stream.write_all(handle)?;
266    if let Some(data) = data {
267        stream.write_all(data)?;
268    }
269    stream.flush()?;
270    Ok(())
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use std::io::Write;
277    use crate::FlatFileBackend;
278
279    fn create_test_backend() -> (tempfile::NamedTempFile, Arc<FlatFileBackend>) {
280        let mut tmp = tempfile::NamedTempFile::new().unwrap();
281        let data = vec![0xABu8; 1024 * 1024]; // 1MB
282        tmp.write_all(&data).unwrap();
283        tmp.flush().unwrap();
284        let backend = Arc::new(FlatFileBackend::open(tmp.path().to_str().unwrap()).unwrap());
285        (tmp, backend)
286    }
287
288    #[test]
289    fn test_backend_read_write() {
290        let (_tmp, backend) = create_test_backend();
291        let mut buf = [0u8; 4];
292        backend.read(0, &mut buf).unwrap();
293        assert_eq!(buf, [0xAB; 4]);
294
295        backend.write(0, &[1, 2, 3, 4]).unwrap();
296        backend.read(0, &mut buf).unwrap();
297        assert_eq!(buf, [1, 2, 3, 4]);
298
299        // Original data unchanged after the write
300        let mut buf2 = [0u8; 4];
301        backend.read(4, &mut buf2).unwrap();
302        assert_eq!(buf2, [0xAB; 4]);
303    }
304}