LongCut logo

YouTube Video

By Unknown

Summary

## Key takeaways - **Memory coalescing can give 4x speedup**: When all threads in a warp access sequential memory addresses within the same burst section, hardware automatically groups these queries together instead of making individual memory calls, yielding roughly 4x the memory throughput. [44:12], [44:51] - **Matrix multiplies are orders of magnitude faster**: GPU tensor cores are specialized for matrix operations; non-matrix operations are orders of magnitude slower, which is why most high-performance neural architectures structure their workloads around matrix multiplies. [21:22], [21:59] - **Tiling reduces global memory reads by factor of T**: The tiling strategy loads matrices into fast shared memory: each input is read from slow global memory only N/T times instead of N times, trading extra compute for drastically reduced memory bandwidth requirements. [52:55], [53:17] - **Online softmax enables tile-by-tile Flash Attention**: Flash Attention uses an online softmax algorithm that maintains running max and normalizer values, allowing softmax to be computed tile-by-tile without ever materializing the full N² attention matrix in global memory. [07:50], [10:06] - **Wave quantization causes performance cliffs**: When the number of tiles exceeds the number of SMs (e.g., 120 tiles on 108 SMs), some SMs sit idle while others finish, causing dramatic utilization drops; the transition from 1792 to 1794 matrix size shows this effect. [01:02:34], [01:03:08] - **Global memory is ~10x slower than on-SM memory**: On-SM registers and L1 cache take ~20 cycles to access while L2 cache or global memory takes 200-300 cycles, creating a 10x speed gap that makes avoiding global memory writes the primary optimization goal. [12:41], [12:59]

Topics Covered

  • GPUs Trade Latency for Massive Throughput
  • Memory Is the Real Bottleneck, Not Compute
  • The Factory Model: Why Operator Fusion Works
  • Tile Alignment and Wave Quantization Explain Mysterious Performance
  • Flash Attention Combines Tiling and Recomputation to Eliminate N² Memory

Full Transcript

So hopefully everyone's having a good time with assignment one. Uh it's due tonight. Uh let us know if you need an extension. Assignment due two is uh

tonight. Uh let us know if you need an extension. Assignment due two is uh coming out soon. We're putting on the finishing touches onto some of the the Triton stuff. Um hopefully you'll enjoy it. You'll get to implement um Flash

Triton stuff. Um hopefully you'll enjoy it. You'll get to implement um Flash Attention 2 or parts of Flash Attention 2 which I think will be nice.

So today we're going to talk about uh GPUs. GPUs are the thing that makes our language models go. So they're pretty critical to get right. Um and if you haven't really studied, you know, the hardware that makes you know your models

run, they can seem pretty mysterious. So my goal today is to try to make CUDA and GPUs less magic. Um and one of the things that I want to demystify, you don't have to understand the plot. There's a there's a lot on the slide I

know. Um you know why do GPUs get slow? And they get slow in very mysterious

know. Um you know why do GPUs get slow? And they get slow in very mysterious ways. You know I will try to talk through this plot um near towards the

ways. You know I will try to talk through this plot um near towards the end of lecture. As you increase the size of your matrix multiplies you might expect you know either gets slower or faster or whatever you get these very

unpredictable looking wavelike patterns and you're like why is my GPU fast at certain multiples of certain numbers and slow at others? Right? It's very

mysterious. We'll try to understand that. The other thing is we would like to understand how to make fast algorithms. Um I think almost all of you have heard of flash attention. Um it's the thing that makes you know much

longer context possible by very cleverly computing the attention operation inside a transformer. And so maybe you would like to you know come up with new

a transformer. And so maybe you would like to you know come up with new algorithms or new implementations like flash attention right like what primitives and what components do we need to understand in order to be able to do that right so those are kind of the two learning goals of today the

first one is you know by the end of the lecture you should feel kind of comfortable with GPUs you should kind of understand how they work and the second one is you should feel comfortable accelerating certain parts of your algorithms you make a new architecture you should hopefully feel like you can

try to to accelerate that with CUDA Um, and because hardware is not necessarily the domain in which I work, you know, there's uh special resources that I have to give a lot of credit to, especially Horus Heath's blog where he's got a lot

of fun GPU facts that you can learn about. For example, why are matrix multiplies that are filled with zeros faster than ones that are not filled with zeros. You can learn by going to his blog. There's also other resources

with zeros. You can learn by going to his blog. There's also other resources that I've drawn from like the CUDA mode group and the and the nice TPU book from Google. Um, if this topic interests you, you know, I'd encourage you to go and

Google. Um, if this topic interests you, you know, I'd encourage you to go and look at those resources to learn more because this is in in some ways like a shallow but hopefully you know complete coverage of the hardware. So today we're only going to

focus on uh, you know, nonp parallel parts of the hardware stack. So we're

going to study the GPU like a single accelerator in depth, how they work and some important parts. I'm also going to talk very very briefly about TPUs because in some ways they're very similar conceptually to a GPU. And so my discussion here is going to carry over. Um and then once we understand kind of

the hardware and execution model of the GPU, then we're going to try to understand what makes GPUs go fast on certain workloads, what makes them slow.

We're going to understand the performance. And in the last part, this is kind of going to be almost like a hands-on piece. Um I'm going to try to walk through flash attention, right? I'm going to take all the lessons that we've learned and try to walk you through flash attention saying see here's how it

all comes together. Right? So that's the last part um of today's lecture. So you know many of you have you know taken an NLP course and these

lecture. So you know many of you have you know taken an NLP course and these days in an NLP course I think you teach some amount of scaling laws and so you've probably seen this right and so this is just setting the context. Um we

know that having more compute is helpful for training large language models. Um

this is a pre-training scaling chart but you could replace this with an inference scaling chart if you would like. Um it's generally agreed upon that the more compute you have the more processing you can do on your data. You can ingest more data, you can train larger models, all of those lead to improved performance,

right? So you might think of, of course, you know, deep learning is really

right? So you might think of, of course, you know, deep learning is really important, but what's really driven performance is, you know, faster hardware, better utilization, improved parallelization, right? So that's kind of setting the stage of why hardware is important to understand. And of course,

you know, once you think about compute scaling, you ask, okay, how do we get compute scaling? How do we get our models to train faster? So kind of in

compute scaling? How do we get our models to train faster? So kind of in the early days you know of of semiconductor scaling um if you were thinking about okay are CPUs how do they get faster um they were they you know

would scale under something called Dennard scaling right um with Moors law you would sort of double the the amount of transistors on a chip every year um and if you have this doubling what you end up is um darded scaling where smaller and smaller transistors can be driven at faster and faster clock speeds

with lower and lower power um which in turn give you more performance right and in the in the 1980s to 2000s, this sort of tapped out. You can kind of see in this chart here by Hennessy and Patterson, that single thread performance, that's the blue dots here, um that basically started to taper out.

Of course, the number of transistors didn't really, you know, start falling off. You did have, you know, chips with higher and higher transistor densities,

off. You did have, you know, chips with higher and higher transistor densities, but that wasn't helpful. It wasn't giving you higher uh throughput on single threads. Um and so this means that we can't just do computation faster

single threads. Um and so this means that we can't just do computation faster in absolute terms. You know what we have to make up for it with is parallel scaling, right? So the story of scaling for deep learning and neural networks is

scaling, right? So the story of scaling for deep learning and neural networks is going from single thread scaling which is just doing your computation faster in absolute terms um to parallel scaling where you have a lot of workloads that

are all computed at once. Um, and this is is one of my favorite, you know, uh, sort of compute scaling charts by by Bill Dowy in his keynote. Um, where, you know, he's showing the super exponential increase in the number of, um, sort of

integer operations per second, um, going from, you know, the earliest K20s um, to the H100, right? And it's kind of like this really remarkable uh, exponential or super exponential curve. And so, you know, we have to really understand how

to take advantage of this curve in order to really get the most out of our language model, right? So, that's kind of going to be our goal. And so, I've already hinted at this this kind of important difference, right? CPU is is something that I think

everyone's familiar with once you sort start doing programming, right? It's

this execution model of you have a program, it goes through and in a single thread, it executes step by step what's happening. And in order to support that kind of an execution model, what what do you need? Well, you need big control units. You just need to generally run these things very quickly because you

units. You just need to generally run these things very quickly because you have a lot of branching and you have a lot of conditional control logic, right?

So a CPU, this is a abstracted diagram is going to dedicate, you know, a lot of of its chip towards, you know, large control branch prediction and it's going to run these, you know, very quickly because it doesn't have that many

threads. you know, there there are CPUs with lots and lots of cores now, but

threads. you know, there there are CPUs with lots and lots of cores now, but compared to a GPU, it's almost nothing. And so in in contrast, the GPU has really tons and tons of compute units, ALUS, right? So there's the little green

