1use crate::error::{IoError, Result};
4use crate::ml_framework::{DataType, MLModel, MLTensor, TensorMetadata};
5use scirs2_core::ndarray::{ArrayD, IxDyn};
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, VecDeque};
8use std::sync::Arc;
9use std::sync::RwLock as StdRwLock;
10use std::time::{Duration, Instant};
11
12#[cfg(feature = "async")]
13use tokio::sync::{Mutex, RwLock};
14#[cfg(feature = "async")]
15use tokio::time::{sleep, timeout};
16
17#[cfg(feature = "async")]
19pub struct ModelServer {
20 model: Arc<RwLock<MLModel>>,
21 config: ServerConfig,
22 metrics: Arc<Mutex<ServerMetrics>>,
23 request_queue: Arc<Mutex<VecDeque<InferenceRequest>>>,
24 health_status: Arc<RwLock<HealthStatus>>,
25}
26
27#[cfg(not(feature = "async"))]
28pub struct ModelServer {
29 model: Arc<StdRwLock<MLModel>>,
30 config: ServerConfig,
31 metrics: Arc<StdRwLock<ServerMetrics>>,
32 health_status: Arc<StdRwLock<HealthStatus>>,
33}
34
35#[derive(Debug, Clone)]
36pub struct ServerConfig {
37 pub max_batch_size: usize,
38 pub timeout_ms: u64,
39 pub num_workers: usize,
40 pub enable_batching: bool,
41 pub batch_timeout_ms: u64,
42 pub max_queue_size: usize,
43 pub enable_streaming: bool,
44 pub api_config: ApiConfig,
45}
46
47#[derive(Debug, Clone)]
48pub struct ApiConfig {
49 pub rest_enabled: bool,
50 pub grpc_enabled: bool,
51 pub rest_port: u16,
52 pub grpc_port: u16,
53 pub enable_cors: bool,
54 pub enable_auth: bool,
55 pub auth_token: Option<String>,
56 pub rate_limit: Option<RateLimit>,
57}
58
59#[derive(Debug, Clone)]
60pub struct RateLimit {
61 pub requests_per_minute: u32,
62 pub requests_per_hour: u32,
63}
64
65impl Default for ServerConfig {
66 fn default() -> Self {
67 Self {
68 max_batch_size: 32,
69 timeout_ms: 5000,
70 num_workers: 4,
71 enable_batching: true,
72 batch_timeout_ms: 100,
73 max_queue_size: 1000,
74 enable_streaming: false,
75 api_config: ApiConfig::default(),
76 }
77 }
78}
79
80impl Default for ApiConfig {
81 fn default() -> Self {
82 Self {
83 rest_enabled: true,
84 grpc_enabled: false,
85 rest_port: 8080,
86 grpc_port: 9090,
87 enable_cors: true,
88 enable_auth: false,
89 auth_token: None,
90 rate_limit: None,
91 }
92 }
93}
94
95#[derive(Debug, Clone)]
96pub struct InferenceRequest {
97 pub id: String,
98 pub inputs: HashMap<String, MLTensor>,
99 pub metadata: HashMap<String, serde_json::Value>,
100 pub timestamp: Instant,
101 pub timeout: Duration,
102}
103
104#[derive(Debug, Clone)]
105pub struct InferenceResponse {
106 pub request_id: String,
107 pub outputs: HashMap<String, MLTensor>,
108 pub metadata: HashMap<String, serde_json::Value>,
109 pub processing_time_ms: u64,
110 pub status: ResponseStatus,
111}
112
113#[derive(Debug, Clone)]
114pub enum ResponseStatus {
115 Success,
116 Error { code: u16, message: String },
117 Timeout,
118 QueueFull,
119}
120
121#[derive(Debug, Clone, Default)]
122pub struct ServerMetrics {
123 pub total_requests: u64,
124 pub successful_requests: u64,
125 pub failed_requests: u64,
126 pub average_latency_ms: f64,
127 pub requests_per_second: f64,
128 pub current_queue_size: usize,
129 pub max_queue_size_reached: usize,
130 pub model_load_time_ms: u64,
131 pub uptime_seconds: u64,
132 pub batch_stats: BatchStats,
133}
134
135#[derive(Debug, Clone, Default)]
136pub struct BatchStats {
137 pub total_batches: u64,
138 pub average_batch_size: f64,
139 pub batch_processing_time_ms: f64,
140}
141
142#[derive(Debug, Clone)]
143pub enum HealthStatus {
144 Healthy,
145 Degraded { reason: String },
146 Unhealthy { reason: String },
147 Starting,
148 Stopping,
149}
150
151#[cfg(feature = "async")]
152impl ModelServer {
153 pub async fn new(model: MLModel, config: ServerConfig) -> Self {
154 Self {
155 model: Arc::new(RwLock::new(model)),
156 config,
157 metrics: Arc::new(Mutex::new(ServerMetrics::default())),
158 request_queue: Arc::new(Mutex::new(VecDeque::new())),
159 health_status: Arc::new(RwLock::new(HealthStatus::Starting)),
160 }
161 }
162
163 pub async fn start(&self) -> Result<()> {
165 {
167 let mut status = self.health_status.write().await;
168 *status = HealthStatus::Healthy;
169 }
170
171 self.start_metrics_collection().await;
173
174 self.start_workers().await?;
176
177 if self.config.api_config.rest_enabled {
179 self.start_rest_api().await?;
180 }
181
182 if self.config.api_config.grpc_enabled {
184 self.start_grpc_api().await?;
185 }
186
187 Ok(())
188 }
189
190 pub async fn infer(&self, request: InferenceRequest) -> Result<InferenceResponse> {
192 let start_time = Instant::now();
193
194 {
196 let queue = self.request_queue.lock().await;
197 if queue.len() >= self.config.max_queue_size {
198 return Ok(InferenceResponse {
199 request_id: request.id,
200 outputs: HashMap::new(),
201 metadata: HashMap::new(),
202 processing_time_ms: start_time.elapsed().as_millis() as u64,
203 status: ResponseStatus::QueueFull,
204 });
205 }
206 }
207
208 if self.config.enable_batching {
210 {
211 let mut queue = self.request_queue.lock().await;
212 queue.push_back(request.clone());
213 }
214
215 sleep(Duration::from_millis(self.config.batch_timeout_ms)).await;
217 }
218
219 let result = self.process_inference(&request.inputs).await;
221
222 self.update_metrics(start_time, result.is_ok()).await;
224
225 match result {
226 Ok(outputs) => Ok(InferenceResponse {
227 request_id: request.id,
228 outputs,
229 metadata: HashMap::new(),
230 processing_time_ms: start_time.elapsed().as_millis() as u64,
231 status: ResponseStatus::Success,
232 }),
233 Err(e) => Ok(InferenceResponse {
234 request_id: request.id,
235 outputs: HashMap::new(),
236 metadata: HashMap::new(),
237 processing_time_ms: start_time.elapsed().as_millis() as u64,
238 status: ResponseStatus::Error {
239 code: 500,
240 message: e.to_string(),
241 },
242 }),
243 }
244 }
245
246 pub async fn batch_infer(
248 &self,
249 requests: Vec<InferenceRequest>,
250 ) -> Result<Vec<InferenceResponse>> {
251 let start_time = Instant::now();
252 let mut responses = Vec::new();
253
254 for batch in requests.chunks(self.config.max_batch_size) {
255 let mut batch_inputs = HashMap::new();
256
257 for (i, request) in batch.iter().enumerate() {
259 for (name, tensor) in &request.inputs {
260 let batch_name = format!("{}_{}", name, i);
261 batch_inputs.insert(batch_name, tensor.clone());
262 }
263 }
264
265 let batch_outputs = self.process_inference(&batch_inputs).await?;
267
268 for (i, request) in batch.iter().enumerate() {
270 let mut outputs = HashMap::new();
271 for name in request.inputs.keys() {
272 let batch_name = format!("{}_{}", name, i);
273 if let Some(output) = batch_outputs.get(&batch_name) {
274 outputs.insert(name.clone(), output.clone());
275 }
276 }
277
278 responses.push(InferenceResponse {
279 request_id: request.id.clone(),
280 outputs,
281 metadata: HashMap::new(),
282 processing_time_ms: start_time.elapsed().as_millis() as u64,
283 status: ResponseStatus::Success,
284 });
285 }
286 }
287
288 self.update_batch_metrics(requests.len(), start_time).await;
290
291 Ok(responses)
292 }
293
294 async fn process_inference(
296 &self,
297 inputs: &HashMap<String, MLTensor>,
298 ) -> Result<HashMap<String, MLTensor>> {
299 let model = self.model.read().await;
300
301 let mut outputs = HashMap::new();
303 for (name, tensor) in inputs {
304 outputs.insert(format!("output_{}", name), tensor.clone());
306 }
307
308 Ok(outputs)
309 }
310
311 async fn start_rest_api(&self) -> Result<()> {
313 println!(
316 "Starting REST API server on port {}",
317 self.config.api_config.rest_port
318 );
319
320 Ok(())
329 }
330
331 async fn start_grpc_api(&self) -> Result<()> {
333 println!(
335 "Starting gRPC API server on port {}",
336 self.config.api_config.grpc_port
337 );
338
339 Ok(())
346 }
347
348 async fn start_workers(&self) -> Result<()> {
350 for _worker_id in 0..self.config.num_workers {
351 let queue = self.request_queue.clone();
352 let _config = self.config.clone();
353
354 tokio::spawn(async move {
355 loop {
356 let request = {
358 let mut queue_guard = queue.lock().await;
359 queue_guard.pop_front()
360 };
361
362 if let Some(_request) = request {
363 sleep(Duration::from_millis(10)).await; } else {
366 sleep(Duration::from_millis(1)).await;
368 }
369 }
370 });
371 }
372
373 Ok(())
374 }
375
376 async fn start_metrics_collection(&self) {
378 let metrics = self.metrics.clone();
379 let start_time = Instant::now();
380
381 tokio::spawn(async move {
382 loop {
383 sleep(Duration::from_secs(1)).await;
384
385 {
387 let mut m = metrics.lock().await;
388 m.uptime_seconds = start_time.elapsed().as_secs();
389 }
390 }
391 });
392 }
393
394 async fn update_metrics(&self, start_time: Instant, success: bool) {
396 let mut metrics = self.metrics.lock().await;
397 metrics.total_requests += 1;
398
399 if success {
400 metrics.successful_requests += 1;
401 } else {
402 metrics.failed_requests += 1;
403 }
404
405 let latency = start_time.elapsed().as_millis() as f64;
406 metrics.average_latency_ms =
407 (metrics.average_latency_ms * (metrics.total_requests - 1) as f64 + latency)
408 / metrics.total_requests as f64;
409 }
410
411 async fn update_batch_metrics(&self, batch_size: usize, start_time: Instant) {
413 let mut metrics = self.metrics.lock().await;
414 metrics.batch_stats.total_batches += 1;
415
416 let current_avg = metrics.batch_stats.average_batch_size;
417 let total_batches = metrics.batch_stats.total_batches as f64;
418 metrics.batch_stats.average_batch_size =
419 (current_avg * (total_batches - 1.0) + batch_size as f64) / total_batches;
420
421 let processing_time = start_time.elapsed().as_millis() as f64;
422 let current_time_avg = metrics.batch_stats.batch_processing_time_ms;
423 metrics.batch_stats.batch_processing_time_ms =
424 (current_time_avg * (total_batches - 1.0) + processing_time) / total_batches;
425 }
426
427 pub async fn get_health(&self) -> HealthStatus {
429 self.health_status.read().await.clone()
430 }
431
432 pub async fn get_metrics(&self) -> ServerMetrics {
434 self.metrics.lock().await.clone()
435 }
436
437 pub async fn update_model(&self, newmodel: MLModel) -> Result<()> {
439 let start_time = Instant::now();
440
441 {
442 let mut model = self.model.write().await;
443 *model = newmodel;
444 }
445
446 {
448 let mut metrics = self.metrics.lock().await;
449 metrics.model_load_time_ms = start_time.elapsed().as_millis() as u64;
450 }
451
452 Ok(())
453 }
454
455 pub async fn get_model_info(&self) -> ModelInfo {
457 let model = self.model.read().await;
458 ModelInfo {
459 name: model
460 .metadata
461 .model_name
462 .clone()
463 .unwrap_or_else(|| "Unknown".to_string()),
464 framework: model.metadata.framework.clone(),
465 version: model.metadata.model_version.clone(),
466 inputshapes: model.metadata.inputshapes.clone(),
467 outputshapes: model.metadata.outputshapes.clone(),
468 parameters: model.weights.len(),
469 loaded_at: Instant::now(), }
471 }
472
473 pub async fn shutdown(&self) -> Result<()> {
475 {
476 let mut status = self.health_status.write().await;
477 *status = HealthStatus::Stopping;
478 }
479
480 sleep(Duration::from_millis(self.config.timeout_ms)).await;
482
483 {
485 let mut queue = self.request_queue.lock().await;
486 queue.clear();
487 }
488
489 Ok(())
490 }
491}
492
493#[derive(Debug, Clone)]
494pub struct ModelInfo {
495 pub name: String,
496 pub framework: String,
497 pub version: Option<String>,
498 pub inputshapes: HashMap<String, Vec<usize>>,
499 pub outputshapes: HashMap<String, Vec<usize>>,
500 pub parameters: usize,
501 pub loaded_at: Instant,
502}
503
504pub mod rest {
506 use super::*;
507
508 #[derive(Debug, Serialize, Deserialize)]
509 pub struct PredictRequest {
510 pub inputs: HashMap<String, Vec<f32>>,
511 pub metadata: Option<HashMap<String, serde_json::Value>>,
512 }
513
514 #[derive(Debug, Serialize, Deserialize)]
515 pub struct PredictResponse {
516 pub outputs: HashMap<String, Vec<f32>>,
517 pub metadata: HashMap<String, serde_json::Value>,
518 pub processing_time_ms: u64,
519 }
520
521 #[derive(Debug, Serialize, Deserialize)]
522 pub struct BatchPredictRequest {
523 pub inputs: Vec<HashMap<String, Vec<f32>>>,
524 pub metadata: Option<HashMap<String, serde_json::Value>>,
525 }
526
527 #[derive(Debug, Serialize, Deserialize)]
528 pub struct BatchPredictResponse {
529 pub outputs: Vec<HashMap<String, Vec<f32>>>,
530 pub metadata: HashMap<String, serde_json::Value>,
531 pub processing_time_ms: u64,
532 }
533
534 #[derive(Debug, Serialize, Deserialize)]
535 pub struct HealthResponse {
536 pub status: String,
537 pub uptime_seconds: u64,
538 pub version: String,
539 }
540
541 #[derive(Debug, Serialize, Deserialize)]
542 pub struct MetricsResponse {
543 pub total_requests: u64,
544 pub successful_requests: u64,
545 pub failed_requests: u64,
546 pub average_latency_ms: f64,
547 pub requests_per_second: f64,
548 pub queue_size: usize,
549 pub uptime_seconds: u64,
550 }
551
552 pub fn tensor_to_rest(tensor: &MLTensor) -> Vec<f32> {
554 tensor.data.as_slice().unwrap().to_vec()
555 }
556
557 pub fn rest_to_tensor(
559 data: Vec<f32>,
560 shape: Vec<usize>,
561 name: Option<String>,
562 ) -> Result<MLTensor> {
563 let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
564 .map_err(|e| IoError::Other(e.to_string()))?;
565 Ok(MLTensor::new(array, name))
566 }
567}
568
569pub mod grpc {
571 use super::*;
572
573 #[derive(Debug, Clone)]
577 pub struct GrpcTensor {
578 pub name: String,
579 pub shape: Vec<i64>,
580 pub dtype: String,
581 pub data: Vec<u8>,
582 }
583
584 #[derive(Debug, Clone)]
585 pub struct GrpcPredictRequest {
586 pub model_name: String,
587 pub inputs: Vec<GrpcTensor>,
588 pub metadata: HashMap<String, String>,
589 }
590
591 #[derive(Debug, Clone)]
592 pub struct GrpcPredictResponse {
593 pub outputs: Vec<GrpcTensor>,
594 pub metadata: HashMap<String, String>,
595 pub status: GrpcStatus,
596 }
597
598 #[derive(Debug, Clone)]
599 pub struct GrpcStatus {
600 pub code: i32,
601 pub message: String,
602 }
603
604 pub fn tensor_to_grpc(tensor: &MLTensor) -> GrpcTensor {
606 GrpcTensor {
607 name: tensor.metadata.name.clone().unwrap_or_default(),
608 shape: tensor.metadata.shape.iter().map(|&s| s as i64).collect(),
609 dtype: format!("{:?}", tensor.metadata.dtype),
610 data: tensor
611 .data
612 .as_slice()
613 .unwrap()
614 .iter()
615 .flat_map(|f| f.to_le_bytes())
616 .collect(),
617 }
618 }
619
620 pub fn grpc_to_tensor(grpctensor: &GrpcTensor) -> Result<MLTensor> {
622 let shape: Vec<usize> = grpctensor.shape.iter().map(|&s| s as usize).collect();
623
624 let float_data: Vec<f32> = grpctensor
626 .data
627 .chunks_exact(4)
628 .map(|chunk| {
629 let bytes: [u8; 4] = chunk.try_into().unwrap();
630 f32::from_le_bytes(bytes)
631 })
632 .collect();
633
634 let array = ArrayD::from_shape_vec(IxDyn(&shape), float_data)
635 .map_err(|e| IoError::Other(e.to_string()))?;
636
637 Ok(MLTensor::new(array, Some(grpctensor.name.clone())))
638 }
639}
640
641pub struct LoadBalancer {
643 servers: Vec<ModelServer>,
644 strategy: LoadBalancingStrategy,
645 health_checker: HealthChecker,
646}
647
648#[derive(Debug, Clone)]
649pub enum LoadBalancingStrategy {
650 RoundRobin,
651 LeastConnections,
652 WeightedRoundRobin { weights: Vec<f32> },
653 Random,
654 HealthBased,
655}
656
657pub struct HealthChecker {
658 check_interval: Duration,
659 timeout: Duration,
660}
661
662#[cfg(feature = "async")]
663impl LoadBalancer {
664 pub fn new(servers: Vec<ModelServer>, strategy: LoadBalancingStrategy) -> Self {
665 Self {
666 servers,
667 strategy,
668 health_checker: HealthChecker {
669 check_interval: Duration::from_secs(30),
670 timeout: Duration::from_secs(5),
671 },
672 }
673 }
674
675 pub async fn route_request(&self, request: InferenceRequest) -> Result<InferenceResponse> {
677 let server = self.select_server().await?;
678 server.infer(request).await
679 }
680
681 async fn select_server(&self) -> Result<&ModelServer> {
683 match self.strategy {
684 LoadBalancingStrategy::RoundRobin => {
685 Ok(&self.servers[0])
687 }
688 LoadBalancingStrategy::HealthBased => {
689 for server in &self.servers {
691 if matches!(server.get_health().await, HealthStatus::Healthy) {
692 return Ok(server);
693 }
694 }
695 Err(IoError::Other("No healthy servers available".to_string()))
696 }
697 _ => Ok(&self.servers[0]), }
699 }
700
701 pub async fn start_health_checking(&self) {
703 let interval = self.health_checker.check_interval;
704
705 tokio::spawn(async move {
706 loop {
707 sleep(interval).await;
708 }
711 });
712 }
713}