diff --git a/packages/google-auth/google/auth/aio/transport/mtls.py b/packages/google-auth/google/auth/aio/transport/mtls.py index b85d30b53485..bbbd39ff1950 100644 --- a/packages/google-auth/google/auth/aio/transport/mtls.py +++ b/packages/google-auth/google/auth/aio/transport/mtls.py @@ -17,42 +17,17 @@ """ import asyncio -import contextlib import logging -import os import ssl -import tempfile from typing import Optional from google.auth import exceptions -import google.auth.transport._mtls_helper +from google.auth.transport._mtls_helper import secure_cert_key_paths import google.auth.transport.mtls _LOGGER = logging.getLogger(__name__) -@contextlib.contextmanager -def _create_temp_file(content: bytes): - """Creates a temporary file with the given content. - - Args: - content (bytes): The content to write to the file. - - Yields: - str: The path to the temporary file. - """ - # Create a temporary file that is readable only by the owner. - fd, file_path = tempfile.mkstemp() - try: - with os.fdopen(fd, "wb") as f: - f.write(content) - yield file_path - finally: - # Securely delete the file after use. - if os.path.exists(file_path): - os.remove(file_path) - - def make_client_cert_ssl_context( cert_bytes: bytes, key_bytes: bytes, passphrase: Optional[bytes] = None ) -> ssl.SSLContext: @@ -71,13 +46,17 @@ def make_client_cert_ssl_context( Raises: google.auth.exceptions.TransportError: If there is an error loading the certificate. """ - with _create_temp_file(cert_bytes) as cert_path, _create_temp_file( - key_bytes - ) as key_path: + with secure_cert_key_paths(cert_bytes, key_bytes, passphrase=passphrase) as ( + cert_path, + key_path, + passphrase_val, + ): try: context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) context.load_cert_chain( - certfile=cert_path, keyfile=key_path, password=passphrase + certfile=cert_path, + keyfile=key_path, + password=passphrase_val or "", ) return context except (ssl.SSLError, OSError, IOError, ValueError, RuntimeError) as exc: diff --git a/packages/google-auth/google/auth/compute_engine/_mtls.py b/packages/google-auth/google/auth/compute_engine/_mtls.py index a427e66a89b3..d475c1e59a9e 100644 --- a/packages/google-auth/google/auth/compute_engine/_mtls.py +++ b/packages/google-auth/google/auth/compute_engine/_mtls.py @@ -120,7 +120,7 @@ def __init__( self.ssl_context = ssl.create_default_context() self.ssl_context.load_verify_locations(cafile=mds_mtls_config.ca_cert_path) self.ssl_context.load_cert_chain( - certfile=mds_mtls_config.client_combined_cert_path + certfile=mds_mtls_config.client_combined_cert_path, password="" ) super(MdsMtlsAdapter, self).__init__(*args, **kwargs) diff --git a/packages/google-auth/google/auth/environment_vars.py b/packages/google-auth/google/auth/environment_vars.py index c7d706467ed4..c622f1773531 100644 --- a/packages/google-auth/google/auth/environment_vars.py +++ b/packages/google-auth/google/auth/environment_vars.py @@ -129,3 +129,6 @@ "GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES" ) """Environment variable to prevent agent token sharing for GCP services.""" + +GOOGLE_API_USE_MTLS_ENDPOINT = "GOOGLE_API_USE_MTLS_ENDPOINT" +"""Environment variable controlling whether to use mTLS endpoint or not.""" diff --git a/packages/google-auth/google/auth/identity_pool.py b/packages/google-auth/google/auth/identity_pool.py index 30819ef0485a..ca13a2b9f927 100644 --- a/packages/google-auth/google/auth/identity_pool.py +++ b/packages/google-auth/google/auth/identity_pool.py @@ -152,13 +152,9 @@ def __init__(self, trust_chain_path, leaf_cert_callback): @_helpers.copy_docstring(SubjectTokenSupplier) def get_subject_token(self, context, request): - # Import OpennSSL inline because it is an extra import only required by customers - # using mTLS. - from OpenSSL import crypto + from cryptography import x509 - leaf_cert = crypto.load_certificate( - crypto.FILETYPE_PEM, self._leaf_cert_callback() - ) + leaf_cert = x509.load_pem_x509_certificate(self._leaf_cert_callback()) trust_chain = self._read_trust_chain() cert_chain = [] @@ -184,9 +180,7 @@ def get_subject_token(self, context, request): return json.dumps(cert_chain) def _read_trust_chain(self): - # Import OpennSSL inline because it is an extra import only required by customers - # using mTLS. - from OpenSSL import crypto + from cryptography import x509 certificate_trust_chain = [] # If no trust chain path was provided, return an empty list. @@ -204,9 +198,7 @@ def _read_trust_chain(self): cert_data = b"-----BEGIN CERTIFICATE-----" + cert_block try: # Load each certificate and add it to the trust chain. - cert = crypto.load_certificate( - crypto.FILETYPE_PEM, cert_data - ) + cert = x509.load_pem_x509_certificate(cert_data) certificate_trust_chain.append(cert) except Exception as e: raise exceptions.RefreshError( @@ -221,13 +213,11 @@ def _read_trust_chain(self): ) def _encode_cert(cert): - # Import OpennSSL inline because it is an extra import only required by customers - # using mTLS. - from OpenSSL import crypto + from cryptography.hazmat.primitives import serialization - return base64.b64encode( - crypto.dump_certificate(crypto.FILETYPE_ASN1, cert) - ).decode("utf-8") + return base64.b64encode(cert.public_bytes(serialization.Encoding.DER)).decode( + "utf-8" + ) def _parse_token_data(token_content, format_type="text", subject_token_field_name=None): diff --git a/packages/google-auth/google/auth/transport/_custom_tls_signer.py b/packages/google-auth/google/auth/transport/_custom_tls_signer.py index 9279158d45c6..1ac0d081e2da 100644 --- a/packages/google-auth/google/auth/transport/_custom_tls_signer.py +++ b/packages/google-auth/google/auth/transport/_custom_tls_signer.py @@ -23,8 +23,6 @@ import os import sys -import cffi # type: ignore - from google.auth import exceptions _LOGGER = logging.getLogger(__name__) @@ -45,11 +43,6 @@ ) -# Cast SSL_CTX* to void* -def _cast_ssl_ctx_to_void_p_pyopenssl(ssl_ctx): - return ctypes.cast(int(cffi.FFI().cast("intptr_t", ssl_ctx)), ctypes.c_void_p) - - # Cast SSL_CTX* to void* def _cast_ssl_ctx_to_void_p_stdlib(context): return ctypes.c_void_p.from_address( @@ -274,7 +267,7 @@ def attach_to_ssl_context(self, ctx): if not self._offload_lib.ConfigureSslContext( self._sign_callback, ctypes.c_char_p(self._cert), - _cast_ssl_ctx_to_void_p_pyopenssl(ctx._ctx._context), + _cast_ssl_ctx_to_void_p_stdlib(ctx), ): raise exceptions.MutualTLSChannelError( "failed to configure ECP Offload SSL context" diff --git a/packages/google-auth/google/auth/transport/_mtls_helper.py b/packages/google-auth/google/auth/transport/_mtls_helper.py index d6450291c7f2..fbead4fc54ea 100644 --- a/packages/google-auth/google/auth/transport/_mtls_helper.py +++ b/packages/google-auth/google/auth/transport/_mtls_helper.py @@ -14,11 +14,13 @@ """Helper functions for getting mTLS cert and key.""" +import contextlib import json import logging from os import environ, getenv, path import re import subprocess +from typing import cast, Generator, Optional, Tuple, Union from google.auth import _agent_identity_utils from google.auth import environment_vars @@ -65,6 +67,237 @@ ) +@contextlib.contextmanager +def secure_cert_key_paths( + cert: Union[str, bytes], + key: Union[str, bytes], + passphrase: Optional[bytes] = None, +) -> Generator[Tuple[str, str, Optional[bytes]], None, None]: + """Provides secure file paths for certificate and key. + + Standard TLS libraries (like Python's standard library `ssl`) require file paths to + load credentials. To minimize exposure of raw private key bytes on physical storage, + this context manager implements a three-tier fallback strategy: yielding pass-through + paths (Tier 1), using RAM-backed virtual files on Linux (Tier 2), or falling back + to encrypted temporary files on disk (Tier 3). + + Args: + cert (Union[str, bytes]): Certificate path or raw PEM content bytes. + key (Union[str, bytes]): Private key path or raw PEM content bytes. + passphrase (Optional[bytes]): Optional passphrase for the private key. + + Yields: + Tuple[str, str, Optional[bytes]]: The certificate path, key path, and + the passphrase needed to load the key (either the user's original, + or the newly generated one if Tier 3 had to encrypt the key). + """ + import os + import sys + + # Tier 1: Pass-through (No-op). If the caller already provided file paths, + # we yield them directly to avoid any unnecessary file creation. + if isinstance(cert, str) and isinstance(key, str): + yield cert, key, passphrase + return + + cert_bytes = cert if isinstance(cert, bytes) else None + key_bytes = key if isinstance(key, bytes) else None + + # Tier 2: Linux RAM-backed virtual files. If supported by the OS, we write + # the bytes to anonymous in-memory files using memfd_create. This yields + # /proc/self/fd/... paths, keeping the private key entirely in memory. + if sys.platform == "linux" and hasattr(os, "memfd_create"): + cm = _memfd_cert_key_paths(cert_bytes, key_bytes) + try: + cert_path, key_path = cm.__enter__() + except OSError: + pass # Fallback to Tier 3 on failure. + else: + try: + # Handle cases where path exists but might be restricted. + if (cert_path is None or os.path.exists(cert_path)) and ( + key_path is None or os.path.exists(key_path) + ): + yield cast(str, cert_path or cert), cast( + str, key_path or key + ), passphrase + return + finally: + import sys + + exc_info = sys.exc_info() + cm.__exit__( + *(exc_info if exc_info[0] is not None else (None, None, None)) + ) + # If verification failed, fall through to Tier 3. + + # Tier 3: Fallback Encrypted Temp Files. If in-memory files are not supported + # (macOS/Windows), we write to disk. To protect the key, we encrypt plaintext + # keys on-the-fly and securely wipe the files with null bytes during cleanup. + with _tempfile_cert_key_paths(cert_bytes, key_bytes, passphrase) as ( + cert_path, + key_path, + new_passphrase, + ): + yield cast(str, cert_path or cert), cast(str, key_path or key), new_passphrase + + +def _encrypt_key_if_plaintext( + key_bytes: bytes, passphrase: Optional[bytes] +) -> Tuple[bytes, Optional[bytes]]: + """Encrypts a plaintext PEM key if necessary, returning the bytes and passphrase. + + If the key is already encrypted, returns it as-is. + """ + from cryptography.hazmat.primitives import serialization + import secrets + + try: + pkey = serialization.load_pem_private_key(key_bytes, password=None) + # It's plaintext, encrypt it. + target_passphrase = ( + passphrase + if passphrase is not None + else secrets.token_hex(32).encode("utf-8") + ) + encrypted_content = pkey.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.BestAvailableEncryption( + target_passphrase + ), + ) + return encrypted_content, target_passphrase + except (ValueError, TypeError): + # Likely already encrypted or invalid, return as-is. + return key_bytes, passphrase + + +def _secure_wipe_and_remove(file_path: str): + """Overwrites a file with null bytes before deleting it. + + This is an extra security measure to make file recovery harder. However, on modern + solid-state drives (SSDs), the hardware optimizes where data is written, meaning + the original private key bytes might still physically remain on the storage chips + until the drive cleans them up. + """ + import os + + if not os.path.exists(file_path): + return + try: + size = os.path.getsize(file_path) + with open(file_path, "r+b") as f: + f.write(b"\0" * size) + f.flush() + os.fsync(f.fileno()) + except OSError: + pass # Ignore permission/lock errors during cleanup. + finally: + try: + os.remove(file_path) + except OSError: + pass + + +@contextlib.contextmanager +def _memfd_cert_key_paths( + cert_bytes: Optional[bytes], key_bytes: Optional[bytes] +) -> Generator[Tuple[Optional[str], Optional[str]], None, None]: + """Creates secure, in-memory virtual files on Linux using memfd_create. + + Yields: + Tuple[Optional[str], Optional[str]]: In-memory file paths pointing to + the active descriptors (e.g., '/proc/self/fd/3'). + """ + import os + + cleanup_fds = [] + cert_path, key_path = None, None + + try: + if cert_bytes is not None: + # MFD_CLOEXEC prevents FD leaks to spawned subprocesses. + fd_cert = os.memfd_create("mtls_cert", os.MFD_CLOEXEC) # type: ignore[attr-defined] + cleanup_fds.append(fd_cert) + with os.fdopen(fd_cert, "wb", closefd=False) as f: + f.write(cert_bytes) + cert_path = f"/proc/self/fd/{fd_cert}" + + if key_bytes is not None: + fd_key = os.memfd_create("mtls_key", os.MFD_CLOEXEC) # type: ignore[attr-defined] + cleanup_fds.append(fd_key) + with os.fdopen(fd_key, "wb", closefd=False) as f: + f.write(key_bytes) + key_path = f"/proc/self/fd/{fd_key}" + + yield cert_path, key_path + finally: + # Closing the descriptors automatically frees the RAM allocation. + for fd in cleanup_fds: + try: + os.close(fd) + except OSError: + pass + + +@contextlib.contextmanager +def _tempfile_cert_key_paths( + cert_bytes: Optional[bytes], + key_bytes: Optional[bytes], + passphrase: Optional[bytes], +) -> Generator[Tuple[Optional[str], Optional[str], Optional[bytes]], None, None]: + """Creates secure temporary file paths on disk, encrypting private keys. + + Yields: + Tuple[Optional[str], Optional[str], Optional[bytes]]: The temporary file + paths and the passphrase needed to load the key. + """ + import os + import tempfile + + # Prioritize RAM-backed /dev/shm to avoid writing secrets to physical storage. + tmp_dir = "/dev/shm" if os.path.isdir("/dev/shm") else None + cert_path, key_path = None, None + cleanup_files = [] + new_passphrase = passphrase + + try: + if cert_bytes is not None: + fd, cert_path = tempfile.mkstemp(dir=tmp_dir) + cleanup_files.append(cert_path) + with os.fdopen(fd, "wb") as f: + f.write(cert_bytes) + f.flush() + os.fsync(f.fileno()) + + if key_bytes is not None: + # Encrypt plaintext keys on-the-fly before dropping to disk. + encrypted_key_bytes, new_passphrase = _encrypt_key_if_plaintext( + key_bytes, passphrase + ) + + fd, key_path = tempfile.mkstemp(dir=tmp_dir) + cleanup_files.append(key_path) + with os.fdopen(fd, "wb") as f: + f.write(encrypted_key_bytes) + f.flush() + os.fsync(f.fileno()) + + yield cert_path, key_path, new_passphrase + finally: + for file_path in cleanup_files: + try: + # Wiping the private key with null bytes before removal. + if file_path == key_path: + _secure_wipe_and_remove(file_path) + else: + if os.path.exists(file_path): + os.remove(file_path) + except OSError: + pass + + def _check_config_path(config_path): """Checks for config file path. If it exists, returns the absolute path with user expansion; otherwise returns None. @@ -436,16 +669,19 @@ def client_cert_callback(): bytes: The decrypted private key in PEM format. Raises: - ImportError: If pyOpenSSL is not installed. - OpenSSL.crypto.Error: If there is any problem decrypting the private key. + ValueError: If there is any problem decrypting the private key. """ - from OpenSSL import crypto + from cryptography.hazmat.primitives import serialization # First convert encrypted_key_bytes to PKey object - pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key, passphrase=passphrase) + pkey = serialization.load_pem_private_key(key, password=passphrase) # Then dump the decrypted key bytes - return crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey) + return pkey.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) def check_use_client_cert(): diff --git a/packages/google-auth/google/auth/transport/mtls.py b/packages/google-auth/google/auth/transport/mtls.py index 666a6ca1fd91..c3c1cf186c2b 100644 --- a/packages/google-auth/google/auth/transport/mtls.py +++ b/packages/google-auth/google/auth/transport/mtls.py @@ -14,12 +14,19 @@ """Utilites for mutual TLS.""" +import logging from os import getenv +import ssl +from typing import Optional +from google.auth import environment_vars from google.auth import exceptions from google.auth.transport import _mtls_helper +_LOGGER = logging.getLogger(__name__) + + def has_default_client_cert_source(include_context_aware=True): """Check if default client SSL credentials exists on the device. @@ -60,7 +67,7 @@ def default_client_cert_source(): client certificate bytes and private key bytes, both in PEM format. Raises: - google.auth.exceptions.DefaultClientCertSourceError: If the default + google.auth.exceptions.MutualTLSChannelError: If the default client SSL credentials don't exist or are malformed. """ if not has_default_client_cert_source(include_context_aware=True): @@ -71,7 +78,12 @@ def default_client_cert_source(): def callback(): try: _, cert_bytes, key_bytes = _mtls_helper.get_client_cert_and_key() - except (OSError, RuntimeError, ValueError) as caught_exc: + except ( + exceptions.ClientCertError, + OSError, + RuntimeError, + ValueError, + ) as caught_exc: new_exc = exceptions.MutualTLSChannelError(caught_exc) raise new_exc from caught_exc @@ -96,7 +108,7 @@ def default_client_encrypted_cert_source(cert_path, key_path): returns the cert_path, key_path and passphrase bytes. Raises: - google.auth.exceptions.DefaultClientCertSourceError: If any problem + google.auth.exceptions.MutualTLSChannelError: If any problem occurs when loading or saving the client certificate and key. """ if not has_default_client_cert_source(include_context_aware=True): @@ -140,3 +152,163 @@ def should_use_client_cert(): bool: indicating whether the client certificate should be used for mTLS. """ return _mtls_helper.check_use_client_cert() + + +def load_client_cert_into_context( + ctx: ssl.SSLContext, + cert_bytes: bytes, + key_bytes: bytes, + passphrase: Optional[bytes] = None, +) -> None: + """Load a client certificate and key into an SSL context. + + Args: + ctx (ssl.SSLContext): The SSL context to load the certificate and key into. + cert_bytes (bytes): The client certificate bytes in PEM format. + key_bytes (bytes): The client private key bytes in PEM format. + passphrase (Optional[bytes]): The passphrase for the client private key. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If the SSL context is invalid, + or if loading the certificate and key fails. + """ + if ctx is None or not hasattr(ctx, "load_cert_chain"): + raise exceptions.MutualTLSChannelError( + "Failed to load client certificate and key for mTLS. The provided context " + "object is invalid or does not support loading certificate chains." + ) + + try: + with _mtls_helper.secure_cert_key_paths( + cert_bytes, key_bytes, passphrase=passphrase + ) as ( + cert_path, + key_path, + passphrase_val, + ): + ctx.load_cert_chain( + certfile=cert_path, keyfile=key_path, password=passphrase_val + ) + except ( + ssl.SSLError, + OSError, + ValueError, + RuntimeError, + ) as caught_exc: + new_exc = exceptions.MutualTLSChannelError( + "Failed to load client certificate and key for mTLS." + ) + raise new_exc from caught_exc + + +def make_client_cert_ssl_context( + cert_bytes: bytes, + key_bytes: bytes, + passphrase: Optional[bytes] = None, +) -> ssl.SSLContext: + """Create a default SSL context loaded with the client certificate and key. + + Args: + cert_bytes (bytes): The client certificate bytes in PEM format. + key_bytes (bytes): The client private key bytes in PEM format. + passphrase (Optional[bytes]): The passphrase for the client private key. + + Returns: + ssl.SSLContext: The SSL context loaded with the client certificate and key. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If loading the certificate and key fails. + """ + ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + load_client_cert_into_context(ctx, cert_bytes, key_bytes, passphrase=passphrase) + return ctx + + +def load_default_client_cert(ctx: ssl.SSLContext) -> bool: + """Load the default client certificate and key into an SSL context if configured. + + If client certificates are enabled and a default client certificate source is + found, the certificate and key are loaded into the SSL context. + + Args: + ctx (ssl.SSLContext): The SSL context to load the default client certificate + and key into. + + Returns: + bool: True if client certificates are enabled and the default client + certificate was successfully loaded. False if client certificates + are disabled or if no default certificate source is configured. + + Raises: + google.auth.exceptions.ClientCertError: If the default client certificate + source exists but cannot be loaded or parsed. + google.auth.exceptions.MutualTLSChannelError: If the default client certificate + or key is malformed. + """ + if not should_use_client_cert() or not has_default_client_cert_source(): + return False + ( + has_cert, + cert_bytes, + key_bytes, + passphrase, + ) = _mtls_helper.get_client_ssl_credentials() + if not has_cert: + return False + load_client_cert_into_context(ctx, cert_bytes, key_bytes, passphrase) + return True + + +def get_default_ssl_context() -> Optional[ssl.SSLContext]: + """Get a default SSL context loaded with the default client certificate. + + Returns: + ssl.SSLContext: An SSL context loaded with the default client + certificate, or None if client certificates are not configured + or available. + + Raises: + google.auth.exceptions.ClientCertError: If the default client certificate + source exists but cannot be loaded or parsed. + google.auth.exceptions.MutualTLSChannelError: If the default client certificate + or key is malformed. + """ + ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + return ctx if load_default_client_cert(ctx) else None + + +def should_use_mtls_endpoint( + client_cert_available: Optional[bool] = None, +) -> bool: + """Determine whether to use an mTLS endpoint. + + This relies on the GOOGLE_API_USE_MTLS_ENDPOINT environment variable. If set to + "always", returns True. If set to "never", returns False. If set to "auto" + or unset, returns whether a client certificate is available. + + Args: + client_cert_available (bool): indicating if a client certificate + is available. If None, this is determined by checking if client + certificates are enabled and a default source is present. + + Returns: + bool: indicating if an mTLS endpoint should be used. + """ + if client_cert_available is None: + client_cert_available = should_use_client_cert() + + use_mtls_endpoint = getenv(environment_vars.GOOGLE_API_USE_MTLS_ENDPOINT, "auto") + use_mtls_endpoint = use_mtls_endpoint.lower() + if use_mtls_endpoint == "always": + return True + if use_mtls_endpoint == "never": + return False + if use_mtls_endpoint == "auto": + return client_cert_available + + _LOGGER.warning( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value %r. Accepted " + "values: never, auto, always. Defaulting to auto.", + use_mtls_endpoint, + ) + return client_cert_available diff --git a/packages/google-auth/google/auth/transport/requests.py b/packages/google-auth/google/auth/transport/requests.py index 9735762c4414..4c435eac3cfc 100644 --- a/packages/google-auth/google/auth/transport/requests.py +++ b/packages/google-auth/google/auth/transport/requests.py @@ -204,30 +204,41 @@ class _MutualTlsAdapter(requests.adapters.HTTPAdapter): key (bytes): client private key in PEM format Raises: - ImportError: if certifi or pyOpenSSL is not installed - OpenSSL.crypto.Error: if client cert or key is invalid + ImportError: if certifi is not installed """ def __init__(self, cert, key): import certifi - from OpenSSL import crypto - import urllib3.contrib.pyopenssl # type: ignore - - urllib3.contrib.pyopenssl.inject_into_urllib3() - - pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key) - x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert) + import ssl ctx_poolmanager = create_urllib3_context() ctx_poolmanager.load_verify_locations(cafile=certifi.where()) - ctx_poolmanager._ctx.use_certificate(x509) - ctx_poolmanager._ctx.use_privatekey(pkey) - self._ctx_poolmanager = ctx_poolmanager ctx_proxymanager = create_urllib3_context() ctx_proxymanager.load_verify_locations(cafile=certifi.where()) - ctx_proxymanager._ctx.use_certificate(x509) - ctx_proxymanager._ctx.use_privatekey(pkey) + + with _mtls_helper.secure_cert_key_paths(cert, key) as ( + cert_path, + key_path, + passphrase, + ): + try: + ctx_poolmanager.load_cert_chain( + certfile=cert_path, + keyfile=key_path, + password=passphrase or "", + ) + ctx_proxymanager.load_cert_chain( + certfile=cert_path, + keyfile=key_path, + password=passphrase or "", + ) + except (ssl.SSLError, OSError, IOError, ValueError, RuntimeError) as exc: + raise exceptions.MutualTLSChannelError( + "Failed to configure client certificate and key for mTLS." + ) from exc + + self._ctx_poolmanager = ctx_poolmanager self._ctx_proxymanager = ctx_proxymanager super(_MutualTlsAdapter, self).__init__() @@ -258,7 +269,7 @@ class _MutualTlsOffloadAdapter(requests.adapters.HTTPAdapter): } Raises: - ImportError: if certifi or pyOpenSSL is not installed + ImportError: if certifi is not installed google.auth.exceptions.MutualTLSChannelError: If mutual TLS channel creation failed for any reason. """ @@ -270,10 +281,6 @@ def __init__(self, enterprise_cert_file_path): self.signer = _custom_tls_signer.CustomTlsSigner(enterprise_cert_file_path) self.signer.load_libraries() - import urllib3.contrib.pyopenssl - - urllib3.contrib.pyopenssl.inject_into_urllib3() - poolmanager = create_urllib3_context() poolmanager.load_verify_locations(cafile=certifi.where()) self.signer.attach_to_ssl_context(poolmanager) @@ -449,11 +456,6 @@ def configure_mtls_channel(self, client_cert_callback=None): if not use_client_cert: self._is_mtls = False return - try: - import OpenSSL - except ImportError as caught_exc: - new_exc = exceptions.MutualTLSChannelError(caught_exc) - raise new_exc from caught_exc try: ( @@ -471,10 +473,14 @@ def configure_mtls_channel(self, client_cert_callback=None): except ( exceptions.ClientCertError, ImportError, - OpenSSL.crypto.Error, + ValueError, ) as caught_exc: + self._is_mtls = False new_exc = exceptions.MutualTLSChannelError(caught_exc) raise new_exc from caught_exc + except Exception: + self._is_mtls = False + raise def request( self, diff --git a/packages/google-auth/google/auth/transport/urllib3.py b/packages/google-auth/google/auth/transport/urllib3.py index de07007a946c..239cbbb2f455 100644 --- a/packages/google-auth/google/auth/transport/urllib3.py +++ b/packages/google-auth/google/auth/transport/urllib3.py @@ -174,22 +174,29 @@ def _make_mutual_tls_http(cert, key): urllib3.PoolManager: Mutual TLS HTTP connection. Raises: - ImportError: If certifi or pyOpenSSL is not installed. - OpenSSL.crypto.Error: If the cert or key is invalid. + ValueError: If the cert or key is invalid. """ import certifi - from OpenSSL import crypto - import urllib3.contrib.pyopenssl # type: ignore + import ssl - urllib3.contrib.pyopenssl.inject_into_urllib3() ctx = urllib3.util.ssl_.create_urllib3_context() ctx.load_verify_locations(cafile=certifi.where()) - pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key) - x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert) - - ctx._ctx.use_certificate(x509) - ctx._ctx.use_privatekey(pkey) + with _mtls_helper.secure_cert_key_paths(cert, key) as ( + cert_path, + key_path, + passphrase, + ): + try: + ctx.load_cert_chain( + certfile=cert_path, + keyfile=key_path, + password=passphrase or "", + ) + except (ssl.SSLError, OSError, IOError, ValueError, RuntimeError) as exc: + raise exceptions.MutualTLSChannelError( + "Failed to configure client certificate and key for mTLS." + ) from exc http = urllib3.PoolManager(ssl_context=ctx) return http @@ -341,11 +348,6 @@ def configure_mtls_channel(self, client_cert_callback=None): return False else: self._is_mtls = True - try: - import OpenSSL - except ImportError as caught_exc: - new_exc = exceptions.MutualTLSChannelError(caught_exc) - raise new_exc from caught_exc try: found_cert_key, cert, key = transport._mtls_helper.get_client_cert_and_key( @@ -357,13 +359,18 @@ def configure_mtls_channel(self, client_cert_callback=None): self._cached_cert = cert else: self.http = _make_default_http() + self._is_mtls = False except ( exceptions.ClientCertError, ImportError, - OpenSSL.crypto.Error, + ValueError, ) as caught_exc: + self._is_mtls = False new_exc = exceptions.MutualTLSChannelError(caught_exc) raise new_exc from caught_exc + except Exception: + self._is_mtls = False + raise if self._has_user_provided_http: self._has_user_provided_http = False diff --git a/packages/google-auth/noxfile.py b/packages/google-auth/noxfile.py index 5962f96bf094..752d719ebcf5 100644 --- a/packages/google-auth/noxfile.py +++ b/packages/google-auth/noxfile.py @@ -150,7 +150,6 @@ def mypy(session): "mypy", "types-certifi", "types-freezegun", - "types-pyOpenSSL", "types-requests", "types-setuptools", "types-mock", diff --git a/packages/google-auth/setup.py b/packages/google-auth/setup.py index cf3148130d6e..85902dbb32ce 100644 --- a/packages/google-auth/setup.py +++ b/packages/google-auth/setup.py @@ -35,10 +35,7 @@ reauth_extra_require = ["pyu2f>=0.1.5"] -# TODO(https://github.com/googleapis/google-auth-library-python/issues/1738): Add bounds for pyopenssl dependency. -enterprise_cert_extra_require = ["pyopenssl"] - -pyopenssl_extra_require = ["pyopenssl>=20.0.0"] +enterprise_cert_extra_require = cryptography_base_require # TODO(https://github.com/googleapis/google-auth-library-python/issues/1739): Add bounds for urllib3 and packaging dependencies. urllib3_extra_require = ["urllib3", "packaging"] @@ -55,7 +52,6 @@ "pytest", "pytest-cov", "pytest-localserver", - *pyopenssl_extra_require, *reauth_extra_require, "responses", *urllib3_extra_require, @@ -63,10 +59,6 @@ *aiohttp_extra_require, "aioresponses", "pytest-asyncio", - # TODO(https://github.com/googleapis/google-auth-library-python/issues/1665): Remove the pinned version of pyopenssl - # once `TestDecryptPrivateKey::test_success` is updated to remove the deprecated `OpenSSL.crypto.sign` and - # `OpenSSL.crypto.verify` methods. See: https://www.pyopenssl.org/en/latest/changelog.html#id3. - "pyopenssl < 24.3.0", # TODO(https://github.com/googleapis/google-auth-library-python/issues/1722): `test_aiohttp_requests` depend on # aiohttp < 3.10.0 which is a bug. Investigate and remove the pinned aiohttp version. "aiohttp < 3.10.0", @@ -77,7 +69,6 @@ "cryptography": cryptography_base_require, "aiohttp": aiohttp_extra_require, "enterprise_cert": enterprise_cert_extra_require, - "pyopenssl": pyopenssl_extra_require, "pyjwt": pyjwt_extra_require, "reauth": reauth_extra_require, "requests": requests_extra_require, diff --git a/packages/google-auth/system_tests/noxfile.py b/packages/google-auth/system_tests/noxfile.py index 2cc4d122cf02..825ef0aab509 100644 --- a/packages/google-auth/system_tests/noxfile.py +++ b/packages/google-auth/system_tests/noxfile.py @@ -322,7 +322,7 @@ def urllib3(session): @nox.session(python=PYTHON_VERSIONS_SYNC) def mtls_http(session): session.install(LIBRARY_DIR) - session.install(*TEST_DEPENDENCIES_SYNC, "pyopenssl") + session.install(*TEST_DEPENDENCIES_SYNC) session.env[EXPLICIT_CREDENTIALS_ENV] = SERVICE_ACCOUNT_FILE default( session, diff --git a/packages/google-auth/tests/compute_engine/test__mtls.py b/packages/google-auth/tests/compute_engine/test__mtls.py index 2effa29bbdc2..eb7b919fd374 100644 --- a/packages/google-auth/tests/compute_engine/test__mtls.py +++ b/packages/google-auth/tests/compute_engine/test__mtls.py @@ -123,7 +123,7 @@ def test_mds_mtls_adapter_init(mock_ssl_context, mock_mds_mtls_config): cafile=mock_mds_mtls_config.ca_cert_path ) adapter.ssl_context.load_cert_chain.assert_called_once_with( - certfile=mock_mds_mtls_config.client_combined_cert_path + certfile=mock_mds_mtls_config.client_combined_cert_path, password="" ) diff --git a/packages/google-auth/tests/test_identity_pool.py b/packages/google-auth/tests/test_identity_pool.py index c68fac64708d..cfc4f3589bc4 100644 --- a/packages/google-auth/tests/test_identity_pool.py +++ b/packages/google-auth/tests/test_identity_pool.py @@ -20,7 +20,8 @@ from unittest import mock import urllib -from OpenSSL import crypto +from cryptography import x509 +from cryptography.hazmat.primitives import serialization import pytest # type: ignore from google.auth import _helpers, external_account @@ -69,17 +70,15 @@ JSON_FILE_SUBJECT_TOKEN = JSON_FILE_CONTENT.get(SUBJECT_TOKEN_FIELD_NAME) with open(CERT_FILE, "rb") as f: + cert = x509.load_pem_x509_certificate(f.read()) CERT_FILE_CONTENT = base64.b64encode( - crypto.dump_certificate( - crypto.FILETYPE_ASN1, crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) - ) + cert.public_bytes(serialization.Encoding.DER) ).decode("utf-8") with open(OTHER_CERT_FILE, "rb") as f: + cert = x509.load_pem_x509_certificate(f.read()) OTHER_CERT_FILE_CONTENT = base64.b64encode( - crypto.dump_certificate( - crypto.FILETYPE_ASN1, crypto.load_certificate(crypto.FILETYPE_PEM, f.read()) - ) + cert.public_bytes(serialization.Encoding.DER) ).decode("utf-8") TOKEN_URL = "https://sts.googleapis.com/v1/token" diff --git a/packages/google-auth/tests/transport/test__custom_tls_signer.py b/packages/google-auth/tests/transport/test__custom_tls_signer.py index 3ecb29a60516..c0e40466e17e 100644 --- a/packages/google-auth/tests/transport/test__custom_tls_signer.py +++ b/packages/google-auth/tests/transport/test__custom_tls_signer.py @@ -22,12 +22,6 @@ from google.auth import exceptions from google.auth.transport import _custom_tls_signer -urllib3_pyopenssl = pytest.importorskip( - "urllib3.contrib.pyopenssl", - reason="urllib3.contrib.pyopenssl not available in this environment", -) - -urllib3_pyopenssl.inject_into_urllib3() FAKE_ENTERPRISE_CERT_FILE_PATH = "/path/to/enterprise/cert/file" ENTERPRISE_CERT_FILE = os.path.join( diff --git a/packages/google-auth/tests/transport/test__mtls_helper.py b/packages/google-auth/tests/transport/test__mtls_helper.py index 078df67470d2..ae34107f7ac6 100644 --- a/packages/google-auth/tests/transport/test__mtls_helper.py +++ b/packages/google-auth/tests/transport/test__mtls_helper.py @@ -14,28 +14,35 @@ import os import re +import sys +import tempfile from unittest import mock -from OpenSSL import crypto +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec import pytest # type: ignore from google.auth import environment_vars, exceptions from google.auth.transport import _mtls_helper +if not hasattr(os, "MFD_CLOEXEC"): + setattr(os, "MFD_CLOEXEC", 1) + CERT_MOCK_VAL = b"cert" KEY_MOCK_VAL = b"key" CONTEXT_AWARE_METADATA = {"cert_provider_command": ["some command"]} ENCRYPTED_EC_PRIVATE_KEY = b"""-----BEGIN ENCRYPTED PRIVATE KEY----- -MIHkME8GCSqGSIb3DQEFDTBCMCkGCSqGSIb3DQEFDDAcBAgl2/yVgs1h3QICCAAw -DAYIKoZIhvcNAgkFADAVBgkrBgEEAZdVAQIECJk2GRrvxOaJBIGQXIBnMU4wmciT -uA6yD8q0FxuIzjG7E2S6tc5VRgSbhRB00eBO3jWmO2pBybeQW+zVioDcn50zp2ts -wYErWC+LCm1Zg3r+EGnT1E1GgNoODbVQ3AEHlKh1CGCYhEovxtn3G+Fjh7xOBrNB -saVVeDb4tHD4tMkiVVUBrUcTZPndP73CtgyGHYEphasYPzEz3+AU +MIH0MF8GCSqGSIb3DQEFDTBSMDEGCSqGSIb3DQEFDDAkBBClWcQyUELNC9Hjr+Sp +WK85AgIIADAMBggqhkiG9w0CCQUAMB0GCWCGSAFlAwQBKgQQ6uJeoqE7P9HtxAgS +n6rBFgSBkMRDYXLucNp7ew7LbQmkZCmjnRhgyw6b0dD3eK8f3jisj8UiR8aj9a2S +1FZiNHKLmI7hkZHH+d2DPWYhe/tf5SS4iLzpZogBehMv4UDNnNaj0dvQZgpnpciK +1H+0u/i+crc1WAGlemLAi7dktCCBTzeX19cRMGHie68rx1C82LHLZmefr7AEIVxp +uUoJ+sLhBw== -----END ENCRYPTED PRIVATE KEY-----""" EC_PUBLIC_KEY = b"""-----BEGIN PUBLIC KEY----- -MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEvCNi1NoDY1oMqPHIgXI8RBbTYGi/ -brEjbre1nSiQW11xRTJbVeETdsuP0EAu2tG3PcRhhwDfeJ8zXREgTBurNw== +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwdsHzL05VUmqYJat2yGdbSHQAg49 +Wc+fhwLH3b+SCC/2/TqPNDy9yMdMxMtEfZfKal2EaeE2erJrtu7WNfjD0Q== -----END PUBLIC KEY-----""" PASSPHRASE = b"""-----BEGIN PASSPHRASE----- @@ -757,17 +764,15 @@ def test_success(self): decrypted_key = _mtls_helper.decrypt_private_key( ENCRYPTED_EC_PRIVATE_KEY, PASSPHRASE_VALUE ) - private_key = crypto.load_privatekey(crypto.FILETYPE_PEM, decrypted_key) - public_key = crypto.load_publickey(crypto.FILETYPE_PEM, EC_PUBLIC_KEY) - x509 = crypto.X509() - x509.set_pubkey(public_key) + private_key = serialization.load_pem_private_key(decrypted_key, password=None) + public_key = serialization.load_pem_public_key(EC_PUBLIC_KEY) # Test the decrypted key works by signing and verification. - signature = crypto.sign(private_key, b"data", "sha256") - crypto.verify(x509, signature, b"data", "sha256") + signature = private_key.sign(b"data", ec.ECDSA(hashes.SHA256())) + public_key.verify(signature, b"data", ec.ECDSA(hashes.SHA256())) def test_crypto_error(self): - with pytest.raises(crypto.Error): + with pytest.raises(ValueError): _mtls_helper.decrypt_private_key( ENCRYPTED_EC_PRIVATE_KEY, b"wrong_password" ) @@ -992,3 +997,359 @@ def test_call_client_cert_callback(self, mock_get_client_ssl_credentials): mock_get_client_ssl_credentials.assert_called_once_with( generate_encrypted_key=True ) + + +class TestSecureCertKeyPaths(object): + def test_tier1_pass_through(self): + with _mtls_helper.secure_cert_key_paths( + "/path/to/cert", "/path/to/key", b"passphrase" + ) as (cert_path, key_path, passphrase): + assert cert_path == "/path/to/cert" + assert key_path == "/path/to/key" + assert passphrase == b"passphrase" + + @mock.patch.object(sys, "platform", "linux") + @mock.patch.object(os, "memfd_create", create=True) + @mock.patch.object(_mtls_helper, "_memfd_cert_key_paths", autospec=True) + def test_tier2_memfd_success(self, mock_memfd_cm, mock_memfd_create): + mock_memfd_ctx = mock.MagicMock() + mock_memfd_ctx.__enter__.return_value = ( + "/proc/self/fd/3", + "/proc/self/fd/4", + ) + mock_memfd_cm.return_value = mock_memfd_ctx + + with mock.patch.object(os.path, "exists", return_value=True): + with _mtls_helper.secure_cert_key_paths( + pytest.public_cert_bytes, + pytest.private_key_bytes, + b"passphrase", + ) as (cert_path, key_path, passphrase): + assert cert_path == "/proc/self/fd/3" + assert key_path == "/proc/self/fd/4" + assert passphrase == b"passphrase" + assert mock_memfd_ctx.__exit__.called + + @mock.patch.object(sys, "platform", "linux") + @mock.patch.object(os, "memfd_create", create=True) + @mock.patch.object(_mtls_helper, "_memfd_cert_key_paths", autospec=True) + @mock.patch.object(_mtls_helper, "_tempfile_cert_key_paths", autospec=True) + def test_tier2_restricted_filesystem( + self, mock_tempfile_cm, mock_memfd_cm, mock_memfd_create + ): + mock_memfd_ctx = mock.MagicMock() + mock_memfd_ctx.__enter__.return_value = ( + "/proc/self/fd/3", + "/proc/self/fd/4", + ) + mock_memfd_cm.return_value = mock_memfd_ctx + + mock_tempfile_ctx = mock.MagicMock() + mock_tempfile_ctx.__enter__.return_value = ( + "/tmp/cert", + "/tmp/key", + b"new_pass", + ) + mock_tempfile_cm.return_value = mock_tempfile_ctx + + with mock.patch.object(os.path, "exists", return_value=False): + with _mtls_helper.secure_cert_key_paths( + pytest.public_cert_bytes, pytest.private_key_bytes, b"passphrase" + ) as (cert_path, key_path, passphrase): + assert cert_path == "/tmp/cert" + assert key_path == "/tmp/key" + assert passphrase == b"new_pass" + mock_memfd_ctx.__exit__.assert_called_once_with(None, None, None) + + @mock.patch.object(sys, "platform", "linux") + @mock.patch.object(os, "memfd_create", create=True) + @mock.patch.object(_mtls_helper, "_memfd_cert_key_paths", autospec=True) + @mock.patch.object(_mtls_helper, "_tempfile_cert_key_paths", autospec=True) + def test_tier2_fallback_to_tier3_on_oserror( + self, mock_tempfile_cm, mock_memfd_cm, mock_memfd_create + ): + mock_memfd_ctx = mock.MagicMock() + mock_memfd_ctx.__enter__.side_effect = OSError("memfd failed") + mock_memfd_cm.return_value = mock_memfd_ctx + + mock_tempfile_ctx = mock.MagicMock() + mock_tempfile_ctx.__enter__.return_value = ( + "/tmp/cert", + "/tmp/key", + b"new_pass", + ) + mock_tempfile_cm.return_value = mock_tempfile_ctx + + with _mtls_helper.secure_cert_key_paths( + pytest.public_cert_bytes, pytest.private_key_bytes, b"passphrase" + ) as (cert_path, key_path, passphrase): + assert cert_path == "/tmp/cert" + assert key_path == "/tmp/key" + assert passphrase == b"new_pass" + + @mock.patch.object(sys, "platform", "darwin") + @mock.patch.object(_mtls_helper, "_tempfile_cert_key_paths", autospec=True) + def test_tier3_tempfile_success_non_linux(self, mock_tempfile_cm): + mock_tempfile_ctx = mock.MagicMock() + mock_tempfile_ctx.__enter__.return_value = ( + "/tmp/cert", + "/tmp/key", + b"new_pass", + ) + mock_tempfile_cm.return_value = mock_tempfile_ctx + + with _mtls_helper.secure_cert_key_paths( + pytest.public_cert_bytes, pytest.private_key_bytes, b"passphrase" + ) as (cert_path, key_path, passphrase): + assert cert_path == "/tmp/cert" + assert key_path == "/tmp/key" + assert passphrase == b"new_pass" + + @mock.patch.object(sys, "platform", "darwin") + @mock.patch.object(_mtls_helper, "_tempfile_cert_key_paths", autospec=True) + def test_hybrid_inputs(self, mock_tempfile_cm): + mock_tempfile_ctx = mock.MagicMock() + mock_tempfile_ctx.__enter__.return_value = ( + None, + "/tmp/key", + b"new_pass", + ) + mock_tempfile_cm.return_value = mock_tempfile_ctx + + with _mtls_helper.secure_cert_key_paths( + "/pass/through/cert.pem", pytest.private_key_bytes, b"passphrase" + ) as (cert_path, key_path, passphrase): + assert cert_path == "/pass/through/cert.pem" + assert key_path == "/tmp/key" + assert passphrase == b"new_pass" + + +class TestMemfdCertKeyPaths(object): + @mock.patch.object(os, "memfd_create", create=True) + @mock.patch.object(os, "fdopen") + @mock.patch.object(os, "close") + def test_success_both_bytes(self, mock_close, mock_fdopen, mock_memfd_create): + mock_memfd_create.side_effect = [10, 11] + mock_file_cert = mock.MagicMock() + mock_file_cert.__enter__.return_value = mock_file_cert + mock_file_key = mock.MagicMock() + mock_file_key.__enter__.return_value = mock_file_key + mock_fdopen.side_effect = [mock_file_cert, mock_file_key] + with _mtls_helper._memfd_cert_key_paths(b"cert", b"key") as ( + cert_path, + key_path, + ): + assert cert_path == "/proc/self/fd/10" + assert key_path == "/proc/self/fd/11" + mock_fdopen.assert_has_calls( + [mock.call(10, "wb", closefd=False), mock.call(11, "wb", closefd=False)] + ) + mock_file_cert.write.assert_called_once_with(b"cert") + mock_file_key.write.assert_called_once_with(b"key") + assert mock_close.call_count == 2 + + @mock.patch.object(os, "memfd_create", create=True) + @mock.patch.object(os, "fdopen") + @mock.patch.object(os, "close") + def test_close_ignores_oserror(self, mock_close, mock_fdopen, mock_memfd_create): + mock_memfd_create.return_value = 12 + mock_close.side_effect = OSError("close error") + mock_file = mock.MagicMock() + mock_file.__enter__.return_value = mock_file + mock_fdopen.return_value = mock_file + with _mtls_helper._memfd_cert_key_paths(b"cert", None) as (cert_path, key_path): + assert cert_path == "/proc/self/fd/12" + assert key_path is None + mock_fdopen.assert_called_once_with(12, "wb", closefd=False) + mock_file.write.assert_called_once_with(b"cert") + mock_close.assert_called_once_with(12) + + @mock.patch.object(os, "memfd_create", create=True) + @mock.patch.object(os, "fdopen") + @mock.patch.object(os, "close") + def test_write_oserror_prevents_fd_leak( + self, mock_close, mock_fdopen, mock_memfd_create + ): + mock_memfd_create.return_value = 15 + mock_file = mock.MagicMock() + mock_file.__enter__.return_value = mock_file + mock_file.write.side_effect = OSError("write fault") + mock_fdopen.return_value = mock_file + with pytest.raises(OSError): + with _mtls_helper._memfd_cert_key_paths(b"cert", None): + pass + mock_fdopen.assert_called_once_with(15, "wb", closefd=False) + mock_file.write.assert_called_once_with(b"cert") + mock_close.assert_called_once_with(15) + + +class TestTempfileCertKeyPaths(object): + @mock.patch.object(os.path, "isdir", return_value=True) + @mock.patch.object(tempfile, "mkstemp") + @mock.patch.object(os, "fdopen") + @mock.patch.object(_mtls_helper, "_encrypt_key_if_plaintext", autospec=True) + @mock.patch.object(_mtls_helper, "_secure_wipe_and_remove", autospec=True) + def test_success_shm( + self, + mock_wipe, + mock_encrypt, + mock_fdopen, + mock_mkstemp, + mock_isdir, + ): + mock_mkstemp.side_effect = [(1, "/shm/cert"), (2, "/shm/key")] + mock_encrypt.return_value = (b"encrypted_key", b"new_pass") + mock_file = mock.MagicMock() + mock_file.fileno.return_value = 1 + mock_fdopen.return_value.__enter__.return_value = mock_file + + with mock.patch.object(os, "remove") as mock_remove, mock.patch.object( + os.path, "exists", return_value=True + ): + with _mtls_helper._tempfile_cert_key_paths(b"cert", b"key", b"pass") as ( + cert_path, + key_path, + passphrase, + ): + assert cert_path == "/shm/cert" + assert key_path == "/shm/key" + assert passphrase == b"new_pass" + mock_remove.assert_called_once_with("/shm/cert") + + mock_mkstemp.assert_has_calls( + [mock.call(dir="/dev/shm"), mock.call(dir="/dev/shm")] + ) + mock_wipe.assert_called_once_with("/shm/key") + + @mock.patch.object(os.path, "isdir", return_value=True) + @mock.patch.object(tempfile, "mkstemp") + @mock.patch.object(os, "fdopen") + @mock.patch.object(_mtls_helper, "_encrypt_key_if_plaintext", autospec=True) + @mock.patch.object(_mtls_helper, "_secure_wipe_and_remove", autospec=True) + def test_permission_error_loop_resilience( + self, + mock_wipe, + mock_encrypt, + mock_fdopen, + mock_mkstemp, + mock_isdir, + ): + mock_mkstemp.side_effect = [(1, "/shm/cert"), (2, "/shm/key")] + mock_encrypt.return_value = (b"encrypted_key", b"new_pass") + mock_file = mock.MagicMock() + mock_file.fileno.return_value = 1 + mock_fdopen.return_value.__enter__.return_value = mock_file + + mock_wipe.side_effect = PermissionError("lock error") + + with mock.patch.object(os, "remove") as mock_remove, mock.patch.object( + os.path, "exists", return_value=True + ): + with _mtls_helper._tempfile_cert_key_paths(b"cert", b"key", b"pass"): + pass + mock_remove.assert_called_once_with("/shm/cert") + + +class TestEncryptKeyIfPlaintext(object): + def test_encrypts_plaintext_key(self): + encrypted_bytes, passphrase = _mtls_helper._encrypt_key_if_plaintext( + pytest.private_key_bytes, b"my_passphrase" + ) + assert passphrase == b"my_passphrase" + assert encrypted_bytes != pytest.private_key_bytes + assert b"ENCRYPTED PRIVATE KEY" in encrypted_bytes + + decrypted = serialization.load_pem_private_key( + encrypted_bytes, password=b"my_passphrase" + ) + assert decrypted + + @mock.patch("secrets.token_hex", return_value="0123456789abcdef0123456789abcdef") + def test_default_passphrase_generation(self, mock_secrets): + encrypted_bytes, passphrase = _mtls_helper._encrypt_key_if_plaintext( + pytest.private_key_bytes, None + ) + assert passphrase == b"0123456789abcdef0123456789abcdef" + assert b"ENCRYPTED PRIVATE KEY" in encrypted_bytes + + def test_returns_encrypted_key_asis(self): + encrypted_bytes, passphrase = _mtls_helper._encrypt_key_if_plaintext( + ENCRYPTED_EC_PRIVATE_KEY, b"passphrase" + ) + assert encrypted_bytes == ENCRYPTED_EC_PRIVATE_KEY + assert passphrase == b"passphrase" + + def test_returns_invalid_key_asis(self): + invalid_bytes = b"not a valid key" + encrypted_bytes, passphrase = _mtls_helper._encrypt_key_if_plaintext( + invalid_bytes, b"passphrase" + ) + assert encrypted_bytes == invalid_bytes + assert passphrase == b"passphrase" + + +class TestSecureWipeAndRemove(object): + @mock.patch.object(os.path, "exists", return_value=True) + @mock.patch.object(os.path, "getsize", return_value=10) + @mock.patch("builtins.open", autospec=True) + @mock.patch.object(os, "fsync") + @mock.patch.object(os, "remove") + def test_success( + self, mock_remove, mock_fsync, mock_open, mock_getsize, mock_exists + ): + mock_fh = mock.MagicMock() + mock_fh.fileno.return_value = 1 + mock_open.return_value.__enter__.return_value = mock_fh + + _mtls_helper._secure_wipe_and_remove("/path/to/secret") + + mock_open.assert_called_once_with("/path/to/secret", "r+b") + mock_fh.write.assert_called_once_with(b"\0" * 10) + mock_fsync.assert_called_once() + mock_remove.assert_called_once_with("/path/to/secret") + + @mock.patch.object(os.path, "exists", return_value=False) + @mock.patch.object(os, "remove") + def test_file_not_found(self, mock_remove, mock_exists): + _mtls_helper._secure_wipe_and_remove("/path/to/nonexistent") + + mock_exists.assert_called_once_with("/path/to/nonexistent") + mock_remove.assert_not_called() + + @mock.patch.object(os.path, "exists", return_value=True) + @mock.patch.object(os.path, "getsize", return_value=10) + @mock.patch("builtins.open", autospec=True) + @mock.patch.object(os, "fsync") + @mock.patch.object(os, "remove") + def test_write_oserror_ignored( + self, mock_remove, mock_fsync, mock_open, mock_getsize, mock_exists + ): + mock_fh = mock.MagicMock() + mock_fh.fileno.return_value = 1 + mock_fh.write.side_effect = OSError("write fault") + mock_open.return_value.__enter__.return_value = mock_fh + + _mtls_helper._secure_wipe_and_remove("/path/to/secret") + + mock_open.assert_called_once_with("/path/to/secret", "r+b") + mock_fsync.assert_not_called() + mock_remove.assert_called_once_with("/path/to/secret") + + @mock.patch.object(os.path, "exists", return_value=True) + @mock.patch.object(os.path, "getsize", return_value=10) + @mock.patch("builtins.open", autospec=True) + @mock.patch.object(os, "fsync") + @mock.patch.object(os, "remove") + def test_remove_oserror_ignored( + self, mock_remove, mock_fsync, mock_open, mock_getsize, mock_exists + ): + mock_fh = mock.MagicMock() + mock_fh.fileno.return_value = 1 + mock_open.return_value.__enter__.return_value = mock_fh + mock_remove.side_effect = OSError("remove fault") + + _mtls_helper._secure_wipe_and_remove("/path/to/secret") + + mock_open.assert_called_once_with("/path/to/secret", "r+b") + mock_fsync.assert_called_once() + mock_remove.assert_called_once_with("/path/to/secret") diff --git a/packages/google-auth/tests/transport/test_aio_mtls_helper.py b/packages/google-auth/tests/transport/test_aio_mtls_helper.py index bc9cde7d793b..2af155d9ee83 100644 --- a/packages/google-auth/tests/transport/test_aio_mtls_helper.py +++ b/packages/google-auth/tests/transport/test_aio_mtls_helper.py @@ -26,24 +26,6 @@ class TestMTLS: - @pytest.mark.asyncio - async def test__create_temp_file(self): - """Tests that _create_temp_file creates a file with correct content and deletes it.""" - content = b"test cert data" - - # Test file creation and content - with mtls._create_temp_file(content) as file_path: - assert os.path.exists(file_path) - # Verify file is not readable by others (mkstemp default) - if os.name == "posix": - assert (os.stat(file_path).st_mode & 0o777) == 0o600 - - with open(file_path, "rb") as f: - assert f.read() == content - - # Test file deletion after context exit - assert not os.path.exists(file_path) - @pytest.mark.asyncio async def test_make_client_cert_ssl_context_success(self): """Tests successful creation of an SSLContext with client certificates.""" diff --git a/packages/google-auth/tests/transport/test_mtls.py b/packages/google-auth/tests/transport/test_mtls.py index 405cb496cad2..69cbe792575d 100644 --- a/packages/google-auth/tests/transport/test_mtls.py +++ b/packages/google-auth/tests/transport/test_mtls.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import ssl from unittest import mock import pytest # type: ignore @@ -135,6 +137,12 @@ def test_default_client_cert_source( with pytest.raises(exceptions.MutualTLSChannelError): callback() + # Test bad callback which throws ClientCertError. + get_client_cert_and_key.side_effect = exceptions.ClientCertError() + callback = mtls.default_client_cert_source() + with pytest.raises(exceptions.MutualTLSChannelError): + callback() + @mock.patch( "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True @@ -173,3 +181,258 @@ def test_should_use_client_cert(check_use_client_cert): check_use_client_cert.return_value = False assert not mtls.should_use_client_cert() + + +@contextlib.contextmanager +def _fake_secure_paths(cert_bytes, key_bytes, passphrase=None): + yield "cert_path", "key_path", passphrase + + +@mock.patch( + "google.auth.transport._mtls_helper.secure_cert_key_paths", + side_effect=_fake_secure_paths, +) +def test_load_client_cert_into_context_success(mock_secure_paths): + mock_ctx = mock.Mock(spec=ssl.SSLContext) + result = mtls.load_client_cert_into_context( + mock_ctx, b"cert", b"key", passphrase=b"passphrase" + ) + assert result is None + mock_ctx.load_cert_chain.assert_called_once_with( + certfile="cert_path", keyfile="key_path", password=b"passphrase" + ) + + +@mock.patch( + "google.auth.transport._mtls_helper.secure_cert_key_paths", + side_effect=_fake_secure_paths, +) +def test_load_client_cert_into_context_error(mock_secure_paths): + mock_ctx = mock.Mock(spec=ssl.SSLContext) + mock_ctx.load_cert_chain.side_effect = ssl.SSLError("boom") + with pytest.raises(exceptions.MutualTLSChannelError) as exc_info: + mtls.load_client_cert_into_context(mock_ctx, b"cert", b"key") + assert "Failed to load client certificate and key" in str(exc_info.value) + assert isinstance(exc_info.value.__cause__, ssl.SSLError) + + +def test_load_client_cert_into_context_invalid_ctx(): + with pytest.raises(exceptions.MutualTLSChannelError) as exc_info: + mtls.load_client_cert_into_context(None, b"cert", b"key") + assert ( + "The provided context object is invalid or does not support loading certificate chains" + in str(exc_info.value) + ) + assert exc_info.value.__cause__ is None + + +@mock.patch( + "google.auth.transport._mtls_helper.secure_cert_key_paths", + side_effect=_fake_secure_paths, +) +def test_load_client_cert_into_context_load_chain_type_error(mock_secure_paths): + mock_ctx = mock.Mock(spec=ssl.SSLContext) + mock_ctx.load_cert_chain.side_effect = TypeError("invalid password type") + with pytest.raises(TypeError) as exc_info: + mtls.load_client_cert_into_context(mock_ctx, b"cert", b"key") + assert "invalid password type" in str(exc_info.value) + + +@mock.patch("google.auth.transport.mtls.load_client_cert_into_context", autospec=True) +@mock.patch("ssl.create_default_context", autospec=True) +def test_make_client_cert_ssl_context(mock_create_context, mock_load_cert): + mock_ctx = mock.Mock(spec=ssl.SSLContext) + mock_create_context.return_value = mock_ctx + + result = mtls.make_client_cert_ssl_context(b"cert", b"key", b"passphrase") + + assert result == mock_ctx + mock_create_context.assert_called_once_with(ssl.Purpose.SERVER_AUTH) + mock_load_cert.assert_called_once_with( + mock_ctx, b"cert", b"key", passphrase=b"passphrase" + ) + + +@mock.patch("google.auth.transport.mtls.should_use_client_cert", autospec=True) +def test_load_default_client_cert_disabled(mock_should_use): + mock_should_use.return_value = False + mock_ctx = mock.Mock(spec=ssl.SSLContext) + assert mtls.load_default_client_cert(mock_ctx) is False + mock_ctx.load_cert_chain.assert_not_called() + + +@mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True) +@mock.patch("google.auth.transport.mtls.should_use_client_cert", autospec=True) +def test_load_default_client_cert_no_source(mock_should_use, mock_has_source): + mock_should_use.return_value = True + mock_has_source.return_value = False + mock_ctx = mock.Mock(spec=ssl.SSLContext) + assert mtls.load_default_client_cert(mock_ctx) is False + mock_ctx.load_cert_chain.assert_not_called() + + +@mock.patch( + "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True +) +@mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True) +@mock.patch("google.auth.transport.mtls.should_use_client_cert", autospec=True) +def test_load_default_client_cert_no_cert( + mock_should_use, mock_has_source, mock_get_credentials +): + mock_should_use.return_value = True + mock_has_source.return_value = True + mock_get_credentials.return_value = (False, None, None, None) + mock_ctx = mock.Mock(spec=ssl.SSLContext) + assert mtls.load_default_client_cert(mock_ctx) is False + mock_ctx.load_cert_chain.assert_not_called() + + +@mock.patch("google.auth.transport.mtls.load_client_cert_into_context", autospec=True) +@mock.patch( + "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True +) +@mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True) +@mock.patch("google.auth.transport.mtls.should_use_client_cert", autospec=True) +def test_load_default_client_cert_success( + mock_should_use, mock_has_source, mock_get_credentials, mock_load_cert +): + mock_should_use.return_value = True + mock_has_source.return_value = True + mock_get_credentials.return_value = (True, b"cert", b"key", b"passphrase") + mock_ctx = mock.Mock(spec=ssl.SSLContext) + + assert mtls.load_default_client_cert(mock_ctx) is True + mock_load_cert.assert_called_once_with(mock_ctx, b"cert", b"key", b"passphrase") + + +@mock.patch( + "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True +) +@mock.patch("google.auth.transport.mtls.has_default_client_cert_source", autospec=True) +@mock.patch("google.auth.transport.mtls.should_use_client_cert", autospec=True) +def test_load_default_client_cert_propagates_client_cert_error( + mock_should_use, mock_has_source, mock_get_credentials +): + mock_should_use.return_value = True + mock_has_source.return_value = True + mock_get_credentials.side_effect = exceptions.ClientCertError("credentials failure") + mock_ctx = mock.Mock(spec=ssl.SSLContext) + + with pytest.raises(exceptions.ClientCertError) as exc_info: + mtls.load_default_client_cert(mock_ctx) + assert "credentials failure" in str(exc_info.value) + + +@mock.patch("google.auth.transport.mtls.load_default_client_cert", autospec=True) +@mock.patch("ssl.create_default_context", autospec=True) +def test_get_default_ssl_context_configured(mock_create_context, mock_load_default): + mock_ctx = mock.Mock(spec=ssl.SSLContext) + mock_create_context.return_value = mock_ctx + mock_load_default.return_value = True + + result = mtls.get_default_ssl_context() + + assert result == mock_ctx + mock_create_context.assert_called_once_with(ssl.Purpose.SERVER_AUTH) + mock_load_default.assert_called_once_with(mock_ctx) + + +@mock.patch("google.auth.transport.mtls.load_default_client_cert", autospec=True) +@mock.patch("ssl.create_default_context", autospec=True) +def test_get_default_ssl_context_unconfigured(mock_create_context, mock_load_default): + mock_ctx = mock.Mock(spec=ssl.SSLContext) + mock_create_context.return_value = mock_ctx + mock_load_default.return_value = False + + result = mtls.get_default_ssl_context() + + assert result is None + mock_create_context.assert_called_once_with(ssl.Purpose.SERVER_AUTH) + mock_load_default.assert_called_once_with(mock_ctx) + + +@pytest.mark.parametrize( + "env_val,client_cert_available,expected", + [ + ("always", True, True), + ("always", False, True), + ("never", True, False), + ("never", False, False), + ("auto", True, True), + ("auto", False, False), + (None, True, True), # Defaults to auto + (None, False, False), # Defaults to auto + ("ALWAYS", True, True), + ("ALWAYS", False, True), + ("NEVER", True, False), + ("NEVER", False, False), + ("AUTO", True, True), + ("AUTO", False, False), + ("invalid_val", True, True), + ("invalid_val", False, False), + ], +) +@mock.patch( + "google.auth.environment_vars.GOOGLE_API_USE_MTLS_ENDPOINT", + "GOOGLE_API_USE_MTLS_ENDPOINT", +) +@mock.patch("google.auth.transport.mtls.getenv", autospec=True) +def test_should_use_mtls_endpoint( + mock_getenv, env_val, client_cert_available, expected +): + mock_getenv.side_effect = ( + lambda var, default=None: env_val + if (var == "GOOGLE_API_USE_MTLS_ENDPOINT" and env_val is not None) + else default + ) + result = mtls.should_use_mtls_endpoint(client_cert_available) + assert result == expected + + +@mock.patch( + "google.auth.environment_vars.GOOGLE_API_USE_MTLS_ENDPOINT", + "GOOGLE_API_USE_MTLS_ENDPOINT", +) +@mock.patch("google.auth.transport.mtls.getenv", autospec=True) +def test_should_use_mtls_endpoint_invalid_value(mock_getenv, caplog): + mock_getenv.side_effect = ( + lambda var, default=None: "invalid_value" + if var == "GOOGLE_API_USE_MTLS_ENDPOINT" + else default + ) + with caplog.at_level("WARNING"): + assert mtls.should_use_mtls_endpoint(True) is True + assert "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value" in caplog.text + assert "Defaulting to auto" in caplog.text + + caplog.clear() + + with caplog.at_level("WARNING"): + assert mtls.should_use_mtls_endpoint(False) is False + assert "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value" in caplog.text + assert "Defaulting to auto" in caplog.text + + +@mock.patch( + "google.auth.environment_vars.GOOGLE_API_USE_MTLS_ENDPOINT", + "GOOGLE_API_USE_MTLS_ENDPOINT", +) +@mock.patch("google.auth.transport.mtls.should_use_client_cert", autospec=True) +@mock.patch("google.auth.transport.mtls.getenv", autospec=True) +def test_should_use_mtls_endpoint_default_client_cert( + mock_getenv, mock_should_use_client_cert +): + mock_getenv.side_effect = ( + lambda var, default=None: "auto" + if var == "GOOGLE_API_USE_MTLS_ENDPOINT" + else default + ) + mock_should_use_client_cert.return_value = True + assert mtls.should_use_mtls_endpoint() is True + mock_should_use_client_cert.assert_called_once() + + mock_should_use_client_cert.reset_mock() + + mock_should_use_client_cert.return_value = False + assert mtls.should_use_mtls_endpoint() is False + mock_should_use_client_cert.assert_called_once() diff --git a/packages/google-auth/tests/transport/test_requests.py b/packages/google-auth/tests/transport/test_requests.py index c9fab036e17b..972379159f29 100644 --- a/packages/google-auth/tests/transport/test_requests.py +++ b/packages/google-auth/tests/transport/test_requests.py @@ -16,11 +16,9 @@ import functools import http.client as http_client import os -import sys from unittest import mock import freezegun -import OpenSSL import pytest # type: ignore import requests import requests.adapters @@ -192,18 +190,11 @@ def test_success(self, mock_proxy_manager_for, mock_init_poolmanager): mock_proxy_manager_for.assert_called_with(ssl_context=adapter._ctx_proxymanager) def test_invalid_cert_or_key(self): - with pytest.raises(OpenSSL.crypto.Error): + with pytest.raises(exceptions.MutualTLSChannelError): google.auth.transport.requests._MutualTlsAdapter( b"invalid cert", b"invalid key" ) - @mock.patch.dict("sys.modules", {"OpenSSL.crypto": None}) - def test_import_error(self): - with pytest.raises(ImportError): - google.auth.transport.requests._MutualTlsAdapter( - pytest.public_cert_bytes, pytest.private_key_bytes - ) - def make_response(status=http_client.OK, data=None): response = requests.Response() @@ -491,9 +482,29 @@ def test_configure_mtls_channel_exceptions(self, mock_get_client_cert_and_key): ): auth_session.configure_mtls_channel() - mock_get_client_cert_and_key.return_value = (False, None, None) - with mock.patch.dict("sys.modules"): - sys.modules["OpenSSL"] = None + @mock.patch( + "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True + ) + @mock.patch("google.auth.transport.requests.create_urllib3_context", autospec=True) + def test_configure_mtls_channel_cert_loading_exceptions( + self, mock_create_urllib3_context, mock_get_client_cert_and_key + ): + import ssl + + mock_get_client_cert_and_key.return_value = ( + True, + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + + for exception_type in [ValueError("error"), ssl.SSLError("error")]: + mock_ctx = mock.Mock() + mock_ctx.load_cert_chain.side_effect = exception_type + mock_create_urllib3_context.return_value = mock_ctx + + auth_session = google.auth.transport.requests.AuthorizedSession( + credentials=mock.Mock() + ) with pytest.raises(exceptions.MutualTLSChannelError): with mock.patch.dict( os.environ, @@ -501,6 +512,8 @@ def test_configure_mtls_channel_exceptions(self, mock_get_client_cert_and_key): ): auth_session.configure_mtls_channel() + assert not auth_session.is_mtls + @mock.patch( "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True ) diff --git a/packages/google-auth/tests/transport/test_urllib3.py b/packages/google-auth/tests/transport/test_urllib3.py index b29e4e950433..99c359ad398d 100644 --- a/packages/google-auth/tests/transport/test_urllib3.py +++ b/packages/google-auth/tests/transport/test_urllib3.py @@ -14,10 +14,8 @@ import http.client as http_client import os -import sys from unittest import mock -import OpenSSL import pytest # type: ignore import urllib3 # type: ignore @@ -103,18 +101,11 @@ def test_success(self): assert isinstance(http, urllib3.PoolManager) def test_crypto_error(self): - with pytest.raises(OpenSSL.crypto.Error): + with pytest.raises(exceptions.MutualTLSChannelError): google.auth.transport.urllib3._make_mutual_tls_http( b"invalid cert", b"invalid key" ) - @mock.patch.dict("sys.modules", {"OpenSSL.crypto": None}) - def test_import_error(self): - with pytest.raises(ImportError): - google.auth.transport.urllib3._make_mutual_tls_http( - pytest.public_cert_bytes, pytest.private_key_bytes - ) - class TestAuthorizedHttp(object): TEST_URL = "http://example.com" @@ -280,9 +271,33 @@ def test_configure_mtls_channel_exceptions(self, mock_get_client_cert_and_key): ): authed_http.configure_mtls_channel() - mock_get_client_cert_and_key.return_value = (False, None, None) - with mock.patch.dict("sys.modules"): - sys.modules["OpenSSL"] = None + @mock.patch( + "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True + ) + @mock.patch( + "google.auth.transport.urllib3.urllib3.util.ssl_.create_urllib3_context", + autospec=True, + ) + def test_configure_mtls_channel_cert_loading_exceptions( + self, mock_create_urllib3_context, mock_get_client_cert_and_key + ): + import ssl + + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials=mock.Mock() + ) + + mock_get_client_cert_and_key.return_value = ( + True, + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + + for exception_type in [ValueError("error"), ssl.SSLError("error")]: + mock_ctx = mock.Mock() + mock_ctx.load_cert_chain.side_effect = exception_type + mock_create_urllib3_context.return_value = mock_ctx + with pytest.raises(exceptions.MutualTLSChannelError): with mock.patch.dict( os.environ, @@ -290,6 +305,8 @@ def test_configure_mtls_channel_exceptions(self, mock_get_client_cert_and_key): ): authed_http.configure_mtls_channel() + assert not authed_http._is_mtls + @mock.patch( "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True )