1use crate::database::{Database, TxnHandle};
87use parking_lot::Mutex;
88use std::collections::HashMap;
89use std::io::{BufReader, BufWriter, Read, Write};
90use std::os::unix::net::{UnixListener, UnixStream};
91use std::path::{Path, PathBuf};
92use std::sync::Arc;
93use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
94use std::thread::{self, JoinHandle};
95use std::time::{Duration, Instant};
96use thiserror::Error;
97
98mod opcode {
104 pub const PUT: u8 = 0x01;
105 pub const GET: u8 = 0x02;
106 pub const DELETE: u8 = 0x03;
107 pub const BEGIN_TXN: u8 = 0x04;
108 pub const COMMIT_TXN: u8 = 0x05;
109 pub const ABORT_TXN: u8 = 0x06;
110 pub const QUERY: u8 = 0x07;
111 pub const CREATE_TABLE: u8 = 0x08;
112 pub const PUT_PATH: u8 = 0x09;
113 pub const GET_PATH: u8 = 0x0A;
114 pub const SCAN: u8 = 0x0B;
115 pub const CHECKPOINT: u8 = 0x0C;
116 pub const STATS: u8 = 0x0D;
117 pub const PING: u8 = 0x0E;
118 pub const EXECUTE_SQL: u8 = 0x0F;
119
120 pub const OK: u8 = 0x80;
122 pub const ERROR: u8 = 0x81;
123 pub const VALUE: u8 = 0x82;
124 pub const TXN_ID: u8 = 0x83;
125 #[allow(dead_code)]
126 pub const ROW: u8 = 0x84;
127 #[allow(dead_code)]
128 pub const END_STREAM: u8 = 0x85;
129 pub const STATS_RESP: u8 = 0x86;
130 pub const PONG: u8 = 0x87;
131}
132
133const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
135
136#[derive(Debug, Error)]
141pub enum IpcError {
142 #[error("I/O error: {0}")]
143 Io(#[from] std::io::Error),
144
145 #[error("Database error: {0}")]
146 Database(String),
147
148 #[error("Protocol error: {0}")]
149 Protocol(String),
150
151 #[error("Server already running")]
152 AlreadyRunning,
153
154 #[error("Server not running")]
155 NotRunning,
156
157 #[error("Connection closed")]
158 ConnectionClosed,
159
160 #[error("Message too large: {0} bytes (max: {1})")]
161 MessageTooLarge(usize, usize),
162
163 #[error("Invalid opcode: {0:#x}")]
164 InvalidOpcode(u8),
165
166 #[error("Transaction not found: {0}")]
167 TxnNotFound(u64),
168}
169
170pub type Result<T> = std::result::Result<T, IpcError>;
171
172#[derive(Debug, Clone)]
178pub struct Message {
179 pub opcode: u8,
180 pub payload: Vec<u8>,
181}
182
183impl Message {
184 pub fn new(opcode: u8, payload: Vec<u8>) -> Self {
185 Self { opcode, payload }
186 }
187
188 pub fn ok() -> Self {
189 Self::new(opcode::OK, vec![])
190 }
191
192 pub fn error(msg: &str) -> Self {
193 Self::new(opcode::ERROR, msg.as_bytes().to_vec())
194 }
195
196 pub fn value(data: Vec<u8>) -> Self {
197 Self::new(opcode::VALUE, data)
198 }
199
200 pub fn txn_id(id: u64) -> Self {
201 Self::new(opcode::TXN_ID, id.to_le_bytes().to_vec())
202 }
203
204 pub fn read_from<R: Read>(reader: &mut R) -> Result<Self> {
206 let mut opcode_buf = [0u8; 1];
208 match reader.read_exact(&mut opcode_buf) {
209 Ok(_) => {}
210 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
211 return Err(IpcError::ConnectionClosed);
212 }
213 Err(e) => return Err(e.into()),
214 }
215 let opcode = opcode_buf[0];
216
217 let mut len_buf = [0u8; 4];
219 reader.read_exact(&mut len_buf)?;
220 let len = u32::from_le_bytes(len_buf) as usize;
221
222 if len > MAX_MESSAGE_SIZE {
224 return Err(IpcError::MessageTooLarge(len, MAX_MESSAGE_SIZE));
225 }
226
227 let mut payload = vec![0u8; len];
229 if len > 0 {
230 reader.read_exact(&mut payload)?;
231 }
232
233 Ok(Self { opcode, payload })
234 }
235
236 pub fn write_to<W: Write>(&self, writer: &mut W) -> Result<()> {
238 writer.write_all(&[self.opcode])?;
240
241 let len = self.payload.len() as u32;
243 writer.write_all(&len.to_le_bytes())?;
244
245 if !self.payload.is_empty() {
247 writer.write_all(&self.payload)?;
248 }
249
250 writer.flush()?;
251 Ok(())
252 }
253}
254
255fn encode_put(key: &[u8], value: &[u8]) -> Vec<u8> {
261 let mut buf = Vec::with_capacity(4 + key.len() + value.len());
262 buf.extend_from_slice(&(key.len() as u32).to_le_bytes());
263 buf.extend_from_slice(key);
264 buf.extend_from_slice(value);
265 buf
266}
267
268fn decode_put(payload: &[u8]) -> Result<(&[u8], &[u8])> {
270 if payload.len() < 4 {
271 return Err(IpcError::Protocol("PUT payload too short".into()));
272 }
273 let key_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
274 if payload.len() < 4 + key_len {
275 return Err(IpcError::Protocol("PUT payload key truncated".into()));
276 }
277 let key = &payload[4..4 + key_len];
278 let value = &payload[4 + key_len..];
279 Ok((key, value))
280}
281
282fn encode_put_path(path: &[&str], value: &[u8]) -> Vec<u8> {
284 let mut buf = Vec::new();
285 buf.extend_from_slice(&(path.len() as u16).to_le_bytes());
286 for segment in path {
287 let seg_bytes = segment.as_bytes();
288 buf.extend_from_slice(&(seg_bytes.len() as u16).to_le_bytes());
289 buf.extend_from_slice(seg_bytes);
290 }
291 buf.extend_from_slice(value);
292 buf
293}
294
295fn decode_path(payload: &[u8]) -> Result<(Vec<String>, &[u8])> {
297 if payload.len() < 2 {
298 return Err(IpcError::Protocol("Path payload too short".into()));
299 }
300 let count = u16::from_le_bytes(payload[0..2].try_into().unwrap()) as usize;
301 let mut offset = 2;
302 let mut path = Vec::with_capacity(count);
303
304 for _ in 0..count {
305 if offset + 2 > payload.len() {
306 return Err(IpcError::Protocol("Path segment length truncated".into()));
307 }
308 let seg_len = u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
309 offset += 2;
310 if offset + seg_len > payload.len() {
311 return Err(IpcError::Protocol("Path segment truncated".into()));
312 }
313 let segment = std::str::from_utf8(&payload[offset..offset + seg_len])
314 .map_err(|_| IpcError::Protocol("Invalid UTF-8 in path".into()))?;
315 path.push(segment.to_string());
316 offset += seg_len;
317 }
318
319 Ok((path, &payload[offset..]))
320}
321
322#[derive(Debug, Default)]
327pub struct ServerStats {
328 pub connections_total: AtomicU64,
329 pub connections_active: AtomicU64,
330 pub requests_total: AtomicU64,
331 pub requests_success: AtomicU64,
332 pub requests_error: AtomicU64,
333 pub bytes_received: AtomicU64,
334 pub bytes_sent: AtomicU64,
335 pub start_time: Mutex<Option<Instant>>,
336}
337
338impl ServerStats {
339 pub fn new() -> Self {
340 Self::default()
341 }
342
343 pub fn snapshot(&self) -> ServerStatsSnapshot {
344 ServerStatsSnapshot {
345 connections_total: self.connections_total.load(Ordering::Relaxed),
346 connections_active: self.connections_active.load(Ordering::Relaxed),
347 requests_total: self.requests_total.load(Ordering::Relaxed),
348 requests_success: self.requests_success.load(Ordering::Relaxed),
349 requests_error: self.requests_error.load(Ordering::Relaxed),
350 bytes_received: self.bytes_received.load(Ordering::Relaxed),
351 bytes_sent: self.bytes_sent.load(Ordering::Relaxed),
352 uptime_secs: self
353 .start_time
354 .lock()
355 .map(|t| t.elapsed().as_secs())
356 .unwrap_or(0),
357 }
358 }
359}
360
361#[derive(Debug, Clone)]
362pub struct ServerStatsSnapshot {
363 pub connections_total: u64,
364 pub connections_active: u64,
365 pub requests_total: u64,
366 pub requests_success: u64,
367 pub requests_error: u64,
368 pub bytes_received: u64,
369 pub bytes_sent: u64,
370 pub uptime_secs: u64,
371}
372
373struct ClientHandler {
378 db: Arc<Database>,
379 stream: UnixStream,
380 stats: Arc<ServerStats>,
381 active_txns: HashMap<u64, TxnHandle>, next_client_txn_id: u64,
383}
384
385impl ClientHandler {
386 fn new(db: Arc<Database>, stream: UnixStream, stats: Arc<ServerStats>) -> Self {
387 Self {
388 db,
389 stream,
390 stats,
391 active_txns: HashMap::new(),
392 next_client_txn_id: 1,
393 }
394 }
395
396 fn handle(&mut self) -> Result<()> {
397 self.stream
399 .set_read_timeout(Some(Duration::from_secs(30)))?;
400
401 let mut reader = BufReader::new(self.stream.try_clone()?);
402 let mut writer = BufWriter::new(self.stream.try_clone()?);
403
404 loop {
405 let request = match Message::read_from(&mut reader) {
407 Ok(msg) => msg,
408 Err(IpcError::ConnectionClosed) => {
409 self.cleanup_transactions();
411 return Ok(());
412 }
413 Err(e) => return Err(e),
414 };
415
416 self.stats.requests_total.fetch_add(1, Ordering::Relaxed);
417 self.stats
418 .bytes_received
419 .fetch_add((5 + request.payload.len()) as u64, Ordering::Relaxed);
420
421 let response = self.process_request(&request);
423
424 if response.opcode == opcode::ERROR {
426 self.stats.requests_error.fetch_add(1, Ordering::Relaxed);
427 } else {
428 self.stats.requests_success.fetch_add(1, Ordering::Relaxed);
429 }
430
431 self.stats
433 .bytes_sent
434 .fetch_add((5 + response.payload.len()) as u64, Ordering::Relaxed);
435 response.write_to(&mut writer)?;
436 }
437 }
438
439 fn process_request(&mut self, request: &Message) -> Message {
440 match request.opcode {
441 opcode::PING => Message::new(opcode::PONG, vec![]),
442
443 opcode::PUT => self.handle_put(&request.payload),
444 opcode::GET => self.handle_get(&request.payload),
445 opcode::DELETE => self.handle_delete(&request.payload),
446
447 opcode::BEGIN_TXN => self.handle_begin_txn(),
448 opcode::COMMIT_TXN => self.handle_commit_txn(&request.payload),
449 opcode::ABORT_TXN => self.handle_abort_txn(&request.payload),
450
451 opcode::PUT_PATH => self.handle_put_path(&request.payload),
452 opcode::GET_PATH => self.handle_get_path(&request.payload),
453
454 opcode::QUERY => self.handle_query(&request.payload),
455 opcode::CREATE_TABLE => self.handle_create_table(&request.payload),
456 opcode::SCAN => self.handle_scan(&request.payload),
457 opcode::EXECUTE_SQL => self.handle_execute_sql(&request.payload),
458
459 opcode::CHECKPOINT => self.handle_checkpoint(),
460 opcode::STATS => self.handle_stats(),
461
462 _ => Message::error(&format!("Unknown opcode: {:#x}", request.opcode)),
463 }
464 }
465
466 fn handle_execute_sql(&self, payload: &[u8]) -> Message {
467 let sql = match std::str::from_utf8(payload) {
469 Ok(s) => s,
470 Err(_) => return Message::error("Invalid UTF-8 in SQL query"),
471 };
472
473 let result = serde_json::json!({
476 "error": "SQL execution must be implemented client-side. Use Python SDK for full SQL support.",
477 "sql": sql
478 });
479
480 match serde_json::to_vec(&result) {
481 Ok(json) => Message::value(json),
482 Err(e) => Message::error(&format!("Failed to serialize error: {}", e)),
483 }
484 }
485
486 fn handle_query(&self, payload: &[u8]) -> Message {
487 let mut offset = 0;
489
490 if payload.len() < 2 {
491 return Message::error("Query payload too short");
492 }
493
494 let path_len = u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
496 offset += 2;
497 if offset + path_len > payload.len() {
498 return Message::error("Query path truncated");
499 }
500 let path = match std::str::from_utf8(&payload[offset..offset + path_len]) {
501 Ok(s) => s,
502 Err(_) => return Message::error("Invalid UTF-8 in query path"),
503 };
504 offset += path_len;
505
506 if offset + 8 > payload.len() {
508 return Message::error("Query limit/offset truncated");
509 }
510 let limit_val =
511 u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
512 offset += 4;
513 let offset_val =
514 u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
515 offset += 4;
516
517 if offset + 2 > payload.len() {
519 return Message::error("Query columns count truncated");
520 }
521 let cols_count =
522 u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
523 offset += 2;
524
525 let mut columns = Vec::with_capacity(cols_count);
526 for _ in 0..cols_count {
527 if offset + 2 > payload.len() {
528 return Message::error("Query column length truncated");
529 }
530 let col_len =
531 u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
532 offset += 2;
533
534 if offset + col_len > payload.len() {
535 return Message::error("Query column truncated");
536 }
537 let col = match std::str::from_utf8(&payload[offset..offset + col_len]) {
538 Ok(s) => s.to_string(),
539 Err(_) => return Message::error("Invalid UTF-8 in query column"),
540 };
541 columns.push(col);
542 offset += col_len;
543 }
544
545 let txn = match self.db.begin_transaction() {
548 Ok(t) => t,
549 Err(e) => return Message::error(&e.to_string()),
550 };
551
552 let mut builder = self.db.query(txn, path);
553
554 if limit_val > 0 {
555 builder = builder.limit(limit_val);
556 }
557 if offset_val > 0 {
558 builder = builder.offset(offset_val);
559 }
560
561 if !columns.is_empty() {
562 let cols_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
563 builder = builder.columns(&cols_refs);
564 }
565
566 let result = builder.to_toon();
567 let _ = self.db.abort(txn); match result {
570 Ok(soch_str) => Message::new(opcode::VALUE, soch_str.into_bytes()),
571 Err(e) => Message::error(&e.to_string()),
572 }
573 }
574
575 fn handle_scan(&self, payload: &[u8]) -> Message {
576 let prefix = match std::str::from_utf8(payload) {
577 Ok(s) => s,
578 Err(_) => return Message::error("Invalid UTF-8 in scan prefix"),
579 };
580
581 let txn = match self.db.begin_transaction() {
582 Ok(t) => t,
583 Err(e) => return Message::error(&e.to_string()),
584 };
585
586 let result = self.db.scan_path(txn, prefix);
587 let _ = self.db.abort(txn);
588
589 match result {
590 Ok(items) => {
591 let mut buf = Vec::new();
595 buf.extend_from_slice(&(items.len() as u32).to_le_bytes());
596
597 for (key, val) in items {
598 let key_bytes = key.as_bytes();
599 buf.extend_from_slice(&(key_bytes.len() as u16).to_le_bytes());
600 buf.extend_from_slice(key_bytes);
601 buf.extend_from_slice(&(val.len() as u32).to_le_bytes());
602 buf.extend_from_slice(&val);
603 }
604
605 Message::new(opcode::VALUE, buf)
606 }
607 Err(e) => Message::error(&e.to_string()),
608 }
609 }
610
611 fn handle_create_table(&self, payload: &[u8]) -> Message {
612 let _schema_json = match std::str::from_utf8(payload) {
614 Ok(s) => s,
615 Err(_) => return Message::error("Invalid UTF-8 in schema"),
616 };
617
618 let mut offset = 0;
632 if payload.len() < 2 {
633 return Message::error("Schema payload too short");
634 }
635
636 let name_len = u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
637 offset += 2;
638 if offset + name_len > payload.len() {
639 return Message::error("Schema name truncated");
640 }
641 let name = match std::str::from_utf8(&payload[offset..offset + name_len]) {
642 Ok(s) => s.to_string(),
643 Err(_) => return Message::error("Invalid UTF-8 in schema name"),
644 };
645 offset += name_len;
646
647 if offset + 2 > payload.len() {
648 return Message::error("Schema column count truncated");
649 }
650 let col_count =
651 u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
652 offset += 2;
653
654 let mut columns = Vec::with_capacity(col_count);
655 for _ in 0..col_count {
656 if offset + 2 > payload.len() {
657 return Message::error("Column name length truncated");
658 }
659 let col_name_len =
660 u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
661 offset += 2;
662
663 if offset + col_name_len > payload.len() {
664 return Message::error("Column name truncated");
665 }
666 let col_name = match std::str::from_utf8(&payload[offset..offset + col_name_len]) {
667 Ok(s) => s.to_string(),
668 Err(_) => return Message::error("Invalid UTF-8 in column name"),
669 };
670 offset += col_name_len;
671
672 if offset + 2 > payload.len() {
673 return Message::error("Column type/nullable truncated");
674 }
675 let type_byte = payload[offset];
676 offset += 1;
677 let nullable_byte = payload[offset];
678 offset += 1;
679
680 let col_type = match type_byte {
681 0 => crate::database::ColumnType::Int64,
682 1 => crate::database::ColumnType::UInt64,
683 2 => crate::database::ColumnType::Float64,
684 3 => crate::database::ColumnType::Text,
685 4 => crate::database::ColumnType::Binary,
686 5 => crate::database::ColumnType::Bool,
687 _ => return Message::error("Invalid column type"),
688 };
689
690 columns.push(crate::database::ColumnDef {
691 name: col_name,
692 col_type,
693 nullable: nullable_byte != 0,
694 });
695 }
696
697 let schema = crate::database::TableSchema { name, columns };
698
699 match self.db.register_table(schema) {
700 Ok(_) => Message::ok(),
701 Err(e) => Message::error(&e.to_string()),
702 }
703 }
704
705 fn handle_put(&self, payload: &[u8]) -> Message {
707 match decode_put(payload) {
708 Ok((key, value)) => {
709 let txn = match self.db.begin_transaction() {
711 Ok(t) => t,
712 Err(e) => return Message::error(&e.to_string()),
713 };
714
715 if let Err(e) = self.db.put(txn, key, value) {
716 let _ = self.db.abort(txn);
717 return Message::error(&e.to_string());
718 }
719
720 match self.db.commit(txn) {
721 Ok(_) => Message::ok(),
722 Err(e) => Message::error(&e.to_string()),
723 }
724 }
725 Err(e) => Message::error(&e.to_string()),
726 }
727 }
728
729 fn handle_get(&self, payload: &[u8]) -> Message {
731 let txn = match self.db.begin_transaction() {
733 Ok(t) => t,
734 Err(e) => return Message::error(&e.to_string()),
735 };
736
737 let result = self.db.get(txn, payload);
738 let _ = self.db.abort(txn); match result {
741 Ok(Some(value)) => Message::value(value),
742 Ok(None) => Message::new(opcode::VALUE, vec![]),
743 Err(e) => Message::error(&e.to_string()),
744 }
745 }
746
747 fn handle_delete(&self, payload: &[u8]) -> Message {
749 let txn = match self.db.begin_transaction() {
750 Ok(t) => t,
751 Err(e) => return Message::error(&e.to_string()),
752 };
753
754 if let Err(e) = self.db.delete(txn, payload) {
755 let _ = self.db.abort(txn);
756 return Message::error(&e.to_string());
757 }
758
759 match self.db.commit(txn) {
760 Ok(_) => Message::ok(),
761 Err(e) => Message::error(&e.to_string()),
762 }
763 }
764
765 fn handle_begin_txn(&mut self) -> Message {
766 match self.db.begin_transaction() {
767 Ok(txn) => {
768 let client_txn_id = self.next_client_txn_id;
769 self.next_client_txn_id += 1;
770 self.active_txns.insert(client_txn_id, txn);
771 Message::txn_id(client_txn_id)
772 }
773 Err(e) => Message::error(&e.to_string()),
774 }
775 }
776
777 fn handle_commit_txn(&mut self, payload: &[u8]) -> Message {
778 if payload.len() < 8 {
779 return Message::error("COMMIT_TXN requires txn_id");
780 }
781 let client_txn_id = u64::from_le_bytes(payload[0..8].try_into().unwrap());
782
783 match self.active_txns.remove(&client_txn_id) {
784 Some(txn) => match self.db.commit(txn) {
785 Ok(commit_ts) => Message::txn_id(commit_ts),
786 Err(e) => Message::error(&e.to_string()),
787 },
788 None => Message::error(&format!("Transaction not found: {}", client_txn_id)),
789 }
790 }
791
792 fn handle_abort_txn(&mut self, payload: &[u8]) -> Message {
793 if payload.len() < 8 {
794 return Message::error("ABORT_TXN requires txn_id");
795 }
796 let client_txn_id = u64::from_le_bytes(payload[0..8].try_into().unwrap());
797
798 match self.active_txns.remove(&client_txn_id) {
799 Some(txn) => match self.db.abort(txn) {
800 Ok(_) => Message::ok(),
801 Err(e) => Message::error(&e.to_string()),
802 },
803 None => Message::error(&format!("Transaction not found: {}", client_txn_id)),
804 }
805 }
806
807 fn handle_put_path(&self, payload: &[u8]) -> Message {
808 match decode_path(payload) {
809 Ok((path, value)) => {
810 let txn = match self.db.begin_transaction() {
811 Ok(t) => t,
812 Err(e) => return Message::error(&e.to_string()),
813 };
814
815 let path_str = path.join("/");
816 if let Err(e) = self.db.put_path(txn, &path_str, value) {
817 let _ = self.db.abort(txn);
818 return Message::error(&e.to_string());
819 }
820
821 match self.db.commit(txn) {
822 Ok(_) => Message::ok(),
823 Err(e) => Message::error(&e.to_string()),
824 }
825 }
826 Err(e) => Message::error(&e.to_string()),
827 }
828 }
829
830 fn handle_get_path(&self, payload: &[u8]) -> Message {
831 match decode_path(payload) {
832 Ok((path, _)) => {
833 let txn = match self.db.begin_transaction() {
834 Ok(t) => t,
835 Err(e) => return Message::error(&e.to_string()),
836 };
837
838 let path_str = path.join("/");
839 let result = self.db.get_path(txn, &path_str);
840 let _ = self.db.abort(txn);
841
842 match result {
843 Ok(Some(value)) => Message::value(value),
844 Ok(None) => Message::new(opcode::VALUE, vec![]),
845 Err(e) => Message::error(&e.to_string()),
846 }
847 }
848 Err(e) => Message::error(&e.to_string()),
849 }
850 }
851
852 fn handle_checkpoint(&self) -> Message {
853 match self.db.checkpoint() {
854 Ok(_) => Message::ok(),
855 Err(e) => Message::error(&e.to_string()),
856 }
857 }
858
859 fn handle_stats(&self) -> Message {
860 let stats = self.stats.snapshot();
861 let stats_json = format!(
863 r#"{{"connections_total":{},"connections_active":{},"requests_total":{},"requests_success":{},"requests_error":{},"bytes_received":{},"bytes_sent":{},"uptime_secs":{},"memtable_size_bytes":0,"wal_size_bytes":0,"active_transactions":{}}}"#,
864 stats.connections_total,
865 stats.connections_active,
866 stats.requests_total,
867 stats.requests_success,
868 stats.requests_error,
869 stats.bytes_received,
870 stats.bytes_sent,
871 stats.uptime_secs,
872 self.active_txns.len()
873 );
874 Message::new(opcode::STATS_RESP, stats_json.into_bytes())
875 }
876
877 fn cleanup_transactions(&mut self) {
878 for (_client_id, txn) in self.active_txns.drain() {
880 let _ = self.db.abort(txn);
881 }
882 }
883}
884
885#[derive(Debug, Clone)]
891pub struct IpcServerConfig {
892 pub socket_path: PathBuf,
894
895 pub max_connections: usize,
897
898 pub thread_pool_size: usize,
900
901 pub connection_timeout_secs: u64,
903}
904
905impl Default for IpcServerConfig {
906 fn default() -> Self {
907 Self {
908 socket_path: PathBuf::from("/tmp/sochdb.sock"),
909 max_connections: 100,
910 thread_pool_size: 4,
911 connection_timeout_secs: 300, }
913 }
914}
915
916impl IpcServerConfig {
917 pub fn with_socket_path<P: AsRef<Path>>(mut self, path: P) -> Self {
918 self.socket_path = path.as_ref().to_path_buf();
919 self
920 }
921
922 pub fn with_max_connections(mut self, max: usize) -> Self {
923 self.max_connections = max;
924 self
925 }
926}
927
928pub struct IpcServer {
930 db: Arc<Database>,
931 config: IpcServerConfig,
932 stats: Arc<ServerStats>,
933 running: Arc<AtomicBool>,
934 listener_handle: Mutex<Option<JoinHandle<()>>>,
935}
936
937impl IpcServer {
938 pub fn new(db: Arc<Database>, config: IpcServerConfig) -> Self {
940 Self {
941 db,
942 config,
943 stats: Arc::new(ServerStats::new()),
944 running: Arc::new(AtomicBool::new(false)),
945 listener_handle: Mutex::new(None),
946 }
947 }
948
949 pub fn with_defaults(db: Arc<Database>) -> Self {
951 Self::new(db, IpcServerConfig::default())
952 }
953
954 pub fn run(&self) -> Result<()> {
956 if self.running.swap(true, Ordering::SeqCst) {
957 return Err(IpcError::AlreadyRunning);
958 }
959
960 if self.config.socket_path.exists() {
962 std::fs::remove_file(&self.config.socket_path)?;
963 }
964
965 let listener = UnixListener::bind(&self.config.socket_path)?;
967 listener.set_nonblocking(false)?;
968
969 *self.stats.start_time.lock() = Some(Instant::now());
971
972 eprintln!("[IpcServer] Listening on {:?}", self.config.socket_path);
973
974 while self.running.load(Ordering::SeqCst) {
976 match listener.accept() {
977 Ok((stream, _addr)) => {
978 let active = self.stats.connections_active.load(Ordering::Relaxed);
980 if active >= self.config.max_connections as u64 {
981 eprintln!("[IpcServer] Connection limit reached, rejecting");
982 continue;
983 }
984
985 self.stats.connections_total.fetch_add(1, Ordering::Relaxed);
986 self.stats
987 .connections_active
988 .fetch_add(1, Ordering::Relaxed);
989
990 let db = Arc::clone(&self.db);
991 let stats = Arc::clone(&self.stats);
992
993 thread::spawn(move || {
995 let mut handler = ClientHandler::new(db, stream, Arc::clone(&stats));
996 if let Err(e) = handler.handle() {
997 eprintln!("[IpcServer] Client error: {}", e);
998 }
999 stats.connections_active.fetch_sub(1, Ordering::Relaxed);
1000 });
1001 }
1002 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1003 thread::sleep(Duration::from_millis(100));
1005 }
1006 Err(e) => {
1007 eprintln!("[IpcServer] Accept error: {}", e);
1008 }
1009 }
1010 }
1011
1012 let _ = std::fs::remove_file(&self.config.socket_path);
1014
1015 Ok(())
1016 }
1017
1018 pub fn start(&self) -> Result<()> {
1020 if self.running.swap(true, Ordering::SeqCst) {
1021 return Err(IpcError::AlreadyRunning);
1022 }
1023
1024 let db = Arc::clone(&self.db);
1025 let config = self.config.clone();
1026 let stats = Arc::clone(&self.stats);
1027 let running = Arc::clone(&self.running);
1028
1029 let handle = thread::spawn(move || {
1030 if config.socket_path.exists() {
1033 let _ = std::fs::remove_file(&config.socket_path);
1034 }
1035
1036 let listener = match UnixListener::bind(&config.socket_path) {
1038 Ok(l) => l,
1039 Err(e) => {
1040 eprintln!("[IpcServer] Failed to bind: {}", e);
1041 running.store(false, Ordering::SeqCst);
1042 return;
1043 }
1044 };
1045 let _ = listener.set_nonblocking(false);
1046
1047 *stats.start_time.lock() = Some(Instant::now());
1049
1050 eprintln!("[IpcServer] Listening on {:?}", config.socket_path);
1051
1052 while running.load(Ordering::SeqCst) {
1054 match listener.accept() {
1055 Ok((stream, _addr)) => {
1056 let active = stats.connections_active.load(Ordering::Relaxed);
1058 if active >= config.max_connections as u64 {
1059 eprintln!("[IpcServer] Connection limit reached, rejecting");
1060 continue;
1061 }
1062
1063 stats.connections_total.fetch_add(1, Ordering::Relaxed);
1064 stats.connections_active.fetch_add(1, Ordering::Relaxed);
1065
1066 let db_clone = Arc::clone(&db);
1067 let stats_clone = Arc::clone(&stats);
1068
1069 thread::spawn(move || {
1071 let mut handler =
1072 ClientHandler::new(db_clone, stream, Arc::clone(&stats_clone));
1073 if let Err(e) = handler.handle() {
1074 eprintln!("[IpcServer] Client error: {}", e);
1075 }
1076 stats_clone
1077 .connections_active
1078 .fetch_sub(1, Ordering::Relaxed);
1079 });
1080 }
1081 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1082 thread::sleep(Duration::from_millis(100));
1084 }
1085 Err(e) => {
1086 eprintln!("[IpcServer] Accept error: {}", e);
1087 break;
1088 }
1089 }
1090 }
1091
1092 let _ = std::fs::remove_file(&config.socket_path);
1094 });
1095
1096 *self.listener_handle.lock() = Some(handle);
1097 Ok(())
1098 }
1099
1100 pub fn stop(&self) {
1102 self.running.store(false, Ordering::SeqCst);
1103
1104 let _ = UnixStream::connect(&self.config.socket_path);
1106
1107 if let Some(handle) = self.listener_handle.lock().take() {
1109 let _ = handle.join();
1110 }
1111 }
1112
1113 pub fn is_running(&self) -> bool {
1115 self.running.load(Ordering::SeqCst)
1116 }
1117
1118 pub fn stats(&self) -> ServerStatsSnapshot {
1120 self.stats.snapshot()
1121 }
1122
1123 pub fn socket_path(&self) -> &Path {
1125 &self.config.socket_path
1126 }
1127}
1128
1129impl Drop for IpcServer {
1130 fn drop(&mut self) {
1131 self.stop();
1132 }
1133}
1134
1135pub struct IpcClient {
1141 stream: UnixStream,
1142}
1143
1144impl IpcClient {
1145 pub fn connect<P: AsRef<Path>>(socket_path: P) -> Result<Self> {
1147 let stream = UnixStream::connect(socket_path)?;
1148 Ok(Self { stream })
1149 }
1150
1151 pub fn connect_with_timeout<P: AsRef<Path>>(socket_path: P, timeout: Duration) -> Result<Self> {
1153 let stream = UnixStream::connect(socket_path)?;
1154 stream.set_read_timeout(Some(timeout))?;
1155 stream.set_write_timeout(Some(timeout))?;
1156 Ok(Self { stream })
1157 }
1158
1159 fn request(&mut self, msg: Message) -> Result<Message> {
1161 msg.write_to(&mut self.stream)?;
1162 Message::read_from(&mut self.stream)
1163 }
1164
1165 pub fn ping(&mut self) -> Result<Duration> {
1167 let start = Instant::now();
1168 let resp = self.request(Message::new(opcode::PING, vec![]))?;
1169 if resp.opcode != opcode::PONG {
1170 return Err(IpcError::Protocol("Expected PONG".into()));
1171 }
1172 Ok(start.elapsed())
1173 }
1174
1175 pub fn put(&mut self, key: &[u8], value: &[u8]) -> Result<()> {
1177 let payload = encode_put(key, value);
1178 let resp = self.request(Message::new(opcode::PUT, payload))?;
1179 self.check_ok(resp)
1180 }
1181
1182 pub fn get(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
1184 let resp = self.request(Message::new(opcode::GET, key.to_vec()))?;
1185 match resp.opcode {
1186 opcode::VALUE if resp.payload.is_empty() => Ok(None),
1187 opcode::VALUE => Ok(Some(resp.payload)),
1188 opcode::ERROR => Err(IpcError::Database(
1189 String::from_utf8_lossy(&resp.payload).to_string(),
1190 )),
1191 _ => Err(IpcError::Protocol(format!(
1192 "Unexpected opcode: {:#x}",
1193 resp.opcode
1194 ))),
1195 }
1196 }
1197
1198 pub fn delete(&mut self, key: &[u8]) -> Result<()> {
1200 let resp = self.request(Message::new(opcode::DELETE, key.to_vec()))?;
1201 self.check_ok(resp)
1202 }
1203
1204 pub fn begin_txn(&mut self) -> Result<u64> {
1206 let resp = self.request(Message::new(opcode::BEGIN_TXN, vec![]))?;
1207 match resp.opcode {
1208 opcode::TXN_ID => {
1209 if resp.payload.len() >= 8 {
1210 Ok(u64::from_le_bytes(resp.payload[0..8].try_into().unwrap()))
1211 } else {
1212 Err(IpcError::Protocol("TXN_ID response too short".into()))
1213 }
1214 }
1215 opcode::ERROR => Err(IpcError::Database(
1216 String::from_utf8_lossy(&resp.payload).to_string(),
1217 )),
1218 _ => Err(IpcError::Protocol(format!(
1219 "Unexpected opcode: {:#x}",
1220 resp.opcode
1221 ))),
1222 }
1223 }
1224
1225 pub fn commit_txn(&mut self, txn_id: u64) -> Result<u64> {
1227 let resp = self.request(Message::new(
1228 opcode::COMMIT_TXN,
1229 txn_id.to_le_bytes().to_vec(),
1230 ))?;
1231 match resp.opcode {
1232 opcode::TXN_ID => {
1233 if resp.payload.len() >= 8 {
1234 Ok(u64::from_le_bytes(resp.payload[0..8].try_into().unwrap()))
1235 } else {
1236 Err(IpcError::Protocol("TXN_ID response too short".into()))
1237 }
1238 }
1239 opcode::ERROR => Err(IpcError::Database(
1240 String::from_utf8_lossy(&resp.payload).to_string(),
1241 )),
1242 _ => Err(IpcError::Protocol(format!(
1243 "Unexpected opcode: {:#x}",
1244 resp.opcode
1245 ))),
1246 }
1247 }
1248
1249 pub fn abort_txn(&mut self, txn_id: u64) -> Result<()> {
1251 let resp = self.request(Message::new(
1252 opcode::ABORT_TXN,
1253 txn_id.to_le_bytes().to_vec(),
1254 ))?;
1255 self.check_ok(resp)
1256 }
1257
1258 pub fn put_path(&mut self, path: &[&str], value: &[u8]) -> Result<()> {
1260 let payload = encode_put_path(path, value);
1261 let resp = self.request(Message::new(opcode::PUT_PATH, payload))?;
1262 self.check_ok(resp)
1263 }
1264
1265 pub fn get_path(&mut self, path: &[&str]) -> Result<Option<Vec<u8>>> {
1267 let payload = encode_put_path(path, &[]);
1268 let resp = self.request(Message::new(opcode::GET_PATH, payload))?;
1269 match resp.opcode {
1270 opcode::VALUE if resp.payload.is_empty() => Ok(None),
1271 opcode::VALUE => Ok(Some(resp.payload)),
1272 opcode::ERROR => Err(IpcError::Database(
1273 String::from_utf8_lossy(&resp.payload).to_string(),
1274 )),
1275 _ => Err(IpcError::Protocol(format!(
1276 "Unexpected opcode: {:#x}",
1277 resp.opcode
1278 ))),
1279 }
1280 }
1281
1282 pub fn checkpoint(&mut self) -> Result<()> {
1284 let resp = self.request(Message::new(opcode::CHECKPOINT, vec![]))?;
1285 self.check_ok(resp)
1286 }
1287
1288 pub fn stats(&mut self) -> Result<HashMap<String, u64>> {
1290 let resp = self.request(Message::new(opcode::STATS, vec![]))?;
1291 match resp.opcode {
1292 opcode::STATS_RESP => {
1293 let stats_str = String::from_utf8_lossy(&resp.payload);
1294
1295 let stats: HashMap<String, u64> = serde_json::from_str(&stats_str)
1297 .map_err(|e| IpcError::Protocol(format!("Failed to parse stats JSON: {}", e)))?;
1298
1299 Ok(stats)
1300 }
1301 opcode::ERROR => Err(IpcError::Database(
1302 String::from_utf8_lossy(&resp.payload).to_string(),
1303 )),
1304 _ => Err(IpcError::Protocol(format!(
1305 "Unexpected opcode: {:#x}",
1306 resp.opcode
1307 ))),
1308 }
1309 }
1310
1311 fn check_ok(&self, resp: Message) -> Result<()> {
1312 match resp.opcode {
1313 opcode::OK => Ok(()),
1314 opcode::ERROR => Err(IpcError::Database(
1315 String::from_utf8_lossy(&resp.payload).to_string(),
1316 )),
1317 _ => Err(IpcError::Protocol(format!(
1318 "Unexpected opcode: {:#x}",
1319 resp.opcode
1320 ))),
1321 }
1322 }
1323}
1324
1325#[cfg(test)]
1330mod tests {
1331 use super::*;
1332 use std::time::Duration;
1333 use tempfile::TempDir;
1334
1335 fn setup_test_server() -> (Arc<Database>, TempDir, PathBuf) {
1336 let temp_dir = TempDir::new().unwrap();
1337 let db_path = temp_dir.path().join("test.db");
1338 let socket_path = temp_dir.path().join("test.sock");
1339
1340 let db = Database::open(&db_path).unwrap();
1341 (db, temp_dir, socket_path)
1342 }
1343
1344 #[test]
1345 fn test_message_roundtrip() {
1346 let original = Message::new(0x01, b"hello world".to_vec());
1347
1348 let mut buffer = Vec::new();
1349 original.write_to(&mut buffer).unwrap();
1350
1351 let mut cursor = std::io::Cursor::new(buffer);
1352 let decoded = Message::read_from(&mut cursor).unwrap();
1353
1354 assert_eq!(decoded.opcode, original.opcode);
1355 assert_eq!(decoded.payload, original.payload);
1356 }
1357
1358 #[test]
1359 fn test_encode_decode_put() {
1360 let key = b"test-key";
1361 let value = b"test-value";
1362
1363 let encoded = encode_put(key, value);
1364 let (decoded_key, decoded_value) = decode_put(&encoded).unwrap();
1365
1366 assert_eq!(decoded_key, key);
1367 assert_eq!(decoded_value, value);
1368 }
1369
1370 #[test]
1371 fn test_encode_decode_path() {
1372 let path = vec!["users", "alice", "settings"];
1373 let value = b"preferences";
1374
1375 let encoded = encode_put_path(&path, value);
1376 let (decoded_path, decoded_value) = decode_path(&encoded).unwrap();
1377
1378 let expected_path: Vec<String> = path.iter().map(|s| s.to_string()).collect();
1379 assert_eq!(decoded_path, expected_path);
1380 assert_eq!(decoded_value, value);
1381 }
1382
1383 #[test]
1384 fn test_server_client_basic() {
1385 let (db, _temp_dir, socket_path) = setup_test_server();
1386
1387 let config = IpcServerConfig::default().with_socket_path(&socket_path);
1389 let server = IpcServer::new(Arc::clone(&db), config);
1390 server.start().unwrap();
1391
1392 thread::sleep(Duration::from_millis(100));
1394
1395 let mut client = IpcClient::connect(&socket_path).unwrap();
1397
1398 let latency = client.ping().unwrap();
1400 assert!(latency < Duration::from_secs(1));
1401
1402 client.put(b"key1", b"value1").unwrap();
1404 let value = client.get(b"key1").unwrap();
1405 assert_eq!(value, Some(b"value1".to_vec()));
1406
1407 let value = client.get(b"nonexistent").unwrap();
1409 assert_eq!(value, None);
1410
1411 client.delete(b"key1").unwrap();
1413 let value = client.get(b"key1").unwrap();
1414 assert_eq!(value, None);
1415
1416 server.stop();
1418 }
1419
1420 #[test]
1421 fn test_server_client_transactions() {
1422 let (db, _temp_dir, socket_path) = setup_test_server();
1423
1424 let config = IpcServerConfig::default().with_socket_path(&socket_path);
1425 let server = IpcServer::new(Arc::clone(&db), config);
1426 server.start().unwrap();
1427
1428 thread::sleep(Duration::from_millis(100));
1429
1430 let mut client = IpcClient::connect(&socket_path).unwrap();
1431
1432 let txn_id = client.begin_txn().unwrap();
1434 assert!(txn_id > 0);
1435
1436 let commit_ts = client.commit_txn(txn_id).unwrap();
1438 assert!(commit_ts > 0);
1439
1440 let txn_id2 = client.begin_txn().unwrap();
1442 client.abort_txn(txn_id2).unwrap();
1443
1444 server.stop();
1445 }
1446
1447 #[test]
1448 fn test_server_client_paths() {
1449 let (db, _temp_dir, socket_path) = setup_test_server();
1450
1451 let config = IpcServerConfig::default().with_socket_path(&socket_path);
1452 let server = IpcServer::new(Arc::clone(&db), config);
1453 server.start().unwrap();
1454
1455 thread::sleep(Duration::from_millis(100));
1456
1457 let mut client = IpcClient::connect(&socket_path).unwrap();
1458
1459 client
1461 .put_path(&["users", "alice", "email"], b"alice@example.com")
1462 .unwrap();
1463
1464 let value = client.get_path(&["users", "alice", "email"]).unwrap();
1466 assert_eq!(value, Some(b"alice@example.com".to_vec()));
1467
1468 let value = client.get_path(&["users", "bob", "email"]).unwrap();
1470 assert_eq!(value, None);
1471
1472 server.stop();
1473 }
1474
1475 #[test]
1476 fn test_server_stats() {
1477 let (db, _temp_dir, socket_path) = setup_test_server();
1478
1479 let config = IpcServerConfig::default().with_socket_path(&socket_path);
1480 let server = IpcServer::new(Arc::clone(&db), config);
1481 server.start().unwrap();
1482
1483 thread::sleep(Duration::from_millis(100));
1484
1485 let mut client = IpcClient::connect(&socket_path).unwrap();
1486
1487 client.ping().unwrap();
1489 client.put(b"k", b"v").unwrap();
1490 client.get(b"k").unwrap();
1491
1492 let stats = client.stats().unwrap();
1494 assert!(stats.contains_key("requests_total"));
1495 assert!(*stats.get("requests_total").unwrap() >= 4);
1496
1497 let server_stats = server.stats();
1499 assert!(server_stats.requests_total >= 4);
1500 assert!(server_stats.connections_active >= 1);
1501
1502 server.stop();
1503 }
1504
1505 #[test]
1506 fn test_multiple_clients() {
1507 let (db, _temp_dir, socket_path) = setup_test_server();
1508
1509 let config = IpcServerConfig::default()
1510 .with_socket_path(&socket_path)
1511 .with_max_connections(10);
1512 let server = IpcServer::new(Arc::clone(&db), config);
1513 server.start().unwrap();
1514
1515 thread::sleep(Duration::from_millis(100));
1516
1517 let mut handles = Vec::new();
1519 let socket_path_clone = socket_path.clone();
1520
1521 for i in 0..5 {
1522 let path = socket_path_clone.clone();
1523 let handle = thread::spawn(move || {
1524 let mut client = IpcClient::connect(&path).unwrap();
1525 let key = format!("key-{}", i);
1526 let value = format!("value-{}", i);
1527
1528 client.put(key.as_bytes(), value.as_bytes()).unwrap();
1529 let result = client.get(key.as_bytes()).unwrap();
1530 assert_eq!(result, Some(value.into_bytes()));
1531 });
1532 handles.push(handle);
1533 }
1534
1535 for handle in handles {
1536 handle.join().unwrap();
1537 }
1538
1539 let stats = server.stats();
1540 assert_eq!(stats.connections_total, 5);
1541
1542 server.stop();
1543 }
1544}