LongCut logo

Flash Attention derived and coded from first principles with Triton (Python)

By Umar Jamil

Summary

## Key takeaways - **Deriving Flash Attention from First Principles**: The video aims to derive and code Flash Attention from scratch, treating the original paper as if it never existed, to foster a deep understanding of its mechanics. [08:08], [21:22] - **Optimizing Attention by Focusing on Softmax**: Flash Attention primarily focuses on optimizing the softmax computation within the attention mechanism, as matrix multiplications are already highly optimized on GPUs. [01:10:12], [07:57:53] - **Addressing Softmax Numerical Instability**: The video explains the challenge of numerical instability in softmax due to large input values and details the 'online softmax' technique to compute it safely in a single pass. [20:11:12], [38:38:43] - **GPU Architecture and Parallelism**: Understanding the GPU's architecture, with its numerous cores optimized for parallel computation and limited control units, is crucial for writing efficient CUDA and Triton kernels. [01:21:22], [01:51:07] - **Triton Kernels and Memory Management**: Triton allows writing GPU kernels using Python, abstracting away some CUDA complexities and optimizing memory access by leveraging the GPU's shared memory (SRAM) over slower global memory (HBM). [02:40:48], [02:54:26] - **Flash Attention Forward Pass Mechanics**: The forward pass involves block-wise computations, handling potential inaccuracies from local softmax computations by using an online softmax approach and a correction factor. [01:44:01], [01:46:47]

Topics Covered

  • GPU Memory Hierarchy: Why Attention is IO-Bound.
  • Online Softmax: Fusing Passes for Numerical Stability.
  • Block Matrix Multiplication for Efficient Parallelization.
  • GPU Parallelism: Mapping Work to Threads and Blocks.
  • LogSumExp Trick: Optimizing Backward Pass Memory.

Full Transcript

hello guys welcome back to my Channel

today we are going to explore flesh

attention now we are going to fles

explore flesh attention from first

principle which means that not only we

will code flesh attention we will

actually derive it so we pretend that

the paper flesh attention paper never

existed and we look at the attention

competition and we look at the problem

it has and we try to solve it step by

step pretending that flashh attention

never existed this will give us a deep

understanding of how it works and also

we will combine Theory with practice

practice because we will code it now in

order to code flash attention we we will

need to write a kernel for our GPU and

in our specific case I will be using an

Nvidia GPU so Auda kernel but instead of

writing C++ code we will use Tron which

is a a way of converting python directly

into um Cuda uh kernels that can run

directly on the GPU and Tron you can

think of it as a compiler that takes in

Python and converts it into something

that can run on the GPU um so let's look

at the topics for today first of all I

will give an introduction to multi-ad

attention because we need to look at

what is attention and how it's computed

and what are the problems in Computing

this attention then we will look at

actually the most critical part of the

attention computation is this softmax

and how it impacts the computation and

uh complexity we will look at what is

online softmax then we will explore what

is the GPU because we are going to write

a kernel that will run on the GPU so we

need to understand what is the

difference for example the CPU and the

GPU and what is the kernel and how it

differ from a normal program that you

write for the CPU we will look at how

tensors are layout in memory so row

major layout column major out Etc

strides um we are going to look at block

matrix multiplication Tron software

pipeline all the optimization that Tryon

does to our code finally we will be able

to code the flash attention forward pass

but of course we are not satisfied only

by coding the forward pass we also want

to code the backward pass but in order

to code the backward pass we also need

to understand how out works and the

gradient descent works in the case of

custom operations so we need to

understand what are derivatives what are

gradients what are jauan and then we

calculate the gradient of the common

operations that we use in Flash

attention and finally we will have

enough knowledge to code the backward

pass for this reason this video is going

to be super long but I hope you don't

mind because we are going to learn a lot

of course you may be wondering all of

this requires a lot of knowledge that

you may not have but that's not a

problem because that's my problem

because in this video I will make sure

that if you only have high school

calculus so you know what are

derivatives you have basics of linear

algebra like you know what is matrix

multiplication or what is the transpose

of a matrix and you have a basic

knowledge of attention mechanism so like

for example you have watched my previous

video on the uh attention is all you

need paper and you have a lot of patient

that should be enough to understand all

of this video because all the topics

that I will introduce I will always

introduce them in such a way that I um

pretend that you don't know anything

about that topic so so we try to derive

everything from first principle

everything from scratch okay now that we

have seen the introduction let's go see

the first part of the video which is the

multihead attention all right let's talk

about multi-head attention now I am

using the slides from my previous video

attention is all you need so uh we can

look at very fast at what multi-head

attention is and how it works I hope you

remember the formula soft Max of the

query multiplied by the transpose of the

key divide by DK all multiplied by V

because we will be using that a lot

throughout the video now multi-ad

attention starts from an input sequence

or two input sequence in case we are

talking about cross attention in the

simple case of self attention we have

one input sequence which is a sequence

of in the case of language Model A

sequence of tokens where we have sec

number of tokens and each token is

represented by an embedding so a vector

with the model Dimensions the first

thing that we do is we um convert this

uh input sequence into query key and

values through three linear projections

one called WQ one called WK one called

WV which in pytorch are represented

through linear layers and these linear

layers are of D model by D model so they

do not change the shape of the input

tensor and

the after we do this job of projecting

them they become three different

sequences one called qu one called key

and one called value so here I'm calling

them Q Prime K Prime and V Prime then we

divide them into smaller embeddings so

each of this token which is made up of

the model Dimensions we divide it into

smaller tokens each one suppose we have

four heads each one will have the model

divided by four dimensions so this one

is a sequence of tokens where each token

is not the entire token but a part of

the embedding of each token and this one

is a another part of the embedding of

the tokens and this one is another part

of the embedding of the token Etc and we

do this job for the query key and value

sequence then we compute the attention

as follows so the soft Marx of the query

multiplied by the transpose of the key

divide by the um the square root of DK

where DK is the dimension of each head

so how many dimensions each head is

working with and then we do the

multiplication with v and this will give

us the output of the attention mechanism

for each head and this job is done

independently for each head this should

be clear to you if it's not please watch

my previous video on the attention

mechanism because we will be working

with this uh scenario a

lot now uh then we take this o this uh

the output of each head and then we

concatenate it back in order to get the

representation of each token uh as a

full embedding so before we split this

embedding into smaller embeddings this

one here it's called the q1 Q2 Q3 Q4

then after we compute the attention we

get back um the output of each head and

we concut it together to get back the

full embedding Dimension which is this

edge here we run it through another

linear projection called wo which will

be the output of the multi head

attention now flesh attention is not

concerned with all of these operations

actually flesh attention is only

concerned with the operation that

require uh optimization and the

operations that require optimizations

are this one so the soft Marx of the

query multiplied by the transpose of the

key divide by the square root of DK

multiplied by V which means that the

projection of the input sequence through

WQ WK and WV is not something that

flashh attention is concerned about

because that's a matrix multiplication

so when you use a linear layer it's just

a matrix multiplication of the inut with

the weight Matrix of the linear layer

and this kind of operation so the matrix

multiplication is one of the most um

optimized operation that we have in the

GPU because the manufacturer of the GPU

usually also releases um the necessary

library for computing the the the metrix

multiplication so actually these are

quite fast and they do not require any

optimization so flesh attention will

pretend that the query is has already

been passed through by WQ and the key

has already passed through WK and the V

has already passed from WB moreover

flash attention will not be concerned

with the projection with wo because

that's also a matrix multiplication

because the wo is always represented in

byor as a linear layer so it's a matrix

multiplication and matrix multiplication

as we have seen are very um optimized so

there is nothing to optimize there uh

but what we need to optimize in terms of

speed is this operation here soft marks

of the query multiply by transpose of

the keys by the square root of DK

multiplied by V all right guys so now we

have rehearsed what is multi head

attention I also want to give you a lot

of visualization which is

basically here in the paper of the multi

head attention we can see that we have

the input that is uh v um K and Q so q k

and V each of them runs through a linear

layer which is the WQ WK and

WV um then we do the scale dot product

attention which is done independently

for each head so each head will do query

multipli by the transpose of the key

divide by the square root of DEC where

each query and each key is not the full

embedding of each token but a part of

the embedding of the token because we

split them into smaller embeddings and

eventually we take all the output of

each of these head which are computed in

parallel so that's why you see this

Dimension Edge in the depth we

concatenate them and then we run them

through

wo uh what are we we concerned with we

are concerned with optimizing this

particular block here the scaled do

product attention so let's start our

journey one thing that is very important

to understand is why do we even need a

better implementation of the attention

mechanism and if you look at the flashh

attention paper you will notice The

Following part this is the paper flashh

attention one and in the flashh

attention one paper they describe the

attention Implement imple implementation

as it's done naively when using pytorch

so first we do the multiplication of the

query multiplied by the transpose of the

keys then we apply the softmax to the

output of this operation and finally we

multiply the output of the softmax with

the V Matrix to obtain the output of the

attention the way this implementation is

done by pytorch without any optimization

is as follows so we load the first of

all these tensors are residing in the

GPU the GPU is made up of two main

memories one is called the hbm which is

the dram which is the the the ram of the

GPU which is the 40 GB of the a100 for

example so it's the biggest memory that

we have in the GPU and then uh there are

there is the shared memory so the

problem of the GPU is that accessing

this hbm so the global it's also called

the global memory it's very very slow

compared to the shared memory however

the shared memory it's much much much

smaller compared to the hbm and what

they claim in the flesh attention paper

is that the operation of the attention

is IO bound meaning that if we keep

accessing the um the um Global memory

the overall uh operation of computing

the attention is not because Computing

all these operations it's slow but

because we keep accessing the global

memory which is slow so we call these

kind of operations iob bound so the only

way to improve this situation is to

compute the attention inside the shared

memory of the GPU which is much smaller

which is much closer to the cores that

actually do the computation so we we

will need to kind of also split the

attention computation into smaller

blocks that can reside in the shared

memory and we will see later in how this

is possible through block matrix

multiplication and this is in the paper

here they call it the tiling and it's a

very um uh how to say use the technique

when doing um when writing kernels for

the GPU uh which are usually involve

some kind of metrix multiplication so

now we know what problem the flashh

attention is trying to solve it's trying

to make sure that we do not need to

access the hbm so the high bandwidth

memory when Computing the uh attention

but copying only a part of each Matrix

inside the local memory so the sh shared

memory of the GPU that is closer to the

course and Computing a part of the

output Matrix there then copying that

part to the output in that is residing

in the hbm and keep doing it for all the

blocks in which we can divide these

query key and value matrices and later

we will see how this blocked computation

is done but also we will see that the

biggest problem in Computing this

blocked computation is the softmax

because the softmax needs to access

all the row of the S Matrix to apply the

soft Max

because uh the the soft Max needs to

have a normalization factor which is the

sum of all the exponentials of all the

values to which it is applied rowwise

and we will see later how we will solve

this problem so let's move

on all right guys um okay when I say

guys I mean guys and girls because I

don't know in my usually I just say guys

to you know but please girls don't feel

excluded so we saw that FL fall flashh

attention is only concerned in

optimizing this soft Max of the

transpose of soft Max of the query

multip by three divide by the square

root of DK multiplied by V and we need

to introduce a little bit of notation so

that we don't get lost um in the future

slides first of all this is the formulas

I took from the flashh attention paper

but for now we Let's Pretend flashh

attention never existed so we are trying

to solve the the problem step by step

now um we should treat this Q as

something that has as the sequence that

is the output of the input sequence that

has already passed through WQ the K as

something that has already passed

through WK and v as something that has

already passed through WV because we

don't want to optimize the matrix

multiplication because it's already fast

enough another thing is let's talk about

what are the dimensions of these

matrices so we can then understand what

is the the dimensions of the output of

these operations we will see treat Q as

a sequence of tokens with end tokens so

n tokens where each token is D has D

Dimension so lowercase D Dimensions why

because usually we take the queries and

then we split them into multiple heads

so we have we pretend we have already

done this splitting so we pretend we

already took our input sequence we

already run it through WQ and then we

have already splited into multiple heads

and each of these head will do the

following operation so the the one we

already saw and so the usual formula

where multiply the transpose of the keys

and each of this head will work with

this dimensions for the query for the

key and for the value sequence so now

let's look at the uh the the dimensions

of the output so the first operation

that we will do is the query multiply by

the transpose of the keys where the

transpose of the keys is is a matrix

that originally is n by D but become but

with the transpose will be D by n so d

by n and the result will be a matrix

that is n byn because in a matrix

multiplication the outer Dimensions

become the dimension of the output

Matrix what do we what is the next

operation that we do we take the output

of this operation so the query multiply

by transpose of the keys and we run it

through a soft Max operation and we will

see what is the soft Max

operation which preserves the shape of

the input so it doesn't change the shape

of the input Matrix it just changes the

values of it and then we take the output

of the softmax and we multiply it by V

which will change the um which will uh

change the of course the shape because

the p Matrix is n by n so this one is n

by n and V is um n by D so this one the

output will be n byd the outer

dimensions of this matrix

multiplication now let's look at the

details of each of these operations so

when we do query multiply by transpose

of the keys we will get a matrix that is

n byn where each value in this Matrix is

a dotproduct

of a row of q and a column of K in

particular the first element of this

Matrix will be the dot product of the

first query with the first key Vector

the second element will be the

dotproduct of the first query with the

second key vector and the third element

will be the first quy with the third key

etc etc and the the let's say the the

last row of this Matrix will be the dot

product of the last query with the first

key then the last query with the second

key the last query with the third key

etc etc until the last query with the

last key um you may also notice that

here I have written query transpose the

key because when we um what is q1 first

of all q1 is the first r of the query

Matrix so little bit of um background on

matrix multiplication so we know that

when we do matrix multiplication each

output element is one row of the first

Matrix with one column of the second

Matrix but we are doing the product of

the first Matrix with the transpose of

the second so it will be the do product

of the one row of the query Matrix with

one row of the key Matrix because we are

doing the multiplication with key uh K

transpose

um when you take a vector from a matrix

the usual notation so the in as as in

how to say in um in in mathematics in a

linear algebra we always pretend that a

vector is a column Vector so we cannot

just write Q multiplied by K because

that would be mean uh that would mean we

are doing the dot product of uh we are

doing the kind of the matrix

multiplication of one column Matrix with

one column Matrix that is not possible

because the shapes do not match so as a

notation we write that we do the dot

product of the first Matrix transpose

which is a column Vector but we

transpose it so it becomes a row Vector

with the second Vector this is just

because of notation guys so um you just

need to pretend that this is the first

query with the first key then the first

query with the second key the first

query with the third key etc etc etc so

we are doing do of vectors then we apply

this softmax

operation the softmax operation what it

will do it will transform each of these

dot products which are scalars so the

output of a DOT product is a scalar and

it will transform each of these numbers

in such a way that they become kind of a

probability distribution rowwise which

means that each of these numbers is

between 0er and one and when we sum up

these numbers together they are sum up

to one and this condition this property

will be valid for each row so this row

also will sum up to one this row will

sum up to one and this row will sum up

to one etc etc

etc let's see what is the softmax

operation now uh given a vector so let's

call it X which is made up of n

Dimensions the softmax is defined as

follows so it is the um the soft Max

basically transforms this transforms

this Vector into another Vector with the

same Dimension where each item of the

output Vector is calculated as follows

so the E element of the output Vector is

the exponential of the element input

element divided by the summation of all

the exponentials of all the dimensions

of the

vector basically this is called the

normalization factor to make it all

these numbers between zero and one we

usually normalize that's why it's called

the normalization factor and uh we use

the soft Max because we want each of

these numbers to be positive we don't

want the stuff the output of this

operation to be negative so that's why

we use the exponential but there is a

problem the problem is Imagine our input

Vector is made up of many numbers that

are maybe large so for example let's say

X1 is equal to 100 X2 is equal to 200 X3

is equal to 300 which is can happen um

if we do the exponential of this number

so the exponential of 100 that is going

to be a huge number it's going to very

close to Infinity uh at least compared

to what we can store in a computer so

the output of exponential of 100 may not

fit into a floating Point 32 or a

floating Point 16 number or even an

integer of 32 bit so we cannot compute

it because it will overflow our uh our

variable our integer that is storing

this value this

output so we talk in this case about

numerical instability so every time you

hear the term numerical instability in

computer science it means that the

number cannot be represented within a

fixed representation with the bits we

have available which are usually 32 bit

or 16 bit uh we have also 64bit but that

would be too expensive to use um so

let's try to find a solution to make

this stuff here computable and

numerically

stable in order to make this soft Max

operation numerically stable which means

that we want these numbers to not

explode or to become too small that they

are not representable we need to find a

solution and luckily it's quite easy so

the softmax as we have seen before it is

the following formula so each number is

exponentiated and then we divide it by

this normalization factor which is just

the sum of the exponential of each input

dimension of the input

Vector um if we multiply the numerator

and the denominator of a fraction with a

constant uh with a number number then

the fraction will not change so that's

what we are going to do we are

multiplying the numerator and this

denominator with this factor C as long

as C is not equal to zero of course uh

then we can U take this C and by using

the the distributive property of the

product with respect to the sum we can

bring this C inside of the summation as

you can see here um then we can also

write every number as the exponential of

the log of itself because the

exponential and the log will cancel out

and um then we can by using the

properties of the exponentials we know

that the product of two exponential is

equal to the sum of the is equal to the

exponential of the sum of the arguments

of each exponential and we do it on the

numerator and in the denominator then we

just call this quantity minus log c

equal to K or k is equal to minus K is

equal to log C so we can replace this

quantity with k

we can do that because this is a

constant that we have chosen and we just

are assigning it to another

constant um so basically by doing this

derivation we can see that we can sneak

in a value inside of this exponential

that if chosen carefully can reduce the

argument of this exponential and we will

choose this k equal to the maximum

element inside of the input Vector that

we are applying the soft Max to so that

each this argument will be either zero

in case x i is equal to the maximum

element that we are processing of the

vector or it will be less than zero and

we know that the exponential when it's

equal to zero will be equal to the

output of the exponential will be one so

the argument when it's zero it will be

equal to one and when it's smaller than

zero so it's in the negative range it

will be between Z and one so which is

easily representable with floating Point

32 for example so this exponential will

not explode anymore so basically to

apply the the soft Max to a vector in a

numerically safe way we need to find a k

constant which is the maximum value of

this vector and when we apply it we need

to substract each element minus this

constant that we have chosen so let's

look at the algorithm to compute the

soft Max so first of all given a vector

or given an N byn Matrix because we want

to apply the soft Max to this Matrix

here which is n byn we need to go

through each row of this Matrix and for

each row we need to find the maximum

value among the elements which takes

time complexity linear with respect to

the size of the vector to the size of

the row to which we are applying the

softmax then we need to compute the

normalization factor which is this stuff

here and we we cannot compute it before

the step number one because we need to

have the maximum element to compute this

summation here and after we have

calculated the the normalization factor

we can then divide each elements

exponential by the normalization factor

and we cannot do the step number three

before uh calculating the normalization

factor because uh we need to divide each

number by the normalization factor so if

you like a pyo code for algorithms this

is an algorithm for computing the

softmax that we have seen right now so

first we find the maximum of the row of

to which we are applying the soft Max

then we comput the normalization factor

and then we apply the soft Max to each

element which means that we calculate

comput the exponential of each element

minus the maximum value of the vector

divided by the normalization factor now

this pseo code is an algorithm that is

quite slow because look at a practical

example imagine we have this Vector here

first we need to do step one find the

maximum value in this Vector which is

number five and this takes linear time

computation then we need to calculate

the normalization constant which is the

sum of the exponential of each element

minus the maximum value so e ^ of 3 - 5

plus e to ^ of 2 - 5 etc etc and this we

will call it l and then each we need to

go again through this Vector again and

take the exponential of each element

minus the maximum divided by the uh

normalization Factor so to apply the

soft Max to an n byn Matrix we need to

go through each element of this Matrix

three times and these operations must be

done sequentially so we cannot start

operation two until we have done

operation one and we cannot start

operation three until we have done one

and

two so this is quite slow only to apply

an operation that doesn't even change

the shape of the matrix it's just uh

normal normalizing the values so there

must be a better way that that does not

involve three sequential operations in

which we need to go through this metrix

three times let's

see all right guys let's rehearse what

is the problem that we are trying to

solve the problem statement is the

following can we find a better way to

compute the softmax that does not

involve going through the vector three

times because let's look at the pseo

code of the algorithm for computing the

Lo the softmax that we have found so far

imagine we have a vector made up of four

elements the first thing that we need to

do is to compute the maximum element in

this Vector which means going through

this for Loop here that allow us to

compute the maximum element in this

Vector which means that we start from

the left side of the vector and

iteratively go to the right side so we

start from the first element arrive to

the end and we compare the previously

found maximum with the current element

to find the global maximum basically

this means that uh I I know that this is

very simple uh I'm probably sure that

you don't need to this example but

making this example will help us

understand what we will do next so

please bear with me even if it's super

simple what I'm doing um okay we at the

beginning

m0 is equal to minus infinity M1 is

basically the for loop at the iteration

number one which means that we are M1

will be equal to the maximum of the

previous estimate of the M which is

minus infinity with the current element

which is three so it will become equal

to three then M2 will be equal to the

maximum of the previously computed

maximum so M1 so three with the current

element which is two so it will be equal

to three M3 will be equal to the maximum

of the previously computed maximum so

three with the current oops three with

the current element which is five so it

will be equal to five and M4 will be

equal to the maximum of the previously

computed maximum and the current element

so it will be equal to five so this

allow us to compute the maximum element

so at the fourth iteration we will have

the maximum the global maximum

independently of what is the input

array um delete okay after we have

computed the maximum which we know is

five we can compute the normalization

factor so let's start with the l0 l0 is

equal to Z L1 will be equal to the

exponential of l0 so actually sorry it

will be l0 plus the exponential of the

current element so three minus the

maximum element we have found in the

previous for Loop so five then L2 will

be equal to L1 plus the exponential of

the uh the current element so it's two

minus the maximum then L3 will be equal

to L2 plus the exponential of um the

current element 5 - 5 then L4 will be

equal to the uh L3 + exponential of 1 -

5 if you expand this L this this will be

basically equal to E power of 3 - 5 + e

^ of 2 - 5 + e the^ of 5 - 5 + e the ^

of 1 - 1 -

5 after we have computed this

normalization Factor we can use it to

normalize the each element in the input

Vector which means that the X new X1 so

X1 Prime let's see will be equal to e to

the power of um uh what's the first

element

three - 5 divided by L that we computed

in the previous for Loop so the L at the

fourth

iteration uh the new

X2 so X2 Prime will be equal to the E

to^ of 2 - 5 / L4 and X3 Prime will be

equal to the E to power of 5 - 5 ided L4

etc etc for all the

elements uh I know this is super simple

but it will help us later so in this for

Loop we have that we need to go through

the vector three TS because first we

need to compute this for Loop then we

need to compute this for Loop and then

we need to compute another for Loop we

cannot do them not in this sequence

because in order to compute this for

Loop we need to have the maximum element

because we need it here and we cannot

compute this forloop until we have

computed the previous one because we

need to have the normalization factor

however we are stubborn and let's try to

fuse these two operations into one for

Loop which means that we go through the

array and simultaneously compute Mi and

simil in the same iteration we also try

to compute LJ of course we will not be

able to compute LJ because we don't have

the global maximum because we didn't go

through the all the um array yet however

let's try to use the locally um the

whatever estimate we have of the maximum

so far so let's try to use instead of MN

let's try to use m I so the local

maximum that we have computed so far so

if we apply the soft Max in this way in

this fused way to this Vector we will

have the following

iterations um so this is our array or

vector and the first step is MI so M1

will be equal to the previous maximum

which is minus infinity with the current

element so the maximum minus infinity

and the current element is equal to

three and L1 will be equal to the

previous l so l0 which is starts from

zero plus e to the power of the current

element minus we should be using the

global maximum but we don't have the

global maximum so let's use the whatever

maximum we have so far so we can use

three now at the second iteration we are

at this element of the vector and we

comput the maximum so far so the maximum

so far is the previous maximum and the

current element so the maximum of the

previous maximum and the current element

which is the maximum between three and

two which is three uh and the the

normalization factor is the previous

normalization Factor plus exponential of

2 - 3 which is the current element minus

whatever maximum we have so far now if

our array were made only of these two

elements so three and two then whatever

we have computed is actually correct

because the maximum that we have found

is a three and it's actually the global

maximum and the um normalization factor

that we have computed is actually

correct because each of the exponent

itial has been computed with the global

maximum because the first element was

computed using three as the uh with the

argument minus three and also the second

element was computed with the argument

with the argument having minus three in

the in the argument which is the global

maximum of the vector however when we

arrive at the third iteration so let me

delete this Vector so we arrive here at

the third iteration the maximum will

change which will also uh cause our

normalization factor to get to to be

wrong because we arrive at the element

number three uh so the number five here

and we computed the maximum so the

maximum is the comparison of the

previous maximum and the current element

so the new maximum becomes five and the

normalization factor is the previous

normalization Factor so L2 plus the

exponential of the current element minus

the current estimate of the maximum

which is five however

if you look at this L3 this is wrong why

because L3 is equal to if you expand

this summation it will be equal to e to

the power of 3 - 3 + e to power of 2 - 3

+ e to the power of 5 minus 5 this

exponential here is using five as the

global maximum this exponential here is

using three as the global maximum and

this one is using three as the global

maximum so the first two element have

been computed

thinking that the global maximum is

three but actually we later we found a

better Global maximum which is five so

which makes this normalization Factor

wrong however can we fix at the third

iteration whatever normalization we have

computed so far up to the second

iteration actually we can because if we

expand this so as we have here we have

expanded

it what we need here is here to have

minus5 because that's actually the

Global maximum that we have found so far

not the minus three that we had at the

previous iteration so and here we also

need to fix this replace this -3 with

minus5 how can we do that well if we

multiply this one here and this one here

with a correction factor that will sneak

in a new maximum inside of this

exponential then we solve the problem

and actually this correction factor is

very easy to calculate because at the

third iteration if we multiply L2 so the

previous prly computed normalization

factor with this Factor here which is

the exponential of the previous estimate

of the maximum minus the current

estimate of the maximum so five we will

see that by the properties of the

exponentials this one here will become e

to the^ of 3 - 3 + 3 - 5 so this minus 3

will cancel out with this three and also

the second Factor will have this three

will cancel out with this minus three

will cancel out with this three and they

will become e to ^ of 3 - 5 and 2 to the

power of u e to the^ of 2 - 5 which is

actually correct because at the third

iteration we should be actually have we

should be using minus5 as the maximum of

the array so far um so basically what we

have found is a way to fix whatever

normalization Factor we have computed so

far while iterating through the array

when we found we when we find a better

maximum compared to what we have so far

and when we don't need to fix anything

then the formula still stands because

what we did here as a multiplic as a

correction factor so this is the

correction factor this correction factor

is nothing more than the previous um

previous maximum so the previous

estimate of the maximum minus the

current estimates of the maximum at the

current iteration so the current

Max um so this is basically M of IUS one

and this is M of I so the current

maximum at the current iteration and let

me delete it otherwise it remains

forever in my slides um so basically

when we arrive to the last element we

will see that the maximum doesn't change

because we compare the previous maximum

with the current element which is less

than the previous maximum so the maximum

doesn't change and we don't need to fix

anything because the the the the

previous L3 so the previously computed

uh normalization factor is correct

because they have all been using the

minus5 so when we don't need to fix

anything we just multiply by e to the

power of the previous maximum minus the

current maximum which is e to the power

of zero in this case so it's not fixing

anything so we have found a way to fix

the previously computed normalization

Factor while going through the array

even if at the current iteration we

don't have the global maximum yet so

that every time the maximum changes we

can fix and every time it doesn't change

we just multiply with e to the^ of Z

which is like multiplying with one so

the new algorithm that we have found for

the softmax is the following so we start

with m0 equal to minus infinity we start

with l0 equal to Z we go through the

array we compute the locally uh the the

local maximum so up so the maximum so

far from the zeroth element to the E

element so to the elements at which we

are we are doing the

iteration and the previously computed Li

can be fixed by using this correction

factor which is e to the power of the

previous maximum minus the current

maximum plus the exponential of the

current element minus the current

estimate of the maximum in this way we

go through the array only once and we

obtain two values the global maximum at

at the end at the same time the uh the

the ization factor and then we can use

it to compute the softmax so we made

three transformed the three passes

through the array into two passes

through the array and this is very

important and we will see how we

actually use it to derive flesh

attention the example that I have given

you so far is not really a proof that

our algorithm will work in every case

because we made a very simple example by

using a vector made up of four elements

but but does our new algorithm work in

every single case with whatever the

numbers are we need to prove that so we

will prove that by induction so what

first of all what are we trying to prove

we have fused the first two for Loops

into one for loop as you can see here

what we expect is that at the end of

this for Loop this MN so the m at the

last iteration will be actually the

global maximum in the vector and this Ln

so the L at the last iteration will be

equal to the sum of all the exponential

of all the

elements minus the maximum element of

the vector so the global maximum of the

vector and we need to prove that because

what I did before was an example and

that was not really a rigorous proof and

the way we will prove it is by induction

which is a typical way of proving these

kind of um

theorems now proof by induction

basically Works in the following way we

need to prove that our algorithm work

works for a base case for example with

Nal to 1 and then we pretend we assume

that the algorithm work works on n and

we need to prove that it also works for

n + one if this holds then we have

proven our algorithm for every possible

n because it will work for the base case

so for example n equal to 1 and then by

using the induction step we say so this

if it works for n then it also works for

n + one then it means that it will also

work for two but then if it works for

two then it should also work for three

because of the induction step that we

will prove and if it works for three

then it will also work for four etc etc

up to

Infinity so let's prove it for the base

case which is n equal to 1 uh it's very

simple so uh at n equal to 1 this for

Loop will only have one iteration so M

M1 and L1 M1 will be be the maximum of

the previous M which is minus infinity

because we initialize m0 equal to minus

infinity um so it will be equal to U the

maximum of the previous M and the

current element which is X1 so it will

be equal to X1 whatever X1 is uh X1

usually will never be equal it cannot be

equal to minus infinity um because it's

a number in fixed representation so it

cannot be minus infinity um so we the

the X the M1 at the end so it will

because we have only one element n equal

to one this is M1 is also the last um M

of this of this for Loop it will be

equal to the global maximum of the

vector made up of only one element and

L1 will be equal to the previous L which

we start from zero so l0 multiply by a

correction factor which will be in this

case e to the power of minus infinity

because the correction factor is the

previous estimate of the max of the max

minus the current estimate of the max

but the previous estimate of the max is

minus infinity minus X1 it is equal to

minus infinity so this one will be this

will be canceled out and then plus e to

the^ of X1 minus the current maximum

which is X1 so

M1 and if this one will be equal to the

sum of all the elements of the vector

which is made up of only one element

minus uh the maximum element in the

array which is X1 so we have proven that

it works for n equal to 1 now we assume

that it works for n does it also work

for an array of vect or with a vector of

size n + one so let's see what happens

at the n + one iteration at the n+ one

iteration we will be doing the maximum

of the previous estimate of M which is

the m at the end iteration and the

current element so xn of + one this by

the properties of the Max uh function it

will be actually equal to the maximum of

the global Vector up to n + one uh

because um the maximum will choose

whatever is the maximum between the

previous estimate and the current

estimate and Ln + one which is the

normalization factor at the n + one

iteration will be equal to the Ln so the

previous estimate not previous by the

previous normalization Factor at at the

end iteration multiplied by the

correction factor which is the previous

maximum minus the current maximum plus

the exponential of x uh the current

element minus the current estimate of

the

maximum but

Ln we have we assume that this um

property so this algorithm works up to n

so Ln is for sure equal to the sum of

all the exponentials of the previous of

the vector up to n minus the local

maximum of the uh Vector up to the end

element which is

MN we multiply by the correction factor

if there is something to correct which

will be the previous maximum minus the

current maximum plus the exponential of

the current element minus the current

exp estimate of the

maximum now

by the properties of the exponentials so

we can bring this one inside of the uh

summation and we will see that this MN

and this MN will cancel out because it

will be exponential of XJ minus MN + MN

minus MN + one so this MN and this MN

will cancel out and we obtain this one

plus this Factor here that remains

unchanged however you can see that this

this stuff here is exactly the in the

argument of this summation for the at

the iteration n + one so it is this one

is e to the power of XJ where J is going

from 1 to n minus MN + 1 + e ^ of x n +

1 - MN + 1 so the J only appears here

and it's equal maximum to n and this is

similar to being a j with n + one so we

can increase increase the index of this

summation by one and it will be the same

um and it will result in the same

summation so we have proven that also at

the n + one iteration we will have that

the L will be equal to the sum of all

the elements of the array the

exponential of all the elements of the

array up to the n+ one

element uh minus the maximum up to the n

+ one element so we have proven that if

it works for n then it also works for n

+ one um this is enough to prove that it

works for all sides of

arrays um don't worry if you didn't get

uh the proof by induction it is uh if

it's the first time you are seeing this

kind of proof it may take a little bit

to to get it um if you want to learn a

bit more about proof by induction I

recommend watching some other proof it's

very simple it's just you need to get

into the right mindset anyway let's move

forward

all right let's talk about block matrix

multiplication I know that you want to

jump to the code immediately and we will

go there we just need a little more

Theory actually so imagine we are doing

a matrix multiplication so we have a

matrix a we want to multiply it with a

matrix B and it will produce an output

Matrix C imagine the dimensions of the

first Matrix are M by K the second

Matrix is a k by n it will produce an

output Matrix that is m by n now imagine

we want to parallelize the computation

of this output Matrix I know that I

didn't talk about gpus yet so we will

not talk about gpus we will talk about

parallelization in the case of a

multicore CPU with which you are very

probably familiar with because right now

in nowadays when you buy a computer you

have a CPU and usually you can buy a

single core CPU or multie like a two

core four core eight core etc etc each

of the these cores are actually kind of

small CPUs inside your CPU that can

execute operations in parallel how to

parallelize the matrix multiplication

imagine you have this matrix

multiplication to parallelize each of

the output element in this C Matrix is a

DOT product of a row of the a matrix

with a column of the B Matrix for

example this element on the top left is

the dotproduct of the first row of a and

the First Column of B this element on

the top right of C is the dot product of

the first row of a and the last column

of B this element on the bottom left is

the dot product of the last row of a and

the First Column of B etc etc for all

the other elements now to parallelize

this computation we need as many cores

as is as there are elements in C if we

want to parallelize it uh so if M and N

are very small then maybe we have enough

course but imagine M and N are quite big

we imagine like 100 by 100 we don't have

10,000 cores right now in the CPUs so

how can we parallelize a matrix um

operation by using less cores than there

are elements in The Matrix itself that's

when we talk about block matrix

multiplication basically block matrix

multiplication means that you can divide

the original Matrix into smaller blocks

of of elements and then the operations

of matrix multiplication can be computed

between these blocks for example imagine

we have a matrix that is 8x4 it means

that it has eight

rows and four columns which means that

it has 32 elements and then we are

multiplying it with another Matrix that

is 4X 8 so it has four rows and eight

columns so it also has a 3 two elements

the output Matrix will should have 64

elements we don't have 64 cores so how

can we parallelize it imagine we only

have eight cores now with eight cores we

can divide this original Matrix a into

four blocks where the first block is

this top left block of two by no 4X two

elements so um uh how to say um eight

elements on the top top left and then

eight elements on the top right of this

Matrix then eight elements on the bottom

left and eight elements in the bottom

right of this Matrix these are four

