1use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, VecDeque};
8use std::sync::{Arc, Mutex};
9use std::time::{Duration, Instant};
10use tokio::sync::RwLock;
11use trustformers_core::errors::{Result, TrustformersError};
12use trustformers_core::{traits::Model, Tensor};
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct ServingConfig {
18 pub max_concurrent_requests: usize,
20 pub request_timeout_seconds: u64,
22 pub max_queue_size: usize,
24 pub health_check_interval_seconds: u64,
26 pub enable_metrics: bool,
28 pub load_balancing_strategy: LoadBalancingStrategy,
30}
31
32impl Default for ServingConfig {
33 fn default() -> Self {
34 Self {
35 max_concurrent_requests: 10,
36 request_timeout_seconds: 30,
37 max_queue_size: 100,
38 health_check_interval_seconds: 60,
39 enable_metrics: true,
40 load_balancing_strategy: LoadBalancingStrategy::RoundRobin,
41 }
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub enum LoadBalancingStrategy {
48 RoundRobin,
50 LeastConnections,
52 WeightedRoundRobin(Vec<f64>),
54 ResponseTime,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
60pub enum RequestPriority {
61 Low = 1,
62 Normal = 2,
63 High = 3,
64 Critical = 4,
65}
66
67#[derive(Debug, Clone)]
69pub struct InferenceRequest {
70 pub id: Uuid,
71 pub input: Tensor,
72 pub priority: RequestPriority,
73 pub timestamp: Instant,
74 pub metadata: HashMap<String, String>,
75}
76
77impl InferenceRequest {
78 pub fn new(input: Tensor, priority: RequestPriority) -> Self {
80 Self {
81 id: Uuid::new_v4(),
82 input,
83 priority,
84 timestamp: Instant::now(),
85 metadata: HashMap::new(),
86 }
87 }
88
89 pub fn with_metadata(mut self, key: String, value: String) -> Self {
91 self.metadata.insert(key, value);
92 self
93 }
94
95 pub fn elapsed(&self) -> Duration {
97 self.timestamp.elapsed()
98 }
99}
100
101#[derive(Debug)]
103pub struct InferenceResponse {
104 pub request_id: Uuid,
105 pub output: Result<Tensor>,
106 pub processing_time: Duration,
107 pub metadata: HashMap<String, String>,
108}
109
110#[derive(Debug)]
112pub struct ModelInstance {
113 pub id: String,
114 pub weight: f64,
115 pub active_requests: usize,
116 pub total_requests: u64,
117 pub total_processing_time: Duration,
118 pub last_health_check: Instant,
119 pub is_healthy: bool,
120}
121
122impl ModelInstance {
123 pub fn new(id: String, weight: f64) -> Self {
125 Self {
126 id,
127 weight,
128 active_requests: 0,
129 total_requests: 0,
130 total_processing_time: Duration::new(0, 0),
131 last_health_check: Instant::now(),
132 is_healthy: true,
133 }
134 }
135
136 pub fn update_stats(&mut self, processing_time: Duration) {
138 self.active_requests = self.active_requests.saturating_sub(1);
139 self.total_requests += 1;
140 self.total_processing_time += processing_time;
141 }
142
143 pub fn average_response_time(&self) -> Duration {
145 if self.total_requests > 0 {
146 self.total_processing_time / self.total_requests as u32
147 } else {
148 Duration::new(0, 0)
149 }
150 }
151
152 pub fn start_request(&mut self) {
154 self.active_requests += 1;
155 }
156}
157
158#[derive(Debug)]
160pub struct LoadBalancer {
161 instances: Vec<ModelInstance>,
162 strategy: LoadBalancingStrategy,
163 current_index: usize,
164}
165
166impl LoadBalancer {
167 pub fn new(strategy: LoadBalancingStrategy) -> Self {
169 Self {
170 instances: Vec::new(),
171 strategy,
172 current_index: 0,
173 }
174 }
175
176 pub fn add_instance(&mut self, instance: ModelInstance) {
178 self.instances.push(instance);
179 }
180
181 pub fn select_instance(&mut self) -> Option<&mut ModelInstance> {
183 if self.instances.is_empty() {
184 return None;
185 }
186
187 let selected_index = match &self.strategy {
188 LoadBalancingStrategy::RoundRobin => {
189 let index = self.current_index;
190 self.current_index = (self.current_index + 1) % self.instances.len();
191 index
192 },
193 LoadBalancingStrategy::LeastConnections => self
194 .instances
195 .iter()
196 .enumerate()
197 .filter(|(_, instance)| instance.is_healthy)
198 .min_by_key(|(_, instance)| instance.active_requests)
199 .map(|(index, _)| index)
200 .unwrap_or(0),
201 LoadBalancingStrategy::WeightedRoundRobin(weights) => {
202 self.instances
204 .iter()
205 .enumerate()
206 .filter(|(_, instance)| instance.is_healthy)
207 .max_by(|(i, _), (j, _)| {
208 let weight_i = weights.get(*i).unwrap_or(&1.0);
209 let weight_j = weights.get(*j).unwrap_or(&1.0);
210 weight_i.partial_cmp(weight_j).unwrap_or(std::cmp::Ordering::Equal)
211 })
212 .map(|(index, _)| index)
213 .unwrap_or(0)
214 },
215 LoadBalancingStrategy::ResponseTime => self
216 .instances
217 .iter()
218 .enumerate()
219 .filter(|(_, instance)| instance.is_healthy)
220 .min_by_key(|(_, instance)| instance.average_response_time())
221 .map(|(index, _)| index)
222 .unwrap_or(0),
223 };
224
225 self.instances.get_mut(selected_index)
226 }
227
228 pub fn healthy_instances_count(&self) -> usize {
230 self.instances.iter().filter(|i| i.is_healthy).count()
231 }
232
233 pub fn update_instance_health(&mut self, instance_id: &str, is_healthy: bool) {
235 if let Some(instance) = self.instances.iter_mut().find(|i| i.id == instance_id) {
236 instance.is_healthy = is_healthy;
237 instance.last_health_check = Instant::now();
238 }
239 }
240}
241
242#[derive(Debug)]
244pub struct RequestQueue {
245 queue: VecDeque<InferenceRequest>,
246 max_size: usize,
247}
248
249impl RequestQueue {
250 pub fn new(max_size: usize) -> Self {
252 Self {
253 queue: VecDeque::new(),
254 max_size,
255 }
256 }
257
258 pub fn enqueue(&mut self, request: InferenceRequest) -> Result<()> {
260 if self.queue.len() >= self.max_size {
261 return Err(TrustformersError::resource_exhausted(
262 "Request queue is full".to_string(),
263 ));
264 }
265
266 let insert_index = self
268 .queue
269 .iter()
270 .position(|r| r.priority < request.priority)
271 .unwrap_or(self.queue.len());
272
273 self.queue.insert(insert_index, request);
274 Ok(())
275 }
276
277 pub fn dequeue(&mut self) -> Option<InferenceRequest> {
279 self.queue.pop_front()
280 }
281
282 pub fn size(&self) -> usize {
284 self.queue.len()
285 }
286
287 pub fn is_empty(&self) -> bool {
289 self.queue.is_empty()
290 }
291
292 pub fn remove_expired(&mut self, timeout: Duration) -> usize {
294 let initial_size = self.queue.len();
295 self.queue.retain(|req| req.elapsed() < timeout);
296 initial_size - self.queue.len()
297 }
298}
299
300#[derive(Debug, Default, Clone, Serialize, Deserialize)]
302pub struct ServingMetrics {
303 pub total_requests: u64,
304 pub successful_requests: u64,
305 pub failed_requests: u64,
306 pub timeout_requests: u64,
307 pub average_response_time_ms: f64,
308 pub current_queue_size: usize,
309 pub peak_queue_size: usize,
310 pub active_connections: usize,
311}
312
313impl ServingMetrics {
314 pub fn update_request(&mut self, success: bool, response_time: Duration) {
316 self.total_requests += 1;
317 if success {
318 self.successful_requests += 1;
319 } else {
320 self.failed_requests += 1;
321 }
322
323 let new_time_ms = response_time.as_millis() as f64;
325 if self.total_requests == 1 {
326 self.average_response_time_ms = new_time_ms;
327 } else {
328 self.average_response_time_ms =
329 (self.average_response_time_ms * (self.total_requests - 1) as f64 + new_time_ms)
330 / self.total_requests as f64;
331 }
332 }
333
334 pub fn update_queue_size(&mut self, current_size: usize) {
336 self.current_queue_size = current_size;
337 if current_size > self.peak_queue_size {
338 self.peak_queue_size = current_size;
339 }
340 }
341
342 pub fn record_timeout(&mut self) {
344 self.timeout_requests += 1;
345 self.failed_requests += 1;
346 self.total_requests += 1;
347 }
348
349 pub fn success_rate(&self) -> f64 {
351 if self.total_requests > 0 {
352 self.successful_requests as f64 / self.total_requests as f64
353 } else {
354 0.0
355 }
356 }
357}
358
359#[derive(Debug, Clone, Copy, PartialEq, Eq)]
361pub enum CircuitBreakerState {
362 Closed, Open, HalfOpen, }
366
367#[derive(Debug)]
369pub struct CircuitBreaker {
370 state: CircuitBreakerState,
371 failure_count: usize,
372 failure_threshold: usize,
373 recovery_timeout: Duration,
374 last_failure_time: Option<Instant>,
375 success_threshold: usize, half_open_successes: usize,
377}
378
379impl CircuitBreaker {
380 pub fn new(
382 failure_threshold: usize,
383 recovery_timeout: Duration,
384 success_threshold: usize,
385 ) -> Self {
386 Self {
387 state: CircuitBreakerState::Closed,
388 failure_count: 0,
389 failure_threshold,
390 recovery_timeout,
391 last_failure_time: None,
392 success_threshold,
393 half_open_successes: 0,
394 }
395 }
396
397 pub fn allow_request(&mut self) -> bool {
399 match self.state {
400 CircuitBreakerState::Closed => true,
401 CircuitBreakerState::Open => {
402 if let Some(last_failure) = self.last_failure_time {
403 if last_failure.elapsed() >= self.recovery_timeout {
404 self.state = CircuitBreakerState::HalfOpen;
405 self.half_open_successes = 0;
406 true
407 } else {
408 false
409 }
410 } else {
411 false
412 }
413 },
414 CircuitBreakerState::HalfOpen => true,
415 }
416 }
417
418 pub fn record_success(&mut self) {
420 match self.state {
421 CircuitBreakerState::HalfOpen => {
422 self.half_open_successes += 1;
423 if self.half_open_successes >= self.success_threshold {
424 self.state = CircuitBreakerState::Closed;
425 self.failure_count = 0;
426 self.last_failure_time = None;
427 }
428 },
429 CircuitBreakerState::Closed => {
430 self.failure_count = 0;
431 },
432 _ => {},
433 }
434 }
435
436 pub fn record_failure(&mut self) {
438 self.failure_count += 1;
439 self.last_failure_time = Some(Instant::now());
440
441 match self.state {
442 CircuitBreakerState::Closed if self.failure_count >= self.failure_threshold => {
443 self.state = CircuitBreakerState::Open;
444 },
445 CircuitBreakerState::HalfOpen => {
446 self.state = CircuitBreakerState::Open;
447 self.half_open_successes = 0;
448 },
449 _ => {},
450 }
451 }
452
453 pub fn state(&self) -> CircuitBreakerState {
455 self.state
456 }
457}
458
459#[derive(Debug)]
461pub struct HealthMonitor {
462 circuit_breakers: HashMap<String, CircuitBreaker>,
463 health_check_interval: Duration,
464 last_health_check: Instant,
465}
466
467impl HealthMonitor {
468 pub fn new(health_check_interval: Duration) -> Self {
470 Self {
471 circuit_breakers: HashMap::new(),
472 health_check_interval,
473 last_health_check: Instant::now(),
474 }
475 }
476
477 pub fn add_instance(&mut self, instance_id: String) {
479 let circuit_breaker = CircuitBreaker::new(
480 3, Duration::from_secs(30), 2, );
484 self.circuit_breakers.insert(instance_id, circuit_breaker);
485 }
486
487 pub fn can_handle_request(&mut self, instance_id: &str) -> bool {
489 if let Some(circuit_breaker) = self.circuit_breakers.get_mut(instance_id) {
490 circuit_breaker.allow_request()
491 } else {
492 false
493 }
494 }
495
496 pub fn record_success(&mut self, instance_id: &str) {
498 if let Some(circuit_breaker) = self.circuit_breakers.get_mut(instance_id) {
499 circuit_breaker.record_success();
500 }
501 }
502
503 pub fn record_failure(&mut self, instance_id: &str) {
505 if let Some(circuit_breaker) = self.circuit_breakers.get_mut(instance_id) {
506 circuit_breaker.record_failure();
507 }
508 }
509
510 pub fn get_health_status(&self) -> HashMap<String, CircuitBreakerState> {
512 self.circuit_breakers.iter().map(|(id, cb)| (id.clone(), cb.state())).collect()
513 }
514
515 pub fn should_run_health_check(&self) -> bool {
517 self.last_health_check.elapsed() >= self.health_check_interval
518 }
519}
520
521pub type ModelInferenceFn = Arc<dyn Fn(Tensor) -> Result<Tensor> + Send + Sync>;
523
524pub struct ModelServingManager {
526 config: ServingConfig,
527 load_balancer: Arc<Mutex<LoadBalancer>>,
528 request_queue: Arc<Mutex<RequestQueue>>,
529 metrics: Arc<RwLock<ServingMetrics>>,
530 health_monitor: Arc<Mutex<HealthMonitor>>,
531 model_fn: Option<ModelInferenceFn>,
532}
533
534impl std::fmt::Debug for ModelServingManager {
535 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536 f.debug_struct("ModelServingManager")
537 .field("config", &self.config)
538 .field("load_balancer", &"Arc<Mutex<LoadBalancer>>")
539 .field("request_queue", &"Arc<Mutex<RequestQueue>>")
540 .field("metrics", &"Arc<RwLock<ServingMetrics>>")
541 .field("health_monitor", &"Arc<Mutex<HealthMonitor>>")
542 .field("model_fn", &self.model_fn.is_some())
543 .finish()
544 }
545}
546
547impl ModelServingManager {
548 pub fn new(config: ServingConfig) -> Self {
550 let load_balancer = LoadBalancer::new(config.load_balancing_strategy.clone());
551 let request_queue = RequestQueue::new(config.max_queue_size);
552 let health_monitor =
553 HealthMonitor::new(Duration::from_secs(config.health_check_interval_seconds));
554
555 Self {
556 config,
557 load_balancer: Arc::new(Mutex::new(load_balancer)),
558 request_queue: Arc::new(Mutex::new(request_queue)),
559 metrics: Arc::new(RwLock::new(ServingMetrics::default())),
560 health_monitor: Arc::new(Mutex::new(health_monitor)),
561 model_fn: None,
562 }
563 }
564
565 pub fn with_model<M: Model<Input = Tensor, Output = Tensor> + 'static>(
567 config: ServingConfig,
568 model: M,
569 ) -> Self {
570 let load_balancer = LoadBalancer::new(config.load_balancing_strategy.clone());
571 let request_queue = RequestQueue::new(config.max_queue_size);
572 let health_monitor =
573 HealthMonitor::new(Duration::from_secs(config.health_check_interval_seconds));
574
575 let model = Arc::new(model);
576 let model_fn: ModelInferenceFn = Arc::new(move |input| model.forward(input));
577
578 Self {
579 config,
580 load_balancer: Arc::new(Mutex::new(load_balancer)),
581 request_queue: Arc::new(Mutex::new(request_queue)),
582 metrics: Arc::new(RwLock::new(ServingMetrics::default())),
583 health_monitor: Arc::new(Mutex::new(health_monitor)),
584 model_fn: Some(model_fn),
585 }
586 }
587
588 pub fn set_inference_fn(&mut self, inference_fn: ModelInferenceFn) {
590 self.model_fn = Some(inference_fn);
591 }
592
593 pub fn add_instance(&self, instance: ModelInstance) -> Result<()> {
595 let instance_id = instance.id.clone();
596
597 let mut balancer = self.load_balancer.lock().map_err(|_| {
598 TrustformersError::runtime_error("Failed to acquire load balancer lock".to_string())
599 })?;
600 balancer.add_instance(instance);
601
602 let mut health_monitor = self.health_monitor.lock().map_err(|_| {
604 TrustformersError::runtime_error("Failed to acquire health monitor lock".to_string())
605 })?;
606 health_monitor.add_instance(instance_id);
607
608 Ok(())
609 }
610
611 pub fn get_health_status(&self) -> Result<HashMap<String, CircuitBreakerState>> {
613 let health_monitor = self.health_monitor.lock().map_err(|_| {
614 TrustformersError::runtime_error("Failed to acquire health monitor lock".to_string())
615 })?;
616 Ok(health_monitor.get_health_status())
617 }
618
619 pub async fn perform_health_check(&self) -> Result<()> {
621 let should_check = {
622 let health_monitor = self.health_monitor.lock().map_err(|_| {
623 TrustformersError::runtime_error(
624 "Failed to acquire health monitor lock".to_string(),
625 )
626 })?;
627 health_monitor.should_run_health_check()
628 };
629
630 if should_check {
631 let mut _health_monitor = self.health_monitor.lock().map_err(|_| {
634 TrustformersError::runtime_error(
635 "Failed to acquire health monitor lock".to_string(),
636 )
637 })?;
638 }
640
641 Ok(())
642 }
643
644 pub async fn submit_request(&self, request: InferenceRequest) -> Result<()> {
646 let mut queue = self.request_queue.lock().map_err(|_| {
647 TrustformersError::runtime_error("Failed to acquire queue lock".to_string())
648 })?;
649
650 queue.enqueue(request)?;
651
652 if self.config.enable_metrics {
654 let mut metrics = self.metrics.write().await;
655 metrics.update_queue_size(queue.size());
656 }
657
658 Ok(())
659 }
660
661 pub async fn process_next_request(&self) -> Result<Option<InferenceResponse>> {
663 let request = {
665 let mut queue = self.request_queue.lock().map_err(|_| {
666 TrustformersError::runtime_error("Failed to acquire queue lock".to_string())
667 })?;
668 queue.dequeue()
669 };
670
671 let request = match request {
672 Some(req) => req,
673 None => return Ok(None),
674 };
675
676 let timeout_duration = Duration::from_secs(self.config.request_timeout_seconds);
678 if request.elapsed() > timeout_duration {
679 if self.config.enable_metrics {
680 let mut metrics = self.metrics.write().await;
681 metrics.record_timeout();
682 }
683 return Ok(Some(InferenceResponse {
684 request_id: request.id,
685 output: Err(TrustformersError::runtime_error(
686 "Request timed out".to_string(),
687 )),
688 processing_time: request.elapsed(),
689 metadata: HashMap::new(),
690 }));
691 }
692
693 let instance_id = {
695 let mut balancer = self.load_balancer.lock().map_err(|_| {
696 TrustformersError::runtime_error("Failed to acquire load balancer lock".to_string())
697 })?;
698
699 match balancer.select_instance() {
700 Some(instance) => {
701 instance.start_request();
702 instance.id.clone()
703 },
704 None => {
705 return Err(TrustformersError::resource_exhausted(
706 "No healthy instances available".to_string(),
707 ));
708 },
709 }
710 };
711
712 let start_time = Instant::now();
714 let output = self.process_inference(&request).await;
715 let processing_time = start_time.elapsed();
716
717 {
719 let mut balancer = self.load_balancer.lock().map_err(|_| {
720 TrustformersError::runtime_error("Failed to acquire load balancer lock".to_string())
721 })?;
722
723 if let Some(instance) = balancer.instances.iter_mut().find(|i| i.id == instance_id) {
724 instance.update_stats(processing_time);
725 }
726 }
727
728 if self.config.enable_metrics {
730 let mut metrics = self.metrics.write().await;
731 metrics.update_request(output.is_ok(), processing_time);
732
733 let queue_size = {
734 let queue = self.request_queue.lock().map_err(|_| {
735 TrustformersError::runtime_error("Failed to acquire queue lock".to_string())
736 })?;
737 queue.size()
738 };
739 metrics.update_queue_size(queue_size);
740 }
741
742 Ok(Some(InferenceResponse {
743 request_id: request.id,
744 output,
745 processing_time,
746 metadata: HashMap::new(),
747 }))
748 }
749
750 async fn process_inference(&self, request: &InferenceRequest) -> Result<Tensor> {
752 match &self.model_fn {
753 Some(model_fn) => {
754 let model_fn = Arc::clone(model_fn);
756 let input_tensor = request.input.clone();
757
758 let output = tokio::task::spawn_blocking(move || (model_fn)(input_tensor))
760 .await
761 .map_err(|e| {
762 TrustformersError::runtime_error(format!("Inference task failed: {}", e))
763 })??;
764
765 Ok(output)
766 },
767 None => {
768 let input = &request.input;
770
771 let tensor_size = match input {
773 Tensor::F32(arr) => arr.len(),
774 Tensor::I64(arr) => arr.len(),
775 _ => 1000, };
777 let processing_time = std::cmp::min(100, tensor_size / 1000); tokio::time::sleep(Duration::from_millis(processing_time as u64)).await;
779
780 Ok(request.input.clone())
782 },
783 }
784 }
785
786 pub async fn get_metrics(&self) -> ServingMetrics {
788 let metrics = self.metrics.read().await;
789 (*metrics).clone()
790 }
791
792 pub async fn cleanup_expired_requests(&self) -> Result<usize> {
794 let timeout_duration = Duration::from_secs(self.config.request_timeout_seconds);
795 let mut queue = self.request_queue.lock().map_err(|_| {
796 TrustformersError::runtime_error("Failed to acquire queue lock".to_string())
797 })?;
798
799 let removed_count = queue.remove_expired(timeout_duration);
800
801 if self.config.enable_metrics && removed_count > 0 {
802 let mut metrics = self.metrics.write().await;
803 for _ in 0..removed_count {
804 metrics.record_timeout();
805 }
806 metrics.update_queue_size(queue.size());
807 }
808
809 Ok(removed_count)
810 }
811
812 pub fn healthy_instances_count(&self) -> Result<usize> {
814 let balancer = self.load_balancer.lock().map_err(|_| {
815 TrustformersError::runtime_error("Failed to acquire load balancer lock".to_string())
816 })?;
817 Ok(balancer.healthy_instances_count())
818 }
819}
820
821#[derive(Debug)]
823pub struct RateLimiter {
824 max_tokens: u64,
825 tokens: u64,
826 refill_rate: u64, last_refill: Instant,
828}
829
830impl RateLimiter {
831 pub fn new(max_tokens: u64, refill_rate: u64) -> Self {
833 Self {
834 max_tokens,
835 tokens: max_tokens,
836 refill_rate,
837 last_refill: Instant::now(),
838 }
839 }
840
841 pub fn try_acquire(&mut self, tokens: u64) -> bool {
843 self.refill_tokens();
844
845 if self.tokens >= tokens {
846 self.tokens -= tokens;
847 true
848 } else {
849 false
850 }
851 }
852
853 fn refill_tokens(&mut self) {
855 let now = Instant::now();
856 let elapsed = now.duration_since(self.last_refill);
857 let new_tokens = (elapsed.as_secs_f64() * self.refill_rate as f64) as u64;
858
859 if new_tokens > 0 {
860 self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
861 self.last_refill = now;
862 }
863 }
864
865 pub fn available_tokens(&mut self) -> u64 {
867 self.refill_tokens();
868 self.tokens
869 }
870}
871
872#[derive(Debug, Clone, Serialize, Deserialize)]
874pub struct AutoScalingConfig {
875 pub enabled: bool,
877 pub min_instances: usize,
879 pub max_instances: usize,
881 pub target_cpu_utilization: f64,
883 pub scale_up_queue_threshold: usize,
885 pub scale_down_queue_threshold: usize,
887 pub cooldown_period_seconds: u64,
889}
890
891impl Default for AutoScalingConfig {
892 fn default() -> Self {
893 Self {
894 enabled: false,
895 min_instances: 1,
896 max_instances: 10,
897 target_cpu_utilization: 70.0,
898 scale_up_queue_threshold: 20,
899 scale_down_queue_threshold: 5,
900 cooldown_period_seconds: 300, }
902 }
903}
904
905#[derive(Debug)]
907pub struct AutoScaler {
908 config: AutoScalingConfig,
909 last_scaling_action: Option<Instant>,
910 current_instance_count: usize,
911}
912
913impl AutoScaler {
914 pub fn new(config: AutoScalingConfig, initial_instance_count: usize) -> Self {
916 Self {
917 config,
918 last_scaling_action: None,
919 current_instance_count: initial_instance_count,
920 }
921 }
922
923 pub fn should_scale(
925 &self,
926 queue_size: usize,
927 avg_cpu_utilization: f64,
928 ) -> Option<ScalingAction> {
929 if !self.config.enabled {
930 return None;
931 }
932
933 if let Some(last_action) = self.last_scaling_action {
935 if last_action.elapsed().as_secs() < self.config.cooldown_period_seconds {
936 return None;
937 }
938 }
939
940 if (queue_size > self.config.scale_up_queue_threshold
942 || avg_cpu_utilization > self.config.target_cpu_utilization)
943 && self.current_instance_count < self.config.max_instances
944 {
945 return Some(ScalingAction::ScaleUp);
946 }
947
948 if queue_size < self.config.scale_down_queue_threshold
950 && avg_cpu_utilization < self.config.target_cpu_utilization * 0.5
951 && self.current_instance_count > self.config.min_instances
952 {
953 return Some(ScalingAction::ScaleDown);
954 }
955
956 None
957 }
958
959 pub fn record_scaling_action(&mut self, action: ScalingAction) {
961 self.last_scaling_action = Some(Instant::now());
962
963 match action {
964 ScalingAction::ScaleUp => {
965 self.current_instance_count =
966 (self.current_instance_count + 1).min(self.config.max_instances);
967 },
968 ScalingAction::ScaleDown => {
969 self.current_instance_count =
970 (self.current_instance_count.saturating_sub(1)).max(self.config.min_instances);
971 },
972 }
973 }
974
975 pub fn current_instance_count(&self) -> usize {
977 self.current_instance_count
978 }
979
980 pub fn get_scaling_recommendations(&self, metrics: &ServingMetrics) -> Vec<String> {
982 let mut recommendations = Vec::new();
983
984 if !self.config.enabled {
985 recommendations.push("Auto-scaling is disabled".to_string());
986 return recommendations;
987 }
988
989 let queue_ratio =
990 metrics.current_queue_size as f64 / self.config.scale_up_queue_threshold as f64;
991
992 if queue_ratio > 1.0 {
993 recommendations.push(format!(
994 "Queue size ({}) exceeds scale-up threshold ({}). Consider scaling up.",
995 metrics.current_queue_size, self.config.scale_up_queue_threshold
996 ));
997 } else if queue_ratio < 0.25 {
998 recommendations.push(format!(
999 "Queue size ({}) is very low. Consider scaling down to save resources.",
1000 metrics.current_queue_size
1001 ));
1002 }
1003
1004 if metrics.average_response_time_ms > 1000.0 {
1005 recommendations.push("High response times detected. Consider scaling up.".to_string());
1006 }
1007
1008 if metrics.success_rate() < 0.95 {
1009 recommendations.push(
1010 "Low success rate detected. Check instance health and consider scaling."
1011 .to_string(),
1012 );
1013 }
1014
1015 recommendations
1016 }
1017}
1018
1019#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1021pub enum ScalingAction {
1022 ScaleUp,
1023 ScaleDown,
1024}
1025
1026#[derive(Debug)]
1028pub struct EnhancedServingManager {
1029 base_manager: Arc<ModelServingManager>,
1030 rate_limiter: Arc<Mutex<RateLimiter>>,
1031 auto_scaler: Arc<Mutex<AutoScaler>>,
1032 rate_limit_config: RateLimitConfig,
1033}
1034
1035#[derive(Debug, Clone, Serialize, Deserialize)]
1037pub struct RateLimitConfig {
1038 pub enabled: bool,
1040 pub max_requests_per_second: u64,
1042 pub burst_capacity: u64,
1044}
1045
1046impl Default for RateLimitConfig {
1047 fn default() -> Self {
1048 Self {
1049 enabled: true,
1050 max_requests_per_second: 100,
1051 burst_capacity: 200,
1052 }
1053 }
1054}
1055
1056impl EnhancedServingManager {
1057 pub fn new(
1059 serving_config: ServingConfig,
1060 rate_limit_config: RateLimitConfig,
1061 auto_scaling_config: AutoScalingConfig,
1062 ) -> Self {
1063 let base_manager = Arc::new(ModelServingManager::new(serving_config));
1064 let rate_limiter = Arc::new(Mutex::new(RateLimiter::new(
1065 rate_limit_config.burst_capacity,
1066 rate_limit_config.max_requests_per_second,
1067 )));
1068 let auto_scaler = Arc::new(Mutex::new(AutoScaler::new(auto_scaling_config, 1)));
1069
1070 Self {
1071 base_manager,
1072 rate_limiter,
1073 auto_scaler,
1074 rate_limit_config,
1075 }
1076 }
1077
1078 pub async fn submit_request_with_rate_limiting(&self, request: InferenceRequest) -> Result<()> {
1080 if self.rate_limit_config.enabled {
1082 let mut limiter = self.rate_limiter.lock().map_err(|_| {
1083 TrustformersError::runtime_error("Failed to acquire rate limiter lock".to_string())
1084 })?;
1085
1086 if !limiter.try_acquire(1) {
1087 return Err(TrustformersError::resource_exhausted(
1088 "Rate limit exceeded".to_string(),
1089 ));
1090 }
1091 }
1092
1093 self.base_manager.submit_request(request).await
1095 }
1096
1097 pub async fn check_auto_scaling(&self) -> Result<Option<ScalingAction>> {
1099 let metrics = self.base_manager.get_metrics().await;
1100
1101 let mut scaler = self.auto_scaler.lock().map_err(|_| {
1102 TrustformersError::runtime_error("Failed to acquire auto-scaler lock".to_string())
1103 })?;
1104
1105 let avg_cpu_utilization = self.get_approximate_cpu_utilization();
1107
1108 if let Some(action) = scaler.should_scale(metrics.current_queue_size, avg_cpu_utilization) {
1109 scaler.record_scaling_action(action);
1110 Ok(Some(action))
1111 } else {
1112 Ok(None)
1113 }
1114 }
1115
1116 pub async fn get_enhanced_metrics(&self) -> Result<EnhancedMetrics> {
1118 let base_metrics = self.base_manager.get_metrics().await;
1119
1120 let available_tokens = {
1121 let mut limiter = self.rate_limiter.lock().map_err(|_| {
1122 TrustformersError::runtime_error("Failed to acquire rate limiter lock".to_string())
1123 })?;
1124 limiter.available_tokens()
1125 };
1126
1127 let (current_instance_count, scaling_recommendations) = {
1128 let scaler = self.auto_scaler.lock().map_err(|_| {
1129 TrustformersError::runtime_error("Failed to acquire auto-scaler lock".to_string())
1130 })?;
1131 (
1132 scaler.current_instance_count(),
1133 scaler.get_scaling_recommendations(&base_metrics),
1134 )
1135 };
1136
1137 Ok(EnhancedMetrics {
1138 base_metrics,
1139 available_rate_limit_tokens: available_tokens,
1140 current_instance_count,
1141 scaling_recommendations,
1142 })
1143 }
1144
1145 fn get_approximate_cpu_utilization(&self) -> f64 {
1147 use std::fs;
1148 use std::io::Read;
1149
1150 #[cfg(unix)]
1152 {
1153 if let Ok(mut file) = fs::File::open("/proc/loadavg") {
1154 let mut contents = String::new();
1155 if file.read_to_string(&mut contents).is_ok() {
1156 let parts: Vec<&str> = contents.split_whitespace().collect();
1157 if let Some(load_1min) = parts.first() {
1158 if let Ok(load) = load_1min.parse::<f64>() {
1159 let num_cores = num_cpus::get() as f64;
1160 let utilization = (load / num_cores * 100.0).min(100.0);
1162 return utilization;
1163 }
1164 }
1165 }
1166 }
1167 }
1168
1169 let queue_size = if let Ok(queue) = self.base_manager.request_queue.lock() {
1171 queue.size() as f64
1172 } else {
1173 0.0
1174 };
1175
1176 let base_utilization = 30.0; let queue_factor = (queue_size * 5.0).min(50.0); (base_utilization + queue_factor).min(95.0) }
1182
1183 pub fn base_manager(&self) -> &Arc<ModelServingManager> {
1185 &self.base_manager
1186 }
1187}
1188
1189#[derive(Debug, Clone)]
1191pub struct EnhancedMetrics {
1192 pub base_metrics: ServingMetrics,
1193 pub available_rate_limit_tokens: u64,
1194 pub current_instance_count: usize,
1195 pub scaling_recommendations: Vec<String>,
1196}
1197
1198#[cfg(test)]
1199mod tests {
1200 use super::*;
1201
1202 #[test]
1203 fn test_serving_config_default() {
1204 let config = ServingConfig::default();
1205 assert_eq!(config.max_concurrent_requests, 10);
1206 assert_eq!(config.request_timeout_seconds, 30);
1207 assert_eq!(config.max_queue_size, 100);
1208 }
1209
1210 #[test]
1211 fn test_inference_request_creation() {
1212 let tensor = Tensor::zeros(&[1, 2]).expect("operation failed");
1213 let request = InferenceRequest::new(tensor, RequestPriority::Normal);
1214
1215 assert_eq!(request.priority, RequestPriority::Normal);
1216 assert!(!request.metadata.is_empty() || request.metadata.is_empty()); }
1218
1219 #[test]
1220 fn test_model_instance() {
1221 let mut instance = ModelInstance::new("test-instance".to_string(), 1.0);
1222 assert_eq!(instance.id, "test-instance");
1223 assert_eq!(instance.weight, 1.0);
1224 assert_eq!(instance.active_requests, 0);
1225
1226 instance.start_request();
1227 assert_eq!(instance.active_requests, 1);
1228
1229 instance.update_stats(Duration::from_millis(100));
1230 assert_eq!(instance.active_requests, 0);
1231 assert_eq!(instance.total_requests, 1);
1232 }
1233
1234 #[test]
1235 fn test_load_balancer() {
1236 let mut balancer = LoadBalancer::new(LoadBalancingStrategy::RoundRobin);
1237
1238 let instance1 = ModelInstance::new("instance1".to_string(), 1.0);
1239 let instance2 = ModelInstance::new("instance2".to_string(), 1.0);
1240
1241 balancer.add_instance(instance1);
1242 balancer.add_instance(instance2);
1243
1244 assert_eq!(balancer.healthy_instances_count(), 2);
1245
1246 let selected1 = balancer.select_instance().expect("operation failed");
1247 assert_eq!(selected1.id, "instance1");
1248
1249 let selected2 = balancer.select_instance().expect("operation failed");
1250 assert_eq!(selected2.id, "instance2");
1251 }
1252
1253 #[test]
1254 fn test_request_queue() {
1255 let mut queue = RequestQueue::new(2);
1256
1257 let tensor1 = Tensor::zeros(&[1, 2]).expect("operation failed");
1258 let tensor2 = Tensor::zeros(&[1, 2]).expect("operation failed");
1259 let tensor3 = Tensor::zeros(&[1, 2]).expect("operation failed");
1260
1261 let req1 = InferenceRequest::new(tensor1, RequestPriority::Normal);
1262 let req2 = InferenceRequest::new(tensor2, RequestPriority::High);
1263 let req3 = InferenceRequest::new(tensor3, RequestPriority::Low);
1264
1265 assert!(queue.enqueue(req1).is_ok());
1266 assert!(queue.enqueue(req2).is_ok());
1267 assert!(queue.enqueue(req3).is_err()); assert_eq!(queue.size(), 2);
1270
1271 let dequeued = queue.dequeue().expect("operation failed");
1273 assert_eq!(dequeued.priority, RequestPriority::High);
1274 }
1275
1276 #[test]
1277 fn test_serving_metrics() {
1278 let mut metrics = ServingMetrics::default();
1279
1280 metrics.update_request(true, Duration::from_millis(100));
1281 metrics.update_request(false, Duration::from_millis(200));
1282
1283 assert_eq!(metrics.total_requests, 2);
1284 assert_eq!(metrics.successful_requests, 1);
1285 assert_eq!(metrics.failed_requests, 1);
1286 assert_eq!(metrics.success_rate(), 0.5);
1287 assert_eq!(metrics.average_response_time_ms, 150.0);
1288 }
1289
1290 #[tokio::test]
1291 async fn test_model_serving_manager() {
1292 let config = ServingConfig::default();
1293 let manager = ModelServingManager::new(config);
1294
1295 let instance = ModelInstance::new("test-instance".to_string(), 1.0);
1296 manager.add_instance(instance).expect("operation failed");
1297
1298 let tensor = Tensor::zeros(&[1, 2]).expect("operation failed");
1299 let request = InferenceRequest::new(tensor, RequestPriority::Normal);
1300
1301 manager.submit_request(request).await.expect("operation failed");
1302
1303 let response = manager.process_next_request().await.expect("operation failed");
1304 assert!(response.is_some());
1305
1306 let metrics = manager.get_metrics().await;
1307 assert_eq!(metrics.total_requests, 1);
1308 }
1309
1310 #[test]
1311 fn test_rate_limiter() {
1312 let mut limiter = RateLimiter::new(10, 5); assert!(limiter.try_acquire(5));
1316 assert_eq!(limiter.available_tokens(), 5);
1317
1318 assert!(!limiter.try_acquire(10));
1320
1321 assert!(limiter.try_acquire(5));
1323 assert_eq!(limiter.available_tokens(), 0);
1324
1325 assert!(!limiter.try_acquire(1));
1327 }
1328
1329 #[test]
1330 fn test_auto_scaler() {
1331 let config = AutoScalingConfig {
1332 enabled: true,
1333 min_instances: 1,
1334 max_instances: 5,
1335 target_cpu_utilization: 70.0,
1336 scale_up_queue_threshold: 10,
1337 scale_down_queue_threshold: 2,
1338 cooldown_period_seconds: 60,
1339 };
1340
1341 let mut scaler = AutoScaler::new(config, 2);
1342
1343 let action = scaler.should_scale(15, 50.0);
1345 assert_eq!(action, Some(ScalingAction::ScaleUp));
1346
1347 scaler.record_scaling_action(ScalingAction::ScaleUp);
1349 assert_eq!(scaler.current_instance_count(), 3);
1350
1351 let action = scaler.should_scale(15, 50.0);
1353 assert_eq!(action, None);
1354 }
1355
1356 #[test]
1357 fn test_auto_scaling_recommendations() {
1358 let config = AutoScalingConfig {
1359 enabled: true,
1360 scale_up_queue_threshold: 20,
1361 ..Default::default()
1362 };
1363 let scaler = AutoScaler::new(config, 2);
1364
1365 let mut metrics = ServingMetrics {
1366 current_queue_size: 25, ..ServingMetrics::default()
1368 };
1369 metrics.update_request(true, Duration::from_millis(1500)); let recommendations = scaler.get_scaling_recommendations(&metrics);
1372 assert!(!recommendations.is_empty());
1373 assert!(recommendations.iter().any(|r| r.contains("scale-up threshold")));
1374 assert!(recommendations.iter().any(|r| r.contains("High response times")));
1375 }
1376
1377 #[tokio::test]
1378 async fn test_enhanced_serving_manager() {
1379 let serving_config = ServingConfig::default();
1380 let rate_limit_config = RateLimitConfig {
1381 enabled: true,
1382 max_requests_per_second: 2,
1383 burst_capacity: 5,
1384 };
1385 let auto_scaling_config = AutoScalingConfig::default();
1386
1387 let manager =
1388 EnhancedServingManager::new(serving_config, rate_limit_config, auto_scaling_config);
1389
1390 let instance = ModelInstance::new("test-instance".to_string(), 1.0);
1392 manager.base_manager().add_instance(instance).expect("operation failed");
1393
1394 let tensor = Tensor::zeros(&[1, 2]).expect("operation failed");
1396
1397 for _ in 0..5 {
1399 let request = InferenceRequest::new(tensor.clone(), RequestPriority::Normal);
1400 let result = manager.submit_request_with_rate_limiting(request).await;
1401 assert!(result.is_ok());
1402 }
1403
1404 let request = InferenceRequest::new(tensor, RequestPriority::Normal);
1406 let result = manager.submit_request_with_rate_limiting(request).await;
1407 assert!(result.is_err());
1408
1409 let enhanced_metrics = manager.get_enhanced_metrics().await.expect("operation failed");
1411 assert_eq!(enhanced_metrics.current_instance_count, 1);
1412 assert!(enhanced_metrics.available_rate_limit_tokens < 5);
1413 }
1414
1415 #[tokio::test]
1416 async fn test_enhanced_serving_auto_scaling() {
1417 let serving_config = ServingConfig::default();
1418 let rate_limit_config = RateLimitConfig::default();
1419 let auto_scaling_config = AutoScalingConfig {
1420 enabled: true,
1421 min_instances: 1,
1422 max_instances: 3,
1423 scale_up_queue_threshold: 5,
1424 scale_down_queue_threshold: 1,
1425 cooldown_period_seconds: 0, ..Default::default()
1427 };
1428
1429 let manager =
1430 EnhancedServingManager::new(serving_config, rate_limit_config, auto_scaling_config);
1431
1432 let tensor = Tensor::zeros(&[1, 2]).expect("operation failed");
1434 for _ in 0..10 {
1435 let request = InferenceRequest::new(tensor.clone(), RequestPriority::Normal);
1436 manager.base_manager().submit_request(request).await.expect("operation failed");
1437 }
1438
1439 let scaling_action = manager.check_auto_scaling().await.expect("operation failed");
1441 assert_eq!(scaling_action, Some(ScalingAction::ScaleUp));
1442
1443 let enhanced_metrics = manager.get_enhanced_metrics().await.expect("operation failed");
1444 assert_eq!(enhanced_metrics.current_instance_count, 2);
1445 }
1446}