1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
//! [`Extensions`] used to store extra data in each socket instance.
//!
//! It is heavily inspired by the [`http::Extensions`] type from the `http` crate.
//!
//! The main difference is that the inner [`HashMap`] is wrapped with an [`RwLock`]
//! to allow concurrent access. Moreover, any value extracted from the map is cloned before being returned.
//!
//! This is necessary because [`Extensions`] are shared between all the threads that handle the same socket.
//!
//! You can use the [`Extension`](crate::extract::Extension) or
//! [`MaybeExtension`](crate::extract::MaybeExtension) extractor to extract an extension of the given type.
use std::collections::HashMap;
use std::fmt;
use std::sync::RwLock;
use std::{
any::{Any, TypeId},
hash::{BuildHasherDefault, Hasher},
};
/// TypeMap value
type AnyVal = Box<dyn Any + Send + Sync>;
/// The [`AnyHashMap`] is a [`HashMap`] that uses `TypeId` as keys and `Any` as values.
type AnyHashMap = RwLock<HashMap<TypeId, AnyVal, BuildHasherDefault<IdHasher>>>;
// With TypeIds as keys, there's no need to hash them. They are already hashes
// themselves, coming from the compiler. The IdHasher just holds the u64 of
// the TypeId, and then returns it, instead of doing any bit fiddling.
#[derive(Default)]
struct IdHasher(u64);
impl Hasher for IdHasher {
#[inline]
fn finish(&self) -> u64 {
self.0
}
fn write(&mut self, _: &[u8]) {
unreachable!("TypeId calls write_u64");
}
#[inline]
fn write_u64(&mut self, id: u64) {
self.0 = id;
}
}
/// A type map of protocol extensions.
///
/// It is heavily inspired by the `Extensions` type from the `http` crate.
///
/// The main difference is that the inner Map is wrapped with an `RwLock` to allow concurrent access.
///
/// This is necessary because `Extensions` are shared between all the threads that handle the same socket.
///
/// You can use the [`Extension`](crate::extract::Extension) or
/// [`MaybeExtension`](crate::extract::MaybeExtension) extractor to extract an extension of the given type.
#[derive(Default)]
pub struct Extensions {
/// The underlying map
map: AnyHashMap,
}
impl Extensions {
/// Create an empty `Extensions`.
#[inline]
pub fn new() -> Extensions {
Extensions {
map: AnyHashMap::default(),
}
}
/// Insert a type into this `Extensions`.
///
/// The type must be cloneable and thread safe to be stored.
///
/// If a extension of this type already existed, it will
/// be returned.
///
/// # Example
///
/// ```
/// # use socketioxide::extensions::Extensions;
/// let mut ext = Extensions::new();
/// assert!(ext.insert(5i32).is_none());
/// assert!(ext.insert(4u8).is_none());
/// assert_eq!(ext.insert(9i32), Some(5i32));
/// ```
pub fn insert<T: Send + Sync + Clone + 'static>(&self, val: T) -> Option<T> {
self.map
.write()
.unwrap()
.insert(TypeId::of::<T>(), Box::new(val))
.and_then(|v| v.downcast().ok().map(|boxed| *boxed))
}
/// Get a cloned value of a type previously inserted on this `Extensions`.
///
/// # Example
///
/// ```
/// # use socketioxide::extensions::Extensions;
/// let ext = Extensions::new();
/// assert!(ext.get::<i32>().is_none());
/// ext.insert(5i32);
///
/// assert_eq!(ext.get::<i32>().unwrap(), 5i32);
/// ```
pub fn get<T: Send + Sync + Clone + 'static>(&self) -> Option<T> {
self.map
.read()
.unwrap()
.get(&TypeId::of::<T>())
.and_then(|v| v.downcast_ref::<T>())
.cloned()
}
/// Remove a type from this `Extensions`.
///
/// If a extension of this type existed, it will be returned.
///
/// # Example
///
/// ```
/// # use socketioxide::extensions::Extensions;
/// let mut ext = Extensions::new();
/// ext.insert(5i32);
/// assert_eq!(ext.remove::<i32>(), Some(5i32));
/// assert!(ext.get::<i32>().is_none());
/// ```
pub fn remove<T: Send + Sync + 'static>(&self) -> Option<T> {
self.map
.write()
.unwrap()
.remove(&TypeId::of::<T>())
.and_then(|v| v.downcast().ok().map(|boxed| *boxed))
}
/// Clear the `Extensions` of all inserted extensions.
///
/// # Example
///
/// ```
/// # use socketioxide::extensions::Extensions;
/// let mut ext = Extensions::new();
/// ext.insert(5i32);
/// ext.clear();
///
/// assert!(ext.get::<i32>().is_none());
/// ```
#[inline]
pub fn clear(&self) {
self.map.write().unwrap().clear();
}
/// Check whether the extension set is empty or not.
///
/// # Example
///
/// ```
/// # use socketioxide::extensions::Extensions;
/// let mut ext = Extensions::new();
/// assert!(ext.is_empty());
/// ext.insert(5i32);
/// assert!(!ext.is_empty());
/// ```
#[inline]
pub fn is_empty(&self) -> bool {
self.map.read().unwrap().is_empty()
}
/// Get the number of extensions available.
///
/// # Example
///
/// ```
/// # use socketioxide::extensions::Extensions;
/// let mut ext = Extensions::new();
/// assert_eq!(ext.len(), 0);
/// ext.insert(5i32);
/// assert_eq!(ext.len(), 1);
/// ```
#[inline]
pub fn len(&self) -> usize {
self.map.read().unwrap().len()
}
}
impl fmt::Debug for Extensions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Extensions").finish()
}
}
#[test]
fn test_extensions() {
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
struct MyType(i32);
#[derive(Debug, PartialEq)]
struct ComplexSharedType(u64);
let shared = Arc::new(ComplexSharedType(20));
let extensions = Extensions::new();
extensions.insert(5i32);
extensions.insert(MyType(10));
extensions.insert(shared.clone());
assert_eq!(extensions.get(), Some(5i32));
assert_eq!(extensions.get::<Arc<ComplexSharedType>>(), Some(shared));
assert_eq!(extensions.remove::<i32>(), Some(5i32));
assert!(extensions.get::<i32>().is_none());
assert!(extensions.get::<bool>().is_none());
assert_eq!(extensions.get(), Some(MyType(10)));
}