blocks then we divide also the B Matrix

into um eight blocks where each block is

made up of four elements so this b11 is

the top left four elements in the

original Matrix this B4 is the top right

four elements in the original Matrix

this B21 is the

um bottom left for elements in the

origin etc etc etc how do we do this

block matrix multiplication we can watch

these matrices as made only by their

blocks so we can view this Matrix here

as made up only by its blocks we can

view this Matrix here as made up only by

its blocks and the output of this

multiplication will be a

matrices that is computed in the same

way as the original Matrix

but where the output of each dot product

will not be a single element of the

output Matrix but it will be a block of

elements of the output Matrix for

example the top left block here is the

dot product of the first row of this

Matrix with the First Column of this

Matrix and it will be computed as

follows so it will be a11 * b11 plus a12

* B21 and this output will not be a

single scalar but it will be a uh well

let me count it should be uh eight

elements so it should be four um made up

it's it should be a block of four

elements um or eight elements let me let

me count actually so because we have

eight blocks and it should be made up of

eight elements uh let's we can see that

here um how to find the dimensions of

this output block

uh well we can check what is a11 a11 is

4X two so it's eight elements in as a

smaller Matrix made made up of eight

elements where the elements are

distributed in four rows and two columns

we are multiplying it by b11 which is a

smaller Matrix compared to the original

made up of 2x two elements so four

elements so when we multiply a 4x2

multiply by 2x two it will produce a 4x2

output uh block Matrix so

block so if we do this computation here

block by block it will produce a block

of output elements of the original

Matrix so not not a single scalar but a

block of outputs which makes it very

easy to parallelize because if we have

only eight cores we can assign each

output block to One Core and each core

will not produce one output element of

the original Matrix but it will produce

eight elements of the original Matrix as

a 4x2

Matrix um so basically block Matrix

allow us to um uh to do the matrix

multiplication either by element by

element so like in the original Matrix

so each row with each column or blocks

by blocks in the same way like we do

normal matrix multiplication because the

the matrix multiplication that we are

doing between blocks is the same way as

we do matrix multiplication with the

original Matrix and it will produce not

a scolar but a block and now let's see

why this is very important for us

so why should we care about block matrix

multiplication because we are trying to

compute the following operation so the

query multiplied by the transpose of the

keys and then we will should apply the

soft Marx of this operation and then we

should multiply the output of the

softmax with V for now let's ignore the

soft Marx let's pretend that we are not

going to apply any softmax so we take

the output of the query multiplied by

the transpose of the keys and we just

multiply it by V to obtain the output of

the attention which is wrong of course

but it simplifies our tract of what we

are going to do next so for for this

moment let's pretend that we are not

going to apply any soft Max so we just

do the query multiply by transpose of

the keys and directly we multiply the

result of this operation with v this

will result in a matrix that is n by D

so n tokens each made up of an embedding

of D Dimensions so lower case D

dimensions and we know that query key

and values are themselves matrices of n

by D Dimensions so the um um n tokens

which made up of embedding of the

dimensions so imagine we have a query

Matrix and the key and the value Matrix

that are 8 by 128 so we have eight

tokens each token is made up of 128

Dimensions we can divide as we have seen

each when we compute matrix

multiplication we can divide our um

Matrix into blocks how we choose the

blocks is up to us as long as the oper

the the shapes of the blocks match when

doing the matrix multiplication so for

example in the previous case we divided

our Matrix a into blocks such that the

the the shape of the block Matrix so the

Matrix that is made up only of the

blocks is compatible with the block

Matrix B so that this operation is

possible so this is the only requirement

that we need to be aware when doing the

matrix multiplication the shapes of the

blocked Matrix so the Matrix that is

made only of the blocks should match in

the matrix multiplication for the rest

it doesn't matter how we divide it so

imagine that we choose to divide this

query Matrix into blocks of rows and we

can do that we don't have to necessarily

divide also the columns we can just

divide the rows so that each Q is not a

single row but it's a group of two rows

so q1 is a group of the first two rows

of the Q Matrix of the Q sequence Q2 is

the group of the second two rows of the

Q sequence etc etc and we do the same

also for V for K we don't do it because

we are actually going to multiply with K

transposed so we do this subdivision

directly on K transposed so so we have

the Q which has been divided into groups

of rows and then we have a k transposed

which is a matrix that is 108 by 8

because it's the transpose of the Keys

which is 8 by

108 and we decide to divide each of the

column group of columns of K into a

single block so the K1 is the first two

columns of K transposed K2 is the second

group of two columns in K transposed etc

etc until K4 which is the last two

columns in K transposed the first

operation that we do is the

multiplication query multiply by the

transpose of the keys which basically

means that we need to multiply each

query with all the keys then the second

query with the all the keys etc etc now

each query is not a single row of the Q

sequence it's a group of two rows of the

Q sequence and each K is not a single

column of K transposed it's a group of

two columns of K transposed but doesn't

matter because we have seen that the

matrix multiplication if we write the

matrixes as made up of blocks we just

compute it in the same way when we do uh

normal matrix multiplication ation so we

are multiplying this matrix by this

Matrix and for what we know this Matrix

here is made up of four rows with some

Dimensions which is uh 128 dimensions

and this one here is made up of uh how

many rows 128 rows and four uh

columns I didn't uh draw The Columns

because it's too many to draw here but

you need to pretend it's a lot of

Dimensions one for each 128 for each

vector and here you need to pretend that

this is 128 rows when we do the matrix

multiplication we apply the normal M

matrix multiplication procedure which is

each output element so this first of all

the output shape of this Matrix of this

matrix multiplication will be 4x4

because it's the outer dimensions of the

two Matrix that they are

multiplying the first element of the

output will be the dot product of this

Vector here with this Vector here the

second element so this one here will be

done dot product of this Vector here

with this Vector here however this is

not vector and this is not a vector so

it's actually a matrix multiplication in

this case this element here is not a

scalar it is a group of elements of the

output Matrix because we are doing block

matrix multiplication and how many

elements it will be well we know that

the original q1 is a 2 by

128 the K1 is 108x 2 so it will be a

group of 2x two elements of the output

Matrix so we are doing the matrix

multiplication of the q1 with K1 then q1

with K2 then q1 with K3 q1 with K4 etc

etc for the first row and then the

second row will be Q2 with all the K and

the Q3 with all the ks and Q4 with all

the ks so as you can see when we do

matrix multiplication we don't even care

if what is underlying is a block or a

vector or a scalar we just apply the

same procedure for first um row of the

Black Block matrix multiplication with

the First Column of the Matrix M of the

second Matrix uh and then the first row

with the second column the first row

with the third column etc

etc let's then multiply because the

formula says that we need to multiply

query with the transpose of the keys and

then multiply by V all of these are um

block matrices now as you can see from

my using of colors every time time I

refer to the original Matrix I use the

blue color and every time I refer to the

block Matrix I use the pink color so we

need to multiply the output of the query

multiplied by the transpose of the key

then by V because we are skipping for

now the soft Max and later we will see

why so if we want to do this

multiplication we need to do the

following so it will be uh this Matrix

is made up of blocks and block matrix

multiplication just ignores this fact

and just does the matrix multiplication

like it is a normal matrix

multiplication so we do the first row

with the First Column then the first row

with the second column then the third

row the first row with the third column

etc etc so the first block of row how is

going to be calculated this out um this

output in the output Matrix of this

matrix multiplication well it will be

the um the the first row so the dot

product of the first row the dot product

because it's not really Dot product it's

the actually the matrix multiplication

of the first row but in a dotproduct way

let's say uh with the First Column which

is made up of V1 vs2 V3 and V4 so it

will be this element with V1 plus this

element with vs2 plus this element with

V3 plus this element with V4 uh and this

will be the first output element the

second output uh block will be this row

with this

column uh which will be this element

with V1 this element plus this element

with V2 plus this element with V3 plus

this element with V4 and this will

produce the second output

block etc etc also for the third and the

fourth block output let's look at what

is each block made up of so each block

is made up of the um the first element

so query 1 multip by qy1 because um

it's the result of the query multiply by

the keys with the V1 of the second

Matrix plus the this element with this

one plus this element with this one plus

this element with this one so the pelo

code for generating this output of this

attention mechanism which is not really

attention mechanism because we skip the

soft Max but I just want you to get into

the habit of thinking in terms of blocks

is the following so we take each query

lock um we go through each

query and as you can see let's look at

actually what this output is made up it

is made up of the query one multiplied

by key1 and the result multiply by V1

then the query one with K2 then the

result multiply by V2 then the query one

with the K3 and the result multiply by

V3 plus the query one with the K4 and

result multiplied by V4 this is

basically what we are doing is the dot

product of this row with this column

made up of

blocks um so the the pelo code for

generating this first row is the query

is the query number one and then we

iterate through the keys and the values

from one to four and we sum iteratively

so for each block basically to generate

this output Matrix and if you um for

each row we will see that it's a

different query with all the keys and

values and then this will be the the the

query number number three with all the

keys and values and this will be the

query four with all the keys and values

so to generate this output Matrix we

need to do we iterate through the

queries and this will be one row of this

output Matrix and then we need to do

this iterative sum of the query I that

we are iterating through multiply by the

J K and V and we keep summing them

iteratively and that would that will

produce the output Matrix or you can see

here I know that what I have done so far

is not useless not useful for Flash

attention but it's useful for us to get

into the mindset of computing this

product by blocks because later we will

use it also with the

softmax all right guys I I know that we

have comput what we have computed so far

is not really the soft Max operation is

not sorry the really the attention

mechanism because we have skipped the

soft Max so somehow we need to restore

it and the the following few I think

think 10 20 minutes we are going to be

really really challenging because I am

going to do a lot of operations that

will involve a lot of different blocks

and a lot of different metrix

multiplication and variants of the

softmax so it may be difficult to follow

however don't give up you can watch this

part twice three times and every time

you it will have a better understanding

I also recommend watch it until we reach

the flesh attention algorithm before we

start restarting from to to go back to

to rewatch it because you watch it we

reach the flesh attention algorithm and

it will give you a better understanding

of what have happened so far and then

you can rewatch it to deepen your

understanding another thing that I

recommend is take pen and paper and

write exactly the operations that you

are seeing and write the shapes of each

of these blocks of these elements that

are made in that are part in this um

Matrix

multiplications so that you better uh

understand what is happening and you

better remember what when I refer to a

particular element or a particular block

okay after giving this small uh

motivational speech Let's Start so what

we have done so far was query multipli

by the transpose of the keys however

each query is not a single row of the

query sequence but it's a block of

queries it's a block of rows in our

particular case this q1 is not one row

of the query sequence it's two rows of

the query sequence because we have

chosen as a block size a group of two

rows and this K transposed one is not

one column of the K transposed Matrix is

two columns of the K transposed Matrix

because we have chosen it like this and

if you don't remember let's go back to

see it uh here we have chosen K1 is 2

two columns and q1 is two rows of the

query original Matrix and every time I

use the blue color I am referring to the

original shape and every time I'm using

the pink or Violet whatever it is I am

referring to the block metrix so it's a

block of elements of the original

Matrix okay now uh the first thing that

we have done was query multiplied by the

transpose of the keys and this produces

a block patrix as output that we will

call S where each element s j so the S11

element of this Matrix will be the query

one with K transposed one the S12 will

be query one with K transposed 2 s13

will be query one with K transposed

three etc etc for all the rows and for

all the columns then we should be

applying the soft Max because if you

remember the formula is soft Max of the

query multiplied by the transpose of the

keys however I want to restore the

softmax operation but with a Twist which

means that we will apply the simplified

version of the softmax and we will call

it softmax star which is just the

softmax without the normalization so let

me write it for you what it

means uh let's do it with the same color

that I chose for the softmax which is

orange so the soft max if you remember

correctly if we remember it's the soft

Max

of to of a vector we apply it element

wise so each element is modified

according to the following formula so

the E element of the output Vector to

which we are applying the softmax is

equal to the

exponential of the E element of the

input Vector minus the maximum element

in the input Vector divided by a

normalization factor that is calculated

according to this

summation that is going from Jal to 1 up

to n of the

exponential of x i minus x max so

basically uh we are doing the

exponential of each element minus this x

Max and why are if you remember

correctly why are we subtracting this x

max to make this exponential numerically

stable computable because otherwise it

will explode and because we are applying

it to the numerator we also need to

apply it to the denominator okay the

softmax start operation is exactly like

the softmax but without the

normalization part which means that it's

just the numerator of the soft Max so we

will modify each element of the vector

to which we apply the soft Max star

according to this formula let me move it

more aligned like this so we just do

element element wise uh operation that

is the exponential of each element minus

the maximum to the of the vector to

which we are applying soft Max star

okay now why did I introduce this soft

maxar operation because we will be

applying it to the Matrix that we have

computed so far which is this s Matrix

so we applied soft Max star to each

element of this s Matrix but each

element of this s Matrix is itself a

matrix because it's a block Matrix and

each element of this s Matrix so for

example the element S11 is a 2x2 matrix

because it is coming from the product of

two matrices which are a group of rows

and a group of columns from the Q and

the K uh so for example this S11 is what

is um let's draw it actually this S11

will be for example made up of four

elements let's call it uh I don't know a

of

S11 uh let's let's choose better naming

let's call it I don't know

a uh

b c and d just the generic elements when

we apply the soft Max star to this S11

it will result so let's apply the soft

Max

star soft Max

star it will result in a

matrix that

is each element the exponential of each

element minus the maximum for each row

now we don't know which is the maximum

so let's choose one suppose that's the

maximum for this row is a and the

maximum for this row is D the first

element of the output of this soft Max

star applied to this block S11 will be

the

exponential of a multi minus a because

that's what we chose as the maximum for

this row the second element will be the

exponential of B minus a because it's

the maximum for that row then in the

bottom row it will be the exponential of

uh C minus D because that's the maximum

for the bottom row and this will be the

exponential of D minus D and that's the

exponential that's how the softmax star

will modify each block in this block

Matrix let me delete this stuff

otherwise it will remain in my slides

forever and later I want to share the

slides with you guys so you can use my

same slides so delete delete delete okay

after we have applied the soft Max to

each of the elements in this s Matrix we

will call it the P Matrix and each

element P11 will again be block of 2x

two

elements um so P11 will be the soft Max

so P11 will be the soft Max star applied

to S11 where S11 is what is query one k

transposed one and the p12 will be the

soft Max star applied to S12 where S12

is what is query one

multiplied by K transpose the two etc

etc etc for all the elements of

s okay now that we have applied this

softmax star operation the next

operation that we should be doing

according to the formula of the

attention is uh the softmax of the query

multiplied by the transpose of the keys

then the result of the softmax

multiplied by V I know that we didn't

apply the real soft Max we applied

softmax star which is softmax without

the normalization later we will see how

to compens at this lack of normalization

because we will do it at the end and

it's something that we can

do okay so we take this P Matrix which

is the result of the soft Max star

applied to this s Matrix and we multiply

it by

V what how do we do it well it's a block

it's a matrix made up of blocks of

matrices um so P11 is actually not a

scolar but it's a matrix of 2x two

elements and we need to multiplied by V

but we don't multiply with the original

sequence V but with the blocked sequence

V just like before where each V is not

one row of V but it's a group of rows of

v and how many rows is it is it is two

rows of V uh for now please ignore

completely whatever I have written here

because we will use it later so we need

to do this product of this Matrix here

which is made up of blocks remember with

this Matrix here which is made up of

blocks it is made up of four

four rows where each row is not really a

row it is a block of rows and this one

it is made up of 4x4 elements where each

element is not really a scolar but it's

a

matrix so as you remember in the block

matrix multiplication when the algorithm

for computing the matrix multiplication

is the same as the normal matrix

multiplication except that we use blocks

so what I am doing is guys uh the

following operation so let's write it

somewhere let's say o is equal to P

multip by V okay so um the first output

um row row because it's not really a row

but it's a block row uh will be computed

as follows the first row of this block

Matrix with the First with um the First

Column of this V Matrix and um we are

treating it uh like a block Matrix so it

will be P11 multip by V1 plus p12 ultip

by vs2 plus p13 * by V3 plus p14

multipli by

V4 this will produce the first output

row of O but it's not really a row

because it's a made up of two rows uh so

this stuff here is not one row it is two

row and we can prove that because what

is P11 P11 is let's write it somewhere

so P11 is a 2X two Matrix uh yeah 2 by

two and we are multiplying it with V1

which is a block of two rows of V so it

is a two rows by 128 Dimensions so it is

equal to 2x

128 so this stuff here is 2 by

128 so this block here the output block

that we are Computing is a block of two

rows of the output Matrix that we are

Computing uh I know this is really uh

difficult to follow because we are

involving blocks so we need to visualize

at the same time metrics as blocks and

as the original metrix that's why I

highly recommend you to pause the video

think it's through write down whatever

you need to write down because it's not

easy to find follow it just by

memorizing the shapes so you you

actually need to write down things

anyway we are Compu the first output

block of the output o Matrix now if we

if you remember uh the output um the

output this output here should be the

output of the output of the soft Max

multiplied by V now this soft Max has

not been applied to the entire row of

this Matrix here s Matrix here but

basically to compute this soft Max star

what we did was to compute the soft Max

star at each block independently from

the other blocks which means that the

maximum that we are using to compute

each soft Max star is not the global

maximum for the row of this s Matrix but

the local maximum of each block and this

is wrong actually because when we

compute the soft Max uh we apply the

soft Max we should be using the global

row I want to give you an example

without using blocks because otherwise I

think it's not easy to follow so when we

do the normal attention so we have a

query multiplied by the transpose of the

keys this produces a matrix that is n by

n so sequence by sequence where each

element of this Matrix so let's say

three four five I don't know how many is

one two three four five six yeah six

two 3 four and five six should be 1 2 3

4 5 6 okay this one here should be the

dot product of the first query with the

uh first um let me use because query one

transpose the

key1 uh this is because as I said before

when we do the product of two vectors we

always treat them as column vectors so

when you want to write the DOT product

you cannot multiply two column vectors

you need to multiply one row Vector with

one column Vector that's why we

transpose this one if it confuses you

you can also write q1 K1 that's totally

fine it's just uh wrong from a notation

point of view anyway the first one will

be the dot product of the query one with

K1 the second element will be the dot

product of the query one with K2 the

third will be the query one with K3 etc

etc etc um

so this is q1 with K1 q1 with

K2 K2 and q1 with K3 q1 with

K4 um

anyway when we do the soft Max we

actually calculate the maximum on this

entire row however what we are doing is

we are actually doing a block matrix

multiplication and as you remember

um when we do by blocks we are grouping

together rows of queries and rows of

keys and in this particular case we are

grouping the two queries together to

create one uh one group of queries and

two keys together to create one block of

keys so we need another row of this one

so it's the let me ose query one k or

query 2 K1 this should be query 2 2

K1 query 2 K2 query 2 K3 query 2 K 4

query 2 K5 and query 2 K

6 um when we each of this each of this

block here is

Computing this block here is Computing 2

by two elements of the original Matrix

if we had never applied the blocks so it

is

Computing these two four elements

here and if we apply the soft Max star

to each of these blocks we are not using

the maximum element in this row we are

only using the maximum element in each

block which means that when we will use

it in the downstream product with v

Matrix we will be summing values that

are wrong because each of these values

here will be based on a maximum that is

not the global maximum for this row it

is the local maximum of this block here

and um and this block here will have the

global the it will use the local maximum

of this block here and this block here

will use the local maximum of this block

here etc etc etc so what I'm trying to

say is that when you sum P11 with V1 P11

may have some Maximum local maximum that

is different than from the local maximum

of p12 and p13 may have a different

maximum local maximum that of P1 P11 and

p12 so we need to find a way to fix the

maximum that was used to compute the

exponential here with the maximum found

here in case the maximum here is higher

than the one local to P11 so if we have

found for example here a Max maximum

that is higher than the maximum used

here here then we need to fix this one

and this one because that maximum in the

soft Max should be the maximum for all

the row not the one belonging to the

each

block and this leads to our next step

how to fix

this first of all let me introduce a

little pseo code for computing this

output Matrix here which is an output

block Matrix and later we will use this

pelo code to adjust the error that we

have made in some blocks in case the

future blocks so the p13 has a better

maximum than P11 or p12 so to compute

this output Matrix

o we go through so for example to

compute the first row we choose well P11

is what is is um let's go back P11 is

let me delete also this one it's not

needed anymore P11 is the soft Max star

of q1 K1 p12 2 is the um soft maxar of

q1 K2 p13 is the soft maxar of q1 K3 p14

is the soft Mark star of q1

K4 um which means that uh to compute

this block here here we first need to

compute the P11 what is P11 well P11 is

the soft Max star of a block of Q and

another block of k k which in the case

of the first first row of the output

Matrix means that it is the query one

with the soft maxar of the query one

with key1 the soft maxar of the query 1

with K2 the soft maxar of the query 1

with K3 the soft maxar of the query 1

with K4 which means that we need to go

we need to make a for Loop through all

the keys while keeping the query fixed

so to compute the first output row we

need to do the soft maxar to produce P11

we need to do the soft ma out of query 1

K1 and um we sum it initially to zeros

because we don't um we need to

initialize our output somehow and we

initialize it with

zeros then we sum the next P1 to which

is the query one with the K2 and then we

sum the next p13 which is the query one

with the K3 etc etc that's why we have

this inner loop

here all right so however this output

that we Computing is wrong because I

told you we have computed the softmax

star using a statistics the maximum

value that is belonging to each block

and not the one that is the overall row

of the original Matrix how to fix that

we have a tool actually we have computed

before an algorithm called the online

softmax I don't know if I referred to it

before as the online softmax but it's

called the online softmax that allows to

fix previous iterations when we are

Computing the current iteration

Bas how well let's review the online

softmax we start imagine we are working

with one single Vector so we are a

vector made up of n elements the what we

do is we do a for Loop where we compute

iteratively the maximum up to the he

element and we fix the normalization

factor uh computed in previous iteration

in case we found a better maximum at the

current element if this is not clear

guys go go back and watch the online

soft Mar because this is very important

because this is what we are going to use

to fix this P11 p12 blocks in case we

found better maximum in p13 or p14

Etc so let's see how to apply this

online soft Marx to this case here so

that we can compute so you may be

wondering why are we going through all

these troubles I mean

why the real reason is when first of all

why did did we introduce block matrix

multiplication because we want to

compute matrix multiplication in

parallel so you can think that each of

this P11 because they are independent

from each other and because each of them

are using the maximum belonging to each

block they can be computed independently

from each other then however we need to

somehow aggregate their value and to

aggregate their value we need to fix the

values that have been calculated

independently because we didn't when

Computing values independently we don't

have a global view we have local view so

we compute local blocks P11 p12 p13 etc

etc and then when we aggregate these

values we need to fix them so that's why

we are trying to uh come up with this

system of fixing uh values that have

been calculated

independently so how to fix this let's

look at the following algorithm first of

all um this O Block here as I said

before it is a block of two rows

where each row is made up of 128

dimensions and we have seen that before

by checking the dimensions of P11 and V1

the result of P11 V1 which means that

for each output block we need to take

care of two maximums uh and two

normalization factors so up to now I

didn't use the normalization factor we

said that we are applying softmax star

which is the soft Max without the

normalization but eventually we will

need to compute this normalization so we

want to create an an algorithm that

fixes the maximum used to compute each

of this P11 and also computes

simultaneously the normalization factor

and at the end we will apply this

normalization factor and the way we will

do it is as follows we start with

initializing the maximum to minus

infinity one for each row that we are

Computing so is our output block is made

up of two rows so we need one maximum

for the top row and one maximum for the

bottom row and also the normalization

factor which we initialize with zero

because we didn't sum anything for now

and the output we initialize it with all

zeros because we didn't sum anything to

this output for now we

compute the we uh to compute the output

row so this output block here so this

output block here we need to go through

all the keys uh to produce this P11 p12

p13 p14 while the query is the query

number one the query block number one so

the first step that we do is we compute

the maximum of the first block

P11 which is the row Max so the maximum

for each row of the block um uh uh q1 K1

this is not P11 it's S1 sorry guys this

is

S11 so we compute the maximum of this

one and we call it actually S1 as you

can see

here um then we can um calculate P11

which is the soft Max star which is the

exponential of the query multip query 1

K1 so S1 minus the maximum in the local

group

S1 and we add it to our

output for now the output is initialized

with zero so for now ignore this part

here I will explain it later so for now

o1 should be equal only to P11 V1

now at the step number

two we may find in the local group S12

so this one is

S12 we may find a better maximum for the

top row and the bottom row and this

maximum is the M2 which may be better

than the previous maximum for each of

these two row but may also not be so we

need to find a way to fix in case it's

better and to not fix anything in case

it's not better

and the way we do it is this so we

compute the new maximum of the current

local row query

2 we um calculated the p12 which is the

soft Max star of S2 which is S2 minus M2

which is the local maximum and then we

need to add it to the output however in

this case we may have found a better

maximum so how to fix the o1 which on

only use the maximum that was local to

S1 well we know that we can fix that by

using exponentials because each of these

element of o1 is just an exponential

without the normalization because we are

applying soft Max star so how to fix an

exponential with another exponential so

basically we are saying that uh we

multiply o1 which is a matrix so let me

show you what is this Matrix so o1

is a matrix made up of two rows so as

you can see here I have the shape of o1

it's a 2 by 128

Matrix so this is the top row so 011 012

blah blah blah until

o1

128 then

o21 o22 up blah and

o21

128 we need to fix this value how we

basically

just using the exponential that we have

used in the online softmax that we have

seen before so if we multiply this

Matrix here by a diagonal matrix that is

made as follows it's a diagonal matrix

made up of two elements because the

exponential of M1 minus M2 will be a

vector of two elements and the

exponential of a element y exponential

is another Vector of two elements and

this diag basically means that diagonal

Matrix where in the diagonal we have the

elements of the vector to which we are

applying this diag operation which means

that this value here will be the

exponential of the first element of

M1 so let me show you uh how to write it

exponential

of M1 minus M2 minus M2 so the first

element so let's call it one here here

is a zero here will be zero and let's

delete this one and we write another one

here

exponential M1 minus M2 but the second

element of this Vector so basically the

DI this notation here diag means

basically just take the vector and

distribute it over a n byn Matrix where

n is the size of the vector to which it

is applied and all the other elements of

this Matrix should be zeros this is what

this

means if we do this operation here we

will see that the output of this

multiplication will fix each element of

the top row using this exponential and

the bottom row with this exponential

which will basically cancel out this M1

that was computed in the previous

iteration and introduce the M2 that we

have computed in the current iteration

in each of these o elements in this o

block

Matrix okay so this output will be this

element will will multiply by this one

so it will fix o11 with this Factor here

and o1 21 will not be fixed by will be

multiplied by zero so it will not

contribute to this first output element

so this element here will only depend on

o11 fixed by the exponential of M1 minus

M2 but the first element of this vector

and then o12 will also be fixed by um

o12 will be fixed by this exponential

here but not by this one and all the

dimensions of the first row will be

fixed by this exponential and all the

dimensions of the second uh row here

will be fixed by this uh exponential

here this this scalar here which is the

second element of the vector X of M1

minus

M2 or okay it was really challenging

this one so so what we are doing is we

compute p12 and we fix all the elements

in P1 by multiplying by this Matrix here

by multiplying by this Factor here

Matrix Factor here and when we will

compute uh step three we will fix step

two etc etc etc now let's talk about the

normalization factor because for now we

have been ignoring it um the

normalization factor is something that

we can compute while Computing this

maximums because it is provided in the

pseo code of the online algorithm that

we have seen before for the soft Max so

while Computing the maximum we can

actually compute the uh normalization

Factor by fixing the normalization

factor of the previous iteration and

this is exactly what we are doing here

so at the first iteration we comput the

normalization factor using the local

maximum and at the second iteration so

you can for now ignore uh this one

because we are not fixing l0 with

anything because l0 will be zero so we

are just basically

um we are just Computing this summation

here so l0 will be zero so this Factor

here will be

zero um and when Computing L2 so the

normalization Step at the second

iteration we will fix L1 with an

exponential which guess what it's

exactly the same exponential that fixes

the maximum uh the P11 so it is the

previous estimation of the maximum minus

the current estimation of the maximum

plus the new uh normalization factor

using the local maximum and we keep

doing this job at the end we will obtain

a correct output for this uh metrix for

for this uh block here but without the

normalization how to apply the

normalization well the normalization is

something that is um we need to divide

each element of this o by the

normalization factor but because we are

keeping while iterating through these

for Loops we also calculate the

normalization factor we keep

accumulating it until we reach the end

of the iteration and then we apply the

normalization factor so we take the last

output and we just divide it by L4 which

is the normalization factor calculated

at the fourth iteration and that will

fix the soft Max all right guys so now

that we have derived the algorithm of

how to compute this output of the

attention blockwise why also fixing the

soft Max which is done independently in

each single block we know that the

normalization is done at the end I want

to also prove it so what we done when we

introduce this algorithm that computes

the soft Max in an online way we proved

by induction that this algorithm is

correct so at the end of this algorithm

this L of the last iteration will

actually be the normalization factor

that we can appli to get the soft Max so

we don't apply the normalization while

Computing this output in an online way

iteratively way by multiplying the query

with all the blocks of keys we apply it

at the end of this four iteration and at

the end of this four iteration we will

have the last

output um and we also know that the last

L will contain the exact normalization

factor that we need to apply to each row

because this o of four is a block of

output rows which is if you remember

from the attention mechanism each output

the output of the attention has the same

shape as the input query Vector which is

a sequence of tokens so this O is a

sequence of tokens that we need to apply

the normalization to and we know that

the correct factor is L4 so let's prove

this simple formula uh L4 is a vector

one for that contains as many elements

as there are rows in o4 so in this O

Block uh of

rows suppose that it contains uh two

rows like in the algorithm that I have

described so far in which we pretend

that we are grouping two rows of queries

with two columns of keys together so the

output o uh um the block O will contain

two rows of the output so we will have

two normalization factor in this L4

Vector here what we are doing with this

formula is we are taking this L4 vector

and we are creating a diagonal matrix

with it and then we are Computing the

inverse of this diagonal matrix so L4 is

a vector that contains two normalization

Factor so it's l i don't know let's call

it l L4 4 element 1 and L4 element

2 this is our

L4 uh Vector then we have 04 o4 is a

matrix as you can see from the shape is

2 by

128 Matrix so o is let's copy it

actually oh no let's not copy it o4 is a

matrix that is two

rows with 128 elements so so the first

row with 128 dimensions and the second

row with 128 Dimensions the first thing

that we are doing with this L4 is we are

converting it into a diagonal matrix

which will be a diagonal matrix 2x two

because it contains two elements so it

will become something like this so it

will be L4 the first element of L4 zero

and then zero L4 the second element of

this

Vector then we are Computing the inverse

of this Matrix the inverse of a diagonal

matrix is just the diagonal matrix with

each element on the diagonal that

becomes its

reciprocal uh this is from linear

algebra it's not I'm making it I'm

making this up so uh the inverse of this

Matrix here is equal

to uh the same uh diagonal matrix but

where each element is one over L4 the

first element of L4 zero 0 and 1 over uh

L4 the second element of L4 and then we

are multiplying this stuff here so let

me delete some stuff so this stuff here

is getting multiplied by

o which is a matrix that is 2 by

128 so we are doing this multiplication

now

multiply now the output of this so this

is two let me WR it 2x 2 * 2x 128 will

be a matrix that is 2x

128

where the first um dimension of the

first row of the output of this

operation will be the dot product of

this col this row here with the First

Column so basically we are dividing this

element here by L4 the first element of

L four the second uh output element here

will be the dot product of this row with

this second column so we are only

multiply we are dividing the the the the

second element here of this input Vector

here by L4 the first element of L4

because the all the elements of the

second row will be multiplied by zero so

they will not contribute to this output

row while the second output row will be

the dot this element here will be the

dot product of this row with the First

Column the first element

here is multiplied by zero so it will

not contribute to this output so it's

only the second ele the first row of the

second the first element of the second

row of the input Matrix here will be

divided by

l42 so basically this will be applied

will divide all the elements in the

second row and this will divide all the

element in the first row in producing

this one here which is exactly what we

need to do when we want to normalize we

need to apply this normalization factor

and this should uh help you better

visualize why this operation is

normalizing the vectors of the output at

the end and still obtaining the same

result now let's proceed further all

right guys finally we are ready to see

the flash attention forward pass uh by

um also comparing it with what we have

derived so far so if you look at the

flashh attention paper first of all this

is the flashh attention to forward pass

and later I will explain what are the

differences between the flashh attention

one and the flesh attention two um I

didn't want to jump directly to this uh

forward pass because I believe that even

if the derivation like the derivation

was a little uh difficult to follow I

believe that it gave you some intuition

into what is happening so even if you

understand 50% of it that's enough

because later we will also code it and

you should reach like a 90% of

understanding so every time we introduce

some new information it should improve

your be the the your understanding so

basically in flesh attention what we are

flesh attention to especially we we take

our um as input we have our uh query

Kean values which are sequence of tokens

each token is made up of a vector of D

dimensions and D lower case D dimensions

and we divide this query guess what into

blocks in how many blocks well depending

on this parameter BR which is the size

of the query block that we want to

choose so how many rows of query we want

to group together into one block and we

also do it with K and V and we divided

that into um blocks of uh um depending

on this parameter BC then we also

initialize the output which is the

output that we want to produce so the

what is the flashh attention Computing

well the flashh attention um is

Computing the following so it's

Computing the uh soft

Max soft Max of the query multiply by

the transpose of the keys divide by the

some normalization Factor uh multiply

that by V and um so that's what it's

going to compute and it's going to

computed this way first of all there is

an autor Loop through the queries which

corresponds to the same pelo code that

we have seen before because we want to

compute each block of the output

Matrix um in parallel with the with with

respect to the others so basically we

want to compute this output block and

this block all output block

independently this output block here

depends on the query one and all the

keys this output block here depends on

the query 2 and all the keys this output

block here depends on the query three

and all the keys where query 1 is not

the first query but it's the first group

of queries or first block of queries

query two is not the first um query two

is not the second row of the query

Matrix but it's the second block of the

query Matrix etc etc um so that's why we

have this outer um outer uh iteration

among all the blocks because we want to

compute all those blocks of the output

Matrix in parallel but to compute each

of these output block we need to go to

an iteration among all the keys that's

why we have an inner loop on the keys

and we do exactly the same operation

that we have done so far uh by hand so

first we compute the S Matrix which is

what the each block of query with the

corresponding block of the keys then we

compute the local maximum to the current

s block this is the local maximum and we

um compare it with the maximum of the

previous iteration because that's what

we do in the online softmax then we

compute the

P the P block which is the softmax star

of the S block minus the local maximum

of the S

block then we compute the uh

normalization Factor what is the

normalization factor it is the summation

of all the exponential of the soft Max

star but um uh by fixing the

normalization factor of the previous

step and we know how to fix the

normalization factor because we just

multiply by an exponential which is the

previous maximum minus the current

maximum that's what this factor is and

then we computed the output exactly

using the same uh correction factor that

we have seen before which is the

diagonal matrix made up of the diagonal

um where on the diagonal you have the

elements of this Vector here which is

the exponential of the previous maximum

minus the current maximum multiplied by

the output of the previous step because

we want to fix the previous step because

it was based on the previous P which was

using the maximum of the local previous

P plus the current p v which is based on

the current local maximum and it will be

fixed by the next

iteration okay okay and at the end after

we have gone through all the cas so we

have computed all the output block but

we didn't apply the normalization factor

and it's applied at the end because

while going through each key we are

