Skip to main content

systemprompt_provider_contracts/
llm.rs

1use async_trait::async_trait;
2use futures::stream::Stream;
3use serde::{Deserialize, Serialize};
4use serde_json::Value as JsonValue;
5use std::pin::Pin;
6use systemprompt_identifiers::{SessionId, TraceId};
7
8use crate::tool::{ToolCallRequest, ToolCallResult, ToolDefinition};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ChatMessage {
12    pub role: ChatRole,
13    pub content: String,
14}
15
16impl ChatMessage {
17    #[must_use]
18    pub fn user(content: impl Into<String>) -> Self {
19        Self {
20            role: ChatRole::User,
21            content: content.into(),
22        }
23    }
24
25    #[must_use]
26    pub fn assistant(content: impl Into<String>) -> Self {
27        Self {
28            role: ChatRole::Assistant,
29            content: content.into(),
30        }
31    }
32
33    #[must_use]
34    pub fn system(content: impl Into<String>) -> Self {
35        Self {
36            role: ChatRole::System,
37            content: content.into(),
38        }
39    }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
43#[serde(rename_all = "lowercase")]
44pub enum ChatRole {
45    System,
46    User,
47    Assistant,
48    Tool,
49}
50
51#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
52pub struct SamplingParameters {
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub temperature: Option<f32>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub top_p: Option<f32>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub top_k: Option<u32>,
59}
60
61impl SamplingParameters {
62    #[must_use]
63    pub const fn new() -> Self {
64        Self {
65            temperature: None,
66            top_p: None,
67            top_k: None,
68        }
69    }
70
71    #[must_use]
72    pub const fn with_temperature(mut self, temperature: f32) -> Self {
73        self.temperature = Some(temperature);
74        self
75    }
76
77    #[must_use]
78    pub const fn with_top_p(mut self, top_p: f32) -> Self {
79        self.top_p = Some(top_p);
80        self
81    }
82
83    #[must_use]
84    pub const fn with_top_k(mut self, top_k: u32) -> Self {
85        self.top_k = Some(top_k);
86        self
87    }
88}
89
90impl Default for SamplingParameters {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96#[derive(Debug, Clone)]
97pub struct ChatRequest {
98    pub messages: Vec<ChatMessage>,
99    pub model: String,
100    pub max_output_tokens: u32,
101    pub sampling: Option<SamplingParameters>,
102    pub tools: Option<Vec<ToolDefinition>>,
103    pub response_schema: Option<JsonValue>,
104}
105
106impl ChatRequest {
107    #[must_use]
108    pub fn new(
109        messages: Vec<ChatMessage>,
110        model: impl Into<String>,
111        max_output_tokens: u32,
112    ) -> Self {
113        Self {
114            messages,
115            model: model.into(),
116            max_output_tokens,
117            sampling: None,
118            tools: None,
119            response_schema: None,
120        }
121    }
122
123    #[must_use]
124    pub const fn with_sampling(mut self, sampling: SamplingParameters) -> Self {
125        self.sampling = Some(sampling);
126        self
127    }
128
129    #[must_use]
130    pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
131        self.tools = Some(tools);
132        self
133    }
134
135    #[must_use]
136    pub fn with_response_schema(mut self, schema: JsonValue) -> Self {
137        self.response_schema = Some(schema);
138        self
139    }
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct ChatResponse {
144    pub content: String,
145    pub tool_calls: Vec<ToolCallRequest>,
146    pub usage: Option<TokenUsage>,
147    pub model: String,
148    pub latency_ms: u64,
149}
150
151impl ChatResponse {
152    #[must_use]
153    pub fn new(content: impl Into<String>, model: impl Into<String>) -> Self {
154        Self {
155            content: content.into(),
156            tool_calls: vec![],
157            usage: None,
158            model: model.into(),
159            latency_ms: 0,
160        }
161    }
162
163    #[must_use]
164    pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCallRequest>) -> Self {
165        self.tool_calls = tool_calls;
166        self
167    }
168
169    #[must_use]
170    pub const fn with_usage(mut self, usage: TokenUsage) -> Self {
171        self.usage = Some(usage);
172        self
173    }
174
175    #[must_use]
176    pub const fn with_latency(mut self, latency_ms: u64) -> Self {
177        self.latency_ms = latency_ms;
178        self
179    }
180}
181
182#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
183pub struct TokenUsage {
184    #[serde(rename = "input_tokens")]
185    pub input: u32,
186    #[serde(rename = "output_tokens")]
187    pub output: u32,
188    #[serde(rename = "total_tokens")]
189    pub total: u32,
190    #[serde(rename = "cache_read_tokens")]
191    pub cache_read: Option<u32>,
192    #[serde(rename = "cache_creation_tokens")]
193    pub cache_creation: Option<u32>,
194}
195
196impl TokenUsage {
197    #[must_use]
198    pub const fn new(input: u32, output: u32) -> Self {
199        Self {
200            input,
201            output,
202            total: input + output,
203            cache_read: None,
204            cache_creation: None,
205        }
206    }
207
208    #[must_use]
209    pub const fn with_cache_read(mut self, cache_read: u32) -> Self {
210        self.cache_read = Some(cache_read);
211        self
212    }
213
214    #[must_use]
215    pub const fn with_cache_creation(mut self, cache_creation: u32) -> Self {
216        self.cache_creation = Some(cache_creation);
217        self
218    }
219}
220
221pub type ChatStream = Pin<Box<dyn Stream<Item = anyhow::Result<String>> + Send>>;
222
223#[derive(Debug, thiserror::Error)]
224pub enum LlmProviderError {
225    #[error("Model '{0}' not supported")]
226    ModelNotSupported(String),
227
228    #[error("Provider '{0}' not available")]
229    ProviderNotAvailable(String),
230
231    #[error("Rate limit exceeded")]
232    RateLimitExceeded,
233
234    #[error("Authentication failed: {0}")]
235    AuthenticationFailed(String),
236
237    #[error("Invalid request: {0}")]
238    InvalidRequest(String),
239
240    #[error("Generation failed: {0}")]
241    GenerationFailed(String),
242
243    #[error("Internal error: {0}")]
244    Internal(#[source] anyhow::Error),
245}
246
247impl From<anyhow::Error> for LlmProviderError {
248    fn from(err: anyhow::Error) -> Self {
249        Self::Internal(err)
250    }
251}
252
253pub type LlmProviderResult<T> = Result<T, LlmProviderError>;
254
255#[async_trait]
256pub trait LlmProvider: Send + Sync {
257    async fn chat(&self, request: &ChatRequest) -> LlmProviderResult<ChatResponse>;
258
259    async fn stream_chat(&self, request: &ChatRequest) -> LlmProviderResult<ChatStream>;
260
261    fn default_model(&self) -> &str;
262
263    fn supports_model(&self, model: &str) -> bool;
264
265    fn supports_streaming(&self) -> bool;
266
267    fn supports_tools(&self) -> bool;
268}
269
270#[async_trait]
271pub trait ToolExecutor: Send + Sync {
272    async fn execute(
273        &self,
274        tool_calls: Vec<ToolCallRequest>,
275        tools: &[ToolDefinition],
276        context: &ToolExecutionContext,
277    ) -> (Vec<ToolCallRequest>, Vec<ToolCallResult>);
278}
279
280#[derive(Debug, Clone)]
281pub struct ToolExecutionContext {
282    pub auth_token: String,
283    pub session_id: Option<SessionId>,
284    pub trace_id: Option<TraceId>,
285    pub model_overrides: Option<JsonValue>,
286}
287
288impl ToolExecutionContext {
289    #[must_use]
290    pub fn new(auth_token: impl Into<String>) -> Self {
291        Self {
292            auth_token: auth_token.into(),
293            session_id: None,
294            trace_id: None,
295            model_overrides: None,
296        }
297    }
298
299    #[must_use]
300    pub fn with_session_id(mut self, session_id: SessionId) -> Self {
301        self.session_id = Some(session_id);
302        self
303    }
304
305    #[must_use]
306    pub fn with_trace_id(mut self, trace_id: TraceId) -> Self {
307        self.trace_id = Some(trace_id);
308        self
309    }
310
311    #[must_use]
312    pub fn with_model_overrides(mut self, overrides: JsonValue) -> Self {
313        self.model_overrides = Some(overrides);
314        self
315    }
316}