1use crate::error::{KernelError, KernelResult};
61use crate::kernel_api::RowId;
62use crate::wasm_runtime::WasmPluginCapabilities;
63use std::collections::HashMap;
64use std::sync::Arc;
65use std::sync::atomic::{AtomicU64, Ordering};
66
67#[derive(Debug, Clone, PartialEq)]
73pub enum HostCallResult {
74 Success(Vec<u8>),
76 Ok,
78 PermissionDenied(String),
80 NotFound(String),
82 InvalidArgs(String),
84 Error(String),
86}
87
88impl HostCallResult {
89 pub fn status_code(&self) -> i32 {
91 match self {
92 HostCallResult::Success(_) => 0,
93 HostCallResult::Ok => 0,
94 HostCallResult::PermissionDenied(_) => -1,
95 HostCallResult::NotFound(_) => -2,
96 HostCallResult::InvalidArgs(_) => -3,
97 HostCallResult::Error(_) => -4,
98 }
99 }
100
101 pub fn data(&self) -> Option<&[u8]> {
103 match self {
104 HostCallResult::Success(data) => Some(data),
105 _ => None,
106 }
107 }
108}
109
110pub struct HostFunctionContext {
118 pub plugin_name: String,
120 pub capabilities: WasmPluginCapabilities,
122 pub audit_log: Vec<AuditEntry>,
124 pub transaction_id: Option<u64>,
126 pub session_vars: HashMap<String, Vec<u8>>,
128}
129
130#[derive(Debug, Clone)]
132pub struct AuditEntry {
133 pub timestamp_us: u64,
135 pub function: String,
137 pub table: Option<String>,
139 pub status: i32,
141 pub rows_affected: u64,
143}
144
145impl HostFunctionContext {
146 pub fn new(plugin_name: &str, capabilities: WasmPluginCapabilities) -> Self {
148 Self {
149 plugin_name: plugin_name.to_string(),
150 capabilities,
151 audit_log: Vec::new(),
152 transaction_id: None,
153 session_vars: HashMap::new(),
154 }
155 }
156
157 pub fn check_read(&self, table: &str) -> KernelResult<()> {
159 if !self.capabilities.can_read(table) {
160 return Err(KernelError::Plugin {
161 message: format!(
162 "plugin '{}' not authorized to read table '{}'",
163 self.plugin_name, table
164 ),
165 });
166 }
167 Ok(())
168 }
169
170 pub fn check_write(&self, table: &str) -> KernelResult<()> {
172 if !self.capabilities.can_write(table) {
173 return Err(KernelError::Plugin {
174 message: format!(
175 "plugin '{}' not authorized to write table '{}'",
176 self.plugin_name, table
177 ),
178 });
179 }
180 Ok(())
181 }
182
183 pub fn check_vector_search(&self) -> KernelResult<()> {
185 if !self.capabilities.can_vector_search {
186 return Err(KernelError::Plugin {
187 message: format!(
188 "plugin '{}' not authorized for vector search",
189 self.plugin_name
190 ),
191 });
192 }
193 Ok(())
194 }
195
196 pub fn check_index_search(&self) -> KernelResult<()> {
198 if !self.capabilities.can_index_search {
199 return Err(KernelError::Plugin {
200 message: format!(
201 "plugin '{}' not authorized for index search",
202 self.plugin_name
203 ),
204 });
205 }
206 Ok(())
207 }
208
209 pub fn audit(&mut self, function: &str, table: Option<&str>, status: i32, rows: u64) {
211 self.audit_log.push(AuditEntry {
212 timestamp_us: std::time::SystemTime::now()
213 .duration_since(std::time::UNIX_EPOCH)
214 .unwrap_or_default()
215 .as_micros() as u64,
216 function: function.to_string(),
217 table: table.map(|s| s.to_string()),
218 status,
219 rows_affected: rows,
220 });
221 }
222}
223
224pub trait HostFunction: Send + Sync {
232 fn name(&self) -> &str;
234
235 fn execute(&self, ctx: &mut HostFunctionContext, args: &[u8]) -> HostCallResult;
237
238 fn description(&self) -> &str;
240}
241
242pub struct SochRead {
256 _marker: std::marker::PhantomData<()>,
258}
259
260impl SochRead {
261 pub fn new() -> Self {
262 Self {
263 _marker: std::marker::PhantomData,
264 }
265 }
266}
267
268impl Default for SochRead {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274impl HostFunction for SochRead {
275 fn name(&self) -> &str {
276 "soch_read"
277 }
278
279 fn description(&self) -> &str {
280 "Read rows from a table with optional key filter"
281 }
282
283 fn execute(&self, ctx: &mut HostFunctionContext, args: &[u8]) -> HostCallResult {
284 let args_str = match std::str::from_utf8(args) {
286 Ok(s) => s,
287 Err(_) => {
288 ctx.audit("soch_read", None, -3, 0);
289 return HostCallResult::InvalidArgs("invalid UTF-8 in arguments".to_string());
290 }
291 };
292
293 let table = args_str.lines().next().unwrap_or("");
295
296 if let Err(e) = ctx.check_read(table) {
298 ctx.audit("soch_read", Some(table), -1, 0);
299 return HostCallResult::PermissionDenied(e.to_string());
300 }
301
302 let mock_data = "table[1]{id,name}:\n(1,\"mock_row\")"
305 .to_string()
306 .into_bytes();
307
308 ctx.audit("soch_read", Some(table), 0, 1);
309 HostCallResult::Success(mock_data)
310 }
311}
312
313pub struct SochWrite {
326 _marker: std::marker::PhantomData<()>,
327}
328
329impl SochWrite {
330 pub fn new() -> Self {
331 Self {
332 _marker: std::marker::PhantomData,
333 }
334 }
335}
336
337impl Default for SochWrite {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343impl HostFunction for SochWrite {
344 fn name(&self) -> &str {
345 "soch_write"
346 }
347
348 fn description(&self) -> &str {
349 "Write rows to a table"
350 }
351
352 fn execute(&self, ctx: &mut HostFunctionContext, args: &[u8]) -> HostCallResult {
353 let args_str = match std::str::from_utf8(args) {
354 Ok(s) => s,
355 Err(_) => {
356 ctx.audit("soch_write", None, -3, 0);
357 return HostCallResult::InvalidArgs("invalid UTF-8 in arguments".to_string());
358 }
359 };
360
361 let table = args_str.lines().next().unwrap_or("");
362
363 if let Err(e) = ctx.check_write(table) {
364 ctx.audit("soch_write", Some(table), -1, 0);
365 return HostCallResult::PermissionDenied(e.to_string());
366 }
367
368 let row_count = args_str.lines().skip(1).count() as u64;
370
371 ctx.audit("soch_write", Some(table), 0, row_count);
372 HostCallResult::Success(row_count.to_le_bytes().to_vec())
373 }
374}
375
376pub struct VectorSearch {
391 _marker: std::marker::PhantomData<()>,
392}
393
394impl VectorSearch {
395 pub fn new() -> Self {
396 Self {
397 _marker: std::marker::PhantomData,
398 }
399 }
400}
401
402impl Default for VectorSearch {
403 fn default() -> Self {
404 Self::new()
405 }
406}
407
408impl HostFunction for VectorSearch {
409 fn name(&self) -> &str {
410 "vector_search"
411 }
412
413 fn description(&self) -> &str {
414 "Perform vector similarity search"
415 }
416
417 fn execute(&self, ctx: &mut HostFunctionContext, args: &[u8]) -> HostCallResult {
418 if let Err(e) = ctx.check_vector_search() {
419 ctx.audit("vector_search", None, -1, 0);
420 return HostCallResult::PermissionDenied(e.to_string());
421 }
422
423 let args_str = std::str::from_utf8(args).unwrap_or("");
425 let collection = args_str.lines().next().unwrap_or("default");
426
427 let mock_results: Vec<(RowId, f32)> = vec![(1, 0.1), (2, 0.2), (3, 0.3)];
430
431 let mut result = Vec::new();
433 for (row_id, distance) in mock_results {
434 result.extend_from_slice(&row_id.to_le_bytes());
435 result.extend_from_slice(&distance.to_le_bytes());
436 }
437
438 ctx.audit("vector_search", Some(collection), 0, 3);
439 HostCallResult::Success(result)
440 }
441}
442
443pub struct EmitMetric {
458 metrics_emitted: AtomicU64,
460}
461
462impl EmitMetric {
463 pub fn new() -> Self {
464 Self {
465 metrics_emitted: AtomicU64::new(0),
466 }
467 }
468
469 pub fn total_emitted(&self) -> u64 {
470 self.metrics_emitted.load(Ordering::Relaxed)
471 }
472}
473
474impl Default for EmitMetric {
475 fn default() -> Self {
476 Self::new()
477 }
478}
479
480impl HostFunction for EmitMetric {
481 fn name(&self) -> &str {
482 "emit_metric"
483 }
484
485 fn description(&self) -> &str {
486 "Emit an observability metric (counter, gauge, or histogram)"
487 }
488
489 fn execute(&self, ctx: &mut HostFunctionContext, args: &[u8]) -> HostCallResult {
490 if args.is_empty() {
493 ctx.audit("emit_metric", None, -3, 0);
494 return HostCallResult::InvalidArgs("empty metric data".to_string());
495 }
496
497 self.metrics_emitted.fetch_add(1, Ordering::Relaxed);
499
500 ctx.audit("emit_metric", None, 0, 1);
501 HostCallResult::Ok
502 }
503}
504
505pub struct LogMessage {
518 logs: parking_lot::RwLock<Vec<(u8, String)>>,
520}
521
522impl LogMessage {
523 pub fn new() -> Self {
524 Self {
525 logs: parking_lot::RwLock::new(Vec::new()),
526 }
527 }
528
529 pub fn captured_logs(&self) -> Vec<(u8, String)> {
531 self.logs.read().clone()
532 }
533
534 pub fn clear_logs(&self) {
536 self.logs.write().clear();
537 }
538}
539
540impl Default for LogMessage {
541 fn default() -> Self {
542 Self::new()
543 }
544}
545
546impl HostFunction for LogMessage {
547 fn name(&self) -> &str {
548 "log_message"
549 }
550
551 fn description(&self) -> &str {
552 "Log a message at specified level"
553 }
554
555 fn execute(&self, ctx: &mut HostFunctionContext, args: &[u8]) -> HostCallResult {
556 if args.is_empty() {
559 return HostCallResult::InvalidArgs("empty log data".to_string());
560 }
561
562 let level = args[0];
563 let message = std::str::from_utf8(&args[1..]).unwrap_or("<invalid UTF-8>");
564
565 self.logs.write().push((level, message.to_string()));
567
568 ctx.audit("log_message", None, 0, 0);
570 HostCallResult::Ok
571 }
572}
573
574pub struct HostFunctionRegistry {
580 functions: HashMap<String, Arc<dyn HostFunction>>,
582}
583
584impl Default for HostFunctionRegistry {
585 fn default() -> Self {
586 Self::new()
587 }
588}
589
590impl HostFunctionRegistry {
591 pub fn new() -> Self {
593 let mut registry = Self {
594 functions: HashMap::new(),
595 };
596
597 registry.register(Arc::new(SochRead::new()));
599 registry.register(Arc::new(SochWrite::new()));
600 registry.register(Arc::new(VectorSearch::new()));
601 registry.register(Arc::new(EmitMetric::new()));
602 registry.register(Arc::new(LogMessage::new()));
603
604 registry
605 }
606
607 pub fn register(&mut self, func: Arc<dyn HostFunction>) {
609 self.functions.insert(func.name().to_string(), func);
610 }
611
612 pub fn get(&self, name: &str) -> Option<Arc<dyn HostFunction>> {
614 self.functions.get(name).cloned()
615 }
616
617 pub fn list(&self) -> Vec<(&str, &str)> {
619 self.functions
620 .values()
621 .map(|f| (f.name(), f.description()))
622 .collect()
623 }
624
625 pub fn execute(
627 &self,
628 name: &str,
629 ctx: &mut HostFunctionContext,
630 args: &[u8],
631 ) -> HostCallResult {
632 match self.functions.get(name) {
633 Some(func) => func.execute(ctx, args),
634 None => HostCallResult::NotFound(format!("host function '{}' not found", name)),
635 }
636 }
637}
638
639pub mod wire {
645 pub fn encode_string(s: &str) -> Vec<u8> {
647 let mut buf = Vec::with_capacity(4 + s.len());
648 buf.extend_from_slice(&(s.len() as u32).to_le_bytes());
649 buf.extend_from_slice(s.as_bytes());
650 buf
651 }
652
653 pub fn decode_string(data: &[u8]) -> Option<(&str, &[u8])> {
655 if data.len() < 4 {
656 return None;
657 }
658 let len = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
659 if data.len() < 4 + len {
660 return None;
661 }
662 let s = std::str::from_utf8(&data[4..4 + len]).ok()?;
663 Some((s, &data[4 + len..]))
664 }
665
666 pub fn encode_row_id(id: u64) -> [u8; 8] {
668 id.to_le_bytes()
669 }
670
671 pub fn decode_row_id(data: &[u8]) -> Option<(u64, &[u8])> {
673 if data.len() < 8 {
674 return None;
675 }
676 let id = u64::from_le_bytes([
677 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
678 ]);
679 Some((id, &data[8..]))
680 }
681
682 pub fn encode_f32_vec(v: &[f32]) -> Vec<u8> {
684 let mut buf = Vec::with_capacity(4 + v.len() * 4);
685 buf.extend_from_slice(&(v.len() as u32).to_le_bytes());
686 for f in v {
687 buf.extend_from_slice(&f.to_le_bytes());
688 }
689 buf
690 }
691
692 pub fn decode_f32_vec(data: &[u8]) -> Option<(Vec<f32>, &[u8])> {
694 if data.len() < 4 {
695 return None;
696 }
697 let len = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
698 if data.len() < 4 + len * 4 {
699 return None;
700 }
701 let mut vec = Vec::with_capacity(len);
702 for i in 0..len {
703 let offset = 4 + i * 4;
704 let f = f32::from_le_bytes([
705 data[offset],
706 data[offset + 1],
707 data[offset + 2],
708 data[offset + 3],
709 ]);
710 vec.push(f);
711 }
712 Some((vec, &data[4 + len * 4..]))
713 }
714}
715
716#[cfg(test)]
721mod tests {
722 use super::*;
723
724 #[test]
725 fn test_host_call_result_status() {
726 assert_eq!(HostCallResult::Ok.status_code(), 0);
727 assert_eq!(HostCallResult::Success(vec![]).status_code(), 0);
728 assert_eq!(
729 HostCallResult::PermissionDenied("".to_string()).status_code(),
730 -1
731 );
732 assert_eq!(HostCallResult::NotFound("".to_string()).status_code(), -2);
733 assert_eq!(
734 HostCallResult::InvalidArgs("".to_string()).status_code(),
735 -3
736 );
737 assert_eq!(HostCallResult::Error("".to_string()).status_code(), -4);
738 }
739
740 #[test]
741 fn test_host_context_permission_checks() {
742 let caps = WasmPluginCapabilities {
743 can_read_table: vec!["users".to_string()],
744 can_write_table: vec!["logs".to_string()],
745 can_vector_search: true,
746 can_index_search: false,
747 ..Default::default()
748 };
749
750 let ctx = HostFunctionContext::new("test_plugin", caps);
751
752 assert!(ctx.check_read("users").is_ok());
753 assert!(ctx.check_read("other").is_err());
754 assert!(ctx.check_write("logs").is_ok());
755 assert!(ctx.check_write("users").is_err());
756 assert!(ctx.check_vector_search().is_ok());
757 assert!(ctx.check_index_search().is_err());
758 }
759
760 #[test]
761 fn test_soch_read_permission() {
762 let caps = WasmPluginCapabilities {
763 can_read_table: vec!["allowed_table".to_string()],
764 ..Default::default()
765 };
766
767 let mut ctx = HostFunctionContext::new("test", caps);
768 let read_fn = SochRead::new();
769
770 let result = read_fn.execute(&mut ctx, b"allowed_table\n");
772 assert_eq!(result.status_code(), 0);
773
774 let result = read_fn.execute(&mut ctx, b"denied_table\n");
776 assert_eq!(result.status_code(), -1);
777 }
778
779 #[test]
780 fn test_soch_write_permission() {
781 let caps = WasmPluginCapabilities {
782 can_write_table: vec!["writable".to_string()],
783 ..Default::default()
784 };
785
786 let mut ctx = HostFunctionContext::new("test", caps);
787 let write_fn = SochWrite::new();
788
789 let result = write_fn.execute(&mut ctx, b"writable\nrow1\nrow2\n");
790 assert_eq!(result.status_code(), 0);
791 assert_eq!(result.data().unwrap(), &2u64.to_le_bytes());
792
793 let result = write_fn.execute(&mut ctx, b"readonly\nrow1\n");
794 assert_eq!(result.status_code(), -1);
795 }
796
797 #[test]
798 fn test_vector_search() {
799 let caps = WasmPluginCapabilities {
800 can_vector_search: true,
801 ..Default::default()
802 };
803
804 let mut ctx = HostFunctionContext::new("test", caps);
805 let search_fn = VectorSearch::new();
806
807 let result = search_fn.execute(&mut ctx, b"collection\n");
808 assert_eq!(result.status_code(), 0);
809
810 let data = result.data().unwrap();
812 assert_eq!(data.len(), 3 * (8 + 4)); }
814
815 #[test]
816 fn test_emit_metric() {
817 let caps = WasmPluginCapabilities::default();
818 let mut ctx = HostFunctionContext::new("test", caps);
819 let metric_fn = EmitMetric::new();
820
821 let result = metric_fn.execute(&mut ctx, b"\x01metric_name\x00\x00\x00\x00");
822 assert_eq!(result.status_code(), 0);
823 assert_eq!(metric_fn.total_emitted(), 1);
824 }
825
826 #[test]
827 fn test_log_message() {
828 let caps = WasmPluginCapabilities::default();
829 let mut ctx = HostFunctionContext::new("test", caps);
830 let log_fn = LogMessage::new();
831
832 let result = log_fn.execute(&mut ctx, b"\x01hello world");
833 assert_eq!(result.status_code(), 0);
834
835 let logs = log_fn.captured_logs();
836 assert_eq!(logs.len(), 1);
837 assert_eq!(logs[0].0, 1); assert_eq!(logs[0].1, "hello world");
839 }
840
841 #[test]
842 fn test_host_function_registry() {
843 let registry = HostFunctionRegistry::new();
844
845 assert!(registry.get("soch_read").is_some());
847 assert!(registry.get("soch_write").is_some());
848 assert!(registry.get("vector_search").is_some());
849 assert!(registry.get("emit_metric").is_some());
850 assert!(registry.get("log_message").is_some());
851
852 assert!(registry.get("unknown").is_none());
854
855 let list = registry.list();
857 assert!(list.len() >= 5);
858 }
859
860 #[test]
861 fn test_registry_execute() {
862 let registry = HostFunctionRegistry::new();
863 let caps = WasmPluginCapabilities {
864 can_read_table: vec!["test".to_string()],
865 ..Default::default()
866 };
867 let mut ctx = HostFunctionContext::new("plugin", caps);
868
869 let result = registry.execute("soch_read", &mut ctx, b"test\n");
870 assert_eq!(result.status_code(), 0);
871
872 let result = registry.execute("nonexistent", &mut ctx, b"");
873 assert_eq!(result.status_code(), -2);
874 }
875
876 #[test]
877 fn test_audit_log() {
878 let caps = WasmPluginCapabilities {
879 can_read_table: vec!["audit_test".to_string()],
880 ..Default::default()
881 };
882 let mut ctx = HostFunctionContext::new("test", caps);
883 let read_fn = SochRead::new();
884
885 let _ = read_fn.execute(&mut ctx, b"audit_test\n");
886
887 assert_eq!(ctx.audit_log.len(), 1);
888 assert_eq!(ctx.audit_log[0].function, "soch_read");
889 assert_eq!(ctx.audit_log[0].table, Some("audit_test".to_string()));
890 assert_eq!(ctx.audit_log[0].status, 0);
891 }
892
893 mod wire_tests {
895 use super::super::wire::*;
896
897 #[test]
898 fn test_encode_decode_string() {
899 let s = "hello world";
900 let encoded = encode_string(s);
901 let (decoded, rest) = decode_string(&encoded).unwrap();
902 assert_eq!(decoded, s);
903 assert!(rest.is_empty());
904 }
905
906 #[test]
907 fn test_encode_decode_row_id() {
908 let id = 0x123456789ABCDEF0u64;
909 let encoded = encode_row_id(id);
910 let (decoded, rest) = decode_row_id(&encoded).unwrap();
911 assert_eq!(decoded, id);
912 assert!(rest.is_empty());
913 }
914
915 #[test]
916 fn test_encode_decode_f32_vec() {
917 let v = vec![1.0, 2.0, 3.0, 4.0];
918 let encoded = encode_f32_vec(&v);
919 let (decoded, rest) = decode_f32_vec(&encoded).unwrap();
920 assert_eq!(decoded, v);
921 assert!(rest.is_empty());
922 }
923
924 #[test]
925 fn test_decode_empty() {
926 assert!(decode_string(&[]).is_none());
927 assert!(decode_row_id(&[]).is_none());
928 assert!(decode_f32_vec(&[]).is_none());
929 }
930 }
931}