Skip to main content

systemprompt_models/ai/
provider_trait.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use futures::Stream;
4use std::collections::HashMap;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use super::execution_plan::PlanningResult;
9use super::request::{AiMessage, AiRequest};
10use super::response::{AiResponse, SearchGroundedResponse};
11use super::sampling::SamplingParams;
12use super::tools::{CallToolResult, McpTool, ToolCall};
13use crate::execution::context::RequestContext;
14use systemprompt_identifiers::AgentName;
15
16#[derive(Debug)]
17pub struct GenerateResponseParams<'a> {
18    pub messages: Vec<AiMessage>,
19    pub execution_summary: &'a str,
20    pub context: &'a RequestContext,
21    pub provider: Option<&'a str>,
22    pub model: Option<&'a str>,
23    pub max_output_tokens: Option<u32>,
24}
25
26#[derive(Debug)]
27pub struct GoogleSearchParams<'a> {
28    pub messages: Vec<AiMessage>,
29    pub sampling: Option<SamplingParams>,
30    pub max_output_tokens: u32,
31    pub model: Option<&'a str>,
32    pub urls: Option<Vec<String>>,
33    pub response_schema: Option<serde_json::Value>,
34}
35
36#[async_trait]
37pub trait AiProvider: Send + Sync {
38    fn default_provider(&self) -> &str;
39
40    fn default_model(&self) -> &str;
41
42    fn default_max_output_tokens(&self) -> u32;
43
44    async fn generate(&self, request: &AiRequest) -> Result<AiResponse>;
45
46    async fn generate_stream(
47        &self,
48        request: &AiRequest,
49    ) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>>;
50
51    async fn generate_with_tools(&self, request: &AiRequest) -> Result<AiResponse>;
52
53    async fn generate_with_tools_stream(
54        &self,
55        request: &AiRequest,
56    ) -> Result<Pin<Box<dyn Stream<Item = Result<String>> + Send>>>;
57
58    async fn generate_single_turn(
59        &self,
60        request: &AiRequest,
61    ) -> Result<(AiResponse, Vec<ToolCall>)>;
62
63    async fn execute_tools(
64        &self,
65        tool_calls: Vec<ToolCall>,
66        tools: &[McpTool],
67        context: &RequestContext,
68        agent_overrides: Option<&super::ToolModelOverrides>,
69    ) -> (Vec<ToolCall>, Vec<CallToolResult>);
70
71    async fn list_available_tools_for_agent(
72        &self,
73        agent_name: &AgentName,
74        context: &RequestContext,
75    ) -> Result<Vec<McpTool>>;
76
77    async fn generate_with_google_search(
78        &self,
79        params: GoogleSearchParams<'_>,
80    ) -> Result<SearchGroundedResponse>;
81
82    async fn health_check(&self) -> Result<HashMap<String, bool>>;
83
84    async fn generate_plan(
85        &self,
86        request: &AiRequest,
87        available_tools: &[McpTool],
88    ) -> Result<PlanningResult>;
89
90    async fn generate_response(&self, params: GenerateResponseParams<'_>) -> Result<String>;
91}
92
93pub type DynAiProvider = Arc<dyn AiProvider>;