triton.experimental.gluon.language.map_elementwise
- triton.experimental.gluon.language.map_elementwise(scalar_fn: Callable[[...], Tuple[tensor, ...]], *args: tensor, pack=1, _semantic=None, _generator=None)
Map a scalar function over a tensor.
The input tensors
argsare implicitly broadcasted to the same shape.This may be useful in allowing control flow over single elements in a tensor, for example a multi-branch function where one branch is more expensive. With
tl.whereyou are forced to calculate both sides of the branch, but with an if we only execute one side.@triton.jit def selu_scalar(x, alpha): if x > 0: return a else: return alpha * (tl.exp(x) - 1) @triton.jit def selu(x, alpha): return tl.map_elementwise(selu_scalar, x, alpha)
- Parameters:
scalar_fn – the function to map over.
pack – the number of elements to be processed by one function call.
- Returns:
one tensor or a tuple of tensors, depending on the mapped function.