.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "getting-started/tutorials/04-low-memory-dropout.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_getting-started_tutorials_04-low-memory-dropout.py: Low-Memory Dropout ================== In this tutorial, you will write a memory-efficient implementation of dropout whose state will be composed of a single int32 seed. This differs from more traditional implementations of dropout, whose state is generally composed of a bit mask tensor of the same shape as the input. In doing so, you will learn about: * The limitations of naive implementations of Dropout with PyTorch. * Parallel pseudo-random number generation in Triton. .. GENERATED FROM PYTHON SOURCE LINES 18-34 Baseline -------- The *dropout* operator was first introduced in [SRIVASTAVA2014]_ as a way to improve the performance of deep neural networks in low-data regime (i.e. regularization). It takes a vector as input and produces a vector of the same shape as output. Each scalar in the output has a probability :math:`p` of being changed to zero and otherwise it is copied from the input. This forces the network to perform well even when only :math:`1 - p` scalars from the input are available. At evaluation time we want to use the full power of the network so we set :math:`p=0`. Naively this would increase the norm of the output (which can be a bad thing, e.g. it can lead to artificial decrease in the output softmax temperature). To prevent this we multiply the output by :math:`\frac{1}{1 - p}`, which keeps the norm consistent regardless of the dropout probability. Let's first take a look at the baseline implementation. .. GENERATED FROM PYTHON SOURCE LINES 34-86 .. code-block:: Python import tabulate import torch import triton import triton.language as tl @triton.jit def _dropout( x_ptr, # pointer to the input x_keep_ptr, # pointer to a mask of 0s and 1s output_ptr, # pointer to the output n_elements, # number of elements in the `x` tensor p, # probability that an element of `x` is changed to zero BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements # Load data x = tl.load(x_ptr + offsets, mask=mask) x_keep = tl.load(x_keep_ptr + offsets, mask=mask) # The line below is the crucial part, described in the paragraph above! output = tl.where(x_keep, x / (1 - p), 0.0) # Write-back output tl.store(output_ptr + offsets, output, mask=mask) def dropout(x, x_keep, p): output = torch.empty_like(x) assert x.is_contiguous() n_elements = x.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) return output # Input tensor x = torch.randn(size=(10, )).cuda() # Dropout mask p = 0.5 x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda() # output = dropout(x, x_keep=x_keep, p=p) print(tabulate.tabulate([ ["input"] + x.tolist(), ["keep mask"] + x_keep.tolist(), ["output"] + output.tolist(), ])) .. rst-class:: sphx-glr-script-out .. code-block:: none --------- ------- --------- -------- -------- -------- -------- -------- -------- --------- --------- input 1.541 -0.293429 -2.17879 0.568431 -1.08452 -1.3986 0.403347 0.838026 -0.719258 -0.403344 keep mask 1 1 0 1 0 1 1 0 0 0 output 3.08199 -0.586858 0 1.13686 0 -2.79719 0.806694 0 0 0 --------- ------- --------- -------- -------- -------- -------- -------- -------- --------- --------- .. GENERATED FROM PYTHON SOURCE LINES 87-106 Seeded dropout -------------- The above implementation of dropout works fine, but it can be a bit awkward to deal with. Firstly we need to store the dropout mask for backpropagation. Secondly, dropout state management can get very tricky when using recompute/checkpointing (e.g. see all the notes about `preserve_rng_state` in https://pytorch.org/docs/1.9.0/checkpoint.html). In this tutorial we'll describe an alternative implementation that (1) has a smaller memory footprint; (2) requires less data movement; and (3) simplifies the management of persisting randomness across multiple invocations of the kernel. Pseudo-random number generation in Triton is simple! In this tutorial we will use the :code:`triton.language.rand` function which generates a block of uniformly distributed :code:`float32` values in [0, 1), given a seed and a block of :code:`int32` offsets. But if you need it, Triton also provides other :ref:`random number generation strategies`. .. note:: Triton's implementation of PRNG is based on the Philox algorithm (described on [SALMON2011]_). Let's put it all together. .. GENERATED FROM PYTHON SOURCE LINES 106-155 .. code-block:: Python @triton.jit def _seeded_dropout( x_ptr, output_ptr, n_elements, p, seed, BLOCK_SIZE: tl.constexpr, ): # compute memory offsets of elements handled by this instance pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) # load data from x mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) # randomly prune it random = tl.rand(seed, offsets) x_keep = random > p # write-back output = tl.where(x_keep, x / (1 - p), 0.0) tl.store(output_ptr + offsets, output, mask=mask) def seeded_dropout(x, p, seed): output = torch.empty_like(x) assert x.is_contiguous() n_elements = x.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024) return output x = torch.randn(size=(10, )).cuda() # Compare this to the baseline - dropout mask is never instantiated! output = seeded_dropout(x, p=0.5, seed=123) output2 = seeded_dropout(x, p=0.5, seed=123) output3 = seeded_dropout(x, p=0.5, seed=512) print( tabulate.tabulate([ ["input"] + x.tolist(), ["output (seed = 123)"] + output.tolist(), ["output (seed = 123)"] + output2.tolist(), ["output (seed = 512)"] + output3.tolist(), ])) .. rst-class:: sphx-glr-script-out .. code-block:: none ------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- --------- input -0.952835 0.371721 0.408716 1.42142 0.149397 -0.67086 -0.214186 -0.431969 -0.707878 -0.106434 output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868 output (seed = 123) 0 0.743443 0 0 0 -1.34172 0 0 -1.41576 -0.212868 output (seed = 512) 0 0 0.817432 2.84284 0 -1.34172 -0.428372 0 0 0 ------------------- --------- -------- -------- ------- -------- -------- --------- --------- --------- --------- .. GENERATED FROM PYTHON SOURCE LINES 156-159 Et VoilĂ ! We have a triton kernel that applies the same dropout mask provided the seed is the same! If you'd like explore further applications of pseudorandomness in GPU programming, we encourage you to explore the `triton/language/random` folder! .. GENERATED FROM PYTHON SOURCE LINES 161-167 Exercises --------- 1. Extend the kernel to operate over a matrix and use a vector of seeds - one per row. 2. Add support for striding. 3. (challenge) Implement a kernel for sparse Johnson-Lindenstrauss transform which generates the projection matrix on the fly each time using a seed. .. GENERATED FROM PYTHON SOURCE LINES 169-174 References ---------- .. [SALMON2011] John K. Salmon, Mark A. Moraes, Ron O. Dror, and David E. Shaw, "Parallel Random Numbers: As Easy as 1, 2, 3", 2011 .. [SRIVASTAVA2014] Nitish Srivastava and Geoffrey Hinton and Alex Krizhevsky and Ilya Sutskever and Ruslan Salakhutdinov, "Dropout: A Simple Way to Prevent Neural Networks from Overfitting", JMLR 2014 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.680 seconds) .. _sphx_glr_download_getting-started_tutorials_04-low-memory-dropout.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 04-low-memory-dropout.ipynb <04-low-memory-dropout.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 04-low-memory-dropout.py <04-low-memory-dropout.py>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_