2025-08-17 22:57:52 -06:00

330 lines
12 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "1883540a",
"metadata": {},
"outputs": [],
"source": [
"# /// script\n",
"# dependencies = [\n",
"# \"zstandard\",\n",
"# \"brotli\",\n",
"# \"matplotlib\",\n",
"# \"tqdm\",\n",
"# \"numpy\",\n",
"# ]\n",
"# ///\n",
"\n",
"import gc\n",
"import gzip\n",
"import time\n",
"from dataclasses import dataclass\n",
"from pathlib import Path\n",
"from typing import Callable\n",
"\n",
"import brotli\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from tqdm.auto import tqdm\n",
"from zstandard import ZstdCompressor, ZstdDecompressor\n",
"\n",
"# What should we benchmark?\n",
"FOLDER = Path(\".\")\n",
"ROUNDS = 5\n",
"ALGS = {\n",
" \"zstd\": {\n",
" \"levels\": [1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 14, 16, 22],\n",
" \"c\": lambda b, q: ZstdCompressor(level=q).compress(b),\n",
" \"d\": lambda b: ZstdDecompressor().decompress(b),\n",
" },\n",
" \"brotli\": {\n",
" \"levels\": [1, 2, 3, 4, 5, 7, 9, 11],\n",
" \"c\": lambda b, q: brotli.compress(b, quality=q),\n",
" \"d\": lambda b: brotli.decompress(b),\n",
" },\n",
" \"gzip\": {\n",
" \"levels\": [1, 3, 4, 5, 9],\n",
" \"c\": lambda b, q: gzip.compress(b, compresslevel=q),\n",
" \"d\": lambda b: gzip.decompress(b),\n",
" },\n",
"}\n",
"\n",
"# In case you were wondering:\n",
"# Sharing the ZstdCompressor across operations or using threads makes a little difference (with small files)\n",
"\n",
"files = [p.read_bytes() for p in FOLDER.rglob(\"*\") if p.is_file()]\n",
"orig_total = sum(len(b) for b in files)\n",
"print(f\"Loaded {len(files)} files, total {orig_total / 1e6:.2f} MB\")\n",
"assert files, \"No files found in the specified folder\"\n",
"\n",
"\n",
"@dataclass\n",
"class BenchPoint:\n",
" name: str\n",
" lvl: int\n",
" compress: Callable\n",
" decompress: Callable\n",
" compdata: list[bytes]\n",
" comp_time_ns: int = 0 # median compression time (ns)\n",
" decomp_time_ns: int = 0 # median decompression time (ns)\n",
"\n",
"\n",
"bench_points = [\n",
" BenchPoint(alg, lvl, spec[\"c\"], spec[\"d\"], [])\n",
" for alg, spec in ALGS.items()\n",
" for lvl in spec[\"levels\"]\n",
"]\n",
"P = len(bench_points)\n",
"C_ns = np.zeros((P, ROUNDS), dtype=np.int64) # compression time (ns)\n",
"D_ns = np.zeros((P, ROUNDS), dtype=np.int64) # decompression time (ns)\n",
"\n",
"# Benchmark\n",
"gc.disable()\n",
"time.time = lambda: 0 # make gzip deterministic (timestamp header)\n",
"progress = tqdm(range(ROUNDS), desc=\"Rounds\")\n",
"for r in progress:\n",
" for i, bp in enumerate(bench_points):\n",
" progress.set_postfix_str(f\"Compress {bp.name}/{bp.lvl}\")\n",
" for f, data in enumerate(files):\n",
" t0 = time.perf_counter_ns()\n",
" out = bp.compress(data, bp.lvl)\n",
" t1 = time.perf_counter_ns()\n",
" C_ns[i, r] += t1 - t0\n",
" if r == 0:\n",
" bp.compdata.append(out)\n",
" else:\n",
" assert out == bp.compdata[f], (\n",
" f\"Compressed data changed between rounds for {bp.name}, level={bp.lvl}\"\n",
" )\n",
" progress.set_postfix_str(f\"Decompress {bp.name}/{bp.lvl}\")\n",
" for file in bp.compdata:\n",
" t0 = time.perf_counter_ns()\n",
" _ = bp.decompress(file)\n",
" t1 = time.perf_counter_ns()\n",
" D_ns[i, r] += t1 - t0\n",
" progress.set_postfix_str(\"Benchmark Done\")\n",
"gc.enable()\n",
"\n",
"# Store median times directly in bench points\n",
"Cns, Dns = np.median(C_ns, 1), np.median(D_ns, 1)\n",
"for i, bp in enumerate(bench_points):\n",
" bp.comp_time_ns = int(Cns[i])\n",
" bp.decomp_time_ns = int(Dns[i])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12a0cf49",
"metadata": {},
"outputs": [],
"source": [
"def plot_xy(bench_points, time_attr, title, xlabel, ylabel, ylabel2):\n",
" fig, ax1 = plt.subplots(figsize=(9, 5), dpi=300)\n",
"\n",
" # Collect all y values (Mbit/s) and their corresponding x values for overlap detection\n",
" all_data_points = []\n",
"\n",
" for alg_name in ALGS:\n",
" alg_points = [bp for bp in bench_points if bp.name == alg_name]\n",
" comp_totals = np.array([sum(len(b) for b in bp.compdata) for bp in alg_points])\n",
" times_ns = np.array([getattr(bp, time_attr) for bp in alg_points])\n",
" levels = [bp.lvl for bp in alg_points]\n",
"\n",
" x = 100.0 * comp_totals / orig_total # compression percentages\n",
" y_mbps = orig_total * 8000.0 / times_ns # mbps - this is what we plot\n",
"\n",
" # Store all data points for overlap detection\n",
" for xx, yy in zip(x, y_mbps):\n",
" all_data_points.append((xx, yy))\n",
"\n",
" # Plot using Mbit/s (so faster operations are at the top)\n",
" ax1.plot(x, y_mbps, marker=\"s\", linewidth=1.25, label=alg_name, zorder=2)\n",
" levelstyle = {\n",
" \"ha\": \"center\",\n",
" \"va\": \"center\",\n",
" \"fontsize\": 7,\n",
" \"zorder\": 3,\n",
" \"bbox\": dict(\n",
" boxstyle=\"round,pad=0.2\",\n",
" facecolor=\"white\",\n",
" edgecolor=\"black\",\n",
" lw=0.4,\n",
" ),\n",
" }\n",
" for xx, yy, L in zip(x, y_mbps, levels):\n",
" ax1.text(xx, yy, str(L), **levelstyle)\n",
"\n",
" # Set up the axis\n",
" ax1.set_xlabel(xlabel)\n",
" ax1.set_ylabel(ylabel) # Compression/Decompression speed (Mbit/s)\n",
"\n",
" # Get plot dimensions\n",
" y_min, y_max = ax1.get_ylim()\n",
" x_min, x_max = ax1.get_xlim()\n",
"\n",
" # Extract unique y values and sort them\n",
" unique_y_values = sorted(set(point[1] for point in all_data_points))\n",
"\n",
" # Filter to y values within the plot range\n",
" visible_y_values = [y for y in unique_y_values if y_min <= y <= y_max]\n",
"\n",
" # Determine minimum spacing to avoid overlap (as a percentage of the y-range)\n",
" y_range = y_max - y_min\n",
" min_spacing = y_range * 0.05 # 5% of the y-range as minimum spacing\n",
"\n",
" # Select y values with sufficient spacing\n",
" selected_y_values = []\n",
" for y in visible_y_values:\n",
" # Check if this y value is far enough from already selected ones\n",
" if not selected_y_values or all(\n",
" abs(y - selected) >= min_spacing for selected in selected_y_values\n",
" ):\n",
" selected_y_values.append(y)\n",
"\n",
" # Function to format time labels\n",
" def format_time_label(ms):\n",
" if ms < 10:\n",
" return f\"{ms:.1f} ms\" # One decimal for <10ms\n",
" elif ms < 1000:\n",
" return f\"{ms:.0f} ms\" # No decimal for <1000ms\n",
" else:\n",
" seconds = ms / 1000\n",
" return f\"{seconds:.1f} s\" # Seconds with one decimal above 1000ms\n",
"\n",
" # Draw black lines from data points to right margin with labels\n",
" for y_mbps in selected_y_values:\n",
" # Find the rightmost x position for this y value\n",
" rightmost_x = max(\n",
" [point[0] for point in all_data_points if abs(point[1] - y_mbps) < 1e-6],\n",
" default=x_min,\n",
" )\n",
"\n",
" # Convert Mbit/s to milliseconds for the label\n",
" ms = orig_total * 8000.0 / (y_mbps * 1e6)\n",
"\n",
" # Draw a black line from the rightmost data point to the right margin\n",
" ax1.plot(\n",
" [rightmost_x, x_max],\n",
" [y_mbps, y_mbps],\n",
" color=\"black\",\n",
" linestyle=\"-\",\n",
" alpha=0.3,\n",
" linewidth=0.8,\n",
" zorder=1,\n",
" )\n",
"\n",
" # Add text label on the right side with improved formatting\n",
" ms_label = format_time_label(ms)\n",
"\n",
" ax1.text(\n",
" x_max,\n",
" y_mbps,\n",
" ms_label,\n",
" verticalalignment=\"center\",\n",
" horizontalalignment=\"left\",\n",
" fontsize=7,\n",
" color=\"black\",\n",
" alpha=0.8,\n",
" bbox=dict(\n",
" boxstyle=\"round,pad=0.1\", facecolor=\"white\", alpha=0.9, edgecolor=\"none\"\n",
" ),\n",
" )\n",
"\n",
" # Add the original fixed millisecond reference lines (only if they don't conflict)\n",
" ms_reference_values = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]\n",
"\n",
" for ms in ms_reference_values:\n",
" # Convert milliseconds to Mbit/s: mbps = orig_total * 8000.0 / (ms * 1e6)\n",
" mbps = orig_total * 8000.0 / (ms * 1e6)\n",
"\n",
" # Only draw the line if it falls within the current y-axis range and isn't too close to data points\n",
" if y_min <= mbps <= y_max:\n",
" # Check if this reference line is far enough from data points\n",
" if not any(abs(mbps - y) < min_spacing for y in selected_y_values):\n",
" ax1.axhline(\n",
" y=mbps,\n",
" color=\"gray\",\n",
" linestyle=\"--\",\n",
" alpha=0.3,\n",
" linewidth=0.5,\n",
" zorder=1,\n",
" )\n",
" # Add text label on the right side with improved formatting\n",
" ms_label = format_time_label(ms)\n",
" ax1.text(\n",
" x_max,\n",
" mbps,\n",
" ms_label,\n",
" verticalalignment=\"center\",\n",
" horizontalalignment=\"left\",\n",
" fontsize=7,\n",
" color=\"gray\",\n",
" alpha=0.6,\n",
" bbox=dict(\n",
" boxstyle=\"round,pad=0.1\",\n",
" facecolor=\"white\",\n",
" alpha=0.8,\n",
" edgecolor=\"none\",\n",
" ),\n",
" )\n",
"\n",
" ax1.set_title(title)\n",
" ax1.grid(True, linestyle=\":\", linewidth=0.7, alpha=0.6, zorder=0)\n",
" ax1.legend()\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"\n",
"plot_xy(\n",
" bench_points,\n",
" \"comp_time_ns\",\n",
" f\"Compression — {len(files)} files, total {orig_total / 1_048_576:.2f} MiB (median of {ROUNDS} rounds)\",\n",
" \"Compressed size (% of original)\",\n",
" \"Compression speed (Mbit/s)\",\n",
" \"Time (ms)\",\n",
")\n",
"plot_xy(\n",
" bench_points,\n",
" \"decomp_time_ns\",\n",
" f\"Decompression — {len(files)} files, total {orig_total / 1_048_576:.2f} MiB (median of {ROUNDS} rounds)\",\n",
" \"Compressed size (% of original)\",\n",
" \"Decompression speed (Mbit/s)\",\n",
" \"Time (ms)\",\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8a2bdcd",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}