import argparse import gc import gzip import time from dataclasses import dataclass from pathlib import Path from typing import Callable import brotli import matplotlib.pyplot as plt import numpy as np from tqdm.auto import tqdm from zstandard import ZstdCompressor, ZstdDecompressor ALGS = { "zstd": { "levels": [1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 14, 16, 22], "c": lambda b, q: ZstdCompressor(level=q).compress(b), "d": lambda b: ZstdDecompressor().decompress(b), }, "brotli": { "levels": [1, 2, 3, 4, 5, 7, 9, 11], "c": lambda b, q: brotli.compress(b, quality=q), "d": lambda b: brotli.decompress(b), }, "gzip": { "levels": [1, 3, 4, 5, 9], "c": lambda b, q: gzip.compress(b, compresslevel=q), "d": lambda b: gzip.decompress(b), }, } @dataclass class BenchPoint: name: str lvl: int compress: Callable decompress: Callable compdata: list comp_time_ns: int = 0 decomp_time_ns: int = 0 def plot_xy( bench_points, time_attr, title, xlabel, ylabel, ylabel2, orig_total, out_path ): fig, ax1 = plt.subplots(figsize=(9, 5), dpi=300) all_data_points = [] for alg_name in ALGS: alg_points = [bp for bp in bench_points if bp.name == alg_name] comp_totals = np.array([sum(len(b) for b in bp.compdata) for bp in alg_points]) times_ns = np.array([getattr(bp, time_attr) for bp in alg_points]) levels = [bp.lvl for bp in alg_points] x = 100.0 * comp_totals / orig_total y_mbps = orig_total * 8000.0 / times_ns for xx, yy in zip(x, y_mbps): all_data_points.append((xx, yy)) ax1.plot(x, y_mbps, marker="s", linewidth=1.25, label=alg_name, zorder=2) levelstyle = { "ha": "center", "va": "center", "fontsize": 7, "zorder": 3, "bbox": dict( boxstyle="round,pad=0.2", facecolor="white", edgecolor="black", lw=0.4, ), } for xx, yy, L in zip(x, y_mbps, levels): ax1.text(xx, yy, str(L), **levelstyle) ax1.set_xlabel(xlabel) ax1.set_ylabel(ylabel) y_min, y_max = ax1.get_ylim() x_min, x_max = ax1.get_xlim() unique_y_values = sorted(set(point[1] for point in all_data_points)) visible_y_values = [y for y in unique_y_values if y_min <= y <= y_max] y_range = y_max - y_min min_spacing = y_range * 0.05 selected_y_values = [] for y in visible_y_values: if not selected_y_values or all( abs(y - selected) >= min_spacing for selected in selected_y_values ): selected_y_values.append(y) def format_time_label(ms): if ms < 10: return f"{ms:.1f} ms" elif ms < 1000: return f"{ms:.0f} ms" else: seconds = ms / 1000 return f"{seconds:.1f} s" for y_mbps in selected_y_values: rightmost_x = max( [point[0] for point in all_data_points if abs(point[1] - y_mbps) < 1e-6], default=x_min, ) ms = orig_total * 8000.0 / (y_mbps * 1e6) ax1.plot( [rightmost_x, x_max], [y_mbps, y_mbps], color="black", linestyle="-", alpha=0.3, linewidth=0.8, zorder=1, ) ms_label = format_time_label(ms) ax1.text( x_max, y_mbps, ms_label, verticalalignment="center", horizontalalignment="left", fontsize=7, color="black", alpha=0.8, bbox=dict( boxstyle="round,pad=0.1", facecolor="white", alpha=0.9, edgecolor="none" ), ) ms_reference_values = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000] for ms in ms_reference_values: mbps = orig_total * 8000.0 / (ms * 1e6) if y_min <= mbps <= y_max: if not any(abs(mbps - y) < min_spacing for y in selected_y_values): ax1.axhline( y=mbps, color="gray", linestyle="--", alpha=0.3, linewidth=0.5, zorder=1, ) ms_label = format_time_label(ms) ax1.text( x_max, mbps, ms_label, verticalalignment="center", horizontalalignment="left", fontsize=7, color="gray", alpha=0.6, bbox=dict( boxstyle="round,pad=0.1", facecolor="white", alpha=0.8, edgecolor="none", ), ) ax1.set_title(title) ax1.grid(True, linestyle=":", linewidth=0.7, alpha=0.6, zorder=0) ax1.legend() plt.tight_layout() plt.savefig(out_path) plt.close(fig) def run_benchmark(folder: Path, rounds: int): files = [p.read_bytes() for p in folder.rglob("*") if p.is_file()] orig_total = sum(len(b) for b in files) print(f"Loaded {len(files)} files, total {orig_total / 1e6:.2f} MB") assert files, "No files found in the specified folder" bench_points = [ BenchPoint(alg, lvl, spec["c"], spec["d"], []) for alg, spec in ALGS.items() for lvl in spec["levels"] ] P = len(bench_points) C_ns = np.zeros((P, rounds), dtype=np.int64) D_ns = np.zeros((P, rounds), dtype=np.int64) gc.disable() time.time = lambda: 0 progress = tqdm(range(rounds), desc="Rounds") for r in progress: for i, bp in enumerate(bench_points): progress.set_postfix_str(f"Compress {bp.name}/{bp.lvl}") for f, data in enumerate(files): t0 = time.perf_counter_ns() out = bp.compress(data, bp.lvl) t1 = time.perf_counter_ns() C_ns[i, r] += t1 - t0 if r == 0: bp.compdata.append(out) else: assert out == bp.compdata[f], ( f"Compressed data changed between rounds for {bp.name}, level={bp.lvl}" ) progress.set_postfix_str(f"Decompress {bp.name}/{bp.lvl}") for file in bp.compdata: t0 = time.perf_counter_ns() _ = bp.decompress(file) t1 = time.perf_counter_ns() D_ns[i, r] += t1 - t0 progress.set_postfix_str("Benchmark Done") gc.enable() Cns, Dns = np.median(C_ns, 1), np.median(D_ns, 1) for i, bp in enumerate(bench_points): bp.comp_time_ns = int(Cns[i]) bp.decomp_time_ns = int(Dns[i]) return bench_points, files, orig_total def main(): parser = argparse.ArgumentParser(description="Compression benchmark.") parser.add_argument("folder", type=str, help="Folder to benchmark") parser.add_argument("--rounds", type=int, default=5, help="Number of rounds") parser.add_argument( "--outdir", type=str, default=".", help="Output directory for PNGs" ) args = parser.parse_args() folder = Path(args.folder) rounds = args.rounds outdir = Path(args.outdir) outdir.mkdir(parents=True, exist_ok=True) bench_points, files, orig_total = run_benchmark(folder, rounds) plot_xy( bench_points, "comp_time_ns", f"Compression — {len(files)} files, total {orig_total / 1_048_576:.2f} MiB (median of {rounds} rounds)", "Compressed size (% of original)", "Compression speed (Mbit/s)", "Time (ms)", orig_total, outdir / "compression.png", ) plot_xy( bench_points, "decomp_time_ns", f"Decompression — {len(files)} files, total {orig_total / 1_048_576:.2f} MiB (median of {rounds} rounds)", "Compressed size (% of original)", "Decompression speed (Mbit/s)", "Time (ms)", orig_total, outdir / "decompression.png", ) print(f"Saved plots to {outdir}")