Thursday, May 23, 2024

Implement KV Cache Quantization on Any AI Model - Simple Tutorial

This video is a simple tutorial to explain what is KV Cache Quantization and a hands-on demo to see how to implement it.



Diagram:


                                  +---------------+

                                  |  Token  Generator  |

                                  +---------------+

                                             |

                                             |

                                             v

                                  +---------------+

                                  |  KV Cache  |

                                  |  (stores previous  |

                                  |   calculations)  |

                                  +---------------+

                                             |

                                             |

                                             v

                                  +---------------+

                                  |  Token 1  |-------->  Cache Key: 1

                                  |  (input)  |-------->  Cache Value: Matrix Multiplication Results

                                  +---------------+

                                             |

                                             |

                                             v

                                  +---------------+

                                  |  Token 2  |-------->  Cache Key: 2

                                  |  (input)  |-------->  Cache Value: Matrix Multiplication Results

                                  +---------------+

                                             |

                                             |

                                             v

                                  ...

                                  +---------------+

                                  |  Token 999  |-------->  Cache Key: 999

                                  |  (input)  |-------->  Cache Value: Matrix Multiplication Results

                                  +---------------+

                                             |

                                             |

                                             v

                                  +---------------+

                                  |  Token 1000 |-------->  Cache Key: 1000

                                  |  (input)  |-------->  Cache Value: Matrix Multiplication Results

                                  +---------------+

                                             |

                                             |

                                             v

                                  +---------------+

                                  |  Token 1001 |-------->  Cache Key: 1001

                                  |  (input)  |-------->  Cache Value: Reuses cached results from

                                  |             |         Tokens 1-999, and new calculations

                                  |             |         from Token 1000

                                  +---------------+


Code:

!pip install -q git+https://github.com/huggingface/transformers
!pip install -q quanto bitsandbytes accelerate datasets

!huggingface-cli login

import torch

if not torch.cuda.is_available():
    logging.warning('GPU device not found. Go to Runtime > Change Runtime type and set Hardware accelerator to "GPU"')
    logging.warning('If you use CPU it will be very slow')
else:
    print(f"Cuda device found: {torch.cuda.get_device_name(0)}")
   
   
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the tokenizer and model

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", padding_side="left")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")

# Set pad token for batched generation

tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer(["Hello, how are you?", "I am going to subscribe to this channel because"], padding=True, return_tensors="pt").to(model.device)

# Feel free to play with generation kwargs. See (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationConfig) for more

generation_kwargs = {"do_sample": False, "temperature": 1.0, "top_p": 1.0, "max_new_tokens": 20, "min_new_tokens": 20}

# Let's generate one with quantized cache, and another with original precision cache. Then check the quality of generations

out = model.generate(**inputs, cache_implementation="quantized", cache_config={"backend": "quanto", "nbits": 4, "q_group_size": 32, "residual_length": 64})

out_fp16 = model.generate(**inputs, **generation_kwargs)

print(f"text with quant cache: {tokenizer.batch_decode(out)}")
print(f"text with fp16 cache: {tokenizer.batch_decode(out_fp16)}")


No comments: