Stanford CS336 Language Modeling from Scratch | Spring 2025 | Lecture 4: Mixture of experts
By Stanford Online
Summary
Topics Covered
- Replace Dense FFN with Sparse Routed Experts
- Mixture of Experts Outperforms Dense at Same FLOPs
- Top-K Token Choice Routing Dominates
- Load Balancing Prevents Expert Collapse
- DeepSeek V3 Evolves with Fine-Grained Experts
Full Transcript
So, we'll get started. Today, we're going to cover a mixture of experts.
Last year, this was kind of a fun bonus lecture that I threw together. Um, but
this year, thanks to, you know, lots of people doing, this has become a much more critical lecture. So I've added a lot of um the recent developments and at the end we'll try to walk through um DeepSeek V3 and try to understand like
what are all the sort of components that make up a state-of-the-art open source system or at least on the architecture side what that looks like. So mixture of experts is how a lot of you know the most modern high performance systems today are are built
and deployed. Um, so there was the funny Nvidia leak of uh GPT4 actually being
and deployed. Um, so there was the funny Nvidia leak of uh GPT4 actually being potentially revealed as GPTOE1 BT. Um but more you know broadly
others like Grock um and Deepseek and Llama 4 now um have all adopted a mixture of experts uh architecture and it seems like at this point in 2025 that
the advantage of mixtures of experts over dense architectures is very much clear right almost um all compute scales training a mixture of experts model if you do it well um is going to give you benefits over a dense model and so
everyone seems seems to be doing it in both the east and the west. And so this will be an important um thing to understand if you're trying to build sort of the the best model that you can um for the flops that you
have. So mixture of experts is very simple. Um it's a very terribly named
have. So mixture of experts is very simple. Um it's a very terribly named concept. I think you hear mixture of experts and you think oh there must be
concept. I think you hear mixture of experts and you think oh there must be experts specialized for different domains and they're like doing different things. Like there's a coding expert and like an English expert and a other
things. Like there's a coding expert and like an English expert and a other languages expert. um it is very far from that mental model. A mixture of experts
languages expert. um it is very far from that mental model. A mixture of experts is a type of fancy architecture that has several subcomponents called experts that are activated sparsely. Um and in particular when you think about mixture
of experts you should be thinking about the MLPS. This is where all the action is right. So ae architecture and a nonoe architecture are going to be similar in
is right. So ae architecture and a nonoe architecture are going to be similar in almost all of its components um except for one and that is you know if you look at this this slide over here you know this is the components um of a standard
transformer you got your self attention you got your FFN if you zoom in you know in a dense model the the feed forward component just sort of you know is there it's one big block in a sparse model what you would do is that you would take
this FFN and you would split it up or you would copy it depending on how you're going to be setting up your multiple copies, let's say, of your FFN, your fully connected networks, and you're going to have a router that picks some smaller number of those, you know, in each forward pass or at each
inference, right? So, this is the basic idea behind, and we're going to replace
inference, right? So, this is the basic idea behind, and we're going to replace this one big feed forward on the left side with a selector layer and many smaller ones. Um, and what's the advantage of this thing? Well, if it's
smaller ones. Um, and what's the advantage of this thing? Well, if it's sparssely activated, that is, let's say only picks one expert and an expert is the same size as your dense FFN, then the flops between the left side and the
right side, the dense model and the model, they have the same flops, right?
They're doing the same matrix multiplies as you do your forward pass. Um, so you have more parameters without affecting your flops. And if you're a believer that what matters is having more parameters to, for example, memorize
facts about the world, well, you know, this is a great architecture. So you can kind of see the the intuition behind. Um hopefully that's all very clear. Um and
you might wonder okay so it makes sense that you can get you know more parameters per flops but does that translate to actually better performance for the models that you're training. Um, and there's been, I think at this point,
many, many, many papers showing that at the same flop count, at the same training amount of flops, you get better performance out of a mixture of experts, um, than out of a dense model. So, um, this is a nice paper to so today I'm
going to go over a couple of the the classic Google papers that, you know, put this field together. Um, and this is one of them by Fetis at all 2022. um
where they show that you know if you flops match your your you know training flops so that's the same amount of compute used for training and as you increase the number of experts the training loss of your language model just keeps going down and down and down and down and down right so you know more
experts better of course the experts aren't free you need to store the memory for these experts and when you do parallelism you're going to have to think about you know routing your data into 256 separate experts so that's there's going to be systems complexities but if you're only thinking about flops.
This is a great chart to see because you have the same flops, but you've gotten free, you know, test loss um here. And you see the same thing reflected on the right side. You know, as you train for longer and longer, the model, the switch
right side. You know, as you train for longer and longer, the model, the switch base with 128 experts, right? The model with more experts, you know, gets better perplexity um faster, right? So hopefully that that is quite clear. Um
you might say, well, this is a 2022 paper. Um is this true sort of on modern architectures on modern scales? Um it you know continues to very much be true.
Um AI2 had a very nice paper um OLO which did a whole bunch of ablations and carefully controlled comparisons into dense versus um and other architectures and they sort of see exactly the same thing. So here on the left side this is
still from Fetus at all. You see the 7x speed up from having many experts. On
the right side this is the OMO comparison. You see um the the pink one is thee and the the teal one is dense and the training loss for the dense model goes down much more slowly than thee. Right? So hopefully you know I
have in some sense sold you on the value of and for for learning this kind of new slightly new architecture. Right? So we're going to pay a a price for all of this but at least at the flops level this looks very compelling. Right? So
yes question the last lecture you mentioned uh like yeah the bias turning part because although the it's a pretty cheap computation affects our actual process
pretty badly we're you know loading in and out is there right so the question was in last lecture you know I was saying um even small non- flops you know negligible
flops can be really big in wall clock is anything in thee world going to look like that and so I think one of the drawbacks of why you know that's not the standard thing that's being taught you know let's say at 224n is because there's significant systems complexities to making this thing efficient so I'll
you know get to that it's possible to make these things very efficient especially if each expert lives on a separate device so that you know you're routing thing data to different places you can be very efficient when you do that but it's not easy right so there's a lot of infrastructural concerns s and
you're going to see a lot of complexities to get this thing to work.
But when it does work, you know, you're you're putting all of your flops to use. Okay. And then uh the last one uh that I wanted to show is, you know, a
use. Okay. And then uh the last one uh that I wanted to show is, you know, a lot of the companies really love because you get to present plots that look very compelling like this, right? This was from the DeepSeek V2 paper. Um you know,
on the X-axis, this is this is a little bit of slight of hand. This is only activated parameters, right? So this is only the parameters that are used you know for computation. So you ignore all the deactivated experts and the y- axis
is mmlu performance, right? And we see deepseek v2. Wow, look, very few activated parameters, really good mmlu performance, right? And so if you're only interested in both training and inference flops, you know, activated parameters is the name of the game. You get really good performance here. And
this is, you know, not just an ablation. This is a real system that someone, you know, spent a lot of money to train um and deployed out in the wild, right? And
we'll see this sort of pattern uh recur in other uh examples as well. Oh, was
there a question? Oh, no. All right. And so the systems thing that is also um a benefit is that allow us to have another axis of parallelism. So I'm going to get into parallelism in much much more detail sort of in the systems lectures
when I'm going to talk about how you're going to take your model and you're going to cut it up into many small pieces and lay them out across many different devices. But I'm going to talk at a very high level. But when you have
different devices. But I'm going to talk at a very high level. But when you have experts, there's a very natural way to parallelize at the expert level, right?
So you have multiple different feed forward blocks. You can take each of these experts and you can put them on a different device, right? And because
experts are sparsely activated, all you have to do is take your token and route it to the appropriate device and the computation will happen on that device, right? So it's a natural sort of cutting point to be able to shard your model
right? So it's a natural sort of cutting point to be able to shard your model into different devices. And so this is called expert parallelism. Um, and this is another reason why are very popular, right? If you really want to paralyze
really big models, this is a thing that you're going to have to do. And kind of interestingly enough, I think, you know, developed at Google um and many of the the frontier labs, the closed labs were doing it, but I think the open results
actually came from China um very frequently. um Quen and Deepseek were doing you know a lot of work um last year and it's only really recently that I think western open source groups have started to do more uh work so Mixstrol
Grock um I guess Grock's not open um and then now Llama is now ane architecture right and so here you know Llama 4 just got released right latest and greatest
um this is also a sparsee and I'll talk about llama 4 as well um as I go through the lecture Um as I said before, so you know one of the kind of starting points for
this is um some of the Chinese groups Quen and Deepseek have actually done some really nice work benchmarking and understanding and evaluating um some of
these results. So these so Quen 1.5 was one of the first uh models that I knew
these results. So these so Quen 1.5 was one of the first uh models that I knew of to have like this largecale well- tested well doumented. Um, and what they
did was they took a Quen 1.5 dense model and they had a nice trick to upcycle it into a mixture of experts. That's a clever kind of trick to take a dense model and then turn it into ane. And they showed sort of significant gains at
least in terms of compute efficiency um while sort of m decreasing the total number of parameters relative to their sort of 7B uh model. Deepseek, which is now famous, um, but originally when these papers were coming out were were
not quite as famous, um, did some of the, I think, really foundational work in the open-source world. Um, a big part of this lecture is actually going to be tracing the trajectory of the mo uh, Deepseek MOE architecture. But if you
look at their original Deepseek paper, you'll see very nice papers, sorry, very nice sort of comparison showing things like what happens when you train a dense model with a particular amount of flops. what happens when you train a a really naivee that doesn't do very smart routing what happens and then if you use
a smarter routing called the switch sort of uh what happens and so you'll see all of these very carefully controlled comparisons and you see you know as you go from dense to sparse right so that's the leftmost column to the rightmost column you see all of these sort of benchmark metrics very consistently
improve for a fixed amount of flops right so this is um very consistent and kind of um one thing that I think almost everyone at this point has probably heard of, right, is Deepseek V3. And that's in some sense, you know, a
culmination of all this line of work. But if you had been following and you were excited about kind of this branch of of um neural networks and and language modeling, you would have actually known about, you know, Deepseek long before V3 got popular. Um, and we'll see at the very end of this
lecture actually Deepseek V3 is not very different from the very earliest DeepSeces. architecturally, you know, they had kind of nailed it way back when
DeepSeces. architecturally, you know, they had kind of nailed it way back when they were training these sort of much smaller two billion parameter models.
They really just kind of got the engineering right to get something that is actually really quite remarkably good, uh, which is their V3 model.
Okay, so now I think you know I have spent quite a few minutes trying to really hype you up ones and they really are I think worth hyping up. They're
very good. Um, but I think there's a question of why haven't they been more popular, right? Why isn't it the standard thing we teach in, you know,
popular, right? Why isn't it the standard thing we teach in, you know, NLP and and language modeling classes? Um, it's just that they're very complex and they're very messy and I'm hoping that they'll get simplified over the
next few years, but they still remain pretty nasty. Um so one of the things is you know the infrastructure is very complex and the biggest advantages of really happen when you're doing multi-node training like when you have to split up your models anyway then it starts to make sense to shard experts
across different models that's a very natural thing to do but until you get to that point maybe are not quite as good right so um some of the earlier Google papers really talk about this trade-off where they say actually when you get
these really big models that you have to split up then exports become uniquely good. Um there's also other things that are are really tricky. Um if you think
good. Um there's also other things that are are really tricky. Um if you think about it carefully, right, this decision of which expert you route tokens to is a very difficult thing to learn, right? In deep learning, we really like differentiable objectives, right? Very smooth things that we can take gradients
of. Routing decisions are not differentiable because we have to pick
of. Routing decisions are not differentiable because we have to pick and commit to a particular expert. So if we're doing that, you know, we're going to have a very tricky optimization problem. And the training objectives to
make that work is either heristic and or unstable, right? And so we're going to have to really carefully engineer those guys to get them to work, right? So
those are two reasons why you don't really um want to to maybe do this normally. Um so what do look like? As I started this lecture with, you know, the
normally. Um so what do look like? As I started this lecture with, you know, the classices that you should think of is you take, you know, the densely connected layers, the the FFNs, and you split them up or you, you know, copy
them um and you have sparse routing decisions among them. Of course, you could do the same kind of idea. You could have a sparsely routed attention layer. And some people have done this. There's been a couple papers and a
layer. And some people have done this. There's been a couple papers and a couple releases that have taken this approach. Um but it is actually quite rare uh to see this in the major model releases. Um I think I've seen people
talking on the internet saying like this approach is actually really much even more unstable and very difficult to really train consistently. Um it's sort of I haven't really seen the ablations to to back that out but certainly there
haven't really been many people training those kinds of models with uh attentions. So now you know I've told you about the basic architecture, right?
attentions. So now you know I've told you about the basic architecture, right?
It's really simple. It's just you have a router of some kind and you route and then you have different MLPS. So what are the the things that might vary across different choices? Um you might ask how do we route right the routing
function is an obviously important choice. How many experts and how big should the experts be? That's another choice. And then the final one is how would we train this router? Right? this non-ifferiable objective that seems very difficult to train. So those are very important design questions and we're
going to go through each one um hopefully covering the design space um of all these things. Okay. Any questions before I get into to each one of these different subcomponents here. Good. Okay.
So if you're interested in in just kind of understanding a broad overview of at least circa 2022 um there's a really nice sort of survey or a review paper by fetus at all in 2022 that covers a lot of these and and many of my figures you
know are credited to that paper. If we're thinking about how we're going to route or essentially match tokens to experts, right, this is the core component of because whate does is, you know, tokens are going to be coming in,
right? You have your sequence that you're processing and those sequences
right? You have your sequence that you're processing and those sequences are going to be assigned to experts, right? Not all experts will process every token. That's the whole point of a sparsely routed. And so you can ask how
every token. That's the whole point of a sparsely routed. And so you can ask how are these routing decisions made? So you can sort of have three different kinds of choices. You can have token choice where each token is going to have a sort
of choices. You can have token choice where each token is going to have a sort of routing sort of preference for different experts and I will choose the top K experts for each token or I can have expert choice where each expert is
going to sort of have a rank preference over tokens and then I'm going to choose the top K tokens for each expert. This has a really nice benefit of being balanced over experts. Um and then the last one is sort of you could solve some
sort of complicated optimization problem to make sure that the mapping between experts and tokens is somehow balanced. Right? This is global assignment.
Um and just to to you know give you a bit of a a teaser here almost all the do token choice top K. Um in the early days of people tried many many different
things sort of spanning this whole spectrum of design space of token routers. Um if you look at the big releases they have all converged to
routers. Um if you look at the big releases they have all converged to basically one class of routing mechanisms which is token choice top K.
So each token is going to rank order um experts by affinity and then there's going to be kind of a top k choice for each one of this um and mo which I'll keep referring to uh throughout this lecture because they have a really nice series of ablations. So it's really nice to teach off of um have exactly this
oblation. They compare a token choice routing versus an expert choice routing
oblation. They compare a token choice routing versus an expert choice routing and they show if you look at validation loss token choice is much much nicer behaved um much faster um in loss decay. Yes. Is this function a function of the
token itself for its position? Uh it's a function of the sort of the hidden state, right? So the token is going to get processed with all the position
state, right? So the token is going to get processed with all the position embeddings and so on and then the hidden state will come in and then it will be processed by the MLP. And so for the other um for the other two like for the
experts choosing the token and also like the next one when you say it's like more balanced across the experts are they like it's still for the current
like token sequence but it's like it's forcing them to be more distributed like it's still going to be the same set of tokens but really it's about kind of the ranking selector function right in token choice I'm just going to take the top K amongst the columns like maybe the fours are even identical, right? I'm just
going to take the top K amongst the columns. In expert choice, I'm going to take top K amongst the rows, right? Um, and top K amongst the columns is kind of nice because you might be able to say, oh, I can define a scoring function such
that the score is how well each token gets processed by each expert. And token
choice will will route me to the best expert, right, for that token. So, that
makes sense from processing. But expert choice has the benefit that each expert gets exactly the same number of tokens. And so now you might like if you're putting different experts on different devices, you've got balanced utilization. So there's different trade-offs at play as you think about
utilization. So there's different trade-offs at play as you think about routing. Yes. Um how does a token know which expert is the best? Good. Yes. So
routing. Yes. Um how does a token know which expert is the best? Good. Yes. So
the question was how how does each token know which expert is good? That is
exactly the role of the router. And I'll give you the router equation, but to give you a bit of a not really a spoiler, but you know the routers are much more lightweight than you think. So you know your token, let's say, is represented by vector X. That's your like hidden you know residual stream
coming in. So now X is going to get multiplied by you know W a matrix and
coming in. So now X is going to get multiplied by you know W a matrix and then you'll just take you know a sigmoid or something and that's the score. So
it's really just a a vector vector inner product almost like an attention operation in a way. Yes.
Right. So the choice of so the question was is K1 here? Um so K is actually a hyperparameter and different will choose different things. Um, I will talk about this again, but to give you the high level intuition, the initial argument
that the the earliest MOE papers made was that K should be greater than two because that way you get some exploration, right? If you're doing K equals 1, maybe you're just always exploiting the best arm and you'll never know about the potential other things you could do. But if K is two, then
maybe that second arm can tell you a little bit of exploration information.
So, um, you know, K equals 2 was the canonical choice. Um, and K equals 2 actually continues to be very popular. That would be like double the flops.
Like that's right. That's right. So So that would double the flops. And so when people talk about they usually say things like x number of activated parameters and that would account for the fact that you're you know put
putting in two MLPS. Yes. So when k is um greater than one like even time do we combine the outputs of the different experts into? Yes. The question was when k is one do the outputs get combined? That's right. Like if you if you look at
I guess like look at the attention diagram over there you know um you got the router it's routed to two MLPS up top and then they get combined together right after right so that's exactly right so in that case you can just like a simple average
so the question was how does the the aggregation happen it's just the sum right so um I'm going to go over the variance very common variance that
people do um and really in some ways All you need to know is top K in order to actually implement a uh high performance. But I'll give you the other variants because they're natural things you might think of. Um top k routing is
what is used in most token choice top routing uh topk routing. So how that works is you know you have your residual stream inputs x. Um that will go into a router and as I said a router is really kind of like the attention operation.
There's like a linear inner product and then a softmax and then you pick the top k most highly activated um experts and then those outputs are um gated.
Depending on the implementation you might weight the outputs based on um this router weight or you might not um and then you will just output the weighted average or just a straight sum depending on how your implementation
works. Um and so a lot of thee papers and and methods use top k switch
works. Um and so a lot of thee papers and and methods use top k switch transformer gshard grock mixtro clan uh all the deepseeek um variants use uh different top k
variants. Um maybe a very surprising fact and this should really make you
variants. Um maybe a very surprising fact and this should really make you think about what's going on with um there are a lot of results that show that actually you don't even need a smart router at all. you can actually
just use a hashing function at the very bottom to map these X's onto your experts. And even if you're doing hashing, so no semantic information at
experts. And even if you're doing hashing, so no semantic information at all, you will still get gains from a hashing based, which is pretty wild. Um, some of the earliest work ones, I think, had the very smart idea
wild. Um, some of the earliest work ones, I think, had the very smart idea and in many ways the right idea if you're thinking about this top down of using RL to learn the routing behavior, right? Of course, you know, the choice
of where to route to is a discrete decision, and RL is great for learning discrete decisions. Why don't you use RL to learn routing? It was used in some of
discrete decisions. Why don't you use RL to learn routing? It was used in some of the earliest work on mixture of experts. Um, as far as I know, basically no one does this now. The compute cost to do this is too prohibitive and you already
have stability issues. You might not want to do that. Um, there have been a couple of papers that have explored things like solving linear assignment problems or optimal transport style problems. um they're very elegant but once again the cost of doing this is much higher than the benefits that it
gives you I think in practice and it hasn't really uh been adopted but there's a lot of really interesting things that people are doing like this um to try to improve uh the routing so now I can you know point at
this slide and really talk through how routing um works in detail um so this is the kind of top K routing that almost everyone has converged to um now Um this is the the router that's used in Deepseek v1 to2 Quen and Grock do almost
exactly this. Um there's a instead of having a softmax directly at the bottom
exactly this. Um there's a instead of having a softmax directly at the bottom here um they do a soft uh DeepSseek V3 mixrol dBrx um don't have a softmax at the bottom but they'll softmax the G of its but it's a very minor difference. So
let's walk through what's going on here um and try to to reason about you know the behavior of this. So what's happening here is at the very bottom um
we've got our inputs. This is our our UFl uh input and I would like to take this sort of residual stream input and process it through you know my first thing I'm going to do is I have to figure out which experts are going to
be activated. Now how am I going to do that? Well how I'm going to do that is
be activated. Now how am I going to do that? Well how I'm going to do that is very similar to attention. I'm gonna take my U, which is my residual stream input, and I'm gonna take the inner products with the E of I's. These are
kind of learned vectors that are for each expert that tells the expert I'm an expert, you know, that points in this direction, right? And so I'm computing in this inner product here expert and input affinity and I'm computing a
softmax to determine for each, you know, uh, token what are, you know, the best experts, right? So I normalize this is S of T. Now I take the S of I of T and I
experts, right? So I normalize this is S of T. Now I take the S of I of T and I go through a top K function. I only select the K best um weights and then I
use this as my gate. So I zero out everything else and I take the weighted average of each of the experts outputs. Um and then I add that to my original
residual stream and then I return that. Right? So this is hopefully very familiar to kind of um what you're all very familiar with in terms of how you know uh transformer works with only the difference of this top k routing piece.
Is that clear kind of to everyone how this thing works? Good. Excellent. So in some sense the the mechanics of the the forward process of the routing is very simple.
Um what is kind of mystifying is that fact that you can learn this very well.
Right? This is in some sense a fairly complicated set of things to have to learn to do well by a model. Yes. So we're using soft max here. Uh
previously one of the benefits of soft max is that it's going to push you pretty extremely to choosing a singular max. It's not a hard max, but it was so I'm having trouble thinking of the intuition of putting the soft max
basically on top of like combining it with the top K where you're getting multiple and then you're using something that's going to push you towards choosing just one thing. Yeah. I mean I think maybe one way of thinking about
the soft max is you know it the whole purpose of this is just to make it so that when I average my experts later it kind of sums to one. Don't think of the softmax as like a soft max operation even though that's literally the name.
Um really the softmax operation is a normalized to one operation and the normalize to one operation is going to make that a weighted average up top. Um
the other thing that's very important is you know you might think why can't I just get rid of the top K why don't I just use the softmax here and just you know gate all the experts well then you immediately lose the the systems
efficiency aspect of this right you have to have top k during training otherwise you pay the training cost of all capital n of your experts right this is the key thing about like we have to do all this gymnastics to make sure that both at
training time and inference time we have a sparse number of activated ated experts. That's why we go through the top K, right? Okay. Yes. From the back.
experts. That's why we go through the top K, right? Okay. Yes. From the back.
Yeah. So, because you're doing soft max first and then the top K get the weights, you no longer have to guarantee. So, the question was, yeah, so the question was um if you soft max first, you no longer sum to one. And yes, uh
that's absolutely right. You no longer sum to one. And in some ways, like there's no requirement that you have to sum to one cuz you know the next layer can magnify it back up. you know there's layer norms everywhere. It's not as if it has to sum to one. But I think that is the reason why some of the other
architectures basically move the location of the softmax. There's a kind of aesthetic choice about whether you really want that, you know, weight to be normalized to one or not. Yes. Yeah. So I was wondering like how does the E vector here relates to the weight of the feed forward? Okay. So the
question was whether the whether and how the E vectors relate to the feed forward. Um they're not really tied in any way. The E vectors are just learned
forward. Um they're not really tied in any way. The E vectors are just learned vectors for the just think of the E as parameters for the router, right?
They're just separate objects from the FFM. Yeah, I was just wondering how does compared to sampling from the soft. Great. Uh the question was about how does it compare to sampling from the
softmax? Um you can sample from the softmax and and some um uh methods
softmax? Um you can sample from the softmax and and some um uh methods actually do a kind of soft sampling from the softmax. Specifically, um, one of the Google papers has a procedure where they take the top element of the softmax
and then they randomly sample the second element proportional to the remainder of the softmax. Um, and that gives you more exploration, which is good, but the
the softmax. Um, and that gives you more exploration, which is good, but the drawback of that is that if you don't sample at test time, now you've got a train test mismatch.
Okay. Yes. Why not just reormalize after the top K? Uh, why not just reormalize after K was the question. Is that right? Um, in some some models do that. Some
models do bring normalize off to the top K, but that's a kind of a choice like some architectures don't do that, some architectures do. It doesn't actually matter because the scale can be basically adjusted post hop, right? So
there's no reason why it has to sum to one after the G operation. Cool. Oh, sorry. Yes, the
bias term is U there up there. Yeah. So the first term of the sum if g is approximating probability vector could be seen as an expectation of the
function fn right plus you. So so ff actually uh this is not an expectation of ffn because each ffn is a different fn so this is not actually an
expectation and the gates are sparse. So this is like a weighted selection operation over k different or actually capital n different ffns and then the utl at the very end there you know if you remember the transformer that's the
residual stream right so I'm adding back the inputs because I want sort of a identity connection throughout okay oh there's another question uh why does the router have such a basic parameterization like what happens if
you put more weights into um your router function right the the question was why is the router so basic seems like if you're going to have experts, it seems important to route to the right experts. Um, so why don't you do that? Um, I
think, you know, there have been some oblations in some of the the earlier Google papers on having like MLP uh routers and like more sophisticated things. Um, I think the the sort of complex answer here is that the systems
things. Um, I think the the sort of complex answer here is that the systems concerns sort of weigh heavily. If you're using a lot of flops to make routing decisions, you know, you have to pay for those flops and so you have to get performance improvements in just the routing, you know, and I think the one
other thing to appreciate here is that there are really big limits to how well you can route because the learning process for this routing thing is actually pretty dicey, right? Because how are you going to get gradients for which routers are good or bad? Well, the only thing you have is if you have top
two, then you can compare the two things that you have and you can push the gradients into S of T because your your G is a weight and then the S of T might inform your your inner products. But that's a very indirect way to be learning your affinity. So even if you make it complex, there's no guarantee
that you're going to really learn the optimal router, right? Great. Okay.
So um I think the one of the great innovations of the deepseek and uh which was very quickly adopted by all the other sort of Chinese uhe releases is this idea of both a
shared expert and a fine grained expert. Um and so the basic structure that was sort of originally proposed is to take your dense architecture and kind of copy the experts over right. So in this case, you know, you're going to have, let's
say, if you have two, if you have top two routing, you know, you're going to have twice the activated parameters of your original dense model, right? So you
take your and you copy it over and you activate K equals 2. So this is kind of what you might think of as like the vanilla or like the basic that you might start with. Um, people realize fairly quickly um that having lots of experts
start with. Um, people realize fairly quickly um that having lots of experts is good. And the logical sort of next step beyond having lots of experts is
is good. And the logical sort of next step beyond having lots of experts is good is I want lots of experts but I don't want to pay the parameter cost for having lots of experts. And so um DeepSsee basically argued that the right
thing to do then was to cut the expert up into smaller pieces, right? So
remember last lecture I was you know telling you about oh the the kind of golden rule in some sense is to have you know your your hidden layer and then you multiply that by four and that will give you kind of your projection layer right so now what you would do is you would instead of multiplying by let's say four
you might multiply by two right so now you have smaller matrices you have more fine grained experts you can have twice as many of them right um and you can kind of take that logic much more to the extreme you can like you know quadruple or multiply by eight and you can keep decreasing the size of your sort of
projection uh dimension there that's fine grained experts and there's you know drawbacks I'll I'll talk about later it's not it doesn't come for free so you have to be very careful about how you you structure these things um and
then the the other thing that you know uh has been sort of studied and noted is maybe it's helpful to have at least some MLP that can capture shared structure right like maybe there's just like processing that always needs to happen
no matter which token you're processing in that case it seems like kind of a waste to do all this routing work and to have all these like, you know, parameters spread out everywhere when we can just have one shared or one or a few shared experts, you know, whose job it is to handle all of this like shared
processing that's needed. And so they're shared experts. Um, and so this setup of using fine grained experts plus shared experts um originally came out um in DeepSeek um although I think the original inspiration came from deep
speed um and Quen and others. Um, so almost all of the uh open uh releases since DeepSeek have adopted some sets of these innovations because it's it's
quite clear that especially fine grained experts is just really really useful.
That's a kind of no-brainer um at this point to do. Um, one of the things I really like about uh reading Deepseek papers is that they do ablations. You
know, it's not like a whatever sales tech report. you know, they actually care about whether or not their methods work. Um, and so they have this lovely ablation in the Deepseek paper where they show, you know, the the blue bar
over here, this is G-Shard. This is a very basic vanilla implementation of you know, you can have uh one shared expert, that's the orange bar, and that gives you a big boost on some tasks and no boosts on others. Um, you can have fine grained experts, that's the green and orange, you know, bars, and you get
further boosts from that. And if you compare the blue to the orange, you know, composing all these differences give you quite the big boost um over others. Um and so we can see that, you know, more experts and shared experts
others. Um and so we can see that, you know, more experts and shared experts generally um seem to help. Okay. Yes. Question. Like when it says seven out of something, does that mean it's doing like top seven? Yes. Sorry, I should
have I should have explained that. That's right. Um so X out of Y means X activated out of Y total routed experts. That's right. Yeah. And so you can kind of see the pattern here as well of as you increase the number of experts, you
also often increase the number of activated experts. Um especially if you're doing fine grained experts, it flops wise it's free, right? Because you
know each expert is now smaller. Good. Okay. Um so has you know basically corroborating evidence that shows really nicely uh that these things work. So,
uh, the bottom one I think I'll start with because it's more decisive, um, shows, you know, fine- grained experts going from 8 to 32 to 64 fine grained experts mirroring in some sense the deepseek ablations. Um, and you see very clear trends and losses and other kinds of um, uh, metrics that you see
improvements going from 8 to 32 to 64, right? Fine grain experts is great. Um,
shared experts, which is, uh, purple versus teal at the very top, um, you actually don't see really any gains, at least in the mo setup. Um so they actually end up going with no shared experts. Um even though the Deep Seek paper seemed to show more gain. So that one actually is maybe more mixed given
this sort of follow-up or this you know third party uh replication of these kinds of ideas. So at this point you might be wondering you know what are common
ideas. So at this point you might be wondering you know what are common configurations? I think I'm going to, you know, take the page out of, you
configurations? I think I'm going to, you know, take the page out of, you know, last lectures playbook of looking at a lot of the recent releases, you know, looking at what people do and trying to talk a little bit about the
patterns that have have arisen. Um, so some of the early um, uh, Google papers, so Ghart, Switch Transformer, Stmoe, um, some of them had really large numbers of routed experts. Um, and there was a lot of like really interesting stuff going
routed experts. Um, and there was a lot of like really interesting stuff going on in those papers. I'd encourage you to read them. Um some of them happened in LSTMs and other kinds of architectures. Um regardless you know very quickly I
think there was like kind of a period of like 8 to 16 experts like mixtrol DBRx Grock with two active um experts. Those worked reasonably well but then kind of deepsee or deepseek you know v1 comes out. Um that has kind of the the
prototypical configuration I told you about fine grained experts 64 of them six actively routed two shared experts. Um, and each sort of expert is sort of one/4 the size of a of a normally sized expert. Um, take that last column with a
grain of salt because I had to sort of back them out from like config files and things like that. So, I'm not 100% sure about the exact ratios here. Um, so
we've then got essentially Quen 1.5, Deepseek V3, um, Minax. These are, you know, Chinese. They follow essentially in the same footsteps as Deepseek v1.
know, Chinese. They follow essentially in the same footsteps as Deepseek v1.
the specific numbers are different but in the in the sense that they use you know um fine grained experts and they often have shared experts they're very similar to kind of this um original deepsee configuration um OMO minimax and
llama are very recent they definitely do all this like fine grained expert stuff um and llama 4 also uses um a shared expert and you kind of see um sort of variations in configuration but you see what's basically shared which is this
this fine grained experts ID idea and especially for the big models like llama 4 and deepseeek very very large numbers of routed experts or sorry not not routed like total total experts. Yes. So can you explain what the ratios are the
the ratio is expend uh is representing roughly like how much each export is sliced relative to having just the standard dense configuration. So in
terms of hyperparameters, you know that if you're following the rule of thumb, your hidden dimension and sort of your projection from in your MLP should be about 1 to four or one to 2.6 if you're doing a gated network, right? And so by looking at the hidden layers of these architectures, you can kind of see how
many times they sliced up that that original uh uh feed forward size. So if like for those experts, does that mean that like still increasing their group like the factor. That's right. Yeah. So, so you know, you can think of this as
the factor. That's right. Yeah. So, so you know, you can think of this as roughly, you know, they have, you know, 16 normally sized experts. Oh, okay. And
so they, you know, they're they're of course having more parameters than the dens equivalent. They have six routed um so they have eight total active experts
dens equivalent. They have six routed um so they have eight total active experts at any time each that are quarter sized. And so you should think of them as like roughly double the flops, right, of a of a dense equivalent. So some arithmetic,
but hopefully uh the math is clear and consistent hopefully. Yes. like the ratios like one are like so for some of the exotic ratios I'm not quite sure why they're that way but they are very precisely whole numbers when you when you take the ratios between the
the FFNS and the implied hyperparameters and so I think those are exactly the the split counts of like how much they were sliced but I'm not sure why they they have one over 14 I mean like like does it do you ever like project to like smaller dimension because like that ratio is so small in the MLP. So yeah.
So yeah. Oh, that's why you're asking like do they do they down project? Yeah,
that's right. In some of them they are actually smaller. I don't remember which models in particular, but in some of them I do remember they were actually down project.
Yes. What is the intuition for wanting more than one shared expert? Yeah, I
mean I it does kind of seem like there was a period where where some of the Chinese LM companies tried many shared experts and then you know people have come back to zero or one. And if you look at um the OM ablations, it's not
quite clear that even one shared expert is decisively useful. Um I think the original motivation was that then you have equally sized, you know, experts like these are both one quarter sized experts and now you have eight active
experts total and so you can keep the sizes consistent. Otherwise, I don't really see a particular justification for why it should be two smaller one versus one larger one. Okay, cool. So then hopefully you know you get a sense of how the the routing works um
one. Okay, cool. So then hopefully you know you get a sense of how the the routing works um for a lot of these and how it's all set up. The forward pass hopefully you fully understand. Um now we need to think about training and training is pretty
understand. Um now we need to think about training and training is pretty pretty gnarly right um and the major challenge I foreshadowed earlier right when we train we cannot turn on all the experts because if we do that then we
pay the full flops cost of all the experts right having a model that's like I don't know 256 times more expensive to train is a total no-go right so we need train times sparity but sparse gaining decisions are are obviously not
differentiable we now have a kind of annoying RLish problem. And so we could do any of these things like RL to optimize gating policies. We could do you know bandit inspired things of doing randomization to to do exploration. Um
or you know we can just have some heruristics that try to balance things out right like put some loss terms in there and hope things work out. Um you
know having gone through deep learning classes of many kinds you can kind of guess internally which one people use in practice. Um and I'll talk about each one of these three in turn. Okay, so RL I think is one of the the
earliest things that people tried. It's probably the most principle thing that you can do in this space, right? You have a you know non-ifferiable routing decision. Well, think of that as a policy, throw RL at it and then solve
decision. Well, think of that as a policy, throw RL at it and then solve the problem. Um unfortunately it's not better than a lot of the other things
the problem. Um unfortunately it's not better than a lot of the other things that you can do. Um there is a paper by Clark at all in 2020 who were exploring
various like scaling related questions in uh and they do have an RL baseline that you know I was able to dig up. Um but unfortunately it's not really that much better than say using hashing for decisions and they were you know they
were really interested in benchmarking this thing on the left called SBS which is like a linear assignment kind of a method and that thing you know handily beats you know doing RL and I think In practice, the the gradient variances and
complexity means that it's pretty finicky to use and no one, you know, at at scale has really used um an RL based approach to optimize these gating
decisions as far as I know. Um a thing that has been done much more at scale um is stochastic approximations of various kinds. Um so what they might do is they
might add a bit of you know perturbations. Um so here is an example of one um from Shazir in 2017. Um this is one of the early uh papers where they're still going to do kind of top k routing. So they're going to keep the
top k elements of this h of x operation and they're going to softmax that to get the gate. But what we're going to do to get this you know h of x operation is
the gate. But what we're going to do to get this you know h of x operation is kind of the following. So what we're going to do is we're going to have our original sort of linear you know affinity. This is identical to what we were doing before. Um, we were basically just computing, you know, our inputs X
and, you know, a sort of learned weight for each gate. And so this part's the same, but I'm actually now going to jitter it a little bit. I'm going to add a normal and then I'm going to pick sort of a W noise scale um that's learned.
And this thing is going to control how much noise to inject into this process.
And you can kind of think of this as a stochcastic exploration policy. And by
manipulating W noise in particular ways like sort of a kneeling it down or doing various things I can control the exploration exploitation trade-offs um that this is going to have right and so this is going to give you one solution
to the to the explore exploit dilemma um and especially if you're noising things up each expert might randomly get you know some other tokens that it wasn't expecting to get. So it'll lead to experts that are less specialized but
maybe a little bit more robust. Um and so that that seems generally quite nice.
Um of course the stochasticity also means that you don't get as much specialization and that leads to loss of efficiency. Um and you know there's another approach that people have done where they sort of multiply the router
loits um or sorry they they uh add yeah have a multiplicative perturbation to the to the router logets um with the goal of getting less brittle experts. Um
but this sort of jitter process was kind of removed in some of the later papers because they found it just didn't work as well as some of the heristic loss based approaches. And so this was an approach that was tried in a couple this
based approaches. And so this was an approach that was tried in a couple this kind of stochastic routing tricks were were tried in a couple of the the early Google papers. Um but I think that has generally been abandoned by a lot of the
Google papers. Um but I think that has generally been abandoned by a lot of the the people training these. Okay. So yes, um for the stochcastic like what problem does that solve?
Because we're still taking the top K. So we still can't differentiate backwards, right? Well, if you think of this, so the question was um we still can't
right? Well, if you think of this, so the question was um we still can't differentiate because we're taking the top K. But if you kind of change the your interpretation of the problem a little bit, um if you think about a bandit problem, right, it has the same structure as this where you know you
pull a bandit arm and you don't see any of the other arms. So you can't, you know, really allocate your resources efficiently. If you pull some of the other ones at random, now you've got enough data to be able to do some optimization. And so this jittering is very similar to in spirit to this kind
optimization. And so this jittering is very similar to in spirit to this kind of like epsilon greedy style exploration thing where you're randomly pulling some of the other arms with some probability where the probability itself depends on
how confident you are about this routing decision. So that's kind of the intuition and then of course you know um that's going to give you some way of of
getting some signal back. Okay. So, um the thing that in practice um people have ended up with is you know we don't do any of that. We don't do you
know RL we don't do stochastic exploration. Um but we rely on really another mechanism to sort of keep things reasonable. So if we're doing top two routing right technically speaking we do get some signal in the gradient descent
process because we can compare the top two you know uh experts that we did evaluate. Um and so it's possible to do you know some optimization but when we
evaluate. Um and so it's possible to do you know some optimization but when we do you know ignore if we if we drop all the other constraints um the big issue that arises is you just end up sort of picking one expert all the time and that
expert is good at everything and all the other experts are terrible right you end up in this local minimum where you've routed all of your tokens to one experts all the time so really the key game becomes then how do we get out of that
local minimum and loss balancing or like balancing losses is really the key trick to get out of this. And this is this is kind of important to understand because this is the the loss that mostly everyone actually uses to train the right. So if you were zoning out earlier, you know, you probably should
right. So if you were zoning out earlier, you know, you probably should make sure to pay attention to these this particular set of equations here. Um so
this is originally from the the switch transformer um from fetus at all in 2022 and they add this particular loss where what they're going to do is they're going to you know loop all over each of the experts and they're going to take
you know the you could think of this as an inner product between the vector f and the the vector p. And so what are these vectors? Well f is for each of the experts this is the fraction of the tokens that were allocated to expert I.
So you can think of this as kind of a probability vector that's telling me you know what fraction of my tokens in my batch or in my you know whatever the unit is here um did I route to expert per I. Now P of I is the fraction of the
router probability that was allocated to expert I. So the router probability is kind of the um the the original sort of softmaxed routing decision that I was sort of intending to send. Right. So this is kind of measuring P of I is what
was sort of the the intended probability from the router and then um F of I was what was the actual sort of like you know uh what was the actual routing decision made by the top K method. And one thing that's kind of interesting to
to look at here is let's say we take the derivative of that loss with respect to to P of I. So you know this is a a linear function with respect to P of I and you'll see that the strongest downweing action happens on the sort of
biggest experts with the biggest allocations. Right? So the it's actually in fact proportional to the amount of tokens that you get. So you're going to be pushed downwards um sort of more strongly if you got more tokens. And so
this is kind of the basic behavior of this loss. And you know almost everybody uses this kind of F.P kind of a of a trick to try to balance tokens across different units. So the basic unit that you might want to balance over um
different units. So the basic unit that you might want to balance over um initially is batches. You might want each batch to get allocated evenly to experts but you might actually have other kinds um of uh balancing that you might want to do. Um and Deepseek does uh exactly this kind of thing. I'll talk
about all the variants that they've thrown in but you know the first thing is per expert balancing per batch. So each batch they want to make sure experts get an even number of tokens and you know this is from the deepseek paper and hopefully this looks you know very familiar to you. This is exactly the
same you know F.P P inner product structure as you saw before you know P of I is defined a little bit differently that's S of I of T you know but that should be familiar from earlier as well that's the softmax pre-top K right so hopefully this looks all pretty good to you um the other thing you might want
though is you know you might want to balance across experts that's all well and good but you might also want to think about the systems concerns right because you're going to shard your experts onto different devices and you might want to balance per device right and so you might have another loss
that's essentially the same structure, but instead of summing, you know, which tokens go to which experts, you might measure which tokens go to which devices, right? And that's going to be a different f that's measured over the
devices, right? And that's going to be a different f that's measured over the device groups rather than over each expert. And so now you can set up a different loss to balance over devices. You optimize this. you're naturally
going to try to learn routing functions that make sure each GPU or each TPU, what have you, um, have an even number of tokens leading to even utilization, right? And that would be great from a systems perspective. So, basically, everyone
right? And that would be great from a systems perspective. So, basically, everyone does, you know, kind of this kind of a thing. Um, and so, Deep Seek V3, um, actually kind of innovates a little bit. This is this is kind of cool, and I
don't think I've seen this before. It's one of the uh first things in thee world that doesn't actually come from Google really. Um which is that they have gotten rid of this per expert balancing term. They've gotten rid of this
entirely and instead what they now do is they basically take their soft max scores and they add a little fudge factor B of I where B of I is a little fudge factor score for each expert. Right? So expert I you know might get
upweed or downweed. So, if if an expert isn't getting enough tokens, you know, it's going to uh be given a higher B of and then that's going to allow it to grab more tokens. Um, and the way that this works is um uh the way that this
works is that they're going to learn BOFi through a really simple online gradient scheme, online learning. And so they're going to measure at each batch, you know, what are each of the experts getting, like are they getting an even number of tokens? And if they're not getting enough tokens, they add sort of
gamma, some learning rate to B of I, sort of making it higher. If they're if they're getting too many tokens, they're going to subtract gamma, making that expert slightly less attractive, right? So they're just learning little, you know, offsets for each of the S of I. And notice here, you know, you're only
using the B of I to make the routing decisions. You're not actually sending it over as part of your gating weights, right? That's a that's a sort of somewhat important thing to do. So they call this auxiliary loss free balancing.
If you go and read the Deepseek V3 paper, which all of you should because it's a it's a really nice paper, um they'll make a big deal about how this makes training so stable, so great. Um so wonderful. Um and then of course you like keep reading the section and they're like actually but we decided
that you know for each sequence maybe we still want to be balanced and this doesn't work well enough so we've added the the you know the heristic loss back.
So they do have um something called a complimentary sequence-wise auxiliary loss that you know is basically exactly the auxiliary loss um that they decided they needed because what they wanted to do was to balance um uh load balance the
experts at a per sequence level rather than a per batch level. Um I'm not sure why they do this particular thing rather than any other sort of you know uh B of style trick but that's just kind of what they do um in deepseek v3. So it's not
fully auxiliary lossfree as they'd like you to believe. Okay. Oh yes. Question.
This is a bit of an unfair question, but if we did not have to worry about systems optimizations, do you think the performance of this model would be a lot better or would it stay roughly the same? If we did not think about systems optimization, would the performance of this model be better or stay the same?
When you say this model, what do you mean? Deep Seek V3 or like just in general like this model never. So are you saying like if we ignore the systems concerns um do we think are still good? Is that kind of one way of asking that
question? Like would the performance on downstream pass for example be better
question? Like would the performance on downstream pass for example be better than what we have right now? Yeah. So I think um I didn't have to balance this like I must set roughly equal number of tokens for every expert or Yeah. Yeah.
That's right. That's right. Well, I think actually per expert balancing this term, right? This is not a systems concern. So, you still want to do this
term, right? This is not a systems concern. So, you still want to do this because if you don't do this, what you'll find um and actually there's, you know, I'm going to keep referring to the old mode paper because they have so many ablations. They have a a really nice ablation where they get rid of exactly
ablations. They have a a really nice ablation where they get rid of exactly this. Um, and what they find is um basically early on in training the model
this. Um, and what they find is um basically early on in training the model just picks like one or two experts and all the other experts are dead. Like the
router never sends anything to them. So, you're just wasting memory at that point, right? So now you've just lost performance for free. You've effectively
point, right? So now you've just lost performance for free. You've effectively
gotten a smaller model. And so even if you ignore all the other like device balancing parallelism concerns, you've just gotten a worse model because you didn't properly allocate your your experts, right? It's the same way as like you want to use all your parameters, right? You would like to
effectively use your parameters. You want to do expert debalancing.
Sorry, say uh device. What does device refer to? Yeah, actually um so normally this would refer to like GPU or TPU. There is a subtlety. I I'll talk about this maybe in the very last or second to last slide. Um there are more sophisticated and cool versions of this where you try to balance things to
minimize communication costs as well. And so there's you know broader notions of device like you know one rack or whatever else but here it usually refers to like GPU.
Yes. going back to the fact that like hashing as a routing algorithm seems to improve performance like is there intuition for that because that's effectively just like randomly choosing a um like one of the few forward members
to send it through right so like why does having multiple copies of that I guess each of which get less data why does that make performance better yes the question was um why does hashing do anything at all um I don't have the you
really precise intuition for this. But you can make arguments either two ways.
One is, you know, even if you're hashing, the same tokens are going to go to the same, you know, or the same kinds of, you know, sequences are going to go to the same expert every time, right? And so each expert will still get some
deterministic subset of the inputs. And so there's some specialization that can still occur. It's just non-semantic or, you know, non-learned. Um, and if you're
still occur. It's just non-semantic or, you know, non-learned. Um, and if you're a distribution Zipian, like the word 'the', might dominate one expert, you know, and so you might still get actually semantic specialization where like one expert is effectively dominated by like very frequent things, but like a
random uh routing function probably wouldn't be a day like a pure random thing that's not dependent on input. Yeah, I would bet that that would be really terrible. Yes, I have never run or seen that, but yes, I think that
really terrible. Yes, I have never run or seen that, but yes, I think that would be that would be horrible. Good. Yes. Yeah. So for like during LM like you have many layers right many transform I think in the lecture you
mentioned that each expert okay so like you do like ad like 32 layers like 64 experts that's like a lot of GPUs or I wonder if like experts are bundled
together on like a single GPU is that the question was like won't you need lots of GPUs if you have lots of layers and lots of experts. Yeah if you if you exclusively give a GPU to a single expert. Yes, that would be that would be
kind of crazy. Um, but you would kind of shard things so that each GP would hold, you know, enough of these units um to, you know, effectively use memory, right?
The name of the game in parallelism is you always want to use up all of your memory because that's one of your resources, right? You don't want to paralyze more than you have to. Cool. Okay. Excellent. Oh, okay. I did put the ablation in here. Yeah. So,
to. Cool. Okay. Excellent. Oh, okay. I did put the ablation in here. Yeah. So,
this is exactly what happens to the question of what happens if you don't do, you know, um, uh, expert balancing loss. I think the the great picture to see is this bottom left one. If you don't do load balancing, you know, what are the tokens assigned to which expert? You see the pink and the yellow expert.
They just like kind of take over. They take up, you know, about 50% of the tokens. All the other experts are dead. They do nothing, right? And so you've
tokens. All the other experts are dead. They do nothing, right? And so you've wasted, you know, the majority of your experts at this point. You know, uh six out of eight of your experts. Um and you've created a two experte
unintentionally. And you know that gives you uh you know worse losses up seen up
unintentionally. And you know that gives you uh you know worse losses up seen up on the top right the teal lines. Um of course maybe that's still better than the dense model because at least you've got two experts going. Um but you could
have done better right counterfactually speaking. Okay. So um I won't go quite as deep as I could into the system side because I haven't really started to cover you know the core systems concepts necessary for you to deeply appreciate a
lot of the parallelism concerns like you know basically the hierarchy of communication speeds in a data center and so on. Um but really as I said before you know one thing to keep in mind is just how nicely can fit into
devices. you you know the the thing that people say is expert parallel you know
devices. you you know the the thing that people say is expert parallel you know that involves sending or putting one or a few um you know uh experts onto each device and what happens when you you know are basically processing a token.
Well, you would hit the router and after the router, you now have picked few experts. And so now you would have a collective communication call like all
experts. And so now you would have a collective communication call like all to all communication dispatch that would send the tokens to the relevant devices.
You know, the feed fors would compute um you know their outputs and then you would return the tokens um to sort of where they belong or you would you know combine I guess multiple experts and so you would need another sort of collective communication call. And so if your your feed forward computations are
sort of big and beefy enough, you can kind of pay for the cost of basically doing this expert parallelism. Um, and the one of the the thing that's nice about this is that it's another form of uh parallelism in your in your toolkit.
So you've got on the right side, you know, you know, data parallelism, model parallelism of, you know, two or three different kinds and then you've got expert parallelism and you can combine all of them to come up with sort of ways of trading off all the resources you have. So the communication speed, the
amount of data that you have, your batch size, um, and your your, um, number of experts and your memory. So, um, I'm not going to go into too much detail about how specifically this is going to help, but keep in mind that this gives you
another sort of tool in your expert toolkit. Another thing um that is also you know useful is let's say you have multiple experts um on a single device you know you might hope that because the computations are sparse like let's say
you know token one this first token you know gets multiplied to export zero the second one is expert one and this third one's expert two so this is really three matrix multiplies that are small and sparse and you might hope that modern
GPUs can sort of take advantage of these kinds of you know complex uh these kinds of sparse matrix multiplications. Um and that's exactly right. So if you you know lay out your your sort of experts correctly and the weights are sort of
fused in the right way then modern sort of sparse matrix multiply sort of engines can sort of effectively make sure that you're not wasting any flops in doing this one big matrix multiply. So, so modern libraries like meta mega
blocks can basically take advantage of this you know device level sort of sparity support to do multiple expert computations sort of all at once. So
this is yet another advantage that you get um with. Um so one fun side thing um which maybe isn't mysterious to you all anymore because you've sort of grown up
in the era of GPT4. Um but when the GPT4 API first came out, it was kind of mysterious to me because when you set the temperature to zero, you know, you kind of got different responses even though it was supposed to be
deterministic. Um, and lots of people speculated about why would that be. Um,
deterministic. Um, and lots of people speculated about why would that be. Um,
I'm not saying this is the the answer to that reason. Um, but there is actually a interesting source of randomness in right. So in movies, think about you know what happens. You're going to route your tokens to experts, right? And
experts live in different devices. Um, it could be that you know you have a lot of examples. You're going to batch of course batch your queries when you're
of examples. You're going to batch of course batch your queries when you're processing them. And so if you've batched your queries, these tokens are
processing them. And so if you've batched your queries, these tokens are going to get routed into different experts. So imagine you've got, you know, this this uh this batch to process and you've got a bunch of experts, but
for whatever reason, this batch really loves expert number three. Like all the tokens go to expert number three. So now what happens? Well, the device for expert number three doesn't have enough memory to load all of those tokens. Um
and then what happens is what people call token dropping. And this happens at training time as well. You often have what's called a load factor where you're sort of controlling the maximum number of allowed tokens. And if the router just allocates too many tokens to an expert, you just drop those tokens off
either for systems reasons or because you're just worried that that expert is going to take over at least in the training time. And so now this token has gotten dropped and it's not going to get anything at all. Like the MLP is just going to do a zero computation and the residual connection is just going to
pass things straight forward. Um, and then you're going to return an output.
And so if your token got dropped, you're going to get a different result than if your token didn't get dropped. And so based on, you know, who else is in your batch, can induce stochasticity both at training time and inference time, which is like kind of an interesting thing that you don't normally think about
because you almost never think about like cross batch effects um when uh doing inference. Okay, so that's kind of the the main bits of of you know the the
inference. Okay, so that's kind of the the main bits of of you know the the main basic components of of building the fun side thing. If you were to actually go out tomorrow and trying to train ane um I think the system side will make you
a little bit sad, but the other thing that would make you sad is probably the stability side of things. Um so kind of have this property that sometimes they'll just kind of blow up on you if you try to fine-tune them. They're very
difficult to fine-tune and they'll sometimes blow up on you. Um and so you know um Barrett Zoff and others um really studied they had a whole paper on basically trying to make more stable and there's a paper which is the one I'm I'm
referencing here um whose entire purpose is to stabilize training and there's a couple tricks that I'll I'll mention um that I think are relevant and that people do. Um the first one is you know if if you're doing the router softmax so
people do. Um the first one is you know if if you're doing the router softmax so this goes back to last lecture about stability right like what did I say about stability well the thing to be afraid of is the soft maxes right the
softmax is always where you want to be afraid and so um so for the um they do all the computations in float 32 for the router computations just to be safe
right um and sometimes they also add the you know an auxiliary zlos so hopefully you remember that it was just last lecture you know you do log of the sum of the the exponentiated you know values in the softmax and you square that and you add that as an extra loss right so this is going to keep the the normalizer
values near one which is nice for stability um so this is actually one of the places where zlos was used earlier before it got sort of more popular for training um models you can kind of see the effects here um if you look at the
losses I think the the center the second plot here is maybe a great one you know if you remove the zloss from your router uh routing function you see these like giant loss spikes um in your validation loss where you know the model just kind
of goes a little bit crazy um for a couple iterations and then gets kind of pulled back. Of course, it like still trains okay, but you are better off
pulled back. Of course, it like still trains okay, but you are better off having the Z loss than not having a Z-loss. There is a pretty noticeable gap in the validation loss by the end here, right? Um, other things that can happen,
um, people, you know, of course you want to fine-tune your like also RLHF your if you're going to, you know, ship and release it. Um, but this turns out to be, uh, kind of problematic. some of the earlier work, you know, when people were
starting to do, this was back in kind of the BERT and P5 era. So there was a lot of fine-tuning going on. Um, and you know, one of the things that people saw was, you know, actually there's a lot of overfitting that happens if you were
kind of doing uh sparse models. You see this big gap between train and val, right? This blue and orange line. Um whereas the dense model, this green and
right? This blue and orange line. Um whereas the dense model, this green and red line, has a smaller train test gap. Um, and so there was a lot of worries about overfitting because you have these like gigantic parameter models that
you're fine-tuning on small data. Um, one of the solutions that was proposed at the time, I don't think this is very popular, um, as far as I understand, is to architect yours such that not every layer is layer, but you like let's say
alternate dense layers and layers. Then you can just fine-tune the dense layers and then that will that will still be fine, right? That behaves just like a dense model. Um, so that was fine. Another solution um the one that we saw
dense model. Um, so that was fine. Another solution um the one that we saw in the the Deepseek MOE paper is just kind of use a lot of data like if overfitting is a problem you know we have access to lots and lots of SFT data
just shovel all of those guys in so in the case of uh Deepseek they use 1.4 four million training examples. Um then maybe you're you're not quite as worried
about these overfitting concerns. Um the last thing I'll end with which is a is a trick in the toolkit that I that people have done and seen um is upycling. And
so this idea is to take a dense model like the one over here um and then you take your MLP and you make a bunch of copies of it. um and then you maybe perturb it and then you have your router that's initialized from scratch and then
you just pretend this is and then you train it from that point on right you just initialize thee from a dense model and this is a trick that's kind of called upycling um and you know people have shown that if you can get it to
work um it is a very very very cost-effective way of getting right and thee is great for inference because not every MLP is going to be active or uh at inference time right So, so you're going to you might effectively get a much
larger parameter model without doing the training of a much larger parameter model. Um, and several people have succeeded at this. Um, mini CPM, which
model. Um, and several people have succeeded at this. Um, mini CPM, which um, I'll mention again in the scaling wall lecture, but this is a Chinese open LLM that basically tried to build really good small uh, language models and they
succeeded at taking a dense model and upycling it into that you can see that their numbers get significantly better in the last two rows, right? So the
dense models to thee, they get a a pretty non-trivial bump in performance.
Um Quen, uh I mentioned at the start of this lecture, one of their earliest attempts ate was taking one of their dense models and then building upcycled.
Um and they got, you know, fairly significant uh performance gains relative to sort of smaller models at the time. Like they got models on par
with their 7B models with a 2.7 billion parameter active model.
Um so uh to wrap up um I want to sort of walk through the DeepSseek architecture at the very end here. Um and hopefully this will give you a sense of you know the first thing I want to do is I want you to understand the DeepSeek V3
architecture setup and all the changes that they did. Um because that's an example of a modern high performance open source system. I also want you to maybe appreciate that architectures don't change that much. Deepseek v1 or
you know deepseek v1 um is uh you know it's not that new it's like maybe a year and a half or something maybe two years old um and they basically nailed the architecture at that point right so I want to see I want you to see what they
changed from that very ear the earliest attempt um to their big training run so this is the the very first starting point this is um deepseek I'm calling it v1 but actually you know probably the right way to refer to it is deepseek
it's a 16 billion parameter model with 2.8 of those parameters active. Um, and
you've seen already this diagram over here. This is the um shared two shared plus 64 fine grained um experts um of which uh four of them are active at a time or maybe about six of them are active at a time. Sorry. Um and the
routing you know you've already seen this I presented this in the in the middle of the lecture here. This is the very standard top K routing uh where the softmax is at the bottom before the the top K selection. Um and for balancing right at training time all they do is to add this auxiliary loss balancing term
right both the expert and device level balancing terms right so hopefully you know you remember those from earlier so that's the the v1 um and then um they saw how sort of effective their model was so I guess to add some more context
right deepseeek originally had a dense model and then they had ae model and thee model was remarkably good and so when they went to v2 um they went
straight to thee and now this is a 236 billion parameter model of which 21 of those billion parameters are active right so you need a lot of memory but your your flops consumption for for inferencing this model is not so bad now
um the architecture is identical I copied literally the the same figure because the architecture is literally the same minus changes to the number of you know experts that are active um and uh we've got now sort of some new things
happening but not too many new things. So the top selector is the same. So the
equation from before, this previous equation, this is identical. This this
is still how they do things. Um but they have this very clever trick that they add on. And this is um you know, I was going to say at the very beginning, you
add on. And this is um you know, I was going to say at the very beginning, you know, what's the drawback of having fine grained experts? Why can't I have, I don't know, uh 1024 fine- grained experts or 2046 fine grained experts?
Well, the problem is when you shard your experts very finely and you have a lot of active experts, right, you're going to have to route to those experts, right? So, your communication costs potentially grow and if you're very
right? So, your communication costs potentially grow and if you're very fragmented, you might have to send a lot of tokens to a lot of devices, right?
And so, the clever thing they come up with is to say, I'm not just going to, you know, at for each batch route to the top K exports naively, which might force me to send my tokens to lots of devices. What I'm going to do is I'm going to
first pick top M devices, right? So I'm going to do my, you know, normal scoring calculation, but I'm first going to sort of subset the set of allowed devices to top M, right? And once I've picked my devices, then I'm going to pick top K for each token within each device, right? So so now I've restricted the
devices. This really controls the communication cost. And now this gives
devices. This really controls the communication cost. And now this gives you more efficient training when you're scaling up to these gigantic sizes, right? you need to start really engaging with the systems aspect of things when
right? you need to start really engaging with the systems aspect of things when you're training a 236 billion parameter model. The other thing which reflects the systems concerns that are necessary at this scale is that they add a
communication balancing loss. Um, one way of thinking about things is, you know, for an expert, there's kind of inputs and outputs, right? The inputs
are, you know, the the token comes in and you route to your expert. And the
outputs are, you know, you have to kind of bring the tokens back where they belong, right? So, if a batch belongs on this device, it has to go back where the
belong, right? So, if a batch belongs on this device, it has to go back where the original device was. So, we have to think about both the input communication cost and the output communication cost. And so, they add a balancing loss to try to balance out the output communication cost as well, not just the sort of input
side. Um so that's a minor note but you can kind of see their attention to
side. Um so that's a minor note but you can kind of see their attention to detail on trying to make sure all the different sort of systems aspects are properly um taken care of. And then finally we kind of get to to the you
know the big deepseek uh v3 sorry that should say v3 not v2 up there 671 billion parameters of which uh 37 are active. You know once again um you know exactly the same figure because thee architecture itself doesn't change.
that's stayed the same since Deepseek MOE, right? Like if it works, don't change it. Um they do change a couple things. Um maybe they were, you know,
change it. Um they do change a couple things. Um maybe they were, you know, hearing you all say, why don't you normalize to one? And so, you know, they've normalized the gate to one. They've moved kind of the softmax normalizer operation up there. Um but they're not actually exponentiating um
sort of the the sort of gating decisions. They're actually taking sigmoids um which is a sort of softer sort of more nicely behaved operation um you know, than the soft max. So they they've got some changes here, but conceptually this is still the same as the top K routing decision, right? You
hopefully see very very similar things happening. And then in terms of the losses, um they've gone to this auxiliary loss-free trick of this B of I being incremented or decremented based on the expert load. And then they have a
sequence-wise auxiliary loss. Um and just to, you know, add some context, why would you want to balance different um uh experts on a single sequence? Well,
the thing that they're very concerned about is at training time, you know, it's fine to to not have a sequence-wise balancing loss, but at inference time, it might be the case that someone sends you very out of distribution sequences, and that might overwhelm certain experts, right? So, at inference time,
you can't control which sequences you get. So, you might want sort of stronger balancing that operates at a single sequence level rather than overall batch
level. Okay. And then in the Oh, sorry. Yes. Um, does 3 still do like the top M
level. Okay. And then in the Oh, sorry. Yes. Um, does 3 still do like the top M devices like does it keep the B2 improvement? Yeah, they keep the top M improvement. They do not keep uh for example the communication loss. So so
improvement. They do not keep uh for example the communication loss. So so
they've they've jettisoned some things but top M is a I mean it seems like a pretty clever idea they keep it. Yeah. Yeah. But it's not like they they always add things. They have removed some of the things. Um, and so in the
last two or so minutes, uh, of the class, I'm going to go over the none parts of Deepseek V3 because I think, you know, we're already at the point where I've explained most of DeepSeek V3. I might as well go through the the steps of explaining the rest of Deepseek V3 at this point. So, you all know kind
of how that works. So, um, they have a clever sort of optimization for the attention piece called MLA or multi head latent attention. And um you all actually already know all the ingredients that you need to understand
this because at the end of last lecture you know I talked about like GQA and MHA right so those are all inference optimizations that you need in order to optimize the size of the KV catch. So the deepse folks take a different tac or different approach at optimizing this instead of reducing the number of heads
um they're actually going to sort of project the heads into a lower dimensional space. So you have your inputs H of T and instead of sort of
dimensional space. So you have your inputs H of T and instead of sort of generating the K's and V's directly from these H of T's, what I'm going to do is I'm going to first generate a lowdimensional C. This you can think of this as like a, you know, compressed version of H. And this C is going to be
smaller and easier to cache. And I'm just going to cach these C's. And
whenever I need, you know, these K's and V's, well, I can sort of up project from this KV sort of conceptually speaking. And then, you know, I can take the inner products with the Q's, right? So you can kind of see how this would be a KV cache
savings if I only have to save the C instead of the higher dimensional H of T. Um and that's exactly the idea. So you take your H of T, you project it
T. Um and that's exactly the idea. So you take your H of T, you project it into a lower dimensional C and then you up project this back into the K's and V's, right? And if the C's are small, well that's you've compressed the KV
V's, right? And if the C's are small, well that's you've compressed the KV cache. That's good. Um and then you know in terms of the computation, right? If
cache. That's good. Um and then you know in terms of the computation, right? If
you're thinking about flops, well, you might think, well, this is not good because I have to multiply an extra matrix W U K, right? I didn't have this matrix before. That's an extra matrix multiply that I have to pay for. But
matrix before. That's an extra matrix multiply that I have to pay for. But
kind of the clever thing here is remember that on the other side of K, right? I'm going to take K. Q, right, there, that Q.K is going to be an inner
right? I'm going to take K. Q, right, there, that Q.K is going to be an inner product in the attention operation, right? And Q itself has a projection matrix Q. And so the trick here is you can merge this W U K and this Q matrix
matrix Q. And so the trick here is you can merge this W U K and this Q matrix together into one matrix. So I haven't gotten any extra matrix multiplies. I've
just merged this new matrix multiply into my other one. Right? This is, you know, just associivity. Um I can just merge the two. Um they also compress the queries for memory savings during training, but really that one is is not quite as necessary because it doesn't interact at all with the KV cache. Um,
I'm only going to mention this last one in passing. Um, because it is a subtlety, but it's kind of a clever subtlety that you realize, which is that this original trick, this sort of thing that I just described at the top is not compatible with rope, right? And the reason is because, you know, the rope
matrices, you know, basically you have the Q's and the K's and you rotate each of those Q's and the K's by multiplying with a rotation matrix RQ and RK. But if
you do that, then these RQs and RKs are in between the query projection and this up uh latent vector up projection matrix. And since I can't reorder these matrix multiplies, you know, rope kind of gets in the way. And they still have
a a solution of basically um doing rope on non-compressed dimensions. That's
kind of a side point. I think it's not quite as important. You can kind of look at the paper if you're super interested. The other thing that they do, and this is the last thing I promise, is that they have a a minor change in their loss
function called MTP where they predict multiple tokens um in parallel. And so
what they can do is normally, right, you have your inputs, you shift them to the to the left by one. So you're predicting one token in the future and then your transformer is going to predict all those tokens, right? That's your normal transformer loss. But then what you can do is right before you make those
transformer loss. But then what you can do is right before you make those predictions you can take you know the hidden state you can pass it to a very lightweight one layer transformer and that model can predict you know one
token in the future right so now the model is not just predicting the next token it's predicting the two tokens into the future right so that hopefully all makes sense um and this is just a small lightweight model that can do that um you can sort of see the architecture right here um the one thing that is is
kind of disappointing that I learned as I was sort researching for this lecture is actually they only do MTP with one token ahead. So even though they have this very complicated diagram of how they could do it for many tokens, um
turns out it's only done for for one token. Okay, so now I'm all done. Um are
kind of now at the core of how you would build and deploy, you know, a really high performance large scale system. And they take advantage of of kind of the sparsity idea that you don't need all of the uh parameters all the time. and
discrete routing is the real big challenge. And this is I think one of the big reasons why didn't immediately catch on. It's very scary to have to try to optimize this top K routing decisions. Um but heristics somehow seem to work, right? Like they they just do. And so there's a lot of empirical
evidence now that at least for for uh flop constraint settings is just a good idea. It's cost effective. Um you should do it. So definitely worth learning. Um
idea. It's cost effective. Um you should do it. So definitely worth learning. Um
thanks a lot uh for listening.
Loading video analysis...