Skip to main content

smcp_computer/mcp_clients/
subscription_manager.rs

1/**
2* 文件名: subscription_manager
3* 作者: Claude Code
4* 创建日期: 2025-12-26
5* 最后修改日期: 2025-12-26
6* 版权: 2023 JQQ. All rights reserved.
7* 依赖: tokio, smcp-computer
8* 描述: 资源订阅状态管理器 / Resource subscription state manager
9*
10* ================================================================================
11* 功能说明 / Functionality
12* ================================================================================
13*
14* 本模块实现了资源订阅的本地状态管理,包括:
15* - 记录当前订阅的资源列表
16* - 提供订阅状态查询接口
17* - 支持订阅的添加、删除和查询
18* - 线程安全的状态管理
19*
20* ================================================================================
21*/
22use crate::mcp_clients::model::Resource;
23use std::collections::HashSet;
24use std::sync::Arc;
25use tokio::sync::RwLock;
26
27/// 订阅记录
28#[derive(Debug, Clone)]
29pub struct Subscription {
30    /// 资源 URI
31    pub uri: String,
32    /// 订阅时间戳
33    pub subscribed_at: std::time::Instant,
34    /// 资源元数据
35    pub resource: Resource,
36}
37
38impl Subscription {
39    /// 创建新的订阅记录
40    pub fn new(resource: Resource) -> Self {
41        Self {
42            uri: resource.uri.clone(),
43            subscribed_at: std::time::Instant::now(),
44            resource,
45        }
46    }
47
48    /// 检查订阅是否已过期(基于 TTL)
49    pub fn is_expired(&self, ttl: std::time::Duration) -> bool {
50        self.subscribed_at.elapsed() > ttl
51    }
52}
53
54/// 订阅状态管理器
55///
56/// 负责管理资源订阅的本地状态,提供线程安全的订阅管理接口。
57#[derive(Debug, Clone)]
58pub struct SubscriptionManager {
59    /// 订阅列表(使用 Arc<RwLock> 保证线程安全)
60    subscriptions: Arc<RwLock<HashSet<String>>>,
61}
62
63impl SubscriptionManager {
64    /// 创建新的订阅管理器
65    pub fn new() -> Self {
66        Self {
67            subscriptions: Arc::new(RwLock::new(HashSet::new())),
68        }
69    }
70
71    /// 添加订阅
72    ///
73    /// # 参数
74    /// - `uri`: 资源 URI
75    ///
76    /// # 返回
77    /// - `Ok(true)`: 新增订阅
78    /// - `Ok(false)`: 已存在,无需重复订阅
79    pub async fn add_subscription(&self, uri: String) -> Result<bool, String> {
80        let mut subs = self.subscriptions.write().await;
81        let is_new = subs.insert(uri.clone());
82        Ok(is_new)
83    }
84
85    /// 移除订阅
86    ///
87    /// # 参数
88    /// - `uri`: 资源 URI
89    ///
90    /// # 返回
91    /// - `Ok(true)`: 找到并移除
92    /// - `Ok(false)`: 未找到
93    pub async fn remove_subscription(&self, uri: &str) -> Result<bool, String> {
94        let mut subs = self.subscriptions.write().await;
95        let removed = subs.remove(uri);
96        Ok(removed)
97    }
98
99    /// 检查是否已订阅
100    ///
101    /// # 参数
102    /// - `uri`: 资源 URI
103    ///
104    /// # 返回
105    /// - `true`: 已订阅
106    /// - `false`: 未订阅
107    pub async fn is_subscribed(&self, uri: &str) -> bool {
108        let subs = self.subscriptions.read().await;
109        subs.contains(uri)
110    }
111
112    /// 获取所有订阅的 URI 列表
113    ///
114    /// # 返回
115    /// - 所有订阅 URI 的向量
116    pub async fn get_subscriptions(&self) -> Vec<String> {
117        let subs = self.subscriptions.read().await;
118        subs.iter().cloned().collect()
119    }
120
121    /// 获取订阅数量
122    ///
123    /// # 返回
124    /// - 当前订阅总数
125    pub async fn subscription_count(&self) -> usize {
126        let subs = self.subscriptions.read().await;
127        subs.len()
128    }
129
130    /// 清空所有订阅
131    pub async fn clear(&self) {
132        let mut subs = self.subscriptions.write().await;
133        subs.clear();
134    }
135
136    /// 批量添加订阅
137    ///
138    /// # 参数
139    /// - `uris`: 资源 URI 列表
140    ///
141    /// # 返回
142    /// - 成功添加的数量
143    pub async fn add_subscriptions_batch(&self, uris: Vec<String>) -> usize {
144        let mut subs = self.subscriptions.write().await;
145        let mut added = 0;
146        for uri in uris {
147            if subs.insert(uri) {
148                added += 1;
149            }
150        }
151        added
152    }
153}
154
155impl Default for SubscriptionManager {
156    fn default() -> Self {
157        Self::new()
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[tokio::test]
166    async fn test_add_and_check_subscription() {
167        let manager = SubscriptionManager::new();
168
169        // 添加订阅
170        let result = manager.add_subscription("window://test".to_string()).await;
171        assert!(result.is_ok());
172        assert!(result.unwrap());
173
174        // 检查订阅
175        assert!(manager.is_subscribed("window://test").await);
176
177        // 重复添加应该返回 false
178        let result = manager.add_subscription("window://test".to_string()).await;
179        assert!(result.is_ok());
180        assert!(!result.unwrap());
181    }
182
183    #[tokio::test]
184    async fn test_remove_subscription() {
185        let manager = SubscriptionManager::new();
186
187        manager
188            .add_subscription("window://test".to_string())
189            .await
190            .unwrap();
191        assert!(manager.is_subscribed("window://test").await);
192
193        // 移除订阅
194        let removed = manager.remove_subscription("window://test").await.unwrap();
195        assert!(removed);
196        assert!(!manager.is_subscribed("window://test").await);
197
198        // 再次移除应该返回 false
199        let removed = manager.remove_subscription("window://test").await.unwrap();
200        assert!(!removed);
201    }
202
203    #[tokio::test]
204    async fn test_get_subscriptions() {
205        let manager = SubscriptionManager::new();
206
207        manager
208            .add_subscription("window://test1".to_string())
209            .await
210            .unwrap();
211        manager
212            .add_subscription("window://test2".to_string())
213            .await
214            .unwrap();
215
216        let subs = manager.get_subscriptions().await;
217        assert_eq!(subs.len(), 2);
218        assert!(subs.contains(&"window://test1".to_string()));
219        assert!(subs.contains(&"window://test2".to_string()));
220    }
221
222    #[tokio::test]
223    async fn test_clear_subscriptions() {
224        let manager = SubscriptionManager::new();
225
226        manager
227            .add_subscription("window://test1".to_string())
228            .await
229            .unwrap();
230        manager
231            .add_subscription("window://test2".to_string())
232            .await
233            .unwrap();
234
235        assert_eq!(manager.subscription_count().await, 2);
236
237        manager.clear().await;
238        assert_eq!(manager.subscription_count().await, 0);
239    }
240}