1#[cfg(feature = "async")]
9use std::collections::HashMap;
10#[cfg(feature = "async")]
11use std::future::Future;
12#[cfg(feature = "async")]
13use std::pin::Pin;
14
15#[cfg(feature = "async")]
16use tensorlogic_ir::EinsumGraph;
17
18#[cfg(feature = "async")]
19use crate::batch::BatchResult;
20#[cfg(feature = "async")]
21use crate::streaming::{StreamResult, StreamingConfig};
22
23#[cfg(feature = "async")]
25pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
26
27#[cfg(feature = "async")]
32pub trait TlAsyncExecutor {
33 type Tensor: Send;
34 type Error: Send;
35
36 fn execute_async<'a>(
38 &'a mut self,
39 graph: &'a EinsumGraph,
40 inputs: &'a HashMap<String, Self::Tensor>,
41 ) -> BoxFuture<'a, Result<Vec<Self::Tensor>, Self::Error>>;
42
43 fn is_ready(&self) -> bool {
45 true
46 }
47
48 fn wait_ready(&mut self) -> BoxFuture<'_, ()>
50 where
51 Self: Send,
52 {
53 Box::pin(async move {
54 while !self.is_ready() {
55 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
56 }
57 })
58 }
59}
60
61#[cfg(feature = "async")]
63pub trait TlAsyncBatchExecutor: TlAsyncExecutor {
64 fn execute_batch_async<'a>(
66 &'a mut self,
67 graph: &'a EinsumGraph,
68 batch_inputs: Vec<HashMap<String, Self::Tensor>>,
69 ) -> BoxFuture<'a, Result<BatchResult<Self::Tensor>, Self::Error>>;
70}
71
72#[cfg(feature = "async")]
74pub type AsyncStreamResults<T, E> = Vec<Result<StreamResult<T>, E>>;
75
76#[cfg(feature = "async")]
78pub trait TlAsyncStreamExecutor: TlAsyncExecutor {
79 fn execute_stream_async<'a>(
81 &'a mut self,
82 graph: &'a EinsumGraph,
83 input_stream: Vec<Vec<Vec<Self::Tensor>>>,
84 config: &'a StreamingConfig,
85 ) -> BoxFuture<'a, AsyncStreamResults<Self::Tensor, Self::Error>>;
86}
87
88#[derive(Debug, Clone)]
90pub enum AsyncExecutionError<E> {
91 Timeout { elapsed_ms: u64 },
93 ExecutorBusy { queue_size: usize },
95 Cancelled,
97 ExecutorError(E),
99 Dropped,
101}
102
103impl<E: std::fmt::Display> std::fmt::Display for AsyncExecutionError<E> {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match self {
106 Self::Timeout { elapsed_ms } => {
107 write!(f, "Execution timed out after {}ms", elapsed_ms)
108 }
109 Self::ExecutorBusy { queue_size } => {
110 write!(
111 f,
112 "Executor is busy (queue size: {}), try again later",
113 queue_size
114 )
115 }
116 Self::Cancelled => write!(f, "Execution was cancelled"),
117 Self::ExecutorError(e) => write!(f, "Executor error: {}", e),
118 Self::Dropped => write!(f, "Future was dropped before completion"),
119 }
120 }
121}
122
123impl<E: std::error::Error> std::error::Error for AsyncExecutionError<E> {}
124
125#[cfg(feature = "async")]
127pub struct AsyncExecutionHandle {
128 execution_id: String,
129 started_at: std::time::Instant,
130 cancel_token: tokio::sync::mpsc::Sender<()>,
131}
132
133#[cfg(feature = "async")]
134impl AsyncExecutionHandle {
135 pub fn new(execution_id: String) -> (Self, tokio::sync::mpsc::Receiver<()>) {
137 let (tx, rx) = tokio::sync::mpsc::channel(1);
138 (
139 AsyncExecutionHandle {
140 execution_id,
141 started_at: std::time::Instant::now(),
142 cancel_token: tx,
143 },
144 rx,
145 )
146 }
147
148 pub fn execution_id(&self) -> &str {
150 &self.execution_id
151 }
152
153 pub fn elapsed(&self) -> std::time::Duration {
155 self.started_at.elapsed()
156 }
157
158 pub async fn cancel(&self) -> Result<(), AsyncExecutionError<std::io::Error>> {
160 self.cancel_token
161 .send(())
162 .await
163 .map_err(|_| AsyncExecutionError::Cancelled)
164 }
165}
166
167#[cfg(feature = "async")]
169pub struct AsyncExecutorPool<E: TlAsyncExecutor> {
170 executors: Vec<E>,
171 next_index: std::sync::atomic::AtomicUsize,
172}
173
174#[cfg(feature = "async")]
175impl<E: TlAsyncExecutor> AsyncExecutorPool<E> {
176 pub fn new(executors: Vec<E>) -> Self {
178 AsyncExecutorPool {
179 executors,
180 next_index: std::sync::atomic::AtomicUsize::new(0),
181 }
182 }
183
184 pub fn size(&self) -> usize {
186 self.executors.len()
187 }
188
189 pub fn get_next_index(&self) -> usize {
191 self.next_index
192 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
193 % self.executors.len()
194 }
195
196 pub fn get_least_loaded_index(&self) -> usize {
198 for (idx, executor) in self.executors.iter().enumerate() {
201 if executor.is_ready() {
202 return idx;
203 }
204 }
205 0
206 }
207
208 pub async fn execute_any<'a>(
210 &'a mut self,
211 graph: &'a EinsumGraph,
212 inputs: &'a HashMap<String, E::Tensor>,
213 ) -> Result<Vec<E::Tensor>, E::Error> {
214 let index = self.get_least_loaded_index();
215 self.executors[index].execute_async(graph, inputs).await
216 }
217}
218
219#[derive(Debug, Clone)]
221pub struct AsyncConfig {
222 pub max_concurrent: usize,
224 pub timeout_ms: Option<u64>,
226 pub enable_retry: bool,
228 pub max_retries: usize,
230 pub backoff_ms: u64,
232}
233
234impl Default for AsyncConfig {
235 fn default() -> Self {
236 AsyncConfig {
237 max_concurrent: 4,
238 timeout_ms: None,
239 enable_retry: false,
240 max_retries: 3,
241 backoff_ms: 100,
242 }
243 }
244}
245
246impl AsyncConfig {
247 pub fn new() -> Self {
249 Self::default()
250 }
251
252 pub fn with_max_concurrent(mut self, max: usize) -> Self {
254 self.max_concurrent = max;
255 self
256 }
257
258 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
260 self.timeout_ms = Some(timeout_ms);
261 self
262 }
263
264 pub fn with_retry(mut self, max_retries: usize, backoff_ms: u64) -> Self {
266 self.enable_retry = true;
267 self.max_retries = max_retries;
268 self.backoff_ms = backoff_ms;
269 self
270 }
271}
272
273#[derive(Debug, Clone, Default)]
275pub struct AsyncStats {
276 pub total_executions: usize,
278 pub successful: usize,
280 pub failed: usize,
282 pub timeouts: usize,
284 pub cancelled: usize,
286 pub avg_execution_time_ms: f64,
288 pub peak_concurrent: usize,
290}
291
292impl AsyncStats {
293 pub fn new() -> Self {
295 Self::default()
296 }
297
298 pub fn success_rate(&self) -> f64 {
300 if self.total_executions == 0 {
301 0.0
302 } else {
303 self.successful as f64 / self.total_executions as f64
304 }
305 }
306
307 pub fn summary(&self) -> String {
309 format!(
310 "Async Execution Stats:\n\
311 - Total: {}\n\
312 - Successful: {} ({:.1}%)\n\
313 - Failed: {}\n\
314 - Timeouts: {}\n\
315 - Cancelled: {}\n\
316 - Avg time: {:.2}ms\n\
317 - Peak concurrent: {}",
318 self.total_executions,
319 self.successful,
320 self.success_rate() * 100.0,
321 self.failed,
322 self.timeouts,
323 self.cancelled,
324 self.avg_execution_time_ms,
325 self.peak_concurrent
326 )
327 }
328}
329
330#[cfg(all(test, feature = "async"))]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_async_config() {
336 let config = AsyncConfig::new()
337 .with_max_concurrent(8)
338 .with_timeout(5000)
339 .with_retry(3, 200);
340
341 assert_eq!(config.max_concurrent, 8);
342 assert_eq!(config.timeout_ms, Some(5000));
343 assert!(config.enable_retry);
344 assert_eq!(config.max_retries, 3);
345 assert_eq!(config.backoff_ms, 200);
346 }
347
348 #[test]
349 fn test_async_stats() {
350 let mut stats = AsyncStats::new();
351 stats.total_executions = 100;
352 stats.successful = 95;
353 stats.failed = 3;
354 stats.timeouts = 2;
355
356 assert_eq!(stats.success_rate(), 0.95);
357 assert!(stats.summary().contains("95.0%"));
358 }
359
360 #[test]
361 fn test_async_error_display() {
362 let err = AsyncExecutionError::<String>::Timeout { elapsed_ms: 5000 };
363 assert_eq!(err.to_string(), "Execution timed out after 5000ms");
364
365 let err2 = AsyncExecutionError::<String>::ExecutorBusy { queue_size: 10 };
366 assert!(err2.to_string().contains("queue size: 10"));
367 }
368
369 #[tokio::test]
370 async fn test_execution_handle() {
371 let (handle, mut rx) = AsyncExecutionHandle::new("test-123".to_string());
372 assert_eq!(handle.execution_id(), "test-123");
373 assert!(handle.elapsed().as_millis() < 100);
374
375 handle.cancel().await.unwrap();
377 assert!(rx.recv().await.is_some());
378 }
379}