Fix snmp doing blocking I/O in the event loop (#118521)

This commit is contained in:
J. Nick Koston 2024-05-31 02:44:28 -10:00 committed by GitHub
parent a23b5e97e6
commit 76391d71d6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 176 additions and 111 deletions

View file

@ -1 +1,5 @@
"""The snmp component."""
from .util import async_get_snmp_engine
__all__ = ["async_get_snmp_engine"]

View file

@ -4,14 +4,11 @@ from __future__ import annotations
import binascii
import logging
from typing import TYPE_CHECKING
from pysnmp.error import PySnmpError
from pysnmp.hlapi.asyncio import (
CommunityData,
ContextData,
ObjectIdentity,
ObjectType,
SnmpEngine,
Udp6TransportTarget,
UdpTransportTarget,
UsmUserData,
@ -43,6 +40,7 @@ from .const import (
DEFAULT_VERSION,
SNMP_VERSIONS,
)
from .util import RequestArgsType, async_create_request_cmd_args
_LOGGER = logging.getLogger(__name__)
@ -62,7 +60,7 @@ async def async_get_scanner(
) -> SnmpScanner | None:
"""Validate the configuration and return an SNMP scanner."""
scanner = SnmpScanner(config[DOMAIN])
await scanner.async_init()
await scanner.async_init(hass)
return scanner if scanner.success_init else None
@ -99,33 +97,29 @@ class SnmpScanner(DeviceScanner):
if not privkey:
privproto = "none"
request_args = [
SnmpEngine(),
UsmUserData(
community,
authKey=authkey or None,
privKey=privkey or None,
authProtocol=authproto,
privProtocol=privproto,
),
target,
ContextData(),
]
self._auth_data = UsmUserData(
community,
authKey=authkey or None,
privKey=privkey or None,
authProtocol=authproto,
privProtocol=privproto,
)
else:
request_args = [
SnmpEngine(),
CommunityData(community, mpModel=SNMP_VERSIONS[DEFAULT_VERSION]),
target,
ContextData(),
]
self._auth_data = CommunityData(
community, mpModel=SNMP_VERSIONS[DEFAULT_VERSION]
)
self.request_args = request_args
self._target = target
self.request_args: RequestArgsType | None = None
self.baseoid = baseoid
self.last_results = []
self.success_init = False
async def async_init(self):
async def async_init(self, hass: HomeAssistant) -> None:
"""Make a one-off read to check if the target device is reachable and readable."""
self.request_args = await async_create_request_cmd_args(
hass, self._auth_data, self._target, self.baseoid
)
data = await self.async_get_snmp_data()
self.success_init = data is not None
@ -156,12 +150,18 @@ class SnmpScanner(DeviceScanner):
async def async_get_snmp_data(self):
"""Fetch MAC addresses from access point via SNMP."""
devices = []
if TYPE_CHECKING:
assert self.request_args is not None
engine, auth_data, target, context_data, object_type = self.request_args
walker = bulkWalkCmd(
*self.request_args,
engine,
auth_data,
target,
context_data,
0,
50,
ObjectType(ObjectIdentity(self.baseoid)),
object_type,
lexicographicMode=False,
)
async for errindication, errstatus, errindex, res in walker:

View file

@ -11,10 +11,6 @@ from pysnmp.error import PySnmpError
import pysnmp.hlapi.asyncio as hlapi
from pysnmp.hlapi.asyncio import (
CommunityData,
ContextData,
ObjectIdentity,
ObjectType,
SnmpEngine,
Udp6TransportTarget,
UdpTransportTarget,
UsmUserData,
@ -71,6 +67,7 @@ from .const import (
MAP_PRIV_PROTOCOLS,
SNMP_VERSIONS,
)
from .util import async_create_request_cmd_args
_LOGGER = logging.getLogger(__name__)
@ -119,7 +116,7 @@ async def async_setup_platform(
host = config.get(CONF_HOST)
port = config.get(CONF_PORT)
community = config.get(CONF_COMMUNITY)
baseoid = config.get(CONF_BASEOID)
baseoid: str = config[CONF_BASEOID]
version = config[CONF_VERSION]
username = config.get(CONF_USERNAME)
authkey = config.get(CONF_AUTH_KEY)
@ -145,27 +142,18 @@ async def async_setup_platform(
authproto = "none"
if not privkey:
privproto = "none"
request_args = [
SnmpEngine(),
UsmUserData(
username,
authKey=authkey or None,
privKey=privkey or None,
authProtocol=getattr(hlapi, MAP_AUTH_PROTOCOLS[authproto]),
privProtocol=getattr(hlapi, MAP_PRIV_PROTOCOLS[privproto]),
),
target,
ContextData(),
]
auth_data = UsmUserData(
username,
authKey=authkey or None,
privKey=privkey or None,
authProtocol=getattr(hlapi, MAP_AUTH_PROTOCOLS[authproto]),
privProtocol=getattr(hlapi, MAP_PRIV_PROTOCOLS[privproto]),
)
else:
request_args = [
SnmpEngine(),
CommunityData(community, mpModel=SNMP_VERSIONS[version]),
target,
ContextData(),
]
get_result = await getCmd(*request_args, ObjectType(ObjectIdentity(baseoid)))
auth_data = CommunityData(community, mpModel=SNMP_VERSIONS[version])
request_args = await async_create_request_cmd_args(hass, auth_data, target, baseoid)
get_result = await getCmd(*request_args)
errindication, _, _, _ = get_result
if errindication and not accept_errors:
@ -244,9 +232,7 @@ class SnmpData:
async def async_update(self):
"""Get the latest data from the remote SNMP capable host."""
get_result = await getCmd(
*self._request_args, ObjectType(ObjectIdentity(self._baseoid))
)
get_result = await getCmd(*self._request_args)
errindication, errstatus, errindex, restable = get_result
if errindication and not self._accept_errors:

View file

@ -8,10 +8,6 @@ from typing import Any
import pysnmp.hlapi.asyncio as hlapi
from pysnmp.hlapi.asyncio import (
CommunityData,
ContextData,
ObjectIdentity,
ObjectType,
SnmpEngine,
UdpTransportTarget,
UsmUserData,
getCmd,
@ -67,6 +63,7 @@ from .const import (
MAP_PRIV_PROTOCOLS,
SNMP_VERSIONS,
)
from .util import RequestArgsType, async_create_request_cmd_args
_LOGGER = logging.getLogger(__name__)
@ -132,40 +129,54 @@ async def async_setup_platform(
host = config.get(CONF_HOST)
port = config.get(CONF_PORT)
community = config.get(CONF_COMMUNITY)
baseoid = config.get(CONF_BASEOID)
baseoid: str = config[CONF_BASEOID]
command_oid = config.get(CONF_COMMAND_OID)
command_payload_on = config.get(CONF_COMMAND_PAYLOAD_ON)
command_payload_off = config.get(CONF_COMMAND_PAYLOAD_OFF)
version = config.get(CONF_VERSION)
version: str = config[CONF_VERSION]
username = config.get(CONF_USERNAME)
authkey = config.get(CONF_AUTH_KEY)
authproto = config.get(CONF_AUTH_PROTOCOL)
authproto: str = config[CONF_AUTH_PROTOCOL]
privkey = config.get(CONF_PRIV_KEY)
privproto = config.get(CONF_PRIV_PROTOCOL)
privproto: str = config[CONF_PRIV_PROTOCOL]
payload_on = config.get(CONF_PAYLOAD_ON)
payload_off = config.get(CONF_PAYLOAD_OFF)
vartype = config.get(CONF_VARTYPE)
if version == "3":
if not authkey:
authproto = "none"
if not privkey:
privproto = "none"
auth_data = UsmUserData(
username,
authKey=authkey or None,
privKey=privkey or None,
authProtocol=getattr(hlapi, MAP_AUTH_PROTOCOLS[authproto]),
privProtocol=getattr(hlapi, MAP_PRIV_PROTOCOLS[privproto]),
)
else:
auth_data = CommunityData(community, mpModel=SNMP_VERSIONS[version])
request_args = await async_create_request_cmd_args(
hass, auth_data, UdpTransportTarget((host, port)), baseoid
)
async_add_entities(
[
SnmpSwitch(
name,
host,
port,
community,
baseoid,
command_oid,
version,
username,
authkey,
authproto,
privkey,
privproto,
payload_on,
payload_off,
command_payload_on,
command_payload_off,
vartype,
request_args,
)
],
True,
@ -180,21 +191,15 @@ class SnmpSwitch(SwitchEntity):
name,
host,
port,
community,
baseoid,
commandoid,
version,
username,
authkey,
authproto,
privkey,
privproto,
payload_on,
payload_off,
command_payload_on,
command_payload_off,
vartype,
):
request_args,
) -> None:
"""Initialize the switch."""
self._name = name
@ -206,35 +211,11 @@ class SnmpSwitch(SwitchEntity):
self._command_payload_on = command_payload_on or payload_on
self._command_payload_off = command_payload_off or payload_off
self._state = None
self._state: bool | None = None
self._payload_on = payload_on
self._payload_off = payload_off
if version == "3":
if not authkey:
authproto = "none"
if not privkey:
privproto = "none"
self._request_args = [
SnmpEngine(),
UsmUserData(
username,
authKey=authkey or None,
privKey=privkey or None,
authProtocol=getattr(hlapi, MAP_AUTH_PROTOCOLS[authproto]),
privProtocol=getattr(hlapi, MAP_PRIV_PROTOCOLS[privproto]),
),
UdpTransportTarget((host, port)),
ContextData(),
]
else:
self._request_args = [
SnmpEngine(),
CommunityData(community, mpModel=SNMP_VERSIONS[version]),
UdpTransportTarget((host, port)),
ContextData(),
]
self._target = UdpTransportTarget((host, port))
self._request_args: RequestArgsType = request_args
async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on the switch."""
@ -259,9 +240,7 @@ class SnmpSwitch(SwitchEntity):
async def async_update(self) -> None:
"""Update the state."""
get_result = await getCmd(
*self._request_args, ObjectType(ObjectIdentity(self._baseoid))
)
get_result = await getCmd(*self._request_args)
errindication, errstatus, errindex, restable = get_result
if errindication:
@ -296,6 +275,4 @@ class SnmpSwitch(SwitchEntity):
return self._state
async def _set(self, value):
await setCmd(
*self._request_args, ObjectType(ObjectIdentity(self._commandoid), value)
)
await setCmd(*self._request_args, value)

View file

@ -0,0 +1,76 @@
"""Support for displaying collected data over SNMP."""
from __future__ import annotations
import logging
from pysnmp.hlapi.asyncio import (
CommunityData,
ContextData,
ObjectIdentity,
ObjectType,
SnmpEngine,
Udp6TransportTarget,
UdpTransportTarget,
UsmUserData,
)
from pysnmp.hlapi.asyncio.cmdgen import lcd, vbProcessor
from pysnmp.smi.builder import MibBuilder
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.helpers.singleton import singleton
DATA_SNMP_ENGINE = "snmp_engine"
_LOGGER = logging.getLogger(__name__)
type RequestArgsType = tuple[
SnmpEngine,
UsmUserData | CommunityData,
UdpTransportTarget | Udp6TransportTarget,
ContextData,
ObjectType,
]
async def async_create_request_cmd_args(
hass: HomeAssistant,
auth_data: UsmUserData | CommunityData,
target: UdpTransportTarget | Udp6TransportTarget,
object_id: str,
) -> RequestArgsType:
"""Create request arguments."""
return (
await async_get_snmp_engine(hass),
auth_data,
target,
ContextData(),
ObjectType(ObjectIdentity(object_id)),
)
@singleton(DATA_SNMP_ENGINE)
async def async_get_snmp_engine(hass: HomeAssistant) -> SnmpEngine:
"""Get the SNMP engine."""
engine = await hass.async_add_executor_job(_get_snmp_engine)
@callback
def _async_shutdown_listener(ev: Event) -> None:
_LOGGER.debug("Unconfiguring SNMP engine")
lcd.unconfigure(engine, None)
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_shutdown_listener)
return engine
def _get_snmp_engine() -> SnmpEngine:
"""Return a cached instance of SnmpEngine."""
engine = SnmpEngine()
mib_controller = vbProcessor.getMibViewController(engine)
# Actually load the MIBs from disk so we do
# not do it in the event loop
builder: MibBuilder = mib_controller.mibBuilder
if "PYSNMP-MIB" not in builder.mibSymbols:
builder.loadModules()
return engine

View file

@ -0,0 +1,22 @@
"""SNMP tests."""
from unittest.mock import patch
from pysnmp.hlapi.asyncio import SnmpEngine
from pysnmp.hlapi.asyncio.cmdgen import lcd
from homeassistant.components import snmp
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import HomeAssistant
async def test_async_get_snmp_engine(hass: HomeAssistant) -> None:
"""Test async_get_snmp_engine."""
engine = await snmp.async_get_snmp_engine(hass)
assert isinstance(engine, SnmpEngine)
engine2 = await snmp.async_get_snmp_engine(hass)
assert engine is engine2
with patch.object(lcd, "unconfigure") as mock_unconfigure:
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
assert mock_unconfigure.called