"""Safe pickle serialization with HMAC integrity validation.
This module provides HMAC-signed pickle operations to prevent arbitrary code
execution from tampered cache files. All file-based pickle operations in the
codebase should use these functions instead of raw pickle.load/pickle.dump.
The HMAC key is stored in a `.pickle_hmac_key` file within the cache directory
(or a default location). Files written with safe_dump can only be loaded by
safe_load if the HMAC signature matches, preventing deserialization of
untrusted data.
Also provides deterministic_hash() as a replacement for Python's
non-deterministic built-in hash() function.
"""
import hashlib
import hmac
from pathlib import Path
import pickle
import secrets
from typing import Any, Optional
# Default location for the HMAC key
_DEFAULT_KEY_DIR = Path.home() / ".ergodic_insurance"
_KEY_FILENAME = ".pickle_hmac_key"
_SIGNATURE_LENGTH = 32 # SHA-256 produces 32 bytes
def _get_key_path(key_dir: Optional[Path] = None) -> Path:
"""Get the path to the HMAC key file."""
directory = key_dir or _DEFAULT_KEY_DIR
return directory / _KEY_FILENAME
def _get_or_create_hmac_key(key_dir: Optional[Path] = None) -> bytes:
"""Get or create a persistent HMAC key for pickle validation.
Args:
key_dir: Directory to store the key file. Defaults to ~/.ergodic_insurance/
Returns:
32-byte HMAC key
"""
key_path = _get_key_path(key_dir)
if key_path.exists():
return key_path.read_bytes()
key_path.parent.mkdir(parents=True, exist_ok=True)
key = secrets.token_bytes(32)
key_path.write_bytes(key)
return key
[docs]
def safe_dump(
obj: Any,
f,
protocol: int = pickle.HIGHEST_PROTOCOL,
key_dir: Optional[Path] = None,
) -> None:
"""Pickle dump with HMAC signature prepended.
Args:
obj: Object to serialize
f: Writable binary file object
protocol: Pickle protocol version
key_dir: Directory containing the HMAC key
"""
data = pickle.dumps(obj, protocol=protocol)
key = _get_or_create_hmac_key(key_dir)
signature = hmac.new(key, data, hashlib.sha256).digest()
f.write(signature)
f.write(data)
[docs]
def safe_load(f, key_dir: Optional[Path] = None) -> Any:
"""Pickle load with HMAC verification.
Args:
f: Readable binary file object
key_dir: Directory containing the HMAC key
Returns:
Deserialized object
Raises:
ValueError: If HMAC verification fails or file is too short
"""
content = f.read()
if len(content) < _SIGNATURE_LENGTH:
raise ValueError("Invalid pickle file: too short for HMAC verification")
signature = content[:_SIGNATURE_LENGTH]
data = content[_SIGNATURE_LENGTH:]
key = _get_or_create_hmac_key(key_dir)
expected_sig = hmac.new(key, data, hashlib.sha256).digest()
if not hmac.compare_digest(signature, expected_sig):
raise ValueError(
"Pickle file integrity check failed: HMAC mismatch. "
"File may have been tampered with or was created by a different key."
)
return pickle.loads(data) # noqa: S301 - HMAC-verified before loading
[docs]
def safe_dumps(
obj: Any,
protocol: int = pickle.HIGHEST_PROTOCOL,
key_dir: Optional[Path] = None,
) -> bytes:
"""Pickle dumps with HMAC signature prepended.
Args:
obj: Object to serialize
protocol: Pickle protocol version
key_dir: Directory containing the HMAC key
Returns:
HMAC signature + pickled bytes
"""
data = pickle.dumps(obj, protocol=protocol)
key = _get_or_create_hmac_key(key_dir)
signature = hmac.new(key, data, hashlib.sha256).digest()
return signature + data
[docs]
def safe_loads(data: bytes, key_dir: Optional[Path] = None) -> Any:
"""Pickle loads with HMAC verification.
Args:
data: HMAC signature + pickled bytes
key_dir: Directory containing the HMAC key
Returns:
Deserialized object
Raises:
ValueError: If HMAC verification fails or data is too short
"""
if len(data) < _SIGNATURE_LENGTH:
raise ValueError("Invalid pickle data: too short for HMAC verification")
signature = data[:_SIGNATURE_LENGTH]
payload = data[_SIGNATURE_LENGTH:]
key = _get_or_create_hmac_key(key_dir)
expected_sig = hmac.new(key, payload, hashlib.sha256).digest()
if not hmac.compare_digest(signature, expected_sig):
raise ValueError(
"Pickle data integrity check failed: HMAC mismatch. "
"Data may have been tampered with or was created by a different key."
)
return pickle.loads(payload) # noqa: S301 - HMAC-verified before loading
[docs]
def deterministic_hash(*args: str, length: int = 16) -> str:
"""Generate a deterministic hash from string arguments.
Uses SHA-256 instead of Python's non-deterministic hash().
This produces the same result across process restarts regardless
of PYTHONHASHSEED.
Args:
*args: String values to hash
length: Number of hex characters to return (max 64)
Returns:
Hex digest string of specified length
"""
combined = "|".join(str(a) for a in args)
return hashlib.sha256(combined.encode()).hexdigest()[:length]