1use std::collections::BTreeMap;
2use std::fmt::{Display, Formatter};
3use std::fs;
4use std::path::Path;
5
6use serde::{Deserialize, Serialize};
7
8use crate::json::{JsonError, JsonValue};
9use crate::usage::TokenUsage;
10
11#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
12#[serde(rename_all = "snake_case")]
13pub enum MessageRole {
14 System,
15 User,
16 Assistant,
17 Tool,
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
21#[serde(tag = "type", rename_all = "snake_case")]
22pub enum ContentBlock {
23 Text {
24 text: String,
25 },
26 ToolUse {
27 id: String,
28 name: String,
29 input: String,
30 },
31 ToolResult {
32 tool_use_id: String,
33 tool_name: String,
34 output: String,
35 is_error: bool,
36 },
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
40pub struct ConversationMessage {
41 pub role: MessageRole,
42 pub blocks: Vec<ContentBlock>,
43 pub usage: Option<TokenUsage>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
47pub struct Session {
48 pub version: u32,
49 pub messages: Vec<ConversationMessage>,
50}
51
52#[derive(Debug)]
53pub enum SessionError {
54 Io(std::io::Error),
55 Json(JsonError),
56 Format(String),
57}
58
59impl Display for SessionError {
60 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
61 match self {
62 Self::Io(error) => write!(f, "{error}"),
63 Self::Json(error) => write!(f, "{error}"),
64 Self::Format(error) => write!(f, "{error}"),
65 }
66 }
67}
68
69impl std::error::Error for SessionError {}
70
71impl From<std::io::Error> for SessionError {
72 fn from(value: std::io::Error) -> Self {
73 Self::Io(value)
74 }
75}
76
77impl From<JsonError> for SessionError {
78 fn from(value: JsonError) -> Self {
79 Self::Json(value)
80 }
81}
82
83impl Session {
84 #[must_use]
85 pub fn new() -> Self {
86 Self {
87 version: 1,
88 messages: Vec::new(),
89 }
90 }
91
92 pub fn save_to_path(&self, path: impl AsRef<Path>) -> Result<(), SessionError> {
93 fs::write(path, self.to_json().render())?;
94 Ok(())
95 }
96
97 pub fn load_from_path(path: impl AsRef<Path>) -> Result<Self, SessionError> {
98 let contents = fs::read_to_string(path)?;
99 Self::from_json(&JsonValue::parse(&contents)?)
100 }
101
102 #[must_use]
103 pub fn to_json(&self) -> JsonValue {
104 let mut object = BTreeMap::new();
105 object.insert(
106 "version".to_string(),
107 JsonValue::Number(i64::from(self.version)),
108 );
109 object.insert(
110 "messages".to_string(),
111 JsonValue::Array(
112 self.messages
113 .iter()
114 .map(ConversationMessage::to_json)
115 .collect(),
116 ),
117 );
118 JsonValue::Object(object)
119 }
120
121 pub fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
122 let object = value
123 .as_object()
124 .ok_or_else(|| SessionError::Format("session must be an object".to_string()))?;
125 let version = object
126 .get("version")
127 .and_then(JsonValue::as_i64)
128 .ok_or_else(|| SessionError::Format("missing version".to_string()))?;
129 let version = u32::try_from(version)
130 .map_err(|_| SessionError::Format("version out of range".to_string()))?;
131 let messages = object
132 .get("messages")
133 .and_then(JsonValue::as_array)
134 .ok_or_else(|| SessionError::Format("missing messages".to_string()))?
135 .iter()
136 .map(ConversationMessage::from_json)
137 .collect::<Result<Vec<_>, _>>()?;
138 Ok(Self { version, messages })
139 }
140}
141
142impl Default for Session {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148impl ConversationMessage {
149 #[must_use]
150 pub fn user_text(text: impl Into<String>) -> Self {
151 Self {
152 role: MessageRole::User,
153 blocks: vec![ContentBlock::Text { text: text.into() }],
154 usage: None,
155 }
156 }
157
158 #[must_use]
159 pub fn assistant(blocks: Vec<ContentBlock>) -> Self {
160 Self {
161 role: MessageRole::Assistant,
162 blocks,
163 usage: None,
164 }
165 }
166
167 #[must_use]
168 pub fn assistant_with_usage(blocks: Vec<ContentBlock>, usage: Option<TokenUsage>) -> Self {
169 Self {
170 role: MessageRole::Assistant,
171 blocks,
172 usage,
173 }
174 }
175
176 #[must_use]
177 pub fn tool_result(
178 tool_use_id: impl Into<String>,
179 tool_name: impl Into<String>,
180 output: impl Into<String>,
181 is_error: bool,
182 ) -> Self {
183 Self {
184 role: MessageRole::Tool,
185 blocks: vec![ContentBlock::ToolResult {
186 tool_use_id: tool_use_id.into(),
187 tool_name: tool_name.into(),
188 output: output.into(),
189 is_error,
190 }],
191 usage: None,
192 }
193 }
194
195 #[must_use]
196 pub fn to_json(&self) -> JsonValue {
197 let mut object = BTreeMap::new();
198 object.insert(
199 "role".to_string(),
200 JsonValue::String(
201 match self.role {
202 MessageRole::System => "system",
203 MessageRole::User => "user",
204 MessageRole::Assistant => "assistant",
205 MessageRole::Tool => "tool",
206 }
207 .to_string(),
208 ),
209 );
210 object.insert(
211 "blocks".to_string(),
212 JsonValue::Array(self.blocks.iter().map(ContentBlock::to_json).collect()),
213 );
214 if let Some(usage) = self.usage {
215 object.insert("usage".to_string(), usage_to_json(usage));
216 }
217 JsonValue::Object(object)
218 }
219
220 fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
221 let object = value
222 .as_object()
223 .ok_or_else(|| SessionError::Format("message must be an object".to_string()))?;
224 let role = match object
225 .get("role")
226 .and_then(JsonValue::as_str)
227 .ok_or_else(|| SessionError::Format("missing role".to_string()))?
228 {
229 "system" => MessageRole::System,
230 "user" => MessageRole::User,
231 "assistant" => MessageRole::Assistant,
232 "tool" => MessageRole::Tool,
233 other => {
234 return Err(SessionError::Format(format!(
235 "unsupported message role: {other}"
236 )))
237 }
238 };
239 let blocks = object
240 .get("blocks")
241 .and_then(JsonValue::as_array)
242 .ok_or_else(|| SessionError::Format("missing blocks".to_string()))?
243 .iter()
244 .map(ContentBlock::from_json)
245 .collect::<Result<Vec<_>, _>>()?;
246 let usage = object.get("usage").map(usage_from_json).transpose()?;
247 Ok(Self {
248 role,
249 blocks,
250 usage,
251 })
252 }
253}
254
255impl ContentBlock {
256 #[must_use]
257 pub fn to_json(&self) -> JsonValue {
258 let mut object = BTreeMap::new();
259 match self {
260 Self::Text { text } => {
261 object.insert("type".to_string(), JsonValue::String("text".to_string()));
262 object.insert("text".to_string(), JsonValue::String(text.clone()));
263 }
264 Self::ToolUse { id, name, input } => {
265 object.insert(
266 "type".to_string(),
267 JsonValue::String("tool_use".to_string()),
268 );
269 object.insert("id".to_string(), JsonValue::String(id.clone()));
270 object.insert("name".to_string(), JsonValue::String(name.clone()));
271 object.insert("input".to_string(), JsonValue::String(input.clone()));
272 }
273 Self::ToolResult {
274 tool_use_id,
275 tool_name,
276 output,
277 is_error,
278 } => {
279 object.insert(
280 "type".to_string(),
281 JsonValue::String("tool_result".to_string()),
282 );
283 object.insert(
284 "tool_use_id".to_string(),
285 JsonValue::String(tool_use_id.clone()),
286 );
287 object.insert(
288 "tool_name".to_string(),
289 JsonValue::String(tool_name.clone()),
290 );
291 object.insert("output".to_string(), JsonValue::String(output.clone()));
292 object.insert("is_error".to_string(), JsonValue::Bool(*is_error));
293 }
294 }
295 JsonValue::Object(object)
296 }
297
298 fn from_json(value: &JsonValue) -> Result<Self, SessionError> {
299 let object = value
300 .as_object()
301 .ok_or_else(|| SessionError::Format("block must be an object".to_string()))?;
302 match object
303 .get("type")
304 .and_then(JsonValue::as_str)
305 .ok_or_else(|| SessionError::Format("missing block type".to_string()))?
306 {
307 "text" => Ok(Self::Text {
308 text: required_string(object, "text")?,
309 }),
310 "tool_use" => Ok(Self::ToolUse {
311 id: required_string(object, "id")?,
312 name: required_string(object, "name")?,
313 input: required_string(object, "input")?,
314 }),
315 "tool_result" => Ok(Self::ToolResult {
316 tool_use_id: required_string(object, "tool_use_id")?,
317 tool_name: required_string(object, "tool_name")?,
318 output: required_string(object, "output")?,
319 is_error: object
320 .get("is_error")
321 .and_then(JsonValue::as_bool)
322 .ok_or_else(|| SessionError::Format("missing is_error".to_string()))?,
323 }),
324 other => Err(SessionError::Format(format!(
325 "unsupported block type: {other}"
326 ))),
327 }
328 }
329}
330
331fn usage_to_json(usage: TokenUsage) -> JsonValue {
332 let mut object = BTreeMap::new();
333 object.insert(
334 "input_tokens".to_string(),
335 JsonValue::Number(i64::from(usage.input_tokens)),
336 );
337 object.insert(
338 "output_tokens".to_string(),
339 JsonValue::Number(i64::from(usage.output_tokens)),
340 );
341 object.insert(
342 "cache_creation_input_tokens".to_string(),
343 JsonValue::Number(i64::from(usage.cache_creation_input_tokens)),
344 );
345 object.insert(
346 "cache_read_input_tokens".to_string(),
347 JsonValue::Number(i64::from(usage.cache_read_input_tokens)),
348 );
349 JsonValue::Object(object)
350}
351
352fn usage_from_json(value: &JsonValue) -> Result<TokenUsage, SessionError> {
353 let object = value
354 .as_object()
355 .ok_or_else(|| SessionError::Format("usage must be an object".to_string()))?;
356 Ok(TokenUsage {
357 input_tokens: required_u32(object, "input_tokens")?,
358 output_tokens: required_u32(object, "output_tokens")?,
359 cache_creation_input_tokens: required_u32(object, "cache_creation_input_tokens")?,
360 cache_read_input_tokens: required_u32(object, "cache_read_input_tokens")?,
361 })
362}
363
364fn required_string(
365 object: &BTreeMap<String, JsonValue>,
366 key: &str,
367) -> Result<String, SessionError> {
368 object
369 .get(key)
370 .and_then(JsonValue::as_str)
371 .map(ToOwned::to_owned)
372 .ok_or_else(|| SessionError::Format(format!("missing {key}")))
373}
374
375fn required_u32(object: &BTreeMap<String, JsonValue>, key: &str) -> Result<u32, SessionError> {
376 let value = object
377 .get(key)
378 .and_then(JsonValue::as_i64)
379 .ok_or_else(|| SessionError::Format(format!("missing {key}")))?;
380 u32::try_from(value).map_err(|_| SessionError::Format(format!("{key} out of range")))
381}
382
383#[cfg(test)]
384mod tests {
385 use super::{ContentBlock, ConversationMessage, MessageRole, Session};
386 use crate::usage::TokenUsage;
387 use std::fs;
388 use std::time::{SystemTime, UNIX_EPOCH};
389
390 #[test]
391 fn persists_and_restores_session_json() {
392 let mut session = Session::new();
393 session
394 .messages
395 .push(ConversationMessage::user_text("hello"));
396 session
397 .messages
398 .push(ConversationMessage::assistant_with_usage(
399 vec![
400 ContentBlock::Text {
401 text: "thinking".to_string(),
402 },
403 ContentBlock::ToolUse {
404 id: "tool-1".to_string(),
405 name: "bash".to_string(),
406 input: "echo hi".to_string(),
407 },
408 ],
409 Some(TokenUsage {
410 input_tokens: 10,
411 output_tokens: 4,
412 cache_creation_input_tokens: 1,
413 cache_read_input_tokens: 2,
414 }),
415 ));
416 session.messages.push(ConversationMessage::tool_result(
417 "tool-1", "bash", "hi", false,
418 ));
419
420 let nanos = SystemTime::now()
421 .duration_since(UNIX_EPOCH)
422 .expect("system time should be after epoch")
423 .as_nanos();
424 let path = std::env::temp_dir().join(format!("runtime-session-{nanos}.json"));
425 session.save_to_path(&path).expect("session should save");
426 let restored = Session::load_from_path(&path).expect("session should load");
427 fs::remove_file(&path).expect("temp file should be removable");
428
429 assert_eq!(restored, session);
430 assert_eq!(restored.messages[2].role, MessageRole::Tool);
431 assert_eq!(
432 restored.messages[1].usage.expect("usage").total_tokens(),
433 17
434 );
435 }
436}