# Sparse Attention

**tl;dr:**Reducing the number of calculations to compute attention.

*Sparse Attention* introduces sparse factorizations on the attention matrix. To implement this we introduce a *connectivity pattern* $S = {S_1,\dots,S_n}$. Here, $S_i$ denotes the set of indices of the input vectors to which the $i$th output vector attends. For instance, in regular $n^2$ attention every input vector attends to every output vector before it in the sequence. Remember that $d_k$ is the inner dimension of our queries and keys. Sparse Attention is given as follows

\begin{equation*} \text{attention}(Q,K,V, S_i) = \text{softmax}\Big( \frac{(Q_{S_i}) K^T_{S_i}}{\sqrt{d_k}} \Big) V_{S_i}. \end{equation*}

Here, we have defined

$$ Q_{S_i} = (W_q x_j), K_{S_i} = (W_k x_j), V_{S_i} = (W_v x_j) \text{ for } j \in S_i $$

So how do we define the set of connectivity patterns $S$? Formally, we let $S_i = A_i^{h}$ for head $h$ where $A_i^{h} \subset {j : j \leq i}$. It is still no clearer how we pick which indices we should take for a given $S_i$. The original authors consider two key criteria initially:

**Criteria 1**
We should pick $|A_i^h| \propto n^{1/H}$ where $H$ is our total number of heads. This choice is efficient as it ensures the size of the connectivity set scales well with $H$.

**Criteria 2**
All input positions are connected to output positions across $p$ steps of attention. For instance, for a pair $j \leq i$ we would like $i$ to be able to attend to $j$ through a path of locations with maximum length $p+1$. This helps us propagate signals from input to output in a constant number of steps.

We now investigate two different approaches that satisfy this criteria, and allow us to implement sparse attention.

#### Strided Attention.

We will define a factorized attention pattern in two heads. One head will attend to the previous $l$ locations, while the other head will attend to every $l$th location. We call $l$ the stride and it is chosen to be close to $\sqrt{n}$.

\begin{align} A_i^{(1)} &= {y,y+1,\dots,i} \text{ for } t = \max(0,i-l), \\ A_i^{(2)} &= {j: (i-j)\mod l \equiv 0}. \end{align}

Here, $A_i^{(1)}$ simply takes the previous $l$ locations. $A_i^{(2)}$ then takes every $l$th head from the first head where $i-j$ was divisible by $l$ without remainder. This is particularly useful where you can align the structure of your input with the stride. For instance, with a piece of music. Where our input does not have a well defined structured, we use something different. In the image below, you can see $A_i^{(1)}$ responsible for the dark blue shading and $A_i^{(2)}$ responsible for the light blue.

#### Fixed Attention.

Our goal with this approach is to allow specific cells to summarize the previous locations, and to propagate this information on to future cells.

$$ A^{(1)}_i = { j : \text{floor}(\frac{j}{l}) = \text{floor}( \frac{i}{l}) }, $$ $$ A^{(2)}_i = { j : j \mod l \in { t, t + 1, \dots, l } }, \text{ where } t = l - c \text{ and } c \text{ is a hyperparameter.} $$

These are best understood visually in my opinion. In the image below, $A_i^{(1)}$ is responsible for the dark blue shading and $A_i^{(2)}$ for the light blue shading. If we take stride, $l$ = 128 and $c=8$, then all positions greater than 128 can attend to positions $120-128$. The authors find choosing $c \in {8,16,32}$ worked well.