turbomcp_websocket/
config.rs1use std::time::Duration;
7
8#[derive(Clone, Debug)]
10pub struct WebSocketBidirectionalConfig {
11 pub url: Option<String>,
13
14 pub bind_addr: Option<String>,
16
17 pub max_message_size: usize,
19
20 pub keep_alive_interval: Duration,
22
23 pub reconnect: ReconnectConfig,
25
26 pub elicitation_timeout: Duration,
28
29 pub max_concurrent_elicitations: usize,
31
32 pub enable_compression: bool,
34
35 pub tls_config: Option<TlsConfig>,
37}
38
39impl Default for WebSocketBidirectionalConfig {
40 fn default() -> Self {
41 Self {
42 url: None,
43 bind_addr: None,
44 max_message_size: 16 * 1024 * 1024, keep_alive_interval: Duration::from_secs(30),
46 reconnect: ReconnectConfig::default(),
47 elicitation_timeout: Duration::from_secs(30),
48 max_concurrent_elicitations: 10,
49 enable_compression: false,
50 tls_config: None,
51 }
52 }
53}
54
55impl WebSocketBidirectionalConfig {
56 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn client(url: String) -> Self {
63 Self {
64 url: Some(url),
65 ..Self::default()
66 }
67 }
68
69 pub fn server(bind_addr: String) -> Self {
71 Self {
72 bind_addr: Some(bind_addr),
73 ..Self::default()
74 }
75 }
76
77 pub fn with_max_message_size(mut self, size: usize) -> Self {
79 self.max_message_size = size;
80 self
81 }
82
83 pub fn with_keep_alive_interval(mut self, interval: Duration) -> Self {
85 self.keep_alive_interval = interval;
86 self
87 }
88
89 pub fn with_reconnect_config(mut self, config: ReconnectConfig) -> Self {
91 self.reconnect = config;
92 self
93 }
94
95 pub fn with_elicitation_timeout(mut self, timeout: Duration) -> Self {
97 self.elicitation_timeout = timeout;
98 self
99 }
100
101 pub fn with_max_concurrent_elicitations(mut self, max: usize) -> Self {
103 self.max_concurrent_elicitations = max;
104 self
105 }
106
107 pub fn with_compression(mut self, enable: bool) -> Self {
109 self.enable_compression = enable;
110 self
111 }
112
113 pub fn with_tls_config(mut self, tls_config: TlsConfig) -> Self {
115 self.tls_config = Some(tls_config);
116 self
117 }
118}
119
120#[derive(Clone, Debug)]
122pub struct ReconnectConfig {
123 pub enabled: bool,
125
126 pub initial_delay: Duration,
128
129 pub max_delay: Duration,
131
132 pub backoff_factor: f64,
134
135 pub max_retries: u32,
137}
138
139impl Default for ReconnectConfig {
140 fn default() -> Self {
141 Self {
142 enabled: true,
143 initial_delay: Duration::from_millis(500),
144 max_delay: Duration::from_secs(30),
145 backoff_factor: 2.0,
146 max_retries: 10,
147 }
148 }
149}
150
151impl ReconnectConfig {
152 pub fn new() -> Self {
154 Self::default()
155 }
156
157 pub fn with_enabled(mut self, enabled: bool) -> Self {
159 self.enabled = enabled;
160 self
161 }
162
163 pub fn with_initial_delay(mut self, delay: Duration) -> Self {
165 self.initial_delay = delay;
166 self
167 }
168
169 pub fn with_max_delay(mut self, delay: Duration) -> Self {
171 self.max_delay = delay;
172 self
173 }
174
175 pub fn with_backoff_factor(mut self, factor: f64) -> Self {
177 self.backoff_factor = factor;
178 self
179 }
180
181 pub fn with_max_retries(mut self, retries: u32) -> Self {
183 self.max_retries = retries;
184 self
185 }
186}
187
188#[derive(Clone, Debug, Default)]
190pub struct TlsConfig {
191 pub cert_path: Option<String>,
193
194 pub key_path: Option<String>,
196
197 pub ca_path: Option<String>,
199
200 pub skip_verify: bool,
202}
203
204impl TlsConfig {
205 pub fn new() -> Self {
207 Self::default()
208 }
209
210 pub fn with_client_cert(cert_path: String, key_path: String) -> Self {
212 Self {
213 cert_path: Some(cert_path),
214 key_path: Some(key_path),
215 ..Self::default()
216 }
217 }
218
219 pub fn with_ca_cert(ca_path: String) -> Self {
221 Self {
222 ca_path: Some(ca_path),
223 ..Self::default()
224 }
225 }
226
227 pub fn insecure() -> Self {
229 Self {
230 skip_verify: true,
231 ..Self::default()
232 }
233 }
234
235 pub fn with_cert_path(mut self, path: String) -> Self {
237 self.cert_path = Some(path);
238 self
239 }
240
241 pub fn with_key_path(mut self, path: String) -> Self {
243 self.key_path = Some(path);
244 self
245 }
246
247 pub fn with_ca_path(mut self, path: String) -> Self {
249 self.ca_path = Some(path);
250 self
251 }
252
253 pub fn with_skip_verify(mut self, skip: bool) -> Self {
255 self.skip_verify = skip;
256 self
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_websocket_config_default() {
266 let config = WebSocketBidirectionalConfig::default();
267 assert_eq!(config.max_message_size, 16 * 1024 * 1024);
268 assert_eq!(config.keep_alive_interval, Duration::from_secs(30));
269 assert_eq!(config.max_concurrent_elicitations, 10);
270 assert!(!config.enable_compression);
271 }
272
273 #[test]
274 fn test_websocket_config_client() {
275 let config = WebSocketBidirectionalConfig::client("ws://example.com".to_string());
276 assert_eq!(config.url, Some("ws://example.com".to_string()));
277 assert_eq!(config.bind_addr, None);
278 }
279
280 #[test]
281 fn test_websocket_config_server() {
282 let config = WebSocketBidirectionalConfig::server("0.0.0.0:8080".to_string());
283 assert_eq!(config.bind_addr, Some("0.0.0.0:8080".to_string()));
284 assert_eq!(config.url, None);
285 }
286
287 #[test]
288 fn test_websocket_config_builder() {
289 let config = WebSocketBidirectionalConfig::new()
290 .with_max_message_size(1024)
291 .with_keep_alive_interval(Duration::from_secs(60))
292 .with_compression(true)
293 .with_max_concurrent_elicitations(5);
294
295 assert_eq!(config.max_message_size, 1024);
296 assert_eq!(config.keep_alive_interval, Duration::from_secs(60));
297 assert!(config.enable_compression);
298 assert_eq!(config.max_concurrent_elicitations, 5);
299 }
300
301 #[test]
302 fn test_tls_config_presets() {
303 let client_cert =
304 TlsConfig::with_client_cert("cert.pem".to_string(), "key.pem".to_string());
305 assert_eq!(client_cert.cert_path, Some("cert.pem".to_string()));
306 assert_eq!(client_cert.key_path, Some("key.pem".to_string()));
307
308 let ca_cert = TlsConfig::with_ca_cert("ca.pem".to_string());
309 assert_eq!(ca_cert.ca_path, Some("ca.pem".to_string()));
310
311 let insecure = TlsConfig::insecure();
312 assert!(insecure.skip_verify);
313 }
314}