Skip to main content

sochdb_storage/
ipc_server.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Unix Domain Socket IPC Server for SochDB
19//!
20//! Provides a local IPC server that wraps the Database kernel for
21//! multi-process access to a SochDB database.
22//!
23//! # Architecture
24//!
25//! ```text
26//! ┌────────────────────────────────────────────────────────────────┐
27//! │                      IPC Server Process                        │
28//! │  ┌────────────────────────────────────────────────────────┐   │
29//! │  │              Database Kernel (Arc<Database>)           │   │
30//! │  └────────────────────────────────────────────────────────┘   │
31//! │           ▲                    ▲                    ▲         │
32//! │           │                    │                    │         │
33//! │  ┌────────┴────────┐ ┌────────┴────────┐ ┌────────┴────────┐ │
34//! │  │ ClientHandler 1 │ │ ClientHandler 2 │ │ ClientHandler N │ │
35//! │  └────────┬────────┘ └────────┬────────┘ └────────┬────────┘ │
36//! │           │                    │                    │         │
37//! │  ┌────────┴────────────────────┴────────────────────┴────────┐│
38//! │  │              Unix Domain Socket Listener                  ││
39//! │  │                  /tmp/sochdb-<id>.sock                    ││
40//! │  └───────────────────────────────────────────────────────────┘│
41//! └────────────────────────────────────────────────────────────────┘
42//!          ▲                    ▲                    ▲
43//!          │ Unix Socket        │ Unix Socket        │ Unix Socket
44//!   ┌──────┴──────┐      ┌──────┴──────┐      ┌──────┴──────┐
45//!   │  Client 1   │      │  Client 2   │      │  Client N   │
46//!   │  (Process)  │      │  (Process)  │      │  (Process)  │
47//!   └─────────────┘      └─────────────┘      └─────────────┘
48//! ```
49//!
50//! # Wire Protocol
51//!
52//! All messages use a simple length-prefixed binary format:
53//!
54//! ```text
55//! ┌──────────────────────────────────────────────────────────────┐
56//! │  OpCode (1 byte)  │  Length (4 bytes LE)  │  Payload (N)    │
57//! └──────────────────────────────────────────────────────────────┘
58//! ```
59//!
60//! ## OpCodes
61//!
62//! | Code | Name          | Direction | Description                    |
63//! |------|---------------|-----------|--------------------------------|
64//! | 0x01 | PUT           | C→S       | Put key-value pair             |
65//! | 0x02 | GET           | C→S       | Get value by key               |
66//! | 0x03 | DELETE        | C→S       | Delete key                     |
67//! | 0x04 | BEGIN_TXN     | C→S       | Start transaction              |
68//! | 0x05 | COMMIT_TXN    | C→S       | Commit transaction             |
69//! | 0x06 | ABORT_TXN     | C→S       | Abort transaction              |
70//! | 0x07 | QUERY         | C→S       | Execute query                  |
71//! | 0x08 | CREATE_TABLE  | C→S       | Create table                   |
72//! | 0x09 | PUT_PATH      | C→S       | Put hierarchical path          |
73//! | 0x0A | GET_PATH      | C→S       | Get by hierarchical path       |
74//! | 0x0B | SCAN          | C→S       | Scan key range                 |
75//! | 0x0C | CHECKPOINT    | C→S       | Force checkpoint               |
76//! | 0x0D | STATS         | C→S       | Get database stats             |
77//! |------|---------------|-----------|--------------------------------|
78//! | 0x80 | OK            | S→C       | Success response               |
79//! | 0x81 | ERROR         | S→C       | Error response                 |
80//! | 0x82 | VALUE         | S→C       | Value response                 |
81//! | 0x83 | TXN_ID        | S→C       | Transaction ID response        |
82//! | 0x84 | ROW           | S→C       | Query result row (streaming)   |
83//! | 0x85 | END_STREAM    | S→C       | End of streaming results       |
84//! | 0x86 | STATS_RESP    | S→C       | Stats response                 |
85
86use crate::database::{Database, TxnHandle};
87use parking_lot::Mutex;
88use std::collections::HashMap;
89use std::io::{BufReader, BufWriter, Read, Write};
90use std::os::unix::fs::PermissionsExt;
91use std::os::unix::net::{UnixListener, UnixStream};
92use std::path::{Path, PathBuf};
93use std::sync::Arc;
94use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
95use std::thread::{self, JoinHandle};
96use std::time::{Duration, Instant};
97use thiserror::Error;
98
99// ============================================================================
100// Wire Protocol Constants
101// ============================================================================
102
103/// Client → Server opcodes
104mod opcode {
105    pub const PUT: u8 = 0x01;
106    pub const GET: u8 = 0x02;
107    pub const DELETE: u8 = 0x03;
108    pub const BEGIN_TXN: u8 = 0x04;
109    pub const COMMIT_TXN: u8 = 0x05;
110    pub const ABORT_TXN: u8 = 0x06;
111    pub const QUERY: u8 = 0x07;
112    pub const CREATE_TABLE: u8 = 0x08;
113    pub const PUT_PATH: u8 = 0x09;
114    pub const GET_PATH: u8 = 0x0A;
115    pub const SCAN: u8 = 0x0B;
116    pub const CHECKPOINT: u8 = 0x0C;
117    pub const STATS: u8 = 0x0D;
118    pub const PING: u8 = 0x0E;
119    pub const EXECUTE_SQL: u8 = 0x0F;
120
121    /// Server → Client response opcodes
122    pub const OK: u8 = 0x80;
123    pub const ERROR: u8 = 0x81;
124    pub const VALUE: u8 = 0x82;
125    pub const TXN_ID: u8 = 0x83;
126    #[allow(dead_code)]
127    pub const ROW: u8 = 0x84;
128    #[allow(dead_code)]
129    pub const END_STREAM: u8 = 0x85;
130    pub const STATS_RESP: u8 = 0x86;
131    pub const PONG: u8 = 0x87;
132}
133
134// Maximum message size (16 MB)
135const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
136
137// ============================================================================
138// Error Types
139// ============================================================================
140
141#[derive(Debug, Error)]
142pub enum IpcError {
143    #[error("I/O error: {0}")]
144    Io(#[from] std::io::Error),
145
146    #[error("Database error: {0}")]
147    Database(String),
148
149    #[error("Protocol error: {0}")]
150    Protocol(String),
151
152    #[error("Server already running")]
153    AlreadyRunning,
154
155    #[error("Server not running")]
156    NotRunning,
157
158    #[error("Connection closed")]
159    ConnectionClosed,
160
161    #[error("Message too large: {0} bytes (max: {1})")]
162    MessageTooLarge(usize, usize),
163
164    #[error("Invalid opcode: {0:#x}")]
165    InvalidOpcode(u8),
166
167    #[error("Transaction not found: {0}")]
168    TxnNotFound(u64),
169}
170
171pub type Result<T> = std::result::Result<T, IpcError>;
172
173// ============================================================================
174// Wire Protocol Implementation
175// ============================================================================
176
177/// Message frame for the wire protocol
178#[derive(Debug, Clone)]
179pub struct Message {
180    pub opcode: u8,
181    pub payload: Vec<u8>,
182}
183
184impl Message {
185    pub fn new(opcode: u8, payload: Vec<u8>) -> Self {
186        Self { opcode, payload }
187    }
188
189    pub fn ok() -> Self {
190        Self::new(opcode::OK, vec![])
191    }
192
193    pub fn error(msg: &str) -> Self {
194        Self::new(opcode::ERROR, msg.as_bytes().to_vec())
195    }
196
197    pub fn value(data: Vec<u8>) -> Self {
198        Self::new(opcode::VALUE, data)
199    }
200
201    pub fn txn_id(id: u64) -> Self {
202        Self::new(opcode::TXN_ID, id.to_le_bytes().to_vec())
203    }
204
205    /// Read a message from a stream
206    pub fn read_from<R: Read>(reader: &mut R) -> Result<Self> {
207        // Read opcode (1 byte)
208        let mut opcode_buf = [0u8; 1];
209        match reader.read_exact(&mut opcode_buf) {
210            Ok(_) => {}
211            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
212                return Err(IpcError::ConnectionClosed);
213            }
214            Err(e) => return Err(e.into()),
215        }
216        let opcode = opcode_buf[0];
217
218        // Read length (4 bytes, little-endian)
219        let mut len_buf = [0u8; 4];
220        reader.read_exact(&mut len_buf)?;
221        let len = u32::from_le_bytes(len_buf) as usize;
222
223        // Validate length
224        if len > MAX_MESSAGE_SIZE {
225            return Err(IpcError::MessageTooLarge(len, MAX_MESSAGE_SIZE));
226        }
227
228        // Read payload
229        let mut payload = vec![0u8; len];
230        if len > 0 {
231            reader.read_exact(&mut payload)?;
232        }
233
234        Ok(Self { opcode, payload })
235    }
236
237    /// Write a message to a stream
238    pub fn write_to<W: Write>(&self, writer: &mut W) -> Result<()> {
239        // Write opcode
240        writer.write_all(&[self.opcode])?;
241
242        // Write length
243        let len = self.payload.len() as u32;
244        writer.write_all(&len.to_le_bytes())?;
245
246        // Write payload
247        if !self.payload.is_empty() {
248            writer.write_all(&self.payload)?;
249        }
250
251        writer.flush()?;
252        Ok(())
253    }
254}
255
256// ============================================================================
257// Request/Response Encoding
258// ============================================================================
259
260/// Encode a PUT request payload: key_len (4) + key + value
261fn encode_put(key: &[u8], value: &[u8]) -> Vec<u8> {
262    let mut buf = Vec::with_capacity(4 + key.len() + value.len());
263    buf.extend_from_slice(&(key.len() as u32).to_le_bytes());
264    buf.extend_from_slice(key);
265    buf.extend_from_slice(value);
266    buf
267}
268
269/// Decode a PUT request payload
270fn decode_put(payload: &[u8]) -> Result<(&[u8], &[u8])> {
271    if payload.len() < 4 {
272        return Err(IpcError::Protocol("PUT payload too short".into()));
273    }
274    let key_len = u32::from_le_bytes(payload[0..4].try_into().unwrap()) as usize;
275    if payload.len() < 4 + key_len {
276        return Err(IpcError::Protocol("PUT payload key truncated".into()));
277    }
278    let key = &payload[4..4 + key_len];
279    let value = &payload[4 + key_len..];
280    Ok((key, value))
281}
282
283/// Encode a path PUT request: path_count (2) + [path_len (2) + path]... + value
284fn encode_put_path(path: &[&str], value: &[u8]) -> Vec<u8> {
285    let mut buf = Vec::new();
286    buf.extend_from_slice(&(path.len() as u16).to_le_bytes());
287    for segment in path {
288        let seg_bytes = segment.as_bytes();
289        buf.extend_from_slice(&(seg_bytes.len() as u16).to_le_bytes());
290        buf.extend_from_slice(seg_bytes);
291    }
292    buf.extend_from_slice(value);
293    buf
294}
295
296/// Decode a path request
297fn decode_path(payload: &[u8]) -> Result<(Vec<String>, &[u8])> {
298    if payload.len() < 2 {
299        return Err(IpcError::Protocol("Path payload too short".into()));
300    }
301    let count = u16::from_le_bytes(payload[0..2].try_into().unwrap()) as usize;
302    let mut offset = 2;
303    let mut path = Vec::with_capacity(count);
304
305    for _ in 0..count {
306        if offset + 2 > payload.len() {
307            return Err(IpcError::Protocol("Path segment length truncated".into()));
308        }
309        let seg_len = u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
310        offset += 2;
311        if offset + seg_len > payload.len() {
312            return Err(IpcError::Protocol("Path segment truncated".into()));
313        }
314        let segment = std::str::from_utf8(&payload[offset..offset + seg_len])
315            .map_err(|_| IpcError::Protocol("Invalid UTF-8 in path".into()))?;
316        path.push(segment.to_string());
317        offset += seg_len;
318    }
319
320    Ok((path, &payload[offset..]))
321}
322
323// ============================================================================
324// Server Statistics
325// ============================================================================
326
327#[derive(Debug, Default)]
328pub struct ServerStats {
329    pub connections_total: AtomicU64,
330    pub connections_active: AtomicU64,
331    pub requests_total: AtomicU64,
332    pub requests_success: AtomicU64,
333    pub requests_error: AtomicU64,
334    pub bytes_received: AtomicU64,
335    pub bytes_sent: AtomicU64,
336    pub start_time: Mutex<Option<Instant>>,
337}
338
339impl ServerStats {
340    pub fn new() -> Self {
341        Self::default()
342    }
343
344    pub fn snapshot(&self) -> ServerStatsSnapshot {
345        ServerStatsSnapshot {
346            connections_total: self.connections_total.load(Ordering::Relaxed),
347            connections_active: self.connections_active.load(Ordering::Relaxed),
348            requests_total: self.requests_total.load(Ordering::Relaxed),
349            requests_success: self.requests_success.load(Ordering::Relaxed),
350            requests_error: self.requests_error.load(Ordering::Relaxed),
351            bytes_received: self.bytes_received.load(Ordering::Relaxed),
352            bytes_sent: self.bytes_sent.load(Ordering::Relaxed),
353            uptime_secs: self
354                .start_time
355                .lock()
356                .map(|t| t.elapsed().as_secs())
357                .unwrap_or(0),
358        }
359    }
360}
361
362#[derive(Debug, Clone)]
363pub struct ServerStatsSnapshot {
364    pub connections_total: u64,
365    pub connections_active: u64,
366    pub requests_total: u64,
367    pub requests_success: u64,
368    pub requests_error: u64,
369    pub bytes_received: u64,
370    pub bytes_sent: u64,
371    pub uptime_secs: u64,
372}
373
374// ============================================================================
375// Client Connection Handler
376// ============================================================================
377
378struct ClientHandler {
379    db: Arc<Database>,
380    stream: UnixStream,
381    stats: Arc<ServerStats>,
382    active_txns: HashMap<u64, TxnHandle>, // client_txn_id → TxnHandle
383    next_client_txn_id: u64,
384}
385
386impl ClientHandler {
387    fn new(db: Arc<Database>, stream: UnixStream, stats: Arc<ServerStats>) -> Self {
388        Self {
389            db,
390            stream,
391            stats,
392            active_txns: HashMap::new(),
393            next_client_txn_id: 1,
394        }
395    }
396
397    fn handle(&mut self) -> Result<()> {
398        // Set read timeout for graceful shutdown detection
399        self.stream
400            .set_read_timeout(Some(Duration::from_secs(30)))?;
401
402        let mut reader = BufReader::new(self.stream.try_clone()?);
403        let mut writer = BufWriter::new(self.stream.try_clone()?);
404
405        loop {
406            // Read request
407            let request = match Message::read_from(&mut reader) {
408                Ok(msg) => msg,
409                Err(IpcError::ConnectionClosed) => {
410                    // Clean shutdown - abort any pending transactions
411                    self.cleanup_transactions();
412                    return Ok(());
413                }
414                Err(e) => return Err(e),
415            };
416
417            self.stats.requests_total.fetch_add(1, Ordering::Relaxed);
418            self.stats
419                .bytes_received
420                .fetch_add((5 + request.payload.len()) as u64, Ordering::Relaxed);
421
422            // Process request
423            let response = self.process_request(&request);
424
425            // Track success/error
426            if response.opcode == opcode::ERROR {
427                self.stats.requests_error.fetch_add(1, Ordering::Relaxed);
428            } else {
429                self.stats.requests_success.fetch_add(1, Ordering::Relaxed);
430            }
431
432            // Send response
433            self.stats
434                .bytes_sent
435                .fetch_add((5 + response.payload.len()) as u64, Ordering::Relaxed);
436            response.write_to(&mut writer)?;
437        }
438    }
439
440    fn process_request(&mut self, request: &Message) -> Message {
441        match request.opcode {
442            opcode::PING => Message::new(opcode::PONG, vec![]),
443
444            opcode::PUT => self.handle_put(&request.payload),
445            opcode::GET => self.handle_get(&request.payload),
446            opcode::DELETE => self.handle_delete(&request.payload),
447
448            opcode::BEGIN_TXN => self.handle_begin_txn(),
449            opcode::COMMIT_TXN => self.handle_commit_txn(&request.payload),
450            opcode::ABORT_TXN => self.handle_abort_txn(&request.payload),
451
452            opcode::PUT_PATH => self.handle_put_path(&request.payload),
453            opcode::GET_PATH => self.handle_get_path(&request.payload),
454
455            opcode::QUERY => self.handle_query(&request.payload),
456            opcode::CREATE_TABLE => self.handle_create_table(&request.payload),
457            opcode::SCAN => self.handle_scan(&request.payload),
458            opcode::EXECUTE_SQL => self.handle_execute_sql(&request.payload),
459
460            opcode::CHECKPOINT => self.handle_checkpoint(),
461            opcode::STATS => self.handle_stats(),
462
463            _ => Message::error(&format!("Unknown opcode: {:#x}", request.opcode)),
464        }
465    }
466
467    fn handle_execute_sql(&self, payload: &[u8]) -> Message {
468        // Payload: SQL query string (UTF-8)
469        let sql = match std::str::from_utf8(payload) {
470            Ok(s) => s,
471            Err(_) => return Message::error("Invalid UTF-8 in SQL query"),
472        };
473
474        // For now, return error indicating SQL execution happens client-side
475        // The Go SDK will need to implement SQL-to-KV mapping like Python does
476        let result = serde_json::json!({
477            "error": "SQL execution must be implemented client-side. Use Python SDK for full SQL support.",
478            "sql": sql
479        });
480
481        match serde_json::to_vec(&result) {
482            Ok(json) => Message::value(json),
483            Err(e) => Message::error(&format!("Failed to serialize error: {}", e)),
484        }
485    }
486
487    fn handle_query(&self, payload: &[u8]) -> Message {
488        // Payload: path_len(2) + path + limit(4) + offset(4) + cols_count(2) + [col_len(2) + col]...
489        let mut offset = 0;
490
491        if payload.len() < 2 {
492            return Message::error("Query payload too short");
493        }
494
495        // Path
496        let path_len = u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
497        offset += 2;
498        if offset + path_len > payload.len() {
499            return Message::error("Query path truncated");
500        }
501        let path = match std::str::from_utf8(&payload[offset..offset + path_len]) {
502            Ok(s) => s,
503            Err(_) => return Message::error("Invalid UTF-8 in query path"),
504        };
505        offset += path_len;
506
507        // Limit & Offset
508        if offset + 8 > payload.len() {
509            return Message::error("Query limit/offset truncated");
510        }
511        let limit_val =
512            u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
513        offset += 4;
514        let offset_val =
515            u32::from_le_bytes(payload[offset..offset + 4].try_into().unwrap()) as usize;
516        offset += 4;
517
518        // Columns
519        if offset + 2 > payload.len() {
520            return Message::error("Query columns count truncated");
521        }
522        let cols_count =
523            u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
524        offset += 2;
525
526        let mut columns = Vec::with_capacity(cols_count);
527        for _ in 0..cols_count {
528            if offset + 2 > payload.len() {
529                return Message::error("Query column length truncated");
530            }
531            let col_len =
532                u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
533            offset += 2;
534
535            if offset + col_len > payload.len() {
536                return Message::error("Query column truncated");
537            }
538            let col = match std::str::from_utf8(&payload[offset..offset + col_len]) {
539                Ok(s) => s.to_string(),
540                Err(_) => return Message::error("Invalid UTF-8 in query column"),
541            };
542            columns.push(col);
543            offset += col_len;
544        }
545
546        // Execute query
547        // Note: Query is read-only, so we can use a read transaction
548        let txn = match self.db.begin_transaction() {
549            Ok(t) => t,
550            Err(e) => return Message::error(&e.to_string()),
551        };
552
553        let mut builder = self.db.query(txn, path);
554
555        if limit_val > 0 {
556            builder = builder.limit(limit_val);
557        }
558        if offset_val > 0 {
559            builder = builder.offset(offset_val);
560        }
561
562        if !columns.is_empty() {
563            let cols_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
564            builder = builder.columns(&cols_refs);
565        }
566
567        let result = builder.to_toon();
568        let _ = self.db.abort(txn); // Read-only
569
570        match result {
571            Ok(soch_str) => Message::new(opcode::VALUE, soch_str.into_bytes()),
572            Err(e) => Message::error(&e.to_string()),
573        }
574    }
575
576    fn handle_scan(&self, payload: &[u8]) -> Message {
577        let prefix = match std::str::from_utf8(payload) {
578            Ok(s) => s,
579            Err(_) => return Message::error("Invalid UTF-8 in scan prefix"),
580        };
581
582        let txn = match self.db.begin_transaction() {
583            Ok(t) => t,
584            Err(e) => return Message::error(&e.to_string()),
585        };
586
587        let result = self.db.scan_path(txn, prefix);
588        let _ = self.db.abort(txn);
589
590        match result {
591            Ok(items) => {
592                // Format as simple newline-separated key=value for now
593                // Or maybe JSON? Let's use a simple custom format:
594                // count(4) + [key_len(2) + key + val_len(4) + val]...
595                let mut buf = Vec::new();
596                buf.extend_from_slice(&(items.len() as u32).to_le_bytes());
597
598                for (key, val) in items {
599                    let key_bytes = key.as_bytes();
600                    buf.extend_from_slice(&(key_bytes.len() as u16).to_le_bytes());
601                    buf.extend_from_slice(key_bytes);
602                    buf.extend_from_slice(&(val.len() as u32).to_le_bytes());
603                    buf.extend_from_slice(&val);
604                }
605
606                Message::new(opcode::VALUE, buf)
607            }
608            Err(e) => Message::error(&e.to_string()),
609        }
610    }
611
612    fn handle_create_table(&self, payload: &[u8]) -> Message {
613        // Payload: JSON schema definition
614        let _schema_json = match std::str::from_utf8(payload) {
615            Ok(s) => s,
616            Err(_) => return Message::error("Invalid UTF-8 in schema"),
617        };
618
619        // We need a way to parse TableSchema from JSON.
620        // Since we don't have serde_json derived for TableSchema in database.rs (it's in sochdb-storage, but TableSchema is in database.rs),
621        // we might need to manually parse or assume it's passed as a specific format.
622        // Let's assume for now we can use serde_json if we add the dependency or if it's already there.
623        // Checking Cargo.toml... serde_json is there.
624        // But TableSchema struct in database.rs doesn't derive Deserialize.
625        // I'll need to define a local struct or use a helper.
626
627        // For now, let's implement a simple manual parser or just error out saying "Not implemented fully"
628        // but the plan said "Parse payload: Schema definition".
629        // Let's try to use a simple custom binary format for schema to avoid JSON dependency issues if structs aren't serializable.
630        // Format: name_len(2) + name + col_count(2) + [col_name_len(2) + col_name + type(1) + nullable(1)]...
631
632        let mut offset = 0;
633        if payload.len() < 2 {
634            return Message::error("Schema payload too short");
635        }
636
637        let name_len = u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
638        offset += 2;
639        if offset + name_len > payload.len() {
640            return Message::error("Schema name truncated");
641        }
642        let name = match std::str::from_utf8(&payload[offset..offset + name_len]) {
643            Ok(s) => s.to_string(),
644            Err(_) => return Message::error("Invalid UTF-8 in schema name"),
645        };
646        offset += name_len;
647
648        if offset + 2 > payload.len() {
649            return Message::error("Schema column count truncated");
650        }
651        let col_count =
652            u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
653        offset += 2;
654
655        let mut columns = Vec::with_capacity(col_count);
656        for _ in 0..col_count {
657            if offset + 2 > payload.len() {
658                return Message::error("Column name length truncated");
659            }
660            let col_name_len =
661                u16::from_le_bytes(payload[offset..offset + 2].try_into().unwrap()) as usize;
662            offset += 2;
663
664            if offset + col_name_len > payload.len() {
665                return Message::error("Column name truncated");
666            }
667            let col_name = match std::str::from_utf8(&payload[offset..offset + col_name_len]) {
668                Ok(s) => s.to_string(),
669                Err(_) => return Message::error("Invalid UTF-8 in column name"),
670            };
671            offset += col_name_len;
672
673            if offset + 2 > payload.len() {
674                return Message::error("Column type/nullable truncated");
675            }
676            let type_byte = payload[offset];
677            offset += 1;
678            let nullable_byte = payload[offset];
679            offset += 1;
680
681            let col_type = match type_byte {
682                0 => crate::database::ColumnType::Int64,
683                1 => crate::database::ColumnType::UInt64,
684                2 => crate::database::ColumnType::Float64,
685                3 => crate::database::ColumnType::Text,
686                4 => crate::database::ColumnType::Binary,
687                5 => crate::database::ColumnType::Bool,
688                _ => return Message::error("Invalid column type"),
689            };
690
691            columns.push(crate::database::ColumnDef {
692                name: col_name,
693                col_type,
694                nullable: nullable_byte != 0,
695            });
696        }
697
698        let schema = crate::database::TableSchema { name, columns };
699
700        match self.db.register_table(schema) {
701            Ok(_) => Message::ok(),
702            Err(e) => Message::error(&e.to_string()),
703        }
704    }
705
706    /// Auto-commit PUT - creates a transaction, writes, commits
707    fn handle_put(&self, payload: &[u8]) -> Message {
708        match decode_put(payload) {
709            Ok((key, value)) => {
710                // Auto-transaction for simple PUT
711                let txn = match self.db.begin_transaction() {
712                    Ok(t) => t,
713                    Err(e) => return Message::error(&e.to_string()),
714                };
715
716                if let Err(e) = self.db.put(txn, key, value) {
717                    let _ = self.db.abort(txn);
718                    return Message::error(&e.to_string());
719                }
720
721                match self.db.commit(txn) {
722                    Ok(_) => Message::ok(),
723                    Err(e) => Message::error(&e.to_string()),
724                }
725            }
726            Err(e) => Message::error(&e.to_string()),
727        }
728    }
729
730    /// Auto-commit GET - creates a read transaction
731    fn handle_get(&self, payload: &[u8]) -> Message {
732        // Auto-transaction for simple GET
733        let txn = match self.db.begin_transaction() {
734            Ok(t) => t,
735            Err(e) => return Message::error(&e.to_string()),
736        };
737
738        let result = self.db.get(txn, payload);
739        let _ = self.db.abort(txn); // Abort is fine for read-only
740
741        match result {
742            Ok(Some(value)) => Message::value(value),
743            Ok(None) => Message::new(opcode::VALUE, vec![]),
744            Err(e) => Message::error(&e.to_string()),
745        }
746    }
747
748    /// Auto-commit DELETE
749    fn handle_delete(&self, payload: &[u8]) -> Message {
750        let txn = match self.db.begin_transaction() {
751            Ok(t) => t,
752            Err(e) => return Message::error(&e.to_string()),
753        };
754
755        if let Err(e) = self.db.delete(txn, payload) {
756            let _ = self.db.abort(txn);
757            return Message::error(&e.to_string());
758        }
759
760        match self.db.commit(txn) {
761            Ok(_) => Message::ok(),
762            Err(e) => Message::error(&e.to_string()),
763        }
764    }
765
766    fn handle_begin_txn(&mut self) -> Message {
767        match self.db.begin_transaction() {
768            Ok(txn) => {
769                let client_txn_id = self.next_client_txn_id;
770                self.next_client_txn_id += 1;
771                self.active_txns.insert(client_txn_id, txn);
772                Message::txn_id(client_txn_id)
773            }
774            Err(e) => Message::error(&e.to_string()),
775        }
776    }
777
778    fn handle_commit_txn(&mut self, payload: &[u8]) -> Message {
779        if payload.len() < 8 {
780            return Message::error("COMMIT_TXN requires txn_id");
781        }
782        let client_txn_id = u64::from_le_bytes(payload[0..8].try_into().unwrap());
783
784        match self.active_txns.remove(&client_txn_id) {
785            Some(txn) => match self.db.commit(txn) {
786                Ok(commit_ts) => Message::txn_id(commit_ts),
787                Err(e) => Message::error(&e.to_string()),
788            },
789            None => Message::error(&format!("Transaction not found: {}", client_txn_id)),
790        }
791    }
792
793    fn handle_abort_txn(&mut self, payload: &[u8]) -> Message {
794        if payload.len() < 8 {
795            return Message::error("ABORT_TXN requires txn_id");
796        }
797        let client_txn_id = u64::from_le_bytes(payload[0..8].try_into().unwrap());
798
799        match self.active_txns.remove(&client_txn_id) {
800            Some(txn) => match self.db.abort(txn) {
801                Ok(_) => Message::ok(),
802                Err(e) => Message::error(&e.to_string()),
803            },
804            None => Message::error(&format!("Transaction not found: {}", client_txn_id)),
805        }
806    }
807
808    fn handle_put_path(&self, payload: &[u8]) -> Message {
809        match decode_path(payload) {
810            Ok((path, value)) => {
811                let txn = match self.db.begin_transaction() {
812                    Ok(t) => t,
813                    Err(e) => return Message::error(&e.to_string()),
814                };
815
816                let path_str = path.join("/");
817                if let Err(e) = self.db.put_path(txn, &path_str, value) {
818                    let _ = self.db.abort(txn);
819                    return Message::error(&e.to_string());
820                }
821
822                match self.db.commit(txn) {
823                    Ok(_) => Message::ok(),
824                    Err(e) => Message::error(&e.to_string()),
825                }
826            }
827            Err(e) => Message::error(&e.to_string()),
828        }
829    }
830
831    fn handle_get_path(&self, payload: &[u8]) -> Message {
832        match decode_path(payload) {
833            Ok((path, _)) => {
834                let txn = match self.db.begin_transaction() {
835                    Ok(t) => t,
836                    Err(e) => return Message::error(&e.to_string()),
837                };
838
839                let path_str = path.join("/");
840                let result = self.db.get_path(txn, &path_str);
841                let _ = self.db.abort(txn);
842
843                match result {
844                    Ok(Some(value)) => Message::value(value),
845                    Ok(None) => Message::new(opcode::VALUE, vec![]),
846                    Err(e) => Message::error(&e.to_string()),
847                }
848            }
849            Err(e) => Message::error(&e.to_string()),
850        }
851    }
852
853    fn handle_checkpoint(&self) -> Message {
854        match self.db.checkpoint() {
855            Ok(_) => Message::ok(),
856            Err(e) => Message::error(&e.to_string()),
857        }
858    }
859
860    fn handle_stats(&self) -> Message {
861        let stats = self.stats.snapshot();
862        // Encode stats as JSON for SDK compatibility
863        let stats_json = format!(
864            r#"{{"connections_total":{},"connections_active":{},"requests_total":{},"requests_success":{},"requests_error":{},"bytes_received":{},"bytes_sent":{},"uptime_secs":{},"memtable_size_bytes":0,"wal_size_bytes":0,"active_transactions":{}}}"#,
865            stats.connections_total,
866            stats.connections_active,
867            stats.requests_total,
868            stats.requests_success,
869            stats.requests_error,
870            stats.bytes_received,
871            stats.bytes_sent,
872            stats.uptime_secs,
873            self.active_txns.len()
874        );
875        Message::new(opcode::STATS_RESP, stats_json.into_bytes())
876    }
877
878    fn cleanup_transactions(&mut self) {
879        // Abort all pending transactions for this client
880        for (_client_id, txn) in self.active_txns.drain() {
881            let _ = self.db.abort(txn);
882        }
883    }
884}
885
886// ============================================================================
887// IPC Server
888// ============================================================================
889
890/// Configuration for the IPC server
891#[derive(Debug, Clone)]
892pub struct IpcServerConfig {
893    /// Path to the Unix socket file
894    pub socket_path: PathBuf,
895
896    /// Maximum number of concurrent connections
897    pub max_connections: usize,
898
899    /// Thread pool size for handling connections
900    pub thread_pool_size: usize,
901
902    /// Connection timeout in seconds
903    pub connection_timeout_secs: u64,
904}
905
906impl Default for IpcServerConfig {
907    fn default() -> Self {
908        Self {
909            socket_path: PathBuf::from("/tmp/sochdb.sock"),
910            max_connections: 100,
911            thread_pool_size: 4,
912            connection_timeout_secs: 300, // 5 minutes
913        }
914    }
915}
916
917/// Restrict a freshly-bound Unix socket to owner-only access (mode 0600).
918///
919/// Unix-domain sockets honor filesystem permissions on connect, so tightening
920/// the socket file to the owner prevents other local users from connecting to
921/// the unauthenticated IPC endpoint. Best-effort: a failure to chmod is logged
922/// but does not abort startup (the bind already succeeded).
923fn secure_socket_path(path: &Path) {
924    if let Err(e) = std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600)) {
925        eprintln!(
926            "[IpcServer] WARNING: failed to restrict socket permissions on {:?}: {}",
927            path, e
928        );
929    }
930}
931
932impl IpcServerConfig {
933    pub fn with_socket_path<P: AsRef<Path>>(mut self, path: P) -> Self {
934        self.socket_path = path.as_ref().to_path_buf();
935        self
936    }
937
938    pub fn with_max_connections(mut self, max: usize) -> Self {
939        self.max_connections = max;
940        self
941    }
942}
943
944/// Unix Domain Socket IPC Server
945pub struct IpcServer {
946    db: Arc<Database>,
947    config: IpcServerConfig,
948    stats: Arc<ServerStats>,
949    running: Arc<AtomicBool>,
950    listener_handle: Mutex<Option<JoinHandle<()>>>,
951}
952
953impl IpcServer {
954    /// Create a new IPC server for the given database
955    pub fn new(db: Arc<Database>, config: IpcServerConfig) -> Self {
956        Self {
957            db,
958            config,
959            stats: Arc::new(ServerStats::new()),
960            running: Arc::new(AtomicBool::new(false)),
961            listener_handle: Mutex::new(None),
962        }
963    }
964
965    /// Create with default configuration
966    pub fn with_defaults(db: Arc<Database>) -> Self {
967        Self::new(db, IpcServerConfig::default())
968    }
969
970    /// Start the server (blocking)
971    pub fn run(&self) -> Result<()> {
972        if self.running.swap(true, Ordering::SeqCst) {
973            return Err(IpcError::AlreadyRunning);
974        }
975
976        // Remove existing socket file if present
977        if self.config.socket_path.exists() {
978            std::fs::remove_file(&self.config.socket_path)?;
979        }
980
981        // Create listener
982        let listener = UnixListener::bind(&self.config.socket_path)?;
983        listener.set_nonblocking(false)?;
984
985        // Restrict socket to owner-only access (defense in depth for local IPC)
986        secure_socket_path(&self.config.socket_path);
987
988        // Record start time
989        *self.stats.start_time.lock() = Some(Instant::now());
990
991        eprintln!("[IpcServer] Listening on {:?}", self.config.socket_path);
992
993        // Accept connections
994        while self.running.load(Ordering::SeqCst) {
995            match listener.accept() {
996                Ok((stream, _addr)) => {
997                    // Check connection limit
998                    let active = self.stats.connections_active.load(Ordering::Relaxed);
999                    if active >= self.config.max_connections as u64 {
1000                        eprintln!("[IpcServer] Connection limit reached, rejecting");
1001                        continue;
1002                    }
1003
1004                    self.stats.connections_total.fetch_add(1, Ordering::Relaxed);
1005                    self.stats
1006                        .connections_active
1007                        .fetch_add(1, Ordering::Relaxed);
1008
1009                    let db = Arc::clone(&self.db);
1010                    let stats = Arc::clone(&self.stats);
1011
1012                    // Spawn handler thread
1013                    thread::spawn(move || {
1014                        let mut handler = ClientHandler::new(db, stream, Arc::clone(&stats));
1015                        if let Err(e) = handler.handle() {
1016                            eprintln!("[IpcServer] Client error: {}", e);
1017                        }
1018                        stats.connections_active.fetch_sub(1, Ordering::Relaxed);
1019                    });
1020                }
1021                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1022                    // Non-blocking timeout, check if we should stop
1023                    thread::sleep(Duration::from_millis(100));
1024                }
1025                Err(e) => {
1026                    eprintln!("[IpcServer] Accept error: {}", e);
1027                }
1028            }
1029        }
1030
1031        // Cleanup
1032        let _ = std::fs::remove_file(&self.config.socket_path);
1033
1034        Ok(())
1035    }
1036
1037    /// Start the server in a background thread
1038    pub fn start(&self) -> Result<()> {
1039        if self.running.swap(true, Ordering::SeqCst) {
1040            return Err(IpcError::AlreadyRunning);
1041        }
1042
1043        let db = Arc::clone(&self.db);
1044        let config = self.config.clone();
1045        let stats = Arc::clone(&self.stats);
1046        let running = Arc::clone(&self.running);
1047
1048        let handle = thread::spawn(move || {
1049            // Run the server loop directly (flag already set by start())
1050            // Remove existing socket file if present
1051            if config.socket_path.exists() {
1052                let _ = std::fs::remove_file(&config.socket_path);
1053            }
1054
1055            // Create listener
1056            let listener = match UnixListener::bind(&config.socket_path) {
1057                Ok(l) => l,
1058                Err(e) => {
1059                    eprintln!("[IpcServer] Failed to bind: {}", e);
1060                    running.store(false, Ordering::SeqCst);
1061                    return;
1062                }
1063            };
1064            let _ = listener.set_nonblocking(false);
1065
1066            // Restrict socket to owner-only access (defense in depth for local IPC)
1067            secure_socket_path(&config.socket_path);
1068
1069            // Record start time
1070            *stats.start_time.lock() = Some(Instant::now());
1071
1072            eprintln!("[IpcServer] Listening on {:?}", config.socket_path);
1073
1074            // Accept connections
1075            while running.load(Ordering::SeqCst) {
1076                match listener.accept() {
1077                    Ok((stream, _addr)) => {
1078                        // Check connection limit
1079                        let active = stats.connections_active.load(Ordering::Relaxed);
1080                        if active >= config.max_connections as u64 {
1081                            eprintln!("[IpcServer] Connection limit reached, rejecting");
1082                            continue;
1083                        }
1084
1085                        stats.connections_total.fetch_add(1, Ordering::Relaxed);
1086                        stats.connections_active.fetch_add(1, Ordering::Relaxed);
1087
1088                        let db_clone = Arc::clone(&db);
1089                        let stats_clone = Arc::clone(&stats);
1090
1091                        // Spawn handler thread
1092                        thread::spawn(move || {
1093                            let mut handler =
1094                                ClientHandler::new(db_clone, stream, Arc::clone(&stats_clone));
1095                            if let Err(e) = handler.handle() {
1096                                eprintln!("[IpcServer] Client error: {}", e);
1097                            }
1098                            stats_clone
1099                                .connections_active
1100                                .fetch_sub(1, Ordering::Relaxed);
1101                        });
1102                    }
1103                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
1104                        // Non-blocking timeout, check if we should stop
1105                        thread::sleep(Duration::from_millis(100));
1106                    }
1107                    Err(e) => {
1108                        eprintln!("[IpcServer] Accept error: {}", e);
1109                        break;
1110                    }
1111                }
1112            }
1113
1114            // Cleanup
1115            let _ = std::fs::remove_file(&config.socket_path);
1116        });
1117
1118        *self.listener_handle.lock() = Some(handle);
1119        Ok(())
1120    }
1121
1122    /// Stop the server
1123    pub fn stop(&self) {
1124        self.running.store(false, Ordering::SeqCst);
1125
1126        // Connect to socket to wake up accept() if blocking
1127        let _ = UnixStream::connect(&self.config.socket_path);
1128
1129        // Wait for listener thread
1130        if let Some(handle) = self.listener_handle.lock().take() {
1131            let _ = handle.join();
1132        }
1133    }
1134
1135    /// Check if server is running
1136    pub fn is_running(&self) -> bool {
1137        self.running.load(Ordering::SeqCst)
1138    }
1139
1140    /// Get server statistics
1141    pub fn stats(&self) -> ServerStatsSnapshot {
1142        self.stats.snapshot()
1143    }
1144
1145    /// Get socket path
1146    pub fn socket_path(&self) -> &Path {
1147        &self.config.socket_path
1148    }
1149}
1150
1151impl Drop for IpcServer {
1152    fn drop(&mut self) {
1153        self.stop();
1154    }
1155}
1156
1157// ============================================================================
1158// IPC Client (for connecting to server from another process)
1159// ============================================================================
1160
1161/// Client for connecting to an IPC server
1162pub struct IpcClient {
1163    stream: UnixStream,
1164}
1165
1166impl IpcClient {
1167    /// Connect to an IPC server
1168    pub fn connect<P: AsRef<Path>>(socket_path: P) -> Result<Self> {
1169        let stream = UnixStream::connect(socket_path)?;
1170        Ok(Self { stream })
1171    }
1172
1173    /// Connect with timeout
1174    pub fn connect_with_timeout<P: AsRef<Path>>(socket_path: P, timeout: Duration) -> Result<Self> {
1175        let stream = UnixStream::connect(socket_path)?;
1176        stream.set_read_timeout(Some(timeout))?;
1177        stream.set_write_timeout(Some(timeout))?;
1178        Ok(Self { stream })
1179    }
1180
1181    /// Send a request and receive response
1182    fn request(&mut self, msg: Message) -> Result<Message> {
1183        msg.write_to(&mut self.stream)?;
1184        Message::read_from(&mut self.stream)
1185    }
1186
1187    /// Ping the server
1188    pub fn ping(&mut self) -> Result<Duration> {
1189        let start = Instant::now();
1190        let resp = self.request(Message::new(opcode::PING, vec![]))?;
1191        if resp.opcode != opcode::PONG {
1192            return Err(IpcError::Protocol("Expected PONG".into()));
1193        }
1194        Ok(start.elapsed())
1195    }
1196
1197    /// Put a key-value pair
1198    pub fn put(&mut self, key: &[u8], value: &[u8]) -> Result<()> {
1199        let payload = encode_put(key, value);
1200        let resp = self.request(Message::new(opcode::PUT, payload))?;
1201        self.check_ok(resp)
1202    }
1203
1204    /// Get a value by key
1205    pub fn get(&mut self, key: &[u8]) -> Result<Option<Vec<u8>>> {
1206        let resp = self.request(Message::new(opcode::GET, key.to_vec()))?;
1207        match resp.opcode {
1208            opcode::VALUE if resp.payload.is_empty() => Ok(None),
1209            opcode::VALUE => Ok(Some(resp.payload)),
1210            opcode::ERROR => Err(IpcError::Database(
1211                String::from_utf8_lossy(&resp.payload).to_string(),
1212            )),
1213            _ => Err(IpcError::Protocol(format!(
1214                "Unexpected opcode: {:#x}",
1215                resp.opcode
1216            ))),
1217        }
1218    }
1219
1220    /// Delete a key
1221    pub fn delete(&mut self, key: &[u8]) -> Result<()> {
1222        let resp = self.request(Message::new(opcode::DELETE, key.to_vec()))?;
1223        self.check_ok(resp)
1224    }
1225
1226    /// Begin a transaction, returns transaction ID
1227    pub fn begin_txn(&mut self) -> Result<u64> {
1228        let resp = self.request(Message::new(opcode::BEGIN_TXN, vec![]))?;
1229        match resp.opcode {
1230            opcode::TXN_ID => {
1231                if resp.payload.len() >= 8 {
1232                    Ok(u64::from_le_bytes(resp.payload[0..8].try_into().unwrap()))
1233                } else {
1234                    Err(IpcError::Protocol("TXN_ID response too short".into()))
1235                }
1236            }
1237            opcode::ERROR => Err(IpcError::Database(
1238                String::from_utf8_lossy(&resp.payload).to_string(),
1239            )),
1240            _ => Err(IpcError::Protocol(format!(
1241                "Unexpected opcode: {:#x}",
1242                resp.opcode
1243            ))),
1244        }
1245    }
1246
1247    /// Commit a transaction, returns commit timestamp
1248    pub fn commit_txn(&mut self, txn_id: u64) -> Result<u64> {
1249        let resp = self.request(Message::new(
1250            opcode::COMMIT_TXN,
1251            txn_id.to_le_bytes().to_vec(),
1252        ))?;
1253        match resp.opcode {
1254            opcode::TXN_ID => {
1255                if resp.payload.len() >= 8 {
1256                    Ok(u64::from_le_bytes(resp.payload[0..8].try_into().unwrap()))
1257                } else {
1258                    Err(IpcError::Protocol("TXN_ID response too short".into()))
1259                }
1260            }
1261            opcode::ERROR => Err(IpcError::Database(
1262                String::from_utf8_lossy(&resp.payload).to_string(),
1263            )),
1264            _ => Err(IpcError::Protocol(format!(
1265                "Unexpected opcode: {:#x}",
1266                resp.opcode
1267            ))),
1268        }
1269    }
1270
1271    /// Abort a transaction
1272    pub fn abort_txn(&mut self, txn_id: u64) -> Result<()> {
1273        let resp = self.request(Message::new(
1274            opcode::ABORT_TXN,
1275            txn_id.to_le_bytes().to_vec(),
1276        ))?;
1277        self.check_ok(resp)
1278    }
1279
1280    /// Put by hierarchical path
1281    pub fn put_path(&mut self, path: &[&str], value: &[u8]) -> Result<()> {
1282        let payload = encode_put_path(path, value);
1283        let resp = self.request(Message::new(opcode::PUT_PATH, payload))?;
1284        self.check_ok(resp)
1285    }
1286
1287    /// Get by hierarchical path
1288    pub fn get_path(&mut self, path: &[&str]) -> Result<Option<Vec<u8>>> {
1289        let payload = encode_put_path(path, &[]);
1290        let resp = self.request(Message::new(opcode::GET_PATH, payload))?;
1291        match resp.opcode {
1292            opcode::VALUE if resp.payload.is_empty() => Ok(None),
1293            opcode::VALUE => Ok(Some(resp.payload)),
1294            opcode::ERROR => Err(IpcError::Database(
1295                String::from_utf8_lossy(&resp.payload).to_string(),
1296            )),
1297            _ => Err(IpcError::Protocol(format!(
1298                "Unexpected opcode: {:#x}",
1299                resp.opcode
1300            ))),
1301        }
1302    }
1303
1304    /// Force a checkpoint
1305    pub fn checkpoint(&mut self) -> Result<()> {
1306        let resp = self.request(Message::new(opcode::CHECKPOINT, vec![]))?;
1307        self.check_ok(resp)
1308    }
1309
1310    /// Get server statistics
1311    pub fn stats(&mut self) -> Result<HashMap<String, u64>> {
1312        let resp = self.request(Message::new(opcode::STATS, vec![]))?;
1313        match resp.opcode {
1314            opcode::STATS_RESP => {
1315                let stats_str = String::from_utf8_lossy(&resp.payload);
1316
1317                // Parse JSON response
1318                let stats: HashMap<String, u64> =
1319                    serde_json::from_str(&stats_str).map_err(|e| {
1320                        IpcError::Protocol(format!("Failed to parse stats JSON: {}", e))
1321                    })?;
1322
1323                Ok(stats)
1324            }
1325            opcode::ERROR => Err(IpcError::Database(
1326                String::from_utf8_lossy(&resp.payload).to_string(),
1327            )),
1328            _ => Err(IpcError::Protocol(format!(
1329                "Unexpected opcode: {:#x}",
1330                resp.opcode
1331            ))),
1332        }
1333    }
1334
1335    fn check_ok(&self, resp: Message) -> Result<()> {
1336        match resp.opcode {
1337            opcode::OK => Ok(()),
1338            opcode::ERROR => Err(IpcError::Database(
1339                String::from_utf8_lossy(&resp.payload).to_string(),
1340            )),
1341            _ => Err(IpcError::Protocol(format!(
1342                "Unexpected opcode: {:#x}",
1343                resp.opcode
1344            ))),
1345        }
1346    }
1347}
1348
1349// ============================================================================
1350// Tests
1351// ============================================================================
1352
1353#[cfg(test)]
1354mod tests {
1355    use super::*;
1356    use std::time::Duration;
1357    use tempfile::TempDir;
1358
1359    fn setup_test_server() -> (Arc<Database>, TempDir, PathBuf) {
1360        let temp_dir = TempDir::new().unwrap();
1361        let db_path = temp_dir.path().join("test.db");
1362        let socket_path = temp_dir.path().join("test.sock");
1363
1364        let db = Database::open(&db_path).unwrap();
1365        (db, temp_dir, socket_path)
1366    }
1367
1368    #[test]
1369    fn test_message_roundtrip() {
1370        let original = Message::new(0x01, b"hello world".to_vec());
1371
1372        let mut buffer = Vec::new();
1373        original.write_to(&mut buffer).unwrap();
1374
1375        let mut cursor = std::io::Cursor::new(buffer);
1376        let decoded = Message::read_from(&mut cursor).unwrap();
1377
1378        assert_eq!(decoded.opcode, original.opcode);
1379        assert_eq!(decoded.payload, original.payload);
1380    }
1381
1382    #[test]
1383    fn test_encode_decode_put() {
1384        let key = b"test-key";
1385        let value = b"test-value";
1386
1387        let encoded = encode_put(key, value);
1388        let (decoded_key, decoded_value) = decode_put(&encoded).unwrap();
1389
1390        assert_eq!(decoded_key, key);
1391        assert_eq!(decoded_value, value);
1392    }
1393
1394    #[test]
1395    fn test_encode_decode_path() {
1396        let path = vec!["users", "alice", "settings"];
1397        let value = b"preferences";
1398
1399        let encoded = encode_put_path(&path, value);
1400        let (decoded_path, decoded_value) = decode_path(&encoded).unwrap();
1401
1402        let expected_path: Vec<String> = path.iter().map(|s| s.to_string()).collect();
1403        assert_eq!(decoded_path, expected_path);
1404        assert_eq!(decoded_value, value);
1405    }
1406
1407    #[test]
1408    fn test_server_client_basic() {
1409        let (db, _temp_dir, socket_path) = setup_test_server();
1410
1411        // Start server
1412        let config = IpcServerConfig::default().with_socket_path(&socket_path);
1413        let server = IpcServer::new(Arc::clone(&db), config);
1414        server.start().unwrap();
1415
1416        // Wait for server to be ready
1417        thread::sleep(Duration::from_millis(100));
1418
1419        // Connect client
1420        let mut client = IpcClient::connect(&socket_path).unwrap();
1421
1422        // Test ping
1423        let latency = client.ping().unwrap();
1424        assert!(latency < Duration::from_secs(1));
1425
1426        // Test put/get
1427        client.put(b"key1", b"value1").unwrap();
1428        let value = client.get(b"key1").unwrap();
1429        assert_eq!(value, Some(b"value1".to_vec()));
1430
1431        // Test get non-existent
1432        let value = client.get(b"nonexistent").unwrap();
1433        assert_eq!(value, None);
1434
1435        // Test delete
1436        client.delete(b"key1").unwrap();
1437        let value = client.get(b"key1").unwrap();
1438        assert_eq!(value, None);
1439
1440        // Stop server
1441        server.stop();
1442    }
1443
1444    #[test]
1445    fn test_socket_permissions_are_owner_only() {
1446        let (db, _temp_dir, socket_path) = setup_test_server();
1447
1448        let config = IpcServerConfig::default().with_socket_path(&socket_path);
1449        let server = IpcServer::new(Arc::clone(&db), config);
1450        server.start().unwrap();
1451
1452        // Wait for the server to bind and chmod the socket.
1453        thread::sleep(Duration::from_millis(100));
1454
1455        let mode = std::fs::metadata(&socket_path)
1456            .unwrap()
1457            .permissions()
1458            .mode();
1459        // Only the low 9 permission bits matter; the socket must be 0600
1460        // (owner read/write, no group/other access).
1461        assert_eq!(
1462            mode & 0o777,
1463            0o600,
1464            "socket permissions must be owner-only, got {:o}",
1465            mode & 0o777
1466        );
1467
1468        server.stop();
1469    }
1470
1471    #[test]
1472    fn test_server_client_transactions() {
1473        let (db, _temp_dir, socket_path) = setup_test_server();
1474
1475        let config = IpcServerConfig::default().with_socket_path(&socket_path);
1476        let server = IpcServer::new(Arc::clone(&db), config);
1477        server.start().unwrap();
1478
1479        thread::sleep(Duration::from_millis(100));
1480
1481        let mut client = IpcClient::connect(&socket_path).unwrap();
1482
1483        // Begin transaction
1484        let txn_id = client.begin_txn().unwrap();
1485        assert!(txn_id > 0);
1486
1487        // Commit
1488        let commit_ts = client.commit_txn(txn_id).unwrap();
1489        assert!(commit_ts > 0);
1490
1491        // Begin another and abort
1492        let txn_id2 = client.begin_txn().unwrap();
1493        client.abort_txn(txn_id2).unwrap();
1494
1495        server.stop();
1496    }
1497
1498    #[test]
1499    fn test_server_client_paths() {
1500        let (db, _temp_dir, socket_path) = setup_test_server();
1501
1502        let config = IpcServerConfig::default().with_socket_path(&socket_path);
1503        let server = IpcServer::new(Arc::clone(&db), config);
1504        server.start().unwrap();
1505
1506        thread::sleep(Duration::from_millis(100));
1507
1508        let mut client = IpcClient::connect(&socket_path).unwrap();
1509
1510        // Put by path
1511        client
1512            .put_path(&["users", "alice", "email"], b"alice@example.com")
1513            .unwrap();
1514
1515        // Get by path
1516        let value = client.get_path(&["users", "alice", "email"]).unwrap();
1517        assert_eq!(value, Some(b"alice@example.com".to_vec()));
1518
1519        // Get non-existent path
1520        let value = client.get_path(&["users", "bob", "email"]).unwrap();
1521        assert_eq!(value, None);
1522
1523        server.stop();
1524    }
1525
1526    #[test]
1527    fn test_server_stats() {
1528        let (db, _temp_dir, socket_path) = setup_test_server();
1529
1530        let config = IpcServerConfig::default().with_socket_path(&socket_path);
1531        let server = IpcServer::new(Arc::clone(&db), config);
1532        server.start().unwrap();
1533
1534        thread::sleep(Duration::from_millis(100));
1535
1536        let mut client = IpcClient::connect(&socket_path).unwrap();
1537
1538        // Make some requests
1539        client.ping().unwrap();
1540        client.put(b"k", b"v").unwrap();
1541        client.get(b"k").unwrap();
1542
1543        // Get stats
1544        let stats = client.stats().unwrap();
1545        assert!(stats.contains_key("requests_total"));
1546        assert!(*stats.get("requests_total").unwrap() >= 4);
1547
1548        // Check server-side stats
1549        let server_stats = server.stats();
1550        assert!(server_stats.requests_total >= 4);
1551        assert!(server_stats.connections_active >= 1);
1552
1553        server.stop();
1554    }
1555
1556    #[test]
1557    fn test_multiple_clients() {
1558        let (db, _temp_dir, socket_path) = setup_test_server();
1559
1560        let config = IpcServerConfig::default()
1561            .with_socket_path(&socket_path)
1562            .with_max_connections(10);
1563        let server = IpcServer::new(Arc::clone(&db), config);
1564        server.start().unwrap();
1565
1566        thread::sleep(Duration::from_millis(100));
1567
1568        // Connect multiple clients
1569        let mut handles = Vec::new();
1570        let socket_path_clone = socket_path.clone();
1571
1572        for i in 0..5 {
1573            let path = socket_path_clone.clone();
1574            let handle = thread::spawn(move || {
1575                let mut client = IpcClient::connect(&path).unwrap();
1576                let key = format!("key-{}", i);
1577                let value = format!("value-{}", i);
1578
1579                client.put(key.as_bytes(), value.as_bytes()).unwrap();
1580                let result = client.get(key.as_bytes()).unwrap();
1581                assert_eq!(result, Some(value.into_bytes()));
1582            });
1583            handles.push(handle);
1584        }
1585
1586        for handle in handles {
1587            handle.join().unwrap();
1588        }
1589
1590        let stats = server.stats();
1591        assert_eq!(stats.connections_total, 5);
1592
1593        server.stop();
1594    }
1595}