Mixtral-8x7B MoE
For me the best description of mixtral-8x7b
is the following picture
Between the release of gemini
and the release of mixtra-8x7b
there was a difference of only a few days. The first two days after the release of gemini
there was a lot of talk about that model, but as soon as mixtral-8x7b
was released, gemini
was completely forgotten and the whole community was talking about mixtral-8x7b
.
And no wonder, looking at its benchmarks, we can see that it is at the level of models such as llama2-70B
and GPT3.5
, but with the difference that while mixtral-8x7b
has only 46.7B of parameters, llama2-70B
has 70B and GPT3.5
has 175B.
This notebook has been automatically translated to make it accessible to more people, please let me know if you see any typos.
Number of parameters
As the name suggests, mixtral-8x7b
is a set of 8 models of 7B parameters, so we could think that it has 56B parameters (7Bx8), but it is not. As Andrej Karpathy explains, only the Feed forward
blocks of the transformers are multiplied by 8, the rest of the parameters are shared among the 8 models. So in the end the model has 46.7B parameters.
Mixture of Experts (MoE)
As we have said, the model is a set of 8 models of 7B parameters, hence MoE
, which stands for Mixture of Experts
. Each of the 8 models is trained independently, but when inference is done a router decides the output of which model is the one to be used.
The following image shows the architecture of a Transformer
.
If you don't know it, the important thing is that this architecture consists of an encoder and a decoder.
LLMs are decoder-only models, so they do not have an encoder. You can see that in the architecture there are three attention modules, one of them actually connects the encoder to the decoder. But since LLMs do not have an encoder, there is no need for the attention module that connects the decoder and the decoder.
Now that we know what the architecture of an LLM looks like, we can see what the architecture of mixtral-8x7b
looks like. In the following image we can see the architecture of the model
As you can see, the architecture consists of a 7B parameter Transformer
decoder, only the Feed forward
layer consists of 8 Feed forward
layers with a router that chooses which of the 8 Feed forward
layers to use. In the above image only four Feed forward
layers are shown, I suppose this is to simplify the diagram, but in reality there are 8 Feed forward
layers. You can also see two paths for two different words, the word More
and the word Parameters
and how the router chooses which Feed forward
to use for each word.
Looking at the architecture we can understand why the model has 46.7B parameters and not 56B. As we have said, only the Feed forward
blocks are multiplied by 8, the rest of the parameters are shared among the 8 models.
Use of Mixtral-8x7b in the cloud
Unfortunately, using mixtral-8x7b
locally is complicated because the hardware requirements are as follows
- float32: VRAM > 180 GB, i.e., since each parameter occupies 4 bytes, we need 46.7B * 4 = 186.8 GB of VRAM just to store the model, plus the VRAM needed to store the input and output data.
- float16: VRAM > 90 GB, in this case each parameter occupies 2 bytes, so we need 46.7B * 2 = 93.4 GB of VRAM just to store the model, plus the VRAM needed to store the input and output data.
- 8-bit: VRAM > 45 GB, here each parameter occupies 1 byte, so we need 46.7B * 1 = 46.7 GB of VRAM just to store the model, plus the VRAM needed to store the input and output data.
- 4-bit: VRAM > 23 GB, here each parameter occupies 0.5 bytes, so we need 46.7B * 0.5 = 23.35 GB of VRAM just to store the model, plus the VRAM needed to store the input and output data.
We need very powerful GPUs to run it, even when using the 4-bit quantized model.
So, the easiest way to use Mixtral-8x7B
is to use it already deployed in the cloud. I have found several sites where you can use it
Use of Mixtral-8x7b in huggingface chat
The first one is in huggingface chat. To use it you have to click on the cogwheel inside the Current Model
box and select Mistral AI - Mixtral-8x7B
. Once selected, you can start talking to the model.
Once inside select mistralai/Mixtral-8x7B-Instruct-v0.1
and finally click on the Activate
button. Now we can test the model
As you can see, I asked him in Spanish what MoE
is and he explained it to me.
Using Mixtral-8x7b in Perplexity Labs
Another option is to use Perplexity Labs. Once inside you have to select mixtral-8x7b-instruct
in a drop-down menu in the lower right corner.
As you can see, I also asked him in Spanish what MoE
is and he explained it to me.
Using Mixtral-8x7b locally via the huggingface API
One way to use it locally, whatever HW resources you have, is through the huggingface API. To do this you have to install the huggingface-hub
library of huggingface
%pip install huggingface-hub
Here is an implementation with gradio
.
%pip install huggingface-hubfrom huggingface_hub import InferenceClientimport gradio as grclient = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")def format_prompt(message, history):prompt = "<s>"for user_prompt, bot_response in history:prompt += f"[INST] {user_prompt} [/INST]"prompt += f" {bot_response}</s> "prompt += f"[INST] {message} [/INST]"return promptdef generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):temperature = float(temperature)if temperature < 1e-2:temperature = 1e-2top_p = float(top_p)generate_kwargs = dict(temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=42,)formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)output = ""for response in stream:output += response.token.textyield outputreturn outputadditional_inputs=[gr.Textbox(label="System Prompt", max_lines=1, interactive=True,),gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens"),gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")]gr.ChatInterface(fn=generate,chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),additional_inputs=additional_inputs,title="Mixtral 46.7B",concurrency_limit=20,).launch(show_api=False)