boxes and there's much smaller amounts of the chip dedicated to control. So

there's a little bit of control logic sort of orchestrating tons and tons of compute units, you know, operating in parallel. Um, and I think mentally, so this is kind of the the picture of what is being emphasized in in a CPU versus

GPU. But if you kind of look at what's the design goals are, they they designed

GPU. But if you kind of look at what's the design goals are, they they designed for very different sort of goals. So you can think about CPUs as optimizing for latency. I want to finish my tasks as quickly as possible. So if I have tasks

latency. I want to finish my tasks as quickly as possible. So if I have tasks um T1 through uh T4 here on on the right side, you know, in a CPU, I'm going to try to finish each task as quickly as possible. And so if you want any one of

these tasks to be finished quickly, T1's going to complete really quickly. In

GPU, you're optimizing for high throughput. Like I don't care about latency. I just want all of my tasks that I have in aggregate to complete as

latency. I just want all of my tasks that I have in aggregate to complete as quickly as possible. And to support that, you know, maybe you have lots of threads and these threads can go to sleep and wake up very quickly. Um, and

in the end, you know, you you finish all of your workload T1 through T4, you know, before the the CPU one does, even though individually all of these have sort of higher latency, right? So they have different sort of design principles um and design

goals. Okay. And so um a GPU has a has a pretty different anatomy. And I don't

goals. Okay. And so um a GPU has a has a pretty different anatomy. And I don't know if you know you all have ever looked at what a a GPU sort of layout diagram looks like. I'll actually show you the the chip um figures in a moment

here. Um but the core idea and this is important conceptual concepts behind a

here. Um but the core idea and this is important conceptual concepts behind a GPU. um is that a GPU executes you know uh many many SM streaming multipprocessors and a streaming

GPU. um is that a GPU executes you know uh many many SM streaming multipprocessors and a streaming multipprocessor you can kind of think of as an atomic unit when you're programming in something like Triton they're going to operate at the level of

a of an SM and within each SM they're going to uh it contains many SPS streaming processors and a streaming processor is going to execute a whole bunch of threads in parallel so one way to think about it is SM has a bunch of

control logic. It can decide what to execute. It can do, for example,

control logic. It can decide what to execute. It can do, for example, branching. SPs are going to operate to to take the same instruction and apply

branching. SPs are going to operate to to take the same instruction and apply it to many different pieces of data, right? And so you can do tons and tons of parallel computation um under this model. An SM is sort of each granular

unit of control. SP can do a lot of computation individually. And if you look at an A100, which is the previous generation GPU at this point, um you've got 128 SM, you know, that's a lot more than than most cores for for CPUs. And

each of these SM is going to have a very large number of uh SPS and specialized sort of matrix multiply units um inside

them. And so that's kind of the the compute model. Uh was there a question? Sorry. Yeah.

them. And so that's kind of the the compute model. Uh was there a question? Sorry. Yeah.

to get the slide before GPUs. So is this GPU the same as a GPU? So the question was is this GPU the same as that GPU? Yes. Like this is a this is a kind of cartoon version of of this. You can kind of think of each row as being SM. It's

got its own control units. Each green block might be sort of uh one of these green blocks here like a SP32 uh sort of processing unit inside of it. And each

SM can sort of you know operate various pieces that it that it owns like the tensor cores to do computation. Cool. Okay. And there's going to be uh two important things. You think of GPU

computation. Cool. Okay. And there's going to be uh two important things. You think of GPU as you know computers they compute but actually computation is only one of the two important things we have to keep track of right. memory is arguably more

important um at this point and it will continue to be more important in terms of the performance profiles of how we run our programs on the GPU. Um, and so to understand memory, you kind of have to understand the physical layout of the GPU and the chip because in some sense the, you know, when you're operating at

such fast speeds, the physical proximity of the memory starts to matter quite a bit. And so I will show you kind of the physical proximity of how things are

bit. And so I will show you kind of the physical proximity of how things are laid out and how that relates to how you should think about memory access um, and performance. So the closer a piece of memory is to to each SM um, the faster

performance. So the closer a piece of memory is to to each SM um, the faster it's going to be. So there's going to be certain very very very fast kinds of memory like L1 um and shared memory and that's going to live inside uh of the

SM, right? And that's going to be really fast, right? Things like registers,

SM, right? And that's going to be really fast, right? Things like registers, things like things you're reading and writing very frequently, you're going to want to put into the L1 and shared memory L2 cache. Um as you can kind of see, there's these green areas which are SM. And then there's these blue areas.

This is on the GPU chip, right? These are L2 uh memory that's sort of right next to the SM, right? So they're they're not inside the SM, but they're physically still quite close, right? Um and these are still pretty fast. Um you

know, they're they're still a factor of 10 slower, but they're still reasonably fast. And then outside of the chip itself, this is sort of a you know, I

fast. And then outside of the chip itself, this is sort of a you know, I think this is like a 3090 card or something like this or maybe a PCIe 100.

Oh, this is a PCI 100. Um you know, you've got your your GPU here and you've got actually DRAM sort of living next to the chip, right? it has to actually go physically outside of the chip um and connect. And you can kind of see on on

this uh chip diagram here, these yellow connectors at the edges. These are HPM connectors. Um these are connecting to the the DRAM chips that are outside of

connectors. Um these are connecting to the the DRAM chips that are outside of the actual GPU. And you can kind of see the the speed that it takes to access these, right? the onsm memory is much much faster like 20 clock cycles to

these, right? the onsm memory is much much faster like 20 clock cycles to access something from there whereas it's going to take something like 200 or 300 clock cycles to access something from the L2 cache or global memory right and

this factor of 10 is going to hurt you real bad right so if you if you have um a piece of computation that requires you to access global memory right it might mean that you actually run out of work to do on your SM you've multiplied all the matrices you've run out now you just have to idle right so utilization won't

be good and this will be a really key theme thinking about memories in some sense the key to thinking about um how GPUs work and in assignment two you're going to you know actually be writing um high performance code for a GPU so you have

to actually think about um the execution model of how a GPU actually executes things um and this is somewhat complicated but not not insanely so um the there's sort of three granularities of things that you need to think about

there's blocks there's warps and there's threads and that's the order in which kind of the granularity narrows down, right? Blocks are kind of these big groups of threads and each block is going to be assigned to a SM. So think

about this as each SM is kind of a worker. It's its own autonomous unit and a block is going to be assigned to an SM for to to process, right? So this is each granular unit. Now then within these blocks are a whole bunch of

threads. Each thread is a sort of a piece of task that needs to be done. And

threads. Each thread is a sort of a piece of task that needs to be done. And

when these threads execute, they're going to execute in groups. Um, and this is a thing called a warp, right? So you take a block, which is a collection of threads, and you're going to take, you know, threads from that block, and they're going to execute in groups of 32 consecutively numbered threads, um, each

time. And that's sort of called, you know, warps. And so you can kind of see

time. And that's sort of called, you know, warps. And so you can kind of see at this diagram here what's happening. You've got a bunch of blocks. Each block

is assigned to a different SM. And within each block, there's going to be many different warps. And each warp is going to consist of a whole bunch of threads and all of these threads are going to execute um the same instruction

on different data, right? And so this is kind of the execution model. Right? Now

it's going to it seems probably mysterious, you know, what these blocks and warps and threads are. They will have important implications uh for our performance in how we design things like uh CUDA kernels um later. So hopefully

you can kind of remember this. I'll I'll refresh your memory um kind of as we go.

Um hopefully that's clear. So that was the kind of logical execution model of a GPU. Um, and if you understand that, you kind of understand how GPUs execute things. Um, there's

also a logical sort of memory model of a GPU. So, you know, now I'm not showing you the physical hardware. This is just kind of how you think about the programming um, of a GPU. And so there's registers. So these are really fast, you know, storing single numbers type storage. You've got local memory, you

got shared memory, and you've got global memory, right? And that that increases in sort of the the memory hierarchy gets slower and slower and slower. Um and

your code can sort of write to to global memory. It can also write to constant memory which is not something that's that's used too often. Um and so each thread can access you know its own register and shared memory. But

information that goes across blocks need to be written to global memory. This is

actually quite important right. So now it means that you know whenever you write a thread that executes something ideally it's operating on sort of the same small amount of data. So you load that small amount of data into shared memory all the threads are very happy accessing that shared memory. It

terminates it's done right. That would be a great execution model. Instead if

you have a thread that needs to access data all over the place you know that's going to have to access global memory that's very very slow. This theme will come back you know as we talk about different ways of of operating uh on a

GPU. Um hopefully that's clear. Um that's kind of the the very you know

