Extending Tensor Contractions for Deep Neural Networks


During a conversation with OpenAI CTO Greg Brockman [1], Geoffrey Hinton recounted how his lab’s seminal work on AlexNet [2] wouldn’t have been possible without Alex Krizhevsky’s strong CUDA programming skills. This observation demonstrates the importance of custom compute kernels for neural network research, which since then have also enabled the use of faster algorithms for convolutions [3] and sparse attention mechanisms in Transformers [4].

Unfortunately, specialized accelerators for Deep Learning (e.g., Tensor Cores, TPUs) have become harder to program, and properly implementing novel primitives for these architectures can now take multiple months even for experts. This is worrisome, and I can’t help but wonder how many potential breakthroughs unwittingly wait out there — conceptualized by some but never tried at scale for lack of an efficient implementation.

Part of the mission of the Triton project is to give researchers better tools for exploring novel ideas. Certainly I’m not the only one with this goal, and I hold very dearly anyone else who takes part in this community effort (e.g., TVM [5], Diesel [6], PlaidML[7], MLIR [8]). But, over the past year or so, I have grown disappointed with the state of Domain Specific Languages (DSLs) for neural networks: I believe that too much research in the area focuses on re-implementing known primitives rather than exploring new ones. What researchers need is not an umpteenth way to re-implement dense matrix multiplication, but rather something flexible enough to drive innovation in the field.

The purpose of this blog post is to share some of my humble thoughts on this difficult problem, and show how tensor contractions may be extended to help in this endeavor.

Tensor Contractions

Tensor contractions are higher-dimensional generalizations of matrix multiplications. A rough example of how they work is shown in Figure 1, where two 3-dimensional tensors (A and B) are contracted over 2 reduction axes k and l to create a third tensor $C = A \otimes B$.

Figure 1: A tensor contraction such that $C_{i,j} = \sum_k\sum_l A_{i,k,l}B_{k,l,j}$

More formally, the tensor contraction between two input tensors $A$ and $B$ can be defined as the operation $C = A \otimes B$ such that:

$$C^{a_1, \cdots, a_M}_{b_1,\cdots\,b_N} = \sum_{r_1}\cdots\sum_{r_K} A^{a_1, \cdots, a_M}_{r_1,\cdots,r_k}B^{r_1,\cdots,r_k}_{b_1, \cdots, b_N}$$

Here, $a_m$’s and $b_n$’s denote indices that appear in both the result tensor and the A/B operand respectively (sometimes called external indices in the literature [9]), and $r_k$’s denote indices that are reduced over (sometimes called contraction indices). This notation can be simplified by assuming implicit summation over all contraction indices, leading to the following Einstein summation:

$$C^{a_1, \cdots, a_M}_{b_1,\cdots\,b_N} = A^{a_1, \cdots, a_M}_{r_1,\cdots,r_k}B^{r_1,\cdots,r_k}_{b_1, \cdots, b_N}$$

The implementation of this operator in the aforementioned DSLs is straightforward:

# PlaidML
C[i, j: I, J] = +(A[i, k, l] * B[k, j, l])
# Tensor Comprehensions
C(i, j) +=! A(i, k, l) * B(k, j, l)
C = sum(A[i, k] * B[j, k], axis=k)

Affine Index Transformations

Over the past several years, tensor contractions have become an important part of DSLs for DNNs. But despite their natural ability to represent batched matrix multiplication operations useful for RNNs and Transformers, their standard formulation cannot capture some other common primitives, such as convolutions.

Figure 2: Convolution $D = I * F$ such that $D^{h,w}_k = \sum_c\sum_r\sum_s I^{h+r,w+s}_cF^{r,s}_k$ cannot be expressed as a standard tensor contraction

This is because it would require tensors to be indexed at a combination of external and contraction indices. For this reason, existing DSLs have extended tensor contractions to support expressions of the form:

$$C^{a_1, \cdots, a_M}_{b_1,\cdots\,b_N} = A^{f^a_1(a_1, r_1, \cdots, r_k), \cdots, f^a_M(a_M, r_1, \cdots, r_k)}_{r_1,\cdots,r_k}B^{r_1,\cdots,r_k}_{f^b_1(b_1, r_1, \cdots, r_k), \cdots, f^b_N(b_N, r_1, \cdots, r_k)}$$

