330 lines
12 KiB
Plaintext
330 lines
12 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1883540a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# /// script\n",
|
|
"# dependencies = [\n",
|
|
"# \"zstandard\",\n",
|
|
"# \"brotli\",\n",
|
|
"# \"matplotlib\",\n",
|
|
"# \"tqdm\",\n",
|
|
"# \"numpy\",\n",
|
|
"# ]\n",
|
|
"# ///\n",
|
|
"\n",
|
|
"import gc\n",
|
|
"import gzip\n",
|
|
"import time\n",
|
|
"from dataclasses import dataclass\n",
|
|
"from pathlib import Path\n",
|
|
"from typing import Callable\n",
|
|
"\n",
|
|
"import brotli\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"from tqdm.auto import tqdm\n",
|
|
"from zstandard import ZstdCompressor, ZstdDecompressor\n",
|
|
"\n",
|
|
"# What should we benchmark?\n",
|
|
"FOLDER = Path(\".\")\n",
|
|
"ROUNDS = 5\n",
|
|
"ALGS = {\n",
|
|
" \"zstd\": {\n",
|
|
" \"levels\": [1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 14, 16, 22],\n",
|
|
" \"c\": lambda b, q: ZstdCompressor(level=q).compress(b),\n",
|
|
" \"d\": lambda b: ZstdDecompressor().decompress(b),\n",
|
|
" },\n",
|
|
" \"brotli\": {\n",
|
|
" \"levels\": [1, 2, 3, 4, 5, 7, 9, 11],\n",
|
|
" \"c\": lambda b, q: brotli.compress(b, quality=q),\n",
|
|
" \"d\": lambda b: brotli.decompress(b),\n",
|
|
" },\n",
|
|
" \"gzip\": {\n",
|
|
" \"levels\": [1, 3, 4, 5, 9],\n",
|
|
" \"c\": lambda b, q: gzip.compress(b, compresslevel=q),\n",
|
|
" \"d\": lambda b: gzip.decompress(b),\n",
|
|
" },\n",
|
|
"}\n",
|
|
"\n",
|
|
"# In case you were wondering:\n",
|
|
"# Sharing the ZstdCompressor across operations or using threads makes a little difference (with small files)\n",
|
|
"\n",
|
|
"files = [p.read_bytes() for p in FOLDER.rglob(\"*\") if p.is_file()]\n",
|
|
"orig_total = sum(len(b) for b in files)\n",
|
|
"print(f\"Loaded {len(files)} files, total {orig_total / 1e6:.2f} MB\")\n",
|
|
"assert files, \"No files found in the specified folder\"\n",
|
|
"\n",
|
|
"\n",
|
|
"@dataclass\n",
|
|
"class BenchPoint:\n",
|
|
" name: str\n",
|
|
" lvl: int\n",
|
|
" compress: Callable\n",
|
|
" decompress: Callable\n",
|
|
" compdata: list[bytes]\n",
|
|
" comp_time_ns: int = 0 # median compression time (ns)\n",
|
|
" decomp_time_ns: int = 0 # median decompression time (ns)\n",
|
|
"\n",
|
|
"\n",
|
|
"bench_points = [\n",
|
|
" BenchPoint(alg, lvl, spec[\"c\"], spec[\"d\"], [])\n",
|
|
" for alg, spec in ALGS.items()\n",
|
|
" for lvl in spec[\"levels\"]\n",
|
|
"]\n",
|
|
"P = len(bench_points)\n",
|
|
"C_ns = np.zeros((P, ROUNDS), dtype=np.int64) # compression time (ns)\n",
|
|
"D_ns = np.zeros((P, ROUNDS), dtype=np.int64) # decompression time (ns)\n",
|
|
"\n",
|
|
"# Benchmark\n",
|
|
"gc.disable()\n",
|
|
"time.time = lambda: 0 # make gzip deterministic (timestamp header)\n",
|
|
"progress = tqdm(range(ROUNDS), desc=\"Rounds\")\n",
|
|
"for r in progress:\n",
|
|
" for i, bp in enumerate(bench_points):\n",
|
|
" progress.set_postfix_str(f\"Compress {bp.name}/{bp.lvl}\")\n",
|
|
" for f, data in enumerate(files):\n",
|
|
" t0 = time.perf_counter_ns()\n",
|
|
" out = bp.compress(data, bp.lvl)\n",
|
|
" t1 = time.perf_counter_ns()\n",
|
|
" C_ns[i, r] += t1 - t0\n",
|
|
" if r == 0:\n",
|
|
" bp.compdata.append(out)\n",
|
|
" else:\n",
|
|
" assert out == bp.compdata[f], (\n",
|
|
" f\"Compressed data changed between rounds for {bp.name}, level={bp.lvl}\"\n",
|
|
" )\n",
|
|
" progress.set_postfix_str(f\"Decompress {bp.name}/{bp.lvl}\")\n",
|
|
" for file in bp.compdata:\n",
|
|
" t0 = time.perf_counter_ns()\n",
|
|
" _ = bp.decompress(file)\n",
|
|
" t1 = time.perf_counter_ns()\n",
|
|
" D_ns[i, r] += t1 - t0\n",
|
|
" progress.set_postfix_str(\"Benchmark Done\")\n",
|
|
"gc.enable()\n",
|
|
"\n",
|
|
"# Store median times directly in bench points\n",
|
|
"Cns, Dns = np.median(C_ns, 1), np.median(D_ns, 1)\n",
|
|
"for i, bp in enumerate(bench_points):\n",
|
|
" bp.comp_time_ns = int(Cns[i])\n",
|
|
" bp.decomp_time_ns = int(Dns[i])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "12a0cf49",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def plot_xy(bench_points, time_attr, title, xlabel, ylabel, ylabel2):\n",
|
|
" fig, ax1 = plt.subplots(figsize=(9, 5), dpi=300)\n",
|
|
"\n",
|
|
" # Collect all y values (Mbit/s) and their corresponding x values for overlap detection\n",
|
|
" all_data_points = []\n",
|
|
"\n",
|
|
" for alg_name in ALGS:\n",
|
|
" alg_points = [bp for bp in bench_points if bp.name == alg_name]\n",
|
|
" comp_totals = np.array([sum(len(b) for b in bp.compdata) for bp in alg_points])\n",
|
|
" times_ns = np.array([getattr(bp, time_attr) for bp in alg_points])\n",
|
|
" levels = [bp.lvl for bp in alg_points]\n",
|
|
"\n",
|
|
" x = 100.0 * comp_totals / orig_total # compression percentages\n",
|
|
" y_mbps = orig_total * 8000.0 / times_ns # mbps - this is what we plot\n",
|
|
"\n",
|
|
" # Store all data points for overlap detection\n",
|
|
" for xx, yy in zip(x, y_mbps):\n",
|
|
" all_data_points.append((xx, yy))\n",
|
|
"\n",
|
|
" # Plot using Mbit/s (so faster operations are at the top)\n",
|
|
" ax1.plot(x, y_mbps, marker=\"s\", linewidth=1.25, label=alg_name, zorder=2)\n",
|
|
" levelstyle = {\n",
|
|
" \"ha\": \"center\",\n",
|
|
" \"va\": \"center\",\n",
|
|
" \"fontsize\": 7,\n",
|
|
" \"zorder\": 3,\n",
|
|
" \"bbox\": dict(\n",
|
|
" boxstyle=\"round,pad=0.2\",\n",
|
|
" facecolor=\"white\",\n",
|
|
" edgecolor=\"black\",\n",
|
|
" lw=0.4,\n",
|
|
" ),\n",
|
|
" }\n",
|
|
" for xx, yy, L in zip(x, y_mbps, levels):\n",
|
|
" ax1.text(xx, yy, str(L), **levelstyle)\n",
|
|
"\n",
|
|
" # Set up the axis\n",
|
|
" ax1.set_xlabel(xlabel)\n",
|
|
" ax1.set_ylabel(ylabel) # Compression/Decompression speed (Mbit/s)\n",
|
|
"\n",
|
|
" # Get plot dimensions\n",
|
|
" y_min, y_max = ax1.get_ylim()\n",
|
|
" x_min, x_max = ax1.get_xlim()\n",
|
|
"\n",
|
|
" # Extract unique y values and sort them\n",
|
|
" unique_y_values = sorted(set(point[1] for point in all_data_points))\n",
|
|
"\n",
|
|
" # Filter to y values within the plot range\n",
|
|
" visible_y_values = [y for y in unique_y_values if y_min <= y <= y_max]\n",
|
|
"\n",
|
|
" # Determine minimum spacing to avoid overlap (as a percentage of the y-range)\n",
|
|
" y_range = y_max - y_min\n",
|
|
" min_spacing = y_range * 0.05 # 5% of the y-range as minimum spacing\n",
|
|
"\n",
|
|
" # Select y values with sufficient spacing\n",
|
|
" selected_y_values = []\n",
|
|
" for y in visible_y_values:\n",
|
|
" # Check if this y value is far enough from already selected ones\n",
|
|
" if not selected_y_values or all(\n",
|
|
" abs(y - selected) >= min_spacing for selected in selected_y_values\n",
|
|
" ):\n",
|
|
" selected_y_values.append(y)\n",
|
|
"\n",
|
|
" # Function to format time labels\n",
|
|
" def format_time_label(ms):\n",
|
|
" if ms < 10:\n",
|
|
" return f\"{ms:.1f} ms\" # One decimal for <10ms\n",
|
|
" elif ms < 1000:\n",
|
|
" return f\"{ms:.0f} ms\" # No decimal for <1000ms\n",
|
|
" else:\n",
|
|
" seconds = ms / 1000\n",
|
|
" return f\"{seconds:.1f} s\" # Seconds with one decimal above 1000ms\n",
|
|
"\n",
|
|
" # Draw black lines from data points to right margin with labels\n",
|
|
" for y_mbps in selected_y_values:\n",
|
|
" # Find the rightmost x position for this y value\n",
|
|
" rightmost_x = max(\n",
|
|
" [point[0] for point in all_data_points if abs(point[1] - y_mbps) < 1e-6],\n",
|
|
" default=x_min,\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Convert Mbit/s to milliseconds for the label\n",
|
|
" ms = orig_total * 8000.0 / (y_mbps * 1e6)\n",
|
|
"\n",
|
|
" # Draw a black line from the rightmost data point to the right margin\n",
|
|
" ax1.plot(\n",
|
|
" [rightmost_x, x_max],\n",
|
|
" [y_mbps, y_mbps],\n",
|
|
" color=\"black\",\n",
|
|
" linestyle=\"-\",\n",
|
|
" alpha=0.3,\n",
|
|
" linewidth=0.8,\n",
|
|
" zorder=1,\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Add text label on the right side with improved formatting\n",
|
|
" ms_label = format_time_label(ms)\n",
|
|
"\n",
|
|
" ax1.text(\n",
|
|
" x_max,\n",
|
|
" y_mbps,\n",
|
|
" ms_label,\n",
|
|
" verticalalignment=\"center\",\n",
|
|
" horizontalalignment=\"left\",\n",
|
|
" fontsize=7,\n",
|
|
" color=\"black\",\n",
|
|
" alpha=0.8,\n",
|
|
" bbox=dict(\n",
|
|
" boxstyle=\"round,pad=0.1\", facecolor=\"white\", alpha=0.9, edgecolor=\"none\"\n",
|
|
" ),\n",
|
|
" )\n",
|
|
"\n",
|
|
" # Add the original fixed millisecond reference lines (only if they don't conflict)\n",
|
|
" ms_reference_values = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]\n",
|
|
"\n",
|
|
" for ms in ms_reference_values:\n",
|
|
" # Convert milliseconds to Mbit/s: mbps = orig_total * 8000.0 / (ms * 1e6)\n",
|
|
" mbps = orig_total * 8000.0 / (ms * 1e6)\n",
|
|
"\n",
|
|
" # Only draw the line if it falls within the current y-axis range and isn't too close to data points\n",
|
|
" if y_min <= mbps <= y_max:\n",
|
|
" # Check if this reference line is far enough from data points\n",
|
|
" if not any(abs(mbps - y) < min_spacing for y in selected_y_values):\n",
|
|
" ax1.axhline(\n",
|
|
" y=mbps,\n",
|
|
" color=\"gray\",\n",
|
|
" linestyle=\"--\",\n",
|
|
" alpha=0.3,\n",
|
|
" linewidth=0.5,\n",
|
|
" zorder=1,\n",
|
|
" )\n",
|
|
" # Add text label on the right side with improved formatting\n",
|
|
" ms_label = format_time_label(ms)\n",
|
|
" ax1.text(\n",
|
|
" x_max,\n",
|
|
" mbps,\n",
|
|
" ms_label,\n",
|
|
" verticalalignment=\"center\",\n",
|
|
" horizontalalignment=\"left\",\n",
|
|
" fontsize=7,\n",
|
|
" color=\"gray\",\n",
|
|
" alpha=0.6,\n",
|
|
" bbox=dict(\n",
|
|
" boxstyle=\"round,pad=0.1\",\n",
|
|
" facecolor=\"white\",\n",
|
|
" alpha=0.8,\n",
|
|
" edgecolor=\"none\",\n",
|
|
" ),\n",
|
|
" )\n",
|
|
"\n",
|
|
" ax1.set_title(title)\n",
|
|
" ax1.grid(True, linestyle=\":\", linewidth=0.7, alpha=0.6, zorder=0)\n",
|
|
" ax1.legend()\n",
|
|
" plt.tight_layout()\n",
|
|
" plt.show()\n",
|
|
"\n",
|
|
"\n",
|
|
"plot_xy(\n",
|
|
" bench_points,\n",
|
|
" \"comp_time_ns\",\n",
|
|
" f\"Compression — {len(files)} files, total {orig_total / 1_048_576:.2f} MiB (median of {ROUNDS} rounds)\",\n",
|
|
" \"Compressed size (% of original)\",\n",
|
|
" \"Compression speed (Mbit/s)\",\n",
|
|
" \"Time (ms)\",\n",
|
|
")\n",
|
|
"plot_xy(\n",
|
|
" bench_points,\n",
|
|
" \"decomp_time_ns\",\n",
|
|
" f\"Decompression — {len(files)} files, total {orig_total / 1_048_576:.2f} MiB (median of {ROUNDS} rounds)\",\n",
|
|
" \"Compressed size (% of original)\",\n",
|
|
" \"Decompression speed (Mbit/s)\",\n",
|
|
" \"Time (ms)\",\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a8a2bdcd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|