triton.language.tensor

class triton.language.tensor(self, handle, type: dtype)

Represents an N-dimensional array of values or pointers.

tensor is the fundamental data structure in Triton programs. Most functions in triton.language operate on and return tensors.

Most of the named member functions here are duplicates of the free functions in triton.language. For example, triton.language.sqrt(x) is equivalent to x.sqrt(). An exception is to(), which has no equivalent free function.

tensor also defines most of the magic/dunder methods, so you can write x+y, x << 2, etc.

Nontrivial methods

to(self, dtype: dtype, fp_downcast_rounding: str | None = None, bitcast: bool = False)

Casts the tensor to the given dtype.

Parameters:
  • dtype – The target data type.

  • fp_downcast_rounding – The rounding mode for downcasting floating-point values. This parameter is only used when self is a floating-point tensor and dtype is a floating-point type with a smaller bitwidth. Supported values are "rtne" (round to nearest, ties to even) and "rtz" (round towards zero).

  • bitcast – If true, the tensor is bitcasted to the given dtype, instead of being casted.

Constructors

__init__(self, handle, type: dtype)

Not called by user code.

Methods

__init__(self, handle, type)

Not called by user code.

abs(self)

Forwards to abs() free function

advance(self, offsets)

Forwards to advance() free function

argmax(*self, **kwargs)

Forwards to argmax() free function

argmin(*self, **kwargs)

Forwards to argmin() free function

associative_scan(self, axis, combine_fn[, ...])

Forwards to associative_scan() free function

atomic_add(self, val[, mask, sem, scope])

Forwards to atomic_add() free function

atomic_and(self, val[, mask, sem, scope])

Forwards to atomic_and() free function

atomic_cas(self, cmp, val[, sem, scope])

Forwards to atomic_cas() free function

atomic_max(self, val[, mask, sem, scope])

Forwards to atomic_max() free function

atomic_min(self, val[, mask, sem, scope])

Forwards to atomic_min() free function

atomic_or(self, val[, mask, sem, scope])

Forwards to atomic_or() free function

atomic_xchg(self, val[, mask, sem, scope])

Forwards to atomic_xchg() free function

atomic_xor(self, val[, mask, sem, scope])

Forwards to atomic_xor() free function

broadcast_to(self, *shape)

Forwards to broadcast_to() free function

cdiv(*self, **kwargs)

Forwards to cdiv() free function

ceil(self)

Forwards to ceil() free function

cos(self)

Forwards to cos() free function

cumprod(*self, **kwargs)

Forwards to cumprod() free function

cumsum(*self, **kwargs)

Forwards to cumsum() free function

erf(self)

Forwards to erf() free function

exp(self)

Forwards to exp() free function

exp2(self)

Forwards to exp2() free function

expand_dims(self, axis)

Forwards to expand_dims() free function

flip(*self, **kwargs)

Forwards to flip() free function

floor(self)

Forwards to floor() free function

histogram(self, num_bins)

Forwards to histogram() free function

log(self)

Forwards to log() free function

log2(self)

Forwards to log2() free function

logical_and(self, other)

logical_or(self, other)

max(*self, **kwargs)

Forwards to max() free function

min(*self, **kwargs)

Forwards to min() free function

permute(self, *dims)

Forwards to permute() free function

ravel(*self, **kwargs)

Forwards to ravel() free function

reduce(self, axis, combine_fn[, keep_dims])

Forwards to reduce() free function

reshape(self, *shape[, can_reorder])

Forwards to reshape() free function

rsqrt(self)

Forwards to rsqrt() free function

sigmoid(*self, **kwargs)

Forwards to sigmoid() free function

sin(self)

Forwards to sin() free function

softmax(*self, **kwargs)

Forwards to softmax() free function

sort(*self, **kwargs)

Forwards to sort() free function

split(self)

Forwards to split() free function

sqrt(self)

Forwards to sqrt() free function

sqrt_rn(self)

Forwards to sqrt_rn() free function

store(self, value[, mask, boundary_check, ...])

Forwards to store() free function

sum(*self, **kwargs)

Forwards to sum() free function

to(self, dtype[, fp_downcast_rounding, bitcast])

Casts the tensor to the given dtype.

trans(self, *dims)

Forwards to trans() free function

view(self, *shape)

Forwards to view() free function

xor_sum(self[, axis, keep_dims])

Forwards to xor_sum() free function

Attributes

T

Transposes a 2D tensor.