Skip to main content

tandem_wire/
convert.rs

1use std::collections::HashMap;
2
3use serde_json::{json, Value};
4use tandem_types::{Message, MessagePart, ModelSpec, ProviderInfo, Session};
5
6use crate::{
7    WireMessageInfo, WireMessagePart, WireMessageTime, WireModelSpec, WireProviderCatalog,
8    WireProviderEntry, WireProviderModel, WireProviderModelLimit, WireSession, WireSessionMessage,
9    WireSessionTime,
10};
11use std::sync::atomic::{AtomicU64, Ordering};
12
13static PART_SEQ: AtomicU64 = AtomicU64::new(1);
14
15fn next_part_id() -> String {
16    format!("part-{}", PART_SEQ.fetch_add(1, Ordering::Relaxed))
17}
18
19fn to_epoch_seconds(dt: chrono::DateTime<chrono::Utc>) -> u64 {
20    dt.timestamp().max(0) as u64
21}
22
23impl From<ModelSpec> for WireModelSpec {
24    fn from(value: ModelSpec) -> Self {
25        Self {
26            provider_id: value.provider_id,
27            model_id: value.model_id,
28        }
29    }
30}
31
32impl From<Session> for WireSession {
33    fn from(value: Session) -> Self {
34        let session_id = value.id.clone();
35        Self {
36            id: value.id,
37            slug: value.slug,
38            version: value.version,
39            project_id: value.project_id,
40            directory: Some(value.directory),
41            workspace_root: value.workspace_root,
42            origin_workspace_root: value.origin_workspace_root,
43            attached_from_workspace: value.attached_from_workspace,
44            attached_to_workspace: value.attached_to_workspace,
45            attach_timestamp_ms: value.attach_timestamp_ms,
46            attach_reason: value.attach_reason,
47            title: value.title,
48            time: Some(WireSessionTime {
49                created: to_epoch_seconds(value.time.created),
50                updated: to_epoch_seconds(value.time.updated),
51            }),
52            model: value.model.map(Into::into),
53            provider: value.provider,
54            messages: value
55                .messages
56                .into_iter()
57                .map(|m| WireSessionMessage::from_message(&m, &session_id))
58                .collect(),
59        }
60    }
61}
62
63impl WireSessionMessage {
64    pub fn from_message(msg: &Message, session_id: &str) -> Self {
65        let info = WireMessageInfo {
66            id: msg.id.clone(),
67            session_id: session_id.to_string(),
68            role: format!("{:?}", msg.role).to_lowercase(),
69            time: WireMessageTime {
70                created: to_epoch_seconds(msg.created_at),
71                completed: None,
72            },
73            summary: None,
74            agent: None,
75            model: None,
76            deleted: None,
77            reverted: None,
78        };
79
80        let parts = msg.parts.iter().map(message_part_to_value).collect();
81        Self { info, parts }
82    }
83}
84
85fn message_part_to_value(part: &MessagePart) -> Value {
86    match part {
87        MessagePart::Text { text } => json!({"type":"text","text":text}),
88        MessagePart::Reasoning { text } => json!({"type":"reasoning","text":text}),
89        MessagePart::ToolInvocation {
90            tool,
91            args,
92            result,
93            error,
94        } => json!({
95            "type":"tool",
96            "tool": tool,
97            "args": args,
98            "result": result,
99            "error": error
100        }),
101    }
102}
103
104impl WireProviderCatalog {
105    pub fn from_providers(providers: Vec<ProviderInfo>, connected: Vec<String>) -> Self {
106        let all = providers
107            .into_iter()
108            .map(|provider| {
109                let models = provider
110                    .models
111                    .into_iter()
112                    .map(|model| {
113                        (
114                            model.id,
115                            WireProviderModel {
116                                name: Some(model.display_name),
117                                limit: Some(WireProviderModelLimit {
118                                    context: Some(model.context_window as u32),
119                                }),
120                            },
121                        )
122                    })
123                    .collect::<HashMap<_, _>>();
124
125                WireProviderEntry {
126                    id: provider.id,
127                    name: Some(provider.name),
128                    models,
129                }
130            })
131            .collect();
132
133        Self { all, connected }
134    }
135}
136
137impl WireMessagePart {
138    pub fn text(session_id: &str, message_id: &str, text: impl Into<String>) -> Self {
139        Self {
140            id: Some(next_part_id()),
141            session_id: Some(session_id.to_string()),
142            message_id: Some(message_id.to_string()),
143            part_type: Some("text".to_string()),
144            text: Some(text.into()),
145            tool: None,
146            args: None,
147            state: None,
148            result: None,
149            error: None,
150        }
151    }
152
153    pub fn tool_invocation(
154        session_id: &str,
155        message_id: &str,
156        tool: impl Into<String>,
157        args: Value,
158    ) -> Self {
159        Self {
160            id: Some(next_part_id()),
161            session_id: Some(session_id.to_string()),
162            message_id: Some(message_id.to_string()),
163            part_type: Some("tool".to_string()),
164            text: None,
165            tool: Some(tool.into()),
166            args: Some(args),
167            state: Some("running".to_string()),
168            result: None,
169            error: None,
170        }
171    }
172
173    pub fn tool_result(
174        session_id: &str,
175        message_id: &str,
176        tool: impl Into<String>,
177        result: Value,
178    ) -> Self {
179        Self {
180            id: Some(next_part_id()),
181            session_id: Some(session_id.to_string()),
182            message_id: Some(message_id.to_string()),
183            part_type: Some("tool".to_string()),
184            text: None,
185            tool: Some(tool.into()),
186            args: None,
187            state: Some("completed".to_string()),
188            result: Some(result),
189            error: None,
190        }
191    }
192}