Você está em uma entrevista para Engenheiro de IA na OpenAI. O entrevistador pergunta: "Nosso modelo GPT gera 100 tokens em 42 segundos. Como você faria para torná-lo 5x mais rápido?" Você: "Vou alocar mais GPUs para uma geração mais rápida." Entrevista encerrada. Aqui está o que você perdeu:
O verdadeiro gargalo não é o cálculo, é o cálculo redundante. Sem cache KV, o seu modelo recalcula chaves e valores para cada token, repetindo trabalho. - com cache KV → 9 segundos - sem cache KV → 42 segundos (~5x mais lento) Vamos mergulhar para entender como funciona!
Para entender o caching KV, devemos saber como os LLMs produzem tokens. - O Transformer produz estados ocultos para todos os tokens. - Os estados ocultos são projetados para o espaço de vocabulário. - Os logits do último token são usados para gerar o próximo token. - Repetir para os tokens subsequentes. Verifique isto👇
Assim, para gerar um novo token, precisamos apenas do estado oculto do token mais recente. Nenhum dos outros estados ocultos é necessário. A seguir, vamos ver como o último estado oculto é calculado dentro da camada do transformador a partir do mecanismo de atenção.
Durante a atenção: A última linha do produto consulta-chave envolve: - o último vetor de consulta. - todos os vetores de chave. Além disso, a última linha do resultado final da atenção envolve: - o último vetor de consulta. - todos os vetores de chave e valor. Verifique esta visualização para entender melhor:
A visão acima sugere que, para gerar um novo token, cada operação de atenção na rede precisa apenas de: - vetor de consulta do último token. - todos os vetores de chave e valor. Mas, há mais uma visão chave aqui.
À medida que geramos novos tokens: - Os vetores KV usados para TODOS os tokens anteriores não mudam. Assim, só precisamos gerar um vetor KV para o token gerado um passo antes. O resto dos vetores KV pode ser recuperado de um cache para economizar computação e tempo.
Isto é chamado de caching KV! Para reiterar, em vez de calcular redundante os vetores KV de todos os tokens de contexto, armazene-os em cache. Para gerar um token: - Gere o vetor QKV para o token gerado um passo antes. - Obtenha todos os outros vetores KV do cache. - Calcule a atenção. Verifique isto👇
O caching KV acelera a inferência ao calcular o cache KV do prompt antes de gerar tokens. É exatamente por isso que o ChatGPT demora mais a gerar o primeiro token do que os restantes. Este atraso é conhecido como tempo-para-o-primeiro-token (TTFT). Melhorar o TTFT é um tópico para outro dia!
211