Add SSL Cipher option to aiohttp async_get_clientsession (#126317)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
starkillerOG 2024-09-24 21:31:52 +02:00 committed by GitHub
parent b370893e58
commit 69ecdda5f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 164 additions and 68 deletions

View file

@ -32,11 +32,11 @@ if TYPE_CHECKING:
from aiohttp.typedefs import JSONDecoder
DATA_CONNECTOR: HassKey[dict[tuple[bool, int], aiohttp.BaseConnector]] = HassKey(
DATA_CONNECTOR: HassKey[dict[tuple[bool, int, str], aiohttp.BaseConnector]] = HassKey(
"aiohttp_connector"
)
DATA_CLIENTSESSION: HassKey[dict[tuple[bool, int], aiohttp.ClientSession]] = HassKey(
"aiohttp_clientsession"
DATA_CLIENTSESSION: HassKey[dict[tuple[bool, int, str], aiohttp.ClientSession]] = (
HassKey("aiohttp_clientsession")
)
SERVER_SOFTWARE = (
@ -86,12 +86,13 @@ def async_get_clientsession(
hass: HomeAssistant,
verify_ssl: bool = True,
family: socket.AddressFamily = socket.AF_UNSPEC,
ssl_cipher: ssl_util.SSLCipherList = ssl_util.SSLCipherList.PYTHON_DEFAULT,
) -> aiohttp.ClientSession:
"""Return default aiohttp ClientSession.
This method must be run in the event loop.
"""
session_key = _make_key(verify_ssl, family)
session_key = _make_key(verify_ssl, family, ssl_cipher)
sessions = hass.data.setdefault(DATA_CLIENTSESSION, {})
if session_key not in sessions:
@ -100,6 +101,7 @@ def async_get_clientsession(
verify_ssl,
auto_cleanup_method=_async_register_default_clientsession_shutdown,
family=family,
ssl_cipher=ssl_cipher,
)
sessions[session_key] = session
else:
@ -115,6 +117,7 @@ def async_create_clientsession(
verify_ssl: bool = True,
auto_cleanup: bool = True,
family: socket.AddressFamily = socket.AF_UNSPEC,
ssl_cipher: ssl_util.SSLCipherList = ssl_util.SSLCipherList.PYTHON_DEFAULT,
**kwargs: Any,
) -> aiohttp.ClientSession:
"""Create a new ClientSession with kwargs, i.e. for cookies.
@ -135,6 +138,7 @@ def async_create_clientsession(
verify_ssl,
auto_cleanup_method=auto_cleanup_method,
family=family,
ssl_cipher=ssl_cipher,
**kwargs,
)
@ -146,11 +150,12 @@ def _async_create_clientsession(
auto_cleanup_method: Callable[[HomeAssistant, aiohttp.ClientSession], None]
| None = None,
family: socket.AddressFamily = socket.AF_UNSPEC,
ssl_cipher: ssl_util.SSLCipherList = ssl_util.SSLCipherList.PYTHON_DEFAULT,
**kwargs: Any,
) -> aiohttp.ClientSession:
"""Create a new ClientSession with kwargs, i.e. for cookies."""
clientsession = aiohttp.ClientSession(
connector=_async_get_connector(hass, verify_ssl, family),
connector=_async_get_connector(hass, verify_ssl, family, ssl_cipher),
json_serialize=json_dumps,
response_class=HassClientResponse,
**kwargs,
@ -279,10 +284,12 @@ def _async_register_default_clientsession_shutdown(
@callback
def _make_key(
verify_ssl: bool = True, family: socket.AddressFamily = socket.AF_UNSPEC
) -> tuple[bool, socket.AddressFamily]:
verify_ssl: bool = True,
family: socket.AddressFamily = socket.AF_UNSPEC,
ssl_cipher: ssl_util.SSLCipherList = ssl_util.SSLCipherList.PYTHON_DEFAULT,
) -> tuple[bool, socket.AddressFamily, ssl_util.SSLCipherList]:
"""Make a key for connector or session pool."""
return (verify_ssl, family)
return (verify_ssl, family, ssl_cipher)
class HomeAssistantTCPConnector(aiohttp.TCPConnector):
@ -305,21 +312,22 @@ def _async_get_connector(
hass: HomeAssistant,
verify_ssl: bool = True,
family: socket.AddressFamily = socket.AF_UNSPEC,
ssl_cipher: ssl_util.SSLCipherList = ssl_util.SSLCipherList.PYTHON_DEFAULT,
) -> aiohttp.BaseConnector:
"""Return the connector pool for aiohttp.
This method must be run in the event loop.
"""
connector_key = _make_key(verify_ssl, family)
connector_key = _make_key(verify_ssl, family, ssl_cipher)
connectors = hass.data.setdefault(DATA_CONNECTOR, {})
if connector_key in connectors:
return connectors[connector_key]
if verify_ssl:
ssl_context: SSLContext = ssl_util.get_default_context()
ssl_context: SSLContext = ssl_util.client_context(ssl_cipher)
else:
ssl_context = ssl_util.get_default_no_verify_context()
ssl_context = ssl_util.client_context_no_verify(ssl_cipher)
connector = HomeAssistantTCPConnector(
family=family,

View file

@ -15,6 +15,7 @@ class SSLCipherList(StrEnum):
PYTHON_DEFAULT = "python_default"
INTERMEDIATE = "intermediate"
MODERN = "modern"
INSECURE = "insecure"
SSL_CIPHER_LISTS = {
@ -58,11 +59,12 @@ SSL_CIPHER_LISTS = {
"ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:"
"ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256"
),
SSLCipherList.INSECURE: "DEFAULT:@SECLEVEL=0",
}
@cache
def _create_no_verify_ssl_context(ssl_cipher_list: SSLCipherList) -> ssl.SSLContext:
def _client_context_no_verify(ssl_cipher_list: SSLCipherList) -> ssl.SSLContext:
# This is a copy of aiohttp's create_default_context() function, with the
# ssl verify turned off.
# https://github.com/aio-libs/aiohttp/blob/33953f110e97eecc707e1402daa8d543f38a189b/aiohttp/connector.py#L911
@ -80,16 +82,10 @@ def _create_no_verify_ssl_context(ssl_cipher_list: SSLCipherList) -> ssl.SSLCont
return sslcontext
def create_no_verify_ssl_context(
@cache
def _client_context(
ssl_cipher_list: SSLCipherList = SSLCipherList.PYTHON_DEFAULT,
) -> ssl.SSLContext:
"""Return an SSL context that does not verify the server certificate."""
return _create_no_verify_ssl_context(ssl_cipher_list=ssl_cipher_list)
@cache
def _client_context(ssl_cipher_list: SSLCipherList) -> ssl.SSLContext:
# Reuse environment variable definition from requests, since it's already a
# requirement. If the environment variable has no value, fall back to using
# certs from certifi package.
@ -104,17 +100,19 @@ def _client_context(ssl_cipher_list: SSLCipherList) -> ssl.SSLContext:
return sslcontext
def client_context(
ssl_cipher_list: SSLCipherList = SSLCipherList.PYTHON_DEFAULT,
) -> ssl.SSLContext:
"""Return an SSL context for making requests."""
return _client_context(ssl_cipher_list=ssl_cipher_list)
# Create this only once and reuse it
_DEFAULT_SSL_CONTEXT = client_context()
_DEFAULT_NO_VERIFY_SSL_CONTEXT = create_no_verify_ssl_context()
_DEFAULT_SSL_CONTEXT = _client_context(SSLCipherList.PYTHON_DEFAULT)
_DEFAULT_NO_VERIFY_SSL_CONTEXT = _client_context_no_verify(SSLCipherList.PYTHON_DEFAULT)
_NO_VERIFY_SSL_CONTEXTS = {
SSLCipherList.INTERMEDIATE: _client_context_no_verify(SSLCipherList.INTERMEDIATE),
SSLCipherList.MODERN: _client_context_no_verify(SSLCipherList.MODERN),
SSLCipherList.INSECURE: _client_context_no_verify(SSLCipherList.INSECURE),
}
_SSL_CONTEXTS = {
SSLCipherList.INTERMEDIATE: _client_context(SSLCipherList.INTERMEDIATE),
SSLCipherList.MODERN: _client_context(SSLCipherList.MODERN),
SSLCipherList.INSECURE: _client_context(SSLCipherList.INSECURE),
}
def get_default_context() -> ssl.SSLContext:
@ -127,6 +125,27 @@ def get_default_no_verify_context() -> ssl.SSLContext:
return _DEFAULT_NO_VERIFY_SSL_CONTEXT
def client_context_no_verify(
ssl_cipher_list: SSLCipherList = SSLCipherList.PYTHON_DEFAULT,
) -> ssl.SSLContext:
"""Return a SSL context with no verification with a specific ssl cipher."""
return _NO_VERIFY_SSL_CONTEXTS.get(ssl_cipher_list, _DEFAULT_NO_VERIFY_SSL_CONTEXT)
def client_context(
ssl_cipher_list: SSLCipherList = SSLCipherList.PYTHON_DEFAULT,
) -> ssl.SSLContext:
"""Return an SSL context for making requests."""
return _SSL_CONTEXTS.get(ssl_cipher_list, _DEFAULT_SSL_CONTEXT)
def create_no_verify_ssl_context(
ssl_cipher_list: SSLCipherList = SSLCipherList.PYTHON_DEFAULT,
) -> ssl.SSLContext:
"""Return an SSL context that does not verify the server certificate."""
return _client_context_no_verify(ssl_cipher_list)
def server_context_modern() -> ssl.SSLContext:
"""Return an SSL context following the Mozilla recommendations.

View file

@ -23,6 +23,7 @@ from homeassistant.const import (
from homeassistant.core import HomeAssistant
import homeassistant.helpers.aiohttp_client as client
from homeassistant.util.color import RGBColor
from homeassistant.util.ssl import SSLCipherList
from tests.common import (
MockConfigEntry,
@ -62,11 +63,14 @@ async def test_get_clientsession_with_ssl(hass: HomeAssistant) -> None:
"""Test init clientsession with ssl."""
client.async_get_clientsession(hass)
verify_ssl = True
ssl_cipher = SSLCipherList.PYTHON_DEFAULT
family = 0
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)]
client_session = hass.data[client.DATA_CLIENTSESSION][
(verify_ssl, family, ssl_cipher)
]
assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family, ssl_cipher)]
assert isinstance(connector, aiohttp.TCPConnector)
@ -74,33 +78,63 @@ async def test_get_clientsession_without_ssl(hass: HomeAssistant) -> None:
"""Test init clientsession without ssl."""
client.async_get_clientsession(hass, verify_ssl=False)
verify_ssl = False
ssl_cipher = SSLCipherList.PYTHON_DEFAULT
family = 0
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)]
client_session = hass.data[client.DATA_CLIENTSESSION][
(verify_ssl, family, ssl_cipher)
]
assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family, ssl_cipher)]
assert isinstance(connector, aiohttp.TCPConnector)
@pytest.mark.parametrize(
("verify_ssl", "expected_family"),
("verify_ssl", "expected_family", "ssl_cipher"),
[
(True, socket.AF_UNSPEC),
(False, socket.AF_UNSPEC),
(True, socket.AF_INET),
(False, socket.AF_INET),
(True, socket.AF_INET6),
(False, socket.AF_INET6),
(True, socket.AF_UNSPEC, SSLCipherList.PYTHON_DEFAULT),
(True, socket.AF_INET, SSLCipherList.PYTHON_DEFAULT),
(True, socket.AF_INET6, SSLCipherList.PYTHON_DEFAULT),
(True, socket.AF_UNSPEC, SSLCipherList.INTERMEDIATE),
(True, socket.AF_INET, SSLCipherList.INTERMEDIATE),
(True, socket.AF_INET6, SSLCipherList.INTERMEDIATE),
(True, socket.AF_UNSPEC, SSLCipherList.MODERN),
(True, socket.AF_INET, SSLCipherList.MODERN),
(True, socket.AF_INET6, SSLCipherList.MODERN),
(True, socket.AF_UNSPEC, SSLCipherList.INSECURE),
(True, socket.AF_INET, SSLCipherList.INSECURE),
(True, socket.AF_INET6, SSLCipherList.INSECURE),
(False, socket.AF_UNSPEC, SSLCipherList.PYTHON_DEFAULT),
(False, socket.AF_INET, SSLCipherList.PYTHON_DEFAULT),
(False, socket.AF_INET6, SSLCipherList.PYTHON_DEFAULT),
(False, socket.AF_UNSPEC, SSLCipherList.INTERMEDIATE),
(False, socket.AF_INET, SSLCipherList.INTERMEDIATE),
(False, socket.AF_INET6, SSLCipherList.INTERMEDIATE),
(False, socket.AF_UNSPEC, SSLCipherList.MODERN),
(False, socket.AF_INET, SSLCipherList.MODERN),
(False, socket.AF_INET6, SSLCipherList.MODERN),
(False, socket.AF_UNSPEC, SSLCipherList.INSECURE),
(False, socket.AF_INET, SSLCipherList.INSECURE),
(False, socket.AF_INET6, SSLCipherList.INSECURE),
],
)
async def test_get_clientsession(
hass: HomeAssistant, verify_ssl: bool, expected_family: int
hass: HomeAssistant,
verify_ssl: bool,
expected_family: int,
ssl_cipher: SSLCipherList,
) -> None:
"""Test init clientsession combinations."""
client.async_get_clientsession(hass, verify_ssl=verify_ssl, family=expected_family)
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, expected_family)]
client.async_get_clientsession(
hass, verify_ssl=verify_ssl, family=expected_family, ssl_cipher=ssl_cipher
)
client_session = hass.data[client.DATA_CLIENTSESSION][
(verify_ssl, expected_family, ssl_cipher)
]
assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, expected_family)]
connector = hass.data[client.DATA_CONNECTOR][
(verify_ssl, expected_family, ssl_cipher)
]
assert isinstance(connector, aiohttp.TCPConnector)
@ -110,10 +144,11 @@ async def test_create_clientsession_with_ssl_and_cookies(hass: HomeAssistant) ->
assert isinstance(session, aiohttp.ClientSession)
verify_ssl = True
ssl_cipher = SSLCipherList.PYTHON_DEFAULT
family = 0
assert client.DATA_CLIENTSESSION not in hass.data
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family, ssl_cipher)]
assert isinstance(connector, aiohttp.TCPConnector)
@ -125,26 +160,61 @@ async def test_create_clientsession_without_ssl_and_cookies(
assert isinstance(session, aiohttp.ClientSession)
verify_ssl = False
ssl_cipher = SSLCipherList.PYTHON_DEFAULT
family = 0
assert client.DATA_CLIENTSESSION not in hass.data
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family)]
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, family, ssl_cipher)]
assert isinstance(connector, aiohttp.TCPConnector)
@pytest.mark.parametrize(
("verify_ssl", "expected_family"),
[(True, 0), (False, 0), (True, 4), (False, 4), (True, 6), (False, 6)],
("verify_ssl", "expected_family", "ssl_cipher"),
[
(True, 0, SSLCipherList.PYTHON_DEFAULT),
(True, 4, SSLCipherList.PYTHON_DEFAULT),
(True, 6, SSLCipherList.PYTHON_DEFAULT),
(True, 0, SSLCipherList.INTERMEDIATE),
(True, 4, SSLCipherList.INTERMEDIATE),
(True, 6, SSLCipherList.INTERMEDIATE),
(True, 0, SSLCipherList.MODERN),
(True, 4, SSLCipherList.MODERN),
(True, 6, SSLCipherList.MODERN),
(True, 0, SSLCipherList.INSECURE),
(True, 4, SSLCipherList.INSECURE),
(True, 6, SSLCipherList.INSECURE),
(False, 0, SSLCipherList.PYTHON_DEFAULT),
(False, 4, SSLCipherList.PYTHON_DEFAULT),
(False, 6, SSLCipherList.PYTHON_DEFAULT),
(False, 0, SSLCipherList.INTERMEDIATE),
(False, 4, SSLCipherList.INTERMEDIATE),
(False, 6, SSLCipherList.INTERMEDIATE),
(False, 0, SSLCipherList.MODERN),
(False, 4, SSLCipherList.MODERN),
(False, 6, SSLCipherList.MODERN),
(False, 0, SSLCipherList.INSECURE),
(False, 4, SSLCipherList.INSECURE),
(False, 6, SSLCipherList.INSECURE),
],
)
async def test_get_clientsession_cleanup(
hass: HomeAssistant, verify_ssl: bool, expected_family: int
hass: HomeAssistant,
verify_ssl: bool,
expected_family: int,
ssl_cipher: SSLCipherList,
) -> None:
"""Test init clientsession cleanup."""
client.async_get_clientsession(hass, verify_ssl=verify_ssl, family=expected_family)
client.async_get_clientsession(
hass, verify_ssl=verify_ssl, family=expected_family, ssl_cipher=ssl_cipher
)
client_session = hass.data[client.DATA_CLIENTSESSION][(verify_ssl, expected_family)]
client_session = hass.data[client.DATA_CLIENTSESSION][
(verify_ssl, expected_family, ssl_cipher)
]
assert isinstance(client_session, aiohttp.ClientSession)
connector = hass.data[client.DATA_CONNECTOR][(verify_ssl, expected_family)]
connector = hass.data[client.DATA_CONNECTOR][
(verify_ssl, expected_family, ssl_cipher)
]
assert isinstance(connector, aiohttp.TCPConnector)
hass.bus.async_fire(EVENT_HOMEASSISTANT_CLOSE)
@ -158,17 +228,19 @@ async def test_get_clientsession_patched_close(hass: HomeAssistant) -> None:
"""Test closing clientsession does not work."""
verify_ssl = True
ssl_cipher = SSLCipherList.PYTHON_DEFAULT
family = 0
with patch("aiohttp.ClientSession.close") as mock_close:
session = client.async_get_clientsession(hass)
assert isinstance(
hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family)],
hass.data[client.DATA_CLIENTSESSION][(verify_ssl, family, ssl_cipher)],
aiohttp.ClientSession,
)
assert isinstance(
hass.data[client.DATA_CONNECTOR][(verify_ssl, family)], aiohttp.TCPConnector
hass.data[client.DATA_CONNECTOR][(verify_ssl, family, ssl_cipher)],
aiohttp.TCPConnector,
)
with pytest.raises(RuntimeError):

View file

@ -5,7 +5,6 @@ from unittest.mock import MagicMock, Mock, patch
import pytest
from homeassistant.util.ssl import (
SSL_CIPHER_LISTS,
SSLCipherList,
client_context,
create_no_verify_ssl_context,
@ -25,14 +24,13 @@ def test_client_context(mock_sslcontext) -> None:
mock_sslcontext.set_ciphers.assert_not_called()
client_context(SSLCipherList.MODERN)
mock_sslcontext.set_ciphers.assert_called_with(
SSL_CIPHER_LISTS[SSLCipherList.MODERN]
)
mock_sslcontext.set_ciphers.assert_not_called()
client_context(SSLCipherList.INTERMEDIATE)
mock_sslcontext.set_ciphers.assert_called_with(
SSL_CIPHER_LISTS[SSLCipherList.INTERMEDIATE]
)
mock_sslcontext.set_ciphers.assert_not_called()
client_context(SSLCipherList.INSECURE)
mock_sslcontext.set_ciphers.assert_not_called()
def test_no_verify_ssl_context(mock_sslcontext) -> None:
@ -42,14 +40,13 @@ def test_no_verify_ssl_context(mock_sslcontext) -> None:
mock_sslcontext.set_ciphers.assert_not_called()
create_no_verify_ssl_context(SSLCipherList.MODERN)
mock_sslcontext.set_ciphers.assert_called_with(
SSL_CIPHER_LISTS[SSLCipherList.MODERN]
)
mock_sslcontext.set_ciphers.assert_not_called()
create_no_verify_ssl_context(SSLCipherList.INTERMEDIATE)
mock_sslcontext.set_ciphers.assert_called_with(
SSL_CIPHER_LISTS[SSLCipherList.INTERMEDIATE]
)
mock_sslcontext.set_ciphers.assert_not_called()
create_no_verify_ssl_context(SSLCipherList.INSECURE)
mock_sslcontext.set_ciphers.assert_not_called()
def test_ssl_context_caching() -> None: