train_station/lib.rs
1//! # Train Station
2//!
3//! PyTorch-inspired, zero-dependency Rust ML library focused on performance, research, and ergonomics.
4//!
5//! - **PyTorch-inspired API**: familiar semantics make switching effortless.
6//! - **SIMD-aligned memory pool**: per-thread pools with 64/32/16-byte alignment.
7//! - **Safe zero-copy views**: reshape/transpose/slice/as_strided with capacity checks and copy-on-write.
8//! - **Iterator-first API**: idiomatic Rust iterators that preserve gradients.
9//! - **Thread-safe GradTrack**: local/shared graphs with sharded storage; cross-thread capable.
10//! - **Broadcasting**: NumPy-style element-wise and batched matmul broadcasting.
11//!
12//! Mission: enable low-level control and simple composition to build larger objects and next-gen architectures.
13//!
14//! ## Quick Start
15//!
16//! ```rust
17//! use train_station::Tensor;
18//!
19//! // Parameters
20//! let x = Tensor::ones(vec![2, 3]);
21//! let w = Tensor::ones(vec![3, 2]).with_requires_grad();
22//! let b = Tensor::zeros(vec![2]).with_requires_grad();
23//!
24//! // Forward and backward
25//! let y = x.matmul(&w).add_tensor(&b).relu();
26//! let mut loss = y.sum();
27//! loss.backward(None);
28//!
29//! // Access gradients
30//! let gw = w.grad_owned().unwrap();
31//! assert_eq!(gw.shape().dims(), vec![3, 2]);
32//! ```
33//!
34//! ## Views (zero-copy) with gradients
35//!
36//! ```rust
37//! use train_station::Tensor;
38//!
39//! let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap().with_requires_grad();
40//! let v = x.slice_view(1, 1, 2); // [2.0, 3.0]
41//! let mut s = v.sum();
42//! s.backward(None);
43//! let gx = x.grad_owned().unwrap();
44//! assert_eq!(gx.data(), &[0.0, 1.0, 1.0, 0.0]);
45//! ```
46//!
47//! ## Broadcasting
48//!
49//! ```rust
50//! use train_station::Tensor;
51//!
52//! let a = Tensor::ones(vec![2, 3]);
53//! let b = Tensor::ones(vec![1, 3]);
54//! let c = a.add_tensor(&b); // [2,3] + [1,3] -> [2,3]
55//! assert_eq!(c.shape().dims(), vec![2, 3]);
56//! ```
57//!
58//! ## Iterators
59//!
60//! ```rust
61//! use train_station::Tensor;
62//!
63//! let t = Tensor::from_slice(&(0..12).map(|x| x as f32).collect::<Vec<_>>(), vec![12]).unwrap();
64//! let chunks: Vec<Tensor> = t.iter_chunks(4).collect();
65//! assert_eq!(chunks.len(), 3);
66//! ```
67//!
68//! ## Memory pool controls
69//!
70//! ```rust
71//! use train_station::Tensor;
72//! use train_station::tensor::with_no_mem_pool;
73//!
74//! // Default: pooled allocation
75//! let a = Tensor::new(vec![64]);
76//!
77//! // Force system allocator within a scope (useful for diagnostics or memory studies)
78//! let b = with_no_mem_pool(|| Tensor::new(vec![64]));
79//! assert_eq!(a.size(), b.size());
80//! ```
81//!
82//! Threading note: the memory pool is thread-local. If you create tensors in a worker
83//! thread and return them to another thread (e.g., main), prefer wrapping the creation in
84//! `with_no_mem_pool(|| ...)` so those allocations use the system allocator instead of a
85//! thread-local pool.
86//!
87//! ## Modules
88//!
89//! - `tensor`: core tensor, ops (SIMD), broadcasting, views, iterators
90//! - `gradtrack`: automatic differentiation (local/shared graph)
91//! - `optimizers`: Adam and related utilities
92//! - `serialization`: minimal JSON/binary I/O
93//! - `device`: CPU today; CUDA scaffolding is feature-gated
94//!
95//! ## Feature Flags
96//!
97//! - `cuda`: compile CUDA scaffolding (experimental; CPU is the default path)
98//!
99//! ## Data types
100//!
101//! Train Station currently targets `f32` tensors. We will be expanding to additional
102//! data types over time; see the Roadmap for direction.
103//!
104//! ### CUDA status
105//!
106//! The `cuda` feature is currently experimental and not ready for general use. It provides
107//! scaffolding only; production use should stick to CPU. API surface may change without notice.
108
109#[cfg(feature = "cuda")]
110pub(crate) mod cuda;
111pub(crate) mod device;
112pub mod gradtrack;
113pub mod optimizers;
114pub mod serialization;
115pub mod tensor;
116
117pub use device::{
118 cuda_device_count, cuda_is_available, current_device, get_default_device, set_default_device,
119 with_device, Device, DeviceType,
120};
121
122pub use tensor::Tensor;