from __future__ import annotations
import abc
import base64
import contextlib
import os
from typing import TYPE_CHECKING, Any, Callable
from sqlalchemy import String, Text, TypeDecorator
from sqlalchemy import func as sql_func
cryptography = None
with contextlib.suppress(ImportError):
from cryptography.fernet import Fernet
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
if TYPE_CHECKING:
from sqlalchemy.engine import Dialect
[docs]class EncryptionBackend(abc.ABC):
def mount_vault(self, key: str | bytes) -> None:
if isinstance(key, str):
key = key.encode()
@abc.abstractmethod
def init_engine(self, key: bytes | str) -> None: # pragma: nocover
pass
@abc.abstractmethod
def encrypt(self, value: Any) -> str: # pragma: nocover
pass
@abc.abstractmethod
def decrypt(self, value: Any) -> str: # pragma: nocover
pass
[docs]class PGCryptoBackend(EncryptionBackend):
"""PG Crypto backend."""
def init_engine(self, key: bytes | str) -> None:
if isinstance(key, str):
key = key.encode()
self.passphrase = base64.urlsafe_b64encode(key)
def encrypt(self, value: Any) -> str:
if not isinstance(value, str): # pragma: nocover
value = repr(value)
value = value.encode()
return sql_func.pgp_sym_encrypt(value, self.passphrase) # type: ignore[return-value]
def decrypt(self, value: Any) -> str:
if not isinstance(value, str): # pragma: nocover
value = str(value)
return sql_func.pgp_sym_decrypt(value, self.passphrase) # type: ignore[return-value]
[docs]class FernetBackend(EncryptionBackend):
"""Encryption Using a Fernet backend"""
def mount_vault(self, key: str | bytes) -> None:
if isinstance(key, str):
key = key.encode()
digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) # pyright: ignore[reportPossiblyUnboundVariable]
digest.update(key)
engine_key = digest.finalize()
self.init_engine(engine_key)
def init_engine(self, key: bytes | str) -> None:
if isinstance(key, str):
key = key.encode()
self.key = base64.urlsafe_b64encode(key)
self.fernet = Fernet(self.key) # pyright: ignore[reportPossiblyUnboundVariable]
def encrypt(self, value: Any) -> str:
if not isinstance(value, str):
value = repr(value)
value = value.encode()
encrypted = self.fernet.encrypt(value)
return encrypted.decode("utf-8")
def decrypt(self, value: Any) -> str:
if not isinstance(value, str): # pragma: nocover
value = str(value)
decrypted: str | bytes = self.fernet.decrypt(value.encode())
if not isinstance(decrypted, str):
decrypted = decrypted.decode("utf-8")
return decrypted
[docs]class EncryptedString(TypeDecorator[str]):
"""Used to store encrypted values in a database"""
impl = String
cache_ok = True
[docs] def __init__(
self,
key: str | bytes | Callable[[], str | bytes] = os.urandom(32),
backend: type[EncryptionBackend] = FernetBackend,
**kwargs: Any,
) -> None:
super().__init__()
self.key = key
self.backend = backend()
@property
def python_type(self) -> type[str]:
return str
[docs] def load_dialect_impl(self, dialect: Dialect) -> Any:
if dialect.name in {"mysql", "mariadb"}:
return dialect.type_descriptor(Text())
if dialect.name == "oracle":
return dialect.type_descriptor(String(length=4000))
return dialect.type_descriptor(String())
[docs] def process_bind_param(self, value: Any, dialect: Dialect) -> str | None:
if value is None:
return value
self.mount_vault()
return self.backend.encrypt(value)
[docs] def process_result_value(self, value: Any, dialect: Dialect) -> str | None:
if value is None:
return value
self.mount_vault()
return self.backend.decrypt(value)
def mount_vault(self) -> None:
key = self.key() if callable(self.key) else self.key
self.backend.mount_vault(key)
[docs]class EncryptedText(EncryptedString):
"""Encrypted Clob"""
impl = Text
cache_ok = True
[docs] def load_dialect_impl(self, dialect: Dialect) -> Any:
return dialect.type_descriptor(Text())