1pub mod framing;
46pub mod handler;
47pub mod service;
48pub mod streaming;
49
50pub use framing::parse_grpc_client_stream;
52pub use handler::{GrpcHandler, GrpcHandlerResult, GrpcRequestData, GrpcResponseData, RpcMode};
53pub use service::{GenericGrpcService, copy_metadata, is_grpc_request, parse_grpc_path};
54pub use streaming::{MessageStream, StreamingRequest, StreamingResponse};
55
56use serde::{Deserialize, Serialize};
57use std::collections::HashMap;
58use std::sync::Arc;
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct GrpcConfig {
95 #[serde(default = "default_true")]
97 pub enabled: bool,
98
99 #[serde(default = "default_max_message_size")]
112 pub max_message_size: usize,
113
114 #[serde(default = "default_true")]
116 pub enable_compression: bool,
117
118 #[serde(default)]
120 pub request_timeout: Option<u64>,
121
122 #[serde(default = "default_max_concurrent_streams")]
142 pub max_concurrent_streams: u32,
143
144 #[serde(default = "default_true")]
146 pub enable_keepalive: bool,
147
148 #[serde(default = "default_keepalive_interval")]
150 pub keepalive_interval: u64,
151
152 #[serde(default = "default_keepalive_timeout")]
154 pub keepalive_timeout: u64,
155 }
158
159impl Default for GrpcConfig {
160 fn default() -> Self {
161 Self {
162 enabled: true,
163 max_message_size: default_max_message_size(),
164 enable_compression: true,
165 request_timeout: None,
166 max_concurrent_streams: default_max_concurrent_streams(),
167 enable_keepalive: true,
168 keepalive_interval: default_keepalive_interval(),
169 keepalive_timeout: default_keepalive_timeout(),
170 }
171 }
172}
173
174const fn default_true() -> bool {
175 true
176}
177
178const fn default_max_message_size() -> usize {
179 4 * 1024 * 1024 }
181
182const fn default_max_concurrent_streams() -> u32 {
183 100
184}
185
186const fn default_keepalive_interval() -> u64 {
187 75 }
189
190const fn default_keepalive_timeout() -> u64 {
191 20 }
193
194type GrpcHandlerEntry = (Arc<dyn GrpcHandler>, RpcMode);
210
211#[derive(Clone)]
212pub struct GrpcRegistry {
213 handlers: Arc<HashMap<String, GrpcHandlerEntry>>,
214}
215
216impl GrpcRegistry {
217 pub fn new() -> Self {
219 Self {
220 handlers: Arc::new(HashMap::new()),
221 }
222 }
223
224 pub fn register(&mut self, service_name: impl Into<String>, handler: Arc<dyn GrpcHandler>, rpc_mode: RpcMode) {
232 let handlers = Arc::make_mut(&mut self.handlers);
233 handlers.insert(service_name.into(), (handler, rpc_mode));
234 }
235
236 pub fn get(&self, service_name: &str) -> Option<(Arc<dyn GrpcHandler>, RpcMode)> {
241 self.handlers.get(service_name).cloned()
242 }
243
244 pub fn service_names(&self) -> Vec<String> {
246 self.handlers.keys().cloned().collect()
247 }
248
249 pub fn contains(&self, service_name: &str) -> bool {
251 self.handlers.contains_key(service_name)
252 }
253
254 pub fn len(&self) -> usize {
256 self.handlers.len()
257 }
258
259 pub fn is_empty(&self) -> bool {
261 self.handlers.is_empty()
262 }
263}
264
265impl Default for GrpcRegistry {
266 fn default() -> Self {
267 Self::new()
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use crate::grpc::handler::{GrpcHandler, GrpcHandlerResult, GrpcRequestData};
275 use std::future::Future;
276 use std::pin::Pin;
277
278 struct TestHandler;
279
280 impl GrpcHandler for TestHandler {
281 fn call(&self, _request: GrpcRequestData) -> Pin<Box<dyn Future<Output = GrpcHandlerResult> + Send>> {
282 Box::pin(async {
283 Ok(GrpcResponseData {
284 payload: bytes::Bytes::new(),
285 metadata: tonic::metadata::MetadataMap::new(),
286 })
287 })
288 }
289
290 fn service_name(&self) -> &'static str {
291 "test.Service"
294 }
295 }
296
297 #[test]
298 fn test_grpc_config_default() {
299 let config = GrpcConfig::default();
300 assert!(config.enabled);
301 assert_eq!(config.max_message_size, 4 * 1024 * 1024);
302 assert!(config.enable_compression);
303 assert!(config.request_timeout.is_none());
304 assert_eq!(config.max_concurrent_streams, 100);
305 assert!(config.enable_keepalive);
306 assert_eq!(config.keepalive_interval, 75);
307 assert_eq!(config.keepalive_timeout, 20);
308 }
309
310 #[test]
311 fn test_grpc_config_serialization() {
312 let config = GrpcConfig::default();
313 let json = serde_json::to_string(&config).unwrap();
314 let deserialized: GrpcConfig = serde_json::from_str(&json).unwrap();
315
316 assert_eq!(config.enabled, deserialized.enabled);
317 assert_eq!(config.max_message_size, deserialized.max_message_size);
318 assert_eq!(config.enable_compression, deserialized.enable_compression);
319 }
320
321 #[test]
322 fn test_grpc_registry_new() {
323 let registry = GrpcRegistry::new();
324 assert!(registry.is_empty());
325 assert_eq!(registry.len(), 0);
326 }
327
328 #[test]
329 fn test_grpc_registry_register() {
330 let mut registry = GrpcRegistry::new();
331 let handler = Arc::new(TestHandler);
332
333 registry.register("test.Service", handler, RpcMode::Unary);
334
335 assert!(!registry.is_empty());
336 assert_eq!(registry.len(), 1);
337 assert!(registry.contains("test.Service"));
338 }
339
340 #[test]
341 fn test_grpc_registry_get() {
342 let mut registry = GrpcRegistry::new();
343 let handler = Arc::new(TestHandler);
344
345 registry.register("test.Service", handler, RpcMode::Unary);
346
347 let retrieved = registry.get("test.Service");
348 assert!(retrieved.is_some());
349 let (handler, rpc_mode) = retrieved.unwrap();
350 assert_eq!(handler.service_name(), "test.Service");
351 assert_eq!(rpc_mode, RpcMode::Unary);
352 }
353
354 #[test]
355 fn test_grpc_registry_get_nonexistent() {
356 let registry = GrpcRegistry::new();
357 let result = registry.get("nonexistent.Service");
358 assert!(result.is_none());
359 }
360
361 #[test]
362 fn test_grpc_registry_service_names() {
363 let mut registry = GrpcRegistry::new();
364
365 registry.register("service1", Arc::new(TestHandler), RpcMode::Unary);
366 registry.register("service2", Arc::new(TestHandler), RpcMode::ServerStreaming);
367 registry.register("service3", Arc::new(TestHandler), RpcMode::Unary);
368
369 let mut names = registry.service_names();
370 names.sort();
371
372 assert_eq!(names, vec!["service1", "service2", "service3"]);
373 }
374
375 #[test]
376 fn test_grpc_registry_contains() {
377 let mut registry = GrpcRegistry::new();
378 registry.register("test.Service", Arc::new(TestHandler), RpcMode::Unary);
379
380 assert!(registry.contains("test.Service"));
381 assert!(!registry.contains("other.Service"));
382 }
383
384 #[test]
385 fn test_grpc_registry_multiple_services() {
386 let mut registry = GrpcRegistry::new();
387
388 registry.register("user.Service", Arc::new(TestHandler), RpcMode::Unary);
389 registry.register("post.Service", Arc::new(TestHandler), RpcMode::ServerStreaming);
390
391 assert_eq!(registry.len(), 2);
392 assert!(registry.contains("user.Service"));
393 assert!(registry.contains("post.Service"));
394 }
395
396 #[test]
397 fn test_grpc_registry_clone() {
398 let mut registry = GrpcRegistry::new();
399 registry.register("test.Service", Arc::new(TestHandler), RpcMode::Unary);
400
401 let cloned = registry.clone();
402
403 assert_eq!(cloned.len(), 1);
404 assert!(cloned.contains("test.Service"));
405 }
406
407 #[test]
408 fn test_grpc_registry_default() {
409 let registry = GrpcRegistry::default();
410 assert!(registry.is_empty());
411 }
412
413 #[test]
414 fn test_grpc_registry_rpc_mode_storage() {
415 let mut registry = GrpcRegistry::new();
416
417 registry.register("unary.Service", Arc::new(TestHandler), RpcMode::Unary);
418 registry.register("server_stream.Service", Arc::new(TestHandler), RpcMode::ServerStreaming);
419 registry.register("client_stream.Service", Arc::new(TestHandler), RpcMode::ClientStreaming);
420 registry.register("bidi.Service", Arc::new(TestHandler), RpcMode::BidirectionalStreaming);
421
422 let (_, mode) = registry.get("unary.Service").unwrap();
423 assert_eq!(mode, RpcMode::Unary);
424
425 let (_, mode) = registry.get("server_stream.Service").unwrap();
426 assert_eq!(mode, RpcMode::ServerStreaming);
427
428 let (_, mode) = registry.get("client_stream.Service").unwrap();
429 assert_eq!(mode, RpcMode::ClientStreaming);
430
431 let (_, mode) = registry.get("bidi.Service").unwrap();
432 assert_eq!(mode, RpcMode::BidirectionalStreaming);
433 }
434}