Banded Sparse Matrices for PyTorch¶

Working in deep learning for NLP a common statement I often hear from students is that "dense operations are fast / sparse are slow". With the exception of convolutions most of the operations we do in NLP, roughly take the form of dense matrix-matrix (i.e. attention) or matrix-vector multiplies (i.e. RNN, output embeddings).

Part of this though is just that these are the operations people have taken the time to optimize on CUDA and come built into libraries. There are a lot of sparse matrix strategies that run really fast on GPU.

In this notebook I show the implementation of one class of sparse matrix: https://en.wikipedia.org/wiki/Band_matrix

In :
import torch
import matplotlib.pyplot as plt
!pip install -qU git+https://github.com/harvardnlp/genbmm
import genbmm
Building wheel for genbmm (setup.py) ... done
In :
def show(bm):
bm = bm
plt.tight_layout()
fig = plt.figure()
fig.set_figheight(5)
fig.set_figwidth(5)
plt.imshow(bm.to_dense().cpu().detach(), vmin=0, vmax=1)
plt.axis('off')

fig.set_figheight(5)
fig.set_figwidth(10)
plt.imshow(bm.data.cpu().detach(), vmin=0, vmax=1)
plt.axis('off')
In :
base = torch.arange(1, 21).view(1, 20, 1).expand(1, 20, 4).cuda() / 21.0
bm = genbmm.BandedMatrix(base, 2, 1)
In :
show(bm)
<Figure size 432x288 with 0 Axes> In :
show(bm.transpose())
<Figure size 432x288 with 0 Axes> In :
show(bm.op(bm, lambda a,b : (a + b) / 2))
<Figure size 432x288 with 0 Axes> In :
show(bm.op(bm.transpose(), lambda a,b : (a + b) / 2))
<Figure size 432x288 with 0 Axes> In :
show(bm.multiply(bm))
<Figure size 432x288 with 0 Axes> In :
show(bm.band_shift(1))
<Figure size 432x288 with 0 Axes> In :
show(bm.band_shift(-1))
<Figure size 432x288 with 0 Axes> In :
show(bm.col_shift(1))
<Figure size 432x288 with 0 Axes> In :
show(bm.col_shift(-1))
<Figure size 432x288 with 0 Axes> In :
show(bm.multiply(bm.transpose()).multiply(bm.transpose()).multiply(bm.transpose()))
<Figure size 432x288 with 0 Axes> In :
show(bm.multiply(bm.band_shift(1).band_shift(1)))
<Figure size 432x288 with 0 Axes> There are a lot of fun applications of this style of banded matrix. In my library PyTorch-Struct we use them for very fast computation of sequence alignnments.

In :
!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
In :
import torch_struct
vals = torch.rand(1, 200, 200, 3).cuda()
vals[0, 180:, :] = 1e6
vals[0, :, 180:] = 1e6
q = torch_struct.Alignment(torch_struct.LogSemiring).marginals(vals.cuda())
plt.imshow(q.cpu().detach()[...].sum(-1), vmin=0, vmax=1.0)
Out:
<matplotlib.image.AxesImage at 0x7fe0ceb369e8> 