1use crate::agent::core::AgentCore;
6use crate::agent::factory::{create_provider, resolve_api_key};
7use crate::agent::model::{ModelProvider, ProviderKind};
8#[cfg(feature = "openai")]
9use crate::agent::providers::openai::OpenAIProvider;
10#[cfg(feature = "lmstudio")]
11use crate::agent::providers::LMStudioProvider;
12#[cfg(feature = "mlx")]
13use crate::agent::providers::MLXProvider;
14use crate::config::{AgentProfile, AgentRegistry, AppConfig, ModelConfig};
15use crate::embeddings::EmbeddingsClient;
16use crate::persistence::Persistence;
17use crate::policy::PolicyEngine;
18use crate::tools::ToolRegistry;
19use anyhow::{anyhow, Context, Result};
20#[cfg(any(feature = "mlx", feature = "lmstudio"))]
21use async_openai::config::OpenAIConfig;
22use std::sync::Arc;
23use tracing::{info, warn};
24
25pub struct AgentBuilder {
27 profile: Option<AgentProfile>,
28 provider: Option<Arc<dyn ModelProvider>>,
29 embeddings_client: Option<EmbeddingsClient>,
30 persistence: Option<Persistence>,
31 session_id: Option<String>,
32 config: Option<AppConfig>,
33 tool_registry: Option<Arc<ToolRegistry>>,
34 policy_engine: Option<Arc<PolicyEngine>>,
35 agent_name: Option<String>,
36 speak_responses: bool,
37}
38
39impl AgentBuilder {
40 pub fn new() -> Self {
42 Self {
43 profile: None,
44 provider: None,
45 embeddings_client: None,
46 persistence: None,
47 session_id: None,
48 config: None,
49 tool_registry: None,
50 policy_engine: None,
51 agent_name: None,
52 speak_responses: false,
53 }
54 }
55
56 pub fn new_with_registry(
59 registry: &AgentRegistry,
60 config: &AppConfig,
61 session_id: Option<String>,
62 ) -> Result<AgentCore> {
63 create_agent_from_registry(registry, config, session_id)
64 }
65
66 pub fn with_profile(mut self, profile: AgentProfile) -> Self {
68 self.profile = Some(profile);
69 self
70 }
71
72 pub fn with_provider(mut self, provider: Arc<dyn ModelProvider>) -> Self {
74 self.provider = Some(provider);
75 self
76 }
77
78 pub fn with_embeddings_client(mut self, embeddings_client: EmbeddingsClient) -> Self {
80 self.embeddings_client = Some(embeddings_client);
81 self
82 }
83
84 pub fn with_persistence(mut self, persistence: Persistence) -> Self {
86 self.persistence = Some(persistence);
87 self
88 }
89
90 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
92 self.session_id = Some(session_id.into());
93 self
94 }
95
96 pub fn with_config(mut self, config: AppConfig) -> Self {
98 self.config = Some(config);
99 self
100 }
101
102 pub fn with_tool_registry(mut self, tool_registry: Arc<ToolRegistry>) -> Self {
104 self.tool_registry = Some(tool_registry);
105 self
106 }
107
108 pub fn with_policy_engine(mut self, policy_engine: Arc<PolicyEngine>) -> Self {
110 self.policy_engine = Some(policy_engine);
111 self
112 }
113
114 pub fn with_agent_name(mut self, agent_name: impl Into<String>) -> Self {
116 self.agent_name = Some(agent_name.into());
117 self
118 }
119
120 pub fn with_speak_responses(mut self, enabled: bool) -> Self {
122 self.speak_responses = enabled;
123 self
124 }
125
126 pub fn build(self) -> Result<AgentCore> {
128 let profile = self
130 .profile
131 .as_ref()
132 .cloned()
133 .ok_or_else(|| anyhow!("Agent profile is required"))?;
134
135 let session_id = self
137 .session_id
138 .clone()
139 .unwrap_or_else(|| format!("session-{}", chrono::Utc::now().timestamp_millis()));
140 let speak_preference = self.resolve_speech_preference();
141 let agent_name = self.agent_name.clone();
142
143 let persistence = if let Some(persistence) = self.persistence {
145 persistence
146 } else if let Some(ref config) = self.config {
147 Persistence::new(&config.database.path).context("Failed to create persistence layer")?
148 } else {
149 return Err(anyhow!(
150 "Either persistence or config must be provided to build agent"
151 ));
152 };
153
154 let embeddings_client = if let Some(client) = self.embeddings_client {
156 Some(client)
157 } else if let Some(ref config) = self.config {
158 create_embeddings_client_from_config(config)?
159 } else {
160 None
161 };
162
163 let tool_registry = if let Some(registry) = self.tool_registry {
166 registry
167 } else {
168 let persistence_arc = Arc::new(persistence.clone());
169 let mut registry =
170 ToolRegistry::with_builtin_tools(Some(persistence_arc), embeddings_client.clone());
171 info!(
172 "Created tool registry with {} builtin tools",
173 registry.len()
174 );
175 for tool_name in registry.list() {
176 tracing::debug!(" - Registered tool: {}", tool_name);
177 }
178
179 if let Some(ref config) = self.config {
181 if config.plugins.enabled {
182 match registry.load_plugins(
183 &config.plugins.custom_tools_dir,
184 config.plugins.allow_override_builtin,
185 ) {
186 Ok(stats) => {
187 if stats.loaded > 0 {
188 info!(
189 "Loaded {} plugins with {} tools from {}",
190 stats.loaded,
191 stats.tools_loaded,
192 config.plugins.custom_tools_dir.display()
193 );
194 }
195 if stats.failed > 0 {
196 warn!("{} plugins failed to load", stats.failed);
197 }
198 }
199 Err(e) => {
200 if config.plugins.continue_on_error {
201 warn!("Plugin loading failed (continuing): {}", e);
202 } else {
203 return Err(anyhow::anyhow!("Plugin loading failed: {}", e));
205 }
206 }
207 }
208 }
209 }
210
211 Arc::new(registry)
212 };
213
214 let provider = if let Some(provider) = self.provider {
216 provider
217 } else if let Some(ref config) = self.config {
218 let mut base_provider =
219 create_provider(&config.model).context("Failed to create provider from config")?;
220
221 #[cfg(feature = "openai")]
223 {
224 if base_provider.kind() == ProviderKind::OpenAI {
225 let tools = tool_registry.to_openai_tools();
226 if !tools.is_empty() {
227 info!(
228 "Configuring OpenAI provider with {} tools for native function calling",
229 tools.len()
230 );
231
232 let api_key = if let Some(source) = &config.model.api_key_source {
234 resolve_api_key(source)?
235 } else {
236 std::env::var("OPENAI_API_KEY")
238 .context("OPENAI_API_KEY environment variable not set")?
239 };
240
241 let mut openai_provider = OpenAIProvider::with_api_key(api_key);
242
243 if let Some(model_name) = &config.model.model_name {
245 openai_provider = openai_provider.with_model(model_name.clone());
246 }
247
248 base_provider = Arc::new(openai_provider.with_tools(tools));
250 }
251 }
252 }
253
254 #[cfg(feature = "mlx")]
256 {
257 if base_provider.kind() == ProviderKind::MLX {
258 let tools = tool_registry.to_openai_tools();
259 if !tools.is_empty() {
260 info!(
261 "Configuring MLX provider with {} tools for native function calling",
262 tools.len()
263 );
264
265 let model_name = config
267 .model
268 .model_name
269 .as_ref()
270 .ok_or_else(|| {
271 anyhow!("MLX provider requires a model_name to be specified")
272 })?
273 .clone();
274
275 let mlx_provider = if let Ok(endpoint) = std::env::var("MLX_ENDPOINT") {
276 MLXProvider::with_endpoint(endpoint, model_name)
277 } else {
278 MLXProvider::new(model_name)
279 };
280
281 base_provider = Arc::new(mlx_provider.with_tools(tools));
282 }
283 }
284 }
285
286 #[cfg(feature = "lmstudio")]
287 {
288 if base_provider.kind() == ProviderKind::LMStudio {
289 let tools = tool_registry.to_openai_tools();
290 if !tools.is_empty() {
291 info!(
292 "Configuring LM Studio provider with {} tools for native function calling",
293 tools.len()
294 );
295
296 let model_name = config
297 .model
298 .model_name
299 .as_ref()
300 .ok_or_else(|| {
301 anyhow!("LM Studio provider requires a model_name to be specified")
302 })?
303 .clone();
304
305 let lmstudio_provider =
306 if let Ok(endpoint) = std::env::var("LMSTUDIO_ENDPOINT") {
307 LMStudioProvider::with_endpoint(endpoint, model_name)
308 } else {
309 LMStudioProvider::new(model_name)
310 };
311
312 base_provider = Arc::new(lmstudio_provider.with_tools(tools));
313 }
314 }
315 }
316
317 base_provider
318 } else {
319 return Err(anyhow!(
320 "Either provider or config must be provided to build agent"
321 ));
322 };
323
324 let policy_engine = if let Some(engine) = self.policy_engine {
326 engine
327 } else {
328 let mut engine = PolicyEngine::load_from_persistence(&persistence)
330 .unwrap_or_else(|_| PolicyEngine::new());
331
332 if engine.rule_count() == 0 {
334 tracing::debug!(
335 "Empty policy engine detected, adding default allow-all rule for tools"
336 );
337 engine.add_rule(crate::policy::PolicyRule {
338 agent: "*".to_string(),
339 action: "tool_call".to_string(),
340 resource: "*".to_string(),
341 effect: crate::policy::PolicyEffect::Allow,
342 });
343 }
344
345 Arc::new(engine)
346 };
347
348 let fast_provider = if profile.fast_reasoning {
349 match (&profile.fast_model_provider, &profile.fast_model_name) {
350 (Some(provider_name), Some(model_name)) => {
351 let fast_config = ModelConfig {
352 provider: provider_name.clone(),
353 model_name: Some(model_name.clone()),
354 embeddings_model: None,
355 api_key_source: None,
356 temperature: profile.fast_model_temperature,
357 };
358 match create_provider(&fast_config) {
359 Ok(provider) => Some(provider),
360 Err(err) => {
361 warn!(
362 "Failed to create fast provider {}:{} - {}",
363 provider_name, model_name, err
364 );
365 None
366 }
367 }
368 }
369 _ => None,
370 }
371 } else {
372 None
373 };
374
375 let mut agent = AgentCore::new(
376 profile,
377 provider,
378 embeddings_client,
379 persistence,
380 session_id,
381 agent_name,
382 tool_registry,
383 policy_engine,
384 speak_preference,
385 );
386
387 if let Some(fast_provider) = fast_provider {
388 agent = agent.with_fast_provider(fast_provider);
389 }
390
391 Ok(agent)
392 }
393
394 fn resolve_speech_preference(&self) -> bool {
395 let config_pref = self
396 .config
397 .as_ref()
398 .map(|c| c.audio.speak_responses)
399 .unwrap_or(false);
400 let requested = self.speak_responses || config_pref;
401
402 #[cfg(target_os = "macos")]
403 {
404 requested
405 }
406
407 #[cfg(not(target_os = "macos"))]
408 {
409 let _ = requested; false
411 }
412 }
413}
414
415impl Default for AgentBuilder {
416 fn default() -> Self {
417 Self::new()
418 }
419}
420
421pub fn create_agent_from_registry(
423 registry: &AgentRegistry,
424 config: &AppConfig,
425 session_id: Option<String>,
426) -> Result<AgentCore> {
427 let (agent_name, profile) = registry
428 .active()
429 .context("No active agent profile in registry")?
430 .ok_or_else(|| anyhow!("No active agent set in registry"))?;
431
432 let mut builder = AgentBuilder::new()
433 .with_profile(profile)
434 .with_config(config.clone())
435 .with_persistence(registry.persistence().clone())
436 .with_agent_name(agent_name.clone());
437
438 if let Some(sid) = session_id {
439 builder = builder.with_session_id(sid);
440 }
441
442 builder.build()
443}
444
445fn create_embeddings_client_from_config(config: &AppConfig) -> Result<Option<EmbeddingsClient>> {
446 let model = &config.model;
447 let Some(model_name) = &model.embeddings_model else {
448 return Ok(None);
449 };
450
451 #[cfg(feature = "mlx")]
452 {
453 if ProviderKind::from_str(&model.provider) == Some(ProviderKind::MLX) {
454 return Ok(Some(build_mlx_embeddings_client(model_name)));
455 }
456 }
457
458 #[cfg(feature = "lmstudio")]
459 {
460 if ProviderKind::from_str(&model.provider) == Some(ProviderKind::LMStudio) {
461 return Ok(Some(build_lmstudio_embeddings_client(model_name)));
462 }
463 }
464
465 let client = if let Some(source) = &model.api_key_source {
466 let api_key = resolve_api_key(source)?;
467 EmbeddingsClient::with_api_key(model_name.clone(), api_key)
468 } else {
469 EmbeddingsClient::new(model_name.clone())
470 };
471
472 Ok(Some(client))
473}
474
475#[cfg(feature = "mlx")]
476fn build_mlx_embeddings_client(model_name: &str) -> EmbeddingsClient {
477 let endpoint =
478 std::env::var("MLX_ENDPOINT").unwrap_or_else(|_| "http://localhost:10240".to_string());
479 let api_base = if endpoint.ends_with("/v1") {
480 endpoint
481 } else {
482 format!("{}/v1", endpoint)
483 };
484
485 let config = OpenAIConfig::new()
486 .with_api_base(api_base)
487 .with_api_key("mlx-key");
488
489 EmbeddingsClient::with_config(model_name.to_string(), config)
490}
491
492#[cfg(feature = "lmstudio")]
493fn build_lmstudio_embeddings_client(model_name: &str) -> EmbeddingsClient {
494 let endpoint =
495 std::env::var("LMSTUDIO_ENDPOINT").unwrap_or_else(|_| "http://localhost:1234".to_string());
496 let api_base = if endpoint.ends_with("/v1") {
497 endpoint
498 } else {
499 format!("{}/v1", endpoint)
500 };
501
502 let config = OpenAIConfig::new()
503 .with_api_base(api_base)
504 .with_api_key("lmstudio-key");
505
506 EmbeddingsClient::with_config(model_name.to_string(), config)
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use crate::agent::providers::MockProvider;
513 use crate::config::{
514 AgentProfile, AudioConfig, DatabaseConfig, LoggingConfig, ModelConfig, PluginConfig,
515 SyncConfig, UiConfig,
516 };
517 use std::collections::HashMap;
518 use tempfile::tempdir;
519
520 fn create_test_config() -> AppConfig {
521 let dir = tempdir().unwrap();
522 let db_path = dir.path().join("test.duckdb");
523
524 AppConfig {
525 database: DatabaseConfig { path: db_path },
526 model: ModelConfig {
527 provider: "mock".to_string(),
528 model_name: Some("test-model".to_string()),
529 embeddings_model: None,
530 api_key_source: None,
531 temperature: 0.7,
532 },
533 ui: UiConfig {
534 prompt: "> ".to_string(),
535 theme: "default".to_string(),
536 },
537 logging: LoggingConfig {
538 level: "info".to_string(),
539 },
540 audio: AudioConfig::default(),
541 mesh: crate::config::MeshConfig::default(),
542 plugins: PluginConfig::default(),
543 sync: SyncConfig::default(),
544 agents: HashMap::new(),
545 default_agent: None,
546 }
547 }
548
549 fn create_test_profile() -> AgentProfile {
550 AgentProfile {
551 prompt: Some("Test system prompt".to_string()),
552 style: None,
553 temperature: Some(0.8),
554 model_provider: None,
555 model_name: None,
556 allowed_tools: None,
557 denied_tools: None,
558 memory_k: 10,
559 top_p: 0.95,
560 max_context_tokens: Some(4096),
561 enable_graph: false,
562 graph_memory: false,
563 auto_graph: false,
564 graph_steering: false,
565 graph_depth: 3,
566 graph_weight: 0.5,
567 graph_threshold: 0.7,
568 fast_reasoning: false,
569 fast_model_provider: None,
570 fast_model_name: None,
571 fast_model_temperature: 0.3,
572 fast_model_tasks: vec![],
573 escalation_threshold: 0.6,
574 show_reasoning: false,
575 enable_audio_transcription: false,
576 audio_response_mode: "immediate".to_string(),
577 audio_scenario: None,
578 }
579 }
580
581 #[test]
582 fn test_builder_with_all_fields() {
583 let dir = tempdir().unwrap();
584 let db_path = dir.path().join("test.duckdb");
585 let persistence = Persistence::new(&db_path).unwrap();
586
587 let profile = create_test_profile();
588 let provider = Arc::new(MockProvider::default());
589
590 let agent = AgentBuilder::new()
591 .with_profile(profile)
592 .with_provider(provider)
593 .with_persistence(persistence)
594 .with_session_id("test-session")
595 .build()
596 .unwrap();
597
598 assert_eq!(agent.session_id(), "test-session");
599 assert_eq!(
600 agent.profile().prompt,
601 Some("Test system prompt".to_string())
602 );
603 }
604
605 #[test]
606 fn test_builder_with_config() {
607 let config = create_test_config();
608 let profile = create_test_profile();
609
610 let agent = AgentBuilder::new()
611 .with_profile(profile)
612 .with_config(config)
613 .build()
614 .unwrap();
615
616 assert!(agent.session_id().starts_with("session-"));
618 }
619
620 #[test]
621 fn test_builder_missing_profile() {
622 let config = create_test_config();
623
624 let result = AgentBuilder::new().with_config(config).build();
625
626 assert!(result.is_err());
627 assert!(result.err().unwrap().to_string().contains("profile"));
628 }
629
630 #[test]
631 fn test_builder_missing_provider_and_config() {
632 let dir = tempdir().unwrap();
633 let db_path = dir.path().join("test.duckdb");
634 let persistence = Persistence::new(&db_path).unwrap();
635
636 let profile = create_test_profile();
637
638 let result = AgentBuilder::new()
639 .with_profile(profile)
640 .with_persistence(persistence)
641 .build();
642
643 assert!(result.is_err());
644 assert!(result
645 .err()
646 .unwrap()
647 .to_string()
648 .contains("provider or config"));
649 }
650
651 #[test]
652 fn test_builder_auto_session_id() {
653 let config = create_test_config();
654 let profile = create_test_profile();
655
656 let agent = AgentBuilder::new()
657 .with_profile(profile)
658 .with_config(config)
659 .build()
660 .unwrap();
661
662 assert!(!agent.session_id().is_empty());
664 }
665
666 #[test]
667 fn test_create_agent_from_registry() {
668 let dir = tempdir().unwrap();
669 let db_path = dir.path().join("test.duckdb");
670 let persistence = Persistence::new(&db_path).unwrap();
671
672 let config = create_test_config();
673 let profile = create_test_profile();
674
675 let mut agents = HashMap::new();
676 agents.insert("test-agent".to_string(), profile.clone());
677
678 let registry = AgentRegistry::new(agents, persistence.clone());
679 registry.set_active("test-agent").unwrap();
680
681 let agent =
682 create_agent_from_registry(®istry, &config, Some("custom-session".to_string()))
683 .unwrap();
684
685 assert_eq!(agent.session_id(), "custom-session");
686 assert_eq!(
687 agent.profile().prompt,
688 Some("Test system prompt".to_string())
689 );
690 }
691
692 #[test]
693 fn test_create_agent_from_registry_no_active() {
694 let dir = tempdir().unwrap();
695 let db_path = dir.path().join("test.duckdb");
696 let persistence = Persistence::new(&db_path).unwrap();
697
698 let config = create_test_config();
699 let registry = AgentRegistry::new(HashMap::new(), persistence);
700
701 let result = create_agent_from_registry(®istry, &config, None);
702
703 assert!(result.is_err());
704 let err_msg = result.err().unwrap().to_string();
705 assert!(err_msg.contains("No active") || err_msg.contains("active agent"));
706 }
707}