Skip to main content

tandem_wire/
convert.rs

1use std::collections::HashMap;
2
3use serde_json::{json, Value};
4use tandem_types::{Message, MessagePart, ModelSpec, ProviderInfo, Session};
5
6use crate::{
7    WireMessageInfo, WireMessagePart, WireMessageTime, WireModelSpec, WireProviderCatalog,
8    WireProviderEntry, WireProviderModel, WireProviderModelLimit, WireSession, WireSessionMessage,
9    WireSessionTime,
10};
11use std::sync::atomic::{AtomicU64, Ordering};
12
13static PART_SEQ: AtomicU64 = AtomicU64::new(1);
14
15fn next_part_id() -> String {
16    format!("part-{}", PART_SEQ.fetch_add(1, Ordering::Relaxed))
17}
18
19fn to_epoch_seconds(dt: chrono::DateTime<chrono::Utc>) -> u64 {
20    dt.timestamp().max(0) as u64
21}
22
23impl From<ModelSpec> for WireModelSpec {
24    fn from(value: ModelSpec) -> Self {
25        Self {
26            provider_id: value.provider_id,
27            model_id: value.model_id,
28        }
29    }
30}
31
32impl From<Session> for WireSession {
33    fn from(value: Session) -> Self {
34        let session_id = value.id.clone();
35        Self {
36            id: value.id,
37            slug: value.slug,
38            version: value.version,
39            project_id: value.project_id,
40            directory: Some(value.directory),
41            workspace_root: value.workspace_root,
42            origin_workspace_root: value.origin_workspace_root,
43            attached_from_workspace: value.attached_from_workspace,
44            attached_to_workspace: value.attached_to_workspace,
45            attach_timestamp_ms: value.attach_timestamp_ms,
46            attach_reason: value.attach_reason,
47            title: value.title,
48            time: Some(WireSessionTime {
49                created: to_epoch_seconds(value.time.created),
50                updated: to_epoch_seconds(value.time.updated),
51            }),
52            model: value.model.map(Into::into),
53            provider: value.provider,
54            environment: value.environment,
55            messages: value
56                .messages
57                .into_iter()
58                .map(|m| WireSessionMessage::from_message(&m, &session_id))
59                .collect(),
60        }
61    }
62}
63
64impl WireSessionMessage {
65    pub fn from_message(msg: &Message, session_id: &str) -> Self {
66        let info = WireMessageInfo {
67            id: msg.id.clone(),
68            session_id: session_id.to_string(),
69            role: format!("{:?}", msg.role).to_lowercase(),
70            time: WireMessageTime {
71                created: to_epoch_seconds(msg.created_at),
72                completed: None,
73            },
74            summary: None,
75            agent: None,
76            model: None,
77            deleted: None,
78            reverted: None,
79        };
80
81        let parts = msg.parts.iter().map(message_part_to_value).collect();
82        Self { info, parts }
83    }
84}
85
86fn message_part_to_value(part: &MessagePart) -> Value {
87    match part {
88        MessagePart::Text { text } => json!({"type":"text","text":text}),
89        MessagePart::Reasoning { text } => json!({"type":"reasoning","text":text}),
90        MessagePart::ToolInvocation {
91            tool,
92            args,
93            result,
94            error,
95        } => json!({
96            "type":"tool",
97            "tool": tool,
98            "args": args,
99            "result": result,
100            "error": error
101        }),
102    }
103}
104
105impl WireProviderCatalog {
106    pub fn from_providers(providers: Vec<ProviderInfo>, connected: Vec<String>) -> Self {
107        let all = providers
108            .into_iter()
109            .map(|provider| {
110                let models = provider
111                    .models
112                    .into_iter()
113                    .map(|model| {
114                        (
115                            model.id,
116                            WireProviderModel {
117                                name: Some(model.display_name),
118                                limit: Some(WireProviderModelLimit {
119                                    context: Some(model.context_window as u32),
120                                }),
121                            },
122                        )
123                    })
124                    .collect::<HashMap<_, _>>();
125
126                WireProviderEntry {
127                    id: provider.id,
128                    name: Some(provider.name),
129                    models,
130                    catalog_source: None,
131                    catalog_status: None,
132                    catalog_message: None,
133                }
134            })
135            .collect();
136
137        Self { all, connected }
138    }
139}
140
141impl WireMessagePart {
142    pub fn text(session_id: &str, message_id: &str, text: impl Into<String>) -> Self {
143        Self {
144            id: Some(next_part_id()),
145            session_id: Some(session_id.to_string()),
146            message_id: Some(message_id.to_string()),
147            part_type: Some("text".to_string()),
148            text: Some(text.into()),
149            tool: None,
150            args: None,
151            state: None,
152            result: None,
153            error: None,
154        }
155    }
156
157    pub fn tool_invocation(
158        session_id: &str,
159        message_id: &str,
160        tool: impl Into<String>,
161        args: Value,
162    ) -> Self {
163        Self {
164            id: Some(next_part_id()),
165            session_id: Some(session_id.to_string()),
166            message_id: Some(message_id.to_string()),
167            part_type: Some("tool".to_string()),
168            text: None,
169            tool: Some(tool.into()),
170            args: Some(args),
171            state: Some("running".to_string()),
172            result: None,
173            error: None,
174        }
175    }
176
177    pub fn tool_result(
178        session_id: &str,
179        message_id: &str,
180        tool: impl Into<String>,
181        args: Option<Value>,
182        result: Value,
183    ) -> Self {
184        Self {
185            id: Some(next_part_id()),
186            session_id: Some(session_id.to_string()),
187            message_id: Some(message_id.to_string()),
188            part_type: Some("tool".to_string()),
189            text: None,
190            tool: Some(tool.into()),
191            args,
192            state: Some("completed".to_string()),
193            result: Some(result),
194            error: None,
195        }
196    }
197}