2025-08-17 22:57:52 -06:00

246 lines
8.1 KiB
Python

import argparse
import gc
import gzip
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
import brotli
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
from zstandard import ZstdCompressor, ZstdDecompressor
ALGS = {
"zstd": {
"levels": [1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 14, 16, 22],
"c": lambda b, q: ZstdCompressor(level=q).compress(b),
"d": lambda b: ZstdDecompressor().decompress(b),
},
"brotli": {
"levels": [1, 2, 3, 4, 5, 7, 9, 11],
"c": lambda b, q: brotli.compress(b, quality=q),
"d": lambda b: brotli.decompress(b),
},
"gzip": {
"levels": [1, 3, 4, 5, 9],
"c": lambda b, q: gzip.compress(b, compresslevel=q),
"d": lambda b: gzip.decompress(b),
},
}
@dataclass
class BenchPoint:
name: str
lvl: int
compress: Callable
decompress: Callable
compdata: list
comp_time_ns: int = 0
decomp_time_ns: int = 0
def plot_xy(
bench_points, time_attr, title, xlabel, ylabel, ylabel2, orig_total, out_path
):
fig, ax1 = plt.subplots(figsize=(9, 5), dpi=300)
all_data_points = []
for alg_name in ALGS:
alg_points = [bp for bp in bench_points if bp.name == alg_name]
comp_totals = np.array([sum(len(b) for b in bp.compdata) for bp in alg_points])
times_ns = np.array([getattr(bp, time_attr) for bp in alg_points])
levels = [bp.lvl for bp in alg_points]
x = 100.0 * comp_totals / orig_total
y_mbps = orig_total * 8000.0 / times_ns
for xx, yy in zip(x, y_mbps):
all_data_points.append((xx, yy))
ax1.plot(x, y_mbps, marker="s", linewidth=1.25, label=alg_name, zorder=2)
levelstyle = {
"ha": "center",
"va": "center",
"fontsize": 7,
"zorder": 3,
"bbox": dict(
boxstyle="round,pad=0.2",
facecolor="white",
edgecolor="black",
lw=0.4,
),
}
for xx, yy, L in zip(x, y_mbps, levels):
ax1.text(xx, yy, str(L), **levelstyle)
ax1.set_xlabel(xlabel)
ax1.set_ylabel(ylabel)
y_min, y_max = ax1.get_ylim()
x_min, x_max = ax1.get_xlim()
unique_y_values = sorted(set(point[1] for point in all_data_points))
visible_y_values = [y for y in unique_y_values if y_min <= y <= y_max]
y_range = y_max - y_min
min_spacing = y_range * 0.05
selected_y_values = []
for y in visible_y_values:
if not selected_y_values or all(
abs(y - selected) >= min_spacing for selected in selected_y_values
):
selected_y_values.append(y)
def format_time_label(ms):
if ms < 10:
return f"{ms:.1f} ms"
elif ms < 1000:
return f"{ms:.0f} ms"
else:
seconds = ms / 1000
return f"{seconds:.1f} s"
for y_mbps in selected_y_values:
rightmost_x = max(
[point[0] for point in all_data_points if abs(point[1] - y_mbps) < 1e-6],
default=x_min,
)
ms = orig_total * 8000.0 / (y_mbps * 1e6)
ax1.plot(
[rightmost_x, x_max],
[y_mbps, y_mbps],
color="black",
linestyle="-",
alpha=0.3,
linewidth=0.8,
zorder=1,
)
ms_label = format_time_label(ms)
ax1.text(
x_max,
y_mbps,
ms_label,
verticalalignment="center",
horizontalalignment="left",
fontsize=7,
color="black",
alpha=0.8,
bbox=dict(
boxstyle="round,pad=0.1", facecolor="white", alpha=0.9, edgecolor="none"
),
)
ms_reference_values = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]
for ms in ms_reference_values:
mbps = orig_total * 8000.0 / (ms * 1e6)
if y_min <= mbps <= y_max:
if not any(abs(mbps - y) < min_spacing for y in selected_y_values):
ax1.axhline(
y=mbps,
color="gray",
linestyle="--",
alpha=0.3,
linewidth=0.5,
zorder=1,
)
ms_label = format_time_label(ms)
ax1.text(
x_max,
mbps,
ms_label,
verticalalignment="center",
horizontalalignment="left",
fontsize=7,
color="gray",
alpha=0.6,
bbox=dict(
boxstyle="round,pad=0.1",
facecolor="white",
alpha=0.8,
edgecolor="none",
),
)
ax1.set_title(title)
ax1.grid(True, linestyle=":", linewidth=0.7, alpha=0.6, zorder=0)
ax1.legend()
plt.tight_layout()
plt.savefig(out_path)
plt.close(fig)
def run_benchmark(folder: Path, rounds: int):
files = [p.read_bytes() for p in folder.rglob("*") if p.is_file()]
orig_total = sum(len(b) for b in files)
print(f"Loaded {len(files)} files, total {orig_total / 1e6:.2f} MB")
assert files, "No files found in the specified folder"
bench_points = [
BenchPoint(alg, lvl, spec["c"], spec["d"], [])
for alg, spec in ALGS.items()
for lvl in spec["levels"]
]
P = len(bench_points)
C_ns = np.zeros((P, rounds), dtype=np.int64)
D_ns = np.zeros((P, rounds), dtype=np.int64)
gc.disable()
time.time = lambda: 0
progress = tqdm(range(rounds), desc="Rounds")
for r in progress:
for i, bp in enumerate(bench_points):
progress.set_postfix_str(f"Compress {bp.name}/{bp.lvl}")
for f, data in enumerate(files):
t0 = time.perf_counter_ns()
out = bp.compress(data, bp.lvl)
t1 = time.perf_counter_ns()
C_ns[i, r] += t1 - t0
if r == 0:
bp.compdata.append(out)
else:
assert out == bp.compdata[f], (
f"Compressed data changed between rounds for {bp.name}, level={bp.lvl}"
)
progress.set_postfix_str(f"Decompress {bp.name}/{bp.lvl}")
for file in bp.compdata:
t0 = time.perf_counter_ns()
_ = bp.decompress(file)
t1 = time.perf_counter_ns()
D_ns[i, r] += t1 - t0
progress.set_postfix_str("Benchmark Done")
gc.enable()
Cns, Dns = np.median(C_ns, 1), np.median(D_ns, 1)
for i, bp in enumerate(bench_points):
bp.comp_time_ns = int(Cns[i])
bp.decomp_time_ns = int(Dns[i])
return bench_points, files, orig_total
def main():
parser = argparse.ArgumentParser(description="Compression benchmark.")
parser.add_argument("folder", type=str, help="Folder to benchmark")
parser.add_argument("--rounds", type=int, default=5, help="Number of rounds")
parser.add_argument(
"--outdir", type=str, default=".", help="Output directory for PNGs"
)
args = parser.parse_args()
folder = Path(args.folder)
rounds = args.rounds
outdir = Path(args.outdir)
outdir.mkdir(parents=True, exist_ok=True)
bench_points, files, orig_total = run_benchmark(folder, rounds)
plot_xy(
bench_points,
"comp_time_ns",
f"Compression — {len(files)} files, total {orig_total / 1_048_576:.2f} MiB (median of {rounds} rounds)",
"Compressed size (% of original)",
"Compression speed (Mbit/s)",
"Time (ms)",
orig_total,
outdir / "compression.png",
)
plot_xy(
bench_points,
"decomp_time_ns",
f"Decompression — {len(files)} files, total {orig_total / 1_048_576:.2f} MiB (median of {rounds} rounds)",
"Compressed size (% of original)",
"Decompression speed (Mbit/s)",
"Time (ms)",
orig_total,
outdir / "decompression.png",
)
print(f"Saved plots to {outdir}")