.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "getting-started/tutorials/07-extern-functions.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_getting-started_tutorials_07-extern-functions.py: Libdevice (`tl.extra.libdevice`) function ============================== Triton can invoke a custom function from an external library. In this example, we will use the `libdevice` library to apply `asin` on a tensor. Please refer to `CUDA libdevice-users-guide `_ and/or `HIP device-lib source code `_ regarding the semantics of all available libdevice functions. In `libdevice.py`, we try to aggregate functions with the same computation but different data types together. For example, both `__nv_asin` and `__nv_asinf` calculate the principal value of the arc sine of the input, but `__nv_asin` operates on `double` and `__nv_asinf` operates on `float`. Triton automatically selects the correct underlying device function to invoke based on input and output types. .. GENERATED FROM PYTHON SOURCE LINES 15-17 asin Kernel ------------ .. GENERATED FROM PYTHON SOURCE LINES 17-47 .. code-block:: Python import torch import triton import triton.language as tl import inspect import os from triton.language.extra import libdevice from pathlib import Path DEVICE = triton.runtime.driver.active.get_active_torch_device() @triton.jit def asin_kernel( x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) x = libdevice.asin(x) tl.store(y_ptr + offsets, x, mask=mask) .. GENERATED FROM PYTHON SOURCE LINES 48-51 Using the default libdevice library path ----------------------------------------- We can use the default libdevice library path encoded in `triton/language/math.py` .. GENERATED FROM PYTHON SOURCE LINES 51-67 .. code-block:: Python torch.manual_seed(0) size = 98432 x = torch.rand(size, device=DEVICE) output_triton = torch.zeros(size, device=DEVICE) output_torch = torch.asin(x) assert x.is_cuda and output_triton.is_cuda n_elements = output_torch.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) print(output_torch) print(output_triton) print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(output_torch - output_triton))}') .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0') tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0') The maximum difference between torch and triton is 2.384185791015625e-07 .. GENERATED FROM PYTHON SOURCE LINES 68-71 Customize the libdevice library path ------------------------------------- We can also customize the libdevice library path by passing the path to the `libdevice` library to the `asin` kernel. .. GENERATED FROM PYTHON SOURCE LINES 71-100 .. code-block:: Python def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" current_file = inspect.getfile(inspect.currentframe()) current_dir = Path(os.path.dirname(os.path.abspath(current_file))) if is_cuda(): libdir = current_dir.parent.parent / 'third_party/nvidia/backend/lib' extern_libs = {'libdevice': str(libdir / 'libdevice.10.bc')} elif is_hip(): libdir = current_dir.parent.parent / 'third_party/amd/backend/lib' extern_libs = {} libs = ["ocml", "ockl"] for lib in libs: extern_libs[lib] = str(libdir / f'{lib}.bc') else: raise RuntimeError('unknown backend') output_triton = torch.empty_like(x) asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024, extern_libs=extern_libs) print(output_torch) print(output_triton) print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(output_torch - output_triton))}') .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0') tensor([0.4105, 0.5430, 0.0249, ..., 0.0424, 0.5351, 0.8149], device='cuda:0') The maximum difference between torch and triton is 2.384185791015625e-07 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.233 seconds) .. _sphx_glr_download_getting-started_tutorials_07-extern-functions.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 07-extern-functions.ipynb <07-extern-functions.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 07-extern-functions.py <07-extern-functions.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 07-extern-functions.zip <07-extern-functions.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_