GPU. Um hopefully that's clear. Um that's kind of the the very you know high level four slide overview of a GPU. If you have kind of questions about how any of that works, feel free to ask me um as I go on. Okay, so here's a side thread. Um

last year I didn't cover this because I think resources on TPU was a little thin. Um but the nice TPU book or internet website that I mentioned at the

thin. Um but the nice TPU book or internet website that I mentioned at the start of the lecture came out. um and that has actually a lot of nice details and I and I talked to a few Google people about the TPU um and at a high

level it's very very similar um to a GPU and so I want to just talk for a moment about TPUs you may never you know operate on a TPU but I think it's important to understand that these alternative accelerators operate in in

in many ways very similarly um so here's a diagram of what a TPU looks like um there's kind of a so there's something called a tensor core and mentally you can think about a tensor core as being similar to SM or streaming multipprocessor. Each of these

are are kind of its own atomic units that can operate on data. There's a

scalar unit which is basically a control unit and it can also do CPU like arbitrary things. You've got a vector unit that can operate on vectors. So if

arbitrary things. You've got a vector unit that can operate on vectors. So if

you got a vector and you want to operate entry-wise on it, that's a good place to do it. And then it's got a very big specialized, you know, part of the chip

do it. And then it's got a very big specialized, you know, part of the chip dedicated to just doing matrix multiplies called the MXU. Um, and then it's got very fast memory for vector memory and SME. Both of these are very

fast onchip or like on tensor core memory and then there's high bandwidth memory that lives outside of the chip. Right? So hopefully you see the similarities to an SM, right? There's slow memory outside, very fast memory inside and there's specialized hardware to do matrix multiplication. Core

structure is very much the same. Um the difference is um I'll talk about this in the parallelism lecture next week. You know how the accelerators are together is a little bit different. Um and then also you know mention I didn't notice I didn't talk about warps. I didn't talk about any of that other stuff. Um,

tensor cores are in some ways very simple because they're optimized to just do matrix multiplies, right? Like the tensor core unlike the GPU doesn't attempt to do anything but that. And so that's in some ways very very simple. Uh

much simpler in architecture but conceptually doing the same thing. Yes.

Is it tensor also in some ways optimized to general tensor or this is just enough to work on?

Yeah. So, so the question was you know is it called tensor because it can operate on arbitrary tensors. Um so it can operate on arbitrary tensors like can do the indexing the operations that MXU performs is a matrix multiply and so

it would always be like a batch matrix multiply operating on a tensor. So it's

kind of both a yes and a no answer if that makes sense. So they operate on tensors but the operations they always perform are matrix multiplies not more complicated tensor operations that you can do. Cool. Um the reason why the GPU has

been so successful is that you know it scales up really easily. If you want more processing power just add more SMS right you don't have to worry about driving the clock faster and getting more heat dissipation problems. Um

programming wise um CUDA is intimidating but it's actually you know not as horrendous to program because of the its programming model like the way it works is within each SM right you have a thread and it executes the same instruction on a bunch of different pieces of data right that's conceptually

sort of easy to reason about you can think through what that means and especially it's nice if you're operating over a matrix and you're doing sort of very simple operations it's exactly this kind of simp model um finally each of these threads are very lightweight and they can be kind of stopped and started

at any time. And so if you need to wait for another thread um or if you need to sort of like evict something and like start another process, all these threads are very lightweight. So this just kind of means that there's not much state associated with the threads and they can kind of be stopped and started which

allows GPUs to get high utilization um within sort of each SM. So GPUs, you know, obviously graphics processing units. Um, and for for much of its life, you know, in the

early days, it was not used to do scientific computing. Um, but you know, people because it was programmable, researchers figured out how to use, you know, early NVIDIA GPUs to do fast matrix multiplies. Um, this is one of

the early papers on, you know, doing fast matrix multiplies with graphics hardware. Um, and it shows, you know, how you can, you know, hack kind of

hardware. Um, and it shows, you know, how you can, you know, hack kind of things like the texture buffer and so on, uh, to get it to do matrix multiplies, right? And so you know even without specific support for mapm you

multiplies, right? And so you know even without specific support for mapm you know researchers figured out how to do it but I think now you know especially in this day and age nvidia and others have realized matrix multiplies are special like if you're doing deep learning right most of your workload is

matrix multiplies and so matrix multiplies are in some sense blessed operations so this is um a chart showing um the number of teraflops per second by different generations of Nvidia GPUs and the orange line is your map mo flops

right? Like with your performance you can get if you're doing matt moles. The

right? Like with your performance you can get if you're doing matt moles. The

blue line is your non-matmo flops, right? And you see kind of this big big gap at v 100's when um they started putting in sort of tensor cores that were specialized hardware to do matrix multiplies. And you see this gigantic

gap uh in the matrix multiply performance relative to the non-mappable performance, right? Um, and so if you're going to design any sort of a neural

performance, right? Um, and so if you're going to design any sort of a neural architecture, I was saying this, you know, in the architecture part as well, you have to have most of your workload be matrix multiplies because that's the thing that's, you know, orders of magnitude faster than any other

operation that you're going to be able to do uh on a GPU, right? So if you make like a non-natmo based neural network, you're going to be in a big big trouble.

And then kind of the the last thing that I want you to kind of understand as just general facts. Um you know matt moles is fast is one thing but the other thing

general facts. Um you know matt moles is fast is one thing but the other thing that's important to remember is kind of the relative scaling of the different components of the GPU. So this is a very nice chart that shows you know how

quickly different components of the uh GPU or different components of the let's call it like LM training stack are scaling. So the blue line is the connectivity from the GPU to the host, right? Like the the server that it's

attached to, right? So you you can use PCIe, you can use NVLink, you can use all these these fancy interconnects. They are growing, but they're growing somewhat slowly, right? So so this chart is like normalized scaling, you know,

bandwidth relative to to when you know the the first generation um of interconnects. The green line, this is the uh global memory speed, right? So

interconnects. The green line, this is the uh global memory speed, right? So

you go from GDDR to HBM2E and that's much much faster right this is log scale it's 100x faster um but this is still kind of slow scaling right and the gray line here right this is compute scaling this is the number of floatingoint

operations if you're you know considering the map flops this is this is how fast the compute has been scaling and this is astoundingly fast it's like one to 100,000 times faster um and so kind of in the early days of the scaling

maybe your problems were flops based right like you you just didn't have enough flops to do your matrix multiplications. But now, you know, all the way to the right with the H100's, you know, these are astoundingly fast GPUs, your bottlenecks are probably going to end up being memory, right?

Because the memory is not growing as fast, right? And as we go into the future, you know, this is not really going to change. DRAM is very hard to scale. You're going to keep getting this bigger and bigger gap, right? So, if

scale. You're going to keep getting this bigger and bigger gap, right? So, if

you're ever designing, you know, hardware efficient algorithms, you're going to have to think more and more about memory, right? And so, we're going to keep a lookout on that. I'm going to keep emphasizing this. It's one of the important themes um in GPUs. Okay, so you know I've been kind of throwing lots

of GPU facts at you um especially if you haven't you know seen this recently and maybe kind of new. So just to recap, right, GPUs are these massively parallel processing systems. They have same instructions applied across many

different threads and they have these things called SM which are kind of like cores that you know there's many many of them in the GPUs. um compute and and matrix multiplies have scaled really fast and they have scaled faster than memory and that is an important part of the characteristics that you think about

about GPUs but there is some fast memory right it's not like everything is slow so there's nothing we can do there's the memory hierarchy right so some kinds of memory are very very fast other kinds of memories are slow and so if we exploit this hierarchy maybe we can get things that are really really fast right so

that's kind of things to remember about the GPU and if you remember these facts you know you're going to be able think pretty cleanly about the performance components that I'm going to talk about next. Um, any questions before I I move on to the next

part. Okay, cool. So now you all are GPU experts and what we would like to do is we would

part. Okay, cool. So now you all are GPU experts and what we would like to do is we would like to make machine learning workloads go very fast on a GPU. Um, and so I'm going to start with this chart and one of our goals will be to understand what

this chart exactly is. I think it'll be a good puzzle to get us motivated. Um,

and so here what we are doing is we are multiplying square matrices together, right? So the x-axis is the size of my square matrix multiplies. And you know

right? So the x-axis is the size of my square matrix multiplies. And you know the y-axis here, this is the number of operations per second that I'm doing. So

you can kind of think of this as hardware utilization on the y ais, right? Um, and so as I get bigger and bigger matrices, I'm going to get better

right? Um, and so as I get bigger and bigger matrices, I'm going to get better and better hardware utilization because, you know, I I have more work to do. So I

don't you know that overwhelms the overhead of sort of you know launching jobs and things like this. Um, but there's all these weird things that are happening, right? You see one, two, three different, four different lines,

