Skip to main content

tandem_wire/
convert.rs

1use std::collections::HashMap;
2
3use serde_json::{json, Value};
4use tandem_types::{Message, MessagePart, ModelSpec, ProviderInfo, Session};
5
6use crate::{
7    WireMessageInfo, WireMessagePart, WireMessageTime, WireModelSpec, WireProviderCatalog,
8    WireProviderEntry, WireProviderModel, WireProviderModelLimit, WireSession, WireSessionMessage,
9    WireSessionTime,
10};
11use std::sync::atomic::{AtomicU64, Ordering};
12
13static PART_SEQ: AtomicU64 = AtomicU64::new(1);
14
15fn next_part_id() -> String {
16    format!("part-{}", PART_SEQ.fetch_add(1, Ordering::Relaxed))
17}
18
19fn to_epoch_seconds(dt: chrono::DateTime<chrono::Utc>) -> u64 {
20    dt.timestamp().max(0) as u64
21}
22
23impl From<ModelSpec> for WireModelSpec {
24    fn from(value: ModelSpec) -> Self {
25        Self {
26            provider_id: value.provider_id,
27            model_id: value.model_id,
28        }
29    }
30}
31
32impl From<Session> for WireSession {
33    fn from(value: Session) -> Self {
34        let session_id = value.id.clone();
35        Self {
36            id: value.id,
37            slug: value.slug,
38            version: value.version,
39            project_id: value.project_id,
40            directory: Some(value.directory),
41            workspace_root: value.workspace_root,
42            origin_workspace_root: value.origin_workspace_root,
43            attached_from_workspace: value.attached_from_workspace,
44            attached_to_workspace: value.attached_to_workspace,
45            attach_timestamp_ms: value.attach_timestamp_ms,
46            attach_reason: value.attach_reason,
47            title: value.title,
48            time: Some(WireSessionTime {
49                created: to_epoch_seconds(value.time.created),
50                updated: to_epoch_seconds(value.time.updated),
51            }),
52            model: value.model.map(Into::into),
53            provider: value.provider,
54            environment: value.environment,
55            messages: value
56                .messages
57                .into_iter()
58                .map(|m| WireSessionMessage::from_message(&m, &session_id))
59                .collect(),
60        }
61    }
62}
63
64impl WireSessionMessage {
65    pub fn from_message(msg: &Message, session_id: &str) -> Self {
66        let info = WireMessageInfo {
67            id: msg.id.clone(),
68            session_id: session_id.to_string(),
69            role: format!("{:?}", msg.role).to_lowercase(),
70            time: WireMessageTime {
71                created: to_epoch_seconds(msg.created_at),
72                completed: None,
73            },
74            summary: None,
75            agent: None,
76            model: None,
77            deleted: None,
78            reverted: None,
79        };
80
81        let parts = msg.parts.iter().map(message_part_to_value).collect();
82        Self { info, parts }
83    }
84}
85
86fn message_part_to_value(part: &MessagePart) -> Value {
87    match part {
88        MessagePart::Text { text } => json!({"type":"text","text":text}),
89        MessagePart::Reasoning { text } => json!({"type":"reasoning","text":text}),
90        MessagePart::ToolInvocation {
91            tool,
92            args,
93            result,
94            error,
95        } => json!({
96            "type":"tool",
97            "tool": tool,
98            "args": args,
99            "result": result,
100            "error": error
101        }),
102    }
103}
104
105impl WireProviderCatalog {
106    pub fn from_providers(providers: Vec<ProviderInfo>, connected: Vec<String>) -> Self {
107        let all = providers
108            .into_iter()
109            .map(|provider| {
110                let models = provider
111                    .models
112                    .into_iter()
113                    .map(|model| {
114                        (
115                            model.id,
116                            WireProviderModel {
117                                name: Some(model.display_name),
118                                limit: Some(WireProviderModelLimit {
119                                    context: Some(model.context_window as u32),
120                                }),
121                            },
122                        )
123                    })
124                    .collect::<HashMap<_, _>>();
125
126                WireProviderEntry {
127                    id: provider.id,
128                    name: Some(provider.name),
129                    models,
130                }
131            })
132            .collect();
133
134        Self { all, connected }
135    }
136}
137
138impl WireMessagePart {
139    pub fn text(session_id: &str, message_id: &str, text: impl Into<String>) -> Self {
140        Self {
141            id: Some(next_part_id()),
142            session_id: Some(session_id.to_string()),
143            message_id: Some(message_id.to_string()),
144            part_type: Some("text".to_string()),
145            text: Some(text.into()),
146            tool: None,
147            args: None,
148            state: None,
149            result: None,
150            error: None,
151        }
152    }
153
154    pub fn tool_invocation(
155        session_id: &str,
156        message_id: &str,
157        tool: impl Into<String>,
158        args: Value,
159    ) -> Self {
160        Self {
161            id: Some(next_part_id()),
162            session_id: Some(session_id.to_string()),
163            message_id: Some(message_id.to_string()),
164            part_type: Some("tool".to_string()),
165            text: None,
166            tool: Some(tool.into()),
167            args: Some(args),
168            state: Some("running".to_string()),
169            result: None,
170            error: None,
171        }
172    }
173
174    pub fn tool_result(
175        session_id: &str,
176        message_id: &str,
177        tool: impl Into<String>,
178        result: Value,
179    ) -> Self {
180        Self {
181            id: Some(next_part_id()),
182            session_id: Some(session_id.to_string()),
183            message_id: Some(message_id.to_string()),
184            part_type: Some("tool".to_string()),
185            text: None,
186            tool: Some(tool.into()),
187            args: None,
188            state: Some("completed".to_string()),
189            result: Some(result),
190            error: None,
191        }
192    }
193}