use crate::endpoints::Empty;
use crate::error::WarpgateError;
use crate::helpers::{from_virtual_path, to_virtual_path};
use crate::id::Id;
use extism::{Function, Manifest, Plugin};
use once_map::OnceMap;
use serde::de::DeserializeOwned;
use serde::Serialize;
use starbase_styles::color;
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use tracing::trace;
use warpgate_api::VirtualPath;
pub struct PluginContainer<'plugin> {
pub id: Id,
pub manifest: Manifest,
func_cache: OnceMap<String, Vec<u8>>,
plugin: Arc<RwLock<Plugin<'plugin>>>,
}
unsafe impl<'plugin> Send for PluginContainer<'plugin> {}
unsafe impl<'plugin> Sync for PluginContainer<'plugin> {}
impl<'plugin> PluginContainer<'plugin> {
pub fn new<'new>(
id: Id,
manifest: Manifest,
functions: impl IntoIterator<Item = Function>,
) -> miette::Result<PluginContainer<'new>> {
let plugin = Plugin::create_with_manifest(&manifest, functions, true)
.map_err(|error| WarpgateError::PluginCreateFailed { error })?;
Ok(PluginContainer {
manifest,
plugin: Arc::new(RwLock::new(plugin)),
id,
func_cache: OnceMap::new(),
})
}
pub fn new_without_functions<'new>(
id: Id,
manifest: Manifest,
) -> miette::Result<PluginContainer<'new>> {
Self::new(id, manifest, [])
}
pub fn reload_config(&mut self) -> miette::Result<()> {
let config = self
.manifest
.config
.iter()
.map(|(k, v)| (k.to_owned(), Some(v.to_owned())))
.collect::<BTreeMap<_, _>>();
self.plugin
.write()
.expect("Failed to acquire write access!")
.set_config(&config)
.unwrap();
Ok(())
}
pub fn cache_func<O>(&self, func: &str) -> miette::Result<O>
where
O: Debug + DeserializeOwned,
{
self.cache_func_with(func, Empty::default())
}
pub fn cache_func_with<I, O>(&self, func: &str, input: I) -> miette::Result<O>
where
I: Debug + Serialize,
O: Debug + DeserializeOwned,
{
let input = self.format_input(func, input)?;
let cache_key = format!("{func}-{input}");
{
if let Some(data) = self.func_cache.get(&cache_key) {
return self.parse_output(func, data);
}
}
let data = self.call(func, input)?;
let output: O = self.parse_output(func, &data)?;
self.func_cache.insert(cache_key, |_| data);
Ok(output)
}
pub fn call_func<O>(&self, func: &str) -> miette::Result<O>
where
O: Debug + DeserializeOwned,
{
self.call_func_with(func, Empty::default())
}
pub fn call_func_with<I, O>(&self, func: &str, input: I) -> miette::Result<O>
where
I: Debug + Serialize,
O: Debug + DeserializeOwned,
{
self.parse_output(func, &self.call(func, self.format_input(func, input)?)?)
}
pub fn call_func_without_output<I>(&self, func: &str, input: I) -> miette::Result<()>
where
I: Debug + Serialize,
{
self.call(func, self.format_input(func, input)?)?;
Ok(())
}
pub fn has_func(&self, func: &str) -> bool {
self.plugin
.read()
.unwrap_or_else(|_| {
panic!(
"Unable to acquire read access to `{}` WASM plugin.",
self.id
)
})
.has_function(func)
}
pub fn from_virtual_path(&self, path: &Path) -> PathBuf {
from_virtual_path(&self.manifest, path)
}
pub fn to_virtual_path(&self, path: &Path) -> VirtualPath {
to_virtual_path(&self.manifest, path)
}
pub fn call(&self, func: &str, input: impl AsRef<[u8]>) -> miette::Result<Vec<u8>> {
let input = input.as_ref();
trace!(
plugin = self.id.as_str(),
input = %String::from_utf8_lossy(input),
"Calling plugin function {}",
color::label(func),
);
let mut instance = self.plugin.write().unwrap_or_else(|_| {
panic!(
"Unable to acquire write access to `{}` WASM plugin.",
self.id
)
});
let output = instance.call(func, input).map_err(|error| {
#[cfg(debug_assertions)]
{
WarpgateError::PluginCallFailed {
func: func.to_owned(),
error,
}
}
#[cfg(not(debug_assertions))]
{
WarpgateError::PluginCallFailedRelease {
error: error
.to_string()
.replace("\\\\n", "\n")
.replace("\\n", "\n"),
}
}
})?;
trace!(
plugin = self.id.as_str(),
output = %String::from_utf8_lossy(output),
"Called plugin function {}",
color::label(func),
);
Ok(output.to_vec())
}
fn format_input<I: Serialize>(&self, func: &str, input: I) -> miette::Result<String> {
Ok(
serde_json::to_string(&input).map_err(|error| WarpgateError::FormatInputFailed {
func: func.to_owned(),
error,
})?,
)
}
fn parse_output<O: DeserializeOwned>(&self, func: &str, data: &[u8]) -> miette::Result<O> {
Ok(
serde_json::from_slice(data).map_err(|error| WarpgateError::ParseOutputFailed {
func: func.to_owned(),
error,
})?,
)
}
}