happening, right? You see one, two, three different, four different lines, right? And each of these lines are kind of wavy in a way that's kind of, you

right? And each of these lines are kind of wavy in a way that's kind of, you know, looks very unpredictable, right? Um, and so we would like to kind of understand what exactly is going on uh with these lines. And by the end of this section, um, my promise is that you will kind of understand exactly each one of

these phenomenon. You'll be able to say, "Yeah, that plot looks totally normal.

these phenomenon. You'll be able to say, "Yeah, that plot looks totally normal.

That is a natural thing for a GPU to do." Okay, so the very first part, right, is if you look at that plot, you will notice that it looks a little bit like this, right? And if you've taken a systems hardware course, you know, you

this, right? And if you've taken a systems hardware course, you know, you should you should remember this as kind of the roof line model. Um the roofline model basically says if we're looking at, you know, throughput or utilization, um you know, what we're going to find is, you know, there's two regimes.

There's going to be a regime that is sort of memory limited, right? That is

on the left side of this curve on the green over here. And then there's a part that is throughput limited on the right side. In some sense you can kind of think of it as on the right side we have we are fully utilizing our compute units. All the matrix multiply units are multiplying all the time. Um and on the

units. All the matrix multiply units are multiplying all the time. Um and on the diagonal here we just have some sort of memory bottleneck and so our ability to to do computation is limited by kind of the amount of sort of intensity that we

have the amount of flops per bite that we have. Right? So we want to avoid being in this left side region where we're memory bound and we would like to be on this right side where we're getting in some sense full utilization of all of our compute units. Right? So that's in some sense the goal and

hopefully this roofine model looks something like this. Right? Like we've

got sort of this diagonal part and then we've got this flat part all the way at the top here. So that's one part of the mystery. Um and so this is this turns out to be kind of complex, right? Um, the simple way to say this is let's make

sure that we're not accessing memory unnecessarily, right? We have as few memory accesses to slow global memory as possible. But it turns out that in order to do that, we need a large array of tricks. Um, there's a lot of different

things that you could do that would mess you up, that would make you very slow.

And the first one's not a memory bottleneck. I'll just mention it. Um, it

doesn't come up too often. we'll get it out of the way and then we'll talk about the remaining uh five items that in some sense are really core to thinking about GPU performance. Okay, so the first thing that I want to talk about is

performance. Okay, so the first thing that I want to talk about is conditionals. Um so as I said before GPUs their execution model is something

conditionals. Um so as I said before GPUs their execution model is something called SIM key right single instruction multi-thread. Um and so every thread in a warp is going to execute the same instruction um and it's going to do so

on different data. And so what happens if I write uh a piece of code that looks like this? I have an if statement and if you know the thread index is less than

like this? I have an if statement and if you know the thread index is less than four do something. If the thread index is greater than or equal to four then do something else. Right? I have this very simple conditional uh model. If I run

something else. Right? I have this very simple conditional uh model. If I run this on the GPU um what's going to happen is that I'm going to run the a uh instruction on four of my threads. I will actually pause my other four

threads which are supposed to be executing the else part. And then these other four threads will come live and they will execute X and these my original four threads will will go to sleep and I will just alternate executing each of these instructions. Why is that? I can't in I can't execute

A and X at the same time on these different threads. Right? As I said again um every thread has to execute the same instruction. So conditional

statements within a single warp um can be really really damaging because they will force you to pause any of the threads that are not doing exactly the the main sort of control flow uh execution. Okay, so that was the the

only non-memory thing that I wanted to mention. Um and it should be kind of obvious that you should probably not be putting conditionals um into sort of your uh massively parallel compute unit. But once we've gotten that out of the way, sort of the other tricks that we need to consider are all kind of memory

based. Um the first thing I want to sort of mention is lower precision. And this

based. Um the first thing I want to sort of mention is lower precision. And this

is a big trick. This is an important trick. You should do it all the time. Um

there's kind of a going back to this plot um of of Billy. Um there's a slight of hand here. Um this looks really good because the the numbers are going up and up and up. But if you look at, you know, what's driving GPU progress over all

these years, um, you actually kind of see that it's number representations.

You go from FP32 to FP16 to intate to to so on, uh, you get many orders of magnitude gains from just having lower and lower precision in your GPU operations. Um, and let me let me sort of clarify why that's so important,

operations. Um, and let me let me sort of clarify why that's so important, right? If you have fewer bits in all the things that you're computing and your

right? If you have fewer bits in all the things that you're computing and your weights and so on, you have much fewer bits to move, right? So even if you're accessing these bits from global memory, um they become much much less of a concern. So let's just give a simple example and let's just think about kind

concern. So let's just give a simple example and let's just think about kind of arithmetic intensity of a simple element-wise operation. Right? So I'm

going to do it in values. So that's x equals max zero and x and I'm going to do that on a vector of size n. Let's say naively I'm going to do this on float 32. Right? So so how many memory accesses do I have? I have to read my x.

32. Right? So so how many memory accesses do I have? I have to read my x.

I have to write the result of if x less than zero. Um and that's all in float 32. So that's kind of eight bytes, right? And how many operations do I do?

32. So that's kind of eight bytes, right? And how many operations do I do?

Well, I have to do x less than zero. So that's one comparison operation. And I

do one flop, right? So so I do, you know, eight bytes per single floatingoint operation. If I do this in float 16 now, well, you know, I haven't

floatingoint operation. If I do this in float 16 now, well, you know, I haven't changed the flops intensity here, but I've have the memory access. And so now I have four bytes per flop, right? In some sense, I've like gotten double the

memory bandwidth for free, assuming that I can get away with flop 16.

And this is a key part of of how a lot of things are designed. Part of the assignment is going to be you're going to, you know, try and and play with uh various like mixed precision or low precision uh training and and other kinds of things. Um and a key part here is that not all the parts of your

network and your training algorithm should be put into low precision, right?

So let me give you an example of matrix multiplies. So in matrix multiplies that are mixed precision what you would do is you would have you know your inputs be 16 bit. So these are low precision. Um and then you're going to do your your

16 bit. So these are low precision. Um and then you're going to do your your multiplication in full 32bit right and that's useful because the intermediate computations as you're like accumulating partial sums you would like that to be

in high precision. Um and so you're accumulating this with uh FP32 accumulator and then you know your tensor core uh will return a FP32 result um which you can you know downcast if you would like um back into into 16 bit

right and so we have our inputs in 16 bit but things like the accumulation um we might want to do in 32 right so there's lots of different things there's operations that can use 16- bit storage there's operations that might need more

precision so you want to keep it in like either FP32 or FP16. Um, think you might want to have operations that need more range like X functions. If you don't have sort of the dynamic range, they might blow up or zero out. And so you might want to put those in BF-16. There's a lot of sort of careful

engineering um that has to happen in order to make sure that you know these these models uh are actually stable when they're being trained with lower precision. But if you can do it, that's really great because you've basically

precision. But if you can do it, that's really great because you've basically doubled the throughput of your bottleneck going from 32 to 16 bit, right? if your if your memory is your

right? if your if your memory is your bottleneck. Okay, the other one and I think this is kind of what a lot of

bottleneck. Okay, the other one and I think this is kind of what a lot of people think of when they say like I'm going to write a CUDA kernel or something. Um, operator fusion is kind of both very intuitive and both a like a

something. Um, operator fusion is kind of both very intuitive and both a like a fun natural one to think about. So, one memory or sorry, one mental model of how a GPU works and how memory works is is this kind of fun diagram of a factory um

from Horus Heath, right? So imagine you have a factory and your factory is your compute part, right? And so you know it takes in little uh box widgets and then outputs little triangle widgets. Um and if you grow your compute but your belt

conveyor, you know that takes memory to compute is, you know, finite bandwidth, you know, you're not going to be able to use your second factory, right? Like

you're still capped by the speed at which you can transfer things um from memory to compute. And so you've got this this bottleneck. Now, of course, you already knew that, right? I've been sort of hammering in the the memory bottleneck thing. But I think one insidious way in which you can uh incur

bottleneck thing. But I think one insidious way in which you can uh incur a ton of overhead without really realizing it is kind of this left-hand side computation pattern, right? So, you know, imagine the left side of this this

plot is where the memory is. The right side is your compute unit. Um and so to do computation, I start with a square and I move my squares from my memory to my compute. I do some operation. I turn them into triangles. Right? Now, I ship

my compute. I do some operation. I turn them into triangles. Right? Now, I ship my triangles back to memory. And then you know, okay, I realize I need a triangles again. So I ship them back into the compute unit. Now the triangles

triangles again. So I ship them back into the compute unit. Now the triangles become circles and then so on and so forth, right? I send my compute sort of back and forth and back and forth back to memory. And you might call this kind of a very naive approach. And if you were just doing operations naively on