Where $f^a_m$’s and $f^b_n$’s are affine functions of their inputs. The affine nature of these functions is crucial, as it allows these compilers to rely on polyhedral machinery (e.g., Affine dialect in MLIR) or specific pre-defined scheduling primitives (e.g., TVM, Halide) for code generation.

This framework works surprisingly well for CNNs, since convolutional layers can be expressed as:

$$D^{n, h, w}_{k} = I^{n, h + r, w + s}_{c}F^{c, r, s}_{k}$$

But unfortunately, mounting evidence seems to suggest that affine index transformations fail to capture important convolution analogues (e.g., shift-convolutions) as well as some structured sparsity patterns.

Beyond Affine Index Transformations

In the remainder of this blog post, I will outline a more general class of index transformations that offer more flexibility all while remaining amenable to efficient compilation using Triton. Specifically, I will focus on expressions of the form:

$$C^{a_1, \cdots, a_M}_{b_1,\cdots\,b_N} = A^{f^a_1(a_1) + g^a_1( r_1, \cdots, r_k), \cdots, f^a_M(a_M) + g^a_M(r_1, \cdots, r_k)}_{r_1,\cdots,r_k}B^{r_1,\cdots,r_k}_{f^b_1(b_1) + g^b_1(r_1, \cdots, r_k), \cdots, f^b_N(b_N) + g^b_N(r_1, \cdots, r_k)}$$

Where $f^a_m$’s, $g^a_m$’s, $f^b_n$’s and $g^b_n$’s are arbitrary arithmetic functions of their input(s).

But before going further and explaining why and how such expressions can be efficiently compiled, I would like to spare some words on their ability to represent several important neural network operations that most existing DSLs cannot.


Depthwise-Separable Convolutions aim to factorize standard convolution operations into a depthwise component — which applies each filter independently on every channel — and a pointwise component, which accumulates each resulting pixel across every channel (effectively performing a $1 \times 1$ convolution).

Figure 3: Depthwise-Separable Convolutions and Shift Convolutions

Shift-convolutions [10] are a special case of depthwise separable convolutions whereby each depthwise filter has one single element of value $1$ and $R.S – 1$ others of value 0. In other words, every channel of the data tensor is merely shifted prior to the pointwise convolution. This paradigm can be concisely expressed as the extended tensor contraction:

$$D^{n, h, w}_{k} = I^{n, h + \text{shift}_h[c], w + \text{shift}_w[c]}_{c}F^{c}_{k}$$

Note that affine index computations wouldn’t cut it here, since e.g., $h + \text{shift}_h[c]$ is not affine with respect to c.

Sparse Convolutions

Let us now look at another case where this form of extended tensor contraction could be useful. Consider the case of a sparse 5×5 convolution with N non-zero elements in the filter’s kernel:

Figure 4: Sparse 5×5 convolution with 7 non-zero elements

The resulting sparse convolution can be represented by the extended tensor contraction:

$$D^{n, h, w}_{k} = I^{n, c, h + \text{off}_h[c, x], w + \text{off}_w[c, x]}_{c}F^{c, x}_{k}$$

Note that setting

$$\text{off}_h(c, r, s) = \text{dil}_h.r \quad \quad \text{off}_w(c, r, s) = \text{dil}_w.s$$

would represent a dilated convolution. Similarly, setting

$$\text{off}_h(c, r, s) = \text{shift}_h[c] \quad \quad \text{off}_w(c, r, s) = \text{shift}_w[c]$$

would represent a shift-convolution.

Additionally, the existence of different sparsity patterns for different input channels corresponds to the case where $\text{off}_h$ and/or $\text{off}_w$ depends on $c$. Unfortunately, every feature map is constrained to use the same sparsity structure, as $\text{off}_h$ and $\text{off}_w$ cannot depend on $k$ under the proposed extensions, for reasons that will become clearer as I show how the above class of operations can be efficiently compiled for GPUs.

