1use crate::database::{Database, TxnHandle};
87use parking_lot::Mutex;
88use std::collections::HashMap;
89use std::io::{BufReader, BufWriter, Read, Write};
90use std::os::unix::fs::PermissionsExt;
91use std::os::unix::net::{UnixListener, UnixStream};
92use std::path::{Path, PathBuf};
93use std::sync::Arc;
94use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
95use std::thread::{self, JoinHandle};
96use std::time::{Duration, Instant};
97use thiserror::Error;
98
99mod opcode {
105 pub const PUT: u8 = 0x01;
106 pub const GET: u8 = 0x02;
107 pub const DELETE: u8 = 0x03;
108 pub const BEGIN_TXN: u8 = 0x04;
109 pub const COMMIT_TXN: u8 = 0x05;
110 pub const ABORT_TXN: u8 = 0x06;
111 pub const QUERY: u8 = 0x07;
112 pub const CREATE_TABLE: u8 = 0x08;
113 pub const PUT_PATH: u8 = 0x09;
114 pub const GET_PATH: u8 = 0x0A;
115 pub const SCAN: u8 = 0x0B;
116 pub const CHECKPOINT: u8 = 0x0C;
117 pub const STATS: u8 = 0x0D;
118 pub const PING: u8 = 0x0E;
119 pub const EXECUTE_SQL: u8 = 0x0F;
120
121 pub const OK: u8 = 0x80;
123 pub const ERROR: u8 = 0x81;
124 pub const VALUE: u8 = 0x82;
125 pub const TXN_ID: u8 = 0x83;
126 #[allow(dead_code)]
127 pub const ROW: u8 = 0x84;
128 #[allow(dead_code)]
129 pub const END_STREAM: u8 = 0x85;
130 pub const STATS_RESP: u8 = 0x86;
131 pub const PONG: u8 = 0x87;
132}
133
134const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
136
137#[derive(Debug, Error)]
142pub enum IpcError {
143 #[error("I/O error: {0}")]
144 Io(#[from] std::io::Error),
145
146 #[error("Database error: {0}")]
147 Database(String),
148
149 #[error("Protocol error: {0}")]
150 Protocol(String),
151
152 #[error("Server already running")]
153 AlreadyRunning,
154
155 #[error("Server not running")]
156 NotRunning,
157
158 #[error("Connection closed")]
159 ConnectionClosed,
160
161 #[error("Message too large: {0} bytes (max: {1})")]
162 MessageTooLarge(usize, usize),
163
164 #[error("Invalid opcode: {0:#x}")]
165 InvalidOpcode(u8),
166
167 #[error("Transaction not found: {0}")]
168 TxnNotFound(u64),
169}
170
171pub type Result<T> = std::result::Result<T, IpcError>;
172
173#[derive(Debug, Clone)]
179pub struct Message {
180 pub opcode: u8,
181 pub payload: Vec<u8>,
182}
183
184impl Message {
185 pub fn new(opcode: u8, payload: Vec<u8>) -> Self {
186 Self { opcode, payload }
187 }
188
189 pub fn ok() -> Self {
190 Self::new(opcode::OK, vec![])
191 }
192
193 pub fn error(msg: &str) -> Self {
194 Self::new(opcode::ERROR, msg.as_bytes().to_vec())
195 }
196
197 pub fn value(data: Vec<u8>) -> Self {
198 Self::new(opcode::VALUE, data)
199 }
200
201 pub fn txn_id(id: u64) -> Self {
202 Self::new(opcode::TXN_ID, id.to_le_bytes().to_vec())
203 }
204
205 pub fn read_from<R: Read>(reader: &mut R) -> Result<Self> {
207 let mut opcode_buf = [0u8; 1];
209 match reader.read_exact(&mut opcode_buf) {
210 Ok(_) => {}
211 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
212 return Err(IpcError::ConnectionClosed);
213 }
214 Err(e) => return Err(e.into()),
215 }
216 let opcode = opcode_buf[0];
217
218 let mut len_buf = [0u8; 4];
220 reader.read_exact(&mut len_buf)?;
221 let len = u32::from_le_bytes(len_buf) as usize;
222
223 if len > MAX_MESSAGE_SIZE {
225 return Err(IpcError::MessageTooLarge(len, MAX_MESSAGE_SIZE));
226 }
227
228 let mut payload = vec![0u8; len];
230 if len > 0 {
231 reader.read_exact(&mut payload)?;
232 }
233
234 Ok(Self { opcode, payload })
235 }
236
237 pub fn write_to<W: Write>(&self, writer: &mut W) -> Result<()> {
239 writer.write_all(&[self.opcode])?;
241
242 let len = self.payload.len() as u32;
244 writer.write_all(&len.to_le_bytes())?;
245
246 if !self.payload.is_empty() {
248 writer.write_all(&self.payload)?;
249 }
250
251 writer.flush()?;
252 Ok(())
253 }
254}
255
256fn encode_put(key: &[u8], value: &[u8]) -> Vec<u8> {
262 let mut buf = Vec::with_capacity(4 + key.len() + value.len());
263 buf.extend_from_slice(&(key.len() as u32).to_le_bytes());
264 buf.extend_from_slice(key);
265 buf.extend_from_slice(value);
266 buf
267}
268
269fn decode_put(payload: &[u8]) -> Result<(&[u8], &[u8])> {
271 if payload.len() < 4 {
272 return Err(IpcError::Protocol("PUT payload too short".into()));
273 }
274 let key_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
275 if payload.len() < 4 + key_len {
276 return Err(IpcError::Protocol("PUT payload key truncated".into()));
277 }
278 let key = &payload[4..4 + key_len];
279 let value = &payload[4 + key_len..];
280 Ok((key, value))
281}
282
283fn encode_put_path(path: &[&str], value: &[u8]) -> Vec<u8> {
285 let mut buf = Vec::new();
286 buf.extend_from_slice(&(path.len() as u16).to_le_bytes());
287 for segment in path {
288 let seg_bytes = segment.as_bytes();
289 buf.extend_from_slice(&(seg_bytes.len() as u16).to_le_bytes());
290 buf.extend_from_slice(seg_bytes);
291 }
292 buf.extend_from_slice(value);
293 buf
294}
295
296fn decode_path(payload: &[u8]) -> Result<(Vec<String>, &[u8])> {
298 if payload.len() < 2 {
299 return Err(IpcError::Protocol("Path payload too short".into()));
300 }
301 let count = u16::from_le_bytes(payload[0..2].try_into().unwrap()) as usize;
302 let mut offset = 2;
303 let mut path = Vec::with_capacity(count);
304
305 for _ in 0..count {
306 if offset + 2 > payload.len() {
307 return Err(IpcError::Protocol("Path segment length truncated".into()));
308 }
309 let seg_len = u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
310 offset += 2;
311 if offset + seg_len > payload.len() {
312 return Err(IpcError::Protocol("Path segment truncated".into()));
313 }
314 let segment = std::str::from_utf8(&payload[offset..offset + seg_len])
315 .map_err(|_| IpcError::Protocol("Invalid UTF-8 in path".into()))?;
316 path.push(segment.to_string());
317 offset += seg_len;
318 }
319
320 Ok((path, &payload[offset..]))
321}
322
323#[derive(Debug, Default)]
328pub struct ServerStats {
329 pub connections_total: AtomicU64,
330 pub connections_active: AtomicU64,
331 pub requests_total: AtomicU64,
332 pub requests_success: AtomicU64,
333 pub requests_error: AtomicU64,
334 pub bytes_received: AtomicU64,
335 pub bytes_sent: AtomicU64,
336 pub start_time: Mutex<Option<Instant>>,
337}
338
339impl ServerStats {
340 pub fn new() -> Self {
341 Self::default()
342 }
343
344 pub fn snapshot(&self) -> ServerStatsSnapshot {
345 ServerStatsSnapshot {
346 connections_total: self.connections_total.load(Ordering::Relaxed),
347 connections_active: self.connections_active.load(Ordering::Relaxed),
348 requests_total: self.requests_total.load(Ordering::Relaxed),
349 requests_success: self.requests_success.load(Ordering::Relaxed),
350 requests_error: self.requests_error.load(Ordering::Relaxed),
351 bytes_received: self.bytes_received.load(Ordering::Relaxed),
352 bytes_sent: self.bytes_sent.load(Ordering::Relaxed),
353 uptime_secs: self
354 .start_time
355 .lock()
356 .map(|t| t.elapsed().as_secs())
357 .unwrap_or(0),
358 }
359 }
360}
361
362#[derive(Debug, Clone)]
363pub struct ServerStatsSnapshot {
364 pub connections_total: u64,
365 pub connections_active: u64,
366 pub requests_total: u64,
367 pub requests_success: u64,
368 pub requests_error: u64,
369 pub bytes_received: u64,
370 pub bytes_sent: u64,
371 pub uptime_secs: u64,
372}
373
374struct ClientHandler {
379 db: Arc<Database>,
380 stream: UnixStream,
381 stats: Arc<ServerStats>,
382 active_txns: HashMap<u64, TxnHandle>, next_client_txn_id: u64,
384}
385
386impl ClientHandler {
387 fn new(db: Arc<Database>, stream: UnixStream, stats: Arc<ServerStats>) -> Self {
388 Self {
389 db,
390 stream,
391 stats,
392 active_txns: HashMap::new(),
393 next_client_txn_id: 1,
394 }
395 }
396
397 fn handle(&mut self) -> Result<()> {
398 self.stream
400 .set_read_timeout(Some(Duration::from_secs(30)))?;
401
402 let mut reader = BufReader::new(self.stream.try_clone()?);
403 let mut writer = BufWriter::new(self.stream.try_clone()?);
404
405 loop {
406 let request = match Message::read_from(&mut reader) {
408 Ok(msg) => msg,
409 Err(IpcError::ConnectionClosed) => {
410 self.cleanup_transactions();
412 return Ok(());
413 }
414 Err(e) => return Err(e),
415 };
416
417 self.stats.requests_total.fetch_add(1, Ordering::Relaxed);
418 self.stats
419 .bytes_received
420 .fetch_add((5 + request.payload.len()) as u64, Ordering::Relaxed);
421
422 let response = self.process_request(&request);
424
425 if response.opcode == opcode::ERROR {
427 self.stats.requests_error.fetch_add(1, Ordering::Relaxed);
428 } else {
429 self.stats.requests_success.fetch_add(1, Ordering::Relaxed);
430 }
431
432 self.stats
434 .bytes_sent
435 .fetch_add((5 + response.payload.len()) as u64, Ordering::Relaxed);
436 response.write_to(&mut writer)?;
437 }
438 }
439
440 fn process_request(&mut self, request: &Message) -> Message {
441 match request.opcode {
442 opcode::PING => Message::new(opcode::PONG, vec![]),
443
444 opcode::PUT => self.handle_put(&request.payload),
445 opcode::GET => self.handle_get(&request.payload),
446 opcode::DELETE => self.handle_delete(&request.payload),
447
448 opcode::BEGIN_TXN => self.handle_begin_txn(),
449 opcode::COMMIT_TXN => self.handle_commit_txn(&request.payload),
450 opcode::ABORT_TXN => self.handle_abort_txn(&request.payload),
451
452 opcode::PUT_PATH => self.handle_put_path(&request.payload),
453 opcode::GET_PATH => self.handle_get_path(&request.payload),
454
455 opcode::QUERY => self.handle_query(&request.payload),
456 opcode::CREATE_TABLE => self.handle_create_table(&request.payload),
457 opcode::SCAN => self.handle_scan(&request.payload),
458 opcode::EXECUTE_SQL => self.handle_execute_sql(&request.payload),
459
460 opcode::CHECKPOINT => self.handle_checkpoint(),
461 opcode::STATS => self.handle_stats(),
462
463 _ => Message::error(&format!("Unknown opcode: {:#x}", request.opcode)),
464 }
465 }
466
467 fn handle_execute_sql(&self, payload: &[u8]) -> Message {
468 let sql = match std::str::from_utf8(payload) {
470 Ok(s) => s,
471 Err(_) => return Message::error("Invalid UTF-8 in SQL query"),
472 };
473
474 let result = serde_json::json!({
477 "error": "SQL execution must be implemented client-side. Use Python SDK for full SQL support.",
478 "sql": sql
479 });
480
481 match serde_json::to_vec(&result) {
482 Ok(json) => Message::value(json),
483 Err(e) => Message::error(&format!("Failed to serialize error: {}", e)),
484 }
485 }
486
487 fn handle_query(&self, payload: &[u8]) -> Message {
488 let mut offset = 0;
490
491 if payload.len() < 2 {
492 return Message::error("Query payload too short");
493 }
494
495 let path_len = u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
497 offset += 2;
498 if offset + path_len > payload.len() {
499 return Message::error("Query path truncated");
500 }
501 let path = match std::str::from_utf8(&payload[offset..offset + path_len]) {
502 Ok(s) => s,
503 Err(_) => return Message::error("Invalid UTF-8 in query path"),
504 };
505 offset += path_len;
506
507 if offset + 8 > payload.len() {
509 return Message::error("Query limit/offset truncated");
510 }
511 let limit_val =
512 u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
513 offset += 4;
514 let offset_val =
515 u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
516 offset += 4;
517
518 if offset + 2 > payload.len() {
520 return Message::error("Query columns count truncated");
521 }
522 let cols_count =
523 u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
524 offset += 2;
525
526 let mut columns = Vec::with_capacity(cols_count);
527 for _ in 0..cols_count {
528 if offset + 2 > payload.len() {
529 return Message::error("Query column length truncated");
530 }
531 let col_len =
532 u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
533 offset += 2;
534
535 if offset + col_len > payload.len() {
536 return Message::error("Query column truncated");
537 }
538 let col = match std::str::from_utf8(&payload[offset..offset + col_len]) {
539 Ok(s) => s.to_string(),
540 Err(_) => return Message::error("Invalid UTF-8 in query column"),
541 };
542 columns.push(col);
543 offset += col_len;
544 }
545
546 let txn = match self.db.begin_transaction() {
549 Ok(t) => t,
550 Err(e) => return Message::error(&e.to_string()),
551 };
552
553 let mut builder = self.db.query(txn, path);
554
555 if limit_val > 0 {
556 builder = builder.limit(limit_val);
557 }
558 if offset_val > 0 {
559 builder = builder.offset(offset_val);
560 }
561
562 if !columns.is_empty() {
563 let cols_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
564 builder = builder.columns(&cols_refs);
565 }
566
567 let result = builder.to_toon();
568 let _ = self.db.abort(txn); match result {
571 Ok(soch_str) => Message::new(opcode::VALUE, soch_str.into_bytes()),
572 Err(e) => Message::error(&e.to_string()),
573 }
574 }
575
576 fn handle_scan(&self, payload: &[u8]) -> Message {
577 let prefix = match std::str::from_utf8(payload) {
578 Ok(s) => s,
579 Err(_) => return Message::error("Invalid UTF-8 in scan prefix"),
580 };
581
582 let txn = match self.db.begin_transaction() {
583 Ok(t) => t,
584 Err(e) => return Message::error(&e.to_string()),
585 };
586
587 let result = self.db.scan_path(txn, prefix);
588 let _ = self.db.abort(txn);
589
590 match result {
591 Ok(items) => {
592 let mut buf = Vec::new();
596 buf.extend_from_slice(&(items.len() as u32).to_le_bytes());
597
598 for (key, val) in items {
599 let key_bytes = key.as_bytes();
600 buf.extend_from_slice(&(key_bytes.len() as u16).to_le_bytes());
601 buf.extend_from_slice(key_bytes);
602 buf.extend_from_slice(&(val.len() as u32).to_le_bytes());
603 buf.extend_from_slice(&val);
604 }
605
606 Message::new(opcode::VALUE, buf)
607 }
608 Err(e) => Message::error(&e.to_string()),
609 }
610 }
611
612 fn handle_create_table(&self, payload: &[u8]) -> Message {
613 let _schema_json = match std::str::from_utf8(payload) {
615 Ok(s) => s,
616 Err(_) => return Message::error("Invalid UTF-8 in schema"),
617 };
618
619 let mut offset = 0;
633 if payload.len() < 2 {
634 return Message::error("Schema payload too short");
635 }
636
637 let name_len = u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
638 offset += 2;
639 if offset + name_len > payload.len() {
640 return Message::error("Schema name truncated");
641 }
642 let name = match std::str::from_utf8(&payload[offset..offset + name_len]) {
643 Ok(s) => s.to_string(),
644 Err(_) => return Message::error("Invalid UTF-8 in schema name"),
645 };
646 offset += name_len;
647
648 if offset + 2 > payload.len() {
649 return Message::error("Schema column count truncated");
650 }
651 let col_count =
652 u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
653 offset += 2;
654
655 let mut columns = Vec::with_capacity(col_count);
656 for _ in 0..col_count {
657 if offset + 2 > payload.len() {
658 return Message::error("Column name length truncated");
659 }
660 let col_name_len =
661 u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
662 offset += 2;
663
664 if offset + col_name_len > payload.len() {
665 return Message::error("Column name truncated");
666 }
667 let col_name = match std::str::from_utf8(&payload[offset..offset + col_name_len]) {
668 Ok(s) => s.to_string(),
669 Err(_) => return Message::error("Invalid UTF-8 in column name"),
670 };
671 offset += col_name_len;
672
673 if offset + 2 > payload.len() {
674 return Message::error("Column type/nullable truncated");
675 }
676 let type_byte = payload[offset];
677 offset += 1;
678 let nullable_byte = payload[offset];
679 offset += 1;
680
681 let col_type = match type_byte {
682 0 => crate::database::ColumnType::Int64,
683 1 => crate::database::ColumnType::UInt64,
684 2 => crate::database::ColumnType::Float64,
685 3 => crate::database::ColumnType::Text,
686 4 => crate::database::ColumnType::Binary,
687 5 => crate::database::ColumnType::Bool,
688 _ => return Message::error("Invalid column type"),
689 };
690
691 columns.push(crate::database::ColumnDef {
692 name: col_name,
693 col_type,
694 nullable: nullable_byte != 0,
695 });
696 }
697
698 let schema = crate::database::TableSchema { name, columns };
699
700 match self.db.register_table(schema) {
701 Ok(_) => Message::ok(),
702 Err(e) => Message::error(&e.to_string()),
703 }
704 }
705
706 fn handle_put(&self, payload: &[u8]) -> Message {
708 match decode_put(payload) {
709 Ok((key, value)) => {
710 let txn = match self.db.begin_transaction() {
712 Ok(t) => t,
713 Err(e) => return Message::error(&e.to_string()),
714 };
715
716 if let Err(e) = self.db.put(txn, key, value) {
717 let _ = self.db.abort(txn);
718 return Message::error(&e.to_string());
719 }
720
721 match self.db.commit(txn) {
722 Ok(_) => Message::ok(),
723 Err(e) => Message::error(&e.to_string()),
724 }
725 }
726 Err(e) => Message::error(&e.to_string()),
727 }
728 }
729
730 fn handle_get(&self, payload: &[u8]) -> Message {
732 let txn = match self.db.begin_transaction() {
734 Ok(t) => t,
735 Err(e) => return Message::error(&e.to_string()),
736 };
737
738 let result = self.db.get(txn, payload);
739 let _ = self.db.abort(txn); match result {
742 Ok(Some(value)) => Message::value(value),
743 Ok(None) => Message::new(opcode::VALUE, vec![]),
744 Err(e) => Message::error(&e.to_string()),
745 }
746 }
747
748 fn handle_delete(&self, payload: &[u8]) -> Message {
750 let txn = match self.db.begin_transaction() {
751 Ok(t) => t,
752 Err(e) => return Message::error(&e.to_string()),
753 };
754
755 if let Err(e) = self.db.delete(txn, payload) {
756 let _ = self.db.abort(txn);
757 return Message::error(&e.to_string());
758 }
759
760 match self.db.commit(txn) {
761 Ok(_) => Message::ok(),
762 Err(e) => Message::error(&e.to_string()),
763 }
764 }
765
766 fn handle_begin_txn(&mut self) -> Message {
767 match self.db.begin_transaction() {
768 Ok(txn) => {
769 let client_txn_id = self.next_client_txn_id;
770 self.next_client_txn_id += 1;
771 self.active_txns.insert(client_txn_id, txn);
772 Message::txn_id(client_txn_id)
773 }
774 Err(e) => Message::error(&e.to_string()),
775 }
776 }
777
778 fn handle_commit_txn(&mut self, payload: &[u8]) -> Message {
779 if payload.len() < 8 {
780 return Message::error("COMMIT_TXN requires txn_id");
781 }
782 let client_txn_id = u64::from_le_bytes(payload[0..8].try_into().unwrap());
783
784 match self.active_txns.remove(&client_txn_id) {
785 Some(txn) => match self.db.commit(txn) {
786 Ok(commit_ts) => Message::txn_id(commit_ts),
787 Err(e) => Message::error(&e.to_string()),
788 },
789 None => Message::error(&format!("Transaction not found: {}", client_txn_id)),
790 }
791 }
792
793 fn handle_abort_txn(&mut self, payload: &[u8]) -> Message {
794 if payload.len() < 8 {
795 return Message::error("ABORT_TXN requires txn_id");
796 }
797 let client_txn_id = u64::from_le_bytes(payload[0..8].try_into().unwrap());
798
799 match self.active_txns.remove(&client_txn_id) {
800 Some(txn) => match self.db.abort(txn) {
801 Ok(_) => Message::ok(),
802 Err(e) => Message::error(&e.to_string()),
803 },
804 None => Message::error(&format!("Transaction not found: {}", client_txn_id)),
805 }
806 }
807
808 fn handle_put_path(&self, payload: &[u8]) -> Message {
809 match decode_path(payload) {
810 Ok((path, value)) => {
811 let txn = match self.db.begin_transaction() {
812 Ok(t) => t,
813 Err(e) => return Message::error(&e.to_string()),
814 };
815
816 let path_str = path.join("/");
817 if let Err(e) = self.db.put_path(txn, &path_str, value) {
818 let _ = self.db.abort(txn);
819 return Message::error(&e.to_string());
820 }
821
822 match self.db.commit(txn) {
823 Ok(_) => Message::ok(),
824 Err(e) => Message::error(&e.to_string()),
825 }
826 }
827 Err(e) => Message::error(&e.to_string()),
828 }
829 }
830
831 fn handle_get_path(&self, payload: &[u8]) -> Message {
832 match decode_path(payload) {
833 Ok((path, _)) => {
834 let txn = match self.db.begin_transaction() {
835 Ok(t) => t,
836 Err(e) => return Message::error(&e.to_string()),
837 };
838
839 let path_str = path.join("/");
840 let result = self.db.get_path(txn, &path_str);
841 let _ = self.db.abort(txn);
842
843 match result {
844 Ok(Some(value)) => Message::value(value),
845 Ok(None) => Message::new(opcode::VALUE, vec![]),
846 Err(e) => Message::error(&e.to_string()),
847 }
848 }
849 Err(e) => Message::error(&e.to_string()),
850 }
851 }
852
853 fn handle_checkpoint(&self) -> Message {
854 match self.db.checkpoint() {
855 Ok(_) => Message::ok(),
856 Err(e) => Message::error(&e.to_string()),
857 }
858 }
859
860 fn handle_stats(&self) -> Message {
861 let stats = self.stats.snapshot();
862 let stats_json = format!(
864 r#"{{"connections_total":{},"connections_active":{},"requests_total":{},"requests_success":{},"requests_error":{},"bytes_received":{},"bytes_sent":{},"uptime_secs":{},"memtable_size_bytes":0,"wal_size_bytes":0,"active_transactions":{}}}"#,
865 stats.connections_total,
866 stats.connections_active,
867 stats.requests_total,
868 stats.requests_success,
869 stats.requests_error,
870 stats.bytes_received,
871 stats.bytes_sent,
872 stats.uptime_secs,
873 self.active_txns.len()
874 );
875 Message::new(opcode::STATS_RESP, stats_json.into_bytes())
876 }
877
878 fn cleanup_transactions(&mut self) {
879 for (_client_id, txn) in self.active_txns.drain() {
881 let _ = self.db.abort(txn);
882 }
883 }
884}
885
886#[derive(Debug, Clone)]
892pub struct IpcServerConfig {
893 pub socket_path: PathBuf,
895
896 pub max_connections: usize,
898
899 pub thread_pool_size: usize,
901
902 pub connection_timeout_secs: u64,
904}
905
906impl Default for IpcServerConfig {
907 fn default() -> Self {
908 Self {
909 socket_path: PathBuf::from("/tmp/sochdb.sock"),
910 max_connections: 100,
911 thread_pool_size: 4,
912 connection_timeout_secs: 300, }
914 }
915}
916
917fn secure_socket_path(path: &Path) {
924 if let Err(e) = std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600)) {
925 eprintln!(
926 "[IpcServer] WARNING: failed to restrict socket permissions on {:?}: {}",
927 path, e
928 );
929 }
930}
931
932impl IpcServerConfig {
933 pub fn with_socket_path<P: AsRef<Path>>(mut self, path: P) -> Self {
934 self.socket_path = path.as_ref().to_path_buf();
935 self
936 }
937
938 pub fn with_max_connections(mut self, max: usize) -> Self {
939 self.max_connections = max;
940 self
941 }
942}
943
944pub struct IpcServer {
946 db: Arc<Database>,
947 config: IpcServerConfig,
948 stats: Arc<ServerStats>,
949 running: Arc<AtomicBool>,
950 listener_handle: Mutex<Option<JoinHandle<()>>>,
951}
952
953impl IpcServer {
954 pub fn new(db: Arc<Database>, config: IpcServerConfig) -> Self {
956 Self {
957 db,
958 config,
959 stats: Arc::new(ServerStats::new()),
960 running: Arc::new(AtomicBool::new(false)),
961 listener_handle: Mutex::new(None),
962 }
963 }
964
965 pub fn with_defaults(db: Arc<Database>) -> Self {
967 Self::new(db, IpcServerConfig::default())
968 }
969
970 pub fn run(&self) -> Result<()> {
972 if self.running.swap(true, Ordering::SeqCst) {
973 return Err(IpcError::AlreadyRunning);
974 }
975
976 if self.config.socket_path.exists() {
978 std::fs::remove_file(&self.config.socket_path)?;
979 }
980
981 let listener = UnixListener::bind(&self.config.socket_path)?;
983 listener.set_nonblocking(false)?;
984
985 secure_socket_path(&self.config.socket_path);
987
988 *self.stats.start_time.lock() = Some(Instant::now());
990
991 eprintln!("[IpcServer] Listening on {:?}", self.config.socket_path);
992
993 while self.running.load(Ordering::SeqCst) {
995 match listener.accept() {
996 Ok((stream, _addr)) => {
997 let active = self.stats.connections_active.load(Ordering::Relaxed);
999 if active >= self.config.max_connections as u64 {
1000 eprintln!("[IpcServer] Connection limit reached, rejecting");
1001 continue;
1002 }
1003
1004 self.stats.connections_total.fetch_add(1, Ordering::Relaxed);
1005 self.stats
1006 .connections_active
1007 .fetch_add(1, Ordering::Relaxed);
1008
1009 let db = Arc::clone(&self.db);
1010 let stats = Arc::clone(&self.stats);
1011
1012 thread::spawn(move || {
1014 let mut handler = ClientHandler::new(db, stream, Arc::clone(&stats));
1015 if let Err(e) = handler.handle() {
1016 eprintln!("[IpcServer] Client error: {}", e);
1017 }
1018 stats.connections_active.fetch_sub(1, Ordering::Relaxed);
1019 });
1020 }
1021 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1022 thread::sleep(Duration::from_millis(100));
1024 }
1025 Err(e) => {
1026 eprintln!("[IpcServer] Accept error: {}", e);
1027 }
1028 }
1029 }
1030
1031 let _ = std::fs::remove_file(&self.config.socket_path);
1033
1034 Ok(())
1035 }
1036
1037 pub fn start(&self) -> Result<()> {
1039 if self.running.swap(true, Ordering::SeqCst) {
1040 return Err(IpcError::AlreadyRunning);
1041 }
1042
1043 let db = Arc::clone(&self.db);
1044 let config = self.config.clone();
1045 let stats = Arc::clone(&self.stats);
1046 let running = Arc::clone(&self.running);
1047
1048 let handle = thread::spawn(move || {
1049 if config.socket_path.exists() {
1052 let _ = std::fs::remove_file(&config.socket_path);
1053 }
1054
1055 let listener = match UnixListener::bind(&config.socket_path) {
1057 Ok(l) => l,
1058 Err(e) => {
1059 eprintln!("[IpcServer] Failed to bind: {}", e);
1060 running.store(false, Ordering::SeqCst);
1061 return;
1062 }
1063 };
1064 let _ = listener.set_nonblocking(false);
1065
1066 secure_socket_path(&config.socket_path);
1068
1069 *stats.start_time.lock() = Some(Instant::now());
1071
1072 eprintln!("[IpcServer] Listening on {:?}", config.socket_path);
1073
1074 while running.load(Ordering::SeqCst) {
1076 match listener.accept() {
1077 Ok((stream, _addr)) => {
1078 let active = stats.connections_active.load(Ordering::Relaxed);
1080 if active >= config.max_connections as u64 {
1081 eprintln!("[IpcServer] Connection limit reached, rejecting");
1082 continue;
1083 }
1084
1085 stats.connections_total.fetch_add(1, Ordering::Relaxed);
1086 stats.connections_active.fetch_add(1, Ordering::Relaxed);
1087
1088 let db_clone = Arc::clone(&db);
1089 let stats_clone = Arc::clone(&stats);
1090
1091 thread::spawn(move || {
1093 let mut handler =
1094 ClientHandler::new(db_clone, stream, Arc::clone(&stats_clone));
1095 if let Err(e) = handler.handle() {
1096 eprintln!("[IpcServer] Client error: {}", e);
1097 }
1098 stats_clone
1099 .connections_active
1100 .fetch_sub(1, Ordering::Relaxed);
1101 });
1102 }
1103 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1104 thread::sleep(Duration::from_millis(100));
1106 }
1107 Err(e) => {
1108 eprintln!("[IpcServer] Accept error: {}", e);
1109 break;
1110 }
1111 }
1112 }
1113
1114 let _ = std::fs::remove_file(&config.socket_path);
1116 });
1117
1118 *self.listener_handle.lock() = Some(handle);
1119 Ok(())
1120 }
1121
1122 pub fn stop(&self) {
1124 self.running.store(false, Ordering::SeqCst);
1125
1126 let _ = UnixStream::connect(&self.config.socket_path);
1128
1129 if let Some(handle) = self.listener_handle.lock().take() {
1131 let _ = handle.join();
1132 }
1133 }
1134
1135 pub fn is_running(&self) -> bool {
1137 self.running.load(Ordering::SeqCst)
1138 }
1139
1140 pub fn stats(&self) -> ServerStatsSnapshot {
1142 self.stats.snapshot()
1143 }
1144
1145 pub fn socket_path(&self) -> &Path {
1147 &self.config.socket_path
1148 }
1149}
1150
1151impl Drop for IpcServer {
1152 fn drop(&mut self) {
1153 self.stop();
1154 }
1155}
1156
1157pub struct IpcClient {
1163 stream: UnixStream,
1164}
1165
1166impl IpcClient {
1167 pub fn connect<P: AsRef<Path>>(socket_path: P) -> Result<Self> {
1169 let stream = UnixStream::connect(socket_path)?;
1170 Ok(Self { stream })
1171 }
1172
1173 pub fn connect_with_timeout<P: AsRef<Path>>(socket_path: P, timeout: Duration) -> Result<Self> {
1175 let stream = UnixStream::connect(socket_path)?;
1176 stream.set_read_timeout(Some(timeout))?;
1177 stream.set_write_timeout(Some(timeout))?;
1178 Ok(Self { stream })
1179 }
1180
1181 fn request(&mut self, msg: Message) -> Result<Message> {
1183 msg.write_to(&mut self.stream)?;
1184 Message::read_from(&mut self.stream)
1185 }
1186
1187 pub fn ping(&mut self) -> Result<Duration> {
1189 let start = Instant::now();
1190 let resp = self.request(Message::new(opcode::PING, vec![]))?;
1191 if resp.opcode != opcode::PONG {
1192 return Err(IpcError::Protocol("Expected PONG".into()));
1193 }
1194 Ok(start.elapsed())
1195 }
1196
1197 pub fn put(&mut self, key: &[u8], value: &[u8]) -> Result<()> {
1199 let payload = encode_put(key, value);
1200 let resp = self.request(Message::new(opcode::PUT, payload))?;
1201 self.check_ok(resp)
1202 }
1203
1204 pub fn get(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
1206 let resp = self.request(Message::new(opcode::GET, key.to_vec()))?;
1207 match resp.opcode {
1208 opcode::VALUE if resp.payload.is_empty() => Ok(None),
1209 opcode::VALUE => Ok(Some(resp.payload)),
1210 opcode::ERROR => Err(IpcError::Database(
1211 String::from_utf8_lossy(&resp.payload).to_string(),
1212 )),
1213 _ => Err(IpcError::Protocol(format!(
1214 "Unexpected opcode: {:#x}",
1215 resp.opcode
1216 ))),
1217 }
1218 }
1219
1220 pub fn delete(&mut self, key: &[u8]) -> Result<()> {
1222 let resp = self.request(Message::new(opcode::DELETE, key.to_vec()))?;
1223 self.check_ok(resp)
1224 }
1225
1226 pub fn begin_txn(&mut self) -> Result<u64> {
1228 let resp = self.request(Message::new(opcode::BEGIN_TXN, vec![]))?;
1229 match resp.opcode {
1230 opcode::TXN_ID => {
1231 if resp.payload.len() >= 8 {
1232 Ok(u64::from_le_bytes(resp.payload[0..8].try_into().unwrap()))
1233 } else {
1234 Err(IpcError::Protocol("TXN_ID response too short".into()))
1235 }
1236 }
1237 opcode::ERROR => Err(IpcError::Database(
1238 String::from_utf8_lossy(&resp.payload).to_string(),
1239 )),
1240 _ => Err(IpcError::Protocol(format!(
1241 "Unexpected opcode: {:#x}",
1242 resp.opcode
1243 ))),
1244 }
1245 }
1246
1247 pub fn commit_txn(&mut self, txn_id: u64) -> Result<u64> {
1249 let resp = self.request(Message::new(
1250 opcode::COMMIT_TXN,
1251 txn_id.to_le_bytes().to_vec(),
1252 ))?;
1253 match resp.opcode {
1254 opcode::TXN_ID => {
1255 if resp.payload.len() >= 8 {
1256 Ok(u64::from_le_bytes(resp.payload[0..8].try_into().unwrap()))
1257 } else {
1258 Err(IpcError::Protocol("TXN_ID response too short".into()))
1259 }
1260 }
1261 opcode::ERROR => Err(IpcError::Database(
1262 String::from_utf8_lossy(&resp.payload).to_string(),
1263 )),
1264 _ => Err(IpcError::Protocol(format!(
1265 "Unexpected opcode: {:#x}",
1266 resp.opcode
1267 ))),
1268 }
1269 }
1270
1271 pub fn abort_txn(&mut self, txn_id: u64) -> Result<()> {
1273 let resp = self.request(Message::new(
1274 opcode::ABORT_TXN,
1275 txn_id.to_le_bytes().to_vec(),
1276 ))?;
1277 self.check_ok(resp)
1278 }
1279
1280 pub fn put_path(&mut self, path: &[&str], value: &[u8]) -> Result<()> {
1282 let payload = encode_put_path(path, value);
1283 let resp = self.request(Message::new(opcode::PUT_PATH, payload))?;
1284 self.check_ok(resp)
1285 }
1286
1287 pub fn get_path(&mut self, path: &[&str]) -> Result<Option<Vec<u8>>> {
1289 let payload = encode_put_path(path, &[]);
1290 let resp = self.request(Message::new(opcode::GET_PATH, payload))?;
1291 match resp.opcode {
1292 opcode::VALUE if resp.payload.is_empty() => Ok(None),
1293 opcode::VALUE => Ok(Some(resp.payload)),
1294 opcode::ERROR => Err(IpcError::Database(
1295 String::from_utf8_lossy(&resp.payload).to_string(),
1296 )),
1297 _ => Err(IpcError::Protocol(format!(
1298 "Unexpected opcode: {:#x}",
1299 resp.opcode
1300 ))),
1301 }
1302 }
1303
1304 pub fn checkpoint(&mut self) -> Result<()> {
1306 let resp = self.request(Message::new(opcode::CHECKPOINT, vec![]))?;
1307 self.check_ok(resp)
1308 }
1309
1310 pub fn stats(&mut self) -> Result<HashMap<String, u64>> {
1312 let resp = self.request(Message::new(opcode::STATS, vec![]))?;
1313 match resp.opcode {
1314 opcode::STATS_RESP => {
1315 let stats_str = String::from_utf8_lossy(&resp.payload);
1316
1317 let stats: HashMap<String, u64> =
1319 serde_json::from_str(&stats_str).map_err(|e| {
1320 IpcError::Protocol(format!("Failed to parse stats JSON: {}", e))
1321 })?;
1322
1323 Ok(stats)
1324 }
1325 opcode::ERROR => Err(IpcError::Database(
1326 String::from_utf8_lossy(&resp.payload).to_string(),
1327 )),
1328 _ => Err(IpcError::Protocol(format!(
1329 "Unexpected opcode: {:#x}",
1330 resp.opcode
1331 ))),
1332 }
1333 }
1334
1335 fn check_ok(&self, resp: Message) -> Result<()> {
1336 match resp.opcode {
1337 opcode::OK => Ok(()),
1338 opcode::ERROR => Err(IpcError::Database(
1339 String::from_utf8_lossy(&resp.payload).to_string(),
1340 )),
1341 _ => Err(IpcError::Protocol(format!(
1342 "Unexpected opcode: {:#x}",
1343 resp.opcode
1344 ))),
1345 }
1346 }
1347}
1348
1349#[cfg(test)]
1354mod tests {
1355 use super::*;
1356 use std::time::Duration;
1357 use tempfile::TempDir;
1358
1359 fn setup_test_server() -> (Arc<Database>, TempDir, PathBuf) {
1360 let temp_dir = TempDir::new().unwrap();
1361 let db_path = temp_dir.path().join("test.db");
1362 let socket_path = temp_dir.path().join("test.sock");
1363
1364 let db = Database::open(&db_path).unwrap();
1365 (db, temp_dir, socket_path)
1366 }
1367
1368 #[test]
1369 fn test_message_roundtrip() {
1370 let original = Message::new(0x01, b"hello world".to_vec());
1371
1372 let mut buffer = Vec::new();
1373 original.write_to(&mut buffer).unwrap();
1374
1375 let mut cursor = std::io::Cursor::new(buffer);
1376 let decoded = Message::read_from(&mut cursor).unwrap();
1377
1378 assert_eq!(decoded.opcode, original.opcode);
1379 assert_eq!(decoded.payload, original.payload);
1380 }
1381
1382 #[test]
1383 fn test_encode_decode_put() {
1384 let key = b"test-key";
1385 let value = b"test-value";
1386
1387 let encoded = encode_put(key, value);
1388 let (decoded_key, decoded_value) = decode_put(&encoded).unwrap();
1389
1390 assert_eq!(decoded_key, key);
1391 assert_eq!(decoded_value, value);
1392 }
1393
1394 #[test]
1395 fn test_encode_decode_path() {
1396 let path = vec!["users", "alice", "settings"];
1397 let value = b"preferences";
1398
1399 let encoded = encode_put_path(&path, value);
1400 let (decoded_path, decoded_value) = decode_path(&encoded).unwrap();
1401
1402 let expected_path: Vec<String> = path.iter().map(|s| s.to_string()).collect();
1403 assert_eq!(decoded_path, expected_path);
1404 assert_eq!(decoded_value, value);
1405 }
1406
1407 #[test]
1408 fn test_server_client_basic() {
1409 let (db, _temp_dir, socket_path) = setup_test_server();
1410
1411 let config = IpcServerConfig::default().with_socket_path(&socket_path);
1413 let server = IpcServer::new(Arc::clone(&db), config);
1414 server.start().unwrap();
1415
1416 thread::sleep(Duration::from_millis(100));
1418
1419 let mut client = IpcClient::connect(&socket_path).unwrap();
1421
1422 let latency = client.ping().unwrap();
1424 assert!(latency < Duration::from_secs(1));
1425
1426 client.put(b"key1", b"value1").unwrap();
1428 let value = client.get(b"key1").unwrap();
1429 assert_eq!(value, Some(b"value1".to_vec()));
1430
1431 let value = client.get(b"nonexistent").unwrap();
1433 assert_eq!(value, None);
1434
1435 client.delete(b"key1").unwrap();
1437 let value = client.get(b"key1").unwrap();
1438 assert_eq!(value, None);
1439
1440 server.stop();
1442 }
1443
1444 #[test]
1445 fn test_socket_permissions_are_owner_only() {
1446 let (db, _temp_dir, socket_path) = setup_test_server();
1447
1448 let config = IpcServerConfig::default().with_socket_path(&socket_path);
1449 let server = IpcServer::new(Arc::clone(&db), config);
1450 server.start().unwrap();
1451
1452 thread::sleep(Duration::from_millis(100));
1454
1455 let mode = std::fs::metadata(&socket_path)
1456 .unwrap()
1457 .permissions()
1458 .mode();
1459 assert_eq!(
1462 mode & 0o777,
1463 0o600,
1464 "socket permissions must be owner-only, got {:o}",
1465 mode & 0o777
1466 );
1467
1468 server.stop();
1469 }
1470
1471 #[test]
1472 fn test_server_client_transactions() {
1473 let (db, _temp_dir, socket_path) = setup_test_server();
1474
1475 let config = IpcServerConfig::default().with_socket_path(&socket_path);
1476 let server = IpcServer::new(Arc::clone(&db), config);
1477 server.start().unwrap();
1478
1479 thread::sleep(Duration::from_millis(100));
1480
1481 let mut client = IpcClient::connect(&socket_path).unwrap();
1482
1483 let txn_id = client.begin_txn().unwrap();
1485 assert!(txn_id > 0);
1486
1487 let commit_ts = client.commit_txn(txn_id).unwrap();
1489 assert!(commit_ts > 0);
1490
1491 let txn_id2 = client.begin_txn().unwrap();
1493 client.abort_txn(txn_id2).unwrap();
1494
1495 server.stop();
1496 }
1497
1498 #[test]
1499 fn test_server_client_paths() {
1500 let (db, _temp_dir, socket_path) = setup_test_server();
1501
1502 let config = IpcServerConfig::default().with_socket_path(&socket_path);
1503 let server = IpcServer::new(Arc::clone(&db), config);
1504 server.start().unwrap();
1505
1506 thread::sleep(Duration::from_millis(100));
1507
1508 let mut client = IpcClient::connect(&socket_path).unwrap();
1509
1510 client
1512 .put_path(&["users", "alice", "email"], b"alice@example.com")
1513 .unwrap();
1514
1515 let value = client.get_path(&["users", "alice", "email"]).unwrap();
1517 assert_eq!(value, Some(b"alice@example.com".to_vec()));
1518
1519 let value = client.get_path(&["users", "bob", "email"]).unwrap();
1521 assert_eq!(value, None);
1522
1523 server.stop();
1524 }
1525
1526 #[test]
1527 fn test_server_stats() {
1528 let (db, _temp_dir, socket_path) = setup_test_server();
1529
1530 let config = IpcServerConfig::default().with_socket_path(&socket_path);
1531 let server = IpcServer::new(Arc::clone(&db), config);
1532 server.start().unwrap();
1533
1534 thread::sleep(Duration::from_millis(100));
1535
1536 let mut client = IpcClient::connect(&socket_path).unwrap();
1537
1538 client.ping().unwrap();
1540 client.put(b"k", b"v").unwrap();
1541 client.get(b"k").unwrap();
1542
1543 let stats = client.stats().unwrap();
1545 assert!(stats.contains_key("requests_total"));
1546 assert!(*stats.get("requests_total").unwrap() >= 4);
1547
1548 let server_stats = server.stats();
1550 assert!(server_stats.requests_total >= 4);
1551 assert!(server_stats.connections_active >= 1);
1552
1553 server.stop();
1554 }
1555
1556 #[test]
1557 fn test_multiple_clients() {
1558 let (db, _temp_dir, socket_path) = setup_test_server();
1559
1560 let config = IpcServerConfig::default()
1561 .with_socket_path(&socket_path)
1562 .with_max_connections(10);
1563 let server = IpcServer::new(Arc::clone(&db), config);
1564 server.start().unwrap();
1565
1566 thread::sleep(Duration::from_millis(100));
1567
1568 let mut handles = Vec::new();
1570 let socket_path_clone = socket_path.clone();
1571
1572 for i in 0..5 {
1573 let path = socket_path_clone.clone();
1574 let handle = thread::spawn(move || {
1575 let mut client = IpcClient::connect(&path).unwrap();
1576 let key = format!("key-{}", i);
1577 let value = format!("value-{}", i);
1578
1579 client.put(key.as_bytes(), value.as_bytes()).unwrap();
1580 let result = client.get(key.as_bytes()).unwrap();
1581 assert_eq!(result, Some(value.into_bytes()));
1582 });
1583 handles.push(handle);
1584 }
1585
1586 for handle in handles {
1587 handle.join().unwrap();
1588 }
1589
1590 let stats = server.stats();
1591 assert_eq!(stats.connections_total, 5);
1592
1593 server.stop();
1594 }
1595}