the GPU and just shipping the results straight back to global memory, this is what you'd end up with, right? And if you count the number of times a piece of data went back and forth, this is this is pretty terrible. You've incurred tons of memory overhead. Now uh you should be able to realize that if you look at the

right side well this compute well there's no dependency so I should be able to go square to triangle to circle to to rectangle and ship the rectangle back right I can just keep everything in the compute unit the whole time right and that's the right hand side diagram and this is the mental model of a fused

kernel right you have a bunch of operations that are going to happen on a piece of data in sequence instead of writing it back into storage what I'm going to do is I'm going to do all the computation as much as I can in one place and then only when I have to ship it back to memory. Right? So that's this idea of of kernel

fusion. Okay. Um there's some very simple examples of how if you write some

fusion. Okay. Um there's some very simple examples of how if you write some naive code um you might you know get sort of a naive set of launches. So

here's an example. Um I wrote a little let's say neural network module. um you

know let's say let's say I write a neural network module that takes in uh s x and it produces sin^ squar x and cosine^ squ x right simple code now if I run this you know the computation graph in pietorch is going to look something

like this and it's going to you know launch a whole bunch of cuda kernels it's going to launch take in the x and it'll it'll launch a cuda kernel to compute sin x it'll launch one to compute cosine x then sin square of x and cosine square of x and sin^ square x plus cosine square of x right so there's

a bunch of back and forth that has to happen in order to do this computation.

It's exactly the lefth hand side figure uh that I showed you before. Um but if you were a little smarter, right, and you either wrote your own CUDA kernel or you use something like torch compile, um well, you can easily realize that those five

operations um don't really depend on very much like they they use only a little bit of memory. And so you can fuse them into a single operation that does everything on GPU on a single thread without sending things back to

global memory. Right? So um really easy fusion operations like this can be done

global memory. Right? So um really easy fusion operations like this can be done automatically by compilers. I just mentioned torch compile. Um if you aren't already doing this, you know, you should you should consider strongly thinking about using torch compile everywhere. Um we'll show you in the

assignment um torch compile as well. It's it's pretty uh nice. Okay. Um, so

I've gone through uh precision and fusion. If anyone has questions, let me know uh before I move on to to recomputation um and other kinds of tricks that we can do on the

GPU. Okay, good. So another thing that we can do is called recomputation. Um and

GPU. Okay, good. So another thing that we can do is called recomputation. Um and

recomputation is this idea of sort of spending more compute to uh avoid having to do memory access, right? Um, so remember back your, you know, original back propagation lecture. This one's actually from CS221. Um, what do we do?

Well, we take our our inputs at the very bottom. These are the yellow ones. Um,

and then we propagate activations upwards. Those are also the yellow values on the on the tree. Um, and then we compute the Jacobians backwards.

Those are the green values on the edges. And then to compute my gradients, I'm going to propagate you multiply sort of the Jacobian and the activations. I'm

going to propagate the the gradients backward, right? Um, well, if you think about it, those yellow values after the forward pass have to be stored, right?

And then they're stored and then they have to be taken from global memory where I stored them and put them into the compute units, right? Mechanically,

that's how it has to happen. Um, but that might actually be a ton of sort of memory inputs and outputs happening. Instead, you might actually be able to avoid this. So, let me give you an example of how recomputation can

speed things up. Um, here's another sort of silly uh function that I might write.

I'm just going to stack three sigmoids on top of each other, right? You can

look at the left. That's the forward graph. Um, that should be exactly, you know, your mental model of three sigmoids on top of each other. Now, you

know, the the computation graph for this, I'm going to compute the sigmoids and I'm going to store S1 and S2, which are the activations of the sigmoids, and I have my outputs. Um, and then, you know, that's my sort of forward pass.

Now, the backward pass in this is kind of terrible. when I do my backward graph, I need to go and take S1 and S2 um and I need to take you know the the gradients coming sort of backwards into this out box um and then push it into

this you know backwards computation and I'll get the gradient of X right so I need to have three memory reads one memory right in order to compute the backwards pass and then for the forward pass I need to do one memory read of X

and I need to do three memory rights for S1 S2 and out right so hopefully that's clear this is you know a decent amount of of uh memory reads and writes have to do eight of them and I have very low arithmetic intensity because I have no matrix multiplies um at

all. So the idea of recomputation is to say I don't want to store those

all. So the idea of recomputation is to say I don't want to store those activations at all. Right? Like I'm not going to put them into memory. I'm just

going to recmp compute them on the fly in my backward pass. Right? So now in my new forward pass I don't store S1 and S2. I take X as input. I compute my sigmoids and I get my output. Right? So now that's one memory read for X, one

memory right for out. Right? Um, now in my backward pass, right, I don't have activations anymore. So what I'm going to do is I'm going to get both D out,

activations anymore. So what I'm going to do is I'm going to get both D out, which is, you know, the backward signal coming in from above. Um, and then X, which is my input, right? So I'm going to take two of those, which is two memory reads. Um, and then sort of on the fly in my SM in my local memory, I'm

memory reads. Um, and then sort of on the fly in my SM in my local memory, I'm going to compute each of these sigmoids, and I'm going to put them into the backward graph, right? I'm going to recmp compute S1, S2, uh, and out on the

fly inside sort of my local memory. Um and because I do that there's no global memory reads happening here. Um and then I have one memory right which is dx.

Right? So now if you compare the two um I have 5/8 of the memory access for the exact same computation. Right? The price that we paid is that I'm going to have to recomputee these three sigmoids. But if you were running sort of idle anyway because you were memory capped, this is a a great trade-off, right? Like you

would be very happy with this because now you've traded compute which you have too much of for memory bandwidth which you had too little of. Right? Right. So

this is one great way of trading um one thing you need for another thing uh that you have. And of course this is different uh it's the same trick as um sort of

have. And of course this is different uh it's the same trick as um sort of gradient checkpointing and recmp computing activations for memory savings. Um but this is being done for different reasons. This is for uh sort

savings. Um but this is being done for different reasons. This is for uh sort of execution speed not just because you're running out of memory. Right? So

it's it's the same technique but for different goals. Okay. And then this one I think um is actually kind of a really interesting one and and not one that I knew until I started sort of really looking into how the hardware model of a

GPU uh and DRAM works. Um, so the slow memory, the global memory called DRAM in a GPU, um, that's actually very very slow. And in order to to make it faster, there's certain optimizations that are being done at the hardware level. And one of the

optimizations that's done at a hardware level for for DRAM is that when you go and read a piece of memory, you don't actually get just that value back. You

actually get a whole chunk of the memory back. Um, and this is called burst mode.

So um let's say I went on and uh tried to read the very first value of this big memory block, right? Instead of just the memory giving me back zero, it would actually give me back 012 3, right? It would give me back four values at once.

It'll be like here you go. You know, I'm sure you'll need the one, two, and three too in the future. Um and so each address space is cut up into what's called burst sections. And then you're given the entire burst section rather than just what you looked for. And this might seem very mystifying like

why would the memory give you three extra, you know, bytes for free uh when you're just asking for one. Um there's sort of like a very interesting hardware reason which is that when you're addressing into the memory, you know, in order to send the signal out from the memory that those bytes have to be moved

to an amplifier. That's the slow step. And once you've done that, you can get many many bytes for free. And so that's why sort of this burst section thing exists. it's kind of masking this more expensive step of actually moving where

exists. it's kind of masking this more expensive step of actually moving where the the data is stored to this amplifier but kind of regardless um this kind of means that we might be able to significantly accelerate sort of our

memory access if the pattern of memory access is good right so if I want to read um this entire you know block over here if I access it in random order right then I'm going to have to you know basically query uh a number of times

equal roughly to the length of my query. Right? But if I sort of go and I check the very first value, then I'm going to get all this entire burst section at once. And then if I go and check number four, I'll get this burst section, the

once. And then if I go and check number four, I'll get this burst section, the second burst section at once. And so I can, you know, uh, basically get four times the throughput if I'm really clever about my memory accesses and only

access just the bits I need from each uh, burst section. So this is called memory coallesing. So if all the threads uh in in a warp fall within the same

memory coallesing. So if all the threads uh in in a warp fall within the same burst um then basically the sort of smart hardware and programming model will basically group those queries. instead of querying 0 1 2 3 it will

group them and say just give me zero and then I will be able to read out all the 0 1 2 3 at once from this kind of burst mode uh DRAM right so remember that you know a warp is 32 sort of numbered threads and so memory accesses from a

warp happen together and so when these warps are reading in to these kind of burst sections there's optimizations that can be done so that you're getting all four bytes at once rather than getting one of them at a time individually and so that will 4x uh the throughput that you have on your memory.