Efficient Compilation

It is well known that standard tensor contractions can be scheduled and implemented as batched matrix multiplications. In fact, this is exactly how Tensorflow and PyTorch handle this operator: both input matrices are transposed and a library call to cuBLAS computes the result, which is then itself transposed before being returned.

This paradigm, usually known by the acronym TTGT (Trans, Trans, GEMM, Trans), is easy to implement but suffers from great data-movement overhead induced by unnecessary transpositions. For this reason, various fused implementations [9] have emerged over the past few years, whose behavior can be summarized in the following pseudo-code snippet:

# Aggregate batch axis
for z in range(0, Z):
  z0, z1, ..., zz = unpack(z, [Z0, Z1, ..., ZZ])
  # Aggregate row axis
  for m in range(0, M):
    m0, m1, ..., mm = unpack(m, [M0, M1, ..., MM])
    # Aggregate column axis
    for n in range(0, N):
      n0, n1, ..., nn = unpack(n, [N0, N1, ..., NN])
      acc  = 0
      # Aggregate reduction axis
      for k in range(0, K):
        k0, k1, ..., kk = unpack(k, [K0, K1, ..., KK])
        acc += A[z0,...,zz,m0,...,mm,k0,...,kk]
             * B[z0,...,zz,k0,...,kk,n0,...,nn]
      C[z0,...,zz,m0,...,mm,n0,...,nn] = acc

Where the unpack function would look like this:

def unpack(idx, sizes):
  result = []
  current = sizes[0]
  for N in sizes[1:]:
    result.append(current % N)
    current = current // N
  return result

Note that, in order to maximize data-locality, it is also important to order sizes such that aggregate axes are unpacked in order of increasing memory stride. In other words, we want the fastest growing reduction variables to match the fastest growing dimensions in memory. For example, convolutions on data with an NCHW format should lead to reduction axes being unpacked as s, r, c = unpack(K, [S, R, C]) since W, H and C are the 1st, 2nd and 3rd fastest growing axes in memory.

Problem: Index Computation Overhead

Now, it may seem easy at first sight to modify the above pseudo-code to support extended tensor contractions:

# Aggregate reduction axis
for k in range(0, K):
  k0, k1, ..., kk = unpack(k, [K0, K1, ..., KK])
  zz0 = fz0(z0) + gz0(z0, k0, ..., kk)
  zzz = fzz(zz) + fzz(zz, k0, ..., kk)
  mm0 = fm0(m0) + gm0(m0, k0, ..., kk)
  mmm = fmm(mm) + gmm(nn, k0, ..., kk)
  nn0 = fn0(n0) + gn0(n0, k0, ..., kk)
  nnn = fnn(nn) + gnn(nn, k0, ..., kk)
  acc += A[zz0,...,zzz,mm0,...,mmm,k0,...,kk]
       * B[zz0,...,zzz,k0,...,kk,nn0,...,nnn]

But this is not satisfactory, as it would incur significant overhead for index computations in the innermost loop — especially in highly compute bound regimes (e.g., Tensor Cores) where any superfluous instruction can have major consequences. Even worse, this overhead would grow linearly with the number of reduction axes, as shown below:

Figure 5: Performance of Extended Tensor Contraction decreases dramatically with increasing indexing overhead.

Solution: Pointer Increments Precomputation

Fortunately, it is possible to drastically reduce this indexing overhead thanks to the following observation: pointer increments for A and B can be pre-computed and stored in a look-up table of $K$ (the number of reduction index possibles) elements — usually small enough to fit in a GPU’s L1 cache.

Figure 6: Pointer Increments Precomputation

This is because, as shown in Figure 6, the difference between the memory address of A at two different iterations of the innermost loop (i.e., pointer increments) only depend on reduction variables. This means that they can be pre-computed, resulting in the following algorithm for the inner-loop:

