#![warn(clippy::all)]
use raw_sync::locks::{LockInit, Mutex};
use serde::{de, de::DeserializeOwned, ser, Deserialize, Deserializer, Serialize, Serializer};
use std::cell::UnsafeCell;
use std::convert::{From, Into, TryFrom, TryInto};
use std::ops::{Deref, DerefMut};
#[allow(dead_code)]
mod shared_memory;
use shared_memory::{Shmem, ShmemConf, ShmemError};
#[allow(dead_code)]
mod memory;
use memory::{is_aligned, ALIGNMENT};
#[derive(Debug)]
pub enum Error {
UnalignedMemory,
Mutex(String),
Shmem(ShmemError),
Serialization(String),
Deserialization(String),
InvalidSharedMut,
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnalignedMemory => write!(f, "Encountered unaligned memory."),
Self::Shmem(e) => e.fmt(f),
Self::Mutex(s) | Self::Serialization(s) | Self::Deserialization(s) => {
write!(f, "{}", s)
}
Self::InvalidSharedMut => write!(f, "Trying to use a `SharedMut` previously invalidated with a call to `serde::Serialize::serialize`.")
}
}
}
impl std::error::Error for Error {}
pub unsafe trait ShmemBacked {
type NewArg: ?Sized;
type MetaData: Serialize + DeserializeOwned + Clone;
fn required_memory_arg(arg: &Self::NewArg) -> usize;
fn required_memory_src(src: &Self) -> usize;
fn new(data: &mut [u8], arg: &Self::NewArg) -> Self::MetaData;
fn new_from_src(data: &mut [u8], src: &Self) -> Self::MetaData;
}
pub trait ShmemView<'a>: ShmemBacked {
type View;
fn view(data: &'a [u8], metadata: &'a <Self as ShmemBacked>::MetaData) -> Self::View;
}
pub trait ShmemViewMut<'a>: ShmemBacked {
type View;
fn view_mut(
data: &'a mut [u8],
metadata: &'a mut <Self as ShmemBacked>::MetaData,
) -> Self::View;
}
unsafe impl ShmemBacked for str {
type NewArg = str;
type MetaData = usize;
fn required_memory_arg(src: &Self::NewArg) -> usize {
src.len()
}
fn required_memory_src(src: &Self) -> usize {
src.len()
}
fn new(data: &mut [u8], src: &Self::NewArg) -> Self::MetaData {
assert_eq!(data.len(), Self::required_memory_arg(src));
data.copy_from_slice(src.as_bytes());
data.len()
}
fn new_from_src(data: &mut [u8], src: &Self) -> Self::MetaData {
assert_eq!(data.len(), Self::required_memory_src(src));
data.copy_from_slice(src.as_bytes());
data.len()
}
}
impl<'a> ShmemView<'a> for str {
type View = &'a str;
fn view(data: &'a [u8], metadata: &'a <Self as ShmemBacked>::MetaData) -> Self::View {
assert_eq!(data.len(), *metadata);
unsafe { std::str::from_utf8_unchecked(data) }
}
}
impl<'a> ShmemViewMut<'a> for str {
type View = &'a mut str;
fn view_mut(
data: &'a mut [u8],
metadata: &'a mut <Self as ShmemBacked>::MetaData,
) -> Self::View {
assert_eq!(data.len(), *metadata);
unsafe { std::str::from_utf8_unchecked_mut(data) }
}
}
unsafe impl<T> ShmemBacked for [T]
where
T: Copy,
{
type NewArg = (T, usize);
type MetaData = usize;
fn required_memory_arg((_, len): &Self::NewArg) -> usize {
*len * std::mem::size_of::<T>()
}
fn required_memory_src(src: &Self) -> usize {
src.len() * std::mem::size_of::<T>()
}
fn new(data: &mut [u8], &(init, len): &Self::NewArg) -> Self::MetaData {
assert_eq!(data.len(), Self::required_memory_arg(&(init, len)));
let data_typed =
unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, len) };
for elem in data_typed.iter_mut() {
*elem = init;
}
len
}
fn new_from_src(data: &mut [u8], src: &Self) -> Self::MetaData {
assert_eq!(data.len(), Self::required_memory_src(src));
let data_typed =
unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, src.len()) };
data_typed.copy_from_slice(src);
data_typed.len()
}
}
impl<'a, T> ShmemView<'a> for [T]
where
T: Copy + 'a,
{
type View = &'a [T];
fn view(data: &'a [u8], metadata: &'a <Self as ShmemBacked>::MetaData) -> Self::View {
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, *metadata) }
}
}
impl<'a, T> ShmemViewMut<'a> for [T]
where
T: Copy + 'a,
{
type View = &'a mut [T];
fn view_mut(
data: &'a mut [u8],
metadata: &'a mut <Self as ShmemBacked>::MetaData,
) -> Self::View {
unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, *metadata) }
}
}
use num_traits::{ops::checked::CheckedAdd, sign::Unsigned, NumOps};
trait AccessCounter: Copy + PartialOrd + Eq + NumOps + Unsigned + CheckedAdd {}
impl<T: Copy + PartialOrd + Eq + NumOps + Unsigned + CheckedAdd> AccessCounter for T {}
struct ShmemBase<A: AccessCounter, P: DropBehaviour, const N: usize> {
tag: P,
shmem: Shmem,
access_counter_type: std::marker::PhantomData<A>,
counter_offset: usize,
data_offset: usize,
data_size: usize,
free_shmem: bool,
}
impl<A: AccessCounter, P: DropBehaviour, const N: usize> ShmemBase<A, P, N> {
fn new(access_counter: &[A; N], data_size: usize, tag: P) -> Result<Self, Error> {
let lock_size = Mutex::size_of(None);
let reserved_size =
((lock_size + std::mem::size_of::<A>() * N) / ALIGNMENT + 1) * ALIGNMENT;
let shmem = ShmemConf::new()
.size(reserved_size + data_size)
.create()
.map_err(Error::Shmem)?;
let (mutex, _) = unsafe {
Mutex::new(shmem.as_ptr(), shmem.as_ptr().add(lock_size))
.map_err(|e| Error::Mutex(format!("{}", e)))?
};
{
let lock = mutex.lock().map_err(|e| Error::Mutex(format!("{}", e)))?;
unsafe {
let counter_ptr = std::slice::from_raw_parts_mut(*lock as *mut A, N);
counter_ptr.copy_from_slice(access_counter);
}
}
std::mem::forget(mutex);
let aligned = unsafe { is_aligned(shmem.as_ptr().add(reserved_size), ALIGNMENT) };
if !aligned {
Err(Error::UnalignedMemory)
} else {
Ok(Self {
tag,
shmem,
access_counter_type: std::marker::PhantomData,
counter_offset: lock_size,
data_offset: reserved_size,
data_size,
free_shmem: true,
})
}
}
fn mutex_write(&self, write: fn(&[A; N]) -> Option<[A; N]>) -> Result<[A; N], Error> {
let (mutex, _) = unsafe {
let counter_ptr = self.shmem.as_ptr().add(self.counter_offset);
Mutex::from_existing(self.shmem.as_ptr(), counter_ptr)
.map_err(|e| Error::Mutex(format!("{}", e)))?
};
let counter_values = {
let lock = mutex.lock().map_err(|e| Error::Mutex(format!("{}", e)))?;
let counter_ptr = unsafe { std::slice::from_raw_parts_mut(*lock as *mut A, N) };
let mut old = [A::zero(); N];
old.copy_from_slice(counter_ptr);
if let Some(new) = write(&old) {
counter_ptr.copy_from_slice(&new);
}
old
};
std::mem::forget(mutex);
Ok(counter_values)
}
fn free_on_drop(&mut self, free: bool) {
self.free_shmem = free;
}
fn data_ptr(&self) -> *const u8 {
unsafe { self.shmem.as_ptr().add(self.data_offset) }
}
fn data_ptr_mut(&mut self) -> *mut u8 {
unsafe { self.shmem.as_ptr().add(self.data_offset) }
}
fn to_wire_format<T>(&self, metadata: T) -> WireFormat<T> {
WireFormat {
tag: self.tag.clone().into(),
os_id: String::from(self.shmem.get_os_id()),
mem_size: self.shmem.len(),
counter_offset: self.counter_offset,
data_offset: self.data_offset,
data_size: self.data_size,
meta: metadata,
}
}
fn into_wire_format<T>(mut self, metadata: T) -> WireFormat<T> {
self.free_shmem = false;
WireFormat {
tag: self.tag.clone().into(),
os_id: String::from(self.shmem.get_os_id()),
mem_size: self.shmem.len(),
counter_offset: self.counter_offset,
data_offset: self.data_offset,
data_size: self.data_size,
meta: metadata,
}
}
fn from_wire_format<T>(wire_format: WireFormat<T>) -> Result<(Self, T), Error> {
let WireFormat {
tag,
os_id,
mem_size,
counter_offset,
data_offset,
data_size,
meta,
} = wire_format;
let shmem = ShmemConf::new()
.os_id(os_id)
.size(mem_size)
.open()
.map_err(Error::Shmem)?;
Ok((
Self {
tag: tag.try_into()?,
shmem,
access_counter_type: std::marker::PhantomData,
counter_offset,
data_offset,
data_size,
free_shmem: false,
},
meta,
))
}
}
impl<A: AccessCounter, P: DropBehaviour, const N: usize> Drop for ShmemBase<A, P, N> {
fn drop(&mut self) {
P::called_on_drop(self);
}
}
trait DropBehaviour: Clone + TryFrom<Tag, Error = Error> + Into<Tag> {
fn called_on_drop<A: AccessCounter, P: DropBehaviour, const N: usize>(
base: &mut ShmemBase<A, P, N>,
);
}
impl<A: AccessCounter> Clone for ShmemBase<A, SharedTag, 2> {
fn clone(&self) -> Self {
let shmem = ShmemConf::new()
.os_id(self.shmem.get_os_id())
.size(self.shmem.len())
.open()
.unwrap();
let new = Self {
tag: SharedTag(),
shmem,
access_counter_type: std::marker::PhantomData,
counter_offset: self.counter_offset,
data_offset: self.data_offset,
data_size: self.data_size,
free_shmem: true,
};
let write: fn(&[A; 2]) -> Option<[A; 2]> = |old| {
let mut new = [A::zero(); 2];
if let Some(new_acc_count) = old[0].checked_add(&A::one()) {
new[0] = new_acc_count;
new[1] = old[1];
Some(new)
} else {
panic!("Can't have more than A::MAX `Shared`s with simultaneous access.")
}
};
new.mutex_write(write).unwrap();
new
}
}
impl<A: AccessCounter> From<ShmemBase<A, SharedMutTag, 2>> for ShmemBase<A, SharedTag, 2> {
fn from(mut shared_mut: ShmemBase<A, SharedMutTag, 2>) -> Self {
shared_mut.free_on_drop(false);
let shmem = ShmemConf::new()
.os_id(shared_mut.shmem.get_os_id())
.size(shared_mut.shmem.len())
.open()
.unwrap();
let new = Self {
tag: SharedTag(),
shmem,
access_counter_type: std::marker::PhantomData,
counter_offset: shared_mut.counter_offset,
data_offset: shared_mut.data_offset,
data_size: shared_mut.data_size,
free_shmem: true,
};
new.mutex_write(|_| Some([A::one(), A::zero()])).unwrap();
new
}
}
#[derive(Serialize, Deserialize, Clone)]
struct SharedTag();
impl TryFrom<Tag> for SharedTag {
type Error = Error;
fn try_from(value: Tag) -> Result<Self, Self::Error> {
match value {
Tag::Shared(tag) => Ok(tag),
Tag::SharedMut(_) => Err(Error::Deserialization(String::from(
"Can't deserialize a `Shared` from a `SharedMut` serialization.",
))),
}
}
}
impl DropBehaviour for SharedTag {
fn called_on_drop<A: AccessCounter, P: DropBehaviour, const N: usize>(
base: &mut ShmemBase<A, P, N>,
) {
if base.free_shmem {
let write: fn(&[A; N]) -> Option<[A; N]> = |old| {
let mut new = [A::zero(); N];
new[0] = old[0] - A::one();
new[1] = old[1];
Some(new)
};
let counter_value = base.mutex_write(write).unwrap();
if (counter_value[0] == A::one()) && (counter_value[1] == A::zero()) {
unsafe {
let counter_ptr = base.shmem.as_ptr().add(base.counter_offset);
Mutex::from_existing(base.shmem.as_ptr(), counter_ptr).unwrap();
}
base.shmem.set_owner(true);
} else {
base.shmem.set_owner(false);
}
} else {
base.shmem.set_owner(false);
}
}
}
pub struct Shared<T>
where
T: ShmemBacked + ?Sized,
{
metadata: <T as ShmemBacked>::MetaData,
shmem: ShmemBase<u64, SharedTag, 2>,
}
impl<T> Shared<T>
where
T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
{
pub fn new(arg: &<T as ShmemBacked>::NewArg) -> Result<Self, Error> {
let size = T::required_memory_arg(&arg);
let mut shmem = ShmemBase::new(&[1_u64, 0], size, SharedTag())?;
let metadata = unsafe {
let data = std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size);
T::new(data, arg)
};
Ok(Shared { metadata, shmem })
}
pub fn new_from_inner(arg: &T) -> Result<Self, Error> {
let size = T::required_memory_src(&arg);
let mut shmem = ShmemBase::new(&[1_u64, 0], size, SharedTag())?;
let inner = unsafe {
let data = std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size);
T::new_from_src(data, arg)
};
Ok(Shared {
metadata: inner,
shmem,
})
}
#[allow(clippy::clippy::needless_lifetimes)]
pub fn as_view<'a>(&'a self) -> <T as ShmemView<'a>>::View {
let data =
unsafe { std::slice::from_raw_parts(self.shmem.data_ptr(), self.shmem.data_size) };
T::view(data, &self.metadata)
}
#[cfg(test)]
pub fn counts(&self) -> Result<[u64; 2], Error> {
self.shmem.mutex_write(|_| None)
}
pub fn into_serialized<S: Serializer>(self, serializer: S) -> Result<S::Ok, S::Error> {
let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
let mut new = [0_u64, 0];
if let Some(new_ser_count) = old[1].checked_add(1) {
new[0] = old[0] - 1;
new[1] = new_ser_count;
Some(new)
} else {
panic!("Can't have more than A::MAX serialized `Shared`s.")
}
};
self.shmem.mutex_write(write).map_err(ser::Error::custom)?;
let wire_format = self.shmem.into_wire_format(self.metadata);
wire_format.serialize(serializer)
}
}
impl<T> Serialize for Shared<T>
where
T: ShmemBacked + ?Sized,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
let mut new = [0_u64, 0];
if let Some(new_ser_count) = old[1].checked_add(1) {
new[0] = old[0];
new[1] = new_ser_count;
Some(new)
} else {
panic!("Can't have more than A::MAX serialized `Shared`s.")
}
};
self.shmem.mutex_write(write).map_err(ser::Error::custom)?;
let wire_format = self.shmem.to_wire_format(&self.metadata);
wire_format.serialize(serializer)
}
}
impl<'de, T> Deserialize<'de> for Shared<T>
where
T: ShmemBacked + ?Sized,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let wire_format = WireFormat::deserialize(deserializer)?;
let (mut shmem, metadata) = ShmemBase::<u64, SharedTag, 2>::from_wire_format(wire_format)
.map_err(de::Error::custom)?;
let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
let mut new = [0_u64, 0];
if let Some(new_acc_count) = old[0].checked_add(1) {
new[0] = new_acc_count;
new[1] = old[1].saturating_sub(1);
Some(new)
} else {
None
}
};
if shmem.mutex_write(write).map_err(de::Error::custom)?[0] < u64::MAX {
shmem.free_on_drop(true);
Ok(Shared { metadata, shmem })
} else {
Err(de::Error::custom(
"Can't have more than u64::MAX `Shared`s with simultaneous access.",
))
}
}
}
unsafe impl<T> Send for Shared<T>
where
T: ShmemBacked + ?Sized,
T::MetaData: Send,
{
}
impl<'a, T> From<&'a T> for Shared<T>
where
T: ShmemBacked + for<'b> ShmemView<'b> + ?Sized,
{
fn from(src: &'a T) -> Self {
Self::new_from_inner(src).unwrap()
}
}
impl<T> Clone for Shared<T>
where
T: ShmemBacked + ?Sized,
{
fn clone(&self) -> Self {
let Shared {
shmem, metadata, ..
} = self;
let shmem = shmem.clone();
Self {
metadata: metadata.clone(),
shmem,
}
}
}
impl<T> TryFrom<SharedMut<T>> for Shared<T>
where
T: ShmemBacked + ?Sized,
{
type Error = Error;
fn try_from(shared_mut: SharedMut<T>) -> Result<Self, Self::Error> {
let SharedMut { shmem, metadata } = shared_mut;
if let Some(shmem) = shmem.into_inner() {
let shmem: ShmemBase<_, SharedTag, 2> = shmem.into();
Ok(Self { metadata, shmem })
} else {
Err(Error::InvalidSharedMut)
}
}
}
#[derive(Serialize, Deserialize, Clone)]
struct SharedMutTag();
impl TryFrom<Tag> for SharedMutTag {
type Error = Error;
fn try_from(value: Tag) -> Result<Self, Self::Error> {
match value {
Tag::SharedMut(tag) => Ok(tag),
Tag::Shared(_) => Err(Error::Deserialization(String::from(
"Can't deserialize a `SharedMut` from a `Shared` serialization.",
))),
}
}
}
impl DropBehaviour for SharedMutTag {
fn called_on_drop<A: AccessCounter, P: DropBehaviour, const N: usize>(
base: &mut ShmemBase<A, P, N>,
) {
if base.free_shmem {
unsafe {
let counter_ptr = base.shmem.as_ptr().add(base.counter_offset);
Mutex::from_existing(base.shmem.as_ptr(), counter_ptr).unwrap();
}
base.shmem.set_owner(true);
} else {
base.shmem.set_owner(false);
}
}
}
pub struct SharedMut<T>
where
T: ShmemBacked + ?Sized,
{
metadata: <T as ShmemBacked>::MetaData,
shmem: UnsafeCell<Option<ShmemBase<u64, SharedMutTag, 2>>>,
}
impl<T> SharedMut<T>
where
T: ShmemBacked + for<'a> ShmemView<'a> + for<'a> ShmemViewMut<'a> + ?Sized,
{
pub fn new(arg: &<T as ShmemBacked>::NewArg) -> Result<Self, Error> {
let size = T::required_memory_arg(&arg);
let mut shmem = ShmemBase::new(&[1_u64, 0], size, SharedMutTag())?;
let metadata = unsafe {
let data = std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size);
T::new(data, arg)
};
Ok(SharedMut {
metadata,
shmem: UnsafeCell::new(Some(shmem)),
})
}
pub fn new_from_inner(arg: &T) -> Result<Self, Error> {
let size = T::required_memory_src(&arg);
let mut shmem = ShmemBase::new(&[1_u64, 0], size, SharedMutTag())?;
let metadata = unsafe {
let data = std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size);
T::new_from_src(data, arg)
};
Ok(SharedMut {
metadata,
shmem: UnsafeCell::new(Some(shmem)),
})
}
#[allow(clippy::clippy::needless_lifetimes)]
pub fn as_view_mut<'a>(&'a mut self) -> <T as ShmemViewMut<'a>>::View {
let shmem =
self.shmem.get_mut().as_mut().expect(
"`SharedMut` must not be used after a call to `serde::Serialize::serialize`.",
);
let data = unsafe { std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size) };
T::view_mut(data, &mut self.metadata)
}
}
impl<T> SharedMut<T>
where
T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
{
#[allow(clippy::clippy::needless_lifetimes)]
pub fn as_view<'a>(&'a self) -> <T as ShmemView<'a>>::View {
let data = unsafe {
let shmem: &mut _ = &mut *self.shmem.get();
let shmem = shmem.as_mut().expect(
"`SharedMut` must not be used after a call to `serde::Serialize::serialize`.",
);
std::slice::from_raw_parts(shmem.data_ptr(), shmem.data_size)
};
T::view(data, &self.metadata)
}
pub unsafe fn metadata_mut(&mut self) -> &mut <T as ShmemBacked>::MetaData {
&mut self.metadata
}
#[cfg(test)]
pub fn counts(&mut self) -> Result<[u64; 1], Error> {
let shmem =
self.shmem.get_mut().as_mut().expect(
"`SharedMut` must not be used after a call to `serde::Serialize::serialize`.",
);
let mut access_count = [0];
let counts = shmem.mutex_write(|_| None)?;
access_count[0] = counts[0];
Ok(access_count)
}
pub fn into_serialized<S: Serializer>(self, serializer: S) -> Result<S::Ok, S::Error> {
let shmem = self
.shmem
.into_inner()
.expect("`SharedMut` must not be used after a call to `serde::Serialize::serialize`.");
let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
debug_assert_eq!(old[0], 1);
Some([0_u64, 0])
};
shmem.mutex_write(write).map_err(ser::Error::custom)?;
let wire_format = shmem.into_wire_format(self.metadata);
wire_format.serialize(serializer)
}
}
impl<T> Serialize for SharedMut<T>
where
T: ShmemBacked + ?Sized,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let shmem: &mut _ = unsafe { &mut *self.shmem.get() };
let shmem = shmem
.take()
.expect("`SharedMut` must not be used after a call to `serde::Serialize::serialize`.");
let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
debug_assert_eq!(old[0], 1);
Some([0_u64, 0])
};
shmem.mutex_write(write).map_err(ser::Error::custom)?;
let wire_format = shmem.into_wire_format(self.metadata.clone());
wire_format.serialize(serializer)
}
}
impl<'de, T> Deserialize<'de> for SharedMut<T>
where
T: ShmemBacked + ?Sized,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let wire_format = WireFormat::deserialize(deserializer)?;
let (mut shmem, metadata) =
ShmemBase::<u64, SharedMutTag, 2>::from_wire_format(wire_format)
.map_err(de::Error::custom)?;
let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
debug_assert!(old[0] <= 1);
if old[0] == 0 {
Some([1, 0])
} else {
None
}
};
if shmem.mutex_write(write).map_err(de::Error::custom)?[0] == 0 {
shmem.free_on_drop(true);
Ok(SharedMut {
metadata,
shmem: UnsafeCell::new(Some(shmem)),
})
} else {
Err(de::Error::custom("A shared memory region can only be accessed by one `SharedMut` instance at any time. Note that the existing instance may live in a different process."))
}
}
}
unsafe impl<T> Send for SharedMut<T>
where
T: ShmemBacked + ?Sized,
T::MetaData: Send,
{
}
impl<'a, T> From<&'a T> for SharedMut<T>
where
T: ShmemBacked + for<'b> ShmemView<'b> + for<'b> ShmemViewMut<'b> + ?Sized,
{
fn from(src: &'a T) -> Self {
Self::new_from_inner(src).unwrap()
}
}
#[derive(Serialize, Deserialize)]
struct WireFormat<T> {
tag: Tag,
os_id: String,
mem_size: usize,
counter_offset: usize,
data_offset: usize,
data_size: usize,
meta: T,
}
#[derive(Serialize, Deserialize)]
enum Tag {
Shared(SharedTag),
SharedMut(SharedMutTag),
}
impl From<SharedTag> for Tag {
fn from(shared: SharedTag) -> Self {
Tag::Shared(shared)
}
}
impl From<SharedMutTag> for Tag {
fn from(shared: SharedMutTag) -> Self {
Tag::SharedMut(shared)
}
}
pub type SharedStr = Shared<str>;
impl Deref for SharedStr {
type Target = str;
fn deref(&self) -> &Self::Target {
self.as_view()
}
}
pub type SharedStrMut = SharedMut<str>;
impl Deref for SharedStrMut {
type Target = str;
fn deref(&self) -> &Self::Target {
self.as_view()
}
}
impl DerefMut for SharedStrMut {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_view_mut()
}
}
pub type SharedSlice<T> = Shared<[T]>;
impl<T: Copy + 'static> Deref for SharedSlice<T> {
type Target = [T];
fn deref(&self) -> &Self::Target {
self.as_view()
}
}
pub type SharedSliceMut<T> = SharedMut<[T]>;
impl<T: Copy + 'static> Deref for SharedSliceMut<T> {
type Target = [T];
fn deref(&self) -> &Self::Target {
self.as_view()
}
}
impl<T: Copy + 'static> DerefMut for SharedSliceMut<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_view_mut()
}
}
#[cfg(feature = "shared_ndarray")]
pub use sharify_ndarray::{SharedArray, SharedArrayMut};
#[cfg(feature = "shared_ndarray")]
pub mod sharify_ndarray {
use super::*;
use ndarray::{Array, ArrayView, ArrayViewMut, Dimension};
pub type SharedArray<T, D> = Shared<Array<T, D>>;
pub type SharedArrayMut<T, D> = SharedMut<Array<T, D>>;
unsafe impl<'a, T, D> ShmemBacked for Array<T, D>
where
T: Copy,
D: Dimension + Serialize + DeserializeOwned,
{
type NewArg = (T, D);
type MetaData = (Vec<usize>, Vec<isize>);
fn required_memory_arg((_, dim): &Self::NewArg) -> usize {
dim.size() * std::mem::size_of::<T>()
}
fn required_memory_src(src: &Self) -> usize {
src.len() * std::mem::size_of::<T>()
}
fn new(data: &mut [u8], (init, dim): &Self::NewArg) -> Self::MetaData {
let data =
unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, dim.size()) };
for element in data.iter_mut() {
*element = *init;
}
let view = ArrayView::from_shape(dim.clone(), data).unwrap();
let shape = Vec::from(view.shape());
let strides = Vec::from(view.strides());
(shape, strides)
}
fn new_from_src(data: &mut [u8], src: &Self) -> Self::MetaData {
let data =
unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, src.len()) };
let mut view = ArrayViewMut::from_shape(src.raw_dim(), data).unwrap();
for (src, dst) in src.iter().zip(view.iter_mut()) {
*dst = *src;
}
let shape = Vec::from(view.shape());
let strides = Vec::from(view.strides());
(shape, strides)
}
}
impl<'a, T, D> ShmemView<'a> for Array<T, D>
where
T: Copy + Default + 'a,
D: Dimension + Serialize + DeserializeOwned,
{
type View = ArrayView<'a, T, D>;
fn view(
data: &'a [u8],
(shape, strides): &'a <Self as ShmemBacked>::MetaData,
) -> Self::View {
use ndarray::ShapeBuilder;
debug_assert!(shape.iter().product::<usize>() <= data.len());
let data = unsafe {
std::slice::from_raw_parts(data.as_ptr() as *const T, shape.iter().product())
};
let mut shape_dim = D::zeros(shape.len());
for (src, dst) in shape.iter().zip(shape_dim.as_array_view_mut().iter_mut()) {
*dst = *src;
}
let mut strides_dim = D::zeros(strides.len());
for (src, dst) in strides
.iter()
.zip(strides_dim.as_array_view_mut().iter_mut())
{
*dst = *src as usize;
}
ArrayView::from_shape(shape_dim.strides(strides_dim), data).unwrap()
}
}
impl<'a, T, D> ShmemViewMut<'a> for Array<T, D>
where
T: Copy + Default + 'a,
D: Dimension + Serialize + DeserializeOwned,
{
type View = ArrayViewMut<'a, T, D>;
fn view_mut(
data: &'a mut [u8],
(shape, strides): &'a mut <Self as ShmemBacked>::MetaData,
) -> Self::View {
use ndarray::ShapeBuilder;
debug_assert!(shape.iter().product::<usize>() <= data.len());
let data = unsafe {
std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, shape.iter().product())
};
let mut shape_dim = D::zeros(shape.len());
for (src, dst) in shape.iter().zip(shape_dim.as_array_view_mut().iter_mut()) {
*dst = *src;
}
let mut strides_dim = D::zeros(strides.len());
for (src, dst) in strides
.iter()
.zip(strides_dim.as_array_view_mut().iter_mut())
{
*dst = *src as usize;
}
ArrayViewMut::from_shape(shape_dim.strides(strides_dim), data).unwrap()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bincode::{self, de, options};
use rand::prelude::*;
use std::thread;
fn serialize_shared<T>(shared: Shared<T>) -> Vec<u8>
where
T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
{
let mut bytes = Vec::new();
let mut serializer = bincode::Serializer::new(&mut bytes, options());
shared.into_serialized(&mut serializer).unwrap();
bytes
}
fn serialize_shared_mut<T>(shared: SharedMut<T>) -> Vec<u8>
where
T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
{
let mut bytes = Vec::new();
let mut serializer = bincode::Serializer::new(&mut bytes, options());
shared.into_serialized(&mut serializer).unwrap();
bytes
}
fn deserialize<S: DeserializeOwned>(bytes: &[u8]) -> Result<S, String> {
let mut deserializer = de::Deserializer::from_slice(bytes, options());
S::deserialize(&mut deserializer).map_err(|e| format!("{}", e))
}
fn serialization_roundtrip_shared<T>(shared: Shared<T>) -> Shared<T>
where
T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
{
let mut bytes = Vec::new();
let mut serializer = bincode::Serializer::new(&mut bytes, options());
shared.into_serialized(&mut serializer).unwrap();
let mut deserializer = de::Deserializer::from_slice(bytes.as_slice(), options());
Shared::<T>::deserialize(&mut deserializer).unwrap()
}
fn serialization_roundtrip_shared_mut<T>(shared: SharedMut<T>) -> SharedMut<T>
where
T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
{
let mut bytes = Vec::new();
let mut serializer = bincode::Serializer::new(&mut bytes, options());
shared.into_serialized(&mut serializer).unwrap();
let mut deserializer = de::Deserializer::from_slice(bytes.as_slice(), options());
SharedMut::<T>::deserialize(&mut deserializer).unwrap()
}
fn slice_check_src_shared<T>(slice: &[T])
where
T: Copy + PartialEq + std::fmt::Debug + ?Sized + 'static,
{
let shared = Shared::<[T]>::from(slice);
let roundtrip = serialization_roundtrip_shared(shared);
assert_eq!(roundtrip.as_view(), slice);
}
fn slice_check_src_shared_mut<T>(slice: &[T])
where
T: Copy + PartialEq + std::fmt::Debug + ?Sized + 'static,
{
let shared = SharedMut::<[T]>::from(slice);
let roundtrip = serialization_roundtrip_shared_mut(shared);
assert_eq!(roundtrip.as_view(), slice);
}
#[test]
fn shared_str() {
let s = "sharify_test";
let shared: Shared<str> = Shared::new(s).unwrap();
let roundtrip = serialization_roundtrip_shared(shared);
assert_eq!(roundtrip.as_view(), s);
}
#[test]
fn shared_mut_str() {
let s = "sharify_test";
let shared: SharedMut<str> = SharedMut::new(s).unwrap();
let roundtrip = serialization_roundtrip_shared_mut(shared);
assert_eq!(roundtrip.as_view(), s);
}
enum Slice {
Usize(&'static [usize]),
U8(&'static [u8]),
U64(&'static [u16]),
I16(&'static [i16]),
F64(&'static [f64]),
}
impl Slice {
fn create_slices() -> Vec<Slice> {
vec![
Slice::Usize(&[1, 2, 3, 4, 5]),
Slice::U8(&[1, 2, 3, 4, 5]),
Slice::U64(&[1, 2, 3, 4, 5]),
Slice::I16(&[1, 2, 3, 4, 5]),
Slice::F64(&[1.0, 2.0, 3.0, 4.0, 5.0]),
]
}
}
#[test]
fn shared_slice() {
let slice: &[usize] = &[0_usize, 0, 0, 0, 0];
let shared: Shared<[usize]> = Shared::new(&(0_usize, 5)).unwrap();
let roundtrip = serialization_roundtrip_shared(shared);
assert_eq!(roundtrip.as_view(), slice);
}
#[test]
fn shared_mut_slice() {
let slice: &mut [usize] = &mut [1_usize, 2, 3, 4, 5];
let mut shared: SharedMut<[usize]> = SharedMut::new(&(0_usize, 5)).unwrap();
shared.deref_mut().copy_from_slice(slice);
let roundtrip = serialization_roundtrip_shared_mut(shared);
assert_eq!(roundtrip.as_view(), slice);
}
#[test]
fn shared_slice_from_src() {
let slices = Slice::create_slices();
for s in slices {
match s {
Slice::Usize(s) => slice_check_src_shared(s),
Slice::U8(s) => slice_check_src_shared(s),
Slice::U64(s) => slice_check_src_shared(s),
Slice::I16(s) => slice_check_src_shared(s),
Slice::F64(s) => slice_check_src_shared(s),
}
}
}
#[test]
fn shared_mut_slice_from_src() {
let slices = Slice::create_slices();
for s in slices {
match s {
Slice::Usize(s) => slice_check_src_shared_mut(s),
Slice::U8(s) => slice_check_src_shared_mut(s),
Slice::U64(s) => slice_check_src_shared_mut(s),
Slice::I16(s) => slice_check_src_shared_mut(s),
Slice::F64(s) => slice_check_src_shared_mut(s),
}
}
}
#[test]
fn shared_memory() {
let shared: Shared<[usize]> = Shared::new(&(0_usize, 5)).unwrap();
assert_eq!(shared.counts().unwrap(), [1, 0]);
let bytes = serialize_shared(shared);
let deser: Shared<[usize]> = deserialize(bytes.as_slice()).unwrap();
assert_eq!(deser.counts().unwrap(), [1, 0]);
let mut instances = Vec::new();
for i in 1..=10 {
let inst: Shared<[usize]> = deserialize(bytes.as_slice()).unwrap();
assert_eq!(deser.counts().unwrap(), [1 + i, 0]);
instances.push(inst);
}
assert_eq!(&[0_usize, 0, 0, 0, 0], deser.deref());
std::mem::drop(instances);
assert_eq!(deser.counts().unwrap(), [1, 0]);
assert_eq!(&[0_usize, 0, 0, 0, 0], deser.deref());
std::mem::drop(deser);
assert!(deserialize::<Shared::<[usize]>>(bytes.as_slice()).is_err());
}
#[test]
fn shared_mut_memory() {
let mut shared: SharedMut<[usize]> = SharedMut::new(&(0_usize, 5)).unwrap();
assert_eq!(shared.counts().unwrap(), [1]);
let bytes = serialize_shared_mut(shared);
let mut deser: SharedMut<[usize]> = deserialize(bytes.as_slice()).unwrap();
assert_eq!(deser.counts().unwrap(), [1]);
assert!(deserialize::<SharedMut::<[usize]>>(bytes.as_slice()).is_err());
assert_eq!(&[0_usize, 0, 0, 0, 0], deser.deref());
assert_eq!(deser.counts().unwrap(), [1]);
std::mem::drop(deser);
assert!(deserialize::<SharedMut::<[usize]>>(bytes.as_slice()).is_err());
}
#[test]
fn cross_serialization_from_shared() {
let shared: Shared<[usize]> = Shared::new(&(0_usize, 5)).unwrap();
assert_eq!(&[0_usize, 0, 0, 0, 0], shared.deref());
let bytes = serialize_shared(shared);
assert!(deserialize::<SharedMut::<[usize]>>(bytes.as_slice()).is_err());
let shared: Shared<[usize]> = deserialize(bytes.as_slice()).unwrap();
assert_eq!(shared.counts().unwrap(), [1, 0]);
assert_eq!(&[0_usize, 0, 0, 0, 0], shared.deref());
}
#[test]
fn cross_serialization_from_shared_mut() {
let shared: SharedMut<[usize]> = SharedMut::new(&(0_usize, 5)).unwrap();
assert_eq!(&[0_usize, 0, 0, 0, 0], shared.deref());
let bytes = serialize_shared_mut(shared);
assert!(deserialize::<Shared::<[usize]>>(bytes.as_slice()).is_err());
let mut shared: SharedMut<[usize]> = deserialize(bytes.as_slice()).unwrap();
assert_eq!(shared.counts().unwrap(), [1]);
assert_eq!(&[0_usize, 0, 0, 0, 0], shared.deref());
}
#[test]
fn shared_mut_into_shared() {
let mut shared_mut: SharedMut<[usize]> = SharedMut::new(&(0_usize, 5)).unwrap();
shared_mut.deref_mut().copy_from_slice(&[1, 2, 3, 4, 5]);
let shared: Shared<[usize]> = shared_mut.try_into().unwrap();
assert_eq!(shared.counts().unwrap(), [1, 0]);
assert_eq!(&[1, 2, 3, 4, 5], shared.deref());
}
#[test]
fn shared_clone() {
let shared = Shared::from(&[1_usize, 2, 3, 4, 5] as &[_]);
let mut container = Vec::new();
for i in 0..100 {
let bytes = serialize_shared(shared.clone());
container.push((bytes, shared.clone()));
assert_eq!(shared.counts().unwrap(), [2 + i, 1 + i]);
}
assert_eq!(shared.counts().unwrap(), [101, 100]);
for (i, (bytes, cl)) in container.into_iter().enumerate() {
let mut _deser: Shared<[usize]> = deserialize(bytes.as_slice()).unwrap();
assert_eq!(&[1_usize, 2, 3, 4, 5], cl.deref());
assert_eq!(shared.counts().unwrap(), [102 - i as u64, 99 - i as u64]);
}
assert_eq!(shared.counts().unwrap(), [1, 0]);
assert_eq!(&[1_usize, 2, 3, 4, 5], shared.deref());
}
#[test]
fn races() {
let shared = Shared::from(&[1_usize, 2, 3, 4, 5] as &[_]);
let mut handles = Vec::new();
for _ in 0..50 {
let (send, recv) = std::sync::mpsc::sync_channel(0);
let bytes_send = serialize_shared(shared.clone());
let handle = thread::spawn(move || {
let mut rng = rand::thread_rng();
recv.recv().unwrap();
let mut shared: Shared<[usize]> = deserialize(bytes_send.as_slice()).unwrap();
assert_eq!(&[1_usize, 2, 3, 4, 5], shared.deref());
for _ in 0..1000 {
thread::sleep(std::time::Duration::from_millis(rng.gen_range(0..=5)));
let tmp = serialize_shared(shared);
thread::sleep(std::time::Duration::from_millis(rng.gen_range(0..=5)));
shared = deserialize(tmp.as_slice()).unwrap();
assert_eq!(&[1_usize, 2, 3, 4, 5], shared.deref());
}
});
handles.push((handle, send));
}
thread::sleep(std::time::Duration::from_millis(100));
for (_, send) in handles.iter() {
send.send(()).unwrap();
}
for (handle, _) in handles {
handle.join().unwrap();
}
assert_eq!(&[1_usize, 2, 3, 4, 5], shared.deref());
assert_eq!(shared.counts().unwrap(), [1, 0]);
}
#[cfg(feature = "shared_ndarray")]
mod ndarray_tests {
use super::*;
use ndarray::{Array, Axis, IxDyn};
#[test]
fn shared_ndarray() {
let shared: SharedArray<u64, IxDyn> = Shared::new(&(0, IxDyn(&[3, 2]))).unwrap();
let shared = serialization_roundtrip_shared(shared);
assert_eq!(&[0; 6], shared.as_view().as_slice().unwrap());
}
#[test]
fn shared_mut_ndarray() {
let mut shared: SharedArrayMut<f64, IxDyn> =
SharedMut::new(&(0.0, IxDyn(&[3, 2]))).unwrap();
let slice: &[f64] = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
for (&x, element) in slice.iter().zip(shared.as_view_mut().iter_mut()) {
*element = x;
}
let roundtrip = serialization_roundtrip_shared_mut(shared);
assert_eq!(slice, roundtrip.as_view().as_slice().unwrap());
}
#[test]
fn shared_mut_array_layout() {
let mut array: SharedArrayMut<f64, ndarray::IxDyn> =
SharedArrayMut::new(&(0.0, ndarray::IxDyn(&[100, 200, 300]))).unwrap();
assert_eq!(array.as_view().strides(), &[200 * 300, 300, 1]);
{
let mut view = array.as_view_mut();
assert!(view.is_standard_layout());
view.swap_axes(0, 1);
assert_eq!(view.strides(), &[300, 200 * 300, 1]);
unsafe {
*array.metadata_mut() = (Vec::from(view.shape()), Vec::from(view.strides()));
}
}
assert_eq!(array.as_view().shape(), &[200, 100, 300]);
assert_eq!(array.as_view().strides(), &[300, 200 * 300, 1]);
assert!(!array.as_view().is_standard_layout());
let bytes = serialize_shared_mut(array);
let deser: SharedArrayMut<f64, ndarray::IxDyn> = deserialize(bytes.as_slice()).unwrap();
assert_eq!(deser.as_view().shape(), &[200, 100, 300]);
assert_eq!(deser.as_view().strides(), &[300, 200 * 300, 1]);
assert!(!deser.as_view().is_standard_layout());
}
#[test]
fn shared_ndarray_from_src() {
let mut src = Array::from_elem((100, 200, 300), 0_u64);
src.invert_axis(Axis(1));
let src_shape = Vec::from(src.shape());
let array = Shared::new_from_inner(&src).unwrap();
assert_eq!(src_shape, array.as_view().shape());
assert!(src
.iter()
.zip(array.as_view().iter())
.all(|(src, dst)| src == dst))
}
}
}