calculating the L normalization factor

for the softmax because inside of this

for Loop we are just Computing the

softmax star so we are not normalizing

each value so at the end someone has to

normalize it and it will be this um this

this um instruction here which is use

the normalization factor that we have

computed over all the iterations and

apply it to each element of O because

the difference between the softx star

and the actual sofx is just the division

by the um the normalization factor and

this instruction here is actually

dividing each o with the corresponding

normalization Factor one for each row of

the block each row in the output block

that we are

Computing uh later we will see also what

do we do what is what does it what is

this SRAM what is the hbm for now I just

want you to concentrate on the um

operations that we are doing and they

are exactly the same operations that we

have done so far uh later we will see

also why do we need to save this stuff

here and etc etc but for now you should

have enough knowledge to be able to

follow what is written in the flesh

attention paper for with respect to the

forward pass algorithm and uh what we

are doing basically is just block matrix

multiplication and while Computing this

block we fix the previous block by using

tricks of the

exponential all right now that we have

seen forward Paths of the flashh

attention before we can implement it we

still lack a little bit of knowledge

because we don't know anything about the

gpus and we don't know anything about

Cuda and we don't know anything about

Triton so that's what we are going to

see

next all right guys it's time for us to

explore finally the GPU and the Cuda

programming model well uh let's start by

comparing the CPU and the GPU and this

will let us understand how Cuda works

then so first of all what is the Cuda

and what is the GPU the GPU is the

hardware unit that we are that we buy

and Cuda is a software stocks made by um

made by Nvidia or to write software for

this GPU that they sell AMD has its own

software stock and other manufacturer

have their own in this particular video

we will be seeing example of Cuda

kernels but the knowledge that you will

get can apply also to other gpus now the

first difference between a CPU and a GPU

is its purpose uh the GP the your

computer is right now running on a CPU

and your operating system is interfacing

with the CPU um in using the uh the the

so-called scheduler so right now

probably you are running a browser you

are also running some other software on

your computer

on your computer and the scheduler is

tasked with switching between them very

fast on your CPU in such a way that it

looks like to you that the processes are

running

concurrently uh this actually is a fake

kind of parallelism unless your CPU also

has multiple cores which nowadays CPUs

do have so a CPU usually has one or

multiple cores but not so many of them

so usually have dual core of quad core

or eight core CPU and each of the this

course can execute instructions in

parallel um the CPU is tasked the the

main purpose of the CPU is to execute

many different task uh and switching

between them very fast so maybe you have

a browser that is running a small game

and then you have another movie player

but then you have a word processor and

then you maybe have some uh utility to

manage your um to download files Etc so

most of these programs actually are not

computing intensive are actually iio

bound meaning that most of the time they

are either waiting for the network or

they are waiting for the dis and they

are very different from each other in

the purpose so a browser is completely

different from a movie player and it's

completely different from a word

processor um so the job of the CPU is to

actually reduce the latencies of

processing all these operations and it's

highly optimized to process to optimize

each of these execution unit called the

course which means that each course

has a part that is tasked to understand

first of all what is the next

instruction to run or to to predict the

branch of how the what the next uh

operation may be based on the conditions

that you are running for example if you

have a if condition the branch predictor

can understand what is the more most

likely next instruction and can do some

optimizations uh also the CPU is has a

lot of caches to reduce the latencies in

loading data from all the devices it can

interface with it can interface with the

uh the RAM for sure but it can also

interface with the dis it can also

interface with some peripherals like the

printer like the mouse like the keyboard

etc etc on the other hand the GPU is not

tasked to do many different things at

the same time but it's task to do one

thing or few things but on a massive

amount of data so the operations that we

do on the GPU are uh requires a lot of

computation and for that for this reason

most of the area so the physical area of

the GPU is dedicated to compute units so

this green stuff that you can see here

and these are called

course um and you can see that the part

that is dedicated to the control area so

the part that is um tusked with

understanding what is the next

instruction to run or to do some

optimization in this the program is very

little uh you may be thinking well uh

does it make it does it make the GPU uh

less fast compared to the GPU to the CPU

well not really because we have many

more coures that can compensate for this

um higher

latencies um okay I can give you a lot

of knowledge about the GPU from a

theoretical point of view I think the

best way to understand the Cuda

programming model is just to jump into

the code so we don't get bored okay

imagine we have a very simple task and

we have a

vector we have two vectors A and B and

we want to calculate the sum of these

two vectors vectors into and save the

result into another Vector C where each

item is the element wise sum of the

corresponding item of A and B how would

you proceed with this task on the CPU

well you would do a for Loop for example

so for example uh you would uh make a

for Loop that starts from the first

index so the index zero and C of0 is

equal to a of 0 plus b of0 then C of 1

is equal to uh a of 1 plus b of one Etc

and you do a for loop on all the

elements of this

Vector in the GPU we want to do the same

operation but in parallel because we

have a lot of compute units called

course and we want all of them to work

in parallel so the first thing that we

need to understand is how to divide the

work that we are going to do into

subunits of work and dedicate each core

to one subunit one simple subdivision

would be okay the first core should do

this summation the second core should do

this summation the third core should do

the summation etc etc so imagine we have

a eight element Vector we need eight

course to do this element wise

summation we will call the course

threads because it should also remind

you of the multi- threading uh that we

already use in operating system so

multiple threads work concurrently on

the same uh on the same or similar job

in the GPU let's look at the code now

the code that I am going to show you is

Auda kernel and it's written in C but

you don't have to understand C and you

don't have to understand this code what

I want you to understand is the

intuition behind it because later we

will need this knowledge and convert it

into Tron which is Python and you should

already be familiar with

python so let's go to the code and I

have a very simple vector addition uh we

can see it

here okay uh first of all how to do a

vector summation usually the GPU is

interfaced with a CPU and the CPU has to

first of all um tell the GPU what is the

data it he's going to work with so the

CPU needs to have these vectors it needs

to transfer them to the GPU then the GPU

needs to do this Vector summation then

the CPU has to copy back the information

from the output from the GPU to the CPU

and then make it available to the

program this is what we are going to do

here so we are going to allocate three

vectors of size n one called a one

called B and one is the output Vector we

initialize their items um randomly so a

of I uh is a random number between zero

and 100

excluded then we allocate memory on the

GPU to hold these vectors and then we

copy them to the GPU so we copyed the a

vector to the GPU and the B Vector to

the GPU of course we don't copy the

result because that's what we want the

GPU to populate with uh the output so we

just allocate it on the GPU what we

don't copy uh our output Vector on the

GPU because it's it's made of random

values then um what we do is we launch

the kernel the launching the kernel

means that we launch a program that the

GPU should execute in parallel on

multiple threads or multiple course each

of these threads should do a unit of

operation a unit of work that is

independent from the others actually

they can be dependent on the other but

we will not be talking about

synchronization um so we launched this

kernel and what we are seeing in this

line is launch one block of threads and

later we will see what are blocks but

you can think of you can ignore this one

for now what we are saying here is

launch n threads so n multi parallel

operations on with the following

arguments so the output where we want to

save data the input array a and the B uh

input B and the number of elements let's

see what happens inside of this method

this method is following a particular

syntax that is um um how to say Cuda uh

specific So This Global is actually

added it's like a superet of the C

language where we have some additional

keywords that belong to Cuda so it's not

really C it's Cuda

c um so it's a very simple uh method as

you can see and the first thing that we

need to do is

Cuda cannot know what each thread should

do it's we should tell each thread what

to do so the mapping between the data

and the what each thread should do it's

up to us as software

engineer Cuda what we do is when we ask

it to launch n threads in parallel it

will allocate n threads and assign a

unique identifier to each of these

threads in our simple case we can see it

like this so it will assign the first

thread the index zero so we are asking

for example imagine we have a vector of

eight elements it will assign the first

thread index zero here I call it one but

it's it's wrong but uh we can write

another number here so this will be

actually tread zero this will be tread

one this will be tread two tread three

tread four tread five tread six and

tread 7 so let me delete this

one so we don't get confused

and what we are seeing is that the item

that each thread should process is equal

to its thread index so this is the

thread zero so it should process the

item with index zero this is the thread

one and it should process the item with

index one this is the thread number two

and it should process the item with

index two and this is what we are doing

in this line of code we are saying which

item each thread should process which is

exactly the it's uh the thread

identifier so the thread ID uh later we

will see why why we have this dot X but

that's for later next thing that you

should see is okay we are doing the

output of the I position is equal to the

a um Vector at the E position plus the B

Vector at the E position so it's a very

simple summation element wise you may

have noticed this if statement why do we

need an if statement if we already know

that we are going to launch eight

threads and uh of course I will be

between um we already know that we are

going to launch n threads so I should of

course be less than n because each

thread ID will be between zero and N

minus one so why do we need this if

condition this is needed because when

you um Cuda when it launches a number of

threads this number of threads is always

a multiple of um a unit which is 32 in

the case of the Cuda so if we have like

34 elements in a vector and we ask Cuda

to launch 34 threads Cuda will not

launch 34 exactly it will launch 64

threads so multiple of 32 uh which is

the warp size by the way um and

U what we need to do is we need to ask

these threads to only work for we only

need to ask the threads that have a

corresponding element to work and all

the other that don't have a

corresponding element because the the uh

the vector is not large enough for all

of them to not do anything so do not

enter this uh uh if

statement there is another thing that we

should learn which is actually the

threads um actually when we have a group

of threads in in um Cuda programming

model but I believe also in other

gpus um a group of threads of 32 threads

is called a warp and this 32 threads

will share the same um control unit so

let's go back to the slide so as you saw

as you can see here we have this yellow

unit here in the GPU and a group of

threads will share the same control unit

which means that what is this control

unit it's a part of the hardware of the

GPU that is tasked with understanding

what is the next instruction to run now

if the group of threads is sharing the

same unit it means that this group of

thread will always execute the same

statement at any time they will always

work in synchrony will always work on

the same instruction it's it cannot be

like this thread is working on one

instruction and this one is working on

another instruction what does this mean

on a programming level it means that if

when we launch a group of threads of

course Cuda will spawn more threads than

we need if the if the number of elements

of our Vector is not a multiple of

32 this means that when we This Thread

they will first execute this uh

operation and each of them will have its

own uh value of this thread ID so they

will execute the same instruction but

the data at each instruction may be

different because each of them have

their own registers which means that

they will always they will for example

reach this statement here and the first

thread will have I equal to zero the

second thread will have I equal to 1 etc

etc even if they are executing the same

instruction this programming model is

called single instruction multiple data

Cuda likes to call it single instruction

multiple thread doesn't matter for us it

just means that they will always execute

the same instruction but the value of

the variables may be

different then after executing this

statement they will reach this statement

here the if statement and of course some

some of them will evaluate this

statement to true and some of them will

execute the statement to false which

also means that some of them should

enter this if statement and some of them

should not enter this if statement

however because the control unit is the

same for all of them they will be forced

to enter this if statement even if they

should not so how Cuda manages this

control Divergence it will basically

make work like this all the threads for

which this if statement is equal to True

will enter this if and will execute the

instructions inside of this if and all

the threads that have this statement

equal to false so the condition of this

if equal to false they will enter the

for Loop because they cannot not enter

it because they should be always

executing the same instruction at any

time but they will just not do any

operations inside of this for Loop they

will just sit

idle this is um called control

Divergence and it can reduce the um the

the throughput of your program so you

want to minimize it um but you may be

wondering why doesn't the GPU dedicate a

control unit to each core so that they

can work independently from each other

because the control unit is expensive to

add in the cheap area of the GPU it's

much more efficient to add more workers

instead of adding control area control

units for each worker so this is a

design choice of the GPU and it works

fine okay now that we have seen how a

kernel works Works let's move forward to

another

example all right the next example that

we are going to see is the following is

the same as the as before so we are

going to do a vector addition but

imagine that we have a very large Vector

so imagine that we have a vector with 1

million elements of course we could do

like before so we launch a kernel with 1

million threads the problem is Cuda will

reject it because it say I don't have 1

million threads to run in parallel so

how can we proceed in this case because

usually we are working with very big

matrices or very big vectors so we need

to process a massive amount of data so

how to manage a parallel um uh how to

say parallel computation when we do not

have enough uh computation course one

way is to divide the input Vector into

blocks of uh elements for example we may

decide for example imagine our um GPU

only has 32 cores in total we may divide

our input Vector into blocks of size 32

such that the first 32 element are the

first block the next 32 element are the

second block the third 32 element the

third block and the last 32 element are

the last

block in this way we can ask the GPU to

work on one block at a time so we can

say okay work on the first block and

after it has processed the first block

it can work on on the second block and

then the third block and the fourth

block this also allows the GPU itself to

manage subunit of work because imagine

now we have um blocks of 32 elements but

we have a GPU of 64 cores the GPU we can

also schedule two blocks at the same

time because it has enough course so we

need to give some granularity uh we need

to reduce the Gran uh increase the

granularity of our data to let the GPU

decide how many blocks to schedule this

is the reason we introduce blocks inside

of Cuda so let me make a concrete

example but with a very simple

assumption imagine our GPU only has two

cores or let's say four cores actually

uh so we have uh n is equal to eight

elements eight and we have four cores in

total so what we can do for example is

to is divide this uh um Vector into

groups of either four cores or even less

let's say two two elements at a time so

this is the block number one this is the

block number two this is the block

number three and this is the block

number

four we can ask Cuda to launch a kernel

that is made up of four blocks and where

each block is made up of two threads so

when we launch the um Cuda kernel we can

show the code now

we ask the Cuda where is the

instruction this first instruction tells

Cuda how many blocks we have and the

second part of this um in this uh um

symbols tells how many threads we have

for each block in our case we want um

nide by the block size number of blocks

where the block size in my picture is

two so how many blocks we will will have

we will have a number of blocks so the

number of

blocks is n / by two where two is the

block

size so this is the block size and this

will be equal to four blocks each of

size equal to two and this is what we

are doing here uh so we are saying that

the number of blocks is okay the ceiling

because it may not be a multiple of the

block size n of nide by the block size

and this tells how many blocks we have

and this is will be this will Define our

grid it means the grid is basically

telling how many blocks we have and then

each block is made up of block size

number of threads then the problem is

how do we assign the work to do to each

of these threads when we launch a kernel

like this with this configuration so the

number of blocks and the number of

threads per block Cuda will do the

following job it will assign this block

um each block an index called the block

ID where the block ID of the first block

is zero so let me write here so this

will have the first block will have a

block ID equal to zero and in each block

it will assign a thread ID and the

thread ID of the first thread of each

block will be the thread zero and the

second thread will be the thread number

one the second block will have block

ID block ID equal to one and the first

thread of this block will be the thread

number zero and the second thread of

this block will be the thread number one

the third block will have a block

ID block ID equal to two and the first

thread will be the thread number zero

and the second thread will be thread

number one Etc until the last block

which will be equal to three this will

be thread number zero and thread number

one the problem is now based only on the

index of the block and the index of the

thread how can we map it to what element

of the vector each thread should work

with one simple assignment would be to

just do uh well you can see that in this

case we need the uh this Vector um This

Thread here to work with element zero

this one should work with element one

this one should work with the element

number two this one to the element

number three this one four this one

five six and seven this five is so ugly

so let me write it

again how can we find the mapping given

only the block ID and the thread ID how

can we find Which element it should

correspond to well it's very simple

formula so you can see that the

element let's call it the element ID

which in the code I call it I is equal

to the block

ID multiplied by the size of each block

which is block size let's call

it block size yeah I have it block

size plus the tread

ID because in the case of the first

thread this will be equal to 0 * 2 + 0

which is 0 in this case it will be equal

to 0 * 2 which is 0 + 1 and it will be

equal to 1 in this case it will be equal

to 1 because block ID is equal 1 1 * 2

is equal to 2+ 0 is equal to 2 etc etc

and you can see that this formula works

for all the threads so the mapping when

we launch a Cuda kernel we are telling

the GPU how many blocks we want and how

many threads there are in in each block

but Cuda has no notion of how to map

each um Cuda has no way of knowing how

to map each um thread into the element

it should work with that's up to us and

that's what we are doing here when we

are creating this um uh Kel here so we

are telling that the each element each

thread should work with the E element of

the vector where I is calculated as

follow the block ID to which This Thread

belongs multiplied by the block size so

how many threads there are in in each

block plus the thread ID um and this

will tell the I element this particular

thread should work with by giving uh in

let's go back to the slides by choosing

the block size equal to two and having

four cores the GPU can choose to run one

block or two block concurrently if it

has enough free course so that's why we

want to work with by block by block

because it allows the GP you to choose

how it want to parallelize the

operations if it has enough course and

we don't need to have n course for n ve

n element Vector we can divide it into

smaller blocks and let the GPU manage

the scheduling let's see one last

example and then we move on to

Tron imagine now we want to do a matrix

addition instead of doing a vector

addition now in a matrix addition we

have data that we can see on two axes

one is the rows and one is the

columns it's usually uh we represent the

vertical axis as the y axis and the

horizontal axis as the

xaxis by using the same blocked um uh uh

intuition that we used before so

dividing the data input data into blocks

this is how we can divide the labor of

our maxr addition into blocks for

example we can divide our rows into

blocks and call this one the block block

zero and this one in the block one and

this one is the block two the same we

can do on the xaxis so we can choose

this one as the block zero this one as

the block one and this one as the block

two on the x axis with X is the column

axis and the Y is the row axis we don't

even have to choose the same block size

for the rows and the columns we can even

choose the to group together three

columns and two rows instead of doing

two and two in this case we need to find

because as we said before when we launch

a Cuda kernel Cuda will just assign IDs

to the blocks and the threads in each

block then it's up to us understanding

what to how to map the ID of the block

and its corresponding thread ID into the

data element that this particular thread

should work it should work with so in

the case of metrix addition we could say

that each thread should work with one

output element of the output Matrix C so

it will become the um the the sum of the

a element plus the B element and it

should map it to the C Matrix uh output

Matrix so how to do it imagine we have

six rows and we have six columns one

easy way would be to divide this rows

into three blocks each made up of two

rows and each column into three blocks

um each block made up of two columns

Cuda will launch

uh uh as many blocks as there are the

combinations of the rows and column

blocks so in this case we have uh three

blocks for the columns and three blocks

for the um uh rows so it will launch uh

nine blocks so this is the block number

00 because it's Cuda will identify the

dimensions of the block based on the um

uh axis in which we have divided it so

we will call this the X Dimension the

columns and the rows we will call it the

Y Dimension so it will launch as many

blocks as there are combinations of X

and y's in this case we have nine so

this will be the block 0 0 this will be

the block 01 this will be the block 02

this one will be the block 1 Z 1 one and

one two etc etc inside of each block we

will also divide the threads into X

threads and Y threads along the two

Dimensions so this will be the thread

zero and the thread one along the X axis

in the X block and this will be the

thread zero and the thread one in the Y

uh in the in the block zero of the y-

AIS and each block will have two threads

and they will be identified as thread

zero and thread

one so let's look at how the launch grid

Works in this case um so imagine we have

a matrix with number uh num rows number

of rows and num columns num calls number

of columns and we want to to divide each

row the rows into block size number of

rows and calls block size number of

columns um we define basically the

number of blocks that we need is this

one so this is just a fency way of

writing the ceiling of the num Rose

divide by The Rose block size and this

is just a feny way of writing the

ceiling of the number of columns divide

by the col's block size this tells us

how many blocks we will have on the rows

and how many we will have on the columns

the grid you can see here which tells us

how many blocks we have is a tole that

accepts three values which tells how

many blocks we want on the X Dimension

how many we want on the Y Dimension and

how many we want on the Z Dimension we

are not going to use the Z Dimension

because we only have a matrix then

inside of each block how many threads we

want for the X Dimension and for the Y

Dimension as the X Dimension we have

chosen The Columns so we are saying how

many blocks we want the columns and how

many blocks we want for the rows and

then inside of each block how many

threads we want for the column block and

how many threads we want for the row

block this will Define our launch grid

and what Cuda will do it will just

launch this following configuration so

it will launch as many blocks as there

are combinations of X and y's and inside

of each X and Y it will assign a thread

ID in such a way that the thread zero on

the X axxis is uh so there will be two

threads on the x- axis and two threads

on the y- axis of each block now let's

try to understand how to map just based

on the block ID on the x- axis just

based on the block ID on the y axis and

the thread ID on the X and y- axis how

to map it to the one element of the

output

Matrix uh let's look at the code so

first we can use the following formula

to identify which row this element

should work with uh which which uh the

uh which because each element of a

matrix is identified by two in this say

one is the row identifier and one is the

column identifier the row identifier we

can look at it like the block ID

multiplied by the block size plus the

thread ID uh let's see why it makes

sense so in this case for example this

uh thread will work with the row zero

because the block ID is on the y- axis

is zero and the thread i z zero so it's

block ID multiplied by the block size so

0 plus 0 it will be zero so this element

will be working with the row number zero

and which column it will be working with

well it will be working with the block

ID zero multiplied by the block size on

the column which is again zero I mean

the block size is two but multiply by Z

it will be zero plus the thread zero so

it will be zero this element here on the

here it will be the block ID of the Y y

AIS multiplied by the block size plus

the thread so it will be the element

zero on the row and for the columns it

will be the element one let's see

another one for example here uh for

example this element here so this um how

this uh thread will uh which element it

will work with well it will be the block

size on the Y AIS multiplied by the the

block ID on the y axis multiplied by the

block size so it will be 1 multiplied by

two so that will be our row so the row

number

two uh which makes sense because it's

the um this is the row zero this is the

row one and this is the row two and the

column will be the uh block ID on the x

axis which in this case it's equal to 1

multiplied by the block size which is

equal to two so 2 + 1 is equal to 3 so

this uh thread here will work with the

element number two three and this

formula now makes sense so this is how

we use the block ID and the thread ID

inside of each block to map it to which

element this particular thread should

work with so as I said before Cuda has

no notion of knowing which element this

particular thread should work with this

is up to us just based on the block ID

and the thread ID that CA

assigns then we make sure that the uh

row index is less than the number of row

and the column index is less than number

of columns why because as I said before

when we launch um blocks and threads

Cuda will round up that number to a

multiple of 32 in the case of the

threads so which means that some of

these threads should not work with any

data so we make sure that all the

threads that should not have the

corresponding element to work with they

should be just sit idle inside of this

if

statement but the one that have it they

should go enter and do some job so we

calculate the index of the element of

the Matrix that this particular thread

should work with as follows which is the

row index multiplied by the number of

columns plus the column index uh this is

just another way of

writing uh a uh or for example this is

just another way of writing a of row

index call index but the way we allocate

arrays in C or C++ is a flattened array

where all the rows are one after another

so we need to identify the element

inside of the array based on its row

index and column index and this is the

formula that we use to identify it if

you have never worked with um uh arrays

in C++ or C then it doesn't matter

because later we will see tensor layouts

and this will be much more clear but if

you have already worked with then you

already know how to uh index an element

inside of a multi-dimensional array in

C++ and then we compute the output as uh

as usual so I know that this has been a

lot of information so what should we

should we remember from this the first

thing that we should remember is that we

decide how to divide the work on

whatever Matrix we are working with or

whatever thread we are working whatever

Vector we are working with we tell Cuda

how many blocks we want and we tell Cuda

how many threads we want in each block

based on the identifier of the block ID

and the thread ID we should come up with

a strategy on how to map it to a subunit

of work so which part of the Matrix or

which part of the vector that particular

thread should work with um now the next

step for us is to understand the tensor

layouts because we are going to work

with the tensors and we need to

understand how the tensors are lay out

in the memory of the GPU or in the CPU

as well actually so we need to

understand what is the row column row

major layout and the column major layout

what is the stride Etc and convert all

the knowledge that we have about Cuda

into Triton so that we can then code

with Triton our kernel so let's

go all right guys finally it's time for

us to explore tensor layouts now why do

we need to explore tensor layouts

because before we we have seen some

examples of Cuda kernels and when you

give a matrix to Cuda or to a Cuda Kel

or a vector to Cuda kernel Cuda will not

give you will not give you the entire

metrix like like in Python where you can

access each element by its index Cuda

will just give you a pointer a pointer

to the starting element of that

particular Matrix or the starting

element of that particular Vector then

it's up to you to calculate the memory

address of all the remaining elements so

suppose that we have a simple Vector in

py torch this simple Vector could be the

following which is a vector of shape

seven because it's a tensor with only

one dimension with shape seven which is

the number of elements in the First

Dimension uh for now ignore this

property called the slide and later I

will explain it what is it uh how this

tensor will be saved in the memory of

the CPU or in the GPU it will be saved

as follows suppose that the starting

address of the first element is the

address 100 and suppose that each

element is made up of a floating point

of 16 bit so it means that each element

will occupy two bytes so the start

address of the second element will be

the address 102 and the third element

will be 104 and the fourth element will

be 106 etc etc etc so uh this is exactly

what you get when you in see you get um

you allocate a vector or a matrix with

Malo so when you allocate in C um vector

or a memory with maloc C or the memory

allocator will just allocate enough

memory to store all the elements and it

will give you a pointer to the start

address of this memory then it's up to

you to understand where each of this

element is stored in that block of

memory and this is to to do this we

introduce a property called stride The

Stride tells us how many elements we

need to skip to arrive to the next

element in the particular dimension in

this case for example in the case of a

vector we only have one dimension which

is the X

Dimension uh or the columns Dimension

you can think of it so this is the first

column this is the second the third the

fourth fifth etc etc um so in order to

arrive from one element to the next we

just need to skip one element so to go

from here we need to just increase our

pointer by one element and then to go

here we need to increase again pointer

by one element Etc this allow us to do a

for loop on this tensor let's look at a

more complicated case like the Matrix so

the Matrix is

two-dimensional and suppose we have the

following Matrix which is made up of six

element with two rows and three Colum

columns so the shape of this tensor will

be 2x3 because it we have two rows and

three columns uh how this Matrix will be

saved in the memory in the memory it

will be just a

flattened um Matrix it means and this is

called the row major layout but there is

also another one called column major

layout that we will not be discussing so

how it will be stored in the memory is

as follows it will be the first elements

of the first row so the elements of the

the first row followed immediately by

the elements of the second row so that

the memory address imagine with this is

the memory address of the first element

is 62 to go to the next element we need

to increase the memory address by the

number of bytes that each element

occupies which is two bytes so the the

address of the second element will be 64

the third element will be 66 and the

next row will start immediately after

the end of the first

row let's introduce this property stri

so the stride is what the stride tells

us how many elements you need to skip in

each Dimension to arrive to the next

element of that dimension for example

imagine we want to um uh address um we

want to get the element so all the

elements of the first

row um so let's call this tensor here

let's call it t so T of zero and um this

basically this indexing here says give

me all the elements of the first row so

in the first row select the all only the

first row and give me all the elements

of that row how to how does this

indexing work well by starting from the

pointer to the first element it will

select only the first row and then it

will move the the index here one element

after another so it will select the

first one the second one the third one

how does it know that it needs to move

one element by one element because in in

this Dimension the stride is one so the

stride tells us how many elements you

need to skip to arrive to the next

element in that Dimension imagine now

that we want to uh get the T of let's

say zero and

one well in this case the T let's say t

of one actually and all the elements of

the first row it will first of all it

needs to skip some elements from the

first First Dimension it needs to skip

the element zero because we don't we are

not selecting it we only want to select

the element one of the First Dimension

which basically means the row with index

one so because it will start from the

first pointer to the first element it

will it needs to know how many elements

to skip and how many element to skip is

given by The Stride so the stride tells

us how many elements you need to skip to

arrive to the next element of the First

Dimension so in this case it will take

the pointer to the first element skip

three elements and it will be starting

with the second r go and then inside

this row it will go through the second

in the the the index of the second

dimension in which the stride is one so

it will just go one after another and it

will return only this part of the memory

so to rehearse the stride is just

a a number that tells us how many

elements you need to skip in each

Dimension to arrive to the next index in

that Dimension so it means that to go

from one row to the other we need to

skip three elements to go from one

column to the other we need to skip one

element um why is the stride useful well

the stride is useful because it allow us

to reshape tensors very easily and

without doing any computation let's see

um okay imagine we want to reshape a

matrix imagine initially the shape of

this Matrix is 2x3 so we have two row by

three columns and we have a stride

calculated as follow it means that to go

from one row to the other you need to

skip three elements and to go from one

column uh one row to the other you need

to skip three elements and to go from

one uh column to the next you need to uh

skip one element so you need to jump by

one element if we want to reshape it

into this shape so 3x two basically we

want to um have three rows and two

columns uh the the we can reshape it

without actually changing its memory

layout just by changing the stride

because look at this physical

configuration of the tensor and we can

access this same tensor as this shape or

as this shape exactly by using the same

physical view because to go from one row

to the next here the stride is three so

we need to skip three elements it means

that the starting address the starting

element of the second row is given by

the start point plus three elements so

exactly here the second row will start

and each element of the second row is

one after another because the stride of

the second dimension is one so you can

see that to get the second row we can

just start from here and then go one

after another and get all these elements

which is exactly the second row suppose

we want to obtain the second row of this

view here of this shape of this reshaped

Matrix how to do that let's look at the

stride the stri now is two in the row it

means that to go from one row to the

next we need to skip two elements so if

we want to select this row here we go

from the starting point of the uh memory

so this start pointer we skip um the

first two elements because the stride

says that to go from one row to the next

you need to skip two elements so we

arrive here and then we select exactly

two elements which are one after another

because the stride in the second

dimension is one so the um this stride

allow us to reshape the tensor without

changing the physical layout on how it

is stored in the memory

moreover The Stride also allow us to get

the transpose of a matrix without

changing the shape of how it is stored

in the memory so without changing the

arrangement of the elements in the

memory and this is very cool because we

can view the same Matrix as with without

the transpose and also the transpose

version of The Matrix without changing

anything in the memory so it comes for

free just by working with the index and

the stride so to trans The Matrix along

two Dimensions we just need to swap the

stride along this two Dimensions that we

want to transpose so in this case for

example imagine we want to get the

transpose of this Matrix we just need to

swap the strides so if we want to get

the second row of the transposed Matrix

how to get that well you we always have

the pointer to the first element where

the tensor is stored so at the beginning

of where the tensor is stored in the

memory and it says that in order to go

to from one row to the next we need to

skip one element which is correct

because as you can see the second

element is exactly the second element

also in the memory so we just Skip by

one and we get the starting point of the

second row and then to go from one

element to the next in within the same

row we need to skip three elements so

the second element of the second row

will be after three elements uh um after

the first element of the second row so

after two we need to skip three elements

so we skip this one we skip this one and

we arrive to this one eight which is

exactly the second column of the SE uh

of the second row so basically the the

the stride as you can see allow us to do

two things one is it allow us to reshape

the tensor without having to reallocate

it in another configuration in the

memory secondly it allow us to transpose

a matrix without having to rearrange the

elements in the memory which is great

because moving memory around is

expensive

uh and rearranging the memory is

expensive so that's it's great that this

this stuff comes for free

basically um another thing okay for

example do um if you try to you know

that in py torch there are two methods

to reshape a tensor one is called the

reshape method and one is called The

View method the after transposing a

matrix by swiping by swiping The Stride

of the two Dimensions that you want to

transpose you cannot reshape for free

the tensor anymore because um the tensor

basically what is the stride The Stride

how it is computed The Stride is just

the uh let me show you with a concrete

Example The Stride is just the product

of all the shape uh after um in the

future Dimensions so the stride of the

zero Dimension is just the product of

the elements in the shape of the future

Dimension so so the stride of zero is

just the product of all the shape

starting from the index number one uh

it's not easy to see with the two The

Matrix because we don't have enough

elements so let's do it with a three 3D

Matrix so this is a tensor with the

three dimensions so it is a shape of two

4 three which means that we have two

matrices each Matrix is made up of four

rows and each M and three colums The

Stride is calculated as follows so the

zero Dimension stride is just the the

product of 4x3 and this three here comes

the with the product of just a three

with it with one because we don't have

any future dimension of the tree so when

we transpose this stri property is lost

and we cannot um after transposing this

matrix by swapping the strides we cannot

do further reshaping operations so

basically the the tensor is not log

contigous so this is a very Advanced

okay property if you it doesn't matter

if you know it or not but if you are

curious basically in a pytorch you

cannot um view a tensor after it has

been uh transposed because pytorch to

transpose a tensor will just swap the

two strides but it loses the stride

property which is basically the stride

will not be anymore the product of the

future shapes so this is not anym two

this should be two for example and this

should be one but after transposing this

property lost so you need to actually

reallocate the tensor if you want to

reshape it after it has been transposed

doesn't matter if you remember this it's

just a curiosity anyway so what is the

transpose what is the stride used for is

the St The Stride is used for two things

first of all it it is used to understand

how to Index this tensor so just by

having a pointer to the first to the

starting address of this tensor we can

Index this tensor however we like so we

we can access any row any column uh

moreover it allow us to reshape this

tensor for free so without rearranging

the elements inside the memory and third

it allow us to transpose the tensor

however we like just by swapping the

strides of two uh the two Dimensions

that we want to

transpose now that we have seen also how

the tensor is stored in the memory we

can finally go to see

Tron um and see some examples all all

right guys now that we have seen how uh

tensors work Tor leat Works how Cuda

works now we can see some examples of

Tron kernels to see how Triton differs

from Cuda now if you go on the Tron uh

Tron website you will find some

tutorials like in this section here and

Let's do let's work one tutorial

together to understand how Tron is

different from Cuda um so if you go to

the tutorial there are many examples so

first of all the code that I will be

coding for Flash attention is based on

this tutorial here fused attention that

you can see here but with some

modifications because I simplified the

code a lot I removed for example the fp8

implementation I also for example um

this code here on the fuse detention

only works in the backward pass only for

the causal attention while my code will

work for the causal and non-causal

attention uh the second another

modification I did is instead of using

the exponential two that they use here

to make things faster probably because

the exponential two is implemented uh

with a faster unit uh I I I use the the

original implementation of flash

attention which used the exponential

with the base e etc etc so I simplified

my code as much as possible to make it

simple to follow instead of making it

optimized so for sure my code will be

slower than the the fused attention that

you see here but mine should be more

comprehensible more easy to follow

anyway let's go to the vector addition

tutorial and if you go to the vector

addition tutorial there are some

examples on how to do a vector addition

with Tron uh this should allow you to

get into the mindset of how to write

kernels with Tron uh instead of writing

first the the kernel and then calling it

let's do the opposite so let's see how

to call this kernel and let's explore

how it works so I have already copied

the tutorial vector addition from the

website so let's look at first of all

what we want to achieve we have an input

Vector X and an input Vector y and we

want to compute the vector addition

which means that with the torch we want

to do the following operation and also

we want to do the same operation also

with the Triton by calling this method

add and then we want to compare the two

vectors output and they should be equal

or at least their difference should be

very very small because of course there

is always some rounding error in case

you are working with floating Point

numbers the size of this Vector is

98,000 elements and um we want to work

in a blocked way so as you remember

before with the Cuda you can do vector

addition by spawning a lot of number of

threads each doing one operation but

when the number of threads that you have

is not enough then you need to divide

the input Vector into blocks and this is

what we are going to do here so let's

look at this add method so this add

method basically will first of all

allocate the necessary memory for the

output Vector then it will compute the

launch grid the launch grid tells Tron

just like in Cuda how many um kernels we

want to how many blocks we want to

launch how many blocks of threads we

want to launch uh if you remember in the

Cuda kernel we specify how many blocks

we want and then how many Treads we want

for each block in the case of uh

Triton we tell how many um blocks we

want and then we don't force how many

threads to launch it will be Tron that

will choose how many threads to Launch

