{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "cab91cfc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/Power-RAG/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import copy\n", "import dataclasses\n", "import os\n", "import time\n", "import pathlib\n", "import itertools\n", "import multiprocessing\n", "import scipy\n", "import numpy as np\n", "import pandas as pd\n", "import pickle\n", "import gzip\n", "import threading\n", "import queue\n", "import pytz\n", "import traceback\n", "from datetime import datetime\n", "from tqdm.auto import tqdm, trange\n", "from typing import Any\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as mtick\n", "%matplotlib inline\n", "%config InlineBackend.figure_format='retina'" ] }, { "cell_type": "code", "execution_count": 2, "id": "8d24fbd7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sat Apr 12 00:10:05 2025 \n", "+-----------------------------------------------------------------------------------------+\n", "| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 |\n", "|-----------------------------------------+------------------------+----------------------+\n", "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n", "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n", "| | | MIG M. |\n", "|=========================================+========================+======================|\n", "| 0 NVIDIA A10G Off | 00000000:00:1E.0 Off | 0 |\n", "| 0% 27C P8 15W / 300W | 4MiB / 23028MiB | 0% Default |\n", "| | | N/A |\n", "+-----------------------------------------+------------------------+----------------------+\n", " \n", "+-----------------------------------------------------------------------------------------+\n", "| Processes: |\n", "| GPU GI CI PID Type Process name GPU Memory |\n", "| ID ID Usage |\n", "|=========================================================================================|\n", "| No running processes found |\n", "+-----------------------------------------------------------------------------------------+\n" ] } ], "source": [ "!nvidia-smi" ] }, { "cell_type": "code", "execution_count": 3, "id": "538b2c11", "metadata": {}, "outputs": [], "source": [ "def benchmark(f, *, f_setup=None, min_repeat: int, min_secs: float, tqdm_kwargs: dict | None=None) -> np.ndarray:\n", " latency = []\n", " \n", " # First run, ignore min_secs\n", " if f_setup is not None:\n", " f_setup()\n", " st = time.perf_counter_ns()\n", " f()\n", " ed = time.perf_counter_ns()\n", " latency.append((ed-st)/1e9)\n", " \n", " # Subsequent runs, until reaching both min_repeat and min_secs\n", " min_nanos = int(min_secs * 1e9)\n", " start_nanos = time.perf_counter_ns()\n", " while True:\n", " now_nanos = time.perf_counter_ns()\n", " if len(latency) > min_repeat and now_nanos - start_nanos > min_nanos:\n", " break\n", " if f_setup is not None:\n", " f_setup()\n", " st = time.perf_counter_ns()\n", " f()\n", " ed = time.perf_counter_ns()\n", " latency.append((ed-st)/1e9)\n", " return np.array(latency)\n", "\n", "def tail_mean(xs, skip=0.2):\n", " return xs[int(len(xs) * skip):].mean()" ] }, { "cell_type": "code", "execution_count": 4, "id": "02c9c9b1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "torch.set_grad_enabled(False)" ] }, { "cell_type": "code", "execution_count": 5, "id": "3405fdc7", "metadata": {}, "outputs": [], "source": [ "nd_list = list(itertools.chain(itertools.product([12, 3], [256])))\n", "seqlen_list = [256]\n", "bs_list = [2,4,8,16,32,64,128,256,512,1024,2048]" ] }, { "cell_type": "code", "execution_count": 6, "id": "10dc981a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[(12, 256), (3, 256)]\n", "[256]\n", "[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]\n" ] } ], "source": [ "print(nd_list)\n", "print(seqlen_list)\n", "print(bs_list)" ] }, { "cell_type": "code", "execution_count": 7, "id": "7e0ee385", "metadata": {}, "outputs": [], "source": [ "def benchmark_dense(out, nd_list, seqlen_list, bs_list):\n", " seqlen_list = [1] + seqlen_list\n", " total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n", " pbar = tqdm(total=total)\n", " for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n", " h = n * d\n", " maxbs = max(bs_list)\n", " print(maxbs, n, d, seqlen)\n", " cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n", " X = torch.rand((maxbs, seqlen, h), dtype=torch.bfloat16, device=\"cuda:0\")\n", " W = torch.rand((h, h), dtype=torch.bfloat16, device=\"cuda:0\")\n", " torch.cuda.synchronize()\n", " for bs in reversed(bs_list):\n", " pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n", " def run():\n", " torch.matmul(X[:bs], W)\n", " torch.cuda.synchronize()\n", " def clear_cache():\n", " cache.zero_()\n", " torch.cuda.synchronize()\n", " latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n", " l = tail_mean(latency)\n", " out.append({\n", " \"n\": n,\n", " \"d\": d,\n", " \"seqlen\": seqlen,\n", " \"bs\": bs,\n", " \"latency\": l\n", " })\n", " pbar.update()\n", " del cache, X, W\n", " torch.cuda.empty_cache()\n", " pbar.close()" ] }, { "cell_type": "code", "execution_count": 8, "id": "c206a502", "metadata": {}, "outputs": [], "source": [ "def benchmark_qk_init(out, nd_list, seqlen_list, bs_list):\n", " total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n", " pbar = tqdm(total=total)\n", " for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n", " h = n * d\n", " try:\n", " maxbs = max(b for b in bs_list if b*n*seqlen*d*2*2+b*n*seqlen**2*2 < 80e9)\n", " except ValueError:\n", " pbar.update(len(bs_list))\n", " continue\n", " cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n", " Qmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n", " Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n", " torch.cuda.synchronize()\n", " for bs in reversed(bs_list):\n", " pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n", " if bs > maxbs:\n", " pbar.update()\n", " continue\n", " Q = Qmax[:bs]\n", " K = Kmax[:bs]\n", " def run():\n", " torch.bmm(Q.view(bs * n, seqlen, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n", " torch.cuda.synchronize()\n", " def clear_cache():\n", " cache.zero_()\n", " torch.cuda.synchronize()\n", " latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n", " l = tail_mean(latency)\n", " out.append({\n", " \"n\": n,\n", " \"d\": d,\n", " \"seqlen\": seqlen,\n", " \"bs\": bs,\n", " \"latency\": l\n", " })\n", " pbar.update()\n", " del cache, Q, K, Qmax, Kmax\n", " torch.cuda.empty_cache()\n", " pbar.close()" ] }, { "cell_type": "code", "execution_count": 9, "id": "a3a2103c", "metadata": {}, "outputs": [], "source": [ "def benchmark_qk_ar(out, nd_list, seqlen_list, bs_list):\n", " total = len(list(itertools.product(nd_list, seqlen_list, bs_list)))\n", " pbar = tqdm(total=total)\n", " for (n, d), seqlen in reversed(list(itertools.product(nd_list, seqlen_list))):\n", " h = n * d\n", " try:\n", " maxbs = max(b for b in bs_list if b*n*(1+seqlen)*d*2+b*n*seqlen*2 < 80e9)\n", " except ValueError:\n", " pbar.update(len(bs_list))\n", " continue\n", " cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=\"cuda:0\")\n", " Qmax = torch.rand((maxbs, n, 1, d), dtype=torch.bfloat16, device=\"cuda:0\")\n", " Kmax = torch.rand((maxbs, n, seqlen, d), dtype=torch.bfloat16, device=\"cuda:0\")\n", " torch.cuda.synchronize()\n", " for bs in reversed(bs_list):\n", " pbar.set_postfix(n=n, h=h, d=d, seqlen=seqlen, bs=bs)\n", " if bs > maxbs:\n", " pbar.update()\n", " continue\n", " Q = Qmax[:bs]\n", " K = Kmax[:bs]\n", " def run():\n", " torch.bmm(Q.view(bs * n, 1, d), K.view(bs * n, seqlen, d).transpose(1, 2))\n", " torch.cuda.synchronize()\n", " def clear_cache():\n", " cache.zero_()\n", " torch.cuda.synchronize()\n", " latency = benchmark(run, f_setup=clear_cache, min_repeat=20, min_secs=2)\n", " l = tail_mean(latency)\n", " out.append({\n", " \"n\": n,\n", " \"d\": d,\n", " \"seqlen\": seqlen,\n", " \"bs\": bs,\n", " \"latency\": l\n", " })\n", " pbar.update()\n", " del cache, Q, K, Qmax, Kmax\n", " torch.cuda.empty_cache()\n", " pbar.close()" ] }, { "cell_type": "code", "execution_count": 10, "id": "3aaad98a", "metadata": {}, "outputs": [], "source": [ "data = {}" ] }, { "cell_type": "code", "execution_count": 11, "id": "18137de3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/22 [00:00" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a25cdd5a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "63b8a531", "metadata": {}, "outputs": [], "source": [ "import transformers" ] }, { "cell_type": "code", "execution_count": null, "id": "af90eff1", "metadata": {}, "outputs": [], "source": [ "def _gen_opt_cfg(n_layers: int, d_model: int, n_heads: int, **kwargs) -> transformers.OPTConfig:\n", " return transformers.OPTConfig(\n", " num_hidden_layers=n_layers,\n", " hidden_size=d_model,\n", " ffn_dim=d_model*4,\n", " num_attention_heads=n_heads,\n", " **kwargs\n", " )\n", "optcfg = {\n", " # https://arxiv.org/pdf/2205.01068.pdf Table 2.1\n", " \"125m\": _gen_opt_cfg(12, 768, 12),\n", " \"350m\": _gen_opt_cfg(24, 1024, 16),\n", " \"760m\": _gen_opt_cfg(24, 1536, 16),\n", " \"1.3b\": _gen_opt_cfg(24, 2048, 32),\n", " \"2.7b\": _gen_opt_cfg(32, 2560, 32),\n", " \"6.7b\": _gen_opt_cfg(32, 4096, 32),\n", " \"13b\": _gen_opt_cfg(40, 5120, 40),\n", " \"13b_1layer\": _gen_opt_cfg(1, 5120, 40),\n", " \"30b\": _gen_opt_cfg(48, 7168, 56),\n", " \"66b\": _gen_opt_cfg(64, 9216, 72),\n", " \"175b\": _gen_opt_cfg(96, 12288, 96),\n", " \"175b_1layer\": _gen_opt_cfg(1, 12288, 96),\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "5b9ebbec", "metadata": {}, "outputs": [], "source": [ "def greedy_sample_one(model, input_ids, attention_mask=None, past_key_values=None):\n", " bs, tgt_len = input_ids.shape\n", " if past_key_values is not None:\n", " _bs, _num_heads, src_len, _head_dims = past_key_values[0][0].shape\n", " assert bs == _bs\n", " else:\n", " src_len = 0\n", " if attention_mask is None:\n", " attention_mask = torch.ones((bs, src_len + tgt_len), device=model.device)\n", " ret = model(\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " past_key_values=past_key_values,\n", " use_cache=True, output_hidden_states=False, return_dict=True,\n", " )\n", " return ret\n", "\n", "def time_greedy_generate(model, input_ids, new_tokens):\n", " ts = []\n", " output = input_ids\n", " past_key_values = None\n", " cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=model.device)\n", " attention_mask = torch.ones(input_ids.shape, device=model.device) \n", " for _ in range(new_tokens):\n", " cache.zero_()\n", " torch.cuda.synchronize()\n", " st = time.perf_counter_ns()\n", " \n", " ret = greedy_sample_one(model, input_ids, attention_mask, past_key_values)\n", " input_ids = torch.argmax(ret.logits[:, -1, :], axis=-1)[:, None]\n", " output = torch.cat([output, input_ids], axis=1)\n", " past_key_values = ret.past_key_values\n", " attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)\n", " \n", " torch.cuda.synchronize()\n", " ed = time.perf_counter_ns()\n", " ts.append((ed-st)/1e9)\n", " return np.array(ts)" ] }, { "cell_type": "code", "execution_count": null, "id": "fc92f940", "metadata": {}, "outputs": [], "source": [ "opt_config = optcfg[\"6.7b\"]\n", "\n", "torch.set_default_dtype(torch.bfloat16)\n", "with transformers.modeling_utils.no_init_weights():\n", " model = transformers.models.opt.OPTForCausalLM(opt_config).to(\"cuda\")\n", "torch.set_default_dtype(torch.float32)" ] }, { "cell_type": "code", "execution_count": null, "id": "c19fa396", "metadata": {}, "outputs": [], "source": [ "db = {}\n", "input_tokens = 200\n", "new_tokens = 500\n", "for bs in tqdm(list(itertools.chain(range(1, 8), range(8, 16, 2), [16]))):\n", " x = torch.randint(1000, 10000, (bs, input_tokens), device=model.device)\n", " stack = []\n", " for _ in range(10):\n", " l = time_greedy_generate(model, x, new_tokens=new_tokens)\n", " stack.append(l)\n", " db[bs] = np.median(np.stack(stack), axis=0)\n", " del x\n", " torch.cuda.empty_cache()\n", "del model\n", "torch.cuda.empty_cache()\n", "\n", "with gzip.open(\"data/20230516-e2e-text-generation-batch.pkl.gz\", \"wb\") as f:\n", " pickle.dump(db, f)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }