Skip to main content

stakpak_api/
models.rs

1use chrono::{DateTime, Utc};
2use rmcp::model::Content;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use stakai::Model;
6use stakpak_shared::models::{
7    integrations::openai::{ChatMessage, FunctionCall, MessageContent, Role, Tool, ToolCall},
8    llm::{LLMInput, LLMMessage, LLMMessageContent, LLMMessageTypedContent, LLMTokenUsage},
9};
10use uuid::Uuid;
11
12#[derive(Debug, Clone, Deserialize, Serialize)]
13pub enum ApiStreamError {
14    AgentInputInvalid(String),
15    AgentStateInvalid,
16    AgentNotSupported,
17    AgentExecutionLimitExceeded,
18    AgentInvalidResponseStream,
19    InvalidGeneratedCode,
20    CopilotError,
21    SaveError,
22    Unknown(String),
23}
24
25impl From<&str> for ApiStreamError {
26    fn from(error_str: &str) -> Self {
27        match error_str {
28            s if s.contains("Agent not supported") => ApiStreamError::AgentNotSupported,
29            s if s.contains("Agent state is not valid") => ApiStreamError::AgentStateInvalid,
30            s if s.contains("Agent thinking limit exceeded") => {
31                ApiStreamError::AgentExecutionLimitExceeded
32            }
33            s if s.contains("Invalid response stream") => {
34                ApiStreamError::AgentInvalidResponseStream
35            }
36            s if s.contains("Invalid generated code") => ApiStreamError::InvalidGeneratedCode,
37            s if s.contains(
38                "Our copilot is handling too many requests at this time, please try again later.",
39            ) =>
40            {
41                ApiStreamError::CopilotError
42            }
43            s if s
44                .contains("An error occurred while saving your data. Please try again later.") =>
45            {
46                ApiStreamError::SaveError
47            }
48            s if s.contains("Agent input is not valid: ") => {
49                ApiStreamError::AgentInputInvalid(s.replace("Agent input is not valid: ", ""))
50            }
51            _ => ApiStreamError::Unknown(error_str.to_string()),
52        }
53    }
54}
55
56impl From<String> for ApiStreamError {
57    fn from(error_str: String) -> Self {
58        ApiStreamError::from(error_str.as_str())
59    }
60}
61
62#[derive(Deserialize, Serialize, Debug)]
63pub struct Document {
64    pub content: String,
65    pub uri: String,
66    pub provisioner: ProvisionerType,
67}
68
69#[derive(Deserialize, Serialize, Debug)]
70pub struct SimpleDocument {
71    pub uri: String,
72    pub content: String,
73}
74
75#[derive(Deserialize, Serialize, Debug, Clone)]
76pub struct Block {
77    pub id: Uuid,
78    pub provider: String,
79    pub provisioner: ProvisionerType,
80    pub language: String,
81    pub key: String,
82    pub digest: u64,
83    pub references: Vec<Vec<Segment>>,
84    pub kind: String,
85    pub r#type: Option<String>,
86    pub name: Option<String>,
87    pub config: serde_json::Value,
88    pub document_uri: String,
89    pub code: String,
90    pub start_byte: usize,
91    pub end_byte: usize,
92    pub start_point: Point,
93    pub end_point: Point,
94    pub state: Option<serde_json::Value>,
95    pub updated_at: Option<DateTime<Utc>>,
96    pub created_at: Option<DateTime<Utc>>,
97    pub dependents: Vec<DependentBlock>,
98    pub dependencies: Vec<Dependency>,
99    pub api_group_version: Option<ApiGroupVersion>,
100
101    pub generated_summary: Option<String>,
102}
103
104impl Block {
105    pub fn get_uri(&self) -> String {
106        format!(
107            "{}#L{}-L{}",
108            self.document_uri, self.start_point.row, self.end_point.row
109        )
110    }
111}
112
113#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)]
114pub enum ProvisionerType {
115    #[serde(rename = "Terraform")]
116    Terraform,
117    #[serde(rename = "Kubernetes")]
118    Kubernetes,
119    #[serde(rename = "Dockerfile")]
120    Dockerfile,
121    #[serde(rename = "GithubActions")]
122    GithubActions,
123    #[serde(rename = "None")]
124    None,
125}
126impl std::str::FromStr for ProvisionerType {
127    type Err = String;
128
129    fn from_str(s: &str) -> Result<Self, Self::Err> {
130        match s.to_lowercase().as_str() {
131            "terraform" => Ok(Self::Terraform),
132            "kubernetes" => Ok(Self::Kubernetes),
133            "dockerfile" => Ok(Self::Dockerfile),
134            "github-actions" => Ok(Self::GithubActions),
135            _ => Ok(Self::None),
136        }
137    }
138}
139impl std::fmt::Display for ProvisionerType {
140    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
141        match self {
142            ProvisionerType::Terraform => write!(f, "terraform"),
143            ProvisionerType::Kubernetes => write!(f, "kubernetes"),
144            ProvisionerType::Dockerfile => write!(f, "dockerfile"),
145            ProvisionerType::GithubActions => write!(f, "github-actions"),
146            ProvisionerType::None => write!(f, "none"),
147        }
148    }
149}
150
151#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
152#[serde(untagged)]
153pub enum Segment {
154    Key(String),
155    Index(usize),
156}
157
158impl std::fmt::Display for Segment {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        match self {
161            Segment::Key(key) => write!(f, "{}", key),
162            Segment::Index(index) => write!(f, "{}", index),
163        }
164    }
165}
166impl std::fmt::Debug for Segment {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        match self {
169            Segment::Key(key) => write!(f, "{}", key),
170            Segment::Index(index) => write!(f, "{}", index),
171        }
172    }
173}
174
175#[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq)]
176pub struct Point {
177    pub row: usize,
178    pub column: usize,
179}
180
181#[derive(Deserialize, Serialize, Debug, Clone)]
182pub struct DependentBlock {
183    pub key: String,
184}
185
186#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
187pub struct Dependency {
188    pub id: Option<Uuid>,
189    pub expression: Option<String>,
190    pub from_path: Option<Vec<Segment>>,
191    pub to_path: Option<Vec<Segment>>,
192    #[serde(default = "Vec::new")]
193    pub selectors: Vec<DependencySelector>,
194    #[serde(skip_serializing)]
195    pub key: Option<String>,
196    pub digest: Option<u64>,
197    #[serde(default = "Vec::new")]
198    pub from: Vec<Segment>,
199    pub from_field: Option<Vec<Segment>>,
200    pub to_field: Option<Vec<Segment>>,
201    pub start_byte: Option<usize>,
202    pub end_byte: Option<usize>,
203    pub start_point: Option<Point>,
204    pub end_point: Option<Point>,
205    pub satisfied: bool,
206}
207
208#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
209pub struct DependencySelector {
210    pub references: Vec<Vec<Segment>>,
211    pub operator: DependencySelectorOperator,
212}
213
214#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
215pub enum DependencySelectorOperator {
216    Equals,
217    NotEquals,
218    In,
219    NotIn,
220    Exists,
221    DoesNotExist,
222}
223
224#[derive(Serialize, Deserialize, Debug, Clone)]
225pub struct ApiGroupVersion {
226    pub alias: String,
227    pub group: String,
228    pub version: String,
229    pub provisioner: ProvisionerType,
230    pub status: APIGroupVersionStatus,
231}
232
233#[derive(Serialize, Deserialize, Debug, Clone)]
234pub enum APIGroupVersionStatus {
235    #[serde(rename = "UNAVAILABLE")]
236    Unavailable,
237    #[serde(rename = "PENDING")]
238    Pending,
239    #[serde(rename = "AVAILABLE")]
240    Available,
241}
242
243#[derive(Serialize, Deserialize, Debug)]
244pub struct BuildCodeIndexInput {
245    pub documents: Vec<SimpleDocument>,
246}
247
248#[derive(Serialize, Deserialize, Debug, Clone)]
249pub struct IndexError {
250    pub uri: String,
251    pub message: String,
252    pub details: Option<serde_json::Value>,
253}
254
255#[derive(Serialize, Deserialize, Debug, Clone)]
256pub struct BuildCodeIndexOutput {
257    pub blocks: Vec<Block>,
258    pub errors: Vec<IndexError>,
259    pub warnings: Vec<IndexError>,
260}
261
262#[derive(Serialize, Deserialize, Debug, Clone)]
263pub struct CodeIndex {
264    pub last_updated: DateTime<Utc>,
265    pub index: BuildCodeIndexOutput,
266}
267
268#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
269#[serde(rename_all = "UPPERCASE")]
270pub enum RuleBookVisibility {
271    #[default]
272    Public,
273    Private,
274}
275
276#[derive(Serialize, Deserialize, Debug, Clone)]
277pub struct RuleBook {
278    pub id: String,
279    pub uri: String,
280    pub description: String,
281    pub content: String,
282    pub visibility: RuleBookVisibility,
283    pub tags: Vec<String>,
284    pub created_at: Option<DateTime<Utc>>,
285    pub updated_at: Option<DateTime<Utc>>,
286}
287
288#[derive(Serialize, Deserialize, Debug)]
289pub struct ToolsCallParams {
290    pub name: String,
291    pub arguments: Value,
292}
293
294#[derive(Serialize, Deserialize, Debug)]
295pub struct ToolsCallResponse {
296    pub content: Vec<Content>,
297}
298
299#[derive(Serialize, Deserialize, Debug, Clone)]
300pub struct APIKeyScope {
301    pub r#type: String,
302    pub name: String,
303}
304
305impl std::fmt::Display for APIKeyScope {
306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        write!(f, "{} ({})", self.name, self.r#type)
308    }
309}
310
311#[derive(Serialize, Deserialize, Clone, Debug)]
312pub struct GetMyAccountResponse {
313    pub username: String,
314    pub id: String,
315    pub first_name: String,
316    pub last_name: String,
317    pub email: String,
318    pub scope: Option<APIKeyScope>,
319}
320
321impl GetMyAccountResponse {
322    pub fn to_text(&self) -> String {
323        format!(
324            "ID: {}\nUsername: {}\nName: {} {}\nEmail: {}",
325            self.id, self.username, self.first_name, self.last_name, self.email
326        )
327    }
328}
329
330#[derive(Serialize, Deserialize, Debug, Clone)]
331pub struct ListRuleBook {
332    pub id: String,
333    pub uri: String,
334    pub description: String,
335    pub visibility: RuleBookVisibility,
336    pub tags: Vec<String>,
337    pub created_at: Option<DateTime<Utc>>,
338    pub updated_at: Option<DateTime<Utc>>,
339}
340
341#[derive(Serialize, Deserialize, Debug)]
342pub struct ListRulebooksResponse {
343    pub results: Vec<ListRuleBook>,
344}
345
346#[derive(Serialize, Deserialize, Debug)]
347pub struct CreateRuleBookInput {
348    pub uri: String,
349    pub description: String,
350    pub content: String,
351    pub tags: Vec<String>,
352    #[serde(skip_serializing_if = "Option::is_none")]
353    pub visibility: Option<RuleBookVisibility>,
354}
355
356#[derive(Serialize, Deserialize, Debug)]
357pub struct CreateRuleBookResponse {
358    pub id: String,
359}
360
361impl ListRuleBook {
362    pub fn to_text(&self) -> String {
363        format!(
364            "URI: {}\nDescription: {}\nTags: {}\n",
365            self.uri,
366            self.description,
367            self.tags.join(", ")
368        )
369    }
370}
371
372#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
373pub struct SimpleLLMMessage {
374    #[serde(rename = "role")]
375    pub role: SimpleLLMRole,
376    pub content: String,
377}
378
379#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
380#[serde(rename_all = "lowercase")]
381pub enum SimpleLLMRole {
382    User,
383    Assistant,
384}
385
386impl std::fmt::Display for SimpleLLMRole {
387    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388        match self {
389            SimpleLLMRole::User => write!(f, "user"),
390            SimpleLLMRole::Assistant => write!(f, "assistant"),
391        }
392    }
393}
394
395#[derive(Debug, Deserialize, Serialize)]
396pub struct SearchDocsRequest {
397    pub keywords: String,
398    pub exclude_keywords: Option<String>,
399    pub limit: Option<u32>,
400}
401
402#[derive(Debug, Deserialize, Serialize)]
403pub struct SearchMemoryRequest {
404    pub keywords: Vec<String>,
405    pub start_time: Option<DateTime<Utc>>,
406    pub end_time: Option<DateTime<Utc>>,
407}
408
409#[derive(Debug, Deserialize, Serialize)]
410pub struct SlackReadMessagesRequest {
411    pub channel: String,
412    pub limit: Option<u32>,
413}
414
415#[derive(Debug, Deserialize, Serialize)]
416pub struct SlackReadRepliesRequest {
417    pub channel: String,
418    pub ts: String,
419}
420
421#[derive(Debug, Deserialize, Serialize)]
422pub struct SlackSendMessageRequest {
423    pub channel: String,
424    pub markdown_text: String,
425    pub thread_ts: Option<String>,
426}
427
428#[derive(Debug, Clone, Default, Serialize)]
429pub struct AgentState {
430    /// The active model to use for inference
431    pub active_model: Model,
432    pub messages: Vec<ChatMessage>,
433    pub tools: Option<Vec<Tool>>,
434
435    pub llm_input: Option<LLMInput>,
436    pub llm_output: Option<LLMOutput>,
437
438    /// Metadata for checkpoint persistence (context trimming state, etc.)
439    /// Loaded from checkpoint on session resume and saved back after inference
440    pub metadata: Option<Value>,
441}
442
443#[derive(Debug, Clone, Default, Serialize)]
444pub struct LLMOutput {
445    pub new_message: LLMMessage,
446    pub usage: LLMTokenUsage,
447}
448
449impl From<&LLMOutput> for ChatMessage {
450    fn from(value: &LLMOutput) -> Self {
451        let message_content = match &value.new_message.content {
452            LLMMessageContent::String(s) => s.clone(),
453            LLMMessageContent::List(l) => l
454                .iter()
455                .map(|c| match c {
456                    LLMMessageTypedContent::Text { text } => text.clone(),
457                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
458                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
459                    LLMMessageTypedContent::Image { .. } => String::new(),
460                })
461                .collect::<Vec<_>>()
462                .join("\n"),
463        };
464        let tool_calls = if let LLMMessageContent::List(items) = &value.new_message.content {
465            let calls: Vec<ToolCall> = items
466                .iter()
467                .filter_map(|item| {
468                    if let LLMMessageTypedContent::ToolCall {
469                        id,
470                        name,
471                        args,
472                        metadata,
473                    } = item
474                    {
475                        Some(ToolCall {
476                            id: id.clone(),
477                            r#type: "function".to_string(),
478                            function: FunctionCall {
479                                name: name.clone(),
480                                arguments: args.to_string(),
481                            },
482                            metadata: metadata.clone(),
483                        })
484                    } else {
485                        None
486                    }
487                })
488                .collect();
489
490            if calls.is_empty() { None } else { Some(calls) }
491        } else {
492            None
493        };
494        ChatMessage {
495            role: Role::Assistant,
496            content: Some(MessageContent::String(message_content)),
497            name: None,
498            tool_calls,
499            tool_call_id: None,
500            usage: Some(value.usage.clone()),
501            ..Default::default()
502        }
503    }
504}
505
506impl AgentState {
507    pub fn new(
508        active_model: Model,
509        messages: Vec<ChatMessage>,
510        tools: Option<Vec<Tool>>,
511        metadata: Option<Value>,
512    ) -> Self {
513        Self {
514            active_model,
515            messages,
516            tools,
517            metadata,
518            llm_input: None,
519            llm_output: None,
520        }
521    }
522
523    pub fn set_messages(&mut self, messages: Vec<ChatMessage>) {
524        self.messages = messages;
525    }
526
527    pub fn set_tools(&mut self, tools: Option<Vec<Tool>>) {
528        self.tools = tools;
529    }
530
531    pub fn set_active_model(&mut self, model: Model) {
532        self.active_model = model;
533    }
534
535    pub fn set_llm_input(&mut self, llm_input: Option<LLMInput>) {
536        self.llm_input = llm_input;
537    }
538
539    pub fn set_llm_output(&mut self, new_message: LLMMessage, new_usage: Option<LLMTokenUsage>) {
540        self.llm_output = Some(LLMOutput {
541            new_message,
542            usage: new_usage.unwrap_or_default(),
543        });
544    }
545
546    pub fn append_new_message(&mut self, new_message: ChatMessage) {
547        self.messages.push(new_message);
548    }
549}