um we just tell what each group of

threads should do so in this case for

example we divide our number of elements

so n so which is 98,000 into blocks of

size block size which is initialized as

1224 this is basically saying take the

to calculate the grid size you do the

ceiling division so basically this means

ceiling of oh oops seal of n elements

divided by block size this is the

meaning of this one so how many blocks

we want now what each block should do is

inside of the kernel so let's go to the

kernel and when we launch the the kernel

we we can specify that the launch Grid

in this uh Square parenthesis and then

in the uh round parenthesis we specify

the arguments of this kernel so let's go

to the

kernel we see that python uh Tron will

not give us access to the tensor X it

will give us a pointer to the first

element of this tensor and this takes us

back to the tensor layouts so the reason

we studied the tensor layouts and the

strides and all the stuff is because

Tron this code uh this add kernel will

run on the GPU and the GPU cannot um

does not index tensors like P torch by

using all the dimension and with the uh

broadcast casting and all this fancy

stuff the GPU will just give you the

pointer to the first element of this

tensor in the memory and then it's up to

you to compute all the indexes of all

the elements that you want to access so

this x PTR is the pointer to the first

element of the X Vector this y pointer

is the first the pointer to the first

element of the Y uh Vector then we have

the pointer to the output Vector where

we want to store the result of this

Matrix addition we specify how many many

elements our vectors have and what is

the block size so how many uh items each

block should process which may not

correspond to how many threads each um

each kernel will

have you may be confused because okay uh

in Tron in Cuda we specified how many

threads each um block should have so the

granularity that we manage is the thread

level here we are saying it's a group of

thread that should work with this

quantity of data then it's up to Tron to

optimize the number of threads that it

will actually use actually there are

tricks there there are ways to say how

many threads we actually want by

specifying the number of wordss but we

will see that later for now just

remember that this thread this Kel here

will process a number of elements in the

input vectors how many number how many

elements block size number of elements

first of all we need to identify which

block we are um we are in Cuda we use

the the variable called block ID dox to

identify the identifier of the block

which tells us which group of elements

we should be working with in Tryon you

do the same by using program ID and in

uh Cuda the block ID can be along the X

Y and Z axis in Tron these are called

the dimension zero one and two uh two

here we have onedimensional data so we

only use one AIS to specify the block

index so we get the block index which is

the P ID in this the Tron this called

the program ID it's more intuitive to

think of as the program like this is a

kind of a program that is running in par

with other programs that will have

different program ID and based on the

program ID we can understand what is the

starting element this program should

work with so this blue block of threads

should work with and to get that is just

the P ID multiplied by the block size so

the p 0 should be working with the

elements that starts from the element

zero the P id1 should start with the

element 1024 and the P2 should start

from the element 248 so it should skip

the first 248 elements and starts with

the element with index

2048 next we Define how to load these

elements based on the

pointer in which of the X and the Y uh

vector

to do that we specify a list of offsets

with respect to the starting address

that we want to load so because each

program in Tron works with a group of uh

um of data so not one single element but

a block of elements we we need to

understand which elements to load so the

offset of this elements in the case of

the program ID zero it will load the

block start so zero plus the elements

from index 0 to 100

1,424

excluded with the program element

um one this basically will result in a

vector that is uh well the program start

with PID equal to 1 will be24 then

1025

1,26 1027 etc etc until

2047 um with the program number let's

say

two this uh this offsets will be the

elements

248

249 blah blah blah until 3,000 and

something um now we also as you remember

when we create

um when we launch a grid the number of

threads is not always based on the

number of elements in the block or the

number of elements in your vector it is

always a multiple of a base number which

is usually 32 which means that the grid

this program may have more threads that

it needs so some threads should not be

doing anything so should not be loading

any data and should not be Computing any

summation so what we this is what why we

need this mask this means that if um all

these offsets that we are loading should

be at most up to n elements because

imagine you have not 1 2,000 imagine you

have Vector of

260 Elements which means that this

offset for um the the third program of

this kernel will load the offset that go

from

2048 2049 blah blah blah 20 60 and then

also 2061 2062 etc etc but we said that

we only have 260 elements so all the

elements 261 62 Etc until 300 and

something they don't exist so we need to

tell somehow that all the threads that

are working with this elements should

not load anything that's why we need

this mask this mask tells load among all

the offsets that this block should work

with only those elements that actually

exist for which this mask is true then

we load the elements of this current

program which is a group of elements

defined by these offsets

and only the one that for which this

mask is true so only the one that

actually exist all the other should be

ignored and we can also specify what it

should load in case this um the mask is

false um with another parameter but we

will not seeing that here we also load

the group of elements of the Y vector

and then we compute the output x + y so

if you remember previously in Cuda we we

did something like this like the output

of I is equal to the x of I plus the Y

of I so we did it one element at the

time because each thread was working

with one index here we are working with

a group of elements so this x is a group

of elements is a block of elements at

most of size block

size actually of size block size and

it's this Y is a group of elements from

the Y vector and we are Computing the

output

Group by group so this this is summing a

group of elements of X with the

corresponding group in y and writing it

in output then we need to restore this

output we need to store it in the output

tensor output PTR that you can see here

which is a pointer to the first element

of the output vector and we say that

where should we store this output Vector

which is of size shape of this Vector

here is block size where should we save

it well in the same offset to where

which we loaded X so if this uh program

work with the index 2048 2049 etc etc

then all this output should be written

in the same offset uh 2048 2049 Etc up

to 3,000 and something using the mask as

well because we don't want to write all

the values of this block size because

maybe we don't have enough elements so

only write the one that are actually

present in the vector so the reason we

need this mask is because Cuda will

launch a number of thread that is always

a multiple of a base unit that may not

be a a multiple of the vector size that

we are working with so we need to find a

way to tell some threads to not do

anything for those that the data is not

avilable so let's rehearse what you have

seen so far in Cuda the program that we

write is at the thread level so each

thread what it should do in Tron it's a

this block of data we work with a block

of threads what data this block of

thread should work

with all right guys the final finally

the moment has come so uh we are going

to quote the flashh attention for our

pass right now in Tron but let's

rehearse the algorithm so the goal of

the attention mechanism in specifically

in Triton uh in flesh attention is to

compute the attention output which is we

want to compute the output of the

following formula so the query

multiplied by the transpose of the key

divide by the square root of the head

Dimension all multiply we apply the soft

Max and then all um multiplied by

V now um we in this video we will be

coding the forward pass and also the

backward pass but before coding the

backward pass we need to understand how

the autograd works we need to understand

what is the gradient what is the

Jacobian how to derive the gradient of

the softmax operation how to derive the

gradient of the matrix multiplication

operation etc etc so that is going to be

another part of the video for now let's

concentrate on the forward pass right

now we have some tools so we know that

we have this thing called the GPU that

can parallelize operation among multiple

cores we know that in Cuda we can

parallelize operations by telling by

writing a program that is the definition

of what each thread should do or we can

follow the Tron programming mode which

is telling in Python what each group of

threads should do

the mapping between the what each thread

should do and the which element that

should thread work with is up to us to

the programmers and the same happens in

Tron we tell we how many blocks of

threads we want how much data each

thread should block of thread should

process so that's the block size that we

saw in the vector addition but then the

mapping between the elements of the

vector and the um the the identity of

each group of threads so the program ID

that we saw is up to us and the same

will happen when we recode flashh

attention let's see what can we par

parallelize in this flashh attention so

first of all this code that you see the

forward pass of the flashh attention is

takes as input query key and value that

is a vector that is a matrices of n by D

however usually in a Transformer Network

we don't have only one sequence made up

of D Dimensions we have many sequences

made up of D dimensions and this D is

the lower case D which is the the number

of Dimensions dedicated for each head

but we don't have only one head we have

multiple head so the algorithm that you

see here is what each head should work

so each uh head of each batch should do

moreover we have have seen before when

talking about block matrix

multiplication that we can parallelize

the computation of the output because

this output block here depends on the

query one and all the keys this one here

depends on the query group block of

query two with all the keys and this one

here is the query three with all the

keys Etc so because this one only

depends on query the group The Block

query one and this one only depends on

the Block query to they can work

independently from each other by sharing

of course work the the

keys another thing that we need to

understand about a Tron is the shared

memory so um the in the GPU we have the

high bandwidth memory and which is the

like kind of the ram so the when you buy

an a100 they tell you that it has 4 40

GB that's the amount of memory in the

high bandwidth memory so the D Ram so

let's look at actually the structure of

the

GPU uh which is here we have this dram

which is the big memory uh that we that

the GPU has and then each um streaming

multiprocessor so it's a let's call it

BL block of threads uh actually also

have a shared memory so inside of the

GPU actually we have we have these

streaming multiprocessors and these

streaming multiprocessors have a part of

memory called the shared memory which is

much smaller than the dram like much

much much smaller what changes between

these two mem

the access to the dram is very slow and

the access to the shared memory is very

very very fast so one thing that is

different between Cuda and Tron is that

whenever you load some information in

Cuda you are loading that information

directly from the global memory because

when we launch a Cuda kernel first of

all as you remember in my C C++ code we

first copy the tensors from or the

vectors from the CPU to the GPU and they

reside in the global memory of the GPU

then we load these elements directly

from the global memory but the access to

the global memory usually it's much much

much slower so what happens with the

flesh attention is that the flesh

attention computation in its the

attention computation in its naive

version the one that we can do with the

torch is very slow because access to the

global memory is very slow so we want to

use as much as possible the shared

memory so we want to reuse the elements

loaded from the global memory into the

shared memory so that we don't need to

access the global memory every time to

load elements from the vectors or the

matrices and um this is what happens

also in Tron so in Tryon whenever you

load some data you are copying the

information from the global memory to

the shared memory then whatever

operations that you're doing is done on

the shared memory and then when you

store the information you are copying

the data from the shared memory to the

global memory um this makes it much

faster so we always work with elements

that have been Lo in the shared memory

and this shared memory basically it's

shared for all the threads that belong

to the same uh block uh in Tron we have

an obstruction level that doesn't make

us work directly with the threads so we

always work with a group of threads that

belong to the same block that share this

shared memory so in Tron we are copying

information from the global memory to

the shared memory we do some operation

with it and then we store back to the

global memory and this is what we are

going to do with flash attention now

let's review the algorithm of Flesh

attention so in flesh attention we have

to go an out for Loop that is among all

the between all the keys and then an

inner loop that is sorry between all the

query blocks and then an inner loop that

is um through all the key

Block in the original flash attention

algorithm the flashh attention one the

outer block was on the keys and the

inner block was on the queries this made

it less parallelizable why because the

Outer Loop is on the queries and we have

seen before that the the output of this

um attention can be computed

independently for each block of queries

so it's much easier to parallelize so

this outer for Loop actually we don't

have to run a for Loop we just spawn

many kernels each working with one

iteration of this outer for Loop so each

working with a different query block of

this outer for Loop and the inner for

Loop is something that we have to

iterate through so each Trion kernel

will work with one query block and then

iterate through all the key blocks um

and inside of this key block we have

already seen the operations that we are

going to do which the we we we explored

before and at the end of this for Loop

we need to store back the output in the

high bandwidth

memory um and this is how it's going to

we are going to work another thing that

we should notice is that this s value

are n byd so as I said before but

usually in a um in

a Transformer model we don't have only

one sequence we have many sequences so

we can also parallelize on the number of

sequences that we have in the batch

because each batch can work

independently from each other and inside

each um and each head each sequence has

multiple heads so each head also can

work independently from each other

because that we know from the attention

is all un need paper that's what's the

meaning of head that's what's the

meaning of multi head attention so that

each head can compute the attention

independently from each other so we will

also parallelize along the head

Dimension and moreover if you look at

this definition of the query block we

can also split the query into blocks and

each query block can work independently

from the other query blocks by in

producing one output block this is how

we are going to parallelize so we are

going to parallelize each sequence in

the patch in inside of each sequence we

are going to paraliz each head and

inside of each head we are going to

paralyze each query block so how many

programs we we will have working in

parallel at most it will be uh the sequ

the number of batches so the bch the

number of sequences in the batch so the

batch

size it will be the batch

size multiplied by the number of

heads uh multiplied by the number of

blocks that we will divide the query

sequence into so let's call it the I

don't know block size

Q the block

size

Q all right now that we have seen this

one let's go actually code it

so I have already introduced a little

bit of the differences between my

implementation of The Flash attention

and the one that you can find on the

Tron documentation which is a first of

all I don't work with fp8 because I

believe this is unnecessary for our

explanation it's of course much faster

because the recent gpus also support

fb8 Second difference is that in the um

In the Flesh attention on the Triton

website the backward pass is only

implemented for the causal attention but

in my case I implement it for the causal

and the non-causal attention even if

it's slower and later I actually I want

to give you an exercise on how to

improve it um and the third difference

uh main difference is that I made make

explicit use of the soft Max scale so I

actually use the scale when needed

another difference is that in the online

uh Tron computation of The Flash

attention is this X is not really e to

the power of X but it's two to the power

of X and then they compensate it with by

by using the

logarithm however uh because probably

the implementation of two to the power

of X is faster than the e to the power

of X but in my case I retain the

original exponential because I want to

follow the original algorithm to make it

simpler to visualize the code along with

the algorithm as in the flash ration

paper

so uh I know I have created a lot of

hype so let's do it uh let's start by

creating a new um file let's call it

program. piy uh just like before when I

introduced Tron I will start by coding

first the code that will use our Kel and

then we code the Kel and we will only be

coding the forward P of the Kel

so let's start

by importing what we need to import

which is just the torch and the Triton

and secondly let's start by let me check

okay the co-pilot is already off so I

don't have to worry about that let's

start to implement the code that will

test our implementation of the Triton

and compare it with the naive

implementation of the attention

mechanism so we in uh we create our

query key and SE uh query key and value

sequence for test thing which is if you

remember it's query is the BET size it

has the dimension bet size because we

have a multiple

sequences each sequence has a number of

heads and it's made up of s tokens and

each token is identified by a head dim

number of

Dimensions uh if you in then this is

because we have already split each token

into smaller tokens each each with its

own head Dimension if you remove the N

heads Dimension then you put back you

concatenate all the dimensions of this

head

dim um we initialize the query key and

value sequence by using a normal

distribution this code I already took

from the tutorial of Tron so it's

nothing different and we require the

grad the gradient because we want to

compute the gradient with respect to

query key and value and um we will see

later why because because we want to

implement the back we want to test also

the backb pass even though we will not

be coding it now so the first thing that

we do is we Define our soft Max

scale which is as you remember the

formula is U query multiplied by the

transpose of the keys and then divided

by the square root of head

Dimension so DK or the DHE head sometime

it's called and then we need to so we

need to compute this one we can already

compute it it's this this is one over

the square root of the head dimension

um and then we also Define d o and later

we will see what is this but this is

basically we will be needed needed for

the backward

pass um don't worry if you don't

understand what is doo later we will see

it let's do the naive implementation of

the attention which is very simple which

is first we Define the mask and we use

this mask only if the attention we are

Computing is Cal so as you can see we

pass this parameter called causal that

tells if we want to compute the causal

attention or the not causal attention

and the D type which is float 16 because

we want to work directly with 16 bit

floating Point numbers we will not be

working with fp8 just because we we

don't we don't want to implement uh my

implementation is actually not as fast

as the one in the tutorial of the Tron

website but I believe it's much more

easier to

comprehend uh so we Define the mask

we compute the the the product query

multiply by the transpose of the key

divide by the square root of the head

Dimension so that's why we are

multiplying by Soft mask scale if the

attention we're Computing is Kaa then we

use this mask that we have computed so

we replace all the points all the dot

products where this mask is equal to

zero with minus Infinities and then the

soft mask will replace this minus

Infinities with zeros because then we

are applying the soft Max and the soft

Max is applied by rows just like the

normal attention we

compute okay the second thing that we do

is we want to um so the output is the

product of the output of the softmax

with the V so this is the reference

output on the naive implementation of um

flash of the attention mechanism then we

want to compute we want to also derive

the gradients of the output with respect

to the um

inputs and in this case it's the

the the V the K and the Q later we will

see what are we doing here then we want

also to we want to compare this

reference implementation with our Tron

implementation so let's do it so our

Tron implementation will be implemented

as a class called Tron attention that we

will call using this method called apply

and later we will see what is this

method in which we pass the query key

and value if we want to compute the Cal

attention the soft Mark scale that it

should be using and it should prod

produce some output which is the output

of the output of the softw multiplied by

V then we can run also the backward and

this backward will be the the the the

same backward that we will compute with

the um Tron

attention and then we

compare okay and then we can compare uh

the result of our implementation so this

Tron attention not apply with the

reference implementation which is this

one here and they should be uh we use

the the function all clause which

basically compares the elements of two

tensors and make sure that their

absolute difference is no more than this

one we are not using the relative

distance we are just using the absolute

distance between the two elements

corresponding elements of two vectors

this uh implementation that we have that

we will build will work with the causal

attention and also with not causal

attention while the the one that we saw

in the website of Tron it only works

with the uh the forward pass actually

works with the causal and non causal

while the backward pass only works in

the case of the Cal attention um okay

but it's highly optimized the one online

so if you want to learn a little more

tricks on how to optimize Triton kernels

there is a lot of knowledge there anyway

guys now let's try to uh implement this

Tron atten at least the forward pass so

let's go to implement this Triton

attention

class um okay here every time you want

to introduce a new operation into torch

you need to uh derive the um you need to

implement your operation by deriving

from this autograd do function class so

every operation in torch actually if

it's the soft Max or it's the I don't

know the the ru or the zgo or whatever

there is it is always implemented as a

as a function is a class that derives

from this function and it should provide

two method one called the forward pass

and one called the backward pass the

forward should produce the output of

this operation and the backward should

compute the

gradient um the gradient with of the

loss with respect to that the the input

of that function and later we will see

how that works for now let's concentrate

on the forward pass to implement the

forward pass we need to create a static

method that is called

forward which takes as input one thing

called the context so as you know in the

autograd in when training AAL networks

we have the forward pass and the

backward when Computing the backward

pass we need to reuse the activations of

each of the computation nodes during the

forward pass and this context basically

allow us to save the information to uh

for the necessary activations that we

will need during the backward pass and

later we will see in the tron um in the

flash attention algorithm what

information we need to save in order to

compute the backward pass for example

what we will need to save during the

backward pass we will need to recompute

on the Fly the soft the query multiply

by the transpose of the keys for each

block but we don't want to recompute the

normalization factor or the maximum

value for each row so we will save those

two values and actually we will not save

two values we will save one value with a

trick called The Log sum X log sum X

trick that we will see later anyway this

context is just a kind of a storage area

where we can save some stuff that will

be necessary for us to recompute the

backward and you can see whatever you

like then we have the input of this um

operation which is the query key and

value which is three tensors with the

Cal if we are going to compute the Cal

attention and the soft Mark scale that

we should apply based on the one over

the square root of the uh head

Dimension uh which we could also compute

it on the fly actually by the way by by

checking the shape of this but okay it

doesn't matter anyway so um the first

thing that we are going to do is to

extract the shapes of this objects and

make sure all the shapes are what we

expect them to be so the shape of the

query key and value is a bch size by

number of heads by sequence length by

head Dimension we make sure that the

head Dimension matches for the query key

and value uh they should match because

each Vector should should be of the same

size uh and then we declare what we

pre-allocate the output Vector so where

we should save our output so as you

remember the output in and the attention

mechanism has the same same shape as the

query key and value sequence where theer

Quan value sequence I want to remind you

is not the query value of the input of

the attention which is a sequence of

tokens but it's the output already of

the WQ WK and WV because flesh attention

is not concerned with optimizing those

metrix multiplication but only the

output of the WQ WK and

WB so we pre-allocate the output tensor

where we will store this output which

has the same shape as the query key and

sequence uh um Matrix

actually actually no not true actually

it has the same shape as the query but

it may not be the same as the key and

value why because there is this thing

called cross attention where the query

key and value are transposition are

different projection through WQ WK WV

not of the same input sequence but of

two sequence so cross attention happens

when we have a query that comes from one

uh sequence and the key and value come

from another sequence and they pass

through their own WK and WV and they may

not have the same sequence length so the

shapes of the output of the tension only

depends on the shape of the query

sequence not of the key and value

sequence this is happens during cross

attention but usually in language models

we always work with the self attention

so that should not happen um at least in

the Cal language models um then we have

um the stage and later we will see what

is this stage uh basically the stage

it's just a number that that tells if

the um operation that we are going to do

later is for the causal attention or for

the not causal attention and then we

need to La Define our launch grid the

launch grid tells us how many parallel

process we need to be launched by Tron

actually they will be launched by Cuda

but by we always work with the Triton as

an interface to Cuda so by

Triton so in Tron as I said before we

want to parallelize along the batch

Dimension so each batch each sequence in

the batch should work independently from

each other not only each inside of each

sequence in the batch each head should

work independently from each other so at

least we have a bch size multiplied by

number of heads

programs and for each of these program

we have another uh Dimension called the

we divide the query into blocks of

queries so um as you remember when

talking about a block matrix

multiplication we don't work with the

query as the original Matrix query

Matrix so where each query is one vector

or one token we work with group of

queries so each block of queries is a

group of tokens in the query

sequence so we are saying that we want

to launch at a number of um um kernels

or blocks of threads or a group of

threads along two Dimensions just like

the Cuda kernel can be launched along

Ong two Dimension X and Y here we are

launching programs along two Dimensions

one Dimensions that tells us which batch

which head of which batch we are going

to work with so which head of which uh

batch element are we going to work with

and inside this we are going to say okay

this is a sequence which group of

queries are we going to work

with are we going

to going to work with so overall and the

group of queries is what is the sequence

length divided by the number of queries

that we want to group together so the

block size Cube tells us how many

queries are there in each block of

queries so this CD is just the ceiling

division so it is equal to let me write

it here this is equal to ceiling of

sequence length divide by the block size

Q

this tells us uh how many blocks of Q we

have so let's rehearse we have a tensor

that is q that is B size by number of

heads and each flashh attention

algorithm will work with the following

the sequence length head Dimension

moreover we have seen that the flesh

attention has two Loops one is the outer

loop among all the query blocks one is

is the inner loop along all the key

block we have seen that the query block

can work independently from each other

so we can spawn as many programs in

parallel as there are number of blocks

of Q because they can work in parallel

so this grid tells us how many programs

there are that can work in parallel then

it will be the GPU that based on its

resources will decide how many program

actually to work in parallel if it has

enough resources to make them all work

in parallel wonderful if it doesn't have

enough resources to make them work in

par it will launch them sequentially one

after another and the last Dimension is

this is like the Z dimension in the

Cuda in the Cuda launch grid and we

don't want to use it because we don't

want an additional um level of

parallelism all right this is our launch

grid so we will launch a number of um

programs that is this one number of

programs of parallel programs or number

of parallel Kels and each kernel in Tron

work is a group of

threads which is a batch size multiplied

by number of

heads multiplied by number of blocks of

Q so how many blocks we have in we

divided the Q sequence

into okay let's continue so the we will

see what is this one so this m is

another Matrix that we will need and

it's the log sum Expo for the backward

pass and we will see at the end of this

video what it not at the end of this

video but at the end of the forward pass

what it's needed for but basically this

is you can think of it as the maximum

for each row um you we to to recompute

the query multiply by the key in the

backward pass we should also have if we

don't want to recompute the maximum for

hro and the normalization factor of the

softmax we should save two things one is

the maximum for each row and one is the

uh the normalization factor however by

using the log sum X trick we can only

save one value which is the as you can

see in the um algorithm of FL attention

it's this stuff here which is uh let's

see here it's this stuff here so this Li

which is the maximum for each row plus

the logarithm of the um of the

normalization

factor um and

basically in when Computing the back

pass we need to recompute on the Fly

this block here so the square multip by

the transpose of but to apply the soft

Max as you remember we need to have the

maximum for each row and the

normalization factor so uh we don't um

we don't recompute them during the

backward because we have already

computed them during the forward so we

save this information but we don't need

to save this two information separately

we can aggregate it into one single

value called Li and later we will see

how we can use

it all

right so we have defined find also this

one and we can proceed further so now we

launch our grid our

kernel don't be scared it's going to be

a little long so here so we are

launching the the kernel for the forward

pass by defining what is the launch grid

so how many of this program should run

in parallel at most we are passing the

query we are passing the key we are

passing the values we are passing the

soft Mark scale the M which is the

information that we need to save for the

backward pass it's actually the L in the

code of the pseo code of The Flash

attention algorithm here I call it m I

think because also in the original code

it was called

M um the O where the our kernel should

save its

output and then as you remember uh we

don't get all the nice access by um

indexing tensor like we are used to in

torch we only get a pointer to the

starting element of q a pointer to the

starting element of K and to the

starting element of v and then we have

to figure out all the index in the

memory of the other elements how to

calculate the index we need the stride

because the stride tells us how many

elements to skip to go from one

dimension to the other and that's why we

are passing the stride for each

dimension of each tensor actually in our

case uh we are only working with q k and

V that are actually of the same D type

and of the same shape so we should not

need actually to pass all all the stride

for each of these uh

tensors um because they should have the

same strides however in the original

code I believe they they were passing it

so I kept it so the stride allow will

allow us to index these pointers to

understand um to access the elements of

of this tensor just by using its

starting uh the pointer to the starting

element and then the strides we will be

able to index any element we want in the

tensor then we pass the information of

these shapes so the bch size the number

of heads the sequence length and the

head

Dimension and U which is the same for

all of them and then the stage the stage

indicates if we are going to compute

causal attention or not causal attention

so let's not implement it and let's

continue writing this method so the then

then we need to save some information

that we will be needed for the backward

pass which is this context variable that

I told you before so we save some

information for the backward pass which

is the query key and value uh which are

the tensor for which we want to compute

the gradient during the backward pass um

and then um um we need to store also

this m tensor and this o

tensor um then we can we need to also

store the causal uh variable so because

if if we computed the caal attention

during the for forward pass then during

the backward pass we need to um have

this information because we need to mask

out the things that we don't want to

contribute to the gradient but we will

see that later when Computing the

backward pass for now let's concentrate

on this attention

forward so we need to implement this

forward kernel that you can see so uh

underscore attention underscore forward

method now a Tron kernel is just a

python method with a particular

decorator called Tron go.get so we copy

and paste the signature so this is what

makes a method become a Tron kernel and

as you can see here we pass the query

key and value Matrix along with other

information the M Matrix please don't

confuse the M Matrix with the mask that

we will apply um uh on the Fly we will

generate it on the Fly because we are

only concerned in this case with a cal

attention or not causal attention we do

not accept custom masks

here um then we pass the strides of all

these tensors the B size the number

number of heads the sequence length the

head Dimension which is the shape of

each of these uh tensors um and the

block size Q and The Block size KV the

block size Q indicates how many queries

we want to group together to make one

block of the Q Matrix and how the KV

indicates how many keys and values we

want to put together to make one block

of the K and V Matrix which is what we

do when we do block matrix

multiplication this stage is a number

that indicates if it's a caal or um not

causal attention we are doing so it will

be three in case it's a caal and one in

case it's not

causal okay the first thing that we do

is to verify some information so we

verify that the um the block size of the

KV is less than or equal to the Head

Dimension to be honest I don't think we

need it with my code because I removed

most of the constraints so this uh this

check was also present in the original

code so I kept it but it all depends on

how we are later we will see what is the

autot tuning process and later we will

see uh what variables we are going to

autotune for and how many stages we will

choose how many warps we will choose etc

etc so let's leave it for later you can

comment it or you can keep it it

shouldn't

matter um the first thing that we do as

I said before we launch a grid so a grid

is a series of programs where we will

have some identifiers like in the coda

we had an identifier for the blocks on

the x axis and on the y axis in Tron we

get this identifier for the programs we

launched um

um uh sequence length divide by block

size Q number of programs along the zero

axis and the B size multiplied by number

of heads on along the first axis of the

uh green grid of the launch

grid which will help us identify which

um part of the query we are going to

work with in this program in this uh

kernel and also in which batch and on

which head this program should work with

so that's what we are going to do now we

are trying to understand what part of

the input we should work with based on

the IDS of the program which corresponds

to the block ID in

Cuda uh so let me copy so the program ID

zero indicates it's this stuff here

tells us which part of the queries so

which block of the queries we are going

to work with why do we have a block on

the query because as we saw before the

output can be computed independently for

each block of the queries while each

block of the query has to iterate

through all the key and

values um so this is what will tell us

what is the index of the block of the

queries we are going to work with in

this particular

program then we can understand also

which index which batch and which head

this program is associated with the

program ID number one is the product of

the uh batch size and number of heads it

means that we will have as many programs

on the axis number one as there are uh

indicated by this product so this

product lets us understand that this uh

product will tell us which batch and

which head this particular program is

associated with so uh to get the uh ID

of the batch we just divide this number

by the number of heads and it will give

us the head index and to get the head

index inside this batch we just do the

um this number here uh modulus the

number of heads

okay uh the next thing that we need to

do uh we need to okay first of all when

we pass a tensor because as you can see

here the Q parameter to this attention

forward method is a tensor because it's

the input of this function forward

function and this forward function is

called here when we do attention do

apply and it's this Q stuff here and

this Q stuff here has been created as a

tensor so when we pass a tensor to a

Triton kernel it's not really a tensor

it is a pointer to the first element of

that tensor in the memory now we need to

understand because now we know which

batch we are going to work with and

which head we are going to work with we

need to Index this tensor to select the

right batch and the right head inside of

the right batch which means that

basically we have this Q uh tensor so we

need to do some some sort of like some

stuff like this like Q of the index

batch and uh number of the number of

heads in the case the the which head we

are going to work with so it should be

index of head and we need to select

every everything that is inside this

indices so we are we need to enter the

tensor at the right location where the

particular sequence length and head

dimension for this batch and for this

head starts for that we need to generate

an offset in which we need to move this

tensor from because this T this pointer

sorry this not tensor this pointer from

because this pointer is pointing at the

beginning of the entire tensor so we

need to move in the bch size dimension

and in the number of heads dimension

to do that we generate the following

offset which will tell us where this uh

where this particular batch and where

this particular head starts in this

tensor and to do that we need to do the

strides we need to use the strides so

what we are going to do is we are

create going to create the qkv offset uh

this is should be the sequence

length um which will be the index batch

multiplied by The

Stride for the batch Dimension which

will tell do how many elements we need

to skip to get to the next batch and

it's based and we multiply it by the

index of the batch that we want so for

the zero batch we don't skip anything

because we already pointing to the first

element of that batch but if we are at

the batch one we will skip that many

number of elements plus we also need to

skip the some heads how many head we

need to skip based on which head we are

going to work with and what tells us how

how to go from one head to the next The

Stride of the head dimension so we

multiply the index head so the head that

we should be working with with the

stride Q

head all right then um we select now

Tron helps us with a new function that I

think it was quite recent that helps us

index element inside of a tensor without

having to deal with all the complex um

indexing maths that can be confusing for

beginners so I will be using few methods

to help us with this um uh with this

indexing and this function is called

make block pointer and it's this

following so basically this make block

pointer takes as input a

vector and sorry a pointer not a vector

takes as input a pointer in this case we

are

saying create a

block that has the following shape that

is sequence length by head dimmension so

let me do it one by one actually I don't

want to confuse you guys with all this

stuff Al together okay so take start um

there is a pointer that is right now

pointing at Q Plus Q KV offset so right

now it is not pointing at the first

batch but it's pointing exactly to our

batch so the the batch that this

particular program should be working

with and inside this badge to the

particular head that this program should

be working with which is is basically

saying that we have um we are pointing

to a tensor that is as follows so we are

pointing to the following tensors which

is the right

head the right uh sorry the right batch

and the right head and then we are

selecting everything inside so it's

pointing to the first element of this

particular

tensor this tensor particular tensor

because we have already selected the

batch and the head it is a

two-dimensional tensor with this the

following shape because the the

following dimensions are sequence length

and head dim so we are saying take this

pointer which contains a tensor of the

following shape sequence length and head

Dimension and I'm also giving you the

strides of this Dimensions that are in

this pointer so the the the two

Dimensions that are that we need are the

sequence Dimension and the head dim

Dimension which is this one for the Q

tensor and

um and in this um in this

query uh tensor we want to

select a block of queries based on the

query on the block of queries that this

program should be working with so I

think I need to maybe probably use the

iPad otherwise it can be very confusing

to visualize so uh let's do it actually

so let me see if I can create another

another here and let's use the

iPad all

right okay so we have a Q Vector Q

tensor because this construct we will be

using it for all the other tensor so if

you understand it for one tensor you

understand it for all the others we have

a Q tensor that is a

batch uh by number of

heads number of heads then the sequence

length and then the head

Dimension with the following line so the

this line here so when we create Q Plus

uh qkv offset we are already selecting

the right batch Dimension and already

the right head Dimension which means

that we have already forwarded our Q to

not point to the first batch and the

first head but to point to the exact

batch that this program is working with

and the exact head that this program is

working with which basically means that

right now it is pointing at a tensor

that is made up of these two

Dimensions now inside of this tensor we

also need to select the right block of

query that this program should work with

and this Dimension here so the sequence

Dimension is all the queries so we need

to select select the right queries so we

need to skip some queries how to skip

some queries well we say that we need to

skip block index multiplied by block

size Q number of queries because they

will be processed by another um by

another uh program that will have this

number here the program ID will be

different so we are selecting with this

line not only inside of the que the

right index and the head but also the

right position in this dimension in the

sequence length Dimension that will

point to the exact to the starting point

of the exact query block that this

particular program should be working

with this is what it's happening and we

are also creating this block basically

later we will see how it can be used um

to uh to create a block of the shape uh

we are telling what is the the size of

this tensor so this tensor has two

Dimensions because we are pointing to

the beginning of the right query

sequence so it has only two Dimensions

the sequence Dimension and the head dim

Dimension so it's the last

Dimension um and we are already pointing

to the right beginning of the sequence

uh Dimension because we have already

skipped some queries why we are skipping

some queries because these queries will

be handled by another program that will

have a block index Q to some other

values um and this order uh actually I

don't know what is this order you can

try to put 01 and one two I think it's

some optimization that Triton does I

have read the online documentation and I

couldn't find anything about it so this

is something that I will investigate but

actually even if you put a 01 it doesn't

matter so I think it's something that

you tell Triton uh if this you want the

transposed of this block or you want the

not transposed version of this block and

later we will see actually how we can

transpose the key block without doing

any transpose operation actually we we

will just change the strides like we

have seen before so um now this make

block pointer is not something that is

necessary but it makes our life easier

when we will Index this particular

pointer so we can treat this pointer

nearly as um nearly in the same way when

we work with the tensor in pytorch we

will be able to skip one uh increase uh

one index in one dimension without

having to do the computation of the

strides later when doing the backward

pass I will not use this one and do all

the pointer indexing by hand so you can

check the differences of indexing a

tensor by using make block pointer and

not by using it anyway to rehearse what

are we creating we are creating a

pointer to the right index in the batch

to the right index in the head Dimension

and we are already skipping some queries

based on the Block index que so this

pointer is already point to the right

block of queries that this particular

program should be working with let's

look Instead at the V and the K block

now so let's copy the V block now which

is similar to the query but we are not

going inside we are only indexing by the

index badge and the index head so what

this one actually let me write it here

is already skipping

so this amount of

queries this is what we are indexing

with this make block pointer so we are

in the right batch in the right head and

we are skipping some

queries here we are just indexing by

batch and by head so we are doing V of

index batch index head and we are not

selecting we are not skipping anything

because you see this offset is equal to

zero in the First Dimension in the

second dimension so we are not skipping

