1use crate::agent::core::AgentCore;
6use crate::agent::factory::{create_provider, resolve_api_key};
7use crate::agent::model::{ModelProvider, ProviderKind};
8#[cfg(feature = "openai")]
9use crate::agent::providers::openai::OpenAIProvider;
10#[cfg(feature = "lmstudio")]
11use crate::agent::providers::LMStudioProvider;
12#[cfg(feature = "mlx")]
13use crate::agent::providers::MLXProvider;
14use crate::config::{AgentProfile, AgentRegistry, AppConfig, ModelConfig};
15use crate::embeddings::EmbeddingsClient;
16use crate::persistence::Persistence;
17use crate::policy::PolicyEngine;
18use crate::tools::ToolRegistry;
19use anyhow::{anyhow, Context, Result};
20#[cfg(any(feature = "mlx", feature = "lmstudio"))]
21use async_openai::config::OpenAIConfig;
22use std::sync::Arc;
23use tracing::{info, warn};
24
25pub struct AgentBuilder {
27 profile: Option<AgentProfile>,
28 provider: Option<Arc<dyn ModelProvider>>,
29 embeddings_client: Option<EmbeddingsClient>,
30 persistence: Option<Persistence>,
31 session_id: Option<String>,
32 config: Option<AppConfig>,
33 tool_registry: Option<Arc<ToolRegistry>>,
34 policy_engine: Option<Arc<PolicyEngine>>,
35 agent_name: Option<String>,
36}
37
38impl AgentBuilder {
39 pub fn new() -> Self {
41 Self {
42 profile: None,
43 provider: None,
44 embeddings_client: None,
45 persistence: None,
46 session_id: None,
47 config: None,
48 tool_registry: None,
49 policy_engine: None,
50 agent_name: None,
51 }
52 }
53
54 pub fn new_with_registry(
57 registry: &AgentRegistry,
58 config: &AppConfig,
59 session_id: Option<String>,
60 ) -> Result<AgentCore> {
61 create_agent_from_registry(registry, config, session_id)
62 }
63
64 pub fn with_profile(mut self, profile: AgentProfile) -> Self {
66 self.profile = Some(profile);
67 self
68 }
69
70 pub fn with_provider(mut self, provider: Arc<dyn ModelProvider>) -> Self {
72 self.provider = Some(provider);
73 self
74 }
75
76 pub fn with_embeddings_client(mut self, embeddings_client: EmbeddingsClient) -> Self {
78 self.embeddings_client = Some(embeddings_client);
79 self
80 }
81
82 pub fn with_persistence(mut self, persistence: Persistence) -> Self {
84 self.persistence = Some(persistence);
85 self
86 }
87
88 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
90 self.session_id = Some(session_id.into());
91 self
92 }
93
94 pub fn with_config(mut self, config: AppConfig) -> Self {
96 self.config = Some(config);
97 self
98 }
99
100 pub fn with_tool_registry(mut self, tool_registry: Arc<ToolRegistry>) -> Self {
102 self.tool_registry = Some(tool_registry);
103 self
104 }
105
106 pub fn with_policy_engine(mut self, policy_engine: Arc<PolicyEngine>) -> Self {
108 self.policy_engine = Some(policy_engine);
109 self
110 }
111
112 pub fn with_agent_name(mut self, agent_name: impl Into<String>) -> Self {
114 self.agent_name = Some(agent_name.into());
115 self
116 }
117
118 pub fn build(self) -> Result<AgentCore> {
120 let profile = self
122 .profile
123 .ok_or_else(|| anyhow!("Agent profile is required"))?;
124
125 let persistence = if let Some(persistence) = self.persistence {
127 persistence
128 } else if let Some(ref config) = self.config {
129 Persistence::new(&config.database.path).context("Failed to create persistence layer")?
130 } else {
131 return Err(anyhow!(
132 "Either persistence or config must be provided to build agent"
133 ));
134 };
135
136 let embeddings_client = if let Some(client) = self.embeddings_client {
138 Some(client)
139 } else if let Some(ref config) = self.config {
140 create_embeddings_client_from_config(config)?
141 } else {
142 None
143 };
144
145 let tool_registry = self.tool_registry.unwrap_or_else(|| {
148 let persistence_arc = Arc::new(persistence.clone());
149 let registry =
150 ToolRegistry::with_builtin_tools(Some(persistence_arc), embeddings_client.clone());
151 info!(
152 "Created tool registry with {} builtin tools",
153 registry.len()
154 );
155 for tool_name in registry.list() {
156 tracing::debug!(" - Registered tool: {}", tool_name);
157 }
158 Arc::new(registry)
159 });
160
161 let provider = if let Some(provider) = self.provider {
163 provider
164 } else if let Some(ref config) = self.config {
165 let mut base_provider =
166 create_provider(&config.model).context("Failed to create provider from config")?;
167
168 #[cfg(feature = "openai")]
170 {
171 if base_provider.kind() == ProviderKind::OpenAI {
172 let tools = tool_registry.to_openai_tools();
173 if !tools.is_empty() {
174 info!(
175 "Configuring OpenAI provider with {} tools for native function calling",
176 tools.len()
177 );
178
179 let api_key = if let Some(source) = &config.model.api_key_source {
181 resolve_api_key(source)?
182 } else {
183 std::env::var("OPENAI_API_KEY")
185 .context("OPENAI_API_KEY environment variable not set")?
186 };
187
188 let mut openai_provider = OpenAIProvider::with_api_key(api_key);
189
190 if let Some(model_name) = &config.model.model_name {
192 openai_provider = openai_provider.with_model(model_name.clone());
193 }
194
195 base_provider = Arc::new(openai_provider.with_tools(tools));
197 }
198 }
199 }
200
201 #[cfg(feature = "mlx")]
203 {
204 if base_provider.kind() == ProviderKind::MLX {
205 let tools = tool_registry.to_openai_tools();
206 if !tools.is_empty() {
207 info!(
208 "Configuring MLX provider with {} tools for native function calling",
209 tools.len()
210 );
211
212 let model_name = config
214 .model
215 .model_name
216 .as_ref()
217 .ok_or_else(|| {
218 anyhow!("MLX provider requires a model_name to be specified")
219 })?
220 .clone();
221
222 let mlx_provider = if let Ok(endpoint) = std::env::var("MLX_ENDPOINT") {
223 MLXProvider::with_endpoint(endpoint, model_name)
224 } else {
225 MLXProvider::new(model_name)
226 };
227
228 base_provider = Arc::new(mlx_provider.with_tools(tools));
229 }
230 }
231 }
232
233 #[cfg(feature = "lmstudio")]
234 {
235 if base_provider.kind() == ProviderKind::LMStudio {
236 let tools = tool_registry.to_openai_tools();
237 if !tools.is_empty() {
238 info!(
239 "Configuring LM Studio provider with {} tools for native function calling",
240 tools.len()
241 );
242
243 let model_name = config
244 .model
245 .model_name
246 .as_ref()
247 .ok_or_else(|| {
248 anyhow!("LM Studio provider requires a model_name to be specified")
249 })?
250 .clone();
251
252 let lmstudio_provider =
253 if let Ok(endpoint) = std::env::var("LMSTUDIO_ENDPOINT") {
254 LMStudioProvider::with_endpoint(endpoint, model_name)
255 } else {
256 LMStudioProvider::new(model_name)
257 };
258
259 base_provider = Arc::new(lmstudio_provider.with_tools(tools));
260 }
261 }
262 }
263
264 base_provider
265 } else {
266 return Err(anyhow!(
267 "Either provider or config must be provided to build agent"
268 ));
269 };
270
271 let session_id = self
273 .session_id
274 .unwrap_or_else(|| format!("session-{}", chrono::Utc::now().timestamp_millis()));
275
276 let policy_engine = if let Some(engine) = self.policy_engine {
278 engine
279 } else {
280 let mut engine = PolicyEngine::load_from_persistence(&persistence)
282 .unwrap_or_else(|_| PolicyEngine::new());
283
284 if engine.rule_count() == 0 {
286 tracing::debug!(
287 "Empty policy engine detected, adding default allow-all rule for tools"
288 );
289 engine.add_rule(crate::policy::PolicyRule {
290 agent: "*".to_string(),
291 action: "tool_call".to_string(),
292 resource: "*".to_string(),
293 effect: crate::policy::PolicyEffect::Allow,
294 });
295 }
296
297 Arc::new(engine)
298 };
299
300 let fast_provider = if profile.fast_reasoning {
301 match (&profile.fast_model_provider, &profile.fast_model_name) {
302 (Some(provider_name), Some(model_name)) => {
303 let fast_config = ModelConfig {
304 provider: provider_name.clone(),
305 model_name: Some(model_name.clone()),
306 embeddings_model: None,
307 api_key_source: None,
308 temperature: profile.fast_model_temperature,
309 };
310 match create_provider(&fast_config) {
311 Ok(provider) => Some(provider),
312 Err(err) => {
313 warn!(
314 "Failed to create fast provider {}:{} - {}",
315 provider_name, model_name, err
316 );
317 None
318 }
319 }
320 }
321 _ => None,
322 }
323 } else {
324 None
325 };
326
327 let mut agent = AgentCore::new(
328 profile,
329 provider,
330 embeddings_client,
331 persistence,
332 session_id,
333 self.agent_name,
334 tool_registry,
335 policy_engine,
336 );
337
338 if let Some(fast_provider) = fast_provider {
339 agent = agent.with_fast_provider(fast_provider);
340 }
341
342 Ok(agent)
343 }
344}
345
346impl Default for AgentBuilder {
347 fn default() -> Self {
348 Self::new()
349 }
350}
351
352pub fn create_agent_from_registry(
354 registry: &AgentRegistry,
355 config: &AppConfig,
356 session_id: Option<String>,
357) -> Result<AgentCore> {
358 let (agent_name, profile) = registry
359 .active()
360 .context("No active agent profile in registry")?
361 .ok_or_else(|| anyhow!("No active agent set in registry"))?;
362
363 let mut builder = AgentBuilder::new()
364 .with_profile(profile)
365 .with_config(config.clone())
366 .with_persistence(registry.persistence().clone())
367 .with_agent_name(agent_name.clone());
368
369 if let Some(sid) = session_id {
370 builder = builder.with_session_id(sid);
371 }
372
373 builder.build()
374}
375
376fn create_embeddings_client_from_config(config: &AppConfig) -> Result<Option<EmbeddingsClient>> {
377 let model = &config.model;
378 let Some(model_name) = &model.embeddings_model else {
379 return Ok(None);
380 };
381
382 #[cfg(feature = "mlx")]
383 {
384 if ProviderKind::from_str(&model.provider) == Some(ProviderKind::MLX) {
385 return Ok(Some(build_mlx_embeddings_client(model_name)));
386 }
387 }
388
389 #[cfg(feature = "lmstudio")]
390 {
391 if ProviderKind::from_str(&model.provider) == Some(ProviderKind::LMStudio) {
392 return Ok(Some(build_lmstudio_embeddings_client(model_name)));
393 }
394 }
395
396 let client = if let Some(source) = &model.api_key_source {
397 let api_key = resolve_api_key(source)?;
398 EmbeddingsClient::with_api_key(model_name.clone(), api_key)
399 } else {
400 EmbeddingsClient::new(model_name.clone())
401 };
402
403 Ok(Some(client))
404}
405
406#[cfg(feature = "mlx")]
407fn build_mlx_embeddings_client(model_name: &str) -> EmbeddingsClient {
408 let endpoint =
409 std::env::var("MLX_ENDPOINT").unwrap_or_else(|_| "http://localhost:10240".to_string());
410 let api_base = if endpoint.ends_with("/v1") {
411 endpoint
412 } else {
413 format!("{}/v1", endpoint)
414 };
415
416 let config = OpenAIConfig::new()
417 .with_api_base(api_base)
418 .with_api_key("mlx-key");
419
420 EmbeddingsClient::with_config(model_name.to_string(), config)
421}
422
423#[cfg(feature = "lmstudio")]
424fn build_lmstudio_embeddings_client(model_name: &str) -> EmbeddingsClient {
425 let endpoint =
426 std::env::var("LMSTUDIO_ENDPOINT").unwrap_or_else(|_| "http://localhost:1234".to_string());
427 let api_base = if endpoint.ends_with("/v1") {
428 endpoint
429 } else {
430 format!("{}/v1", endpoint)
431 };
432
433 let config = OpenAIConfig::new()
434 .with_api_base(api_base)
435 .with_api_key("lmstudio-key");
436
437 EmbeddingsClient::with_config(model_name.to_string(), config)
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use crate::agent::providers::MockProvider;
444 use crate::config::{
445 AgentProfile, AudioConfig, DatabaseConfig, LoggingConfig, ModelConfig, UiConfig,
446 };
447 use std::collections::HashMap;
448 use tempfile::tempdir;
449
450 fn create_test_config() -> AppConfig {
451 let dir = tempdir().unwrap();
452 let db_path = dir.path().join("test.duckdb");
453
454 AppConfig {
455 database: DatabaseConfig { path: db_path },
456 model: ModelConfig {
457 provider: "mock".to_string(),
458 model_name: Some("test-model".to_string()),
459 embeddings_model: None,
460 api_key_source: None,
461 temperature: 0.7,
462 },
463 ui: UiConfig {
464 prompt: "> ".to_string(),
465 theme: "default".to_string(),
466 },
467 logging: LoggingConfig {
468 level: "info".to_string(),
469 },
470 audio: AudioConfig::default(),
471 mesh: crate::config::MeshConfig::default(),
472 agents: HashMap::new(),
473 default_agent: None,
474 }
475 }
476
477 fn create_test_profile() -> AgentProfile {
478 AgentProfile {
479 prompt: Some("Test system prompt".to_string()),
480 style: None,
481 temperature: Some(0.8),
482 model_provider: None,
483 model_name: None,
484 allowed_tools: None,
485 denied_tools: None,
486 memory_k: 10,
487 top_p: 0.95,
488 max_context_tokens: Some(4096),
489 enable_graph: false,
490 graph_memory: false,
491 auto_graph: false,
492 graph_steering: false,
493 graph_depth: 3,
494 graph_weight: 0.5,
495 graph_threshold: 0.7,
496 fast_reasoning: false,
497 fast_model_provider: None,
498 fast_model_name: None,
499 fast_model_temperature: 0.3,
500 fast_model_tasks: vec![],
501 escalation_threshold: 0.6,
502 show_reasoning: false,
503 enable_audio_transcription: false,
504 audio_response_mode: "immediate".to_string(),
505 audio_scenario: None,
506 }
507 }
508
509 #[test]
510 fn test_builder_with_all_fields() {
511 let dir = tempdir().unwrap();
512 let db_path = dir.path().join("test.duckdb");
513 let persistence = Persistence::new(&db_path).unwrap();
514
515 let profile = create_test_profile();
516 let provider = Arc::new(MockProvider::default());
517
518 let agent = AgentBuilder::new()
519 .with_profile(profile)
520 .with_provider(provider)
521 .with_persistence(persistence)
522 .with_session_id("test-session")
523 .build()
524 .unwrap();
525
526 assert_eq!(agent.session_id(), "test-session");
527 assert_eq!(
528 agent.profile().prompt,
529 Some("Test system prompt".to_string())
530 );
531 }
532
533 #[test]
534 fn test_builder_with_config() {
535 let config = create_test_config();
536 let profile = create_test_profile();
537
538 let agent = AgentBuilder::new()
539 .with_profile(profile)
540 .with_config(config)
541 .build()
542 .unwrap();
543
544 assert!(agent.session_id().starts_with("session-"));
546 }
547
548 #[test]
549 fn test_builder_missing_profile() {
550 let config = create_test_config();
551
552 let result = AgentBuilder::new().with_config(config).build();
553
554 assert!(result.is_err());
555 assert!(result.err().unwrap().to_string().contains("profile"));
556 }
557
558 #[test]
559 fn test_builder_missing_provider_and_config() {
560 let dir = tempdir().unwrap();
561 let db_path = dir.path().join("test.duckdb");
562 let persistence = Persistence::new(&db_path).unwrap();
563
564 let profile = create_test_profile();
565
566 let result = AgentBuilder::new()
567 .with_profile(profile)
568 .with_persistence(persistence)
569 .build();
570
571 assert!(result.is_err());
572 assert!(result
573 .err()
574 .unwrap()
575 .to_string()
576 .contains("provider or config"));
577 }
578
579 #[test]
580 fn test_builder_auto_session_id() {
581 let config = create_test_config();
582 let profile = create_test_profile();
583
584 let agent = AgentBuilder::new()
585 .with_profile(profile)
586 .with_config(config)
587 .build()
588 .unwrap();
589
590 assert!(!agent.session_id().is_empty());
592 }
593
594 #[test]
595 fn test_create_agent_from_registry() {
596 let dir = tempdir().unwrap();
597 let db_path = dir.path().join("test.duckdb");
598 let persistence = Persistence::new(&db_path).unwrap();
599
600 let config = create_test_config();
601 let profile = create_test_profile();
602
603 let mut agents = HashMap::new();
604 agents.insert("test-agent".to_string(), profile.clone());
605
606 let registry = AgentRegistry::new(agents, persistence.clone());
607 registry.set_active("test-agent").unwrap();
608
609 let agent =
610 create_agent_from_registry(®istry, &config, Some("custom-session".to_string()))
611 .unwrap();
612
613 assert_eq!(agent.session_id(), "custom-session");
614 assert_eq!(
615 agent.profile().prompt,
616 Some("Test system prompt".to_string())
617 );
618 }
619
620 #[test]
621 fn test_create_agent_from_registry_no_active() {
622 let dir = tempdir().unwrap();
623 let db_path = dir.path().join("test.duckdb");
624 let persistence = Persistence::new(&db_path).unwrap();
625
626 let config = create_test_config();
627 let registry = AgentRegistry::new(HashMap::new(), persistence);
628
629 let result = create_agent_from_registry(®istry, &config, None);
630
631 assert!(result.is_err());
632 let err_msg = result.err().unwrap().to_string();
633 assert!(err_msg.contains("No active") || err_msg.contains("active agent"));
634 }
635}