Quantização de LLMs
Os modelos de linguagem estão ficando cada vez maiores, o que os torna cada vez mais caros e demorados para serem executados.
Por exemplo, o modelo chama 3 400B, se seus parâmetros forem armazenados no formato FP32, cada parâmetro ocupará 4 bytes, o que significa que são necessários 400(10e9)4 bytes = 1,6 TB de memória VRAM apenas para armazenar o modelo. Isso significa 20 GPUs com 80 GB de memória VRAM cada, o que não é barato.
Mas se deixarmos de lado os modelos gigantescos e passarmos a modelos com tamanhos mais comuns, por exemplo, 70 B de parâmetros, apenas armazenar o modelo significa 70(10e9)4 bytes = 280 GB de memória VRAM, o que significa 4 GPUs com 80 GB de memória VRAM cada.
Isso ocorre porque armazenamos os pesos no formato FP32, ou seja, cada parâmetro ocupa 4 bytes. Mas o que acontece se fizermos com que cada parâmetro ocupe menos bytes? Isso é chamado de quantização.
Por exemplo, se obtivermos um modelo com parâmetros 70B, cujos parâmetros ocupam meio byte, precisaremos apenas de 70(10e9)0,5 bytes = 35 GB de memória VRAM, o que significa 2 GPUs com 24 GB de memória VRAM cada, que já podem ser consideradas GPUs de usuário normais.
Este caderno foi traduzido automaticamente para torná-lo acessível a mais pessoas, por favor me avise se você vir algum erro de digitação..
Portanto, precisamos de maneiras de redimensionar esses modelos. Há três maneiras de fazer isso: destilação, poda e quantização.
A destilação consiste em treinar um modelo menor a partir das saídas do modelo maior. Ou seja, uma entrada é fornecida ao modelo pequeno e ao modelo grande, a saída correta é considerada a do modelo grande, de modo que o modelo pequeno é treinado de acordo com a saída do modelo grande. Mas isso exige que o modelo grande seja armazenado, o que não é o que queremos ou podemos fazer.
A poda consiste na remoção de parâmetros do modelo, tornando-o cada vez menor. Esse método baseia-se na ideia de que os modelos de linguagem atuais são superdimensionados e que apenas alguns parâmetros são os que realmente fornecem informações. Portanto, se conseguirmos eliminar os parâmetros que não fornecem informações, obteremos um modelo menor. Mas isso não é fácil atualmente, porque não temos como saber quais parâmetros são importantes e quais não são.
Por outro lado, a quantização consiste em reduzir o tamanho de cada um dos parâmetros do modelo. E é isso que vamos explicar nesta postagem.
Formato do parâmetro
Os parâmetros de peso podem ser armazenados em vários tipos de formatos
Originalmente, o FP32 era usado para armazenar os parâmetros, mas como começamos a ficar sem memória para armazenar os modelos, começamos a mudar para o FP16, que não apresentou resultados ruins.
No entanto, o problema com o FP16 é que ele não atinge valores tão altos quanto o FP32, portanto, pode ocorrer o caso de estouro de valor, ou seja, ao realizar cálculos internos na rede, o resultado é tão alto que não pode ser representado em FP16, o que gera erros. Isso ocorre porque o modelo foi treinado em FP32, o que possibilita todos os cálculos internos possíveis, mas, quando ele é alternado para FP16 para poder fazer inferências, alguns cálculos internos podem produzir transbordamentos.
Devido a esses erros de estouro, foram criados o TF32 e o BF16, que têm a mesma quantidade de bits de expoente, o que significa que eles podem atingir valores tão altos quanto o FP32, mas com a vantagem de ocupar menos memória porque têm menos bits. No entanto, ambos têm menos bits de mantissa, portanto, não podem representar números com a mesma precisão do FP32, o que pode levar a erros de arredondamento, mas, pelo menos, não teremos um erro ao executar a rede. O TF32 tem um total de 19 bits, enquanto o BF16 tem 16 bits. O BF16 costuma ser mais usado porque economiza mais memória.
Historicamente, existem os formatos INT8 e UINT8, que podem representar números de -128 a 127 e de 0 a 255, respectivamente. Embora sejam bons formatos porque economizam menos memória, já que cada parâmetro ocupa 1 byte em comparação com os 4 bytes do FP32, o problema é que eles só podem representar um pequeno intervalo de números e também apenas números inteiros, de modo que podem ocorrer os dois problemas vistos anteriormente, estouro e falta de precisão.
Para resolver o problema de os formatos INT8 e UINT8 representarem apenas números inteiros, foram criados os formatos FP8 e FP4, mas eles ainda não estão bem estabelecidos, nem têm um formato muito difundido.
Mesmo que tenhamos uma maneira de armazenar os parâmetros do modelo em formatos menores, e mesmo que consigamos resolver os problemas de estouro e arredondamento, temos outro problema, que é o fato de nem todas as GPUs serem capazes de representar todos os formatos. Isso ocorre porque esses problemas de memória são relativamente novos, de modo que as GPUs mais antigas não foram projetadas para resolver esses problemas e, portanto, não são capazes de representar todos os formatos.
Como último detalhe, a título de curiosidade, durante o treinamento dos modelos, é usado o que chamamos de precisão mista. Os pesos do modelo são armazenados no formato FP32, mas a passagem para frente e para trás é realizada em FP16 para torná-la mais rápida. Os gradientes resultantes da passagem para trás são armazenados em FP16 e são usados para modificar os valores FP32 dos pesos.
Tipos de quantização
Quantificação do ponto zero
Esse é o tipo mais simples de quantização. Ele consiste em reduzir o intervalo de valores de forma linear, o valor mínimo de FP32 corresponde ao valor mínimo do novo formato, o zero de FP32 corresponde ao zero do novo formato e o valor máximo de FP32 corresponde ao valor máximo do novo formato.
Por exemplo, se quisermos passar os números representados de -1 a 1 no formato UINT8, como os limites de UINT8 são -127 e 127, se quisermos representar o valor 0,3, o que faremos é multiplicar 0,3 por 127, o que dá 38,1 e arredondar para 38, que é o valor que seria armazenado em UINT8.
Se quisermos fazer a etapa oposta, para passar 38 para um formato entre -1 e 1, o que faremos é dividir 38 por 127, o que dá 0,2992, que é aproximadamente 0,3, e podemos ver que temos um erro de 0,008.
Quantificação afim
Nesse tipo de quantização, se você tiver uma matriz de valores em um formato e quiser mudar para outro, primeiro divida a matriz de números inteiros pelo valor máximo da matriz e, em seguida, multiplique a matriz de números inteiros pelo valor máximo do novo formato.
Por exemplo, na imagem acima, temos a matriz
[1.2, -0.5, -4.3, 1.2, -3.1, 0.8, 2.4, 5.4]
Como seu valor máximo é 5,4
, dividimos a matriz por esse valor e obtemos
[0.2222222222, -0.09259259259, -0.7962962963, 0.2222222222, -0.5740740741, 0.1481481481, 0.4444444444, 1]
Se agora multiplicarmos todos os valores por 127
, que é o valor máximo de UINT8, obteremos
[28,22222222, -11.75925926, -101.1296296, 28.22222222, -72.90740741, 18.81481481, 56.44444444, 127]
O que, arredondando, seria
[28, -12, -101, 28, -73, 19, 56, 127]
Se agora quiséssemos executar a etapa inversa, teríamos que dividir a matriz resultante por 127
, o que daria
[0,2204724409, -0.09448818898, -0.7952755906, 0.2204724409, -0.5748031496, 0.1496062992, 0.4409448819, 1]
E multiplique novamente por 5,4
, para obtermos
[1,190551181, -0.5102362205, -4.294488189, 1.190551181, -3.103937008, 0.8078740157, 2.381102362, 5.4]
Se a compararmos com a matriz original, veremos que temos um erro
Momentos de quantização
Quantificação pós-treinamento
Como o nome sugere, a quantização ocorre após o treinamento. O modelo é treinado em FP32 e, em seguida, quantizado em outro formato. Esse método é o mais simples, mas pode levar a erros de precisão na quantização.
Quantização durante o treinamento
Durante o treinamento, a passagem para frente é realizada no modelo original e em um modelo quantizado, e os possíveis erros decorrentes da quantização são vistos para atenuá-los. Esse processo torna o treinamento mais caro, porque é necessário ter o modelo original e o modelo quantizado armazenados na memória, e mais lento, porque é necessário executar a passagem para frente em dois modelos.
Métodos de quantização
Abaixo estão os links para as postagens em que explico cada um dos métodos para que esta postagem não fique muito longa.
- LLM.int8()
- GPTQ
- QLoRA
- AWQ
- QuIP
- GGUF
- HQQ
- AQLM
- FBGEMM FP8