anything on the sequence length and we

are not SK anything in the head

dimmension Dimension head Dimension

Dimension um all right so let's look at

the K block pointer and this is

different because as you know when

Computing the flashh attention algorithm

we need to have access to the block of

queries and all the block of the key

transposed so when accessing the key we

shouldn't access it like we are

accessing Q we should invert the two uh

Di Dimensions that we want to transpose

for and that's very simple with make

block PTR and you can see it here we say

that we want to point to the right index

and to the right head and the tensor be

inside of it so let's let me write it

here so later I can explain in line by

line so what we are doing here is go to

the K tensor select the right batch

select the right head select everything

that is inside so it's a tensor of two

Dimensions with the sequence length and

the head dim because we we you can see

here um here uh sequence length and head

dims Etc but we don't want first

sequence length and then head dim we

want first head dim and then sequence

length so we want to transpose it how to

transpose it we just say that you need

to read this tensor with the two strides

transposed so we are saying first use

the stride of the dimension Dimension

and then use the stride of the sequence

Dimension and the shape of this uh

tensor is not sequence uh head dim it's

head dim sequence and it's a block of

KVs um why we are not putting directly

the sequence Dimension here because we

want to skip block by block later so we

are not selecting all the sequence

length in the sequence Dimension we are

just selecting a block of KVs and later

we will use another method to go to the

next block

so I hope that by showing you the

indexing like this it's a little easier

to follow the indexing so for each

tensor we are going in the right batch

in the right head Dimension and for the

query we are skipping some query blocks

because each each program will work with

a small different um query block but for

the key and value each program needs to

iterate through all the key and value so

we just point it to the first key and

value block and then we will advance by

one block by um we will advance uh one

block uh by one during the for Loop that

we will do

later then in the output also we need we

can make um a tensor block

tensor this basically creates a pointer

just like in the query key and value

case in which we select the right index

batch so what we are doing is we are

indexing by

batch we are indexing by head and we are

selecting everything that inside Unown

we are not cting everything inside we

are skipping also in this case some

blocks of queries uh because as I said

before the output has the same shape as

the query so um the this particular

block this particular program that we

that will have this particular block

index Q will only work with one block of

uh the queries which will produce only

one block of the output Matrix and we

need to select exactly that one so we we

can point this pointer exactly to the

point where we should start writing so

let's skip also in this case block index

Q multiplied by block size Q um

rows so we select exactly the block that

our um our program this particular

program will produce when I speak about

this particular program I mean the

program that is identified by this

program ID in the xzero axis and this

program ID in the first axis because

each of these program will run in

parallel hopefully and each of them will

have a different value for the block

index q and index batch

head okay now we have pointed our

pointers to the right position where

they should either read some information

or they should either write some

information by using make block pointer

these uh pointers can also be treated

directly as tensors so that's why we

specify the shapes of these tensors

because python uh Tryon right now

provides some methods to work directly

with uh blocks of um to work directly

with pointers like they are we are

accessing um tensors so we can index

them like

tensors all right so basically just try

on doing some calculation for you based

on the strides so you don't have to do

it by hand but later when we do the back

part T we will avoid using make block

pointer and we will see the indexing

done by

hand all right uh um as you know we are

processing a single block of queries so

let let's go back to the um algorithm

otherwise we we lose uh the site of what

we are

doing so let's go here and let's show my

iPad all right so as you know each

program we will parallelize along the

query block Dimension so each program

will work with a different query block

and then we need to do a for loop on all

the key and values

blocks right now we just uh moved our

pointers to the right position to select

the right query block that we should

work with and to the beginning of the

keys and values block that we should

work with based on which index and which

head this particular program should be

working

with all right now that we have pointed

our pointers to the right position in

which our program should be working it

inside of the big pointers that are um

inside of the big tensors that are the

that have the batch Dimension the number

of heads Dimension the sequence length

Dimension and the head Dimension we have

because we are pointing to the right

batch and we are pointing to the right

head these tensors have become two

dimensional tensors so they only work on

the they are only tensors on the

sequence length and on the head

Dimension now we need some mod some more

information that we will use

later the first information that we we

need is the offsets of each query inside

of the current block of queries that

this particular program should be

working with and that is given by the

following line so let me copy and paste

which is this one so the offsets of the

queries are the first of all they are

how many of them block size Q because

each block of queries is made up of

block size Q number of queries what is

each where it's a token and it's on the

head Dimension is um the dim Dimension

is not the all the embedding of the

token but a part of the embedding of

each token which part the part

corresponding to the head that this

particular program is going to work with

so we are generating the offsets that

will load this particular number of uh

this particular queries from the big

tensor that contains all queries and we

know that our queries start at the Block

index cu multiplied by block size Q

position so if this is the program

number zero they will the Imagine block

size is equal to four they will be the

query with index 0 1 two and three but

imagine we are the program number three

which means that we need to skip um

three multiplied by four so 12 so it

will point to the query number 13 14 15

and 16 etc etc

etc all right and we do the same for the

key and values initially the key and

values um is a range of keys and values

that we need at e each iteration and at

the beginning because our pointer for

the K andv is pointing to the beginning

of the sequence of key and value for

this particular badge and for this

particular head we are pointing to the

first block of key and value so we are

not skipping anything in the query case

we are skipping because our program will

only work with one single block of

queries in this case we don't skip

anything because we need to iterate

through these key and values so we are

pointing to the first uh block of key

values so imagine block size KV is equal

to four so this stuff here will be equal

to zero 1 2 and

three all right now we need as you

remember inside of the flashh atten

algorithm we need to compute a block of

query multiplied by the transpose of the

keys and to each of this block we need

to apply the soft Max star if you

remember what is the soft Max star is

the soft Max of without the

normalization so while Computing the

soft Max star we also actually compute

the normalization factor without

applying it and we apply the

normalization factor at the end so for

each block of query multiplied by

transpose of the keys we need to have

the maximum for each row in this

particular block and the normalization

factor for each row so that's why why we

need these two following

statistics which is this one and this

this is basically um a block uh it's a

block of numbers how many based on how

many queries we have in our block of

queries each one initialize with minus

infinity just like in my algorithm that

I show before so let me go back to the

slides in case we forgot um or actually

you can also check the flash attention

algorithm we initialize it with minus

Infinities so so far we are creating

this stuff here so we are initializing

the Mi we are we will be initializing

the LI I and we will initializing the o

and then we will show the inner loop

here um and this is exactly the

algorithm that we have seen before so we

initialize M with minus

Infinities now we initialize also the

L's so let me go back to the

code uh all

right so the L's are initialized with

this number here so

here in the O blocks as we can see from

the flashh algorithm they are the O

Block is initialized with zeros so

that's why we initialize a block this is

the output block that this particular

program will compute which is based on

the position in the batch and the

position in the indexes so it is one

block of the size block size Q so how

many queries there are in this block by

head

Dimension um which if you want to

visualize it let's go back to the slides

it is equal

to one block of this Matrix here so it's

one one block of the output Matrix so

one row of blocks uh one block of

rows uh okay so let's go back to the

code

now all right so now we have initialized

little stuff here so the output the Mi

and Li where Mi is the maximum for each

row in this particular query block and

the LI is the normalization factor for

each of the um items in the query for

each of

the rows in our query

block now we need to do the for Loop the

inner loop in the flashh attention

algorithm uh we will create a separate

method that will run the inner loop so

let's let me copy

it here and I am following the same

structure of the code that you see in

the tutorial of the Tron

website so um basically if we are

running the Cal attention or even if we

are not running the Cal attention we

make this for Loop and then we will make

another for Loop and I will show you why

so let me first write it and then we

will see so this function here will be

the inner loop this inner loop needs to

go through all key and value blocks one

by one and for each query and value

block it needs to fix the previous

calculated block of uh the the previous

softmax star block so basically what we

are doing here we will need to create a

function as the following where we are

going to iterate on all the key value

block we will need to compute the query

multiply by the transpose of the keys

using the query block that is fixed for

this program and the key is block is the

one that we are iterating it through and

for each of these queries we need to

calculate what is the maximum for each

row we need to compute the softmax star

so the softmax without the normalization

factor

we need to keep the statistics L which

is the normalization factor that we will

apply at the end of the iteration of the

for Loop and at the same time we need to

update the output so as you remember the

output is P11 * by V1 plus p12 * by vs2

but we need to fix the previous P11 so

to fix that we need to every time we sum

to O to the output we need to fix the

output of the previous

iteration um and then we increase

introduce the p and v block of the

current iteration so um here the author

of The the code for the the one that you

see on the Tron website decided to split

this for Loop into two steps why because

in the causal attention we need to uh

when we have a caal attention we have a

group of um we we don't we don't want

the query to attend keys that come after

it while in the non-causal attention we

let all the queries attend to all the

keys which also means that we will need

to have some kind of if statement inside

of this if U inside of this for Loop

through all the key and values in which

we need to check if the this particular

query that we are working with is comes

before or after the key and value in

case we are doing the caal attention so

instead of uh iterating through all the

keyan values also in the case of the Cal

attention by splitting it into two uh

steps we are saying uh first let's

iterate through all the keyan values for

which the index is smaller than the

current queries block and for this we

need to compute the attention in the

case of the causal and non-causal case

then for all the elements on the right

of this block so for which the key index

is more than the Q index in the case of

causal attention we don't need to

compute anything because it will be

masked out because in the soft Max it

will become zeros so it will not

contribute to the output so we don't

even have to compute

it um this is why we split this this for

Loop into two steps so first we iterate

to all the parts that are left to the

diagonal of the query multiplied by the

key

Matrix so for all the values for which

the query index is less than the key

index then we um and then we skip all

the parts to the right of this diagonal

in case we are working with a cal mask

but in case of the non-causal Mask we

compute the left part and the right part

of this

diagonal all right don't worry when we

recorde this for Loop it will be more

clear so I just wanted to give a little

introduction so let's go uh code this

inner loop what will this inner loop do

it will work with this particular query

block that we have found so this Q block

it will uh why I don't see the Q block

because I didn't load it well yeah uh

let's load it so we need to load the

query block actually we forgot to load

it so as you remember in Tron we we load

um data from the high bandwidth memory

to the SRAM so to the shared memory by

using the load statement and we are

telling load the query block that we

should be working with because this

pointer Q block PTR is already pointing

to the right block that we should be

working with so it's already skipping

all the blocks that other programs

should be working with and it will load

a uh a tensor of size of block size Q

had dim so the right block of

queries and we pass it to this inner

loop to which we pass the output so

where it should write this output the Li

and Mi which are the statistics for the

rows and for for the maximum for each

row of each query and the LI I which is

the normalization factor for each query

and the query block this program should

be working with the beginning of the key

and value block pointer because we need

to iterate through them so we just point

it to the beginning and then inside the

for inner for Loop we will iterate

through them then the soft Max scale

that we should use when Computing query

multiplied by the transpose of the keys

the block size so how many queries we

have in each block of Q and how many key

and value we have in each block of KV

uh this is a stage that tells us what uh

if we are on the left side of the

diagonal or on the right side of the

diagonal so it will tell us if we need

to apply the Cal mask or not based on

where we are and if we are need to apply

the Cal

mask um the offset q and the offset KV

are just the offsets of the query and

key inside of each q and KV block which

is a list of indices that tells us um

how many queries we have

uh and then the sequence length the

entire sequence length because in the

for Loop we need to iterate to all the

sequence length block by block so block

of KV block of KV block of KV all right

let's write this me let's write this

method and later we need actually need

to continue this method again so let's

go and let me go

here all right

so uh this method we have already seen

the signature so it's just another

kernel so it can be called by the first

kernel and this is something you can

also do in Cuda you can actually call

call one Cuda kernel from another Cuda

kernel um and then we based on the stage

of this inner loop we decide what we

need to do so when we are using caal um

causal attention so we only want to

apply the um attention to the queries

for which the index is less than or

equal to the key so we only want the

query to never attend to key and value

that come after it then um we pass the

value three for the stage parameter now

when we in the Cal case this will become

4 minus 3 it is equal to 1 so what will

happen is that we will only work with

the range of um keys and values that are

are from zero up to the current block of

Q so all the keys that whose index is

less than or less than the the the index

of the queries we are working with so to

the left part of the Cal mask let me

draw it otherwise I think it's going to

be very difficult to follow so let's do

it actually so let's open a new one and

let's go here all right so we have been

using this one before so we can do it

again clear page all right in this now I

I want you to think of the following uh

Matrix as a block Matrix so let's draw

it in pink because I have been drawing

it all in pink we know that in the rows

of this query multiplied by the

transpose of the keys we have a uh the

queries blocks of queries so we are not

watching one single block we are

watching all the blocks right now so

this is the query block one this is the

query block two this is the query block

three this is the query block four each

of this query block is made up of

multiple tokens of queries and then we

have the key the key

blocks uh let's do it like this very

ugly but okay uh key 1 key block two key

block three key block four when apply

calculating the attention when you

calculate the caal attention so um like

with the causal mask you want only the

query to attend to keys that come before

it so when we apply the causal mask this

stuff here will be made up of zeros this

stuff here will be made up of zeros this

stuff here will be made up of zeros and

this stuff here and this stuff here and

this stuff here all made up of

zeros we never have to mask out anything

when we are in this case

because well when we are in this

particular scenario actually in this

particular scenario we don't need to

mask out anything for sure why because

all the key um keys in this block so in

this block of keys keys will have an

index that is smaller than the index of

the corresponding queries in case the uh

the key the block size of the query and

the key matches so imagine each query is

made up of three queries so each block

of query is made up of three queries so

this is the query number 0 1 and two

this is the query number 3 4 five really

3 4 five yeah this will be the number uh

six 7 and eight and this will be the

query number nine 10 10 and 11 in total

we have 12 queries we will have the same

indices also for the keys in case we

choose the same uh size for the blocks

so this key key block here will be the

key number 0 1 and

two this will be the key number three

four five this will be the six six 7 and

eight etc etc ET now what happens is

that in this case as you can see the key

in this in indices of the keys are

always smaller than the indices of the

queries so we don't need to mask out

anything even in the case of the Cal

mask because we are sure that in this

case all of these dot products will

never be masked out also in this case

all these dot products will never be

masked out and also in this case we'll

never be masked out we'll never be

masked out and we'll never be masked out

and in this case however along the

diagonal some of the queries will be

more have will have an index that is

bigger than than that of the keys and

some of them will not be uh will not

have an index that is bigger than that

of the keys because these are blocks of

queries and blocks of keys some of them

need to be masked out and some of them

don't need to be masked out so we are

dividing our for Loop into multiple

steps the first step that we are doing

is all to the left of this diagonal in

which we don't need to mask out anything

then we will see another step here in

which we um uh we need to mask out and

then everything to the right of this

will be we will not even compute in the

case of cation because we already know

it's made up of zero so it will not comp

so the product query multiplied by the

by transpose of the keys after the

softmax will be made up of zeros so if

you look at the flesh attention

algorithm so um this stuff here the

contribution will be zero because we are

multiplying zero with v it will be zero

so we don't need to change the output so

why even compute this part of the Matrix

if we already know it's not going to

contribute to the output so we just skip

all those iterations and this is why we

are splitting the for loop I hope now

it's much more

clear all right so let's go back um okay

so uh we are now to the left part of the

diagonal in case of the stage number one

in the case of the stage number two it's

the part in exactly on the diagonal so

in which we need to do some dot products

and some other dot products we don't

need to do and then for the non-causal

attention we just go through the from

zero to the sequence length without

doing this

multi-step

um because we don't need to mask out

anything so this is why we have this

stage this tells us what is the lower

and higher index of the key block that

this particular stage should be working

with all right um now this function here

multiple of is just telling Tron that

this number here is a multiple of this

number so Tron can make some

optimizations so the stage one happens

when

when we are doing a causal attention so

stage number three in this function and

4 minus 3 will become one so imagine we

are in the causal attention we will go

through the key and value block that are

to the left of the uh diagonal with

respect to the query block that we are

working

with um in the case we are doing not

causal attention in this first call to

the uh inner function this the stage

will be one so the uh four minus stage

will be equal to three so we will

execute this part of the if statement so

we will go through all the key and

values in

case uh for the Cal attention only as

you can see here we will do another

iteration here that will only be done

along the diagonal in which we need to

mask out something and we don't need to

mask out something because inside of

each blocks there will be some keys that

have the index uh below the index of the

query and some that have above the index

of the query quy so only in the Cal

attention we will call this function

twice the first time with Stage equal to

one and the second time with Stage equal

to two and the second time we will only

iterate through the group of kyv blocks

that are exactly on the diagonal of the

um Matrix S multiply by transpose of the

keys the big Matrix that is made up of

all the

blocks all right now that this should be

clear let's proceed further so let's um

because we need to do the for Loop the

inner for Loop of the flashh attention

let's go and load the first blocks of

key and values which is exactly the one

that the key and V blocks are currently

pointing at which is the Zer z00

block so uh we we defined the the

pointers basically um we we we point the

key and value blocks to the first uh key

and value block that this uh for Loop

should be working with which will be

based on the stage so if it's the first

call to this function they will be

pointing to the first block in the case

of the uh causal and non causal if it's

the second call to this function which

only happs happens in the case of the

causal attention they will be pointing

exactly to the key and value block to

the

diagonal all right then we need to make

the for

Loop so let's Loop over all the for Loop

so let's do it

so Loop over the key and value and what

we do is um okay we we let the compiler

know that this number here the start KV

will always be a multiple of the block

size KV because we will be moving from

one KV block to the next KV block block

by block so we let the compiler know

that this number here start KV is a

multiple of block size KV it doesn't

change anything from a logic point of

view we are just telling giving some

hint to the compiler so it can do some

other optimization that Tron does um now

the first thing that we see in the flash

attention algorithm is we need to

compute the product of the query so this

the particular block of the query that

we are working with with the current KV

Block in this iteration so let's do it

so we compute K andv so we load the the

the query have already been loaded by

the color of this function we have

loaded it

here here we have a already loaded the

query but we need to load the current

block of K so we load the current block

of K indicated by the K pointer and we

multi we do the matrix multiplication of

the current block of query the the block

of query with the current block of K

which is already transposed because when

we loaded this k k when we defined the K

block pointer we defined it already with

the stride changed so we are reading the

T already already transposed so we are

doing the query multipli by the

transpose of the keys

basically okay now let's do

here this part here basically

saying okay if the stage is to when the

stages to is when we are exactly on the

diagonal we know that some of the

queries will have an index that is

bigger than that of the keys and some of

them we have an index that is smaller

than that of the keys so we need to

apply the Cal mask only in this case so

uh basically what we do is we Define the

mask that we should be applying so the

mask will mask out all the values for

which this mask is not true so when this

mask is true when the index of the query

is more than the index of the K and vs

and um we uh okay we apply the soft Max

scale so as you remember we here we only

computed query multiplied by the

transpose of the keys but we also need

to divide by the square root of head

Dimension and we do it

here um and then we because we already

computed

the uh the the product we can calculate

the maximum for each

row and then we uh we we substract

because uh when later in the flashh

attention algorithm we have another

operation which is the which I call the

softmax star and as you remember the

softmax star needs to to do uh each row

each element of the S Matrix so the

query multiplied by the transpose of the

keys minus the maximum for each row so

we can already compute the maximum for

each row and we can also before

Computing the maximum for each row we

need to mask out all the elements that

will be masked out in the stage number

two which is along the

diagonal and how to mask out we just

replace with minus infinity before

applying the soft Max all the values for

which the mask is false um so right now

we are we have computed what we have

computed the query multiplied by

transport of the keys we have masked out

in case we need to mask and when we need

to mask only when we are along the

diagonal in all the other cases we don't

need to mask out anything we just

multiply by the soft Max scale and then

we um we subtract the MJ the MJ is the

maximum value for each row because we

need to compute the softmax star

operation which is the softmax with the

normalization which in the flash

attention algorithm is exactly this

operation which will produce the

P okay so let's go here so now we can

compute the P block which is this stuff

here which is the exponential of the

query KV block variable here which have

already substracted the m so we have

already substracted this uh Mi at the

previous instruction so now we can just

apply the exponential and this is what

we are doing here uh okay then we need

to compute the uh sum of the the row for

the before the um uh normalization

Factor so for the current block we will

have a a list of uh we we will have the

P block for the current KV block to

compute the normalization factor for the

softmax we need to keep summing up these

exponentials and later we will fix the

exponentials the um the normalization

factor that we computed at the previous

step but we will do that later so now we

just computed the normalization factor

for the current block which is just the

sum of all the values on a single row uh

which is the same as what we did before

here as you can see

here when I show you the algorithm so

for each um for each block we do the row

sum as you can see here

of the P Matrix what is the P Matrix is

the exponential of the S minus M and for

now we didn't apply the the the

correction to the previous block that's

it so we computed the L J for the

current K andv block and then we comput

the correction factor for the previous

block so the correction factor for the

previous block if you remember the

formula from the paper is this one is

the exponential of the previous estimate

of the maximum minus the current

estimate of the maximum which is exactly

this one so the previous estimate of the

maximum minus the current estimate of

the maximum uh we will see later why Mi

is the previous estimate of the maximum

and what m j is the current estimate of

the maximum because it is coming from

the current block that we are Computing

Mi is the let's say the the one that um

it is the the one of the previous

iteration because later we will override

Mi with m j but I'm just following the

flashh attention gthm so far so I am

Computing the correction factor of the

previous Li which in the flash attention

algorithm is let me show you uh this

stuff here so it is this stuff here this

one

here um okay and then we apply it so

apply the correction factor so we apply

it so we apply the previous Li with the

correction factor plus the current Li

which is the one coming from the current

P block the one that we computed with

the current K andv with the current

iteration and right now we are doing

this operation so Li is equal to the

previous Li multiplied by the correction

factor all right and then what we need

to do okay we need to as you remember

the formula is um we um calculate the P

block and then we need to multiply by

the V block so we need to load the V

block so let's load it

we loaded the V block based on the

pointer of the V block to which this um

to to which the pointer V is is pointing

to at the beginning of this iteration in

case we are in stage number three so in

case we are doing for example not Cal

attention it will be pointing to the

first k v block uh V block and then okay

here there is just a type conversion so

we make sure this is in floating Point

16 and then we

compute the output block so we are

Computing the following so we just take

v p multiply it by V and we add it to

the output and this is what we are doing

here we take P we multiply it by V and

we add it to the O Block uh let's go

actually to this line one by one so

first of all we need to fix the previous

Aus block with the correction factor

correction factor that we have here so

we can fix the previous block with this

Alpha term here which is the correction

factor for the previous

block and so we just fixed the previous

block for now but we didn't add the new

PV so to add the new PV we do the dot

product of p and v and this third

argument tells the dot this not DOT

product it's actually the matrix

multiplication uh tell this matrix

multiplication to use this element here

as uh the accumulator so this is exactly

the same as doing uh P block multiplied

by the V block

uh O Block Plus equal to P block

multiplied by the V block uh this is

just optimized because anyway this dot

uh function here needs some place where

to store the intermediate results so why

not just store it where it should

actually go and because it um the dot

the the the matrix multiplication is

just a DOT product and the dot product

is just a repeated sum this accumulator

will be will this dot will keep summing

the result to this block here which will

exactly result in this uh um instruction

like we have done the matrix

multiplication separately and we added

it to the O Block so this is uh that's

why this argument is called the

accumulator all right so we have also

computed the output and then we save the

new estimation of the maximum for the

current iteration and it becomes Mi so

at the next we can use it to calculate

the correction

factor and then we have finished for the

current block and then we can move on to

the next block so we advance our K and V

pointers by one block of K and v um we

advance it differently because we know

that the V block is a pointer to a

tensor of shape let me write it here

this is a tensor of shape uh sequence

length head dim

so we need to increase the sequence

length by one KV uh the block size KV uh

while the K block is actually the K

transpose block so we need to and it is

transposed because we have exchanged the

strides and the shape so it is head

Dimension head Dimension sequence length

so we don't change the head Dimension we

just Advance the sequence length by

sequence uh block size KB so basically

we are just going to point to the next

block of K and to the next block of a

v I hope to you were able to follow the

algorithm of flash atation I try to use

the same names I try to use the more or

less the same logic and always writing

the formula that I am referring to so

hopefully you didn't get lost I think

the only difference that there is

between the flashh attention algorithm

as written on the paper and this code is

probably this Alpha which is the

correction factor but I hope it's easily

understandable anyway um then we just

return the O Block so O

Block Li I which is the um the

normalization factor for each row in the

current output block which is also a q

block because we are working with one Q

block independently from the other

programs and Mi is the maximum value for

each row which will be needed for the

backward pass because when in the

backward pass we will compute the Q quy

query multip by transport of the key

blck on the Fly we need to also apply

the soft Max but instead of recomputing

the stuff which we already computed

during the forward pass we just save

them and reuse them during the backward

pass which will save us some

computation um now I know it's time to

talk about the log some X trick because

we are going to use it so let's go back

to the old method so let's go here all

right so we have computed two calls of

this function in case we are working

with caal attention in case of the we

are Computing Cal attention we call this

function once to work with all the query

blocks that are to the left side of the

diagonal of the query key Matrix then we

do another call of this function to work

only with those blocks of keys that

exactly lie on the diagonal of the query

key uh

Matrix uh because in this case some of

the values need to be masked out and

some of them do not need to be masked

out moreover by doing this we can avoid

computer in the dot products for all

those values in the Cal m in the causal

case for which the key is index of the

key is higher than the index of the

query saving some computation because

anyway they will be resulting after the

soft Max in zeros and they will not

contribute to the output so it should be

faster uh okay now let's go back to the

this method here so calling method and

there is one last thing that we need to

do which is uh we need to compute the

log some exp and now I will show you

what is it so in order for the backward

pass to recompute the soft Max without

having to recalculate the normalization

factor and the maximum value for each

row we should be actually saving two

different stuff one is the maximum for

each row in the query block and one is

the normalization factor for each query

in the query block however there is a

trick and the trick is okay it's not

really called log some X trick because

the logm X trick is used for another

purpose but let's call it logm X trick

number two so um the logm X trick number

two is something as like this so let me

open the slides so when we do um query

multiply by transpose with the keys we

get a matrix that is made up of dot

products so something like this like

this is one dot product so let's call it

query one transpose the key1 query one

transpose the key2 this is a query two

transpose the key one one and this is

query 2 transpose the key2 then we need

to apply the soft Max right so the soft

Max is what is the let's write the

formula of the soft Max for each of

these vectors so this is a vector and

this is a vector because we applied it

by rows for each of these vectors it

will modify element wise each element as

follows so the soft Max of x i is equal

to the exponential of x i minus oh my

God I didn't leave enough space so let's

move this stuff here

back and this stuff here a little left

all right it will be the soft Max of um

the exponential of each element minus

the maximum for the current Vector to

which we are applying the soft

Max divided by the normalization factor

which is the summation over all possible

JS where n in this case is equal to two

because we have each Vector is made up

of two elements of the exponential of x

i- x

max now imagine we already have X Max

and we already have this summation in

the flashh alation algorithm in the

forward pass this stuff here is called

Li and this stuff here is called

Mi what we are going to save in the code

you can see here we are saving actually

not Mi and Li separately we will be

saving Mi I plus the logarithm of Li I

so we are going to save Mi plus the log

of Li I so what will happen is that when

we will um compute the um compute the

backward pass we need to recreate this

Matrix here on the Fly which means that

we need to recompute the query multiply

by the transpose of the keys and we to

um and then we should apply the softmax

to apply the softmax we should need this

stuff and this stuff here but we have

only this stuff here so so this is the

Mi plus the logarithm of Li so when we

Computing the softmax we will compute

the following so we will compute the

softmax as follows uh we will Define

let's call it a new softmax so let me

use another color

uh

here we will apply the softmax as

follows so

softmax of

XI let's call it the soft Max 2

because it's I don't want to confuse

soft Max is equal to the exponential of

each

element minus we will substract this

value here the one corresponding to the

current row to which we are applying the

soft Max so it will be the exponential

of x i minus m i minus the log of Li

I if we expand this expression this will

become the

exponential of

because exponential the sum of two expon

the exponential of the sum is equal to

the product of the two exponentials we

can also write it like this so it will

be the exponential of x i minus m i

divided

by the

exponential

um the exponential of the log of Li I

which guess what it is equal to the

exponential of x i minus m i

divided by Li I which is exactly the

normalization factor and we also have Mi

so instead of saving two values we save

only one value and when we apply it the

exponentials properties will take care

of actually also normalizing each value

to which we apply it um if you don't

remember the properties of the

exponential it is very simple so the

exponential of a MTI plus b is equal to

the exponential of a multiplied by the

exponential of B and the exponential of

a uh not exponential it's the

exponential a minus B is equal to the

exponential of a divided by the

exponential of B and this is the the

trick that we're using so that's why we

don't need to save two different values

we just need to save one value and then

when we apply it it will automatically

be taken care will take care of

normalizing because of the properties of

the

exponential all right let's move forward

so we have also created this um value

that we will use during the backward

pass now uh as you remember in the flash

attention algorithm we don't normalize

each block while Computing it we

normalize the output at the end and this

is exactly what we are going to do here

so we normalize the block at the end

after we have computed all the

normalization factors that we need for

all the rows that belong to the current

output

block um we save this Mi so we save it

um this Mi is what is the normalization

factor and the maximum for each row that

we will need for the backward pass so we

need to save it in a tensor that we will

use during the backward pass so we need

to understand which tensor is this and

it's the tensor that we called M which

is a tensor of a b size numb heads and

sequence length Dimensions so we need to

select the right point in this tensor to

select to where we should save this Mi

values uh so we need to select the right

B size index and the right number of uh

head

index uh so we advance this pointer by

the following offset which is M

plus um the index batch head because

each um index okay the index batch head

is what is the index of the current

program that includes information about

which head we are working working with

and which batch we are working with

because each of this um for each batch

and for each head we have a sequence

length we can skip n uh a number of

sequence length based on which index is

um okay what we are doing is basically

we are

skipping uh for each uh batch and for

each head we will have a sequence length

because each to token in the sequence

has a maximum value and each token in

the sequence will have a normalization

value So based on the current

combination of batch and head we can

skip a number of sequence length that

other programs will process so uh

because in this uh tensor we have the

sequence length as the last Dimension

and we have what is the combined index

of the batch size and number of head

size we can skip a number of sequence of

length based on the combined in index

which is given by the program index

number one which is the index batch head

that we have here and this is why we

skip here a sequence length number um

multiplied by the index batch head this

m is pointing to the

first uh element of the entire tensor so

we we are skipping the heads and the

batch based on the combined index index

batch head that this particular program

is working with and then we have off

Cube of skew is because each um of this

Kels the attention forward method will

work with one um query block uh each

query block has some indices for the

exact queries it includes and this is

given by off skew variable that you can

see here which is how many blocks of

queries we need to skip because they

will be processed by other programs plus

the range of queries that this

particular that Noti this that a

particular block of queries has so uh

imagine this particular program is

working with the queries that go from I

don't know from uh 12 to 16 then this

will be 12 13 14 15 so the normalization

factor and the maximum value for each

row we only have that for the this for

this indices of query queries so 12 13

14 and 15 and that's why we need to also

skip the number of queries that this

particular program works with which is

already included in this offset offs Q

variable all right so now we can store

the Mi so because we have the point to

which where it should be saved and we

can also store the output which was

computed of by our inner for Loop and

this guys is the forward step of the

attention flashh

attention now we should go forward which

is we should compute the backward path

we also have all the ingredients for

computing the backward pass because we

have already seen this trick which is

the logm X trick so we already know what

um how to use it to compute the query

key block during the backward pass on

the

Fly what we miss to understand the

backward pass well we need to understand

what is the first of all what is the

backward pass why do we even need a

backward pass we need to understand what

is the autograd of py torch how does it

work how to compute the gradient what is

the gradient how to compute do we need

to what is the Jacobian when Computing

the gradient on the backward pass do we

even need to compute that so we need to

derive all the formulas of the backward

pass by hand so if you are in for the

challenge let's continue all right so

now before looking at the flesh

attentions backward pass at at the

algorithm we need to understand why we

even need a backward pass and to

understand why we even need a backward

pass so before looking at the autograd

of P torch we should be looking at what

is what are derivatives what are

gradients what are jaian so that when we

talk about derivatives gradients and

jaian we don't feel lost so I will do a

very fast uh let's say rehearsal of what

these topics are now what is the

derivative when you have a function that

takes as input a real value and outputs

a real value we talk about derivatives

which is defined as follows the

derivative of the function with respect

to its variable uh X is defined as the

limit for a step size that goes to zero

of the function evaluated at X Plus H so

X plus the step size minus F evaluated

at The

X at x divided by the step size so

intuitively we are saying is the ratio

of how much the output change for a

small change for how much the input has

changed in the

function that this also gives you the

intuitive um intuition uh of why the

gradient is the derivative is also the

tells you the inclination of the tangent

line of the um to the function at the

point in which it's

evaluated I will use also the following

notation to denote the derivative so the

derivative I am used to write it as like

this so frime of X but it's also

possible to write it as D of f ofx with

respect to DX or D of Y where Y is the

output of the function with respect to X

and they are all equal to the same thing

which is the definition

above if we invert this form here and we

take H to the left side we can also

write the follows so if we want to

evaluate the function um at at the

position X+ H we can also evaluate it as

F Prime of H so the derivative of the

function in the point x multiplied by H

which is the step size plus f of x this

is actually also how we derive the ler

rule uh for computing the differential

equations but that's not the topic of

today so this H we can also call it

Delta X so f of x plus Delta X is more

or less because here we have a limit

that says when this only happens when H

is very very very small so that's why we

put this more or less approximately so f

of x plus Delta X is more or less equal

to frime of X multip by Delta X Plus F

ofx this you can also read it as follows

that if by inverting this formula um if

x changes by a little amount and this

little amount is Delta X how much y will

change y will change by this exact

amount which is the derivative of y with

respect to X so Dy with respect to DX

multiplied by how much X has changed so

this Dy DX tells us how much y will

change with a small change of x if we

multiply with the actual change of X it

will tell us by how exactly y will be

affected um I don't want to use stay too

much on this but I I would like to use

this intuition to introduce the chain

rule because imagine we have a function

of a function so imagine we have Z is

equal to F of G of

X we can think of X being mapped into a

variable y through the function G and

then y being mapped to into variable Z

through the function f if x changes by a

little bit and by a little bit I mean

Delta X how much y will change well y

will change by Delta Y what is Delta y

Delta Y is the derivative of y with

respect to X multiply by the step size

of

X but if y changes it will also affect Z

because there is a direct mapping

between Y and Z so how much Z will

change for a small change in y let's see

so if y changes from the old y by a

small step Delta y then Z will also

change by some Delta Z and this Delta Z

is the DZ of on Dy multip by Delta y if

we replace this Delta by with the Delta

y that we have computed in the

expression

above we arrive to the chain rule it

will tell us how Z will be affected so

this is Delta Z what is um the effect on

Z for a small change on X and it's the

product of the two derivatives one with

of Y with respect to S and one Z with

respect to Y and this is the chain rule

that we study in high school so it is if

you want to compute DZ on DX it is DZ on

Dy mtip Dy DX uh which is very intuitive

if you think about the following example

so you can think of Z as the price of uh

C

and X as the price of the oil how much

will the a small change in the price of

oil affect the price of a car well the

small change in the price of the oil

will affect for example a variable Y

which could be the price of uh

electricity so if how much the price of

electricity will affect the price of a

car it's through the derivative of the

price of electricity with respect to the

uh the price of the car with respect to

the um electricity so to to get the

effect of the price of oil on the price

