#![warn(clippy::all, clippy::pedantic, clippy::nursery, clippy::cargo)]
use std::{
any::{Any, TypeId},
collections::HashMap,
sync::{OnceLock, RwLock},
};
#[cfg(feature = "nightly")]
use try_default::TryDefault;
pub(crate) static SERVICE_CONTAINER: OnceLock<RwLock<Container>> = OnceLock::new();
pub mod facade;
pub struct Container {
#[allow(clippy::type_complexity)]
bindings: HashMap<TypeId, Box<(dyn Fn(&Self) -> Box<dyn Any> + Sync + Send)>>,
instances: HashMap<TypeId, Box<(dyn Fn() -> Box<dyn Any> + Sync + Send)>>,
scoped_instances: Vec<TypeId>,
}
impl Container {
#[must_use]
pub fn new() -> Self {
Self {
bindings: HashMap::new(),
instances: HashMap::new(),
scoped_instances: Vec::new(),
}
}
pub fn get_instance() -> &'static RwLock<Self> {
SERVICE_CONTAINER.get_or_init(|| RwLock::new(Self::new()))
}
pub fn bind<T: 'static>(&mut self, factory: impl Fn(&Self) -> T + 'static + Sync + Send) {
self.instances.remove(&TypeId::of::<T>());
self.bindings.insert(
TypeId::of::<T>(),
Box::new(move |container: &Self| {
let result = factory(container);
Box::new(result) as Box<dyn Any>
}),
);
}
pub fn bind_if<T: 'static>(&mut self, factory: impl Fn(&Self) -> T + 'static + Sync + Send) {
if !self.bindings.contains_key(&TypeId::of::<T>()) {
self.bind(factory);
}
}
pub fn scoped<T: 'static + Clone + Send + Sync>(
&mut self,
factory: &(impl Fn(&Self) -> T + 'static),
) {
self.scoped_instances.push(TypeId::of::<T>());
self.singleton(factory);
}
pub fn scoped_if<T: 'static + Clone + Send + Sync>(
&mut self,
factory: &(impl Fn(&Self) -> T + 'static),
) {
if !self.scoped_instances.contains(&TypeId::of::<T>()) {
self.scoped(factory);
}
}
pub fn singleton<T: 'static + Clone + Send + Sync>(
&mut self,
factory: &(impl Fn(&Self) -> T + 'static),
) {
let result = factory(self);
self.instances.insert(
TypeId::of::<T>(),
Box::new(move || Box::new(result.clone()) as Box<dyn Any + Send + Sync>),
);
}
pub fn singleton_if<T: 'static + Clone + Send + Sync>(
&mut self,
factory: &(impl Fn(&Self) -> T + 'static),
) {
if !self.instances.contains_key(&TypeId::of::<T>()) {
self.singleton(factory);
}
}
pub fn resolve<T: 'static>(&self) -> Result<T, Error> {
let type_id = TypeId::of::<T>();
if let Some(instance) = self.instances.get(&type_id) {
return instance()
.downcast::<T>()
.map(|i| *i)
.map_err(|_| Error::CastFailed);
};
if let Some(binding) = self.bindings.get(&type_id) {
return binding(self)
.downcast::<T>()
.map(|b| *b)
.map_err(|_| Error::CastFailed);
};
try_default_if_enabled().ok_or(Error::NotFound)
}
pub fn forget_scoped_instances(&mut self) {
for instance in &self.scoped_instances {
self.instances.remove(instance);
}
}
pub fn flush(&mut self) {
self.bindings.clear();
self.instances.clear();
self.scoped_instances.clear();
}
}
impl Default for Container {
fn default() -> Self {
Self::new()
}
}
#[cfg(not(feature = "nightly"))]
const fn try_default_if_enabled<T>() -> Option<T> {
None
}
#[cfg(feature = "nightly")]
fn try_default_if_enabled<T>() -> Option<T> {
T::try_default()
}
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
pub enum Error {
#[error("Binding not found")]
NotFound,
#[error("Failed to cast binding to requested type")]
CastFailed,
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[derive(Debug, Clone, PartialEq)]
struct TestDependency {
value: String,
}
#[test]
fn can_register_a_binding() {
let mut container = Container::new();
container.bind(|_: &Container| TestDependency {
value: "Hello, world!".to_string(),
});
assert_eq!(container.bindings.len(), 1);
assert_eq!(container.instances.len(), 0);
}
#[test]
fn can_register_a_singleton() {
let mut container = Container::new();
container.singleton(&|_: &Container| TestDependency {
value: "Hello, world!".to_string(),
});
assert_eq!(container.bindings.len(), 0);
assert_eq!(container.instances.len(), 1);
}
#[test]
fn can_register_a_scoped_binding() {
let mut container = Container::new();
container.scoped(&|_: &Container| TestDependency {
value: "Hello, world!".to_string(),
});
assert_eq!(container.bindings.len(), 0);
assert_eq!(container.instances.len(), 1);
assert_eq!(container.scoped_instances.len(), 1);
}
#[test]
fn registering_a_binding_clears_previous_singleton() {
let mut container = Container::new();
container.singleton(&|_: &Container| TestDependency {
value: "Hello, world!".to_string(),
});
container.bind(|_: &Container| TestDependency {
value: "Goodbye, world!".to_string(),
});
assert_eq!(container.bindings.len(), 1);
assert_eq!(container.instances.len(), 0);
assert_eq!(
container.resolve::<TestDependency>(),
Ok(TestDependency {
value: "Goodbye, world!".to_string()
})
);
}
#[test]
fn can_retrieve_a_registered_binding() {
let mut container = Container::new();
container.bind(|_: &Container| TestDependency {
value: "Hello, world!".to_string(),
});
let result = container.resolve::<TestDependency>().unwrap();
assert_eq!(result.value, "Hello, world!");
}
#[test]
fn can_retrieve_a_registered_scoped_binding_until_flush() {
let mut container = Container::new();
container.scoped(&|_: &Container| TestDependency {
value: "Hello, world!".to_string(),
});
assert_eq!(
container.resolve::<TestDependency>(),
Ok(TestDependency {
value: "Hello, world!".to_string()
})
);
container.forget_scoped_instances();
assert_eq!(container.resolve::<TestDependency>(), Err(Error::NotFound));
}
#[test]
fn can_retrieve_a_registered_singleton() {
let mut container = Container::new();
container.singleton(&|_: &Container| TestDependency {
value: "Hello, world!".to_string(),
});
let result = container.resolve::<TestDependency>().unwrap();
assert_eq!(result.value, "Hello, world!");
}
#[test]
fn returns_singleton_over_binding() {
let mut container = Container::new();
container.bind(|_: &Container| TestDependency {
value: "Hello, world!".to_string(),
});
container.singleton(&|_: &Container| TestDependency {
value: "Goodbye, world!".to_string(),
});
let result = container.resolve::<TestDependency>().unwrap();
assert_eq!(result.value, "Goodbye, world!");
}
#[test]
fn returns_error_when_not_found() {
let container = Container::new();
assert_eq!(container.resolve::<TestDependency>(), Err(Error::NotFound));
}
#[test]
#[cfg(feature = "nightly")]
fn can_resolve_a_binding_for_a_type_that_implements_default() {
let container = Container::new();
assert_eq!(container.resolve::<u64>(), Ok(u64::default()));
}
#[test]
#[serial]
fn can_use_global_container() {
let container = Container::get_instance();
let mut container_w = container.write().unwrap();
container_w.bind(|_: &Container| TestDependency {
value: "Hello, world!".to_string(),
});
drop(container_w);
let container_r = container.read().unwrap();
let result = container_r.resolve::<TestDependency>().unwrap();
drop(container_r);
assert_eq!(result.value, "Hello, world!");
}
}