Skip to main content

swarm_engine_llm/
batch_processor.rs

1//! Batch Processor - ManagerAgent 向け Batch LLM 処理
2//!
3//! ManagerAgent の `BatchDecisionRequest` を処理するための抽象化レイヤー。
4//!
5//! # 設計
6//!
7//! ```text
8//! Core Layer
9//! ├── ManagerAgent trait (prepare / finalize)
10//! ├── DefaultBatchManagerAgent   ← Core層のデフォルト実装
11//! ├── ContextStore / ContextView ← 正規化されたコンテキスト
12//! └── ContextResolver            ← スコープ解決
13//!
14//! LLM Layer
15//! ├── PromptBuilder              ← ResolvedContext → プロンプト
16//! ├── BatchProcessor trait       ← Batch 処理の抽象
17//! │   └── LlmBatchProcessor   ← Ollama 実装(仮想バッチ)
18//! │
19//! └── BatchInvoker 実装          ← LLM Batch 呼び出し
20//!     └── OllamaBatchInvoker
21//! ```
22//!
23//! # 仮想バッチ vs 真のバッチ
24//!
25//! Ollama は真の Batch API を持たないため、`LlmBatchProcessor` は
26//! 内部で並列/順次処理を行う「仮想バッチ」として実装されます。
27//! 将来 vLLM 等の真の Batch API を持つバックエンドに切り替える際は、
28//! `BatchProcessor` trait の別実装を提供するだけで対応可能です。
29//!
30//! # 型の統一
31//!
32//! LLM層はCore層の型を直接使用するため、変換ロジックは不要です:
33//! - `WorkerDecisionRequest` - リクエスト
34//! - `DecisionResponse` - レスポンス
35
36use std::future::Future;
37use std::pin::Pin;
38use std::sync::Arc;
39
40use std::collections::HashMap;
41
42use swarm_engine_core::actions::ActionDef;
43use swarm_engine_core::agent::{BatchDecisionRequest, DecisionResponse, WorkerDecisionRequest};
44use swarm_engine_core::exploration::{DependencyGraph, SelectResult};
45use swarm_engine_core::types::{LoraConfig, WorkerId};
46
47use crate::decider::{LlmDecider, LlmError};
48
49// ============================================================================
50// BatchProcessor Trait
51// ============================================================================
52
53/// Batch 処理結果
54pub type BatchProcessResult = Vec<(WorkerId, Result<DecisionResponse, BatchProcessError>)>;
55
56/// Batch 処理エラー
57#[derive(Debug, Clone, thiserror::Error)]
58pub enum BatchProcessError {
59    /// 一時的エラー(リトライ可能)
60    #[error("Batch process error (transient): {0}")]
61    Transient(String),
62
63    /// 恒久的エラー(リトライ不可)
64    #[error("Batch process error: {0}")]
65    Permanent(String),
66}
67
68impl BatchProcessError {
69    pub fn transient(message: impl Into<String>) -> Self {
70        Self::Transient(message.into())
71    }
72
73    pub fn permanent(message: impl Into<String>) -> Self {
74        Self::Permanent(message.into())
75    }
76
77    pub fn is_transient(&self) -> bool {
78        matches!(self, Self::Transient(_))
79    }
80
81    pub fn message(&self) -> &str {
82        match self {
83            Self::Transient(msg) => msg,
84            Self::Permanent(msg) => msg,
85        }
86    }
87}
88
89impl From<LlmError> for BatchProcessError {
90    fn from(e: LlmError) -> Self {
91        if e.is_transient() {
92            Self::Transient(e.message().to_string())
93        } else {
94            Self::Permanent(e.message().to_string())
95        }
96    }
97}
98
99impl From<swarm_engine_core::error::SwarmError> for BatchProcessError {
100    fn from(err: swarm_engine_core::error::SwarmError) -> Self {
101        if err.is_transient() {
102            Self::Transient(err.message())
103        } else {
104            Self::Permanent(err.message())
105        }
106    }
107}
108
109impl From<BatchProcessError> for swarm_engine_core::error::SwarmError {
110    fn from(err: BatchProcessError) -> Self {
111        match err {
112            BatchProcessError::Transient(message) => {
113                swarm_engine_core::error::SwarmError::LlmTransient { message }
114            }
115            BatchProcessError::Permanent(message) => {
116                swarm_engine_core::error::SwarmError::LlmPermanent { message }
117            }
118        }
119    }
120}
121
122/// Batch Processor trait
123///
124/// `BatchDecisionRequest` を受け取り、各 Worker への決定を返す。
125/// バックエンド(Ollama, vLLM 等)に応じた実装を提供する。
126pub trait BatchProcessor: Send + Sync {
127    /// Batch リクエストを処理
128    ///
129    /// # Arguments
130    /// * `request` - Core の `BatchDecisionRequest`
131    ///
132    /// # Returns
133    /// 各 Worker への決定結果(WorkerId とペア)
134    fn process(
135        &self,
136        request: BatchDecisionRequest,
137    ) -> Pin<Box<dyn Future<Output = BatchProcessResult> + Send + '_>>;
138
139    /// タスクとアクション一覧からアクション依存グラフを生成
140    ///
141    /// Swarm の Ticks 開始前に呼び出され、アクション間の依存関係を計画する。
142    /// LLM を使用して動的に依存グラフを生成する。
143    ///
144    /// # Arguments
145    ///
146    /// * `task` - タスク説明
147    /// * `actions` - 利用可能なアクション一覧
148    /// * `hint` - `LearnedDependencyProvider.select()` の結果(LoRA、vote_count 等を含む)
149    ///
150    /// # Default
151    /// デフォルトでは None を返す(依存グラフ生成をスキップ)。
152    fn plan_dependencies(
153        &self,
154        _task: &str,
155        _actions: &[ActionDef],
156        _hint: Option<&SelectResult>,
157    ) -> Pin<Box<dyn Future<Output = Option<DependencyGraph>> + Send + '_>> {
158        Box::pin(async { None })
159    }
160
161    /// ヘルスチェック
162    fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>>;
163
164    /// プロセッサ名
165    fn name(&self) -> &str;
166}
167
168// ============================================================================
169// LlmBatchProcessor
170// ============================================================================
171
172/// Ollama Batch Processor 設定
173#[derive(Debug, Clone)]
174pub struct LlmBatchProcessorConfig {
175    /// 並列実行するか(false の場合は順次処理)
176    pub parallel: bool,
177    /// 並列実行時の最大同時実行数
178    pub max_concurrency: usize,
179    /// DependencyGraph 生成時の最大リトライ回数
180    pub max_retries: Option<usize>,
181}
182
183impl Default for LlmBatchProcessorConfig {
184    fn default() -> Self {
185        Self {
186            parallel: true,
187            max_concurrency: 4,
188            max_retries: Some(5),
189        }
190    }
191}
192
193/// Ollama Batch Processor
194///
195/// Ollama は真の Batch API を持たないため、仮想バッチとして実装。
196/// 内部で `LlmDecider` を使用して並列/順次処理を行う。
197pub struct LlmBatchProcessor<D: LlmDecider> {
198    decider: Arc<D>,
199    config: LlmBatchProcessorConfig,
200}
201
202impl<D: LlmDecider> LlmBatchProcessor<D> {
203    /// 新しい LlmBatchProcessor を作成
204    pub fn new(decider: D) -> Self {
205        Self {
206            decider: Arc::new(decider),
207            config: LlmBatchProcessorConfig::default(),
208        }
209    }
210
211    /// Arc でラップされた Decider から作成
212    pub fn from_arc(decider: Arc<D>) -> Self {
213        Self {
214            decider,
215            config: LlmBatchProcessorConfig::default(),
216        }
217    }
218
219    /// 設定を指定して作成
220    pub fn with_config(mut self, config: LlmBatchProcessorConfig) -> Self {
221        self.config = config;
222        self
223    }
224}
225
226impl<D: LlmDecider + 'static> BatchProcessor for LlmBatchProcessor<D> {
227    fn process(
228        &self,
229        request: BatchDecisionRequest,
230    ) -> Pin<Box<dyn Future<Output = BatchProcessResult> + Send + '_>> {
231        Box::pin(async move {
232            if request.requests.is_empty() {
233                return vec![];
234            }
235
236            // Core の WorkerDecisionRequest をそのまま使用(変換不要)
237            let requests: Vec<(WorkerId, WorkerDecisionRequest)> = request
238                .requests
239                .into_iter()
240                .map(|r| (r.worker_id, r))
241                .collect();
242
243            if self.config.parallel {
244                self.process_parallel(requests).await
245            } else {
246                self.process_sequential(requests).await
247            }
248        })
249    }
250
251    fn plan_dependencies(
252        &self,
253        task: &str,
254        actions: &[ActionDef],
255        hint: Option<&SelectResult>,
256    ) -> Pin<Box<dyn Future<Output = Option<DependencyGraph>> + Send + '_>> {
257        let task = task.to_string();
258        let actions: Vec<ActionDef> = actions.to_vec();
259        let decider = Arc::clone(&self.decider);
260
261        // SelectResult から lora と vote_count を抽出
262        let (lora, vote_count) = match hint {
263            Some(SelectResult::UseLlm {
264                lora,
265                vote_count,
266                match_rate,
267                ..
268            }) => {
269                tracing::debug!(
270                    match_rate = match_rate,
271                    vote_count = vote_count,
272                    has_lora = lora.is_some(),
273                    "Using SelectResult hint for plan_dependencies"
274                );
275                (lora.clone(), *vote_count)
276            }
277            _ => {
278                // ヒントなし or UseLearnedGraph(呼ばれるべきでない)
279                tracing::debug!("No SelectResult hint, using defaults (lora=None, vote_count=3)");
280                (None, 3)
281            }
282        };
283
284        Box::pin(async move {
285            use std::time::Instant;
286            use swarm_engine_core::actions::ActionCategory;
287            use swarm_engine_core::exploration::DependencyGraphBuilder;
288
289            let start_time = Instant::now();
290            let action_names: Vec<String> = actions.iter().map(|a| a.name.clone()).collect();
291
292            // 1. Discover (NodeExpand) と NotDiscover (NodeStateChange) に分離
293            let discover: Vec<&ActionDef> = actions
294                .iter()
295                .filter(|a| a.category == ActionCategory::NodeExpand)
296                .collect();
297            let not_discover: Vec<&ActionDef> = actions
298                .iter()
299                .filter(|a| a.category == ActionCategory::NodeStateChange)
300                .collect();
301
302            tracing::debug!(
303                discover = ?discover.iter().map(|a| &a.name).collect::<Vec<_>>(),
304                not_discover = ?not_discover.iter().map(|a| &a.name).collect::<Vec<_>>(),
305                "Separated actions by category"
306            );
307
308            // 2. Discover を Binary + Vote でソート(SelectResult の vote_count と lora を使用)
309            let discover_sort_start = Instant::now();
310            let sorted_discover = if discover.len() <= 1 {
311                discover.iter().map(|a| a.name.clone()).collect()
312            } else {
313                binary_sort_actions(&task, &discover, decider.as_ref(), lora.as_ref(), vote_count).await
314            };
315            let discover_sort_ms = discover_sort_start.elapsed().as_millis();
316
317            tracing::debug!(
318                sorted = ?sorted_discover,
319                elapsed_ms = discover_sort_ms,
320                vote_count = vote_count,
321                has_lora = lora.is_some(),
322                "Sorted Discover actions via binary comparison"
323            );
324
325            // 3. NotDiscover を Binary + Vote でソート(同じく SelectResult のパラメータを使用)
326            let not_discover_sort_start = Instant::now();
327            let sorted_not_discover = if not_discover.len() <= 1 {
328                not_discover.iter().map(|a| a.name.clone()).collect()
329            } else {
330                binary_sort_actions(&task, &not_discover, decider.as_ref(), lora.as_ref(), vote_count).await
331            };
332            let not_discover_sort_ms = not_discover_sort_start.elapsed().as_millis();
333
334            tracing::debug!(
335                sorted = ?sorted_not_discover,
336                elapsed_ms = not_discover_sort_ms,
337                "Sorted NotDiscover actions via binary comparison"
338            );
339
340            // 4. グラフ構築: Discover(線形)→ NotDiscover(線形)
341            let mut builder = DependencyGraphBuilder::new()
342                .task(&task)
343                .available_actions(action_names.clone());
344
345            // 最初の Discover を Start node として設定
346            if !sorted_discover.is_empty() {
347                builder = builder.start_node(&sorted_discover[0]);
348            } else if !sorted_not_discover.is_empty() {
349                // Discover がなければ最初の NotDiscover を Start に
350                builder = builder.start_node(&sorted_not_discover[0]);
351            }
352
353            // NotDiscover の最後を Terminal に
354            if let Some(last) = sorted_not_discover.last() {
355                builder = builder.terminal_node(last);
356            } else if !sorted_discover.is_empty() {
357                // NotDiscover がなければ最後の Discover を Terminal に
358                builder = builder.terminal_node(sorted_discover.last().unwrap());
359            }
360
361            // Discover 間のエッジ(線形)
362            for window in sorted_discover.windows(2) {
363                builder = builder.edge(&window[0], &window[1], 0.9);
364            }
365
366            // 最後の Discover → 最初の NotDiscover へのエッジ
367            if !sorted_discover.is_empty() && !sorted_not_discover.is_empty() {
368                builder = builder.edge(
369                    sorted_discover.last().unwrap(),
370                    &sorted_not_discover[0],
371                    0.9,
372                );
373            }
374
375            // NotDiscover 間のエッジ(線形)
376            for window in sorted_not_discover.windows(2) {
377                builder = builder.edge(&window[0], &window[1], 0.9);
378            }
379
380            let mut graph = builder.build();
381            let total_ms = start_time.elapsed().as_millis();
382
383            // Store action order for caching
384            graph.set_action_order(sorted_discover.clone(), sorted_not_discover.clone());
385
386            // Create learning record for DependencyGraph inference
387            {
388                use swarm_engine_core::events::{LearningEvent, LearningEventChannel};
389                use swarm_engine_core::learn::DependencyGraphRecord;
390
391                // Build a summary prompt representing the inference input
392                let prompt = format!(
393                    "Task: {}\n\nAvailable Actions:\n{}",
394                    task,
395                    action_names
396                        .iter()
397                        .map(|n| format!("- {}", n))
398                        .collect::<Vec<_>>()
399                        .join("\n")
400                );
401
402                // Build a summary response representing the inference output
403                let response = format!(
404                    "discover_order: {:?}\nnot_discover_order: {:?}",
405                    sorted_discover, sorted_not_discover
406                );
407
408                // Create Event and emit to LearningEventChannel
409                let event = LearningEvent::dependency_graph_inference(decider.model_name())
410                    .prompt(&prompt)
411                    .response(&response)
412                    .available_actions(action_names)
413                    .discover_order(sorted_discover.clone())
414                    .not_discover_order(sorted_not_discover.clone())
415                    .endpoint(decider.endpoint())
416                    .latency_ms(total_ms as u64)
417                    .success()
418                    .build();
419
420                // Emit event for learning pipeline
421                LearningEventChannel::global().emit(event.clone());
422
423                // Convert Event to Record for graph storage
424                let record = DependencyGraphRecord::from(&event);
425                graph.set_learn_record(record);
426            }
427
428            tracing::info!(
429                discover_order = ?sorted_discover,
430                not_discover_order = ?sorted_not_discover,
431                edges = graph.edges().len(),
432                discover_sort_ms = discover_sort_ms,
433                not_discover_sort_ms = not_discover_sort_ms,
434                total_ms = total_ms,
435                "DependencyGraph generated via LLM binary sort"
436            );
437
438            Some(graph)
439        })
440    }
441
442    fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
443        let decider = Arc::clone(&self.decider);
444        Box::pin(async move { decider.is_healthy().await })
445    }
446
447    fn name(&self) -> &str {
448        self.decider.model_name()
449    }
450}
451
452impl<D: LlmDecider + 'static> LlmBatchProcessor<D> {
453    /// 並列実行(LoRA グルーピング + Semaphore で同時実行数を制限)
454    ///
455    /// # LoRA グルーピング
456    ///
457    /// llama.cpp の continuous batching では、同じ LoRA 設定のリクエストは
458    /// 効率的にバッチ処理される。異なる LoRA を混ぜると効率が落ちるため、
459    /// リクエストを LoRA 設定でグルーピングして処理する。
460    ///
461    /// ```text
462    /// リクエスト群
463    /// ├── LoRA A のリクエスト群 → 並列実行(グループ内)
464    /// ├── LoRA B のリクエスト群 → 並列実行(グループ内)
465    /// └── LoRA なしのリクエスト群 → 並列実行(グループ内)
466    /// ```
467    ///
468    /// グループ間は順次処理(同じ LoRA を連続して処理することで効率化)
469    async fn process_parallel(
470        &self,
471        requests: Vec<(WorkerId, WorkerDecisionRequest)>,
472    ) -> BatchProcessResult {
473        // リクエストを LoRA 設定でグルーピング
474        let grouped = group_by_lora(requests);
475
476        let group_count = grouped.len();
477        if group_count > 1 {
478            tracing::debug!(
479                groups = group_count,
480                "Processing requests in {} LoRA groups",
481                group_count
482            );
483        }
484
485        // 各グループを順次処理(グループ内は並列)
486        let mut all_results = Vec::new();
487        for (lora_config, group_requests) in grouped {
488            if group_count > 1 {
489                tracing::trace!(
490                    lora = ?lora_config,
491                    count = group_requests.len(),
492                    "Processing LoRA group"
493                );
494            }
495            let results = self.process_group(group_requests).await;
496            all_results.extend(results);
497        }
498
499        all_results
500    }
501
502    /// 単一グループの並列処理(Semaphore で同時実行数を制限)
503    async fn process_group(
504        &self,
505        requests: Vec<(WorkerId, WorkerDecisionRequest)>,
506    ) -> BatchProcessResult {
507        use futures::future::join_all;
508        use tokio::sync::Semaphore;
509
510        // サーバーからスロット数を取得、取得できなければconfig値を使用
511        let max_concurrency = self
512            .decider
513            .max_concurrency()
514            .await
515            .unwrap_or(self.config.max_concurrency);
516
517        let semaphore = Arc::new(Semaphore::new(max_concurrency));
518
519        let futures: Vec<_> = requests
520            .into_iter()
521            .map(|(worker_id, req)| {
522                let decider = Arc::clone(&self.decider);
523                let sem = Arc::clone(&semaphore);
524                async move {
525                    // スロットを取得してから実行
526                    let _permit = sem.acquire().await.expect("Semaphore closed");
527                    let result = decider.decide(req).await;
528                    (worker_id, result)
529                }
530            })
531            .collect();
532
533        let results = join_all(futures).await;
534
535        results
536            .into_iter()
537            .map(|(worker_id, result)| {
538                let mapped = result.map_err(BatchProcessError::from);
539                (worker_id, mapped)
540            })
541            .collect()
542    }
543
544    /// 順次実行
545    async fn process_sequential(
546        &self,
547        requests: Vec<(WorkerId, WorkerDecisionRequest)>,
548    ) -> BatchProcessResult {
549        let mut results = Vec::with_capacity(requests.len());
550
551        for (worker_id, req) in requests {
552            let result = self.decider.decide(req).await;
553            let mapped = result.map_err(BatchProcessError::from);
554            results.push((worker_id, mapped));
555        }
556
557        results
558    }
559}
560
561/// リクエストを LoRA 設定でグルーピング
562///
563/// 同じ LoRA 設定(または LoRA なし)のリクエストをまとめる。
564/// HashMap の順序は不定だが、グループ内の順序は保持される。
565fn group_by_lora(
566    requests: Vec<(WorkerId, WorkerDecisionRequest)>,
567) -> HashMap<Option<LoraConfig>, Vec<(WorkerId, WorkerDecisionRequest)>> {
568    let mut groups: HashMap<Option<LoraConfig>, Vec<(WorkerId, WorkerDecisionRequest)>> =
569        HashMap::new();
570
571    for (worker_id, req) in requests {
572        let lora_key = req.lora.clone();
573        groups.entry(lora_key).or_default().push((worker_id, req));
574    }
575
576    groups
577}
578
579// ============================================================================
580// Helper Functions
581// ============================================================================
582
583/// Binary + Vote でアクションをソート(バッチ版)
584///
585/// 全ペア × N回分のプロンプトを一括でバッチ送信し、結果を集計。
586/// 勝ち数でソート(勝ち数が少ない = 先に来る)。
587///
588/// # Arguments
589///
590/// * `task` - タスクの説明
591/// * `actions` - ソート対象のアクション
592/// * `decider` - LLM 決定器
593/// * `lora` - LoRA 設定(None = Base Model)
594/// * `vote_count` - 投票回数(1 or 3、0 は呼び出し元で処理)
595async fn binary_sort_actions<D: LlmDecider>(
596    task: &str,
597    actions: &[&ActionDef],
598    decider: &D,
599    lora: Option<&LoraConfig>,
600    vote_count: u8,
601) -> Vec<String> {
602    use futures::future::join_all;
603    use std::collections::HashMap;
604
605    if actions.len() <= 1 {
606        return actions.iter().map(|a| a.name.clone()).collect();
607    }
608
609    // 全ペア × 3回分のリクエストを作成
610    // (pair_index, vote_index, prompt, a_name, b_name)
611    let mut requests: Vec<(usize, usize, String, String, String)> = Vec::new();
612    let mut pair_index = 0;
613
614    for i in 0..actions.len() {
615        for j in (i + 1)..actions.len() {
616            let a = actions[i];
617            let b = actions[j];
618            let prompt = format!(
619                "Goal: {}\n- {}: {}\n- {}: {}\nWhich comes first: {} or {}?\nAnswer (one word):",
620                task, a.name, a.description, b.name, b.description, a.name, b.name
621            );
622
623            // 同じペアを vote_count 回投げる
624            for vote_idx in 0..vote_count as usize {
625                requests.push((
626                    pair_index,
627                    vote_idx,
628                    prompt.clone(),
629                    a.name.clone(),
630                    b.name.clone(),
631                ));
632            }
633            pair_index += 1;
634        }
635    }
636
637    let total_requests = requests.len();
638    tracing::debug!(
639        pairs = pair_index,
640        total_requests = total_requests,
641        "Binary sort: sending batch requests"
642    );
643
644    // 全リクエストを並列で送信
645    // LoRA が指定されていれば適用
646    let futures: Vec<_> = requests
647        .into_iter()
648        .map(|(pair_idx, vote_idx, prompt, a_name, b_name)| {
649            let decider_ref = decider;
650            async move {
651                let result = decider_ref.call_raw(&prompt, lora).await;
652                (pair_idx, vote_idx, result, a_name, b_name)
653            }
654        })
655        .collect();
656
657    let results = join_all(futures).await;
658
659    // ペアごとに投票結果を集計
660    // pair_index -> (a_count, b_count, a_name, b_name)
661    let mut pair_votes: HashMap<usize, (usize, usize, String, String)> = HashMap::new();
662
663    for (pair_idx, _vote_idx, result, a_name, b_name) in results {
664        let entry = pair_votes
665            .entry(pair_idx)
666            .or_insert((0, 0, a_name.clone(), b_name.clone()));
667
668        if let Ok(response) = result {
669            let response_upper = response.to_uppercase();
670            let a_upper = a_name.to_uppercase();
671            let b_upper = b_name.to_uppercase();
672
673            if response_upper.contains(&a_upper) {
674                entry.0 += 1;
675            } else if response_upper.contains(&b_upper) {
676                entry.1 += 1;
677            }
678        }
679    }
680
681    // 各アクションの「勝ち数」をカウント
682    let mut wins: HashMap<String, usize> = HashMap::new();
683    for a in actions {
684        wins.insert(a.name.clone(), 0);
685    }
686
687    for (_pair_idx, (a_count, b_count, a_name, b_name)) in pair_votes {
688        // winner = 「先に来る方」なので、もう一方が「後」= 勝ち
689        if a_count >= b_count {
690            // a が先 → b に勝ち+1
691            *wins.get_mut(&b_name).unwrap() += 1;
692        } else {
693            // b が先 → a に勝ち+1
694            *wins.get_mut(&a_name).unwrap() += 1;
695        }
696    }
697
698    // 勝ち数が少ない順にソート(先に来るものが少ない)
699    let mut sorted: Vec<_> = wins.into_iter().collect();
700    sorted.sort_by_key(|(_, count)| *count);
701
702    tracing::debug!(
703        sorted = ?sorted.iter().map(|(n, c)| format!("{}:{}", n, c)).collect::<Vec<_>>(),
704        "Binary sort completed"
705    );
706
707    sorted.into_iter().map(|(name, _)| name).collect()
708}
709
710// ============================================================================
711// Tests
712// ============================================================================
713
714#[cfg(test)]
715mod tests {
716    use super::*;
717
718    #[test]
719    fn test_batch_process_error_transient() {
720        let err = BatchProcessError::transient("connection timeout");
721        assert!(err.is_transient());
722        assert_eq!(err.message(), "connection timeout");
723    }
724
725    #[test]
726    fn test_batch_process_error_permanent() {
727        let err = BatchProcessError::permanent("invalid model");
728        assert!(!err.is_transient());
729        assert_eq!(err.message(), "invalid model");
730    }
731
732    #[test]
733    fn test_batch_process_error_from_llm_error() {
734        let llm_err = LlmError::transient("timeout");
735        let batch_err: BatchProcessError = llm_err.into();
736        assert!(batch_err.is_transient());
737        assert_eq!(batch_err.message(), "timeout");
738    }
739
740    #[test]
741    fn test_ollama_batch_processor_config_default() {
742        let config = LlmBatchProcessorConfig::default();
743        assert!(config.parallel);
744        assert_eq!(config.max_concurrency, 4);
745    }
746
747    // =========================================================================
748    // Binary Sort Tests
749    // =========================================================================
750
751    use std::collections::HashMap;
752
753    /// 同期版の binary_sort (テスト用)
754    /// wins の計算ロジックをテスト
755    fn binary_sort_sync(
756        actions: &[&str],
757        // (a, b) -> winner (先に来る方)
758        comparator: impl Fn(&str, &str) -> String,
759    ) -> Vec<String> {
760        if actions.len() <= 1 {
761            return actions.iter().map(|s| s.to_string()).collect();
762        }
763
764        let mut wins: HashMap<String, usize> = HashMap::new();
765        for &a in actions {
766            wins.insert(a.to_string(), 0);
767        }
768
769        for i in 0..actions.len() {
770            for j in (i + 1)..actions.len() {
771                let a = actions[i];
772                let b = actions[j];
773                let winner = comparator(a, b);
774
775                // winner = 先に来る方 → もう一方が後 = 勝ち
776                if winner == a {
777                    *wins.get_mut(b).unwrap() += 1;
778                } else {
779                    *wins.get_mut(a).unwrap() += 1;
780                }
781            }
782        }
783
784        let mut sorted: Vec<_> = wins.into_iter().collect();
785        sorted.sort_by_key(|(_, count)| *count);
786        sorted.into_iter().map(|(name, _)| name).collect()
787    }
788
789    #[test]
790    fn test_binary_sort_two_actions() {
791        // Fetch が先、Summarize が後
792        let result = binary_sort_sync(
793            &["Fetch", "Summarize"],
794            |a, _b| a.to_string(), // 常に a が先
795        );
796        assert_eq!(result, vec!["Fetch", "Summarize"]);
797
798        // Summarize が先、Fetch が後
799        let result = binary_sort_sync(
800            &["Fetch", "Summarize"],
801            |_a, b| b.to_string(), // 常に b が先
802        );
803        assert_eq!(result, vec!["Summarize", "Fetch"]);
804    }
805
806    #[test]
807    fn test_binary_sort_three_actions() {
808        // Test -> Deploy の順
809        // comparator: 常に正しい順序を返す
810        let result = binary_sort_sync(&["Test", "Deploy", "Build"], |a, b| {
811            let order = ["Build", "Test", "Deploy"];
812            let a_idx = order.iter().position(|&x| x == a).unwrap();
813            let b_idx = order.iter().position(|&x| x == b).unwrap();
814            if a_idx < b_idx {
815                a.to_string()
816            } else {
817                b.to_string()
818            }
819        });
820        assert_eq!(result, vec!["Build", "Test", "Deploy"]);
821    }
822
823    #[test]
824    fn test_binary_sort_wins_calculation() {
825        // 3つのアクション: A, B, C
826        // 正しい順序: A -> B -> C
827        // 比較結果:
828        //   A vs B -> A が先 -> B に+1
829        //   A vs C -> A が先 -> C に+1
830        //   B vs C -> B が先 -> C に+1
831        // wins = {A: 0, B: 1, C: 2}
832        // ソート後: A(0), B(1), C(2)
833
834        let mut wins: HashMap<String, usize> = HashMap::new();
835        wins.insert("A".to_string(), 0);
836        wins.insert("B".to_string(), 0);
837        wins.insert("C".to_string(), 0);
838
839        // A vs B: A が先 → B に+1
840        *wins.get_mut("B").unwrap() += 1;
841        // A vs C: A が先 → C に+1
842        *wins.get_mut("C").unwrap() += 1;
843        // B vs C: B が先 → C に+1
844        *wins.get_mut("C").unwrap() += 1;
845
846        assert_eq!(wins["A"], 0);
847        assert_eq!(wins["B"], 1);
848        assert_eq!(wins["C"], 2);
849
850        let mut sorted: Vec<_> = wins.into_iter().collect();
851        sorted.sort_by_key(|(_, count)| *count);
852        let result: Vec<_> = sorted.into_iter().map(|(name, _)| name).collect();
853
854        assert_eq!(result, vec!["A", "B", "C"]);
855    }
856
857    /// response から winner を抽出するロジックのテスト
858    fn extract_winner(response: &str, a: &str, b: &str) -> Option<String> {
859        let response_upper = response.to_uppercase();
860        let a_upper = a.to_uppercase();
861        let b_upper = b.to_uppercase();
862
863        if response_upper.contains(&a_upper) {
864            Some(a.to_string())
865        } else if response_upper.contains(&b_upper) {
866            Some(b.to_string())
867        } else {
868            None
869        }
870    }
871
872    #[test]
873    fn test_extract_winner() {
874        // 正常ケース
875        assert_eq!(
876            extract_winner("Fetch", "Fetch", "Summarize"),
877            Some("Fetch".to_string())
878        );
879        assert_eq!(
880            extract_winner("Summarize", "Fetch", "Summarize"),
881            Some("Summarize".to_string())
882        );
883
884        // 先頭スペース
885        assert_eq!(
886            extract_winner(" Fetch", "Fetch", "Summarize"),
887            Some("Fetch".to_string())
888        );
889
890        // 大文字小文字
891        assert_eq!(
892            extract_winner("fetch", "Fetch", "Summarize"),
893            Some("Fetch".to_string())
894        );
895        assert_eq!(
896            extract_winner("FETCH", "Fetch", "Summarize"),
897            Some("Fetch".to_string())
898        );
899
900        // 文中に含まれる
901        assert_eq!(
902            extract_winner("The answer is Fetch.", "Fetch", "Summarize"),
903            Some("Fetch".to_string())
904        );
905
906        // どちらも含まれない
907        assert_eq!(extract_winner("Unknown", "Fetch", "Summarize"), None);
908
909        // 両方含まれる場合は先にマッチした方
910        assert_eq!(
911            extract_winner("Fetch then Summarize", "Fetch", "Summarize"),
912            Some("Fetch".to_string())
913        );
914    }
915
916    #[test]
917    fn test_vote_majority() {
918        // 3回の投票で多数決
919        fn vote_majority(responses: &[&str], a: &str, b: &str) -> String {
920            let mut a_count = 0;
921            let mut b_count = 0;
922
923            for response in responses {
924                if let Some(winner) = extract_winner(response, a, b) {
925                    if winner == a {
926                        a_count += 1;
927                    } else {
928                        b_count += 1;
929                    }
930                }
931            }
932
933            if a_count >= b_count {
934                a.to_string()
935            } else {
936                b.to_string()
937            }
938        }
939
940        // 3回とも Fetch
941        assert_eq!(
942            vote_majority(&["Fetch", "Fetch", "Fetch"], "Fetch", "Summarize"),
943            "Fetch"
944        );
945
946        // 2回 Fetch, 1回 Summarize
947        assert_eq!(
948            vote_majority(&["Fetch", "Summarize", "Fetch"], "Fetch", "Summarize"),
949            "Fetch"
950        );
951
952        // 2回 Summarize, 1回 Fetch
953        assert_eq!(
954            vote_majority(&["Summarize", "Summarize", "Fetch"], "Fetch", "Summarize"),
955            "Summarize"
956        );
957
958        // 同数の場合は a (Fetch) を返す
959        assert_eq!(
960            vote_majority(&["Fetch", "Summarize", "Unknown"], "Fetch", "Summarize"),
961            "Fetch"
962        );
963    }
964
965    // =========================================================================
966    // LoRA Grouping Tests
967    // =========================================================================
968
969    use swarm_engine_core::context::{ContextTarget, GlobalContext, ResolvedContext};
970
971    fn create_test_request(
972        worker_id: usize,
973        lora: Option<LoraConfig>,
974    ) -> (WorkerId, WorkerDecisionRequest) {
975        let global = GlobalContext {
976            tick: 0,
977            max_ticks: 100,
978            progress: 0.0,
979            success_rate: 0.0,
980            task_description: Some("test".to_string()),
981            hint: None,
982        };
983        let context = ResolvedContext::new(global, ContextTarget::Worker(WorkerId(worker_id)));
984
985        (
986            WorkerId(worker_id),
987            WorkerDecisionRequest {
988                worker_id: WorkerId(worker_id),
989                query: format!("query_{}", worker_id),
990                context,
991                lora,
992            },
993        )
994    }
995
996    #[test]
997    fn test_group_by_lora_single_group_no_lora() {
998        let requests = vec![
999            create_test_request(0, None),
1000            create_test_request(1, None),
1001            create_test_request(2, None),
1002        ];
1003
1004        let groups = group_by_lora(requests);
1005
1006        assert_eq!(groups.len(), 1);
1007        assert!(groups.contains_key(&None));
1008        assert_eq!(groups[&None].len(), 3);
1009    }
1010
1011    #[test]
1012    fn test_group_by_lora_single_group_with_lora() {
1013        let lora = LoraConfig::with_id(0);
1014        let requests = vec![
1015            create_test_request(0, Some(lora.clone())),
1016            create_test_request(1, Some(lora.clone())),
1017        ];
1018
1019        let groups = group_by_lora(requests);
1020
1021        assert_eq!(groups.len(), 1);
1022        assert!(groups.contains_key(&Some(lora)));
1023    }
1024
1025    #[test]
1026    fn test_group_by_lora_multiple_groups() {
1027        let lora_a = LoraConfig::with_id(0);
1028        let lora_b = LoraConfig::with_id(1);
1029
1030        let requests = vec![
1031            create_test_request(0, Some(lora_a.clone())),
1032            create_test_request(1, Some(lora_b.clone())),
1033            create_test_request(2, Some(lora_a.clone())),
1034            create_test_request(3, None),
1035            create_test_request(4, Some(lora_b.clone())),
1036        ];
1037
1038        let groups = group_by_lora(requests);
1039
1040        assert_eq!(groups.len(), 3);
1041        assert_eq!(groups[&Some(lora_a)].len(), 2);
1042        assert_eq!(groups[&Some(lora_b)].len(), 2);
1043        assert_eq!(groups[&None].len(), 1);
1044    }
1045
1046    #[test]
1047    fn test_group_by_lora_preserves_order_within_group() {
1048        let lora = LoraConfig::with_id(0);
1049        let requests = vec![
1050            create_test_request(5, Some(lora.clone())),
1051            create_test_request(3, Some(lora.clone())),
1052            create_test_request(7, Some(lora.clone())),
1053        ];
1054
1055        let groups = group_by_lora(requests);
1056        let group = &groups[&Some(lora)];
1057
1058        // グループ内の順序は保持される
1059        assert_eq!(group[0].0, WorkerId(5));
1060        assert_eq!(group[1].0, WorkerId(3));
1061        assert_eq!(group[2].0, WorkerId(7));
1062    }
1063
1064    #[test]
1065    fn test_group_by_lora_different_scales() {
1066        // 同じ ID でも scale が違えば別グループ
1067        let lora_full = LoraConfig::new(0, 1.0);
1068        let lora_half = LoraConfig::new(0, 0.5);
1069
1070        let requests = vec![
1071            create_test_request(0, Some(lora_full.clone())),
1072            create_test_request(1, Some(lora_half.clone())),
1073            create_test_request(2, Some(lora_full.clone())),
1074        ];
1075
1076        let groups = group_by_lora(requests);
1077
1078        assert_eq!(groups.len(), 2);
1079        assert_eq!(groups[&Some(lora_full)].len(), 2);
1080        assert_eq!(groups[&Some(lora_half)].len(), 1);
1081    }
1082
1083    #[test]
1084    fn test_group_by_lora_empty() {
1085        let requests: Vec<(WorkerId, WorkerDecisionRequest)> = vec![];
1086        let groups = group_by_lora(requests);
1087        assert!(groups.is_empty());
1088    }
1089}