1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use bytes::Bytes;

use std::sync::Arc;

use crate::adapter::Adapter;
use crate::handler::{FromConnectParts, FromDisconnectParts, FromMessageParts};
use crate::socket::{DisconnectReason, Socket};

/// An Extractor that contains a [`Clone`] of a state previously set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder).
/// It implements [`std::ops::Deref`] to access the inner type so you can use it as a normal reference.
///
/// The specified state type must be the same as the one set with [`SocketIoBuilder::with_state`](crate::io::SocketIoBuilder).
/// If it is not the case, the handler won't be called and an error log will be print if the `tracing` feature is enabled.
///
/// The state is shared between the entire socket.io app context.
///
/// ### Example
/// ```
/// # use socketioxide::{SocketIo, extract::{SocketRef, State}};
/// # use serde::{Serialize, Deserialize};
/// # use std::sync::{Arc, atomic::{Ordering, AtomicUsize}};
/// #[derive(Default, Clone)]
/// struct MyAppData {
///     user_cnt: Arc<AtomicUsize>,
/// }
/// impl MyAppData {
///     fn add_user(&self) {
///         self.user_cnt.fetch_add(1, Ordering::SeqCst);
///     }
///     fn rm_user(&self) {
///         self.user_cnt.fetch_sub(1, Ordering::SeqCst);
///     }
/// }
/// let (_, io) = SocketIo::builder().with_state(MyAppData::default()).build_svc();
/// io.ns("/", |socket: SocketRef, state: State<MyAppData>| {
///     state.add_user();
///     println!("User count: {}", state.user_cnt.load(Ordering::SeqCst));
/// });
pub struct State<T>(pub T);

/// It was impossible to find the given state and therefore the handler won't be called.
pub struct StateNotFound<T>(std::marker::PhantomData<T>);

impl<T> std::fmt::Display for StateNotFound<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(
            f,
            "State of type {} not found, maybe you forgot to insert it in the state map?",
            std::any::type_name::<T>()
        )
    }
}
impl<T> std::fmt::Debug for StateNotFound<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "StateNotFound {}", std::any::type_name::<T>())
    }
}
impl<T> std::error::Error for StateNotFound<T> {}

impl<A: Adapter, T: Clone + Send + Sync + 'static> FromConnectParts<A> for State<T> {
    type Error = StateNotFound<T>;
    fn from_connect_parts(
        s: &Arc<Socket<A>>,
        _: &Option<String>,
    ) -> Result<Self, StateNotFound<T>> {
        s.get_io()
            .get_state::<T>()
            .map(State)
            .ok_or(StateNotFound(std::marker::PhantomData))
    }
}
impl<A: Adapter, T: Clone + Send + Sync + 'static> FromDisconnectParts<A> for State<T> {
    type Error = StateNotFound<T>;
    fn from_disconnect_parts(
        s: &Arc<Socket<A>>,
        _: DisconnectReason,
    ) -> Result<Self, StateNotFound<T>> {
        s.get_io()
            .get_state::<T>()
            .map(State)
            .ok_or(StateNotFound(std::marker::PhantomData))
    }
}
impl<A: Adapter, T: Clone + Send + Sync + 'static> FromMessageParts<A> for State<T> {
    type Error = StateNotFound<T>;
    fn from_message_parts(
        s: &Arc<Socket<A>>,
        _: &mut serde_json::Value,
        _: &mut Vec<Bytes>,
        _: &Option<i64>,
    ) -> Result<Self, StateNotFound<T>> {
        s.get_io()
            .get_state::<T>()
            .map(State)
            .ok_or(StateNotFound(std::marker::PhantomData))
    }
}

super::__impl_deref!(State);