Skip to main content

stakpak_api/
models.rs

1use chrono::{DateTime, Utc};
2use rmcp::model::Content;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use stakai::Model;
6use stakpak_shared::models::{
7    integrations::openai::{ChatMessage, FunctionCall, MessageContent, Role, Tool, ToolCall},
8    llm::{LLMInput, LLMMessage, LLMMessageContent, LLMMessageTypedContent, LLMTokenUsage},
9};
10use std::collections::HashMap;
11use uuid::Uuid;
12
13#[derive(Debug, Clone, Deserialize, Serialize)]
14pub enum ApiStreamError {
15    AgentInputInvalid(String),
16    AgentStateInvalid,
17    AgentNotSupported,
18    AgentExecutionLimitExceeded,
19    AgentInvalidResponseStream,
20    InvalidGeneratedCode,
21    CopilotError,
22    SaveError,
23    Unknown(String),
24}
25
26impl From<&str> for ApiStreamError {
27    fn from(error_str: &str) -> Self {
28        match error_str {
29            s if s.contains("Agent not supported") => ApiStreamError::AgentNotSupported,
30            s if s.contains("Agent state is not valid") => ApiStreamError::AgentStateInvalid,
31            s if s.contains("Agent thinking limit exceeded") => {
32                ApiStreamError::AgentExecutionLimitExceeded
33            }
34            s if s.contains("Invalid response stream") => {
35                ApiStreamError::AgentInvalidResponseStream
36            }
37            s if s.contains("Invalid generated code") => ApiStreamError::InvalidGeneratedCode,
38            s if s.contains(
39                "Our copilot is handling too many requests at this time, please try again later.",
40            ) =>
41            {
42                ApiStreamError::CopilotError
43            }
44            s if s
45                .contains("An error occurred while saving your data. Please try again later.") =>
46            {
47                ApiStreamError::SaveError
48            }
49            s if s.contains("Agent input is not valid: ") => {
50                ApiStreamError::AgentInputInvalid(s.replace("Agent input is not valid: ", ""))
51            }
52            _ => ApiStreamError::Unknown(error_str.to_string()),
53        }
54    }
55}
56
57impl From<String> for ApiStreamError {
58    fn from(error_str: String) -> Self {
59        ApiStreamError::from(error_str.as_str())
60    }
61}
62
63#[derive(Deserialize, Serialize, Debug)]
64pub struct Document {
65    pub content: String,
66    pub uri: String,
67    pub provisioner: ProvisionerType,
68}
69
70#[derive(Deserialize, Serialize, Debug)]
71pub struct SimpleDocument {
72    pub uri: String,
73    pub content: String,
74}
75
76#[derive(Deserialize, Serialize, Debug, Clone)]
77pub struct Block {
78    pub id: Uuid,
79    pub provider: String,
80    pub provisioner: ProvisionerType,
81    pub language: String,
82    pub key: String,
83    pub digest: u64,
84    pub references: Vec<Vec<Segment>>,
85    pub kind: String,
86    pub r#type: Option<String>,
87    pub name: Option<String>,
88    pub config: serde_json::Value,
89    pub document_uri: String,
90    pub code: String,
91    pub start_byte: usize,
92    pub end_byte: usize,
93    pub start_point: Point,
94    pub end_point: Point,
95    pub state: Option<serde_json::Value>,
96    pub updated_at: Option<DateTime<Utc>>,
97    pub created_at: Option<DateTime<Utc>>,
98    pub dependents: Vec<DependentBlock>,
99    pub dependencies: Vec<Dependency>,
100    pub api_group_version: Option<ApiGroupVersion>,
101
102    pub generated_summary: Option<String>,
103}
104
105impl Block {
106    pub fn get_uri(&self) -> String {
107        format!(
108            "{}#L{}-L{}",
109            self.document_uri, self.start_point.row, self.end_point.row
110        )
111    }
112}
113
114#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)]
115pub enum ProvisionerType {
116    #[serde(rename = "Terraform")]
117    Terraform,
118    #[serde(rename = "Kubernetes")]
119    Kubernetes,
120    #[serde(rename = "Dockerfile")]
121    Dockerfile,
122    #[serde(rename = "GithubActions")]
123    GithubActions,
124    #[serde(rename = "None")]
125    None,
126}
127impl std::str::FromStr for ProvisionerType {
128    type Err = String;
129
130    fn from_str(s: &str) -> Result<Self, Self::Err> {
131        match s.to_lowercase().as_str() {
132            "terraform" => Ok(Self::Terraform),
133            "kubernetes" => Ok(Self::Kubernetes),
134            "dockerfile" => Ok(Self::Dockerfile),
135            "github-actions" => Ok(Self::GithubActions),
136            _ => Ok(Self::None),
137        }
138    }
139}
140impl std::fmt::Display for ProvisionerType {
141    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
142        match self {
143            ProvisionerType::Terraform => write!(f, "terraform"),
144            ProvisionerType::Kubernetes => write!(f, "kubernetes"),
145            ProvisionerType::Dockerfile => write!(f, "dockerfile"),
146            ProvisionerType::GithubActions => write!(f, "github-actions"),
147            ProvisionerType::None => write!(f, "none"),
148        }
149    }
150}
151
152#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
153#[serde(untagged)]
154pub enum Segment {
155    Key(String),
156    Index(usize),
157}
158
159impl std::fmt::Display for Segment {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        match self {
162            Segment::Key(key) => write!(f, "{}", key),
163            Segment::Index(index) => write!(f, "{}", index),
164        }
165    }
166}
167impl std::fmt::Debug for Segment {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        match self {
170            Segment::Key(key) => write!(f, "{}", key),
171            Segment::Index(index) => write!(f, "{}", index),
172        }
173    }
174}
175
176#[derive(Serialize, Deserialize, Debug, Default, Clone, PartialEq)]
177pub struct Point {
178    pub row: usize,
179    pub column: usize,
180}
181
182#[derive(Deserialize, Serialize, Debug, Clone)]
183pub struct DependentBlock {
184    pub key: String,
185}
186
187#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
188pub struct Dependency {
189    pub id: Option<Uuid>,
190    pub expression: Option<String>,
191    pub from_path: Option<Vec<Segment>>,
192    pub to_path: Option<Vec<Segment>>,
193    #[serde(default = "Vec::new")]
194    pub selectors: Vec<DependencySelector>,
195    #[serde(skip_serializing)]
196    pub key: Option<String>,
197    pub digest: Option<u64>,
198    #[serde(default = "Vec::new")]
199    pub from: Vec<Segment>,
200    pub from_field: Option<Vec<Segment>>,
201    pub to_field: Option<Vec<Segment>>,
202    pub start_byte: Option<usize>,
203    pub end_byte: Option<usize>,
204    pub start_point: Option<Point>,
205    pub end_point: Option<Point>,
206    pub satisfied: bool,
207}
208
209#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
210pub struct DependencySelector {
211    pub references: Vec<Vec<Segment>>,
212    pub operator: DependencySelectorOperator,
213}
214
215#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
216pub enum DependencySelectorOperator {
217    Equals,
218    NotEquals,
219    In,
220    NotIn,
221    Exists,
222    DoesNotExist,
223}
224
225#[derive(Serialize, Deserialize, Debug, Clone)]
226pub struct ApiGroupVersion {
227    pub alias: String,
228    pub group: String,
229    pub version: String,
230    pub provisioner: ProvisionerType,
231    pub status: APIGroupVersionStatus,
232}
233
234#[derive(Serialize, Deserialize, Debug, Clone)]
235pub enum APIGroupVersionStatus {
236    #[serde(rename = "UNAVAILABLE")]
237    Unavailable,
238    #[serde(rename = "PENDING")]
239    Pending,
240    #[serde(rename = "AVAILABLE")]
241    Available,
242}
243
244#[derive(Serialize, Deserialize, Debug)]
245pub struct BuildCodeIndexInput {
246    pub documents: Vec<SimpleDocument>,
247}
248
249#[derive(Serialize, Deserialize, Debug, Clone)]
250pub struct IndexError {
251    pub uri: String,
252    pub message: String,
253    pub details: Option<serde_json::Value>,
254}
255
256#[derive(Serialize, Deserialize, Debug, Clone)]
257pub struct BuildCodeIndexOutput {
258    pub blocks: Vec<Block>,
259    pub errors: Vec<IndexError>,
260    pub warnings: Vec<IndexError>,
261}
262
263#[derive(Serialize, Deserialize, Debug, Clone)]
264pub struct CodeIndex {
265    pub last_updated: DateTime<Utc>,
266    pub index: BuildCodeIndexOutput,
267}
268
269/// Unified skill type representing knowledge from any source.
270#[derive(Serialize, Deserialize, Debug, Clone)]
271pub struct Skill {
272    /// Display name / identifier
273    pub name: String,
274    /// Unique identifier — API URI for remote, file path for local
275    pub uri: String,
276    /// When to use this skill
277    pub description: String,
278    /// Where this skill comes from
279    pub source: SkillSource,
280    /// None = metadata only; Some = full content loaded
281    pub content: Option<String>,
282    /// Classification tags
283    pub tags: Vec<String>,
284    /// License name or reference to a bundled license file.
285    #[serde(skip_serializing_if = "Option::is_none")]
286    pub license: Option<String>,
287    /// Environment requirements (intended product, system packages, network access, etc.).
288    #[serde(skip_serializing_if = "Option::is_none")]
289    pub compatibility: Option<String>,
290    /// Arbitrary key-value mapping for additional metadata.
291    #[serde(skip_serializing_if = "Option::is_none")]
292    pub metadata: Option<HashMap<String, String>>,
293    /// Space-delimited list of pre-approved tools the skill may use. (Experimental)
294    #[serde(skip_serializing_if = "Option::is_none")]
295    pub allowed_tools: Option<String>,
296}
297
298#[derive(Serialize, Deserialize, Debug, Clone)]
299pub enum SkillSource {
300    Local,
301    Remote { provider: RemoteProvider },
302}
303
304#[derive(Serialize, Deserialize, Debug, Clone)]
305pub enum RemoteProvider {
306    Rulebook { visibility: RuleBookVisibility },
307    Pak,
308}
309
310impl Skill {
311    pub fn is_local(&self) -> bool {
312        matches!(self.source, SkillSource::Local)
313    }
314
315    pub fn is_rulebook(&self) -> bool {
316        matches!(
317            self.source,
318            SkillSource::Remote {
319                provider: RemoteProvider::Rulebook { .. }
320            }
321        )
322    }
323
324    pub fn is_pak(&self) -> bool {
325        matches!(
326            self.source,
327            SkillSource::Remote {
328                provider: RemoteProvider::Pak
329            }
330        )
331    }
332
333    pub fn to_metadata_text(&self) -> String {
334        let source_label = match &self.source {
335            SkillSource::Local => "local",
336            SkillSource::Remote {
337                provider: RemoteProvider::Rulebook { .. },
338            } => "remote",
339            SkillSource::Remote {
340                provider: RemoteProvider::Pak,
341            } => "pak",
342        };
343        let tags_str = if self.tags.is_empty() {
344            String::new()
345        } else {
346            format!(" [{}]", self.tags.join(", "))
347        };
348        format!(
349            "<skill><label>{}</label><name>{}</name><description>{}</description><uri>{}</uri><tags>{}</tags></skill>",
350            source_label, self.name, self.description, self.uri, tags_str
351        )
352    }
353}
354
355impl From<ListRuleBook> for Skill {
356    fn from(rb: ListRuleBook) -> Self {
357        Skill {
358            name: rb.uri.clone(),
359            uri: rb.uri,
360            description: rb.description,
361            source: SkillSource::Remote {
362                provider: RemoteProvider::Rulebook {
363                    visibility: rb.visibility,
364                },
365            },
366            content: None,
367            tags: rb.tags,
368            license: None,
369            compatibility: None,
370            metadata: None,
371            allowed_tools: None,
372        }
373    }
374}
375
376impl From<RuleBook> for Skill {
377    fn from(rb: RuleBook) -> Self {
378        Skill {
379            name: rb.uri.clone(),
380            uri: rb.uri,
381            description: rb.description,
382            source: SkillSource::Remote {
383                provider: RemoteProvider::Rulebook {
384                    visibility: rb.visibility,
385                },
386            },
387            content: Some(rb.content),
388            tags: rb.tags,
389            license: None,
390            compatibility: None,
391            metadata: None,
392            allowed_tools: None,
393        }
394    }
395}
396
397#[derive(Serialize, Deserialize, Debug, Clone, Copy, Default)]
398#[serde(rename_all = "UPPERCASE")]
399pub enum RuleBookVisibility {
400    #[default]
401    Public,
402    Private,
403}
404
405#[derive(Serialize, Deserialize, Debug, Clone)]
406pub struct RuleBook {
407    pub id: String,
408    pub uri: String,
409    pub description: String,
410    pub content: String,
411    pub visibility: RuleBookVisibility,
412    pub tags: Vec<String>,
413    pub created_at: Option<DateTime<Utc>>,
414    pub updated_at: Option<DateTime<Utc>>,
415}
416
417#[derive(Serialize, Deserialize, Debug)]
418pub struct ToolsCallParams {
419    pub name: String,
420    pub arguments: Value,
421}
422
423#[derive(Serialize, Deserialize, Debug)]
424pub struct ToolsCallResponse {
425    pub content: Vec<Content>,
426}
427
428#[derive(Serialize, Deserialize, Debug, Clone)]
429pub struct APIKeyScope {
430    pub r#type: String,
431    pub name: String,
432}
433
434impl std::fmt::Display for APIKeyScope {
435    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
436        write!(f, "{} ({})", self.name, self.r#type)
437    }
438}
439
440#[derive(Serialize, Deserialize, Clone, Debug)]
441pub struct GetMyAccountResponse {
442    pub username: String,
443    pub id: String,
444    pub first_name: String,
445    pub last_name: String,
446    pub email: String,
447    pub scope: Option<APIKeyScope>,
448}
449
450impl GetMyAccountResponse {
451    pub fn to_text(&self) -> String {
452        format!(
453            "ID: {}\nUsername: {}\nName: {} {}\nEmail: {}",
454            self.id, self.username, self.first_name, self.last_name, self.email
455        )
456    }
457}
458
459#[derive(Serialize, Deserialize, Debug, Clone)]
460pub struct ListRuleBook {
461    pub id: String,
462    pub uri: String,
463    pub description: String,
464    pub visibility: RuleBookVisibility,
465    pub tags: Vec<String>,
466    pub created_at: Option<DateTime<Utc>>,
467    pub updated_at: Option<DateTime<Utc>>,
468}
469
470#[derive(Serialize, Deserialize, Debug)]
471pub struct ListRulebooksResponse {
472    pub results: Vec<ListRuleBook>,
473}
474
475#[derive(Serialize, Deserialize, Debug)]
476pub struct CreateRuleBookInput {
477    pub uri: String,
478    pub description: String,
479    pub content: String,
480    pub tags: Vec<String>,
481    #[serde(skip_serializing_if = "Option::is_none")]
482    pub visibility: Option<RuleBookVisibility>,
483}
484
485#[derive(Serialize, Deserialize, Debug)]
486pub struct CreateRuleBookResponse {
487    pub id: String,
488}
489
490impl ListRuleBook {
491    pub fn to_text(&self) -> String {
492        format!(
493            "URI: {}\nDescription: {}\nTags: {}\n",
494            self.uri,
495            self.description,
496            self.tags.join(", ")
497        )
498    }
499}
500
501#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
502pub struct SimpleLLMMessage {
503    #[serde(rename = "role")]
504    pub role: SimpleLLMRole,
505    pub content: String,
506}
507
508#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
509#[serde(rename_all = "lowercase")]
510pub enum SimpleLLMRole {
511    User,
512    Assistant,
513}
514
515impl std::fmt::Display for SimpleLLMRole {
516    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
517        match self {
518            SimpleLLMRole::User => write!(f, "user"),
519            SimpleLLMRole::Assistant => write!(f, "assistant"),
520        }
521    }
522}
523
524#[derive(Debug, Deserialize, Serialize)]
525pub struct SearchDocsRequest {
526    pub keywords: String,
527    pub exclude_keywords: Option<String>,
528    pub limit: Option<u32>,
529}
530
531#[derive(Debug, Deserialize, Serialize)]
532pub struct SearchMemoryRequest {
533    pub keywords: Vec<String>,
534    pub start_time: Option<DateTime<Utc>>,
535    pub end_time: Option<DateTime<Utc>>,
536}
537
538#[derive(Debug, Deserialize, Serialize)]
539pub struct SlackReadMessagesRequest {
540    pub channel: String,
541    pub limit: Option<u32>,
542}
543
544#[derive(Debug, Deserialize, Serialize)]
545pub struct SlackReadRepliesRequest {
546    pub channel: String,
547    pub ts: String,
548}
549
550#[derive(Debug, Deserialize, Serialize)]
551pub struct SlackSendMessageRequest {
552    pub channel: String,
553    pub markdown_text: String,
554    pub thread_ts: Option<String>,
555}
556
557#[derive(Debug, Clone, Default, Serialize)]
558pub struct AgentState {
559    /// The active model to use for inference
560    pub active_model: Model,
561    pub messages: Vec<ChatMessage>,
562    pub tools: Option<Vec<Tool>>,
563
564    pub llm_input: Option<LLMInput>,
565    pub llm_output: Option<LLMOutput>,
566
567    /// Metadata for checkpoint persistence (context trimming state, etc.)
568    /// Loaded from checkpoint on session resume and saved back after inference
569    pub metadata: Option<Value>,
570}
571
572#[derive(Debug, Clone, Default, Serialize)]
573pub struct LLMOutput {
574    pub new_message: LLMMessage,
575    pub usage: LLMTokenUsage,
576}
577
578impl From<&LLMOutput> for ChatMessage {
579    fn from(value: &LLMOutput) -> Self {
580        let message_content = match &value.new_message.content {
581            LLMMessageContent::String(s) => s.clone(),
582            LLMMessageContent::List(l) => l
583                .iter()
584                .map(|c| match c {
585                    LLMMessageTypedContent::Text { text } => text.clone(),
586                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
587                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
588                    LLMMessageTypedContent::Image { .. } => String::new(),
589                })
590                .collect::<Vec<_>>()
591                .join("\n"),
592        };
593        let tool_calls = if let LLMMessageContent::List(items) = &value.new_message.content {
594            let calls: Vec<ToolCall> = items
595                .iter()
596                .filter_map(|item| {
597                    if let LLMMessageTypedContent::ToolCall {
598                        id,
599                        name,
600                        args,
601                        metadata,
602                    } = item
603                    {
604                        Some(ToolCall {
605                            id: id.clone(),
606                            r#type: "function".to_string(),
607                            function: FunctionCall {
608                                name: name.clone(),
609                                arguments: args.to_string(),
610                            },
611                            metadata: metadata.clone(),
612                        })
613                    } else {
614                        None
615                    }
616                })
617                .collect();
618
619            if calls.is_empty() { None } else { Some(calls) }
620        } else {
621            None
622        };
623        ChatMessage {
624            role: Role::Assistant,
625            content: Some(MessageContent::String(message_content)),
626            name: None,
627            tool_calls,
628            tool_call_id: None,
629            usage: Some(value.usage.clone()),
630            ..Default::default()
631        }
632    }
633}
634
635impl AgentState {
636    pub fn new(
637        active_model: Model,
638        messages: Vec<ChatMessage>,
639        tools: Option<Vec<Tool>>,
640        metadata: Option<Value>,
641    ) -> Self {
642        Self {
643            active_model,
644            messages,
645            tools,
646            metadata,
647            llm_input: None,
648            llm_output: None,
649        }
650    }
651
652    pub fn set_messages(&mut self, messages: Vec<ChatMessage>) {
653        self.messages = messages;
654    }
655
656    pub fn set_tools(&mut self, tools: Option<Vec<Tool>>) {
657        self.tools = tools;
658    }
659
660    pub fn set_active_model(&mut self, model: Model) {
661        self.active_model = model;
662    }
663
664    pub fn set_llm_input(&mut self, llm_input: Option<LLMInput>) {
665        self.llm_input = llm_input;
666    }
667
668    pub fn set_llm_output(&mut self, new_message: LLMMessage, new_usage: Option<LLMTokenUsage>) {
669        self.llm_output = Some(LLMOutput {
670            new_message,
671            usage: new_usage.unwrap_or_default(),
672        });
673    }
674
675    pub fn append_new_message(&mut self, new_message: ChatMessage) {
676        self.messages.push(new_message);
677    }
678}