Right? So these are kind of very simple things, but they're actually very important. Like imagine I'm going to do matrix multiplications, right? This is a

important. Like imagine I'm going to do matrix multiplications, right? This is a core thing that you're going to have to do a ton if you were to sort of implement, let's say, uh neural network really from scratch in CUDA. Um in this

case, imagine I'm going to read my my uh matrices in one of two ways. I can read it by traversing the rows, right? So each thread is going to traverse the row. Or I can sort of read it in sort of column order. So each thread is going to

row. Or I can sort of read it in sort of column order. So each thread is going to go down a column, right? Um turns out that this left one where you're sort of going across different rows, so each thread is accessing a different Oh,

sorry, each thread is going through columns. This left model is going to be quite slow. Um because the memory reads are not going to be coalesed. Um whereas

quite slow. Um because the memory reads are not going to be coalesed. Um whereas

if you're going to this right side where each of the threads are going down, so they're they're incrementing in rows, then these memory reads will be coalesed. Um and so you know you can think about it for a moment why this is

coalesed. Um and so you know you can think about it for a moment why this is true. Um when I first looked at this diagram I was like isn't it reversed?

true. Um when I first looked at this diagram I was like isn't it reversed?

It's actually not. This is this is the correct one. Um and the way to think about this right um is let's say uh on this right hand side diagram over here.

I'm going to have a thread that's trying to a series of threads that's trying to access you know left to right. So each thread is going to try to load you know the very first element and then in the next time step I'm going to the load the

element from the uh this column the second column and then the third column and the fourth column and so on. So if that happens, what happens at time step one? Right? At time step one, my first thread loads this point and then the

one? Right? At time step one, my first thread loads this point and then the second thread loads this point and then this point and that point, right? So

those can't be coalesed at all. They're reading different burst sections. And so

that means that I have to read this entire chunk of memory in order to perform any sort of an operation. Instead, if I was sort of going in the column direction, all the threads will be reading within the single burst section. And then so only one memory read operation needs to be performed and

section. And then so only one memory read operation needs to be performed and you get all the memory at once. Right? This is a very low-level optimization, but this is very important. Right? If your memory traversal order is all

wrong, you you will actually get much slower memory accesses than you really want.

Okay? So then uh that brings us to kind of the very last and kind of big one. Um

and this is the idea of uh tiling. Um and tiling is this idea that you would like to group together memory accesses um in order to minimize the amount of global memory access that we have to do. And so to explain this one uh I'm going

to try to go through this example of a matrix multiply. And hopefully I'll be able to sort of explain to you um why sort of a naive algorithm for doing matrix multiply is going to be very problematic. And then afterwards I'm

going to give you a tiled version of the same idea. And hopefully you'll be able to see why that's going to reduce the number of global memory reads um that you have to do. So let's start with this very simple matrix multiply uh

algorithm. So you know I've got a matrix you know I got this M matrix on the left

algorithm. So you know I've got a matrix you know I got this M matrix on the left side. I've got my N matrix on the top. Um, and in order to compute, you know,

side. I've got my N matrix on the top. Um, and in order to compute, you know, the matrix matrix product, right, I'm going to have to traverse over the rows of M and the columns of N and then take the inner product and store that into uh

this P matrix, right, the corresponding rows. Um, and I've written out here um each of the threads, the thread 01 1 0 1 corresponding to where they're sort of storing their outputs and sort of the access order in which they access each

of the individual elements. Now notice here that you know what's going to happen is that the memory access here is not coalesed like the row um uh matrices here these are going to be accessed in a non-co order and I have repeated memory

accesses right so I've got m00 0 being accessed in the first thread m 0 accessed here n0 n10 being accessed in two different threads you know so these values are being kind of read um over and over from global memory into many different threads

And so this is going to be potentially very slow. So there's a question of can we avoid having too many global memory reads and writes. What I would ideally like to do, right? So let me explain kind of the the ideal outcome first and

then I'll explain the algorithm. The ideal outcome is that I would like to spend one sort of, you know, chunk of time loading pieces from global memory to shared memory where things are fast. I want to do a ton of computation in

shared memory and then I want to kind of be done with that piece of uh data.

Right? That's the ideal outcome. I've minimized my global memory accesses. So

now how can I do this um in this matrix multiply world? So now what I'm going to do is I'm going to take my matrices both the M matrix and the N matrix and I'm going to cut them up right into tiles. So here I've cut this up into 2x2 tiles.

So I've got a 2x2 M tile and a 2x2 N tile right. So I've got basically uh smaller submatrices within each of the matrix. And now imagine that my shared memory is big enough to be able to fit these submatrices, right? Uh within each

of these SM. So now this gives a very very simple uh algorithm with which we can do computation. So uh what I'm going to do is I'm going to first load you know let's say this m00 tile on the top left over here and I'm going to also

load my N00 tile um into shared memory here. Right? Um, so now I have these partial sums that I can compute. I can take, you know, the the row product of

m00 z m01 with n z n 0 and I can increment that into p 0. I can do the same with all the different submatrices that I can fill out over here. Right now

then once I'm completely done sort of processing these two tiles, then I can load a new tile over here. And then I can repeat that computation with my M tile and my N2.0 tile loaded into shared memory. And then I can sort of increment

my partial sums in P. Right? So now I've really sort of consolidated and reduced the amount of global memory access I have to do. Right? I I load as much memory as I can at once into shared memory. I do all of my sort of submatrix

computations on that tile that I can and then I move on to the next one. Right?

Um and of course the other nice thing is that because I'm loading um an entire tile, you know, I can traverse these submatrices matrices in whatever order I want like column measure or row measure. And so I can coales all the memory accesses whenever I'm loading a tile from global to shared memory. Right? So

so there's kind of winds all around here um when we tile our accesses. So we can do a little bit of of tiling math. Um so we've got let's say a matrix A, a matrix B, and a matrix

C. So let's say the full matrices these are square matrices are of size N. And

C. So let's say the full matrices these are square matrices are of size N. And

let's say I have a tile of size T, right? Oh yes, question. Previous slide of load m0. So three loading m00 0 again. So in that case I just wrote it for for

load m0. So three loading m00 0 again. So in that case I just wrote it for for completeness but m00 z let's say is just you know stored in shared memory. Let's

just keep it cached. I won't load it again. That that's definitely just there for completeness. Not that you would actually like discard and reload the the

for completeness. Not that you would actually like discard and reload the the uh matrix again. That would be kind of insane. Cool. Okay. Um, and so we can kind of do very simple tiling math to think about, you know, what's happening.

So let's say I'm going to do a n byn matrix multiply, right? Um, so if I do a non-tiled matrix multiply, if I'm just going over rows and columns, then every input every time I process it has to come from global memory. So each input

is read sort of n times from global memory, right? So each of these is read sort of n times. Um, if I do a tiled matrix multiply, well, you know, the the global reads are operating over a tile. So I'm reading each input n over t times

from global memory and I'm reading t times within each tile right of course I'm doing matrix matrix multiplies so I can't reduce the total number of reads I have to read all the matrix elements but I can shift the reads into uh basically

fast shared memory right so I do t times um memory reads into shared memory and n overt times from global memory um and that's great because if we have a big shared memory that can store big tiles that's a factor of t reduction in the

total amount of data that has to come from global memory. Right? So tiling can be really really powerful um of an idea when you're operating over matrices and you can move things um into shared memory. Um tiling is is quite complex.

Um this is the source of many many sort of uh confusing things about GPU and matrix multiply performance. Um one thing that can happen right once we start tiling things you start asking things about discretization right. Um,

so imagine I have a tile size of 128. That seems like a nice good round tile size. Um, but then um, you know, when I have a a full matrix of 256 size, that's

size. Um, but then um, you know, when I have a a full matrix of 256 size, that's great. That's a 2x2 tile. Things load nicely. Now, let's say I have a 257 size

great. That's a 2x2 tile. Things load nicely. Now, let's say I have a 257 size tile, um, on the column side. Now, this is a bad time because I need to have six tiles in order to cover this matrix. And the two tiles on the right are very,

very sparse. there's just not much stuff in there, right? And the problem with

very sparse. there's just not much stuff in there, right? And the problem with this is that each tile is going to be assigned to SM, right? So each of these tiles is going to be a block and each thread is going to be operating within

each tile. So those two tiles on the right, they're not going to be doing

each tile. So those two tiles on the right, they're not going to be doing very much at all, right? Those SM are going to be basically be sitting idle.

Um, and if you were kind of compute capped, you would have wanted to more evenly distribute the load uh between SM, right? So um you have to basically optimize your tile sizes to try to avoid these kinds of scenarios. But in

reality, right, there's a lot of complex things that go into setting the tile size, right? Um remember you have to coales your memory accesses. So you have

