1use crate::agent::core::AgentCore;
6use crate::agent::factory::{create_provider, resolve_api_key};
7use crate::agent::model::{ModelProvider, ProviderKind};
8#[cfg(feature = "openai")]
9use crate::agent::providers::openai::OpenAIProvider;
10#[cfg(feature = "lmstudio")]
11use crate::agent::providers::LMStudioProvider;
12#[cfg(feature = "mlx")]
13use crate::agent::providers::MLXProvider;
14use crate::config::{AgentProfile, AgentRegistry, AppConfig, ModelConfig};
15use crate::embeddings::EmbeddingsClient;
16use crate::persistence::Persistence;
17use crate::policy::PolicyEngine;
18use crate::tools::ToolRegistry;
19use anyhow::{anyhow, Context, Result};
20#[cfg(any(feature = "mlx", feature = "lmstudio"))]
21use async_openai::config::OpenAIConfig;
22use std::sync::Arc;
23use tracing::{info, warn};
24
25pub struct AgentBuilder {
27 profile: Option<AgentProfile>,
28 provider: Option<Arc<dyn ModelProvider>>,
29 embeddings_client: Option<EmbeddingsClient>,
30 persistence: Option<Persistence>,
31 session_id: Option<String>,
32 config: Option<AppConfig>,
33 tool_registry: Option<Arc<ToolRegistry>>,
34 policy_engine: Option<Arc<PolicyEngine>>,
35 agent_name: Option<String>,
36}
37
38impl AgentBuilder {
39 pub fn new() -> Self {
41 Self {
42 profile: None,
43 provider: None,
44 embeddings_client: None,
45 persistence: None,
46 session_id: None,
47 config: None,
48 tool_registry: None,
49 policy_engine: None,
50 agent_name: None,
51 }
52 }
53
54 pub fn new_with_registry(
57 registry: &AgentRegistry,
58 config: &AppConfig,
59 session_id: Option<String>,
60 ) -> Result<AgentCore> {
61 create_agent_from_registry(registry, config, session_id)
62 }
63
64 pub fn with_profile(mut self, profile: AgentProfile) -> Self {
66 self.profile = Some(profile);
67 self
68 }
69
70 pub fn with_provider(mut self, provider: Arc<dyn ModelProvider>) -> Self {
72 self.provider = Some(provider);
73 self
74 }
75
76 pub fn with_embeddings_client(mut self, embeddings_client: EmbeddingsClient) -> Self {
78 self.embeddings_client = Some(embeddings_client);
79 self
80 }
81
82 pub fn with_persistence(mut self, persistence: Persistence) -> Self {
84 self.persistence = Some(persistence);
85 self
86 }
87
88 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
90 self.session_id = Some(session_id.into());
91 self
92 }
93
94 pub fn with_config(mut self, config: AppConfig) -> Self {
96 self.config = Some(config);
97 self
98 }
99
100 pub fn with_tool_registry(mut self, tool_registry: Arc<ToolRegistry>) -> Self {
102 self.tool_registry = Some(tool_registry);
103 self
104 }
105
106 pub fn with_policy_engine(mut self, policy_engine: Arc<PolicyEngine>) -> Self {
108 self.policy_engine = Some(policy_engine);
109 self
110 }
111
112 pub fn with_agent_name(mut self, agent_name: impl Into<String>) -> Self {
114 self.agent_name = Some(agent_name.into());
115 self
116 }
117
118 pub fn build(self) -> Result<AgentCore> {
120 let profile = self
122 .profile
123 .ok_or_else(|| anyhow!("Agent profile is required"))?;
124
125 let persistence = if let Some(persistence) = self.persistence {
127 persistence
128 } else if let Some(ref config) = self.config {
129 Persistence::new(&config.database.path).context("Failed to create persistence layer")?
130 } else {
131 return Err(anyhow!(
132 "Either persistence or config must be provided to build agent"
133 ));
134 };
135
136 let embeddings_client = if let Some(client) = self.embeddings_client {
138 Some(client)
139 } else if let Some(ref config) = self.config {
140 create_embeddings_client_from_config(config)?
141 } else {
142 None
143 };
144
145 let tool_registry = if let Some(registry) = self.tool_registry {
148 registry
149 } else {
150 let persistence_arc = Arc::new(persistence.clone());
151 let mut registry =
152 ToolRegistry::with_builtin_tools(Some(persistence_arc), embeddings_client.clone());
153 info!(
154 "Created tool registry with {} builtin tools",
155 registry.len()
156 );
157 for tool_name in registry.list() {
158 tracing::debug!(" - Registered tool: {}", tool_name);
159 }
160
161 if let Some(ref config) = self.config {
163 if config.plugins.enabled {
164 match registry.load_plugins(
165 &config.plugins.custom_tools_dir,
166 config.plugins.allow_override_builtin,
167 ) {
168 Ok(stats) => {
169 if stats.loaded > 0 {
170 info!(
171 "Loaded {} plugins with {} tools from {}",
172 stats.loaded,
173 stats.tools_loaded,
174 config.plugins.custom_tools_dir.display()
175 );
176 }
177 if stats.failed > 0 {
178 warn!("{} plugins failed to load", stats.failed);
179 }
180 }
181 Err(e) => {
182 if config.plugins.continue_on_error {
183 warn!("Plugin loading failed (continuing): {}", e);
184 } else {
185 return Err(anyhow::anyhow!("Plugin loading failed: {}", e));
187 }
188 }
189 }
190 }
191 }
192
193 Arc::new(registry)
194 };
195
196 let provider = if let Some(provider) = self.provider {
198 provider
199 } else if let Some(ref config) = self.config {
200 let mut base_provider =
201 create_provider(&config.model).context("Failed to create provider from config")?;
202
203 #[cfg(feature = "openai")]
205 {
206 if base_provider.kind() == ProviderKind::OpenAI {
207 let tools = tool_registry.to_openai_tools();
208 if !tools.is_empty() {
209 info!(
210 "Configuring OpenAI provider with {} tools for native function calling",
211 tools.len()
212 );
213
214 let api_key = if let Some(source) = &config.model.api_key_source {
216 resolve_api_key(source)?
217 } else {
218 std::env::var("OPENAI_API_KEY")
220 .context("OPENAI_API_KEY environment variable not set")?
221 };
222
223 let mut openai_provider = OpenAIProvider::with_api_key(api_key);
224
225 if let Some(model_name) = &config.model.model_name {
227 openai_provider = openai_provider.with_model(model_name.clone());
228 }
229
230 base_provider = Arc::new(openai_provider.with_tools(tools));
232 }
233 }
234 }
235
236 #[cfg(feature = "mlx")]
238 {
239 if base_provider.kind() == ProviderKind::MLX {
240 let tools = tool_registry.to_openai_tools();
241 if !tools.is_empty() {
242 info!(
243 "Configuring MLX provider with {} tools for native function calling",
244 tools.len()
245 );
246
247 let model_name = config
249 .model
250 .model_name
251 .as_ref()
252 .ok_or_else(|| {
253 anyhow!("MLX provider requires a model_name to be specified")
254 })?
255 .clone();
256
257 let mlx_provider = if let Ok(endpoint) = std::env::var("MLX_ENDPOINT") {
258 MLXProvider::with_endpoint(endpoint, model_name)
259 } else {
260 MLXProvider::new(model_name)
261 };
262
263 base_provider = Arc::new(mlx_provider.with_tools(tools));
264 }
265 }
266 }
267
268 #[cfg(feature = "lmstudio")]
269 {
270 if base_provider.kind() == ProviderKind::LMStudio {
271 let tools = tool_registry.to_openai_tools();
272 if !tools.is_empty() {
273 info!(
274 "Configuring LM Studio provider with {} tools for native function calling",
275 tools.len()
276 );
277
278 let model_name = config
279 .model
280 .model_name
281 .as_ref()
282 .ok_or_else(|| {
283 anyhow!("LM Studio provider requires a model_name to be specified")
284 })?
285 .clone();
286
287 let lmstudio_provider =
288 if let Ok(endpoint) = std::env::var("LMSTUDIO_ENDPOINT") {
289 LMStudioProvider::with_endpoint(endpoint, model_name)
290 } else {
291 LMStudioProvider::new(model_name)
292 };
293
294 base_provider = Arc::new(lmstudio_provider.with_tools(tools));
295 }
296 }
297 }
298
299 base_provider
300 } else {
301 return Err(anyhow!(
302 "Either provider or config must be provided to build agent"
303 ));
304 };
305
306 let session_id = self
308 .session_id
309 .unwrap_or_else(|| format!("session-{}", chrono::Utc::now().timestamp_millis()));
310
311 let policy_engine = if let Some(engine) = self.policy_engine {
313 engine
314 } else {
315 let mut engine = PolicyEngine::load_from_persistence(&persistence)
317 .unwrap_or_else(|_| PolicyEngine::new());
318
319 if engine.rule_count() == 0 {
321 tracing::debug!(
322 "Empty policy engine detected, adding default allow-all rule for tools"
323 );
324 engine.add_rule(crate::policy::PolicyRule {
325 agent: "*".to_string(),
326 action: "tool_call".to_string(),
327 resource: "*".to_string(),
328 effect: crate::policy::PolicyEffect::Allow,
329 });
330 }
331
332 Arc::new(engine)
333 };
334
335 let fast_provider = if profile.fast_reasoning {
336 match (&profile.fast_model_provider, &profile.fast_model_name) {
337 (Some(provider_name), Some(model_name)) => {
338 let fast_config = ModelConfig {
339 provider: provider_name.clone(),
340 model_name: Some(model_name.clone()),
341 embeddings_model: None,
342 api_key_source: None,
343 temperature: profile.fast_model_temperature,
344 };
345 match create_provider(&fast_config) {
346 Ok(provider) => Some(provider),
347 Err(err) => {
348 warn!(
349 "Failed to create fast provider {}:{} - {}",
350 provider_name, model_name, err
351 );
352 None
353 }
354 }
355 }
356 _ => None,
357 }
358 } else {
359 None
360 };
361
362 let mut agent = AgentCore::new(
363 profile,
364 provider,
365 embeddings_client,
366 persistence,
367 session_id,
368 self.agent_name,
369 tool_registry,
370 policy_engine,
371 );
372
373 if let Some(fast_provider) = fast_provider {
374 agent = agent.with_fast_provider(fast_provider);
375 }
376
377 Ok(agent)
378 }
379}
380
381impl Default for AgentBuilder {
382 fn default() -> Self {
383 Self::new()
384 }
385}
386
387pub fn create_agent_from_registry(
389 registry: &AgentRegistry,
390 config: &AppConfig,
391 session_id: Option<String>,
392) -> Result<AgentCore> {
393 let (agent_name, profile) = registry
394 .active()
395 .context("No active agent profile in registry")?
396 .ok_or_else(|| anyhow!("No active agent set in registry"))?;
397
398 let mut builder = AgentBuilder::new()
399 .with_profile(profile)
400 .with_config(config.clone())
401 .with_persistence(registry.persistence().clone())
402 .with_agent_name(agent_name.clone());
403
404 if let Some(sid) = session_id {
405 builder = builder.with_session_id(sid);
406 }
407
408 builder.build()
409}
410
411fn create_embeddings_client_from_config(config: &AppConfig) -> Result<Option<EmbeddingsClient>> {
412 let model = &config.model;
413 let Some(model_name) = &model.embeddings_model else {
414 return Ok(None);
415 };
416
417 #[cfg(feature = "mlx")]
418 {
419 if ProviderKind::from_str(&model.provider) == Some(ProviderKind::MLX) {
420 return Ok(Some(build_mlx_embeddings_client(model_name)));
421 }
422 }
423
424 #[cfg(feature = "lmstudio")]
425 {
426 if ProviderKind::from_str(&model.provider) == Some(ProviderKind::LMStudio) {
427 return Ok(Some(build_lmstudio_embeddings_client(model_name)));
428 }
429 }
430
431 let client = if let Some(source) = &model.api_key_source {
432 let api_key = resolve_api_key(source)?;
433 EmbeddingsClient::with_api_key(model_name.clone(), api_key)
434 } else {
435 EmbeddingsClient::new(model_name.clone())
436 };
437
438 Ok(Some(client))
439}
440
441#[cfg(feature = "mlx")]
442fn build_mlx_embeddings_client(model_name: &str) -> EmbeddingsClient {
443 let endpoint =
444 std::env::var("MLX_ENDPOINT").unwrap_or_else(|_| "http://localhost:10240".to_string());
445 let api_base = if endpoint.ends_with("/v1") {
446 endpoint
447 } else {
448 format!("{}/v1", endpoint)
449 };
450
451 let config = OpenAIConfig::new()
452 .with_api_base(api_base)
453 .with_api_key("mlx-key");
454
455 EmbeddingsClient::with_config(model_name.to_string(), config)
456}
457
458#[cfg(feature = "lmstudio")]
459fn build_lmstudio_embeddings_client(model_name: &str) -> EmbeddingsClient {
460 let endpoint =
461 std::env::var("LMSTUDIO_ENDPOINT").unwrap_or_else(|_| "http://localhost:1234".to_string());
462 let api_base = if endpoint.ends_with("/v1") {
463 endpoint
464 } else {
465 format!("{}/v1", endpoint)
466 };
467
468 let config = OpenAIConfig::new()
469 .with_api_base(api_base)
470 .with_api_key("lmstudio-key");
471
472 EmbeddingsClient::with_config(model_name.to_string(), config)
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478 use crate::agent::providers::MockProvider;
479 use crate::config::{
480 AgentProfile, AudioConfig, DatabaseConfig, LoggingConfig, ModelConfig, PluginConfig,
481 UiConfig,
482 };
483 use std::collections::HashMap;
484 use tempfile::tempdir;
485
486 fn create_test_config() -> AppConfig {
487 let dir = tempdir().unwrap();
488 let db_path = dir.path().join("test.duckdb");
489
490 AppConfig {
491 database: DatabaseConfig { path: db_path },
492 model: ModelConfig {
493 provider: "mock".to_string(),
494 model_name: Some("test-model".to_string()),
495 embeddings_model: None,
496 api_key_source: None,
497 temperature: 0.7,
498 },
499 ui: UiConfig {
500 prompt: "> ".to_string(),
501 theme: "default".to_string(),
502 },
503 logging: LoggingConfig {
504 level: "info".to_string(),
505 },
506 audio: AudioConfig::default(),
507 mesh: crate::config::MeshConfig::default(),
508 plugins: PluginConfig::default(),
509 agents: HashMap::new(),
510 default_agent: None,
511 }
512 }
513
514 fn create_test_profile() -> AgentProfile {
515 AgentProfile {
516 prompt: Some("Test system prompt".to_string()),
517 style: None,
518 temperature: Some(0.8),
519 model_provider: None,
520 model_name: None,
521 allowed_tools: None,
522 denied_tools: None,
523 memory_k: 10,
524 top_p: 0.95,
525 max_context_tokens: Some(4096),
526 enable_graph: false,
527 graph_memory: false,
528 auto_graph: false,
529 graph_steering: false,
530 graph_depth: 3,
531 graph_weight: 0.5,
532 graph_threshold: 0.7,
533 fast_reasoning: false,
534 fast_model_provider: None,
535 fast_model_name: None,
536 fast_model_temperature: 0.3,
537 fast_model_tasks: vec![],
538 escalation_threshold: 0.6,
539 show_reasoning: false,
540 enable_audio_transcription: false,
541 audio_response_mode: "immediate".to_string(),
542 audio_scenario: None,
543 }
544 }
545
546 #[test]
547 fn test_builder_with_all_fields() {
548 let dir = tempdir().unwrap();
549 let db_path = dir.path().join("test.duckdb");
550 let persistence = Persistence::new(&db_path).unwrap();
551
552 let profile = create_test_profile();
553 let provider = Arc::new(MockProvider::default());
554
555 let agent = AgentBuilder::new()
556 .with_profile(profile)
557 .with_provider(provider)
558 .with_persistence(persistence)
559 .with_session_id("test-session")
560 .build()
561 .unwrap();
562
563 assert_eq!(agent.session_id(), "test-session");
564 assert_eq!(
565 agent.profile().prompt,
566 Some("Test system prompt".to_string())
567 );
568 }
569
570 #[test]
571 fn test_builder_with_config() {
572 let config = create_test_config();
573 let profile = create_test_profile();
574
575 let agent = AgentBuilder::new()
576 .with_profile(profile)
577 .with_config(config)
578 .build()
579 .unwrap();
580
581 assert!(agent.session_id().starts_with("session-"));
583 }
584
585 #[test]
586 fn test_builder_missing_profile() {
587 let config = create_test_config();
588
589 let result = AgentBuilder::new().with_config(config).build();
590
591 assert!(result.is_err());
592 assert!(result.err().unwrap().to_string().contains("profile"));
593 }
594
595 #[test]
596 fn test_builder_missing_provider_and_config() {
597 let dir = tempdir().unwrap();
598 let db_path = dir.path().join("test.duckdb");
599 let persistence = Persistence::new(&db_path).unwrap();
600
601 let profile = create_test_profile();
602
603 let result = AgentBuilder::new()
604 .with_profile(profile)
605 .with_persistence(persistence)
606 .build();
607
608 assert!(result.is_err());
609 assert!(result
610 .err()
611 .unwrap()
612 .to_string()
613 .contains("provider or config"));
614 }
615
616 #[test]
617 fn test_builder_auto_session_id() {
618 let config = create_test_config();
619 let profile = create_test_profile();
620
621 let agent = AgentBuilder::new()
622 .with_profile(profile)
623 .with_config(config)
624 .build()
625 .unwrap();
626
627 assert!(!agent.session_id().is_empty());
629 }
630
631 #[test]
632 fn test_create_agent_from_registry() {
633 let dir = tempdir().unwrap();
634 let db_path = dir.path().join("test.duckdb");
635 let persistence = Persistence::new(&db_path).unwrap();
636
637 let config = create_test_config();
638 let profile = create_test_profile();
639
640 let mut agents = HashMap::new();
641 agents.insert("test-agent".to_string(), profile.clone());
642
643 let registry = AgentRegistry::new(agents, persistence.clone());
644 registry.set_active("test-agent").unwrap();
645
646 let agent =
647 create_agent_from_registry(®istry, &config, Some("custom-session".to_string()))
648 .unwrap();
649
650 assert_eq!(agent.session_id(), "custom-session");
651 assert_eq!(
652 agent.profile().prompt,
653 Some("Test system prompt".to_string())
654 );
655 }
656
657 #[test]
658 fn test_create_agent_from_registry_no_active() {
659 let dir = tempdir().unwrap();
660 let db_path = dir.path().join("test.duckdb");
661 let persistence = Persistence::new(&db_path).unwrap();
662
663 let config = create_test_config();
664 let registry = AgentRegistry::new(HashMap::new(), persistence);
665
666 let result = create_agent_from_registry(®istry, &config, None);
667
668 assert!(result.is_err());
669 let err_msg = result.err().unwrap().to_string();
670 assert!(err_msg.contains("No active") || err_msg.contains("active agent"));
671 }
672}