watsonx_rs/
types.rs

1//! Core types for WatsonX operations
2
3use serde::{Deserialize, Serialize};
4use std::time::Duration;
5
6// Token constants are defined in models.rs to avoid conflicts
7
8/// Configuration for text generation requests
9#[derive(Clone, Debug, Serialize)]
10pub struct GenerationConfig {
11    /// Model ID to use for generation
12    pub model_id: String,
13    /// Request timeout
14    pub timeout: Duration,
15    /// Maximum number of tokens to generate
16    pub max_tokens: u32,
17    /// Top-k sampling parameter
18    pub top_k: Option<u32>,
19    /// Top-p sampling parameter
20    pub top_p: Option<f32>,
21    /// Stop sequences to halt generation
22    pub stop_sequences: Vec<String>,
23    /// Temperature for generation (not used in current API)
24    pub temperature: Option<f32>,
25    /// Repetition penalty
26    pub repetition_penalty: Option<f32>,
27}
28
29impl Default for GenerationConfig {
30    fn default() -> Self {
31        Self {
32            model_id: crate::models::DEFAULT_MODEL.to_string(),
33            timeout: Duration::from_secs(120),
34            max_tokens: crate::models::DEFAULT_MAX_TOKENS,
35            top_k: Some(50),
36            top_p: Some(1.0),
37            stop_sequences: vec![],
38            temperature: None,
39            repetition_penalty: Some(1.1),
40        }
41    }
42}
43
44impl GenerationConfig {
45    /// Create a config with maximum token support (128k)
46    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
47        self.max_tokens = max_tokens.min(crate::models::MAX_TOKENS_LIMIT);
48        self
49    }
50
51    /// Create a config optimized for long-form generation (128k tokens)
52    pub fn long_form() -> Self {
53        Self {
54            max_tokens: crate::models::MAX_TOKENS_LIMIT,
55            timeout: Duration::from_secs(300), // 5 minutes for long responses
56            ..Default::default()
57        }
58    }
59
60    /// Create a config optimized for quick responses
61    pub fn quick_response() -> Self {
62        Self {
63            max_tokens: crate::models::QUICK_RESPONSE_MAX_TOKENS,
64            timeout: Duration::from_secs(30),
65            ..Default::default()
66        }
67    }
68
69    /// Set the model ID
70    pub fn with_model(mut self, model_id: impl Into<String>) -> Self {
71        self.model_id = model_id.into();
72        self
73    }
74
75    /// Set the timeout
76    pub fn with_timeout(mut self, timeout: Duration) -> Self {
77        self.timeout = timeout;
78        self
79    }
80
81    /// Set top-k parameter
82    pub fn with_top_k(mut self, top_k: u32) -> Self {
83        self.top_k = Some(top_k);
84        self
85    }
86
87    /// Set top-p parameter
88    pub fn with_top_p(mut self, top_p: f32) -> Self {
89        self.top_p = Some(top_p);
90        self
91    }
92
93    /// Set stop sequences
94    pub fn with_stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
95        self.stop_sequences = stop_sequences;
96        self
97    }
98
99    /// Set repetition penalty
100    pub fn with_repetition_penalty(mut self, penalty: f32) -> Self {
101        self.repetition_penalty = Some(penalty);
102        self
103    }
104}
105
106/// Result of a text generation request
107#[derive(Clone, Debug, Serialize, Deserialize)]
108pub struct GenerationResult {
109    /// Generated text content
110    pub text: String,
111    /// Model ID used for generation
112    pub model_id: String,
113    /// Number of tokens used (if available)
114    pub tokens_used: Option<u32>,
115    /// Quality score (if calculated)
116    pub quality_score: Option<f32>,
117    /// Request ID for tracking
118    pub request_id: Option<String>,
119}
120
121impl GenerationResult {
122    /// Create a new generation result
123    pub fn new(text: String, model_id: String) -> Self {
124        Self {
125            text,
126            model_id,
127            tokens_used: None,
128            quality_score: None,
129            request_id: None,
130        }
131    }
132
133    /// Set the tokens used
134    pub fn with_tokens_used(mut self, tokens: u32) -> Self {
135        self.tokens_used = Some(tokens);
136        self
137    }
138
139    /// Set the quality score
140    pub fn with_quality_score(mut self, score: f32) -> Self {
141        self.quality_score = Some(score);
142        self
143    }
144
145    /// Set the request ID
146    pub fn with_request_id(mut self, request_id: String) -> Self {
147        self.request_id = Some(request_id);
148        self
149    }
150}
151
152/// Configuration for retry attempts
153#[derive(Clone, Debug)]
154pub struct RetryConfig {
155    /// Maximum number of retry attempts
156    pub max_attempts: u32,
157    /// Base timeout for requests
158    pub base_timeout: Duration,
159    /// Quality threshold for accepting results
160    pub quality_threshold: f32,
161    /// Delay between retries
162    pub retry_delay: Duration,
163}
164
165impl Default for RetryConfig {
166    fn default() -> Self {
167        Self {
168            max_attempts: 3,
169            base_timeout: Duration::from_secs(30),
170            quality_threshold: 0.7,
171            retry_delay: Duration::from_secs(1),
172        }
173    }
174}
175
176impl RetryConfig {
177    /// Create a new retry configuration
178    pub fn new(max_attempts: u32) -> Self {
179        Self {
180            max_attempts,
181            ..Default::default()
182        }
183    }
184
185    /// Set the quality threshold
186    pub fn with_quality_threshold(mut self, threshold: f32) -> Self {
187        self.quality_threshold = threshold;
188        self
189    }
190
191    /// Set the retry delay
192    pub fn with_retry_delay(mut self, delay: Duration) -> Self {
193        self.retry_delay = delay;
194        self
195    }
196}
197
198/// Information about an available model
199#[derive(Clone, Debug, Serialize, Deserialize)]
200pub struct ModelInfo {
201    /// Model ID
202    pub model_id: String,
203    /// Model name
204    pub name: Option<String>,
205    /// Model description
206    pub description: Option<String>,
207    /// Model provider
208    pub provider: Option<String>,
209    /// Model version
210    pub version: Option<String>,
211    /// Supported tasks
212    pub supported_tasks: Option<Vec<String>>,
213    /// Maximum context length
214    pub max_context_length: Option<u32>,
215    /// Whether the model is available
216    pub available: Option<bool>,
217}
218
219impl ModelInfo {
220    /// Create a new model info instance
221    pub fn new(model_id: String) -> Self {
222        Self {
223            model_id,
224            name: None,
225            description: None,
226            provider: None,
227            version: None,
228            supported_tasks: None,
229            max_context_length: None,
230            available: None,
231        }
232    }
233
234    /// Set the model name
235    pub fn with_name(mut self, name: String) -> Self {
236        self.name = Some(name);
237        self
238    }
239
240    /// Set the model description
241    pub fn with_description(mut self, description: String) -> Self {
242        self.description = Some(description);
243        self
244    }
245
246    /// Set the model provider
247    pub fn with_provider(mut self, provider: String) -> Self {
248        self.provider = Some(provider);
249        self
250    }
251
252    /// Set the model version
253    pub fn with_version(mut self, version: String) -> Self {
254        self.version = Some(version);
255        self
256    }
257
258    /// Set supported tasks
259    pub fn with_supported_tasks(mut self, tasks: Vec<String>) -> Self {
260        self.supported_tasks = Some(tasks);
261        self
262    }
263
264    /// Set maximum context length
265    pub fn with_max_context_length(mut self, length: u32) -> Self {
266        self.max_context_length = Some(length);
267        self
268    }
269
270    /// Set availability status
271    pub fn with_available(mut self, available: bool) -> Self {
272        self.available = Some(available);
273        self
274    }
275}
276
277/// Information about a generation attempt
278#[derive(Clone, Debug)]
279pub struct GenerationAttempt {
280    /// The prompt used for this attempt
281    pub prompt: String,
282    /// The generated result
283    pub result: String,
284    /// Quality score for this attempt
285    pub quality_score: f32,
286    /// Attempt number (1-based)
287    pub attempt_number: u32,
288    /// Duration of this attempt
289    pub duration: Duration,
290}
291
292impl GenerationAttempt {
293    /// Create a new generation attempt
294    pub fn new(prompt: String, result: String, attempt_number: u32) -> Self {
295        Self {
296            prompt,
297            result,
298            quality_score: 0.0,
299            attempt_number,
300            duration: Duration::from_secs(0),
301        }
302    }
303
304    /// Set the quality score
305    pub fn with_quality_score(mut self, score: f32) -> Self {
306        self.quality_score = score;
307        self
308    }
309
310    /// Set the duration
311    pub fn with_duration(mut self, duration: Duration) -> Self {
312        self.duration = duration;
313        self
314    }
315}
316
317/// A single request in a batch generation operation
318#[derive(Clone, Debug)]
319pub struct BatchRequest {
320    /// The prompt to generate text for
321    pub prompt: String,
322    /// Optional configuration (uses default if None)
323    pub config: Option<GenerationConfig>,
324    /// Optional identifier for tracking this request
325    pub id: Option<String>,
326}
327
328impl BatchRequest {
329    /// Create a new batch request with a prompt
330    pub fn new(prompt: impl Into<String>) -> Self {
331        Self {
332            prompt: prompt.into(),
333            config: None,
334            id: None,
335        }
336    }
337
338    /// Create a new batch request with prompt and config
339    pub fn with_config(prompt: impl Into<String>, config: GenerationConfig) -> Self {
340        Self {
341            prompt: prompt.into(),
342            config: Some(config),
343            id: None,
344        }
345    }
346
347    /// Set an identifier for this request
348    pub fn with_id(mut self, id: impl Into<String>) -> Self {
349        self.id = Some(id.into());
350        self
351    }
352}
353
354/// Result for a single item in a batch generation operation
355#[derive(Clone, Debug)]
356pub struct BatchItemResult {
357    /// The identifier for this request (if provided)
358    pub id: Option<String>,
359    /// The prompt that was used
360    pub prompt: String,
361    /// The generation result if successful
362    pub result: Option<GenerationResult>,
363    /// The error if the request failed
364    pub error: Option<crate::error::Error>,
365}
366
367impl BatchItemResult {
368    /// Create a successful batch item result
369    pub fn success(id: Option<String>, prompt: String, result: GenerationResult) -> Self {
370        Self {
371            id,
372            prompt,
373            result: Some(result),
374            error: None,
375        }
376    }
377
378    /// Create a failed batch item result
379    pub fn failure(id: Option<String>, prompt: String, error: crate::error::Error) -> Self {
380        Self {
381            id,
382            prompt,
383            result: None,
384            error: Some(error),
385        }
386    }
387
388    /// Check if this result is successful
389    pub fn is_success(&self) -> bool {
390        self.error.is_none()
391    }
392
393    /// Check if this result failed
394    pub fn is_failure(&self) -> bool {
395        self.error.is_some()
396    }
397}
398
399/// Result of a batch generation operation
400#[derive(Clone, Debug)]
401pub struct BatchGenerationResult {
402    /// Results for each item in the batch
403    pub results: Vec<BatchItemResult>,
404    /// Total number of requests
405    pub total: usize,
406    /// Number of successful requests
407    pub successful: usize,
408    /// Number of failed requests
409    pub failed: usize,
410    /// Total duration of the batch operation
411    pub duration: Duration,
412}
413
414impl BatchGenerationResult {
415    /// Create a new batch generation result
416    pub fn new(results: Vec<BatchItemResult>, duration: Duration) -> Self {
417        let successful = results.iter().filter(|r| r.is_success()).count();
418        let failed = results.len() - successful;
419        
420        Self {
421            total: results.len(),
422            successful,
423            failed,
424            results,
425            duration,
426        }
427    }
428
429    /// Get all successful results
430    pub fn successes(&self) -> Vec<&GenerationResult> {
431        self.results
432            .iter()
433            .filter_map(|r| r.result.as_ref())
434            .collect()
435    }
436
437    /// Get all failed results with their errors
438    pub fn failures(&self) -> Vec<(&str, &crate::error::Error)> {
439        self.results
440            .iter()
441            .filter_map(|r| r.error.as_ref().map(|e| (r.prompt.as_str(), e)))
442            .collect()
443    }
444
445    /// Check if all requests succeeded
446    pub fn all_succeeded(&self) -> bool {
447        self.failed == 0
448    }
449
450    /// Check if any request failed
451    pub fn any_failed(&self) -> bool {
452        self.failed > 0
453    }
454}