1use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39
40use std::collections::HashMap;
41
42use swarm_engine_core::actions::ActionDef;
43use swarm_engine_core::agent::{BatchDecisionRequest, DecisionResponse, WorkerDecisionRequest};
44use swarm_engine_core::exploration::DependencyGraph;
45use swarm_engine_core::types::{LoraConfig, WorkerId};
46
47use crate::decider::{LlmDecider, LlmError};
48
49pub type BatchProcessResult = Vec<(WorkerId, Result<DecisionResponse, BatchProcessError>)>;
55
56#[derive(Debug, Clone, thiserror::Error)]
58pub enum BatchProcessError {
59 #[error("Batch process error (transient): {0}")]
61 Transient(String),
62
63 #[error("Batch process error: {0}")]
65 Permanent(String),
66}
67
68impl BatchProcessError {
69 pub fn transient(message: impl Into<String>) -> Self {
70 Self::Transient(message.into())
71 }
72
73 pub fn permanent(message: impl Into<String>) -> Self {
74 Self::Permanent(message.into())
75 }
76
77 pub fn is_transient(&self) -> bool {
78 matches!(self, Self::Transient(_))
79 }
80
81 pub fn message(&self) -> &str {
82 match self {
83 Self::Transient(msg) => msg,
84 Self::Permanent(msg) => msg,
85 }
86 }
87}
88
89impl From<LlmError> for BatchProcessError {
90 fn from(e: LlmError) -> Self {
91 if e.is_transient() {
92 Self::Transient(e.message().to_string())
93 } else {
94 Self::Permanent(e.message().to_string())
95 }
96 }
97}
98
99impl From<swarm_engine_core::error::SwarmError> for BatchProcessError {
100 fn from(err: swarm_engine_core::error::SwarmError) -> Self {
101 if err.is_transient() {
102 Self::Transient(err.message())
103 } else {
104 Self::Permanent(err.message())
105 }
106 }
107}
108
109impl From<BatchProcessError> for swarm_engine_core::error::SwarmError {
110 fn from(err: BatchProcessError) -> Self {
111 match err {
112 BatchProcessError::Transient(message) => {
113 swarm_engine_core::error::SwarmError::LlmTransient { message }
114 }
115 BatchProcessError::Permanent(message) => {
116 swarm_engine_core::error::SwarmError::LlmPermanent { message }
117 }
118 }
119 }
120}
121
122pub trait BatchProcessor: Send + Sync {
127 fn process(
135 &self,
136 request: BatchDecisionRequest,
137 ) -> Pin<Box<dyn Future<Output = BatchProcessResult> + Send + '_>>;
138
139 fn plan_dependencies(
147 &self,
148 _task: &str,
149 _actions: &[ActionDef],
150 ) -> Pin<Box<dyn Future<Output = Option<DependencyGraph>> + Send + '_>> {
151 Box::pin(async { None })
152 }
153
154 fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>>;
156
157 fn name(&self) -> &str;
159}
160
161#[derive(Debug, Clone)]
167pub struct LlmBatchProcessorConfig {
168 pub parallel: bool,
170 pub max_concurrency: usize,
172 pub max_retries: Option<usize>,
174}
175
176impl Default for LlmBatchProcessorConfig {
177 fn default() -> Self {
178 Self {
179 parallel: true,
180 max_concurrency: 4,
181 max_retries: Some(5),
182 }
183 }
184}
185
186pub struct LlmBatchProcessor<D: LlmDecider> {
191 decider: Arc<D>,
192 config: LlmBatchProcessorConfig,
193}
194
195impl<D: LlmDecider> LlmBatchProcessor<D> {
196 pub fn new(decider: D) -> Self {
198 Self {
199 decider: Arc::new(decider),
200 config: LlmBatchProcessorConfig::default(),
201 }
202 }
203
204 pub fn from_arc(decider: Arc<D>) -> Self {
206 Self {
207 decider,
208 config: LlmBatchProcessorConfig::default(),
209 }
210 }
211
212 pub fn with_config(mut self, config: LlmBatchProcessorConfig) -> Self {
214 self.config = config;
215 self
216 }
217}
218
219impl<D: LlmDecider + 'static> BatchProcessor for LlmBatchProcessor<D> {
220 fn process(
221 &self,
222 request: BatchDecisionRequest,
223 ) -> Pin<Box<dyn Future<Output = BatchProcessResult> + Send + '_>> {
224 Box::pin(async move {
225 if request.requests.is_empty() {
226 return vec![];
227 }
228
229 let requests: Vec<(WorkerId, WorkerDecisionRequest)> = request
231 .requests
232 .into_iter()
233 .map(|r| (r.worker_id, r))
234 .collect();
235
236 if self.config.parallel {
237 self.process_parallel(requests).await
238 } else {
239 self.process_sequential(requests).await
240 }
241 })
242 }
243
244 fn plan_dependencies(
245 &self,
246 task: &str,
247 actions: &[ActionDef],
248 ) -> Pin<Box<dyn Future<Output = Option<DependencyGraph>> + Send + '_>> {
249 let task = task.to_string();
250 let actions: Vec<ActionDef> = actions.to_vec();
251 let decider = Arc::clone(&self.decider);
252
253 Box::pin(async move {
254 use std::time::Instant;
255 use swarm_engine_core::actions::ActionCategory;
256 use swarm_engine_core::exploration::DependencyGraphBuilder;
257
258 let start_time = Instant::now();
259 let action_names: Vec<String> = actions.iter().map(|a| a.name.clone()).collect();
260
261 let discover: Vec<&ActionDef> = actions
263 .iter()
264 .filter(|a| a.category == ActionCategory::NodeExpand)
265 .collect();
266 let not_discover: Vec<&ActionDef> = actions
267 .iter()
268 .filter(|a| a.category == ActionCategory::NodeStateChange)
269 .collect();
270
271 tracing::debug!(
272 discover = ?discover.iter().map(|a| &a.name).collect::<Vec<_>>(),
273 not_discover = ?not_discover.iter().map(|a| &a.name).collect::<Vec<_>>(),
274 "Separated actions by category"
275 );
276
277 let discover_sort_start = Instant::now();
279 let sorted_discover = if discover.len() <= 1 {
280 discover.iter().map(|a| a.name.clone()).collect()
281 } else {
282 binary_sort_actions(&task, &discover, decider.as_ref()).await
283 };
284 let discover_sort_ms = discover_sort_start.elapsed().as_millis();
285
286 tracing::debug!(
287 sorted = ?sorted_discover,
288 elapsed_ms = discover_sort_ms,
289 "Sorted Discover actions via binary comparison"
290 );
291
292 let not_discover_sort_start = Instant::now();
294 let sorted_not_discover = if not_discover.len() <= 1 {
295 not_discover.iter().map(|a| a.name.clone()).collect()
296 } else {
297 binary_sort_actions(&task, ¬_discover, decider.as_ref()).await
298 };
299 let not_discover_sort_ms = not_discover_sort_start.elapsed().as_millis();
300
301 tracing::debug!(
302 sorted = ?sorted_not_discover,
303 elapsed_ms = not_discover_sort_ms,
304 "Sorted NotDiscover actions via binary comparison"
305 );
306
307 let mut builder = DependencyGraphBuilder::new()
309 .task(&task)
310 .available_actions(action_names.clone());
311
312 if !sorted_discover.is_empty() {
314 builder = builder.start_node(&sorted_discover[0]);
315 } else if !sorted_not_discover.is_empty() {
316 builder = builder.start_node(&sorted_not_discover[0]);
318 }
319
320 if let Some(last) = sorted_not_discover.last() {
322 builder = builder.terminal_node(last);
323 } else if !sorted_discover.is_empty() {
324 builder = builder.terminal_node(sorted_discover.last().unwrap());
326 }
327
328 for window in sorted_discover.windows(2) {
330 builder = builder.edge(&window[0], &window[1], 0.9);
331 }
332
333 if !sorted_discover.is_empty() && !sorted_not_discover.is_empty() {
335 builder = builder.edge(
336 sorted_discover.last().unwrap(),
337 &sorted_not_discover[0],
338 0.9,
339 );
340 }
341
342 for window in sorted_not_discover.windows(2) {
344 builder = builder.edge(&window[0], &window[1], 0.9);
345 }
346
347 let mut graph = builder.build();
348 let total_ms = start_time.elapsed().as_millis();
349
350 graph.set_action_order(sorted_discover.clone(), sorted_not_discover.clone());
352
353 {
355 use swarm_engine_core::events::{LearningEvent, LearningEventChannel};
356 use swarm_engine_core::learn::DependencyGraphRecord;
357
358 let prompt = format!(
360 "Task: {}\n\nAvailable Actions:\n{}",
361 task,
362 action_names
363 .iter()
364 .map(|n| format!("- {}", n))
365 .collect::<Vec<_>>()
366 .join("\n")
367 );
368
369 let response = format!(
371 "discover_order: {:?}\nnot_discover_order: {:?}",
372 sorted_discover, sorted_not_discover
373 );
374
375 let event = LearningEvent::dependency_graph_inference(decider.model_name())
377 .prompt(&prompt)
378 .response(&response)
379 .available_actions(action_names)
380 .discover_order(sorted_discover.clone())
381 .not_discover_order(sorted_not_discover.clone())
382 .endpoint(decider.endpoint())
383 .latency_ms(total_ms as u64)
384 .success()
385 .build();
386
387 LearningEventChannel::global().emit(event.clone());
389
390 let record = DependencyGraphRecord::from(&event);
392 graph.set_learn_record(record);
393 }
394
395 tracing::info!(
396 discover_order = ?sorted_discover,
397 not_discover_order = ?sorted_not_discover,
398 edges = graph.edges().len(),
399 discover_sort_ms = discover_sort_ms,
400 not_discover_sort_ms = not_discover_sort_ms,
401 total_ms = total_ms,
402 "DependencyGraph generated via LLM binary sort"
403 );
404
405 Some(graph)
406 })
407 }
408
409 fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
410 let decider = Arc::clone(&self.decider);
411 Box::pin(async move { decider.is_healthy().await })
412 }
413
414 fn name(&self) -> &str {
415 self.decider.model_name()
416 }
417}
418
419impl<D: LlmDecider + 'static> LlmBatchProcessor<D> {
420 async fn process_parallel(
437 &self,
438 requests: Vec<(WorkerId, WorkerDecisionRequest)>,
439 ) -> BatchProcessResult {
440 let grouped = group_by_lora(requests);
442
443 let group_count = grouped.len();
444 if group_count > 1 {
445 tracing::debug!(
446 groups = group_count,
447 "Processing requests in {} LoRA groups",
448 group_count
449 );
450 }
451
452 let mut all_results = Vec::new();
454 for (lora_config, group_requests) in grouped {
455 if group_count > 1 {
456 tracing::trace!(
457 lora = ?lora_config,
458 count = group_requests.len(),
459 "Processing LoRA group"
460 );
461 }
462 let results = self.process_group(group_requests).await;
463 all_results.extend(results);
464 }
465
466 all_results
467 }
468
469 async fn process_group(
471 &self,
472 requests: Vec<(WorkerId, WorkerDecisionRequest)>,
473 ) -> BatchProcessResult {
474 use futures::future::join_all;
475 use tokio::sync::Semaphore;
476
477 let max_concurrency = self
479 .decider
480 .max_concurrency()
481 .await
482 .unwrap_or(self.config.max_concurrency);
483
484 let semaphore = Arc::new(Semaphore::new(max_concurrency));
485
486 let futures: Vec<_> = requests
487 .into_iter()
488 .map(|(worker_id, req)| {
489 let decider = Arc::clone(&self.decider);
490 let sem = Arc::clone(&semaphore);
491 async move {
492 let _permit = sem.acquire().await.expect("Semaphore closed");
494 let result = decider.decide(req).await;
495 (worker_id, result)
496 }
497 })
498 .collect();
499
500 let results = join_all(futures).await;
501
502 results
503 .into_iter()
504 .map(|(worker_id, result)| {
505 let mapped = result.map_err(BatchProcessError::from);
506 (worker_id, mapped)
507 })
508 .collect()
509 }
510
511 async fn process_sequential(
513 &self,
514 requests: Vec<(WorkerId, WorkerDecisionRequest)>,
515 ) -> BatchProcessResult {
516 let mut results = Vec::with_capacity(requests.len());
517
518 for (worker_id, req) in requests {
519 let result = self.decider.decide(req).await;
520 let mapped = result.map_err(BatchProcessError::from);
521 results.push((worker_id, mapped));
522 }
523
524 results
525 }
526}
527
528fn group_by_lora(
533 requests: Vec<(WorkerId, WorkerDecisionRequest)>,
534) -> HashMap<Option<LoraConfig>, Vec<(WorkerId, WorkerDecisionRequest)>> {
535 let mut groups: HashMap<Option<LoraConfig>, Vec<(WorkerId, WorkerDecisionRequest)>> =
536 HashMap::new();
537
538 for (worker_id, req) in requests {
539 let lora_key = req.lora.clone();
540 groups.entry(lora_key).or_default().push((worker_id, req));
541 }
542
543 groups
544}
545
546async fn binary_sort_actions<D: LlmDecider>(
555 task: &str,
556 actions: &[&ActionDef],
557 decider: &D,
558) -> Vec<String> {
559 use futures::future::join_all;
560 use std::collections::HashMap;
561
562 if actions.len() <= 1 {
563 return actions.iter().map(|a| a.name.clone()).collect();
564 }
565
566 let mut requests: Vec<(usize, usize, String, String, String)> = Vec::new();
569 let mut pair_index = 0;
570
571 for i in 0..actions.len() {
572 for j in (i + 1)..actions.len() {
573 let a = actions[i];
574 let b = actions[j];
575 let prompt = format!(
576 "Goal: {}\n- {}: {}\n- {}: {}\nWhich comes first: {} or {}?\nAnswer (one word):",
577 task, a.name, a.description, b.name, b.description, a.name, b.name
578 );
579
580 for vote_idx in 0..3 {
582 requests.push((
583 pair_index,
584 vote_idx,
585 prompt.clone(),
586 a.name.clone(),
587 b.name.clone(),
588 ));
589 }
590 pair_index += 1;
591 }
592 }
593
594 let total_requests = requests.len();
595 tracing::debug!(
596 pairs = pair_index,
597 total_requests = total_requests,
598 "Binary sort: sending batch requests"
599 );
600
601 let futures: Vec<_> = requests
604 .into_iter()
605 .map(|(pair_idx, vote_idx, prompt, a_name, b_name)| {
606 let decider_ref = decider;
607 async move {
608 let result = decider_ref.call_raw(&prompt, None).await;
609 (pair_idx, vote_idx, result, a_name, b_name)
610 }
611 })
612 .collect();
613
614 let results = join_all(futures).await;
615
616 let mut pair_votes: HashMap<usize, (usize, usize, String, String)> = HashMap::new();
619
620 for (pair_idx, _vote_idx, result, a_name, b_name) in results {
621 let entry = pair_votes
622 .entry(pair_idx)
623 .or_insert((0, 0, a_name.clone(), b_name.clone()));
624
625 if let Ok(response) = result {
626 let response_upper = response.to_uppercase();
627 let a_upper = a_name.to_uppercase();
628 let b_upper = b_name.to_uppercase();
629
630 if response_upper.contains(&a_upper) {
631 entry.0 += 1;
632 } else if response_upper.contains(&b_upper) {
633 entry.1 += 1;
634 }
635 }
636 }
637
638 let mut wins: HashMap<String, usize> = HashMap::new();
640 for a in actions {
641 wins.insert(a.name.clone(), 0);
642 }
643
644 for (_pair_idx, (a_count, b_count, a_name, b_name)) in pair_votes {
645 if a_count >= b_count {
647 *wins.get_mut(&b_name).unwrap() += 1;
649 } else {
650 *wins.get_mut(&a_name).unwrap() += 1;
652 }
653 }
654
655 let mut sorted: Vec<_> = wins.into_iter().collect();
657 sorted.sort_by_key(|(_, count)| *count);
658
659 tracing::debug!(
660 sorted = ?sorted.iter().map(|(n, c)| format!("{}:{}", n, c)).collect::<Vec<_>>(),
661 "Binary sort completed"
662 );
663
664 sorted.into_iter().map(|(name, _)| name).collect()
665}
666
667#[cfg(test)]
672mod tests {
673 use super::*;
674
675 #[test]
676 fn test_batch_process_error_transient() {
677 let err = BatchProcessError::transient("connection timeout");
678 assert!(err.is_transient());
679 assert_eq!(err.message(), "connection timeout");
680 }
681
682 #[test]
683 fn test_batch_process_error_permanent() {
684 let err = BatchProcessError::permanent("invalid model");
685 assert!(!err.is_transient());
686 assert_eq!(err.message(), "invalid model");
687 }
688
689 #[test]
690 fn test_batch_process_error_from_llm_error() {
691 let llm_err = LlmError::transient("timeout");
692 let batch_err: BatchProcessError = llm_err.into();
693 assert!(batch_err.is_transient());
694 assert_eq!(batch_err.message(), "timeout");
695 }
696
697 #[test]
698 fn test_ollama_batch_processor_config_default() {
699 let config = LlmBatchProcessorConfig::default();
700 assert!(config.parallel);
701 assert_eq!(config.max_concurrency, 4);
702 }
703
704 use std::collections::HashMap;
709
710 fn binary_sort_sync(
713 actions: &[&str],
714 comparator: impl Fn(&str, &str) -> String,
716 ) -> Vec<String> {
717 if actions.len() <= 1 {
718 return actions.iter().map(|s| s.to_string()).collect();
719 }
720
721 let mut wins: HashMap<String, usize> = HashMap::new();
722 for &a in actions {
723 wins.insert(a.to_string(), 0);
724 }
725
726 for i in 0..actions.len() {
727 for j in (i + 1)..actions.len() {
728 let a = actions[i];
729 let b = actions[j];
730 let winner = comparator(a, b);
731
732 if winner == a {
734 *wins.get_mut(b).unwrap() += 1;
735 } else {
736 *wins.get_mut(a).unwrap() += 1;
737 }
738 }
739 }
740
741 let mut sorted: Vec<_> = wins.into_iter().collect();
742 sorted.sort_by_key(|(_, count)| *count);
743 sorted.into_iter().map(|(name, _)| name).collect()
744 }
745
746 #[test]
747 fn test_binary_sort_two_actions() {
748 let result = binary_sort_sync(
750 &["Fetch", "Summarize"],
751 |a, _b| a.to_string(), );
753 assert_eq!(result, vec!["Fetch", "Summarize"]);
754
755 let result = binary_sort_sync(
757 &["Fetch", "Summarize"],
758 |_a, b| b.to_string(), );
760 assert_eq!(result, vec!["Summarize", "Fetch"]);
761 }
762
763 #[test]
764 fn test_binary_sort_three_actions() {
765 let result = binary_sort_sync(&["Test", "Deploy", "Build"], |a, b| {
768 let order = ["Build", "Test", "Deploy"];
769 let a_idx = order.iter().position(|&x| x == a).unwrap();
770 let b_idx = order.iter().position(|&x| x == b).unwrap();
771 if a_idx < b_idx {
772 a.to_string()
773 } else {
774 b.to_string()
775 }
776 });
777 assert_eq!(result, vec!["Build", "Test", "Deploy"]);
778 }
779
780 #[test]
781 fn test_binary_sort_wins_calculation() {
782 let mut wins: HashMap<String, usize> = HashMap::new();
792 wins.insert("A".to_string(), 0);
793 wins.insert("B".to_string(), 0);
794 wins.insert("C".to_string(), 0);
795
796 *wins.get_mut("B").unwrap() += 1;
798 *wins.get_mut("C").unwrap() += 1;
800 *wins.get_mut("C").unwrap() += 1;
802
803 assert_eq!(wins["A"], 0);
804 assert_eq!(wins["B"], 1);
805 assert_eq!(wins["C"], 2);
806
807 let mut sorted: Vec<_> = wins.into_iter().collect();
808 sorted.sort_by_key(|(_, count)| *count);
809 let result: Vec<_> = sorted.into_iter().map(|(name, _)| name).collect();
810
811 assert_eq!(result, vec!["A", "B", "C"]);
812 }
813
814 fn extract_winner(response: &str, a: &str, b: &str) -> Option<String> {
816 let response_upper = response.to_uppercase();
817 let a_upper = a.to_uppercase();
818 let b_upper = b.to_uppercase();
819
820 if response_upper.contains(&a_upper) {
821 Some(a.to_string())
822 } else if response_upper.contains(&b_upper) {
823 Some(b.to_string())
824 } else {
825 None
826 }
827 }
828
829 #[test]
830 fn test_extract_winner() {
831 assert_eq!(
833 extract_winner("Fetch", "Fetch", "Summarize"),
834 Some("Fetch".to_string())
835 );
836 assert_eq!(
837 extract_winner("Summarize", "Fetch", "Summarize"),
838 Some("Summarize".to_string())
839 );
840
841 assert_eq!(
843 extract_winner(" Fetch", "Fetch", "Summarize"),
844 Some("Fetch".to_string())
845 );
846
847 assert_eq!(
849 extract_winner("fetch", "Fetch", "Summarize"),
850 Some("Fetch".to_string())
851 );
852 assert_eq!(
853 extract_winner("FETCH", "Fetch", "Summarize"),
854 Some("Fetch".to_string())
855 );
856
857 assert_eq!(
859 extract_winner("The answer is Fetch.", "Fetch", "Summarize"),
860 Some("Fetch".to_string())
861 );
862
863 assert_eq!(extract_winner("Unknown", "Fetch", "Summarize"), None);
865
866 assert_eq!(
868 extract_winner("Fetch then Summarize", "Fetch", "Summarize"),
869 Some("Fetch".to_string())
870 );
871 }
872
873 #[test]
874 fn test_vote_majority() {
875 fn vote_majority(responses: &[&str], a: &str, b: &str) -> String {
877 let mut a_count = 0;
878 let mut b_count = 0;
879
880 for response in responses {
881 if let Some(winner) = extract_winner(response, a, b) {
882 if winner == a {
883 a_count += 1;
884 } else {
885 b_count += 1;
886 }
887 }
888 }
889
890 if a_count >= b_count {
891 a.to_string()
892 } else {
893 b.to_string()
894 }
895 }
896
897 assert_eq!(
899 vote_majority(&["Fetch", "Fetch", "Fetch"], "Fetch", "Summarize"),
900 "Fetch"
901 );
902
903 assert_eq!(
905 vote_majority(&["Fetch", "Summarize", "Fetch"], "Fetch", "Summarize"),
906 "Fetch"
907 );
908
909 assert_eq!(
911 vote_majority(&["Summarize", "Summarize", "Fetch"], "Fetch", "Summarize"),
912 "Summarize"
913 );
914
915 assert_eq!(
917 vote_majority(&["Fetch", "Summarize", "Unknown"], "Fetch", "Summarize"),
918 "Fetch"
919 );
920 }
921
922 use swarm_engine_core::context::{ContextTarget, GlobalContext, ResolvedContext};
927
928 fn create_test_request(
929 worker_id: usize,
930 lora: Option<LoraConfig>,
931 ) -> (WorkerId, WorkerDecisionRequest) {
932 let global = GlobalContext {
933 tick: 0,
934 max_ticks: 100,
935 progress: 0.0,
936 success_rate: 0.0,
937 task_description: Some("test".to_string()),
938 hint: None,
939 };
940 let context = ResolvedContext::new(global, ContextTarget::Worker(WorkerId(worker_id)));
941
942 (
943 WorkerId(worker_id),
944 WorkerDecisionRequest {
945 worker_id: WorkerId(worker_id),
946 query: format!("query_{}", worker_id),
947 context,
948 lora,
949 },
950 )
951 }
952
953 #[test]
954 fn test_group_by_lora_single_group_no_lora() {
955 let requests = vec![
956 create_test_request(0, None),
957 create_test_request(1, None),
958 create_test_request(2, None),
959 ];
960
961 let groups = group_by_lora(requests);
962
963 assert_eq!(groups.len(), 1);
964 assert!(groups.contains_key(&None));
965 assert_eq!(groups[&None].len(), 3);
966 }
967
968 #[test]
969 fn test_group_by_lora_single_group_with_lora() {
970 let lora = LoraConfig::with_id(0);
971 let requests = vec![
972 create_test_request(0, Some(lora.clone())),
973 create_test_request(1, Some(lora.clone())),
974 ];
975
976 let groups = group_by_lora(requests);
977
978 assert_eq!(groups.len(), 1);
979 assert!(groups.contains_key(&Some(lora)));
980 }
981
982 #[test]
983 fn test_group_by_lora_multiple_groups() {
984 let lora_a = LoraConfig::with_id(0);
985 let lora_b = LoraConfig::with_id(1);
986
987 let requests = vec![
988 create_test_request(0, Some(lora_a.clone())),
989 create_test_request(1, Some(lora_b.clone())),
990 create_test_request(2, Some(lora_a.clone())),
991 create_test_request(3, None),
992 create_test_request(4, Some(lora_b.clone())),
993 ];
994
995 let groups = group_by_lora(requests);
996
997 assert_eq!(groups.len(), 3);
998 assert_eq!(groups[&Some(lora_a)].len(), 2);
999 assert_eq!(groups[&Some(lora_b)].len(), 2);
1000 assert_eq!(groups[&None].len(), 1);
1001 }
1002
1003 #[test]
1004 fn test_group_by_lora_preserves_order_within_group() {
1005 let lora = LoraConfig::with_id(0);
1006 let requests = vec![
1007 create_test_request(5, Some(lora.clone())),
1008 create_test_request(3, Some(lora.clone())),
1009 create_test_request(7, Some(lora.clone())),
1010 ];
1011
1012 let groups = group_by_lora(requests);
1013 let group = &groups[&Some(lora)];
1014
1015 assert_eq!(group[0].0, WorkerId(5));
1017 assert_eq!(group[1].0, WorkerId(3));
1018 assert_eq!(group[2].0, WorkerId(7));
1019 }
1020
1021 #[test]
1022 fn test_group_by_lora_different_scales() {
1023 let lora_full = LoraConfig::new(0, 1.0);
1025 let lora_half = LoraConfig::new(0, 0.5);
1026
1027 let requests = vec![
1028 create_test_request(0, Some(lora_full.clone())),
1029 create_test_request(1, Some(lora_half.clone())),
1030 create_test_request(2, Some(lora_full.clone())),
1031 ];
1032
1033 let groups = group_by_lora(requests);
1034
1035 assert_eq!(groups.len(), 2);
1036 assert_eq!(groups[&Some(lora_full)].len(), 2);
1037 assert_eq!(groups[&Some(lora_half)].len(), 1);
1038 }
1039
1040 #[test]
1041 fn test_group_by_lora_empty() {
1042 let requests: Vec<(WorkerId, WorkerDecisionRequest)> = vec![];
1043 let groups = group_by_lora(requests);
1044 assert!(groups.is_empty());
1045 }
1046}