spec_ai_core/agent/
builder.rs

1//! Agent Builder
2//!
3//! Provides a fluent API for constructing agent instances.
4
5use crate::agent::core::AgentCore;
6use crate::agent::factory::{create_provider, resolve_api_key};
7use crate::agent::model::{ModelProvider, ProviderKind};
8#[cfg(feature = "openai")]
9use crate::agent::providers::openai::OpenAIProvider;
10#[cfg(feature = "lmstudio")]
11use crate::agent::providers::LMStudioProvider;
12#[cfg(feature = "mlx")]
13use crate::agent::providers::MLXProvider;
14use crate::config::{AgentProfile, AgentRegistry, AppConfig, ModelConfig};
15use crate::embeddings::EmbeddingsClient;
16use crate::persistence::Persistence;
17use crate::policy::PolicyEngine;
18use crate::tools::ToolRegistry;
19use anyhow::{anyhow, Context, Result};
20#[cfg(any(feature = "mlx", feature = "lmstudio"))]
21use async_openai::config::OpenAIConfig;
22use std::sync::Arc;
23use tracing::{info, warn};
24
25/// Builder for constructing AgentCore instances
26pub struct AgentBuilder {
27    profile: Option<AgentProfile>,
28    provider: Option<Arc<dyn ModelProvider>>,
29    embeddings_client: Option<EmbeddingsClient>,
30    persistence: Option<Persistence>,
31    session_id: Option<String>,
32    config: Option<AppConfig>,
33    tool_registry: Option<Arc<ToolRegistry>>,
34    policy_engine: Option<Arc<PolicyEngine>>,
35    agent_name: Option<String>,
36}
37
38impl AgentBuilder {
39    /// Create a new agent builder
40    pub fn new() -> Self {
41        Self {
42            profile: None,
43            provider: None,
44            embeddings_client: None,
45            persistence: None,
46            session_id: None,
47            config: None,
48            tool_registry: None,
49            policy_engine: None,
50            agent_name: None,
51        }
52    }
53
54    /// Create an agent from the registry with the active profile
55    /// This is a convenience method for CLI use
56    pub fn new_with_registry(
57        registry: &AgentRegistry,
58        config: &AppConfig,
59        session_id: Option<String>,
60    ) -> Result<AgentCore> {
61        create_agent_from_registry(registry, config, session_id)
62    }
63
64    /// Set the agent profile
65    pub fn with_profile(mut self, profile: AgentProfile) -> Self {
66        self.profile = Some(profile);
67        self
68    }
69
70    /// Set the model provider
71    pub fn with_provider(mut self, provider: Arc<dyn ModelProvider>) -> Self {
72        self.provider = Some(provider);
73        self
74    }
75
76    /// Set a custom embeddings client
77    pub fn with_embeddings_client(mut self, embeddings_client: EmbeddingsClient) -> Self {
78        self.embeddings_client = Some(embeddings_client);
79        self
80    }
81
82    /// Set the persistence layer
83    pub fn with_persistence(mut self, persistence: Persistence) -> Self {
84        self.persistence = Some(persistence);
85        self
86    }
87
88    /// Set the session ID
89    pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
90        self.session_id = Some(session_id.into());
91        self
92    }
93
94    /// Set the application configuration (used to derive defaults)
95    pub fn with_config(mut self, config: AppConfig) -> Self {
96        self.config = Some(config);
97        self
98    }
99
100    /// Set the tool registry
101    pub fn with_tool_registry(mut self, tool_registry: Arc<ToolRegistry>) -> Self {
102        self.tool_registry = Some(tool_registry);
103        self
104    }
105
106    /// Set the policy engine
107    pub fn with_policy_engine(mut self, policy_engine: Arc<PolicyEngine>) -> Self {
108        self.policy_engine = Some(policy_engine);
109        self
110    }
111
112    /// Set the logical agent name (used for telemetry/logging)
113    pub fn with_agent_name(mut self, agent_name: impl Into<String>) -> Self {
114        self.agent_name = Some(agent_name.into());
115        self
116    }
117
118    /// Build the agent, validating all required fields
119    pub fn build(self) -> Result<AgentCore> {
120        // Get profile (required)
121        let profile = self
122            .profile
123            .ok_or_else(|| anyhow!("Agent profile is required"))?;
124
125        // Get or create persistence (needed for tool registry)
126        let persistence = if let Some(persistence) = self.persistence {
127            persistence
128        } else if let Some(ref config) = self.config {
129            Persistence::new(&config.database.path).context("Failed to create persistence layer")?
130        } else {
131            return Err(anyhow!(
132                "Either persistence or config must be provided to build agent"
133            ));
134        };
135
136        // Get or create embeddings client
137        let embeddings_client = if let Some(client) = self.embeddings_client {
138            Some(client)
139        } else if let Some(ref config) = self.config {
140            create_embeddings_client_from_config(config)?
141        } else {
142            None
143        };
144
145        // Get or create tool registry (defaults to built-in tools)
146        // Create this before the provider so OpenAI can be configured with tools
147        let tool_registry = self.tool_registry.unwrap_or_else(|| {
148            let persistence_arc = Arc::new(persistence.clone());
149            let registry =
150                ToolRegistry::with_builtin_tools(Some(persistence_arc), embeddings_client.clone());
151            info!(
152                "Created tool registry with {} builtin tools",
153                registry.len()
154            );
155            for tool_name in registry.list() {
156                tracing::debug!("  - Registered tool: {}", tool_name);
157            }
158            Arc::new(registry)
159        });
160
161        // Get or create provider with tools configured (for OpenAI-compatible providers)
162        let provider = if let Some(provider) = self.provider {
163            provider
164        } else if let Some(ref config) = self.config {
165            let mut base_provider =
166                create_provider(&config.model).context("Failed to create provider from config")?;
167
168            // Configure OpenAI provider with tools for native function calling
169            #[cfg(feature = "openai")]
170            {
171                if base_provider.kind() == ProviderKind::OpenAI {
172                    let tools = tool_registry.to_openai_tools();
173                    if !tools.is_empty() {
174                        info!(
175                            "Configuring OpenAI provider with {} tools for native function calling",
176                            tools.len()
177                        );
178
179                        // Recreate OpenAI provider with tools
180                        let api_key = if let Some(source) = &config.model.api_key_source {
181                            resolve_api_key(source)?
182                        } else {
183                            // Default to OPENAI_API_KEY environment variable
184                            std::env::var("OPENAI_API_KEY")
185                                .context("OPENAI_API_KEY environment variable not set")?
186                        };
187
188                        let mut openai_provider = OpenAIProvider::with_api_key(api_key);
189
190                        // Set model if specified in config
191                        if let Some(model_name) = &config.model.model_name {
192                            openai_provider = openai_provider.with_model(model_name.clone());
193                        }
194
195                        // Configure with tools and cast to trait object
196                        base_provider = Arc::new(openai_provider.with_tools(tools));
197                    }
198                }
199            }
200
201            // Configure MLX provider with tools for native function calling (OpenAI-compatible API)
202            #[cfg(feature = "mlx")]
203            {
204                if base_provider.kind() == ProviderKind::MLX {
205                    let tools = tool_registry.to_openai_tools();
206                    if !tools.is_empty() {
207                        info!(
208                            "Configuring MLX provider with {} tools for native function calling",
209                            tools.len()
210                        );
211
212                        // MLX requires a model name; mirror create_provider's behavior
213                        let model_name = config
214                            .model
215                            .model_name
216                            .as_ref()
217                            .ok_or_else(|| {
218                                anyhow!("MLX provider requires a model_name to be specified")
219                            })?
220                            .clone();
221
222                        let mlx_provider = if let Ok(endpoint) = std::env::var("MLX_ENDPOINT") {
223                            MLXProvider::with_endpoint(endpoint, model_name)
224                        } else {
225                            MLXProvider::new(model_name)
226                        };
227
228                        base_provider = Arc::new(mlx_provider.with_tools(tools));
229                    }
230                }
231            }
232
233            #[cfg(feature = "lmstudio")]
234            {
235                if base_provider.kind() == ProviderKind::LMStudio {
236                    let tools = tool_registry.to_openai_tools();
237                    if !tools.is_empty() {
238                        info!(
239                            "Configuring LM Studio provider with {} tools for native function calling",
240                            tools.len()
241                        );
242
243                        let model_name = config
244                            .model
245                            .model_name
246                            .as_ref()
247                            .ok_or_else(|| {
248                                anyhow!("LM Studio provider requires a model_name to be specified")
249                            })?
250                            .clone();
251
252                        let lmstudio_provider =
253                            if let Ok(endpoint) = std::env::var("LMSTUDIO_ENDPOINT") {
254                                LMStudioProvider::with_endpoint(endpoint, model_name)
255                            } else {
256                                LMStudioProvider::new(model_name)
257                            };
258
259                        base_provider = Arc::new(lmstudio_provider.with_tools(tools));
260                    }
261                }
262            }
263
264            base_provider
265        } else {
266            return Err(anyhow!(
267                "Either provider or config must be provided to build agent"
268            ));
269        };
270
271        // Get or generate session ID
272        let session_id = self
273            .session_id
274            .unwrap_or_else(|| format!("session-{}", chrono::Utc::now().timestamp_millis()));
275
276        // Get or create policy engine (defaults to empty policy engine, or load from persistence)
277        let policy_engine = if let Some(engine) = self.policy_engine {
278            engine
279        } else {
280            // Try to load from persistence, or create empty engine with default allow rule
281            let mut engine = PolicyEngine::load_from_persistence(&persistence)
282                .unwrap_or_else(|_| PolicyEngine::new());
283
284            // If the policy engine has no rules at all, add a default allow-all for tools
285            if engine.rule_count() == 0 {
286                tracing::debug!(
287                    "Empty policy engine detected, adding default allow-all rule for tools"
288                );
289                engine.add_rule(crate::policy::PolicyRule {
290                    agent: "*".to_string(),
291                    action: "tool_call".to_string(),
292                    resource: "*".to_string(),
293                    effect: crate::policy::PolicyEffect::Allow,
294                });
295            }
296
297            Arc::new(engine)
298        };
299
300        let fast_provider = if profile.fast_reasoning {
301            match (&profile.fast_model_provider, &profile.fast_model_name) {
302                (Some(provider_name), Some(model_name)) => {
303                    let fast_config = ModelConfig {
304                        provider: provider_name.clone(),
305                        model_name: Some(model_name.clone()),
306                        embeddings_model: None,
307                        api_key_source: None,
308                        temperature: profile.fast_model_temperature,
309                    };
310                    match create_provider(&fast_config) {
311                        Ok(provider) => Some(provider),
312                        Err(err) => {
313                            warn!(
314                                "Failed to create fast provider {}:{} - {}",
315                                provider_name, model_name, err
316                            );
317                            None
318                        }
319                    }
320                }
321                _ => None,
322            }
323        } else {
324            None
325        };
326
327        let mut agent = AgentCore::new(
328            profile,
329            provider,
330            embeddings_client,
331            persistence,
332            session_id,
333            self.agent_name,
334            tool_registry,
335            policy_engine,
336        );
337
338        if let Some(fast_provider) = fast_provider {
339            agent = agent.with_fast_provider(fast_provider);
340        }
341
342        Ok(agent)
343    }
344}
345
346impl Default for AgentBuilder {
347    fn default() -> Self {
348        Self::new()
349    }
350}
351
352/// Create an agent from the active profile in the registry
353pub fn create_agent_from_registry(
354    registry: &AgentRegistry,
355    config: &AppConfig,
356    session_id: Option<String>,
357) -> Result<AgentCore> {
358    let (agent_name, profile) = registry
359        .active()
360        .context("No active agent profile in registry")?
361        .ok_or_else(|| anyhow!("No active agent set in registry"))?;
362
363    let mut builder = AgentBuilder::new()
364        .with_profile(profile)
365        .with_config(config.clone())
366        .with_persistence(registry.persistence().clone())
367        .with_agent_name(agent_name.clone());
368
369    if let Some(sid) = session_id {
370        builder = builder.with_session_id(sid);
371    }
372
373    builder.build()
374}
375
376fn create_embeddings_client_from_config(config: &AppConfig) -> Result<Option<EmbeddingsClient>> {
377    let model = &config.model;
378    let Some(model_name) = &model.embeddings_model else {
379        return Ok(None);
380    };
381
382    #[cfg(feature = "mlx")]
383    {
384        if ProviderKind::from_str(&model.provider) == Some(ProviderKind::MLX) {
385            return Ok(Some(build_mlx_embeddings_client(model_name)));
386        }
387    }
388
389    #[cfg(feature = "lmstudio")]
390    {
391        if ProviderKind::from_str(&model.provider) == Some(ProviderKind::LMStudio) {
392            return Ok(Some(build_lmstudio_embeddings_client(model_name)));
393        }
394    }
395
396    let client = if let Some(source) = &model.api_key_source {
397        let api_key = resolve_api_key(source)?;
398        EmbeddingsClient::with_api_key(model_name.clone(), api_key)
399    } else {
400        EmbeddingsClient::new(model_name.clone())
401    };
402
403    Ok(Some(client))
404}
405
406#[cfg(feature = "mlx")]
407fn build_mlx_embeddings_client(model_name: &str) -> EmbeddingsClient {
408    let endpoint =
409        std::env::var("MLX_ENDPOINT").unwrap_or_else(|_| "http://localhost:10240".to_string());
410    let api_base = if endpoint.ends_with("/v1") {
411        endpoint
412    } else {
413        format!("{}/v1", endpoint)
414    };
415
416    let config = OpenAIConfig::new()
417        .with_api_base(api_base)
418        .with_api_key("mlx-key");
419
420    EmbeddingsClient::with_config(model_name.to_string(), config)
421}
422
423#[cfg(feature = "lmstudio")]
424fn build_lmstudio_embeddings_client(model_name: &str) -> EmbeddingsClient {
425    let endpoint =
426        std::env::var("LMSTUDIO_ENDPOINT").unwrap_or_else(|_| "http://localhost:1234".to_string());
427    let api_base = if endpoint.ends_with("/v1") {
428        endpoint
429    } else {
430        format!("{}/v1", endpoint)
431    };
432
433    let config = OpenAIConfig::new()
434        .with_api_base(api_base)
435        .with_api_key("lmstudio-key");
436
437    EmbeddingsClient::with_config(model_name.to_string(), config)
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    use crate::agent::providers::MockProvider;
444    use crate::config::{
445        AgentProfile, AudioConfig, DatabaseConfig, LoggingConfig, ModelConfig, UiConfig,
446    };
447    use std::collections::HashMap;
448    use tempfile::tempdir;
449
450    fn create_test_config() -> AppConfig {
451        let dir = tempdir().unwrap();
452        let db_path = dir.path().join("test.duckdb");
453
454        AppConfig {
455            database: DatabaseConfig { path: db_path },
456            model: ModelConfig {
457                provider: "mock".to_string(),
458                model_name: Some("test-model".to_string()),
459                embeddings_model: None,
460                api_key_source: None,
461                temperature: 0.7,
462            },
463            ui: UiConfig {
464                prompt: "> ".to_string(),
465                theme: "default".to_string(),
466            },
467            logging: LoggingConfig {
468                level: "info".to_string(),
469            },
470            audio: AudioConfig::default(),
471            mesh: crate::config::MeshConfig::default(),
472            agents: HashMap::new(),
473            default_agent: None,
474        }
475    }
476
477    fn create_test_profile() -> AgentProfile {
478        AgentProfile {
479            prompt: Some("Test system prompt".to_string()),
480            style: None,
481            temperature: Some(0.8),
482            model_provider: None,
483            model_name: None,
484            allowed_tools: None,
485            denied_tools: None,
486            memory_k: 10,
487            top_p: 0.95,
488            max_context_tokens: Some(4096),
489            enable_graph: false,
490            graph_memory: false,
491            auto_graph: false,
492            graph_steering: false,
493            graph_depth: 3,
494            graph_weight: 0.5,
495            graph_threshold: 0.7,
496            fast_reasoning: false,
497            fast_model_provider: None,
498            fast_model_name: None,
499            fast_model_temperature: 0.3,
500            fast_model_tasks: vec![],
501            escalation_threshold: 0.6,
502            show_reasoning: false,
503            enable_audio_transcription: false,
504            audio_response_mode: "immediate".to_string(),
505            audio_scenario: None,
506        }
507    }
508
509    #[test]
510    fn test_builder_with_all_fields() {
511        let dir = tempdir().unwrap();
512        let db_path = dir.path().join("test.duckdb");
513        let persistence = Persistence::new(&db_path).unwrap();
514
515        let profile = create_test_profile();
516        let provider = Arc::new(MockProvider::default());
517
518        let agent = AgentBuilder::new()
519            .with_profile(profile)
520            .with_provider(provider)
521            .with_persistence(persistence)
522            .with_session_id("test-session")
523            .build()
524            .unwrap();
525
526        assert_eq!(agent.session_id(), "test-session");
527        assert_eq!(
528            agent.profile().prompt,
529            Some("Test system prompt".to_string())
530        );
531    }
532
533    #[test]
534    fn test_builder_with_config() {
535        let config = create_test_config();
536        let profile = create_test_profile();
537
538        let agent = AgentBuilder::new()
539            .with_profile(profile)
540            .with_config(config)
541            .build()
542            .unwrap();
543
544        // Should auto-generate session ID with timestamp
545        assert!(agent.session_id().starts_with("session-"));
546    }
547
548    #[test]
549    fn test_builder_missing_profile() {
550        let config = create_test_config();
551
552        let result = AgentBuilder::new().with_config(config).build();
553
554        assert!(result.is_err());
555        assert!(result.err().unwrap().to_string().contains("profile"));
556    }
557
558    #[test]
559    fn test_builder_missing_provider_and_config() {
560        let dir = tempdir().unwrap();
561        let db_path = dir.path().join("test.duckdb");
562        let persistence = Persistence::new(&db_path).unwrap();
563
564        let profile = create_test_profile();
565
566        let result = AgentBuilder::new()
567            .with_profile(profile)
568            .with_persistence(persistence)
569            .build();
570
571        assert!(result.is_err());
572        assert!(result
573            .err()
574            .unwrap()
575            .to_string()
576            .contains("provider or config"));
577    }
578
579    #[test]
580    fn test_builder_auto_session_id() {
581        let config = create_test_config();
582        let profile = create_test_profile();
583
584        let agent = AgentBuilder::new()
585            .with_profile(profile)
586            .with_config(config)
587            .build()
588            .unwrap();
589
590        // Should auto-generate session ID with timestamp
591        assert!(!agent.session_id().is_empty());
592    }
593
594    #[test]
595    fn test_create_agent_from_registry() {
596        let dir = tempdir().unwrap();
597        let db_path = dir.path().join("test.duckdb");
598        let persistence = Persistence::new(&db_path).unwrap();
599
600        let config = create_test_config();
601        let profile = create_test_profile();
602
603        let mut agents = HashMap::new();
604        agents.insert("test-agent".to_string(), profile.clone());
605
606        let registry = AgentRegistry::new(agents, persistence.clone());
607        registry.set_active("test-agent").unwrap();
608
609        let agent =
610            create_agent_from_registry(&registry, &config, Some("custom-session".to_string()))
611                .unwrap();
612
613        assert_eq!(agent.session_id(), "custom-session");
614        assert_eq!(
615            agent.profile().prompt,
616            Some("Test system prompt".to_string())
617        );
618    }
619
620    #[test]
621    fn test_create_agent_from_registry_no_active() {
622        let dir = tempdir().unwrap();
623        let db_path = dir.path().join("test.duckdb");
624        let persistence = Persistence::new(&db_path).unwrap();
625
626        let config = create_test_config();
627        let registry = AgentRegistry::new(HashMap::new(), persistence);
628
629        let result = create_agent_from_registry(&registry, &config, None);
630
631        assert!(result.is_err());
632        let err_msg = result.err().unwrap().to_string();
633        assert!(err_msg.contains("No active") || err_msg.contains("active agent"));
634    }
635}