diff --git a/homeassistant/components/snmp/__init__.py b/homeassistant/components/snmp/__init__.py index a4c922877f3..4a049ee1553 100644 --- a/homeassistant/components/snmp/__init__.py +++ b/homeassistant/components/snmp/__init__.py @@ -1 +1,5 @@ """The snmp component.""" + +from .util import async_get_snmp_engine + +__all__ = ["async_get_snmp_engine"] diff --git a/homeassistant/components/snmp/device_tracker.py b/homeassistant/components/snmp/device_tracker.py index 5d4f9e5e0d9..d336838117f 100644 --- a/homeassistant/components/snmp/device_tracker.py +++ b/homeassistant/components/snmp/device_tracker.py @@ -4,14 +4,11 @@ from __future__ import annotations import binascii import logging +from typing import TYPE_CHECKING from pysnmp.error import PySnmpError from pysnmp.hlapi.asyncio import ( CommunityData, - ContextData, - ObjectIdentity, - ObjectType, - SnmpEngine, Udp6TransportTarget, UdpTransportTarget, UsmUserData, @@ -43,6 +40,7 @@ from .const import ( DEFAULT_VERSION, SNMP_VERSIONS, ) +from .util import RequestArgsType, async_create_request_cmd_args _LOGGER = logging.getLogger(__name__) @@ -62,7 +60,7 @@ async def async_get_scanner( ) -> SnmpScanner | None: """Validate the configuration and return an SNMP scanner.""" scanner = SnmpScanner(config[DOMAIN]) - await scanner.async_init() + await scanner.async_init(hass) return scanner if scanner.success_init else None @@ -99,33 +97,29 @@ class SnmpScanner(DeviceScanner): if not privkey: privproto = "none" - request_args = [ - SnmpEngine(), - UsmUserData( - community, - authKey=authkey or None, - privKey=privkey or None, - authProtocol=authproto, - privProtocol=privproto, - ), - target, - ContextData(), - ] + self._auth_data = UsmUserData( + community, + authKey=authkey or None, + privKey=privkey or None, + authProtocol=authproto, + privProtocol=privproto, + ) else: - request_args = [ - SnmpEngine(), - CommunityData(community, mpModel=SNMP_VERSIONS[DEFAULT_VERSION]), - target, - ContextData(), - ] + self._auth_data = CommunityData( + community, mpModel=SNMP_VERSIONS[DEFAULT_VERSION] + ) - self.request_args = request_args + self._target = target + self.request_args: RequestArgsType | None = None self.baseoid = baseoid self.last_results = [] self.success_init = False - async def async_init(self): + async def async_init(self, hass: HomeAssistant) -> None: """Make a one-off read to check if the target device is reachable and readable.""" + self.request_args = await async_create_request_cmd_args( + hass, self._auth_data, self._target, self.baseoid + ) data = await self.async_get_snmp_data() self.success_init = data is not None @@ -156,12 +150,18 @@ class SnmpScanner(DeviceScanner): async def async_get_snmp_data(self): """Fetch MAC addresses from access point via SNMP.""" devices = [] + if TYPE_CHECKING: + assert self.request_args is not None + engine, auth_data, target, context_data, object_type = self.request_args walker = bulkWalkCmd( - *self.request_args, + engine, + auth_data, + target, + context_data, 0, 50, - ObjectType(ObjectIdentity(self.baseoid)), + object_type, lexicographicMode=False, ) async for errindication, errstatus, errindex, res in walker: diff --git a/homeassistant/components/snmp/sensor.py b/homeassistant/components/snmp/sensor.py index 939cb13ae35..0e5b215dcd4 100644 --- a/homeassistant/components/snmp/sensor.py +++ b/homeassistant/components/snmp/sensor.py @@ -11,10 +11,6 @@ from pysnmp.error import PySnmpError import pysnmp.hlapi.asyncio as hlapi from pysnmp.hlapi.asyncio import ( CommunityData, - ContextData, - ObjectIdentity, - ObjectType, - SnmpEngine, Udp6TransportTarget, UdpTransportTarget, UsmUserData, @@ -71,6 +67,7 @@ from .const import ( MAP_PRIV_PROTOCOLS, SNMP_VERSIONS, ) +from .util import async_create_request_cmd_args _LOGGER = logging.getLogger(__name__) @@ -119,7 +116,7 @@ async def async_setup_platform( host = config.get(CONF_HOST) port = config.get(CONF_PORT) community = config.get(CONF_COMMUNITY) - baseoid = config.get(CONF_BASEOID) + baseoid: str = config[CONF_BASEOID] version = config[CONF_VERSION] username = config.get(CONF_USERNAME) authkey = config.get(CONF_AUTH_KEY) @@ -145,27 +142,18 @@ async def async_setup_platform( authproto = "none" if not privkey: privproto = "none" - - request_args = [ - SnmpEngine(), - UsmUserData( - username, - authKey=authkey or None, - privKey=privkey or None, - authProtocol=getattr(hlapi, MAP_AUTH_PROTOCOLS[authproto]), - privProtocol=getattr(hlapi, MAP_PRIV_PROTOCOLS[privproto]), - ), - target, - ContextData(), - ] + auth_data = UsmUserData( + username, + authKey=authkey or None, + privKey=privkey or None, + authProtocol=getattr(hlapi, MAP_AUTH_PROTOCOLS[authproto]), + privProtocol=getattr(hlapi, MAP_PRIV_PROTOCOLS[privproto]), + ) else: - request_args = [ - SnmpEngine(), - CommunityData(community, mpModel=SNMP_VERSIONS[version]), - target, - ContextData(), - ] - get_result = await getCmd(*request_args, ObjectType(ObjectIdentity(baseoid))) + auth_data = CommunityData(community, mpModel=SNMP_VERSIONS[version]) + + request_args = await async_create_request_cmd_args(hass, auth_data, target, baseoid) + get_result = await getCmd(*request_args) errindication, _, _, _ = get_result if errindication and not accept_errors: @@ -244,9 +232,7 @@ class SnmpData: async def async_update(self): """Get the latest data from the remote SNMP capable host.""" - get_result = await getCmd( - *self._request_args, ObjectType(ObjectIdentity(self._baseoid)) - ) + get_result = await getCmd(*self._request_args) errindication, errstatus, errindex, restable = get_result if errindication and not self._accept_errors: diff --git a/homeassistant/components/snmp/switch.py b/homeassistant/components/snmp/switch.py index a447cdc8e9c..40083ed4213 100644 --- a/homeassistant/components/snmp/switch.py +++ b/homeassistant/components/snmp/switch.py @@ -8,10 +8,6 @@ from typing import Any import pysnmp.hlapi.asyncio as hlapi from pysnmp.hlapi.asyncio import ( CommunityData, - ContextData, - ObjectIdentity, - ObjectType, - SnmpEngine, UdpTransportTarget, UsmUserData, getCmd, @@ -67,6 +63,7 @@ from .const import ( MAP_PRIV_PROTOCOLS, SNMP_VERSIONS, ) +from .util import RequestArgsType, async_create_request_cmd_args _LOGGER = logging.getLogger(__name__) @@ -132,40 +129,54 @@ async def async_setup_platform( host = config.get(CONF_HOST) port = config.get(CONF_PORT) community = config.get(CONF_COMMUNITY) - baseoid = config.get(CONF_BASEOID) + baseoid: str = config[CONF_BASEOID] command_oid = config.get(CONF_COMMAND_OID) command_payload_on = config.get(CONF_COMMAND_PAYLOAD_ON) command_payload_off = config.get(CONF_COMMAND_PAYLOAD_OFF) - version = config.get(CONF_VERSION) + version: str = config[CONF_VERSION] username = config.get(CONF_USERNAME) authkey = config.get(CONF_AUTH_KEY) - authproto = config.get(CONF_AUTH_PROTOCOL) + authproto: str = config[CONF_AUTH_PROTOCOL] privkey = config.get(CONF_PRIV_KEY) - privproto = config.get(CONF_PRIV_PROTOCOL) + privproto: str = config[CONF_PRIV_PROTOCOL] payload_on = config.get(CONF_PAYLOAD_ON) payload_off = config.get(CONF_PAYLOAD_OFF) vartype = config.get(CONF_VARTYPE) + if version == "3": + if not authkey: + authproto = "none" + if not privkey: + privproto = "none" + + auth_data = UsmUserData( + username, + authKey=authkey or None, + privKey=privkey or None, + authProtocol=getattr(hlapi, MAP_AUTH_PROTOCOLS[authproto]), + privProtocol=getattr(hlapi, MAP_PRIV_PROTOCOLS[privproto]), + ) + else: + auth_data = CommunityData(community, mpModel=SNMP_VERSIONS[version]) + + request_args = await async_create_request_cmd_args( + hass, auth_data, UdpTransportTarget((host, port)), baseoid + ) + async_add_entities( [ SnmpSwitch( name, host, port, - community, baseoid, command_oid, - version, - username, - authkey, - authproto, - privkey, - privproto, payload_on, payload_off, command_payload_on, command_payload_off, vartype, + request_args, ) ], True, @@ -180,21 +191,15 @@ class SnmpSwitch(SwitchEntity): name, host, port, - community, baseoid, commandoid, - version, - username, - authkey, - authproto, - privkey, - privproto, payload_on, payload_off, command_payload_on, command_payload_off, vartype, - ): + request_args, + ) -> None: """Initialize the switch.""" self._name = name @@ -206,35 +211,11 @@ class SnmpSwitch(SwitchEntity): self._command_payload_on = command_payload_on or payload_on self._command_payload_off = command_payload_off or payload_off - self._state = None + self._state: bool | None = None self._payload_on = payload_on self._payload_off = payload_off - - if version == "3": - if not authkey: - authproto = "none" - if not privkey: - privproto = "none" - - self._request_args = [ - SnmpEngine(), - UsmUserData( - username, - authKey=authkey or None, - privKey=privkey or None, - authProtocol=getattr(hlapi, MAP_AUTH_PROTOCOLS[authproto]), - privProtocol=getattr(hlapi, MAP_PRIV_PROTOCOLS[privproto]), - ), - UdpTransportTarget((host, port)), - ContextData(), - ] - else: - self._request_args = [ - SnmpEngine(), - CommunityData(community, mpModel=SNMP_VERSIONS[version]), - UdpTransportTarget((host, port)), - ContextData(), - ] + self._target = UdpTransportTarget((host, port)) + self._request_args: RequestArgsType = request_args async def async_turn_on(self, **kwargs: Any) -> None: """Turn on the switch.""" @@ -259,9 +240,7 @@ class SnmpSwitch(SwitchEntity): async def async_update(self) -> None: """Update the state.""" - get_result = await getCmd( - *self._request_args, ObjectType(ObjectIdentity(self._baseoid)) - ) + get_result = await getCmd(*self._request_args) errindication, errstatus, errindex, restable = get_result if errindication: @@ -296,6 +275,4 @@ class SnmpSwitch(SwitchEntity): return self._state async def _set(self, value): - await setCmd( - *self._request_args, ObjectType(ObjectIdentity(self._commandoid), value) - ) + await setCmd(*self._request_args, value) diff --git a/homeassistant/components/snmp/util.py b/homeassistant/components/snmp/util.py new file mode 100644 index 00000000000..23adbdf0b90 --- /dev/null +++ b/homeassistant/components/snmp/util.py @@ -0,0 +1,76 @@ +"""Support for displaying collected data over SNMP.""" + +from __future__ import annotations + +import logging + +from pysnmp.hlapi.asyncio import ( + CommunityData, + ContextData, + ObjectIdentity, + ObjectType, + SnmpEngine, + Udp6TransportTarget, + UdpTransportTarget, + UsmUserData, +) +from pysnmp.hlapi.asyncio.cmdgen import lcd, vbProcessor +from pysnmp.smi.builder import MibBuilder + +from homeassistant.const import EVENT_HOMEASSISTANT_STOP +from homeassistant.core import Event, HomeAssistant, callback +from homeassistant.helpers.singleton import singleton + +DATA_SNMP_ENGINE = "snmp_engine" + +_LOGGER = logging.getLogger(__name__) + +type RequestArgsType = tuple[ + SnmpEngine, + UsmUserData | CommunityData, + UdpTransportTarget | Udp6TransportTarget, + ContextData, + ObjectType, +] + + +async def async_create_request_cmd_args( + hass: HomeAssistant, + auth_data: UsmUserData | CommunityData, + target: UdpTransportTarget | Udp6TransportTarget, + object_id: str, +) -> RequestArgsType: + """Create request arguments.""" + return ( + await async_get_snmp_engine(hass), + auth_data, + target, + ContextData(), + ObjectType(ObjectIdentity(object_id)), + ) + + +@singleton(DATA_SNMP_ENGINE) +async def async_get_snmp_engine(hass: HomeAssistant) -> SnmpEngine: + """Get the SNMP engine.""" + engine = await hass.async_add_executor_job(_get_snmp_engine) + + @callback + def _async_shutdown_listener(ev: Event) -> None: + _LOGGER.debug("Unconfiguring SNMP engine") + lcd.unconfigure(engine, None) + + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_shutdown_listener) + return engine + + +def _get_snmp_engine() -> SnmpEngine: + """Return a cached instance of SnmpEngine.""" + engine = SnmpEngine() + mib_controller = vbProcessor.getMibViewController(engine) + # Actually load the MIBs from disk so we do + # not do it in the event loop + builder: MibBuilder = mib_controller.mibBuilder + if "PYSNMP-MIB" not in builder.mibSymbols: + builder.loadModules() + return engine diff --git a/tests/components/snmp/test_init.py b/tests/components/snmp/test_init.py new file mode 100644 index 00000000000..0aa97dcc475 --- /dev/null +++ b/tests/components/snmp/test_init.py @@ -0,0 +1,22 @@ +"""SNMP tests.""" + +from unittest.mock import patch + +from pysnmp.hlapi.asyncio import SnmpEngine +from pysnmp.hlapi.asyncio.cmdgen import lcd + +from homeassistant.components import snmp +from homeassistant.const import EVENT_HOMEASSISTANT_STOP +from homeassistant.core import HomeAssistant + + +async def test_async_get_snmp_engine(hass: HomeAssistant) -> None: + """Test async_get_snmp_engine.""" + engine = await snmp.async_get_snmp_engine(hass) + assert isinstance(engine, SnmpEngine) + engine2 = await snmp.async_get_snmp_engine(hass) + assert engine is engine2 + with patch.object(lcd, "unconfigure") as mock_unconfigure: + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + assert mock_unconfigure.called