1use maybe_async::maybe_async;
2use serde::{de::DeserializeOwned, Deserialize};
3use serde_json::{json, Map, Value};
4use thiserror::Error;
5
6#[cfg(not(feature = "is_sync"))]
7use reqwest::Client as HttpClient;
8
9#[cfg(feature = "is_sync")]
10use reqwest::blocking::Client as HttpClient;
11
12const DEFAULT_VIZIER_TAP_URL: &str = "http://tapvizier.u-strasbg.fr/TAPVizieR/tap/sync";
13
14#[cfg(not(feature = "is_sync"))]
15macro_rules! maybe_await {
16 ($future:expr) => {
17 $future.await
18 };
19}
20
21#[cfg(feature = "is_sync")]
22macro_rules! maybe_await {
23 ($value:expr) => {
24 $value
25 };
26}
27
28#[derive(Error, Debug)]
29pub enum VizierError {
30 #[error("Request failed: {0}")]
31 RequestFailed(reqwest::Error),
32 #[error("Non-success status code: {0}")]
33 NonSuccessStatus(reqwest::StatusCode),
34 #[error("Unexpected response schema: {0}")]
35 UnexpectedSchema(String),
36 #[error("Failed to deserialize response: {0}")]
37 DeserializationFailed(serde_json::Error),
38 #[error("{0}")]
39 Other(String),
40}
41
42#[derive(Deserialize, Debug)]
43pub struct ColumnMetadata {
44 pub name: String,
45 pub description: String,
46 pub arraysize: Option<String>,
47 pub unit: Option<String>,
48 pub ucd: String,
49}
50
51#[derive(Deserialize)]
52struct ResponseSchema {
53 #[serde(rename = "metadata")]
54 meta: Vec<ColumnMetadata>,
55 data: Vec<Vec<Value>>,
56}
57
58pub struct QueryResult<T> {
59 meta: Vec<ColumnMetadata>,
60 data: Vec<T>,
61}
62
63impl<T> QueryResult<T> {
64 pub fn meta(&self) -> &[ColumnMetadata] {
65 &self.meta
66 }
67
68 pub fn data(&self) -> &[T] {
69 &self.data
70 }
71
72 pub fn len(&self) -> usize {
73 self.data.len()
74 }
75
76 pub fn is_empty(&self) -> bool {
77 self.data.is_empty()
78 }
79}
80
81pub struct Client {
82 tap_url: String,
83 http_client: HttpClient,
84}
85
86impl Client {
87 pub fn new(tap_url: &str) -> Self {
88 Self {
89 tap_url: tap_url.to_string(),
90 http_client: HttpClient::new(),
91 }
92 }
93
94 #[maybe_async]
95 pub async fn query<T: DeserializeOwned>(
96 &self,
97 adql_query: &str,
98 ) -> Result<QueryResult<T>, VizierError> {
99 let request_query = json!({
100 "request": "doQuery",
101 "lang": "ADQL",
102 "format": "json",
103 "query": adql_query
104 });
105
106 let response = maybe_await!(self
107 .http_client
108 .get(&self.tap_url)
109 .query(&request_query)
110 .send())
111 .map_err(VizierError::RequestFailed)?;
112
113 if response.status().is_success() {
114 let data =
115 maybe_await!(response.json::<Value>()).map_err(VizierError::RequestFailed)?;
116 let parsed_data = Client::parse_query_result::<T>(data)
117 .map_err(VizierError::DeserializationFailed)?;
118
119 Ok(parsed_data)
120 } else {
121 Err(VizierError::NonSuccessStatus(response.status()))
122 }
123 }
124
125 fn parse_query_result<T: DeserializeOwned>(
126 data: Value,
127 ) -> Result<QueryResult<T>, serde_json::Error> {
128 let response = serde_json::from_value::<ResponseSchema>(data)?;
129
130 let mut result = Vec::new();
131 for row in response.data {
132 let mut row_data = Map::new();
133
134 for (i, value) in row.iter().enumerate() {
135 row_data.insert(response.meta[i].name.clone(), value.clone());
136 }
137 result.push(serde_json::from_value(Value::Object(row_data))?);
138 }
139
140 Ok(QueryResult {
141 meta: response.meta,
142 data: result,
143 })
144 }
145}
146
147impl Default for Client {
148 fn default() -> Self {
149 Self::new(DEFAULT_VIZIER_TAP_URL)
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use serde::Deserialize;
156
157 use super::*;
158
159 #[cfg(not(feature = "is_sync"))]
160 #[tokio::test]
161 async fn query_test() {
162 let client = Client::default();
163
164 let result = client
165 .query::<Value>("SELECT TOP 100 * FROM \"I/261/fonac\"")
166 .await
167 .unwrap();
168
169 assert!(result.len() == 100);
170 }
171
172 #[derive(Deserialize, Debug)]
173 #[allow(non_snake_case, dead_code)]
174 struct Object {
175 AC2000: i32,
176 ACT: Option<i32>,
177 #[serde(rename = "B-R")]
178 BR: Option<f64>,
179 #[serde(rename = "B-V")]
180 BV: Option<f64>,
181 Bmag: f64,
182 DEJ2000: f64,
183 #[serde(rename = "Ep-1900")]
184 Ep1900: f64,
185 Qflag: Option<i32>,
186 RAJ2000: f64,
187 pmDE: f64,
188 pmRA: f64,
189 q_Bmag: Option<i32>,
190 recno: i32,
191 }
192
193 #[cfg(not(feature = "is_sync"))]
194 #[tokio::test]
195 async fn query_test_typed() {
196 let client = Client::default();
197
198 client
199 .query::<Object>("SELECT TOP 100 * FROM \"I/261/fonac\"")
200 .await
201 .unwrap();
202 }
203
204 #[cfg(feature = "is_sync")]
205 #[test]
206 fn query_test_sync() {
207 let client = Client::default();
208 let result = client
209 .query::<Value>("SELECT TOP 100 * FROM \"I/261/fonac\"")
210 .unwrap();
211
212 assert!(result.len() == 100);
213 }
214}