Skip to main content

stakpak_api/
models.rs

1use std::collections::HashMap;
2
3use chrono::{DateTime, Utc};
4use rmcp::model::Content;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use stakpak_shared::models::{
8    integrations::openai::{
9        AgentModel, ChatMessage, FunctionCall, MessageContent, Role, Tool, ToolCall,
10    },
11    llm::{LLMInput, LLMMessage, LLMMessageContent, LLMMessageTypedContent, LLMTokenUsage},
12};
13use uuid::Uuid;
14
15#[derive(Debug, Clone, Deserialize, Serialize)]
16pub enum ApiStreamError {
17    AgentInputInvalid(String),
18    AgentStateInvalid,
19    AgentNotSupported,
20    AgentExecutionLimitExceeded,
21    AgentInvalidResponseStream,
22    InvalidGeneratedCode,
23    CopilotError,
24    SaveError,
25    Unknown(String),
26}
27
28impl From<&str> for ApiStreamError {
29    fn from(error_str: &str) -> Self {
30        match error_str {
31            s if s.contains("Agent not supported") => ApiStreamError::AgentNotSupported,
32            s if s.contains("Agent state is not valid") => ApiStreamError::AgentStateInvalid,
33            s if s.contains("Agent thinking limit exceeded") => {
34                ApiStreamError::AgentExecutionLimitExceeded
35            }
36            s if s.contains("Invalid response stream") => {
37                ApiStreamError::AgentInvalidResponseStream
38            }
39            s if s.contains("Invalid generated code") => ApiStreamError::InvalidGeneratedCode,
40            s if s.contains(
41                "Our copilot is handling too many requests at this time, please try again later.",
42            ) =>
43            {
44                ApiStreamError::CopilotError
45            }
46            s if s
47                .contains("An error occurred while saving your data. Please try again later.") =>
48            {
49                ApiStreamError::SaveError
50            }
51            s if s.contains("Agent input is not valid: ") => {
52                ApiStreamError::AgentInputInvalid(s.replace("Agent input is not valid: ", ""))
53            }
54            _ => ApiStreamError::Unknown(error_str.to_string()),
55        }
56    }
57}
58
59impl From<String> for ApiStreamError {
60    fn from(error_str: String) -> Self {
61        ApiStreamError::from(error_str.as_str())
62    }
63}
64
65#[derive(Deserialize, Serialize, Debug)]
66pub struct Document {
67    pub content: String,
68    pub uri: String,
69    pub provisioner: ProvisionerType,
70}
71
72#[derive(Deserialize, Serialize, Debug)]
73pub struct SimpleDocument {
74    pub uri: String,
75    pub content: String,
76}
77
78#[derive(Deserialize, Serialize, Debug, Clone)]
79pub struct Block {
80    pub id: Uuid,
81    pub provider: String,
82    pub provisioner: ProvisionerType,
83    pub language: String,
84    pub key: String,
85    pub digest: u64,
86    pub references: Vec<Vec<Segment>>,
87    pub kind: String,
88    pub r#type: Option<String>,
89    pub name: Option<String>,
90    pub config: serde_json::Value,
91    pub document_uri: String,
92    pub code: String,
93    pub start_byte: usize,
94    pub end_byte: usize,
95    pub start_point: Point,
96    pub end_point: Point,
97    pub state: Option<serde_json::Value>,
98    pub updated_at: Option<DateTime<Utc>>,
99    pub created_at: Option<DateTime<Utc>>,
100    pub dependents: Vec<DependentBlock>,
101    pub dependencies: Vec<Dependency>,
102    pub api_group_version: Option<ApiGroupVersion>,
103
104    pub generated_summary: Option<String>,
105}
106
107impl Block {
108    pub fn get_uri(&self) -> String {
109        format!(
110            "{}#L{}-L{}",
111            self.document_uri, self.start_point.row, self.end_point.row
112        )
113    }
114}
115
116#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)]
117pub enum ProvisionerType {
118    #[serde(rename = "Terraform")]
119    Terraform,
120    #[serde(rename = "Kubernetes")]
121    Kubernetes,
122    #[serde(rename = "Dockerfile")]
123    Dockerfile,
124    #[serde(rename = "GithubActions")]
125    GithubActions,
126    #[serde(rename = "None")]
127    None,
128}
129impl std::str::FromStr for ProvisionerType {
130    type Err = String;
131
132    fn from_str(s: &str) -> Result<Self, Self::Err> {
133        match s.to_lowercase().as_str() {
134            "terraform" => Ok(Self::Terraform),
135            "kubernetes" => Ok(Self::Kubernetes),
136            "dockerfile" => Ok(Self::Dockerfile),
137            "github-actions" => Ok(Self::GithubActions),
138            _ => Ok(Self::None),
139        }
140    }
141}
142impl std::fmt::Display for ProvisionerType {
143    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
144        match self {
145            ProvisionerType::Terraform => write!(f, "terraform"),
146            ProvisionerType::Kubernetes => write!(f, "kubernetes"),
147            ProvisionerType::Dockerfile => write!(f, "dockerfile"),
148            ProvisionerType::GithubActions => write!(f, "github-actions"),
149            ProvisionerType::None => write!(f, "none"),
150        }
151    }
152}
153
154#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
155#[serde(untagged)]
156pub enum Segment {
157    Key(String),
158    Index(usize),
159}
160
161impl std::fmt::Display for Segment {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        match self {
164            Segment::Key(key) => write!(f, "{}", key),
165            Segment::Index(index) => write!(f, "{}", index),
166        }
167    }
168}
169impl std::fmt::Debug for Segment {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        match self {
172            Segment::Key(key) => write!(f, "{}", key),
173            Segment::Index(index) => write!(f, "{}", index),
174        }
175    }
176}
177
178#[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq)]
179pub struct Point {
180    pub row: usize,
181    pub column: usize,
182}
183
184#[derive(Deserialize, Serialize, Debug, Clone)]
185pub struct DependentBlock {
186    pub key: String,
187}
188
189#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
190pub struct Dependency {
191    pub id: Option<Uuid>,
192    pub expression: Option<String>,
193    pub from_path: Option<Vec<Segment>>,
194    pub to_path: Option<Vec<Segment>>,
195    #[serde(default = "Vec::new")]
196    pub selectors: Vec<DependencySelector>,
197    #[serde(skip_serializing)]
198    pub key: Option<String>,
199    pub digest: Option<u64>,
200    #[serde(default = "Vec::new")]
201    pub from: Vec<Segment>,
202    pub from_field: Option<Vec<Segment>>,
203    pub to_field: Option<Vec<Segment>>,
204    pub start_byte: Option<usize>,
205    pub end_byte: Option<usize>,
206    pub start_point: Option<Point>,
207    pub end_point: Option<Point>,
208    pub satisfied: bool,
209}
210
211#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
212pub struct DependencySelector {
213    pub references: Vec<Vec<Segment>>,
214    pub operator: DependencySelectorOperator,
215}
216
217#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
218pub enum DependencySelectorOperator {
219    Equals,
220    NotEquals,
221    In,
222    NotIn,
223    Exists,
224    DoesNotExist,
225}
226
227#[derive(Serialize, Deserialize, Debug, Clone)]
228pub struct ApiGroupVersion {
229    pub alias: String,
230    pub group: String,
231    pub version: String,
232    pub provisioner: ProvisionerType,
233    pub status: APIGroupVersionStatus,
234}
235
236#[derive(Serialize, Deserialize, Debug, Clone)]
237pub enum APIGroupVersionStatus {
238    #[serde(rename = "UNAVAILABLE")]
239    Unavailable,
240    #[serde(rename = "PENDING")]
241    Pending,
242    #[serde(rename = "AVAILABLE")]
243    Available,
244}
245
246#[derive(Serialize, Deserialize, Debug)]
247pub struct BuildCodeIndexInput {
248    pub documents: Vec<SimpleDocument>,
249}
250
251#[derive(Serialize, Deserialize, Debug, Clone)]
252pub struct IndexError {
253    pub uri: String,
254    pub message: String,
255    pub details: Option<serde_json::Value>,
256}
257
258#[derive(Serialize, Deserialize, Debug, Clone)]
259pub struct BuildCodeIndexOutput {
260    pub blocks: Vec<Block>,
261    pub errors: Vec<IndexError>,
262    pub warnings: Vec<IndexError>,
263}
264
265#[derive(Serialize, Deserialize, Debug, Clone)]
266pub struct CodeIndex {
267    pub last_updated: DateTime<Utc>,
268    pub index: BuildCodeIndexOutput,
269}
270
271#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
272#[serde(rename_all = "UPPERCASE")]
273pub enum RuleBookVisibility {
274    #[default]
275    Public,
276    Private,
277}
278
279#[derive(Serialize, Deserialize, Debug, Clone)]
280pub struct RuleBook {
281    pub id: String,
282    pub uri: String,
283    pub description: String,
284    pub content: String,
285    pub visibility: RuleBookVisibility,
286    pub tags: Vec<String>,
287    pub created_at: Option<DateTime<Utc>>,
288    pub updated_at: Option<DateTime<Utc>>,
289}
290
291#[derive(Serialize, Deserialize, Debug)]
292pub struct ToolsCallParams {
293    pub name: String,
294    pub arguments: Value,
295}
296
297#[derive(Serialize, Deserialize, Debug)]
298pub struct ToolsCallResponse {
299    pub content: Vec<Content>,
300}
301
302#[derive(Serialize, Deserialize, Debug, Clone)]
303pub struct APIKeyScope {
304    pub r#type: String,
305    pub name: String,
306}
307
308impl std::fmt::Display for APIKeyScope {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        write!(f, "{} ({})", self.name, self.r#type)
311    }
312}
313
314#[derive(Serialize, Deserialize, Clone, Debug)]
315pub struct GetMyAccountResponse {
316    pub username: String,
317    pub id: String,
318    pub first_name: String,
319    pub last_name: String,
320    pub email: String,
321    pub scope: Option<APIKeyScope>,
322}
323
324impl GetMyAccountResponse {
325    pub fn to_text(&self) -> String {
326        format!(
327            "ID: {}\nUsername: {}\nName: {} {}\nEmail: {}",
328            self.id, self.username, self.first_name, self.last_name, self.email
329        )
330    }
331}
332
333#[derive(Serialize, Deserialize, Debug, Clone)]
334pub struct ListRuleBook {
335    pub id: String,
336    pub uri: String,
337    pub description: String,
338    pub visibility: RuleBookVisibility,
339    pub tags: Vec<String>,
340    pub created_at: Option<DateTime<Utc>>,
341    pub updated_at: Option<DateTime<Utc>>,
342}
343
344#[derive(Serialize, Deserialize, Debug)]
345pub struct ListRulebooksResponse {
346    pub results: Vec<ListRuleBook>,
347}
348
349#[derive(Serialize, Deserialize, Debug)]
350pub struct CreateRuleBookInput {
351    pub uri: String,
352    pub description: String,
353    pub content: String,
354    pub tags: Vec<String>,
355    #[serde(skip_serializing_if = "Option::is_none")]
356    pub visibility: Option<RuleBookVisibility>,
357}
358
359#[derive(Serialize, Deserialize, Debug)]
360pub struct CreateRuleBookResponse {
361    pub id: String,
362}
363
364impl ListRuleBook {
365    pub fn to_text(&self) -> String {
366        format!(
367            "URI: {}\nDescription: {}\nTags: {}\n",
368            self.uri,
369            self.description,
370            self.tags.join(", ")
371        )
372    }
373}
374
375#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
376pub struct SimpleLLMMessage {
377    #[serde(rename = "role")]
378    pub role: SimpleLLMRole,
379    pub content: String,
380}
381
382#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
383#[serde(rename_all = "lowercase")]
384pub enum SimpleLLMRole {
385    User,
386    Assistant,
387}
388
389impl std::fmt::Display for SimpleLLMRole {
390    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391        match self {
392            SimpleLLMRole::User => write!(f, "user"),
393            SimpleLLMRole::Assistant => write!(f, "assistant"),
394        }
395    }
396}
397
398#[derive(Debug, Deserialize, Serialize)]
399pub struct SearchDocsRequest {
400    pub keywords: String,
401    pub exclude_keywords: Option<String>,
402    pub limit: Option<u32>,
403}
404
405#[derive(Debug, Deserialize, Serialize)]
406pub struct SearchMemoryRequest {
407    pub keywords: Vec<String>,
408    pub start_time: Option<DateTime<Utc>>,
409    pub end_time: Option<DateTime<Utc>>,
410}
411
412#[derive(Debug, Deserialize, Serialize)]
413pub struct SlackReadMessagesRequest {
414    pub channel: String,
415    pub limit: Option<u32>,
416}
417
418#[derive(Debug, Deserialize, Serialize)]
419pub struct SlackReadRepliesRequest {
420    pub channel: String,
421    pub ts: String,
422}
423
424#[derive(Debug, Deserialize, Serialize)]
425pub struct SlackSendMessageRequest {
426    pub channel: String,
427    pub mrkdwn_text: String,
428    pub thread_ts: Option<String>,
429}
430
431#[derive(Debug, Clone, Default, Serialize)]
432pub struct AgentState {
433    pub agent_model: AgentModel,
434    pub messages: Vec<ChatMessage>,
435    pub tools: Option<Vec<Tool>>,
436
437    pub llm_input: Option<LLMInput>,
438    pub llm_output: Option<LLMOutput>,
439
440    pub metadata: Option<HashMap<String, Value>>,
441}
442
443#[derive(Debug, Clone, Default, Serialize)]
444pub struct LLMOutput {
445    pub new_message: LLMMessage,
446    pub usage: LLMTokenUsage,
447}
448
449impl From<&LLMOutput> for ChatMessage {
450    fn from(value: &LLMOutput) -> Self {
451        let message_content = match &value.new_message.content {
452            LLMMessageContent::String(s) => s.clone(),
453            LLMMessageContent::List(l) => l
454                .iter()
455                .map(|c| match c {
456                    LLMMessageTypedContent::Text { text } => text.clone(),
457                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
458                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
459                    LLMMessageTypedContent::Image { .. } => String::new(),
460                })
461                .collect::<Vec<_>>()
462                .join("\n"),
463        };
464        let tool_calls = if let LLMMessageContent::List(items) = &value.new_message.content {
465            let calls: Vec<ToolCall> = items
466                .iter()
467                .filter_map(|item| {
468                    if let LLMMessageTypedContent::ToolCall { id, name, args } = item {
469                        Some(ToolCall {
470                            id: id.clone(),
471                            r#type: "function".to_string(),
472                            function: FunctionCall {
473                                name: name.clone(),
474                                arguments: args.to_string(),
475                            },
476                        })
477                    } else {
478                        None
479                    }
480                })
481                .collect();
482
483            if calls.is_empty() { None } else { Some(calls) }
484        } else {
485            None
486        };
487        ChatMessage {
488            role: Role::Assistant,
489            content: Some(MessageContent::String(message_content)),
490            name: None,
491            tool_calls,
492            tool_call_id: None,
493            usage: Some(value.usage.clone()),
494            ..Default::default()
495        }
496    }
497}
498
499impl AgentState {
500    pub fn new(
501        agent_model: AgentModel,
502        messages: Vec<ChatMessage>,
503        tools: Option<Vec<Tool>>,
504    ) -> Self {
505        Self {
506            agent_model,
507            messages,
508            tools,
509            llm_input: None,
510            llm_output: None,
511            metadata: None,
512        }
513    }
514
515    pub fn set_messages(&mut self, messages: Vec<ChatMessage>) {
516        self.messages = messages;
517    }
518
519    pub fn set_tools(&mut self, tools: Option<Vec<Tool>>) {
520        self.tools = tools;
521    }
522
523    pub fn set_agent_model(&mut self, agent_model: AgentModel) {
524        self.agent_model = agent_model;
525    }
526
527    pub fn set_llm_input(&mut self, llm_input: Option<LLMInput>) {
528        self.llm_input = llm_input;
529    }
530
531    pub fn set_llm_output(&mut self, new_message: LLMMessage, new_usage: Option<LLMTokenUsage>) {
532        self.llm_output = Some(LLMOutput {
533            new_message,
534            usage: new_usage.unwrap_or_default(),
535        });
536    }
537
538    pub fn append_new_message(&mut self, new_message: ChatMessage) {
539        self.messages.push(new_message);
540    }
541}