size, right? Um remember you have to coales your memory accesses. So you have to think carefully about that. You have to um uh m you have to not exceed your shared memory size, right? So so the tiles can't be too big. And you have to divide the matrix dimension hopefully evenly or as close to evenly as possible

so you don't end up with this situation of sort of an underutilized SM um at the very end here. Um yes so you have say smaller sizes do something like would GPUs do

here. Um yes so you have say smaller sizes do something like would GPUs do something like where they can like fetch the tile beforehand and if so like would

that happen the level? Yeah. So you're you're asking about whether or not you can like overlap uh memory reads and computation and yeah that's that's naturally done in uh GPUs like they're always like trying to use the the

available bandwidth. Like as long as shared memory is available they can go

available bandwidth. Like as long as shared memory is available they can go and put things into it. The issue is that whenever you're you know effectively utilizing um your SMS you're basically maxed out on your shared

memory right that's like the the the bottlenecked resource and so there is no place to prefetch in some sense. Cool. Okay. Um, and the other thing that is very

very, you know, we're getting into the weeds here, um, complex is the interaction between tiling, um, and sort of burst sections. Um, so imagine I have

a a matrix layout that's kind of like this. Um, where, you know, I have my nice burst sections. Um, and each burst section lines up nicely with a tile. So

to read this tile, all I have to do is to, you know, get four different burst sections and I've gotten this entire tile. Now imagine what happens if I add sort of one element extra and the way the the matrix is laid out, you know, my

sort of tile start sort of my burst sections flow over. So now what's happening is when I load my tile, I'm going to load this first part and that's really great. I get the entire first row as a burst section. Now in the second

really great. I get the entire first row as a burst section. Now in the second row, this actually belongs to two different burst sections. And so I have to do two reads in order to get this second row and so on and so forth. So

I've essentially doubled the number of memory accesses because I've added a single extra element at the very end there that's kind of bumped up the alignment of my burst section and my align layout. And so basically if tiles

or your matrix sizes aren't multiples of your burst section, you can easily end up with situations like this where the rows don't line up with the burst section and you've doubled the amount of memory access that you have to do. Um,

and the way to get around this is you have to do padding to be able to kind of get nice round matrix sizes so that your burst sections line up with the size of your tiles. Right? So this is this is getting very into the weeds here. Um but

your tiles. Right? So this is this is getting very into the weeds here. Um but

if you really want to squeeze out all the performance from your matrix multiplies, these are the kinds of things you have to think about, right?

And you will get bitten by this um if you're not thinking about it.

Um and of course, I guess like uh things like torch compile and and all the CUDA optimizations for matrix multiplies, they're doing exactly the kinds of stuff that that I just talked about, right? That's the way you you get better

performance. Um and so you know all this matrix complexity you know ends up in

performance. Um and so you know all this matrix complexity you know ends up in situations like this um where you know the I'm reading Andre's this tweet here but you know the the most dramatic optimization to nano GPT is to increase

the vocab size from 5257 to 5304 um which is the nearest multiple 64 um which gives you much much higher occupancy um careful with your powers of two right so that's a 25% speed up from adding uh how many it's like 50 uh 57 no

47 uh dimensions to your vocap like that's that's kind of like you know how does that happen right um and so that kind of brings us back to the mystery like you know I I was dragging you through all the GPU details um in the

hopes that you know you'll have a full understanding of all the performance characteristics but in some sense the payoff is you know I now get to explain to you how this chart comes to be and at the end you won't find matrix multiply

performance to be so uh mysterious or scary um at the end here, right? So the

very first part is very very simple like we understand compute intensity, right?

This is exactly the roof line that I pointed out at the very beginning, right? So so up until here, which is about 1536, right? Um there's just not

right? So so up until here, which is about 1536, right? Um there's just not enough matrix multiply work to do, right? The just loading the matrix and doing very basic IO, right, that you have to do is becoming a bottleneck

below this point, right? So throughput is going to fall through to the ground.

uh uh past this point you just don't have enough memory bandwidth to support your compute units. Now on the right side here in theory right if you if I draw the upper envelope this is the kind of maximum achievable performance. So

it's possible up here to saturate all of my compute units and get really great performance. But if you kind of mess up your matrix sizing you can end up in

performance. But if you kind of mess up your matrix sizing you can end up in these kind of really weird places and within each one of these you can kind of end up in a weird trough. And so we're going to kind of think a little bit

about you know why do you have all these different places you can end up. Um so the very first thing um this first line here um this is a a tiling

up. Um so the very first thing um this first line here um this is a a tiling alignment issue. So if you look at um kind of the multiples here so I've now

alignment issue. So if you look at um kind of the multiples here so I've now colored each of these lines based on kind of the divisibility of the matrix size um and this is the size by which it's divisible. So if it's divisible by

32 then you're in good shape. you're in these purple dots up here. If you're

divisible by uh 16, um you're actually uh still up here. There's two colors.

And then if you're green, your k equals 8, you're up here. If you're orange, you're k equals 2. And if you're k equals 1, you're all the way down here.

If you're not divisible by any number, uh don't pick prime dimensions. You're

not going to get very good throughput on your matrix multiplies. Um, and a big part of this is going to be, you know, once you get to kind of k equals 2 and k equals 1, you are basically forcing the situation where you can no longer read

tiles in the sort of nicely aligned way with your burst reads. And that's going to lead to to some serious issues. So, so that's kind of a problem. But then, okay. So, so that's one part of the mystery, but I think another part of the mystery remains.

Like, so within this orange line, you know, I think if you zoom into here, you see this giant drop, right, from from this point all the way down to this point where you're just kind of wondering what happened here? How could

I lose so much performance increasing my dimension by two? Um, and so let's just look at these numbers. Um, and it's just I think this is a fun puzzle. So, I'm just going to

numbers. Um, and it's just I think this is a fun puzzle. So, I'm just going to walk you through the puzzle. Um this is going to happen when you transition from 1792 to 1790 uh I guess three or four um size. Let's say four here. Um just so

that it's a factor of two still. Well, why does that happen? Okay. Well, let's

say that we're using a tile size of 256x 128. That's a pretty natural size. Um as

a fun fun fact, you know, the matrix multiply units in these GPUs, they're they're naturally operating on matrices of roughly size 128. So 256 x 128 is a is a very nice tile size, right? So that means how many tiles are there? Well,

there's seven * 14 tiles, right? Because we're dividing the dimension of the matrix by the size of our tiles. That's a total of 98 different tiles. Um and if we increase this by one, well, you know, we're going to have to round up each one of our coordinates. And so we're going to have a lot more tiles, 120 of them,

right? Um so we've increased the number of tiles by quite a bit. Um, well, you

right? Um so we've increased the number of tiles by quite a bit. Um, well, you know what's going to happen is not only did we significantly increase the tiles and some of them have lower utilization, which is bad, but actually even worse,

an A100 has 108 SMS, right? Um, and if I if you go all the way back to kind of the GPU execution model, right, SM can can execute in parallel and they're kind of the execution units. And so when you have 98 SMS, they all go and run, right?

You can you can dispatch them all. All the SM are running, you know, you got great utilization. Once you go to 120 um tiles now you've got more tiles than

great utilization. Once you go to 120 um tiles now you've got more tiles than SMS. So 108 of those will execute and then you will go back and you'll say all right I've got some more SMS at very very low utilization you're going to execute the remaining 12 and wait for those to complete right and that's going

to be really bad. So if you look at your utilization, you got good utilization for a while, you'll drop off a cliff and then you'll sort of finish up your job, right? So this is something called wave quantization. Um and so ideally your

right? So this is something called wave quantization. Um and so ideally your tile sizes are either much bigger than the number of SMS or you know they're they're not like this where you're just like barely over the SM and you've

caused this quantization uh sort of error uh additionally. Cool. All right. I know this is this is low-level details, but in many ways, you know, I I've been saying through through many classes that language models and deep learning is attention to detail.

Um, and these kinds of attention to detail is the things that allow people to scale up LMS um to really really large sizes and get great performance.

Um, so it's worth knowing even if you're not a person that's going to do systems engineering. So, what were the tricks, right? Um, key ideas here. First one is

engineering. So, what were the tricks, right? Um, key ideas here. First one is you got to reduce the amount of memory accesses, right? So, there's lots of ways to do it. You can do coalesing right so that you're not you can sort of

reuse reads that you're getting for free. Um you can do fusion so that you know you can fuse multiple operations together and avoid unnecessary reads and writes. You can move memory to shared memory. So you know even if you're going

writes. You can move memory to shared memory. So you know even if you're going to do reads they're going to be from much faster memory. Um and that's going to be sort of tiling tricks that you can do. Um and then finally you can kind of trade memory for other resources that you do have right? So you can trade it

