1use std::collections::HashMap;
2
3use serde_json::{json, Value};
4use tandem_types::{Message, MessagePart, ModelSpec, ProviderInfo, Session};
5
6use crate::{
7 WireMessageInfo, WireMessagePart, WireMessageTime, WireModelSpec, WireProviderCatalog,
8 WireProviderEntry, WireProviderModel, WireProviderModelLimit, WireSession, WireSessionMessage,
9 WireSessionTime,
10};
11use std::sync::atomic::{AtomicU64, Ordering};
12
13static PART_SEQ: AtomicU64 = AtomicU64::new(1);
14
15fn next_part_id() -> String {
16 format!("part-{}", PART_SEQ.fetch_add(1, Ordering::Relaxed))
17}
18
19fn to_epoch_seconds(dt: chrono::DateTime<chrono::Utc>) -> u64 {
20 dt.timestamp().max(0) as u64
21}
22
23impl From<ModelSpec> for WireModelSpec {
24 fn from(value: ModelSpec) -> Self {
25 Self {
26 provider_id: value.provider_id,
27 model_id: value.model_id,
28 }
29 }
30}
31
32impl From<Session> for WireSession {
33 fn from(value: Session) -> Self {
34 let session_id = value.id.clone();
35 Self {
36 id: value.id,
37 slug: value.slug,
38 version: value.version,
39 project_id: value.project_id,
40 directory: Some(value.directory),
41 workspace_root: value.workspace_root,
42 origin_workspace_root: value.origin_workspace_root,
43 attached_from_workspace: value.attached_from_workspace,
44 attached_to_workspace: value.attached_to_workspace,
45 attach_timestamp_ms: value.attach_timestamp_ms,
46 attach_reason: value.attach_reason,
47 title: value.title,
48 time: Some(WireSessionTime {
49 created: to_epoch_seconds(value.time.created),
50 updated: to_epoch_seconds(value.time.updated),
51 }),
52 model: value.model.map(Into::into),
53 provider: value.provider,
54 environment: value.environment,
55 messages: value
56 .messages
57 .into_iter()
58 .map(|m| WireSessionMessage::from_message(&m, &session_id))
59 .collect(),
60 }
61 }
62}
63
64impl WireSessionMessage {
65 pub fn from_message(msg: &Message, session_id: &str) -> Self {
66 let info = WireMessageInfo {
67 id: msg.id.clone(),
68 session_id: session_id.to_string(),
69 role: format!("{:?}", msg.role).to_lowercase(),
70 time: WireMessageTime {
71 created: to_epoch_seconds(msg.created_at),
72 completed: None,
73 },
74 summary: None,
75 agent: None,
76 model: None,
77 deleted: None,
78 reverted: None,
79 };
80
81 let parts = msg.parts.iter().map(message_part_to_value).collect();
82 Self { info, parts }
83 }
84}
85
86fn message_part_to_value(part: &MessagePart) -> Value {
87 match part {
88 MessagePart::Text { text } => json!({"type":"text","text":text}),
89 MessagePart::Reasoning { text } => json!({"type":"reasoning","text":text}),
90 MessagePart::ToolInvocation {
91 tool,
92 args,
93 result,
94 error,
95 } => json!({
96 "type":"tool",
97 "tool": tool,
98 "args": args,
99 "result": result,
100 "error": error
101 }),
102 }
103}
104
105impl WireProviderCatalog {
106 pub fn from_providers(providers: Vec<ProviderInfo>, connected: Vec<String>) -> Self {
107 let all = providers
108 .into_iter()
109 .map(|provider| {
110 let models = provider
111 .models
112 .into_iter()
113 .map(|model| {
114 (
115 model.id,
116 WireProviderModel {
117 name: Some(model.display_name),
118 limit: Some(WireProviderModelLimit {
119 context: Some(model.context_window as u32),
120 }),
121 },
122 )
123 })
124 .collect::<HashMap<_, _>>();
125
126 WireProviderEntry {
127 id: provider.id,
128 name: Some(provider.name),
129 models,
130 catalog_source: None,
131 catalog_status: None,
132 catalog_message: None,
133 }
134 })
135 .collect();
136
137 Self { all, connected }
138 }
139}
140
141impl WireMessagePart {
142 pub fn text(session_id: &str, message_id: &str, text: impl Into<String>) -> Self {
143 Self {
144 id: Some(next_part_id()),
145 session_id: Some(session_id.to_string()),
146 message_id: Some(message_id.to_string()),
147 part_type: Some("text".to_string()),
148 text: Some(text.into()),
149 tool: None,
150 args: None,
151 state: None,
152 result: None,
153 error: None,
154 }
155 }
156
157 pub fn tool_invocation(
158 session_id: &str,
159 message_id: &str,
160 tool: impl Into<String>,
161 args: Value,
162 ) -> Self {
163 Self {
164 id: Some(next_part_id()),
165 session_id: Some(session_id.to_string()),
166 message_id: Some(message_id.to_string()),
167 part_type: Some("tool".to_string()),
168 text: None,
169 tool: Some(tool.into()),
170 args: Some(args),
171 state: Some("running".to_string()),
172 result: None,
173 error: None,
174 }
175 }
176
177 pub fn tool_result(
178 session_id: &str,
179 message_id: &str,
180 tool: impl Into<String>,
181 args: Option<Value>,
182 result: Value,
183 ) -> Self {
184 Self {
185 id: Some(next_part_id()),
186 session_id: Some(session_id.to_string()),
187 message_id: Some(message_id.to_string()),
188 part_type: Some("tool".to_string()),
189 text: None,
190 tool: Some(tool.into()),
191 args,
192 state: Some("completed".to_string()),
193 result: Some(result),
194 error: None,
195 }
196 }
197}