Expand description
Rust bindings for XLA (Accelerated Linear Algebra).
XLA is a compiler library for Machine Learning. It can be used to run models efficiently on GPUs, TPUs, and on CPUs too.
XlaOp
s are used to build a computation graph. This graph can built into a
XlaComputation
. This computation can then be compiled into a PjRtLoadedExecutable
and
then this executable can be run on a PjRtClient
. Literal
values are used to represent
tensors in the host memory, and PjRtBuffer
represent views of tensors/memory on the
targeted device.
The following example illustrates how to build and run a simple computation.
ⓘ
// Create a CPU client.
let client = xla::PjRtClient::cpu()?;
// A builder object is used to store the graph of XlaOp.
let builder = xla::XlaBuilder::new("test-builder");
// Build a simple graph summing two constants.
let cst20 = xla_builder.constant_r0(20f32);
let cst22 = xla_builder.constant_r0(22f32);
let sum = (cst20 + cst22)?;
// Create a computation from the final node.
let sum = sum.build()?;
// Compile this computation for the target device and then execute it.
let result = client.compile(&sum)?;
let result = &result.execute::<xla::Literal>(&[])?;
// Retrieve the resulting value.
let result = result[0][0].to_literal_sync()?.to_vec::<f32>()?;
Structs
- A literal represent a value, typically a multi-dimensional array, stored on the host device.
- A buffer represents a view on a memory slice hosted on a device.
- A client represents a device that can be used to run some computations. A computation graph is compiled in a way that is specific to a device before it can be run.
- A device attached to a
super::PjRtClient
. - A computation is built from a root
XlaOp
. Computations are device independent and can be specialized to a given device through a compilation step.
Enums
- Main library error type.
- The primitive types supported by XLA.
S8
is a signed 1 byte integer,U32
is an unsigned 4 bytes integer, etc. - A shape specifies a primitive type as well as some array dimensions.
Traits
- A type implementing the
NativeType
trait can be directly converted to constant ops or literals.