Skip to main content

tensorlogic_infer/
async_exec.rs

1//! Asynchronous execution traits for concurrent tensor operations.
2//!
3//! This module provides async/await-based execution interfaces for
4//! non-blocking tensor computations and streaming operations.
5//!
6//! Note: Async support requires the "async" feature flag.
7
8#[cfg(feature = "async")]
9use std::collections::HashMap;
10#[cfg(feature = "async")]
11use std::future::Future;
12#[cfg(feature = "async")]
13use std::pin::Pin;
14
15#[cfg(feature = "async")]
16use tensorlogic_ir::EinsumGraph;
17
18#[cfg(feature = "async")]
19use crate::batch::BatchResult;
20#[cfg(feature = "async")]
21use crate::streaming::{StreamResult, StreamingConfig};
22
23/// Type alias for pinned boxed futures
24#[cfg(feature = "async")]
25pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
26
27/// Asynchronous executor trait for non-blocking execution
28///
29/// This trait enables concurrent execution of multiple graphs without blocking,
30/// making it suitable for high-throughput inference servers and streaming applications.
31#[cfg(feature = "async")]
32pub trait TlAsyncExecutor {
33    type Tensor: Send;
34    type Error: Send;
35
36    /// Execute a graph asynchronously
37    fn execute_async<'a>(
38        &'a mut self,
39        graph: &'a EinsumGraph,
40        inputs: &'a HashMap<String, Self::Tensor>,
41    ) -> BoxFuture<'a, Result<Vec<Self::Tensor>, Self::Error>>;
42
43    /// Check if executor is ready (non-blocking)
44    fn is_ready(&self) -> bool {
45        true
46    }
47
48    /// Wait until executor is ready
49    fn wait_ready(&mut self) -> BoxFuture<'_, ()>
50    where
51        Self: Send,
52    {
53        Box::pin(async move {
54            while !self.is_ready() {
55                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
56            }
57        })
58    }
59}
60
61/// Asynchronous batch executor
62#[cfg(feature = "async")]
63pub trait TlAsyncBatchExecutor: TlAsyncExecutor {
64    /// Execute a batch asynchronously
65    fn execute_batch_async<'a>(
66        &'a mut self,
67        graph: &'a EinsumGraph,
68        batch_inputs: Vec<HashMap<String, Self::Tensor>>,
69    ) -> BoxFuture<'a, Result<BatchResult<Self::Tensor>, Self::Error>>;
70}
71
72/// Type alias for async stream results
73#[cfg(feature = "async")]
74pub type AsyncStreamResults<T, E> = Vec<Result<StreamResult<T>, E>>;
75
76/// Asynchronous streaming executor
77#[cfg(feature = "async")]
78pub trait TlAsyncStreamExecutor: TlAsyncExecutor {
79    /// Execute stream asynchronously with chunking
80    fn execute_stream_async<'a>(
81        &'a mut self,
82        graph: &'a EinsumGraph,
83        input_stream: Vec<Vec<Vec<Self::Tensor>>>,
84        config: &'a StreamingConfig,
85    ) -> BoxFuture<'a, AsyncStreamResults<Self::Tensor, Self::Error>>;
86}
87
88/// Errors specific to async execution
89#[derive(Debug, Clone)]
90pub enum AsyncExecutionError<E> {
91    /// Execution timed out
92    Timeout { elapsed_ms: u64 },
93    /// Executor is busy (overloaded)
94    ExecutorBusy { queue_size: usize },
95    /// Cancellation requested
96    Cancelled,
97    /// Underlying executor error
98    ExecutorError(E),
99    /// Future was dropped before completion
100    Dropped,
101}
102
103impl<E: std::fmt::Display> std::fmt::Display for AsyncExecutionError<E> {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        match self {
106            Self::Timeout { elapsed_ms } => {
107                write!(f, "Execution timed out after {}ms", elapsed_ms)
108            }
109            Self::ExecutorBusy { queue_size } => {
110                write!(
111                    f,
112                    "Executor is busy (queue size: {}), try again later",
113                    queue_size
114                )
115            }
116            Self::Cancelled => write!(f, "Execution was cancelled"),
117            Self::ExecutorError(e) => write!(f, "Executor error: {}", e),
118            Self::Dropped => write!(f, "Future was dropped before completion"),
119        }
120    }
121}
122
123impl<E: std::error::Error> std::error::Error for AsyncExecutionError<E> {}
124
125/// Async execution handle for tracking and cancellation
126#[cfg(feature = "async")]
127pub struct AsyncExecutionHandle {
128    execution_id: String,
129    started_at: std::time::Instant,
130    cancel_token: tokio::sync::mpsc::Sender<()>,
131}
132
133#[cfg(feature = "async")]
134impl AsyncExecutionHandle {
135    /// Create a new execution handle
136    pub fn new(execution_id: String) -> (Self, tokio::sync::mpsc::Receiver<()>) {
137        let (tx, rx) = tokio::sync::mpsc::channel(1);
138        (
139            AsyncExecutionHandle {
140                execution_id,
141                started_at: std::time::Instant::now(),
142                cancel_token: tx,
143            },
144            rx,
145        )
146    }
147
148    /// Get execution ID
149    pub fn execution_id(&self) -> &str {
150        &self.execution_id
151    }
152
153    /// Get elapsed time
154    pub fn elapsed(&self) -> std::time::Duration {
155        self.started_at.elapsed()
156    }
157
158    /// Request cancellation
159    pub async fn cancel(&self) -> Result<(), AsyncExecutionError<std::io::Error>> {
160        self.cancel_token
161            .send(())
162            .await
163            .map_err(|_| AsyncExecutionError::Cancelled)
164    }
165}
166
167/// Async executor pool for load balancing
168#[cfg(feature = "async")]
169pub struct AsyncExecutorPool<E: TlAsyncExecutor> {
170    executors: Vec<E>,
171    next_index: std::sync::atomic::AtomicUsize,
172}
173
174#[cfg(feature = "async")]
175impl<E: TlAsyncExecutor> AsyncExecutorPool<E> {
176    /// Create a new executor pool
177    pub fn new(executors: Vec<E>) -> Self {
178        AsyncExecutorPool {
179            executors,
180            next_index: std::sync::atomic::AtomicUsize::new(0),
181        }
182    }
183
184    /// Get number of executors in pool
185    pub fn size(&self) -> usize {
186        self.executors.len()
187    }
188
189    /// Get next executor index (round-robin)
190    pub fn get_next_index(&self) -> usize {
191        self.next_index
192            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
193            % self.executors.len()
194    }
195
196    /// Get least loaded executor index
197    pub fn get_least_loaded_index(&self) -> usize {
198        // Simple implementation: return first ready executor
199        // In production, track actual load per executor
200        for (idx, executor) in self.executors.iter().enumerate() {
201            if executor.is_ready() {
202                return idx;
203            }
204        }
205        0
206    }
207
208    /// Execute on any available executor
209    pub async fn execute_any<'a>(
210        &'a mut self,
211        graph: &'a EinsumGraph,
212        inputs: &'a HashMap<String, E::Tensor>,
213    ) -> Result<Vec<E::Tensor>, E::Error> {
214        let index = self.get_least_loaded_index();
215        self.executors[index].execute_async(graph, inputs).await
216    }
217}
218
219/// Configuration for async execution
220#[derive(Debug, Clone)]
221pub struct AsyncConfig {
222    /// Maximum number of concurrent executions
223    pub max_concurrent: usize,
224    /// Timeout for each execution (milliseconds)
225    pub timeout_ms: Option<u64>,
226    /// Enable automatic retry on transient failures
227    pub enable_retry: bool,
228    /// Maximum number of retries
229    pub max_retries: usize,
230    /// Backoff strategy for retries
231    pub backoff_ms: u64,
232}
233
234impl Default for AsyncConfig {
235    fn default() -> Self {
236        AsyncConfig {
237            max_concurrent: 4,
238            timeout_ms: None,
239            enable_retry: false,
240            max_retries: 3,
241            backoff_ms: 100,
242        }
243    }
244}
245
246impl AsyncConfig {
247    /// Create a new async configuration
248    pub fn new() -> Self {
249        Self::default()
250    }
251
252    /// Set maximum concurrent executions
253    pub fn with_max_concurrent(mut self, max: usize) -> Self {
254        self.max_concurrent = max;
255        self
256    }
257
258    /// Set timeout
259    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
260        self.timeout_ms = Some(timeout_ms);
261        self
262    }
263
264    /// Enable retry
265    pub fn with_retry(mut self, max_retries: usize, backoff_ms: u64) -> Self {
266        self.enable_retry = true;
267        self.max_retries = max_retries;
268        self.backoff_ms = backoff_ms;
269        self
270    }
271}
272
273/// Async execution statistics
274#[derive(Debug, Clone, Default)]
275pub struct AsyncStats {
276    /// Total executions started
277    pub total_executions: usize,
278    /// Successful completions
279    pub successful: usize,
280    /// Failed executions
281    pub failed: usize,
282    /// Timed out executions
283    pub timeouts: usize,
284    /// Cancelled executions
285    pub cancelled: usize,
286    /// Average execution time (milliseconds)
287    pub avg_execution_time_ms: f64,
288    /// Peak concurrent executions
289    pub peak_concurrent: usize,
290}
291
292impl AsyncStats {
293    /// Create new stats
294    pub fn new() -> Self {
295        Self::default()
296    }
297
298    /// Success rate
299    pub fn success_rate(&self) -> f64 {
300        if self.total_executions == 0 {
301            0.0
302        } else {
303            self.successful as f64 / self.total_executions as f64
304        }
305    }
306
307    /// Summary report
308    pub fn summary(&self) -> String {
309        format!(
310            "Async Execution Stats:\n\
311             - Total: {}\n\
312             - Successful: {} ({:.1}%)\n\
313             - Failed: {}\n\
314             - Timeouts: {}\n\
315             - Cancelled: {}\n\
316             - Avg time: {:.2}ms\n\
317             - Peak concurrent: {}",
318            self.total_executions,
319            self.successful,
320            self.success_rate() * 100.0,
321            self.failed,
322            self.timeouts,
323            self.cancelled,
324            self.avg_execution_time_ms,
325            self.peak_concurrent
326        )
327    }
328}
329
330#[cfg(all(test, feature = "async"))]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_async_config() {
336        let config = AsyncConfig::new()
337            .with_max_concurrent(8)
338            .with_timeout(5000)
339            .with_retry(3, 200);
340
341        assert_eq!(config.max_concurrent, 8);
342        assert_eq!(config.timeout_ms, Some(5000));
343        assert!(config.enable_retry);
344        assert_eq!(config.max_retries, 3);
345        assert_eq!(config.backoff_ms, 200);
346    }
347
348    #[test]
349    fn test_async_stats() {
350        let mut stats = AsyncStats::new();
351        stats.total_executions = 100;
352        stats.successful = 95;
353        stats.failed = 3;
354        stats.timeouts = 2;
355
356        assert_eq!(stats.success_rate(), 0.95);
357        assert!(stats.summary().contains("95.0%"));
358    }
359
360    #[test]
361    fn test_async_error_display() {
362        let err = AsyncExecutionError::<String>::Timeout { elapsed_ms: 5000 };
363        assert_eq!(err.to_string(), "Execution timed out after 5000ms");
364
365        let err2 = AsyncExecutionError::<String>::ExecutorBusy { queue_size: 10 };
366        assert!(err2.to_string().contains("queue size: 10"));
367    }
368
369    #[tokio::test]
370    async fn test_execution_handle() {
371        let (handle, mut rx) = AsyncExecutionHandle::new("test-123".to_string());
372        assert_eq!(handle.execution_id(), "test-123");
373        assert!(handle.elapsed().as_millis() < 100);
374
375        // Test cancellation
376        handle.cancel().await.unwrap();
377        assert!(rx.recv().await.is_some());
378    }
379}