Visualizing Banded Sparse Matrices
Banded Sparse Matrices for PyTorch¶
Working in deep learning for NLP a common statement I often hear from students is that "dense operations are fast / sparse are slow". With the exception of convolutions most of the operations we do in NLP, roughly take the form of dense matrix-matrix (i.e. attention) or matrix-vector multiplies (i.e. RNN, output embeddings).
Part of this though is just that these are the operations people have taken the time to optimize on CUDA and come built into libraries. There are a lot of sparse matrix strategies that run really fast on GPU.
In this notebook I show the implementation of one class of sparse matrix: https://en.wikipedia.org/wiki/Band_matrix
import torch
import matplotlib.pyplot as plt
!pip install -qU git+https://github.com/harvardnlp/genbmm
import genbmm
def show(bm):
bm = bm
plt.tight_layout()
fig = plt.figure()
fig.set_figheight(5)
fig.set_figwidth(5)
a = fig.add_subplot(1,2,2)
plt.imshow(bm.to_dense().cpu().detach()[0], vmin=0, vmax=1)
plt.axis('off')
fig.set_figheight(5)
fig.set_figwidth(10)
a = fig.add_subplot(1,2,1)
plt.imshow(bm.data.cpu().detach()[0], vmin=0, vmax=1)
plt.axis('off')
base = torch.arange(1, 21).view(1, 20, 1).expand(1, 20, 4).cuda() / 21.0
bm = genbmm.BandedMatrix(base, 2, 1)
show(bm)
show(bm.transpose())
show(bm.op(bm, lambda a,b : (a + b) / 2))
show(bm.op(bm.transpose(), lambda a,b : (a + b) / 2))
show(bm.multiply(bm))
show(bm.band_shift(1))
show(bm.band_shift(-1))
show(bm.col_shift(1))
show(bm.col_shift(-1))
show(bm.multiply(bm.transpose()).multiply(bm.transpose()).multiply(bm.transpose()))
show(bm.multiply(bm.band_shift(1).band_shift(1)))
There are a lot of fun applications of this style of banded matrix. In my library PyTorch-Struct we use them for very fast computation of sequence alignnments.
!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
import torch_struct
vals = torch.rand(1, 200, 200, 3).cuda()
vals[0, 180:, :] = 1e6
vals[0, :, 180:] = 1e6
q = torch_struct.Alignment(torch_struct.LogSemiring).marginals(vals.cuda())
plt.imshow(q[0].cpu().detach()[...].sum(-1), vmin=0, vmax=1.0)