Skip to main content

wraith_api/
types.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
5pub struct MessageRequest {
6    pub model: String,
7    pub max_tokens: u32,
8    pub messages: Vec<InputMessage>,
9    #[serde(skip_serializing_if = "Option::is_none")]
10    pub system: Option<String>,
11    #[serde(skip_serializing_if = "Option::is_none")]
12    pub tools: Option<Vec<ToolDefinition>>,
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub tool_choice: Option<ToolChoice>,
15    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
16    pub stream: bool,
17}
18
19impl MessageRequest {
20    #[must_use]
21    pub fn with_streaming(mut self) -> Self {
22        self.stream = true;
23        self
24    }
25}
26
27#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
28pub struct InputMessage {
29    pub role: String,
30    pub content: Vec<InputContentBlock>,
31}
32
33impl InputMessage {
34    #[must_use]
35    pub fn user_text(text: impl Into<String>) -> Self {
36        Self {
37            role: "user".to_string(),
38            content: vec![InputContentBlock::Text { text: text.into() }],
39        }
40    }
41
42    #[must_use]
43    pub fn user_tool_result(
44        tool_use_id: impl Into<String>,
45        content: impl Into<String>,
46        is_error: bool,
47    ) -> Self {
48        Self {
49            role: "user".to_string(),
50            content: vec![InputContentBlock::ToolResult {
51                tool_use_id: tool_use_id.into(),
52                content: vec![ToolResultContentBlock::Text {
53                    text: content.into(),
54                }],
55                is_error,
56            }],
57        }
58    }
59}
60
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
62#[serde(tag = "type", rename_all = "snake_case")]
63pub enum InputContentBlock {
64    Text {
65        text: String,
66    },
67    ToolUse {
68        id: String,
69        name: String,
70        input: Value,
71    },
72    ToolResult {
73        tool_use_id: String,
74        content: Vec<ToolResultContentBlock>,
75        #[serde(default, skip_serializing_if = "std::ops::Not::not")]
76        is_error: bool,
77    },
78}
79
80#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
81#[serde(tag = "type", rename_all = "snake_case")]
82pub enum ToolResultContentBlock {
83    Text { text: String },
84    Json { value: Value },
85}
86
87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
88pub struct ToolDefinition {
89    pub name: String,
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub description: Option<String>,
92    pub input_schema: Value,
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
96#[serde(tag = "type", rename_all = "snake_case")]
97pub enum ToolChoice {
98    Auto,
99    Any,
100    Tool { name: String },
101}
102
103#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
104pub struct MessageResponse {
105    pub id: String,
106    #[serde(rename = "type")]
107    pub kind: String,
108    pub role: String,
109    pub content: Vec<OutputContentBlock>,
110    pub model: String,
111    #[serde(default)]
112    pub stop_reason: Option<String>,
113    #[serde(default)]
114    pub stop_sequence: Option<String>,
115    pub usage: Usage,
116    #[serde(default)]
117    pub request_id: Option<String>,
118}
119
120impl MessageResponse {
121    #[must_use]
122    pub fn total_tokens(&self) -> u32 {
123        self.usage.total_tokens()
124    }
125}
126
127#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
128#[serde(tag = "type", rename_all = "snake_case")]
129pub enum OutputContentBlock {
130    Text {
131        text: String,
132    },
133    ToolUse {
134        id: String,
135        name: String,
136        input: Value,
137    },
138    Thinking {
139        #[serde(default)]
140        thinking: String,
141        #[serde(default, skip_serializing_if = "Option::is_none")]
142        signature: Option<String>,
143    },
144    RedactedThinking {
145        data: Value,
146    },
147}
148
149#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
150pub struct Usage {
151    pub input_tokens: u32,
152    #[serde(default)]
153    pub cache_creation_input_tokens: u32,
154    #[serde(default)]
155    pub cache_read_input_tokens: u32,
156    pub output_tokens: u32,
157}
158
159impl Usage {
160    #[must_use]
161    pub const fn total_tokens(&self) -> u32 {
162        self.input_tokens + self.output_tokens
163    }
164}
165
166#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
167pub struct MessageStartEvent {
168    pub message: MessageResponse,
169}
170
171#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
172pub struct MessageDeltaEvent {
173    pub delta: MessageDelta,
174    pub usage: Usage,
175}
176
177#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
178pub struct MessageDelta {
179    #[serde(default)]
180    pub stop_reason: Option<String>,
181    #[serde(default)]
182    pub stop_sequence: Option<String>,
183}
184
185#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
186pub struct ContentBlockStartEvent {
187    pub index: u32,
188    pub content_block: OutputContentBlock,
189}
190
191#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
192pub struct ContentBlockDeltaEvent {
193    pub index: u32,
194    pub delta: ContentBlockDelta,
195}
196
197#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
198#[serde(tag = "type", rename_all = "snake_case")]
199pub enum ContentBlockDelta {
200    TextDelta { text: String },
201    InputJsonDelta { partial_json: String },
202    ThinkingDelta { thinking: String },
203    SignatureDelta { signature: String },
204}
205
206#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
207pub struct ContentBlockStopEvent {
208    pub index: u32,
209}
210
211#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
212pub struct MessageStopEvent {}
213
214#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
215#[serde(tag = "type", rename_all = "snake_case")]
216pub enum StreamEvent {
217    MessageStart(MessageStartEvent),
218    MessageDelta(MessageDeltaEvent),
219    ContentBlockStart(ContentBlockStartEvent),
220    ContentBlockDelta(ContentBlockDeltaEvent),
221    ContentBlockStop(ContentBlockStopEvent),
222    MessageStop(MessageStopEvent),
223}