acc  = 0
# Initialize pointers
ptr_a = &A[zz0,...,zzz,mm0,...,mmm,0,...,0 ]
ptr_b = &B[zz0,...,zzz,0,...,0 ,nn0,...,nnn]
for k in range(0, K):
  acc += (*ptr_a) * (*ptr_b)
  ptr_a += lookup_a[k] #pre-computed ptr_a increment
  ptr_b += lookup_b[k] #pre-computed ptr_b increment
C[z0,...,zz,m0,...,mm,n0,...,nn] = acc

Now you can see that the instruction cost of index computation has become constant: it amounts to two L1-cache loads regardless of the number of reduction axes involved or the nature of index transformations.

Figure 7: Performance of optimized tensor contractions does not depend on the number of reduction axes.

Note that there may be extra overhead for cases where index transformations intrinsically exhibit poor data-locality, and this is something I leave for future considerations. For common Deep Learning primitives, however, this does not seem to matter that much: the proposed compilation technique achieves performance within 90-98% of GEMM for a variety of common task. We can also see that, using pointer increments pre-computations, the performance of standard and extended tensor contractions is similar

Figure 8: Performance of the extended tensor contractions in Triton

Details of input shapes and a script to reproduce these results is available here.

Relationship to the Indirect Convolution Algorithm

The indirect convolution algorithm [11] has recently emerged as an efficient way to implement dense convolutions on hardware optimized for matrix multiplications. They rely on the construction of a so-called indirection buffer whose purpose is to simplify pointer arithmetics in the compute kernel’s inner loop. This is fairly similar to what I’ve presented above, except for the fact that:

  • Extended tensor contractions are more general in scope: they define a class of operations which can be implemented using the indirect convolution algorithm.
  • The indirection buffer proposed in [11] contains $\mathcal{O}(NHWCRS)$ elements, meaning that it is typically too large to fit entirely in hardware L1 caches. On the contrary, the memory footprint of the above pre-computed pointer increments is only $\mathcal{O}(CRS)$, leading to fewer cache-misses when computing indirections.

Conclusions and Code Availability

In this blog post, we have seen how tensor contractions may be extended to support a set of linear transformations currently not available in existing DSLs. We have shown how these class of operations can be efficiently compiled using a technique dubbed as pointer increments precomputation to achieve 90-95% of GEMM performance (assuming tensor cores) for common DNN operations.

All of the above is open-source and has been implemented in triton.ops.einsum:

# standard conv2d
act = einsum('nc(h+r)(w+s), ckrs -> nkhw', data, weight)     
# depthwise conv2d
act = einsum('nc(h+r)(w+s), crs -> nchw', data, weight)      
# pointwise conv2d
act = einsum('nchw, ck -> nkhw', data, weight)                   
# shift conv2d
act = einsum('nc(h+sh[c])(w+sw[c]), ck -> nkhw', sh=shift_h, sw=shift_w)
# sparse filters
act = einsum('nc(h+off_h[x])(w+off_w[x]), kcx -> nkhw', data, weight, off_h=off_h, off_w=off_w)
# batched matrix multiplication for self-attention
act = einsum('nths, hes -> nhte', a, b)


[1] Greg Brockman – #define CTO OpenAI
[2] Alex Krizhevsky, Ilya Sutskever, Geoff Hinton – ImageNet Classification With Deep Convolutional Networks
[3] Andrew Lavin, Scott Gray – Fast Algorithms for Convolutional Neural Networks
[4] Rewon Child, Scott Gray, Alec Radford, Ilya Sutskever – Generating Long Sequences with Sparse Transformers
[5] TVM – https://tvm.apache.org/
[6] Venmugil Elango, Norman Rubin,Mahesh Ravishankar, Hariharan Sandanago, Vinod Grover – Diesel: DSL for linear algebra and neural net computations on GPUs
[7] PlaidML – https://www.intel.ai/plaidml/
[8] MLIR – https://www.tensorflow.org/mlir
[9] Jinsung Kim et al. – A Code Generator for High-Performance TensorContractions on GPUs
[10] Bichen Wu et al. – Shift: A Zero FLOP, Zero Parameter Alternative to Spatial Convolutions
[11] Marat Dukhan – The Indirect Convolution Algorithm

Leave a Reply

Your email address will not be published. Required fields are marked *