Skip to main content

silent/server/
connection.rs

1use std::any::Any;
2use tokio::io::{AsyncRead, AsyncWrite};
3
4pub trait Connection: Any + AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static {
5    fn as_any(&self) -> &dyn Any;
6    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
7}
8
9impl<T> Connection for T
10where
11    T: Any + AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
12{
13    fn as_any(&self) -> &dyn Any {
14        self
15    }
16    fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
17        self
18    }
19}
20
21pub type BoxedConnection = Box<dyn Connection + Send + Sync>;
22
23impl dyn Connection + Send + Sync {
24    pub fn downcast<T: Any + Send + Sync + 'static>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
25        // 仅在类型匹配时才转换为 Any 并 downcast;否则直接返回原始 Box<Self>
26        if (*self).as_any().is::<T>() {
27            let boxed_any = Connection::into_any(self);
28            // SAFETY: 上面已经通过 is::<T>() 检查确保类型匹配
29            Ok(boxed_any.downcast::<T>().unwrap())
30        } else {
31            Err(self)
32        }
33    }
34}
35
36#[cfg(test)]
37mod tests {
38    use super::*;
39    use tokio::io::{AsyncReadExt, AsyncWriteExt};
40
41    #[tokio::test]
42    async fn test_downcast_success_and_failure() {
43        // 使用 tokio::io::duplex 作为一个实现 AsyncRead/Write 的类型
44        let (mut a, b) = tokio::io::duplex(64);
45        let boxed: BoxedConnection = Box::new(b);
46        // 成功 downcast 为 DuplexStream(不使用 expect 以避免 Debug 约束)
47        let res = boxed.downcast::<tokio::io::DuplexStream>();
48        assert!(res.is_ok());
49        let mut peer: Box<tokio::io::DuplexStream> = res.ok().unwrap();
50
51        // 再构造一个 BoxedConnection,用于失败分支测试
52        let (_a2, b2) = tokio::io::duplex(32);
53        let boxed2: BoxedConnection = Box::new(b2);
54        // 失败 downcast,应返回 Err 原对象
55        let err = boxed2
56            .downcast::<tokio::net::TcpStream>()
57            .expect_err("expected Err on mismatch");
58        // 验证原对象仍可用(向 peer 写入并从 a 读取)
59        peer.write_all(b"ping").await.unwrap();
60        let mut buf = [0u8; 4];
61        a.read_exact(&mut buf).await.unwrap();
62        assert_eq!(&buf, b"ping");
63        let _ = err; // 忽略使用
64    }
65}