Flash Attention derived and coded from first principles with Triton (Python)
By Umar Jamil
Summary
## Key takeaways - **Deriving Flash Attention from First Principles**: The video aims to derive and code Flash Attention from scratch, treating the original paper as if it never existed, to foster a deep understanding of its mechanics. [08:08], [21:22] - **Optimizing Attention by Focusing on Softmax**: Flash Attention primarily focuses on optimizing the softmax computation within the attention mechanism, as matrix multiplications are already highly optimized on GPUs. [01:10:12], [07:57:53] - **Addressing Softmax Numerical Instability**: The video explains the challenge of numerical instability in softmax due to large input values and details the 'online softmax' technique to compute it safely in a single pass. [20:11:12], [38:38:43] - **GPU Architecture and Parallelism**: Understanding the GPU's architecture, with its numerous cores optimized for parallel computation and limited control units, is crucial for writing efficient CUDA and Triton kernels. [01:21:22], [01:51:07] - **Triton Kernels and Memory Management**: Triton allows writing GPU kernels using Python, abstracting away some CUDA complexities and optimizing memory access by leveraging the GPU's shared memory (SRAM) over slower global memory (HBM). [02:40:48], [02:54:26] - **Flash Attention Forward Pass Mechanics**: The forward pass involves block-wise computations, handling potential inaccuracies from local softmax computations by using an online softmax approach and a correction factor. [01:44:01], [01:46:47]
Topics Covered
- GPU Memory Hierarchy: Why Attention is IO-Bound.
- Online Softmax: Fusing Passes for Numerical Stability.
- Block Matrix Multiplication for Efficient Parallelization.
- GPU Parallelism: Mapping Work to Threads and Blocks.
- LogSumExp Trick: Optimizing Backward Pass Memory.
Full Transcript
hello guys welcome back to my Channel
today we are going to explore flesh
attention now we are going to fles
explore flesh attention from first
principle which means that not only we
will code flesh attention we will
actually derive it so we pretend that
the paper flesh attention paper never
existed and we look at the attention
competition and we look at the problem
it has and we try to solve it step by
step pretending that flashh attention
never existed this will give us a deep
understanding of how it works and also
we will combine Theory with practice
practice because we will code it now in
order to code flash attention we we will
need to write a kernel for our GPU and
in our specific case I will be using an
Nvidia GPU so Auda kernel but instead of
writing C++ code we will use Tron which
is a a way of converting python directly
into um Cuda uh kernels that can run
directly on the GPU and Tron you can
think of it as a compiler that takes in
Python and converts it into something
that can run on the GPU um so let's look
at the topics for today first of all I
will give an introduction to multi-ad
attention because we need to look at
what is attention and how it's computed
and what are the problems in Computing
this attention then we will look at
actually the most critical part of the
attention computation is this softmax
and how it impacts the computation and
uh complexity we will look at what is
online softmax then we will explore what
is the GPU because we are going to write
a kernel that will run on the GPU so we
need to understand what is the
difference for example the CPU and the
GPU and what is the kernel and how it
differ from a normal program that you
write for the CPU we will look at how
tensors are layout in memory so row
major layout column major out Etc
strides um we are going to look at block
matrix multiplication Tron software
pipeline all the optimization that Tryon
does to our code finally we will be able
to code the flash attention forward pass
but of course we are not satisfied only
by coding the forward pass we also want
to code the backward pass but in order
to code the backward pass we also need
to understand how out works and the
gradient descent works in the case of
custom operations so we need to
understand what are derivatives what are
gradients what are jauan and then we
calculate the gradient of the common
operations that we use in Flash
attention and finally we will have
enough knowledge to code the backward
pass for this reason this video is going
to be super long but I hope you don't
mind because we are going to learn a lot
of course you may be wondering all of
this requires a lot of knowledge that
you may not have but that's not a
problem because that's my problem
because in this video I will make sure
that if you only have high school
calculus so you know what are
derivatives you have basics of linear
algebra like you know what is matrix
multiplication or what is the transpose
of a matrix and you have a basic
knowledge of attention mechanism so like
for example you have watched my previous
video on the uh attention is all you
need paper and you have a lot of patient
that should be enough to understand all
of this video because all the topics
that I will introduce I will always
introduce them in such a way that I um
pretend that you don't know anything
about that topic so so we try to derive
everything from first principle
everything from scratch okay now that we
have seen the introduction let's go see
the first part of the video which is the
multihead attention all right let's talk
about multi-head attention now I am
using the slides from my previous video
attention is all you need so uh we can
look at very fast at what multi-head
attention is and how it works I hope you
remember the formula soft Max of the
query multiplied by the transpose of the
key divide by DK all multiplied by V
because we will be using that a lot
throughout the video now multi-ad
attention starts from an input sequence
or two input sequence in case we are
talking about cross attention in the
simple case of self attention we have
one input sequence which is a sequence
of in the case of language Model A
sequence of tokens where we have sec
number of tokens and each token is
represented by an embedding so a vector
with the model Dimensions the first
thing that we do is we um convert this
uh input sequence into query key and
values through three linear projections
one called WQ one called WK one called
WV which in pytorch are represented
through linear layers and these linear
layers are of D model by D model so they
do not change the shape of the input
tensor and
the after we do this job of projecting
them they become three different
sequences one called qu one called key
and one called value so here I'm calling
them Q Prime K Prime and V Prime then we
divide them into smaller embeddings so
each of this token which is made up of
the model Dimensions we divide it into
smaller tokens each one suppose we have
four heads each one will have the model
divided by four dimensions so this one
is a sequence of tokens where each token
is not the entire token but a part of
the embedding of each token and this one
is a another part of the embedding of
the tokens and this one is another part
of the embedding of the token Etc and we
do this job for the query key and value
sequence then we compute the attention
as follows so the soft Marx of the query
multiplied by the transpose of the key
divide by the um the square root of DK
where DK is the dimension of each head
so how many dimensions each head is
working with and then we do the
multiplication with v and this will give
us the output of the attention mechanism
for each head and this job is done
independently for each head this should
be clear to you if it's not please watch
my previous video on the attention
mechanism because we will be working
with this uh scenario a
lot now uh then we take this o this uh
the output of each head and then we
concatenate it back in order to get the
representation of each token uh as a
full embedding so before we split this
embedding into smaller embeddings this
one here it's called the q1 Q2 Q3 Q4
then after we compute the attention we
get back um the output of each head and
we concut it together to get back the
full embedding Dimension which is this
edge here we run it through another
linear projection called wo which will
be the output of the multi head
attention now flesh attention is not
concerned with all of these operations
actually flesh attention is only
concerned with the operation that
require uh optimization and the
operations that require optimizations
are this one so the soft Marx of the
query multiplied by the transpose of the
key divide by the square root of DK
multiplied by V which means that the
projection of the input sequence through
WQ WK and WV is not something that
flashh attention is concerned about
because that's a matrix multiplication
so when you use a linear layer it's just
a matrix multiplication of the inut with
the weight Matrix of the linear layer
and this kind of operation so the matrix
multiplication is one of the most um
optimized operation that we have in the
GPU because the manufacturer of the GPU
usually also releases um the necessary
library for computing the the the metrix
multiplication so actually these are
quite fast and they do not require any
optimization so flesh attention will
pretend that the query is has already
been passed through by WQ and the key
has already passed through WK and the V
has already passed from WB moreover
flash attention will not be concerned
with the projection with wo because
that's also a matrix multiplication
because the wo is always represented in
byor as a linear layer so it's a matrix
multiplication and matrix multiplication
as we have seen are very um optimized so
there is nothing to optimize there uh
but what we need to optimize in terms of
speed is this operation here soft marks
of the query multiply by transpose of
the keys by the square root of DK
multiplied by V all right guys so now we
have rehearsed what is multi head
attention I also want to give you a lot
of visualization which is
basically here in the paper of the multi
head attention we can see that we have
the input that is uh v um K and Q so q k
and V each of them runs through a linear
layer which is the WQ WK and
WV um then we do the scale dot product
attention which is done independently
for each head so each head will do query
multipli by the transpose of the key
divide by the square root of DEC where
each query and each key is not the full
embedding of each token but a part of
the embedding of the token because we
split them into smaller embeddings and
eventually we take all the output of
each of these head which are computed in
parallel so that's why you see this
Dimension Edge in the depth we
concatenate them and then we run them
through
wo uh what are we we concerned with we
are concerned with optimizing this
particular block here the scaled do
product attention so let's start our
journey one thing that is very important
to understand is why do we even need a
better implementation of the attention
mechanism and if you look at the flashh
attention paper you will notice The
Following part this is the paper flashh
attention one and in the flashh
attention one paper they describe the
attention Implement imple implementation
as it's done naively when using pytorch
so first we do the multiplication of the
query multiplied by the transpose of the
keys then we apply the softmax to the
output of this operation and finally we
multiply the output of the softmax with
the V Matrix to obtain the output of the
attention the way this implementation is
done by pytorch without any optimization
is as follows so we load the first of
all these tensors are residing in the
GPU the GPU is made up of two main
memories one is called the hbm which is
the dram which is the the the ram of the
GPU which is the 40 GB of the a100 for
example so it's the biggest memory that
we have in the GPU and then uh there are
there is the shared memory so the
problem of the GPU is that accessing
this hbm so the global it's also called
the global memory it's very very slow
compared to the shared memory however
the shared memory it's much much much
smaller compared to the hbm and what
they claim in the flesh attention paper
is that the operation of the attention
is IO bound meaning that if we keep
accessing the um the um Global memory
the overall uh operation of computing
the attention is not because Computing
all these operations it's slow but
because we keep accessing the global
memory which is slow so we call these
kind of operations iob bound so the only
way to improve this situation is to
compute the attention inside the shared
memory of the GPU which is much smaller
which is much closer to the cores that
actually do the computation so we we
will need to kind of also split the
attention computation into smaller
blocks that can reside in the shared
memory and we will see later in how this
is possible through block matrix
multiplication and this is in the paper
here they call it the tiling and it's a
very um uh how to say use the technique
when doing um when writing kernels for
the GPU uh which are usually involve
some kind of metrix multiplication so
now we know what problem the flashh
attention is trying to solve it's trying
to make sure that we do not need to
access the hbm so the high bandwidth
memory when Computing the uh attention
but copying only a part of each Matrix
inside the local memory so the sh shared
memory of the GPU that is closer to the
course and Computing a part of the
output Matrix there then copying that
part to the output in that is residing
in the hbm and keep doing it for all the
blocks in which we can divide these
query key and value matrices and later
we will see how this blocked computation
is done but also we will see that the
biggest problem in Computing this
blocked computation is the softmax
because the softmax needs to access
all the row of the S Matrix to apply the
soft Max
because uh the the soft Max needs to
have a normalization factor which is the
sum of all the exponentials of all the
values to which it is applied rowwise
and we will see later how we will solve
this problem so let's move
on all right guys um okay when I say
guys I mean guys and girls because I
don't know in my usually I just say guys
to you know but please girls don't feel
excluded so we saw that FL fall flashh
attention is only concerned in
optimizing this soft Max of the
transpose of soft Max of the query
multip by three divide by the square
root of DK multiplied by V and we need
to introduce a little bit of notation so
that we don't get lost um in the future
slides first of all this is the formulas
I took from the flashh attention paper
but for now we Let's Pretend flashh
attention never existed so we are trying
to solve the the problem step by step
now um we should treat this Q as
something that has as the sequence that
is the output of the input sequence that
has already passed through WQ the K as
something that has already passed
through WK and v as something that has
already passed through WV because we
don't want to optimize the matrix
multiplication because it's already fast
enough another thing is let's talk about
what are the dimensions of these
matrices so we can then understand what
is the the dimensions of the output of
these operations we will see treat Q as
a sequence of tokens with end tokens so
n tokens where each token is D has D
Dimension so lowercase D Dimensions why
because usually we take the queries and
then we split them into multiple heads
so we have we pretend we have already
done this splitting so we pretend we
already took our input sequence we
already run it through WQ and then we
have already splited into multiple heads
and each of these head will do the
following operation so the the one we
already saw and so the usual formula
where multiply the transpose of the keys
and each of this head will work with
this dimensions for the query for the
key and for the value sequence so now
let's look at the uh the the dimensions
of the output so the first operation
that we will do is the query multiply by
the transpose of the keys where the
transpose of the keys is is a matrix
that originally is n by D but become but
with the transpose will be D by n so d
by n and the result will be a matrix
that is n byn because in a matrix
multiplication the outer Dimensions
become the dimension of the output
Matrix what do we what is the next
operation that we do we take the output
of this operation so the query multiply
by transpose of the keys and we run it
through a soft Max operation and we will
see what is the soft Max
operation which preserves the shape of
the input so it doesn't change the shape
of the input Matrix it just changes the
values of it and then we take the output
of the softmax and we multiply it by V
which will change the um which will uh
change the of course the shape because
the p Matrix is n by n so this one is n
by n and V is um n by D so this one the
output will be n byd the outer
dimensions of this matrix
multiplication now let's look at the
details of each of these operations so
when we do query multiply by transpose
of the keys we will get a matrix that is
n byn where each value in this Matrix is
a dotproduct
of a row of q and a column of K in
particular the first element of this
Matrix will be the dot product of the
first query with the first key Vector
the second element will be the
dotproduct of the first query with the
second key vector and the third element
will be the first quy with the third key
etc etc and the the let's say the the
last row of this Matrix will be the dot
product of the last query with the first
key then the last query with the second
key the last query with the third key
etc etc until the last query with the
last key um you may also notice that
here I have written query transpose the
key because when we um what is q1 first
of all q1 is the first r of the query
Matrix so little bit of um background on
matrix multiplication so we know that
when we do matrix multiplication each
output element is one row of the first
Matrix with one column of the second
Matrix but we are doing the product of
the first Matrix with the transpose of
the second so it will be the do product
of the one row of the query Matrix with
one row of the key Matrix because we are
doing the multiplication with key uh K
transpose
um when you take a vector from a matrix
the usual notation so the in as as in
how to say in um in in mathematics in a
linear algebra we always pretend that a
vector is a column Vector so we cannot
just write Q multiplied by K because
that would be mean uh that would mean we
are doing the dot product of uh we are
doing the kind of the matrix
multiplication of one column Matrix with
one column Matrix that is not possible
because the shapes do not match so as a
notation we write that we do the dot
product of the first Matrix transpose
which is a column Vector but we
transpose it so it becomes a row Vector
with the second Vector this is just
because of notation guys so um you just
need to pretend that this is the first
query with the first key then the first
query with the second key the first
query with the third key etc etc etc so
we are doing do of vectors then we apply
this softmax
operation the softmax operation what it
will do it will transform each of these
dot products which are scalars so the
output of a DOT product is a scalar and
it will transform each of these numbers
in such a way that they become kind of a
probability distribution rowwise which
means that each of these numbers is
between 0er and one and when we sum up
these numbers together they are sum up
to one and this condition this property
will be valid for each row so this row
also will sum up to one this row will
sum up to one and this row will sum up
to one etc etc
etc let's see what is the softmax
operation now uh given a vector so let's
call it X which is made up of n
Dimensions the softmax is defined as
follows so it is the um the soft Max
basically transforms this transforms
this Vector into another Vector with the
same Dimension where each item of the
output Vector is calculated as follows
so the E element of the output Vector is
the exponential of the element input
element divided by the summation of all
the exponentials of all the dimensions
of the
vector basically this is called the
normalization factor to make it all
these numbers between zero and one we
usually normalize that's why it's called
the normalization factor and uh we use
the soft Max because we want each of
these numbers to be positive we don't
want the stuff the output of this
operation to be negative so that's why
we use the exponential but there is a
problem the problem is Imagine our input
Vector is made up of many numbers that
are maybe large so for example let's say
X1 is equal to 100 X2 is equal to 200 X3
is equal to 300 which is can happen um
if we do the exponential of this number
so the exponential of 100 that is going
to be a huge number it's going to very
close to Infinity uh at least compared
to what we can store in a computer so
the output of exponential of 100 may not
fit into a floating Point 32 or a
floating Point 16 number or even an
integer of 32 bit so we cannot compute
it because it will overflow our uh our
variable our integer that is storing
this value this
output so we talk in this case about
numerical instability so every time you
hear the term numerical instability in
computer science it means that the
number cannot be represented within a
fixed representation with the bits we
have available which are usually 32 bit
or 16 bit uh we have also 64bit but that
would be too expensive to use um so
let's try to find a solution to make
this stuff here computable and
numerically
stable in order to make this soft Max
operation numerically stable which means
that we want these numbers to not
explode or to become too small that they
are not representable we need to find a
solution and luckily it's quite easy so
the softmax as we have seen before it is
the following formula so each number is
exponentiated and then we divide it by
this normalization factor which is just
the sum of the exponential of each input
dimension of the input
Vector um if we multiply the numerator
and the denominator of a fraction with a
constant uh with a number number then
the fraction will not change so that's
what we are going to do we are
multiplying the numerator and this
denominator with this factor C as long
as C is not equal to zero of course uh
then we can U take this C and by using
the the distributive property of the
product with respect to the sum we can
bring this C inside of the summation as
you can see here um then we can also
write every number as the exponential of
the log of itself because the
exponential and the log will cancel out
and um then we can by using the
properties of the exponentials we know
that the product of two exponential is
equal to the sum of the is equal to the
exponential of the sum of the arguments
of each exponential and we do it on the
numerator and in the denominator then we
just call this quantity minus log c
equal to K or k is equal to minus K is
equal to log C so we can replace this
quantity with k
we can do that because this is a
constant that we have chosen and we just
are assigning it to another
constant um so basically by doing this
derivation we can see that we can sneak
in a value inside of this exponential
that if chosen carefully can reduce the
argument of this exponential and we will
choose this k equal to the maximum
element inside of the input Vector that
we are applying the soft Max to so that
each this argument will be either zero
in case x i is equal to the maximum
element that we are processing of the
vector or it will be less than zero and
we know that the exponential when it's
equal to zero will be equal to the
output of the exponential will be one so
the argument when it's zero it will be
equal to one and when it's smaller than
zero so it's in the negative range it
will be between Z and one so which is
easily representable with floating Point
32 for example so this exponential will
not explode anymore so basically to
apply the the soft Max to a vector in a
numerically safe way we need to find a k
constant which is the maximum value of
this vector and when we apply it we need
to substract each element minus this
constant that we have chosen so let's
look at the algorithm to compute the
soft Max so first of all given a vector
or given an N byn Matrix because we want
to apply the soft Max to this Matrix
here which is n byn we need to go
through each row of this Matrix and for
each row we need to find the maximum
value among the elements which takes
time complexity linear with respect to
the size of the vector to the size of
the row to which we are applying the
softmax then we need to compute the
normalization factor which is this stuff
here and we we cannot compute it before
the step number one because we need to
have the maximum element to compute this
summation here and after we have
calculated the the normalization factor
we can then divide each elements
exponential by the normalization factor
and we cannot do the step number three
before uh calculating the normalization
factor because uh we need to divide each
number by the normalization factor so if
you like a pyo code for algorithms this
is an algorithm for computing the
softmax that we have seen right now so
first we find the maximum of the row of
to which we are applying the soft Max
then we comput the normalization factor
and then we apply the soft Max to each
element which means that we calculate
comput the exponential of each element
minus the maximum value of the vector
divided by the normalization factor now
this pseo code is an algorithm that is
quite slow because look at a practical
example imagine we have this Vector here
first we need to do step one find the
maximum value in this Vector which is
number five and this takes linear time
computation then we need to calculate
the normalization constant which is the
sum of the exponential of each element
minus the maximum value so e ^ of 3 - 5
plus e to ^ of 2 - 5 etc etc and this we
will call it l and then each we need to
go again through this Vector again and
take the exponential of each element
minus the maximum divided by the uh
normalization Factor so to apply the
soft Max to an n byn Matrix we need to
go through each element of this Matrix
three times and these operations must be
done sequentially so we cannot start
operation two until we have done
operation one and we cannot start
operation three until we have done one
and
two so this is quite slow only to apply
an operation that doesn't even change
the shape of the matrix it's just uh
normal normalizing the values so there
must be a better way that that does not
involve three sequential operations in
which we need to go through this metrix
three times let's
see all right guys let's rehearse what
is the problem that we are trying to
solve the problem statement is the
following can we find a better way to
compute the softmax that does not
involve going through the vector three
times because let's look at the pseo
code of the algorithm for computing the
Lo the softmax that we have found so far
imagine we have a vector made up of four
elements the first thing that we need to
do is to compute the maximum element in
this Vector which means going through
this for Loop here that allow us to
compute the maximum element in this
Vector which means that we start from
the left side of the vector and
iteratively go to the right side so we
start from the first element arrive to
the end and we compare the previously
found maximum with the current element
to find the global maximum basically
this means that uh I I know that this is
very simple uh I'm probably sure that
you don't need to this example but
making this example will help us
understand what we will do next so
please bear with me even if it's super
simple what I'm doing um okay we at the
beginning
m0 is equal to minus infinity M1 is
basically the for loop at the iteration
number one which means that we are M1
will be equal to the maximum of the
previous estimate of the M which is
minus infinity with the current element
which is three so it will become equal
to three then M2 will be equal to the
maximum of the previously computed
maximum so M1 so three with the current
element which is two so it will be equal
to three M3 will be equal to the maximum
of the previously computed maximum so
three with the current oops three with
the current element which is five so it
will be equal to five and M4 will be
equal to the maximum of the previously
computed maximum and the current element
so it will be equal to five so this
allow us to compute the maximum element
so at the fourth iteration we will have
the maximum the global maximum
independently of what is the input
array um delete okay after we have
computed the maximum which we know is
five we can compute the normalization
factor so let's start with the l0 l0 is
equal to Z L1 will be equal to the
exponential of l0 so actually sorry it
will be l0 plus the exponential of the
current element so three minus the
maximum element we have found in the
previous for Loop so five then L2 will
be equal to L1 plus the exponential of
the uh the current element so it's two
minus the maximum then L3 will be equal
to L2 plus the exponential of um the
current element 5 - 5 then L4 will be
equal to the uh L3 + exponential of 1 -
5 if you expand this L this this will be
basically equal to E power of 3 - 5 + e
^ of 2 - 5 + e the^ of 5 - 5 + e the ^
of 1 - 1 -
5 after we have computed this
normalization Factor we can use it to
normalize the each element in the input
Vector which means that the X new X1 so
X1 Prime let's see will be equal to e to
the power of um uh what's the first
element
three - 5 divided by L that we computed
in the previous for Loop so the L at the
fourth
iteration uh the new
X2 so X2 Prime will be equal to the E
to^ of 2 - 5 / L4 and X3 Prime will be
equal to the E to power of 5 - 5 ided L4
etc etc for all the
elements uh I know this is super simple
but it will help us later so in this for
Loop we have that we need to go through
the vector three TS because first we
need to compute this for Loop then we
need to compute this for Loop and then
we need to compute another for Loop we
cannot do them not in this sequence
because in order to compute this for
Loop we need to have the maximum element
because we need it here and we cannot
compute this forloop until we have
computed the previous one because we
need to have the normalization factor
however we are stubborn and let's try to
fuse these two operations into one for
Loop which means that we go through the
array and simultaneously compute Mi and
simil in the same iteration we also try
to compute LJ of course we will not be
able to compute LJ because we don't have
the global maximum because we didn't go
through the all the um array yet however
let's try to use the locally um the
whatever estimate we have of the maximum
so far so let's try to use instead of MN
let's try to use m I so the local
maximum that we have computed so far so
if we apply the soft Max in this way in
this fused way to this Vector we will
have the following
iterations um so this is our array or
vector and the first step is MI so M1
will be equal to the previous maximum
which is minus infinity with the current
element so the maximum minus infinity
and the current element is equal to
three and L1 will be equal to the
previous l so l0 which is starts from
zero plus e to the power of the current
element minus we should be using the
global maximum but we don't have the
global maximum so let's use the whatever
maximum we have so far so we can use
three now at the second iteration we are
at this element of the vector and we
comput the maximum so far so the maximum
so far is the previous maximum and the
current element so the maximum of the
previous maximum and the current element
which is the maximum between three and
two which is three uh and the the
normalization factor is the previous
normalization Factor plus exponential of
2 - 3 which is the current element minus
whatever maximum we have so far now if
our array were made only of these two
elements so three and two then whatever
we have computed is actually correct
because the maximum that we have found
is a three and it's actually the global
maximum and the um normalization factor
that we have computed is actually
correct because each of the exponent
itial has been computed with the global
maximum because the first element was
computed using three as the uh with the
argument minus three and also the second
element was computed with the argument
with the argument having minus three in
the in the argument which is the global
maximum of the vector however when we
arrive at the third iteration so let me
delete this Vector so we arrive here at
the third iteration the maximum will
change which will also uh cause our
normalization factor to get to to be
wrong because we arrive at the element
number three uh so the number five here
and we computed the maximum so the
maximum is the comparison of the
previous maximum and the current element
so the new maximum becomes five and the
normalization factor is the previous
normalization Factor so L2 plus the
exponential of the current element minus
the current estimate of the maximum
which is five however
if you look at this L3 this is wrong why
because L3 is equal to if you expand
this summation it will be equal to e to
the power of 3 - 3 + e to power of 2 - 3
+ e to the power of 5 minus 5 this
exponential here is using five as the
global maximum this exponential here is
using three as the global maximum and
this one is using three as the global
maximum so the first two element have
been computed
thinking that the global maximum is
three but actually we later we found a
better Global maximum which is five so
which makes this normalization Factor
wrong however can we fix at the third
iteration whatever normalization we have
computed so far up to the second
iteration actually we can because if we
expand this so as we have here we have
expanded
it what we need here is here to have
minus5 because that's actually the
Global maximum that we have found so far
not the minus three that we had at the
previous iteration so and here we also
need to fix this replace this -3 with
minus5 how can we do that well if we
multiply this one here and this one here
with a correction factor that will sneak
in a new maximum inside of this
exponential then we solve the problem
and actually this correction factor is
very easy to calculate because at the
third iteration if we multiply L2 so the
previous prly computed normalization
factor with this Factor here which is
the exponential of the previous estimate
of the maximum minus the current
estimate of the maximum so five we will
see that by the properties of the
exponentials this one here will become e
to the^ of 3 - 3 + 3 - 5 so this minus 3
will cancel out with this three and also
the second Factor will have this three
will cancel out with this minus three
will cancel out with this three and they
will become e to ^ of 3 - 5 and 2 to the
power of u e to the^ of 2 - 5 which is
actually correct because at the third
iteration we should be actually have we
should be using minus5 as the maximum of
the array so far um so basically what we
have found is a way to fix whatever
normalization Factor we have computed so
far while iterating through the array
when we found we when we find a better
maximum compared to what we have so far
and when we don't need to fix anything
then the formula still stands because
what we did here as a multiplic as a
correction factor so this is the
correction factor this correction factor
is nothing more than the previous um
previous maximum so the previous
estimate of the maximum minus the
current estimates of the maximum at the
current iteration so the current
Max um so this is basically M of IUS one
and this is M of I so the current
maximum at the current iteration and let
me delete it otherwise it remains
forever in my slides um so basically
when we arrive to the last element we
will see that the maximum doesn't change
because we compare the previous maximum
with the current element which is less
than the previous maximum so the maximum
doesn't change and we don't need to fix
anything because the the the the
previous L3 so the previously computed
uh normalization factor is correct
because they have all been using the
minus5 so when we don't need to fix
anything we just multiply by e to the
power of the previous maximum minus the
current maximum which is e to the power
of zero in this case so it's not fixing
anything so we have found a way to fix
the previously computed normalization
Factor while going through the array
even if at the current iteration we
don't have the global maximum yet so
that every time the maximum changes we
can fix and every time it doesn't change
we just multiply with e to the^ of Z
which is like multiplying with one so
the new algorithm that we have found for
the softmax is the following so we start
with m0 equal to minus infinity we start
with l0 equal to Z we go through the
array we compute the locally uh the the
local maximum so up so the maximum so
far from the zeroth element to the E
element so to the elements at which we
are we are doing the
iteration and the previously computed Li
can be fixed by using this correction
factor which is e to the power of the
previous maximum minus the current
maximum plus the exponential of the
current element minus the current
estimate of the maximum in this way we
go through the array only once and we
obtain two values the global maximum at
at the end at the same time the uh the
the ization factor and then we can use
it to compute the softmax so we made
three transformed the three passes
through the array into two passes
through the array and this is very
important and we will see how we
actually use it to derive flesh
attention the example that I have given
you so far is not really a proof that
our algorithm will work in every case
because we made a very simple example by
using a vector made up of four elements
but but does our new algorithm work in
every single case with whatever the
numbers are we need to prove that so we
will prove that by induction so what
first of all what are we trying to prove
we have fused the first two for Loops
into one for loop as you can see here
what we expect is that at the end of
this for Loop this MN so the m at the
last iteration will be actually the
global maximum in the vector and this Ln
so the L at the last iteration will be
equal to the sum of all the exponential
of all the
elements minus the maximum element of
the vector so the global maximum of the
vector and we need to prove that because
what I did before was an example and
that was not really a rigorous proof and
the way we will prove it is by induction
which is a typical way of proving these
kind of um
theorems now proof by induction
basically Works in the following way we
need to prove that our algorithm work
works for a base case for example with
Nal to 1 and then we pretend we assume
that the algorithm work works on n and
we need to prove that it also works for
n + one if this holds then we have
proven our algorithm for every possible
n because it will work for the base case
so for example n equal to 1 and then by
using the induction step we say so this
if it works for n then it also works for
n + one then it means that it will also
work for two but then if it works for
two then it should also work for three
because of the induction step that we
will prove and if it works for three
then it will also work for four etc etc
up to
Infinity so let's prove it for the base
case which is n equal to 1 uh it's very
simple so uh at n equal to 1 this for
Loop will only have one iteration so M
M1 and L1 M1 will be be the maximum of
the previous M which is minus infinity
because we initialize m0 equal to minus
infinity um so it will be equal to U the
maximum of the previous M and the
current element which is X1 so it will
be equal to X1 whatever X1 is uh X1
usually will never be equal it cannot be
equal to minus infinity um because it's
a number in fixed representation so it
cannot be minus infinity um so we the
the X the M1 at the end so it will
because we have only one element n equal
to one this is M1 is also the last um M
of this of this for Loop it will be
equal to the global maximum of the
vector made up of only one element and
L1 will be equal to the previous L which
we start from zero so l0 multiply by a
correction factor which will be in this
case e to the power of minus infinity
because the correction factor is the
previous estimate of the max of the max
minus the current estimate of the max
but the previous estimate of the max is
minus infinity minus X1 it is equal to
minus infinity so this one will be this
will be canceled out and then plus e to
the^ of X1 minus the current maximum
which is X1 so
M1 and if this one will be equal to the
sum of all the elements of the vector
which is made up of only one element
minus uh the maximum element in the
array which is X1 so we have proven that
it works for n equal to 1 now we assume
that it works for n does it also work
for an array of vect or with a vector of
size n + one so let's see what happens
at the n + one iteration at the n+ one
iteration we will be doing the maximum
of the previous estimate of M which is
the m at the end iteration and the
current element so xn of + one this by
the properties of the Max uh function it
will be actually equal to the maximum of
the global Vector up to n + one uh
because um the maximum will choose
whatever is the maximum between the
previous estimate and the current
estimate and Ln + one which is the
normalization factor at the n + one
iteration will be equal to the Ln so the
previous estimate not previous by the
previous normalization Factor at at the
end iteration multiplied by the
correction factor which is the previous
maximum minus the current maximum plus
the exponential of x uh the current
element minus the current estimate of
the
maximum but
Ln we have we assume that this um
property so this algorithm works up to n
so Ln is for sure equal to the sum of
all the exponentials of the previous of
the vector up to n minus the local
maximum of the uh Vector up to the end
element which is
MN we multiply by the correction factor
if there is something to correct which
will be the previous maximum minus the
current maximum plus the exponential of
the current element minus the current
exp estimate of the
maximum now
by the properties of the exponentials so
we can bring this one inside of the uh
summation and we will see that this MN
and this MN will cancel out because it
will be exponential of XJ minus MN + MN
minus MN + one so this MN and this MN
will cancel out and we obtain this one
plus this Factor here that remains
unchanged however you can see that this
this stuff here is exactly the in the
argument of this summation for the at
the iteration n + one so it is this one
is e to the power of XJ where J is going
from 1 to n minus MN + 1 + e ^ of x n +
1 - MN + 1 so the J only appears here
and it's equal maximum to n and this is
similar to being a j with n + one so we
can increase increase the index of this
summation by one and it will be the same
um and it will result in the same
summation so we have proven that also at
the n + one iteration we will have that
the L will be equal to the sum of all
the elements of the array the
exponential of all the elements of the
array up to the n+ one
element uh minus the maximum up to the n
+ one element so we have proven that if
it works for n then it also works for n
+ one um this is enough to prove that it
works for all sides of
arrays um don't worry if you didn't get
uh the proof by induction it is uh if
it's the first time you are seeing this
kind of proof it may take a little bit
to to get it um if you want to learn a
bit more about proof by induction I
recommend watching some other proof it's
very simple it's just you need to get
into the right mindset anyway let's move
forward
all right let's talk about block matrix
multiplication I know that you want to
jump to the code immediately and we will
go there we just need a little more
Theory actually so imagine we are doing
a matrix multiplication so we have a
matrix a we want to multiply it with a
matrix B and it will produce an output
Matrix C imagine the dimensions of the
first Matrix are M by K the second
Matrix is a k by n it will produce an
output Matrix that is m by n now imagine
we want to parallelize the computation
of this output Matrix I know that I
didn't talk about gpus yet so we will
not talk about gpus we will talk about
parallelization in the case of a
multicore CPU with which you are very
probably familiar with because right now
in nowadays when you buy a computer you
have a CPU and usually you can buy a
single core CPU or multie like a two
core four core eight core etc etc each
of the these cores are actually kind of
small CPUs inside your CPU that can
execute operations in parallel how to
parallelize the matrix multiplication
imagine you have this matrix
multiplication to parallelize each of
the output element in this C Matrix is a
DOT product of a row of the a matrix
with a column of the B Matrix for
example this element on the top left is
the dotproduct of the first row of a and
the First Column of B this element on
the top right of C is the dot product of
the first row of a and the last column
of B this element on the bottom left is
the dot product of the last row of a and
the First Column of B etc etc for all
the other elements now to parallelize
this computation we need as many cores
as is as there are elements in C if we
want to parallelize it uh so if M and N
are very small then maybe we have enough
course but imagine M and N are quite big
we imagine like 100 by 100 we don't have
10,000 cores right now in the CPUs so
how can we parallelize a matrix um
operation by using less cores than there
are elements in The Matrix itself that's
when we talk about block matrix
multiplication basically block matrix
multiplication means that you can divide
the original Matrix into smaller blocks
of of elements and then the operations
of matrix multiplication can be computed
between these blocks for example imagine
we have a matrix that is 8x4 it means
that it has eight
rows and four columns which means that
it has 32 elements and then we are
multiplying it with another Matrix that
is 4X 8 so it has four rows and eight
columns so it also has a 3 two elements
the output Matrix will should have 64
elements we don't have 64 cores so how
can we parallelize it imagine we only
have eight cores now with eight cores we
can divide this original Matrix a into
four blocks where the first block is
this top left block of two by no 4X two
elements so um uh how to say um eight
elements on the top top left and then
eight elements on the top right of this
Matrix then eight elements on the bottom
left and eight elements in the bottom
right of this Matrix these are four
blocks then we divide also the B Matrix
into um eight blocks where each block is
made up of four elements so this b11 is
the top left four elements in the
original Matrix this B4 is the top right
four elements in the original Matrix
this B21 is the
um bottom left for elements in the
origin etc etc etc how do we do this
block matrix multiplication we can watch
these matrices as made only by their
blocks so we can view this Matrix here
as made up only by its blocks we can
view this Matrix here as made up only by
its blocks and the output of this
multiplication will be a
matrices that is computed in the same
way as the original Matrix
but where the output of each dot product
will not be a single element of the
output Matrix but it will be a block of
elements of the output Matrix for
example the top left block here is the
dot product of the first row of this
Matrix with the First Column of this
Matrix and it will be computed as
follows so it will be a11 * b11 plus a12
* B21 and this output will not be a
single scalar but it will be a uh well
let me count it should be uh eight
elements so it should be four um made up
it's it should be a block of four
elements um or eight elements let me let
me count actually so because we have
eight blocks and it should be made up of
eight elements uh let's we can see that
here um how to find the dimensions of
this output block
uh well we can check what is a11 a11 is
4X two so it's eight elements in as a
smaller Matrix made made up of eight
elements where the elements are
distributed in four rows and two columns
we are multiplying it by b11 which is a
smaller Matrix compared to the original
made up of 2x two elements so four
elements so when we multiply a 4x2
multiply by 2x two it will produce a 4x2
output uh block Matrix so
block so if we do this computation here
block by block it will produce a block
of output elements of the original
Matrix so not not a single scalar but a
block of outputs which makes it very
easy to parallelize because if we have
only eight cores we can assign each
output block to One Core and each core
will not produce one output element of
the original Matrix but it will produce
eight elements of the original Matrix as
a 4x2
Matrix um so basically block Matrix
allow us to um uh to do the matrix
multiplication either by element by
element so like in the original Matrix
so each row with each column or blocks
by blocks in the same way like we do
normal matrix multiplication because the
the matrix multiplication that we are
doing between blocks is the same way as
we do matrix multiplication with the
original Matrix and it will produce not
a scolar but a block and now let's see
why this is very important for us
so why should we care about block matrix
multiplication because we are trying to
compute the following operation so the
query multiplied by the transpose of the
keys and then we will should apply the
soft Marx of this operation and then we
should multiply the output of the
softmax with V for now let's ignore the
soft Marx let's pretend that we are not
going to apply any softmax so we take
the output of the query multiplied by
the transpose of the keys and we just
multiply it by V to obtain the output of
the attention which is wrong of course
but it simplifies our tract of what we
are going to do next so for for this
moment let's pretend that we are not
going to apply any soft Max so we just
do the query multiply by transpose of
the keys and directly we multiply the
result of this operation with v this
will result in a matrix that is n by D
so n tokens each made up of an embedding
of D Dimensions so lower case D
dimensions and we know that query key
and values are themselves matrices of n
by D Dimensions so the um um n tokens
which made up of embedding of the
dimensions so imagine we have a query
Matrix and the key and the value Matrix
that are 8 by 128 so we have eight
tokens each token is made up of 128
Dimensions we can divide as we have seen
each when we compute matrix
multiplication we can divide our um
Matrix into blocks how we choose the
blocks is up to us as long as the oper
the the shapes of the blocks match when
doing the matrix multiplication so for
example in the previous case we divided
our Matrix a into blocks such that the
the the shape of the block Matrix so the
Matrix that is made up only of the
blocks is compatible with the block
Matrix B so that this operation is
possible so this is the only requirement
that we need to be aware when doing the
matrix multiplication the shapes of the
blocked Matrix so the Matrix that is
made only of the blocks should match in
the matrix multiplication for the rest
it doesn't matter how we divide it so
imagine that we choose to divide this
query Matrix into blocks of rows and we
can do that we don't have to necessarily
divide also the columns we can just
divide the rows so that each Q is not a
single row but it's a group of two rows
so q1 is a group of the first two rows
of the Q Matrix of the Q sequence Q2 is
the group of the second two rows of the
Q sequence etc etc and we do the same
also for V for K we don't do it because
we are actually going to multiply with K
transposed so we do this subdivision
directly on K transposed so so we have
the Q which has been divided into groups
of rows and then we have a k transposed
which is a matrix that is 108 by 8
because it's the transpose of the Keys
which is 8 by
108 and we decide to divide each of the
column group of columns of K into a
single block so the K1 is the first two
columns of K transposed K2 is the second
group of two columns in K transposed etc
etc until K4 which is the last two
columns in K transposed the first
operation that we do is the
multiplication query multiply by the
transpose of the keys which basically
means that we need to multiply each
query with all the keys then the second
query with the all the keys etc etc now
each query is not a single row of the Q
sequence it's a group of two rows of the
Q sequence and each K is not a single
column of K transposed it's a group of
two columns of K transposed but doesn't
matter because we have seen that the
matrix multiplication if we write the
matrixes as made up of blocks we just
compute it in the same way when we do uh
normal matrix multiplication ation so we
are multiplying this matrix by this
Matrix and for what we know this Matrix
here is made up of four rows with some
Dimensions which is uh 128 dimensions
and this one here is made up of uh how
many rows 128 rows and four uh
columns I didn't uh draw The Columns
because it's too many to draw here but
you need to pretend it's a lot of
Dimensions one for each 128 for each
vector and here you need to pretend that
this is 128 rows when we do the matrix
multiplication we apply the normal M
matrix multiplication procedure which is
each output element so this first of all
the output shape of this Matrix of this
matrix multiplication will be 4x4
because it's the outer dimensions of the
two Matrix that they are
multiplying the first element of the
output will be the dot product of this
Vector here with this Vector here the
second element so this one here will be
done dot product of this Vector here
with this Vector here however this is
not vector and this is not a vector so
it's actually a matrix multiplication in
this case this element here is not a
scalar it is a group of elements of the
output Matrix because we are doing block
matrix multiplication and how many
elements it will be well we know that
the original q1 is a 2 by
128 the K1 is 108x 2 so it will be a
group of 2x two elements of the output
Matrix so we are doing the matrix
multiplication of the q1 with K1 then q1
with K2 then q1 with K3 q1 with K4 etc
etc for the first row and then the
second row will be Q2 with all the K and
the Q3 with all the ks and Q4 with all
the ks so as you can see when we do
matrix multiplication we don't even care
if what is underlying is a block or a
vector or a scalar we just apply the
same procedure for first um row of the
Black Block matrix multiplication with
the First Column of the Matrix M of the
second Matrix uh and then the first row
with the second column the first row
with the third column etc
etc let's then multiply because the
formula says that we need to multiply
query with the transpose of the keys and
then multiply by V all of these are um
block matrices now as you can see from
my using of colors every time time I
refer to the original Matrix I use the
blue color and every time I refer to the
block Matrix I use the pink color so we
need to multiply the output of the query
multiplied by the transpose of the key
then by V because we are skipping for
now the soft Max and later we will see
why so if we want to do this
multiplication we need to do the
following so it will be uh this Matrix
is made up of blocks and block matrix
multiplication just ignores this fact
and just does the matrix multiplication
like it is a normal matrix
multiplication so we do the first row
with the First Column then the first row
with the second column then the third
row the first row with the third column
etc etc so the first block of row how is
going to be calculated this out um this
output in the output Matrix of this
matrix multiplication well it will be
the um the the first row so the dot
product of the first row the dot product
because it's not really Dot product it's
the actually the matrix multiplication
of the first row but in a dotproduct way
let's say uh with the First Column which
is made up of V1 vs2 V3 and V4 so it
will be this element with V1 plus this
element with vs2 plus this element with
V3 plus this element with V4 uh and this
will be the first output element the
second output uh block will be this row
with this
column uh which will be this element
with V1 this element plus this element
with V2 plus this element with V3 plus
this element with V4 and this will
produce the second output
block etc etc also for the third and the
fourth block output let's look at what
is each block made up of so each block
is made up of the um the first element
so query 1 multip by qy1 because um
it's the result of the query multiply by
the keys with the V1 of the second
Matrix plus the this element with this
one plus this element with this one plus
this element with this one so the pelo
code for generating this output of this
attention mechanism which is not really
attention mechanism because we skip the
soft Max but I just want you to get into
the habit of thinking in terms of blocks
is the following so we take each query
lock um we go through each
query and as you can see let's look at
actually what this output is made up it
is made up of the query one multiplied
by key1 and the result multiply by V1
then the query one with K2 then the
result multiply by V2 then the query one
with the K3 and the result multiply by
V3 plus the query one with the K4 and
result multiplied by V4 this is
basically what we are doing is the dot
product of this row with this column
made up of
blocks um so the the pelo code for
generating this first row is the query
is the query number one and then we
iterate through the keys and the values
from one to four and we sum iteratively
so for each block basically to generate
this output Matrix and if you um for
each row we will see that it's a
different query with all the keys and
values and then this will be the the the
query number number three with all the
keys and values and this will be the
query four with all the keys and values
so to generate this output Matrix we
need to do we iterate through the
queries and this will be one row of this
output Matrix and then we need to do
this iterative sum of the query I that
we are iterating through multiply by the
J K and V and we keep summing them
iteratively and that would that will
produce the output Matrix or you can see
here I know that what I have done so far
is not useless not useful for Flash
attention but it's useful for us to get
into the mindset of computing this
product by blocks because later we will
use it also with the
softmax all right guys I I know that we
have comput what we have computed so far
is not really the soft Max operation is
not sorry the really the attention
mechanism because we have skipped the
soft Max so somehow we need to restore
it and the the following few I think
think 10 20 minutes we are going to be
really really challenging because I am
going to do a lot of operations that
will involve a lot of different blocks
and a lot of different metrix
multiplication and variants of the
softmax so it may be difficult to follow
however don't give up you can watch this
part twice three times and every time
you it will have a better understanding
I also recommend watch it until we reach
the flesh attention algorithm before we
start restarting from to to go back to
to rewatch it because you watch it we
reach the flesh attention algorithm and
it will give you a better understanding
of what have happened so far and then
you can rewatch it to deepen your
understanding another thing that I
recommend is take pen and paper and
write exactly the operations that you
are seeing and write the shapes of each
of these blocks of these elements that
are made in that are part in this um
Matrix
multiplications so that you better uh
understand what is happening and you
better remember what when I refer to a
particular element or a particular block
okay after giving this small uh
motivational speech Let's Start so what
we have done so far was query multipli
by the transpose of the keys however
each query is not a single row of the
query sequence but it's a block of
queries it's a block of rows in our
particular case this q1 is not one row
of the query sequence it's two rows of
the query sequence because we have
chosen as a block size a group of two
rows and this K transposed one is not
one column of the K transposed Matrix is
two columns of the K transposed Matrix
because we have chosen it like this and
if you don't remember let's go back to
see it uh here we have chosen K1 is 2
two columns and q1 is two rows of the
query original Matrix and every time I
use the blue color I am referring to the
original shape and every time I'm using
the pink or Violet whatever it is I am
referring to the block metrix so it's a
block of elements of the original
Matrix okay now uh the first thing that
we have done was query multiplied by the
transpose of the keys and this produces
a block patrix as output that we will
call S where each element s j so the S11
element of this Matrix will be the query
one with K transposed one the S12 will
be query one with K transposed 2 s13
will be query one with K transposed
three etc etc for all the rows and for
all the columns then we should be
applying the soft Max because if you
remember the formula is soft Max of the
query multiplied by the transpose of the
keys however I want to restore the
softmax operation but with a Twist which
means that we will apply the simplified
version of the softmax and we will call
it softmax star which is just the
softmax without the normalization so let
me write it for you what it
means uh let's do it with the same color
that I chose for the softmax which is
orange so the soft max if you remember
correctly if we remember it's the soft
Max
of to of a vector we apply it element
wise so each element is modified
according to the following formula so
the E element of the output Vector to
which we are applying the softmax is
equal to the
exponential of the E element of the
input Vector minus the maximum element
in the input Vector divided by a
normalization factor that is calculated
according to this
summation that is going from Jal to 1 up
to n of the
exponential of x i minus x max so
basically uh we are doing the
exponential of each element minus this x
Max and why are if you remember
correctly why are we subtracting this x
max to make this exponential numerically
stable computable because otherwise it
will explode and because we are applying
it to the numerator we also need to
apply it to the denominator okay the
softmax start operation is exactly like
the softmax but without the
normalization part which means that it's
just the numerator of the soft Max so we
will modify each element of the vector
to which we apply the soft Max star
according to this formula let me move it
more aligned like this so we just do
element element wise uh operation that
is the exponential of each element minus
the maximum to the of the vector to
which we are applying soft Max star
okay now why did I introduce this soft
maxar operation because we will be
applying it to the Matrix that we have
computed so far which is this s Matrix
so we applied soft Max star to each
element of this s Matrix but each
element of this s Matrix is itself a
matrix because it's a block Matrix and
each element of this s Matrix so for
example the element S11 is a 2x2 matrix
because it is coming from the product of
two matrices which are a group of rows
and a group of columns from the Q and
the K uh so for example this S11 is what
is um let's draw it actually this S11
will be for example made up of four
elements let's call it uh I don't know a
of
S11 uh let's let's choose better naming
let's call it I don't know
a uh
b c and d just the generic elements when
we apply the soft Max star to this S11
it will result so let's apply the soft
Max
star soft Max
star it will result in a
matrix that
is each element the exponential of each
element minus the maximum for each row
now we don't know which is the maximum
so let's choose one suppose that's the
maximum for this row is a and the
maximum for this row is D the first
element of the output of this soft Max
star applied to this block S11 will be
the
exponential of a multi minus a because
that's what we chose as the maximum for
this row the second element will be the
exponential of B minus a because it's
the maximum for that row then in the
bottom row it will be the exponential of
uh C minus D because that's the maximum
for the bottom row and this will be the
exponential of D minus D and that's the
exponential that's how the softmax star
will modify each block in this block
Matrix let me delete this stuff
otherwise it will remain in my slides
forever and later I want to share the
slides with you guys so you can use my
same slides so delete delete delete okay
after we have applied the soft Max to
each of the elements in this s Matrix we
will call it the P Matrix and each
element P11 will again be block of 2x
two
elements um so P11 will be the soft Max
so P11 will be the soft Max star applied
to S11 where S11 is what is query one k
transposed one and the p12 will be the
soft Max star applied to S12 where S12
is what is query one
multiplied by K transpose the two etc
etc etc for all the elements of
s okay now that we have applied this
softmax star operation the next
operation that we should be doing
according to the formula of the
attention is uh the softmax of the query
multiplied by the transpose of the keys
then the result of the softmax
multiplied by V I know that we didn't
apply the real soft Max we applied
softmax star which is softmax without
the normalization later we will see how
to compens at this lack of normalization
because we will do it at the end and
it's something that we can
do okay so we take this P Matrix which
is the result of the soft Max star
applied to this s Matrix and we multiply
it by
V what how do we do it well it's a block
it's a matrix made up of blocks of
matrices um so P11 is actually not a
scolar but it's a matrix of 2x two
elements and we need to multiplied by V
but we don't multiply with the original
sequence V but with the blocked sequence
V just like before where each V is not
one row of V but it's a group of rows of
v and how many rows is it is it is two
rows of V uh for now please ignore
completely whatever I have written here
because we will use it later so we need
to do this product of this Matrix here
which is made up of blocks remember with
this Matrix here which is made up of
blocks it is made up of four
four rows where each row is not really a
row it is a block of rows and this one
it is made up of 4x4 elements where each
element is not really a scolar but it's
a
matrix so as you remember in the block
matrix multiplication when the algorithm
for computing the matrix multiplication
is the same as the normal matrix
multiplication except that we use blocks
so what I am doing is guys uh the
following operation so let's write it
somewhere let's say o is equal to P
multip by V okay so um the first output
um row row because it's not really a row
but it's a block row uh will be computed
as follows the first row of this block
Matrix with the First with um the First
Column of this V Matrix and um we are
treating it uh like a block Matrix so it
will be P11 multip by V1 plus p12 ultip
by vs2 plus p13 * by V3 plus p14
multipli by
V4 this will produce the first output
row of O but it's not really a row
because it's a made up of two rows uh so
this stuff here is not one row it is two
row and we can prove that because what
is P11 P11 is let's write it somewhere
so P11 is a 2X two Matrix uh yeah 2 by
two and we are multiplying it with V1
which is a block of two rows of V so it
is a two rows by 128 Dimensions so it is
equal to 2x
128 so this stuff here is 2 by
128 so this block here the output block
that we are Computing is a block of two
rows of the output Matrix that we are
Computing uh I know this is really uh
difficult to follow because we are
involving blocks so we need to visualize
at the same time metrics as blocks and
as the original metrix that's why I
highly recommend you to pause the video
think it's through write down whatever
you need to write down because it's not
easy to find follow it just by
memorizing the shapes so you you
actually need to write down things
anyway we are Compu the first output
block of the output o Matrix now if we
if you remember uh the output um the
output this output here should be the
output of the output of the soft Max
multiplied by V now this soft Max has
not been applied to the entire row of
this Matrix here s Matrix here but
basically to compute this soft Max star
what we did was to compute the soft Max
star at each block independently from
the other blocks which means that the
maximum that we are using to compute
each soft Max star is not the global
maximum for the row of this s Matrix but
the local maximum of each block and this
is wrong actually because when we
compute the soft Max uh we apply the
soft Max we should be using the global
row I want to give you an example
without using blocks because otherwise I
think it's not easy to follow so when we
do the normal attention so we have a
query multiplied by the transpose of the
keys this produces a matrix that is n by
n so sequence by sequence where each
element of this Matrix so let's say
three four five I don't know how many is
one two three four five six yeah six
two 3 four and five six should be 1 2 3
4 5 6 okay this one here should be the
dot product of the first query with the
uh first um let me use because query one
transpose the
key1 uh this is because as I said before
when we do the product of two vectors we
always treat them as column vectors so
when you want to write the DOT product
you cannot multiply two column vectors
you need to multiply one row Vector with
one column Vector that's why we
transpose this one if it confuses you
you can also write q1 K1 that's totally
fine it's just uh wrong from a notation
point of view anyway the first one will
be the dot product of the query one with
K1 the second element will be the dot
product of the query one with K2 the
third will be the query one with K3 etc
etc etc um
so this is q1 with K1 q1 with
K2 K2 and q1 with K3 q1 with
K4 um
anyway when we do the soft Max we
actually calculate the maximum on this
entire row however what we are doing is
we are actually doing a block matrix
multiplication and as you remember
um when we do by blocks we are grouping
together rows of queries and rows of
keys and in this particular case we are
grouping the two queries together to
create one uh one group of queries and
two keys together to create one block of
keys so we need another row of this one
so it's the let me ose query one k or
query 2 K1 this should be query 2 2
K1 query 2 K2 query 2 K3 query 2 K 4
query 2 K5 and query 2 K
6 um when we each of this each of this
block here is
Computing this block here is Computing 2
by two elements of the original Matrix
if we had never applied the blocks so it
is
Computing these two four elements
here and if we apply the soft Max star
to each of these blocks we are not using
the maximum element in this row we are
only using the maximum element in each
block which means that when we will use
it in the downstream product with v
Matrix we will be summing values that
are wrong because each of these values
here will be based on a maximum that is
not the global maximum for this row it
is the local maximum of this block here
and um and this block here will have the
global the it will use the local maximum
of this block here and this block here
will use the local maximum of this block
here etc etc etc so what I'm trying to
say is that when you sum P11 with V1 P11
may have some Maximum local maximum that
is different than from the local maximum
of p12 and p13 may have a different
maximum local maximum that of P1 P11 and
p12 so we need to find a way to fix the
maximum that was used to compute the
exponential here with the maximum found
here in case the maximum here is higher
than the one local to P11 so if we have
found for example here a Max maximum
that is higher than the maximum used
here here then we need to fix this one
and this one because that maximum in the
soft Max should be the maximum for all
the row not the one belonging to the
each
block and this leads to our next step
how to fix
this first of all let me introduce a
little pseo code for computing this
output Matrix here which is an output
block Matrix and later we will use this
pelo code to adjust the error that we
have made in some blocks in case the
future blocks so the p13 has a better
maximum than P11 or p12 so to compute
this output Matrix
o we go through so for example to
compute the first row we choose well P11
is what is is um let's go back P11 is
let me delete also this one it's not
needed anymore P11 is the soft Max star
of q1 K1 p12 2 is the um soft maxar of
q1 K2 p13 is the soft maxar of q1 K3 p14
is the soft Mark star of q1
K4 um which means that uh to compute
this block here here we first need to
compute the P11 what is P11 well P11 is
the soft Max star of a block of Q and
another block of k k which in the case
of the first first row of the output
Matrix means that it is the query one
with the soft maxar of the query one
with key1 the soft maxar of the query 1
with K2 the soft maxar of the query 1
with K3 the soft maxar of the query 1
with K4 which means that we need to go
we need to make a for Loop through all
the keys while keeping the query fixed
so to compute the first output row we
need to do the soft maxar to produce P11
we need to do the soft ma out of query 1
K1 and um we sum it initially to zeros
because we don't um we need to
initialize our output somehow and we
initialize it with
zeros then we sum the next P1 to which
is the query one with the K2 and then we
sum the next p13 which is the query one
with the K3 etc etc that's why we have
this inner loop
here all right so however this output
that we Computing is wrong because I
told you we have computed the softmax
star using a statistics the maximum
value that is belonging to each block
and not the one that is the overall row
of the original Matrix how to fix that
we have a tool actually we have computed
before an algorithm called the online
softmax I don't know if I referred to it
before as the online softmax but it's
called the online softmax that allows to
fix previous iterations when we are
Computing the current iteration
Bas how well let's review the online
softmax we start imagine we are working
with one single Vector so we are a
vector made up of n elements the what we
do is we do a for Loop where we compute
iteratively the maximum up to the he
element and we fix the normalization
factor uh computed in previous iteration
in case we found a better maximum at the
current element if this is not clear
guys go go back and watch the online
soft Mar because this is very important
because this is what we are going to use
to fix this P11 p12 blocks in case we
found better maximum in p13 or p14
Etc so let's see how to apply this
online soft Marx to this case here so
that we can compute so you may be
wondering why are we going through all
these troubles I mean
why the real reason is when first of all
why did did we introduce block matrix
multiplication because we want to
compute matrix multiplication in
parallel so you can think that each of
this P11 because they are independent
from each other and because each of them
are using the maximum belonging to each
block they can be computed independently
from each other then however we need to
somehow aggregate their value and to
aggregate their value we need to fix the
values that have been calculated
independently because we didn't when
Computing values independently we don't
have a global view we have local view so
we compute local blocks P11 p12 p13 etc
etc and then when we aggregate these
values we need to fix them so that's why
we are trying to uh come up with this
system of fixing uh values that have
been calculated
independently so how to fix this let's
look at the following algorithm first of
all um this O Block here as I said
before it is a block of two rows
where each row is made up of 128
dimensions and we have seen that before
by checking the dimensions of P11 and V1
the result of P11 V1 which means that
for each output block we need to take
care of two maximums uh and two
normalization factors so up to now I
didn't use the normalization factor we
said that we are applying softmax star
which is the soft Max without the
normalization but eventually we will
need to compute this normalization so we
want to create an an algorithm that
fixes the maximum used to compute each
of this P11 and also computes
simultaneously the normalization factor
and at the end we will apply this
normalization factor and the way we will
do it is as follows we start with
initializing the maximum to minus
infinity one for each row that we are
Computing so is our output block is made
up of two rows so we need one maximum
for the top row and one maximum for the
bottom row and also the normalization
factor which we initialize with zero
because we didn't sum anything for now
and the output we initialize it with all
zeros because we didn't sum anything to
this output for now we
compute the we uh to compute the output
row so this output block here so this
output block here we need to go through
all the keys uh to produce this P11 p12
p13 p14 while the query is the query
number one the query block number one so
the first step that we do is we compute
the maximum of the first block
P11 which is the row Max so the maximum
for each row of the block um uh uh q1 K1
this is not P11 it's S1 sorry guys this
is
S11 so we compute the maximum of this
one and we call it actually S1 as you
can see
here um then we can um calculate P11
which is the soft Max star which is the
exponential of the query multip query 1
K1 so S1 minus the maximum in the local
group
S1 and we add it to our
output for now the output is initialized
with zero so for now ignore this part
here I will explain it later so for now
o1 should be equal only to P11 V1
now at the step number
two we may find in the local group S12
so this one is
S12 we may find a better maximum for the
top row and the bottom row and this
maximum is the M2 which may be better
than the previous maximum for each of
these two row but may also not be so we
need to find a way to fix in case it's
better and to not fix anything in case
it's not better
and the way we do it is this so we
compute the new maximum of the current
local row query
2 we um calculated the p12 which is the
soft Max star of S2 which is S2 minus M2
which is the local maximum and then we
need to add it to the output however in
this case we may have found a better
maximum so how to fix the o1 which on
only use the maximum that was local to
S1 well we know that we can fix that by
using exponentials because each of these
element of o1 is just an exponential
without the normalization because we are
applying soft Max star so how to fix an
exponential with another exponential so
basically we are saying that uh we
multiply o1 which is a matrix so let me
show you what is this Matrix so o1
is a matrix made up of two rows so as
you can see here I have the shape of o1
it's a 2 by 128
Matrix so this is the top row so 011 012
blah blah blah until
o1
128 then
o21 o22 up blah and
o21
128 we need to fix this value how we
basically
just using the exponential that we have
used in the online softmax that we have
seen before so if we multiply this
Matrix here by a diagonal matrix that is
made as follows it's a diagonal matrix
made up of two elements because the
exponential of M1 minus M2 will be a
vector of two elements and the
exponential of a element y exponential
is another Vector of two elements and
this diag basically means that diagonal
Matrix where in the diagonal we have the
elements of the vector to which we are
applying this diag operation which means
that this value here will be the
exponential of the first element of
M1 so let me show you uh how to write it
exponential
of M1 minus M2 minus M2 so the first
element so let's call it one here here
is a zero here will be zero and let's
delete this one and we write another one
here
exponential M1 minus M2 but the second
element of this Vector so basically the
DI this notation here diag means
basically just take the vector and
distribute it over a n byn Matrix where
n is the size of the vector to which it
is applied and all the other elements of
this Matrix should be zeros this is what
this
means if we do this operation here we
will see that the output of this
multiplication will fix each element of
the top row using this exponential and
the bottom row with this exponential
which will basically cancel out this M1
that was computed in the previous
iteration and introduce the M2 that we
have computed in the current iteration
in each of these o elements in this o
block
Matrix okay so this output will be this
element will will multiply by this one
so it will fix o11 with this Factor here
and o1 21 will not be fixed by will be
multiplied by zero so it will not
contribute to this first output element
so this element here will only depend on
o11 fixed by the exponential of M1 minus
M2 but the first element of this vector
and then o12 will also be fixed by um
o12 will be fixed by this exponential
here but not by this one and all the
dimensions of the first row will be
fixed by this exponential and all the
dimensions of the second uh row here
will be fixed by this uh exponential
here this this scalar here which is the
second element of the vector X of M1
minus
M2 or okay it was really challenging
this one so so what we are doing is we
compute p12 and we fix all the elements
in P1 by multiplying by this Matrix here
by multiplying by this Factor here
Matrix Factor here and when we will
compute uh step three we will fix step
two etc etc etc now let's talk about the
normalization factor because for now we
have been ignoring it um the
normalization factor is something that
we can compute while Computing this
maximums because it is provided in the
pseo code of the online algorithm that
we have seen before for the soft Max so
while Computing the maximum we can
actually compute the uh normalization
Factor by fixing the normalization
factor of the previous iteration and
this is exactly what we are doing here
so at the first iteration we comput the
normalization factor using the local
maximum and at the second iteration so
you can for now ignore uh this one
because we are not fixing l0 with
anything because l0 will be zero so we
are just basically
um we are just Computing this summation
here so l0 will be zero so this Factor
here will be
zero um and when Computing L2 so the
normalization Step at the second
iteration we will fix L1 with an
exponential which guess what it's
exactly the same exponential that fixes
the maximum uh the P11 so it is the
previous estimation of the maximum minus
the current estimation of the maximum
plus the new uh normalization factor
using the local maximum and we keep
doing this job at the end we will obtain
a correct output for this uh metrix for
for this uh block here but without the
normalization how to apply the
normalization well the normalization is
something that is um we need to divide
each element of this o by the
normalization factor but because we are
keeping while iterating through these
for Loops we also calculate the
normalization factor we keep
accumulating it until we reach the end
of the iteration and then we apply the
normalization factor so we take the last
output and we just divide it by L4 which
is the normalization factor calculated
at the fourth iteration and that will
fix the soft Max all right guys so now
that we have derived the algorithm of
how to compute this output of the
attention blockwise why also fixing the
soft Max which is done independently in
each single block we know that the
normalization is done at the end I want
to also prove it so what we done when we
introduce this algorithm that computes
the soft Max in an online way we proved
by induction that this algorithm is
correct so at the end of this algorithm
this L of the last iteration will
actually be the normalization factor
that we can appli to get the soft Max so
we don't apply the normalization while
Computing this output in an online way
iteratively way by multiplying the query
with all the blocks of keys we apply it
at the end of this four iteration and at
the end of this four iteration we will
have the last
output um and we also know that the last
L will contain the exact normalization
factor that we need to apply to each row
because this o of four is a block of
output rows which is if you remember
from the attention mechanism each output
the output of the attention has the same
shape as the input query Vector which is
a sequence of tokens so this O is a
sequence of tokens that we need to apply
the normalization to and we know that
the correct factor is L4 so let's prove
this simple formula uh L4 is a vector
one for that contains as many elements
as there are rows in o4 so in this O
Block uh of
rows suppose that it contains uh two
rows like in the algorithm that I have
described so far in which we pretend
that we are grouping two rows of queries
with two columns of keys together so the
output o uh um the block O will contain
two rows of the output so we will have
two normalization factor in this L4
Vector here what we are doing with this
formula is we are taking this L4 vector
and we are creating a diagonal matrix
with it and then we are Computing the
inverse of this diagonal matrix so L4 is
a vector that contains two normalization
Factor so it's l i don't know let's call
it l L4 4 element 1 and L4 element
2 this is our
L4 uh Vector then we have 04 o4 is a
matrix as you can see from the shape is
2 by
128 Matrix so o is let's copy it
actually oh no let's not copy it o4 is a
matrix that is two
rows with 128 elements so so the first
row with 128 dimensions and the second
row with 128 Dimensions the first thing
that we are doing with this L4 is we are
converting it into a diagonal matrix
which will be a diagonal matrix 2x two
because it contains two elements so it
will become something like this so it
will be L4 the first element of L4 zero
and then zero L4 the second element of
this
Vector then we are Computing the inverse
of this Matrix the inverse of a diagonal
matrix is just the diagonal matrix with
each element on the diagonal that
becomes its
reciprocal uh this is from linear
algebra it's not I'm making it I'm
making this up so uh the inverse of this
Matrix here is equal
to uh the same uh diagonal matrix but
where each element is one over L4 the
first element of L4 zero 0 and 1 over uh
L4 the second element of L4 and then we
are multiplying this stuff here so let
me delete some stuff so this stuff here
is getting multiplied by
o which is a matrix that is 2 by
128 so we are doing this multiplication
now
multiply now the output of this so this
is two let me WR it 2x 2 * 2x 128 will
be a matrix that is 2x
128
where the first um dimension of the
first row of the output of this
operation will be the dot product of
this col this row here with the First
Column so basically we are dividing this
element here by L4 the first element of
L four the second uh output element here
will be the dot product of this row with
this second column so we are only
multiply we are dividing the the the the
second element here of this input Vector
here by L4 the first element of L4
because the all the elements of the
second row will be multiplied by zero so
they will not contribute to this output
row while the second output row will be
the dot this element here will be the
dot product of this row with the First
Column the first element
here is multiplied by zero so it will
not contribute to this output so it's
only the second ele the first row of the
second the first element of the second
row of the input Matrix here will be
divided by
l42 so basically this will be applied
will divide all the elements in the
second row and this will divide all the
element in the first row in producing
this one here which is exactly what we
need to do when we want to normalize we
need to apply this normalization factor
and this should uh help you better
visualize why this operation is
normalizing the vectors of the output at
the end and still obtaining the same
result now let's proceed further all
right guys finally we are ready to see
the flash attention forward pass uh by
um also comparing it with what we have
derived so far so if you look at the
flashh attention paper first of all this
is the flashh attention to forward pass
and later I will explain what are the
differences between the flashh attention
one and the flesh attention two um I
didn't want to jump directly to this uh
forward pass because I believe that even
if the derivation like the derivation
was a little uh difficult to follow I
believe that it gave you some intuition
into what is happening so even if you
understand 50% of it that's enough
because later we will also code it and
you should reach like a 90% of
understanding so every time we introduce
some new information it should improve
your be the the your understanding so
basically in flesh attention what we are
flesh attention to especially we we take
our um as input we have our uh query
Kean values which are sequence of tokens
each token is made up of a vector of D
dimensions and D lower case D dimensions
and we divide this query guess what into
blocks in how many blocks well depending
on this parameter BR which is the size
of the query block that we want to
choose so how many rows of query we want
to group together into one block and we
also do it with K and V and we divided
that into um blocks of uh um depending
on this parameter BC then we also
initialize the output which is the
output that we want to produce so the
what is the flashh attention Computing
well the flashh attention um is
Computing the following so it's
Computing the uh soft
Max soft Max of the query multiply by
the transpose of the keys divide by the
some normalization Factor uh multiply
that by V and um so that's what it's
going to compute and it's going to
computed this way first of all there is
an autor Loop through the queries which
corresponds to the same pelo code that
we have seen before because we want to
compute each block of the output
Matrix um in parallel with the with with
respect to the others so basically we
want to compute this output block and
this block all output block
independently this output block here
depends on the query one and all the
keys this output block here depends on
the query 2 and all the keys this output
block here depends on the query three
and all the keys where query 1 is not
the first query but it's the first group
of queries or first block of queries
query two is not the first um query two
is not the second row of the query
Matrix but it's the second block of the
query Matrix etc etc um so that's why we
have this outer um outer uh iteration
among all the blocks because we want to
compute all those blocks of the output
Matrix in parallel but to compute each
of these output block we need to go to
an iteration among all the keys that's
why we have an inner loop on the keys
and we do exactly the same operation
that we have done so far uh by hand so
first we compute the S Matrix which is
what the each block of query with the
corresponding block of the keys then we
compute the local maximum to the current
s block this is the local maximum and we
um compare it with the maximum of the
previous iteration because that's what
we do in the online softmax then we
compute the
P the P block which is the softmax star
of the S block minus the local maximum
of the S
block then we compute the uh
normalization Factor what is the
normalization factor it is the summation
of all the exponential of the soft Max
star but um uh by fixing the
normalization factor of the previous
step and we know how to fix the
normalization factor because we just
multiply by an exponential which is the
previous maximum minus the current
maximum that's what this factor is and
then we computed the output exactly
using the same uh correction factor that
we have seen before which is the
diagonal matrix made up of the diagonal
um where on the diagonal you have the
elements of this Vector here which is
the exponential of the previous maximum
minus the current maximum multiplied by
the output of the previous step because
we want to fix the previous step because
it was based on the previous P which was
using the maximum of the local previous
P plus the current p v which is based on
the current local maximum and it will be
fixed by the next
iteration okay okay and at the end after
we have gone through all the cas so we
have computed all the output block but
we didn't apply the normalization factor
and it's applied at the end because
while going through each key we are
calculating the L normalization factor
for the softmax because inside of this
for Loop we are just Computing the
softmax star so we are not normalizing
each value so at the end someone has to
normalize it and it will be this um this
this um instruction here which is use
the normalization factor that we have
computed over all the iterations and
apply it to each element of O because
the difference between the softx star
and the actual sofx is just the division
by the um the normalization factor and
this instruction here is actually
dividing each o with the corresponding
normalization Factor one for each row of
the block each row in the output block
that we are
Computing uh later we will see also what
do we do what is what does it what is
this SRAM what is the hbm for now I just
want you to concentrate on the um
operations that we are doing and they
are exactly the same operations that we
have done so far uh later we will see
also why do we need to save this stuff
here and etc etc but for now you should
have enough knowledge to be able to
follow what is written in the flesh
attention paper for with respect to the
forward pass algorithm and uh what we
are doing basically is just block matrix
multiplication and while Computing this
block we fix the previous block by using
tricks of the
exponential all right now that we have
seen forward Paths of the flashh
attention before we can implement it we
still lack a little bit of knowledge
because we don't know anything about the
gpus and we don't know anything about
Cuda and we don't know anything about
Triton so that's what we are going to
see
next all right guys it's time for us to
explore finally the GPU and the Cuda
programming model well uh let's start by
comparing the CPU and the GPU and this
will let us understand how Cuda works
then so first of all what is the Cuda
and what is the GPU the GPU is the
hardware unit that we are that we buy
and Cuda is a software stocks made by um
made by Nvidia or to write software for
this GPU that they sell AMD has its own
software stock and other manufacturer
have their own in this particular video
we will be seeing example of Cuda
kernels but the knowledge that you will
get can apply also to other gpus now the
first difference between a CPU and a GPU
is its purpose uh the GP the your
computer is right now running on a CPU
and your operating system is interfacing
with the CPU um in using the uh the the
so-called scheduler so right now
probably you are running a browser you
are also running some other software on
your computer
on your computer and the scheduler is
tasked with switching between them very
fast on your CPU in such a way that it
looks like to you that the processes are
running
concurrently uh this actually is a fake
kind of parallelism unless your CPU also
has multiple cores which nowadays CPUs
do have so a CPU usually has one or
multiple cores but not so many of them
so usually have dual core of quad core
or eight core CPU and each of the this
course can execute instructions in
parallel um the CPU is tasked the the
main purpose of the CPU is to execute
many different task uh and switching
between them very fast so maybe you have
a browser that is running a small game
and then you have another movie player
but then you have a word processor and
then you maybe have some uh utility to
manage your um to download files Etc so
most of these programs actually are not
computing intensive are actually iio
bound meaning that most of the time they
are either waiting for the network or
they are waiting for the dis and they
are very different from each other in
the purpose so a browser is completely
different from a movie player and it's
completely different from a word
processor um so the job of the CPU is to
actually reduce the latencies of
processing all these operations and it's
highly optimized to process to optimize
each of these execution unit called the
course which means that each course
has a part that is tasked to understand
first of all what is the next
instruction to run or to to predict the
branch of how the what the next uh
operation may be based on the conditions
that you are running for example if you
have a if condition the branch predictor
can understand what is the more most
likely next instruction and can do some
optimizations uh also the CPU is has a
lot of caches to reduce the latencies in
loading data from all the devices it can
interface with it can interface with the
uh the RAM for sure but it can also
interface with the dis it can also
interface with some peripherals like the
printer like the mouse like the keyboard
etc etc on the other hand the GPU is not
tasked to do many different things at
the same time but it's task to do one
thing or few things but on a massive
amount of data so the operations that we
do on the GPU are uh requires a lot of
computation and for that for this reason
most of the area so the physical area of
the GPU is dedicated to compute units so
this green stuff that you can see here
and these are called
course um and you can see that the part
that is dedicated to the control area so
the part that is um tusked with
understanding what is the next
instruction to run or to do some
optimization in this the program is very
little uh you may be thinking well uh
does it make it does it make the GPU uh
less fast compared to the GPU to the CPU
well not really because we have many
more coures that can compensate for this
um higher
latencies um okay I can give you a lot
of knowledge about the GPU from a
theoretical point of view I think the
best way to understand the Cuda
programming model is just to jump into
the code so we don't get bored okay
imagine we have a very simple task and
we have a
vector we have two vectors A and B and
we want to calculate the sum of these
two vectors vectors into and save the
result into another Vector C where each
item is the element wise sum of the
corresponding item of A and B how would
you proceed with this task on the CPU
well you would do a for Loop for example
so for example uh you would uh make a
for Loop that starts from the first
index so the index zero and C of0 is
equal to a of 0 plus b of0 then C of 1
is equal to uh a of 1 plus b of one Etc
and you do a for loop on all the
elements of this
Vector in the GPU we want to do the same
operation but in parallel because we
have a lot of compute units called
course and we want all of them to work
in parallel so the first thing that we
need to understand is how to divide the
work that we are going to do into
subunits of work and dedicate each core
to one subunit one simple subdivision
would be okay the first core should do
this summation the second core should do
this summation the third core should do
the summation etc etc so imagine we have
a eight element Vector we need eight
course to do this element wise
summation we will call the course
threads because it should also remind
you of the multi- threading uh that we
already use in operating system so
multiple threads work concurrently on
the same uh on the same or similar job
in the GPU let's look at the code now
the code that I am going to show you is
Auda kernel and it's written in C but
you don't have to understand C and you
don't have to understand this code what
I want you to understand is the
intuition behind it because later we
will need this knowledge and convert it
into Tron which is Python and you should
already be familiar with
python so let's go to the code and I
have a very simple vector addition uh we
can see it
here okay uh first of all how to do a
vector summation usually the GPU is
interfaced with a CPU and the CPU has to
first of all um tell the GPU what is the
data it he's going to work with so the
CPU needs to have these vectors it needs
to transfer them to the GPU then the GPU
needs to do this Vector summation then
the CPU has to copy back the information
from the output from the GPU to the CPU
and then make it available to the
program this is what we are going to do
here so we are going to allocate three
vectors of size n one called a one
called B and one is the output Vector we
initialize their items um randomly so a
of I uh is a random number between zero
and 100
excluded then we allocate memory on the
GPU to hold these vectors and then we
copy them to the GPU so we copyed the a
vector to the GPU and the B Vector to
the GPU of course we don't copy the
result because that's what we want the
GPU to populate with uh the output so we
just allocate it on the GPU what we
don't copy uh our output Vector on the
GPU because it's it's made of random
values then um what we do is we launch
the kernel the launching the kernel
means that we launch a program that the
GPU should execute in parallel on
multiple threads or multiple course each
of these threads should do a unit of
operation a unit of work that is
independent from the others actually
they can be dependent on the other but
we will not be talking about
synchronization um so we launched this
kernel and what we are seeing in this
line is launch one block of threads and
later we will see what are blocks but
you can think of you can ignore this one
for now what we are saying here is
launch n threads so n multi parallel
operations on with the following
arguments so the output where we want to
save data the input array a and the B uh
input B and the number of elements let's
see what happens inside of this method
this method is following a particular
syntax that is um um how to say Cuda uh
specific So This Global is actually
added it's like a superet of the C
language where we have some additional
keywords that belong to Cuda so it's not
really C it's Cuda
c um so it's a very simple uh method as
you can see and the first thing that we
need to do is
Cuda cannot know what each thread should
do it's we should tell each thread what
to do so the mapping between the data
and the what each thread should do it's
up to us as software
engineer Cuda what we do is when we ask
it to launch n threads in parallel it
will allocate n threads and assign a
unique identifier to each of these
threads in our simple case we can see it
like this so it will assign the first
thread the index zero so we are asking
for example imagine we have a vector of
eight elements it will assign the first
thread index zero here I call it one but
it's it's wrong but uh we can write
another number here so this will be
actually tread zero this will be tread
one this will be tread two tread three
tread four tread five tread six and
tread 7 so let me delete this
one so we don't get confused
and what we are seeing is that the item
that each thread should process is equal
to its thread index so this is the
thread zero so it should process the
item with index zero this is the thread
one and it should process the item with
index one this is the thread number two
and it should process the item with
index two and this is what we are doing
in this line of code we are saying which
item each thread should process which is
exactly the it's uh the thread
identifier so the thread ID uh later we
will see why why we have this dot X but
that's for later next thing that you
should see is okay we are doing the
output of the I position is equal to the
a um Vector at the E position plus the B
Vector at the E position so it's a very
simple summation element wise you may
have noticed this if statement why do we
need an if statement if we already know
that we are going to launch eight
threads and uh of course I will be
between um we already know that we are
going to launch n threads so I should of
course be less than n because each
thread ID will be between zero and N
minus one so why do we need this if
condition this is needed because when
you um Cuda when it launches a number of
threads this number of threads is always
a multiple of um a unit which is 32 in
the case of the Cuda so if we have like
34 elements in a vector and we ask Cuda
to launch 34 threads Cuda will not
launch 34 exactly it will launch 64
threads so multiple of 32 uh which is
the warp size by the way um and
U what we need to do is we need to ask
these threads to only work for we only
need to ask the threads that have a
corresponding element to work and all
the other that don't have a
corresponding element because the the uh
the vector is not large enough for all
of them to not do anything so do not
enter this uh uh if
statement there is another thing that we
should learn which is actually the
threads um actually when we have a group
of threads in in um Cuda programming
model but I believe also in other
gpus um a group of threads of 32 threads
is called a warp and this 32 threads
will share the same um control unit so
let's go back to the slide so as you saw
as you can see here we have this yellow
unit here in the GPU and a group of
threads will share the same control unit
which means that what is this control
unit it's a part of the hardware of the
GPU that is tasked with understanding
what is the next instruction to run now
if the group of threads is sharing the
same unit it means that this group of
thread will always execute the same
statement at any time they will always
work in synchrony will always work on
the same instruction it's it cannot be
like this thread is working on one
instruction and this one is working on
another instruction what does this mean
on a programming level it means that if
when we launch a group of threads of
course Cuda will spawn more threads than
we need if the if the number of elements
of our Vector is not a multiple of
32 this means that when we This Thread
they will first execute this uh
operation and each of them will have its
own uh value of this thread ID so they
will execute the same instruction but
the data at each instruction may be
different because each of them have
their own registers which means that
they will always they will for example
reach this statement here and the first
thread will have I equal to zero the
second thread will have I equal to 1 etc
etc even if they are executing the same
instruction this programming model is
called single instruction multiple data
Cuda likes to call it single instruction
multiple thread doesn't matter for us it
just means that they will always execute
the same instruction but the value of
the variables may be
different then after executing this
statement they will reach this statement
here the if statement and of course some
some of them will evaluate this
statement to true and some of them will
execute the statement to false which
also means that some of them should
enter this if statement and some of them
should not enter this if statement
however because the control unit is the
same for all of them they will be forced
to enter this if statement even if they
should not so how Cuda manages this
control Divergence it will basically
make work like this all the threads for
which this if statement is equal to True
will enter this if and will execute the
instructions inside of this if and all
the threads that have this statement
equal to false so the condition of this
if equal to false they will enter the
for Loop because they cannot not enter
it because they should be always
executing the same instruction at any
time but they will just not do any
operations inside of this for Loop they
will just sit
idle this is um called control
Divergence and it can reduce the um the
the throughput of your program so you
want to minimize it um but you may be
wondering why doesn't the GPU dedicate a
control unit to each core so that they
can work independently from each other
because the control unit is expensive to
add in the cheap area of the GPU it's
much more efficient to add more workers
instead of adding control area control
units for each worker so this is a
design choice of the GPU and it works
fine okay now that we have seen how a
kernel works Works let's move forward to
another
example all right the next example that
we are going to see is the following is
the same as the as before so we are
going to do a vector addition but
imagine that we have a very large Vector
so imagine that we have a vector with 1
million elements of course we could do
like before so we launch a kernel with 1
million threads the problem is Cuda will
reject it because it say I don't have 1
million threads to run in parallel so
how can we proceed in this case because
usually we are working with very big
matrices or very big vectors so we need
to process a massive amount of data so
how to manage a parallel um uh how to
say parallel computation when we do not
have enough uh computation course one
way is to divide the input Vector into
blocks of uh elements for example we may
decide for example imagine our um GPU
only has 32 cores in total we may divide
our input Vector into blocks of size 32
such that the first 32 element are the
first block the next 32 element are the
second block the third 32 element the
third block and the last 32 element are
the last
block in this way we can ask the GPU to
work on one block at a time so we can
say okay work on the first block and
after it has processed the first block
it can work on on the second block and
then the third block and the fourth
block this also allows the GPU itself to
manage subunit of work because imagine
now we have um blocks of 32 elements but
we have a GPU of 64 cores the GPU we can
also schedule two blocks at the same
time because it has enough course so we
need to give some granularity uh we need
to reduce the Gran uh increase the
granularity of our data to let the GPU
decide how many blocks to schedule this
is the reason we introduce blocks inside
of Cuda so let me make a concrete
example but with a very simple
assumption imagine our GPU only has two
cores or let's say four cores actually
uh so we have uh n is equal to eight
elements eight and we have four cores in
total so what we can do for example is
to is divide this uh um Vector into
groups of either four cores or even less
let's say two two elements at a time so
this is the block number one this is the
block number two this is the block
number three and this is the block
number
four we can ask Cuda to launch a kernel
that is made up of four blocks and where
each block is made up of two threads so
when we launch the um Cuda kernel we can
show the code now
we ask the Cuda where is the
instruction this first instruction tells
Cuda how many blocks we have and the
second part of this um in this uh um
symbols tells how many threads we have
for each block in our case we want um
nide by the block size number of blocks
where the block size in my picture is
two so how many blocks we will will have
we will have a number of blocks so the
number of
blocks is n / by two where two is the
block
size so this is the block size and this
will be equal to four blocks each of
size equal to two and this is what we
are doing here uh so we are saying that
the number of blocks is okay the ceiling
because it may not be a multiple of the
block size n of nide by the block size
and this tells how many blocks we have
and this is will be this will Define our
grid it means the grid is basically
telling how many blocks we have and then
each block is made up of block size
number of threads then the problem is
how do we assign the work to do to each
of these threads when we launch a kernel
like this with this configuration so the
number of blocks and the number of
threads per block Cuda will do the
following job it will assign this block
um each block an index called the block
ID where the block ID of the first block
is zero so let me write here so this
will have the first block will have a
block ID equal to zero and in each block
it will assign a thread ID and the
thread ID of the first thread of each
block will be the thread zero and the
second thread will be the thread number
one the second block will have block
ID block ID equal to one and the first
thread of this block will be the thread
number zero and the second thread of
this block will be the thread number one
the third block will have a block
ID block ID equal to two and the first
thread will be the thread number zero
and the second thread will be thread
number one Etc until the last block
which will be equal to three this will
be thread number zero and thread number
one the problem is now based only on the
index of the block and the index of the
thread how can we map it to what element
of the vector each thread should work
with one simple assignment would be to
just do uh well you can see that in this
case we need the uh this Vector um This
Thread here to work with element zero
this one should work with element one
this one should work with the element
number two this one to the element
number three this one four this one
five six and seven this five is so ugly
so let me write it
again how can we find the mapping given
only the block ID and the thread ID how
can we find Which element it should
correspond to well it's very simple
formula so you can see that the
element let's call it the element ID
which in the code I call it I is equal
to the block
ID multiplied by the size of each block
which is block size let's call
it block size yeah I have it block
size plus the tread
ID because in the case of the first
thread this will be equal to 0 * 2 + 0
which is 0 in this case it will be equal
to 0 * 2 which is 0 + 1 and it will be
equal to 1 in this case it will be equal
to 1 because block ID is equal 1 1 * 2
is equal to 2+ 0 is equal to 2 etc etc
and you can see that this formula works
for all the threads so the mapping when
we launch a Cuda kernel we are telling
the GPU how many blocks we want and how
many threads there are in in each block
but Cuda has no notion of how to map
each um Cuda has no way of knowing how
to map each um thread into the element
it should work with that's up to us and
that's what we are doing here when we
are creating this um uh Kel here so we
are telling that the each element each
thread should work with the E element of
the vector where I is calculated as
follow the block ID to which This Thread
belongs multiplied by the block size so
how many threads there are in in each
block plus the thread ID um and this
will tell the I element this particular
thread should work with by giving uh in
let's go back to the slides by choosing
the block size equal to two and having
four cores the GPU can choose to run one
block or two block concurrently if it
has enough free course so that's why we
want to work with by block by block
because it allows the GP you to choose
how it want to parallelize the
operations if it has enough course and
we don't need to have n course for n ve
n element Vector we can divide it into
smaller blocks and let the GPU manage
the scheduling let's see one last
example and then we move on to
Tron imagine now we want to do a matrix
addition instead of doing a vector
addition now in a matrix addition we
have data that we can see on two axes
one is the rows and one is the
columns it's usually uh we represent the
vertical axis as the y axis and the
horizontal axis as the
xaxis by using the same blocked um uh uh
intuition that we used before so
dividing the data input data into blocks
this is how we can divide the labor of
our maxr addition into blocks for
example we can divide our rows into
blocks and call this one the block block
zero and this one in the block one and
this one is the block two the same we
can do on the xaxis so we can choose
this one as the block zero this one as
the block one and this one as the block
two on the x axis with X is the column
axis and the Y is the row axis we don't
even have to choose the same block size
for the rows and the columns we can even
choose the to group together three
columns and two rows instead of doing
two and two in this case we need to find
because as we said before when we launch
a Cuda kernel Cuda will just assign IDs
to the blocks and the threads in each
block then it's up to us understanding
what to how to map the ID of the block
and its corresponding thread ID into the
data element that this particular thread
should work it should work with so in
the case of metrix addition we could say
that each thread should work with one
output element of the output Matrix C so
it will become the um the the sum of the
a element plus the B element and it
should map it to the C Matrix uh output
Matrix so how to do it imagine we have
six rows and we have six columns one
easy way would be to divide this rows
into three blocks each made up of two
rows and each column into three blocks
um each block made up of two columns
Cuda will launch
uh uh as many blocks as there are the
combinations of the rows and column
blocks so in this case we have uh three
blocks for the columns and three blocks
for the um uh rows so it will launch uh
nine blocks so this is the block number
00 because it's Cuda will identify the
dimensions of the block based on the um
uh axis in which we have divided it so
we will call this the X Dimension the
columns and the rows we will call it the
Y Dimension so it will launch as many
blocks as there are combinations of X
and y's in this case we have nine so
this will be the block 0 0 this will be
the block 01 this will be the block 02
this one will be the block 1 Z 1 one and
one two etc etc inside of each block we
will also divide the threads into X
threads and Y threads along the two
Dimensions so this will be the thread
zero and the thread one along the X axis
in the X block and this will be the
thread zero and the thread one in the Y
uh in the in the block zero of the y-
AIS and each block will have two threads
and they will be identified as thread
zero and thread
one so let's look at how the launch grid
Works in this case um so imagine we have
a matrix with number uh num rows number
of rows and num columns num calls number
of columns and we want to to divide each
row the rows into block size number of
rows and calls block size number of
columns um we define basically the
number of blocks that we need is this
one so this is just a fency way of
writing the ceiling of the num Rose
divide by The Rose block size and this
is just a feny way of writing the
ceiling of the number of columns divide
by the col's block size this tells us
how many blocks we will have on the rows
and how many we will have on the columns
the grid you can see here which tells us
how many blocks we have is a tole that
accepts three values which tells how
many blocks we want on the X Dimension
how many we want on the Y Dimension and
how many we want on the Z Dimension we
are not going to use the Z Dimension
because we only have a matrix then
inside of each block how many threads we
want for the X Dimension and for the Y
Dimension as the X Dimension we have
chosen The Columns so we are saying how
many blocks we want the columns and how
many blocks we want for the rows and
then inside of each block how many
threads we want for the column block and
how many threads we want for the row
block this will Define our launch grid
and what Cuda will do it will just
launch this following configuration so
it will launch as many blocks as there
are combinations of X and y's and inside
of each X and Y it will assign a thread
ID in such a way that the thread zero on
the X axxis is uh so there will be two
threads on the x- axis and two threads
on the y- axis of each block now let's
try to understand how to map just based
on the block ID on the x- axis just
based on the block ID on the y axis and
the thread ID on the X and y- axis how
to map it to the one element of the
output
Matrix uh let's look at the code so
first we can use the following formula
to identify which row this element
should work with uh which which uh the
uh which because each element of a
matrix is identified by two in this say
one is the row identifier and one is the
column identifier the row identifier we
can look at it like the block ID
multiplied by the block size plus the
thread ID uh let's see why it makes
sense so in this case for example this
uh thread will work with the row zero
because the block ID is on the y- axis
is zero and the thread i z zero so it's
block ID multiplied by the block size so
0 plus 0 it will be zero so this element
will be working with the row number zero
and which column it will be working with
well it will be working with the block
ID zero multiplied by the block size on
the column which is again zero I mean
the block size is two but multiply by Z
it will be zero plus the thread zero so
it will be zero this element here on the
here it will be the block ID of the Y y
AIS multiplied by the block size plus
the thread so it will be the element
zero on the row and for the columns it
will be the element one let's see
another one for example here uh for
example this element here so this um how
this uh thread will uh which element it
will work with well it will be the block
size on the Y AIS multiplied by the the
block ID on the y axis multiplied by the
block size so it will be 1 multiplied by
two so that will be our row so the row
number
two uh which makes sense because it's
the um this is the row zero this is the
row one and this is the row two and the
column will be the uh block ID on the x
axis which in this case it's equal to 1
multiplied by the block size which is
equal to two so 2 + 1 is equal to 3 so
this uh thread here will work with the
element number two three and this
formula now makes sense so this is how
we use the block ID and the thread ID
inside of each block to map it to which
element this particular thread should
work with so as I said before Cuda has
no notion of knowing which element this
particular thread should work with this
is up to us just based on the block ID
and the thread ID that CA
assigns then we make sure that the uh
row index is less than the number of row
and the column index is less than number
of columns why because as I said before
when we launch um blocks and threads
Cuda will round up that number to a
multiple of 32 in the case of the
threads so which means that some of
these threads should not work with any
data so we make sure that all the
threads that should not have the
corresponding element to work with they
should be just sit idle inside of this
if
statement but the one that have it they
should go enter and do some job so we
calculate the index of the element of
the Matrix that this particular thread
should work with as follows which is the
row index multiplied by the number of
columns plus the column index uh this is
just another way of
writing uh a uh or for example this is
just another way of writing a of row
index call index but the way we allocate
arrays in C or C++ is a flattened array
where all the rows are one after another
so we need to identify the element
inside of the array based on its row
index and column index and this is the
formula that we use to identify it if
you have never worked with um uh arrays
in C++ or C then it doesn't matter
because later we will see tensor layouts
and this will be much more clear but if
you have already worked with then you
already know how to uh index an element
inside of a multi-dimensional array in
C++ and then we compute the output as uh
as usual so I know that this has been a
lot of information so what should we
should we remember from this the first
thing that we should remember is that we
decide how to divide the work on
whatever Matrix we are working with or
whatever thread we are working whatever
Vector we are working with we tell Cuda
how many blocks we want and we tell Cuda
how many threads we want in each block
based on the identifier of the block ID
and the thread ID we should come up with
a strategy on how to map it to a subunit
of work so which part of the Matrix or
which part of the vector that particular
thread should work with um now the next
step for us is to understand the tensor
layouts because we are going to work
with the tensors and we need to
understand how the tensors are lay out
in the memory of the GPU or in the CPU
as well actually so we need to
understand what is the row column row
major layout and the column major layout
what is the stride Etc and convert all
the knowledge that we have about Cuda
into Triton so that we can then code
with Triton our kernel so let's
go all right guys finally it's time for
us to explore tensor layouts now why do
we need to explore tensor layouts
because before we we have seen some
examples of Cuda kernels and when you
give a matrix to Cuda or to a Cuda Kel
or a vector to Cuda kernel Cuda will not
give you will not give you the entire
metrix like like in Python where you can
access each element by its index Cuda
will just give you a pointer a pointer
to the starting element of that
particular Matrix or the starting
element of that particular Vector then
it's up to you to calculate the memory
address of all the remaining elements so
suppose that we have a simple Vector in
py torch this simple Vector could be the
following which is a vector of shape
seven because it's a tensor with only
one dimension with shape seven which is
the number of elements in the First
Dimension uh for now ignore this
property called the slide and later I
will explain it what is it uh how this
tensor will be saved in the memory of
the CPU or in the GPU it will be saved
as follows suppose that the starting
address of the first element is the
address 100 and suppose that each
element is made up of a floating point
of 16 bit so it means that each element
will occupy two bytes so the start
address of the second element will be
the address 102 and the third element
will be 104 and the fourth element will
be 106 etc etc etc so uh this is exactly
what you get when you in see you get um
you allocate a vector or a matrix with
Malo so when you allocate in C um vector
or a memory with maloc C or the memory
allocator will just allocate enough
memory to store all the elements and it
will give you a pointer to the start
address of this memory then it's up to
you to understand where each of this
element is stored in that block of
memory and this is to to do this we
introduce a property called stride The
Stride tells us how many elements we
need to skip to arrive to the next
element in the particular dimension in
this case for example in the case of a
vector we only have one dimension which
is the X
Dimension uh or the columns Dimension
you can think of it so this is the first
column this is the second the third the
fourth fifth etc etc um so in order to
arrive from one element to the next we
just need to skip one element so to go
from here we need to just increase our
pointer by one element and then to go
here we need to increase again pointer
by one element Etc this allow us to do a
for loop on this tensor let's look at a
more complicated case like the Matrix so
the Matrix is
two-dimensional and suppose we have the
following Matrix which is made up of six
element with two rows and three Colum
columns so the shape of this tensor will
be 2x3 because it we have two rows and
three columns uh how this Matrix will be
saved in the memory in the memory it
will be just a
flattened um Matrix it means and this is
called the row major layout but there is
also another one called column major
layout that we will not be discussing so
how it will be stored in the memory is
as follows it will be the first elements
of the first row so the elements of the
the first row followed immediately by
the elements of the second row so that
the memory address imagine with this is
the memory address of the first element
is 62 to go to the next element we need
to increase the memory address by the
number of bytes that each element
occupies which is two bytes so the the
address of the second element will be 64
the third element will be 66 and the
next row will start immediately after
the end of the first
row let's introduce this property stri
so the stride is what the stride tells
us how many elements you need to skip in
each Dimension to arrive to the next
element of that dimension for example
imagine we want to um uh address um we
want to get the element so all the
elements of the first
row um so let's call this tensor here
let's call it t so T of zero and um this
basically this indexing here says give
me all the elements of the first row so
in the first row select the all only the
first row and give me all the elements
of that row how to how does this
indexing work well by starting from the
pointer to the first element it will
select only the first row and then it
will move the the index here one element
after another so it will select the
first one the second one the third one
how does it know that it needs to move
one element by one element because in in
this Dimension the stride is one so the
stride tells us how many elements you
need to skip to arrive to the next
element in that Dimension imagine now
that we want to uh get the T of let's
say zero and
one well in this case the T let's say t
of one actually and all the elements of
the first row it will first of all it
needs to skip some elements from the
first First Dimension it needs to skip
the element zero because we don't we are
not selecting it we only want to select
the element one of the First Dimension
which basically means the row with index
one so because it will start from the
first pointer to the first element it
will it needs to know how many elements
to skip and how many element to skip is
given by The Stride so the stride tells
us how many elements you need to skip to
arrive to the next element of the First
Dimension so in this case it will take
the pointer to the first element skip
three elements and it will be starting
with the second r go and then inside
this row it will go through the second
in the the the index of the second
dimension in which the stride is one so
it will just go one after another and it
will return only this part of the memory
so to rehearse the stride is just
a a number that tells us how many
elements you need to skip in each
Dimension to arrive to the next index in
that Dimension so it means that to go
from one row to the other we need to
skip three elements to go from one
column to the other we need to skip one
element um why is the stride useful well
the stride is useful because it allow us
to reshape tensors very easily and
without doing any computation let's see
um okay imagine we want to reshape a
matrix imagine initially the shape of
this Matrix is 2x3 so we have two row by
three columns and we have a stride
calculated as follow it means that to go
from one row to the other you need to
skip three elements and to go from one
column uh one row to the other you need
to skip three elements and to go from
one uh column to the next you need to uh
skip one element so you need to jump by
one element if we want to reshape it
into this shape so 3x two basically we
want to um have three rows and two
columns uh the the we can reshape it
without actually changing its memory
layout just by changing the stride
because look at this physical
configuration of the tensor and we can
access this same tensor as this shape or
as this shape exactly by using the same
physical view because to go from one row
to the next here the stride is three so
we need to skip three elements it means
that the starting address the starting
element of the second row is given by
the start point plus three elements so
exactly here the second row will start
and each element of the second row is
one after another because the stride of
the second dimension is one so you can
see that to get the second row we can
just start from here and then go one
after another and get all these elements
which is exactly the second row suppose
we want to obtain the second row of this
view here of this shape of this reshaped
Matrix how to do that let's look at the
stride the stri now is two in the row it
means that to go from one row to the
next we need to skip two elements so if
we want to select this row here we go
from the starting point of the uh memory
so this start pointer we skip um the
first two elements because the stride
says that to go from one row to the next
you need to skip two elements so we
arrive here and then we select exactly
two elements which are one after another
because the stride in the second
dimension is one so the um this stride
allow us to reshape the tensor without
changing the physical layout on how it
is stored in the memory
moreover The Stride also allow us to get
the transpose of a matrix without
changing the shape of how it is stored
in the memory so without changing the
arrangement of the elements in the
memory and this is very cool because we
can view the same Matrix as with without
the transpose and also the transpose
version of The Matrix without changing
anything in the memory so it comes for
free just by working with the index and
the stride so to trans The Matrix along
two Dimensions we just need to swap the
stride along this two Dimensions that we
want to transpose so in this case for
example imagine we want to get the
transpose of this Matrix we just need to
swap the strides so if we want to get
the second row of the transposed Matrix
how to get that well you we always have
the pointer to the first element where
the tensor is stored so at the beginning
of where the tensor is stored in the
memory and it says that in order to go
to from one row to the next we need to
skip one element which is correct
because as you can see the second
element is exactly the second element
also in the memory so we just Skip by
one and we get the starting point of the
second row and then to go from one
element to the next in within the same
row we need to skip three elements so
the second element of the second row
will be after three elements uh um after
the first element of the second row so
after two we need to skip three elements
so we skip this one we skip this one and
we arrive to this one eight which is
exactly the second column of the SE uh
of the second row so basically the the
the stride as you can see allow us to do
two things one is it allow us to reshape
the tensor without having to reallocate
it in another configuration in the
memory secondly it allow us to transpose
a matrix without having to rearrange the
elements in the memory which is great
because moving memory around is
expensive
uh and rearranging the memory is
expensive so that's it's great that this
this stuff comes for free
basically um another thing okay for
example do um if you try to you know
that in py torch there are two methods
to reshape a tensor one is called the
reshape method and one is called The
View method the after transposing a
matrix by swiping by swiping The Stride
of the two Dimensions that you want to
transpose you cannot reshape for free
the tensor anymore because um the tensor
basically what is the stride The Stride
how it is computed The Stride is just
the uh let me show you with a concrete
Example The Stride is just the product
of all the shape uh after um in the
future Dimensions so the stride of the
zero Dimension is just the product of
the elements in the shape of the future
Dimension so so the stride of zero is
just the product of all the shape
starting from the index number one uh
it's not easy to see with the two The
Matrix because we don't have enough
elements so let's do it with a three 3D
Matrix so this is a tensor with the
three dimensions so it is a shape of two
4 three which means that we have two
matrices each Matrix is made up of four
rows and each M and three colums The
Stride is calculated as follows so the
zero Dimension stride is just the the
product of 4x3 and this three here comes
the with the product of just a three
with it with one because we don't have
any future dimension of the tree so when
we transpose this stri property is lost
and we cannot um after transposing this
matrix by swapping the strides we cannot
do further reshaping operations so
basically the the tensor is not log
contigous so this is a very Advanced
okay property if you it doesn't matter
if you know it or not but if you are
curious basically in a pytorch you
cannot um view a tensor after it has
been uh transposed because pytorch to
transpose a tensor will just swap the
two strides but it loses the stride
property which is basically the stride
will not be anymore the product of the
future shapes so this is not anym two
this should be two for example and this
should be one but after transposing this
property lost so you need to actually
reallocate the tensor if you want to
reshape it after it has been transposed
doesn't matter if you remember this it's
just a curiosity anyway so what is the
transpose what is the stride used for is
the St The Stride is used for two things
first of all it it is used to understand
how to Index this tensor so just by
having a pointer to the first to the
starting address of this tensor we can
Index this tensor however we like so we
we can access any row any column uh
moreover it allow us to reshape this
tensor for free so without rearranging
the elements inside the memory and third
it allow us to transpose the tensor
however we like just by swapping the
strides of two uh the two Dimensions
that we want to
transpose now that we have seen also how
the tensor is stored in the memory we
can finally go to see
Tron um and see some examples all all
right guys now that we have seen how uh
tensors work Tor leat Works how Cuda
works now we can see some examples of
Tron kernels to see how Triton differs
from Cuda now if you go on the Tron uh
Tron website you will find some
tutorials like in this section here and
Let's do let's work one tutorial
together to understand how Tron is
different from Cuda um so if you go to
the tutorial there are many examples so
first of all the code that I will be
coding for Flash attention is based on
this tutorial here fused attention that
you can see here but with some
modifications because I simplified the
code a lot I removed for example the fp8
implementation I also for example um
this code here on the fuse detention
only works in the backward pass only for
the causal attention while my code will
work for the causal and non-causal
attention uh the second another
modification I did is instead of using
the exponential two that they use here
to make things faster probably because
the exponential two is implemented uh
with a faster unit uh I I I use the the
original implementation of flash
attention which used the exponential
with the base e etc etc so I simplified
my code as much as possible to make it
simple to follow instead of making it
optimized so for sure my code will be
slower than the the fused attention that
you see here but mine should be more
comprehensible more easy to follow
anyway let's go to the vector addition
tutorial and if you go to the vector
addition tutorial there are some
examples on how to do a vector addition
with Tron uh this should allow you to
get into the mindset of how to write
kernels with Tron uh instead of writing
first the the kernel and then calling it
let's do the opposite so let's see how
to call this kernel and let's explore
how it works so I have already copied
the tutorial vector addition from the
website so let's look at first of all
what we want to achieve we have an input
Vector X and an input Vector y and we
want to compute the vector addition
which means that with the torch we want
to do the following operation and also
we want to do the same operation also
with the Triton by calling this method
add and then we want to compare the two
vectors output and they should be equal
or at least their difference should be
very very small because of course there
is always some rounding error in case
you are working with floating Point
numbers the size of this Vector is
98,000 elements and um we want to work
in a blocked way so as you remember
before with the Cuda you can do vector
addition by spawning a lot of number of
threads each doing one operation but
when the number of threads that you have
is not enough then you need to divide
the input Vector into blocks and this is
what we are going to do here so let's
look at this add method so this add
method basically will first of all
allocate the necessary memory for the
output Vector then it will compute the
launch grid the launch grid tells Tron
just like in Cuda how many um kernels we
want to how many blocks we want to
launch how many blocks of threads we
want to launch uh if you remember in the
Cuda kernel we specify how many blocks
we want and then how many Treads we want
for each block in the case of uh
Triton we tell how many um blocks we
want and then we don't force how many
threads to launch it will be Tron that
will choose how many threads to Launch
um we just tell what each group of
threads should do so in this case for
example we divide our number of elements
so n so which is 98,000 into blocks of
size block size which is initialized as
1224 this is basically saying take the
to calculate the grid size you do the
ceiling division so basically this means
ceiling of oh oops seal of n elements
divided by block size this is the
meaning of this one so how many blocks
we want now what each block should do is
inside of the kernel so let's go to the
kernel and when we launch the the kernel
we we can specify that the launch Grid
in this uh Square parenthesis and then
in the uh round parenthesis we specify
the arguments of this kernel so let's go
to the
kernel we see that python uh Tron will
not give us access to the tensor X it
will give us a pointer to the first
element of this tensor and this takes us
back to the tensor layouts so the reason
we studied the tensor layouts and the
strides and all the stuff is because
Tron this code uh this add kernel will
run on the GPU and the GPU cannot um
does not index tensors like P torch by
using all the dimension and with the uh
broadcast casting and all this fancy
stuff the GPU will just give you the
pointer to the first element of this
tensor in the memory and then it's up to
you to compute all the indexes of all
the elements that you want to access so
this x PTR is the pointer to the first
element of the X Vector this y pointer
is the first the pointer to the first
element of the Y uh Vector then we have
the pointer to the output Vector where
we want to store the result of this
Matrix addition we specify how many many
elements our vectors have and what is
the block size so how many uh items each
block should process which may not
correspond to how many threads each um
each kernel will
have you may be confused because okay uh
in Tron in Cuda we specified how many
threads each um block should have so the
granularity that we manage is the thread
level here we are saying it's a group of
thread that should work with this
quantity of data then it's up to Tron to
optimize the number of threads that it
will actually use actually there are
tricks there there are ways to say how
many threads we actually want by
specifying the number of wordss but we
will see that later for now just
remember that this thread this Kel here
will process a number of elements in the
input vectors how many number how many
elements block size number of elements
first of all we need to identify which
block we are um we are in Cuda we use
the the variable called block ID dox to
identify the identifier of the block
which tells us which group of elements
we should be working with in Tryon you
do the same by using program ID and in
uh Cuda the block ID can be along the X
Y and Z axis in Tron these are called
the dimension zero one and two uh two
here we have onedimensional data so we
only use one AIS to specify the block
index so we get the block index which is
the P ID in this the Tron this called
the program ID it's more intuitive to
think of as the program like this is a
kind of a program that is running in par
with other programs that will have
different program ID and based on the
program ID we can understand what is the
starting element this program should
work with so this blue block of threads
should work with and to get that is just
the P ID multiplied by the block size so
the p 0 should be working with the
elements that starts from the element
zero the P id1 should start with the
element 1024 and the P2 should start
from the element 248 so it should skip
the first 248 elements and starts with
the element with index
2048 next we Define how to load these
elements based on the
pointer in which of the X and the Y uh
vector
to do that we specify a list of offsets
with respect to the starting address
that we want to load so because each
program in Tron works with a group of uh
um of data so not one single element but
a block of elements we we need to
understand which elements to load so the
offset of this elements in the case of
the program ID zero it will load the
block start so zero plus the elements
from index 0 to 100
1,424
excluded with the program element
um one this basically will result in a
vector that is uh well the program start
with PID equal to 1 will be24 then
1025
1,26 1027 etc etc until
2047 um with the program number let's
say
two this uh this offsets will be the
elements
248
249 blah blah blah until 3,000 and
something um now we also as you remember
when we create
um when we launch a grid the number of
threads is not always based on the
number of elements in the block or the
number of elements in your vector it is
always a multiple of a base number which
is usually 32 which means that the grid
this program may have more threads that
it needs so some threads should not be
doing anything so should not be loading
any data and should not be Computing any
summation so what we this is what why we
need this mask this means that if um all
these offsets that we are loading should
be at most up to n elements because
imagine you have not 1 2,000 imagine you
have Vector of
260 Elements which means that this
offset for um the the third program of
this kernel will load the offset that go
from
2048 2049 blah blah blah 20 60 and then
also 2061 2062 etc etc but we said that
we only have 260 elements so all the
elements 261 62 Etc until 300 and
something they don't exist so we need to
tell somehow that all the threads that
are working with this elements should
not load anything that's why we need
this mask this mask tells load among all
the offsets that this block should work
with only those elements that actually
exist for which this mask is true then
we load the elements of this current
program which is a group of elements
defined by these offsets
and only the one that for which this
mask is true so only the one that
actually exist all the other should be
ignored and we can also specify what it
should load in case this um the mask is
false um with another parameter but we
will not seeing that here we also load
the group of elements of the Y vector
and then we compute the output x + y so
if you remember previously in Cuda we we
did something like this like the output
of I is equal to the x of I plus the Y
of I so we did it one element at the
time because each thread was working
with one index here we are working with
a group of elements so this x is a group
of elements is a block of elements at
most of size block
size actually of size block size and
it's this Y is a group of elements from
the Y vector and we are Computing the
output
Group by group so this this is summing a
group of elements of X with the
corresponding group in y and writing it
in output then we need to restore this
output we need to store it in the output
tensor output PTR that you can see here
which is a pointer to the first element
of the output vector and we say that
where should we store this output Vector
which is of size shape of this Vector
here is block size where should we save
it well in the same offset to where
which we loaded X so if this uh program
work with the index 2048 2049 etc etc
then all this output should be written
in the same offset uh 2048 2049 Etc up
to 3,000 and something using the mask as
well because we don't want to write all
the values of this block size because
maybe we don't have enough elements so
only write the one that are actually
present in the vector so the reason we
need this mask is because Cuda will
launch a number of thread that is always
a multiple of a base unit that may not
be a a multiple of the vector size that
we are working with so we need to find a
way to tell some threads to not do
anything for those that the data is not
avilable so let's rehearse what you have
seen so far in Cuda the program that we
write is at the thread level so each
thread what it should do in Tron it's a
this block of data we work with a block
of threads what data this block of
thread should work
with all right guys the final finally
the moment has come so uh we are going
to quote the flashh attention for our
pass right now in Tron but let's
rehearse the algorithm so the goal of
the attention mechanism in specifically
in Triton uh in flesh attention is to
compute the attention output which is we
want to compute the output of the
following formula so the query
multiplied by the transpose of the key
divide by the square root of the head
Dimension all multiply we apply the soft
Max and then all um multiplied by
V now um we in this video we will be
coding the forward pass and also the
backward pass but before coding the
backward pass we need to understand how
the autograd works we need to understand
what is the gradient what is the
Jacobian how to derive the gradient of
the softmax operation how to derive the
gradient of the matrix multiplication
operation etc etc so that is going to be
another part of the video for now let's
concentrate on the forward pass right
now we have some tools so we know that
we have this thing called the GPU that
can parallelize operation among multiple
cores we know that in Cuda we can
parallelize operations by telling by
writing a program that is the definition
of what each thread should do or we can
follow the Tron programming mode which
is telling in Python what each group of
threads should do
the mapping between the what each thread
should do and the which element that
should thread work with is up to us to
the programmers and the same happens in
Tron we tell we how many blocks of
threads we want how much data each
thread should block of thread should
process so that's the block size that we
saw in the vector addition but then the
mapping between the elements of the
vector and the um the the identity of
each group of threads so the program ID
that we saw is up to us and the same
will happen when we recode flashh
attention let's see what can we par
parallelize in this flashh attention so
first of all this code that you see the
forward pass of the flashh attention is
takes as input query key and value that
is a vector that is a matrices of n by D
however usually in a Transformer Network
we don't have only one sequence made up
of D Dimensions we have many sequences
made up of D dimensions and this D is
the lower case D which is the the number
of Dimensions dedicated for each head
but we don't have only one head we have
multiple head so the algorithm that you
see here is what each head should work
so each uh head of each batch should do
moreover we have have seen before when
talking about block matrix
multiplication that we can parallelize
the computation of the output because
this output block here depends on the
query one and all the keys this one here
depends on the query group block of
query two with all the keys and this one
here is the query three with all the
keys Etc so because this one only
depends on query the group The Block
query one and this one only depends on
the Block query to they can work
independently from each other by sharing
of course work the the
keys another thing that we need to
understand about a Tron is the shared
memory so um the in the GPU we have the
high bandwidth memory and which is the
like kind of the ram so the when you buy
an a100 they tell you that it has 4 40
GB that's the amount of memory in the
high bandwidth memory so the D Ram so
let's look at actually the structure of
the
GPU uh which is here we have this dram
which is the big memory uh that we that
the GPU has and then each um streaming
multiprocessor so it's a let's call it
BL block of threads uh actually also
have a shared memory so inside of the
GPU actually we have we have these
streaming multiprocessors and these
streaming multiprocessors have a part of
memory called the shared memory which is
much smaller than the dram like much
much much smaller what changes between
these two mem
the access to the dram is very slow and
the access to the shared memory is very
very very fast so one thing that is
different between Cuda and Tron is that
whenever you load some information in
Cuda you are loading that information
directly from the global memory because
when we launch a Cuda kernel first of
all as you remember in my C C++ code we
first copy the tensors from or the
vectors from the CPU to the GPU and they
reside in the global memory of the GPU
then we load these elements directly
from the global memory but the access to
the global memory usually it's much much
much slower so what happens with the
flesh attention is that the flesh
attention computation in its the
attention computation in its naive
version the one that we can do with the
torch is very slow because access to the
global memory is very slow so we want to
use as much as possible the shared
memory so we want to reuse the elements
loaded from the global memory into the
shared memory so that we don't need to
access the global memory every time to
load elements from the vectors or the
matrices and um this is what happens
also in Tron so in Tryon whenever you
load some data you are copying the
information from the global memory to
the shared memory then whatever
operations that you're doing is done on
the shared memory and then when you
store the information you are copying
the data from the shared memory to the
global memory um this makes it much
faster so we always work with elements
that have been Lo in the shared memory
and this shared memory basically it's
shared for all the threads that belong
to the same uh block uh in Tron we have
an obstruction level that doesn't make
us work directly with the threads so we
always work with a group of threads that
belong to the same block that share this
shared memory so in Tron we are copying
information from the global memory to
the shared memory we do some operation
with it and then we store back to the
global memory and this is what we are
going to do with flash attention now
let's review the algorithm of Flesh
attention so in flesh attention we have
to go an out for Loop that is among all
the between all the keys and then an
inner loop that is sorry between all the
query blocks and then an inner loop that
is um through all the key
Block in the original flash attention
algorithm the flashh attention one the
outer block was on the keys and the
inner block was on the queries this made
it less parallelizable why because the
Outer Loop is on the queries and we have
seen before that the the output of this
um attention can be computed
independently for each block of queries
so it's much easier to parallelize so
this outer for Loop actually we don't
have to run a for Loop we just spawn
many kernels each working with one
iteration of this outer for Loop so each
working with a different query block of
this outer for Loop and the inner for
Loop is something that we have to
iterate through so each Trion kernel
will work with one query block and then
iterate through all the key blocks um
and inside of this key block we have
already seen the operations that we are
going to do which the we we we explored
before and at the end of this for Loop
we need to store back the output in the
high bandwidth
memory um and this is how it's going to
we are going to work another thing that
we should notice is that this s value
are n byd so as I said before but
usually in a um in
a Transformer model we don't have only
one sequence we have many sequences so
we can also parallelize on the number of
sequences that we have in the batch
because each batch can work
independently from each other and inside
each um and each head each sequence has
multiple heads so each head also can
work independently from each other
because that we know from the attention
is all un need paper that's what's the
meaning of head that's what's the
meaning of multi head attention so that
each head can compute the attention
independently from each other so we will
also parallelize along the head
Dimension and moreover if you look at
this definition of the query block we
can also split the query into blocks and
each query block can work independently
from the other query blocks by in
producing one output block this is how
we are going to parallelize so we are
going to parallelize each sequence in
the patch in inside of each sequence we
are going to paraliz each head and
inside of each head we are going to
paralyze each query block so how many
programs we we will have working in
parallel at most it will be uh the sequ
the number of batches so the bch the
number of sequences in the batch so the
batch
size it will be the batch
size multiplied by the number of
heads uh multiplied by the number of
blocks that we will divide the query
sequence into so let's call it the I
don't know block size
Q the block
size
Q all right now that we have seen this
one let's go actually code it
so I have already introduced a little
bit of the differences between my
implementation of The Flash attention
and the one that you can find on the
Tron documentation which is a first of
all I don't work with fp8 because I
believe this is unnecessary for our
explanation it's of course much faster
because the recent gpus also support
fb8 Second difference is that in the um
In the Flesh attention on the Triton
website the backward pass is only
implemented for the causal attention but
in my case I implement it for the causal
and the non-causal attention even if
it's slower and later I actually I want
to give you an exercise on how to
improve it um and the third difference
uh main difference is that I made make
explicit use of the soft Max scale so I
actually use the scale when needed
another difference is that in the online
uh Tron computation of The Flash
attention is this X is not really e to
the power of X but it's two to the power
of X and then they compensate it with by
by using the
logarithm however uh because probably
the implementation of two to the power
of X is faster than the e to the power
of X but in my case I retain the
original exponential because I want to
follow the original algorithm to make it
simpler to visualize the code along with
the algorithm as in the flash ration
paper
so uh I know I have created a lot of
hype so let's do it uh let's start by
creating a new um file let's call it
program. piy uh just like before when I
introduced Tron I will start by coding
first the code that will use our Kel and
then we code the Kel and we will only be
coding the forward P of the Kel
so let's start
by importing what we need to import
which is just the torch and the Triton
and secondly let's start by let me check
okay the co-pilot is already off so I
don't have to worry about that let's
start to implement the code that will
test our implementation of the Triton
and compare it with the naive
implementation of the attention
mechanism so we in uh we create our
query key and SE uh query key and value
sequence for test thing which is if you
remember it's query is the BET size it
has the dimension bet size because we
have a multiple
sequences each sequence has a number of
heads and it's made up of s tokens and
each token is identified by a head dim
number of
Dimensions uh if you in then this is
because we have already split each token
into smaller tokens each each with its
own head Dimension if you remove the N
heads Dimension then you put back you
concatenate all the dimensions of this
head
dim um we initialize the query key and
value sequence by using a normal
distribution this code I already took
from the tutorial of Tron so it's
nothing different and we require the
grad the gradient because we want to
compute the gradient with respect to
query key and value and um we will see
later why because because we want to
implement the back we want to test also
the backb pass even though we will not
be coding it now so the first thing that
we do is we Define our soft Max
scale which is as you remember the
formula is U query multiplied by the
transpose of the keys and then divided
by the square root of head
Dimension so DK or the DHE head sometime
it's called and then we need to so we
need to compute this one we can already
compute it it's this this is one over
the square root of the head dimension
um and then we also Define d o and later
we will see what is this but this is
basically we will be needed needed for
the backward
pass um don't worry if you don't
understand what is doo later we will see
it let's do the naive implementation of
the attention which is very simple which
is first we Define the mask and we use
this mask only if the attention we are
Computing is Cal so as you can see we
pass this parameter called causal that
tells if we want to compute the causal
attention or the not causal attention
and the D type which is float 16 because
we want to work directly with 16 bit
floating Point numbers we will not be
working with fp8 just because we we
don't we don't want to implement uh my
implementation is actually not as fast
as the one in the tutorial of the Tron
website but I believe it's much more
easier to
comprehend uh so we Define the mask
we compute the the the product query
multiply by the transpose of the key
divide by the square root of the head
Dimension so that's why we are
multiplying by Soft mask scale if the
attention we're Computing is Kaa then we
use this mask that we have computed so
we replace all the points all the dot
products where this mask is equal to
zero with minus Infinities and then the
soft mask will replace this minus
Infinities with zeros because then we
are applying the soft Max and the soft
Max is applied by rows just like the
normal attention we
compute okay the second thing that we do
is we want to um so the output is the
product of the output of the softmax
with the V so this is the reference
output on the naive implementation of um
flash of the attention mechanism then we
want to compute we want to also derive
the gradients of the output with respect
to the um
inputs and in this case it's the
the the V the K and the Q later we will
see what are we doing here then we want
also to we want to compare this
reference implementation with our Tron
implementation so let's do it so our
Tron implementation will be implemented
as a class called Tron attention that we
will call using this method called apply
and later we will see what is this
method in which we pass the query key
and value if we want to compute the Cal
attention the soft Mark scale that it
should be using and it should prod
produce some output which is the output
of the output of the softw multiplied by
V then we can run also the backward and
this backward will be the the the the
same backward that we will compute with
the um Tron
attention and then we
compare okay and then we can compare uh
the result of our implementation so this
Tron attention not apply with the
reference implementation which is this
one here and they should be uh we use
the the function all clause which
basically compares the elements of two
tensors and make sure that their
absolute difference is no more than this
one we are not using the relative
distance we are just using the absolute
distance between the two elements
corresponding elements of two vectors
this uh implementation that we have that
we will build will work with the causal
attention and also with not causal
attention while the the one that we saw
in the website of Tron it only works
with the uh the forward pass actually
works with the causal and non causal
while the backward pass only works in
the case of the Cal attention um okay
but it's highly optimized the one online
so if you want to learn a little more
tricks on how to optimize Triton kernels
there is a lot of knowledge there anyway
guys now let's try to uh implement this
Tron atten at least the forward pass so
let's go to implement this Triton
attention
class um okay here every time you want
to introduce a new operation into torch
you need to uh derive the um you need to
implement your operation by deriving
from this autograd do function class so
every operation in torch actually if
it's the soft Max or it's the I don't
know the the ru or the zgo or whatever
there is it is always implemented as a
as a function is a class that derives
from this function and it should provide
two method one called the forward pass
and one called the backward pass the
forward should produce the output of
this operation and the backward should
compute the
gradient um the gradient with of the
loss with respect to that the the input
of that function and later we will see
how that works for now let's concentrate
on the forward pass to implement the
forward pass we need to create a static
method that is called
forward which takes as input one thing
called the context so as you know in the
autograd in when training AAL networks
we have the forward pass and the
backward when Computing the backward
pass we need to reuse the activations of
each of the computation nodes during the
forward pass and this context basically
allow us to save the information to uh
for the necessary activations that we
will need during the backward pass and
later we will see in the tron um in the
flash attention algorithm what
information we need to save in order to
compute the backward pass for example
what we will need to save during the
backward pass we will need to recompute
on the Fly the soft the query multiply
by the transpose of the keys for each
block but we don't want to recompute the
normalization factor or the maximum
value for each row so we will save those
two values and actually we will not save
two values we will save one value with a
trick called The Log sum X log sum X
trick that we will see later anyway this
context is just a kind of a storage area
where we can save some stuff that will
be necessary for us to recompute the
backward and you can see whatever you
like then we have the input of this um
operation which is the query key and
value which is three tensors with the
Cal if we are going to compute the Cal
attention and the soft Mark scale that
we should apply based on the one over
the square root of the uh head
Dimension uh which we could also compute
it on the fly actually by the way by by
checking the shape of this but okay it
doesn't matter anyway so um the first
thing that we are going to do is to
extract the shapes of this objects and
make sure all the shapes are what we
expect them to be so the shape of the
query key and value is a bch size by
number of heads by sequence length by
head Dimension we make sure that the
head Dimension matches for the query key
and value uh they should match because
each Vector should should be of the same
size uh and then we declare what we
pre-allocate the output Vector so where
we should save our output so as you
remember the output in and the attention
mechanism has the same same shape as the
query key and value sequence where theer
Quan value sequence I want to remind you
is not the query value of the input of
the attention which is a sequence of
tokens but it's the output already of
the WQ WK and WV because flesh attention
is not concerned with optimizing those
metrix multiplication but only the
output of the WQ WK and
WB so we pre-allocate the output tensor
where we will store this output which
has the same shape as the query key and
sequence uh um Matrix
actually actually no not true actually
it has the same shape as the query but
it may not be the same as the key and
value why because there is this thing
called cross attention where the query
key and value are transposition are
different projection through WQ WK WV
not of the same input sequence but of
two sequence so cross attention happens
when we have a query that comes from one
uh sequence and the key and value come
from another sequence and they pass
through their own WK and WV and they may
not have the same sequence length so the
shapes of the output of the tension only
depends on the shape of the query
sequence not of the key and value
sequence this is happens during cross
attention but usually in language models
we always work with the self attention
so that should not happen um at least in
the Cal language models um then we have
um the stage and later we will see what
is this stage uh basically the stage
it's just a number that that tells if
the um operation that we are going to do
later is for the causal attention or for
the not causal attention and then we
need to La Define our launch grid the
launch grid tells us how many parallel
process we need to be launched by Tron
actually they will be launched by Cuda
but by we always work with the Triton as
an interface to Cuda so by
Triton so in Tron as I said before we
want to parallelize along the batch
Dimension so each batch each sequence in
the batch should work independently from
each other not only each inside of each
sequence in the batch each head should
work independently from each other so at
least we have a bch size multiplied by
number of heads
programs and for each of these program
we have another uh Dimension called the
we divide the query into blocks of
queries so um as you remember when
talking about a block matrix
multiplication we don't work with the
query as the original Matrix query
Matrix so where each query is one vector
or one token we work with group of
queries so each block of queries is a
group of tokens in the query
sequence so we are saying that we want
to launch at a number of um um kernels
or blocks of threads or a group of
threads along two Dimensions just like
the Cuda kernel can be launched along
Ong two Dimension X and Y here we are
launching programs along two Dimensions
one Dimensions that tells us which batch
which head of which batch we are going
to work with so which head of which uh
batch element are we going to work with
and inside this we are going to say okay
this is a sequence which group of
queries are we going to work
with are we going
to going to work with so overall and the
group of queries is what is the sequence
length divided by the number of queries
that we want to group together so the
block size Cube tells us how many
queries are there in each block of
queries so this CD is just the ceiling
division so it is equal to let me write
it here this is equal to ceiling of
sequence length divide by the block size
Q
this tells us uh how many blocks of Q we
have so let's rehearse we have a tensor
that is q that is B size by number of
heads and each flashh attention
algorithm will work with the following
the sequence length head Dimension
moreover we have seen that the flesh
attention has two Loops one is the outer
loop among all the query blocks one is
is the inner loop along all the key
block we have seen that the query block
can work independently from each other
so we can spawn as many programs in
parallel as there are number of blocks
of Q because they can work in parallel
so this grid tells us how many programs
there are that can work in parallel then
it will be the GPU that based on its
resources will decide how many program
actually to work in parallel if it has
enough resources to make them all work
in parallel wonderful if it doesn't have
enough resources to make them work in
par it will launch them sequentially one
after another and the last Dimension is
this is like the Z dimension in the
Cuda in the Cuda launch grid and we
don't want to use it because we don't
want an additional um level of
parallelism all right this is our launch
grid so we will launch a number of um
programs that is this one number of
programs of parallel programs or number
of parallel Kels and each kernel in Tron
work is a group of
threads which is a batch size multiplied
by number of
heads multiplied by number of blocks of
Q so how many blocks we have in we
divided the Q sequence
into okay let's continue so the we will
see what is this one so this m is
another Matrix that we will need and
it's the log sum Expo for the backward
pass and we will see at the end of this
video what it not at the end of this
video but at the end of the forward pass
what it's needed for but basically this
is you can think of it as the maximum
for each row um you we to to recompute
the query multiply by the key in the
backward pass we should also have if we
don't want to recompute the maximum for
hro and the normalization factor of the
softmax we should save two things one is
the maximum for each row and one is the
uh the normalization factor however by
using the log sum X trick we can only
save one value which is the as you can
see in the um algorithm of FL attention
it's this stuff here which is uh let's
see here it's this stuff here so this Li
which is the maximum for each row plus
the logarithm of the um of the
normalization
factor um and
basically in when Computing the back
pass we need to recompute on the Fly
this block here so the square multip by
the transpose of but to apply the soft
Max as you remember we need to have the
maximum for each row and the
normalization factor so uh we don't um
we don't recompute them during the
backward because we have already
computed them during the forward so we
save this information but we don't need
to save this two information separately
we can aggregate it into one single
value called Li and later we will see
how we can use
it all
right so we have defined find also this
one and we can proceed further so now we
launch our grid our
kernel don't be scared it's going to be
a little long so here so we are
launching the the kernel for the forward
pass by defining what is the launch grid
so how many of this program should run
in parallel at most we are passing the
query we are passing the key we are
passing the values we are passing the
soft Mark scale the M which is the
information that we need to save for the
backward pass it's actually the L in the
code of the pseo code of The Flash
attention algorithm here I call it m I
think because also in the original code
it was called
M um the O where the our kernel should
save its
output and then as you remember uh we
don't get all the nice access by um
indexing tensor like we are used to in
torch we only get a pointer to the
starting element of q a pointer to the
starting element of K and to the
starting element of v and then we have
to figure out all the index in the
memory of the other elements how to
calculate the index we need the stride
because the stride tells us how many
elements to skip to go from one
dimension to the other and that's why we
are passing the stride for each
dimension of each tensor actually in our
case uh we are only working with q k and
V that are actually of the same D type
and of the same shape so we should not
need actually to pass all all the stride
for each of these uh
tensors um because they should have the
same strides however in the original
code I believe they they were passing it
so I kept it so the stride allow will
allow us to index these pointers to
understand um to access the elements of
of this tensor just by using its
starting uh the pointer to the starting
element and then the strides we will be
able to index any element we want in the
tensor then we pass the information of
these shapes so the bch size the number
of heads the sequence length and the
head
Dimension and U which is the same for
all of them and then the stage the stage
indicates if we are going to compute
causal attention or not causal attention
so let's not implement it and let's
continue writing this method so the then
then we need to save some information
that we will be needed for the backward
pass which is this context variable that
I told you before so we save some
information for the backward pass which
is the query key and value uh which are
the tensor for which we want to compute
the gradient during the backward pass um
and then um um we need to store also
this m tensor and this o
tensor um then we can we need to also
store the causal uh variable so because
if if we computed the caal attention
during the for forward pass then during
the backward pass we need to um have
this information because we need to mask
out the things that we don't want to
contribute to the gradient but we will
see that later when Computing the
backward pass for now let's concentrate
on this attention
forward so we need to implement this
forward kernel that you can see so uh
underscore attention underscore forward
method now a Tron kernel is just a
python method with a particular
decorator called Tron go.get so we copy
and paste the signature so this is what
makes a method become a Tron kernel and
as you can see here we pass the query
key and value Matrix along with other
information the M Matrix please don't
confuse the M Matrix with the mask that
we will apply um uh on the Fly we will
generate it on the Fly because we are
only concerned in this case with a cal
attention or not causal attention we do
not accept custom masks
here um then we pass the strides of all
these tensors the B size the number
number of heads the sequence length the
head Dimension which is the shape of
each of these uh tensors um and the
block size Q and The Block size KV the
block size Q indicates how many queries
we want to group together to make one
block of the Q Matrix and how the KV
indicates how many keys and values we
want to put together to make one block
of the K and V Matrix which is what we
do when we do block matrix
multiplication this stage is a number
that indicates if it's a caal or um not
causal attention we are doing so it will
be three in case it's a caal and one in
case it's not
causal okay the first thing that we do
is to verify some information so we
verify that the um the block size of the
KV is less than or equal to the Head
Dimension to be honest I don't think we
need it with my code because I removed
most of the constraints so this uh this
check was also present in the original
code so I kept it but it all depends on
how we are later we will see what is the
autot tuning process and later we will
see uh what variables we are going to
autotune for and how many stages we will
choose how many warps we will choose etc
etc so let's leave it for later you can
comment it or you can keep it it
shouldn't
matter um the first thing that we do as
I said before we launch a grid so a grid
is a series of programs where we will
have some identifiers like in the coda
we had an identifier for the blocks on
the x axis and on the y axis in Tron we
get this identifier for the programs we
launched um
um uh sequence length divide by block
size Q number of programs along the zero
axis and the B size multiplied by number
of heads on along the first axis of the
uh green grid of the launch
grid which will help us identify which
um part of the query we are going to
work with in this program in this uh
kernel and also in which batch and on
which head this program should work with
so that's what we are going to do now we
are trying to understand what part of
the input we should work with based on
the IDS of the program which corresponds
to the block ID in
Cuda uh so let me copy so the program ID
zero indicates it's this stuff here
tells us which part of the queries so
which block of the queries we are going
to work with why do we have a block on
the query because as we saw before the
output can be computed independently for
each block of the queries while each
block of the query has to iterate
through all the key and
values um so this is what will tell us
what is the index of the block of the
queries we are going to work with in
this particular
program then we can understand also
which index which batch and which head
this program is associated with the
program ID number one is the product of
the uh batch size and number of heads it
means that we will have as many programs
on the axis number one as there are uh
indicated by this product so this
product lets us understand that this uh
product will tell us which batch and
which head this particular program is
associated with so uh to get the uh ID
of the batch we just divide this number
by the number of heads and it will give
us the head index and to get the head
index inside this batch we just do the
um this number here uh modulus the
number of heads
okay uh the next thing that we need to
do uh we need to okay first of all when
we pass a tensor because as you can see
here the Q parameter to this attention
forward method is a tensor because it's
the input of this function forward
function and this forward function is
called here when we do attention do
apply and it's this Q stuff here and
this Q stuff here has been created as a
tensor so when we pass a tensor to a
Triton kernel it's not really a tensor
it is a pointer to the first element of
that tensor in the memory now we need to
understand because now we know which
batch we are going to work with and
which head we are going to work with we
need to Index this tensor to select the
right batch and the right head inside of
the right batch which means that
basically we have this Q uh tensor so we
need to do some some sort of like some
stuff like this like Q of the index
batch and uh number of the number of
heads in the case the the which head we
are going to work with so it should be
index of head and we need to select
every everything that is inside this
indices so we are we need to enter the
tensor at the right location where the
particular sequence length and head
dimension for this batch and for this
head starts for that we need to generate
an offset in which we need to move this
tensor from because this T this pointer
sorry this not tensor this pointer from
because this pointer is pointing at the
beginning of the entire tensor so we
need to move in the bch size dimension
and in the number of heads dimension
to do that we generate the following
offset which will tell us where this uh
where this particular batch and where
this particular head starts in this
tensor and to do that we need to do the
strides we need to use the strides so
what we are going to do is we are
create going to create the qkv offset uh
this is should be the sequence
length um which will be the index batch
multiplied by The
Stride for the batch Dimension which
will tell do how many elements we need
to skip to get to the next batch and
it's based and we multiply it by the
index of the batch that we want so for
the zero batch we don't skip anything
because we already pointing to the first
element of that batch but if we are at
the batch one we will skip that many
number of elements plus we also need to
skip the some heads how many head we
need to skip based on which head we are
going to work with and what tells us how
how to go from one head to the next The
Stride of the head dimension so we
multiply the index head so the head that
we should be working with with the
stride Q
head all right then um we select now
Tron helps us with a new function that I
think it was quite recent that helps us
index element inside of a tensor without
having to deal with all the complex um
indexing maths that can be confusing for
beginners so I will be using few methods
to help us with this um uh with this
indexing and this function is called
make block pointer and it's this
following so basically this make block
pointer takes as input a
vector and sorry a pointer not a vector
takes as input a pointer in this case we
are
saying create a
block that has the following shape that
is sequence length by head dimmension so
let me do it one by one actually I don't
want to confuse you guys with all this
stuff Al together okay so take start um
there is a pointer that is right now
pointing at Q Plus Q KV offset so right
now it is not pointing at the first
batch but it's pointing exactly to our
batch so the the batch that this
particular program should be working
with and inside this badge to the
particular head that this program should
be working with which is is basically
saying that we have um we are pointing
to a tensor that is as follows so we are
pointing to the following tensors which
is the right
head the right uh sorry the right batch
and the right head and then we are
selecting everything inside so it's
pointing to the first element of this
particular
tensor this tensor particular tensor
because we have already selected the
batch and the head it is a
two-dimensional tensor with this the
following shape because the the
following dimensions are sequence length
and head dim so we are saying take this
pointer which contains a tensor of the
following shape sequence length and head
Dimension and I'm also giving you the
strides of this Dimensions that are in
this pointer so the the the two
Dimensions that are that we need are the
sequence Dimension and the head dim
Dimension which is this one for the Q
tensor and
um and in this um in this
query uh tensor we want to
select a block of queries based on the
query on the block of queries that this
program should be working with so I
think I need to maybe probably use the
iPad otherwise it can be very confusing
to visualize so uh let's do it actually
so let me see if I can create another
another here and let's use the
iPad all
right okay so we have a Q Vector Q
tensor because this construct we will be
using it for all the other tensor so if
you understand it for one tensor you
understand it for all the others we have
a Q tensor that is a
batch uh by number of
heads number of heads then the sequence
length and then the head
Dimension with the following line so the
this line here so when we create Q Plus
uh qkv offset we are already selecting
the right batch Dimension and already
the right head Dimension which means
that we have already forwarded our Q to
not point to the first batch and the
first head but to point to the exact
batch that this program is working with
and the exact head that this program is
working with which basically means that
right now it is pointing at a tensor
that is made up of these two
Dimensions now inside of this tensor we
also need to select the right block of
query that this program should work with
and this Dimension here so the sequence
Dimension is all the queries so we need
to select select the right queries so we
need to skip some queries how to skip
some queries well we say that we need to
skip block index multiplied by block
size Q number of queries because they
will be processed by another um by
another uh program that will have this
number here the program ID will be
different so we are selecting with this
line not only inside of the que the
right index and the head but also the
right position in this dimension in the
sequence length Dimension that will
point to the exact to the starting point
of the exact query block that this
particular program should be working
with this is what it's happening and we
are also creating this block basically
later we will see how it can be used um
to uh to create a block of the shape uh
we are telling what is the the size of
this tensor so this tensor has two
Dimensions because we are pointing to
the beginning of the right query
sequence so it has only two Dimensions
the sequence Dimension and the head dim
Dimension so it's the last
Dimension um and we are already pointing
to the right beginning of the sequence
uh Dimension because we have already
skipped some queries why we are skipping
some queries because these queries will
be handled by another program that will
have a block index Q to some other
values um and this order uh actually I
don't know what is this order you can
try to put 01 and one two I think it's
some optimization that Triton does I
have read the online documentation and I
couldn't find anything about it so this
is something that I will investigate but
actually even if you put a 01 it doesn't
matter so I think it's something that
you tell Triton uh if this you want the
transposed of this block or you want the
not transposed version of this block and
later we will see actually how we can
transpose the key block without doing
any transpose operation actually we we
will just change the strides like we
have seen before so um now this make
block pointer is not something that is
necessary but it makes our life easier
when we will Index this particular
pointer so we can treat this pointer
nearly as um nearly in the same way when
we work with the tensor in pytorch we
will be able to skip one uh increase uh
one index in one dimension without
having to do the computation of the
strides later when doing the backward
pass I will not use this one and do all
the pointer indexing by hand so you can
check the differences of indexing a
tensor by using make block pointer and
not by using it anyway to rehearse what
are we creating we are creating a
pointer to the right index in the batch
to the right index in the head Dimension
and we are already skipping some queries
based on the Block index que so this
pointer is already point to the right
block of queries that this particular
program should be working with let's
look Instead at the V and the K block
now so let's copy the V block now which
is similar to the query but we are not
going inside we are only indexing by the
index badge and the index head so what
this one actually let me write it here
is already skipping
so this amount of
queries this is what we are indexing
with this make block pointer so we are
in the right batch in the right head and
we are skipping some
queries here we are just indexing by
batch and by head so we are doing V of
index batch index head and we are not
selecting we are not skipping anything
because you see this offset is equal to
zero in the First Dimension in the
second dimension so we are not skipping
anything on the sequence length and we
are not SK anything in the head
dimmension Dimension head Dimension
Dimension um all right so let's look at
the K block pointer and this is
different because as you know when
Computing the flashh attention algorithm
we need to have access to the block of
queries and all the block of the key
transposed so when accessing the key we
shouldn't access it like we are
accessing Q we should invert the two uh
Di Dimensions that we want to transpose
for and that's very simple with make
block PTR and you can see it here we say
that we want to point to the right index
and to the right head and the tensor be
inside of it so let's let me write it
here so later I can explain in line by
line so what we are doing here is go to
the K tensor select the right batch
select the right head select everything
that is inside so it's a tensor of two
Dimensions with the sequence length and
the head dim because we we you can see
here um here uh sequence length and head
dims Etc but we don't want first
sequence length and then head dim we
want first head dim and then sequence
length so we want to transpose it how to
transpose it we just say that you need
to read this tensor with the two strides
transposed so we are saying first use
the stride of the dimension Dimension
and then use the stride of the sequence
Dimension and the shape of this uh
tensor is not sequence uh head dim it's
head dim sequence and it's a block of
KVs um why we are not putting directly
the sequence Dimension here because we
want to skip block by block later so we
are not selecting all the sequence
length in the sequence Dimension we are
just selecting a block of KVs and later
we will use another method to go to the
next block
so I hope that by showing you the
indexing like this it's a little easier
to follow the indexing so for each
tensor we are going in the right batch
in the right head Dimension and for the
query we are skipping some query blocks
because each each program will work with
a small different um query block but for
the key and value each program needs to
iterate through all the key and value so
we just point it to the first key and
value block and then we will advance by
one block by um we will advance uh one
block uh by one during the for Loop that
we will do
later then in the output also we need we
can make um a tensor block
tensor this basically creates a pointer
just like in the query key and value
case in which we select the right index
batch so what we are doing is we are
indexing by
batch we are indexing by head and we are
selecting everything that inside Unown
we are not cting everything inside we
are skipping also in this case some
blocks of queries uh because as I said
before the output has the same shape as
the query so um the this particular
block this particular program that we
that will have this particular block
index Q will only work with one block of
uh the queries which will produce only
one block of the output Matrix and we
need to select exactly that one so we we
can point this pointer exactly to the
point where we should start writing so
let's skip also in this case block index
Q multiplied by block size Q um
rows so we select exactly the block that
our um our program this particular
program will produce when I speak about
this particular program I mean the
program that is identified by this
program ID in the xzero axis and this
program ID in the first axis because
each of these program will run in
parallel hopefully and each of them will
have a different value for the block
index q and index batch
head okay now we have pointed our
pointers to the right position where
they should either read some information
or they should either write some
information by using make block pointer
these uh pointers can also be treated
directly as tensors so that's why we
specify the shapes of these tensors
because python uh Tryon right now
provides some methods to work directly
with uh blocks of um to work directly
with pointers like they are we are
accessing um tensors so we can index
them like
tensors all right so basically just try
on doing some calculation for you based
on the strides so you don't have to do
it by hand but later when we do the back
part T we will avoid using make block
pointer and we will see the indexing
done by
hand all right uh um as you know we are
processing a single block of queries so
let let's go back to the um algorithm
otherwise we we lose uh the site of what
we are
doing so let's go here and let's show my
iPad all right so as you know each
program we will parallelize along the
query block Dimension so each program
will work with a different query block
and then we need to do a for loop on all
the key and values
blocks right now we just uh moved our
pointers to the right position to select
the right query block that we should
work with and to the beginning of the
keys and values block that we should
work with based on which index and which
head this particular program should be
working
with all right now that we have pointed
our pointers to the right position in
which our program should be working it
inside of the big pointers that are um
inside of the big tensors that are the
that have the batch Dimension the number
of heads Dimension the sequence length
Dimension and the head Dimension we have
because we are pointing to the right
batch and we are pointing to the right
head these tensors have become two
dimensional tensors so they only work on
the they are only tensors on the
sequence length and on the head
Dimension now we need some mod some more
information that we will use
later the first information that we we
need is the offsets of each query inside
of the current block of queries that
this particular program should be
working with and that is given by the
following line so let me copy and paste
which is this one so the offsets of the
queries are the first of all they are
how many of them block size Q because
each block of queries is made up of
block size Q number of queries what is
each where it's a token and it's on the
head Dimension is um the dim Dimension
is not the all the embedding of the
token but a part of the embedding of
each token which part the part
corresponding to the head that this
particular program is going to work with
so we are generating the offsets that
will load this particular number of uh
this particular queries from the big
tensor that contains all queries and we
know that our queries start at the Block
index cu multiplied by block size Q
position so if this is the program
number zero they will the Imagine block
size is equal to four they will be the
query with index 0 1 two and three but
imagine we are the program number three
which means that we need to skip um
three multiplied by four so 12 so it
will point to the query number 13 14 15
and 16 etc etc
etc all right and we do the same for the
key and values initially the key and
values um is a range of keys and values
that we need at e each iteration and at
the beginning because our pointer for
the K andv is pointing to the beginning
of the sequence of key and value for
this particular badge and for this
particular head we are pointing to the
first block of key and value so we are
not skipping anything in the query case
we are skipping because our program will
only work with one single block of
queries in this case we don't skip
anything because we need to iterate
through these key and values so we are
pointing to the first uh block of key
values so imagine block size KV is equal
to four so this stuff here will be equal
to zero 1 2 and
three all right now we need as you
remember inside of the flashh atten
algorithm we need to compute a block of
query multiplied by the transpose of the
keys and to each of this block we need
to apply the soft Max star if you
remember what is the soft Max star is
the soft Max of without the
normalization so while Computing the
soft Max star we also actually compute
the normalization factor without
applying it and we apply the
normalization factor at the end so for
each block of query multiplied by
transpose of the keys we need to have
the maximum for each row in this
particular block and the normalization
factor for each row so that's why why we
need these two following
statistics which is this one and this
this is basically um a block uh it's a
block of numbers how many based on how
many queries we have in our block of
queries each one initialize with minus
infinity just like in my algorithm that
I show before so let me go back to the
slides in case we forgot um or actually
you can also check the flash attention
algorithm we initialize it with minus
Infinities so so far we are creating
this stuff here so we are initializing
the Mi we are we will be initializing
the LI I and we will initializing the o
and then we will show the inner loop
here um and this is exactly the
algorithm that we have seen before so we
initialize M with minus
Infinities now we initialize also the
L's so let me go back to the
code uh all
right so the L's are initialized with
this number here so
here in the O blocks as we can see from
the flashh algorithm they are the O
Block is initialized with zeros so
that's why we initialize a block this is
the output block that this particular
program will compute which is based on
the position in the batch and the
position in the indexes so it is one
block of the size block size Q so how
many queries there are in this block by
head
Dimension um which if you want to
visualize it let's go back to the slides
it is equal
to one block of this Matrix here so it's
one one block of the output Matrix so
one row of blocks uh one block of
rows uh okay so let's go back to the
code
now all right so now we have initialized
little stuff here so the output the Mi
and Li where Mi is the maximum for each
row in this particular query block and
the LI is the normalization factor for
each of the um items in the query for
each of
the rows in our query
block now we need to do the for Loop the
inner loop in the flashh attention
algorithm uh we will create a separate
method that will run the inner loop so
let's let me copy
it here and I am following the same
structure of the code that you see in
the tutorial of the Tron
website so um basically if we are
running the Cal attention or even if we
are not running the Cal attention we
make this for Loop and then we will make
another for Loop and I will show you why
so let me first write it and then we
will see so this function here will be
the inner loop this inner loop needs to
go through all key and value blocks one
by one and for each query and value
block it needs to fix the previous
calculated block of uh the the previous
softmax star block so basically what we
are doing here we will need to create a
function as the following where we are
going to iterate on all the key value
block we will need to compute the query
multiply by the transpose of the keys
using the query block that is fixed for
this program and the key is block is the
one that we are iterating it through and
for each of these queries we need to
calculate what is the maximum for each
row we need to compute the softmax star
so the softmax without the normalization
factor
we need to keep the statistics L which
is the normalization factor that we will
apply at the end of the iteration of the
for Loop and at the same time we need to
update the output so as you remember the
output is P11 * by V1 plus p12 * by vs2
but we need to fix the previous P11 so
to fix that we need to every time we sum
to O to the output we need to fix the
output of the previous
iteration um and then we increase
introduce the p and v block of the
current iteration so um here the author
of The the code for the the one that you
see on the Tron website decided to split
this for Loop into two steps why because
in the causal attention we need to uh
when we have a caal attention we have a
group of um we we don't we don't want
the query to attend keys that come after
it while in the non-causal attention we
let all the queries attend to all the
keys which also means that we will need
to have some kind of if statement inside
of this if U inside of this for Loop
through all the key and values in which
we need to check if the this particular
query that we are working with is comes
before or after the key and value in
case we are doing the caal attention so
instead of uh iterating through all the
keyan values also in the case of the Cal
attention by splitting it into two uh
steps we are saying uh first let's
iterate through all the keyan values for
which the index is smaller than the
current queries block and for this we
need to compute the attention in the
case of the causal and non-causal case
then for all the elements on the right
of this block so for which the key index
is more than the Q index in the case of
causal attention we don't need to
compute anything because it will be
masked out because in the soft Max it
will become zeros so it will not
contribute to the output so we don't
even have to compute
it um this is why we split this this for
Loop into two steps so first we iterate
to all the parts that are left to the
diagonal of the query multiplied by the
key
Matrix so for all the values for which
the query index is less than the key
index then we um and then we skip all
the parts to the right of this diagonal
in case we are working with a cal mask
but in case of the non-causal Mask we
compute the left part and the right part
of this
diagonal all right don't worry when we
recorde this for Loop it will be more
clear so I just wanted to give a little
introduction so let's go uh code this
inner loop what will this inner loop do
it will work with this particular query
block that we have found so this Q block
it will uh why I don't see the Q block
because I didn't load it well yeah uh
let's load it so we need to load the
query block actually we forgot to load
it so as you remember in Tron we we load
um data from the high bandwidth memory
to the SRAM so to the shared memory by
using the load statement and we are
telling load the query block that we
should be working with because this
pointer Q block PTR is already pointing
to the right block that we should be
working with so it's already skipping
all the blocks that other programs
should be working with and it will load
a uh a tensor of size of block size Q
had dim so the right block of
queries and we pass it to this inner
loop to which we pass the output so
where it should write this output the Li
and Mi which are the statistics for the
rows and for for the maximum for each
row of each query and the LI I which is
the normalization factor for each query
and the query block this program should
be working with the beginning of the key
and value block pointer because we need
to iterate through them so we just point
it to the beginning and then inside the
for inner for Loop we will iterate
through them then the soft Max scale
that we should use when Computing query
multiplied by the transpose of the keys
the block size so how many queries we
have in each block of Q and how many key
and value we have in each block of KV
uh this is a stage that tells us what uh
if we are on the left side of the
diagonal or on the right side of the
diagonal so it will tell us if we need
to apply the Cal mask or not based on
where we are and if we are need to apply
the Cal
mask um the offset q and the offset KV
are just the offsets of the query and
key inside of each q and KV block which
is a list of indices that tells us um
how many queries we have
uh and then the sequence length the
entire sequence length because in the
for Loop we need to iterate to all the
sequence length block by block so block
of KV block of KV block of KV all right
let's write this me let's write this
method and later we need actually need
to continue this method again so let's
go and let me go
here all right
so uh this method we have already seen
the signature so it's just another
kernel so it can be called by the first
kernel and this is something you can
also do in Cuda you can actually call
call one Cuda kernel from another Cuda
kernel um and then we based on the stage
of this inner loop we decide what we
need to do so when we are using caal um
causal attention so we only want to
apply the um attention to the queries
for which the index is less than or
equal to the key so we only want the
query to never attend to key and value
that come after it then um we pass the
value three for the stage parameter now
when we in the Cal case this will become
4 minus 3 it is equal to 1 so what will
happen is that we will only work with
the range of um keys and values that are
are from zero up to the current block of
Q so all the keys that whose index is
less than or less than the the the index
of the queries we are working with so to
the left part of the Cal mask let me
draw it otherwise I think it's going to
be very difficult to follow so let's do
it actually so let's open a new one and
let's go here all right so we have been
using this one before so we can do it
again clear page all right in this now I
I want you to think of the following uh
Matrix as a block Matrix so let's draw
it in pink because I have been drawing
it all in pink we know that in the rows
of this query multiplied by the
transpose of the keys we have a uh the
queries blocks of queries so we are not
watching one single block we are
watching all the blocks right now so
this is the query block one this is the
query block two this is the query block
three this is the query block four each
of this query block is made up of
multiple tokens of queries and then we
have the key the key
blocks uh let's do it like this very
ugly but okay uh key 1 key block two key
block three key block four when apply
calculating the attention when you
calculate the caal attention so um like
with the causal mask you want only the
query to attend to keys that come before
it so when we apply the causal mask this
stuff here will be made up of zeros this
stuff here will be made up of zeros this
stuff here will be made up of zeros and
this stuff here and this stuff here and
this stuff here all made up of
zeros we never have to mask out anything
when we are in this case
because well when we are in this
particular scenario actually in this
particular scenario we don't need to
mask out anything for sure why because
all the key um keys in this block so in
this block of keys keys will have an
index that is smaller than the index of
the corresponding queries in case the uh
the key the block size of the query and
the key matches so imagine each query is
made up of three queries so each block
of query is made up of three queries so
this is the query number 0 1 and two
this is the query number 3 4 five really
3 4 five yeah this will be the number uh
six 7 and eight and this will be the
query number nine 10 10 and 11 in total
we have 12 queries we will have the same
indices also for the keys in case we
choose the same uh size for the blocks
so this key key block here will be the
key number 0 1 and
two this will be the key number three
four five this will be the six six 7 and
eight etc etc ET now what happens is
that in this case as you can see the key
in this in indices of the keys are
always smaller than the indices of the
queries so we don't need to mask out
anything even in the case of the Cal
mask because we are sure that in this
case all of these dot products will
never be masked out also in this case
all these dot products will never be
masked out and also in this case we'll
never be masked out we'll never be
masked out and we'll never be masked out
and in this case however along the
diagonal some of the queries will be
more have will have an index that is
bigger than than that of the keys and
some of them will not be uh will not
have an index that is bigger than that
of the keys because these are blocks of
queries and blocks of keys some of them
need to be masked out and some of them
don't need to be masked out so we are
dividing our for Loop into multiple
steps the first step that we are doing
is all to the left of this diagonal in
which we don't need to mask out anything
then we will see another step here in
which we um uh we need to mask out and
then everything to the right of this
will be we will not even compute in the
case of cation because we already know
it's made up of zero so it will not comp
so the product query multiplied by the
by transpose of the keys after the
softmax will be made up of zeros so if
you look at the flesh attention
algorithm so um this stuff here the
contribution will be zero because we are
multiplying zero with v it will be zero
so we don't need to change the output so
why even compute this part of the Matrix
if we already know it's not going to
contribute to the output so we just skip
all those iterations and this is why we
are splitting the for loop I hope now
it's much more
clear all right so let's go back um okay
so uh we are now to the left part of the
diagonal in case of the stage number one
in the case of the stage number two it's
the part in exactly on the diagonal so
in which we need to do some dot products
and some other dot products we don't
need to do and then for the non-causal
attention we just go through the from
zero to the sequence length without
doing this
multi-step
um because we don't need to mask out
anything so this is why we have this
stage this tells us what is the lower
and higher index of the key block that
this particular stage should be working
with all right um now this function here
multiple of is just telling Tron that
this number here is a multiple of this
number so Tron can make some
optimizations so the stage one happens
when
when we are doing a causal attention so
stage number three in this function and
4 minus 3 will become one so imagine we
are in the causal attention we will go
through the key and value block that are
to the left of the uh diagonal with
respect to the query block that we are
working
with um in the case we are doing not
causal attention in this first call to
the uh inner function this the stage
will be one so the uh four minus stage
will be equal to three so we will
execute this part of the if statement so
we will go through all the key and
values in
case uh for the Cal attention only as
you can see here we will do another
iteration here that will only be done
along the diagonal in which we need to
mask out something and we don't need to
mask out something because inside of
each blocks there will be some keys that
have the index uh below the index of the
query and some that have above the index
of the query quy so only in the Cal
attention we will call this function
twice the first time with Stage equal to
one and the second time with Stage equal
to two and the second time we will only
iterate through the group of kyv blocks
that are exactly on the diagonal of the
um Matrix S multiply by transpose of the
keys the big Matrix that is made up of
all the
blocks all right now that this should be
clear let's proceed further so let's um
because we need to do the for Loop the
inner for Loop of the flashh attention
let's go and load the first blocks of
key and values which is exactly the one
that the key and V blocks are currently
pointing at which is the Zer z00
block so uh we we defined the the
pointers basically um we we we point the
key and value blocks to the first uh key
and value block that this uh for Loop
should be working with which will be
based on the stage so if it's the first
call to this function they will be
pointing to the first block in the case
of the uh causal and non causal if it's
the second call to this function which
only happs happens in the case of the
causal attention they will be pointing
exactly to the key and value block to
the
diagonal all right then we need to make
the for
Loop so let's Loop over all the for Loop
so let's do it
so Loop over the key and value and what
we do is um okay we we let the compiler
know that this number here the start KV
will always be a multiple of the block
size KV because we will be moving from
one KV block to the next KV block block
by block so we let the compiler know
that this number here start KV is a
multiple of block size KV it doesn't
change anything from a logic point of
view we are just telling giving some
hint to the compiler so it can do some
other optimization that Tron does um now
the first thing that we see in the flash
attention algorithm is we need to
compute the product of the query so this
the particular block of the query that
we are working with with the current KV
Block in this iteration so let's do it
so we compute K andv so we load the the
the query have already been loaded by
the color of this function we have
loaded it
here here we have a already loaded the
query but we need to load the current
block of K so we load the current block
of K indicated by the K pointer and we
multi we do the matrix multiplication of
the current block of query the the block
of query with the current block of K
which is already transposed because when
we loaded this k k when we defined the K
block pointer we defined it already with
the stride changed so we are reading the
T already already transposed so we are
doing the query multipli by the
transpose of the keys
basically okay now let's do
here this part here basically
saying okay if the stage is to when the
stages to is when we are exactly on the
diagonal we know that some of the
queries will have an index that is
bigger than that of the keys and some of
them we have an index that is smaller
than that of the keys so we need to
apply the Cal mask only in this case so
uh basically what we do is we Define the
mask that we should be applying so the
mask will mask out all the values for
which this mask is not true so when this
mask is true when the index of the query
is more than the index of the K and vs
and um we uh okay we apply the soft Max
scale so as you remember we here we only
computed query multiplied by the
transpose of the keys but we also need
to divide by the square root of head
Dimension and we do it
here um and then we because we already
computed
the uh the the product we can calculate
the maximum for each
row and then we uh we we substract
because uh when later in the flashh
attention algorithm we have another
operation which is the which I call the
softmax star and as you remember the
softmax star needs to to do uh each row
each element of the S Matrix so the
query multiplied by the transpose of the
keys minus the maximum for each row so
we can already compute the maximum for
each row and we can also before
Computing the maximum for each row we
need to mask out all the elements that
will be masked out in the stage number
two which is along the
diagonal and how to mask out we just
replace with minus infinity before
applying the soft Max all the values for
which the mask is false um so right now
we are we have computed what we have
computed the query multiplied by
transport of the keys we have masked out
in case we need to mask and when we need
to mask only when we are along the
diagonal in all the other cases we don't
need to mask out anything we just
multiply by the soft Max scale and then
we um we subtract the MJ the MJ is the
maximum value for each row because we
need to compute the softmax star
operation which is the softmax with the
normalization which in the flash
attention algorithm is exactly this
operation which will produce the
P okay so let's go here so now we can
compute the P block which is this stuff
here which is the exponential of the
query KV block variable here which have
already substracted the m so we have
already substracted this uh Mi at the
previous instruction so now we can just
apply the exponential and this is what
we are doing here uh okay then we need
to compute the uh sum of the the row for
the before the um uh normalization
Factor so for the current block we will
have a a list of uh we we will have the
P block for the current KV block to
compute the normalization factor for the
softmax we need to keep summing up these
exponentials and later we will fix the
exponentials the um the normalization
factor that we computed at the previous
step but we will do that later so now we
just computed the normalization factor
for the current block which is just the
sum of all the values on a single row uh
which is the same as what we did before
here as you can see
here when I show you the algorithm so
for each um for each block we do the row
sum as you can see here
of the P Matrix what is the P Matrix is
the exponential of the S minus M and for
now we didn't apply the the the
correction to the previous block that's
it so we computed the L J for the
current K andv block and then we comput
the correction factor for the previous
block so the correction factor for the
previous block if you remember the
formula from the paper is this one is
the exponential of the previous estimate
of the maximum minus the current
estimate of the maximum which is exactly
this one so the previous estimate of the
maximum minus the current estimate of
the maximum uh we will see later why Mi
is the previous estimate of the maximum
and what m j is the current estimate of
the maximum because it is coming from
the current block that we are Computing
Mi is the let's say the the one that um
it is the the one of the previous
iteration because later we will override
Mi with m j but I'm just following the
flashh attention gthm so far so I am
Computing the correction factor of the
previous Li which in the flash attention
algorithm is let me show you uh this
stuff here so it is this stuff here this
one
here um okay and then we apply it so
apply the correction factor so we apply
it so we apply the previous Li with the
correction factor plus the current Li
which is the one coming from the current
P block the one that we computed with
the current K andv with the current
iteration and right now we are doing
this operation so Li is equal to the
previous Li multiplied by the correction
factor all right and then what we need
to do okay we need to as you remember
the formula is um we um calculate the P
block and then we need to multiply by
the V block so we need to load the V
block so let's load it
we loaded the V block based on the
pointer of the V block to which this um
to to which the pointer V is is pointing
to at the beginning of this iteration in
case we are in stage number three so in
case we are doing for example not Cal
attention it will be pointing to the
first k v block uh V block and then okay
here there is just a type conversion so
we make sure this is in floating Point
16 and then we
compute the output block so we are
Computing the following so we just take
v p multiply it by V and we add it to
the output and this is what we are doing
here we take P we multiply it by V and
we add it to the O Block uh let's go
actually to this line one by one so
first of all we need to fix the previous
Aus block with the correction factor
correction factor that we have here so
we can fix the previous block with this
Alpha term here which is the correction
factor for the previous
block and so we just fixed the previous
block for now but we didn't add the new
PV so to add the new PV we do the dot
product of p and v and this third
argument tells the dot this not DOT
product it's actually the matrix
multiplication uh tell this matrix
multiplication to use this element here
as uh the accumulator so this is exactly
the same as doing uh P block multiplied
by the V block
uh O Block Plus equal to P block
multiplied by the V block uh this is
just optimized because anyway this dot
uh function here needs some place where
to store the intermediate results so why
not just store it where it should
actually go and because it um the dot
the the the matrix multiplication is
just a DOT product and the dot product
is just a repeated sum this accumulator
will be will this dot will keep summing
the result to this block here which will
exactly result in this uh um instruction
like we have done the matrix
multiplication separately and we added
it to the O Block so this is uh that's
why this argument is called the
accumulator all right so we have also
computed the output and then we save the
new estimation of the maximum for the
current iteration and it becomes Mi so
at the next we can use it to calculate
the correction
factor and then we have finished for the
current block and then we can move on to
the next block so we advance our K and V
pointers by one block of K and v um we
advance it differently because we know
that the V block is a pointer to a
tensor of shape let me write it here
this is a tensor of shape uh sequence
length head dim
so we need to increase the sequence
length by one KV uh the block size KV uh
while the K block is actually the K
transpose block so we need to and it is
transposed because we have exchanged the
strides and the shape so it is head
Dimension head Dimension sequence length
so we don't change the head Dimension we
just Advance the sequence length by
sequence uh block size KB so basically
we are just going to point to the next
block of K and to the next block of a
v I hope to you were able to follow the
algorithm of flash atation I try to use
the same names I try to use the more or
less the same logic and always writing
the formula that I am referring to so
hopefully you didn't get lost I think
the only difference that there is
between the flashh attention algorithm
as written on the paper and this code is
probably this Alpha which is the
correction factor but I hope it's easily
understandable anyway um then we just
return the O Block so O
Block Li I which is the um the
normalization factor for each row in the
current output block which is also a q
block because we are working with one Q
block independently from the other
programs and Mi is the maximum value for
each row which will be needed for the
backward pass because when in the
backward pass we will compute the Q quy
query multip by transport of the key
blck on the Fly we need to also apply
the soft Max but instead of recomputing
the stuff which we already computed
during the forward pass we just save
them and reuse them during the backward
pass which will save us some
computation um now I know it's time to
talk about the log some X trick because
we are going to use it so let's go back
to the old method so let's go here all
right so we have computed two calls of
this function in case we are working
with caal attention in case of the we
are Computing Cal attention we call this
function once to work with all the query
blocks that are to the left side of the
diagonal of the query key Matrix then we
do another call of this function to work
only with those blocks of keys that
exactly lie on the diagonal of the query
key uh
Matrix uh because in this case some of
the values need to be masked out and
some of them do not need to be masked
out moreover by doing this we can avoid
computer in the dot products for all
those values in the Cal m in the causal
case for which the key is index of the
key is higher than the index of the
query saving some computation because
anyway they will be resulting after the
soft Max in zeros and they will not
contribute to the output so it should be
faster uh okay now let's go back to the
this method here so calling method and
there is one last thing that we need to
do which is uh we need to compute the
log some exp and now I will show you
what is it so in order for the backward
pass to recompute the soft Max without
having to recalculate the normalization
factor and the maximum value for each
row we should be actually saving two
different stuff one is the maximum for
each row in the query block and one is
the normalization factor for each query
in the query block however there is a
trick and the trick is okay it's not
really called log some X trick because
the logm X trick is used for another
purpose but let's call it logm X trick
number two so um the logm X trick number
two is something as like this so let me
open the slides so when we do um query
multiply by transpose with the keys we
get a matrix that is made up of dot
products so something like this like
this is one dot product so let's call it
query one transpose the key1 query one
transpose the key2 this is a query two
transpose the key one one and this is
query 2 transpose the key2 then we need
to apply the soft Max right so the soft
Max is what is the let's write the
formula of the soft Max for each of
these vectors so this is a vector and
this is a vector because we applied it
by rows for each of these vectors it
will modify element wise each element as
follows so the soft Max of x i is equal
to the exponential of x i minus oh my
God I didn't leave enough space so let's
move this stuff here
back and this stuff here a little left
all right it will be the soft Max of um
the exponential of each element minus
the maximum for the current Vector to
which we are applying the soft
Max divided by the normalization factor
which is the summation over all possible
JS where n in this case is equal to two
because we have each Vector is made up
of two elements of the exponential of x
i- x
max now imagine we already have X Max
and we already have this summation in
the flashh alation algorithm in the
forward pass this stuff here is called
Li and this stuff here is called
Mi what we are going to save in the code
you can see here we are saving actually
not Mi and Li separately we will be
saving Mi I plus the logarithm of Li I
so we are going to save Mi plus the log
of Li I so what will happen is that when
we will um compute the um compute the
backward pass we need to recreate this
Matrix here on the Fly which means that
we need to recompute the query multiply
by the transpose of the keys and we to
um and then we should apply the softmax
to apply the softmax we should need this
stuff and this stuff here but we have
only this stuff here so so this is the
Mi plus the logarithm of Li so when we
Computing the softmax we will compute
the following so we will compute the
softmax as follows uh we will Define
let's call it a new softmax so let me
use another color
uh
here we will apply the softmax as
follows so
softmax of
XI let's call it the soft Max 2
because it's I don't want to confuse
soft Max is equal to the exponential of
each
element minus we will substract this
value here the one corresponding to the
current row to which we are applying the
soft Max so it will be the exponential
of x i minus m i minus the log of Li
I if we expand this expression this will
become the
exponential of
because exponential the sum of two expon
the exponential of the sum is equal to
the product of the two exponentials we
can also write it like this so it will
be the exponential of x i minus m i
divided
by the
exponential
um the exponential of the log of Li I
which guess what it is equal to the
exponential of x i minus m i
divided by Li I which is exactly the
normalization factor and we also have Mi
so instead of saving two values we save
only one value and when we apply it the
exponentials properties will take care
of actually also normalizing each value
to which we apply it um if you don't
remember the properties of the
exponential it is very simple so the
exponential of a MTI plus b is equal to
the exponential of a multiplied by the
exponential of B and the exponential of
a uh not exponential it's the
exponential a minus B is equal to the
exponential of a divided by the
exponential of B and this is the the
trick that we're using so that's why we
don't need to save two different values
we just need to save one value and then
when we apply it it will automatically
be taken care will take care of
normalizing because of the properties of
the
exponential all right let's move forward
so we have also created this um value
that we will use during the backward
pass now uh as you remember in the flash
attention algorithm we don't normalize
each block while Computing it we
normalize the output at the end and this
is exactly what we are going to do here
so we normalize the block at the end
after we have computed all the
normalization factors that we need for
all the rows that belong to the current
output
block um we save this Mi so we save it
um this Mi is what is the normalization
factor and the maximum for each row that
we will need for the backward pass so we
need to save it in a tensor that we will
use during the backward pass so we need
to understand which tensor is this and
it's the tensor that we called M which
is a tensor of a b size numb heads and
sequence length Dimensions so we need to
select the right point in this tensor to
select to where we should save this Mi
values uh so we need to select the right
B size index and the right number of uh
head
index uh so we advance this pointer by
the following offset which is M
plus um the index batch head because
each um index okay the index batch head
is what is the index of the current
program that includes information about
which head we are working working with
and which batch we are working with
because each of this um for each batch
and for each head we have a sequence
length we can skip n uh a number of
sequence length based on which index is
um okay what we are doing is basically
we are
skipping uh for each uh batch and for
each head we will have a sequence length
because each to token in the sequence
has a maximum value and each token in
the sequence will have a normalization
value So based on the current
combination of batch and head we can
skip a number of sequence length that
other programs will process so uh
because in this uh tensor we have the
sequence length as the last Dimension
and we have what is the combined index
of the batch size and number of head
size we can skip a number of sequence of
length based on the combined in index
which is given by the program index
number one which is the index batch head
that we have here and this is why we
skip here a sequence length number um
multiplied by the index batch head this
m is pointing to the
first uh element of the entire tensor so
we we are skipping the heads and the
batch based on the combined index index
batch head that this particular program
is working with and then we have off
Cube of skew is because each um of this
Kels the attention forward method will
work with one um query block uh each
query block has some indices for the
exact queries it includes and this is
given by off skew variable that you can
see here which is how many blocks of
queries we need to skip because they
will be processed by other programs plus
the range of queries that this
particular that Noti this that a
particular block of queries has so uh
imagine this particular program is
working with the queries that go from I
don't know from uh 12 to 16 then this
will be 12 13 14 15 so the normalization
factor and the maximum value for each
row we only have that for the this for
this indices of query queries so 12 13
14 and 15 and that's why we need to also
skip the number of queries that this
particular program works with which is
already included in this offset offs Q
variable all right so now we can store
the Mi so because we have the point to
which where it should be saved and we
can also store the output which was
computed of by our inner for Loop and
this guys is the forward step of the
attention flashh
attention now we should go forward which
is we should compute the backward path
we also have all the ingredients for
computing the backward pass because we
have already seen this trick which is
the logm X trick so we already know what
um how to use it to compute the query
key block during the backward pass on
the
Fly what we miss to understand the
backward pass well we need to understand
what is the first of all what is the
backward pass why do we even need a
backward pass we need to understand what
is the autograd of py torch how does it
work how to compute the gradient what is
the gradient how to compute do we need
to what is the Jacobian when Computing
the gradient on the backward pass do we
even need to compute that so we need to
derive all the formulas of the backward
pass by hand so if you are in for the
challenge let's continue all right so
now before looking at the flesh
attentions backward pass at at the
algorithm we need to understand why we
even need a backward pass and to
understand why we even need a backward
pass so before looking at the autograd
of P torch we should be looking at what
is what are derivatives what are
gradients what are jaian so that when we
talk about derivatives gradients and
jaian we don't feel lost so I will do a
very fast uh let's say rehearsal of what
these topics are now what is the
derivative when you have a function that
takes as input a real value and outputs
a real value we talk about derivatives
which is defined as follows the
derivative of the function with respect
to its variable uh X is defined as the
limit for a step size that goes to zero
of the function evaluated at X Plus H so
X plus the step size minus F evaluated
at The
X at x divided by the step size so
intuitively we are saying is the ratio
of how much the output change for a
small change for how much the input has
changed in the
function that this also gives you the
intuitive um intuition uh of why the
gradient is the derivative is also the
tells you the inclination of the tangent
line of the um to the function at the
point in which it's
evaluated I will use also the following
notation to denote the derivative so the
derivative I am used to write it as like
this so frime of X but it's also
possible to write it as D of f ofx with
respect to DX or D of Y where Y is the
output of the function with respect to X
and they are all equal to the same thing
which is the definition
above if we invert this form here and we
take H to the left side we can also
write the follows so if we want to
evaluate the function um at at the
position X+ H we can also evaluate it as
F Prime of H so the derivative of the
function in the point x multiplied by H
which is the step size plus f of x this
is actually also how we derive the ler
rule uh for computing the differential
equations but that's not the topic of
today so this H we can also call it
Delta X so f of x plus Delta X is more
or less because here we have a limit
that says when this only happens when H
is very very very small so that's why we
put this more or less approximately so f
of x plus Delta X is more or less equal
to frime of X multip by Delta X Plus F
ofx this you can also read it as follows
that if by inverting this formula um if
x changes by a little amount and this
little amount is Delta X how much y will
change y will change by this exact
amount which is the derivative of y with
respect to X so Dy with respect to DX
multiplied by how much X has changed so
this Dy DX tells us how much y will
change with a small change of x if we
multiply with the actual change of X it
will tell us by how exactly y will be
affected um I don't want to use stay too
much on this but I I would like to use
this intuition to introduce the chain
rule because imagine we have a function
of a function so imagine we have Z is
equal to F of G of
X we can think of X being mapped into a
variable y through the function G and
then y being mapped to into variable Z
through the function f if x changes by a
little bit and by a little bit I mean
Delta X how much y will change well y
will change by Delta Y what is Delta y
Delta Y is the derivative of y with
respect to X multiply by the step size
of
X but if y changes it will also affect Z
because there is a direct mapping
between Y and Z so how much Z will
change for a small change in y let's see
so if y changes from the old y by a
small step Delta y then Z will also
change by some Delta Z and this Delta Z
is the DZ of on Dy multip by Delta y if
we replace this Delta by with the Delta
y that we have computed in the
expression
above we arrive to the chain rule it
will tell us how Z will be affected so
this is Delta Z what is um the effect on
Z for a small change on X and it's the
product of the two derivatives one with
of Y with respect to S and one Z with
respect to Y and this is the chain rule
that we study in high school so it is if
you want to compute DZ on DX it is DZ on
Dy mtip Dy DX uh which is very intuitive
if you think about the following example
so you can think of Z as the price of uh
C
and X as the price of the oil how much
will the a small change in the price of
oil affect the price of a car well the
small change in the price of the oil
will affect for example a variable Y
which could be the price of uh
electricity so if how much the price of
electricity will affect the price of a
car it's through the derivative of the
price of electricity with respect to the
uh the price of the car with respect to
the um electricity so to to get the
effect of the price of oil on the price
of the car we just multiply the two
effects and this is the intuition behind
the chain rule anyway let's talk about
gradients so when we have a function
that as input takes a vector and
produces a scalar we talk not anymore
about derivatives we talk about
gradients so imagine we have a function
that takes as input a vector made up of
two Dimensions but n dimension in
general and it produces a scolar when do
we have to deal with this kind of
function for example loss functions loss
functions are something that are always
a scalar as output and as input they
take tensors so um for example imagine
the crossentropy loss it will take
a a sequence of tokens each tokens with
its own Logics and it will compute one
single number which is the
loss so how to uh View the effect on the
output with respect to the input in this
case well if x changes by a little
amount and this little amount is not
anymore a number but it's a vector so if
change the X the old the X Plus Delta x
uh is a vector sum then y will also be
affected by what y will be affected by d
y on DX multiply by Delta X however this
Delta X is not a number anymore it's a
vector because X1 may change by a little
bit X2 will change by a little bit X3
will change by a little bit X4 blah blah
blah until xn will change by a little
bit so this is actually a DOT product of
this Vector multiplied by this Vector
why a DOT product because y will be
affected by the change in X1 it will be
affected by the change in X2 it will be
change affected by the change in X3 up
to xn and each of the contribution of
the contribution of X1 will the partial
derivative of y with respect to X1
multiply by how much X1 has changed plus
the contribution FX2 will be the partial
derivative of y with respect to X2
multiply by how much X2 has changed blah
blah blah until the last uh contribution
of
xn so and the chain rule in this case
also applies in the same way as in the
scalar case so the formula does not
change also for the change rule here I
just want you to to to remind that in
this case we are talking about a
gradient and the gradient is just a
vector um made up of all the partial
derivatives of the output with respect
each of the input variables that are in
the input
Vector when we talk about a function
that have as input a vector and produces
a vector then we don't talk about
gradient anymore we talk about jacobians
so if our input X the input X of this
function changes by a little amount and
this Delta X is a vector then the output
y will also change and this output y
will change by a Delta y that is not a
number anymore it is a vector and this
Vector is the result of this quantity Dy
on DX multiplied by Delta X Delta X is a
vector so this one to be a vector it has
this one here has to be a matrix and
this Matrix is called the Jacobian it is
a matrix that has as many rows
later we will talk about the denotations
so it has as many rows as there are
output variables and as many columns as
there are input variables the first row
will be the partial derivative of the
first output variable with respect to
all the input variables the second row
will be the partial derivative of the
second output variable with respect to
all the input variables and the last row
will be the partial derivatives of the
last output uh variable with respect to
all the input variable in the input
Vector um now let's talk about notations
the Jacobian that I have written here is
is is written according to the uh
numerator layout this is called the
numerator layout and there is another um
convention called the oh not layout
sorry guys it's called the numerator
convention and there is another
convention called denominator convention
or notation in which um the rows are not
the
the number of rows is not the equivalent
to the number of output variables but
equal to the number of input variables
so um the fact that I have we we choose
to write the Jacobian as follows is
based on a convention you can also write
the the Jacobian according to the
denominator convention just by
transposing this Jacobian here and also
the formula for the chain rule changes
accordingly for now I want to keep the
formula for the chain rule just like the
one for the scolar case so that's why I
am using this notation here but later we
can change between one notation to the
other just by doing a
transposition okay now that we have
reviewed what is derivative what is a
gradient and what is a Jacobian let's
talk about um what happens when we take
derivatives with respect to tensors of a
tensor with respect to another tensor in
this case we talk about the Jacobian but
it's called the generalized Jacobian so
if we have the function that is add
input takes a tensor of DX
Dimensions where the first shape this is
kind of the shape of the tensor so the
first element of the shape is N1 the
second element of the shape of the input
Vector is N2 etc etc until n DX and it
produces an output tensor that has this
shape so M1 M2 blah blah blah m d y in
this case the formula for the chain rule
doesn't change and
if x changes by little amount so by
Delta X which is a
tensor y will also be affected by how
much by Dy on DX multip by Delta X and
this is a tensor product it will be a
Jacobian uh this is called generalized
Jacobian with the following shape so all
the dimensions of the output multiplied
by all the dimensions of the
input all right this is very abstract
for now we will see actually a concrete
case of this one because we will be
deriving the the gradient of the output
of a matrix
multiplication uh the the gradient of
the loss when Computing backward pass
with respect to each of the input of the
matrix multiplication operation and we
will do it also for the soft Max and we
will do it also for the attention so I
don't want to jump to too many topics I
just wanted us to get into the right
mindset so we know that derivatives when
we have scalar functions gradients when
the output is scolar input is a vector
Jacobian when the input and output are
both vectors generalized Jacobian when
the input and the output are tensors the
chain rule always works in the same way
all right let's talk about Auto autog
gradient I will do the scalar case and
then we will extend it to the tensor
case so imagine we have a very simple
computation graph why we have
computation graph because we are talking
about neural networks and neural
networks are nothing more than
computation graphs where we have some
input with we have some parameters and
we do some operations with this input
and parameters suppose that you have an
input a and this input a is multiplied
by a weight a parameter weight it's just
a scalar U and it produces an output y1
this y1 is then summed up with another
number called B1 and it produces Y2 this
Y2 is then raised to the power of two so
this Z to the^ of two is just the power
of two of the input and it produces Y3
and this Y3 becomes our loss function so
it's a scalar now uh what we uh want to
do to apply gradient descent is we want
to compute the gradient of the um loss
function with respect to each of the
input of this computation graph so each
of the leaves of these computation
graphs what are the leaves it's this
node here so the parameter nodes and the
input
nodes um and to do that there are two
ways one is if you have access to the uh
expression that relates Direct the input
to the uh uh output so the to the loss
then you can directly compute the the
gradient the derivative in this case
because it's not a gradient it's a
scalar versus scalar so in this case
imagine we want to compute the
derivative of the loss with respect to
W1 imagine we have access to the um
exact um expression that relates the W1
to uh to the five which is our loss we
can compute it as follows so we just
derive this expression with respect to
W1 which is two * because this is the
power of two of a function so it is two
multiplied by the function multiplied by
the the derivative of the content of
this function with respect to the
variable that we are deriving so it will
become the following expression there is
another way which is by using the chain
rule so we can use the derivative of
five with respect to Y W1 is the
derivative of I with respect to Y3 which
is the previous uh output of the
previous node then the derivative of 53
with respect to the previous the output
of the previous node so and then the
multiplied by the derivative of Y2 with
respect to the output of the previous
node and then the derivative of y1 with
respect to W1 if we do all this chain of
multiplication we will obtain the same
result and you can see that here this
stuff here is exactly equal to this
stuff
here by doing this procedure here we
will note something that is
I want to zoom out a little bit okay to
compute the the derivative of five with
respect to W1 we are doing all this
chain of multiplication but what is each
item in what is each um uh factor in
this sequence of multiplications well
this stuff here is nothing more than the
derivative of five with respect to Y2
these multiplications here are nothing
more than the derivative of five with
respect to to W uh to respect to y1 and
all of them combined are the derivative
of f with respect to W1 what pytorch
will do it will do the following pytorch
will do the backward pass because
pytorch knows what is the computation
graph that relates the output so the
loss function in this case and the
variable for which we want to compute
the gradient right now we are talking
about derivatives so uh it's not
gradient but the mechanism is exactly
the same so pytorch will say uh it will
py is like a person that knocks the door
of this operation and
says um hey
operation expon power of two if I give
you the gradient of the loss with
respect to Y3 which is one because loss
and Y3 are actually the same can you
give me the gradient of the loss with
respect to Y2 because a Pythor actually
does not Implement an autograd system in
in the sense that it does not know the
symbolic operations that led to the
output it just knows what are the
functions that computed the output and
each function has a function each
function is a class in Python that
implements two methods one is the
forward step and one is the backward
step the forward steps takes the input
so in this case Y2 and computes the
output Y3 the back First Step will take
the gradient of the loss with respect to
its output and needs to compute the
gradient of the loss with respect to its
input how can we do that well it's very
simple because pytorch will knock the
door as let me copy it this stuff here
otherwise it's not easy to go back and
forth
so okay and let's past it here pytorch
will knock the door of this function
here and will say hey if I give you
the loss of the Lo the the gradient of
the loss function with respect to your
output can you give me the gradient of
the loss function with respect to your
input yes the function can do it why
because of the chain rule this operator
here this function here can just do take
the loss uh the gradient of the loss
function with respect to its output
multiply it by the
Jacobian uh or in this case the
derivative of its output with respect to
its input and it will be equal to the
gradient of the loss with respect to its
input then pytorch will take this one
and knock the door at the next operator
which is this one this summation and
we'll say hey if I give you the gradient
of the loss with respect to your output
can you give me the gradient of the loss
with respect to your input yes this
operator can do it because this operator
just needs to apply the chain rule so it
will take the gradient of the loss with
respect to um to Y2 which is provided by
pytorch and by multiplying it with the
the Jacobian in this case it's the
derivative the derivative of the its
output with respect to its input it can
compute the uh the gradient of the loss
with respect to its input then pytorch
will take this output of this backward
pass and will knock the door of the next
operator which is this product and we
ask again the same question hey if I
give you the gradient of the loss with
respect to your output can you give me
the gradient of the loss with respect to
your input yes this will do the same
exact job it will take the gradient of
the loss with respect to the output
multiplied by the Jacobian of the output
with respect to the input and obtain the
gradient of the loss with respect to the
input and this is how py toch runs the
backward step it runs one operator at
the time backwards in the computation
graph knocking the door of each operator
and asking always the same question if I
give you the output the gradient of the
loss with respect to your output can you
give me the gradient of the loss with
respect to your input and each operator
will just apply the chain rule to to to
get this um to get this gradient to
calculate this gradient that pych needs
why pytorch cannot do it by itself
because pytorch does not do symbolic um
mathematics it does not have access to
the exact expression that each function
is Computing it just uses the function
as a blackbox that computes forward and
backward however with the Jacobian we
have a problem and let's see what is the
problem all right so up to now we have
been working with a compostition graph
that is made up of Scholars but the
things that we have said they work in
the scolar case but also in the tensor
case so let's go back see what is our
computation graph we have seen that
pytorch will go Operator by operator
asking always the same question if I
give you the gradient of the loss with
respect to your output can you compute
me the gradient of the loss with respect
to your input and each operator can just
apply the chain rule to compute that uh
imagine now that all of these operators
are working not with scolars but are
working with tensors which means that
the derivative of the output with
respect to the input of each operator is
not a derivative it will be a Jacobian
because the output will be a tensor a
generalized Jacobian and input will be a
tensor which means also that this
quantity here so the derivative of the
loss with respect to the input in this
case will not be a derivative it will be
a gradient because the output the loss
is a number always while the input in
this case y1 will be a tensor so number
output input is a tensor then we talk
about gradients so this will be a
gradient the and we will call it the
downstream gradient that the operator
needs to compute this will be the
Upstream gradient that pytorch will give
to the each of these operators so the
gradient of the loss with respect to the
output of each operator and each
operator needs to come up with this
Downstream gradient by using the
Jacobian however the Jacobian has a
problem let's see so uh imagine we are
implementing a simple operation that is
the matrix multiplication and the matrix
multiplication is takes as input X
tensor it multiplies it by a w Matrix
made up of parameters and produces a y
Matrix as output suppose that X is let's
call it n by D Matrix W is uh let's say
d by m
Matrix and so y will be a n by uh M
Matrix usually the input X is a sequence
of T of let's say vectors each of each
with d Dimensions so you can think of it
as a sequence of tokens each token is a
vector made up of the dimensions usually
we have many tokens so suppose that n
usually is at least
1,24 at least in the most recent
language models we even have millions of
tokens actually so uh ND is also
actually quite big it usually it is at
least 1,24 also so also this one
is24 um D and M M is also at least 1024
so we can actually become 22 2048 let's
say so I I like the powers of two by the
way so the problem of the Jacobian is
this if we compute want to compute this
Downstream gradient by multiplying the
Upstream gradient with the Jacobian this
Jacobian matrix is huge because look at
the dimensions here this will be a
matrix that
is um it will be well n by
m multiplied so it will be a generalized
Jacobian so it will be a tensor that has
a shape n m uh and then the input is X
so it is n by D so how many elements it
will have well it will have
1,24 multiplied by m which is
248 multip by 1024 multi by D which is
1,24 so it is at least wow it's billions
more than 1 billion
elements so it is impossible actually to
materialize this Matrix here in the
memory because in the ram of the GPU
because it will be too
big so but we need to compute this down
screen gradient because pytorch needs it
to continue calculating the gradient uh
of the loss function with respect to
each of the nodes in the computation
graph so how can we proceed the first
thing that we should notice is that this
G ve this uh Jacobian is actually a
sparse Matrix and I want to show you why
it is a actually is a super super super
sparse Matrix because um if if you look
at the input what is the effect of the
input on the output the input is a
sequence of
tokens so this is the token number one
it's a vector of some Dimensions 1,24
dimensions then we have another token as
input then we have another tokens as
input then we have another tokens as
input and we multiply by the W Matrix
which is made up of some columns uh some
uh columns so this one is n by D
right yes and W is uh d by m so d by m
this will produce a matrix that is n by
m so it will be also a sequence of
tokens each made up of M Dimensions so
it will be a matrix like this so this
will be the first output token this will
be the second output token this will be
the third output token and this will be
the fourth output
token now this output row here
is the dotproduct of this input row with
all the columns so the derivative of
each of these Dimensions with respect to
the dimensions of all the other tokens
will be zero because they do not
contribute to this output so the
Jacobian will have zeros every time the
we are calculating the derivative of
this First Dimension with respect to any
other element of other tokens that's why
we always can come up with a better
formula for computing this Downstream
gradient that does not involve the
materialization of the Jacobian because
the M the Jacobian itself is sparse so
let's see how we can optimize this uh
computation without materializing the
Jacobian in the case of matrix
multiplication because we need it for
flesh
attention all right guys so before
proceeding to the backward uh watch the
formulas of the backward path of the
flesh attention uh let's look at how to
compute the gradient of the matrix
multiplication operation with respect to
its input so imagine we are create okay
Pythor already have actually how to
compute the the the the gradient of the
uh inputs of the matrix multiplication
with the gradient of the loss with
respect to the input of the matrix
multiplication operation but in Flash
attention we are creating a custom
kernel which means that the custom
kernel is fusing multiple operations
into one operation so when pyro will
knock the door of our operator it will
ask the our operator which is the Tron
attention operator that we have built
what is the gradient of the loss
function with respect to q k and V
because that's the input of our function
so if we look at the code that we have
built so far you can see that our trial
rotation will be a node in the
computation graph that Tak take takes as
input q k and V and produces an output
then pyro will give us the gradient of
the loss with respect to that output so
it will will give us d o so the
derivative of the loss with the gradient
of the loss with respect to O and then
we'll ask this class here so Tron
attention to compute the gradient of the
loss with respect to Q K and B because
we are fusing multiple operations
together so we are Computing on the Fly
the soft Max of query query multiply by
the transp of the key and then
multiplying doing the soft Max and
multiplying it by uh V to compute the
output uh we need to compute this
gradients internally to compute this um
the gradient of the inputs so because in
this operations that we are doing fusing
together there is a matrix
multiplication we need to derive by hand
the matrix multiplication uh the
gradient of the of the loss function
with respect to the input of the matrix
multiplication operation so that we can
uh provide it to P torch that's why we
need to derive this formula uh I will uh
derive it in the simp in a very simple
way and um
and then we will do it for the soft Max
as well because these are the two things
that we need to derive by hand to derive
the formula of The Flash attentions
backward
pass so let's start imagine we have a a
computation graph uh in od in the
computation graph called the matrix
multiplication and this node in the
computation graph is doing a matrix
multiplication so it is Computing the
following operation Y is equal to X
multiplied W now what py will give us as
input when Computing the backward path
of this node py torch will give us the
gradient of the loss so it will give us
DF with respect to Dy so the output of
this node and we ask us to compute the
gradient of the loss function so the
gradient of the loss function with
respect to DX and the gradient of the
loss function with respect to the W uh
the easiest one to work with and the one
that I will be showing and the other one
I will not show in the video but I will
attach the PDF slide on how it is
computed because they are very similar
in the way they are computed so I don't
want to make the video too long for
Unnecessary
reasons let's compute the gradient of
the uh loss function with respect to the
um input so with respect to
X all right so how to do that by hand
without materializing the Jacobian
because as we have seen we cannot just
use the chain rule by materializing the
Jacobian which would be the easiest way
because the Jacobian is very big Matrix
that cannot even fit in the memory of
the GPU so we need to find a smarter way
we exploit the fact that the Jacobian is
sparse so hopefully we will get formula
that does not involve the
materialization of a very big sparse
Jacobian let's see so uh let's see um
let's when dealing with this kind of
derivations I always recommend to make
some example tensors so suppose that
that X is a tensor of size let's say n
by
D and where n let's say n is equal to 1
and D is equal to let's say three and um
X the W is a tensor also or a matrix
with the shape let's say d by
m where m is equal to let's say 4
um and Y will have as a consequence the
shape n by
m so it will have the shape um well 1 by
4 what pyto will give us pyto will give
us the following quantity so it will
give us this stuff here so the gradient
of the loss function with respect to the
output of this operator which is y so it
will give us a vector or a tensor
actually with the following Dimension
which is n by
m and we need to compute the gradient of
the loss function with respect to X
which should be a tensor of shape n by D
because um when dealing with the
gradient it always has the shape of the
input variable because it's the output
which is a scalar with respect to each
element in the input so it has the same
shape as the
denominator all right so uh when dealing
with these kind of problems I always
recommend to create example matrices and
then work out what happens to the output
and then try to work out the the the the
the the gradient Matrix so let's do it
so let's see that what is how is the
output computed well the output will be
a matrix that is 1
by4 computed as follows it will be the
input so 1 by3 so let's call the input X
X11
X12
x13 it will be multiplied by another
Matrix W that is has Dimension 3 by 4 so
it will be three rows by four columns so
it will be W1 1 W12
W13
w14 then w21 w22 w23
W2
4 W3 1
w32
w33 W 3
4 if we do this matrix multiplication it
will be well it will produce the
following Matrix that is okay this is
one row by three columns this is three
column three rows by four columns so the
output will be a matrix that is 1 by
four so one row by four columns so it
will be uh uh let me write it with a
smaller because otherwise it will never
fit here
so let's do it like this it will be X11
ultip by
w11 plus
X12 * by
w21 plus x13 *
w31 and this will be the first element
of the output the second element of the
output will be um
X11 with W12 X11 with
W12 plus
X12 with one two with
w22 plus
x13 with
w32 this will be the second element of
the output Matrix the third element of
the output Matrix will be let me move
this stuff on the left otherwise it will
never fit
so okay I think now it can fit this will
be X I need also to watch this one so
X11 with W13
X1 X11 with
W13 plus X1 2 with W 2 3 plus x 1 3 with
W3 3 and then we multiply the same row
with the last column so it will be X11
w14 + X12 W2 4 + X1 3 W3
4 this will be the output uh y if we do
the matrix multiplication what pyto will
give us it will give us the gradient of
the
loss so it will give us Delta fi with
respect to Delta y because it's a
gradient it has the same shape as the
denominator so it has a shape that is
1x4 let's call it because we don't know
what this value will be they will be
provided to us by pytorch let's just
give them generic name like
dy11
dy12
dy13 and
dy4 like
this now to compute the um uh the
downstream gradient that we need to
provide to py torure we should be
Computing the we should be m realizing
the jacoban which
is um which is uh okay let's write the
chain the chain rule formula so we need
to provide Delta fi to with respect to
Delta X which is equal to Delta fi with
respect to Delta y this is provided by
by torch multiplied by the Jacobian
which is Delta y with respect to Delta
X now instead of materializing this
Jacobian let's try to do this let's
materialize it now and let's do the
multiplication of these two L quantities
to see if something simplifies so this
stuff here will be Dy with respect to DX
which means the derivative of every
output y with respect to every input X
how many output we have we have four
elements as the output which is this
stuff here and we have three element as
input in the X
Matrix so it will be as follows I I
don't know how to let me copy it because
my screen is not big enough and I
remember that X is X11 and xx2
so uh Delta y with respect to Delta X
will have the following entries so the
uh y1 with respect to X11 and as you can
see y1 only has one X11 appearing as
multiplied by w11 so the derivative with
respect to X11 will be
w11 then y11 so this stuff with respect
to X12 it will be
w21 then X um y11 with respect to x13
will be
w31 the second row of this Matrix will
be the derivative
of the partial derivative of the second
output so w Y2 with respect to all the X
inputs which will be the derivative
partial derivatives of this stuff here
with respect to every X which is
W12 w22 I guess and
w32 now let me check if it's what I'm
doing is correct yes because I've
already done it so I can always double
check uh uh and then we have W the
partial derivatives of this stuff here
with respect to all the
X which is
W13
w23 and
w33 then the partial derivatives of the
last output so y4 with respect to all
the X which will be W
oh
w14 W2 4 and
w34 we obtain the following
Jacobian if um but this Jacobian as you
can see it's just equal to W
transposed so we don't need to
materialize the Jacobian we can just do
the multiplication of whatever uh
gradient pytorch is giving us multiply
it by W transposed and we will get the
downstream gradient so let me rewrite so
we know have we know what we are doing
so D on D DX is equal to
D with respect to Y multiplied by Dy on
DX but we have seen that Dy on DX is
just equal to W transposed so this is
equal to D on DX Dy multiplied by W
transposed and this gives us the
downstream gradient so in order to
provide the downstream gradient that
pytorch need we just need to take
whatever gradient pytorch will give us
multiply it by W transposed and it will
give us the gradient of the loss
function with respect to the input X of
the matrix
multiplication in the same way we can
also arrive to the formula for the
gradient of the loss function with
respect to W and it is equal to X
transposed multiplied by
Dy with respect to uh DW uh Dy
how to remember these formulas these are
there is a ponic rule which is um these
are the only possible ways for this to
have the shape of X and this to have the
shape of w because this ones this stuff
here will have the same shape of Y so it
will be
u n by m this stuff here will have shape
of w transpose W is is d by m so w
transpose should be M by
D and the resulting operation of this
matrix multiplication or tensor
multiplication will be n by D which is
exactly the same shape as
X in this case we will have that XT is
the transpose of T and it is uh n by D
so it's D by n
multiplied by D with respect to Dy which
is a gradient so it has the same shape
as the
denominator uh so it has uh n by
m uh and the output will have um shape d
by m which is exactly the um the shape
of w so if you to to remember them this
is the only way this shape work out
otherwise they don't work out so this is
a nimonic formula on how to remember how
to compute the gradient of the inputs of
a matrix multiplication given the
gradient of the loss with respect to the
output of the matrix multiplication and
the inputs to the matrix multiplication
are the input Matrix and the parameter
Matrix W now we need to derive the
gradient of the output of the softmax
with respect to the input of the softmax
because that's another operation that we
do in our fused attention because we are
fusing many operations together which is
matrix multiplication and the soft Max
so this is the second ingredient that we
need to understand the backward pass of
flash attention so let's do
it I will use to make this derivation I
will use the same notation as in the
flash attention paper so first of all
let's write the title of this stuff
which is the
gradient
through the soft
Max um the first operation that we do in
um during computation of the attention
is we we compute the product of the
query multipli by the transpose of the
keys we do in a block-wise ways it means
that we do it block by block but it
doesn't matter because the end result is
the same so we can also we can write s =
to Q multip by the transpose of the keys
and then we apply the soft Max to this
operation to the result of this
operation and we call this output P
which is the soft Max
of s and after the we have applied the
soft Max we take the output of the soft
Max we multiply it by V to obtain the
output so the output is equal to P
multiplied by
v um now we need to understand how to uh
because as I saw as I as I said before
pytorch autograd works in the following
way pytorch will treat our attention
computation as a black box so we will
have a computation graph like the
following
we will have a query input a key input
and a value input which are sequences of
tokens each one with some embedding
Dimension these are fed to some black
box called the
attention which is our implementation of
the attention which is the function that
we started coding before this will be
fed as input to this node in the
computation graph and the computation
graph will output an output tensor o
what pyro will give us pyro will give us
the gradient of the loss with respect to
the
output so as you remember py knocks the
door knocks the door at each operator
and says if I give you the gradient of
the loss with respect to your output can
you give me the gradient of the loss
with respect to your inputs and this is
what we need to figure out so giving the
gradient of the loss with respect to the
output we need to understand how to
compute the gradient of the loss
with respect to WQ the gradient of the
loss with respect to w k the gradient of
the loss with respect to w v however the
there is no direct connection between q
and o or k and o because there are two
intermediate operation so one there is a
first a matrix multiplication then there
is a softx then there is an additional
matrix multiplication however we have
tools that allow us to understand how
the gradient propagates through multiple
operations when they are applied in
sequence and that's called The Chain
rule however we have seen that applying
the chain rule in its naive way by
materializing the Jacobian is invisible
so we need to understand how to apply
the chain rule without materializing the
Jacobian and that's what we are going to
figure out for one of the operations
inside of this attention computation
which is the softmax and that's why we
are going to do this derivation which I
promise is the last one that we will do
and then we will finally go to code the
backward path of fles attention we
cannot proceed directly to coding the
backward pass of The Flash attention
because if we look at the formulas on
how it is computed we will not
understand how the um the derivation
comes
out okay now we can
start so let me delete this stuff delete
and imagine for Simplicity now we apply
the soft Max to a uh rowwise to this s
Matrix so each row is soft maxed
independently from the others
so let's see what happens to one single
row of this Matrix and for Simplicity I
we call it s so s is uh a single row of
the S Matrix I could also call it s of I
but if I do it like this we will have to
carry over the index I okay guys just
just do it we will carry over the index
all right so let's call SI one row of
the S Matrix so SI is equal to Let's
let's say it's the in tensor notation py
tensor notation it would be like this so
from The Matrix S so from the tensor s
we take the I row and all the columns
this is the definition of Si I know it's
very ugly notation but it helps you
understand and this is a vector of size
and
dimensions uh we apply the soft Max to
this uh vector and we will obtain an
output vector and we call it Pi Pi is
equal to the uh soft Max soft Max of Si
so as we have seen the soft Max
operation does not change the shape of
the input it just change element wise
each number um so it the output will
also be a vector of size R to the power
of n now
um what is the soft Max so the soft Max
is defined as follows the soft Max
of uh well p i j so the J element of the
p i Vector is equal to the exponential
of the J element of the SI
Vector divided by a nor normalization
factor that is computed as follows uh
with uh let's say not J let's use K in
this case not even K let's use l
is equal to 1 up to n of e to the power
of s i
l all right so uh first of all you may
be wondering the soft Max that we are
that we apply during the forward pass of
the computation of the attention is not
really this softmax because in if you
remember what we applied before we were
applying the soft Marx where each of the
argument of the exponential is reduced
by the maximum element in the vector to
which we appli the soft Max so it was
more or less like this so s i minus SI
Max so the maximum element in the SI J
SI
vector and also the argument of the
denominator was reduced by Si
Max
however we also proved that this stuff
here is equivalent to the standard soft
Max without this reduction in the
argument because this reduction in the
argument is only added because we want
to make it numerically safe to compute
but there is it's equivalent to do it
without from a mathematical point of
view on the computer of course it will
be become numerically unstable but from
a mathematical point of view it is the
same thing which also means that doesn't
doesn't matter how you comput the
forward pass if it's equivalent to
another mathematical definition you can
always use the other mathematical
definition to compute the backward pass
it will result in the same value if you
didn't understand what I said let me
give you a more simple example which is
um imagine you have a uh do you remember
the formula from high school this one so
Co cosine of s of X Plus sin 2 of X is
equal to one now imagine we compute uh
an output Y is equal to cosine 2 of X
and then we need to compute the
derivative
of y with respect to X it doesn't matter
if you compute it as the derivative of
cosine squared of X with respect to X or
if you compute it as the derivative of 1
- sin 2 of X with respect to to X
because they will result in exactly the
same result uh because the two
definitions are equivalent and this is
why we don't need to add this um this
factor in the exponential uh because the
two definitions are equivalent
mathematically we just use the
numerically save one because when
computed on the on the computer we need
something that is numerically save uh
stable that will not
overflow all right now um what do we
want to obtain
so we want to obtain the uh the gradient
of the loss with respect to the input
Vector of the soft Max which is the SI
Vector given the gradient of the loss
with respect to the output of the
softmax which is the pi
Vector multi and we can obtain that with
the chain rule multiply that by the
Jacobian Pi with respect to uh s
now we uh the chain R is always valid
let's see what does this Jacobian look
like um all right so this Jacobian will
be the pi with respect to Delta
SII uh well we need to do it let's look
at what each element in this jacoban
will look like so the J element with
respect to to the let's say the K
element so we are
um we are Computing the
the we are looking at what each element
in this Jacobian will look like which is
what is the Jacobian it's each element
in the out in the numerator of the
Jacobian derived with respect to each
element in the denominator of the
Jacobian uh in this uh fraction here so
we are saying for each element in the
output Vector derived with respect to
each element in the input Vector this is
the what we are writing here so what is
how is the output Vector obtained well P
J we know that it is equal to by the
definition of the softmax is obtained as
follows so e to the power of s
j divided by the normalization factor
let's call it l
is equal to 1 to n e to the power
of uh s i
l uh all derived with respect to s i
k i k so what we are trying to do is we
know that the P Vector is suppose it's a
vector with the three elements so this
is a p one this is well P11
1 one P1 2 and
p13 the S Vector will be a vector also
with the three elements so it will be
the
S11 S12 and
s13 what we are trying to do is
calculate what the Jacobian will be the
derivative of this one with respect to
all the input Vector then then the
second row of the Jacobian will be the
derivative of this one with respect to
each of this input element then the
third row of the Jacobian will be this
stuff here with respect to derived with
respect to each of the input element of
the S Vector we are trying to understand
what does the generic element in this
Jacobian look like based on the J date
element of the output Vector so this J
index refers to the output vector and
the K element in the input
Vector all right so what can happen when
we do this uh jacoban is that we have a
this one here is the derivative of a
fraction of two functions and we know
from high school that the derivative of
the fraction of two functions is as
follows so the derivative oops the
derivative let me write like this of f
of x with respect to G of X Prime is
equal to with respect to X by the way is
equal
to uh F Prime oops
of x * by G of x minus uh G Prime of x f
of
x all divided by the uh G of x to the
power of 2 like this now let's apply it
here so this will become here we will
have two cases either the variable that
we are deriving with respect to so this
s i k has the same index as the variable
being derived so either we are doing P11
with respect to S11 or we are doing P11
with respect to something else that has
not the same index so like P11 with
respect to S12 or s13 so there are two
cases that we need to consider suppose
that we are deriving P11 with respect to
S11 or we are deriving p12 with respect
to S12 or we are deriving p13 with
respect to s13 so we are deriving the
element of the output with respect to
the same element in the input with the
same
index so in this case the this um this
uh derivative will look like the
following so it's the derivative of f so
the numerator with respect to the
denominator that has the same index so
we are saying that in this
case uh J is equal to K
so uh the numerator with respect to SI J
with respect to uh e to the power of s j
with respect to s j will be e to the
power of s j so because e to the power
of X1 with respect to X1 will be e to
the power of X1 so this is equal to I am
reducing the size
now e to the power of s i j then we need
to multiply that by the denominator of
the fraction which is this summation
here so the summation
over all possible
L of e to^ of s i
l uh minus the derivative of the
denominator with respect to the variable
being derived so this denominator is the
sum of all the exponentials of all the
input elements if we derive it with
respect to one particular input element
there will be at least one term that
contains that input element and so the
the all the other terms will result in
zero so the only derivative that will
survive will be the e to the power of s
i k with respect to s i
k so we write minus e to the power of Si
i
k multiplied by the numerator which is e
to the^ of Si
J all this divided by the denominator to
the power of two which is this summation
here so L = to 1 up to n e^ of s i
l all to the power of two and this stuff
here will be equal to well we can see
that the this two term this one and this
one have a one term factor in common
which is e to the power of s j so we can
collect that so e to the power of s i j
multiplied by the
summation minus e to the power of s i
k all this divided by the denominator uh
which is the power of two of this stuff
here so let me just copy and paste it
which is let me rotate it also because I
don't know why I always write
little
little yeah all right and this stuff
here is equal to well uh we can separate
the two terms so we can separate this
term here and this term here because the
denominator is to the power of
two so we can write it also as e to the
power of s
j mided by the denominator so which is
summation of l = 1 up to n e to the^ of
s i
l uh multiplied by this stuff here so
this stuff here divided by the same
denominator so there's summation of l =
1 up to n e to the^ of s i
l minus E to the^ of s i
k i Am s i k divided by the same uh
denominator
s i
l now this one can be written as this
stuff here is nothing more than the
output element P J because this one is
just the soft Max applied to the SI J
element which we know that the output of
the softmax applied to the SI element is
called P because it's one element of the
output Vector which we called P so this
stuff here is equal to p i
j multiplied by
this stuff here will be equal to 1 minus
this stuff here what is this stuff here
is the output of the soft Max appli to
the s i k element so it will be p i k so
it is equal to 1 minus p i
k okay and this is in the case the the
the variable with respect to which we
derive has the same index as the
numerator in this uh fraction
here uh in this derivative here uh the
other case is when the two variables so
the the output the index of the output
with respect to the index of the input
are not the same in this case we will
have another case so we will have that
J uh let me write it again so this stuff
here hope I can copy it all
without in the other case in which s is
not equal to
J uh not s it's j not equal to K so J is
not equal to K what
happens in this case it will be well uh
the derivative of the numerator because
we need to apply again this formula here
so derivative of the numerator with
respect to something that is not the
same variable it will be zero because
it's like Computing the derivative e to
the^ of X1 with respect to X X2 it will
be zero so it will be zero so all the
first term here will become zero no
matter what is g of x minus the
derivative of the denominator of this
fraction here with respect to the
variable SI
K uh G Prime of s i k so this is all the
variable in the input and we are
deriving it with respect to one
particular variable of the input so only
one item in the summation will survive
so it will be the item s i
k so it will be e to the power of s i k
multiplied by f of x which is the
numerator in this fraction which is e to
the power oh we forgot a minus e to the
power of s
j uh let me see if I forgot something
all divided by the denominator of this
fraction here to the power of two
so it is equal to the
summation l = 1 up to n of e to the
power of s i l all to the power of
two uh I believe I didn't forget
anything so let's continue so here also
we can see that this one here is because
okay let's separate it minus E to the^
of s i k divided by the summation l = 1
up to n of e^ of s i l multiplied by e^
of s i j divided by the summation l = 1
up to n of e to the power of s i
l this stuff here is nothing more than
the soft Max appli to the K element of
the SI Vector this one here is nothing
more than the softmax applied to the J
element of the s I Vector so we know
what these are we know that we call them
P minus p i k p i
j so in the end we have two cases one is
the derivative of this stuff here looks
like the
following each item in the Jacobian
looks like the following when the
numerator and the denominator have the
same index so J equal to K this stuff
here is equal to now this notation here
is wrong so I shouldn't be writing it
with equal sign but doesn't matter guys
it's we are doing a
little um okay so p i
j p i j multiplied by 1 minus p i k let
me check yes the other Cas is when the J
is not equal to k then this stuff here
let me write it like this
will be equal to minus p i k MTI p i
j now that we know what the two typical
cases of this Jacobian look like let's
actually look at what this Jacobian look
like in The Matrix
form so this Jacobian will look like the
following it will be a matrix that is
more or less like the following it will
be an n byn Matrix where n is the size
of the input vector and the output
vector and here the first element of the
Jacobian as you saw as you remember uh
the first row of the Jacobian in the
numerator convention is the uh
derivative of the first output with
respect to all the input so this first
term here will be the derivative of P11
with respect to
S11 so in this case J and K match so we
know that it will be equal to P1 1 * 1 -
P11 the second element to the right of
this one so the element one two will be
uh the derivative of p12 with respect to
uh sorry the P11 with respect to S12 the
J and K do not match so we will be in
this case here so it will be minus
P11
p12 the third element you can check it
by yourself it will be minus P1 1 p13
blah blah blah until the end which will
be - P11
p1n the second row of this Jacobian will
be uh will look like this so it will be
the derivative of p12 with respect to
S11 the J and K do not match so we are
in this case here so it will be minus P1
2 P11
then the next element it will be the
derivative of p12 with respect to S12 so
J and K match so we are in the first
case so it will be P1 2 * 1 minus
p12 then this stuff here will be equal
to then the third element will be minus
P1 2 with respect to
p13 blah blah blah and until we arrive
to the last one which is minus P1 2 with
respect to P1 n not with respect to
multiply
by and all the elements like this until
the last row the last row will be the
the first element of the last row will
be the derivative of the uh last output
element with respect to the first input
element so it will be the derivative of
p1n with respect to S11 so um the two
indices do not match so we are in the
second case so it will be minus P1 n
P11 this will be minus P1 n p12 etc etc
etc let me do also the third element
since we are here so minus P1 n P1 3 etc
etc etc until the last element of the
last row which will be minus
p1n P1 N I
guess oh no that's wrong guys because
the two indices match so it should be
p1n multiplied 1 minus
p1n this is what the Jacobian will look
like let's see if we can find a better
um uh how to generate this Jacobian with
some pattern
recognition let's write it in a
different way first of all the thing
first thing that we can notice is that
this Jacobian is symmetric so you can
see that this element is equal to this
element if you expand the third row you
will see that it's equal to this element
this one on the top right corner is
equal to the one in the top bottom left
corner um so this Matrix is
symmetric the second thing that we can
notice is that only the element in the
diagonal are different they have an
additional term because you can look at
this element here so let me write this
element here can also be written as
P11 minus P11 * by P11 the second
element here in the second uh row so
second diagonal element of this Matrix
is p12 minus p12 * p12 so this the
element on the diagonal actually look
like just like the other elements they
just have an additional
term which is P11 in the first diagonal
element p12 in the second diagonal
element so we can also say that this
Matrix here is the product of all the
possible combinations of P J with p i
k which you we can obtain with an outer
product or even with the product of one
column with the transpose of the same
column so if you do one column Vector
for example imagine p is a column vector
and you do p multiplied by PT you obtain
all the possible combinations of
products of these two vectors because
this will be one
I can do a simple case so P11 P1 let's
call it P2
P3 uh multiplied by the row
Vector P1 P2 P3 this will generate all
the possible uh combinations of products
between P1 and the P the first vector
and the second Vector because this will
be a 3X one this is 1 by3 so it will be
generated 3x3 vector and it will be uh
equal to P1 P1
uh P1 P2 P1 P2 P1 P3 etc etc etc
moreover we can see that in the diagonal
of the Matrix we have this additional
term this additional term P1 in the
first diagonal element p p12 in the
second diagonal element p13 in the third
diagonal element uh I actually call it
P1 it's wrong because I should call it
Pi I that's why I didn't want to bring
the I uh indices so it's not really P1
it's should be Pi I Pi Pi I Pi because
we are doing it for the generic I Pi
Vector so let me fix the indices p i n p
I3 uh this is
one one one
one Pi I and pi
okay so this is Pi
Pi Pi Pi Pi Pi Pi I Pi I
okay we can obtain um so we can write
the this
um this Jacobian here also as the
diagonal matrix that in the diagonal has
all the element of the pi I
Vector minus the P Vector multiplied by
the transpose of itself so with itself
but transposed because we need all the
elements to be kind of a combination of
one element of p with it with another
element of p plus only on the diagonal
we need some this additional term which
are the elements of p and all the
elements of the the the output of this P
multipli by P transposed are negated
that's why we need this minus sign so if
you look at the flashh attention paper
they give you this formula here they say
that if Y is equal to the soft
Max of X then the
Jacobian uh will look like the following
will be um
diagonal of Y minus y y transposed where
Y is
um
the is a column
Vector all right guys I know this has
been long so let's take a pause and we
are going to now um uh code finally
first of all let's check the mathematics
of the backward path of Flesh attention
we will see it briefly I will not do any
more derivation but I will explain it
and then we finally switch to coding it
so let's
go all right guys now finally we can see
the um the backward path of the flashh
attention so we will be looking at the
algorithm and if you look at the the the
appendix of the flesh attention p paper
you will see this part b.2 where they
derive the backward pass step by step
now I don't want to do all the s all the
steps of this derivation because it's
going to be too long but I want to give
you all the tools necessary to
understand it now let's start from what
kind of um
uh how to say conventions they are using
uh notations they are using in this
paper so the first thing that we need to
rehearse is the naming of what is what
is the name of each
Matrix uh as you know in the forward
attention in the forward pass we do the
query multiplied by the transpose of the
key and the output of this we call it s
then we apply the soft Max to this s
Matrix and it becomes the P Matrix the
soft marks is applied by rows then we
talk take this P Matrix and we multiply
by a v Matrix to obtain the output of
the
attention
um let's look at for example how the
computation of the I row of the output
is computed based on the P Matrix and
the V Matrix so we can understand this
kind of notation that they are using
here in the paper because the way I read
this formula here is the E row of the
output which is a column Vector because
in when we write in um in mathemat in
linear algebra whenever we write the
name of a vector it is always by
convention a column Vector but the
origin of this particular Vector is
actually a row of the output Matrix
let's try to understand what is the
output uh row of a matrix in a matrix
multiplication now um so that we can
understand how to go from here to here
um so let's write a generic matrix
multiplication for example an a matrix
let's say that it is the
following and we only write one row
actually let me Zoom again and I want to
write smaller so we have enough space so
we make a matrix that has a row let's
call it A1 A2
A3 and then we multiply this will be a
matrix with many rows like the this one
because we want to study the effect only
of one row and we multiply it by another
Matrix let's call it this one is the
Matrix a and it has I don't know let's
say n rows by three columns then we
should have another Matrix B with three
columns and some number of um three rows
and some number of column let's say four
columns so we call the first uh row
let's call
it let me Zoom more so it's
b11
B12 B1 3
b14 then this one should be B2 2 1
B22
b23
B24 this should be b31
b32 uh
b33 B3 4 Etc I know I am not very
rigorous in my notation I should have
called all these elements with the
capital letter a and the capital letter
B so this is the notation that you use
when referring to single item of a a
matrix but please forgive me for this so
the output of this matrix multiplication
will be another Matrix that is n
by4 so it will be n by 4 so we will have
four columns for each uh row of the
output I want to write the output in a
different way so I want to write it as
follows as a vector only so the first
output row as a vector and want to
understand what is is each dimension of
this Vector so because otherwise I don't
have enough space to write it here so
the uh first uh let's write it so let's
call it o I want to write what is O of
one which is the first row of the output
but written as a column Vector so o of
one will be here we should use the small
letter O of one should be a vector where
the First Dimension is the dot product
of this stuff here so the first row of
the a matrix with the First Column of
the B Matrix so the first uh let's say
Dimension will be A1 with
b11 uh I should also call this one a11
a12
actually and a13 so A1
3 uh because we have many rows in the a
matrix so let me use the correct naming
so this will be a11 with b11 a11
b11 plus
a12 * by
B21 plus
a13 with
b31 and this will be the first dimension
of the first row of the output Matrix
o the second dimension of the first row
of the output Matrix o will be the
dotproduct of this row of the a matrix
with the second column of the B Matrix
and let me write here B so it will be
a11
B12 plus
a12
B22 plus
a13 b
32 the third dimension will be a11
B1
3 plus A1 2 B2
3 plus A1 3
b33 the fourth dimension will be
a11 uh
b14 + A1 2 b 2 4 plus A1 3 B 3
4 now this is the output the first
output row of the O Matrix and it's a
vector called o1 and these are this is
the first dimension of this Vector this
is the second this one is the third and
this is the fourth dimension and each of
this stuff here is one
scalar um so the
output o1 which is the first row of the
output Matrix can also be written
as the first element as you can see in
is a sum of many vectors where the first
element is
a11 Multiplied let me use a smaller uh
this one but I want to use a smaller I
can't change the size
here okay doesn't matter so as you can
see here there is A1 multiplying a
different B number every time so this is
b11 B12 B13 b14 what is b11 B12 B13 b14
it is the first row of the B Matrix so
it is equal to B uh 1 and all the
dimensions of the first row then plus
then we have the element
a12 multiplied by B21 B22 b23 Etc and
this is uh the second row of the B
Matrix so we use the tensor rotation of
py torch to describe this row which is a
b uh two and all the dimensions of
B2 uh so it looks this is a vector
scalar product and
plus uh
a13 multiplied by B uh
3 and all the dimensions of P3 this one
can also be written
as um the
summation over all possible I that go
from 1 to three where 1 to 3 is how many
uh uh columns there are in the a
matrix uh of
a i
j uh well
A1 let's call let's call this one J
actually sorry let's call it
J equal to one and let's call this the
generic e row of the output Matrix will
be a i1 a I2 and a I3 each one
multiplied by the corresponding Row in
the B Matrix so we can write it as a i j
multiplied by B uh J where BJ is the uh
a row of
B uh we can also write it like this to
indicate that this is a vector and this
is exactly what they do here so the
output in the output Matrix when we do
the multiplication P multipli by V the E
row of the output Matrix we call it oi
which is a vector but by notation it is
a column Vector where the elements of
this column Vector are actually the
elements of the E row of O uh this is
only by notation guys uh is equal to the
E row of P so the E row of the Matrix
that is on the left in the matrix
multiplication multiply by all the
columns of the V Matrix which can also
be written as the summation over all the
elements of the I row of P so all the
elements of the I row of the first
Matrix the one on the left in the matrix
multiplication multiplied by each Vector
in the v Matrix where the J Matrix here
in V is each row of the V Matrix
so and P J can also be written as um p
is what is um the the output of the soft
Max so as you know the output of the
soft Max is e to the power of the
element input of the softmax what is the
element input of the softmax is the
query multiplied by the transpose of the
keys so it's a DOT product between one
query and one key and that's why you
have this stuff here in the exponential
so this is the first step in
understanding this derivation another
thing that we have studied so far is how
to derive the uh backward path of the
matrix multiplication
and of the soft Max so now let's use it
in the matrix multiplication let's
rehearse the formula so the if given a
matrix multiplication that is y = to X
multiplied w we know that given the
gradient of the loss function with
respect to Y so the output of this
operation we know how to derive the
gradient of the loss with respect to one
of the input of this function which is
the x or W to get the gradient with
respect to X we need to take the
Upstream gradient so the the gradient
with respect to the output multiply by
the transpose of WT and to get the
gradient with respect to w we need to do
the XT so the input transposed
multiplied by the upstreaming gradient
this one is the formula that we didn't
derive and this one is the formula that
we derived but how to derive them is
exactly the same
procedure in attention we are doing the
last product that we are doing is O
equal to P multiplied by V what pyo will
give us as input during the backward
pass is the gradient of the loss with
respect to the output and we need to use
this gradient of the loss with respect
to the output of the attention to derive
the gradient of the loss with respect to
Q with respect to K and with respect to
V so that it can then be used by the
operators in the backward pass in the in
computation graph in the operations
before okay so but in order to arrive to
the gradient with respect to query key
and value we need to derive the gradient
with respect to each intermediate
operation so the last operation that we
do is O equal to P multiplied by V so
the gradient with respect to O of the
loss with respect to V given the
gradient of the loss with respect to O
it is exactly like Computing the
gradient of um the of the loss with
respect to X in a matrix multiplication
and we know that it is equal to the PT
so just by analogy guys so this is our
reference point and I am just changing
the names here and you should understand
what is the analogy here so um the
gradient of the loss with respect to V
which is the Matrix on the right which
is like Computing it with respect to to
W it is equal to just like this formula
here so the transpose of the Matrix on
the left multiplied by the Upstream
gradient which in the paper they write
it as this so DV is equal to PT
multiplied by d o and it's the formula
that you say you can see here the other
derivation is how to derive the gradient
with respect to DP DP is just like
deriving the gradient of the loss with
respect to the Matrix that is on the
left side of the matrix multiplication
so it is just like deriving the gradient
of the loss with respect to X in the
reference uh formulas
which is equal to the Upstream gradient
multiplied by the transpose of the other
Matrix which in the notation of the
paper they write it as DP is equal to d
o multili by V transposed and it's this
formula here how they compute this Stu
here is exactly as above so as as this
derivation here they call
VJ the J row of the V Matrix and they
write it as um p i J multip by d o how
to arrive to this formula here well
let's do it
so let me write let's see
okay theoretically we know that from
this derivation here so from this
derivation here or from this derivation
here we know that the I row of the
output in a matrix multiplication first
of all let's simplify our life every
time you see a transposed and you don't
like work with the transposed in a
matrix multiplication just give it a
different name and then work with the
different name and after when you have
derived the formula you resubstitute the
transpose operation in this case we are
doing DV is equal to P transpose multip
by d o let's call P transposed let's
give it a name that we are not we didn't
use so far so let's call it f I always
used F when it's available so um we call
DV is equal to f d o we know from above
here from this derivation here or this
derivation here is
equivalent that the output of a matrix
multiplication so the out I row of the
let's not the J row let's call it the J
row d
v j is equal to a summation of each
element of the J row of f of the first
Matrix so we do the let's see here they
do the sum by I so let's do it by I it's
the sum over all possible I of the uh I
element in the
J row of the first Matrix so f j uh yeah
FJ
I multiplied dot product not DOT product
this is a a scolar uh Vector
multiplication multiplied by a vector
that is let me check what was the
formula so it was the J row of the other
Matrix so in this case it should be the
I row of the other
Matrix uh o of
I where I this is the I row of I this is
the J row of the V
Matrix um and but also we we know that f
is not a matrix that we have it's
actually the transposed of P which means
that fji will be equal to p i j because
in a matrix transposition you invert the
two indices so this is the summation
over all possible I's of P not j i but
IJ multiplied by oi and this should be
equal to the same formula that you see
on the right here this allows you to
compute one output Row in the v
Matrix okay and we know that p is just
um the output of the softmax the soft
output of the softmax is the input of
the softmax to the exponential of the
input of the softmax divided by the
normalization factor
associated with the that row so because
we are iterating through the row of I it
will be the I the normalization factor
associated with that row of um of
OI so we know that the for formula for
the p is equal to the soft Max of s now
the I row of P will be the soft Max of
the E row of s and this is what is
written here we know from our derivation
that the Jacobian with respect to the
softmax operation so if we have an input
X and the output is y of the softmax
operation the Jacobian of this um of the
Y with respect to the x is equal to the
diagonal y it's a diagonal matrix of the
element of the factor y minus y * y
transposed and we have also seen before
that this Matrix is
symmetric however you may not understand
this formula here because we have seen
from our uh in the chain rule we always
write it like this we always write that
the downstream gradient so the D um uh
Fe of let's say um TX
should be equal to the uh Upstream
gradient so d f with respect to Dy
multiplied by Dy and um with respect to
DX this only works if you make this
Matrix here as in the numerator
convention the numerator convention is
one of the two convention in which you
can create a Jacobian we so far we have
always written it as the numerator
convention if you use the Ator
convention this is a row vector and this
is a row Vector however if you want to
treat this stuff here as a column Vector
then you need to take the transposed or
you need to make the Jacobian in the
denominator
convention how to get this formula here
because this formula here is basically
doing the Jacobi and multiply by the
Upstream um uh gradient not the gradient
Upstream gradient multiplied by the
Jacobian and it's only because here we
treat it as a column vector and when you
do the you want to transform a row
Vector into a column Vector you take the
transpose of both sides of the equation
and let's do it actually so we apply the
transpose to the both side of the
equation okay um in a matrix
multiplication if you do uh a b
transposed it become B transposed
multiplied by a transposed so the
transposed is applied independently to
each input of the the matrix
multiplication but we invert the matrix
multiplication and if you remember the
matrix multiplication is not commutative
so what we do here is that we say okay
it will be the D Fe of DX and here they
call
it here they call it
DSi so it will basically just become D
on DX if you treat this one as a column
Vector so this one as a column Vector
will be equal to Dy on DX as a column
Vector as a Jacobian in um in
denominator layout in this case
multiplied by d f on d y as a column
Vector this one is a column Vector this
is a column vector and this is what you
see here that's why the Jacobian is on
the left side of the Upstream
gradient uh what else we need well I I
know that there is a lot of things here
in this derivation but I prefer actually
going directly to the code otherwise I
think it's going to be too
boring um so let's go to the code and
while writing the code I go back to the
formulas in which we can find the
association of what we are doing and the
formula in the paper I think this is the
best uh way so let's proceed further all
right guys now we can finally code the
backward pass before we code the
backward pass let's look at the
algorithm of the backward pass as
written in the paper this is the paper
flashh attention one and I will because
we will follow the structure of the code
that is present on the Tron website so
it's not my idea to split it like this
but I simplified it in s i simplified it
so it's different than the one that you
can find online because mine is a
simplified version and mine works with
causal and non-causal
attention um so first if you look at
this algorithm you need to you can see
that we have an utor Loop through all
the K and V blocks and an inner loop
through all the query
blocks however as you can see to compute
the DQ which is the downstream gradient
of the the loss with respect to the Q uh
Matrix we need to have an iteration
through all the ks and to compute each
DC block we need to have an iteration
through all the cues so if we follow the
loop like it is it would involve writing
to the high bandwidth memory so to the
dram of the GPU at every inner iteration
and that could be also that that is not
so
efficient um and also if we don't want
to write it would require some sort of
inter some sort of um synchronization
between blocks which is also not very
efficient so we split we will split this
four into two parts because we can see
that each DQ depends on a loop over the
case and each d DK depends on a loop
over all the cues so to compute DK we
will fix the Kate block and iterate
through all the Q blocks then we will do
another iteration in which we fix the Q
block and iterate through all the KV
blocks to compute the Q this is what we
are going to follow and this is an idea
that I took from the original
implementation that is present on Tron
website another thing that we can notice
here is um where where is it here to
comput DQ and DK so um a DQ vector and a
DK Vector we need this element this
information here called di di and it's
shared between the two so we can
precompute it and then we can reuse it
for the qite vector to compute the qite
vector and the DK Vector what is this di
di is um is uh introduced here and it's
the dot product of a um Vector that is
the DOI Vector multiply by o Vector so
the first thing that we will do is do a
loop over all the vectors in O and d o
and do their dot products to compute
this di element then we will use this di
element and actually uh let me see yeah
and then we will use this di element to
update to to compute DQ and DK and we
will also have another two Loops one in
which we fix Q and we iterate through
all the keys and one in we fix the keys
and iterate through all the cues so
let's start so now that we know more or
less the structure of the code that we
are all right so um we start by writing
this backward
function
here uh let me check yeah okay so do you
remember this is saved tensor these are
all the information that we save during
the forward pass uh to compute the
backward pass now to to optimize the
memory utilization in um flash attention
we don't save the query multiplied by
the transpose of the key Matrix because
that would be a sequence by sequence
Matrix that is too big to save into the
hbm in the dam during the forward pass
and then re reget it back from the hbm
into the local memory because I want to
remind you that in Tron uh compared to
Cuda in Tron what we do is we load stuff
from the high band withd memory in the
shared memory so the SRAM we do all the
operations there and then after when we
call the store method we save the
element from the shared memory into the
high band WID memory so in order to not
materialize this s Matrix in its
entirety save it to the hbm and then
Reet it back which could be very slow
and secondly actually it is very
expensive because usually right now we
are Computing attention on thousands and
thousands of tokens so imagine saving a
matrix that is 5,000 by 5,000 that's a
big Matrix to save for each batch uh for
each batch and for each head so that
would be really too expensive to
save so the idea in Flash attention is
to recompute what we can compute on the
fly during the backward pass because
anyway if we were to load it it would be
Memory IO found so it's faster to
recompute than to save it and restore it
from the memory this is the idea of
flashh
attention Okay so we saved some stuff
during the forward pass and now we can
access it back during the backward pass
and this stuff is saved in the context
and this it's it's a kind of a
dictionary that is made available by by
P
torture all right so we get back the
query key and values and as you know P
torch during the autograph will just
give us the gradient of the loss with
respect to the output of our
implementation of the attention of our
attention so this is Tron attention and
then we need to compute DQ DK and DV by
using only the gradient of the output
with respect to the the deloss with
respect to the output um we do for some
checks so here I know I could optimize
this code and make it even smaller by
for example checking that here the
stride that I am using I actually inside
of the code I always uh pretend that the
stride is the same but uh doesn't matter
I just take the code from Tron and uh
try to simplify it my goal was to
simplify it not optimize it
so all right we create the um the
vectors the tensors in which we will
store the result of this backward pass
which is the DQ DK and DV and as you
know from what we have seen of the
definition of the gradient the size of
the output of the gradient Vector is the
size of the uh Vector with respect to
which we calculate the gradient because
in the numerator is always a scalar and
we compute the gradient with respect to
all the elements in the input Vector so
the output the gradient itself is a
vector of the same size of the element
by which we compute the gradient with
respect to so uh we got some information
on the bed size blah blah blah and later
we will see what is this number of warps
and number number of stages I will not
explain it now it's how P torch number
of Parts warps is an indication on how
many threads we want to launch in our
grid and number of stages is actually
the number of stages that is used in
software pipelining we will see later
what is software pipelining when we talk
about the autot
tuning then we Define some uh
blocks uh in the original um in the
original code I think they call it a
block kv1 kv2 k q1 1 and Q2 I think it
was confusing I call it block macro and
block micro because the things that we
will fix and the things that we will
iterate from will be once is the query
so we fix the query block and we iterate
through all the keys and then we will
fix the keys and values block and we
iterate through the queries the one that
we iterate on is the micro one and the
one that we fix is the macro one this is
my uh the naming that I am
using um then we as I said before we
need to precompute the DI elements that
we saw in the paper before so that's the
first kernel that we are going to launch
and this kernel will have its own launch
grid because later we want to optimize
the uh the tuning of this carel later we
will talk about tuning with respect to
its own parameters so uh let me see what
are we going to do so here so the first
kernel that we are going to launch is
this pre-process kernel
this preprocess can will pre-compute all
the DI elements that we need to compute
I remember DK and DV if I no DQ and
DK and this di element depends only on o
and d o um so let's do it and uh let's
create another function uh called
backward pre-process what is the process
pre-process grid this is the launch grid
of this function of this car
and this will be launched on um
independently for each batch and for
each head and moreover it will be work
with a block size of vectors of O what
is this block what is this number of
vectors of O it will be the block size
macro so on 128 vectors of O so uh let
me copy the signat of this function this
is here so let's write it here it's fine
yeah okay this function takes uh The
Matrix o so it's a pointer to The Matrix
o it's a pointer to the d o and it's a
pointer to The Matrix D where we will
store this di elements and we have one
for each Vector in the
output uh that's why the shape of this D
is batch size number heads sequence
length it means it's one for each of the
output element in the output of the
attention this di uh where is it
actually it's not this one it's this one
yeah like M so it has the same shape as
M which is as you can see it is this
size here so bch size number heads and
sequence length M if you remember is the
Matrix that we Sav during the forward
pass which includes the normalization
factor of the soft Max and also the
maximum element but in log some exp
format so that when we apply it will
automatically apply the maximum element
for each row and also normalize that at
the same time which I think I proved
previously uh
so let me do it so we write it like this
so we extract
the uh the index of this program so this
program has two uh index uh like
identifier this is is equivalent to the
Cuda identifier and this is along the
axis zero so let's see what we uh what
we what did we launch on the axis zero
so on the axis zero of this launch grid
we defined what is the block of vectors
of uh the O that this particular will
program will work with and the second
axis is which batch and which head
inside of each batch this particular
program will work with so this
identifies the block index of Q so which
group of vectors in the O Matrix this
particular program will work with here
is called Q I believe because I copied
it from the original code where they
call it q but I could have eventually
also called it
o um so we Define uh so basically this
means that we uh for this program we
need to skip some query vectors that
have been already or that will be or
have been already processed by other
programs in parallel so we will only
block with a number of query vectors
inside of O that have the following
indices so so imagine that query block
size is I think it's 128 the way we have
defined it but suppose it's four for
Simplicity so this one will be and the
query vectors are how many are sequence
length number of query vectors we have
so some of Imagine The query vectors are
in total they are I don't know let's say
uh 64 and 32 will be managed by other
programs so this particular off skq will
be equal to 33 34 35 and 36 this tells
me which query vectors or which vectors
in the output o Matrix among all the
vectors in the O Matrix this particular
program is going to work
with okay so then we extract also the
index of the batch which tells us which
batch and which head in each batch this
particular program is going to work with
which is the dimension one of our launch
grid
and then we Define the offset of the
dimension because we need to load all
the dimensions of each Vector so these
are the uh it's a vector uh that tells
um which Dimensions we need to load from
each vector and we will load all of them
so we don't divide on the head Dimension
uh Dimension we just divide on the
sequence length
Dimension the the load among multiple
programs um you will see in this part of
the the video so when we are writing the
backward pass that we will not be using
the make block pointer like we did
during the forward pass so this function
here we will work with directly with
indexing by using the
strides so let's do
it so let's load a single block of rows
of O which I want to remind you has the
same shape as q and that's why we can
call it block size Q um so the O Block
that we are loading is O so uh the load
function accepts a pointer to what it
should load uh actually not a pointer it
accepts a array of pointers or a
multi-dimensional array of pointer in
case you want to load a multidimensional
data so actually load also allows you to
load um two dimensional data in this
case we are going to load two
dimensional data which is a block of
rows of O which should be a block a
tensor of the shape uh block size Q in
this case multiplied by the other
dimension being head
Dimension but we don't we need to tell
it where in this o Matrix it needs to
find this one first of all we need to
skip some batches and some heads based
on what the head and the batch that will
be processed by other programs So based
on the index that this um program will
process of the batch and the head we
need to skip all the other um batches
and heads
uh let's write the shape of this tensor
so the O tensor has a shape block size
uh not block size bch
size number of
heads then sequence length and then head
Dimension each block and each head will
have sequence length multiplied by dim
head dim number of items So based on our
index we skip how many items the our
index multiplied by head Dimension
multiplied by sequence length so what I
mean is this the batch zero and the head
zero will have sequence length
multiplied by head Dimension items the
batch zero and the head one will also
have the same number of items and the
batch zero and head two will also have
the same number of items so how many
items sequence length multiplied by head
Dimension do we need to skip from the
starting of the O tensor it is equal to
the index of the current batch and head
indicator so because this index
indicates both um the head in the batch
and the head inside of each bat because
it's already the product of the head and
the batch so how many we skip indicated
by the this index and after we point to
this starting point of the current batch
and the current head we need to select a
two-dimensional tensor where the offsets
are indicated for the rows by off SK and
that's that's why we have this one um
the I don't know what this is called
this is the the index U semicolon index
that tells all the all these vectors in
offc will with an additional dimension
for the columns and this columns will be
the offs dim so basically this will
select a tensor of the following shape
inside of this big tensor that includes
pet
size and number number of
heads this is what we are doing so we
are saying select a tensor of this size
inside of one that is made up of four
dimensions by skipping the elements of
all the batch and heads that will be
processed by other programs I always
talk in terms of programs because in
Tron these are called programs in Cuda
you would refer to them as
kernels all right so this one is done I
hope it is this recently clear um all
right so then we also load a single
block of
D in the same way because we are going
to load a a group of vectors from all
the sequence length also from D and the
D has the same shape as o which has the
same shape as q and that's why we can
use the um the the block index we call
it Q because it's equivalent because
they have the same
shape okay and how to compute this di
element well it's written in the paper
so if we go in the in the what is it man
if we go here it shows you how to
compute the DI of each given a block of
D and um a block of O it tells you how
to compute di which is the row sum which
means the the sum of by rows for each
row we will have one sum for each Vector
in the O Matrix we will have one sum of
the element wise product so this stuff
here is the element wise product of D oi
multiplied by oi so it's not um uh
matrix multiplication it's element wise
product which means each element of one
Matrix with the corresponding element of
the second Matrix and the output shape
it will be the same as the two matric
which must have the same
shape okay so we compute this di
block which will have shape block size Q
because we will have one sum for each
Vector uh then well we need to store it
somewhere so we need to calculate where
to store it inside of the D Matrix uh
well the D Matrix is I remember
correctly has the same shape as M so it
should be batch
size uh number of heads and sequence
length so we need to select the right
batch and the right head and also the
right position inside of the sequence
length based on the block index Cube
that we
have uh okay so let me
index
[Music]
okay all right because we already um so
we skip um again just like before we
know that D is of this size each botch
and each head will have sequence length
number of elements so how many number of
elements we need to skip from the
starting of the uh tensor is sequence
length multiplied by the combined index
B size head
number uh and plus we need to also skip
some queries based on our block index q
and it's this skipping is already done
inside of off skq so we add off skew and
then once we have computed the index
where we should store this di I block
why did I even call it d block let's
store it so let
me I didn't call it block I think it was
already in the original code but this is
di I and this big Matrix D is actually
the Matrix that includes all the DI for
one for each token in the sequence
length all right so the pre-processing
has been done now we need to do prepare
the two for Loops as you remember I said
before we will be doing two for Loops
one in which we fix the query and we
iterate through all the keys and values
and one in which we fix the key and
value block and we iterate through all
the queries and while coding it I will
always show you the formula from the
paper so don't worry let's start with
the next iteration so first we create
the launch grid for the next iteration
um as the launch grid is always the same
so we first because we we need to keep
one block fixed and iterate through all
the other blocks uh the block that we
keep fixed we Define how many programs
we have that run in parallel uh and the
block that is fixed has a block size
macro number of elements that's why we
create a sequence length divide by block
size macro number of blocks uh thread
blocks uh or programs in this AIS uh the
AIS to in this grid is I could have used
also the XIs one IND differently I think
it was already done here in the original
code it's um will indicate which batch
and which head inside of each batch we
are going to work
with uh so so and just like the uh
forward pass we will also use a variable
called stage that if the attention that
we are Computing is caal it will be
equal to three and if we are Computing a
non-causal attention then it will be
equal to
one um the first iteration we will fix K
and V blocks and we will iterate through
all the Q blocks in size of block size
micro number of query
vectors uh so let's look at the
signature so we pass we we launch it as
a launch grid because um and we we have
defined how many programs we have so we
have how many uh KV blocks we will have
it's a sequence land divide by the block
size macro because that's the the block
that we will keep fixed in this uh for
Loop in this
function and then we go through all the
query blocks in size of block size micro
which I defined it as 32 and later we
will talk about autot tuning and how to
tune these
values all right so I passed the query
Vector the key vector and the V Vector
uh sorry not Vector tensors now the
query tensor K tensor and V tensor and
they are pointing to the beginning of
the tensor which means that they are
beginning to the first botch and the
first head and the first token and the
first dimension of the tensors then we
pass the soft Max scale we pass d o DQ
DK and DV m is the one that is needed to
compute as you remember from what we
said before we did not see if the P
Matrix in the hbm because we want to
recompute it on the fly during the
backward pass so the query multiply by
transpose of the keys it's a very big
Matrix to save in the hbm and restore it
so we want to compute it on the fly but
we don't need to recompute the
normalization factor and the maximum
element for each row to apply the soft
Max that was already computed during the
forward pass and saved into this Matrix
M which includes the log sum exp of the
maximum of each row plus the logarithm
of the normalization factor with the
theug sum X trick we can just apply it
and it will also normalize each value
then we have the d uh V uh tensor that
we computed here with all the DI values
one for each Vector in the O tensor then
we need to pass some uh the number of
heads the sequence length the block size
that we want to use for the KV which is
the macro block size and the micro block
size is always the one that we iterate
on I think using this name it should be
easier to understand which one we are
iterating and which we want keep fixed
so the fixed one is macro and the
iterating one is the micro um head
Dimension um and later we will see why
we use a different block size to iterate
from because this is related to the
number of stages that Tron can divide
your for Loop into thanks to soft
pipelining then we have head Dimension
the stage indicates if the attention
that we computed in the forward pass was
Cal or not
cal um the number of warps and the
number of stages which we defined as
fixed but later we will talk about Auto
tuning so uh sometimes I repeat the same
stuff over and over so I should change
that
um okay let's write the signature of
this function um
let's put it here but
so we already described what is the
signature of this function let's go
directly to the meat so the first thing
that we need to do is understand the
offset by which we need to move this
query key and value and the offset is
given by the first of all we need to
enter the right batch and the right head
inside of each batch we compute the
index of the batch just like during the
forward pass by dividing the program the
program index which is a multiplic of
the index of the head and of the the
batch we divide it by the number of
heads to get which batch this program is
working with and to get the head we just
do the modulus just like in the for Loop
uh forward
pass um the offset batch head indicates
let me check what is it for okay it
enters the right batch and the right
head so what is the stride if you
remember correctly The Stride tells us
how many items you need to skip in that
Dimension to arrive to the next index in
the same dimension so if we want to skip
index number of batch we need multiply
it by the stride batch which is how many
elements you need to skip to arrive to
the next batch plus we also need to
enter the right head so we multiply um
the index of the head multiplied by The
Stride of the head to enter exactly in
that head in the tensor for each of the
qk and V
matrices uh plus we also have this is
will be used for if I remember for m and
d because m and d only don't have the um
the head di head Dimension so they are
only bch size number of heads sequence
length so we just use the index batch
multiplied by sequence length because
for each batch and on each head we will
have sequence length number of item so
you can think of it at the stride to
move from one batch head to the next
batch
head uh or to the yeah so
uh let's move the pointers
and this was so we move the pointer q k
and V by the offset batch head because
we want to enter the right um batch and
the right head inside of these big
tensors and we do it also for d o DQ DK
and DV because they have the same shape
as a q k and V and d o also has the same
shape as Q so they have the same shape
so we move by the same uh by the same
offset all right
so and then we move m and d to move them
to the right starting point on which the
sequence of the current head and the
current batch and the current head
starts so they are pointing to the first
Vector of the sequence dedicated to the
current badge and the current
head and the same is true for qk and V
and the d o DQ DK
andb okay then we load some other stuff
because here we fix in this iteration in
this method we are going to do a for
Loop in which we fix KV and we iterate
through Q so we first need to load this
dibs block of KV and we do it as
follows as follows so we know we need to
load a 2d tensor so we need to Define
what are the ranges in the second
dimension of each um vector k and V that
we need to load and it's defined by the
this
um by this
Factor then we uh we want to understand
which KV block this particular program
is going to work with so this particular
program is going to skip some KVs that
will already be managed by other
programs that may be running in parallel
and how to understand what this program
should be working with in based on the
index of the program zero which is
defined on sequence divide by the block
size macro and if you remember block
size macro is the thing that we fix so
it's telling us this program ID zero
will tell us how many uh block size
macro KV are already being managed by
other programs so we shouldn't care
about them so we skip them so let's go
back here and this is the number of
vectors that we need to skip so our KB
Start From Start KB and how many we need
to load them well depends on what is the
block KV this block KV is equal to block
size macro so it will be
128
vectors so we Define our um tensors two
dimensional tensors that we will store
in the SRAM because in Tron every time
you load something you load it from the
hbm into the SRAM so we Define where
they should be saved in the SRAM and
they are initially zeros and now we load
them so we load them as as
follows um we say that okay in the K uh
in the K tensor pointer which is already
pointing to the right index to the right
batch and to the right head because
that's something that we did
here we say we should need we need to
load the right sequence of keys which
should start from offs key because this
already includes how many We Should Skip
in the sequence length Dimension and for
each of these Vector we need to load all
the dimensions in
the in the head Dimension Dimension uh
because the K if I want to remind you is
batch number of
heads um sequence length and head
dim Now by using this line we are
skipping to the right B and to the right
head so it's like we already indexed
here and here we already selected an
index so right now this K is pointing to
the beginning of a tensor of two
Dimension and we tell okay we don't want
all the sequence we want some part of
this sequence which part the one that is
indicated by this start
KV um and how many of in the sequence
length we want well we want uh I I think
it's easy to write it like this so we
can write it that from start KV to start
KV plus block
KV uh so we want this number of tensor
exactly at this location and for head
Dimension what do we want to select we
want to select all the dimension so we
say that we want from zero to head
Dimension which is exactly this offs
dim
okay uh we do it for the K block and we
do it for the V block
uh here I think I didn't change the
comment this should be block
KV and this should be block KV before it
was called block kv1 right like in the
original code uh I simplified a little
bit the naming I think this one is
better easier to follow because in the
original code they also do for two for
Loops but in the second for Loop they do
it backward just to not change the
structure of the loops but I think mine
is more verbos but easier to
understand and probably less efficient
mine is much less
efficient um then we have offc because
we need to understand for each block of
queries how many vectors we need to load
and it's indicated by this offs q and
how many are them it's a block Q block Q
in the col color of this method was
block size micro so it is 32
vectors okay um now we need to access
Q vectors and O vectors trans uh no Q
vectors but already transposed and the O
vectors also we need to access them
because we are going to iterate through
queries and O vectors actually also why
because let's look at here let's look at
the formulas in the paper to compute VJ
so to compute the dvj that's what we are
trying to compute here we need to
iterate through all the D vectors and to
compute DK we need to iterate through
all the QI uh
vectors because the QI is a block of
vectors uh so uh that's why we need um
and why do we need to access a q as a
transposed because we need to compute
let me show you here uh P transpose that
to compute P transpose we need to we
need the Q trans transpose because the p
would be the soft Max of the query
multipli by the transpose of the
keys after we apply the soft Max it
becomes P but if you want the transposed
of P then you need to do query
transposed K multip by query transposed
so that's why we accessed query
transposed instead of
queries and the way we access query
transposed is just by playing with the
stride so let's do it like this and I
have also written the comment on why we
can do
it so this is equivalent to accessing
the
query uh how many first okay what is
this um what is this operation uh what
is this operation here this is saying go
to the query starting point starting um
pointer to the query which is already
pointing to the right batch and to the
right head for which this particular
program should work with and select a
two-dimensional Vector where you repeat
the query starting point along the uh in
this case along the columns but we
should be repeating it along the rows
because we want to select rows of
queries however if we want to select uh
the query transpose we just invert the
two Dimensions so this is a let me
actually show you without doing the
query transpose so let's do it
simplified like this so to access the
query um
the query pointers without transposition
we can just do like this go to the query
tensor and create a 2d tensor where in
the rows you put the starting point of
each query that you want to get and um
and replicate each of this pointer also
on the column that's the meaning of
adding this Dimension none this is
equivalent to when you do in py torch
the UNS squeeze like you are call
offs Q multiplied not un squeeze I think
one so this is equivalent to adding the
column Dimension to this tensor and
repeating all the values that are on the
on all the um on the columns how many
columns will be there it will be
broadcasted when we sum it with the dist
tensor here um this is a combination of
UNS squeezing and broadcasting so we are
taking the query vectors indicated by
offs
cuq um and um then we are uh for each
query Vector we are selecting all the
head Dimensions indicated by dim if you
invert this broadcasting it will create
the transpose of the the the query
Vector that you're trying to access so
this stuff here is equivalent to the
these two lines so accessing query and
then transposing
and uh it's something that you can do U
I could write down what is happening at
the pointer level so basically you need
to think of offc as being a vector of
pointers uh we multiplied by the
sequence stride which tells us how many
element we need to skip to go from one
query Vector to the next because each
stride Q will be the stride will IND
will be equal to in the case the head
Dimension is 128 The Stride of the
sequence Dimension will be 128 it means
that to go from one query Vector to the
next you need to um you need to uh uh go
forward by 128 elements because I want
to remind you that in the memory the
tensors are always stored like flattened
like each Dimension is flattened with
the next Dimension so imagine you have
three rows and four columns but the
first you will have the first three rows
then the sorry the first row so the
first four columns then the next four
columns then the next four columns row
after
row U it's difficult to visualize until
you write it down so uh how to write it
down take um create a vector of offs
skew so what is offc at the beginning is
is a range that is from here from 0 to
100 no 0 to 32 0 1 2 3 4 five 6 7 etc
etc we are multiplying each one with the
stride of the sequence so this will not
skip any element this will skip exactly
128 elements this will skip exactly
implying that the head Dimension is
128 uh this will skip two times 128
element this will skip three times 128
elements and then we are adding also
another Di di menion to this Vector so
this will be a vector then you broadcast
it on head Dimension number of columns
and to each of them you add one number
so it will become a vector like fall
okay let me just do it guys otherwise I
think it's too
confusing okay so we have a vector that
is as follows so zero then we have 128
then we have two * 128 then we have 3 *
128 etc etc we are adding how many
columns indicated by off dim so off dim
has how many columns it has a head dim
number of columns please for Simplicity
let's pretend it's not
128 Dimensions let's pretend it's four
dimensions so this will be four this
will be 2 * 4 this will be three * 4 we
are adding another dimension that is the
dim
Dimension each one multiplied by The
Stride of dim which will be one because
it's the last
Dimension um stride de so we are adding
how many columns
four uh so we are adding um one 0 1 2 3
I guess 0 1 2 3 right also to here this
one we are adding oh oh my God 0 1 2 3
and also to this one we are
adding 0 1 2
3 uh okay and then also to this one we
are adding 0 1 2 3 so what this will
select this will select from the
starting point of the pointer Q it will
select the element zero then the element
one then the element two and then the
element three which is exactly the head
dimension of the first Vector that we
should be selecting then it will
select uh the element four from the
starting point of the vector the element
uh sorry this one let me write the
result of this operation so this one
will be 0 1 2 three then it will select
the element four 5 6 7 then it will
select the element um eight I guess 9 10
11 and then it will select the element
12 13 14
15 so from the starting point of where
this Q is pointing it will select the
first element right after this que the
second element right after this que the
third element right after this Q etc etc
and this will be the f you can see that
this will be the first query Vector this
will be the second query Vector this
will be the third query Vector this is
the fourth query Vector because in the
memory they are stored one after another
they are flattened so in the memory they
are stored like this they are stored
like the
following they are stored like this one
after another so it will select all of
them and we also create a virtual tensor
with the right shape that we want to
visualize it into so as you saw as we
saw before when you work with a tensor
layout in memory you can always view it
as whatever shape you like based on the
shape that you want and the reshaping is
always free doesn't involve changing the
arrangement of the elements in the
memory I hope now it's more clear so now
we can proceed
further oh my God it was quite
complicated so whenever I get stuck I
just draw things and I think you should
do it too because that's the only way to
learn uh if you try to imagine
everything in your head it's always
difficult
and we do the same job for the O vectors
in the O vectors we don't access it as
access it as um transposed because we
don't need it in transposed only the Q
we need it
transposed okay it race through the
sequence dimension of the query so we
start from the query number
zero in the current um well in the query
we need to go through the all the
sequence length Dimension because only
the key we select the right key that we
want to work with so I want to remind
you here we fix the key and we go
through all the
queries but the query we need to start
from zero until sequence length so the
number of steps of this for Loop will
be uh sequence length divide by block Q
uh so if we have 1,000 elements in the
sequence and the block Q is the 32 it
will be 1,00 divide by 32 I'm bad choice
of 1,000 should be 1,24 otherwise it's
not divisible so then we go through each
block in this for Loop and we load a
block of Q the first one indicated by
our pointer and at the end of the
iteration we will move it to the next um
to the next block of
Q okay we'll add also the log log Su exp
values that are stored in the M Matrix
uh because we want to compute on the Fly
PT PT is the transposed of the soft Max
of query multiplied by the keys but we
want to not take quy multiply by the
transpose of the key and then do the
transpose we just already access um Q as
transposed so we can already compute a
PT instead of computing p and then
transposing
it um so we load the offsets of the
elements that we need from this log some
exp
x uh Matrix which is the m Matrix that
we computed during the forward pass and
we access a block of Q at a time the one
we are currently working with in the
iteration then we access a
query key transposed already so we do
the if you want to get a PT um P should
be um this is actually not P because we
didn't do the soft Marx it's actually s
t but okay if you want to get PT you
need to get the soft Max of
STD um the soft Max of St is what it's a
transpose of s what is s is a query
multipied by transpose of ke so to get
St you need to do um key transposed no
key multiplied by query transposed so as
you remember in the matrix
multiplication if you transpose the
matrix multiplication you need to also
invert the two uh element in the matrix
multiplication so that's why we are
doing a key multiply by quy transposed
this will give us s transposed we are
also scaling it with the softmax
scale before we apply the to apply the
softmax we just need to do the
exponential of each element minus its
maximum divide by the normalization
value but with the log sum XT we just
need to um each element um subtracted by
the M value which already includes the
normalization factor uh I think I
already did the derivation of this so we
don't need to go through that again okay
so now we have the PT block actually so
this this formula I should have written
St
actually okay um then when doing the
causal attention we also need to um mask
out some
values um so as you can see here so the
in this case the Cal mask is applied
after the soft Max has been computed
because during this one is a you are
used to uh compute the apply the soft
the Cal mask before Computing the soft
Max attention but this is actually
during the forward pass because you
don't want the normalization factor to
be affected by the element that should
be
zero uh but we already computed the
normalization factor so it cannot be
affected anymore so we can compute we
can mask out after applying the sofware
because the normalization factor has
already been calculated based on the
fact that we applied the mask and that's
why we we we can apply it after applying
the soft Max
uh so the mask is always the same so if
the query is more than the the index of
the query um so the mask is true in this
case for all the values that do not need
to be masked so all the values that do
not need to be masked are this ones here
uh and all the other value will be um
will be replaced with the
zeros all right so after we have the PT
block already masked we can calculate
the DV DV I will write I will point to
the right formula in the paper so we
load a block of the O why do we know to
load a block of do let's look at the
paper so how to compute the DV block so
the DV block is computed
as the old DV plus so a repeated sum as
you can see as you can see it's here
plus
equal the old DV plus PT so here PT
dropped indicates the P IJ after
applying the Dropout in this
implementation we don't support the
Dropout and also very few models
actually use the Dropout in the
attention um so PT multiplied by DOI so
a block of DOI and DOI is the same block
that should be also um DOI and Ki Qi are
referring to always the same block of
rows in the respective um t s that's why
because this inly iteration I indicates
a block of q and a block of O but
they're always referring to the same
positions in the tensors because d o has
the same shape as DQ so we go through
the blocks of query and d o
simultaneously because one is needed for
DV so for D DV we need d o and for DK we
need
q and that's why we compute the DV as
follows just like from the paper so PT
block multiplied by d o as you can see
it's a ppose multiplied by d o block so
we have comput computed the do block um
then we need to load the DI element that
we computed precomputed
initially uh the d uh with the first
call to the function called the
attention backward Pro
pre-process because we will need it for
Decay so let's see
um and how many of them we are loading
exactly the same number of query that we
load because um they are we load always
the same number of block size micro
number of
vectors okay I will copy some stuff and
explain it step by step so um the next
operation that we need to do is to
compute this DK to compute DK we need
DST to compute DST we need need to to
get
DPT so let's go one by one let's go from
the B from the end to the beginning of
this uh formulas so we don't uh we we we
where where everything is used to where
everything is created so let's start
from DK if you look at the paper DK is
equal to the old DK plus DS transposed
multiply by a block of
Q um and this is is what is written here
so it is plus equal means basically just
the old plus the new some um it's an
incremental addition so increment the
old K with some new stuff which is this
stuff here so the soft Max scale
multiplied because also there is a soft
ma scale this tow here multiplied by the
matrix multiplication between the St
block and the transposed of um and and
Q and Q you can see here this q but we
don't have a Q we have a Q transpose so
we take the transpose of Q transpose and
it becomes back Q now let's look at this
DST block DST is calculated as follows
so in the formula of the paper we have
DS d s is here it is yeah it is here and
it is equal to a block P J multiplied
element wise with DPI minus Di now um we
don't need the S we need the S
transposed so to compute the S
transposed this is an element wise
multiplication not a matrix
multiplication which means that when you
take the transpose of this operation you
don't need to invert anything you just
need to take the transpose of the two
operan so to compute the St we take the
transpose of P which is the PT and we
already have that and then the transpose
of everything that is inside of the
parenthesis so this
DPT minus di I where we inverted the
rows with the columns so this DPT is
what well in the paper we know the
formula for DP DP is here and it is
equal to D wait wait wait DP here and it
is equal to d o multi by V transposed so
uh but we don't need the DP we need the
DPT and in this case it's not an element
wise multiplication it is a matrix
multiplication so um in order to get not
DP DP but DPT we need to take the
transpose of these two operants of this
matrix multiplication and in the matrix
multiplication when you take the
transpose you need to also invert the
order of the two operants so we need to
take the VT transposed which becomes V
so the V block Matrix multiplied by the
other operand so D oi transposed and
that's why we are doing the transpose of
d o
um right now I'm not going through all
the single pointers because I already
told you how to check what a pointer is
pointing to and what an offsets is
referring to I hope that now you have a
better understanding on how these
pointers work in uh Triton which is also
the same way in the in which they work
in Cuda because um in the GPU we only
get a pointer to the starting point uh
to the starting address of the tensor
and then we need to work out all this
indices we have computed the DK block
block so we now go to the next query uh
to the next block of
queries and um so uh the next GL block
of queries because we are fixing K and V
blocks and we are iterating through all
the queries so we need to move the query
transpose
pointers uh by stride sequence which
means that how can we go from one query
to the next and we multiply with the
current um block Q which is a vector
which indicates the pointers to the
current element in Q that we are
accessing and we do it also for the O
and we use the block Q as element and
the stride Q because the O and Q all
have the same
shape okay after we have run the for
Loop of all the queries we can store
this DK and DV block so we write it back
as
follows and this is the end of our
function guys so so we save the DV block
exactly in the position inside of the
current okay DV is already I believe
pointing to the right batch and to the
right head because we incremented it
here and also in the case of Decay then
we need to tell it in the sequence
Dimension where they should save this um
block of K and V and this is indicated
by this one we say and we create the the
pointers just like before guys don't
make me do it again it's a really really
um easy if you write it down like you
write this uh um Vector of key and
values pointers which is not pointers
actually they are a range of the of key
and value that you need to take from the
sequence
Dimension you add another dimension that
is the column so you repeat each value
in the columns and then you add the
dimension here for the head Dimension
anyway after we have Cal at the pointers
where we should store the DK and DV we
store them in the um the pointers of um
we store them in the DV um uh I mean we
store them in the DV uh tensor and the
DK tensor what do we save we save the DV
block and the DK block which is the one
that we were um uh incrementally
changing in the for Loop that we have
written okay now that we finish this one
we can go to the next function that will
do the other for Loop so let's do it
okay so now we do the second part of the
iteration which is this one so let me
just copy it and then we we we describe
it uh let's write it here okay we use
the same lunch grid as before of course
we need to declare this function and
again we um we because the grid is
defined for the block size Macro for
what is the thing that we keep fixed and
then we in the side of the for iteration
we do um steps of block size micro in
this case we are fixing q and we are
iterating through K
andv uh because we need to compute DQ
right now we have computed DK and
DV okay uh the I believe the arguments
are the same as before so and actually
this is also the reason why in the
original implementation on the Tron
website the author decided to um to use
the same for Loop but different
arguments and uh I believe it was a
little confusing so that's why I just
separated them I just repeat the code
twice it's the goal of this video is to
be as easy to understand as possible not
to be as efficient as possible
so uh let's go uh here so let me copy
the signature again and we Define this
function here okay so uh again we need
to first move the query key and value uh
to the right pointers which will point
to the exact batch and the exact head
that we are working with in this program
so um let's do it let me check where is
the code here and the first part is
exactly the same as the other for Loop
that we have
written so let's go here and really is I
just copied so it's exactly the same so
we check what is the index batch head we
move the query key value pointers to the
right place the d o DQ DK DV point to
the right place the m and d to the right
place exactly like before so I don't
think I need to explain that
again and then we load a block of Q the
one that we will keep
fixed so
DQ let me load a lot of stuff here
actually
okay we Define the offset that we will
need to load the blocks of k&p in the
head Dimension because we are going to
iterate in the k
andv um we will access them as
transposed blocks so instead of
accessing them directly as K and v We
access access them as KT and VT and you
know that that's possible just by
changing the
strides uh in this case because we are
treating them as 2D vectors we treat the
offs KV when you want to access K as
just not transposed but K you treat this
offs KV as a row Vector uh sorry a
column Vector so you repeat on the rows
each k um offset that you want to access
in this case we are repeating it as a we
are treating it as a row Vector so it
will be repeated on the
rows um sorry it will be broadcasted on
the column Dimension and that's how you
can access the transposed version of K
and how you can access the transposed
version of V another thing that we are
doing is we are loading the Q Vector
which Vector well based on ofq which is
based on the start Q which is based on
the exact starting point in which this
particular program should be working
with because this particular program
works as two
Dimensions uh the First Dimension
indicate which batch and which head this
program should be working with and the
second dimension which is the program
index number zero indicates which among
all the sequence length which query this
particular program is going to work with
this is indicated by the index block um
this should be actually Q in this case I
forgot to change the name so actually
let me change it so it's index Q because
we start we skip some Q how many Q we
skip um based on the index of the
current program multiplied by how many
uh blocks have already been processed by
the previous
programs this this will tell us inside
of the sequence length what are the
queries that this one needs to select so
that's why we use the start query plus
the range that is block Q so imagine the
starting query for this program among
all the sequence length is 100 then this
will load the query row 100 1001 1002
blah blah blah until 100 plus block Q
minus
one this is the range that we of the
query vectors that we will load in this
program
uh we load the block of Q by using a Q
Plus the offset repeated on the columns
so we treat it as a column Vector but we
repeat broadcast it on the rows
Vector um where each uh column will be
one head Dimension multiply by The
Stride in this case we actually can also
not multiply by The Stride because the
stride in the uh Dimension Dimension so
the last dimension of the batch is one
uh because to go from one
um actually the stride um how it is
defined the stride of the last Dimension
is always one because to go one element
to the next element you should move to
move it to by one
element um so we load the DQ which is
the stuff that we are going to compute
in this
um iteration and then we have D that we
need to load and the do use the same
offset as Q because D and DQ have the
same shape and they work in the same way
so we load a block of Q and we load the
corresponding block of
O of d o in this case and d o has the
same shape as o which has the same shape
as Q Plus we need to load also the M
normalization factors which are in the M
Matrix which one the cor the one
corresponding to this particular group
of queries that we are going to work
with in this particular program
we start with the uh offsets uh are the
as you can see the offsets are the first
block of KV starting from the zero
position so because we will iterate
through all the KVs and we start from
the zero KV so the key key Vector zero
and the V Vector zero and then we will
move by block KV number of vectors
forward at each
iteration I hope I didn't go too fast
because most of the things that are
written here are very similar to what we
have already done um in the the other
for Loop so I don't want to be you know
repeat myself too much um what it matter
is actually the the formulas that we
will use which is exactly the one in the
paper
so uh we go through this GL blocks of
keyway KV we load the first block of KR
transposed and V transposed which is
loaded like this as usual you tell it
what pointers the elements you want to
load and what are the pointers of the
another element that you want to load
and it will load the the block that you
are asking Tron to load inside of the
SRAM so this stuff all reside in the
SRAM and also Q resides in the SRAM and
also d o reside in the
SRAM um then we compute the query
multiply by the transpose of the keys
because we need to compute the P block
so the query the qk block is just the
query in the current query block with
the K transposed in the current query
Block in the current key
block um why but we access query the
keys already as transposed so we don't
need to transpose it and anyway even if
we did if we need to transpose it it's
just um it's not um doesn't require any
computation to transpose a matrix we
just access it in a different
way uh because in the memory layout it's
always stored kind of as a flattened
array um then we computed the P block
which is the output of the soft Marx so
each of the query key we substract the
log sum exp value for the this uh block
of queries that's why for loading the M
Block we use the offsets of the queries
that we are
loading uh and as you remember the M
Block already includes also the
normalization factor because each m is
actually the maximum value for each row
plus the logarithm of the normalization
factor that when you apply with the
properties of the exponential it goes
into the
denominator okay and then we apply again
the auto regressive
masking
oops what did I
do let me go back to the code here so we
have the stage this one so when we
launch the back backward pass stage
three indicates that it's a the in the
forward pass we computed the caal
attention and the one indicates that we
computed the non-causal attention so if
we computed the causal attention in the
forward pass we also need to mask out
these elements in the backward
pass so we check um we created The Mask
which tells us which index uh this mask
is true for only for the elements for
which the query index is more than the
key index and if this is true then we uh
we don't mask otherwise we
mask um let's compute the next operation
which is to compute DP and DS actually I
let's compute directly DK and then we
explain it like before so we start from
the end and we go to where this stuff
what is needed to computed so if you
look at the formula uh let me check
check this one we don't need I think
okay uh let's go here to the iPad okay
what we are trying to compute here is
DQ so DQ as you can see in the paper is
DQ is equal to the old DQ plus to which
is the soft Max scale which is this
stuff here m multiplied by the matrix
multiplication between DS and K block so
the DS block is here and the K block is
the transpose of the KT block because we
are accessing K already as a transposed
block we could also access K directly as
not transposed block by inverting if you
want to access it as a transpose block
just do like this like here none this
will treat it as a row vector and
broadcast along the columns otherwise
and also this one you need to change so
this one you should need to change
because this one you need
to treat it as um a column Vector the
dimensions but if you want to access it
as a k transpose then you just inverted
these two operations I hope I didn't
mess up anything so let's move forward
um so okay we know that the formula for
the uh DQ is exactly the same as the uh
as the paper one but what is this DS
block let's look at the paper this DS
block is coming from this stuff here so
this I believe this stuff here d s which
is a
pi the P block element wise
multiplication with DPI minus di which
is DPI minus d i now what is um the this
P block the p block is exactly the
output of the soft Max which we already
have what is the DP block well the DP
block is exactly d o multiplied by V
transposed which is d o which we already
loaded and it's here multiplied by the
transpose of the V which we already load
as transposed and this is how we comput
the
DQ uh let's Inc then of course we need
to move to the next uh block of uh KVs
so we increment the pointers just just
like before so we move to the next block
of keys and
values and also we move the
pointers um just like before and then we
need to store the result of DQ and this
way we only need to do one right to the
hbm by dividing the for Loop like the
following so if you look at the original
algorithm uh I I don't know if the
original algorithm actually corresponds
in to to the implementation that they
did in Cuda but I don't think so because
it would not be so
optimized but in the original algorithm
in the paper they say that you need to
go through all the keys and then while
going through the keys you need to go to
all the cues and for each queue that you
visit then you need to write back the
queue while you are updating it which is
not optimized that's why we needed to do
two for Loops one in which we fix the
query and we updated the keys because
each Keys update depends only on
particular block of Que
uh on all the blocks of Q sorry and then
we fix the queries and we iterate
through all the keys because one block
of Q depends on all the blocks of
K and this is why we split and this is
the second Loop that we have
written now we have written everything
that we need to for flashh attention um
the forward pass and the backward pass
so uh we should be ready to uh launch
the uh the kernel I hope I didn't make
any mistake in copying the
so I don't think I will try to launch it
and if there is any any error I will
just use my reference code which I have
already written that I used as a copy
the only difference up to now between my
reference code and the one that we have
written is the autot tuning which I
didn't explain so let's talk about the
autot tuning so the autot tuning is also
something that was already present in
the original paper and I kept it as is
uh I removed the autot tuning for the
backward
pass but in the forward pass you if you
check there is this code here that
indicates um the autotuning
configuration for Tron so Tryon
basically um cannot know uh beforehand
what is the best block size or what is
the best block size for the query or
what is the best block size for the key
and values or what is the best block
size for another dimension that we have
we need to try based on the hardware
that we are running on based on the
availability on the SRAM based on the
thread quaren that Tron can apply so I
didn't talk also about thread quaring
basically in Cuda you can choose if each
thread does one Atomic operation for
example in a matrix addition each thread
is doing one addition of one particular
element of the output Matrix or it's
managing multiple elements this is
called thread quaring and I think I
didn't check the documentation but I
believe Tron does it for you based on
the Block size that you give it and the
number of warps that you want the number
of warps is what is um a block of
threads of 32 threads that work Cooper
cooperatively running the same
instruction always at the same time the
number of stages is more interesting
it's an optimization that Tron does um
basically it is not Loop and rolling so
actually let's talk about uh let's talk
about software pipelining because this
is the last part that we need to
understand from this code which is the
aing so I believe that most interesting
part here is not choosing the block size
Q and The Block size K because that is
just kind of you try whatever whatever
configuration works best based on the
timing CA Tron will actually run all
these configurations for you every time
the sequence length or the head
Dimension changes and for every pair of
head Dimension and sequence length it
will choose the best configuration that
runs in the least amount of time that
gives you the best throughput actually
um so let's look at this num stages what
is it and how it works um so let's do it
okay so uh software pipelining
is is used when you have kind of a for
Loop so you have a sequential operation
in which each iteration does not depend
on the previous iteration so the
operations that you do in one iteration
are independent from what you have done
in the previous iteration which is more
or less what we have done before in our
for Loops actually there I believe there
are um uh how to say conditions in which
this doesn't have to be three so like
the operations can depend on each other
and you still can do software
pipelining so for example imagine you
have the following for Loop uh for Loop
that Row from one to n and first you
load some data then you load some other
data then you do a matrix multiplication
and then you store some data so here you
are reading data here you are reading
data here you are Computing some stuff
and here you are writing data if we look
at what happens at it each iteration we
will see the following picture imagine
our GPU is made up of a compute unit and
a a unit that is dedicated to loading
stuff so reading from the memory or
writing to the
memory what we will see in the time
scale is that at the first iteration
first we are reading some data and the
compute unit is Idle because we need
this data then we are reading some more
data and the computer unit is Idle
because we need this data then finally
we have enough data and then we can
compute this operation and the reading
unit is idle and then we are writing
some data back to the memory and the
comput unit is again idle and then it
will be idle for another two time steps
until it has enough data to run the
computation so as you can see this is
not very efficient because at any time
point in time there is only one unit
working and the other is a sitting Idol
so one way to uh optimize this for Loop
is to do software pipelining and you can
tell try to do it for your for Loops by
telling it how many stages you want so
let's see how it works so to pipeline a
for Loop means that first of all you
need to convert all these operations
into async operations and in Cuda at
least in the GPU of mvidia there are the
async loading from the memory and the
async load writing to the memory which
basically means that I spawn a load
operation and after and um when I I only
I check if it has completed when I
actually need it so I will spawn this
operation and this instruction will
return immediately and move to the next
instruction here I will spawn a load
iteration and this will return
immediately and move to the next
instruction and then I can uh compute
but before Computing I just check if
these two operations have completed so I
can spawn immediately two reads and then
I just check if this uh they have
completed so with software pipelining
what we are doing is we are pipelining
operations of different iterations into
a single iterations so first basically
what we will do is we will do the read
the first Matrix that we need for
computing this matrix multiplication
then at this next iteration we read the
um the uh we we read the first Matrix of
the second iteration and also read the
second matric of the first iteration so
I call it read a and read B which
indicates read the first Matrix of the
uh that we need and the b means the read
the second Matrix that we
need uh all these operations are
asynchronous then I launch another
asynchronous operation at the third
iteration that says read the um the
first Matrix of the third iteration and
then read the uh second Matrix of the um
of the second iteration and then compute
the matrix multiplication because at the
third iteration this one and this one
should have completed but but while
Computing the matrix multiplication I
don't keep the loading unit idle because
they are still Computing this this and
this load this can only work if you can
spawn async
operations so at the third iteration I
can compute this matrix multiplication
by using this one and this one because
they should have finished but while I'm
Computing the matrix multiplication I
already spound some async operations to
load the data necessary for the second
iteration and the third iteration so at
the fourth iteration I will spound the
loading of the data for the fourth
iteration loading the data for the third
iteration while Computing the matrix
multiplication of the second iteration
because they should have already
completed by now uh actually it's not
like we expect them to have been
completed there are um Primitives in the
language or in the Cuda language to
check if the operation has completed so
actually before doing the matric
multiplication we will actually check if
the asyn operation has finished so it's
not like we just expect it to be
finished with respect to time uh this is
like in JavaScript you have these things
called prom promise I remember and you
can wait for the promise to be finished
before you actually need them but you
can spawn as many promise as you want in
C I think they are called tasks so you
spawn as many tasks as you want and then
when you need the it then you just wait
for them only the one that you needed
while the other are still running in the
background asynchronously this is the
whole idea of software
pipelining um software pipelining as you
can see only works when you have a sync
operations and also it increases the
memory requirement for your program
because when uh matrix multiplication
one uh is going to run we may have
enough data for the first two iterations
plus plus half data for the third
iteration so we increase the memory
requirement for the um
SRAM okay and P Tron will do this
software pipelining for you it will
convert all the load all the stores and
maybe also the metrix multiplication
into a sync operations and do this
pipelining for you if you are confused
by how it works there is another easy
solution to explain you how it works
because it's already something that we
do in model training it is called
pipeline parallelism
so in pipeline parallelism it works as
follows we have a very big neural
network that does not fit in a single
GPU so imagine this neural network is
made up of three layers layer one layer
two and layer three but this is so big
it does not fit entirely in one simple
GPU so one way would be to put this each
layer into one GPU so we put for example
layer one into GPU
one a layer two into GPU 2 layer three
into GPU number three so imagine we have
an input for this Neal Network so we put
it to the first GPU the GPU one will
process the layer one and generate some
output which will be transferred to the
gpu2 which will calculate its own output
and transfer it to the gp3 which will
compute its own output and finally we
will have the output of the Nal Network
the problem is when you send the output
of the gpu1 to the gpu2 for the gpu2 to
do its own thing the gpu1 now is free so
it is a waste of resources we could
always should keep the gpus busy so what
one thing that we can do is instead of
sending all the the mega batch to the
gpu1 we send many smaller batches how
does it work imagine that we send the
batch number zero so patch zero uh to
the gpu1 the gpu1 will compute its
output and send it to the gpu2 so now
the gpu2 is Computing the batch number
zero so now the batch zero is not here
anymore but now the GPU one is free so
we send another micro batch called the
batch
one then gpu2 will finish processing the
batch zero and we'll send it to the
batch to the GPU number three so now the
GPU 3 has the batch number
zero and the gpu2 now is free so we
transferred and hopefully also gpu1 has
finished so we transfer the batch number
one from gpu1 to gpu2 the bches um and
then the GP one will be free so so we
transfer here becomes one and now this
one is free so because it's gpu1 is free
we can introduce another batch so batch
number two etc etc etc so we always
introduce when while moving one batch
from one GPU to the other we introduce a
new batch at the beginning of the
pipeline and they shift by one position
at every iteration this will keep the
gpus always busy uh there is only one
problem of the pipeline parallelism
which is the this bubbling effect
because to create this pipeline you at
the beginning of this um okay actually
in the pipeline paradism you also have
the problem of the backward step so the
backward step has to run exactly in
Reverse in the order in which you
receive the micro batches while in Tron
when doing software
pipelining you have the problem of the
prologue and the epilog because you need
to create this Pipeline and and
to start the pipelining and at the end
of the pipeline you need to uh use all
the stuff that is currently in the
pipeline so only in the beginning step
and in the last step of this for Loop
your um all the units of this GPU may
not be working
simultaneously which what does it mean
it means that in order to use pipelining
you want the number of iterations of
your for Loop to be much more bigger
than the number of stages in which your
iteration is divided into in this case
we have four stages these are called
stages so you want the number of
iterations to be much more to to be much
larger than the number of stages all
right guys finally I have completed the
video um I hope that you learned a lot
from this video I believe that we can
run the Tron code so let's run it
actually uh let's see I copied
everything I believe we also put the
code to test it but we didn't uh put the
uh main method which we can copy right
now I hope there is no error so I really
hope there is no error I really hope so
um let me check if I am in the right
machine I am so let's just run
program prey if there is an error I will
just copy my own reference
implementation but I hope it works
because otherwise I forgot something so
I'm running my code on an h100 uh
because my company has h100 uh if you
have a smaller uh GPU what you can do is
you can reduce the sequence length uh
you can reduce the BET BET size I think
is already one when we call it uh oh no
the BET size you can reduce the BET size
the number of heads the sequence length
you can even put head Dimension equal to
8 and sequence length equal to 16 uh
let's
check uh run backward Tron backward
return the incorrect number of grading
expected five got one
we probably forgot some return statement
I
believe
yes so I forgot the return statement
here so in the backward pass after
running the last for Loop we need to
return the stuff that we have
computed cross finger
again okay passed so the backward pass
that is computed by torch it is
equivalent to our backward patch up to
10 to the power of minus two error
absolute error um so when you as you can
see this backward that we run here is
different than the backward that we run
here because when you apply Tron
attention it will introduce a new
computation graph in the computation
graph of our tensors that will include
this Triton attention operator and when
pyto want to compute the backward pass
it will just call the backward function
of this Tron attention to computed and
it will populate the grad value of all
the tensors that are the input to this
Triton
attention and this is how pytorch
autograd Works guys uh thank you for
watching my video guys it has been super
super super
demanding I spent many uh months first
of all to learn myself about the Tron
about Cuda about Flash atten Etc uh also
I have a full-time job so it it's is
really hard to make videos like this
like I need to dedicate you know the
nights the mornings the the weekends I
spend three days just to record this
video because sometimes I don't like how
I explain something sometimes I make
mistake or sometimes I need to restart
because what I'm doing is wrong etc etc
um I believe there should be no big
errors in what I have done so far but
for sure my notation is completely bad
like because all the mathematics I know
has been self-taught by I I learned it
by myself so because I didn't learn it
in Academia I have bad habits and I'm
trying to get rid of them so I use the
very bad notation sometime I call with
the capital letter sometime with this
lower case sometimes I just forget the
index Etc so I'm trying to solve these
problems um I believe I have explained
everything so I should be uh you should
have all the knowledge to derive all the
formulas that you see in the paper of
the flashh attention uh and you should
also have an internal image on how the
back the the the attention calculation
is working block by blocks I know that I
could have spent 20 hours explaining
things better but I also have a life and
I also have a wife so I I I I cannot
make a 100 hours videos um also there
were some interruptions making these
videos I I removed some wisdom teeth so
it took me at least one more than one
week to to to to recover because it was
so painful uh so thank you guys for
watching my video I hope you learned a
lot also this time I as you can see
Triton is something new there is not
much documentation so something that I
have said about Tron may not be totally
correct because really there is very
little documentation so all the Tron
that I have learned is by looking at the
code written by others and try to
understand
it
um and I think that's it guys so I wish
you a wonderful day and see you next
time on my channel
Loading video analysis...