Precise update output length calculation to allocate exactly sized buffers.

This commit is contained in:
Leo Vasanko
2025-11-06 07:45:15 -06:00
parent 438627e0db
commit bcf4655f64
7 changed files with 77 additions and 42 deletions

View File

@@ -19,6 +19,7 @@ ABYTES_MIN = _lib.aegis128l_abytes_min()
ABYTES_MAX = _lib.aegis128l_abytes_max()
TAILBYTES_MAX = _lib.aegis128l_tailbytes_max()
ALIGNMENT = 32
RATE = 32
def calc_update_output_size(bytes_in: int, bytes_next: int) -> int:
@@ -31,8 +32,7 @@ def calc_update_output_size(bytes_in: int, bytes_next: int) -> int:
Returns:
Number of bytes that the next update will write.
"""
remainder = bytes_in % ALIGNMENT
return ((remainder + int(bytes_next)) // ALIGNMENT) * ALIGNMENT
return (((bytes_in % RATE) + bytes_next) // RATE) * RATE
def _ptr(buf):
@@ -651,7 +651,7 @@ class Encryptor:
f"state encrypt update failed: {err_name} written {written[0]}"
)
w = int(written[0])
# Advance counters for this update call
assert w == expected_out
self._bytes_in += message.nbytes
self._bytes_out += w
return out[:w]
@@ -810,11 +810,10 @@ class Decryptor:
RuntimeError: If the C update call fails.
"""
ct = memoryview(ct)
produced = calc_update_output_size(self._bytes_in, ct.nbytes)
requirement = produced + ALIGNMENT # libaegis requires larger than actual w
out = into if into is not None else bytearray(requirement)
expected_out = calc_update_output_size(self._bytes_in, ct.nbytes)
out = into if into is not None else bytearray(expected_out)
out = memoryview(out)
if out.nbytes < requirement:
if out.nbytes < expected_out:
raise TypeError("into length must be >= required capacity for this update")
written = ffi.new("size_t *")
rc = _lib.aegis128l_state_decrypt_detached_update(
@@ -830,6 +829,7 @@ class Decryptor:
err_name = errno.errorcode.get(err_num, f"errno_{err_num}")
raise RuntimeError(f"state decrypt update failed: {err_name}")
w = int(written[0])
assert w == expected_out
self._bytes_in += ct.nbytes
self._bytes_out += w
return out[:w]
@@ -898,6 +898,7 @@ __all__ = [
"ABYTES_MAX",
"TAILBYTES_MAX",
"ALIGNMENT",
"RATE",
# helpers
"calc_update_output_size",
# one-shot functions

View File

@@ -19,6 +19,7 @@ ABYTES_MIN = _lib.aegis128x2_abytes_min()
ABYTES_MAX = _lib.aegis128x2_abytes_max()
TAILBYTES_MAX = _lib.aegis128x2_tailbytes_max()
ALIGNMENT = 64
RATE = 64
def calc_update_output_size(bytes_in: int, bytes_next: int) -> int:
@@ -31,8 +32,7 @@ def calc_update_output_size(bytes_in: int, bytes_next: int) -> int:
Returns:
Number of bytes that the next update will write.
"""
remainder = bytes_in % ALIGNMENT
return ((remainder + int(bytes_next)) // ALIGNMENT) * ALIGNMENT
return (((bytes_in % RATE) + bytes_next) // RATE) * RATE
def _ptr(buf):
@@ -651,7 +651,7 @@ class Encryptor:
f"state encrypt update failed: {err_name} written {written[0]}"
)
w = int(written[0])
# Advance counters for this update call
assert w == expected_out
self._bytes_in += message.nbytes
self._bytes_out += w
return out[:w]
@@ -810,11 +810,10 @@ class Decryptor:
RuntimeError: If the C update call fails.
"""
ct = memoryview(ct)
produced = calc_update_output_size(self._bytes_in, ct.nbytes)
requirement = produced + ALIGNMENT # libaegis requires larger than actual w
out = into if into is not None else bytearray(requirement)
expected_out = calc_update_output_size(self._bytes_in, ct.nbytes)
out = into if into is not None else bytearray(expected_out)
out = memoryview(out)
if out.nbytes < requirement:
if out.nbytes < expected_out:
raise TypeError("into length must be >= required capacity for this update")
written = ffi.new("size_t *")
rc = _lib.aegis128x2_state_decrypt_detached_update(
@@ -830,6 +829,7 @@ class Decryptor:
err_name = errno.errorcode.get(err_num, f"errno_{err_num}")
raise RuntimeError(f"state decrypt update failed: {err_name}")
w = int(written[0])
assert w == expected_out
self._bytes_in += ct.nbytes
self._bytes_out += w
return out[:w]
@@ -898,6 +898,7 @@ __all__ = [
"ABYTES_MAX",
"TAILBYTES_MAX",
"ALIGNMENT",
"RATE",
# helpers
"calc_update_output_size",
# one-shot functions

View File

@@ -19,6 +19,7 @@ ABYTES_MIN = _lib.aegis128x4_abytes_min()
ABYTES_MAX = _lib.aegis128x4_abytes_max()
TAILBYTES_MAX = _lib.aegis128x4_tailbytes_max()
ALIGNMENT = 64
RATE = 128
def calc_update_output_size(bytes_in: int, bytes_next: int) -> int:
@@ -31,8 +32,7 @@ def calc_update_output_size(bytes_in: int, bytes_next: int) -> int:
Returns:
Number of bytes that the next update will write.
"""
remainder = bytes_in % ALIGNMENT
return ((remainder + int(bytes_next)) // ALIGNMENT) * ALIGNMENT
return (((bytes_in % RATE) + bytes_next) // RATE) * RATE
def _ptr(buf):
@@ -651,7 +651,7 @@ class Encryptor:
f"state encrypt update failed: {err_name} written {written[0]}"
)
w = int(written[0])
# Advance counters for this update call
assert w == expected_out
self._bytes_in += message.nbytes
self._bytes_out += w
return out[:w]
@@ -810,11 +810,10 @@ class Decryptor:
RuntimeError: If the C update call fails.
"""
ct = memoryview(ct)
produced = calc_update_output_size(self._bytes_in, ct.nbytes)
requirement = produced + ALIGNMENT # libaegis requires larger than actual w
out = into if into is not None else bytearray(requirement)
expected_out = calc_update_output_size(self._bytes_in, ct.nbytes)
out = into if into is not None else bytearray(expected_out)
out = memoryview(out)
if out.nbytes < requirement:
if out.nbytes < expected_out:
raise TypeError("into length must be >= required capacity for this update")
written = ffi.new("size_t *")
rc = _lib.aegis128x4_state_decrypt_detached_update(
@@ -830,6 +829,7 @@ class Decryptor:
err_name = errno.errorcode.get(err_num, f"errno_{err_num}")
raise RuntimeError(f"state decrypt update failed: {err_name}")
w = int(written[0])
assert w == expected_out
self._bytes_in += ct.nbytes
self._bytes_out += w
return out[:w]
@@ -898,6 +898,7 @@ __all__ = [
"ABYTES_MAX",
"TAILBYTES_MAX",
"ALIGNMENT",
"RATE",
# helpers
"calc_update_output_size",
# one-shot functions

View File

@@ -19,6 +19,7 @@ ABYTES_MIN = _lib.aegis256_abytes_min()
ABYTES_MAX = _lib.aegis256_abytes_max()
TAILBYTES_MAX = _lib.aegis256_tailbytes_max()
ALIGNMENT = 16
RATE = 16
def calc_update_output_size(bytes_in: int, bytes_next: int) -> int:
@@ -31,8 +32,7 @@ def calc_update_output_size(bytes_in: int, bytes_next: int) -> int:
Returns:
Number of bytes that the next update will write.
"""
remainder = bytes_in % ALIGNMENT
return ((remainder + int(bytes_next)) // ALIGNMENT) * ALIGNMENT
return (((bytes_in % RATE) + bytes_next) // RATE) * RATE
def _ptr(buf):
@@ -651,7 +651,7 @@ class Encryptor:
f"state encrypt update failed: {err_name} written {written[0]}"
)
w = int(written[0])
# Advance counters for this update call
assert w == expected_out
self._bytes_in += message.nbytes
self._bytes_out += w
return out[:w]
@@ -810,11 +810,10 @@ class Decryptor:
RuntimeError: If the C update call fails.
"""
ct = memoryview(ct)
produced = calc_update_output_size(self._bytes_in, ct.nbytes)
requirement = produced + ALIGNMENT # libaegis requires larger than actual w
out = into if into is not None else bytearray(requirement)
expected_out = calc_update_output_size(self._bytes_in, ct.nbytes)
out = into if into is not None else bytearray(expected_out)
out = memoryview(out)
if out.nbytes < requirement:
if out.nbytes < expected_out:
raise TypeError("into length must be >= required capacity for this update")
written = ffi.new("size_t *")
rc = _lib.aegis256_state_decrypt_detached_update(
@@ -830,6 +829,7 @@ class Decryptor:
err_name = errno.errorcode.get(err_num, f"errno_{err_num}")
raise RuntimeError(f"state decrypt update failed: {err_name}")
w = int(written[0])
assert w == expected_out
self._bytes_in += ct.nbytes
self._bytes_out += w
return out[:w]
@@ -898,6 +898,7 @@ __all__ = [
"ABYTES_MAX",
"TAILBYTES_MAX",
"ALIGNMENT",
"RATE",
# helpers
"calc_update_output_size",
# one-shot functions

View File

@@ -19,6 +19,7 @@ ABYTES_MIN = _lib.aegis256x2_abytes_min()
ABYTES_MAX = _lib.aegis256x2_abytes_max()
TAILBYTES_MAX = _lib.aegis256x2_tailbytes_max()
ALIGNMENT = 32
RATE = 32
def calc_update_output_size(bytes_in: int, bytes_next: int) -> int:
@@ -31,8 +32,7 @@ def calc_update_output_size(bytes_in: int, bytes_next: int) -> int:
Returns:
Number of bytes that the next update will write.
"""
remainder = bytes_in % ALIGNMENT
return ((remainder + int(bytes_next)) // ALIGNMENT) * ALIGNMENT
return (((bytes_in % RATE) + bytes_next) // RATE) * RATE
def _ptr(buf):
@@ -651,7 +651,7 @@ class Encryptor:
f"state encrypt update failed: {err_name} written {written[0]}"
)
w = int(written[0])
# Advance counters for this update call
assert w == expected_out
self._bytes_in += message.nbytes
self._bytes_out += w
return out[:w]
@@ -810,11 +810,10 @@ class Decryptor:
RuntimeError: If the C update call fails.
"""
ct = memoryview(ct)
produced = calc_update_output_size(self._bytes_in, ct.nbytes)
requirement = produced + ALIGNMENT # libaegis requires larger than actual w
out = into if into is not None else bytearray(requirement)
expected_out = calc_update_output_size(self._bytes_in, ct.nbytes)
out = into if into is not None else bytearray(expected_out)
out = memoryview(out)
if out.nbytes < requirement:
if out.nbytes < expected_out:
raise TypeError("into length must be >= required capacity for this update")
written = ffi.new("size_t *")
rc = _lib.aegis256x2_state_decrypt_detached_update(
@@ -830,6 +829,7 @@ class Decryptor:
err_name = errno.errorcode.get(err_num, f"errno_{err_num}")
raise RuntimeError(f"state decrypt update failed: {err_name}")
w = int(written[0])
assert w == expected_out
self._bytes_in += ct.nbytes
self._bytes_out += w
return out[:w]
@@ -898,6 +898,7 @@ __all__ = [
"ABYTES_MAX",
"TAILBYTES_MAX",
"ALIGNMENT",
"RATE",
# helpers
"calc_update_output_size",
# one-shot functions

View File

@@ -19,6 +19,7 @@ ABYTES_MIN = _lib.aegis256x4_abytes_min()
ABYTES_MAX = _lib.aegis256x4_abytes_max()
TAILBYTES_MAX = _lib.aegis256x4_tailbytes_max()
ALIGNMENT = 64
RATE = 64
def calc_update_output_size(bytes_in: int, bytes_next: int) -> int:
@@ -31,8 +32,7 @@ def calc_update_output_size(bytes_in: int, bytes_next: int) -> int:
Returns:
Number of bytes that the next update will write.
"""
remainder = bytes_in % ALIGNMENT
return ((remainder + int(bytes_next)) // ALIGNMENT) * ALIGNMENT
return (((bytes_in % RATE) + bytes_next) // RATE) * RATE
def _ptr(buf):
@@ -651,7 +651,7 @@ class Encryptor:
f"state encrypt update failed: {err_name} written {written[0]}"
)
w = int(written[0])
# Advance counters for this update call
assert w == expected_out
self._bytes_in += message.nbytes
self._bytes_out += w
return out[:w]
@@ -810,11 +810,10 @@ class Decryptor:
RuntimeError: If the C update call fails.
"""
ct = memoryview(ct)
produced = calc_update_output_size(self._bytes_in, ct.nbytes)
requirement = produced + ALIGNMENT # libaegis requires larger than actual w
out = into if into is not None else bytearray(requirement)
expected_out = calc_update_output_size(self._bytes_in, ct.nbytes)
out = into if into is not None else bytearray(expected_out)
out = memoryview(out)
if out.nbytes < requirement:
if out.nbytes < expected_out:
raise TypeError("into length must be >= required capacity for this update")
written = ffi.new("size_t *")
rc = _lib.aegis256x4_state_decrypt_detached_update(
@@ -830,6 +829,7 @@ class Decryptor:
err_name = errno.errorcode.get(err_num, f"errno_{err_num}")
raise RuntimeError(f"state decrypt update failed: {err_name}")
w = int(written[0])
assert w == expected_out
self._bytes_in += ct.nbytes
self._bytes_out += w
return out[:w]
@@ -898,6 +898,7 @@ __all__ = [
"ABYTES_MAX",
"TAILBYTES_MAX",
"ALIGNMENT",
"RATE",
# helpers
"calc_update_output_size",
# one-shot functions

View File

@@ -6,6 +6,7 @@ Changes per variant:
- Replace module name (aegis256x4 -> target)
- Replace label (AEGIS-256X4 -> target label like AEGIS-128L)
- Replace only the ALIGNMENT = <int> value
- Replace only the RATE = <int> value
We do not touch alloc_aligned(...) calls or any code formatting. Blank lines
after ALIGNMENT are preserved.
@@ -35,10 +36,21 @@ VARIANT_ALIGN = {
"aegis128x4": 64,
}
# Variants and their RATE values
VARIANT_RATE = {
"aegis256": 16,
"aegis256x2": 32,
"aegis256x4": 64,
"aegis128l": 32,
"aegis128x2": 64,
"aegis128x4": 128,
}
TEMPLATE_NAME = "aegis256x4"
TEMPLATE_LABEL = "AEGIS-256X4"
ALIGNMENT_LINE_RE = re.compile(r"^(ALIGNMENT\s*=\s*)(\d+)(\s*)$", re.MULTILINE)
RATE_LINE_RE = re.compile(r"^(RATE\s*=\s*)(\d+)(\s*)$", re.MULTILINE)
def set_alignment_only(text: str, value: int) -> str:
@@ -55,6 +67,20 @@ def set_alignment_only(text: str, value: int) -> str:
return ALIGNMENT_LINE_RE.sub(_sub, text)
def set_rate_only(text: str, value: int) -> str:
"""Replace only the numeric RATE value, preserving surrounding whitespace and lines.
This preserves any empty lines following the RATE assignment because
the line ending is not part of the match; we keep any trailing spaces too.
"""
def _sub(m: re.Match[str]) -> str:
prefix, _num, suffix = m.group(1), m.group(2), m.group(3)
return f"{prefix}{value}{suffix}"
return RATE_LINE_RE.sub(_sub, text)
def algo_label(name: str) -> str:
"""Return the canonical label like AEGIS-256X4 for a module name like aegis256x4."""
if not name.startswith("aegis"):
@@ -70,6 +96,9 @@ def generate_variant(template_src: str, variant: str) -> str:
# 3) set ALIGNMENT constant value using fallback map
align_value = VARIANT_ALIGN.get(variant, 64)
s = set_alignment_only(s, align_value)
# 4) set RATE constant value using fallback map
rate_value = VARIANT_RATE.get(variant, 64)
s = set_rate_only(s, rate_value)
return s