shared_lib/
lib.rs

1//! A small wrapper around the libloading crate that aims to improve the system path and error handling.
2//!
3//! # Usage
4//!
5//! In your code, run the following:
6//!
7//! ```no_run
8//! use shared_lib::*;
9//! use std::path::PathBuf;
10//! 
11//! let lib_path = LibPath::new(PathBuf::from("path/to/dir"), "library_name_no_ext".into());
12//! unsafe {
13//!    let lib = SharedLib::new(lib_path).unwrap();
14//!    let func = lib.get_fn::<fn(usize, usize) -> usize>("foo").unwrap();
15//!    let result = func.run(1, 2);
16//! }
17//! ```
18
19use thiserror::Error;
20use libloading::{library_filename, Library, Symbol};
21use std::{ffi::OsString, path::PathBuf};
22
23/// Enum representing the possible errors that can occur when working with shared libraries.
24#[derive(Debug, Error)]
25pub enum SharedLibError {
26    #[error("Path is empty.")]
27    PathEmpty,
28    #[error("Failed to convert path '{0}' to {1}.")]
29    PathConversion(PathBuf, String),
30    #[error("Failed to load library from path '{path}'. {msg}")]
31    LoadFailure { path: String, msg: String },
32    #[error("Failed to find symbol '{symbol_name}' in library '{lib_name}'. {msg}")]
33    SymbolNotFound { symbol_name: String, lib_name: String, msg: String }
34}
35
36/// Structure representing a shared library path.
37///
38/// `dir_path` is the directory path where the library is located.
39///
40/// `lib_name` is the library name without the platform specific extension and prefix.
41#[derive(Clone, Debug)]
42pub struct LibPath {
43    pub dir_path: PathBuf,
44    pub lib_name: String,
45}
46impl ToString for LibPath {
47    fn to_string(&self) -> String {
48        let binding = self.path().unwrap();
49        binding.to_str().unwrap().to_string()
50    }
51}
52impl TryInto<OsString> for LibPath {
53    type Error = SharedLibError;
54    fn try_into(self) -> Result<OsString, Self::Error> {
55        let path = self.path()?;
56        path.clone().try_into().map_err(|_| {
57            SharedLibError::PathConversion(path, "OsString".into())
58        })
59    }
60}
61impl LibPath {
62    /// Create a new shared library path.
63    ///
64    /// `dir_path` is the directory path where the library is located.
65    ///
66    /// `lib_name` is the library name without the platform specific extension and prefix.
67    pub fn new(dir_path: PathBuf, lib_name: String) -> LibPath {
68        LibPath { dir_path, lib_name }
69    }
70    /// Create a new shared library path without a directory path.
71    /// Using this function will mean that the library is located in the current directory.
72    ///
73    /// `lib_name` is the library name without the platform specific extension and prefix.
74    pub fn new_no_path(lib_name: String) -> LibPath {
75        LibPath {
76            dir_path: PathBuf::new(),
77            lib_name,
78        }
79    }
80    /// Get the platform specific library filename.
81    ///
82    /// For Windows, it will return the library name with `.dll` extension.
83    ///
84    /// For MacOS, it will return the library name with `lib` prefix and `.dylib` extension.
85    ///
86    /// For Linux, it will return the library name with `lib` prefix and `.so` extension.
87    /// # Example
88    /// ```no_run
89    /// use std::ffi::OsString;
90    /// use shared_lib::*;
91    ///
92    /// let lib_path: LibPath = LibPath::new_no_path("test_name".into());
93    /// let lib_name: OsString = lib_path.filename().expect("Failed to get library name");
94    /// ```
95    pub fn filename(&self) -> Result<OsString, SharedLibError> {
96        if self.lib_name.is_empty() {
97            return Err(SharedLibError::PathEmpty);
98        }
99        Ok(library_filename(self.lib_name.clone()))
100    }
101    /// Get the platform specific library filepath.
102    ///
103    /// `dir_path` is the directory path where the library is located.
104    ///
105    /// `lib_name` is the library name without the platform specific extension.
106    /// # Example
107    /// ```no_run
108    /// use std::path::PathBuf;
109    /// use shared_lib::*;
110    ///
111    /// let lib_path: LibPath = LibPath::new(PathBuf::from("path/to/shared/library"), "shared_library".into());
112    /// let lib_path: PathBuf = lib_path.path().expect("Failed to get library path");
113    /// ```
114    pub fn path(&self) -> Result<PathBuf, SharedLibError> {
115        Ok(self.dir_path.join(self.filename()?))
116    }
117}
118
119/// Structure representing a shared library function.
120#[derive(Clone)]
121pub struct SharedLibFn<'a, Fn> {
122    symbol: Symbol<'a, Fn>,
123}
124impl<'a, Fn> SharedLibFn<'a, Fn> {
125    pub unsafe fn new(symbol: Symbol<'a, Fn>) -> SharedLibFn<'a, Fn> {
126        SharedLibFn { symbol }
127    }
128}
129impl<'a, Ret> SharedLibFn<'a, fn() -> Ret> {
130    pub unsafe fn run(&self) -> Ret {
131        (self.symbol)()
132    }
133}
134// === Implementations for functions with arguments (Rust does not support variadic functions yet)
135impl<'a, Ret, A1> SharedLibFn<'a, fn(A1) -> Ret> {
136    pub unsafe fn run(&self, a1: A1) -> Ret {
137        (self.symbol)(a1)
138    }
139}
140impl<'a, Ret, A1, A2> SharedLibFn<'a, fn(A1, A2) -> Ret> {
141    pub unsafe fn run(&self, a1: A1, a2: A2) -> Ret {
142        (self.symbol)(a1, a2)
143    }
144}
145impl<'a, Ret, A1, A2, A3> SharedLibFn<'a, fn(A1, A2, A3) -> Ret> {
146    pub unsafe fn run(&self, a1: A1, a2: A2, a3: A3) -> Ret {
147        (self.symbol)(a1, a2, a3)
148    }
149}
150impl<'a, Ret, A1, A2, A3, A4> SharedLibFn<'a, fn(A1, A2, A3, A4) -> Ret> {
151    pub unsafe fn run(&self, a1: A1, a2: A2, a3: A3, a4: A4) -> Ret {
152        (self.symbol)(a1, a2, a3, a4)
153    }
154}
155impl<'a, Ret, A1, A2, A3, A4, A5> SharedLibFn<'a, fn(A1, A2, A3, A4, A5) -> Ret> {
156    pub unsafe fn run(&self, a1: A1, a2: A2, a3: A3, a4: A4, a5: A5) -> Ret {
157        (self.symbol)(a1, a2, a3, a4, a5)
158    }
159}
160// ===
161
162/// Structure representing a shared library.
163pub struct SharedLib {
164    lib: Library,
165    lib_path: LibPath
166}
167impl SharedLib {
168    /// Create a new shared library from the given path.
169    /// # Safety
170    /// This function is unsafe because it loads a shared library, which is generally unsafe as it is a foregin code.
171    pub unsafe fn new(lib_path: LibPath) -> Result<SharedLib, SharedLibError> {
172        let os_str: OsString = lib_path.clone().try_into()?;
173        let lib = match Library::new(os_str) {
174            Ok(lib) => lib,
175            Err(e) => {
176                let path_str: OsString = lib_path.try_into()?;
177                let path_str: String = path_str.to_string_lossy().to_string();
178                return Err(SharedLibError::LoadFailure {
179                    path: path_str, 
180                    msg: e.to_string()
181                });
182            }
183        };
184        Ok(SharedLib { lib, lib_path })
185    }
186    /// Get a function by name from the shared library.
187    /// # Safety
188    /// This function is unsafe because it loads a function from the shared library, which is generally unsafe as it is a foregin code.
189    /// # Example
190    /// ```no_run
191    /// use std::path::PathBuf;
192    /// use shared_lib::*;
193    /// unsafe {
194    ///     let lib_path = LibPath::new(PathBuf::from("path/to/shared/library"), "shared_library".into());
195    ///     let lib = SharedLib::new(lib_path).expect("Failed to load shared library");
196    ///     let add_fn = lib.get_fn::<fn(usize, usize) -> usize>("add").expect("Failed to get 'add' function from shared library");
197    ///     let result = add_fn.run(1, 2);
198    /// }
199    /// ```
200    pub unsafe fn get_fn<T>(&self, fn_name: &str) -> Result<SharedLibFn<T>, SharedLibError> {
201        let symbol = match self.lib.get(fn_name.as_bytes()) {
202            Ok(symbol) => symbol,
203            Err(e) => {
204                return Err(SharedLibError::SymbolNotFound { 
205                    symbol_name: fn_name.to_owned(), 
206                    lib_name: self.lib_path.path()?.to_string_lossy().to_string(),
207                    msg: e.to_string(), 
208                });
209            }
210        };
211        Ok(SharedLibFn::new(symbol))
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn create_lib_name() {
221        let lib_path = LibPath::new_no_path("test_name".into());
222        let lib_os_string: OsString = lib_path.try_into().unwrap();
223        if cfg!(target_os = "windows") {
224            assert_eq!(lib_os_string, OsString::from("test_name.dll"));
225        } else if cfg!(target_os = "macos") {
226            assert_eq!(lib_os_string, OsString::from("libtest_name.dylib"));
227        } else if cfg!(target_os = "linux") {
228            assert_eq!(lib_os_string, OsString::from("libtest_name.so"));
229        } else {
230            panic!("Unknown target OS: {}", std::env::consts::OS);
231        }
232    }
233    #[test]
234    #[should_panic]
235    fn create_lib_name_empty() {
236        let lib_path = LibPath::new_no_path("".into());
237        let _: OsString = lib_path.try_into().unwrap();
238    }
239    #[test]
240    fn create_lib_path() {
241        let lib_path = LibPath::new(PathBuf::from("test_dir"), "test_name".into());
242        let lib_os_string: OsString = lib_path.try_into().unwrap();
243        if cfg!(target_os = "windows") {
244            assert_eq!(lib_os_string, OsString::from("test_dir\\test_name.dll"));
245        } else if cfg!(target_os = "macos") {
246            assert_eq!(lib_os_string, OsString::from("test_dir/libtest_name.dylib"));
247        } else if cfg!(target_os = "linux") {
248            assert_eq!(lib_os_string, OsString::from("test_dir/libtest_name.so"));
249        } else {
250            panic!("Unknown target OS: {}", std::env::consts::OS);
251        }
252    }
253    #[test]
254    #[should_panic]
255    fn create_lib_path_empty() {
256        let lib_path = LibPath::new(PathBuf::from("test_dir"), "".into());
257        let _: OsString = lib_path.try_into().unwrap();
258    }
259}