Skip to main content

swarm_engine_llm/
batch_processor.rs

1//! Batch Processor - ManagerAgent 向け Batch LLM 処理
2//!
3//! ManagerAgent の `BatchDecisionRequest` を処理するための抽象化レイヤー。
4//!
5//! # 設計
6//!
7//! ```text
8//! Core Layer
9//! ├── ManagerAgent trait (prepare / finalize)
10//! ├── DefaultBatchManagerAgent   ← Core層のデフォルト実装
11//! ├── ContextStore / ContextView ← 正規化されたコンテキスト
12//! └── ContextResolver            ← スコープ解決
13//!
14//! LLM Layer
15//! ├── PromptBuilder              ← ResolvedContext → プロンプト
16//! ├── BatchProcessor trait       ← Batch 処理の抽象
17//! │   └── LlmBatchProcessor   ← Ollama 実装(仮想バッチ)
18//! │
19//! └── BatchInvoker 実装          ← LLM Batch 呼び出し
20//!     └── OllamaBatchInvoker
21//! ```
22//!
23//! # 仮想バッチ vs 真のバッチ
24//!
25//! Ollama は真の Batch API を持たないため、`LlmBatchProcessor` は
26//! 内部で並列/順次処理を行う「仮想バッチ」として実装されます。
27//! 将来 vLLM 等の真の Batch API を持つバックエンドに切り替える際は、
28//! `BatchProcessor` trait の別実装を提供するだけで対応可能です。
29//!
30//! # 型の統一
31//!
32//! LLM層はCore層の型を直接使用するため、変換ロジックは不要です:
33//! - `WorkerDecisionRequest` - リクエスト
34//! - `DecisionResponse` - レスポンス
35
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39
40use std::collections::HashMap;
41
42use swarm_engine_core::actions::ActionDef;
43use swarm_engine_core::agent::{BatchDecisionRequest, DecisionResponse, WorkerDecisionRequest};
44use swarm_engine_core::exploration::DependencyGraph;
45use swarm_engine_core::types::{LoraConfig, WorkerId};
46
47use crate::decider::{LlmDecider, LlmError};
48
49// ============================================================================
50// BatchProcessor Trait
51// ============================================================================
52
53/// Batch 処理結果
54pub type BatchProcessResult = Vec<(WorkerId, Result<DecisionResponse, BatchProcessError>)>;
55
56/// Batch 処理エラー
57#[derive(Debug, Clone, thiserror::Error)]
58pub enum BatchProcessError {
59    /// 一時的エラー(リトライ可能)
60    #[error("Batch process error (transient): {0}")]
61    Transient(String),
62
63    /// 恒久的エラー(リトライ不可)
64    #[error("Batch process error: {0}")]
65    Permanent(String),
66}
67
68impl BatchProcessError {
69    pub fn transient(message: impl Into<String>) -> Self {
70        Self::Transient(message.into())
71    }
72
73    pub fn permanent(message: impl Into<String>) -> Self {
74        Self::Permanent(message.into())
75    }
76
77    pub fn is_transient(&self) -> bool {
78        matches!(self, Self::Transient(_))
79    }
80
81    pub fn message(&self) -> &str {
82        match self {
83            Self::Transient(msg) => msg,
84            Self::Permanent(msg) => msg,
85        }
86    }
87}
88
89impl From<LlmError> for BatchProcessError {
90    fn from(e: LlmError) -> Self {
91        if e.is_transient() {
92            Self::Transient(e.message().to_string())
93        } else {
94            Self::Permanent(e.message().to_string())
95        }
96    }
97}
98
99impl From<swarm_engine_core::error::SwarmError> for BatchProcessError {
100    fn from(err: swarm_engine_core::error::SwarmError) -> Self {
101        if err.is_transient() {
102            Self::Transient(err.message())
103        } else {
104            Self::Permanent(err.message())
105        }
106    }
107}
108
109impl From<BatchProcessError> for swarm_engine_core::error::SwarmError {
110    fn from(err: BatchProcessError) -> Self {
111        match err {
112            BatchProcessError::Transient(message) => {
113                swarm_engine_core::error::SwarmError::LlmTransient { message }
114            }
115            BatchProcessError::Permanent(message) => {
116                swarm_engine_core::error::SwarmError::LlmPermanent { message }
117            }
118        }
119    }
120}
121
122/// Batch Processor trait
123///
124/// `BatchDecisionRequest` を受け取り、各 Worker への決定を返す。
125/// バックエンド(Ollama, vLLM 等)に応じた実装を提供する。
126pub trait BatchProcessor: Send + Sync {
127    /// Batch リクエストを処理
128    ///
129    /// # Arguments
130    /// * `request` - Core の `BatchDecisionRequest`
131    ///
132    /// # Returns
133    /// 各 Worker への決定結果(WorkerId とペア)
134    fn process(
135        &self,
136        request: BatchDecisionRequest,
137    ) -> Pin<Box<dyn Future<Output = BatchProcessResult> + Send + '_>>;
138
139    /// タスクとアクション一覧からアクション依存グラフを生成
140    ///
141    /// Swarm の Ticks 開始前に呼び出され、アクション間の依存関係を計画する。
142    /// LLM を使用して動的に依存グラフを生成する。
143    ///
144    /// # Default
145    /// デフォルトでは None を返す(依存グラフ生成をスキップ)。
146    fn plan_dependencies(
147        &self,
148        _task: &str,
149        _actions: &[ActionDef],
150    ) -> Pin<Box<dyn Future<Output = Option<DependencyGraph>> + Send + '_>> {
151        Box::pin(async { None })
152    }
153
154    /// ヘルスチェック
155    fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>>;
156
157    /// プロセッサ名
158    fn name(&self) -> &str;
159}
160
161// ============================================================================
162// LlmBatchProcessor
163// ============================================================================
164
165/// Ollama Batch Processor 設定
166#[derive(Debug, Clone)]
167pub struct LlmBatchProcessorConfig {
168    /// 並列実行するか(false の場合は順次処理)
169    pub parallel: bool,
170    /// 並列実行時の最大同時実行数
171    pub max_concurrency: usize,
172    /// DependencyGraph 生成時の最大リトライ回数
173    pub max_retries: Option<usize>,
174}
175
176impl Default for LlmBatchProcessorConfig {
177    fn default() -> Self {
178        Self {
179            parallel: true,
180            max_concurrency: 4,
181            max_retries: Some(5),
182        }
183    }
184}
185
186/// Ollama Batch Processor
187///
188/// Ollama は真の Batch API を持たないため、仮想バッチとして実装。
189/// 内部で `LlmDecider` を使用して並列/順次処理を行う。
190pub struct LlmBatchProcessor<D: LlmDecider> {
191    decider: Arc<D>,
192    config: LlmBatchProcessorConfig,
193}
194
195impl<D: LlmDecider> LlmBatchProcessor<D> {
196    /// 新しい LlmBatchProcessor を作成
197    pub fn new(decider: D) -> Self {
198        Self {
199            decider: Arc::new(decider),
200            config: LlmBatchProcessorConfig::default(),
201        }
202    }
203
204    /// Arc でラップされた Decider から作成
205    pub fn from_arc(decider: Arc<D>) -> Self {
206        Self {
207            decider,
208            config: LlmBatchProcessorConfig::default(),
209        }
210    }
211
212    /// 設定を指定して作成
213    pub fn with_config(mut self, config: LlmBatchProcessorConfig) -> Self {
214        self.config = config;
215        self
216    }
217}
218
219impl<D: LlmDecider + 'static> BatchProcessor for LlmBatchProcessor<D> {
220    fn process(
221        &self,
222        request: BatchDecisionRequest,
223    ) -> Pin<Box<dyn Future<Output = BatchProcessResult> + Send + '_>> {
224        Box::pin(async move {
225            if request.requests.is_empty() {
226                return vec![];
227            }
228
229            // Core の WorkerDecisionRequest をそのまま使用(変換不要)
230            let requests: Vec<(WorkerId, WorkerDecisionRequest)> = request
231                .requests
232                .into_iter()
233                .map(|r| (r.worker_id, r))
234                .collect();
235
236            if self.config.parallel {
237                self.process_parallel(requests).await
238            } else {
239                self.process_sequential(requests).await
240            }
241        })
242    }
243
244    fn plan_dependencies(
245        &self,
246        task: &str,
247        actions: &[ActionDef],
248    ) -> Pin<Box<dyn Future<Output = Option<DependencyGraph>> + Send + '_>> {
249        let task = task.to_string();
250        let actions: Vec<ActionDef> = actions.to_vec();
251        let decider = Arc::clone(&self.decider);
252
253        Box::pin(async move {
254            use std::time::Instant;
255            use swarm_engine_core::actions::ActionCategory;
256            use swarm_engine_core::exploration::DependencyGraphBuilder;
257
258            let start_time = Instant::now();
259            let action_names: Vec<String> = actions.iter().map(|a| a.name.clone()).collect();
260
261            // 1. Discover (NodeExpand) と NotDiscover (NodeStateChange) に分離
262            let discover: Vec<&ActionDef> = actions
263                .iter()
264                .filter(|a| a.category == ActionCategory::NodeExpand)
265                .collect();
266            let not_discover: Vec<&ActionDef> = actions
267                .iter()
268                .filter(|a| a.category == ActionCategory::NodeStateChange)
269                .collect();
270
271            tracing::debug!(
272                discover = ?discover.iter().map(|a| &a.name).collect::<Vec<_>>(),
273                not_discover = ?not_discover.iter().map(|a| &a.name).collect::<Vec<_>>(),
274                "Separated actions by category"
275            );
276
277            // 2. Discover も Binary + Vote でソート(順序関係を保持)
278            let discover_sort_start = Instant::now();
279            let sorted_discover = if discover.len() <= 1 {
280                discover.iter().map(|a| a.name.clone()).collect()
281            } else {
282                binary_sort_actions(&task, &discover, decider.as_ref()).await
283            };
284            let discover_sort_ms = discover_sort_start.elapsed().as_millis();
285
286            tracing::debug!(
287                sorted = ?sorted_discover,
288                elapsed_ms = discover_sort_ms,
289                "Sorted Discover actions via binary comparison"
290            );
291
292            // 3. NotDiscover を Binary + Vote でソート
293            let not_discover_sort_start = Instant::now();
294            let sorted_not_discover = if not_discover.len() <= 1 {
295                not_discover.iter().map(|a| a.name.clone()).collect()
296            } else {
297                binary_sort_actions(&task, &not_discover, decider.as_ref()).await
298            };
299            let not_discover_sort_ms = not_discover_sort_start.elapsed().as_millis();
300
301            tracing::debug!(
302                sorted = ?sorted_not_discover,
303                elapsed_ms = not_discover_sort_ms,
304                "Sorted NotDiscover actions via binary comparison"
305            );
306
307            // 4. グラフ構築: Discover(線形)→ NotDiscover(線形)
308            let mut builder = DependencyGraphBuilder::new()
309                .task(&task)
310                .available_actions(action_names.clone());
311
312            // 最初の Discover を Start node として設定
313            if !sorted_discover.is_empty() {
314                builder = builder.start_node(&sorted_discover[0]);
315            } else if !sorted_not_discover.is_empty() {
316                // Discover がなければ最初の NotDiscover を Start に
317                builder = builder.start_node(&sorted_not_discover[0]);
318            }
319
320            // NotDiscover の最後を Terminal に
321            if let Some(last) = sorted_not_discover.last() {
322                builder = builder.terminal_node(last);
323            } else if !sorted_discover.is_empty() {
324                // NotDiscover がなければ最後の Discover を Terminal に
325                builder = builder.terminal_node(sorted_discover.last().unwrap());
326            }
327
328            // Discover 間のエッジ(線形)
329            for window in sorted_discover.windows(2) {
330                builder = builder.edge(&window[0], &window[1], 0.9);
331            }
332
333            // 最後の Discover → 最初の NotDiscover へのエッジ
334            if !sorted_discover.is_empty() && !sorted_not_discover.is_empty() {
335                builder = builder.edge(
336                    sorted_discover.last().unwrap(),
337                    &sorted_not_discover[0],
338                    0.9,
339                );
340            }
341
342            // NotDiscover 間のエッジ(線形)
343            for window in sorted_not_discover.windows(2) {
344                builder = builder.edge(&window[0], &window[1], 0.9);
345            }
346
347            let mut graph = builder.build();
348            let total_ms = start_time.elapsed().as_millis();
349
350            // Store action order for caching
351            graph.set_action_order(sorted_discover.clone(), sorted_not_discover.clone());
352
353            // Create learning record for DependencyGraph inference
354            {
355                use swarm_engine_core::events::{LearningEvent, LearningEventChannel};
356                use swarm_engine_core::learn::DependencyGraphRecord;
357
358                // Build a summary prompt representing the inference input
359                let prompt = format!(
360                    "Task: {}\n\nAvailable Actions:\n{}",
361                    task,
362                    action_names
363                        .iter()
364                        .map(|n| format!("- {}", n))
365                        .collect::<Vec<_>>()
366                        .join("\n")
367                );
368
369                // Build a summary response representing the inference output
370                let response = format!(
371                    "discover_order: {:?}\nnot_discover_order: {:?}",
372                    sorted_discover, sorted_not_discover
373                );
374
375                // Create Event and emit to LearningEventChannel
376                let event = LearningEvent::dependency_graph_inference(decider.model_name())
377                    .prompt(&prompt)
378                    .response(&response)
379                    .available_actions(action_names)
380                    .discover_order(sorted_discover.clone())
381                    .not_discover_order(sorted_not_discover.clone())
382                    .endpoint(decider.endpoint())
383                    .latency_ms(total_ms as u64)
384                    .success()
385                    .build();
386
387                // Emit event for learning pipeline
388                LearningEventChannel::global().emit(event.clone());
389
390                // Convert Event to Record for graph storage
391                let record = DependencyGraphRecord::from(&event);
392                graph.set_learn_record(record);
393            }
394
395            tracing::info!(
396                discover_order = ?sorted_discover,
397                not_discover_order = ?sorted_not_discover,
398                edges = graph.edges().len(),
399                discover_sort_ms = discover_sort_ms,
400                not_discover_sort_ms = not_discover_sort_ms,
401                total_ms = total_ms,
402                "DependencyGraph generated via LLM binary sort"
403            );
404
405            Some(graph)
406        })
407    }
408
409    fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
410        let decider = Arc::clone(&self.decider);
411        Box::pin(async move { decider.is_healthy().await })
412    }
413
414    fn name(&self) -> &str {
415        self.decider.model_name()
416    }
417}
418
419impl<D: LlmDecider + 'static> LlmBatchProcessor<D> {
420    /// 並列実行(LoRA グルーピング + Semaphore で同時実行数を制限)
421    ///
422    /// # LoRA グルーピング
423    ///
424    /// llama.cpp の continuous batching では、同じ LoRA 設定のリクエストは
425    /// 効率的にバッチ処理される。異なる LoRA を混ぜると効率が落ちるため、
426    /// リクエストを LoRA 設定でグルーピングして処理する。
427    ///
428    /// ```text
429    /// リクエスト群
430    /// ├── LoRA A のリクエスト群 → 並列実行(グループ内)
431    /// ├── LoRA B のリクエスト群 → 並列実行(グループ内)
432    /// └── LoRA なしのリクエスト群 → 並列実行(グループ内)
433    /// ```
434    ///
435    /// グループ間は順次処理(同じ LoRA を連続して処理することで効率化)
436    async fn process_parallel(
437        &self,
438        requests: Vec<(WorkerId, WorkerDecisionRequest)>,
439    ) -> BatchProcessResult {
440        // リクエストを LoRA 設定でグルーピング
441        let grouped = group_by_lora(requests);
442
443        let group_count = grouped.len();
444        if group_count > 1 {
445            tracing::debug!(
446                groups = group_count,
447                "Processing requests in {} LoRA groups",
448                group_count
449            );
450        }
451
452        // 各グループを順次処理(グループ内は並列)
453        let mut all_results = Vec::new();
454        for (lora_config, group_requests) in grouped {
455            if group_count > 1 {
456                tracing::trace!(
457                    lora = ?lora_config,
458                    count = group_requests.len(),
459                    "Processing LoRA group"
460                );
461            }
462            let results = self.process_group(group_requests).await;
463            all_results.extend(results);
464        }
465
466        all_results
467    }
468
469    /// 単一グループの並列処理(Semaphore で同時実行数を制限)
470    async fn process_group(
471        &self,
472        requests: Vec<(WorkerId, WorkerDecisionRequest)>,
473    ) -> BatchProcessResult {
474        use futures::future::join_all;
475        use tokio::sync::Semaphore;
476
477        // サーバーからスロット数を取得、取得できなければconfig値を使用
478        let max_concurrency = self
479            .decider
480            .max_concurrency()
481            .await
482            .unwrap_or(self.config.max_concurrency);
483
484        let semaphore = Arc::new(Semaphore::new(max_concurrency));
485
486        let futures: Vec<_> = requests
487            .into_iter()
488            .map(|(worker_id, req)| {
489                let decider = Arc::clone(&self.decider);
490                let sem = Arc::clone(&semaphore);
491                async move {
492                    // スロットを取得してから実行
493                    let _permit = sem.acquire().await.expect("Semaphore closed");
494                    let result = decider.decide(req).await;
495                    (worker_id, result)
496                }
497            })
498            .collect();
499
500        let results = join_all(futures).await;
501
502        results
503            .into_iter()
504            .map(|(worker_id, result)| {
505                let mapped = result.map_err(BatchProcessError::from);
506                (worker_id, mapped)
507            })
508            .collect()
509    }
510
511    /// 順次実行
512    async fn process_sequential(
513        &self,
514        requests: Vec<(WorkerId, WorkerDecisionRequest)>,
515    ) -> BatchProcessResult {
516        let mut results = Vec::with_capacity(requests.len());
517
518        for (worker_id, req) in requests {
519            let result = self.decider.decide(req).await;
520            let mapped = result.map_err(BatchProcessError::from);
521            results.push((worker_id, mapped));
522        }
523
524        results
525    }
526}
527
528/// リクエストを LoRA 設定でグルーピング
529///
530/// 同じ LoRA 設定(または LoRA なし)のリクエストをまとめる。
531/// HashMap の順序は不定だが、グループ内の順序は保持される。
532fn group_by_lora(
533    requests: Vec<(WorkerId, WorkerDecisionRequest)>,
534) -> HashMap<Option<LoraConfig>, Vec<(WorkerId, WorkerDecisionRequest)>> {
535    let mut groups: HashMap<Option<LoraConfig>, Vec<(WorkerId, WorkerDecisionRequest)>> =
536        HashMap::new();
537
538    for (worker_id, req) in requests {
539        let lora_key = req.lora.clone();
540        groups.entry(lora_key).or_default().push((worker_id, req));
541    }
542
543    groups
544}
545
546// ============================================================================
547// Helper Functions
548// ============================================================================
549
550/// Binary + Vote でアクションをソート(バッチ版)
551///
552/// 全ペア × 3回分のプロンプトを一括でバッチ送信し、結果を集計。
553/// 勝ち数でソート(勝ち数が少ない = 先に来る)。
554async fn binary_sort_actions<D: LlmDecider>(
555    task: &str,
556    actions: &[&ActionDef],
557    decider: &D,
558) -> Vec<String> {
559    use futures::future::join_all;
560    use std::collections::HashMap;
561
562    if actions.len() <= 1 {
563        return actions.iter().map(|a| a.name.clone()).collect();
564    }
565
566    // 全ペア × 3回分のリクエストを作成
567    // (pair_index, vote_index, prompt, a_name, b_name)
568    let mut requests: Vec<(usize, usize, String, String, String)> = Vec::new();
569    let mut pair_index = 0;
570
571    for i in 0..actions.len() {
572        for j in (i + 1)..actions.len() {
573            let a = actions[i];
574            let b = actions[j];
575            let prompt = format!(
576                "Goal: {}\n- {}: {}\n- {}: {}\nWhich comes first: {} or {}?\nAnswer (one word):",
577                task, a.name, a.description, b.name, b.description, a.name, b.name
578            );
579
580            // 同じペアを3回投げる
581            for vote_idx in 0..3 {
582                requests.push((
583                    pair_index,
584                    vote_idx,
585                    prompt.clone(),
586                    a.name.clone(),
587                    b.name.clone(),
588                ));
589            }
590            pair_index += 1;
591        }
592    }
593
594    let total_requests = requests.len();
595    tracing::debug!(
596        pairs = pair_index,
597        total_requests = total_requests,
598        "Binary sort: sending batch requests"
599    );
600
601    // 全リクエストを並列で送信
602    // Note: Binary sort does not use LoRA (base model only)
603    let futures: Vec<_> = requests
604        .into_iter()
605        .map(|(pair_idx, vote_idx, prompt, a_name, b_name)| {
606            let decider_ref = decider;
607            async move {
608                let result = decider_ref.call_raw(&prompt, None).await;
609                (pair_idx, vote_idx, result, a_name, b_name)
610            }
611        })
612        .collect();
613
614    let results = join_all(futures).await;
615
616    // ペアごとに投票結果を集計
617    // pair_index -> (a_count, b_count, a_name, b_name)
618    let mut pair_votes: HashMap<usize, (usize, usize, String, String)> = HashMap::new();
619
620    for (pair_idx, _vote_idx, result, a_name, b_name) in results {
621        let entry = pair_votes
622            .entry(pair_idx)
623            .or_insert((0, 0, a_name.clone(), b_name.clone()));
624
625        if let Ok(response) = result {
626            let response_upper = response.to_uppercase();
627            let a_upper = a_name.to_uppercase();
628            let b_upper = b_name.to_uppercase();
629
630            if response_upper.contains(&a_upper) {
631                entry.0 += 1;
632            } else if response_upper.contains(&b_upper) {
633                entry.1 += 1;
634            }
635        }
636    }
637
638    // 各アクションの「勝ち数」をカウント
639    let mut wins: HashMap<String, usize> = HashMap::new();
640    for a in actions {
641        wins.insert(a.name.clone(), 0);
642    }
643
644    for (_pair_idx, (a_count, b_count, a_name, b_name)) in pair_votes {
645        // winner = 「先に来る方」なので、もう一方が「後」= 勝ち
646        if a_count >= b_count {
647            // a が先 → b に勝ち+1
648            *wins.get_mut(&b_name).unwrap() += 1;
649        } else {
650            // b が先 → a に勝ち+1
651            *wins.get_mut(&a_name).unwrap() += 1;
652        }
653    }
654
655    // 勝ち数が少ない順にソート(先に来るものが少ない)
656    let mut sorted: Vec<_> = wins.into_iter().collect();
657    sorted.sort_by_key(|(_, count)| *count);
658
659    tracing::debug!(
660        sorted = ?sorted.iter().map(|(n, c)| format!("{}:{}", n, c)).collect::<Vec<_>>(),
661        "Binary sort completed"
662    );
663
664    sorted.into_iter().map(|(name, _)| name).collect()
665}
666
667// ============================================================================
668// Tests
669// ============================================================================
670
671#[cfg(test)]
672mod tests {
673    use super::*;
674
675    #[test]
676    fn test_batch_process_error_transient() {
677        let err = BatchProcessError::transient("connection timeout");
678        assert!(err.is_transient());
679        assert_eq!(err.message(), "connection timeout");
680    }
681
682    #[test]
683    fn test_batch_process_error_permanent() {
684        let err = BatchProcessError::permanent("invalid model");
685        assert!(!err.is_transient());
686        assert_eq!(err.message(), "invalid model");
687    }
688
689    #[test]
690    fn test_batch_process_error_from_llm_error() {
691        let llm_err = LlmError::transient("timeout");
692        let batch_err: BatchProcessError = llm_err.into();
693        assert!(batch_err.is_transient());
694        assert_eq!(batch_err.message(), "timeout");
695    }
696
697    #[test]
698    fn test_ollama_batch_processor_config_default() {
699        let config = LlmBatchProcessorConfig::default();
700        assert!(config.parallel);
701        assert_eq!(config.max_concurrency, 4);
702    }
703
704    // =========================================================================
705    // Binary Sort Tests
706    // =========================================================================
707
708    use std::collections::HashMap;
709
710    /// 同期版の binary_sort (テスト用)
711    /// wins の計算ロジックをテスト
712    fn binary_sort_sync(
713        actions: &[&str],
714        // (a, b) -> winner (先に来る方)
715        comparator: impl Fn(&str, &str) -> String,
716    ) -> Vec<String> {
717        if actions.len() <= 1 {
718            return actions.iter().map(|s| s.to_string()).collect();
719        }
720
721        let mut wins: HashMap<String, usize> = HashMap::new();
722        for &a in actions {
723            wins.insert(a.to_string(), 0);
724        }
725
726        for i in 0..actions.len() {
727            for j in (i + 1)..actions.len() {
728                let a = actions[i];
729                let b = actions[j];
730                let winner = comparator(a, b);
731
732                // winner = 先に来る方 → もう一方が後 = 勝ち
733                if winner == a {
734                    *wins.get_mut(b).unwrap() += 1;
735                } else {
736                    *wins.get_mut(a).unwrap() += 1;
737                }
738            }
739        }
740
741        let mut sorted: Vec<_> = wins.into_iter().collect();
742        sorted.sort_by_key(|(_, count)| *count);
743        sorted.into_iter().map(|(name, _)| name).collect()
744    }
745
746    #[test]
747    fn test_binary_sort_two_actions() {
748        // Fetch が先、Summarize が後
749        let result = binary_sort_sync(
750            &["Fetch", "Summarize"],
751            |a, _b| a.to_string(), // 常に a が先
752        );
753        assert_eq!(result, vec!["Fetch", "Summarize"]);
754
755        // Summarize が先、Fetch が後
756        let result = binary_sort_sync(
757            &["Fetch", "Summarize"],
758            |_a, b| b.to_string(), // 常に b が先
759        );
760        assert_eq!(result, vec!["Summarize", "Fetch"]);
761    }
762
763    #[test]
764    fn test_binary_sort_three_actions() {
765        // Test -> Deploy の順
766        // comparator: 常に正しい順序を返す
767        let result = binary_sort_sync(&["Test", "Deploy", "Build"], |a, b| {
768            let order = ["Build", "Test", "Deploy"];
769            let a_idx = order.iter().position(|&x| x == a).unwrap();
770            let b_idx = order.iter().position(|&x| x == b).unwrap();
771            if a_idx < b_idx {
772                a.to_string()
773            } else {
774                b.to_string()
775            }
776        });
777        assert_eq!(result, vec!["Build", "Test", "Deploy"]);
778    }
779
780    #[test]
781    fn test_binary_sort_wins_calculation() {
782        // 3つのアクション: A, B, C
783        // 正しい順序: A -> B -> C
784        // 比較結果:
785        //   A vs B -> A が先 -> B に+1
786        //   A vs C -> A が先 -> C に+1
787        //   B vs C -> B が先 -> C に+1
788        // wins = {A: 0, B: 1, C: 2}
789        // ソート後: A(0), B(1), C(2)
790
791        let mut wins: HashMap<String, usize> = HashMap::new();
792        wins.insert("A".to_string(), 0);
793        wins.insert("B".to_string(), 0);
794        wins.insert("C".to_string(), 0);
795
796        // A vs B: A が先 → B に+1
797        *wins.get_mut("B").unwrap() += 1;
798        // A vs C: A が先 → C に+1
799        *wins.get_mut("C").unwrap() += 1;
800        // B vs C: B が先 → C に+1
801        *wins.get_mut("C").unwrap() += 1;
802
803        assert_eq!(wins["A"], 0);
804        assert_eq!(wins["B"], 1);
805        assert_eq!(wins["C"], 2);
806
807        let mut sorted: Vec<_> = wins.into_iter().collect();
808        sorted.sort_by_key(|(_, count)| *count);
809        let result: Vec<_> = sorted.into_iter().map(|(name, _)| name).collect();
810
811        assert_eq!(result, vec!["A", "B", "C"]);
812    }
813
814    /// response から winner を抽出するロジックのテスト
815    fn extract_winner(response: &str, a: &str, b: &str) -> Option<String> {
816        let response_upper = response.to_uppercase();
817        let a_upper = a.to_uppercase();
818        let b_upper = b.to_uppercase();
819
820        if response_upper.contains(&a_upper) {
821            Some(a.to_string())
822        } else if response_upper.contains(&b_upper) {
823            Some(b.to_string())
824        } else {
825            None
826        }
827    }
828
829    #[test]
830    fn test_extract_winner() {
831        // 正常ケース
832        assert_eq!(
833            extract_winner("Fetch", "Fetch", "Summarize"),
834            Some("Fetch".to_string())
835        );
836        assert_eq!(
837            extract_winner("Summarize", "Fetch", "Summarize"),
838            Some("Summarize".to_string())
839        );
840
841        // 先頭スペース
842        assert_eq!(
843            extract_winner(" Fetch", "Fetch", "Summarize"),
844            Some("Fetch".to_string())
845        );
846
847        // 大文字小文字
848        assert_eq!(
849            extract_winner("fetch", "Fetch", "Summarize"),
850            Some("Fetch".to_string())
851        );
852        assert_eq!(
853            extract_winner("FETCH", "Fetch", "Summarize"),
854            Some("Fetch".to_string())
855        );
856
857        // 文中に含まれる
858        assert_eq!(
859            extract_winner("The answer is Fetch.", "Fetch", "Summarize"),
860            Some("Fetch".to_string())
861        );
862
863        // どちらも含まれない
864        assert_eq!(extract_winner("Unknown", "Fetch", "Summarize"), None);
865
866        // 両方含まれる場合は先にマッチした方
867        assert_eq!(
868            extract_winner("Fetch then Summarize", "Fetch", "Summarize"),
869            Some("Fetch".to_string())
870        );
871    }
872
873    #[test]
874    fn test_vote_majority() {
875        // 3回の投票で多数決
876        fn vote_majority(responses: &[&str], a: &str, b: &str) -> String {
877            let mut a_count = 0;
878            let mut b_count = 0;
879
880            for response in responses {
881                if let Some(winner) = extract_winner(response, a, b) {
882                    if winner == a {
883                        a_count += 1;
884                    } else {
885                        b_count += 1;
886                    }
887                }
888            }
889
890            if a_count >= b_count {
891                a.to_string()
892            } else {
893                b.to_string()
894            }
895        }
896
897        // 3回とも Fetch
898        assert_eq!(
899            vote_majority(&["Fetch", "Fetch", "Fetch"], "Fetch", "Summarize"),
900            "Fetch"
901        );
902
903        // 2回 Fetch, 1回 Summarize
904        assert_eq!(
905            vote_majority(&["Fetch", "Summarize", "Fetch"], "Fetch", "Summarize"),
906            "Fetch"
907        );
908
909        // 2回 Summarize, 1回 Fetch
910        assert_eq!(
911            vote_majority(&["Summarize", "Summarize", "Fetch"], "Fetch", "Summarize"),
912            "Summarize"
913        );
914
915        // 同数の場合は a (Fetch) を返す
916        assert_eq!(
917            vote_majority(&["Fetch", "Summarize", "Unknown"], "Fetch", "Summarize"),
918            "Fetch"
919        );
920    }
921
922    // =========================================================================
923    // LoRA Grouping Tests
924    // =========================================================================
925
926    use swarm_engine_core::context::{ContextTarget, GlobalContext, ResolvedContext};
927
928    fn create_test_request(
929        worker_id: usize,
930        lora: Option<LoraConfig>,
931    ) -> (WorkerId, WorkerDecisionRequest) {
932        let global = GlobalContext {
933            tick: 0,
934            max_ticks: 100,
935            progress: 0.0,
936            success_rate: 0.0,
937            task_description: Some("test".to_string()),
938            hint: None,
939        };
940        let context = ResolvedContext::new(global, ContextTarget::Worker(WorkerId(worker_id)));
941
942        (
943            WorkerId(worker_id),
944            WorkerDecisionRequest {
945                worker_id: WorkerId(worker_id),
946                query: format!("query_{}", worker_id),
947                context,
948                lora,
949            },
950        )
951    }
952
953    #[test]
954    fn test_group_by_lora_single_group_no_lora() {
955        let requests = vec![
956            create_test_request(0, None),
957            create_test_request(1, None),
958            create_test_request(2, None),
959        ];
960
961        let groups = group_by_lora(requests);
962
963        assert_eq!(groups.len(), 1);
964        assert!(groups.contains_key(&None));
965        assert_eq!(groups[&None].len(), 3);
966    }
967
968    #[test]
969    fn test_group_by_lora_single_group_with_lora() {
970        let lora = LoraConfig::with_id(0);
971        let requests = vec![
972            create_test_request(0, Some(lora.clone())),
973            create_test_request(1, Some(lora.clone())),
974        ];
975
976        let groups = group_by_lora(requests);
977
978        assert_eq!(groups.len(), 1);
979        assert!(groups.contains_key(&Some(lora)));
980    }
981
982    #[test]
983    fn test_group_by_lora_multiple_groups() {
984        let lora_a = LoraConfig::with_id(0);
985        let lora_b = LoraConfig::with_id(1);
986
987        let requests = vec![
988            create_test_request(0, Some(lora_a.clone())),
989            create_test_request(1, Some(lora_b.clone())),
990            create_test_request(2, Some(lora_a.clone())),
991            create_test_request(3, None),
992            create_test_request(4, Some(lora_b.clone())),
993        ];
994
995        let groups = group_by_lora(requests);
996
997        assert_eq!(groups.len(), 3);
998        assert_eq!(groups[&Some(lora_a)].len(), 2);
999        assert_eq!(groups[&Some(lora_b)].len(), 2);
1000        assert_eq!(groups[&None].len(), 1);
1001    }
1002
1003    #[test]
1004    fn test_group_by_lora_preserves_order_within_group() {
1005        let lora = LoraConfig::with_id(0);
1006        let requests = vec![
1007            create_test_request(5, Some(lora.clone())),
1008            create_test_request(3, Some(lora.clone())),
1009            create_test_request(7, Some(lora.clone())),
1010        ];
1011
1012        let groups = group_by_lora(requests);
1013        let group = &groups[&Some(lora)];
1014
1015        // グループ内の順序は保持される
1016        assert_eq!(group[0].0, WorkerId(5));
1017        assert_eq!(group[1].0, WorkerId(3));
1018        assert_eq!(group[2].0, WorkerId(7));
1019    }
1020
1021    #[test]
1022    fn test_group_by_lora_different_scales() {
1023        // 同じ ID でも scale が違えば別グループ
1024        let lora_full = LoraConfig::new(0, 1.0);
1025        let lora_half = LoraConfig::new(0, 0.5);
1026
1027        let requests = vec![
1028            create_test_request(0, Some(lora_full.clone())),
1029            create_test_request(1, Some(lora_half.clone())),
1030            create_test_request(2, Some(lora_full.clone())),
1031        ];
1032
1033        let groups = group_by_lora(requests);
1034
1035        assert_eq!(groups.len(), 2);
1036        assert_eq!(groups[&Some(lora_full)].len(), 2);
1037        assert_eq!(groups[&Some(lora_half)].len(), 1);
1038    }
1039
1040    #[test]
1041    fn test_group_by_lora_empty() {
1042        let requests: Vec<(WorkerId, WorkerDecisionRequest)> = vec![];
1043        let groups = group_by_lora(requests);
1044        assert!(groups.is_empty());
1045    }
1046}