of the car we just multiply the two

effects and this is the intuition behind

the chain rule anyway let's talk about

gradients so when we have a function

that as input takes a vector and

produces a scalar we talk not anymore

about derivatives we talk about

gradients so imagine we have a function

that takes as input a vector made up of

two Dimensions but n dimension in

general and it produces a scolar when do

we have to deal with this kind of

function for example loss functions loss

functions are something that are always

a scalar as output and as input they

take tensors so um for example imagine

the crossentropy loss it will take

a a sequence of tokens each tokens with

its own Logics and it will compute one

single number which is the

loss so how to uh View the effect on the

output with respect to the input in this

case well if x changes by a little

amount and this little amount is not

anymore a number but it's a vector so if

change the X the old the X Plus Delta x

uh is a vector sum then y will also be

affected by what y will be affected by d

y on DX multiply by Delta X however this

Delta X is not a number anymore it's a

vector because X1 may change by a little

bit X2 will change by a little bit X3

will change by a little bit X4 blah blah

blah until xn will change by a little

bit so this is actually a DOT product of

this Vector multiplied by this Vector

why a DOT product because y will be

affected by the change in X1 it will be

affected by the change in X2 it will be

change affected by the change in X3 up

to xn and each of the contribution of

the contribution of X1 will the partial

derivative of y with respect to X1

multiply by how much X1 has changed plus

the contribution FX2 will be the partial

derivative of y with respect to X2

multiply by how much X2 has changed blah

blah blah until the last uh contribution

of

xn so and the chain rule in this case

also applies in the same way as in the

scalar case so the formula does not

change also for the change rule here I

just want you to to to remind that in

this case we are talking about a

gradient and the gradient is just a

vector um made up of all the partial

derivatives of the output with respect

each of the input variables that are in

the input

Vector when we talk about a function

that have as input a vector and produces

a vector then we don't talk about

gradient anymore we talk about jacobians

so if our input X the input X of this

function changes by a little amount and

this Delta X is a vector then the output

y will also change and this output y

will change by a Delta y that is not a

number anymore it is a vector and this

Vector is the result of this quantity Dy

on DX multiplied by Delta X Delta X is a

vector so this one to be a vector it has

this one here has to be a matrix and

this Matrix is called the Jacobian it is

a matrix that has as many rows

later we will talk about the denotations

so it has as many rows as there are

output variables and as many columns as

there are input variables the first row

will be the partial derivative of the

first output variable with respect to

all the input variables the second row

will be the partial derivative of the

second output variable with respect to

all the input variables and the last row

will be the partial derivatives of the

last output uh variable with respect to

all the input variable in the input

Vector um now let's talk about notations

the Jacobian that I have written here is

is is written according to the uh

numerator layout this is called the

numerator layout and there is another um

convention called the oh not layout

sorry guys it's called the numerator

convention and there is another

convention called denominator convention

or notation in which um the rows are not

the

the number of rows is not the equivalent

to the number of output variables but

equal to the number of input variables

so um the fact that I have we we choose

to write the Jacobian as follows is

based on a convention you can also write

the the Jacobian according to the

denominator convention just by

transposing this Jacobian here and also

the formula for the chain rule changes

accordingly for now I want to keep the

formula for the chain rule just like the

one for the scolar case so that's why I

am using this notation here but later we

can change between one notation to the

other just by doing a

transposition okay now that we have

reviewed what is derivative what is a

gradient and what is a Jacobian let's

talk about um what happens when we take

derivatives with respect to tensors of a

tensor with respect to another tensor in

this case we talk about the Jacobian but

it's called the generalized Jacobian so

if we have the function that is add

input takes a tensor of DX

Dimensions where the first shape this is

kind of the shape of the tensor so the

first element of the shape is N1 the

second element of the shape of the input

Vector is N2 etc etc until n DX and it

produces an output tensor that has this

shape so M1 M2 blah blah blah m d y in

this case the formula for the chain rule

doesn't change and

if x changes by little amount so by

Delta X which is a

tensor y will also be affected by how

much by Dy on DX multip by Delta X and

this is a tensor product it will be a

Jacobian uh this is called generalized

Jacobian with the following shape so all

the dimensions of the output multiplied

by all the dimensions of the

input all right this is very abstract

for now we will see actually a concrete

case of this one because we will be

deriving the the gradient of the output

of a matrix

multiplication uh the the gradient of

the loss when Computing backward pass

with respect to each of the input of the

matrix multiplication operation and we

will do it also for the soft Max and we

will do it also for the attention so I

don't want to jump to too many topics I

just wanted us to get into the right

mindset so we know that derivatives when

we have scalar functions gradients when

the output is scolar input is a vector

Jacobian when the input and output are

both vectors generalized Jacobian when

the input and the output are tensors the

chain rule always works in the same way

all right let's talk about Auto autog

gradient I will do the scalar case and

then we will extend it to the tensor

case so imagine we have a very simple

computation graph why we have

computation graph because we are talking

about neural networks and neural

networks are nothing more than

computation graphs where we have some

input with we have some parameters and

we do some operations with this input

and parameters suppose that you have an

input a and this input a is multiplied

by a weight a parameter weight it's just

a scalar U and it produces an output y1

this y1 is then summed up with another

number called B1 and it produces Y2 this

Y2 is then raised to the power of two so

this Z to the^ of two is just the power

of two of the input and it produces Y3

and this Y3 becomes our loss function so

it's a scalar now uh what we uh want to

do to apply gradient descent is we want

to compute the gradient of the um loss

function with respect to each of the

input of this computation graph so each

of the leaves of these computation

graphs what are the leaves it's this

node here so the parameter nodes and the

input

nodes um and to do that there are two

ways one is if you have access to the uh

expression that relates Direct the input

to the uh uh output so the to the loss

then you can directly compute the the

gradient the derivative in this case

because it's not a gradient it's a

scalar versus scalar so in this case

imagine we want to compute the

derivative of the loss with respect to

W1 imagine we have access to the um

exact um expression that relates the W1

to uh to the five which is our loss we

can compute it as follows so we just

derive this expression with respect to

W1 which is two * because this is the

power of two of a function so it is two

multiplied by the function multiplied by

the the derivative of the content of

this function with respect to the

variable that we are deriving so it will

become the following expression there is

another way which is by using the chain

rule so we can use the derivative of

five with respect to Y W1 is the

derivative of I with respect to Y3 which

is the previous uh output of the

previous node then the derivative of 53

with respect to the previous the output

of the previous node so and then the

multiplied by the derivative of Y2 with

respect to the output of the previous

node and then the derivative of y1 with

respect to W1 if we do all this chain of

multiplication we will obtain the same

result and you can see that here this

stuff here is exactly equal to this

stuff

here by doing this procedure here we

will note something that is

I want to zoom out a little bit okay to

compute the the derivative of five with

respect to W1 we are doing all this

chain of multiplication but what is each

item in what is each um uh factor in

this sequence of multiplications well

this stuff here is nothing more than the

derivative of five with respect to Y2

these multiplications here are nothing

more than the derivative of five with

respect to to W uh to respect to y1 and

all of them combined are the derivative

of f with respect to W1 what pytorch

will do it will do the following pytorch

will do the backward pass because

pytorch knows what is the computation

graph that relates the output so the

loss function in this case and the

variable for which we want to compute

the gradient right now we are talking

about derivatives so uh it's not

gradient but the mechanism is exactly

the same so pytorch will say uh it will

py is like a person that knocks the door

of this operation and

says um hey

operation expon power of two if I give

you the gradient of the loss with

respect to Y3 which is one because loss

and Y3 are actually the same can you

give me the gradient of the loss with

respect to Y2 because a Pythor actually

does not Implement an autograd system in

in the sense that it does not know the

symbolic operations that led to the

output it just knows what are the

functions that computed the output and

each function has a function each

function is a class in Python that

implements two methods one is the

forward step and one is the backward

step the forward steps takes the input

so in this case Y2 and computes the

output Y3 the back First Step will take

the gradient of the loss with respect to

its output and needs to compute the

gradient of the loss with respect to its

input how can we do that well it's very

simple because pytorch will knock the

door as let me copy it this stuff here

otherwise it's not easy to go back and

forth

so okay and let's past it here pytorch

will knock the door of this function

here and will say hey if I give you

the loss of the Lo the the gradient of

the loss function with respect to your

output can you give me the gradient of

the loss function with respect to your

input yes the function can do it why

because of the chain rule this operator

here this function here can just do take

the loss uh the gradient of the loss

function with respect to its output

multiply it by the

Jacobian uh or in this case the

derivative of its output with respect to

its input and it will be equal to the

gradient of the loss with respect to its

input then pytorch will take this one

and knock the door at the next operator

which is this one this summation and

we'll say hey if I give you the gradient

of the loss with respect to your output

can you give me the gradient of the loss

with respect to your input yes this

operator can do it because this operator

just needs to apply the chain rule so it

will take the gradient of the loss with

respect to um to Y2 which is provided by

pytorch and by multiplying it with the

the Jacobian in this case it's the

derivative the derivative of the its

output with respect to its input it can

compute the uh the gradient of the loss

with respect to its input then pytorch

will take this output of this backward

pass and will knock the door of the next

operator which is this product and we

ask again the same question hey if I

give you the gradient of the loss with

respect to your output can you give me

the gradient of the loss with respect to

your input yes this will do the same

exact job it will take the gradient of

the loss with respect to the output

multiplied by the Jacobian of the output

with respect to the input and obtain the

gradient of the loss with respect to the

input and this is how py toch runs the

backward step it runs one operator at

the time backwards in the computation

graph knocking the door of each operator

and asking always the same question if I

give you the output the gradient of the

loss with respect to your output can you

give me the gradient of the loss with

respect to your input and each operator

will just apply the chain rule to to to

get this um to get this gradient to

calculate this gradient that pych needs

why pytorch cannot do it by itself

because pytorch does not do symbolic um

mathematics it does not have access to

the exact expression that each function

is Computing it just uses the function

as a blackbox that computes forward and

backward however with the Jacobian we

have a problem and let's see what is the

problem all right so up to now we have

been working with a compostition graph

that is made up of Scholars but the

things that we have said they work in

the scolar case but also in the tensor

case so let's go back see what is our

computation graph we have seen that

pytorch will go Operator by operator

asking always the same question if I

give you the gradient of the loss with

respect to your output can you compute

me the gradient of the loss with respect

to your input and each operator can just

apply the chain rule to compute that uh

imagine now that all of these operators

are working not with scolars but are

working with tensors which means that

the derivative of the output with

respect to the input of each operator is

not a derivative it will be a Jacobian

because the output will be a tensor a

generalized Jacobian and input will be a

tensor which means also that this

quantity here so the derivative of the

loss with respect to the input in this

case will not be a derivative it will be

a gradient because the output the loss

is a number always while the input in

this case y1 will be a tensor so number

output input is a tensor then we talk

about gradients so this will be a

gradient the and we will call it the

downstream gradient that the operator

needs to compute this will be the

Upstream gradient that pytorch will give

to the each of these operators so the

gradient of the loss with respect to the

output of each operator and each

operator needs to come up with this

Downstream gradient by using the

Jacobian however the Jacobian has a

problem let's see so uh imagine we are

implementing a simple operation that is

the matrix multiplication and the matrix

multiplication is takes as input X

tensor it multiplies it by a w Matrix

made up of parameters and produces a y

Matrix as output suppose that X is let's

call it n by D Matrix W is uh let's say

d by m

Matrix and so y will be a n by uh M

Matrix usually the input X is a sequence

of T of let's say vectors each of each

with d Dimensions so you can think of it

as a sequence of tokens each token is a

vector made up of the dimensions usually

we have many tokens so suppose that n

usually is at least

1,24 at least in the most recent

language models we even have millions of

tokens actually so uh ND is also

actually quite big it usually it is at

least 1,24 also so also this one

is24 um D and M M is also at least 1024

so we can actually become 22 2048 let's

say so I I like the powers of two by the

way so the problem of the Jacobian is

this if we compute want to compute this

Downstream gradient by multiplying the

Upstream gradient with the Jacobian this

Jacobian matrix is huge because look at

the dimensions here this will be a

matrix that

is um it will be well n by

m multiplied so it will be a generalized

Jacobian so it will be a tensor that has

a shape n m uh and then the input is X

so it is n by D so how many elements it

will have well it will have

1,24 multiplied by m which is

248 multip by 1024 multi by D which is

1,24 so it is at least wow it's billions

more than 1 billion

elements so it is impossible actually to

materialize this Matrix here in the

memory because in the ram of the GPU

because it will be too

big so but we need to compute this down

screen gradient because pytorch needs it

to continue calculating the gradient uh

of the loss function with respect to

each of the nodes in the computation

graph so how can we proceed the first

thing that we should notice is that this

G ve this uh Jacobian is actually a

sparse Matrix and I want to show you why

it is a actually is a super super super

sparse Matrix because um if if you look

at the input what is the effect of the

input on the output the input is a

sequence of

tokens so this is the token number one

it's a vector of some Dimensions 1,24

dimensions then we have another token as

input then we have another tokens as

input then we have another tokens as

input and we multiply by the W Matrix

which is made up of some columns uh some

uh columns so this one is n by D

right yes and W is uh d by m so d by m

this will produce a matrix that is n by

m so it will be also a sequence of

tokens each made up of M Dimensions so

it will be a matrix like this so this

will be the first output token this will

be the second output token this will be

the third output token and this will be

the fourth output

token now this output row here

is the dotproduct of this input row with

all the columns so the derivative of

each of these Dimensions with respect to

the dimensions of all the other tokens

will be zero because they do not

contribute to this output so the

Jacobian will have zeros every time the

we are calculating the derivative of

this First Dimension with respect to any

other element of other tokens that's why

we always can come up with a better

formula for computing this Downstream

gradient that does not involve the

materialization of the Jacobian because

the M the Jacobian itself is sparse so

let's see how we can optimize this uh

computation without materializing the

Jacobian in the case of matrix

multiplication because we need it for

flesh

attention all right guys so before

proceeding to the backward uh watch the

formulas of the backward path of the

flesh attention uh let's look at how to

compute the gradient of the matrix

multiplication operation with respect to

its input so imagine we are create okay

Pythor already have actually how to

compute the the the the gradient of the

uh inputs of the matrix multiplication

with the gradient of the loss with

respect to the input of the matrix

multiplication operation but in Flash

attention we are creating a custom

kernel which means that the custom

kernel is fusing multiple operations

into one operation so when pyro will

knock the door of our operator it will

ask the our operator which is the Tron

attention operator that we have built

what is the gradient of the loss

function with respect to q k and V

because that's the input of our function

so if we look at the code that we have

built so far you can see that our trial

rotation will be a node in the

computation graph that Tak take takes as

input q k and V and produces an output

then pyro will give us the gradient of

the loss with respect to that output so

it will will give us d o so the

derivative of the loss with the gradient

of the loss with respect to O and then

we'll ask this class here so Tron

attention to compute the gradient of the

loss with respect to Q K and B because

we are fusing multiple operations

together so we are Computing on the Fly

the soft Max of query query multiply by

the transp of the key and then

multiplying doing the soft Max and

multiplying it by uh V to compute the

output uh we need to compute this

gradients internally to compute this um

the gradient of the inputs so because in

this operations that we are doing fusing

together there is a matrix

multiplication we need to derive by hand

the matrix multiplication uh the

gradient of the of the loss function

with respect to the input of the matrix

multiplication operation so that we can

uh provide it to P torch that's why we

need to derive this formula uh I will uh

derive it in the simp in a very simple

way and um

and then we will do it for the soft Max

as well because these are the two things

that we need to derive by hand to derive

the formula of The Flash attentions

backward

pass so let's start imagine we have a a

computation graph uh in od in the

computation graph called the matrix

multiplication and this node in the

computation graph is doing a matrix

multiplication so it is Computing the

following operation Y is equal to X

multiplied W now what py will give us as

input when Computing the backward path

of this node py torch will give us the

gradient of the loss so it will give us

DF with respect to Dy so the output of

this node and we ask us to compute the

gradient of the loss function so the

gradient of the loss function with

respect to DX and the gradient of the

loss function with respect to the W uh

the easiest one to work with and the one

that I will be showing and the other one

I will not show in the video but I will

attach the PDF slide on how it is

computed because they are very similar

in the way they are computed so I don't

want to make the video too long for

Unnecessary

reasons let's compute the gradient of

the uh loss function with respect to the

um input so with respect to

X all right so how to do that by hand

without materializing the Jacobian

because as we have seen we cannot just

use the chain rule by materializing the

Jacobian which would be the easiest way

because the Jacobian is very big Matrix

that cannot even fit in the memory of

the GPU so we need to find a smarter way

we exploit the fact that the Jacobian is

sparse so hopefully we will get formula

that does not involve the

materialization of a very big sparse

Jacobian let's see so uh let's see um

let's when dealing with this kind of

derivations I always recommend to make

some example tensors so suppose that

that X is a tensor of size let's say n

by

D and where n let's say n is equal to 1

and D is equal to let's say three and um

X the W is a tensor also or a matrix

with the shape let's say d by

m where m is equal to let's say 4

um and Y will have as a consequence the

shape n by

m so it will have the shape um well 1 by

4 what pyto will give us pyto will give

us the following quantity so it will

give us this stuff here so the gradient

of the loss function with respect to the

output of this operator which is y so it

will give us a vector or a tensor

actually with the following Dimension

which is n by

m and we need to compute the gradient of

the loss function with respect to X

which should be a tensor of shape n by D

because um when dealing with the

gradient it always has the shape of the

input variable because it's the output

which is a scalar with respect to each

element in the input so it has the same

shape as the

denominator all right so uh when dealing

with these kind of problems I always

recommend to create example matrices and

then work out what happens to the output

and then try to work out the the the the

the the gradient Matrix so let's do it

so let's see that what is how is the

output computed well the output will be

a matrix that is 1

by4 computed as follows it will be the

input so 1 by3 so let's call the input X

X11

X12

x13 it will be multiplied by another

Matrix W that is has Dimension 3 by 4 so

it will be three rows by four columns so

it will be W1 1 W12

W13

w14 then w21 w22 w23

W2

4 W3 1

w32

w33 W 3

4 if we do this matrix multiplication it

will be well it will produce the

following Matrix that is okay this is

one row by three columns this is three

column three rows by four columns so the

output will be a matrix that is 1 by

four so one row by four columns so it

will be uh uh let me write it with a

smaller because otherwise it will never

fit here

so let's do it like this it will be X11

ultip by

w11 plus

X12 * by

w21 plus x13 *

w31 and this will be the first element

of the output the second element of the

output will be um

X11 with W12 X11 with

W12 plus

X12 with one two with

w22 plus

x13 with

w32 this will be the second element of

the output Matrix the third element of

the output Matrix will be let me move

this stuff on the left otherwise it will

never fit

so okay I think now it can fit this will

be X I need also to watch this one so

X11 with W13

X1 X11 with

W13 plus X1 2 with W 2 3 plus x 1 3 with

W3 3 and then we multiply the same row

with the last column so it will be X11

w14 + X12 W2 4 + X1 3 W3

4 this will be the output uh y if we do

the matrix multiplication what pyto will

give us it will give us the gradient of

the

loss so it will give us Delta fi with

respect to Delta y because it's a

gradient it has the same shape as the

denominator so it has a shape that is

1x4 let's call it because we don't know

what this value will be they will be

provided to us by pytorch let's just

give them generic name like

dy11

dy12

dy13 and

dy4 like

this now to compute the um uh the

downstream gradient that we need to

provide to py torure we should be

Computing the we should be m realizing

the jacoban which

is um which is uh okay let's write the

chain the chain rule formula so we need

to provide Delta fi to with respect to

Delta X which is equal to Delta fi with

respect to Delta y this is provided by

by torch multiplied by the Jacobian

which is Delta y with respect to Delta

X now instead of materializing this

Jacobian let's try to do this let's

materialize it now and let's do the

multiplication of these two L quantities

to see if something simplifies so this

stuff here will be Dy with respect to DX

which means the derivative of every

output y with respect to every input X

how many output we have we have four

elements as the output which is this

stuff here and we have three element as

input in the X

Matrix so it will be as follows I I

don't know how to let me copy it because

my screen is not big enough and I

remember that X is X11 and xx2

so uh Delta y with respect to Delta X

will have the following entries so the

uh y1 with respect to X11 and as you can

see y1 only has one X11 appearing as

multiplied by w11 so the derivative with

respect to X11 will be

w11 then y11 so this stuff with respect

to X12 it will be

w21 then X um y11 with respect to x13

will be

w31 the second row of this Matrix will

be the derivative

of the partial derivative of the second

output so w Y2 with respect to all the X

inputs which will be the derivative

partial derivatives of this stuff here

with respect to every X which is

W12 w22 I guess and

w32 now let me check if it's what I'm

doing is correct yes because I've

already done it so I can always double

check uh uh and then we have W the

partial derivatives of this stuff here

with respect to all the

X which is

W13

w23 and

w33 then the partial derivatives of the

last output so y4 with respect to all

the X which will be W

oh

w14 W2 4 and

w34 we obtain the following

Jacobian if um but this Jacobian as you

can see it's just equal to W

transposed so we don't need to

materialize the Jacobian we can just do

the multiplication of whatever uh

gradient pytorch is giving us multiply

it by W transposed and we will get the

downstream gradient so let me rewrite so

we know have we know what we are doing

so D on D DX is equal to

D with respect to Y multiplied by Dy on

DX but we have seen that Dy on DX is

just equal to W transposed so this is

equal to D on DX Dy multiplied by W

transposed and this gives us the

downstream gradient so in order to

provide the downstream gradient that

pytorch need we just need to take

whatever gradient pytorch will give us

multiply it by W transposed and it will

give us the gradient of the loss

function with respect to the input X of

the matrix

multiplication in the same way we can

also arrive to the formula for the

gradient of the loss function with

respect to W and it is equal to X

transposed multiplied by

Dy with respect to uh DW uh Dy

how to remember these formulas these are

there is a ponic rule which is um these

are the only possible ways for this to

have the shape of X and this to have the

shape of w because this ones this stuff

here will have the same shape of Y so it

will be

u n by m this stuff here will have shape

of w transpose W is is d by m so w

transpose should be M by

D and the resulting operation of this

matrix multiplication or tensor

multiplication will be n by D which is

exactly the same shape as

X in this case we will have that XT is

the transpose of T and it is uh n by D

so it's D by n

multiplied by D with respect to Dy which

is a gradient so it has the same shape

as the

denominator uh so it has uh n by

m uh and the output will have um shape d

by m which is exactly the um the shape

of w so if you to to remember them this

is the only way this shape work out

otherwise they don't work out so this is

a nimonic formula on how to remember how

to compute the gradient of the inputs of

a matrix multiplication given the

gradient of the loss with respect to the

output of the matrix multiplication and

the inputs to the matrix multiplication

are the input Matrix and the parameter

Matrix W now we need to derive the

gradient of the output of the softmax

with respect to the input of the softmax

because that's another operation that we

do in our fused attention because we are

fusing many operations together which is

matrix multiplication and the soft Max

so this is the second ingredient that we

need to understand the backward pass of

flash attention so let's do

it I will use to make this derivation I

will use the same notation as in the

flash attention paper so first of all

let's write the title of this stuff

which is the

gradient

through the soft

Max um the first operation that we do in

um during computation of the attention

is we we compute the product of the

query multipli by the transpose of the

keys we do in a block-wise ways it means

that we do it block by block but it

doesn't matter because the end result is

the same so we can also we can write s =

to Q multip by the transpose of the keys

and then we apply the soft Max to this

operation to the result of this

operation and we call this output P

which is the soft Max

of s and after the we have applied the

soft Max we take the output of the soft

Max we multiply it by V to obtain the

output so the output is equal to P

multiplied by

v um now we need to understand how to uh

because as I saw as I as I said before

pytorch autograd works in the following

way pytorch will treat our attention

computation as a black box so we will

have a computation graph like the

following

we will have a query input a key input

and a value input which are sequences of

tokens each one with some embedding

Dimension these are fed to some black

box called the

attention which is our implementation of

the attention which is the function that

we started coding before this will be

fed as input to this node in the

computation graph and the computation

graph will output an output tensor o

what pyro will give us pyro will give us

the gradient of the loss with respect to

the

output so as you remember py knocks the

door knocks the door at each operator

and says if I give you the gradient of

the loss with respect to your output can

you give me the gradient of the loss

with respect to your inputs and this is

what we need to figure out so giving the

gradient of the loss with respect to the

output we need to understand how to

compute the gradient of the loss

with respect to WQ the gradient of the

loss with respect to w k the gradient of

the loss with respect to w v however the

there is no direct connection between q

and o or k and o because there are two

intermediate operation so one there is a

first a matrix multiplication then there

is a softx then there is an additional

matrix multiplication however we have

tools that allow us to understand how

the gradient propagates through multiple

operations when they are applied in

sequence and that's called The Chain

rule however we have seen that applying

the chain rule in its naive way by

materializing the Jacobian is invisible

so we need to understand how to apply

the chain rule without materializing the

Jacobian and that's what we are going to

figure out for one of the operations

inside of this attention computation

which is the softmax and that's why we

are going to do this derivation which I

promise is the last one that we will do

and then we will finally go to code the

backward path of fles attention we

cannot proceed directly to coding the

backward pass of The Flash attention

because if we look at the formulas on

how it is computed we will not

understand how the um the derivation

comes

out okay now we can

start so let me delete this stuff delete

and imagine for Simplicity now we apply

the soft Max to a uh rowwise to this s

Matrix so each row is soft maxed

independently from the others

so let's see what happens to one single

row of this Matrix and for Simplicity I

we call it s so s is uh a single row of

the S Matrix I could also call it s of I

but if I do it like this we will have to

carry over the index I okay guys just

just do it we will carry over the index

all right so let's call SI one row of

the S Matrix so SI is equal to Let's

let's say it's the in tensor notation py

tensor notation it would be like this so

from The Matrix S so from the tensor s

we take the I row and all the columns

this is the definition of Si I know it's

very ugly notation but it helps you

understand and this is a vector of size

and

dimensions uh we apply the soft Max to

this uh vector and we will obtain an

output vector and we call it Pi Pi is

equal to the uh soft Max soft Max of Si

so as we have seen the soft Max

operation does not change the shape of

the input it just change element wise

each number um so it the output will

also be a vector of size R to the power

of n now

um what is the soft Max so the soft Max

is defined as follows the soft Max

of uh well p i j so the J element of the

p i Vector is equal to the exponential

of the J element of the SI

Vector divided by a nor normalization

factor that is computed as follows uh

with uh let's say not J let's use K in

this case not even K let's use l

is equal to 1 up to n of e to the power

of s i

l all right so uh first of all you may

be wondering the soft Max that we are

that we apply during the forward pass of

the computation of the attention is not

really this softmax because in if you

remember what we applied before we were

applying the soft Marx where each of the

argument of the exponential is reduced

by the maximum element in the vector to

which we appli the soft Max so it was

more or less like this so s i minus SI

Max so the maximum element in the SI J

SI

vector and also the argument of the

denominator was reduced by Si

Max

however we also proved that this stuff

here is equivalent to the standard soft

Max without this reduction in the

argument because this reduction in the

argument is only added because we want

to make it numerically safe to compute

but there is it's equivalent to do it

without from a mathematical point of

view on the computer of course it will

be become numerically unstable but from

a mathematical point of view it is the

same thing which also means that doesn't

doesn't matter how you comput the

forward pass if it's equivalent to

another mathematical definition you can

always use the other mathematical

definition to compute the backward pass

it will result in the same value if you

didn't understand what I said let me

give you a more simple example which is

um imagine you have a uh do you remember

the formula from high school this one so

Co cosine of s of X Plus sin 2 of X is

equal to one now imagine we compute uh

an output Y is equal to cosine 2 of X

and then we need to compute the

derivative

of y with respect to X it doesn't matter

if you compute it as the derivative of

cosine squared of X with respect to X or

if you compute it as the derivative of 1

- sin 2 of X with respect to to X

because they will result in exactly the

same result uh because the two

definitions are equivalent and this is

why we don't need to add this um this

factor in the exponential uh because the

two definitions are equivalent

mathematically we just use the

numerically save one because when

computed on the on the computer we need

something that is numerically save uh

stable that will not

overflow all right now um what do we

want to obtain

so we want to obtain the uh the gradient

of the loss with respect to the input

Vector of the soft Max which is the SI

Vector given the gradient of the loss

with respect to the output of the

softmax which is the pi

Vector multi and we can obtain that with

the chain rule multiply that by the

Jacobian Pi with respect to uh s

now we uh the chain R is always valid

let's see what does this Jacobian look

like um all right so this Jacobian will

be the pi with respect to Delta

SII uh well we need to do it let's look

at what each element in this jacoban

will look like so the J element with

respect to to the let's say the K

element so we are

um we are Computing the

the we are looking at what each element

in this Jacobian will look like which is

what is the Jacobian it's each element

in the out in the numerator of the

Jacobian derived with respect to each

element in the denominator of the

Jacobian uh in this uh fraction here so

we are saying for each element in the

output Vector derived with respect to

each element in the input Vector this is

the what we are writing here so what is

how is the output Vector obtained well P

J we know that it is equal to by the

definition of the softmax is obtained as

follows so e to the power of s

j divided by the normalization factor

let's call it l

is equal to 1 to n e to the power

of uh s i

l uh all derived with respect to s i

k i k so what we are trying to do is we

know that the P Vector is suppose it's a

vector with the three elements so this

is a p one this is well P11

1 one P1 2 and

p13 the S Vector will be a vector also

with the three elements so it will be

the

S11 S12 and

s13 what we are trying to do is

calculate what the Jacobian will be the

derivative of this one with respect to

all the input Vector then then the

second row of the Jacobian will be the

derivative of this one with respect to

each of this input element then the

third row of the Jacobian will be this

stuff here with respect to derived with

respect to each of the input element of

the S Vector we are trying to understand

what does the generic element in this

Jacobian look like based on the J date

element of the output Vector so this J

index refers to the output vector and

the K element in the input

Vector all right so what can happen when

we do this uh jacoban is that we have a

this one here is the derivative of a

fraction of two functions and we know

from high school that the derivative of

the fraction of two functions is as

follows so the derivative oops the

derivative let me write like this of f

of x with respect to G of X Prime is

equal to with respect to X by the way is

equal

to uh F Prime oops

of x * by G of x minus uh G Prime of x f

of

x all divided by the uh G of x to the

power of 2 like this now let's apply it

here so this will become here we will

have two cases either the variable that

we are deriving with respect to so this

s i k has the same index as the variable

being derived so either we are doing P11

with respect to S11 or we are doing P11

with respect to something else that has

not the same index so like P11 with

respect to S12 or s13 so there are two

cases that we need to consider suppose

that we are deriving P11 with respect to

S11 or we are deriving p12 with respect

to S12 or we are deriving p13 with

respect to s13 so we are deriving the

element of the output with respect to

the same element in the input with the

same

index so in this case the this um this

uh derivative will look like the

following so it's the derivative of f so

the numerator with respect to the

denominator that has the same index so

we are saying that in this

case uh J is equal to K

so uh the numerator with respect to SI J

with respect to uh e to the power of s j

with respect to s j will be e to the

power of s j so because e to the power

of X1 with respect to X1 will be e to

the power of X1 so this is equal to I am

reducing the size

now e to the power of s i j then we need

to multiply that by the denominator of

the fraction which is this summation

here so the summation

over all possible

L of e to^ of s i

l uh minus the derivative of the

denominator with respect to the variable

being derived so this denominator is the

sum of all the exponentials of all the

input elements if we derive it with

respect to one particular input element

there will be at least one term that

contains that input element and so the

the all the other terms will result in

zero so the only derivative that will

survive will be the e to the power of s

i k with respect to s i

k so we write minus e to the power of Si

i

k multiplied by the numerator which is e

to the^ of Si

J all this divided by the denominator to

the power of two which is this summation

here so L = to 1 up to n e^ of s i

l all to the power of two and this stuff

here will be equal to well we can see

that the this two term this one and this

one have a one term factor in common

which is e to the power of s j so we can

collect that so e to the power of s i j

multiplied by the

summation minus e to the power of s i

k all this divided by the denominator uh

which is the power of two of this stuff

here so let me just copy and paste it

which is let me rotate it also because I

don't know why I always write

little

little yeah all right and this stuff

here is equal to well uh we can separate

the two terms so we can separate this

term here and this term here because the

denominator is to the power of

two so we can write it also as e to the

power of s

j mided by the denominator so which is

summation of l = 1 up to n e to the^ of

s i

l uh multiplied by this stuff here so

this stuff here divided by the same

denominator so there's summation of l =

1 up to n e to the^ of s i

l minus E to the^ of s i

k i Am s i k divided by the same uh

denominator

s i

l now this one can be written as this

stuff here is nothing more than the

output element P J because this one is

just the soft Max applied to the SI J

element which we know that the output of

the softmax applied to the SI element is

called P because it's one element of the

output Vector which we called P so this

stuff here is equal to p i

j multiplied by

this stuff here will be equal to 1 minus

this stuff here what is this stuff here

is the output of the soft Max appli to

the s i k element so it will be p i k so

it is equal to 1 minus p i

k okay and this is in the case the the

the variable with respect to which we

derive has the same index as the

numerator in this uh fraction

here uh in this derivative here uh the

other case is when the two variables so

the the output the index of the output

with respect to the index of the input

are not the same in this case we will

have another case so we will have that

J uh let me write it again so this stuff

here hope I can copy it all

without in the other case in which s is

not equal to

J uh not s it's j not equal to K so J is

not equal to K what

happens in this case it will be well uh

the derivative of the numerator because

we need to apply again this formula here

so derivative of the numerator with

respect to something that is not the

same variable it will be zero because

it's like Computing the derivative e to

the^ of X1 with respect to X X2 it will

be zero so it will be zero so all the

first term here will become zero no

matter what is g of x minus the

derivative of the denominator of this

fraction here with respect to the

variable SI

K uh G Prime of s i k so this is all the

variable in the input and we are

deriving it with respect to one

particular variable of the input so only

one item in the summation will survive

so it will be the item s i

k so it will be e to the power of s i k

multiplied by f of x which is the

numerator in this fraction which is e to

the power oh we forgot a minus e to the

power of s

j uh let me see if I forgot something

all divided by the denominator of this

fraction here to the power of two

so it is equal to the

summation l = 1 up to n of e to the

power of s i l all to the power of

two uh I believe I didn't forget

anything so let's continue so here also

we can see that this one here is because

okay let's separate it minus E to the^

of s i k divided by the summation l = 1

up to n of e^ of s i l multiplied by e^

of s i j divided by the summation l = 1

up to n of e to the power of s i

l this stuff here is nothing more than

the soft Max appli to the K element of

the SI Vector this one here is nothing

more than the softmax applied to the J

element of the s I Vector so we know

what these are we know that we call them

P minus p i k p i

j so in the end we have two cases one is

the derivative of this stuff here looks

like the

following each item in the Jacobian

looks like the following when the

numerator and the denominator have the

same index so J equal to K this stuff

here is equal to now this notation here

is wrong so I shouldn't be writing it

with equal sign but doesn't matter guys

it's we are doing a

little um okay so p i

j p i j multiplied by 1 minus p i k let

me check yes the other Cas is when the J

is not equal to k then this stuff here

let me write it like this

will be equal to minus p i k MTI p i

j now that we know what the two typical

cases of this Jacobian look like let's

actually look at what this Jacobian look

like in The Matrix

form so this Jacobian will look like the

following it will be a matrix that is

more or less like the following it will

