rust_genai/
operations.rs

1//! Operations API surface.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use reqwest::header::{HeaderName, HeaderValue};
7use rust_genai_types::operations::{
8    GetOperationConfig, ListOperationsConfig, ListOperationsResponse, Operation,
9};
10
11use crate::client::{Backend, ClientInner};
12use crate::error::{Error, Result};
13
14#[derive(Clone)]
15pub struct Operations {
16    pub(crate) inner: Arc<ClientInner>,
17}
18
19impl Operations {
20    pub(crate) const fn new(inner: Arc<ClientInner>) -> Self {
21        Self { inner }
22    }
23
24    /// 获取操作状态。
25    ///
26    /// # Errors
27    /// 当请求失败、服务端返回错误或响应解析失败时返回错误。
28    pub async fn get(&self, name: impl AsRef<str>) -> Result<Operation> {
29        self.get_with_config(name, GetOperationConfig::default())
30            .await
31    }
32
33    /// 获取操作状态(带配置)。
34    ///
35    /// # Errors
36    /// 当请求失败、服务端返回错误或响应解析失败时返回错误。
37    pub async fn get_with_config(
38        &self,
39        name: impl AsRef<str>,
40        mut config: GetOperationConfig,
41    ) -> Result<Operation> {
42        let http_options = config.http_options.take();
43        let name = normalize_operation_name(&self.inner, name.as_ref())?;
44        let url = build_operation_url(&self.inner, &name, http_options.as_ref());
45        let mut request = self.inner.http.get(url);
46        request = apply_http_options(request, http_options.as_ref())?;
47
48        let response = self.inner.send(request).await?;
49        if !response.status().is_success() {
50            return Err(Error::ApiError {
51                status: response.status().as_u16(),
52                message: response.text().await.unwrap_or_default(),
53            });
54        }
55        Ok(response.json::<Operation>().await?)
56    }
57
58    /// 列出操作。
59    ///
60    /// # Errors
61    /// 当请求失败或响应解析失败时返回错误。
62    pub async fn list(&self) -> Result<ListOperationsResponse> {
63        self.list_with_config(ListOperationsConfig::default()).await
64    }
65
66    /// 列出操作(带配置)。
67    ///
68    /// # Errors
69    /// 当请求失败或响应解析失败时返回错误。
70    pub async fn list_with_config(
71        &self,
72        mut config: ListOperationsConfig,
73    ) -> Result<ListOperationsResponse> {
74        let http_options = config.http_options.take();
75        let url = build_operations_list_url(&self.inner, http_options.as_ref())?;
76        let url = add_list_query_params(&url, &config)?;
77        let mut request = self.inner.http.get(url);
78        request = apply_http_options(request, http_options.as_ref())?;
79
80        let response = self.inner.send(request).await?;
81        if !response.status().is_success() {
82            return Err(Error::ApiError {
83                status: response.status().as_u16(),
84                message: response.text().await.unwrap_or_default(),
85            });
86        }
87        Ok(response.json::<ListOperationsResponse>().await?)
88    }
89
90    /// 列出所有操作(自动翻页)。
91    ///
92    /// # Errors
93    /// 当请求失败或响应解析失败时返回错误。
94    pub async fn all(&self) -> Result<Vec<Operation>> {
95        self.all_with_config(ListOperationsConfig::default()).await
96    }
97
98    /// 列出所有操作(带配置,自动翻页)。
99    ///
100    /// # Errors
101    /// 当请求失败或响应解析失败时返回错误。
102    pub async fn all_with_config(
103        &self,
104        mut config: ListOperationsConfig,
105    ) -> Result<Vec<Operation>> {
106        let mut ops = Vec::new();
107        let http_options = config.http_options.clone();
108        loop {
109            let mut page_config = config.clone();
110            page_config.http_options.clone_from(&http_options);
111            let response = self.list_with_config(page_config).await?;
112            if let Some(items) = response.operations {
113                ops.extend(items);
114            }
115            match response.next_page_token {
116                Some(token) if !token.is_empty() => {
117                    config.page_token = Some(token);
118                }
119                _ => break,
120            }
121        }
122        Ok(ops)
123    }
124
125    /// 等待操作完成(轮询)。
126    ///
127    /// # Errors
128    /// 当请求失败、操作缺少名称或轮询过程中响应解析失败时返回错误。
129    pub async fn wait(&self, mut operation: Operation) -> Result<Operation> {
130        let name = operation.name.clone().ok_or_else(|| Error::InvalidConfig {
131            message: "Operation name is empty".into(),
132        })?;
133        while !operation.done.unwrap_or(false) {
134            tokio::time::sleep(Duration::from_secs(5)).await;
135            operation = self.get(&name).await?;
136        }
137        Ok(operation)
138    }
139}
140
141fn normalize_operation_name(inner: &ClientInner, name: &str) -> Result<String> {
142    match inner.config.backend {
143        Backend::GeminiApi => {
144            if name.starts_with("operations/") || name.starts_with("models/") {
145                Ok(name.to_string())
146            } else {
147                Ok(format!("operations/{name}"))
148            }
149        }
150        Backend::VertexAi => {
151            let vertex =
152                inner
153                    .config
154                    .vertex_config
155                    .as_ref()
156                    .ok_or_else(|| Error::InvalidConfig {
157                        message: "Vertex config missing".into(),
158                    })?;
159            if name.starts_with("projects/") {
160                Ok(name.to_string())
161            } else if name.starts_with("locations/") {
162                Ok(format!("projects/{}/{}", vertex.project, name))
163            } else if name.starts_with("operations/") {
164                Ok(format!(
165                    "projects/{}/locations/{}/{}",
166                    vertex.project, vertex.location, name
167                ))
168            } else {
169                Ok(format!(
170                    "projects/{}/locations/{}/operations/{}",
171                    vertex.project, vertex.location, name
172                ))
173            }
174        }
175    }
176}
177
178fn build_operation_url(
179    inner: &ClientInner,
180    name: &str,
181    http_options: Option<&rust_genai_types::http::HttpOptions>,
182) -> String {
183    let base = http_options
184        .and_then(|opts| opts.base_url.as_deref())
185        .unwrap_or(&inner.api_client.base_url);
186    let version = http_options
187        .and_then(|opts| opts.api_version.as_deref())
188        .unwrap_or(&inner.api_client.api_version);
189    format!("{base}{version}/{name}")
190}
191
192fn build_operations_list_url(
193    inner: &ClientInner,
194    http_options: Option<&rust_genai_types::http::HttpOptions>,
195) -> Result<String> {
196    let base = http_options
197        .and_then(|opts| opts.base_url.as_deref())
198        .unwrap_or(&inner.api_client.base_url);
199    let version = http_options
200        .and_then(|opts| opts.api_version.as_deref())
201        .unwrap_or(&inner.api_client.api_version);
202    let url = match inner.config.backend {
203        Backend::GeminiApi => format!("{base}{version}/operations"),
204        Backend::VertexAi => {
205            let vertex =
206                inner
207                    .config
208                    .vertex_config
209                    .as_ref()
210                    .ok_or_else(|| Error::InvalidConfig {
211                        message: "Vertex config missing".into(),
212                    })?;
213            format!(
214                "{base}{version}/projects/{}/locations/{}/operations",
215                vertex.project, vertex.location
216            )
217        }
218    };
219    Ok(url)
220}
221
222fn add_list_query_params(url: &str, config: &ListOperationsConfig) -> Result<String> {
223    let mut url = reqwest::Url::parse(url).map_err(|err| Error::InvalidConfig {
224        message: err.to_string(),
225    })?;
226    {
227        let mut pairs = url.query_pairs_mut();
228        if let Some(page_size) = config.page_size {
229            pairs.append_pair("pageSize", &page_size.to_string());
230        }
231        if let Some(page_token) = &config.page_token {
232            pairs.append_pair("pageToken", page_token);
233        }
234        if let Some(filter) = &config.filter {
235            pairs.append_pair("filter", filter);
236        }
237    }
238    Ok(url.to_string())
239}
240
241fn apply_http_options(
242    mut request: reqwest::RequestBuilder,
243    http_options: Option<&rust_genai_types::http::HttpOptions>,
244) -> Result<reqwest::RequestBuilder> {
245    if let Some(options) = http_options {
246        if let Some(timeout) = options.timeout {
247            request = request.timeout(Duration::from_millis(timeout));
248        }
249        if let Some(headers) = &options.headers {
250            for (key, value) in headers {
251                let name =
252                    HeaderName::from_bytes(key.as_bytes()).map_err(|_| Error::InvalidConfig {
253                        message: format!("Invalid header name: {key}"),
254                    })?;
255                let value = HeaderValue::from_str(value).map_err(|_| Error::InvalidConfig {
256                    message: format!("Invalid header value for {key}"),
257                })?;
258                request = request.header(name, value);
259            }
260        }
261    }
262    Ok(request)
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use crate::test_support::{test_client_inner, test_vertex_inner_missing_config};
269    use std::collections::HashMap;
270
271    #[test]
272    fn test_normalize_operation_name() {
273        let gemini = test_client_inner(Backend::GeminiApi);
274        assert_eq!(
275            normalize_operation_name(&gemini, "operations/123").unwrap(),
276            "operations/123"
277        );
278        assert_eq!(
279            normalize_operation_name(&gemini, "models/abc").unwrap(),
280            "models/abc"
281        );
282        assert_eq!(
283            normalize_operation_name(&gemini, "op-1").unwrap(),
284            "operations/op-1"
285        );
286
287        let vertex = test_client_inner(Backend::VertexAi);
288        assert_eq!(
289            normalize_operation_name(&vertex, "projects/x/locations/y/operations/z").unwrap(),
290            "projects/x/locations/y/operations/z"
291        );
292        assert_eq!(
293            normalize_operation_name(&vertex, "locations/us/operations/1").unwrap(),
294            "projects/proj/locations/us/operations/1"
295        );
296        assert_eq!(
297            normalize_operation_name(&vertex, "operations/2").unwrap(),
298            "projects/proj/locations/loc/operations/2"
299        );
300        assert_eq!(
301            normalize_operation_name(&vertex, "op-3").unwrap(),
302            "projects/proj/locations/loc/operations/op-3"
303        );
304    }
305
306    #[test]
307    fn test_build_operations_list_url_and_params() {
308        let gemini = test_client_inner(Backend::GeminiApi);
309        let url = build_operations_list_url(&gemini, None).unwrap();
310        assert!(url.ends_with("/v1beta/operations"));
311        let url = add_list_query_params(
312            &url,
313            &ListOperationsConfig {
314                page_size: Some(10),
315                page_token: Some("token".to_string()),
316                filter: Some("done=true".to_string()),
317                ..Default::default()
318            },
319        )
320        .unwrap();
321        assert!(url.contains("pageSize=10"));
322        assert!(url.contains("pageToken=token"));
323
324        let vertex = test_client_inner(Backend::VertexAi);
325        let url = build_operations_list_url(&vertex, None).unwrap();
326        assert!(url.contains("/projects/proj/locations/loc/operations"));
327    }
328
329    #[test]
330    fn test_build_operations_list_url_vertex_missing_config_errors() {
331        let inner = test_vertex_inner_missing_config();
332        assert!(build_operations_list_url(&inner, None).is_err());
333    }
334
335    #[test]
336    fn test_add_list_query_params_invalid_url() {
337        let err = add_list_query_params("::bad", &ListOperationsConfig::default()).unwrap_err();
338        assert!(matches!(err, Error::InvalidConfig { .. }));
339    }
340
341    #[test]
342    fn test_apply_http_options_invalid_header() {
343        let client = reqwest::Client::new();
344        let request = client.get("https://example.com");
345        let options = rust_genai_types::http::HttpOptions {
346            headers: Some([("bad header".to_string(), "value".to_string())].into()),
347            ..Default::default()
348        };
349        let err = apply_http_options(request, Some(&options)).unwrap_err();
350        assert!(matches!(err, Error::InvalidConfig { .. }));
351    }
352
353    #[test]
354    fn test_apply_http_options_with_valid_header() {
355        let client = reqwest::Client::new();
356        let request = client.get("https://example.com");
357        let mut headers = HashMap::new();
358        headers.insert("x-test".to_string(), "ok".to_string());
359        let options = rust_genai_types::http::HttpOptions {
360            headers: Some(headers),
361            ..Default::default()
362        };
363        let request = apply_http_options(request, Some(&options)).unwrap();
364        let built = request.build().unwrap();
365        assert!(built.headers().contains_key("x-test"));
366    }
367
368    #[test]
369    fn test_apply_http_options_invalid_header_value() {
370        let client = reqwest::Client::new();
371        let request = client.get("https://example.com");
372        let mut headers = HashMap::new();
373        headers.insert("x-test".to_string(), "bad\nvalue".to_string());
374        let options = rust_genai_types::http::HttpOptions {
375            headers: Some(headers),
376            ..Default::default()
377        };
378        let err = apply_http_options(request, Some(&options)).unwrap_err();
379        assert!(matches!(err, Error::InvalidConfig { .. }));
380    }
381
382    #[tokio::test]
383    async fn test_wait_missing_name_errors() {
384        let client = crate::Client::new("test-key").unwrap();
385        let ops = client.operations();
386        let result = ops
387            .wait(Operation {
388                name: None,
389                done: Some(false),
390                ..Default::default()
391            })
392            .await;
393        assert!(matches!(result.unwrap_err(), Error::InvalidConfig { .. }));
394    }
395}