for compute which is going to be um recomputation or you can trade it for just numerical precision or stability which is going to be quantization right so there's lots of bags of tricks that you have in order to get sort of performance um out right so so there's lots of things you can do um you just

have to be really mindful of kind of the role that memory plays in the performance of a GPU right that's kind of the key thing to get the m the most out cool any questions on on that before I sort of move to the final part with flash attention

Okay, good. All right, so now I'm going to put it all together, right? Like I'm

Okay, good. All right, so now I'm going to put it all together, right? Like I'm

going to try to make it so that all the the tricks that I taught you aren't these like random disconnected facts about GPUs. They're kind of part of the standard performance optimization toolkit and flash attention and flash attention 2 will hopefully teach you how that all comes together to to build one

of the, you know, the foundations, I guess, of modern high performance transformers.

So um flash attention you know we know that it dramatically accelerates attention um and most of you probably know that that's done through some CUDA kernel magic um but maybe you don't know all the details right so you know what

the the paper says is okay so there's one part that's happening which is you know you do attention on a unoptimized you know pietorch transformer implementation if you fuse the kernel and you do some things you can get

significant significant speed ups um and from the paper you know they say we apply two established techniques techniques tiling and recomputation to overcome the technical challenge of computing exact attention in sub quadratic HBM accesses right so so it's not sub quadratic you know computation

because you can't do that you have to compute you know attention in general but they're going to get subquadratic accesses to the uh high bandwidth or global memory right and so that's really the key if your memory is the bottleneck you know you want to make that not quadratic so that at least you can pay

for quadratic cost with your compute rather than with your memory So just for a really quick recap, you know, at this point you've implemented attention many many times in many classes. Um right, so it's going to be three different matrix multiplies. You've got a K, Q, and V with a soft max

in between. Um so the matrix multiplies are pretty simple that can be, you know,

in between. Um so the matrix multiplies are pretty simple that can be, you know, done with tiling. I've showed you examples like that. And what's different about attention? Well, there's a softmax thing that's going to be the real tricky

about attention? Well, there's a softmax thing that's going to be the real tricky bit. And then once we can deal with the softmax um all of the sort of matrix

bit. And then once we can deal with the softmax um all of the sort of matrix multiply things I was talking about will just come into play. So um the matrix multiply as I said before is exactly what I taught you. So if you look at the

figure one from the flash attention paper, this is really just a simple tiled matrix multiply, right? You see, you know, the the K matrix, the Q matrix, you see it cut up into small blocks, you know, small blocks of it are

being copied to SRAMM, they're being multiplied, and then they're being, you know, accumulate or they sent to the HBM where you do soft maxes um and then you multiply um with a V, right? So this is all just really simple um in terms of

the KQV matrix multiply. But now we have to think about the softmax, right? Like

what's going on with the softmax. So the key thing here is the softmax. Sorry, I'm going to I'm going to roll back one step. So the issue with

softmax. Sorry, I'm going to I'm going to roll back one step. So the issue with the softmax. What's the problem with the softmax? It's a global operation, right?

the softmax. What's the problem with the softmax? It's a global operation, right?

The softmax in an attention operates row by row. You have to sum the entire row, right? To compute sort of the sum normalizing term of the softmax. And

right? To compute sort of the sum normalizing term of the softmax. And

that's very problematic. If I have tiles, right, ideally I want to do everything within the tiles, right? I don't ever want to have to write back to the big matrix. And so I need a softmax that can be computed online within each

tile, right? I want to do as much computation within each tile as

tile, right? I want to do as much computation within each tile as possible. So the key thing here is to uh use what's called the online softmax. Um

possible. So the key thing here is to uh use what's called the online softmax. Um

and so what is that? If you have a stream of values, right, normally the batch version of the softmax, you take all of your x1 through x of ns and you would exponentiate them, sum them, and you would divide them, right? That's

what you would do in your normal softmax. And then you would, you know, maybe compute the maximum value and you subtract that in order to be able to make this numerically stable, right? So this is the the standard numerically stable softmax on the the left side. Um, so the online softmax, I've taken this

from from Mikallof and Gimmelstein in 2018. Um well you can sort of realize that you can pull out um via sort of like a telescoping some kind of an argument um basically the current running uh sort of normalizer term and

the current sort of top term of e to the x i minus max of x k right so what you're going to do is you're going to maintain your current max that you've seen over x1 through x of j which is my current iteration and then I'm also

going to maintain sort of this uh correction term if my max updated This is going to basically correct my max and then I'm going to add my sort of new term over here. Right? So this d of j is going to track online the top term of

this equation term two over here and then you know at the end I can also then compute the the normalizer and then sort of get the normalized y of i that I want right this d of v is itself sort of the normalization term um that I need. So

the key thing here is that this can be done online. I don't need the x1 through x ofn up front. All I need need is sort of the stream of x1 through xn. And

that's really key because I can now compute the softmax tile by tile. Right?

Within each tile, I can run this algorithm and that will let me compute kind of the partial softmax for that tile. And then I can sort of write back if I need to all the components that I sort of I'm keeping track of. And that's

all that I kind of need in order to do this computation. Right? So I never have to materialize the full um n squ matrix in order to compute the

softmax. And so that's basically it. But once you have that, you know, you've put

softmax. And so that's basically it. But once you have that, you know, you've put it all together and you can get the forward pass um of flash attention. And

if you go and look at the flash attention to paper, which is um going to be a thing that we're going to ask you to implement. So you you're going to be following through kind of these these steps here. Um you're going to see exactly this idea. So first you're going to have your your KQ matrix multiply and

this is going to be tiled. So these are little tiled chunks and they're going to be multiplied. And how am I going to compute the softmax? Well, I'm going to

be multiplied. And how am I going to compute the softmax? Well, I'm going to maintain sort of a running value of these sort of exponentiated sums. And then I'm going to keep incrementally updating it and correcting for the

maximum terms. And by doing that I can compute all the necessary quantities kind of tile by tile sort of going from one tile to another. and then just multiply once again with tiles with V in the end and that will give me sort of my

full soft max output right um yes so we won't be able to compute that output until we compute the like 2k multiplication across all tiles right so

we do have to double back on each um so the question was you can't compute this until you you are done um with all the tiles and so you have to double back um on all the tiles um So you won't have built up that denominator sum until you

see every tile. Um that's right. So so you will have to before you can output your softmax you will have to go through all the tiles. This is correct. Um but

by let's say I do all the tiles once right like I do all n squed tiles. Um at

that point I have all the components that I need in order to directly output the soft maps. At that point I don't have to re do recomputation because I have the normalizer terms already right by going through each of these kind of

tiles. at the end of going through all these tiles, I've built up, you know, L3

tiles. at the end of going through all these tiles, I've built up, you know, L3 or L L of N, which is the sum of all of the exponentiated terms. So, I already have that in my sort of uh in my shared memory for this last tile. And then that

allows me to exponentiate and divide and then return all the components. Okay. Um, so the backward pass I'm not going to cover. um you can do recomputation tile by tile which will

allow you to avoid storing the softmax right remember you know I always want to avoid storing anything that's of size n squared and so here I've been sort of clever with the tiles so that I don't have to store any of the n squ components when I'm computing for example the softmax but in the backwards

pass if I store the activations that's already something that's n squed sized right so I don't want to store my n squed activations I'm going to have to recmp compute it on the fly tile by tile when I do the backwards pass Right. So

that's a really key other trick that they do um in order to to make the backwards pass possible. But otherwise it's fairly standard. It's really the same thing as computing the gradients just tile by tile and doing that computation.

Um so okay that that brings us to the end here. Um, hopefully you've kind of seen how all of the pieces I talked about about tiling and coalesing and recomputation um, come together to to give you uh, flash attention and all

these really cool things that make your transformers go much faster. Um, so to to you know recap for the whole lecture, right? Hardware is kind of the thing that has really powered all of the language models that we have today. Um,

and so if you really want to leverage your hardware, you have to understand the low-level details. I think all the systems advances really engage with a lot of the the concepts that I taught today. Um and the current GPU sort of scaling, you know, that plot is really the one you should remember. Um really

really incentivizes and encourages you uh to think about memory movement, right? The memory movement is the bottleneck in all of this. And so you

right? The memory movement is the bottleneck in all of this. And so you don't want to just think about oh how do I reduce the number of flops? That's

important too. Um really you really have to think about okay, how do I make my memory movements more efficient? Um and then finally if you you know have to do a certain amount of computation well to optimize things the way to do it is to

optimize your data movement right to to be able to avoid as much movement from the the high bandwidth memory or the global memory um as possible. You want

to reduce that and have everything in the very very fast shared memory and that leads to good performance um on things like um flash attention. Um thanks everyone.

Loading...

Loading video analysis...