1use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39
40use std::collections::HashMap;
41
42use swarm_engine_core::actions::ActionDef;
43use swarm_engine_core::agent::{BatchDecisionRequest, DecisionResponse, WorkerDecisionRequest};
44use swarm_engine_core::exploration::{DependencyGraph, SelectResult};
45use swarm_engine_core::types::{LoraConfig, WorkerId};
46
47use crate::decider::{LlmDecider, LlmError};
48
49pub type BatchProcessResult = Vec<(WorkerId, Result<DecisionResponse, BatchProcessError>)>;
55
56#[derive(Debug, Clone, thiserror::Error)]
58pub enum BatchProcessError {
59 #[error("Batch process error (transient): {0}")]
61 Transient(String),
62
63 #[error("Batch process error: {0}")]
65 Permanent(String),
66}
67
68impl BatchProcessError {
69 pub fn transient(message: impl Into<String>) -> Self {
70 Self::Transient(message.into())
71 }
72
73 pub fn permanent(message: impl Into<String>) -> Self {
74 Self::Permanent(message.into())
75 }
76
77 pub fn is_transient(&self) -> bool {
78 matches!(self, Self::Transient(_))
79 }
80
81 pub fn message(&self) -> &str {
82 match self {
83 Self::Transient(msg) => msg,
84 Self::Permanent(msg) => msg,
85 }
86 }
87}
88
89impl From<LlmError> for BatchProcessError {
90 fn from(e: LlmError) -> Self {
91 if e.is_transient() {
92 Self::Transient(e.message().to_string())
93 } else {
94 Self::Permanent(e.message().to_string())
95 }
96 }
97}
98
99impl From<swarm_engine_core::error::SwarmError> for BatchProcessError {
100 fn from(err: swarm_engine_core::error::SwarmError) -> Self {
101 if err.is_transient() {
102 Self::Transient(err.message())
103 } else {
104 Self::Permanent(err.message())
105 }
106 }
107}
108
109impl From<BatchProcessError> for swarm_engine_core::error::SwarmError {
110 fn from(err: BatchProcessError) -> Self {
111 match err {
112 BatchProcessError::Transient(message) => {
113 swarm_engine_core::error::SwarmError::LlmTransient { message }
114 }
115 BatchProcessError::Permanent(message) => {
116 swarm_engine_core::error::SwarmError::LlmPermanent { message }
117 }
118 }
119 }
120}
121
122pub trait BatchProcessor: Send + Sync {
127 fn process(
135 &self,
136 request: BatchDecisionRequest,
137 ) -> Pin<Box<dyn Future<Output = BatchProcessResult> + Send + '_>>;
138
139 fn plan_dependencies(
153 &self,
154 _task: &str,
155 _actions: &[ActionDef],
156 _hint: Option<&SelectResult>,
157 ) -> Pin<Box<dyn Future<Output = Option<DependencyGraph>> + Send + '_>> {
158 Box::pin(async { None })
159 }
160
161 fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>>;
163
164 fn name(&self) -> &str;
166}
167
168#[derive(Debug, Clone)]
174pub struct LlmBatchProcessorConfig {
175 pub parallel: bool,
177 pub max_concurrency: usize,
179 pub max_retries: Option<usize>,
181}
182
183impl Default for LlmBatchProcessorConfig {
184 fn default() -> Self {
185 Self {
186 parallel: true,
187 max_concurrency: 4,
188 max_retries: Some(5),
189 }
190 }
191}
192
193pub struct LlmBatchProcessor<D: LlmDecider> {
198 decider: Arc<D>,
199 config: LlmBatchProcessorConfig,
200}
201
202impl<D: LlmDecider> LlmBatchProcessor<D> {
203 pub fn new(decider: D) -> Self {
205 Self {
206 decider: Arc::new(decider),
207 config: LlmBatchProcessorConfig::default(),
208 }
209 }
210
211 pub fn from_arc(decider: Arc<D>) -> Self {
213 Self {
214 decider,
215 config: LlmBatchProcessorConfig::default(),
216 }
217 }
218
219 pub fn with_config(mut self, config: LlmBatchProcessorConfig) -> Self {
221 self.config = config;
222 self
223 }
224}
225
226impl<D: LlmDecider + 'static> BatchProcessor for LlmBatchProcessor<D> {
227 fn process(
228 &self,
229 request: BatchDecisionRequest,
230 ) -> Pin<Box<dyn Future<Output = BatchProcessResult> + Send + '_>> {
231 Box::pin(async move {
232 if request.requests.is_empty() {
233 return vec![];
234 }
235
236 let requests: Vec<(WorkerId, WorkerDecisionRequest)> = request
238 .requests
239 .into_iter()
240 .map(|r| (r.worker_id, r))
241 .collect();
242
243 if self.config.parallel {
244 self.process_parallel(requests).await
245 } else {
246 self.process_sequential(requests).await
247 }
248 })
249 }
250
251 fn plan_dependencies(
252 &self,
253 task: &str,
254 actions: &[ActionDef],
255 hint: Option<&SelectResult>,
256 ) -> Pin<Box<dyn Future<Output = Option<DependencyGraph>> + Send + '_>> {
257 let task = task.to_string();
258 let actions: Vec<ActionDef> = actions.to_vec();
259 let decider = Arc::clone(&self.decider);
260
261 let (lora, vote_count) = match hint {
263 Some(SelectResult::UseLlm {
264 lora,
265 vote_count,
266 match_rate,
267 ..
268 }) => {
269 tracing::debug!(
270 match_rate = match_rate,
271 vote_count = vote_count,
272 has_lora = lora.is_some(),
273 "Using SelectResult hint for plan_dependencies"
274 );
275 (lora.clone(), *vote_count)
276 }
277 _ => {
278 tracing::debug!("No SelectResult hint, using defaults (lora=None, vote_count=3)");
280 (None, 3)
281 }
282 };
283
284 Box::pin(async move {
285 use std::time::Instant;
286 use swarm_engine_core::actions::ActionCategory;
287 use swarm_engine_core::exploration::DependencyGraphBuilder;
288
289 let start_time = Instant::now();
290 let action_names: Vec<String> = actions.iter().map(|a| a.name.clone()).collect();
291
292 let discover: Vec<&ActionDef> = actions
294 .iter()
295 .filter(|a| a.category == ActionCategory::NodeExpand)
296 .collect();
297 let not_discover: Vec<&ActionDef> = actions
298 .iter()
299 .filter(|a| a.category == ActionCategory::NodeStateChange)
300 .collect();
301
302 tracing::debug!(
303 discover = ?discover.iter().map(|a| &a.name).collect::<Vec<_>>(),
304 not_discover = ?not_discover.iter().map(|a| &a.name).collect::<Vec<_>>(),
305 "Separated actions by category"
306 );
307
308 let discover_sort_start = Instant::now();
310 let sorted_discover = if discover.len() <= 1 {
311 discover.iter().map(|a| a.name.clone()).collect()
312 } else {
313 binary_sort_actions(&task, &discover, decider.as_ref(), lora.as_ref(), vote_count).await
314 };
315 let discover_sort_ms = discover_sort_start.elapsed().as_millis();
316
317 tracing::debug!(
318 sorted = ?sorted_discover,
319 elapsed_ms = discover_sort_ms,
320 vote_count = vote_count,
321 has_lora = lora.is_some(),
322 "Sorted Discover actions via binary comparison"
323 );
324
325 let not_discover_sort_start = Instant::now();
327 let sorted_not_discover = if not_discover.len() <= 1 {
328 not_discover.iter().map(|a| a.name.clone()).collect()
329 } else {
330 binary_sort_actions(&task, ¬_discover, decider.as_ref(), lora.as_ref(), vote_count).await
331 };
332 let not_discover_sort_ms = not_discover_sort_start.elapsed().as_millis();
333
334 tracing::debug!(
335 sorted = ?sorted_not_discover,
336 elapsed_ms = not_discover_sort_ms,
337 "Sorted NotDiscover actions via binary comparison"
338 );
339
340 let mut builder = DependencyGraphBuilder::new()
342 .task(&task)
343 .available_actions(action_names.clone());
344
345 if !sorted_discover.is_empty() {
347 builder = builder.start_node(&sorted_discover[0]);
348 } else if !sorted_not_discover.is_empty() {
349 builder = builder.start_node(&sorted_not_discover[0]);
351 }
352
353 if let Some(last) = sorted_not_discover.last() {
355 builder = builder.terminal_node(last);
356 } else if !sorted_discover.is_empty() {
357 builder = builder.terminal_node(sorted_discover.last().unwrap());
359 }
360
361 for window in sorted_discover.windows(2) {
363 builder = builder.edge(&window[0], &window[1], 0.9);
364 }
365
366 if !sorted_discover.is_empty() && !sorted_not_discover.is_empty() {
368 builder = builder.edge(
369 sorted_discover.last().unwrap(),
370 &sorted_not_discover[0],
371 0.9,
372 );
373 }
374
375 for window in sorted_not_discover.windows(2) {
377 builder = builder.edge(&window[0], &window[1], 0.9);
378 }
379
380 let mut graph = builder.build();
381 let total_ms = start_time.elapsed().as_millis();
382
383 graph.set_action_order(sorted_discover.clone(), sorted_not_discover.clone());
385
386 {
388 use swarm_engine_core::events::{LearningEvent, LearningEventChannel};
389 use swarm_engine_core::learn::DependencyGraphRecord;
390
391 let prompt = format!(
393 "Task: {}\n\nAvailable Actions:\n{}",
394 task,
395 action_names
396 .iter()
397 .map(|n| format!("- {}", n))
398 .collect::<Vec<_>>()
399 .join("\n")
400 );
401
402 let response = format!(
404 "discover_order: {:?}\nnot_discover_order: {:?}",
405 sorted_discover, sorted_not_discover
406 );
407
408 let event = LearningEvent::dependency_graph_inference(decider.model_name())
410 .prompt(&prompt)
411 .response(&response)
412 .available_actions(action_names)
413 .discover_order(sorted_discover.clone())
414 .not_discover_order(sorted_not_discover.clone())
415 .endpoint(decider.endpoint())
416 .latency_ms(total_ms as u64)
417 .success()
418 .build();
419
420 LearningEventChannel::global().emit(event.clone());
422
423 let record = DependencyGraphRecord::from(&event);
425 graph.set_learn_record(record);
426 }
427
428 tracing::info!(
429 discover_order = ?sorted_discover,
430 not_discover_order = ?sorted_not_discover,
431 edges = graph.edges().len(),
432 discover_sort_ms = discover_sort_ms,
433 not_discover_sort_ms = not_discover_sort_ms,
434 total_ms = total_ms,
435 "DependencyGraph generated via LLM binary sort"
436 );
437
438 Some(graph)
439 })
440 }
441
442 fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
443 let decider = Arc::clone(&self.decider);
444 Box::pin(async move { decider.is_healthy().await })
445 }
446
447 fn name(&self) -> &str {
448 self.decider.model_name()
449 }
450}
451
452impl<D: LlmDecider + 'static> LlmBatchProcessor<D> {
453 async fn process_parallel(
470 &self,
471 requests: Vec<(WorkerId, WorkerDecisionRequest)>,
472 ) -> BatchProcessResult {
473 let grouped = group_by_lora(requests);
475
476 let group_count = grouped.len();
477 if group_count > 1 {
478 tracing::debug!(
479 groups = group_count,
480 "Processing requests in {} LoRA groups",
481 group_count
482 );
483 }
484
485 let mut all_results = Vec::new();
487 for (lora_config, group_requests) in grouped {
488 if group_count > 1 {
489 tracing::trace!(
490 lora = ?lora_config,
491 count = group_requests.len(),
492 "Processing LoRA group"
493 );
494 }
495 let results = self.process_group(group_requests).await;
496 all_results.extend(results);
497 }
498
499 all_results
500 }
501
502 async fn process_group(
504 &self,
505 requests: Vec<(WorkerId, WorkerDecisionRequest)>,
506 ) -> BatchProcessResult {
507 use futures::future::join_all;
508 use tokio::sync::Semaphore;
509
510 let max_concurrency = self
512 .decider
513 .max_concurrency()
514 .await
515 .unwrap_or(self.config.max_concurrency);
516
517 let semaphore = Arc::new(Semaphore::new(max_concurrency));
518
519 let futures: Vec<_> = requests
520 .into_iter()
521 .map(|(worker_id, req)| {
522 let decider = Arc::clone(&self.decider);
523 let sem = Arc::clone(&semaphore);
524 async move {
525 let _permit = sem.acquire().await.expect("Semaphore closed");
527 let result = decider.decide(req).await;
528 (worker_id, result)
529 }
530 })
531 .collect();
532
533 let results = join_all(futures).await;
534
535 results
536 .into_iter()
537 .map(|(worker_id, result)| {
538 let mapped = result.map_err(BatchProcessError::from);
539 (worker_id, mapped)
540 })
541 .collect()
542 }
543
544 async fn process_sequential(
546 &self,
547 requests: Vec<(WorkerId, WorkerDecisionRequest)>,
548 ) -> BatchProcessResult {
549 let mut results = Vec::with_capacity(requests.len());
550
551 for (worker_id, req) in requests {
552 let result = self.decider.decide(req).await;
553 let mapped = result.map_err(BatchProcessError::from);
554 results.push((worker_id, mapped));
555 }
556
557 results
558 }
559}
560
561fn group_by_lora(
566 requests: Vec<(WorkerId, WorkerDecisionRequest)>,
567) -> HashMap<Option<LoraConfig>, Vec<(WorkerId, WorkerDecisionRequest)>> {
568 let mut groups: HashMap<Option<LoraConfig>, Vec<(WorkerId, WorkerDecisionRequest)>> =
569 HashMap::new();
570
571 for (worker_id, req) in requests {
572 let lora_key = req.lora.clone();
573 groups.entry(lora_key).or_default().push((worker_id, req));
574 }
575
576 groups
577}
578
579async fn binary_sort_actions<D: LlmDecider>(
596 task: &str,
597 actions: &[&ActionDef],
598 decider: &D,
599 lora: Option<&LoraConfig>,
600 vote_count: u8,
601) -> Vec<String> {
602 use futures::future::join_all;
603 use std::collections::HashMap;
604
605 if actions.len() <= 1 {
606 return actions.iter().map(|a| a.name.clone()).collect();
607 }
608
609 let mut requests: Vec<(usize, usize, String, String, String)> = Vec::new();
612 let mut pair_index = 0;
613
614 for i in 0..actions.len() {
615 for j in (i + 1)..actions.len() {
616 let a = actions[i];
617 let b = actions[j];
618 let prompt = format!(
619 "Goal: {}\n- {}: {}\n- {}: {}\nWhich comes first: {} or {}?\nAnswer (one word):",
620 task, a.name, a.description, b.name, b.description, a.name, b.name
621 );
622
623 for vote_idx in 0..vote_count as usize {
625 requests.push((
626 pair_index,
627 vote_idx,
628 prompt.clone(),
629 a.name.clone(),
630 b.name.clone(),
631 ));
632 }
633 pair_index += 1;
634 }
635 }
636
637 let total_requests = requests.len();
638 tracing::debug!(
639 pairs = pair_index,
640 total_requests = total_requests,
641 "Binary sort: sending batch requests"
642 );
643
644 let futures: Vec<_> = requests
647 .into_iter()
648 .map(|(pair_idx, vote_idx, prompt, a_name, b_name)| {
649 let decider_ref = decider;
650 async move {
651 let result = decider_ref.call_raw(&prompt, lora).await;
652 (pair_idx, vote_idx, result, a_name, b_name)
653 }
654 })
655 .collect();
656
657 let results = join_all(futures).await;
658
659 let mut pair_votes: HashMap<usize, (usize, usize, String, String)> = HashMap::new();
662
663 for (pair_idx, _vote_idx, result, a_name, b_name) in results {
664 let entry = pair_votes
665 .entry(pair_idx)
666 .or_insert((0, 0, a_name.clone(), b_name.clone()));
667
668 if let Ok(response) = result {
669 let response_upper = response.to_uppercase();
670 let a_upper = a_name.to_uppercase();
671 let b_upper = b_name.to_uppercase();
672
673 if response_upper.contains(&a_upper) {
674 entry.0 += 1;
675 } else if response_upper.contains(&b_upper) {
676 entry.1 += 1;
677 }
678 }
679 }
680
681 let mut wins: HashMap<String, usize> = HashMap::new();
683 for a in actions {
684 wins.insert(a.name.clone(), 0);
685 }
686
687 for (_pair_idx, (a_count, b_count, a_name, b_name)) in pair_votes {
688 if a_count >= b_count {
690 *wins.get_mut(&b_name).unwrap() += 1;
692 } else {
693 *wins.get_mut(&a_name).unwrap() += 1;
695 }
696 }
697
698 let mut sorted: Vec<_> = wins.into_iter().collect();
700 sorted.sort_by_key(|(_, count)| *count);
701
702 tracing::debug!(
703 sorted = ?sorted.iter().map(|(n, c)| format!("{}:{}", n, c)).collect::<Vec<_>>(),
704 "Binary sort completed"
705 );
706
707 sorted.into_iter().map(|(name, _)| name).collect()
708}
709
710#[cfg(test)]
715mod tests {
716 use super::*;
717
718 #[test]
719 fn test_batch_process_error_transient() {
720 let err = BatchProcessError::transient("connection timeout");
721 assert!(err.is_transient());
722 assert_eq!(err.message(), "connection timeout");
723 }
724
725 #[test]
726 fn test_batch_process_error_permanent() {
727 let err = BatchProcessError::permanent("invalid model");
728 assert!(!err.is_transient());
729 assert_eq!(err.message(), "invalid model");
730 }
731
732 #[test]
733 fn test_batch_process_error_from_llm_error() {
734 let llm_err = LlmError::transient("timeout");
735 let batch_err: BatchProcessError = llm_err.into();
736 assert!(batch_err.is_transient());
737 assert_eq!(batch_err.message(), "timeout");
738 }
739
740 #[test]
741 fn test_ollama_batch_processor_config_default() {
742 let config = LlmBatchProcessorConfig::default();
743 assert!(config.parallel);
744 assert_eq!(config.max_concurrency, 4);
745 }
746
747 use std::collections::HashMap;
752
753 fn binary_sort_sync(
756 actions: &[&str],
757 comparator: impl Fn(&str, &str) -> String,
759 ) -> Vec<String> {
760 if actions.len() <= 1 {
761 return actions.iter().map(|s| s.to_string()).collect();
762 }
763
764 let mut wins: HashMap<String, usize> = HashMap::new();
765 for &a in actions {
766 wins.insert(a.to_string(), 0);
767 }
768
769 for i in 0..actions.len() {
770 for j in (i + 1)..actions.len() {
771 let a = actions[i];
772 let b = actions[j];
773 let winner = comparator(a, b);
774
775 if winner == a {
777 *wins.get_mut(b).unwrap() += 1;
778 } else {
779 *wins.get_mut(a).unwrap() += 1;
780 }
781 }
782 }
783
784 let mut sorted: Vec<_> = wins.into_iter().collect();
785 sorted.sort_by_key(|(_, count)| *count);
786 sorted.into_iter().map(|(name, _)| name).collect()
787 }
788
789 #[test]
790 fn test_binary_sort_two_actions() {
791 let result = binary_sort_sync(
793 &["Fetch", "Summarize"],
794 |a, _b| a.to_string(), );
796 assert_eq!(result, vec!["Fetch", "Summarize"]);
797
798 let result = binary_sort_sync(
800 &["Fetch", "Summarize"],
801 |_a, b| b.to_string(), );
803 assert_eq!(result, vec!["Summarize", "Fetch"]);
804 }
805
806 #[test]
807 fn test_binary_sort_three_actions() {
808 let result = binary_sort_sync(&["Test", "Deploy", "Build"], |a, b| {
811 let order = ["Build", "Test", "Deploy"];
812 let a_idx = order.iter().position(|&x| x == a).unwrap();
813 let b_idx = order.iter().position(|&x| x == b).unwrap();
814 if a_idx < b_idx {
815 a.to_string()
816 } else {
817 b.to_string()
818 }
819 });
820 assert_eq!(result, vec!["Build", "Test", "Deploy"]);
821 }
822
823 #[test]
824 fn test_binary_sort_wins_calculation() {
825 let mut wins: HashMap<String, usize> = HashMap::new();
835 wins.insert("A".to_string(), 0);
836 wins.insert("B".to_string(), 0);
837 wins.insert("C".to_string(), 0);
838
839 *wins.get_mut("B").unwrap() += 1;
841 *wins.get_mut("C").unwrap() += 1;
843 *wins.get_mut("C").unwrap() += 1;
845
846 assert_eq!(wins["A"], 0);
847 assert_eq!(wins["B"], 1);
848 assert_eq!(wins["C"], 2);
849
850 let mut sorted: Vec<_> = wins.into_iter().collect();
851 sorted.sort_by_key(|(_, count)| *count);
852 let result: Vec<_> = sorted.into_iter().map(|(name, _)| name).collect();
853
854 assert_eq!(result, vec!["A", "B", "C"]);
855 }
856
857 fn extract_winner(response: &str, a: &str, b: &str) -> Option<String> {
859 let response_upper = response.to_uppercase();
860 let a_upper = a.to_uppercase();
861 let b_upper = b.to_uppercase();
862
863 if response_upper.contains(&a_upper) {
864 Some(a.to_string())
865 } else if response_upper.contains(&b_upper) {
866 Some(b.to_string())
867 } else {
868 None
869 }
870 }
871
872 #[test]
873 fn test_extract_winner() {
874 assert_eq!(
876 extract_winner("Fetch", "Fetch", "Summarize"),
877 Some("Fetch".to_string())
878 );
879 assert_eq!(
880 extract_winner("Summarize", "Fetch", "Summarize"),
881 Some("Summarize".to_string())
882 );
883
884 assert_eq!(
886 extract_winner(" Fetch", "Fetch", "Summarize"),
887 Some("Fetch".to_string())
888 );
889
890 assert_eq!(
892 extract_winner("fetch", "Fetch", "Summarize"),
893 Some("Fetch".to_string())
894 );
895 assert_eq!(
896 extract_winner("FETCH", "Fetch", "Summarize"),
897 Some("Fetch".to_string())
898 );
899
900 assert_eq!(
902 extract_winner("The answer is Fetch.", "Fetch", "Summarize"),
903 Some("Fetch".to_string())
904 );
905
906 assert_eq!(extract_winner("Unknown", "Fetch", "Summarize"), None);
908
909 assert_eq!(
911 extract_winner("Fetch then Summarize", "Fetch", "Summarize"),
912 Some("Fetch".to_string())
913 );
914 }
915
916 #[test]
917 fn test_vote_majority() {
918 fn vote_majority(responses: &[&str], a: &str, b: &str) -> String {
920 let mut a_count = 0;
921 let mut b_count = 0;
922
923 for response in responses {
924 if let Some(winner) = extract_winner(response, a, b) {
925 if winner == a {
926 a_count += 1;
927 } else {
928 b_count += 1;
929 }
930 }
931 }
932
933 if a_count >= b_count {
934 a.to_string()
935 } else {
936 b.to_string()
937 }
938 }
939
940 assert_eq!(
942 vote_majority(&["Fetch", "Fetch", "Fetch"], "Fetch", "Summarize"),
943 "Fetch"
944 );
945
946 assert_eq!(
948 vote_majority(&["Fetch", "Summarize", "Fetch"], "Fetch", "Summarize"),
949 "Fetch"
950 );
951
952 assert_eq!(
954 vote_majority(&["Summarize", "Summarize", "Fetch"], "Fetch", "Summarize"),
955 "Summarize"
956 );
957
958 assert_eq!(
960 vote_majority(&["Fetch", "Summarize", "Unknown"], "Fetch", "Summarize"),
961 "Fetch"
962 );
963 }
964
965 use swarm_engine_core::context::{ContextTarget, GlobalContext, ResolvedContext};
970
971 fn create_test_request(
972 worker_id: usize,
973 lora: Option<LoraConfig>,
974 ) -> (WorkerId, WorkerDecisionRequest) {
975 let global = GlobalContext {
976 tick: 0,
977 max_ticks: 100,
978 progress: 0.0,
979 success_rate: 0.0,
980 task_description: Some("test".to_string()),
981 hint: None,
982 };
983 let context = ResolvedContext::new(global, ContextTarget::Worker(WorkerId(worker_id)));
984
985 (
986 WorkerId(worker_id),
987 WorkerDecisionRequest {
988 worker_id: WorkerId(worker_id),
989 query: format!("query_{}", worker_id),
990 context,
991 lora,
992 },
993 )
994 }
995
996 #[test]
997 fn test_group_by_lora_single_group_no_lora() {
998 let requests = vec![
999 create_test_request(0, None),
1000 create_test_request(1, None),
1001 create_test_request(2, None),
1002 ];
1003
1004 let groups = group_by_lora(requests);
1005
1006 assert_eq!(groups.len(), 1);
1007 assert!(groups.contains_key(&None));
1008 assert_eq!(groups[&None].len(), 3);
1009 }
1010
1011 #[test]
1012 fn test_group_by_lora_single_group_with_lora() {
1013 let lora = LoraConfig::with_id(0);
1014 let requests = vec![
1015 create_test_request(0, Some(lora.clone())),
1016 create_test_request(1, Some(lora.clone())),
1017 ];
1018
1019 let groups = group_by_lora(requests);
1020
1021 assert_eq!(groups.len(), 1);
1022 assert!(groups.contains_key(&Some(lora)));
1023 }
1024
1025 #[test]
1026 fn test_group_by_lora_multiple_groups() {
1027 let lora_a = LoraConfig::with_id(0);
1028 let lora_b = LoraConfig::with_id(1);
1029
1030 let requests = vec![
1031 create_test_request(0, Some(lora_a.clone())),
1032 create_test_request(1, Some(lora_b.clone())),
1033 create_test_request(2, Some(lora_a.clone())),
1034 create_test_request(3, None),
1035 create_test_request(4, Some(lora_b.clone())),
1036 ];
1037
1038 let groups = group_by_lora(requests);
1039
1040 assert_eq!(groups.len(), 3);
1041 assert_eq!(groups[&Some(lora_a)].len(), 2);
1042 assert_eq!(groups[&Some(lora_b)].len(), 2);
1043 assert_eq!(groups[&None].len(), 1);
1044 }
1045
1046 #[test]
1047 fn test_group_by_lora_preserves_order_within_group() {
1048 let lora = LoraConfig::with_id(0);
1049 let requests = vec![
1050 create_test_request(5, Some(lora.clone())),
1051 create_test_request(3, Some(lora.clone())),
1052 create_test_request(7, Some(lora.clone())),
1053 ];
1054
1055 let groups = group_by_lora(requests);
1056 let group = &groups[&Some(lora)];
1057
1058 assert_eq!(group[0].0, WorkerId(5));
1060 assert_eq!(group[1].0, WorkerId(3));
1061 assert_eq!(group[2].0, WorkerId(7));
1062 }
1063
1064 #[test]
1065 fn test_group_by_lora_different_scales() {
1066 let lora_full = LoraConfig::new(0, 1.0);
1068 let lora_half = LoraConfig::new(0, 0.5);
1069
1070 let requests = vec![
1071 create_test_request(0, Some(lora_full.clone())),
1072 create_test_request(1, Some(lora_half.clone())),
1073 create_test_request(2, Some(lora_full.clone())),
1074 ];
1075
1076 let groups = group_by_lora(requests);
1077
1078 assert_eq!(groups.len(), 2);
1079 assert_eq!(groups[&Some(lora_full)].len(), 2);
1080 assert_eq!(groups[&Some(lora_half)].len(), 1);
1081 }
1082
1083 #[test]
1084 fn test_group_by_lora_empty() {
1085 let requests: Vec<(WorkerId, WorkerDecisionRequest)> = vec![];
1086 let groups = group_by_lora(requests);
1087 assert!(groups.is_empty());
1088 }
1089}