Skip to main content

trino_rust_client/models/
result.rs

1use serde::Deserialize;
2use std::collections::HashMap;
3
4use super::*;
5#[cfg(feature = "spooling")]
6use crate::error::Error;
7#[cfg(feature = "spooling")]
8use crate::spooling::{decode_inline_segment, Segment};
9use crate::Trino;
10
11/// Query result data can be either Direct (inline array) or Spooled (compressed segments)
12#[derive(Deserialize, Debug)]
13#[serde(rename_all = "camelCase")]
14#[serde(untagged)]
15pub enum QueryResultData<T: Trino> {
16    // Spooled protocol: data is an object with encoding and segments
17    Spooled(SpooledData),
18    // Direct protocol: data is a simple JSON array
19    #[serde(bound(deserialize = "Vec<T>: Deserialize<'de>"))]
20    Direct(Vec<T>),
21}
22
23impl<T> QueryResultData<T>
24where
25    T: Trino,
26    for<'de> T: serde::Deserialize<'de>,
27{
28    /// Convert into Vec for both Direct and Spooled variants
29    pub fn into_vec(self) -> Vec<T> {
30        match self {
31            QueryResultData::Direct(data) => data,
32            #[cfg(feature = "spooling")]
33            QueryResultData::Spooled(spooled) => spooled.parse_segments().unwrap_or_else(|e| {
34                log::error!("Failed to parse spooled segments: {}", e);
35                Vec::new()
36            }),
37            #[cfg(not(feature = "spooling"))]
38            QueryResultData::Spooled(_) => {
39                panic!("Spooling feature not enabled")
40            }
41        }
42    }
43}
44
45/// Spooled data contains encoding format and segment references
46#[derive(Deserialize, Debug)]
47#[serde(rename_all = "camelCase")]
48pub struct SpooledData {
49    pub encoding: String,
50    #[cfg(feature = "spooling")]
51    pub segments: Vec<Segment>,
52}
53
54#[cfg(feature = "spooling")]
55impl SpooledData {
56    /// Parse all segments and return the rows
57    fn parse_segments<T>(&self) -> Result<Vec<T>, Error>
58    where
59        for<'de> T: Trino + serde::Deserialize<'de>,
60    {
61        let mut all_rows = Vec::new();
62
63        for (idx, segment) in self.segments.iter().enumerate() {
64            match segment {
65                Segment::Inlined { data, .. } => {
66                    let decompressed = decode_inline_segment(data, &self.encoding)?;
67                    let rows: Vec<T> = serde_json::from_str(&decompressed).map_err(|e| {
68                        Error::InternalError(format!("Failed to parse segment {} JSON: {}", idx, e))
69                    })?;
70                    all_rows.reserve(rows.len());
71                    for row in rows {
72                        all_rows.push(row);
73                    }
74                }
75                Segment::Spooled { .. } => {
76                    return Err(Error::InternalError(
77                        "Remote spooled segments not supported in this code path. Use Client::get_all() instead.".to_string(),
78                    ));
79                }
80            }
81        }
82
83        Ok(all_rows)
84    }
85}
86
87/// Metadata about spooled data segments
88#[derive(Deserialize, Debug)]
89#[serde(rename_all = "camelCase")]
90pub struct DataAttributes {
91    pub rows_count: Option<u64>,
92    pub segment_size: Option<u64>,
93    #[serde(flatten)]
94    pub extra: HashMap<String, serde_json::Value>,
95}
96
97/// Trino query result
98#[derive(Deserialize, Debug)]
99#[serde(rename_all = "camelCase")]
100pub struct QueryResult<T: Trino> {
101    pub id: String,
102    pub info_uri: String,
103    pub partial_cancel_uri: Option<String>,
104    pub next_uri: Option<String>,
105
106    pub columns: Option<Vec<Column>>,
107
108    #[serde(bound(deserialize = "Option<QueryResultData<T>>: Deserialize<'de>"))]
109    pub data: Option<QueryResultData<T>>,
110
111    pub error: Option<QueryError>,
112
113    pub stats: Stat,
114    pub warnings: Vec<Warning>,
115
116    pub update_type: Option<String>,
117    pub update_count: Option<u64>,
118}
119
120#[cfg(test)]
121#[cfg(feature = "spooling")]
122mod tests {
123    use super::*;
124    use base64::prelude::*;
125
126    #[test]
127    fn test_parse_segments_multiple_inline() {
128        let rows_json1 = r#"[["alice",1],["bob",2]]"#;
129        let rows_json2 = r#"[["charlie",3]]"#;
130
131        let encoded1 = BASE64_STANDARD.encode(rows_json1.as_bytes());
132        let encoded2 = BASE64_STANDARD.encode(rows_json2.as_bytes());
133
134        let segment1_json = format!(
135            r#"{{"type":"inline","data":"{}","metadata":{{}}}}"#,
136            encoded1
137        );
138        let segment2_json = format!(
139            r#"{{"type":"inline","data":"{}","metadata":{{}}}}"#,
140            encoded2
141        );
142
143        let segment1: Segment = serde_json::from_str(&segment1_json).unwrap();
144        let segment2: Segment = serde_json::from_str(&segment2_json).unwrap();
145
146        let spooled = SpooledData {
147            encoding: "json".to_string(),
148            segments: vec![segment1, segment2],
149        };
150
151        let rows = spooled.parse_segments::<crate::Row>().unwrap();
152        assert_eq!(rows.len(), 3);
153        assert_eq!(
154            rows[0].value()[0],
155            serde_json::Value::String("alice".to_string())
156        );
157        assert_eq!(rows[0].value()[1], serde_json::Value::Number(1.into()));
158        assert_eq!(
159            rows[1].value()[0],
160            serde_json::Value::String("bob".to_string())
161        );
162        assert_eq!(rows[1].value()[1], serde_json::Value::Number(2.into()));
163        assert_eq!(
164            rows[2].value()[0],
165            serde_json::Value::String("charlie".to_string())
166        );
167        assert_eq!(rows[2].value()[1], serde_json::Value::Number(3.into()));
168    }
169}