Initial commit
This commit is contained in:
4
compression_benchmark/__main__.py
Normal file
4
compression_benchmark/__main__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .bench import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
245
compression_benchmark/bench.py
Normal file
245
compression_benchmark/bench.py
Normal file
@@ -0,0 +1,245 @@
|
||||
import argparse
|
||||
import gc
|
||||
import gzip
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import brotli
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
from zstandard import ZstdCompressor, ZstdDecompressor
|
||||
|
||||
ALGS = {
|
||||
"zstd": {
|
||||
"levels": [1, 2, 3, 4, 5, 6, 7, 9, 10, 12, 14, 16, 22],
|
||||
"c": lambda b, q: ZstdCompressor(level=q).compress(b),
|
||||
"d": lambda b: ZstdDecompressor().decompress(b),
|
||||
},
|
||||
"brotli": {
|
||||
"levels": [1, 2, 3, 4, 5, 7, 9, 11],
|
||||
"c": lambda b, q: brotli.compress(b, quality=q),
|
||||
"d": lambda b: brotli.decompress(b),
|
||||
},
|
||||
"gzip": {
|
||||
"levels": [1, 3, 4, 5, 9],
|
||||
"c": lambda b, q: gzip.compress(b, compresslevel=q),
|
||||
"d": lambda b: gzip.decompress(b),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchPoint:
|
||||
name: str
|
||||
lvl: int
|
||||
compress: Callable
|
||||
decompress: Callable
|
||||
compdata: list
|
||||
comp_time_ns: int = 0
|
||||
decomp_time_ns: int = 0
|
||||
|
||||
|
||||
def plot_xy(
|
||||
bench_points, time_attr, title, xlabel, ylabel, ylabel2, orig_total, out_path
|
||||
):
|
||||
fig, ax1 = plt.subplots(figsize=(9, 5), dpi=300)
|
||||
all_data_points = []
|
||||
for alg_name in ALGS:
|
||||
alg_points = [bp for bp in bench_points if bp.name == alg_name]
|
||||
comp_totals = np.array([sum(len(b) for b in bp.compdata) for bp in alg_points])
|
||||
times_ns = np.array([getattr(bp, time_attr) for bp in alg_points])
|
||||
levels = [bp.lvl for bp in alg_points]
|
||||
x = 100.0 * comp_totals / orig_total
|
||||
y_mbps = orig_total * 8000.0 / times_ns
|
||||
for xx, yy in zip(x, y_mbps):
|
||||
all_data_points.append((xx, yy))
|
||||
ax1.plot(x, y_mbps, marker="s", linewidth=1.25, label=alg_name, zorder=2)
|
||||
levelstyle = {
|
||||
"ha": "center",
|
||||
"va": "center",
|
||||
"fontsize": 7,
|
||||
"zorder": 3,
|
||||
"bbox": dict(
|
||||
boxstyle="round,pad=0.2",
|
||||
facecolor="white",
|
||||
edgecolor="black",
|
||||
lw=0.4,
|
||||
),
|
||||
}
|
||||
for xx, yy, L in zip(x, y_mbps, levels):
|
||||
ax1.text(xx, yy, str(L), **levelstyle)
|
||||
ax1.set_xlabel(xlabel)
|
||||
ax1.set_ylabel(ylabel)
|
||||
y_min, y_max = ax1.get_ylim()
|
||||
x_min, x_max = ax1.get_xlim()
|
||||
unique_y_values = sorted(set(point[1] for point in all_data_points))
|
||||
visible_y_values = [y for y in unique_y_values if y_min <= y <= y_max]
|
||||
y_range = y_max - y_min
|
||||
min_spacing = y_range * 0.05
|
||||
selected_y_values = []
|
||||
for y in visible_y_values:
|
||||
if not selected_y_values or all(
|
||||
abs(y - selected) >= min_spacing for selected in selected_y_values
|
||||
):
|
||||
selected_y_values.append(y)
|
||||
|
||||
def format_time_label(ms):
|
||||
if ms < 10:
|
||||
return f"{ms:.1f} ms"
|
||||
elif ms < 1000:
|
||||
return f"{ms:.0f} ms"
|
||||
else:
|
||||
seconds = ms / 1000
|
||||
return f"{seconds:.1f} s"
|
||||
|
||||
for y_mbps in selected_y_values:
|
||||
rightmost_x = max(
|
||||
[point[0] for point in all_data_points if abs(point[1] - y_mbps) < 1e-6],
|
||||
default=x_min,
|
||||
)
|
||||
ms = orig_total * 8000.0 / (y_mbps * 1e6)
|
||||
ax1.plot(
|
||||
[rightmost_x, x_max],
|
||||
[y_mbps, y_mbps],
|
||||
color="black",
|
||||
linestyle="-",
|
||||
alpha=0.3,
|
||||
linewidth=0.8,
|
||||
zorder=1,
|
||||
)
|
||||
ms_label = format_time_label(ms)
|
||||
ax1.text(
|
||||
x_max,
|
||||
y_mbps,
|
||||
ms_label,
|
||||
verticalalignment="center",
|
||||
horizontalalignment="left",
|
||||
fontsize=7,
|
||||
color="black",
|
||||
alpha=0.8,
|
||||
bbox=dict(
|
||||
boxstyle="round,pad=0.1", facecolor="white", alpha=0.9, edgecolor="none"
|
||||
),
|
||||
)
|
||||
ms_reference_values = [0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000]
|
||||
for ms in ms_reference_values:
|
||||
mbps = orig_total * 8000.0 / (ms * 1e6)
|
||||
if y_min <= mbps <= y_max:
|
||||
if not any(abs(mbps - y) < min_spacing for y in selected_y_values):
|
||||
ax1.axhline(
|
||||
y=mbps,
|
||||
color="gray",
|
||||
linestyle="--",
|
||||
alpha=0.3,
|
||||
linewidth=0.5,
|
||||
zorder=1,
|
||||
)
|
||||
ms_label = format_time_label(ms)
|
||||
ax1.text(
|
||||
x_max,
|
||||
mbps,
|
||||
ms_label,
|
||||
verticalalignment="center",
|
||||
horizontalalignment="left",
|
||||
fontsize=7,
|
||||
color="gray",
|
||||
alpha=0.6,
|
||||
bbox=dict(
|
||||
boxstyle="round,pad=0.1",
|
||||
facecolor="white",
|
||||
alpha=0.8,
|
||||
edgecolor="none",
|
||||
),
|
||||
)
|
||||
ax1.set_title(title)
|
||||
ax1.grid(True, linestyle=":", linewidth=0.7, alpha=0.6, zorder=0)
|
||||
ax1.legend()
|
||||
plt.tight_layout()
|
||||
plt.savefig(out_path)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def run_benchmark(folder: Path, rounds: int):
|
||||
files = [p.read_bytes() for p in folder.rglob("*") if p.is_file()]
|
||||
orig_total = sum(len(b) for b in files)
|
||||
print(f"Loaded {len(files)} files, total {orig_total / 1e6:.2f} MB")
|
||||
assert files, "No files found in the specified folder"
|
||||
bench_points = [
|
||||
BenchPoint(alg, lvl, spec["c"], spec["d"], [])
|
||||
for alg, spec in ALGS.items()
|
||||
for lvl in spec["levels"]
|
||||
]
|
||||
P = len(bench_points)
|
||||
C_ns = np.zeros((P, rounds), dtype=np.int64)
|
||||
D_ns = np.zeros((P, rounds), dtype=np.int64)
|
||||
gc.disable()
|
||||
time.time = lambda: 0
|
||||
progress = tqdm(range(rounds), desc="Rounds")
|
||||
for r in progress:
|
||||
for i, bp in enumerate(bench_points):
|
||||
progress.set_postfix_str(f"Compress {bp.name}/{bp.lvl}")
|
||||
for f, data in enumerate(files):
|
||||
t0 = time.perf_counter_ns()
|
||||
out = bp.compress(data, bp.lvl)
|
||||
t1 = time.perf_counter_ns()
|
||||
C_ns[i, r] += t1 - t0
|
||||
if r == 0:
|
||||
bp.compdata.append(out)
|
||||
else:
|
||||
assert out == bp.compdata[f], (
|
||||
f"Compressed data changed between rounds for {bp.name}, level={bp.lvl}"
|
||||
)
|
||||
progress.set_postfix_str(f"Decompress {bp.name}/{bp.lvl}")
|
||||
for file in bp.compdata:
|
||||
t0 = time.perf_counter_ns()
|
||||
_ = bp.decompress(file)
|
||||
t1 = time.perf_counter_ns()
|
||||
D_ns[i, r] += t1 - t0
|
||||
progress.set_postfix_str("Benchmark Done")
|
||||
gc.enable()
|
||||
Cns, Dns = np.median(C_ns, 1), np.median(D_ns, 1)
|
||||
for i, bp in enumerate(bench_points):
|
||||
bp.comp_time_ns = int(Cns[i])
|
||||
bp.decomp_time_ns = int(Dns[i])
|
||||
return bench_points, files, orig_total
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Compression benchmark.")
|
||||
parser.add_argument("folder", type=str, help="Folder to benchmark")
|
||||
parser.add_argument("--rounds", type=int, default=5, help="Number of rounds")
|
||||
parser.add_argument(
|
||||
"--outdir", type=str, default=".", help="Output directory for PNGs"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
folder = Path(args.folder)
|
||||
rounds = args.rounds
|
||||
outdir = Path(args.outdir)
|
||||
outdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
bench_points, files, orig_total = run_benchmark(folder, rounds)
|
||||
|
||||
plot_xy(
|
||||
bench_points,
|
||||
"comp_time_ns",
|
||||
f"Compression — {len(files)} files, total {orig_total / 1_048_576:.2f} MiB (median of {rounds} rounds)",
|
||||
"Compressed size (% of original)",
|
||||
"Compression speed (Mbit/s)",
|
||||
"Time (ms)",
|
||||
orig_total,
|
||||
outdir / "compression.png",
|
||||
)
|
||||
plot_xy(
|
||||
bench_points,
|
||||
"decomp_time_ns",
|
||||
f"Decompression — {len(files)} files, total {orig_total / 1_048_576:.2f} MiB (median of {rounds} rounds)",
|
||||
"Compressed size (% of original)",
|
||||
"Decompression speed (Mbit/s)",
|
||||
"Time (ms)",
|
||||
orig_total,
|
||||
outdir / "decompression.png",
|
||||
)
|
||||
print(f"Saved plots to {outdir}")
|
||||
Reference in New Issue
Block a user