1use std::collections::BTreeMap;
12use std::str::FromStr;
13
14use base64::prelude::BASE64_STANDARD;
15use base64::Engine;
16use serde::de::DeserializeOwned;
17use serde::Serialize;
18use serde_json::Value;
19use thiserror::Error;
20
21use crate::bundle::MimeBundle;
22use crate::message::{EmitBlock, Message, OpenBlock, PatchBlock};
23use crate::tier::TrustTier;
24use crate::{BlockId, OSC_NUMBER, PROTOCOL_VERSION};
25
26const OSC_START: &str = "\x1b]";
32
33const ST: &str = "\x1b\\";
35
36const FIELD_SEP: char = ';';
37const PARAM_SEP: char = ',';
38const KEY_VALUE_SEP: char = '=';
39
40const VERB_CAPS: &str = "caps";
41const VERB_CLOSE: &str = "close";
42const VERB_EMIT: &str = "emit";
43const VERB_OPEN: &str = "open";
44const VERB_PATCH: &str = "patch";
45
46const PARAM_FILE: &str = "file";
47const PARAM_ID: &str = "id";
48const PARAM_MIME: &str = "mime";
49const PARAM_TRUST: &str = "trust";
50const PARAM_VERSION: &str = "v";
51
52#[derive(Clone, Debug, Eq, Error, PartialEq)]
58pub enum ProtoError {
59 #[error("invalid value for parameter `{0}`")]
60 BadParam(String),
61 #[error("payload is not valid base64")]
62 Base64,
63 #[error("payload is not valid JSON")]
64 Json,
65 #[error("malformed TBP frame")]
66 MalformedFrame,
67 #[error("missing required parameter `{0}`")]
68 MissingParam(String),
69 #[error("unknown TBP verb `{0}`")]
70 UnknownVerb(String),
71 #[error("escape is not a TBP message")]
72 WrongOsc,
73}
74
75pub fn encode(message: &Message) -> String {
81 match message {
82 Message::Caps => frame(VERB_CAPS, &[], None),
83 Message::Close(id) => frame(VERB_CLOSE, &[(PARAM_ID, id.0.to_string())], None),
84 Message::Emit(block) => encode_emit(block),
85 Message::Open(block) => encode_open(block),
86 Message::Patch(block) => encode_patch(block),
87 }
88}
89
90fn encode_emit(block: &EmitBlock) -> String {
91 let params = [
92 (PARAM_VERSION, PROTOCOL_VERSION.0.to_string()),
93 (PARAM_ID, block.id.0.to_string()),
94 (PARAM_TRUST, block.trust.as_str().to_string()),
95 ];
96 frame(VERB_EMIT, ¶ms, Some(encode_json(&block.bundle)))
97}
98
99fn encode_open(block: &OpenBlock) -> String {
100 let params = [
101 (PARAM_ID, block.id.0.to_string()),
102 (PARAM_MIME, block.mime.clone()),
103 ];
104 frame(VERB_OPEN, ¶ms, Some(encode_json(&block.spec)))
105}
106
107fn encode_patch(block: &PatchBlock) -> String {
108 let params = [(PARAM_ID, block.id.0.to_string())];
109 frame(VERB_PATCH, ¶ms, Some(encode_json(&block.patch)))
110}
111
112fn frame(verb: &str, params: &[(&str, String)], payload: Option<String>) -> String {
113 let mut out = String::from(OSC_START);
114 out.push_str(&OSC_NUMBER.to_string());
115 out.push(FIELD_SEP);
116 out.push_str(verb);
117 if !params.is_empty() {
118 out.push(FIELD_SEP);
119 for (index, (key, value)) in params.iter().enumerate() {
120 if index > 0 {
121 out.push(PARAM_SEP);
122 }
123 out.push_str(key);
124 out.push(KEY_VALUE_SEP);
125 out.push_str(value);
126 }
127 }
128 if let Some(payload) = payload {
129 out.push(FIELD_SEP);
130 out.push_str(&payload);
131 }
132 out.push_str(ST);
133 out
134}
135
136fn encode_json(value: &impl Serialize) -> String {
137 let json = serde_json::to_vec(value).expect("proto value is always serializable");
138 BASE64_STANDARD.encode(json)
139}
140
141pub fn decode(body: &str) -> Result<Message, ProtoError> {
148 decode_with_sidechannel(body, |_| Err(ProtoError::MalformedFrame))
149}
150
151pub fn decode_with_sidechannel(
157 body: &str,
158 file_reader: impl Fn(&str) -> Result<Vec<u8>, ProtoError>,
159) -> Result<Message, ProtoError> {
160 let mut fields = body.split(FIELD_SEP);
161 let osc = fields.next().ok_or(ProtoError::MalformedFrame)?;
162 if osc.parse::<u32>().ok() != Some(OSC_NUMBER) {
163 return Err(ProtoError::WrongOsc);
164 }
165 let verb = fields.next().ok_or(ProtoError::MalformedFrame)?;
166 let rest: Vec<&str> = fields.collect();
167 match verb {
168 VERB_CAPS => Ok(Message::Caps),
169 VERB_CLOSE => Ok(Message::Close(decode_id(&rest)?)),
170 VERB_EMIT => decode_emit_with_sidechannel(&rest, &file_reader),
171 VERB_OPEN => decode_open(&rest),
172 VERB_PATCH => decode_patch(&rest),
173 other => Err(ProtoError::UnknownVerb(other.to_string())),
174 }
175}
176
177fn decode_emit_with_sidechannel(
178 rest: &[&str],
179 file_reader: &impl Fn(&str) -> Result<Vec<u8>, ProtoError>,
180) -> Result<Message, ProtoError> {
181 let params = parse_params(rest.first().copied().unwrap_or_default());
182 let id = required_id(¶ms)?;
183 let trust = match params.get(PARAM_TRUST) {
184 Some(raw) => {
185 TrustTier::from_str(raw).map_err(|_| ProtoError::BadParam(PARAM_TRUST.to_string()))?
186 }
187 None => TrustTier::default(),
188 };
189
190 let bundle = if let Some(&file_path) = params.get(PARAM_FILE) {
191 let raw_bytes = file_reader(file_path)?;
192 let b64 = BASE64_STANDARD.encode(&raw_bytes);
193 let mime = params
194 .get(PARAM_MIME)
195 .copied()
196 .unwrap_or("application/octet-stream");
197 let mut bundle = MimeBundle::new();
198 bundle.insert(mime, Value::from(b64));
199 bundle
200 } else {
201 decode_payload(rest)?
202 };
203
204 Ok(Message::Emit(EmitBlock { bundle, id, trust }))
205}
206
207fn decode_open(rest: &[&str]) -> Result<Message, ProtoError> {
208 let params = parse_params(rest.first().copied().unwrap_or_default());
209 let id = required_id(¶ms)?;
210 let mime = params
211 .get(PARAM_MIME)
212 .ok_or_else(|| ProtoError::MissingParam(PARAM_MIME.to_string()))?
213 .to_string();
214 let spec: Value = decode_payload(rest)?;
215 Ok(Message::Open(OpenBlock { id, mime, spec }))
216}
217
218fn decode_patch(rest: &[&str]) -> Result<Message, ProtoError> {
219 let params = parse_params(rest.first().copied().unwrap_or_default());
220 let id = required_id(¶ms)?;
221 let patch: Value = decode_payload(rest)?;
222 Ok(Message::Patch(PatchBlock { id, patch }))
223}
224
225fn decode_id(rest: &[&str]) -> Result<BlockId, ProtoError> {
226 required_id(&parse_params(rest.first().copied().unwrap_or_default()))
227}
228
229fn parse_params(field: &str) -> BTreeMap<&str, &str> {
230 field
231 .split(PARAM_SEP)
232 .filter(|pair| !pair.is_empty())
233 .filter_map(|pair| pair.split_once(KEY_VALUE_SEP))
234 .collect()
235}
236
237fn required_id(params: &BTreeMap<&str, &str>) -> Result<BlockId, ProtoError> {
238 let raw = params
239 .get(PARAM_ID)
240 .ok_or_else(|| ProtoError::MissingParam(PARAM_ID.to_string()))?;
241 raw.parse::<u64>()
242 .map(BlockId)
243 .map_err(|_| ProtoError::BadParam(PARAM_ID.to_string()))
244}
245
246fn decode_payload<T: DeserializeOwned>(rest: &[&str]) -> Result<T, ProtoError> {
247 let payload = rest.get(1).ok_or(ProtoError::MalformedFrame)?;
248 let bytes = BASE64_STANDARD
249 .decode(payload)
250 .map_err(|_| ProtoError::Base64)?;
251 serde_json::from_slice(&bytes).map_err(|_| ProtoError::Json)
252}
253
254#[cfg(test)]
259mod tests {
260 use super::*;
261
262 fn strip_frame(escape: &str) -> &str {
263 escape
264 .strip_prefix(OSC_START)
265 .and_then(|rest| rest.strip_suffix(ST))
266 .expect("encoded escape is OSC-framed")
267 }
268
269 fn round_trip(message: &Message) -> Message {
270 decode(strip_frame(&encode(message))).expect("re-decodes")
271 }
272
273 fn sample_emit() -> Message {
274 let mut bundle = MimeBundle::new();
275 bundle.insert("text/plain", Value::from("rows: 3"));
276 bundle.insert("image/svg+xml", Value::from("<svg/>"));
277 Message::Emit(EmitBlock {
278 bundle,
279 id: BlockId(42),
280 trust: TrustTier::Trusted,
281 })
282 }
283
284 #[test]
285 fn test_emit_round_trips_through_the_wire() {
286 let message = sample_emit();
287 assert_eq!(round_trip(&message), message);
288 }
289
290 #[test]
291 fn test_open_patch_close_round_trip() {
292 let open = Message::Open(OpenBlock {
293 id: BlockId(7),
294 mime: "application/vnd.vega-lite+json".to_string(),
295 spec: serde_json::json!({ "mark": "bar" }),
296 });
297 let patch = Message::Patch(PatchBlock {
298 id: BlockId(7),
299 patch: serde_json::json!([{ "op": "replace", "path": "/mark", "value": "line" }]),
300 });
301 let close = Message::Close(BlockId(7));
302 assert_eq!(round_trip(&open), open);
303 assert_eq!(round_trip(&patch), patch);
304 assert_eq!(round_trip(&close), close);
305 }
306
307 #[test]
308 fn test_caps_query_round_trips() {
309 assert_eq!(round_trip(&Message::Caps), Message::Caps);
310 }
311
312 #[test]
313 fn test_encoded_emit_is_framed_by_osc_and_st() {
314 let escape = encode(&sample_emit());
315 assert!(escape.starts_with(OSC_START));
316 assert!(escape.ends_with(ST));
317 }
318
319 #[test]
320 fn test_emit_without_trust_param_defaults_to_restricted() {
321 let bundle = encode_json(&MimeBundle::new());
323 let body = format!("{OSC_NUMBER};{VERB_EMIT};{PARAM_ID}=5;{bundle}");
324 match decode(&body) {
325 Ok(Message::Emit(block)) => assert_eq!(block.trust, TrustTier::Restricted),
326 other => panic!("expected emit, got {other:?}"),
327 }
328 }
329
330 #[test]
331 fn test_non_tbp_osc_is_rejected() {
332 let body = "8;;https://example.com";
333 assert_eq!(decode(body), Err(ProtoError::WrongOsc));
334 }
335
336 #[test]
337 fn test_unknown_verb_is_reported() {
338 let body = format!("{OSC_NUMBER};teleport;{PARAM_ID}=1");
339 assert_eq!(
340 decode(&body),
341 Err(ProtoError::UnknownVerb("teleport".to_string()))
342 );
343 }
344
345 #[test]
346 fn test_emit_missing_id_is_reported() {
347 let bundle = encode_json(&MimeBundle::new());
348 let body = format!("{OSC_NUMBER};{VERB_EMIT};;{bundle}");
349 assert_eq!(
350 decode(&body),
351 Err(ProtoError::MissingParam(PARAM_ID.to_string()))
352 );
353 }
354
355 #[test]
356 fn test_emit_with_corrupt_base64_payload_is_reported() {
357 let body = format!("{OSC_NUMBER};{VERB_EMIT};{PARAM_ID}=1;not!base64!");
358 assert_eq!(decode(&body), Err(ProtoError::Base64));
359 }
360
361 #[test]
362 fn test_side_channel_emit_reads_file() {
363 let png_bytes = vec![0x89u8, b'P', b'N', b'G'];
364 let png_clone = png_bytes.clone();
365 let body = format!(
366 "{OSC_NUMBER};{VERB_EMIT};{PARAM_ID}=10,{PARAM_FILE}=test.png,{PARAM_MIME}=image/png"
367 );
368 let result = decode_with_sidechannel(&body, move |_path| Ok(png_clone.clone()));
369 match result {
370 Ok(Message::Emit(block)) => {
371 assert_eq!(block.id, BlockId(10));
372 let value = block.bundle.get("image/png").expect("has image/png");
373 assert!(value.is_string());
374 let decoded = BASE64_STANDARD.decode(value.as_str().unwrap()).unwrap();
375 assert_eq!(decoded, png_bytes);
376 }
377 other => panic!("expected emit, got {other:?}"),
378 }
379 }
380
381 #[test]
382 fn test_side_channel_path_traversal_calls_reader() {
383 let body = format!(
384 "{OSC_NUMBER};{VERB_EMIT};{PARAM_ID}=1,{PARAM_FILE}=../etc/passwd"
385 );
386 let result = decode_with_sidechannel(&body, |path| {
387 assert!(path.contains(".."));
388 Err(ProtoError::MalformedFrame)
389 });
390 assert_eq!(result, Err(ProtoError::MalformedFrame));
391 }
392
393 #[test]
394 fn test_side_channel_falls_back_to_inline_when_no_file_param() {
395 let bundle = encode_json(&MimeBundle::new());
396 let body = format!("{OSC_NUMBER};{VERB_EMIT};{PARAM_ID}=5;{bundle}");
397 let result =
398 decode_with_sidechannel(&body, |_path| Err(ProtoError::MalformedFrame));
399 assert!(result.is_ok());
400 }
401}