1use anyhow::{Context, Result, anyhow};
41use hashbrown::HashMap;
42use libloading::{Library, Symbol};
43use serde::{Deserialize, Serialize};
44use std::ffi::{CStr, CString};
45use std::path::{Path, PathBuf};
46use std::ptr::NonNull;
47use std::sync::Mutex;
48use tracing::{debug, info, warn};
49
50pub const PLUGIN_ABI_VERSION: u32 = 1;
53
54use std::os::raw::c_char;
55
56type PluginVersionFn = unsafe extern "C" fn() -> u32;
57type PluginMetadataFn = unsafe extern "C" fn() -> *const c_char;
58type PluginExecuteFn = unsafe extern "C" fn(*const c_char) -> *const c_char;
59type PluginFreeStringFn = unsafe extern "C" fn(*const c_char);
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct PluginContext {
64 pub input: HashMap<String, serde_json::Value>,
66 pub workspace_root: Option<String>,
68 pub config: HashMap<String, serde_json::Value>,
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct PluginResult {
75 pub success: bool,
77 pub output: HashMap<String, serde_json::Value>,
79 pub error: Option<String>,
81 pub files: Vec<String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct PluginMetadata {
88 pub name: String,
90 pub description: String,
92 pub version: String,
94 pub author: Option<String>,
96 pub abi_version: u32,
98 pub when_to_use: Option<String>,
100 pub when_not_to_use: Option<String>,
102 pub allowed_tools: Option<Vec<String>>,
104 #[serde(default)]
107 pub thread_safe: bool,
108}
109
110#[repr(C)]
112#[derive(Debug, Clone)]
113pub struct PluginMetadataFFI {
114 pub json_ptr: *const c_char,
116}
117
118#[repr(C)]
120pub struct PluginResultFFI {
121 pub json_ptr: *const c_char,
123}
124
125pub trait NativePluginTrait: Send + Sync + std::fmt::Debug {
127 fn metadata(&self) -> &PluginMetadata;
129
130 fn path(&self) -> &Path;
132
133 fn execute(&self, ctx: &PluginContext) -> Result<PluginResult>;
135}
136
137pub struct NativePlugin {
139 _library: Library,
141 metadata: PluginMetadata,
143 path: PathBuf,
145 execute_fn: PluginExecuteFn,
147 free_string_fn: Option<PluginFreeStringFn>,
149 execution_lock: Mutex<()>,
151 thread_safe: bool,
153}
154
155fn ensure_non_null_c_string_ptr(
156 ptr: *const c_char,
157 context: &'static str,
158) -> Result<NonNull<c_char>> {
159 NonNull::new(ptr.cast_mut()).ok_or_else(|| anyhow!("{context} returned null pointer"))
160}
161
162fn decode_plugin_c_string(
163 ptr: NonNull<c_char>,
164 free_string_fn: Option<PluginFreeStringFn>,
165 utf8_error_context: &'static str,
166) -> Result<String> {
167 let raw_ptr = ptr.as_ptr() as *const c_char;
168 let decoded = unsafe { CStr::from_ptr(raw_ptr) }
175 .to_str()
176 .context(utf8_error_context)
177 .map(str::to_owned);
178
179 if let Some(free_fn) = free_string_fn {
180 unsafe { free_fn(raw_ptr) };
183 }
184
185 decoded
186}
187
188impl std::fmt::Debug for NativePlugin {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 f.debug_struct("NativePlugin")
191 .field("metadata", &self.metadata)
192 .field("path", &self.path)
193 .finish()
194 }
195}
196
197fn canonicalize_existing_path(path: &Path, label: &str) -> Result<PathBuf> {
198 path.canonicalize()
199 .with_context(|| format!("Failed to resolve {label} '{}'", path.display()))
200}
201
202fn normalize_trusted_dir(path: PathBuf) -> PathBuf {
203 canonicalize_existing_path(&path, "trusted plugin directory").unwrap_or_else(|_| {
204 if path.is_absolute() {
205 path
206 } else {
207 std::env::current_dir()
208 .map(|cwd| cwd.join(&path))
209 .unwrap_or(path)
210 }
211 })
212}
213
214impl NativePlugin {
215 pub fn new(library: Library, path: PathBuf) -> Result<Self> {
217 let version_fn: Symbol<PluginVersionFn> = unsafe {
221 library
222 .get(b"vtcode_plugin_version\0")
223 .context("Failed to load vtcode_plugin_version symbol")?
224 };
225
226 let abi_version = unsafe { version_fn() };
228 if abi_version != PLUGIN_ABI_VERSION {
229 return Err(anyhow!(
230 "Plugin ABI version mismatch: expected {}, got {}",
231 PLUGIN_ABI_VERSION,
232 abi_version
233 ));
234 }
235
236 let free_string_fn = unsafe {
239 library
240 .get::<PluginFreeStringFn>(b"vtcode_plugin_free_string\0")
241 .map(|symbol| *symbol)
242 .ok()
243 };
244
245 let metadata_fn: Symbol<PluginMetadataFn> = unsafe {
248 library
249 .get(b"vtcode_plugin_metadata\0")
250 .context("Failed to load vtcode_plugin_metadata symbol")?
251 };
252
253 let metadata_ptr =
255 ensure_non_null_c_string_ptr(unsafe { metadata_fn() }, "Plugin metadata function")?;
256 let metadata_json = decode_plugin_c_string(
257 metadata_ptr,
258 free_string_fn,
259 "Plugin metadata is not valid UTF-8",
260 )?;
261
262 let metadata: PluginMetadata =
263 serde_json::from_str(&metadata_json).context("Failed to parse plugin metadata JSON")?;
264
265 let execute_fn: Symbol<PluginExecuteFn> = unsafe {
268 library
269 .get(b"vtcode_plugin_execute\0")
270 .context("Failed to load vtcode_plugin_execute symbol")?
271 };
272
273 let execute_fn_ptr = *execute_fn;
274
275 Ok(Self {
276 _library: library,
277 metadata: metadata.clone(), path,
279 execute_fn: execute_fn_ptr,
280 free_string_fn,
281 execution_lock: Mutex::new(()),
282 thread_safe: metadata.thread_safe, })
284 }
285
286 pub fn execute(&self, ctx: &PluginContext) -> Result<PluginResult> {
288 let input_json =
289 serde_json::to_string(ctx).context("Failed to serialize plugin context")?;
290 let input_cstr =
291 CString::new(input_json).context("Plugin context contains internal null bytes")?;
292
293 if !self.thread_safe {
297 let _execution_guard = self
298 .execution_lock
299 .lock()
300 .map_err(|_| anyhow!("native plugin execution lock poisoned"))?;
301
302 self.execute_ffi(input_cstr)
303 } else {
304 self.execute_ffi(input_cstr)
306 }
307 }
308
309 fn execute_ffi(&self, input_cstr: CString) -> Result<PluginResult> {
316 let result_ptr = ensure_non_null_c_string_ptr(
321 unsafe { (self.execute_fn)(input_cstr.as_ptr()) },
322 "Plugin execute function",
323 )?;
324 let result_json = decode_plugin_c_string(
325 result_ptr,
326 self.free_string_fn,
327 "Plugin result is not valid UTF-8",
328 )?;
329
330 let result: PluginResult =
331 serde_json::from_str(&result_json).context("Failed to parse plugin result JSON")?;
332
333 Ok(result)
334 }
335}
336
337impl NativePluginTrait for NativePlugin {
338 fn metadata(&self) -> &PluginMetadata {
339 &self.metadata
340 }
341
342 fn path(&self) -> &Path {
343 &self.path
344 }
345
346 fn execute(&self, ctx: &PluginContext) -> Result<PluginResult> {
347 self.execute(ctx)
348 }
349}
350
351pub struct PluginLoader {
353 trusted_dirs: Vec<PathBuf>,
355}
356
357impl PluginLoader {
358 pub fn new() -> Self {
360 Self {
361 trusted_dirs: Vec::new(),
362 }
363 }
364
365 pub fn add_trusted_dir(&mut self, path: PathBuf) -> &mut Self {
367 let path = normalize_trusted_dir(path);
368 if !self.trusted_dirs.contains(&path) {
369 self.trusted_dirs.push(path);
370 }
371 self
372 }
373
374 pub fn trusted_dirs(&self) -> &[PathBuf] {
376 &self.trusted_dirs
377 }
378
379 pub fn load_plugin(&self, plugin_path: &Path) -> Result<Box<dyn NativePluginTrait>> {
381 debug!("Loading native plugin from {:?}", plugin_path);
382
383 let plugin_path = self.ensure_trusted_path(plugin_path, "Plugin path")?;
384
385 let lib_path = self.find_library_file(&plugin_path)?;
387 let lib_path = self.ensure_trusted_path(&lib_path, "Plugin library path")?;
388
389 let library = unsafe { Library::new(&lib_path) }
400 .with_context(|| format!("Failed to load dynamic library at {:?}", lib_path))?;
401
402 let plugin = NativePlugin::new(library, plugin_path.clone())?;
403
404 info!(
405 "Loaded native plugin '{}' v{} from {:?}",
406 plugin.metadata.name, plugin.metadata.version, plugin_path
407 );
408
409 Ok(Box::new(plugin))
410 }
411
412 pub fn discover_plugins(&self) -> Result<Vec<Box<dyn NativePluginTrait>>> {
414 let mut plugins = Vec::new();
415
416 for dir in &self.trusted_dirs {
417 if !dir.exists() {
418 continue;
419 }
420
421 match self.discover_plugins_in_dir(dir) {
422 Ok(mut dir_plugins) => plugins.append(&mut dir_plugins),
423 Err(e) => {
424 warn!("Failed to discover plugins in {:?}: {}", dir, e);
425 }
426 }
427 }
428
429 Ok(plugins)
430 }
431
432 fn is_in_trusted_dir(&self, path: &Path) -> bool {
434 self.trusted_dirs.iter().any(|dir| path.starts_with(dir))
435 }
436
437 fn ensure_trusted_path(&self, path: &Path, label: &str) -> Result<PathBuf> {
438 let path = canonicalize_existing_path(path, label)?;
439 if self.is_in_trusted_dir(&path) {
440 Ok(path)
441 } else {
442 Err(anyhow!("{label} {:?} is not in a trusted directory", path))
443 }
444 }
445
446 fn find_library_file(&self, plugin_dir: &Path) -> Result<PathBuf> {
448 if !plugin_dir.is_dir() {
449 return Err(anyhow!("Plugin path is not a directory"));
450 }
451
452 let metadata_path = plugin_dir.join("plugin.json");
454 if !metadata_path.exists() {
455 return Err(anyhow!("No plugin.json found in {:?}", plugin_dir));
456 }
457
458 let lib_name = self.get_library_name_from_metadata(&metadata_path)?;
460
461 let lib_path = plugin_dir.join(&lib_name);
462 if lib_path.exists() {
463 return Ok(lib_path);
464 }
465
466 let alternatives = self.get_alternative_library_names(&lib_name);
468 for alt in alternatives {
469 let alt_path = plugin_dir.join(alt);
470 if alt_path.exists() {
471 return Ok(alt_path);
472 }
473 }
474
475 Err(anyhow!(
476 "No dynamic library found in {:?}. Expected one of: {}, or alternatives",
477 plugin_dir,
478 lib_name
479 ))
480 }
481
482 fn get_library_name_from_metadata(&self, metadata_path: &Path) -> Result<String> {
484 let metadata_content =
485 std::fs::read_to_string(metadata_path).context("Failed to read plugin metadata")?;
486 let metadata: serde_json::Value =
487 serde_json::from_str(&metadata_content).context("Invalid plugin metadata JSON")?;
488
489 let name = metadata["name"]
490 .as_str()
491 .ok_or_else(|| anyhow!("Plugin metadata missing 'name' field"))?;
492
493 Ok(self.library_filename(name))
494 }
495
496 fn get_alternative_library_names(&self, base_name: &str) -> Vec<String> {
498 let mut alternatives = Vec::new();
499
500 if let Some(stripped) = base_name.strip_prefix("lib") {
502 alternatives.push(stripped.to_string());
503 } else {
504 alternatives.push(format!("lib{}", base_name));
505 }
506
507 let base = base_name.strip_prefix("lib").unwrap_or(base_name);
509 #[cfg(target_os = "macos")]
510 {
511 alternatives.push(format!("{}.dylib", base));
512 alternatives.push(format!("lib{}.dylib", base));
513 }
514 #[cfg(target_os = "linux")]
515 {
516 alternatives.push(format!("{}.so", base));
517 alternatives.push(format!("lib{}.so", base));
518 }
519 #[cfg(target_os = "windows")]
520 {
521 alternatives.push(format!("{}.dll", base));
522 alternatives.push(format!("lib{}.dll", base));
523 }
524
525 alternatives
526 }
527
528 fn discover_plugins_in_dir(&self, dir: &Path) -> Result<Vec<Box<dyn NativePluginTrait>>> {
530 let mut plugins = Vec::new();
531
532 for entry in std::fs::read_dir(dir)? {
533 let entry = entry?;
534 let path = entry.path();
535
536 if path.is_dir() && path.join("plugin.json").exists() {
537 match self.load_plugin(&path) {
538 Ok(plugin) => plugins.push(plugin),
539 Err(e) => {
540 warn!("Failed to load plugin at {:?}: {}", path, e);
541 }
542 }
543 }
544 }
545
546 Ok(plugins)
547 }
548
549 pub fn library_filename(&self, name: &str) -> String {
551 #[cfg(target_os = "macos")]
552 {
553 format!("lib{}.dylib", name)
554 }
555 #[cfg(target_os = "linux")]
556 {
557 format!("lib{}.so", name)
558 }
559 #[cfg(target_os = "windows")]
560 {
561 format!("{}.dll", name)
562 }
563 #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
564 {
565 format!("lib{}", name)
566 }
567 }
568}
569
570impl Default for PluginLoader {
571 fn default() -> Self {
572 Self::new()
573 }
574}
575
576pub fn validate_plugin_structure(plugin_dir: &Path) -> Result<Vec<String>> {
578 let mut errors = Vec::new();
579
580 if !plugin_dir.join("plugin.json").exists() {
582 errors.push("Missing plugin.json".to_string());
583 }
584
585 let has_lib = std::fs::read_dir(plugin_dir)
587 .map(|entries| {
588 entries.filter_map(|e| e.ok()).any(|entry| {
589 let path = entry.path();
590 let ext = path.extension().and_then(|e| e.to_str());
591 matches!(ext, Some("dylib") | Some("so") | Some("dll"))
592 })
593 })
594 .unwrap_or(false);
595
596 if !has_lib {
597 errors.push("No dynamic library found (.dylib, .so, or .dll)".to_string());
598 }
599
600 if let Ok(content) = std::fs::read_to_string(plugin_dir.join("plugin.json")) {
602 if let Ok(metadata) = serde_json::from_str::<serde_json::Value>(&content) {
603 if metadata["name"].as_str().is_none() {
604 errors.push("plugin.json missing required 'name' field".to_string());
605 }
606 if metadata["description"].as_str().is_none() {
607 errors.push("plugin.json missing required 'description' field".to_string());
608 }
609 if metadata["version"].as_str().is_none() {
610 errors.push("plugin.json missing required 'version' field".to_string());
611 }
612 } else {
613 errors.push("Invalid JSON in plugin.json".to_string());
614 }
615 }
616
617 Ok(errors)
618}
619
620#[cfg(test)]
621mod tests {
622 use super::*;
623 use serial_test::serial;
624 use std::cell::Cell;
625 use std::sync::Arc;
626 use std::sync::atomic::{AtomicUsize, Ordering};
627 use std::time::Duration;
628 use tempfile::TempDir;
629
630 thread_local! {
631 static TEST_FREE_WAS_CALLED: Cell<bool> = const { Cell::new(false) };
632 }
633
634 static TEST_EXECUTE_ACTIVE_CALLS: AtomicUsize = AtomicUsize::new(0);
635 static TEST_EXECUTE_MAX_CONCURRENCY: AtomicUsize = AtomicUsize::new(0);
636
637 unsafe extern "C" fn test_free_string(ptr: *const c_char) {
638 TEST_FREE_WAS_CALLED.with(|was_called| was_called.set(true));
639 if !ptr.is_null() {
640 let _ = unsafe { CString::from_raw(ptr as *mut c_char) };
642 }
643 }
644
645 fn create_test_plugin_dir() -> (TempDir, PathBuf) {
646 let temp_dir = TempDir::new().unwrap();
647 let plugin_dir = temp_dir.path().join("test-plugin");
648 std::fs::create_dir(&plugin_dir).unwrap();
649 (temp_dir, plugin_dir)
650 }
651
652 fn write_plugin_metadata(plugin_dir: &Path, name: &str) {
653 std::fs::write(
654 plugin_dir.join("plugin.json"),
655 format!(r#"{{"name":"{name}","description":"test","version":"1.0.0"}}"#),
656 )
657 .unwrap();
658 }
659
660 fn write_fake_library(plugin_dir: &Path, name: &str) -> PathBuf {
661 let loader = PluginLoader::new();
662 let library_path = plugin_dir.join(loader.library_filename(name));
663 std::fs::write(&library_path, b"fake-library").unwrap();
664 library_path
665 }
666
667 fn current_process_library() -> Library {
668 #[cfg(unix)]
669 {
670 libloading::os::unix::Library::this().into()
671 }
672 #[cfg(windows)]
673 {
674 libloading::os::windows::Library::this()
675 .expect("current process library")
676 .into()
677 }
678 }
679
680 fn update_max_concurrency(active_calls: usize) {
681 let mut current_max = TEST_EXECUTE_MAX_CONCURRENCY.load(Ordering::SeqCst);
682 while active_calls > current_max {
683 match TEST_EXECUTE_MAX_CONCURRENCY.compare_exchange(
684 current_max,
685 active_calls,
686 Ordering::SeqCst,
687 Ordering::SeqCst,
688 ) {
689 Ok(_) => break,
690 Err(observed) => current_max = observed,
691 }
692 }
693 }
694
695 unsafe extern "C" fn test_execute_with_delay(_input: *const c_char) -> *const c_char {
696 let active_calls = TEST_EXECUTE_ACTIVE_CALLS.fetch_add(1, Ordering::SeqCst) + 1;
697 update_max_concurrency(active_calls);
698 std::thread::sleep(Duration::from_millis(25));
699 TEST_EXECUTE_ACTIVE_CALLS.fetch_sub(1, Ordering::SeqCst);
700
701 CString::new(r#"{"success":true,"output":{},"error":null,"files":[]}"#)
702 .unwrap()
703 .into_raw()
704 }
705
706 #[test]
707 fn test_validate_plugin_structure_missing_metadata() {
708 let (_temp_dir, plugin_dir) = create_test_plugin_dir();
709 let errors = validate_plugin_structure(&plugin_dir).unwrap();
710 assert!(errors.iter().any(|e| e.contains("plugin.json")));
711 }
712
713 #[test]
714 fn test_validate_plugin_structure_missing_library() {
715 let (_temp_dir, plugin_dir) = create_test_plugin_dir();
716 std::fs::write(
717 plugin_dir.join("plugin.json"),
718 r#"{"name": "test", "description": "test", "version": "1.0.0"}"#,
719 )
720 .unwrap();
721
722 let errors = validate_plugin_structure(&plugin_dir).unwrap();
723 assert!(errors.iter().any(|e| e.contains("dynamic library")));
724 }
725
726 #[test]
727 fn test_validate_plugin_structure_complete() {
728 let (_temp_dir, plugin_dir) = create_test_plugin_dir();
729
730 std::fs::write(
732 plugin_dir.join("plugin.json"),
733 r#"{"name": "test", "description": "test", "version": "1.0.0"}"#,
734 )
735 .unwrap();
736
737 let lib_name = if cfg!(target_os = "macos") {
739 "libtest.dylib"
740 } else if cfg!(target_os = "linux") {
741 "libtest.so"
742 } else {
743 "test.dll"
744 };
745 std::fs::write(plugin_dir.join(lib_name), b"fake").unwrap();
746
747 let errors = validate_plugin_structure(&plugin_dir).unwrap();
748 assert!(errors.is_empty());
749 }
750
751 #[test]
752 fn test_library_filename_platform() {
753 let loader = PluginLoader::new();
754 let filename = loader.library_filename("my-plugin");
755
756 #[cfg(target_os = "macos")]
757 assert_eq!(filename, "libmy-plugin.dylib");
758
759 #[cfg(target_os = "linux")]
760 assert_eq!(filename, "libmy-plugin.so");
761
762 #[cfg(target_os = "windows")]
763 assert_eq!(filename, "my-plugin.dll");
764 }
765
766 #[test]
767 fn test_ensure_non_null_c_string_ptr_rejects_null() {
768 let err = ensure_non_null_c_string_ptr(std::ptr::null::<c_char>(), "Test pointer")
769 .expect_err("null pointer should be rejected");
770 assert!(
771 err.to_string()
772 .contains("Test pointer returned null pointer")
773 );
774 }
775
776 #[test]
777 fn test_decode_plugin_c_string_frees_plugin_buffer() {
778 TEST_FREE_WAS_CALLED.with(|was_called| was_called.set(false));
779
780 let raw = CString::new("{\"ok\":true}")
781 .expect("valid C string")
782 .into_raw();
783 let ptr = NonNull::new(raw).expect("non-null raw pointer");
784
785 let decoded = decode_plugin_c_string(
786 ptr,
787 Some(test_free_string),
788 "Plugin result is not valid UTF-8",
789 )
790 .expect("valid UTF-8 payload");
791
792 assert_eq!(decoded, "{\"ok\":true}");
793 TEST_FREE_WAS_CALLED.with(|was_called| assert!(was_called.get()));
794 }
795
796 #[test]
797 fn test_decode_plugin_c_string_invalid_utf8_still_frees_buffer() {
798 TEST_FREE_WAS_CALLED.with(|was_called| was_called.set(false));
799
800 let raw = CString::from_vec_with_nul(vec![0xFF, 0x00])
801 .expect("valid nul-terminated C string")
802 .into_raw();
803 let ptr = NonNull::new(raw).expect("non-null raw pointer");
804
805 let err = decode_plugin_c_string(
806 ptr,
807 Some(test_free_string),
808 "Plugin payload is not valid UTF-8",
809 )
810 .expect_err("invalid UTF-8 should fail decoding");
811
812 assert!(
813 err.to_string()
814 .contains("Plugin payload is not valid UTF-8")
815 );
816 TEST_FREE_WAS_CALLED.with(|was_called| assert!(was_called.get()));
817 }
818
819 #[test]
820 fn test_load_plugin_rejects_dotdot_escape_from_trusted_root() {
821 let temp_dir = TempDir::new().unwrap();
822 let trusted_root = temp_dir.path().join("trusted");
823 let escaped_plugin_dir = temp_dir.path().join("escaped-plugin");
824 std::fs::create_dir(&trusted_root).unwrap();
825 std::fs::create_dir(&escaped_plugin_dir).unwrap();
826 write_plugin_metadata(&escaped_plugin_dir, "escaped");
827 write_fake_library(&escaped_plugin_dir, "escaped");
828
829 let escaped_path = trusted_root.join("..").join("escaped-plugin");
830
831 let mut loader = PluginLoader::new();
832 loader.add_trusted_dir(trusted_root);
833
834 let err = loader
835 .load_plugin(&escaped_path)
836 .expect_err("path traversal should be rejected");
837
838 assert!(err.to_string().contains("trusted directory"));
839 }
840
841 #[cfg(unix)]
842 #[test]
843 fn test_load_plugin_rejects_symlinked_plugin_dir_escape() {
844 use std::os::unix::fs::symlink;
845
846 let temp_dir = TempDir::new().unwrap();
847 let trusted_root = temp_dir.path().join("trusted");
848 let real_plugin_dir = temp_dir.path().join("external-plugin");
849 let symlinked_plugin_dir = trusted_root.join("linked-plugin");
850 std::fs::create_dir(&trusted_root).unwrap();
851 std::fs::create_dir(&real_plugin_dir).unwrap();
852 write_plugin_metadata(&real_plugin_dir, "linked");
853 write_fake_library(&real_plugin_dir, "linked");
854 symlink(&real_plugin_dir, &symlinked_plugin_dir).unwrap();
855
856 let mut loader = PluginLoader::new();
857 loader.add_trusted_dir(trusted_root);
858
859 let err = loader
860 .load_plugin(&symlinked_plugin_dir)
861 .expect_err("symlink escape should be rejected");
862
863 assert!(err.to_string().contains("trusted directory"));
864 }
865
866 #[cfg(unix)]
867 #[test]
868 fn test_load_plugin_rejects_symlinked_library_escape() {
869 use std::os::unix::fs::symlink;
870
871 let temp_dir = TempDir::new().unwrap();
872 let trusted_root = temp_dir.path().join("trusted");
873 let plugin_dir = trusted_root.join("plugin");
874 let external_dir = temp_dir.path().join("external");
875 std::fs::create_dir(&trusted_root).unwrap();
876 std::fs::create_dir(&plugin_dir).unwrap();
877 std::fs::create_dir(&external_dir).unwrap();
878 write_plugin_metadata(&plugin_dir, "escaped-lib");
879
880 let external_library = write_fake_library(&external_dir, "escaped-lib");
881 let linked_library = plugin_dir.join(PluginLoader::new().library_filename("escaped-lib"));
882 symlink(&external_library, &linked_library).unwrap();
883
884 let mut loader = PluginLoader::new();
885 loader.add_trusted_dir(trusted_root);
886
887 let err = loader
888 .load_plugin(&plugin_dir)
889 .expect_err("library symlink escape should be rejected");
890
891 assert!(err.to_string().contains("trusted directory"));
892 }
893
894 #[test]
895 #[serial]
896 fn test_native_plugin_serializes_concurrent_execution() {
897 TEST_EXECUTE_ACTIVE_CALLS.store(0, Ordering::SeqCst);
898 TEST_EXECUTE_MAX_CONCURRENCY.store(0, Ordering::SeqCst);
899
900 let plugin = Arc::new(NativePlugin {
901 _library: current_process_library(),
902 metadata: PluginMetadata {
903 name: "serialized".to_string(),
904 description: "test plugin".to_string(),
905 version: "1.0.0".to_string(),
906 author: None,
907 abi_version: PLUGIN_ABI_VERSION,
908 when_to_use: None,
909 when_not_to_use: None,
910 allowed_tools: None,
911 thread_safe: false,
912 },
913 path: PathBuf::from("/tmp/serialized-plugin"),
914 execute_fn: test_execute_with_delay,
915 free_string_fn: Some(test_free_string),
916 execution_lock: Mutex::new(()),
917 thread_safe: false,
918 });
919 let ctx = PluginContext {
920 input: HashMap::new(),
921 workspace_root: None,
922 config: HashMap::new(),
923 };
924
925 let handles = (0..4)
926 .map(|_| {
927 let plugin = Arc::clone(&plugin);
928 let ctx = ctx.clone();
929 std::thread::spawn(move || plugin.execute(&ctx).expect("plugin execution"))
930 })
931 .collect::<Vec<_>>();
932
933 for handle in handles {
934 let result = handle.join().expect("thread should complete");
935 assert!(result.success);
936 }
937
938 assert_eq!(TEST_EXECUTE_MAX_CONCURRENCY.load(Ordering::SeqCst), 1);
939 }
940
941 #[test]
942 #[serial]
943 fn test_native_plugin_allows_parallel_execution() {
944 TEST_EXECUTE_ACTIVE_CALLS.store(0, Ordering::SeqCst);
945 TEST_EXECUTE_MAX_CONCURRENCY.store(0, Ordering::SeqCst);
946
947 let plugin = Arc::new(NativePlugin {
948 _library: current_process_library(),
949 metadata: PluginMetadata {
950 name: "parallel".to_string(),
951 description: "reentrant test plugin".to_string(),
952 version: "1.0.0".to_string(),
953 author: None,
954 abi_version: PLUGIN_ABI_VERSION,
955 when_to_use: None,
956 when_not_to_use: None,
957 allowed_tools: None,
958 thread_safe: true, },
960 path: PathBuf::from("/tmp/parallel-plugin"),
961 execute_fn: test_execute_with_delay,
962 free_string_fn: Some(test_free_string),
963 execution_lock: Mutex::new(()),
964 thread_safe: true,
965 });
966
967 let ctx = PluginContext {
968 input: HashMap::new(),
969 workspace_root: None,
970 config: HashMap::new(),
971 };
972
973 let num_threads = 4;
974 let handles = (0..num_threads)
975 .map(|_| {
976 let plugin = Arc::clone(&plugin);
977 let ctx = ctx.clone();
978 std::thread::spawn(move || plugin.execute(&ctx).expect("parallel plugin execution"))
979 })
980 .collect::<Vec<_>>();
981
982 for handle in handles {
983 let result = handle.join().expect("thread join");
984 assert!(result.success);
985 }
986
987 assert!(TEST_EXECUTE_MAX_CONCURRENCY.load(Ordering::SeqCst) > 1);
989 }
990}