be an n byn Matrix where n is the size

of the input vector and the output

vector and here the first element of the

Jacobian as you saw as you remember uh

the first row of the Jacobian in the

numerator convention is the uh

derivative of the first output with

respect to all the input so this first

term here will be the derivative of P11

with respect to

S11 so in this case J and K match so we

know that it will be equal to P1 1 * 1 -

P11 the second element to the right of

this one so the element one two will be

uh the derivative of p12 with respect to

uh sorry the P11 with respect to S12 the

J and K do not match so we will be in

this case here so it will be minus

P11

p12 the third element you can check it

by yourself it will be minus P1 1 p13

blah blah blah until the end which will

be - P11

p1n the second row of this Jacobian will

be uh will look like this so it will be

the derivative of p12 with respect to

S11 the J and K do not match so we are

in this case here so it will be minus P1

2 P11

then the next element it will be the

derivative of p12 with respect to S12 so

J and K match so we are in the first

case so it will be P1 2 * 1 minus

p12 then this stuff here will be equal

to then the third element will be minus

P1 2 with respect to

p13 blah blah blah and until we arrive

to the last one which is minus P1 2 with

respect to P1 n not with respect to

multiply

by and all the elements like this until

the last row the last row will be the

the first element of the last row will

be the derivative of the uh last output

element with respect to the first input

element so it will be the derivative of

p1n with respect to S11 so um the two

indices do not match so we are in the

second case so it will be minus P1 n

P11 this will be minus P1 n p12 etc etc

etc let me do also the third element

since we are here so minus P1 n P1 3 etc

etc etc until the last element of the

last row which will be minus

p1n P1 N I

guess oh no that's wrong guys because

the two indices match so it should be

p1n multiplied 1 minus

p1n this is what the Jacobian will look

like let's see if we can find a better

um uh how to generate this Jacobian with

some pattern

recognition let's write it in a

different way first of all the thing

first thing that we can notice is that

this Jacobian is symmetric so you can

see that this element is equal to this

element if you expand the third row you

will see that it's equal to this element

this one on the top right corner is

equal to the one in the top bottom left

corner um so this Matrix is

symmetric the second thing that we can

notice is that only the element in the

diagonal are different they have an

additional term because you can look at

this element here so let me write this

element here can also be written as

P11 minus P11 * by P11 the second

element here in the second uh row so

second diagonal element of this Matrix

is p12 minus p12 * p12 so this the

element on the diagonal actually look

like just like the other elements they

just have an additional

term which is P11 in the first diagonal

element p12 in the second diagonal

element so we can also say that this

Matrix here is the product of all the

possible combinations of P J with p i

k which you we can obtain with an outer

product or even with the product of one

column with the transpose of the same

column so if you do one column Vector

for example imagine p is a column vector

and you do p multiplied by PT you obtain

all the possible combinations of

products of these two vectors because

this will be one

I can do a simple case so P11 P1 let's

call it P2

P3 uh multiplied by the row

Vector P1 P2 P3 this will generate all

the possible uh combinations of products

between P1 and the P the first vector

and the second Vector because this will

be a 3X one this is 1 by3 so it will be

generated 3x3 vector and it will be uh

equal to P1 P1

uh P1 P2 P1 P2 P1 P3 etc etc etc

moreover we can see that in the diagonal

of the Matrix we have this additional

term this additional term P1 in the

first diagonal element p p12 in the

second diagonal element p13 in the third

diagonal element uh I actually call it

P1 it's wrong because I should call it

Pi I that's why I didn't want to bring

the I uh indices so it's not really P1

it's should be Pi I Pi Pi I Pi because

we are doing it for the generic I Pi

Vector so let me fix the indices p i n p

I3 uh this is

one one one

one Pi I and pi

okay so this is Pi

Pi Pi Pi Pi Pi Pi I Pi I

okay we can obtain um so we can write

the this

um this Jacobian here also as the

diagonal matrix that in the diagonal has

all the element of the pi I

Vector minus the P Vector multiplied by

the transpose of itself so with itself

but transposed because we need all the

elements to be kind of a combination of

one element of p with it with another

element of p plus only on the diagonal

we need some this additional term which

are the elements of p and all the

elements of the the the output of this P

multipli by P transposed are negated

that's why we need this minus sign so if

you look at the flashh attention paper

they give you this formula here they say

that if Y is equal to the soft

Max of X then the

Jacobian uh will look like the following

will be um

diagonal of Y minus y y transposed where

Y is

um

the is a column

Vector all right guys I know this has

been long so let's take a pause and we

are going to now um uh code finally

first of all let's check the mathematics

of the backward path of Flesh attention

we will see it briefly I will not do any

more derivation but I will explain it

and then we finally switch to coding it

so let's

go all right guys now finally we can see

the um the backward path of the flashh

attention so we will be looking at the

algorithm and if you look at the the the

appendix of the flesh attention p paper

you will see this part b.2 where they

derive the backward pass step by step

now I don't want to do all the s all the

steps of this derivation because it's

going to be too long but I want to give

you all the tools necessary to

understand it now let's start from what

kind of um

uh how to say conventions they are using

uh notations they are using in this

paper so the first thing that we need to

rehearse is the naming of what is what

is the name of each

Matrix uh as you know in the forward

attention in the forward pass we do the

query multiplied by the transpose of the

key and the output of this we call it s

then we apply the soft Max to this s

Matrix and it becomes the P Matrix the

soft marks is applied by rows then we

talk take this P Matrix and we multiply

by a v Matrix to obtain the output of

the

attention

um let's look at for example how the

computation of the I row of the output

is computed based on the P Matrix and

the V Matrix so we can understand this

kind of notation that they are using

here in the paper because the way I read

this formula here is the E row of the

output which is a column Vector because

in when we write in um in mathemat in

linear algebra whenever we write the

name of a vector it is always by

convention a column Vector but the

origin of this particular Vector is

actually a row of the output Matrix

let's try to understand what is the

output uh row of a matrix in a matrix

multiplication now um so that we can

understand how to go from here to here

um so let's write a generic matrix

multiplication for example an a matrix

let's say that it is the

following and we only write one row

actually let me Zoom again and I want to

write smaller so we have enough space so

we make a matrix that has a row let's

call it A1 A2

A3 and then we multiply this will be a

matrix with many rows like the this one

because we want to study the effect only

of one row and we multiply it by another

Matrix let's call it this one is the

Matrix a and it has I don't know let's

say n rows by three columns then we

should have another Matrix B with three

columns and some number of um three rows

and some number of column let's say four

columns so we call the first uh row

let's call

it let me Zoom more so it's

b11

B12 B1 3

b14 then this one should be B2 2 1

B22

b23

B24 this should be b31

b32 uh

b33 B3 4 Etc I know I am not very

rigorous in my notation I should have

called all these elements with the

capital letter a and the capital letter

B so this is the notation that you use

when referring to single item of a a

matrix but please forgive me for this so

the output of this matrix multiplication

will be another Matrix that is n

by4 so it will be n by 4 so we will have

four columns for each uh row of the

output I want to write the output in a

different way so I want to write it as

follows as a vector only so the first

output row as a vector and want to

understand what is is each dimension of

this Vector so because otherwise I don't

have enough space to write it here so

the uh first uh let's write it so let's

call it o I want to write what is O of

one which is the first row of the output

but written as a column Vector so o of

one will be here we should use the small

letter O of one should be a vector where

the First Dimension is the dot product

of this stuff here so the first row of

the a matrix with the First Column of

the B Matrix so the first uh let's say

Dimension will be A1 with

b11 uh I should also call this one a11

a12

actually and a13 so A1

3 uh because we have many rows in the a

matrix so let me use the correct naming

so this will be a11 with b11 a11

b11 plus

a12 * by

B21 plus

a13 with

b31 and this will be the first dimension

of the first row of the output Matrix

o the second dimension of the first row

of the output Matrix o will be the

dotproduct of this row of the a matrix

with the second column of the B Matrix

and let me write here B so it will be

a11

B12 plus

a12

B22 plus

a13 b

32 the third dimension will be a11

B1

3 plus A1 2 B2

3 plus A1 3

b33 the fourth dimension will be

a11 uh

b14 + A1 2 b 2 4 plus A1 3 B 3

4 now this is the output the first

output row of the O Matrix and it's a

vector called o1 and these are this is

the first dimension of this Vector this

is the second this one is the third and

this is the fourth dimension and each of

this stuff here is one

scalar um so the

output o1 which is the first row of the

output Matrix can also be written

as the first element as you can see in

is a sum of many vectors where the first

element is

a11 Multiplied let me use a smaller uh

this one but I want to use a smaller I

can't change the size

here okay doesn't matter so as you can

see here there is A1 multiplying a

different B number every time so this is

b11 B12 B13 b14 what is b11 B12 B13 b14

it is the first row of the B Matrix so

it is equal to B uh 1 and all the

dimensions of the first row then plus

then we have the element

a12 multiplied by B21 B22 b23 Etc and

this is uh the second row of the B

Matrix so we use the tensor rotation of

py torch to describe this row which is a

b uh two and all the dimensions of

B2 uh so it looks this is a vector

scalar product and

plus uh

a13 multiplied by B uh

3 and all the dimensions of P3 this one

can also be written

as um the

summation over all possible I that go

from 1 to three where 1 to 3 is how many

uh uh columns there are in the a

matrix uh of

a i

j uh well

A1 let's call let's call this one J

actually sorry let's call it

J equal to one and let's call this the

generic e row of the output Matrix will

be a i1 a I2 and a I3 each one

multiplied by the corresponding Row in

the B Matrix so we can write it as a i j

multiplied by B uh J where BJ is the uh

a row of

B uh we can also write it like this to

indicate that this is a vector and this

is exactly what they do here so the

output in the output Matrix when we do

the multiplication P multipli by V the E

row of the output Matrix we call it oi

which is a vector but by notation it is

a column Vector where the elements of

this column Vector are actually the

elements of the E row of O uh this is

only by notation guys uh is equal to the

E row of P so the E row of the Matrix

that is on the left in the matrix

multiplication multiply by all the

columns of the V Matrix which can also

be written as the summation over all the

elements of the I row of P so all the

elements of the I row of the first

Matrix the one on the left in the matrix

multiplication multiplied by each Vector

in the v Matrix where the J Matrix here

in V is each row of the V Matrix

so and P J can also be written as um p

is what is um the the output of the soft

Max so as you know the output of the

soft Max is e to the power of the

element input of the softmax what is the

element input of the softmax is the

query multiplied by the transpose of the

keys so it's a DOT product between one

query and one key and that's why you

have this stuff here in the exponential

so this is the first step in

understanding this derivation another

thing that we have studied so far is how

to derive the uh backward path of the

matrix multiplication

and of the soft Max so now let's use it

in the matrix multiplication let's

rehearse the formula so the if given a

matrix multiplication that is y = to X

multiplied w we know that given the

gradient of the loss function with

respect to Y so the output of this

operation we know how to derive the

gradient of the loss with respect to one

of the input of this function which is

the x or W to get the gradient with

respect to X we need to take the

Upstream gradient so the the gradient

with respect to the output multiply by

the transpose of WT and to get the

gradient with respect to w we need to do

the XT so the input transposed

multiplied by the upstreaming gradient

this one is the formula that we didn't

derive and this one is the formula that

we derived but how to derive them is

exactly the same

procedure in attention we are doing the

last product that we are doing is O

equal to P multiplied by V what pyo will

give us as input during the backward

pass is the gradient of the loss with

respect to the output and we need to use

this gradient of the loss with respect

to the output of the attention to derive

the gradient of the loss with respect to

Q with respect to K and with respect to

V so that it can then be used by the

operators in the backward pass in the in

computation graph in the operations

before okay so but in order to arrive to

the gradient with respect to query key

and value we need to derive the gradient

with respect to each intermediate

operation so the last operation that we

do is O equal to P multiplied by V so

the gradient with respect to O of the

loss with respect to V given the

gradient of the loss with respect to O

it is exactly like Computing the

gradient of um the of the loss with

respect to X in a matrix multiplication

and we know that it is equal to the PT

so just by analogy guys so this is our

reference point and I am just changing

the names here and you should understand

what is the analogy here so um the

gradient of the loss with respect to V

which is the Matrix on the right which

is like Computing it with respect to to

W it is equal to just like this formula

here so the transpose of the Matrix on

the left multiplied by the Upstream

gradient which in the paper they write

it as this so DV is equal to PT

multiplied by d o and it's the formula

that you say you can see here the other

derivation is how to derive the gradient

with respect to DP DP is just like

deriving the gradient of the loss with

respect to the Matrix that is on the

left side of the matrix multiplication

so it is just like deriving the gradient

of the loss with respect to X in the

reference uh formulas

which is equal to the Upstream gradient

multiplied by the transpose of the other

Matrix which in the notation of the

paper they write it as DP is equal to d

o multili by V transposed and it's this

formula here how they compute this Stu

here is exactly as above so as as this

derivation here they call

VJ the J row of the V Matrix and they

write it as um p i J multip by d o how

to arrive to this formula here well

let's do it

so let me write let's see

okay theoretically we know that from

this derivation here so from this

derivation here or from this derivation

here we know that the I row of the

output in a matrix multiplication first

of all let's simplify our life every

time you see a transposed and you don't

like work with the transposed in a

matrix multiplication just give it a

different name and then work with the

different name and after when you have

derived the formula you resubstitute the

transpose operation in this case we are

doing DV is equal to P transpose multip

by d o let's call P transposed let's

give it a name that we are not we didn't

use so far so let's call it f I always

used F when it's available so um we call

DV is equal to f d o we know from above

here from this derivation here or this

derivation here is

equivalent that the output of a matrix

multiplication so the out I row of the

let's not the J row let's call it the J

row d

v j is equal to a summation of each

element of the J row of f of the first

Matrix so we do the let's see here they

do the sum by I so let's do it by I it's

the sum over all possible I of the uh I

element in the

J row of the first Matrix so f j uh yeah

FJ

I multiplied dot product not DOT product

this is a a scolar uh Vector

multiplication multiplied by a vector

that is let me check what was the

formula so it was the J row of the other

Matrix so in this case it should be the

I row of the other

Matrix uh o of

I where I this is the I row of I this is

the J row of the V

Matrix um and but also we we know that f

is not a matrix that we have it's

actually the transposed of P which means

that fji will be equal to p i j because

in a matrix transposition you invert the

two indices so this is the summation

over all possible I's of P not j i but

IJ multiplied by oi and this should be

equal to the same formula that you see

on the right here this allows you to

compute one output Row in the v

Matrix okay and we know that p is just

um the output of the softmax the soft

output of the softmax is the input of

the softmax to the exponential of the

input of the softmax divided by the

normalization factor

associated with the that row so because

we are iterating through the row of I it

will be the I the normalization factor

associated with that row of um of

OI so we know that the for formula for

the p is equal to the soft Max of s now

the I row of P will be the soft Max of

the E row of s and this is what is

written here we know from our derivation

that the Jacobian with respect to the

softmax operation so if we have an input

X and the output is y of the softmax

operation the Jacobian of this um of the

Y with respect to the x is equal to the

diagonal y it's a diagonal matrix of the

element of the factor y minus y * y

transposed and we have also seen before

that this Matrix is

symmetric however you may not understand

this formula here because we have seen

from our uh in the chain rule we always

write it like this we always write that

the downstream gradient so the D um uh

Fe of let's say um TX

should be equal to the uh Upstream

gradient so d f with respect to Dy

multiplied by Dy and um with respect to

DX this only works if you make this

Matrix here as in the numerator

convention the numerator convention is

one of the two convention in which you

can create a Jacobian we so far we have

always written it as the numerator

convention if you use the Ator

convention this is a row vector and this

is a row Vector however if you want to

treat this stuff here as a column Vector

then you need to take the transposed or

you need to make the Jacobian in the

denominator

convention how to get this formula here

because this formula here is basically

doing the Jacobi and multiply by the

Upstream um uh gradient not the gradient

Upstream gradient multiplied by the

Jacobian and it's only because here we

treat it as a column vector and when you

do the you want to transform a row

Vector into a column Vector you take the

transpose of both sides of the equation

and let's do it actually so we apply the

transpose to the both side of the

equation okay um in a matrix

multiplication if you do uh a b

transposed it become B transposed

multiplied by a transposed so the

transposed is applied independently to

each input of the the matrix

multiplication but we invert the matrix

multiplication and if you remember the

matrix multiplication is not commutative

so what we do here is that we say okay

it will be the D Fe of DX and here they

call

it here they call it

DSi so it will basically just become D

on DX if you treat this one as a column

Vector so this one as a column Vector

will be equal to Dy on DX as a column

Vector as a Jacobian in um in

denominator layout in this case

multiplied by d f on d y as a column

Vector this one is a column Vector this

is a column vector and this is what you

see here that's why the Jacobian is on

the left side of the Upstream

gradient uh what else we need well I I

know that there is a lot of things here

in this derivation but I prefer actually

going directly to the code otherwise I

think it's going to be too

boring um so let's go to the code and

while writing the code I go back to the

formulas in which we can find the

association of what we are doing and the

formula in the paper I think this is the

best uh way so let's proceed further all

right guys now we can finally code the

backward pass before we code the

backward pass let's look at the

algorithm of the backward pass as

written in the paper this is the paper

flashh attention one and I will because

we will follow the structure of the code

that is present on the Tron website so

it's not my idea to split it like this

but I simplified it in s i simplified it

so it's different than the one that you

can find online because mine is a

simplified version and mine works with

causal and non-causal

attention um so first if you look at

this algorithm you need to you can see

that we have an utor Loop through all

the K and V blocks and an inner loop

through all the query

blocks however as you can see to compute

the DQ which is the downstream gradient

of the the loss with respect to the Q uh

Matrix we need to have an iteration

through all the ks and to compute each

DC block we need to have an iteration

through all the cues so if we follow the

loop like it is it would involve writing

to the high bandwidth memory so to the

dram of the GPU at every inner iteration

and that could be also that that is not

so

efficient um and also if we don't want

to write it would require some sort of

inter some sort of um synchronization

between blocks which is also not very

efficient so we split we will split this

four into two parts because we can see

that each DQ depends on a loop over the

case and each d DK depends on a loop

over all the cues so to compute DK we

will fix the Kate block and iterate

through all the Q blocks then we will do

another iteration in which we fix the Q

block and iterate through all the KV

blocks to compute the Q this is what we

are going to follow and this is an idea

that I took from the original

implementation that is present on Tron

website another thing that we can notice

here is um where where is it here to

comput DQ and DK so um a DQ vector and a

DK Vector we need this element this

information here called di di and it's

shared between the two so we can

precompute it and then we can reuse it

for the qite vector to compute the qite

vector and the DK Vector what is this di

di is um is uh introduced here and it's

the dot product of a um Vector that is

the DOI Vector multiply by o Vector so

the first thing that we will do is do a

loop over all the vectors in O and d o

and do their dot products to compute

this di element then we will use this di

element and actually uh let me see yeah

and then we will use this di element to

update to to compute DQ and DK and we

will also have another two Loops one in

which we fix Q and we iterate through

all the keys and one in we fix the keys

and iterate through all the cues so

let's start so now that we know more or

less the structure of the code that we

are all right so um we start by writing

this backward

function

here uh let me check yeah okay so do you

remember this is saved tensor these are

all the information that we save during

the forward pass uh to compute the

backward pass now to to optimize the

memory utilization in um flash attention

we don't save the query multiplied by

the transpose of the key Matrix because

that would be a sequence by sequence

Matrix that is too big to save into the

hbm in the dam during the forward pass

and then re reget it back from the hbm

into the local memory because I want to

remind you that in Tron uh compared to

Cuda in Tron what we do is we load stuff

from the high band withd memory in the

shared memory so the SRAM we do all the

operations there and then after when we

call the store method we save the

element from the shared memory into the

high band WID memory so in order to not

materialize this s Matrix in its

entirety save it to the hbm and then

Reet it back which could be very slow

and secondly actually it is very

expensive because usually right now we

are Computing attention on thousands and

thousands of tokens so imagine saving a

matrix that is 5,000 by 5,000 that's a

big Matrix to save for each batch uh for

each batch and for each head so that

would be really too expensive to

save so the idea in Flash attention is

to recompute what we can compute on the

fly during the backward pass because

anyway if we were to load it it would be

Memory IO found so it's faster to

recompute than to save it and restore it

from the memory this is the idea of

flashh

attention Okay so we saved some stuff

during the forward pass and now we can

access it back during the backward pass

and this stuff is saved in the context

and this it's it's a kind of a

dictionary that is made available by by

P

torture all right so we get back the

query key and values and as you know P

torch during the autograph will just

give us the gradient of the loss with

respect to the output of our

implementation of the attention of our

attention so this is Tron attention and

then we need to compute DQ DK and DV by

using only the gradient of the output

with respect to the the deloss with

respect to the output um we do for some

checks so here I know I could optimize

this code and make it even smaller by

for example checking that here the

stride that I am using I actually inside

of the code I always uh pretend that the

stride is the same but uh doesn't matter

I just take the code from Tron and uh

try to simplify it my goal was to

simplify it not optimize it

so all right we create the um the

vectors the tensors in which we will

store the result of this backward pass

which is the DQ DK and DV and as you

know from what we have seen of the

definition of the gradient the size of

the output of the gradient Vector is the

size of the uh Vector with respect to

which we calculate the gradient because

in the numerator is always a scalar and

we compute the gradient with respect to

all the elements in the input Vector so

the output the gradient itself is a

vector of the same size of the element

by which we compute the gradient with

respect to so uh we got some information

on the bed size blah blah blah and later

we will see what is this number of warps

and number number of stages I will not

explain it now it's how P torch number

of Parts warps is an indication on how

many threads we want to launch in our

grid and number of stages is actually

the number of stages that is used in

software pipelining we will see later

what is software pipelining when we talk

about the autot

tuning then we Define some uh

blocks uh in the original um in the

original code I think they call it a

block kv1 kv2 k q1 1 and Q2 I think it

was confusing I call it block macro and

block micro because the things that we

will fix and the things that we will

iterate from will be once is the query

so we fix the query block and we iterate

through all the keys and then we will

fix the keys and values block and we

iterate through the queries the one that

we iterate on is the micro one and the

one that we fix is the macro one this is

my uh the naming that I am

using um then we as I said before we

need to precompute the DI elements that

we saw in the paper before so that's the

first kernel that we are going to launch

and this kernel will have its own launch

grid because later we want to optimize

the uh the tuning of this carel later we

will talk about tuning with respect to

its own parameters so uh let me see what

are we going to do so here so the first

kernel that we are going to launch is

this pre-process kernel

this preprocess can will pre-compute all

the DI elements that we need to compute

I remember DK and DV if I no DQ and

DK and this di element depends only on o

and d o um so let's do it and uh let's

create another function uh called

backward pre-process what is the process

pre-process grid this is the launch grid

of this function of this car

and this will be launched on um

independently for each batch and for

each head and moreover it will be work

with a block size of vectors of O what

is this block what is this number of

vectors of O it will be the block size

macro so on 128 vectors of O so uh let

me copy the signat of this function this

is here so let's write it here it's fine

yeah okay this function takes uh The

Matrix o so it's a pointer to The Matrix

o it's a pointer to the d o and it's a

pointer to The Matrix D where we will

store this di elements and we have one

for each Vector in the

output uh that's why the shape of this D

is batch size number heads sequence

length it means it's one for each of the

output element in the output of the

attention this di uh where is it

actually it's not this one it's this one

yeah like M so it has the same shape as

M which is as you can see it is this

size here so bch size number heads and

sequence length M if you remember is the

Matrix that we Sav during the forward

pass which includes the normalization

factor of the soft Max and also the

maximum element but in log some exp

format so that when we apply it will

automatically apply the maximum element

for each row and also normalize that at

the same time which I think I proved

previously uh

so let me do it so we write it like this

so we extract

the uh the index of this program so this

program has two uh index uh like

identifier this is is equivalent to the

Cuda identifier and this is along the

axis zero so let's see what we uh what

we what did we launch on the axis zero

so on the axis zero of this launch grid

we defined what is the block of vectors

of uh the O that this particular will

program will work with and the second

axis is which batch and which head

inside of each batch this particular

program will work with so this

identifies the block index of Q so which

group of vectors in the O Matrix this

particular program will work with here

is called Q I believe because I copied

it from the original code where they

call it q but I could have eventually

also called it

o um so we Define uh so basically this

means that we uh for this program we

need to skip some query vectors that

have been already or that will be or

have been already processed by other

programs in parallel so we will only

block with a number of query vectors

inside of O that have the following

indices so so imagine that query block

size is I think it's 128 the way we have

defined it but suppose it's four for

Simplicity so this one will be and the

query vectors are how many are sequence

length number of query vectors we have

so some of Imagine The query vectors are

in total they are I don't know let's say

uh 64 and 32 will be managed by other

programs so this particular off skq will

be equal to 33 34 35 and 36 this tells

me which query vectors or which vectors

in the output o Matrix among all the

vectors in the O Matrix this particular

program is going to work

with okay so then we extract also the

index of the batch which tells us which

batch and which head in each batch this

particular program is going to work with

which is the dimension one of our launch

grid

and then we Define the offset of the

dimension because we need to load all

the dimensions of each Vector so these

are the uh it's a vector uh that tells

um which Dimensions we need to load from

each vector and we will load all of them

so we don't divide on the head Dimension

uh Dimension we just divide on the

sequence length

Dimension the the load among multiple

programs um you will see in this part of

the the video so when we are writing the

backward pass that we will not be using

the make block pointer like we did

during the forward pass so this function

here we will work with directly with

indexing by using the

strides so let's do

it so let's load a single block of rows

of O which I want to remind you has the

same shape as q and that's why we can

call it block size Q um so the O Block

that we are loading is O so uh the load

function accepts a pointer to what it

should load uh actually not a pointer it

accepts a array of pointers or a

multi-dimensional array of pointer in

case you want to load a multidimensional

data so actually load also allows you to

load um two dimensional data in this

case we are going to load two

dimensional data which is a block of

rows of O which should be a block a

tensor of the shape uh block size Q in

this case multiplied by the other

dimension being head

Dimension but we don't we need to tell

it where in this o Matrix it needs to

find this one first of all we need to

skip some batches and some heads based

on what the head and the batch that will

be processed by other programs So based

on the index that this um program will

process of the batch and the head we

need to skip all the other um batches

and heads

uh let's write the shape of this tensor

so the O tensor has a shape block size

uh not block size bch

size number of

heads then sequence length and then head

Dimension each block and each head will

have sequence length multiplied by dim

head dim number of items So based on our

index we skip how many items the our

index multiplied by head Dimension

multiplied by sequence length so what I

mean is this the batch zero and the head

zero will have sequence length

multiplied by head Dimension items the

batch zero and the head one will also

have the same number of items and the

batch zero and head two will also have

the same number of items so how many

items sequence length multiplied by head

Dimension do we need to skip from the

starting of the O tensor it is equal to

the index of the current batch and head

indicator so because this index

indicates both um the head in the batch

and the head inside of each bat because

it's already the product of the head and

the batch so how many we skip indicated

by the this index and after we point to

this starting point of the current batch

and the current head we need to select a

two-dimensional tensor where the offsets

are indicated for the rows by off SK and

that's that's why we have this one um

the I don't know what this is called

this is the the index U semicolon index

that tells all the all these vectors in

offc will with an additional dimension

for the columns and this columns will be

the offs dim so basically this will

select a tensor of the following shape

inside of this big tensor that includes

pet

size and number number of

heads this is what we are doing so we

are saying select a tensor of this size

inside of one that is made up of four

dimensions by skipping the elements of

all the batch and heads that will be

processed by other programs I always

talk in terms of programs because in

Tron these are called programs in Cuda

you would refer to them as

kernels all right so this one is done I

hope it is this recently clear um all

right so then we also load a single

block of

D in the same way because we are going

to load a a group of vectors from all

the sequence length also from D and the

D has the same shape as o which has the

same shape as q and that's why we can

use the um the the block index we call

it Q because it's equivalent because

they have the same

shape okay and how to compute this di

element well it's written in the paper

so if we go in the in the what is it man

if we go here it shows you how to

compute the DI of each given a block of

D and um a block of O it tells you how

to compute di which is the row sum which

means the the sum of by rows for each

row we will have one sum for each Vector

in the O Matrix we will have one sum of

the element wise product so this stuff

here is the element wise product of D oi

multiplied by oi so it's not um uh

matrix multiplication it's element wise

product which means each element of one

Matrix with the corresponding element of

the second Matrix and the output shape

it will be the same as the two matric

which must have the same

shape okay so we compute this di

block which will have shape block size Q

because we will have one sum for each

Vector uh then well we need to store it

somewhere so we need to calculate where

to store it inside of the D Matrix uh

well the D Matrix is I remember

correctly has the same shape as M so it

should be batch

size uh number of heads and sequence

length so we need to select the right

batch and the right head and also the

right position inside of the sequence

length based on the block index Cube

that we

have uh okay so let me

index

[Music]

okay all right because we already um so

we skip um again just like before we

know that D is of this size each botch

and each head will have sequence length

number of elements so how many number of

elements we need to skip from the

starting of the uh tensor is sequence

length multiplied by the combined index

B size head

number uh and plus we need to also skip

some queries based on our block index q

and it's this skipping is already done

inside of off skq so we add off skew and

then once we have computed the index

where we should store this di I block

why did I even call it d block let's

store it so let

me I didn't call it block I think it was

already in the original code but this is

di I and this big Matrix D is actually

the Matrix that includes all the DI for

one for each token in the sequence

length all right so the pre-processing

has been done now we need to do prepare

the two for Loops as you remember I said

before we will be doing two for Loops

one in which we fix the query and we

iterate through all the keys and values

and one in which we fix the key and

value block and we iterate through all

the queries and while coding it I will

always show you the formula from the

paper so don't worry let's start with

the next iteration so first we create

the launch grid for the next iteration

um as the launch grid is always the same

so we first because we we need to keep

one block fixed and iterate through all

the other blocks uh the block that we

keep fixed we Define how many programs

we have that run in parallel uh and the

block that is fixed has a block size

macro number of elements that's why we

create a sequence length divide by block

size macro number of blocks uh thread

blocks uh or programs in this AIS uh the

AIS to in this grid is I could have used

also the XIs one IND differently I think

it was already done here in the original

code it's um will indicate which batch

and which head inside of each batch we

are going to work

with uh so so and just like the uh

forward pass we will also use a variable

called stage that if the attention that

we are Computing is caal it will be

equal to three and if we are Computing a

non-causal attention then it will be

equal to

one um the first iteration we will fix K

and V blocks and we will iterate through

all the Q blocks in size of block size

micro number of query

vectors uh so let's look at the

signature so we pass we we launch it as

a launch grid because um and we we have

defined how many programs we have so we

have how many uh KV blocks we will have

it's a sequence land divide by the block

size macro because that's the the block

that we will keep fixed in this uh for

Loop in this

function and then we go through all the

query blocks in size of block size micro

which I defined it as 32 and later we

will talk about autot tuning and how to

tune these

values all right so I passed the query

Vector the key vector and the V Vector

uh sorry not Vector tensors now the

query tensor K tensor and V tensor and

they are pointing to the beginning of

the tensor which means that they are

beginning to the first botch and the

first head and the first token and the

first dimension of the tensors then we

pass the soft Max scale we pass d o DQ

DK and DV m is the one that is needed to

compute as you remember from what we

said before we did not see if the P

Matrix in the hbm because we want to

recompute it on the fly during the

backward pass so the query multiply by

transpose of the keys it's a very big

Matrix to save in the hbm and restore it

so we want to compute it on the fly but

we don't need to recompute the

normalization factor and the maximum

element for each row to apply the soft

Max that was already computed during the

forward pass and saved into this Matrix

M which includes the log sum exp of the

maximum of each row plus the logarithm

of the normalization factor with the

theug sum X trick we can just apply it

and it will also normalize each value

then we have the d uh V uh tensor that

we computed here with all the DI values

one for each Vector in the O tensor then

we need to pass some uh the number of

heads the sequence length the block size

that we want to use for the KV which is

the macro block size and the micro block

size is always the one that we iterate

on I think using this name it should be

easier to understand which one we are

iterating and which we want keep fixed

so the fixed one is macro and the

iterating one is the micro um head

Dimension um and later we will see why

we use a different block size to iterate

from because this is related to the

number of stages that Tron can divide

your for Loop into thanks to soft

pipelining then we have head Dimension

the stage indicates if the attention

that we computed in the forward pass was

Cal or not

cal um the number of warps and the

number of stages which we defined as

fixed but later we will talk about Auto

tuning so uh sometimes I repeat the same

stuff over and over so I should change

that

um okay let's write the signature of

this function um

let's put it here but

so we already described what is the

signature of this function let's go

directly to the meat so the first thing

that we need to do is understand the

offset by which we need to move this

query key and value and the offset is

given by the first of all we need to

enter the right batch and the right head

inside of each batch we compute the

index of the batch just like during the

forward pass by dividing the program the

program index which is a multiplic of

the index of the head and of the the

batch we divide it by the number of

heads to get which batch this program is

working with and to get the head we just

do the modulus just like in the for Loop

uh forward

pass um the offset batch head indicates

let me check what is it for okay it

enters the right batch and the right

head so what is the stride if you

remember correctly The Stride tells us

how many items you need to skip in that

Dimension to arrive to the next index in

the same dimension so if we want to skip

index number of batch we need multiply

it by the stride batch which is how many

elements you need to skip to arrive to

the next batch plus we also need to

enter the right head so we multiply um

the index of the head multiplied by The

Stride of the head to enter exactly in

that head in the tensor for each of the

qk and V

matrices uh plus we also have this is

will be used for if I remember for m and

d because m and d only don't have the um

the head di head Dimension so they are

only bch size number of heads sequence

length so we just use the index batch

multiplied by sequence length because

for each batch and on each head we will

have sequence length number of item so

you can think of it at the stride to

move from one batch head to the next

batch

head uh or to the yeah so

uh let's move the pointers

and this was so we move the pointer q k

and V by the offset batch head because

we want to enter the right um batch and

the right head inside of these big

tensors and we do it also for d o DQ DK

and DV because they have the same shape

as a q k and V and d o also has the same

shape as Q so they have the same shape

so we move by the same uh by the same

offset all right

so and then we move m and d to move them

to the right starting point on which the

sequence of the current head and the

current batch and the current head

starts so they are pointing to the first

Vector of the sequence dedicated to the

current badge and the current

head and the same is true for qk and V

and the d o DQ DK

andb okay then we load some other stuff

because here we fix in this iteration in

this method we are going to do a for

Loop in which we fix KV and we iterate

through Q so we first need to load this

dibs block of KV and we do it as

follows as follows so we know we need to

load a 2d tensor so we need to Define

what are the ranges in the second

dimension of each um vector k and V that

we need to load and it's defined by the

this

um by this

Factor then we uh we want to understand

which KV block this particular program

is going to work with so this particular

program is going to skip some KVs that

will already be managed by other

programs that may be running in parallel

and how to understand what this program

should be working with in based on the

index of the program zero which is

defined on sequence divide by the block

size macro and if you remember block

size macro is the thing that we fix so

it's telling us this program ID zero

will tell us how many uh block size

macro KV are already being managed by

other programs so we shouldn't care

about them so we skip them so let's go

back here and this is the number of

vectors that we need to skip so our KB

Start From Start KB and how many we need

to load them well depends on what is the

block KV this block KV is equal to block

size macro so it will be

128

vectors so we Define our um tensors two

dimensional tensors that we will store

in the SRAM because in Tron every time

