Skip to main content

Crate tensorism

Crate tensorism 

Source
Expand description

Tensorism is a library built on top of ndarray that provides a domain-specific language (DSL) for expressing tensor computations using named indexes.

The main goal of tensorism is to make multi-index array expressions explicit, readable and compositional, while remaining compatible with the ndarray ecosystem.

At the moment, tensorism focuses on expressiveness and correctness. The evaluation strategy is primarily iterator-based, and no guarantee is made regarding optimal performance. The DSL and its internal translation strategy are considered experimental and may evolve in non-backward-compatible ways in future versions.

§Indexed expressions and the for construct

Tensorism introduces a special construct of the form for ⟨indexes⟩ => ⟨body⟩, where ⟨indexes⟩ is a space-separated list of Rust identifiers.

These identifiers are called Ricci indexes and may be any valid Rust identifier. Names such as i, j or k are used throughout the documentation purely as conventional examples.

A Ricci index represents a variable of type usize. Its range of values is determined by the arrays it indexes inside the associated expression. When several arrays use the same index, their corresponding dimensions must agree, otherwise the program panics at runtime.

Example:

let x: Array2<u8> = new_ndarray! { for i j => a[i, j] + b[j] };

Here, index i ranges over the first dimension of a, while index j ranges over the second dimension of a and the only dimension of b, which must be compatible.

§Generating new arrays

When a for ⟨indexes⟩ => ⟨body⟩ construct appears at the top level of a new_ndarray! invocation, it generates a new ndarray::Array whose number of dimensions matches the number of indexes.

For instance, the previous example produces an Array2<T>, where T is the type of the expression a[i, j] + b[j].

More complex expressions involving conditionals and multiple input arrays are also supported:

let y = new_ndarray! {
    for i j k =>
        if p[i, j] - 0.3 < 0.4 * q[j, k] {
            r[j] * q[j, k] + 0.2
        } else {
            0.5 * s[i, j, k]
        }
};

§Iterators and aggregation

A for ⟨indexes⟩ => ⟨body⟩ construct does not have to appear at the top level. When used as a sub-expression, it evaluates to a Rust iterator over the successive values of ⟨body⟩ as the indexes vary.

This makes it possible to express aggregations by passing such iterators to user-defined functions or standard iterator consumers such as sum, min, or fold:

let x: i64 = new_ndarray! {
    Iterator::sum(for i => Iterator::min(for j => a[i, j]).unwrap())
};

Any function or method that consumes an iterator of the appropriate item type can be used.

§Reindexing

In addition to direct indexing, tensorism provides explicit reindexing objects such as Reindexing1, Reindexing2, and higher-order variants. These objects represent immutable, bounded index transformations and can be safely composed inside indexed expressions.

Reindexing allows expressing indirect access patterns while preserving runtime guarantees on index validity.

§Experimental status

Tensorism is currently experimental. In particular, the iterator-based evaluation model and the interaction with ndarray’s high-level APIs may change in future releases in order to enable more aggressive optimizations.

Macros§

new_ndarray
new_ndarray! is the main macro of the Tensorism crate for constructing arrays from Rust expressions using the Tensorism DSL.

Structs§

Reindexing1
A type representing an immutable mapping from integer interval 0..input0_bound to integer interval 0..output_bound.
Reindexing2
A type representing an immutable mapping from a pair of integer intervals 0..input0_bound and 0..input1_bound to integer interval 0..output_bound.
Reindexing3
A type representing an immutable mapping from a triple of integer intervals 0..input0_bound, 0..input1_bound, and 0..input2_bound to integer interval 0..output_bound.
Reindexing4
A type representing an immutable mapping from a quadruple of integer intervals 0..input0_bound, 0..input1_bound, 0..input2_bound, and 0..input3_bound to integer interval 0..output_bound.