1use crate::agent::core::AgentCore;
6use crate::agent::factory::{create_provider, resolve_api_key};
7use crate::agent::model::{ModelProvider, ProviderKind};
8#[cfg(feature = "openai")]
9use crate::agent::providers::openai::OpenAIProvider;
10#[cfg(feature = "lmstudio")]
11use crate::agent::providers::LMStudioProvider;
12#[cfg(feature = "mlx")]
13use crate::agent::providers::MLXProvider;
14use crate::config::{AgentProfile, AgentRegistry, AppConfig, ModelConfig};
15use crate::embeddings::EmbeddingsClient;
16use crate::persistence::Persistence;
17use crate::policy::PolicyEngine;
18use crate::tools::ToolRegistry;
19use anyhow::{anyhow, Context, Result};
20#[cfg(any(feature = "mlx", feature = "lmstudio"))]
21use async_openai::config::OpenAIConfig;
22use std::sync::Arc;
23use tracing::{info, warn};
24
25pub struct AgentBuilder {
27 profile: Option<AgentProfile>,
28 provider: Option<Arc<dyn ModelProvider>>,
29 embeddings_client: Option<EmbeddingsClient>,
30 persistence: Option<Persistence>,
31 session_id: Option<String>,
32 config: Option<AppConfig>,
33 tool_registry: Option<Arc<ToolRegistry>>,
34 policy_engine: Option<Arc<PolicyEngine>>,
35 agent_name: Option<String>,
36}
37
38impl AgentBuilder {
39 pub fn new() -> Self {
41 Self {
42 profile: None,
43 provider: None,
44 embeddings_client: None,
45 persistence: None,
46 session_id: None,
47 config: None,
48 tool_registry: None,
49 policy_engine: None,
50 agent_name: None,
51 }
52 }
53
54 pub fn new_with_registry(
57 registry: &AgentRegistry,
58 config: &AppConfig,
59 session_id: Option<String>,
60 ) -> Result<AgentCore> {
61 create_agent_from_registry(registry, config, session_id)
62 }
63
64 pub fn with_profile(mut self, profile: AgentProfile) -> Self {
66 self.profile = Some(profile);
67 self
68 }
69
70 pub fn with_provider(mut self, provider: Arc<dyn ModelProvider>) -> Self {
72 self.provider = Some(provider);
73 self
74 }
75
76 pub fn with_embeddings_client(mut self, embeddings_client: EmbeddingsClient) -> Self {
78 self.embeddings_client = Some(embeddings_client);
79 self
80 }
81
82 pub fn with_persistence(mut self, persistence: Persistence) -> Self {
84 self.persistence = Some(persistence);
85 self
86 }
87
88 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
90 self.session_id = Some(session_id.into());
91 self
92 }
93
94 pub fn with_config(mut self, config: AppConfig) -> Self {
96 self.config = Some(config);
97 self
98 }
99
100 pub fn with_tool_registry(mut self, tool_registry: Arc<ToolRegistry>) -> Self {
102 self.tool_registry = Some(tool_registry);
103 self
104 }
105
106 pub fn with_policy_engine(mut self, policy_engine: Arc<PolicyEngine>) -> Self {
108 self.policy_engine = Some(policy_engine);
109 self
110 }
111
112 pub fn with_agent_name(mut self, agent_name: impl Into<String>) -> Self {
114 self.agent_name = Some(agent_name.into());
115 self
116 }
117
118 pub fn build(self) -> Result<AgentCore> {
120 let profile = self
122 .profile
123 .ok_or_else(|| anyhow!("Agent profile is required"))?;
124
125 let persistence = if let Some(persistence) = self.persistence {
127 persistence
128 } else if let Some(ref config) = self.config {
129 Persistence::new(&config.database.path).context("Failed to create persistence layer")?
130 } else {
131 return Err(anyhow!(
132 "Either persistence or config must be provided to build agent"
133 ));
134 };
135
136 let embeddings_client = if let Some(client) = self.embeddings_client {
138 Some(client)
139 } else if let Some(ref config) = self.config {
140 create_embeddings_client_from_config(config)?
141 } else {
142 None
143 };
144
145 let tool_registry = if let Some(registry) = self.tool_registry {
148 registry
149 } else {
150 let persistence_arc = Arc::new(persistence.clone());
151 let mut registry =
152 ToolRegistry::with_builtin_tools(Some(persistence_arc), embeddings_client.clone());
153 info!(
154 "Created tool registry with {} builtin tools",
155 registry.len()
156 );
157 for tool_name in registry.list() {
158 tracing::debug!(" - Registered tool: {}", tool_name);
159 }
160
161 if let Some(ref config) = self.config {
163 if config.plugins.enabled {
164 match registry.load_plugins(
165 &config.plugins.custom_tools_dir,
166 config.plugins.allow_override_builtin,
167 ) {
168 Ok(stats) => {
169 if stats.loaded > 0 {
170 info!(
171 "Loaded {} plugins with {} tools from {}",
172 stats.loaded,
173 stats.tools_loaded,
174 config.plugins.custom_tools_dir.display()
175 );
176 }
177 if stats.failed > 0 {
178 warn!(
179 "{} plugins failed to load",
180 stats.failed
181 );
182 }
183 }
184 Err(e) => {
185 if config.plugins.continue_on_error {
186 warn!("Plugin loading failed (continuing): {}", e);
187 } else {
188 return Err(anyhow::anyhow!("Plugin loading failed: {}", e));
190 }
191 }
192 }
193 }
194 }
195
196 Arc::new(registry)
197 };
198
199 let provider = if let Some(provider) = self.provider {
201 provider
202 } else if let Some(ref config) = self.config {
203 let mut base_provider =
204 create_provider(&config.model).context("Failed to create provider from config")?;
205
206 #[cfg(feature = "openai")]
208 {
209 if base_provider.kind() == ProviderKind::OpenAI {
210 let tools = tool_registry.to_openai_tools();
211 if !tools.is_empty() {
212 info!(
213 "Configuring OpenAI provider with {} tools for native function calling",
214 tools.len()
215 );
216
217 let api_key = if let Some(source) = &config.model.api_key_source {
219 resolve_api_key(source)?
220 } else {
221 std::env::var("OPENAI_API_KEY")
223 .context("OPENAI_API_KEY environment variable not set")?
224 };
225
226 let mut openai_provider = OpenAIProvider::with_api_key(api_key);
227
228 if let Some(model_name) = &config.model.model_name {
230 openai_provider = openai_provider.with_model(model_name.clone());
231 }
232
233 base_provider = Arc::new(openai_provider.with_tools(tools));
235 }
236 }
237 }
238
239 #[cfg(feature = "mlx")]
241 {
242 if base_provider.kind() == ProviderKind::MLX {
243 let tools = tool_registry.to_openai_tools();
244 if !tools.is_empty() {
245 info!(
246 "Configuring MLX provider with {} tools for native function calling",
247 tools.len()
248 );
249
250 let model_name = config
252 .model
253 .model_name
254 .as_ref()
255 .ok_or_else(|| {
256 anyhow!("MLX provider requires a model_name to be specified")
257 })?
258 .clone();
259
260 let mlx_provider = if let Ok(endpoint) = std::env::var("MLX_ENDPOINT") {
261 MLXProvider::with_endpoint(endpoint, model_name)
262 } else {
263 MLXProvider::new(model_name)
264 };
265
266 base_provider = Arc::new(mlx_provider.with_tools(tools));
267 }
268 }
269 }
270
271 #[cfg(feature = "lmstudio")]
272 {
273 if base_provider.kind() == ProviderKind::LMStudio {
274 let tools = tool_registry.to_openai_tools();
275 if !tools.is_empty() {
276 info!(
277 "Configuring LM Studio provider with {} tools for native function calling",
278 tools.len()
279 );
280
281 let model_name = config
282 .model
283 .model_name
284 .as_ref()
285 .ok_or_else(|| {
286 anyhow!("LM Studio provider requires a model_name to be specified")
287 })?
288 .clone();
289
290 let lmstudio_provider =
291 if let Ok(endpoint) = std::env::var("LMSTUDIO_ENDPOINT") {
292 LMStudioProvider::with_endpoint(endpoint, model_name)
293 } else {
294 LMStudioProvider::new(model_name)
295 };
296
297 base_provider = Arc::new(lmstudio_provider.with_tools(tools));
298 }
299 }
300 }
301
302 base_provider
303 } else {
304 return Err(anyhow!(
305 "Either provider or config must be provided to build agent"
306 ));
307 };
308
309 let session_id = self
311 .session_id
312 .unwrap_or_else(|| format!("session-{}", chrono::Utc::now().timestamp_millis()));
313
314 let policy_engine = if let Some(engine) = self.policy_engine {
316 engine
317 } else {
318 let mut engine = PolicyEngine::load_from_persistence(&persistence)
320 .unwrap_or_else(|_| PolicyEngine::new());
321
322 if engine.rule_count() == 0 {
324 tracing::debug!(
325 "Empty policy engine detected, adding default allow-all rule for tools"
326 );
327 engine.add_rule(crate::policy::PolicyRule {
328 agent: "*".to_string(),
329 action: "tool_call".to_string(),
330 resource: "*".to_string(),
331 effect: crate::policy::PolicyEffect::Allow,
332 });
333 }
334
335 Arc::new(engine)
336 };
337
338 let fast_provider = if profile.fast_reasoning {
339 match (&profile.fast_model_provider, &profile.fast_model_name) {
340 (Some(provider_name), Some(model_name)) => {
341 let fast_config = ModelConfig {
342 provider: provider_name.clone(),
343 model_name: Some(model_name.clone()),
344 embeddings_model: None,
345 api_key_source: None,
346 temperature: profile.fast_model_temperature,
347 };
348 match create_provider(&fast_config) {
349 Ok(provider) => Some(provider),
350 Err(err) => {
351 warn!(
352 "Failed to create fast provider {}:{} - {}",
353 provider_name, model_name, err
354 );
355 None
356 }
357 }
358 }
359 _ => None,
360 }
361 } else {
362 None
363 };
364
365 let mut agent = AgentCore::new(
366 profile,
367 provider,
368 embeddings_client,
369 persistence,
370 session_id,
371 self.agent_name,
372 tool_registry,
373 policy_engine,
374 );
375
376 if let Some(fast_provider) = fast_provider {
377 agent = agent.with_fast_provider(fast_provider);
378 }
379
380 Ok(agent)
381 }
382}
383
384impl Default for AgentBuilder {
385 fn default() -> Self {
386 Self::new()
387 }
388}
389
390pub fn create_agent_from_registry(
392 registry: &AgentRegistry,
393 config: &AppConfig,
394 session_id: Option<String>,
395) -> Result<AgentCore> {
396 let (agent_name, profile) = registry
397 .active()
398 .context("No active agent profile in registry")?
399 .ok_or_else(|| anyhow!("No active agent set in registry"))?;
400
401 let mut builder = AgentBuilder::new()
402 .with_profile(profile)
403 .with_config(config.clone())
404 .with_persistence(registry.persistence().clone())
405 .with_agent_name(agent_name.clone());
406
407 if let Some(sid) = session_id {
408 builder = builder.with_session_id(sid);
409 }
410
411 builder.build()
412}
413
414fn create_embeddings_client_from_config(config: &AppConfig) -> Result<Option<EmbeddingsClient>> {
415 let model = &config.model;
416 let Some(model_name) = &model.embeddings_model else {
417 return Ok(None);
418 };
419
420 #[cfg(feature = "mlx")]
421 {
422 if ProviderKind::from_str(&model.provider) == Some(ProviderKind::MLX) {
423 return Ok(Some(build_mlx_embeddings_client(model_name)));
424 }
425 }
426
427 #[cfg(feature = "lmstudio")]
428 {
429 if ProviderKind::from_str(&model.provider) == Some(ProviderKind::LMStudio) {
430 return Ok(Some(build_lmstudio_embeddings_client(model_name)));
431 }
432 }
433
434 let client = if let Some(source) = &model.api_key_source {
435 let api_key = resolve_api_key(source)?;
436 EmbeddingsClient::with_api_key(model_name.clone(), api_key)
437 } else {
438 EmbeddingsClient::new(model_name.clone())
439 };
440
441 Ok(Some(client))
442}
443
444#[cfg(feature = "mlx")]
445fn build_mlx_embeddings_client(model_name: &str) -> EmbeddingsClient {
446 let endpoint =
447 std::env::var("MLX_ENDPOINT").unwrap_or_else(|_| "http://localhost:10240".to_string());
448 let api_base = if endpoint.ends_with("/v1") {
449 endpoint
450 } else {
451 format!("{}/v1", endpoint)
452 };
453
454 let config = OpenAIConfig::new()
455 .with_api_base(api_base)
456 .with_api_key("mlx-key");
457
458 EmbeddingsClient::with_config(model_name.to_string(), config)
459}
460
461#[cfg(feature = "lmstudio")]
462fn build_lmstudio_embeddings_client(model_name: &str) -> EmbeddingsClient {
463 let endpoint =
464 std::env::var("LMSTUDIO_ENDPOINT").unwrap_or_else(|_| "http://localhost:1234".to_string());
465 let api_base = if endpoint.ends_with("/v1") {
466 endpoint
467 } else {
468 format!("{}/v1", endpoint)
469 };
470
471 let config = OpenAIConfig::new()
472 .with_api_base(api_base)
473 .with_api_key("lmstudio-key");
474
475 EmbeddingsClient::with_config(model_name.to_string(), config)
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481 use crate::agent::providers::MockProvider;
482 use crate::config::{
483 AgentProfile, AudioConfig, DatabaseConfig, LoggingConfig, ModelConfig, PluginConfig,
484 UiConfig,
485 };
486 use std::collections::HashMap;
487 use tempfile::tempdir;
488
489 fn create_test_config() -> AppConfig {
490 let dir = tempdir().unwrap();
491 let db_path = dir.path().join("test.duckdb");
492
493 AppConfig {
494 database: DatabaseConfig { path: db_path },
495 model: ModelConfig {
496 provider: "mock".to_string(),
497 model_name: Some("test-model".to_string()),
498 embeddings_model: None,
499 api_key_source: None,
500 temperature: 0.7,
501 },
502 ui: UiConfig {
503 prompt: "> ".to_string(),
504 theme: "default".to_string(),
505 },
506 logging: LoggingConfig {
507 level: "info".to_string(),
508 },
509 audio: AudioConfig::default(),
510 mesh: crate::config::MeshConfig::default(),
511 plugins: PluginConfig::default(),
512 agents: HashMap::new(),
513 default_agent: None,
514 }
515 }
516
517 fn create_test_profile() -> AgentProfile {
518 AgentProfile {
519 prompt: Some("Test system prompt".to_string()),
520 style: None,
521 temperature: Some(0.8),
522 model_provider: None,
523 model_name: None,
524 allowed_tools: None,
525 denied_tools: None,
526 memory_k: 10,
527 top_p: 0.95,
528 max_context_tokens: Some(4096),
529 enable_graph: false,
530 graph_memory: false,
531 auto_graph: false,
532 graph_steering: false,
533 graph_depth: 3,
534 graph_weight: 0.5,
535 graph_threshold: 0.7,
536 fast_reasoning: false,
537 fast_model_provider: None,
538 fast_model_name: None,
539 fast_model_temperature: 0.3,
540 fast_model_tasks: vec![],
541 escalation_threshold: 0.6,
542 show_reasoning: false,
543 enable_audio_transcription: false,
544 audio_response_mode: "immediate".to_string(),
545 audio_scenario: None,
546 }
547 }
548
549 #[test]
550 fn test_builder_with_all_fields() {
551 let dir = tempdir().unwrap();
552 let db_path = dir.path().join("test.duckdb");
553 let persistence = Persistence::new(&db_path).unwrap();
554
555 let profile = create_test_profile();
556 let provider = Arc::new(MockProvider::default());
557
558 let agent = AgentBuilder::new()
559 .with_profile(profile)
560 .with_provider(provider)
561 .with_persistence(persistence)
562 .with_session_id("test-session")
563 .build()
564 .unwrap();
565
566 assert_eq!(agent.session_id(), "test-session");
567 assert_eq!(
568 agent.profile().prompt,
569 Some("Test system prompt".to_string())
570 );
571 }
572
573 #[test]
574 fn test_builder_with_config() {
575 let config = create_test_config();
576 let profile = create_test_profile();
577
578 let agent = AgentBuilder::new()
579 .with_profile(profile)
580 .with_config(config)
581 .build()
582 .unwrap();
583
584 assert!(agent.session_id().starts_with("session-"));
586 }
587
588 #[test]
589 fn test_builder_missing_profile() {
590 let config = create_test_config();
591
592 let result = AgentBuilder::new().with_config(config).build();
593
594 assert!(result.is_err());
595 assert!(result.err().unwrap().to_string().contains("profile"));
596 }
597
598 #[test]
599 fn test_builder_missing_provider_and_config() {
600 let dir = tempdir().unwrap();
601 let db_path = dir.path().join("test.duckdb");
602 let persistence = Persistence::new(&db_path).unwrap();
603
604 let profile = create_test_profile();
605
606 let result = AgentBuilder::new()
607 .with_profile(profile)
608 .with_persistence(persistence)
609 .build();
610
611 assert!(result.is_err());
612 assert!(result
613 .err()
614 .unwrap()
615 .to_string()
616 .contains("provider or config"));
617 }
618
619 #[test]
620 fn test_builder_auto_session_id() {
621 let config = create_test_config();
622 let profile = create_test_profile();
623
624 let agent = AgentBuilder::new()
625 .with_profile(profile)
626 .with_config(config)
627 .build()
628 .unwrap();
629
630 assert!(!agent.session_id().is_empty());
632 }
633
634 #[test]
635 fn test_create_agent_from_registry() {
636 let dir = tempdir().unwrap();
637 let db_path = dir.path().join("test.duckdb");
638 let persistence = Persistence::new(&db_path).unwrap();
639
640 let config = create_test_config();
641 let profile = create_test_profile();
642
643 let mut agents = HashMap::new();
644 agents.insert("test-agent".to_string(), profile.clone());
645
646 let registry = AgentRegistry::new(agents, persistence.clone());
647 registry.set_active("test-agent").unwrap();
648
649 let agent =
650 create_agent_from_registry(®istry, &config, Some("custom-session".to_string()))
651 .unwrap();
652
653 assert_eq!(agent.session_id(), "custom-session");
654 assert_eq!(
655 agent.profile().prompt,
656 Some("Test system prompt".to_string())
657 );
658 }
659
660 #[test]
661 fn test_create_agent_from_registry_no_active() {
662 let dir = tempdir().unwrap();
663 let db_path = dir.path().join("test.duckdb");
664 let persistence = Persistence::new(&db_path).unwrap();
665
666 let config = create_test_config();
667 let registry = AgentRegistry::new(HashMap::new(), persistence);
668
669 let result = create_agent_from_registry(®istry, &config, None);
670
671 assert!(result.is_err());
672 let err_msg = result.err().unwrap().to_string();
673 assert!(err_msg.contains("No active") || err_msg.contains("active agent"));
674 }
675}