svod_tensor/tensor_registry.rs
1//! Global tensor registry for atomic graph substitution.
2//!
3//! This module implements Tinygrad's `all_tensors` pattern using papaya's lock-free HashMap.
4//! When rangeify transforms a UOp (e.g., NEG → BUFFERIZE(NEG)), the `becomes_map` must be
5//! applied to ALL tensors that reference it - not just the one being realized.
6//!
7//! Without this, diamond patterns (like argmin's NEG feeding both MAX and EQ) fail because
8//! different consumers see different versions of the same producer.
9//!
10//! # Thread Safety
11//!
12//! All operations are lock-free and thread-safe. Uses papaya's epoch-based reclamation
13//! for concurrent access and parking_lot::RwLock for interior UOp mutation.
14//!
15//! # Memory Management (Tinygrad-aligned)
16//!
17//! Tensors are stored as `Weak<TensorEntry>` in the registry. When all strong references
18//! (held by `Tensor` structs) are dropped, the entry becomes eligible for cleanup.
19//! Dead weak refs are cleaned lazily on access or via `gc_dead_refs()`.
20//!
21//! This matches Tinygrad's `weakref.WeakKeyDictionary` pattern - no manual cleanup required.
22//!
23//! # Buffer Storage
24//!
25//! Buffers are stored in a separate map (`BUFFERS`) indexed by UOp ID.
26//! This is a lookup index for `collect_input_buffers()` during schedule creation.
27//! - Key is UOp ID (unique per buffer via `Op::Unique` monotonic counter)
28//! - Value is `Arc<Buffer>` (strong ref — kept alive for the duration of the computation)
29//! - Stale entries cleaned via `gc_dead_refs()` when the UOp is no longer alive
30//!
31//! Unlike Tinygrad (which stores buffers inline in UOps), Svod uses a separate
32//! index because UOps are immutable and hash-consed. Unique buffer IDs guarantee
33//! entries never collide, so stale entries are harmless — only a memory concern.
34//!
35//! TensorEntry also caches the buffer for direct access via tensor.buffer().
36
37use std::collections::HashMap;
38use std::sync::atomic::{AtomicU64, Ordering};
39use std::sync::{Arc, OnceLock, Weak};
40
41use papaya::HashMap as PapayaMap;
42use parking_lot::RwLock;
43use svod_device::Buffer;
44use svod_ir::{Op, UOp, UOpKey};
45
46/// Atomic counter for unique tensor IDs.
47static TENSOR_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
48
49fn next_tensor_id() -> u64 {
50 TENSOR_ID_COUNTER.fetch_add(1, Ordering::Relaxed)
51}
52
53/// Entry in the global tensor registry.
54///
55/// Uses RwLock for interior mutability of the UOp during global substitution.
56/// The RwLock allows concurrent reads (typical tensor operations) with exclusive
57/// writes only during `apply_map_to_tensors`.
58///
59/// Buffer can be set at construction (input tensors) or later (realized tensors).
60/// Uses OnceLock for thread-safe one-time initialization.
61pub struct TensorEntry {
62 /// Unique tensor ID (stable across UOp updates).
63 pub id: u64,
64 /// The computation graph (mutable for global substitution).
65 pub uop: RwLock<Arc<UOp>>,
66 /// The materialized buffer (can be set once via OnceLock).
67 buffer: OnceLock<Arc<Buffer>>,
68}
69
70impl TensorEntry {
71 /// Get the buffer if materialized.
72 pub fn buffer(&self) -> Option<&Arc<Buffer>> {
73 self.buffer.get()
74 }
75
76 /// Set the buffer (can only be called once, subsequent calls are no-ops).
77 /// Returns true if buffer was set, false if already set.
78 pub fn set_buffer(&self, buffer: Arc<Buffer>) -> bool {
79 self.buffer.set(buffer).is_ok()
80 }
81}
82
83// Global tensor registry using lock-free concurrent HashMap.
84//
85// Design: Stores Weak<TensorEntry> for automatic memory management (Tinygrad-aligned).
86// - Tensor structs hold Arc<TensorEntry> (strong refs)
87// - Registry holds Weak<TensorEntry> (weak refs)
88// - When Tensor is dropped, TensorEntry can be cleaned up
89// - Dead weak refs cleaned lazily on access or via gc_dead_refs()
90static TENSORS: OnceLock<PapayaMap<u64, Weak<TensorEntry>>> = OnceLock::new();
91
92// Direct buffer storage: UOp ID → Arc<Buffer>.
93//
94// Lookup index for collect_input_buffers() during schedule creation.
95// Buffer UOp IDs are unique (Op::Unique monotonic counter), so entries never
96// collide across tests. Stale entries cleaned via gc_dead_refs().
97static BUFFERS: OnceLock<PapayaMap<u64, Arc<Buffer>>> = OnceLock::new();
98
99fn tensors() -> &'static PapayaMap<u64, Weak<TensorEntry>> {
100 TENSORS.get_or_init(PapayaMap::new)
101}
102
103fn buffers() -> &'static PapayaMap<u64, Arc<Buffer>> {
104 BUFFERS.get_or_init(PapayaMap::new)
105}
106
107/// Register a new tensor without buffer (for lazy computation graphs).
108///
109/// Thread-safe: each call creates a unique tensor ID.
110/// The registry stores a weak reference; the caller holds the strong reference.
111///
112/// # Arguments
113///
114/// * `uop` - The tensor's computation graph
115///
116/// # Returns
117///
118/// Arc to the registered TensorEntry (caller owns the strong reference)
119pub fn register_tensor(uop: Arc<UOp>) -> Arc<TensorEntry> {
120 let id = next_tensor_id();
121 let entry = Arc::new(TensorEntry { id, uop: RwLock::new(uop), buffer: OnceLock::new() });
122
123 // Store weak ref in registry - entry stays alive as long as caller holds Arc
124 let guard = tensors().guard();
125 tensors().insert(id, Arc::downgrade(&entry), &guard);
126
127 entry
128}
129
130/// Register a new tensor with buffer (for input tensors and realized tensors).
131///
132/// Stores buffer in both:
133/// 1. BUFFERS map (indexed by UOp ID) - for schedule buffer lookups
134/// 2. TensorEntry.buffer - for direct tensor access
135///
136/// The registry stores a weak reference; the caller holds the strong reference.
137///
138/// # Arguments
139///
140/// * `uop` - The tensor's computation graph
141/// * `buffer` - The materialized buffer
142/// * `buffer_uop_id` - The UOp ID to index under (for lookups)
143///
144/// # Returns
145///
146/// Arc to the registered TensorEntry (caller owns the strong reference)
147pub fn register_tensor_with_buffer(uop: Arc<UOp>, buffer: Arc<Buffer>, buffer_uop_id: u64) -> Arc<TensorEntry> {
148 let id = next_tensor_id();
149 let entry = Arc::new(TensorEntry { id, uop: RwLock::new(uop), buffer: OnceLock::from(buffer.clone()) });
150
151 // Store weak ref in tensor registry
152 let guard = tensors().guard();
153 tensors().insert(id, Arc::downgrade(&entry), &guard);
154
155 // Store buffer indexed by UOp ID (for collect_input_buffers lookups)
156 let buf_guard = buffers().guard();
157 buffers().insert(buffer_uop_id, buffer, &buf_guard);
158
159 entry
160}
161
162/// Get buffer by UOp ID (cloned).
163///
164/// Direct lookup from BUFFERS map.
165/// Used by collect_input_buffers() during schedule creation.
166pub fn get_buffer(uop_id: u64) -> Option<Buffer> {
167 let guard = buffers().guard();
168 buffers().get(&uop_id, &guard).map(|arc_buf| (**arc_buf).clone())
169}
170
171/// Get buffer by UOp ID as Arc (shared reference, no clone).
172///
173/// Used by ensure_buffer() to attach a buffer without cloning.
174pub fn get_buffer_arc(uop_id: u64) -> Option<Arc<Buffer>> {
175 let guard = buffers().guard();
176 buffers().get(&uop_id, &guard).cloned()
177}
178
179/// Remove buffer entry from the BUFFERS map.
180///
181/// Called during cleanup to eagerly remove stale entries.
182pub fn remove_buffer(uop_id: u64) {
183 let buf_guard = buffers().guard();
184 buffers().remove(&uop_id, &buf_guard);
185}
186
187/// Get count of buffers in the registry (for testing/diagnostics).
188pub fn buffer_count() -> usize {
189 buffers().len()
190}
191
192/// Register a buffer for an existing tensor.
193///
194/// Used by realize() to associate output buffers with tensors for schedule lookups.
195/// Stores buffer in both BUFFERS map and TensorEntry.
196///
197/// # Arguments
198///
199/// * `uop_id` - The UOp ID to index under (for lookups)
200/// * `tensor_id` - The tensor ID that owns this buffer
201/// * `buffer` - The materialized buffer
202pub fn register_buffer(uop_id: u64, tensor_id: u64, buffer: Arc<Buffer>) {
203 // Store buffer indexed by UOp ID (for collect_input_buffers lookups)
204 let buf_guard = buffers().guard();
205 buffers().insert(uop_id, buffer.clone(), &buf_guard);
206
207 // Also set buffer on the TensorEntry for direct tensor access
208 if let Some(entry) = get_tensor(tensor_id) {
209 entry.set_buffer(buffer);
210 }
211}
212
213/// Register a buffer by UOp ID only (no TensorEntry association).
214///
215/// Used for pending assign side-realization where the buffer belongs
216/// to the computation graph, not a specific tensor.
217pub fn register_buffer_by_uop_id(uop_id: u64, buffer: Arc<Buffer>) {
218 let guard = buffers().guard();
219 buffers().insert(uop_id, buffer, &guard);
220}
221
222/// Get a tensor entry by ID.
223///
224/// Thread-safe read operation. Returns None if tensor was dropped.
225pub fn get_tensor(id: u64) -> Option<Arc<TensorEntry>> {
226 let guard = tensors().guard();
227 tensors().get(&id, &guard)?.upgrade()
228}
229
230/// Remove dead weak references and stale buffer entries from the registry.
231///
232/// Tensors: removes entries whose `Weak<TensorEntry>` can no longer be upgraded.
233/// Buffers: removes entries whose UOp is no longer alive in the UOp cache.
234///
235/// This is optional — stale entries don't affect correctness (unique buffer IDs
236/// prevent collisions). Call this to reclaim registry memory in long-running programs.
237pub fn gc_dead_refs() {
238 // Clean dead tensor weak refs
239 let map = tensors();
240 let guard = map.guard();
241 let to_remove: Vec<u64> = map.iter(&guard).filter(|(_, weak)| weak.upgrade().is_none()).map(|(k, _)| *k).collect();
242 for id in to_remove {
243 map.remove(&id, &guard);
244 }
245
246 // Clean stale buffer entries (UOp no longer alive in the cache)
247 let live_uop_ids = svod_ir::uop::live_uop_ids();
248 let buf_map = buffers();
249 let buf_guard = buf_map.guard();
250 let stale_bufs: Vec<u64> =
251 buf_map.iter(&buf_guard).filter(|(uop_id, _)| !live_uop_ids.contains(uop_id)).map(|(id, _)| *id).collect();
252 for uop_id in stale_bufs {
253 buf_map.remove(&uop_id, &buf_guard);
254 }
255}
256
257/// Legacy alias for gc_dead_refs (for compatibility).
258///
259/// With weak references, tensors are automatically cleaned up when no longer
260/// referenced. This function now just cleans up dead weak refs in the registry.
261#[deprecated(note = "Tensor registry now uses weak refs - cleanup is automatic. Use gc_dead_refs() to clean registry.")]
262pub fn gc_unused_tensors() {
263 gc_dead_refs();
264}
265
266/// Apply a transformation map to ALL live tensors globally.
267///
268/// This is Svod's equivalent of Tinygrad's `_apply_map_to_tensors`.
269/// When rangeify creates a becomes_map (old UOp → new UOp), this function
270/// ensures ALL tensors see the same transformed versions.
271///
272/// # Arguments
273///
274/// * `becomes_map` - Mapping from original UOps to their transformed versions
275///
276/// # Thread Safety
277///
278/// This function acquires write locks on affected tensors during the update phase.
279/// Other tensors can still be read/written concurrently.
280#[allow(clippy::mutable_key_type)]
281pub fn apply_map_to_tensors(becomes_map: &HashMap<UOpKey, Arc<UOp>>) {
282 apply_map_to_tensors_inner(becomes_map, false);
283}
284
285/// Walk variant: replacements are NOT re-traversed.
286///
287/// Use when a replacement may contain the original key, such as the
288/// view-assign case `Buffer → After(Buffer, [Store(...)])`.
289#[allow(clippy::mutable_key_type)]
290pub fn apply_map_to_tensors_walk(becomes_map: &HashMap<UOpKey, Arc<UOp>>) {
291 apply_map_to_tensors_inner(becomes_map, true);
292}
293
294#[allow(clippy::mutable_key_type)]
295fn apply_map_to_tensors_inner(becomes_map: &HashMap<UOpKey, Arc<UOp>>, walk: bool) {
296 if becomes_map.is_empty() {
297 return;
298 }
299
300 let map = tensors();
301 let guard = map.guard();
302
303 // Phase 1: Find affected tensors (read-only scan, skip dead weak refs)
304 let affected: Vec<Arc<TensorEntry>> = map
305 .iter(&guard)
306 .filter_map(|(_, weak)| {
307 let entry = weak.upgrade()?; // Skip dead entries
308 let is_affected = {
309 let uop = entry.uop.read();
310 // Check if tensor's root UOp is in map
311 if becomes_map.contains_key(&UOpKey(uop.clone())) {
312 true
313 } else {
314 // Check if any node in the graph is in map
315 uop.toposort().iter().any(|n| becomes_map.contains_key(&UOpKey(n.clone())))
316 }
317 }; // uop lock dropped here
318 if is_affected { Some(entry) } else { None }
319 })
320 .collect();
321
322 if affected.is_empty() {
323 return;
324 }
325
326 // Phase 2: Create SINK of affected tensor UOps
327 let sources: Vec<Arc<UOp>> = affected.iter().map(|e| e.uop.read().clone()).collect();
328 let sink = UOp::sink(sources.clone());
329
330 // Phase 3: Atomic substitution across all affected UOps
331 let new_sink = if walk { sink.substitute_walk(becomes_map) } else { sink.substitute(becomes_map) };
332
333 // Phase 4: Update each tensor's UOp (acquires write locks)
334 if let Op::Sink { sources: new_sources, .. } = new_sink.op() {
335 for (entry, (old, new)) in affected.iter().zip(sources.iter().zip(new_sources.iter())) {
336 if !Arc::ptr_eq(old, new) {
337 *entry.uop.write() = new.clone();
338 }
339 }
340 }
341}
342
343#[cfg(test)]
344#[path = "test/unit/tensor_registry.rs"]
345mod tests;