1use std::sync::Arc;
4use std::time::Duration;
5
6use reqwest::header::{HeaderName, HeaderValue};
7use rust_genai_types::operations::{
8 GetOperationConfig, ListOperationsConfig, ListOperationsResponse, Operation,
9};
10
11use crate::client::{Backend, ClientInner};
12use crate::error::{Error, Result};
13
14#[derive(Clone)]
15pub struct Operations {
16 pub(crate) inner: Arc<ClientInner>,
17}
18
19impl Operations {
20 pub(crate) const fn new(inner: Arc<ClientInner>) -> Self {
21 Self { inner }
22 }
23
24 pub async fn get(&self, name: impl AsRef<str>) -> Result<Operation> {
29 self.get_with_config(name, GetOperationConfig::default())
30 .await
31 }
32
33 pub async fn get_with_config(
38 &self,
39 name: impl AsRef<str>,
40 mut config: GetOperationConfig,
41 ) -> Result<Operation> {
42 let http_options = config.http_options.take();
43 let name = normalize_operation_name(&self.inner, name.as_ref())?;
44 let url = build_operation_url(&self.inner, &name, http_options.as_ref());
45 let mut request = self.inner.http.get(url);
46 request = apply_http_options(request, http_options.as_ref())?;
47
48 let response = self.inner.send(request).await?;
49 if !response.status().is_success() {
50 return Err(Error::ApiError {
51 status: response.status().as_u16(),
52 message: response.text().await.unwrap_or_default(),
53 });
54 }
55 Ok(response.json::<Operation>().await?)
56 }
57
58 pub async fn list(&self) -> Result<ListOperationsResponse> {
63 self.list_with_config(ListOperationsConfig::default()).await
64 }
65
66 pub async fn list_with_config(
71 &self,
72 mut config: ListOperationsConfig,
73 ) -> Result<ListOperationsResponse> {
74 let http_options = config.http_options.take();
75 let url = build_operations_list_url(&self.inner, http_options.as_ref())?;
76 let url = add_list_query_params(&url, &config)?;
77 let mut request = self.inner.http.get(url);
78 request = apply_http_options(request, http_options.as_ref())?;
79
80 let response = self.inner.send(request).await?;
81 if !response.status().is_success() {
82 return Err(Error::ApiError {
83 status: response.status().as_u16(),
84 message: response.text().await.unwrap_or_default(),
85 });
86 }
87 Ok(response.json::<ListOperationsResponse>().await?)
88 }
89
90 pub async fn all(&self) -> Result<Vec<Operation>> {
95 self.all_with_config(ListOperationsConfig::default()).await
96 }
97
98 pub async fn all_with_config(
103 &self,
104 mut config: ListOperationsConfig,
105 ) -> Result<Vec<Operation>> {
106 let mut ops = Vec::new();
107 let http_options = config.http_options.clone();
108 loop {
109 let mut page_config = config.clone();
110 page_config.http_options.clone_from(&http_options);
111 let response = self.list_with_config(page_config).await?;
112 if let Some(items) = response.operations {
113 ops.extend(items);
114 }
115 match response.next_page_token {
116 Some(token) if !token.is_empty() => {
117 config.page_token = Some(token);
118 }
119 _ => break,
120 }
121 }
122 Ok(ops)
123 }
124
125 pub async fn wait(&self, mut operation: Operation) -> Result<Operation> {
130 let name = operation.name.clone().ok_or_else(|| Error::InvalidConfig {
131 message: "Operation name is empty".into(),
132 })?;
133 while !operation.done.unwrap_or(false) {
134 tokio::time::sleep(Duration::from_secs(5)).await;
135 operation = self.get(&name).await?;
136 }
137 Ok(operation)
138 }
139}
140
141fn normalize_operation_name(inner: &ClientInner, name: &str) -> Result<String> {
142 match inner.config.backend {
143 Backend::GeminiApi => {
144 if name.starts_with("operations/") || name.starts_with("models/") {
145 Ok(name.to_string())
146 } else {
147 Ok(format!("operations/{name}"))
148 }
149 }
150 Backend::VertexAi => {
151 let vertex =
152 inner
153 .config
154 .vertex_config
155 .as_ref()
156 .ok_or_else(|| Error::InvalidConfig {
157 message: "Vertex config missing".into(),
158 })?;
159 if name.starts_with("projects/") {
160 Ok(name.to_string())
161 } else if name.starts_with("locations/") {
162 Ok(format!("projects/{}/{}", vertex.project, name))
163 } else if name.starts_with("operations/") {
164 Ok(format!(
165 "projects/{}/locations/{}/{}",
166 vertex.project, vertex.location, name
167 ))
168 } else {
169 Ok(format!(
170 "projects/{}/locations/{}/operations/{}",
171 vertex.project, vertex.location, name
172 ))
173 }
174 }
175 }
176}
177
178fn build_operation_url(
179 inner: &ClientInner,
180 name: &str,
181 http_options: Option<&rust_genai_types::http::HttpOptions>,
182) -> String {
183 let base = http_options
184 .and_then(|opts| opts.base_url.as_deref())
185 .unwrap_or(&inner.api_client.base_url);
186 let version = http_options
187 .and_then(|opts| opts.api_version.as_deref())
188 .unwrap_or(&inner.api_client.api_version);
189 format!("{base}{version}/{name}")
190}
191
192fn build_operations_list_url(
193 inner: &ClientInner,
194 http_options: Option<&rust_genai_types::http::HttpOptions>,
195) -> Result<String> {
196 let base = http_options
197 .and_then(|opts| opts.base_url.as_deref())
198 .unwrap_or(&inner.api_client.base_url);
199 let version = http_options
200 .and_then(|opts| opts.api_version.as_deref())
201 .unwrap_or(&inner.api_client.api_version);
202 let url = match inner.config.backend {
203 Backend::GeminiApi => format!("{base}{version}/operations"),
204 Backend::VertexAi => {
205 let vertex =
206 inner
207 .config
208 .vertex_config
209 .as_ref()
210 .ok_or_else(|| Error::InvalidConfig {
211 message: "Vertex config missing".into(),
212 })?;
213 format!(
214 "{base}{version}/projects/{}/locations/{}/operations",
215 vertex.project, vertex.location
216 )
217 }
218 };
219 Ok(url)
220}
221
222fn add_list_query_params(url: &str, config: &ListOperationsConfig) -> Result<String> {
223 let mut url = reqwest::Url::parse(url).map_err(|err| Error::InvalidConfig {
224 message: err.to_string(),
225 })?;
226 {
227 let mut pairs = url.query_pairs_mut();
228 if let Some(page_size) = config.page_size {
229 pairs.append_pair("pageSize", &page_size.to_string());
230 }
231 if let Some(page_token) = &config.page_token {
232 pairs.append_pair("pageToken", page_token);
233 }
234 if let Some(filter) = &config.filter {
235 pairs.append_pair("filter", filter);
236 }
237 }
238 Ok(url.to_string())
239}
240
241fn apply_http_options(
242 mut request: reqwest::RequestBuilder,
243 http_options: Option<&rust_genai_types::http::HttpOptions>,
244) -> Result<reqwest::RequestBuilder> {
245 if let Some(options) = http_options {
246 if let Some(timeout) = options.timeout {
247 request = request.timeout(Duration::from_millis(timeout));
248 }
249 if let Some(headers) = &options.headers {
250 for (key, value) in headers {
251 let name =
252 HeaderName::from_bytes(key.as_bytes()).map_err(|_| Error::InvalidConfig {
253 message: format!("Invalid header name: {key}"),
254 })?;
255 let value = HeaderValue::from_str(value).map_err(|_| Error::InvalidConfig {
256 message: format!("Invalid header value for {key}"),
257 })?;
258 request = request.header(name, value);
259 }
260 }
261 }
262 Ok(request)
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use crate::test_support::{test_client_inner, test_vertex_inner_missing_config};
269 use std::collections::HashMap;
270
271 #[test]
272 fn test_normalize_operation_name() {
273 let gemini = test_client_inner(Backend::GeminiApi);
274 assert_eq!(
275 normalize_operation_name(&gemini, "operations/123").unwrap(),
276 "operations/123"
277 );
278 assert_eq!(
279 normalize_operation_name(&gemini, "models/abc").unwrap(),
280 "models/abc"
281 );
282 assert_eq!(
283 normalize_operation_name(&gemini, "op-1").unwrap(),
284 "operations/op-1"
285 );
286
287 let vertex = test_client_inner(Backend::VertexAi);
288 assert_eq!(
289 normalize_operation_name(&vertex, "projects/x/locations/y/operations/z").unwrap(),
290 "projects/x/locations/y/operations/z"
291 );
292 assert_eq!(
293 normalize_operation_name(&vertex, "locations/us/operations/1").unwrap(),
294 "projects/proj/locations/us/operations/1"
295 );
296 assert_eq!(
297 normalize_operation_name(&vertex, "operations/2").unwrap(),
298 "projects/proj/locations/loc/operations/2"
299 );
300 assert_eq!(
301 normalize_operation_name(&vertex, "op-3").unwrap(),
302 "projects/proj/locations/loc/operations/op-3"
303 );
304 }
305
306 #[test]
307 fn test_build_operations_list_url_and_params() {
308 let gemini = test_client_inner(Backend::GeminiApi);
309 let url = build_operations_list_url(&gemini, None).unwrap();
310 assert!(url.ends_with("/v1beta/operations"));
311 let url = add_list_query_params(
312 &url,
313 &ListOperationsConfig {
314 page_size: Some(10),
315 page_token: Some("token".to_string()),
316 filter: Some("done=true".to_string()),
317 ..Default::default()
318 },
319 )
320 .unwrap();
321 assert!(url.contains("pageSize=10"));
322 assert!(url.contains("pageToken=token"));
323
324 let vertex = test_client_inner(Backend::VertexAi);
325 let url = build_operations_list_url(&vertex, None).unwrap();
326 assert!(url.contains("/projects/proj/locations/loc/operations"));
327 }
328
329 #[test]
330 fn test_build_operations_list_url_vertex_missing_config_errors() {
331 let inner = test_vertex_inner_missing_config();
332 assert!(build_operations_list_url(&inner, None).is_err());
333 }
334
335 #[test]
336 fn test_add_list_query_params_invalid_url() {
337 let err = add_list_query_params("::bad", &ListOperationsConfig::default()).unwrap_err();
338 assert!(matches!(err, Error::InvalidConfig { .. }));
339 }
340
341 #[test]
342 fn test_apply_http_options_invalid_header() {
343 let client = reqwest::Client::new();
344 let request = client.get("https://example.com");
345 let options = rust_genai_types::http::HttpOptions {
346 headers: Some([("bad header".to_string(), "value".to_string())].into()),
347 ..Default::default()
348 };
349 let err = apply_http_options(request, Some(&options)).unwrap_err();
350 assert!(matches!(err, Error::InvalidConfig { .. }));
351 }
352
353 #[test]
354 fn test_apply_http_options_with_valid_header() {
355 let client = reqwest::Client::new();
356 let request = client.get("https://example.com");
357 let mut headers = HashMap::new();
358 headers.insert("x-test".to_string(), "ok".to_string());
359 let options = rust_genai_types::http::HttpOptions {
360 headers: Some(headers),
361 ..Default::default()
362 };
363 let request = apply_http_options(request, Some(&options)).unwrap();
364 let built = request.build().unwrap();
365 assert!(built.headers().contains_key("x-test"));
366 }
367
368 #[test]
369 fn test_apply_http_options_invalid_header_value() {
370 let client = reqwest::Client::new();
371 let request = client.get("https://example.com");
372 let mut headers = HashMap::new();
373 headers.insert("x-test".to_string(), "bad\nvalue".to_string());
374 let options = rust_genai_types::http::HttpOptions {
375 headers: Some(headers),
376 ..Default::default()
377 };
378 let err = apply_http_options(request, Some(&options)).unwrap_err();
379 assert!(matches!(err, Error::InvalidConfig { .. }));
380 }
381
382 #[tokio::test]
383 async fn test_wait_missing_name_errors() {
384 let client = crate::Client::new("test-key").unwrap();
385 let ops = client.operations();
386 let result = ops
387 .wait(Operation {
388 name: None,
389 done: Some(false),
390 ..Default::default()
391 })
392 .await;
393 assert!(matches!(result.unwrap_err(), Error::InvalidConfig { .. }));
394 }
395}