246 lines
8.1 KiB
Python
246 lines
8.1 KiB
Python
import argparse
|
|
import gc
|
|
import gzip
|
|
import time
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
|
|
import brotli
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from tqdm.auto import tqdm
|
|
from zstandard import ZstdCompressor, ZstdDecompressor
|
|
|
|
ALGS = {
|
|
"zstd": {
|
|
"levels": [1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 14, 16, 22],
|
|
"c": lambda b, q: ZstdCompressor(level=q).compress(b),
|
|
"d": lambda b: ZstdDecompressor().decompress(b),
|
|
},
|
|
"brotli": {
|
|
"levels": [1, 2, 3, 4, 5, 7, 9, 11],
|
|
"c": lambda b, q: brotli.compress(b, quality=q),
|
|
"d": lambda b: brotli.decompress(b),
|
|
},
|
|
"gzip": {
|
|
"levels": [1, 3, 4, 5, 9],
|
|
"c": lambda b, q: gzip.compress(b, compresslevel=q),
|
|
"d": lambda b: gzip.decompress(b),
|
|
},
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class BenchPoint:
|
|
name: str
|
|
lvl: int
|
|
compress: Callable
|
|
decompress: Callable
|
|
compdata: list
|
|
comp_time_ns: int = 0
|
|
decomp_time_ns: int = 0
|
|
|
|
|
|
def plot_xy(
|
|
bench_points, time_attr, title, xlabel, ylabel, ylabel2, orig_total, out_path
|
|
):
|
|
fig, ax1 = plt.subplots(figsize=(9, 5), dpi=300)
|
|
all_data_points = []
|
|
for alg_name in ALGS:
|
|
alg_points = [bp for bp in bench_points if bp.name == alg_name]
|
|
comp_totals = np.array([sum(len(b) for b in bp.compdata) for bp in alg_points])
|
|
times_ns = np.array([getattr(bp, time_attr) for bp in alg_points])
|
|
levels = [bp.lvl for bp in alg_points]
|
|
x = 100.0 * comp_totals / orig_total
|
|
y_mbps = orig_total * 8000.0 / times_ns
|
|
for xx, yy in zip(x, y_mbps):
|
|
all_data_points.append((xx, yy))
|
|
ax1.plot(x, y_mbps, marker="s", linewidth=1.25, label=alg_name, zorder=2)
|
|
levelstyle = {
|
|
"ha": "center",
|
|
"va": "center",
|
|
"fontsize": 7,
|
|
"zorder": 3,
|
|
"bbox": dict(
|
|
boxstyle="round,pad=0.2",
|
|
facecolor="white",
|
|
edgecolor="black",
|
|
lw=0.4,
|
|
),
|
|
}
|
|
for xx, yy, L in zip(x, y_mbps, levels):
|
|
ax1.text(xx, yy, str(L), **levelstyle)
|
|
ax1.set_xlabel(xlabel)
|
|
ax1.set_ylabel(ylabel)
|
|
y_min, y_max = ax1.get_ylim()
|
|
x_min, x_max = ax1.get_xlim()
|
|
unique_y_values = sorted(set(point[1] for point in all_data_points))
|
|
visible_y_values = [y for y in unique_y_values if y_min <= y <= y_max]
|
|
y_range = y_max - y_min
|
|
min_spacing = y_range * 0.05
|
|
selected_y_values = []
|
|
for y in visible_y_values:
|
|
if not selected_y_values or all(
|
|
abs(y - selected) >= min_spacing for selected in selected_y_values
|
|
):
|
|
selected_y_values.append(y)
|
|
|
|
def format_time_label(ms):
|
|
if ms < 10:
|
|
return f"{ms:.1f} ms"
|
|
elif ms < 1000:
|
|
return f"{ms:.0f} ms"
|
|
else:
|
|
seconds = ms / 1000
|
|
return f"{seconds:.1f} s"
|
|
|
|
for y_mbps in selected_y_values:
|
|
rightmost_x = max(
|
|
[point[0] for point in all_data_points if abs(point[1] - y_mbps) < 1e-6],
|
|
default=x_min,
|
|
)
|
|
ms = orig_total * 8000.0 / (y_mbps * 1e6)
|
|
ax1.plot(
|
|
[rightmost_x, x_max],
|
|
[y_mbps, y_mbps],
|
|
color="black",
|
|
linestyle="-",
|
|
alpha=0.3,
|
|
linewidth=0.8,
|
|
zorder=1,
|
|
)
|
|
ms_label = format_time_label(ms)
|
|
ax1.text(
|
|
x_max,
|
|
y_mbps,
|
|
ms_label,
|
|
verticalalignment="center",
|
|
horizontalalignment="left",
|
|
fontsize=7,
|
|
color="black",
|
|
alpha=0.8,
|
|
bbox=dict(
|
|
boxstyle="round,pad=0.1", facecolor="white", alpha=0.9, edgecolor="none"
|
|
),
|
|
)
|
|
ms_reference_values = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]
|
|
for ms in ms_reference_values:
|
|
mbps = orig_total * 8000.0 / (ms * 1e6)
|
|
if y_min <= mbps <= y_max:
|
|
if not any(abs(mbps - y) < min_spacing for y in selected_y_values):
|
|
ax1.axhline(
|
|
y=mbps,
|
|
color="gray",
|
|
linestyle="--",
|
|
alpha=0.3,
|
|
linewidth=0.5,
|
|
zorder=1,
|
|
)
|
|
ms_label = format_time_label(ms)
|
|
ax1.text(
|
|
x_max,
|
|
mbps,
|
|
ms_label,
|
|
verticalalignment="center",
|
|
horizontalalignment="left",
|
|
fontsize=7,
|
|
color="gray",
|
|
alpha=0.6,
|
|
bbox=dict(
|
|
boxstyle="round,pad=0.1",
|
|
facecolor="white",
|
|
alpha=0.8,
|
|
edgecolor="none",
|
|
),
|
|
)
|
|
ax1.set_title(title)
|
|
ax1.grid(True, linestyle=":", linewidth=0.7, alpha=0.6, zorder=0)
|
|
ax1.legend()
|
|
plt.tight_layout()
|
|
plt.savefig(out_path)
|
|
plt.close(fig)
|
|
|
|
|
|
def run_benchmark(folder: Path, rounds: int):
|
|
files = [p.read_bytes() for p in folder.rglob("*") if p.is_file()]
|
|
orig_total = sum(len(b) for b in files)
|
|
print(f"Loaded {len(files)} files, total {orig_total / 1e6:.2f} MB")
|
|
assert files, "No files found in the specified folder"
|
|
bench_points = [
|
|
BenchPoint(alg, lvl, spec["c"], spec["d"], [])
|
|
for alg, spec in ALGS.items()
|
|
for lvl in spec["levels"]
|
|
]
|
|
P = len(bench_points)
|
|
C_ns = np.zeros((P, rounds), dtype=np.int64)
|
|
D_ns = np.zeros((P, rounds), dtype=np.int64)
|
|
gc.disable()
|
|
time.time = lambda: 0
|
|
progress = tqdm(range(rounds), desc="Rounds")
|
|
for r in progress:
|
|
for i, bp in enumerate(bench_points):
|
|
progress.set_postfix_str(f"Compress {bp.name}/{bp.lvl}")
|
|
for f, data in enumerate(files):
|
|
t0 = time.perf_counter_ns()
|
|
out = bp.compress(data, bp.lvl)
|
|
t1 = time.perf_counter_ns()
|
|
C_ns[i, r] += t1 - t0
|
|
if r == 0:
|
|
bp.compdata.append(out)
|
|
else:
|
|
assert out == bp.compdata[f], (
|
|
f"Compressed data changed between rounds for {bp.name}, level={bp.lvl}"
|
|
)
|
|
progress.set_postfix_str(f"Decompress {bp.name}/{bp.lvl}")
|
|
for file in bp.compdata:
|
|
t0 = time.perf_counter_ns()
|
|
_ = bp.decompress(file)
|
|
t1 = time.perf_counter_ns()
|
|
D_ns[i, r] += t1 - t0
|
|
progress.set_postfix_str("Benchmark Done")
|
|
gc.enable()
|
|
Cns, Dns = np.median(C_ns, 1), np.median(D_ns, 1)
|
|
for i, bp in enumerate(bench_points):
|
|
bp.comp_time_ns = int(Cns[i])
|
|
bp.decomp_time_ns = int(Dns[i])
|
|
return bench_points, files, orig_total
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Compression benchmark.")
|
|
parser.add_argument("folder", type=str, help="Folder to benchmark")
|
|
parser.add_argument("--rounds", type=int, default=5, help="Number of rounds")
|
|
parser.add_argument(
|
|
"--outdir", type=str, default=".", help="Output directory for PNGs"
|
|
)
|
|
args = parser.parse_args()
|
|
folder = Path(args.folder)
|
|
rounds = args.rounds
|
|
outdir = Path(args.outdir)
|
|
outdir.mkdir(parents=True, exist_ok=True)
|
|
|
|
bench_points, files, orig_total = run_benchmark(folder, rounds)
|
|
|
|
plot_xy(
|
|
bench_points,
|
|
"comp_time_ns",
|
|
f"Compression — {len(files)} files, total {orig_total / 1_048_576:.2f} MiB (median of {rounds} rounds)",
|
|
"Compressed size (% of original)",
|
|
"Compression speed (Mbit/s)",
|
|
"Time (ms)",
|
|
orig_total,
|
|
outdir / "compression.png",
|
|
)
|
|
plot_xy(
|
|
bench_points,
|
|
"decomp_time_ns",
|
|
f"Decompression — {len(files)} files, total {orig_total / 1_048_576:.2f} MiB (median of {rounds} rounds)",
|
|
"Compressed size (% of original)",
|
|
"Decompression speed (Mbit/s)",
|
|
"Time (ms)",
|
|
orig_total,
|
|
outdir / "decompression.png",
|
|
)
|
|
print(f"Saved plots to {outdir}")
|