systemprompt_provider_contracts/
llm.rs1use async_trait::async_trait;
2use futures::stream::Stream;
3use serde::{Deserialize, Serialize};
4use serde_json::Value as JsonValue;
5use std::pin::Pin;
6use systemprompt_identifiers::{SessionId, TraceId};
7
8use crate::tool::{ToolCallRequest, ToolCallResult, ToolDefinition};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ChatMessage {
12 pub role: ChatRole,
13 pub content: String,
14}
15
16impl ChatMessage {
17 #[must_use]
18 pub fn user(content: impl Into<String>) -> Self {
19 Self {
20 role: ChatRole::User,
21 content: content.into(),
22 }
23 }
24
25 #[must_use]
26 pub fn assistant(content: impl Into<String>) -> Self {
27 Self {
28 role: ChatRole::Assistant,
29 content: content.into(),
30 }
31 }
32
33 #[must_use]
34 pub fn system(content: impl Into<String>) -> Self {
35 Self {
36 role: ChatRole::System,
37 content: content.into(),
38 }
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
43#[serde(rename_all = "lowercase")]
44pub enum ChatRole {
45 System,
46 User,
47 Assistant,
48 Tool,
49}
50
51#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
52pub struct SamplingParameters {
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub temperature: Option<f32>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub top_p: Option<f32>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub top_k: Option<u32>,
59}
60
61impl SamplingParameters {
62 #[must_use]
63 pub const fn new() -> Self {
64 Self {
65 temperature: None,
66 top_p: None,
67 top_k: None,
68 }
69 }
70
71 #[must_use]
72 pub const fn with_temperature(mut self, temperature: f32) -> Self {
73 self.temperature = Some(temperature);
74 self
75 }
76
77 #[must_use]
78 pub const fn with_top_p(mut self, top_p: f32) -> Self {
79 self.top_p = Some(top_p);
80 self
81 }
82
83 #[must_use]
84 pub const fn with_top_k(mut self, top_k: u32) -> Self {
85 self.top_k = Some(top_k);
86 self
87 }
88}
89
90impl Default for SamplingParameters {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96#[derive(Debug, Clone)]
97pub struct ChatRequest {
98 pub messages: Vec<ChatMessage>,
99 pub model: String,
100 pub max_output_tokens: u32,
101 pub sampling: Option<SamplingParameters>,
102 pub tools: Option<Vec<ToolDefinition>>,
103 pub response_schema: Option<JsonValue>,
104}
105
106impl ChatRequest {
107 #[must_use]
108 pub fn new(
109 messages: Vec<ChatMessage>,
110 model: impl Into<String>,
111 max_output_tokens: u32,
112 ) -> Self {
113 Self {
114 messages,
115 model: model.into(),
116 max_output_tokens,
117 sampling: None,
118 tools: None,
119 response_schema: None,
120 }
121 }
122
123 #[must_use]
124 pub const fn with_sampling(mut self, sampling: SamplingParameters) -> Self {
125 self.sampling = Some(sampling);
126 self
127 }
128
129 #[must_use]
130 pub fn with_tools(mut self, tools: Vec<ToolDefinition>) -> Self {
131 self.tools = Some(tools);
132 self
133 }
134
135 #[must_use]
136 pub fn with_response_schema(mut self, schema: JsonValue) -> Self {
137 self.response_schema = Some(schema);
138 self
139 }
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct ChatResponse {
144 pub content: String,
145 pub tool_calls: Vec<ToolCallRequest>,
146 pub usage: Option<TokenUsage>,
147 pub model: String,
148 pub latency_ms: u64,
149}
150
151impl ChatResponse {
152 #[must_use]
153 pub fn new(content: impl Into<String>, model: impl Into<String>) -> Self {
154 Self {
155 content: content.into(),
156 tool_calls: vec![],
157 usage: None,
158 model: model.into(),
159 latency_ms: 0,
160 }
161 }
162
163 #[must_use]
164 pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCallRequest>) -> Self {
165 self.tool_calls = tool_calls;
166 self
167 }
168
169 #[must_use]
170 pub const fn with_usage(mut self, usage: TokenUsage) -> Self {
171 self.usage = Some(usage);
172 self
173 }
174
175 #[must_use]
176 pub const fn with_latency(mut self, latency_ms: u64) -> Self {
177 self.latency_ms = latency_ms;
178 self
179 }
180}
181
182#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
183pub struct TokenUsage {
184 #[serde(rename = "input_tokens")]
185 pub input: u32,
186 #[serde(rename = "output_tokens")]
187 pub output: u32,
188 #[serde(rename = "total_tokens")]
189 pub total: u32,
190 #[serde(rename = "cache_read_tokens")]
191 pub cache_read: Option<u32>,
192 #[serde(rename = "cache_creation_tokens")]
193 pub cache_creation: Option<u32>,
194}
195
196impl TokenUsage {
197 #[must_use]
198 pub const fn new(input: u32, output: u32) -> Self {
199 Self {
200 input,
201 output,
202 total: input + output,
203 cache_read: None,
204 cache_creation: None,
205 }
206 }
207
208 #[must_use]
209 pub const fn with_cache_read(mut self, cache_read: u32) -> Self {
210 self.cache_read = Some(cache_read);
211 self
212 }
213
214 #[must_use]
215 pub const fn with_cache_creation(mut self, cache_creation: u32) -> Self {
216 self.cache_creation = Some(cache_creation);
217 self
218 }
219}
220
221pub type ChatStream = Pin<Box<dyn Stream<Item = anyhow::Result<String>> + Send>>;
222
223#[derive(Debug, thiserror::Error)]
224pub enum LlmProviderError {
225 #[error("Model '{0}' not supported")]
226 ModelNotSupported(String),
227
228 #[error("Provider '{0}' not available")]
229 ProviderNotAvailable(String),
230
231 #[error("Rate limit exceeded")]
232 RateLimitExceeded,
233
234 #[error("Authentication failed: {0}")]
235 AuthenticationFailed(String),
236
237 #[error("Invalid request: {0}")]
238 InvalidRequest(String),
239
240 #[error("Generation failed: {0}")]
241 GenerationFailed(String),
242
243 #[error("Internal error: {0}")]
244 Internal(#[source] anyhow::Error),
245}
246
247impl From<anyhow::Error> for LlmProviderError {
248 fn from(err: anyhow::Error) -> Self {
249 Self::Internal(err)
250 }
251}
252
253pub type LlmProviderResult<T> = Result<T, LlmProviderError>;
254
255#[async_trait]
256pub trait LlmProvider: Send + Sync {
257 async fn chat(&self, request: &ChatRequest) -> LlmProviderResult<ChatResponse>;
258
259 async fn stream_chat(&self, request: &ChatRequest) -> LlmProviderResult<ChatStream>;
260
261 fn default_model(&self) -> &str;
262
263 fn supports_model(&self, model: &str) -> bool;
264
265 fn supports_streaming(&self) -> bool;
266
267 fn supports_tools(&self) -> bool;
268}
269
270#[async_trait]
271pub trait ToolExecutor: Send + Sync {
272 async fn execute(
273 &self,
274 tool_calls: Vec<ToolCallRequest>,
275 tools: &[ToolDefinition],
276 context: &ToolExecutionContext,
277 ) -> (Vec<ToolCallRequest>, Vec<ToolCallResult>);
278}
279
280#[derive(Debug, Clone)]
281pub struct ToolExecutionContext {
282 pub auth_token: String,
283 pub session_id: Option<SessionId>,
284 pub trace_id: Option<TraceId>,
285 pub model_overrides: Option<JsonValue>,
286}
287
288impl ToolExecutionContext {
289 #[must_use]
290 pub fn new(auth_token: impl Into<String>) -> Self {
291 Self {
292 auth_token: auth_token.into(),
293 session_id: None,
294 trace_id: None,
295 model_overrides: None,
296 }
297 }
298
299 #[must_use]
300 pub fn with_session_id(mut self, session_id: SessionId) -> Self {
301 self.session_id = Some(session_id);
302 self
303 }
304
305 #[must_use]
306 pub fn with_trace_id(mut self, trace_id: TraceId) -> Self {
307 self.trace_id = Some(trace_id);
308 self
309 }
310
311 #[must_use]
312 pub fn with_model_overrides(mut self, overrides: JsonValue) -> Self {
313 self.model_overrides = Some(overrides);
314 self
315 }
316}