swarm_engine_llm/
decider.rs1use std::future::Future;
16use std::pin::Pin;
17
18pub type BatchDecisionFuture<'a> =
20 Pin<Box<dyn Future<Output = Vec<Result<DecisionResponse, LlmError>>> + Send + 'a>>;
21
22pub use swarm_engine_core::agent::{
24 ActionCandidate, ActionParam, DecisionResponse, ResolvedContext, WorkerDecisionRequest,
25};
26pub use swarm_engine_core::types::LoraConfig;
27
28#[derive(Debug, Clone, thiserror::Error)]
30pub enum LlmError {
31 #[error("LLM error (transient): {0}")]
33 Transient(String),
34
35 #[error("LLM error: {0}")]
37 Permanent(String),
38}
39
40impl LlmError {
41 pub fn transient(message: impl Into<String>) -> Self {
42 Self::Transient(message.into())
43 }
44
45 pub fn permanent(message: impl Into<String>) -> Self {
46 Self::Permanent(message.into())
47 }
48
49 pub fn is_transient(&self) -> bool {
50 matches!(self, Self::Transient(_))
51 }
52
53 pub fn message(&self) -> &str {
54 match self {
55 Self::Transient(msg) => msg,
56 Self::Permanent(msg) => msg,
57 }
58 }
59}
60
61impl From<swarm_engine_core::error::SwarmError> for LlmError {
62 fn from(err: swarm_engine_core::error::SwarmError) -> Self {
63 if err.is_transient() {
64 Self::Transient(err.message())
65 } else {
66 Self::Permanent(err.message())
67 }
68 }
69}
70
71impl From<LlmError> for swarm_engine_core::error::SwarmError {
72 fn from(err: LlmError) -> Self {
73 match err {
74 LlmError::Transient(message) => {
75 swarm_engine_core::error::SwarmError::LlmTransient { message }
76 }
77 LlmError::Permanent(message) => {
78 swarm_engine_core::error::SwarmError::LlmPermanent { message }
79 }
80 }
81 }
82}
83
84pub trait LlmDecider: Send + Sync {
88 fn decide(
90 &self,
91 request: WorkerDecisionRequest,
92 ) -> Pin<Box<dyn Future<Output = Result<DecisionResponse, LlmError>> + Send + '_>>;
93
94 fn call_raw(
103 &self,
104 _prompt: &str,
105 _lora: Option<&LoraConfig>,
106 ) -> Pin<Box<dyn Future<Output = Result<String, LlmError>> + Send + '_>> {
107 Box::pin(async { Err(LlmError::permanent("call_raw not implemented")) })
108 }
109
110 fn decide_batch(&self, requests: Vec<WorkerDecisionRequest>) -> BatchDecisionFuture<'_> {
112 Box::pin(async move {
114 let mut results = Vec::with_capacity(requests.len());
115 for req in requests {
116 results.push(self.decide(req).await);
117 }
118 results
119 })
120 }
121
122 fn model_name(&self) -> &str;
124
125 fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>>;
127
128 fn max_concurrency(&self) -> Pin<Box<dyn Future<Output = Option<usize>> + Send + '_>> {
133 Box::pin(async { None })
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct LlmDeciderConfig {
140 pub model: String,
142 pub endpoint: String,
144 pub timeout_ms: u64,
146 pub max_batch_size: usize,
148 pub temperature: f32,
150 pub system_prompt: Option<String>,
152}
153
154impl Default for LlmDeciderConfig {
155 fn default() -> Self {
156 Self {
157 model: "qwen2.5-coder:1.5b".to_string(),
158 endpoint: "http://localhost:11434".to_string(),
159 timeout_ms: 5000,
160 max_batch_size: 100,
161 temperature: 0.1,
162 system_prompt: None,
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_llm_error_transient() {
173 let err = LlmError::transient("connection timeout");
174 assert!(err.is_transient());
175 assert_eq!(err.message(), "connection timeout");
176 assert_eq!(
177 format!("{}", err),
178 "LLM error (transient): connection timeout"
179 );
180 }
181
182 #[test]
183 fn test_llm_error_permanent() {
184 let err = LlmError::permanent("invalid model");
185 assert!(!err.is_transient());
186 assert_eq!(err.message(), "invalid model");
187 }
188
189 #[test]
190 fn test_llm_decider_config_default() {
191 let config = LlmDeciderConfig::default();
192 assert_eq!(config.model, "qwen2.5-coder:1.5b");
193 assert_eq!(config.endpoint, "http://localhost:11434");
194 assert_eq!(config.timeout_ms, 5000);
195 assert_eq!(config.max_batch_size, 100);
196 assert!((config.temperature - 0.1).abs() < 0.001);
197 }
198}