1use crate::graph::{ComputationGraph, NodeId};
4use crate::{CompiledKernel, ExecutionStats, JitError, JitResult, TensorRef};
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7use std::time::Instant;
8
9#[derive(Clone)]
11pub struct JitRuntime {
12 cache: Arc<Mutex<KernelCache>>,
14
15 stats: Arc<Mutex<ExecutionStats>>,
17
18 config: RuntimeConfig,
20}
21
22impl JitRuntime {
23 pub fn new(config: crate::JitConfig) -> Self {
25 Self {
26 cache: Arc::new(Mutex::new(KernelCache::new())),
27 stats: Arc::new(Mutex::new(ExecutionStats::default())),
28 config: RuntimeConfig::from_jit_config(config),
29 }
30 }
31
32 pub fn execute(
34 &self,
35 graph: &ComputationGraph,
36 kernels: &[CompiledKernel],
37 inputs: &[TensorRef],
38 ) -> JitResult<Vec<TensorRef>> {
39 let start_time = Instant::now();
40
41 let mut context = ExecutionContext::new(graph, inputs)?;
43
44 for kernel in kernels {
46 self.execute_kernel(&mut context, kernel)?;
47 }
48
49 self.update_stats(start_time.elapsed().as_micros() as u64, kernels.len());
51
52 context.get_outputs()
54 }
55
56 fn execute_kernel(
58 &self,
59 context: &mut ExecutionContext,
60 kernel: &CompiledKernel,
61 ) -> JitResult<()> {
62 let cache_hit = if self.config.enable_caching {
64 self.cache
65 .lock()
66 .expect("lock should not be poisoned")
67 .get(&kernel.id)
68 .is_some()
69 } else {
70 false
71 };
72
73 if cache_hit {
74 let mut cache = self.cache.lock().expect("lock should not be poisoned");
76 if let Some(exec_fn) = cache.get(&kernel.id) {
77 exec_fn(context, kernel)?;
78 }
79 } else {
80 let exec_fn = self.compile_kernel(kernel)?;
82
83 exec_fn(context, kernel)?;
85
86 if self.config.enable_caching {
88 self.cache
89 .lock()
90 .expect("cache lock should not be poisoned")
91 .insert(kernel.id.clone(), exec_fn);
92 }
93 }
94
95 Ok(())
96 }
97
98 fn compile_kernel(&self, _kernel: &CompiledKernel) -> JitResult<ExecutableFn> {
100 Ok(Box::new(move |context, kernel| {
107 interpreter_execute(context, kernel)
108 }))
109 }
110
111 fn update_stats(&self, elapsed_us: u64, kernel_count: usize) {
113 let mut stats = self.stats.lock().expect("lock should not be poisoned");
114 stats.total_time_us += elapsed_us;
115 stats.kernel_launches += kernel_count;
116
117 let cache = self.cache.lock().expect("lock should not be poisoned");
119 stats.cache_hit_rate = cache.hit_rate();
120 }
121
122 pub fn stats(&self) -> ExecutionStats {
124 self.stats
125 .lock()
126 .expect("lock should not be poisoned")
127 .clone()
128 }
129
130 pub fn clear_cache(&self) {
132 self.cache
133 .lock()
134 .expect("lock should not be poisoned")
135 .clear();
136 }
137}
138
139#[derive(Debug, Clone)]
141struct RuntimeConfig {
142 enable_caching: bool,
143 #[allow(dead_code)]
144 enable_profiling: bool,
145 #[allow(dead_code)]
146 max_cache_size: usize,
147}
148
149impl RuntimeConfig {
150 fn from_jit_config(config: crate::JitConfig) -> Self {
151 Self {
152 enable_caching: config.enable_caching,
153 enable_profiling: config.enable_profiling,
154 max_cache_size: 1000, }
156 }
157}
158
159struct KernelCache {
161 cache: HashMap<String, ExecutableFn>,
162 hits: usize,
163 misses: usize,
164 max_size: usize,
165}
166
167impl KernelCache {
168 fn new() -> Self {
169 Self {
170 cache: HashMap::new(),
171 hits: 0,
172 misses: 0,
173 max_size: 1000,
174 }
175 }
176
177 fn get(&mut self, key: &str) -> Option<&ExecutableFn> {
178 if self.cache.contains_key(key) {
179 self.hits += 1;
180 self.cache.get(key)
181 } else {
182 self.misses += 1;
183 None
184 }
185 }
186
187 fn insert(&mut self, key: String, value: ExecutableFn) {
188 if self.cache.len() >= self.max_size {
190 if let Some(first_key) = self.cache.keys().next().cloned() {
192 self.cache.remove(&first_key);
193 }
194 }
195
196 self.cache.insert(key, value);
197 }
198
199 fn clear(&mut self) {
200 self.cache.clear();
201 self.hits = 0;
202 self.misses = 0;
203 }
204
205 fn hit_rate(&self) -> f32 {
206 let total = self.hits + self.misses;
207 if total > 0 {
208 self.hits as f32 / total as f32
209 } else {
210 0.0
211 }
212 }
213}
214
215type ExecutableFn =
217 Box<dyn Fn(&mut ExecutionContext, &CompiledKernel) -> JitResult<()> + Send + Sync>;
218
219pub struct ExecutionContext {
221 #[allow(dead_code)]
223 inputs: Vec<TensorRef>,
224
225 intermediates: HashMap<NodeId, TensorRef>,
227
228 output_ids: Vec<NodeId>,
230}
231
232impl ExecutionContext {
233 fn new(graph: &ComputationGraph, inputs: &[TensorRef]) -> JitResult<Self> {
235 if inputs.len() != graph.inputs.len() {
236 return Err(JitError::RuntimeError(format!(
237 "Expected {} inputs, got {}",
238 graph.inputs.len(),
239 inputs.len()
240 )));
241 }
242
243 let mut intermediates = HashMap::new();
244
245 for (i, &node_id) in graph.inputs.iter().enumerate() {
247 intermediates.insert(node_id, inputs[i].clone());
248 }
249
250 Ok(Self {
251 inputs: inputs.to_vec(),
252 intermediates,
253 output_ids: graph.outputs.clone(),
254 })
255 }
256
257 pub fn get_tensor(&self, node_id: NodeId) -> Option<&TensorRef> {
259 self.intermediates.get(&node_id)
260 }
261
262 pub fn set_tensor(&mut self, node_id: NodeId, tensor: TensorRef) {
264 self.intermediates.insert(node_id, tensor);
265 }
266
267 fn get_outputs(&self) -> JitResult<Vec<TensorRef>> {
269 let mut outputs = Vec::new();
270
271 for &output_id in &self.output_ids {
272 let tensor = self.intermediates.get(&output_id).ok_or_else(|| {
273 JitError::RuntimeError(format!("Output node {:?} not computed", output_id))
274 })?;
275 outputs.push(tensor.clone());
276 }
277
278 Ok(outputs)
279 }
280}
281
282fn interpreter_execute(context: &mut ExecutionContext, kernel: &CompiledKernel) -> JitResult<()> {
284 if kernel.source_nodes.is_empty() {
288 let missing_outputs: Vec<_> = context
292 .output_ids
293 .iter()
294 .filter(|&&id| !context.intermediates.contains_key(&id))
295 .copied()
296 .collect();
297
298 for &output_id in &missing_outputs {
299 let input_data = if let Some(input_tensor) = context.intermediates.values().next() {
301 input_tensor.data.clone()
302 } else {
303 vec![1.0; 10] };
305
306 let output_data: Vec<f32> = input_data
308 .iter()
309 .map(|&x| if x > 0.0 { x } else { 0.0 })
310 .collect();
311
312 let output_tensor = crate::TensorRef { data: output_data };
313 context.set_tensor(output_id, output_tensor);
314 }
315 } else {
316 for &node_id in &kernel.source_nodes {
318 let input_data = if let Some(input_tensor) = context.intermediates.values().next() {
321 input_tensor.data.clone()
322 } else {
323 vec![1.0; 10] };
325
326 let output_data: Vec<f32> = input_data
328 .iter()
329 .map(|&x| if x > 0.0 { x } else { 0.0 })
330 .collect();
331
332 let output_tensor = crate::TensorRef { data: output_data };
333 context.set_tensor(node_id, output_tensor);
334 }
335 }
336
337 Ok(())
338}
339
340pub struct MemoryPool {
342 pools: HashMap<usize, Vec<Vec<u8>>>,
343}
344
345impl MemoryPool {
346 pub fn new() -> Self {
347 Self {
348 pools: HashMap::new(),
349 }
350 }
351
352 pub fn allocate(&mut self, size: usize) -> Vec<u8> {
353 let pool_size = size.next_power_of_two();
355
356 if let Some(pool) = self.pools.get_mut(&pool_size) {
357 if let Some(mut buffer) = pool.pop() {
358 buffer.resize(size, 0);
359 return buffer;
360 }
361 }
362
363 vec![0u8; size]
364 }
365
366 pub fn release(&mut self, mut buffer: Vec<u8>) {
367 let pool_size = buffer.capacity().next_power_of_two();
368 buffer.clear();
369
370 self.pools.entry(pool_size).or_default().push(buffer);
371 }
372}
373
374impl Default for MemoryPool {
375 fn default() -> Self {
376 Self::new()
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383 use crate::graph::{ComputationGraph, Node};
384
385 #[test]
386 fn test_kernel_cache() {
387 let mut cache = KernelCache::new();
388 cache.max_size = 2;
389
390 let fn1: ExecutableFn = Box::new(|_, _| Ok(()));
392 cache.insert("kernel1".to_string(), fn1);
393
394 assert!(cache.get("kernel1").is_some());
395 assert!(cache.get("kernel2").is_none());
396
397 assert_eq!(cache.hits, 1);
398 assert_eq!(cache.misses, 1);
399 assert_eq!(cache.hit_rate(), 0.5);
400 }
401
402 #[test]
403 fn test_memory_pool() {
404 let mut pool = MemoryPool::new();
405
406 let buf1 = pool.allocate(100);
408 assert_eq!(buf1.len(), 100);
409
410 pool.release(buf1);
411
412 let buf2 = pool.allocate(100);
414 assert_eq!(buf2.len(), 100);
415 }
416
417 #[test]
418 fn test_execution_context() {
419 let mut graph = ComputationGraph::new();
420
421 let input_node = graph.add_node(
423 Node::new(crate::graph::Operation::Input, "input".to_string())
424 .with_output_shapes(vec![Some(crate::graph::shape_from_slice(&[10]))])
425 .with_dtypes(vec![torsh_core::DType::F32])
426 .with_device(torsh_core::DeviceType::Cpu),
427 );
428 graph.add_input(input_node);
429
430 let inputs = vec![crate::TensorRef {
431 data: vec![1.0; 10],
432 }];
433
434 let context = ExecutionContext::new(&graph, &inputs);
435 assert!(context.is_ok());
436 }
437}