you load something you load it from the

hbm into the SRAM so we Define where

they should be saved in the SRAM and

they are initially zeros and now we load

them so we load them as as

follows um we say that okay in the K uh

in the K tensor pointer which is already

pointing to the right index to the right

batch and to the right head because

that's something that we did

here we say we should need we need to

load the right sequence of keys which

should start from offs key because this

already includes how many We Should Skip

in the sequence length Dimension and for

each of these Vector we need to load all

the dimensions in

the in the head Dimension Dimension uh

because the K if I want to remind you is

batch number of

heads um sequence length and head

dim Now by using this line we are

skipping to the right B and to the right

head so it's like we already indexed

here and here we already selected an

index so right now this K is pointing to

the beginning of a tensor of two

Dimension and we tell okay we don't want

all the sequence we want some part of

this sequence which part the one that is

indicated by this start

KV um and how many of in the sequence

length we want well we want uh I I think

it's easy to write it like this so we

can write it that from start KV to start

KV plus block

KV uh so we want this number of tensor

exactly at this location and for head

Dimension what do we want to select we

want to select all the dimension so we

say that we want from zero to head

Dimension which is exactly this offs

dim

okay uh we do it for the K block and we

do it for the V block

uh here I think I didn't change the

comment this should be block

KV and this should be block KV before it

was called block kv1 right like in the

original code uh I simplified a little

bit the naming I think this one is

better easier to follow because in the

original code they also do for two for

Loops but in the second for Loop they do

it backward just to not change the

structure of the loops but I think mine

is more verbos but easier to

understand and probably less efficient

mine is much less

efficient um then we have offc because

we need to understand for each block of

queries how many vectors we need to load

and it's indicated by this offs q and

how many are them it's a block Q block Q

in the col color of this method was

block size micro so it is 32

vectors okay um now we need to access

Q vectors and O vectors trans uh no Q

vectors but already transposed and the O

vectors also we need to access them

because we are going to iterate through

queries and O vectors actually also why

because let's look at here let's look at

the formulas in the paper to compute VJ

so to compute the dvj that's what we are

trying to compute here we need to

iterate through all the D vectors and to

compute DK we need to iterate through

all the QI uh

vectors because the QI is a block of

vectors uh so uh that's why we need um

and why do we need to access a q as a

transposed because we need to compute

let me show you here uh P transpose that

to compute P transpose we need to we

need the Q trans transpose because the p

would be the soft Max of the query

multipli by the transpose of the

keys after we apply the soft Max it

becomes P but if you want the transposed

of P then you need to do query

transposed K multip by query transposed

so that's why we accessed query

transposed instead of

queries and the way we access query

transposed is just by playing with the

stride so let's do it like this and I

have also written the comment on why we

can do

it so this is equivalent to accessing

the

query uh how many first okay what is

this um what is this operation uh what

is this operation here this is saying go

to the query starting point starting um

pointer to the query which is already

pointing to the right batch and to the

right head for which this particular

program should work with and select a

two-dimensional Vector where you repeat

the query starting point along the uh in

this case along the columns but we

should be repeating it along the rows

because we want to select rows of

queries however if we want to select uh

the query transpose we just invert the

two Dimensions so this is a let me

actually show you without doing the

query transpose so let's do it

simplified like this so to access the

query um

the query pointers without transposition

we can just do like this go to the query

tensor and create a 2d tensor where in

the rows you put the starting point of

each query that you want to get and um

and replicate each of this pointer also

on the column that's the meaning of

adding this Dimension none this is

equivalent to when you do in py torch

the UNS squeeze like you are call

offs Q multiplied not un squeeze I think

one so this is equivalent to adding the

column Dimension to this tensor and

repeating all the values that are on the

on all the um on the columns how many

columns will be there it will be

broadcasted when we sum it with the dist

tensor here um this is a combination of

UNS squeezing and broadcasting so we are

taking the query vectors indicated by

offs

cuq um and um then we are uh for each

query Vector we are selecting all the

head Dimensions indicated by dim if you

invert this broadcasting it will create

the transpose of the the the query

Vector that you're trying to access so

this stuff here is equivalent to the

these two lines so accessing query and

then transposing

and uh it's something that you can do U

I could write down what is happening at

the pointer level so basically you need

to think of offc as being a vector of

pointers uh we multiplied by the

sequence stride which tells us how many

element we need to skip to go from one

query Vector to the next because each

stride Q will be the stride will IND

will be equal to in the case the head

Dimension is 128 The Stride of the

sequence Dimension will be 128 it means

that to go from one query Vector to the

next you need to um you need to uh uh go

forward by 128 elements because I want

to remind you that in the memory the

tensors are always stored like flattened

like each Dimension is flattened with

the next Dimension so imagine you have

three rows and four columns but the

first you will have the first three rows

then the sorry the first row so the

first four columns then the next four

columns then the next four columns row

after

row U it's difficult to visualize until

you write it down so uh how to write it

down take um create a vector of offs

skew so what is offc at the beginning is

is a range that is from here from 0 to

100 no 0 to 32 0 1 2 3 4 five 6 7 etc

etc we are multiplying each one with the

stride of the sequence so this will not

skip any element this will skip exactly

128 elements this will skip exactly

implying that the head Dimension is

128 uh this will skip two times 128

element this will skip three times 128

elements and then we are adding also

another Di di menion to this Vector so

this will be a vector then you broadcast

it on head Dimension number of columns

and to each of them you add one number

so it will become a vector like fall

okay let me just do it guys otherwise I

think it's too

confusing okay so we have a vector that

is as follows so zero then we have 128

then we have two * 128 then we have 3 *

128 etc etc we are adding how many

columns indicated by off dim so off dim

has how many columns it has a head dim

number of columns please for Simplicity

let's pretend it's not

128 Dimensions let's pretend it's four

dimensions so this will be four this

will be 2 * 4 this will be three * 4 we

are adding another dimension that is the

dim

Dimension each one multiplied by The

Stride of dim which will be one because

it's the last

Dimension um stride de so we are adding

how many columns

four uh so we are adding um one 0 1 2 3

I guess 0 1 2 3 right also to here this

one we are adding oh oh my God 0 1 2 3

and also to this one we are

adding 0 1 2

3 uh okay and then also to this one we

are adding 0 1 2 3 so what this will

select this will select from the

starting point of the pointer Q it will

select the element zero then the element

one then the element two and then the

element three which is exactly the head

dimension of the first Vector that we

should be selecting then it will

select uh the element four from the

starting point of the vector the element

uh sorry this one let me write the

result of this operation so this one

will be 0 1 2 three then it will select

the element four 5 6 7 then it will

select the element um eight I guess 9 10

11 and then it will select the element

12 13 14

15 so from the starting point of where

this Q is pointing it will select the

first element right after this que the

second element right after this que the

third element right after this Q etc etc

and this will be the f you can see that

this will be the first query Vector this

will be the second query Vector this

will be the third query Vector this is

the fourth query Vector because in the

memory they are stored one after another

they are flattened so in the memory they

are stored like this they are stored

like the

following they are stored like this one

after another so it will select all of

them and we also create a virtual tensor

with the right shape that we want to

visualize it into so as you saw as we

saw before when you work with a tensor

layout in memory you can always view it

as whatever shape you like based on the

shape that you want and the reshaping is

always free doesn't involve changing the

arrangement of the elements in the

memory I hope now it's more clear so now

we can proceed

further oh my God it was quite

complicated so whenever I get stuck I

just draw things and I think you should

do it too because that's the only way to

learn uh if you try to imagine

everything in your head it's always

difficult

and we do the same job for the O vectors

in the O vectors we don't access it as

access it as um transposed because we

don't need it in transposed only the Q

we need it

transposed okay it race through the

sequence dimension of the query so we

start from the query number

zero in the current um well in the query

we need to go through the all the

sequence length Dimension because only

the key we select the right key that we

want to work with so I want to remind

you here we fix the key and we go

through all the

queries but the query we need to start

from zero until sequence length so the

number of steps of this for Loop will

be uh sequence length divide by block Q

uh so if we have 1,000 elements in the

sequence and the block Q is the 32 it

will be 1,00 divide by 32 I'm bad choice

of 1,000 should be 1,24 otherwise it's

not divisible so then we go through each

block in this for Loop and we load a

block of Q the first one indicated by

our pointer and at the end of the

iteration we will move it to the next um

to the next block of

Q okay we'll add also the log log Su exp

values that are stored in the M Matrix

uh because we want to compute on the Fly

PT PT is the transposed of the soft Max

of query multiplied by the keys but we

want to not take quy multiply by the

transpose of the key and then do the

transpose we just already access um Q as

transposed so we can already compute a

PT instead of computing p and then

transposing

it um so we load the offsets of the

elements that we need from this log some

exp

x uh Matrix which is the m Matrix that

we computed during the forward pass and

we access a block of Q at a time the one

we are currently working with in the

iteration then we access a

query key transposed already so we do

the if you want to get a PT um P should

be um this is actually not P because we

didn't do the soft Marx it's actually s

t but okay if you want to get PT you

need to get the soft Max of

STD um the soft Max of St is what it's a

transpose of s what is s is a query

multipied by transpose of ke so to get

St you need to do um key transposed no

key multiplied by query transposed so as

you remember in the matrix

multiplication if you transpose the

matrix multiplication you need to also

invert the two uh element in the matrix

multiplication so that's why we are

doing a key multiply by quy transposed

this will give us s transposed we are

also scaling it with the softmax

scale before we apply the to apply the

softmax we just need to do the

exponential of each element minus its

maximum divide by the normalization

value but with the log sum XT we just

need to um each element um subtracted by

the M value which already includes the

normalization factor uh I think I

already did the derivation of this so we

don't need to go through that again okay

so now we have the PT block actually so

this this formula I should have written

St

actually okay um then when doing the

causal attention we also need to um mask

out some

values um so as you can see here so the

in this case the Cal mask is applied

after the soft Max has been computed

because during this one is a you are

used to uh compute the apply the soft

the Cal mask before Computing the soft

Max attention but this is actually

during the forward pass because you

don't want the normalization factor to

be affected by the element that should

be

zero uh but we already computed the

normalization factor so it cannot be

affected anymore so we can compute we

can mask out after applying the sofware

because the normalization factor has

already been calculated based on the

fact that we applied the mask and that's

why we we we can apply it after applying

the soft Max

uh so the mask is always the same so if

the query is more than the the index of

the query um so the mask is true in this

case for all the values that do not need

to be masked so all the values that do

not need to be masked are this ones here

uh and all the other value will be um

will be replaced with the

zeros all right so after we have the PT

block already masked we can calculate

the DV DV I will write I will point to

the right formula in the paper so we

load a block of the O why do we know to

load a block of do let's look at the

paper so how to compute the DV block so

the DV block is computed

as the old DV plus so a repeated sum as

you can see as you can see it's here

plus

equal the old DV plus PT so here PT

dropped indicates the P IJ after

applying the Dropout in this

implementation we don't support the

Dropout and also very few models

actually use the Dropout in the

attention um so PT multiplied by DOI so

a block of DOI and DOI is the same block

that should be also um DOI and Ki Qi are

referring to always the same block of

rows in the respective um t s that's why

because this inly iteration I indicates

a block of q and a block of O but

they're always referring to the same

positions in the tensors because d o has

the same shape as DQ so we go through

the blocks of query and d o

simultaneously because one is needed for

DV so for D DV we need d o and for DK we

need

q and that's why we compute the DV as

follows just like from the paper so PT

block multiplied by d o as you can see

it's a ppose multiplied by d o block so

we have comput computed the do block um

then we need to load the DI element that

we computed precomputed

initially uh the d uh with the first

call to the function called the

attention backward Pro

pre-process because we will need it for

Decay so let's see

um and how many of them we are loading

exactly the same number of query that we

load because um they are we load always

the same number of block size micro

number of

vectors okay I will copy some stuff and

explain it step by step so um the next

operation that we need to do is to

compute this DK to compute DK we need

DST to compute DST we need need to to

get

DPT so let's go one by one let's go from

the B from the end to the beginning of

this uh formulas so we don't uh we we we

where where everything is used to where

everything is created so let's start

from DK if you look at the paper DK is

equal to the old DK plus DS transposed

multiply by a block of

Q um and this is is what is written here

so it is plus equal means basically just

the old plus the new some um it's an

incremental addition so increment the

old K with some new stuff which is this

stuff here so the soft Max scale

multiplied because also there is a soft

ma scale this tow here multiplied by the

matrix multiplication between the St

block and the transposed of um and and

Q and Q you can see here this q but we

don't have a Q we have a Q transpose so

we take the transpose of Q transpose and

it becomes back Q now let's look at this

DST block DST is calculated as follows

so in the formula of the paper we have

DS d s is here it is yeah it is here and

it is equal to a block P J multiplied

element wise with DPI minus Di now um we

don't need the S we need the S

transposed so to compute the S

transposed this is an element wise

multiplication not a matrix

multiplication which means that when you

take the transpose of this operation you

don't need to invert anything you just

need to take the transpose of the two

operan so to compute the St we take the

transpose of P which is the PT and we

already have that and then the transpose

of everything that is inside of the

parenthesis so this

DPT minus di I where we inverted the

rows with the columns so this DPT is

what well in the paper we know the

formula for DP DP is here and it is

equal to D wait wait wait DP here and it

is equal to d o multi by V transposed so

uh but we don't need the DP we need the

DPT and in this case it's not an element

wise multiplication it is a matrix

multiplication so um in order to get not

DP DP but DPT we need to take the

transpose of these two operants of this

matrix multiplication and in the matrix

multiplication when you take the

transpose you need to also invert the

order of the two operants so we need to

take the VT transposed which becomes V

so the V block Matrix multiplied by the

other operand so D oi transposed and

that's why we are doing the transpose of

d o

um right now I'm not going through all

the single pointers because I already

told you how to check what a pointer is

pointing to and what an offsets is

referring to I hope that now you have a

better understanding on how these

pointers work in uh Triton which is also

the same way in the in which they work

in Cuda because um in the GPU we only

get a pointer to the starting point uh

to the starting address of the tensor

and then we need to work out all this

indices we have computed the DK block

block so we now go to the next query uh

to the next block of

queries and um so uh the next GL block

of queries because we are fixing K and V

blocks and we are iterating through all

the queries so we need to move the query

transpose

pointers uh by stride sequence which

means that how can we go from one query

to the next and we multiply with the

current um block Q which is a vector

which indicates the pointers to the

current element in Q that we are

accessing and we do it also for the O

and we use the block Q as element and

the stride Q because the O and Q all

have the same

shape okay after we have run the for

Loop of all the queries we can store

this DK and DV block so we write it back

as

follows and this is the end of our

function guys so so we save the DV block

exactly in the position inside of the

current okay DV is already I believe

pointing to the right batch and to the

right head because we incremented it

here and also in the case of Decay then

we need to tell it in the sequence

Dimension where they should save this um

block of K and V and this is indicated

by this one we say and we create the the

pointers just like before guys don't

make me do it again it's a really really

um easy if you write it down like you

write this uh um Vector of key and

values pointers which is not pointers

actually they are a range of the of key

and value that you need to take from the

sequence

Dimension you add another dimension that

is the column so you repeat each value

in the columns and then you add the

dimension here for the head Dimension

anyway after we have Cal at the pointers

where we should store the DK and DV we

store them in the um the pointers of um

we store them in the DV um uh I mean we

store them in the DV uh tensor and the

DK tensor what do we save we save the DV

block and the DK block which is the one

that we were um uh incrementally

changing in the for Loop that we have

written okay now that we finish this one

we can go to the next function that will

do the other for Loop so let's do it

okay so now we do the second part of the

iteration which is this one so let me

just copy it and then we we we describe

it uh let's write it here okay we use

the same lunch grid as before of course

we need to declare this function and

again we um we because the grid is

defined for the block size Macro for

what is the thing that we keep fixed and

then we in the side of the for iteration

we do um steps of block size micro in

this case we are fixing q and we are

iterating through K

andv uh because we need to compute DQ

right now we have computed DK and

DV okay uh the I believe the arguments

are the same as before so and actually

this is also the reason why in the

original implementation on the Tron

website the author decided to um to use

the same for Loop but different

arguments and uh I believe it was a

little confusing so that's why I just

separated them I just repeat the code

twice it's the goal of this video is to

be as easy to understand as possible not

to be as efficient as possible

so uh let's go uh here so let me copy

the signature again and we Define this

function here okay so uh again we need

to first move the query key and value uh

to the right pointers which will point

to the exact batch and the exact head

that we are working with in this program

so um let's do it let me check where is

the code here and the first part is

exactly the same as the other for Loop

that we have

written so let's go here and really is I

just copied so it's exactly the same so

we check what is the index batch head we

move the query key value pointers to the

right place the d o DQ DK DV point to

the right place the m and d to the right

place exactly like before so I don't

think I need to explain that

again and then we load a block of Q the

one that we will keep

fixed so

DQ let me load a lot of stuff here

actually

okay we Define the offset that we will

need to load the blocks of k&p in the

head Dimension because we are going to

iterate in the k

andv um we will access them as

transposed blocks so instead of

accessing them directly as K and v We

access access them as KT and VT and you

know that that's possible just by

changing the

strides uh in this case because we are

treating them as 2D vectors we treat the

offs KV when you want to access K as

just not transposed but K you treat this

offs KV as a row Vector uh sorry a

column Vector so you repeat on the rows

each k um offset that you want to access

in this case we are repeating it as a we

are treating it as a row Vector so it

will be repeated on the

rows um sorry it will be broadcasted on

the column Dimension and that's how you

can access the transposed version of K

and how you can access the transposed

version of V another thing that we are

doing is we are loading the Q Vector

which Vector well based on ofq which is

based on the start Q which is based on

the exact starting point in which this

particular program should be working

with because this particular program

works as two

Dimensions uh the First Dimension

indicate which batch and which head this

program should be working with and the

second dimension which is the program

index number zero indicates which among

all the sequence length which query this

particular program is going to work with

this is indicated by the index block um

this should be actually Q in this case I

forgot to change the name so actually

let me change it so it's index Q because

we start we skip some Q how many Q we

skip um based on the index of the

current program multiplied by how many

uh blocks have already been processed by

the previous

programs this this will tell us inside

of the sequence length what are the

queries that this one needs to select so

that's why we use the start query plus

the range that is block Q so imagine the

starting query for this program among

all the sequence length is 100 then this

will load the query row 100 1001 1002

blah blah blah until 100 plus block Q

minus

one this is the range that we of the

query vectors that we will load in this

program

uh we load the block of Q by using a Q

Plus the offset repeated on the columns

so we treat it as a column Vector but we

repeat broadcast it on the rows

Vector um where each uh column will be

one head Dimension multiply by The

Stride in this case we actually can also

not multiply by The Stride because the

stride in the uh Dimension Dimension so

the last dimension of the batch is one

uh because to go from one

um actually the stride um how it is

defined the stride of the last Dimension

is always one because to go one element

to the next element you should move to

move it to by one

element um so we load the DQ which is

the stuff that we are going to compute

in this

um iteration and then we have D that we

need to load and the do use the same

offset as Q because D and DQ have the

same shape and they work in the same way

so we load a block of Q and we load the

corresponding block of

O of d o in this case and d o has the

same shape as o which has the same shape

as Q Plus we need to load also the M

normalization factors which are in the M

Matrix which one the cor the one

corresponding to this particular group

of queries that we are going to work

with in this particular program

we start with the uh offsets uh are the

as you can see the offsets are the first

block of KV starting from the zero

position so because we will iterate

through all the KVs and we start from

the zero KV so the key key Vector zero

and the V Vector zero and then we will

move by block KV number of vectors

forward at each

iteration I hope I didn't go too fast

because most of the things that are

written here are very similar to what we

have already done um in the the other

for Loop so I don't want to be you know

repeat myself too much um what it matter

is actually the the formulas that we

will use which is exactly the one in the

paper

so uh we go through this GL blocks of

keyway KV we load the first block of KR

transposed and V transposed which is

loaded like this as usual you tell it

what pointers the elements you want to

load and what are the pointers of the

another element that you want to load

and it will load the the block that you

are asking Tron to load inside of the

SRAM so this stuff all reside in the

SRAM and also Q resides in the SRAM and

also d o reside in the

SRAM um then we compute the query

multiply by the transpose of the keys

because we need to compute the P block

so the query the qk block is just the

query in the current query block with

the K transposed in the current query

Block in the current key

block um why but we access query the

keys already as transposed so we don't

need to transpose it and anyway even if

we did if we need to transpose it it's

just um it's not um doesn't require any

computation to transpose a matrix we

just access it in a different

way uh because in the memory layout it's

always stored kind of as a flattened

array um then we computed the P block

which is the output of the soft Marx so

each of the query key we substract the

log sum exp value for the this uh block

of queries that's why for loading the M

Block we use the offsets of the queries

that we are

loading uh and as you remember the M

Block already includes also the

normalization factor because each m is

actually the maximum value for each row

plus the logarithm of the normalization

factor that when you apply with the

properties of the exponential it goes

into the

denominator okay and then we apply again

the auto regressive

masking

oops what did I

do let me go back to the code here so we

have the stage this one so when we

launch the back backward pass stage

three indicates that it's a the in the

forward pass we computed the caal

attention and the one indicates that we

computed the non-causal attention so if

we computed the causal attention in the

forward pass we also need to mask out

these elements in the backward

pass so we check um we created The Mask

which tells us which index uh this mask

is true for only for the elements for

which the query index is more than the

key index and if this is true then we uh

we don't mask otherwise we

mask um let's compute the next operation

which is to compute DP and DS actually I

let's compute directly DK and then we

explain it like before so we start from

the end and we go to where this stuff

what is needed to computed so if you

look at the formula uh let me check

check this one we don't need I think

okay uh let's go here to the iPad okay

what we are trying to compute here is

DQ so DQ as you can see in the paper is

DQ is equal to the old DQ plus to which

is the soft Max scale which is this

stuff here m multiplied by the matrix

multiplication between DS and K block so

the DS block is here and the K block is

the transpose of the KT block because we

are accessing K already as a transposed

block we could also access K directly as

not transposed block by inverting if you

want to access it as a transpose block

just do like this like here none this

will treat it as a row vector and

broadcast along the columns otherwise

and also this one you need to change so

this one you should need to change

because this one you need

to treat it as um a column Vector the

dimensions but if you want to access it

as a k transpose then you just inverted

these two operations I hope I didn't

mess up anything so let's move forward

um so okay we know that the formula for

the uh DQ is exactly the same as the uh

as the paper one but what is this DS

block let's look at the paper this DS

block is coming from this stuff here so

this I believe this stuff here d s which

is a

pi the P block element wise

multiplication with DPI minus di which

is DPI minus d i now what is um the this

P block the p block is exactly the

output of the soft Max which we already

have what is the DP block well the DP

block is exactly d o multiplied by V

transposed which is d o which we already

loaded and it's here multiplied by the

transpose of the V which we already load

as transposed and this is how we comput

the

DQ uh let's Inc then of course we need

to move to the next uh block of uh KVs

so we increment the pointers just just

like before so we move to the next block

of keys and

values and also we move the

pointers um just like before and then we

need to store the result of DQ and this

way we only need to do one right to the

hbm by dividing the for Loop like the

following so if you look at the original

algorithm uh I I don't know if the

original algorithm actually corresponds

in to to the implementation that they

did in Cuda but I don't think so because

it would not be so

optimized but in the original algorithm

in the paper they say that you need to

go through all the keys and then while

going through the keys you need to go to

all the cues and for each queue that you

visit then you need to write back the

queue while you are updating it which is

not optimized that's why we needed to do

two for Loops one in which we fix the

query and we updated the keys because

each Keys update depends only on

particular block of Que

uh on all the blocks of Q sorry and then

we fix the queries and we iterate

through all the keys because one block

of Q depends on all the blocks of

K and this is why we split and this is

the second Loop that we have

written now we have written everything

that we need to for flashh attention um

the forward pass and the backward pass

so uh we should be ready to uh launch

the uh the kernel I hope I didn't make

any mistake in copying the

so I don't think I will try to launch it

and if there is any any error I will

just use my reference code which I have

already written that I used as a copy

the only difference up to now between my

reference code and the one that we have

written is the autot tuning which I

didn't explain so let's talk about the

autot tuning so the autot tuning is also

something that was already present in

the original paper and I kept it as is

uh I removed the autot tuning for the

backward

pass but in the forward pass you if you

check there is this code here that

indicates um the autotuning

configuration for Tron so Tryon

basically um cannot know uh beforehand

what is the best block size or what is

the best block size for the query or

what is the best block size for the key

and values or what is the best block

size for another dimension that we have

we need to try based on the hardware

that we are running on based on the

availability on the SRAM based on the

thread quaren that Tron can apply so I

didn't talk also about thread quaring

basically in Cuda you can choose if each

thread does one Atomic operation for

example in a matrix addition each thread

is doing one addition of one particular

element of the output Matrix or it's

managing multiple elements this is

called thread quaring and I think I

didn't check the documentation but I

believe Tron does it for you based on

the Block size that you give it and the

number of warps that you want the number

of warps is what is um a block of

threads of 32 threads that work Cooper

cooperatively running the same

instruction always at the same time the

number of stages is more interesting

it's an optimization that Tron does um

basically it is not Loop and rolling so

actually let's talk about uh let's talk

about software pipelining because this

is the last part that we need to

understand from this code which is the

aing so I believe that most interesting

part here is not choosing the block size

Q and The Block size K because that is

just kind of you try whatever whatever

configuration works best based on the

timing CA Tron will actually run all

these configurations for you every time

the sequence length or the head

Dimension changes and for every pair of

head Dimension and sequence length it

will choose the best configuration that

runs in the least amount of time that

gives you the best throughput actually

um so let's look at this num stages what

is it and how it works um so let's do it

okay so uh software pipelining

is is used when you have kind of a for

Loop so you have a sequential operation

in which each iteration does not depend

on the previous iteration so the

operations that you do in one iteration

are independent from what you have done

in the previous iteration which is more

or less what we have done before in our

for Loops actually there I believe there

are um uh how to say conditions in which

this doesn't have to be three so like

the operations can depend on each other

and you still can do software

pipelining so for example imagine you

have the following for Loop uh for Loop

that Row from one to n and first you

load some data then you load some other

data then you do a matrix multiplication

and then you store some data so here you

are reading data here you are reading

data here you are Computing some stuff

and here you are writing data if we look

at what happens at it each iteration we

will see the following picture imagine

our GPU is made up of a compute unit and

a a unit that is dedicated to loading

stuff so reading from the memory or

writing to the

memory what we will see in the time

scale is that at the first iteration

first we are reading some data and the

compute unit is Idle because we need

this data then we are reading some more

data and the computer unit is Idle

because we need this data then finally

we have enough data and then we can

compute this operation and the reading

unit is idle and then we are writing

some data back to the memory and the

comput unit is again idle and then it

will be idle for another two time steps

until it has enough data to run the

computation so as you can see this is

not very efficient because at any time

point in time there is only one unit

working and the other is a sitting Idol

so one way to uh optimize this for Loop

is to do software pipelining and you can

tell try to do it for your for Loops by

telling it how many stages you want so

let's see how it works so to pipeline a

for Loop means that first of all you

need to convert all these operations

into async operations and in Cuda at

least in the GPU of mvidia there are the

async loading from the memory and the

async load writing to the memory which

basically means that I spawn a load

operation and after and um when I I only

I check if it has completed when I

actually need it so I will spawn this

operation and this instruction will

return immediately and move to the next

instruction here I will spawn a load

iteration and this will return

immediately and move to the next

instruction and then I can uh compute

but before Computing I just check if

these two operations have completed so I

can spawn immediately two reads and then

I just check if this uh they have

completed so with software pipelining

what we are doing is we are pipelining

operations of different iterations into

a single iterations so first basically

what we will do is we will do the read

the first Matrix that we need for

computing this matrix multiplication

then at this next iteration we read the

um the uh we we read the first Matrix of

the second iteration and also read the

second matric of the first iteration so

I call it read a and read B which

indicates read the first Matrix of the

uh that we need and the b means the read

the second Matrix that we

need uh all these operations are

asynchronous then I launch another

asynchronous operation at the third

iteration that says read the um the

first Matrix of the third iteration and

then read the uh second Matrix of the um

of the second iteration and then compute

the matrix multiplication because at the

third iteration this one and this one

should have completed but but while

Computing the matrix multiplication I

don't keep the loading unit idle because

they are still Computing this this and

this load this can only work if you can

spawn async

operations so at the third iteration I

can compute this matrix multiplication

by using this one and this one because

they should have finished but while I'm

Computing the matrix multiplication I

already spound some async operations to

load the data necessary for the second

iteration and the third iteration so at

the fourth iteration I will spound the

loading of the data for the fourth

iteration loading the data for the third

iteration while Computing the matrix

multiplication of the second iteration

because they should have already

completed by now uh actually it's not

like we expect them to have been

completed there are um Primitives in the

language or in the Cuda language to

check if the operation has completed so

actually before doing the matric

multiplication we will actually check if

the asyn operation has finished so it's

not like we just expect it to be

finished with respect to time uh this is

like in JavaScript you have these things

called prom promise I remember and you

can wait for the promise to be finished

before you actually need them but you

can spawn as many promise as you want in

C I think they are called tasks so you

spawn as many tasks as you want and then

when you need the it then you just wait

for them only the one that you needed

while the other are still running in the

background asynchronously this is the

whole idea of software

pipelining um software pipelining as you

can see only works when you have a sync

operations and also it increases the

memory requirement for your program

because when uh matrix multiplication

one uh is going to run we may have

enough data for the first two iterations

plus plus half data for the third

iteration so we increase the memory

requirement for the um

SRAM okay and P Tron will do this

software pipelining for you it will

convert all the load all the stores and

maybe also the metrix multiplication

into a sync operations and do this

pipelining for you if you are confused

by how it works there is another easy

solution to explain you how it works

because it's already something that we

do in model training it is called

pipeline parallelism

so in pipeline parallelism it works as

follows we have a very big neural

network that does not fit in a single

GPU so imagine this neural network is

made up of three layers layer one layer

two and layer three but this is so big

it does not fit entirely in one simple

GPU so one way would be to put this each

layer into one GPU so we put for example

layer one into GPU

one a layer two into GPU 2 layer three

into GPU number three so imagine we have

an input for this Neal Network so we put

it to the first GPU the GPU one will

process the layer one and generate some

output which will be transferred to the

gpu2 which will calculate its own output

and transfer it to the gp3 which will

compute its own output and finally we

will have the output of the Nal Network

the problem is when you send the output

of the gpu1 to the gpu2 for the gpu2 to

do its own thing the gpu1 now is free so

it is a waste of resources we could

always should keep the gpus busy so what

one thing that we can do is instead of

sending all the the mega batch to the

gpu1 we send many smaller batches how

does it work imagine that we send the

batch number zero so patch zero uh to

the gpu1 the gpu1 will compute its

output and send it to the gpu2 so now

the gpu2 is Computing the batch number

zero so now the batch zero is not here

anymore but now the GPU one is free so

we send another micro batch called the

batch

one then gpu2 will finish processing the

batch zero and we'll send it to the

batch to the GPU number three so now the

GPU 3 has the batch number

zero and the gpu2 now is free so we

transferred and hopefully also gpu1 has

finished so we transfer the batch number

one from gpu1 to gpu2 the bches um and

then the GP one will be free so so we

transfer here becomes one and now this

one is free so because it's gpu1 is free

we can introduce another batch so batch

number two etc etc etc so we always

introduce when while moving one batch

from one GPU to the other we introduce a

new batch at the beginning of the

pipeline and they shift by one position

at every iteration this will keep the

gpus always busy uh there is only one

problem of the pipeline parallelism

which is the this bubbling effect

because to create this pipeline you at

the beginning of this um okay actually

in the pipeline paradism you also have

the problem of the backward step so the

backward step has to run exactly in

Reverse in the order in which you

receive the micro batches while in Tron

when doing software

pipelining you have the problem of the

prologue and the epilog because you need

to create this Pipeline and and

to start the pipelining and at the end

of the pipeline you need to uh use all

the stuff that is currently in the

pipeline so only in the beginning step

and in the last step of this for Loop

your um all the units of this GPU may

not be working

simultaneously which what does it mean

it means that in order to use pipelining

you want the number of iterations of

your for Loop to be much more bigger

than the number of stages in which your

iteration is divided into in this case

we have four stages these are called

stages so you want the number of

iterations to be much more to to be much

larger than the number of stages all

right guys finally I have completed the

video um I hope that you learned a lot

from this video I believe that we can

run the Tron code so let's run it

actually uh let's see I copied

everything I believe we also put the

code to test it but we didn't uh put the

uh main method which we can copy right

now I hope there is no error so I really

hope there is no error I really hope so

um let me check if I am in the right

machine I am so let's just run

program prey if there is an error I will

just copy my own reference

implementation but I hope it works

because otherwise I forgot something so

I'm running my code on an h100 uh

because my company has h100 uh if you

have a smaller uh GPU what you can do is

you can reduce the sequence length uh

you can reduce the BET BET size I think

is already one when we call it uh oh no

the BET size you can reduce the BET size

the number of heads the sequence length

you can even put head Dimension equal to

8 and sequence length equal to 16 uh

let's

check uh run backward Tron backward

return the incorrect number of grading

expected five got one

we probably forgot some return statement

I

believe

yes so I forgot the return statement

here so in the backward pass after

running the last for Loop we need to

return the stuff that we have

computed cross finger

again okay passed so the backward pass

that is computed by torch it is

equivalent to our backward patch up to

10 to the power of minus two error

absolute error um so when you as you can

see this backward that we run here is

different than the backward that we run

here because when you apply Tron

attention it will introduce a new

computation graph in the computation

graph of our tensors that will include

this Triton attention operator and when

pyto want to compute the backward pass

it will just call the backward function

of this Tron attention to computed and

it will populate the grad value of all

the tensors that are the input to this

Triton

attention and this is how pytorch

autograd Works guys uh thank you for

watching my video guys it has been super

super super

demanding I spent many uh months first

of all to learn myself about the Tron

about Cuda about Flash atten Etc uh also

I have a full-time job so it it's is

really hard to make videos like this

like I need to dedicate you know the

nights the mornings the the weekends I

spend three days just to record this

video because sometimes I don't like how

I explain something sometimes I make

mistake or sometimes I need to restart

because what I'm doing is wrong etc etc

um I believe there should be no big

errors in what I have done so far but

for sure my notation is completely bad

like because all the mathematics I know

has been self-taught by I I learned it

by myself so because I didn't learn it

in Academia I have bad habits and I'm

trying to get rid of them so I use the

very bad notation sometime I call with

the capital letter sometime with this

lower case sometimes I just forget the

index Etc so I'm trying to solve these

problems um I believe I have explained

everything so I should be uh you should

have all the knowledge to derive all the

formulas that you see in the paper of

the flashh attention uh and you should

also have an internal image on how the

back the the the attention calculation

is working block by blocks I know that I

could have spent 20 hours explaining

things better but I also have a life and

I also have a wife so I I I I cannot

make a 100 hours videos um also there

were some interruptions making these

videos I I removed some wisdom teeth so

it took me at least one more than one

week to to to to recover because it was

so painful uh so thank you guys for

watching my video I hope you learned a

lot also this time I as you can see

Triton is something new there is not

much documentation so something that I

have said about Tron may not be totally

correct because really there is very

little documentation so all the Tron

that I have learned is by looking at the

code written by others and try to

understand

it

um and I think that's it guys so I wish

you a wonderful day and see you next

time on my channel

Loading...

Loading video analysis...