1use crate::database::{Database, TxnHandle};
84use parking_lot::Mutex;
85use std::collections::HashMap;
86use std::io::{BufReader, BufWriter, Read, Write};
87use std::os::unix::net::{UnixListener, UnixStream};
88use std::path::{Path, PathBuf};
89use std::sync::Arc;
90use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
91use std::thread::{self, JoinHandle};
92use std::time::{Duration, Instant};
93use thiserror::Error;
94
95mod opcode {
101 pub const PUT: u8 = 0x01;
102 pub const GET: u8 = 0x02;
103 pub const DELETE: u8 = 0x03;
104 pub const BEGIN_TXN: u8 = 0x04;
105 pub const COMMIT_TXN: u8 = 0x05;
106 pub const ABORT_TXN: u8 = 0x06;
107 pub const QUERY: u8 = 0x07;
108 pub const CREATE_TABLE: u8 = 0x08;
109 pub const PUT_PATH: u8 = 0x09;
110 pub const GET_PATH: u8 = 0x0A;
111 pub const SCAN: u8 = 0x0B;
112 pub const CHECKPOINT: u8 = 0x0C;
113 pub const STATS: u8 = 0x0D;
114 pub const PING: u8 = 0x0E;
115 pub const EXECUTE_SQL: u8 = 0x0F;
116
117 pub const OK: u8 = 0x80;
119 pub const ERROR: u8 = 0x81;
120 pub const VALUE: u8 = 0x82;
121 pub const TXN_ID: u8 = 0x83;
122 #[allow(dead_code)]
123 pub const ROW: u8 = 0x84;
124 #[allow(dead_code)]
125 pub const END_STREAM: u8 = 0x85;
126 pub const STATS_RESP: u8 = 0x86;
127 pub const PONG: u8 = 0x87;
128}
129
130const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
132
133#[derive(Debug, Error)]
138pub enum IpcError {
139 #[error("I/O error: {0}")]
140 Io(#[from] std::io::Error),
141
142 #[error("Database error: {0}")]
143 Database(String),
144
145 #[error("Protocol error: {0}")]
146 Protocol(String),
147
148 #[error("Server already running")]
149 AlreadyRunning,
150
151 #[error("Server not running")]
152 NotRunning,
153
154 #[error("Connection closed")]
155 ConnectionClosed,
156
157 #[error("Message too large: {0} bytes (max: {1})")]
158 MessageTooLarge(usize, usize),
159
160 #[error("Invalid opcode: {0:#x}")]
161 InvalidOpcode(u8),
162
163 #[error("Transaction not found: {0}")]
164 TxnNotFound(u64),
165}
166
167pub type Result<T> = std::result::Result<T, IpcError>;
168
169#[derive(Debug, Clone)]
175pub struct Message {
176 pub opcode: u8,
177 pub payload: Vec<u8>,
178}
179
180impl Message {
181 pub fn new(opcode: u8, payload: Vec<u8>) -> Self {
182 Self { opcode, payload }
183 }
184
185 pub fn ok() -> Self {
186 Self::new(opcode::OK, vec![])
187 }
188
189 pub fn error(msg: &str) -> Self {
190 Self::new(opcode::ERROR, msg.as_bytes().to_vec())
191 }
192
193 pub fn value(data: Vec<u8>) -> Self {
194 Self::new(opcode::VALUE, data)
195 }
196
197 pub fn txn_id(id: u64) -> Self {
198 Self::new(opcode::TXN_ID, id.to_le_bytes().to_vec())
199 }
200
201 pub fn read_from<R: Read>(reader: &mut R) -> Result<Self> {
203 let mut opcode_buf = [0u8; 1];
205 match reader.read_exact(&mut opcode_buf) {
206 Ok(_) => {}
207 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
208 return Err(IpcError::ConnectionClosed);
209 }
210 Err(e) => return Err(e.into()),
211 }
212 let opcode = opcode_buf[0];
213
214 let mut len_buf = [0u8; 4];
216 reader.read_exact(&mut len_buf)?;
217 let len = u32::from_le_bytes(len_buf) as usize;
218
219 if len > MAX_MESSAGE_SIZE {
221 return Err(IpcError::MessageTooLarge(len, MAX_MESSAGE_SIZE));
222 }
223
224 let mut payload = vec![0u8; len];
226 if len > 0 {
227 reader.read_exact(&mut payload)?;
228 }
229
230 Ok(Self { opcode, payload })
231 }
232
233 pub fn write_to<W: Write>(&self, writer: &mut W) -> Result<()> {
235 writer.write_all(&[self.opcode])?;
237
238 let len = self.payload.len() as u32;
240 writer.write_all(&len.to_le_bytes())?;
241
242 if !self.payload.is_empty() {
244 writer.write_all(&self.payload)?;
245 }
246
247 writer.flush()?;
248 Ok(())
249 }
250}
251
252fn encode_put(key: &[u8], value: &[u8]) -> Vec<u8> {
258 let mut buf = Vec::with_capacity(4 + key.len() + value.len());
259 buf.extend_from_slice(&(key.len() as u32).to_le_bytes());
260 buf.extend_from_slice(key);
261 buf.extend_from_slice(value);
262 buf
263}
264
265fn decode_put(payload: &[u8]) -> Result<(&[u8], &[u8])> {
267 if payload.len() < 4 {
268 return Err(IpcError::Protocol("PUT payload too short".into()));
269 }
270 let key_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
271 if payload.len() < 4 + key_len {
272 return Err(IpcError::Protocol("PUT payload key truncated".into()));
273 }
274 let key = &payload[4..4 + key_len];
275 let value = &payload[4 + key_len..];
276 Ok((key, value))
277}
278
279fn encode_put_path(path: &[&str], value: &[u8]) -> Vec<u8> {
281 let mut buf = Vec::new();
282 buf.extend_from_slice(&(path.len() as u16).to_le_bytes());
283 for segment in path {
284 let seg_bytes = segment.as_bytes();
285 buf.extend_from_slice(&(seg_bytes.len() as u16).to_le_bytes());
286 buf.extend_from_slice(seg_bytes);
287 }
288 buf.extend_from_slice(value);
289 buf
290}
291
292fn decode_path(payload: &[u8]) -> Result<(Vec<String>, &[u8])> {
294 if payload.len() < 2 {
295 return Err(IpcError::Protocol("Path payload too short".into()));
296 }
297 let count = u16::from_le_bytes(payload[0..2].try_into().unwrap()) as usize;
298 let mut offset = 2;
299 let mut path = Vec::with_capacity(count);
300
301 for _ in 0..count {
302 if offset + 2 > payload.len() {
303 return Err(IpcError::Protocol("Path segment length truncated".into()));
304 }
305 let seg_len = u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
306 offset += 2;
307 if offset + seg_len > payload.len() {
308 return Err(IpcError::Protocol("Path segment truncated".into()));
309 }
310 let segment = std::str::from_utf8(&payload[offset..offset + seg_len])
311 .map_err(|_| IpcError::Protocol("Invalid UTF-8 in path".into()))?;
312 path.push(segment.to_string());
313 offset += seg_len;
314 }
315
316 Ok((path, &payload[offset..]))
317}
318
319#[derive(Debug, Default)]
324pub struct ServerStats {
325 pub connections_total: AtomicU64,
326 pub connections_active: AtomicU64,
327 pub requests_total: AtomicU64,
328 pub requests_success: AtomicU64,
329 pub requests_error: AtomicU64,
330 pub bytes_received: AtomicU64,
331 pub bytes_sent: AtomicU64,
332 pub start_time: Mutex<Option<Instant>>,
333}
334
335impl ServerStats {
336 pub fn new() -> Self {
337 Self::default()
338 }
339
340 pub fn snapshot(&self) -> ServerStatsSnapshot {
341 ServerStatsSnapshot {
342 connections_total: self.connections_total.load(Ordering::Relaxed),
343 connections_active: self.connections_active.load(Ordering::Relaxed),
344 requests_total: self.requests_total.load(Ordering::Relaxed),
345 requests_success: self.requests_success.load(Ordering::Relaxed),
346 requests_error: self.requests_error.load(Ordering::Relaxed),
347 bytes_received: self.bytes_received.load(Ordering::Relaxed),
348 bytes_sent: self.bytes_sent.load(Ordering::Relaxed),
349 uptime_secs: self
350 .start_time
351 .lock()
352 .map(|t| t.elapsed().as_secs())
353 .unwrap_or(0),
354 }
355 }
356}
357
358#[derive(Debug, Clone)]
359pub struct ServerStatsSnapshot {
360 pub connections_total: u64,
361 pub connections_active: u64,
362 pub requests_total: u64,
363 pub requests_success: u64,
364 pub requests_error: u64,
365 pub bytes_received: u64,
366 pub bytes_sent: u64,
367 pub uptime_secs: u64,
368}
369
370struct ClientHandler {
375 db: Arc<Database>,
376 stream: UnixStream,
377 stats: Arc<ServerStats>,
378 active_txns: HashMap<u64, TxnHandle>, next_client_txn_id: u64,
380}
381
382impl ClientHandler {
383 fn new(db: Arc<Database>, stream: UnixStream, stats: Arc<ServerStats>) -> Self {
384 Self {
385 db,
386 stream,
387 stats,
388 active_txns: HashMap::new(),
389 next_client_txn_id: 1,
390 }
391 }
392
393 fn handle(&mut self) -> Result<()> {
394 self.stream
396 .set_read_timeout(Some(Duration::from_secs(30)))?;
397
398 let mut reader = BufReader::new(self.stream.try_clone()?);
399 let mut writer = BufWriter::new(self.stream.try_clone()?);
400
401 loop {
402 let request = match Message::read_from(&mut reader) {
404 Ok(msg) => msg,
405 Err(IpcError::ConnectionClosed) => {
406 self.cleanup_transactions();
408 return Ok(());
409 }
410 Err(e) => return Err(e),
411 };
412
413 self.stats.requests_total.fetch_add(1, Ordering::Relaxed);
414 self.stats
415 .bytes_received
416 .fetch_add((5 + request.payload.len()) as u64, Ordering::Relaxed);
417
418 let response = self.process_request(&request);
420
421 if response.opcode == opcode::ERROR {
423 self.stats.requests_error.fetch_add(1, Ordering::Relaxed);
424 } else {
425 self.stats.requests_success.fetch_add(1, Ordering::Relaxed);
426 }
427
428 self.stats
430 .bytes_sent
431 .fetch_add((5 + response.payload.len()) as u64, Ordering::Relaxed);
432 response.write_to(&mut writer)?;
433 }
434 }
435
436 fn process_request(&mut self, request: &Message) -> Message {
437 match request.opcode {
438 opcode::PING => Message::new(opcode::PONG, vec![]),
439
440 opcode::PUT => self.handle_put(&request.payload),
441 opcode::GET => self.handle_get(&request.payload),
442 opcode::DELETE => self.handle_delete(&request.payload),
443
444 opcode::BEGIN_TXN => self.handle_begin_txn(),
445 opcode::COMMIT_TXN => self.handle_commit_txn(&request.payload),
446 opcode::ABORT_TXN => self.handle_abort_txn(&request.payload),
447
448 opcode::PUT_PATH => self.handle_put_path(&request.payload),
449 opcode::GET_PATH => self.handle_get_path(&request.payload),
450
451 opcode::QUERY => self.handle_query(&request.payload),
452 opcode::CREATE_TABLE => self.handle_create_table(&request.payload),
453 opcode::SCAN => self.handle_scan(&request.payload),
454 opcode::EXECUTE_SQL => self.handle_execute_sql(&request.payload),
455
456 opcode::CHECKPOINT => self.handle_checkpoint(),
457 opcode::STATS => self.handle_stats(),
458
459 _ => Message::error(&format!("Unknown opcode: {:#x}", request.opcode)),
460 }
461 }
462
463 fn handle_execute_sql(&self, payload: &[u8]) -> Message {
464 let sql = match std::str::from_utf8(payload) {
466 Ok(s) => s,
467 Err(_) => return Message::error("Invalid UTF-8 in SQL query"),
468 };
469
470 let result = serde_json::json!({
473 "error": "SQL execution must be implemented client-side. Use Python SDK for full SQL support.",
474 "sql": sql
475 });
476
477 match serde_json::to_vec(&result) {
478 Ok(json) => Message::value(json),
479 Err(e) => Message::error(&format!("Failed to serialize error: {}", e)),
480 }
481 }
482
483 fn handle_query(&self, payload: &[u8]) -> Message {
484 let mut offset = 0;
486
487 if payload.len() < 2 {
488 return Message::error("Query payload too short");
489 }
490
491 let path_len = u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
493 offset += 2;
494 if offset + path_len > payload.len() {
495 return Message::error("Query path truncated");
496 }
497 let path = match std::str::from_utf8(&payload[offset..offset + path_len]) {
498 Ok(s) => s,
499 Err(_) => return Message::error("Invalid UTF-8 in query path"),
500 };
501 offset += path_len;
502
503 if offset + 8 > payload.len() {
505 return Message::error("Query limit/offset truncated");
506 }
507 let limit_val =
508 u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
509 offset += 4;
510 let offset_val =
511 u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
512 offset += 4;
513
514 if offset + 2 > payload.len() {
516 return Message::error("Query columns count truncated");
517 }
518 let cols_count =
519 u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
520 offset += 2;
521
522 let mut columns = Vec::with_capacity(cols_count);
523 for _ in 0..cols_count {
524 if offset + 2 > payload.len() {
525 return Message::error("Query column length truncated");
526 }
527 let col_len =
528 u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
529 offset += 2;
530
531 if offset + col_len > payload.len() {
532 return Message::error("Query column truncated");
533 }
534 let col = match std::str::from_utf8(&payload[offset..offset + col_len]) {
535 Ok(s) => s.to_string(),
536 Err(_) => return Message::error("Invalid UTF-8 in query column"),
537 };
538 columns.push(col);
539 offset += col_len;
540 }
541
542 let txn = match self.db.begin_transaction() {
545 Ok(t) => t,
546 Err(e) => return Message::error(&e.to_string()),
547 };
548
549 let mut builder = self.db.query(txn, path);
550
551 if limit_val > 0 {
552 builder = builder.limit(limit_val);
553 }
554 if offset_val > 0 {
555 builder = builder.offset(offset_val);
556 }
557
558 if !columns.is_empty() {
559 let cols_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
560 builder = builder.columns(&cols_refs);
561 }
562
563 let result = builder.to_toon();
564 let _ = self.db.abort(txn); match result {
567 Ok(soch_str) => Message::new(opcode::VALUE, soch_str.into_bytes()),
568 Err(e) => Message::error(&e.to_string()),
569 }
570 }
571
572 fn handle_scan(&self, payload: &[u8]) -> Message {
573 let prefix = match std::str::from_utf8(payload) {
574 Ok(s) => s,
575 Err(_) => return Message::error("Invalid UTF-8 in scan prefix"),
576 };
577
578 let txn = match self.db.begin_transaction() {
579 Ok(t) => t,
580 Err(e) => return Message::error(&e.to_string()),
581 };
582
583 let result = self.db.scan_path(txn, prefix);
584 let _ = self.db.abort(txn);
585
586 match result {
587 Ok(items) => {
588 let mut buf = Vec::new();
592 buf.extend_from_slice(&(items.len() as u32).to_le_bytes());
593
594 for (key, val) in items {
595 let key_bytes = key.as_bytes();
596 buf.extend_from_slice(&(key_bytes.len() as u16).to_le_bytes());
597 buf.extend_from_slice(key_bytes);
598 buf.extend_from_slice(&(val.len() as u32).to_le_bytes());
599 buf.extend_from_slice(&val);
600 }
601
602 Message::new(opcode::VALUE, buf)
603 }
604 Err(e) => Message::error(&e.to_string()),
605 }
606 }
607
608 fn handle_create_table(&self, payload: &[u8]) -> Message {
609 let _schema_json = match std::str::from_utf8(payload) {
611 Ok(s) => s,
612 Err(_) => return Message::error("Invalid UTF-8 in schema"),
613 };
614
615 let mut offset = 0;
629 if payload.len() < 2 {
630 return Message::error("Schema payload too short");
631 }
632
633 let name_len = u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
634 offset += 2;
635 if offset + name_len > payload.len() {
636 return Message::error("Schema name truncated");
637 }
638 let name = match std::str::from_utf8(&payload[offset..offset + name_len]) {
639 Ok(s) => s.to_string(),
640 Err(_) => return Message::error("Invalid UTF-8 in schema name"),
641 };
642 offset += name_len;
643
644 if offset + 2 > payload.len() {
645 return Message::error("Schema column count truncated");
646 }
647 let col_count =
648 u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
649 offset += 2;
650
651 let mut columns = Vec::with_capacity(col_count);
652 for _ in 0..col_count {
653 if offset + 2 > payload.len() {
654 return Message::error("Column name length truncated");
655 }
656 let col_name_len =
657 u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
658 offset += 2;
659
660 if offset + col_name_len > payload.len() {
661 return Message::error("Column name truncated");
662 }
663 let col_name = match std::str::from_utf8(&payload[offset..offset + col_name_len]) {
664 Ok(s) => s.to_string(),
665 Err(_) => return Message::error("Invalid UTF-8 in column name"),
666 };
667 offset += col_name_len;
668
669 if offset + 2 > payload.len() {
670 return Message::error("Column type/nullable truncated");
671 }
672 let type_byte = payload[offset];
673 offset += 1;
674 let nullable_byte = payload[offset];
675 offset += 1;
676
677 let col_type = match type_byte {
678 0 => crate::database::ColumnType::Int64,
679 1 => crate::database::ColumnType::UInt64,
680 2 => crate::database::ColumnType::Float64,
681 3 => crate::database::ColumnType::Text,
682 4 => crate::database::ColumnType::Binary,
683 5 => crate::database::ColumnType::Bool,
684 _ => return Message::error("Invalid column type"),
685 };
686
687 columns.push(crate::database::ColumnDef {
688 name: col_name,
689 col_type,
690 nullable: nullable_byte != 0,
691 });
692 }
693
694 let schema = crate::database::TableSchema { name, columns };
695
696 match self.db.register_table(schema) {
697 Ok(_) => Message::ok(),
698 Err(e) => Message::error(&e.to_string()),
699 }
700 }
701
702 fn handle_put(&self, payload: &[u8]) -> Message {
704 match decode_put(payload) {
705 Ok((key, value)) => {
706 let txn = match self.db.begin_transaction() {
708 Ok(t) => t,
709 Err(e) => return Message::error(&e.to_string()),
710 };
711
712 if let Err(e) = self.db.put(txn, key, value) {
713 let _ = self.db.abort(txn);
714 return Message::error(&e.to_string());
715 }
716
717 match self.db.commit(txn) {
718 Ok(_) => Message::ok(),
719 Err(e) => Message::error(&e.to_string()),
720 }
721 }
722 Err(e) => Message::error(&e.to_string()),
723 }
724 }
725
726 fn handle_get(&self, payload: &[u8]) -> Message {
728 let txn = match self.db.begin_transaction() {
730 Ok(t) => t,
731 Err(e) => return Message::error(&e.to_string()),
732 };
733
734 let result = self.db.get(txn, payload);
735 let _ = self.db.abort(txn); match result {
738 Ok(Some(value)) => Message::value(value),
739 Ok(None) => Message::new(opcode::VALUE, vec![]),
740 Err(e) => Message::error(&e.to_string()),
741 }
742 }
743
744 fn handle_delete(&self, payload: &[u8]) -> Message {
746 let txn = match self.db.begin_transaction() {
747 Ok(t) => t,
748 Err(e) => return Message::error(&e.to_string()),
749 };
750
751 if let Err(e) = self.db.delete(txn, payload) {
752 let _ = self.db.abort(txn);
753 return Message::error(&e.to_string());
754 }
755
756 match self.db.commit(txn) {
757 Ok(_) => Message::ok(),
758 Err(e) => Message::error(&e.to_string()),
759 }
760 }
761
762 fn handle_begin_txn(&mut self) -> Message {
763 match self.db.begin_transaction() {
764 Ok(txn) => {
765 let client_txn_id = self.next_client_txn_id;
766 self.next_client_txn_id += 1;
767 self.active_txns.insert(client_txn_id, txn);
768 Message::txn_id(client_txn_id)
769 }
770 Err(e) => Message::error(&e.to_string()),
771 }
772 }
773
774 fn handle_commit_txn(&mut self, payload: &[u8]) -> Message {
775 if payload.len() < 8 {
776 return Message::error("COMMIT_TXN requires txn_id");
777 }
778 let client_txn_id = u64::from_le_bytes(payload[0..8].try_into().unwrap());
779
780 match self.active_txns.remove(&client_txn_id) {
781 Some(txn) => match self.db.commit(txn) {
782 Ok(commit_ts) => Message::txn_id(commit_ts),
783 Err(e) => Message::error(&e.to_string()),
784 },
785 None => Message::error(&format!("Transaction not found: {}", client_txn_id)),
786 }
787 }
788
789 fn handle_abort_txn(&mut self, payload: &[u8]) -> Message {
790 if payload.len() < 8 {
791 return Message::error("ABORT_TXN requires txn_id");
792 }
793 let client_txn_id = u64::from_le_bytes(payload[0..8].try_into().unwrap());
794
795 match self.active_txns.remove(&client_txn_id) {
796 Some(txn) => match self.db.abort(txn) {
797 Ok(_) => Message::ok(),
798 Err(e) => Message::error(&e.to_string()),
799 },
800 None => Message::error(&format!("Transaction not found: {}", client_txn_id)),
801 }
802 }
803
804 fn handle_put_path(&self, payload: &[u8]) -> Message {
805 match decode_path(payload) {
806 Ok((path, value)) => {
807 let txn = match self.db.begin_transaction() {
808 Ok(t) => t,
809 Err(e) => return Message::error(&e.to_string()),
810 };
811
812 let path_str = path.join("/");
813 if let Err(e) = self.db.put_path(txn, &path_str, value) {
814 let _ = self.db.abort(txn);
815 return Message::error(&e.to_string());
816 }
817
818 match self.db.commit(txn) {
819 Ok(_) => Message::ok(),
820 Err(e) => Message::error(&e.to_string()),
821 }
822 }
823 Err(e) => Message::error(&e.to_string()),
824 }
825 }
826
827 fn handle_get_path(&self, payload: &[u8]) -> Message {
828 match decode_path(payload) {
829 Ok((path, _)) => {
830 let txn = match self.db.begin_transaction() {
831 Ok(t) => t,
832 Err(e) => return Message::error(&e.to_string()),
833 };
834
835 let path_str = path.join("/");
836 let result = self.db.get_path(txn, &path_str);
837 let _ = self.db.abort(txn);
838
839 match result {
840 Ok(Some(value)) => Message::value(value),
841 Ok(None) => Message::new(opcode::VALUE, vec![]),
842 Err(e) => Message::error(&e.to_string()),
843 }
844 }
845 Err(e) => Message::error(&e.to_string()),
846 }
847 }
848
849 fn handle_checkpoint(&self) -> Message {
850 match self.db.checkpoint() {
851 Ok(_) => Message::ok(),
852 Err(e) => Message::error(&e.to_string()),
853 }
854 }
855
856 fn handle_stats(&self) -> Message {
857 let stats = self.stats.snapshot();
858 let stats_json = format!(
860 r#"{{"connections_total":{},"connections_active":{},"requests_total":{},"requests_success":{},"requests_error":{},"bytes_received":{},"bytes_sent":{},"uptime_secs":{},"memtable_size_bytes":0,"wal_size_bytes":0,"active_transactions":{}}}"#,
861 stats.connections_total,
862 stats.connections_active,
863 stats.requests_total,
864 stats.requests_success,
865 stats.requests_error,
866 stats.bytes_received,
867 stats.bytes_sent,
868 stats.uptime_secs,
869 self.active_txns.len()
870 );
871 Message::new(opcode::STATS_RESP, stats_json.into_bytes())
872 }
873
874 fn cleanup_transactions(&mut self) {
875 for (_client_id, txn) in self.active_txns.drain() {
877 let _ = self.db.abort(txn);
878 }
879 }
880}
881
882#[derive(Debug, Clone)]
888pub struct IpcServerConfig {
889 pub socket_path: PathBuf,
891
892 pub max_connections: usize,
894
895 pub thread_pool_size: usize,
897
898 pub connection_timeout_secs: u64,
900}
901
902impl Default for IpcServerConfig {
903 fn default() -> Self {
904 Self {
905 socket_path: PathBuf::from("/tmp/sochdb.sock"),
906 max_connections: 100,
907 thread_pool_size: 4,
908 connection_timeout_secs: 300, }
910 }
911}
912
913impl IpcServerConfig {
914 pub fn with_socket_path<P: AsRef<Path>>(mut self, path: P) -> Self {
915 self.socket_path = path.as_ref().to_path_buf();
916 self
917 }
918
919 pub fn with_max_connections(mut self, max: usize) -> Self {
920 self.max_connections = max;
921 self
922 }
923}
924
925pub struct IpcServer {
927 db: Arc<Database>,
928 config: IpcServerConfig,
929 stats: Arc<ServerStats>,
930 running: Arc<AtomicBool>,
931 listener_handle: Mutex<Option<JoinHandle<()>>>,
932}
933
934impl IpcServer {
935 pub fn new(db: Arc<Database>, config: IpcServerConfig) -> Self {
937 Self {
938 db,
939 config,
940 stats: Arc::new(ServerStats::new()),
941 running: Arc::new(AtomicBool::new(false)),
942 listener_handle: Mutex::new(None),
943 }
944 }
945
946 pub fn with_defaults(db: Arc<Database>) -> Self {
948 Self::new(db, IpcServerConfig::default())
949 }
950
951 pub fn run(&self) -> Result<()> {
953 if self.running.swap(true, Ordering::SeqCst) {
954 return Err(IpcError::AlreadyRunning);
955 }
956
957 if self.config.socket_path.exists() {
959 std::fs::remove_file(&self.config.socket_path)?;
960 }
961
962 let listener = UnixListener::bind(&self.config.socket_path)?;
964 listener.set_nonblocking(false)?;
965
966 *self.stats.start_time.lock() = Some(Instant::now());
968
969 eprintln!("[IpcServer] Listening on {:?}", self.config.socket_path);
970
971 while self.running.load(Ordering::SeqCst) {
973 match listener.accept() {
974 Ok((stream, _addr)) => {
975 let active = self.stats.connections_active.load(Ordering::Relaxed);
977 if active >= self.config.max_connections as u64 {
978 eprintln!("[IpcServer] Connection limit reached, rejecting");
979 continue;
980 }
981
982 self.stats.connections_total.fetch_add(1, Ordering::Relaxed);
983 self.stats
984 .connections_active
985 .fetch_add(1, Ordering::Relaxed);
986
987 let db = Arc::clone(&self.db);
988 let stats = Arc::clone(&self.stats);
989
990 thread::spawn(move || {
992 let mut handler = ClientHandler::new(db, stream, Arc::clone(&stats));
993 if let Err(e) = handler.handle() {
994 eprintln!("[IpcServer] Client error: {}", e);
995 }
996 stats.connections_active.fetch_sub(1, Ordering::Relaxed);
997 });
998 }
999 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1000 thread::sleep(Duration::from_millis(100));
1002 }
1003 Err(e) => {
1004 eprintln!("[IpcServer] Accept error: {}", e);
1005 }
1006 }
1007 }
1008
1009 let _ = std::fs::remove_file(&self.config.socket_path);
1011
1012 Ok(())
1013 }
1014
1015 pub fn start(&self) -> Result<()> {
1017 if self.running.swap(true, Ordering::SeqCst) {
1018 return Err(IpcError::AlreadyRunning);
1019 }
1020
1021 let db = Arc::clone(&self.db);
1022 let config = self.config.clone();
1023 let stats = Arc::clone(&self.stats);
1024 let running = Arc::clone(&self.running);
1025
1026 let handle = thread::spawn(move || {
1027 if config.socket_path.exists() {
1030 let _ = std::fs::remove_file(&config.socket_path);
1031 }
1032
1033 let listener = match UnixListener::bind(&config.socket_path) {
1035 Ok(l) => l,
1036 Err(e) => {
1037 eprintln!("[IpcServer] Failed to bind: {}", e);
1038 running.store(false, Ordering::SeqCst);
1039 return;
1040 }
1041 };
1042 let _ = listener.set_nonblocking(false);
1043
1044 *stats.start_time.lock() = Some(Instant::now());
1046
1047 eprintln!("[IpcServer] Listening on {:?}", config.socket_path);
1048
1049 while running.load(Ordering::SeqCst) {
1051 match listener.accept() {
1052 Ok((stream, _addr)) => {
1053 let active = stats.connections_active.load(Ordering::Relaxed);
1055 if active >= config.max_connections as u64 {
1056 eprintln!("[IpcServer] Connection limit reached, rejecting");
1057 continue;
1058 }
1059
1060 stats.connections_total.fetch_add(1, Ordering::Relaxed);
1061 stats.connections_active.fetch_add(1, Ordering::Relaxed);
1062
1063 let db_clone = Arc::clone(&db);
1064 let stats_clone = Arc::clone(&stats);
1065
1066 thread::spawn(move || {
1068 let mut handler =
1069 ClientHandler::new(db_clone, stream, Arc::clone(&stats_clone));
1070 if let Err(e) = handler.handle() {
1071 eprintln!("[IpcServer] Client error: {}", e);
1072 }
1073 stats_clone
1074 .connections_active
1075 .fetch_sub(1, Ordering::Relaxed);
1076 });
1077 }
1078 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1079 thread::sleep(Duration::from_millis(100));
1081 }
1082 Err(e) => {
1083 eprintln!("[IpcServer] Accept error: {}", e);
1084 break;
1085 }
1086 }
1087 }
1088
1089 let _ = std::fs::remove_file(&config.socket_path);
1091 });
1092
1093 *self.listener_handle.lock() = Some(handle);
1094 Ok(())
1095 }
1096
1097 pub fn stop(&self) {
1099 self.running.store(false, Ordering::SeqCst);
1100
1101 let _ = UnixStream::connect(&self.config.socket_path);
1103
1104 if let Some(handle) = self.listener_handle.lock().take() {
1106 let _ = handle.join();
1107 }
1108 }
1109
1110 pub fn is_running(&self) -> bool {
1112 self.running.load(Ordering::SeqCst)
1113 }
1114
1115 pub fn stats(&self) -> ServerStatsSnapshot {
1117 self.stats.snapshot()
1118 }
1119
1120 pub fn socket_path(&self) -> &Path {
1122 &self.config.socket_path
1123 }
1124}
1125
1126impl Drop for IpcServer {
1127 fn drop(&mut self) {
1128 self.stop();
1129 }
1130}
1131
1132pub struct IpcClient {
1138 stream: UnixStream,
1139}
1140
1141impl IpcClient {
1142 pub fn connect<P: AsRef<Path>>(socket_path: P) -> Result<Self> {
1144 let stream = UnixStream::connect(socket_path)?;
1145 Ok(Self { stream })
1146 }
1147
1148 pub fn connect_with_timeout<P: AsRef<Path>>(socket_path: P, timeout: Duration) -> Result<Self> {
1150 let stream = UnixStream::connect(socket_path)?;
1151 stream.set_read_timeout(Some(timeout))?;
1152 stream.set_write_timeout(Some(timeout))?;
1153 Ok(Self { stream })
1154 }
1155
1156 fn request(&mut self, msg: Message) -> Result<Message> {
1158 msg.write_to(&mut self.stream)?;
1159 Message::read_from(&mut self.stream)
1160 }
1161
1162 pub fn ping(&mut self) -> Result<Duration> {
1164 let start = Instant::now();
1165 let resp = self.request(Message::new(opcode::PING, vec![]))?;
1166 if resp.opcode != opcode::PONG {
1167 return Err(IpcError::Protocol("Expected PONG".into()));
1168 }
1169 Ok(start.elapsed())
1170 }
1171
1172 pub fn put(&mut self, key: &[u8], value: &[u8]) -> Result<()> {
1174 let payload = encode_put(key, value);
1175 let resp = self.request(Message::new(opcode::PUT, payload))?;
1176 self.check_ok(resp)
1177 }
1178
1179 pub fn get(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
1181 let resp = self.request(Message::new(opcode::GET, key.to_vec()))?;
1182 match resp.opcode {
1183 opcode::VALUE if resp.payload.is_empty() => Ok(None),
1184 opcode::VALUE => Ok(Some(resp.payload)),
1185 opcode::ERROR => Err(IpcError::Database(
1186 String::from_utf8_lossy(&resp.payload).to_string(),
1187 )),
1188 _ => Err(IpcError::Protocol(format!(
1189 "Unexpected opcode: {:#x}",
1190 resp.opcode
1191 ))),
1192 }
1193 }
1194
1195 pub fn delete(&mut self, key: &[u8]) -> Result<()> {
1197 let resp = self.request(Message::new(opcode::DELETE, key.to_vec()))?;
1198 self.check_ok(resp)
1199 }
1200
1201 pub fn begin_txn(&mut self) -> Result<u64> {
1203 let resp = self.request(Message::new(opcode::BEGIN_TXN, vec![]))?;
1204 match resp.opcode {
1205 opcode::TXN_ID => {
1206 if resp.payload.len() >= 8 {
1207 Ok(u64::from_le_bytes(resp.payload[0..8].try_into().unwrap()))
1208 } else {
1209 Err(IpcError::Protocol("TXN_ID response too short".into()))
1210 }
1211 }
1212 opcode::ERROR => Err(IpcError::Database(
1213 String::from_utf8_lossy(&resp.payload).to_string(),
1214 )),
1215 _ => Err(IpcError::Protocol(format!(
1216 "Unexpected opcode: {:#x}",
1217 resp.opcode
1218 ))),
1219 }
1220 }
1221
1222 pub fn commit_txn(&mut self, txn_id: u64) -> Result<u64> {
1224 let resp = self.request(Message::new(
1225 opcode::COMMIT_TXN,
1226 txn_id.to_le_bytes().to_vec(),
1227 ))?;
1228 match resp.opcode {
1229 opcode::TXN_ID => {
1230 if resp.payload.len() >= 8 {
1231 Ok(u64::from_le_bytes(resp.payload[0..8].try_into().unwrap()))
1232 } else {
1233 Err(IpcError::Protocol("TXN_ID response too short".into()))
1234 }
1235 }
1236 opcode::ERROR => Err(IpcError::Database(
1237 String::from_utf8_lossy(&resp.payload).to_string(),
1238 )),
1239 _ => Err(IpcError::Protocol(format!(
1240 "Unexpected opcode: {:#x}",
1241 resp.opcode
1242 ))),
1243 }
1244 }
1245
1246 pub fn abort_txn(&mut self, txn_id: u64) -> Result<()> {
1248 let resp = self.request(Message::new(
1249 opcode::ABORT_TXN,
1250 txn_id.to_le_bytes().to_vec(),
1251 ))?;
1252 self.check_ok(resp)
1253 }
1254
1255 pub fn put_path(&mut self, path: &[&str], value: &[u8]) -> Result<()> {
1257 let payload = encode_put_path(path, value);
1258 let resp = self.request(Message::new(opcode::PUT_PATH, payload))?;
1259 self.check_ok(resp)
1260 }
1261
1262 pub fn get_path(&mut self, path: &[&str]) -> Result<Option<Vec<u8>>> {
1264 let payload = encode_put_path(path, &[]);
1265 let resp = self.request(Message::new(opcode::GET_PATH, payload))?;
1266 match resp.opcode {
1267 opcode::VALUE if resp.payload.is_empty() => Ok(None),
1268 opcode::VALUE => Ok(Some(resp.payload)),
1269 opcode::ERROR => Err(IpcError::Database(
1270 String::from_utf8_lossy(&resp.payload).to_string(),
1271 )),
1272 _ => Err(IpcError::Protocol(format!(
1273 "Unexpected opcode: {:#x}",
1274 resp.opcode
1275 ))),
1276 }
1277 }
1278
1279 pub fn checkpoint(&mut self) -> Result<()> {
1281 let resp = self.request(Message::new(opcode::CHECKPOINT, vec![]))?;
1282 self.check_ok(resp)
1283 }
1284
1285 pub fn stats(&mut self) -> Result<HashMap<String, u64>> {
1287 let resp = self.request(Message::new(opcode::STATS, vec![]))?;
1288 match resp.opcode {
1289 opcode::STATS_RESP => {
1290 let stats_str = String::from_utf8_lossy(&resp.payload);
1291
1292 let stats: HashMap<String, u64> = serde_json::from_str(&stats_str)
1294 .map_err(|e| IpcError::Protocol(format!("Failed to parse stats JSON: {}", e)))?;
1295
1296 Ok(stats)
1297 }
1298 opcode::ERROR => Err(IpcError::Database(
1299 String::from_utf8_lossy(&resp.payload).to_string(),
1300 )),
1301 _ => Err(IpcError::Protocol(format!(
1302 "Unexpected opcode: {:#x}",
1303 resp.opcode
1304 ))),
1305 }
1306 }
1307
1308 fn check_ok(&self, resp: Message) -> Result<()> {
1309 match resp.opcode {
1310 opcode::OK => Ok(()),
1311 opcode::ERROR => Err(IpcError::Database(
1312 String::from_utf8_lossy(&resp.payload).to_string(),
1313 )),
1314 _ => Err(IpcError::Protocol(format!(
1315 "Unexpected opcode: {:#x}",
1316 resp.opcode
1317 ))),
1318 }
1319 }
1320}
1321
1322#[cfg(test)]
1327mod tests {
1328 use super::*;
1329 use std::time::Duration;
1330 use tempfile::TempDir;
1331
1332 fn setup_test_server() -> (Arc<Database>, TempDir, PathBuf) {
1333 let temp_dir = TempDir::new().unwrap();
1334 let db_path = temp_dir.path().join("test.db");
1335 let socket_path = temp_dir.path().join("test.sock");
1336
1337 let db = Database::open(&db_path).unwrap();
1338 (db, temp_dir, socket_path)
1339 }
1340
1341 #[test]
1342 fn test_message_roundtrip() {
1343 let original = Message::new(0x01, b"hello world".to_vec());
1344
1345 let mut buffer = Vec::new();
1346 original.write_to(&mut buffer).unwrap();
1347
1348 let mut cursor = std::io::Cursor::new(buffer);
1349 let decoded = Message::read_from(&mut cursor).unwrap();
1350
1351 assert_eq!(decoded.opcode, original.opcode);
1352 assert_eq!(decoded.payload, original.payload);
1353 }
1354
1355 #[test]
1356 fn test_encode_decode_put() {
1357 let key = b"test-key";
1358 let value = b"test-value";
1359
1360 let encoded = encode_put(key, value);
1361 let (decoded_key, decoded_value) = decode_put(&encoded).unwrap();
1362
1363 assert_eq!(decoded_key, key);
1364 assert_eq!(decoded_value, value);
1365 }
1366
1367 #[test]
1368 fn test_encode_decode_path() {
1369 let path = vec!["users", "alice", "settings"];
1370 let value = b"preferences";
1371
1372 let encoded = encode_put_path(&path, value);
1373 let (decoded_path, decoded_value) = decode_path(&encoded).unwrap();
1374
1375 let expected_path: Vec<String> = path.iter().map(|s| s.to_string()).collect();
1376 assert_eq!(decoded_path, expected_path);
1377 assert_eq!(decoded_value, value);
1378 }
1379
1380 #[test]
1381 fn test_server_client_basic() {
1382 let (db, _temp_dir, socket_path) = setup_test_server();
1383
1384 let config = IpcServerConfig::default().with_socket_path(&socket_path);
1386 let server = IpcServer::new(Arc::clone(&db), config);
1387 server.start().unwrap();
1388
1389 thread::sleep(Duration::from_millis(100));
1391
1392 let mut client = IpcClient::connect(&socket_path).unwrap();
1394
1395 let latency = client.ping().unwrap();
1397 assert!(latency < Duration::from_secs(1));
1398
1399 client.put(b"key1", b"value1").unwrap();
1401 let value = client.get(b"key1").unwrap();
1402 assert_eq!(value, Some(b"value1".to_vec()));
1403
1404 let value = client.get(b"nonexistent").unwrap();
1406 assert_eq!(value, None);
1407
1408 client.delete(b"key1").unwrap();
1410 let value = client.get(b"key1").unwrap();
1411 assert_eq!(value, None);
1412
1413 server.stop();
1415 }
1416
1417 #[test]
1418 fn test_server_client_transactions() {
1419 let (db, _temp_dir, socket_path) = setup_test_server();
1420
1421 let config = IpcServerConfig::default().with_socket_path(&socket_path);
1422 let server = IpcServer::new(Arc::clone(&db), config);
1423 server.start().unwrap();
1424
1425 thread::sleep(Duration::from_millis(100));
1426
1427 let mut client = IpcClient::connect(&socket_path).unwrap();
1428
1429 let txn_id = client.begin_txn().unwrap();
1431 assert!(txn_id > 0);
1432
1433 let commit_ts = client.commit_txn(txn_id).unwrap();
1435 assert!(commit_ts > 0);
1436
1437 let txn_id2 = client.begin_txn().unwrap();
1439 client.abort_txn(txn_id2).unwrap();
1440
1441 server.stop();
1442 }
1443
1444 #[test]
1445 fn test_server_client_paths() {
1446 let (db, _temp_dir, socket_path) = setup_test_server();
1447
1448 let config = IpcServerConfig::default().with_socket_path(&socket_path);
1449 let server = IpcServer::new(Arc::clone(&db), config);
1450 server.start().unwrap();
1451
1452 thread::sleep(Duration::from_millis(100));
1453
1454 let mut client = IpcClient::connect(&socket_path).unwrap();
1455
1456 client
1458 .put_path(&["users", "alice", "email"], b"alice@example.com")
1459 .unwrap();
1460
1461 let value = client.get_path(&["users", "alice", "email"]).unwrap();
1463 assert_eq!(value, Some(b"alice@example.com".to_vec()));
1464
1465 let value = client.get_path(&["users", "bob", "email"]).unwrap();
1467 assert_eq!(value, None);
1468
1469 server.stop();
1470 }
1471
1472 #[test]
1473 fn test_server_stats() {
1474 let (db, _temp_dir, socket_path) = setup_test_server();
1475
1476 let config = IpcServerConfig::default().with_socket_path(&socket_path);
1477 let server = IpcServer::new(Arc::clone(&db), config);
1478 server.start().unwrap();
1479
1480 thread::sleep(Duration::from_millis(100));
1481
1482 let mut client = IpcClient::connect(&socket_path).unwrap();
1483
1484 client.ping().unwrap();
1486 client.put(b"k", b"v").unwrap();
1487 client.get(b"k").unwrap();
1488
1489 let stats = client.stats().unwrap();
1491 assert!(stats.contains_key("requests_total"));
1492 assert!(*stats.get("requests_total").unwrap() >= 4);
1493
1494 let server_stats = server.stats();
1496 assert!(server_stats.requests_total >= 4);
1497 assert!(server_stats.connections_active >= 1);
1498
1499 server.stop();
1500 }
1501
1502 #[test]
1503 fn test_multiple_clients() {
1504 let (db, _temp_dir, socket_path) = setup_test_server();
1505
1506 let config = IpcServerConfig::default()
1507 .with_socket_path(&socket_path)
1508 .with_max_connections(10);
1509 let server = IpcServer::new(Arc::clone(&db), config);
1510 server.start().unwrap();
1511
1512 thread::sleep(Duration::from_millis(100));
1513
1514 let mut handles = Vec::new();
1516 let socket_path_clone = socket_path.clone();
1517
1518 for i in 0..5 {
1519 let path = socket_path_clone.clone();
1520 let handle = thread::spawn(move || {
1521 let mut client = IpcClient::connect(&path).unwrap();
1522 let key = format!("key-{}", i);
1523 let value = format!("value-{}", i);
1524
1525 client.put(key.as_bytes(), value.as_bytes()).unwrap();
1526 let result = client.get(key.as_bytes()).unwrap();
1527 assert_eq!(result, Some(value.into_bytes()));
1528 });
1529 handles.push(handle);
1530 }
1531
1532 for handle in handles {
1533 handle.join().unwrap();
1534 }
1535
1536 let stats = server.stats();
1537 assert_eq!(stats.connections_total, 5);
1538
1539 server.stop();
1540 }
1541}