Fix snmp doing blocking I/O in the event loop (#118521)
This commit is contained in:
parent
a23b5e97e6
commit
76391d71d6
6 changed files with 176 additions and 111 deletions
|
@ -1 +1,5 @@
|
|||
"""The snmp component."""
|
||||
|
||||
from .util import async_get_snmp_engine
|
||||
|
||||
__all__ = ["async_get_snmp_engine"]
|
||||
|
|
|
@ -4,14 +4,11 @@ from __future__ import annotations
|
|||
|
||||
import binascii
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pysnmp.error import PySnmpError
|
||||
from pysnmp.hlapi.asyncio import (
|
||||
CommunityData,
|
||||
ContextData,
|
||||
ObjectIdentity,
|
||||
ObjectType,
|
||||
SnmpEngine,
|
||||
Udp6TransportTarget,
|
||||
UdpTransportTarget,
|
||||
UsmUserData,
|
||||
|
@ -43,6 +40,7 @@ from .const import (
|
|||
DEFAULT_VERSION,
|
||||
SNMP_VERSIONS,
|
||||
)
|
||||
from .util import RequestArgsType, async_create_request_cmd_args
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -62,7 +60,7 @@ async def async_get_scanner(
|
|||
) -> SnmpScanner | None:
|
||||
"""Validate the configuration and return an SNMP scanner."""
|
||||
scanner = SnmpScanner(config[DOMAIN])
|
||||
await scanner.async_init()
|
||||
await scanner.async_init(hass)
|
||||
|
||||
return scanner if scanner.success_init else None
|
||||
|
||||
|
@ -99,33 +97,29 @@ class SnmpScanner(DeviceScanner):
|
|||
if not privkey:
|
||||
privproto = "none"
|
||||
|
||||
request_args = [
|
||||
SnmpEngine(),
|
||||
UsmUserData(
|
||||
community,
|
||||
authKey=authkey or None,
|
||||
privKey=privkey or None,
|
||||
authProtocol=authproto,
|
||||
privProtocol=privproto,
|
||||
),
|
||||
target,
|
||||
ContextData(),
|
||||
]
|
||||
self._auth_data = UsmUserData(
|
||||
community,
|
||||
authKey=authkey or None,
|
||||
privKey=privkey or None,
|
||||
authProtocol=authproto,
|
||||
privProtocol=privproto,
|
||||
)
|
||||
else:
|
||||
request_args = [
|
||||
SnmpEngine(),
|
||||
CommunityData(community, mpModel=SNMP_VERSIONS[DEFAULT_VERSION]),
|
||||
target,
|
||||
ContextData(),
|
||||
]
|
||||
self._auth_data = CommunityData(
|
||||
community, mpModel=SNMP_VERSIONS[DEFAULT_VERSION]
|
||||
)
|
||||
|
||||
self.request_args = request_args
|
||||
self._target = target
|
||||
self.request_args: RequestArgsType | None = None
|
||||
self.baseoid = baseoid
|
||||
self.last_results = []
|
||||
self.success_init = False
|
||||
|
||||
async def async_init(self):
|
||||
async def async_init(self, hass: HomeAssistant) -> None:
|
||||
"""Make a one-off read to check if the target device is reachable and readable."""
|
||||
self.request_args = await async_create_request_cmd_args(
|
||||
hass, self._auth_data, self._target, self.baseoid
|
||||
)
|
||||
data = await self.async_get_snmp_data()
|
||||
self.success_init = data is not None
|
||||
|
||||
|
@ -156,12 +150,18 @@ class SnmpScanner(DeviceScanner):
|
|||
async def async_get_snmp_data(self):
|
||||
"""Fetch MAC addresses from access point via SNMP."""
|
||||
devices = []
|
||||
if TYPE_CHECKING:
|
||||
assert self.request_args is not None
|
||||
|
||||
engine, auth_data, target, context_data, object_type = self.request_args
|
||||
walker = bulkWalkCmd(
|
||||
*self.request_args,
|
||||
engine,
|
||||
auth_data,
|
||||
target,
|
||||
context_data,
|
||||
0,
|
||||
50,
|
||||
ObjectType(ObjectIdentity(self.baseoid)),
|
||||
object_type,
|
||||
lexicographicMode=False,
|
||||
)
|
||||
async for errindication, errstatus, errindex, res in walker:
|
||||
|
|
|
@ -11,10 +11,6 @@ from pysnmp.error import PySnmpError
|
|||
import pysnmp.hlapi.asyncio as hlapi
|
||||
from pysnmp.hlapi.asyncio import (
|
||||
CommunityData,
|
||||
ContextData,
|
||||
ObjectIdentity,
|
||||
ObjectType,
|
||||
SnmpEngine,
|
||||
Udp6TransportTarget,
|
||||
UdpTransportTarget,
|
||||
UsmUserData,
|
||||
|
@ -71,6 +67,7 @@ from .const import (
|
|||
MAP_PRIV_PROTOCOLS,
|
||||
SNMP_VERSIONS,
|
||||
)
|
||||
from .util import async_create_request_cmd_args
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -119,7 +116,7 @@ async def async_setup_platform(
|
|||
host = config.get(CONF_HOST)
|
||||
port = config.get(CONF_PORT)
|
||||
community = config.get(CONF_COMMUNITY)
|
||||
baseoid = config.get(CONF_BASEOID)
|
||||
baseoid: str = config[CONF_BASEOID]
|
||||
version = config[CONF_VERSION]
|
||||
username = config.get(CONF_USERNAME)
|
||||
authkey = config.get(CONF_AUTH_KEY)
|
||||
|
@ -145,27 +142,18 @@ async def async_setup_platform(
|
|||
authproto = "none"
|
||||
if not privkey:
|
||||
privproto = "none"
|
||||
|
||||
request_args = [
|
||||
SnmpEngine(),
|
||||
UsmUserData(
|
||||
username,
|
||||
authKey=authkey or None,
|
||||
privKey=privkey or None,
|
||||
authProtocol=getattr(hlapi, MAP_AUTH_PROTOCOLS[authproto]),
|
||||
privProtocol=getattr(hlapi, MAP_PRIV_PROTOCOLS[privproto]),
|
||||
),
|
||||
target,
|
||||
ContextData(),
|
||||
]
|
||||
auth_data = UsmUserData(
|
||||
username,
|
||||
authKey=authkey or None,
|
||||
privKey=privkey or None,
|
||||
authProtocol=getattr(hlapi, MAP_AUTH_PROTOCOLS[authproto]),
|
||||
privProtocol=getattr(hlapi, MAP_PRIV_PROTOCOLS[privproto]),
|
||||
)
|
||||
else:
|
||||
request_args = [
|
||||
SnmpEngine(),
|
||||
CommunityData(community, mpModel=SNMP_VERSIONS[version]),
|
||||
target,
|
||||
ContextData(),
|
||||
]
|
||||
get_result = await getCmd(*request_args, ObjectType(ObjectIdentity(baseoid)))
|
||||
auth_data = CommunityData(community, mpModel=SNMP_VERSIONS[version])
|
||||
|
||||
request_args = await async_create_request_cmd_args(hass, auth_data, target, baseoid)
|
||||
get_result = await getCmd(*request_args)
|
||||
errindication, _, _, _ = get_result
|
||||
|
||||
if errindication and not accept_errors:
|
||||
|
@ -244,9 +232,7 @@ class SnmpData:
|
|||
async def async_update(self):
|
||||
"""Get the latest data from the remote SNMP capable host."""
|
||||
|
||||
get_result = await getCmd(
|
||||
*self._request_args, ObjectType(ObjectIdentity(self._baseoid))
|
||||
)
|
||||
get_result = await getCmd(*self._request_args)
|
||||
errindication, errstatus, errindex, restable = get_result
|
||||
|
||||
if errindication and not self._accept_errors:
|
||||
|
|
|
@ -8,10 +8,6 @@ from typing import Any
|
|||
import pysnmp.hlapi.asyncio as hlapi
|
||||
from pysnmp.hlapi.asyncio import (
|
||||
CommunityData,
|
||||
ContextData,
|
||||
ObjectIdentity,
|
||||
ObjectType,
|
||||
SnmpEngine,
|
||||
UdpTransportTarget,
|
||||
UsmUserData,
|
||||
getCmd,
|
||||
|
@ -67,6 +63,7 @@ from .const import (
|
|||
MAP_PRIV_PROTOCOLS,
|
||||
SNMP_VERSIONS,
|
||||
)
|
||||
from .util import RequestArgsType, async_create_request_cmd_args
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -132,40 +129,54 @@ async def async_setup_platform(
|
|||
host = config.get(CONF_HOST)
|
||||
port = config.get(CONF_PORT)
|
||||
community = config.get(CONF_COMMUNITY)
|
||||
baseoid = config.get(CONF_BASEOID)
|
||||
baseoid: str = config[CONF_BASEOID]
|
||||
command_oid = config.get(CONF_COMMAND_OID)
|
||||
command_payload_on = config.get(CONF_COMMAND_PAYLOAD_ON)
|
||||
command_payload_off = config.get(CONF_COMMAND_PAYLOAD_OFF)
|
||||
version = config.get(CONF_VERSION)
|
||||
version: str = config[CONF_VERSION]
|
||||
username = config.get(CONF_USERNAME)
|
||||
authkey = config.get(CONF_AUTH_KEY)
|
||||
authproto = config.get(CONF_AUTH_PROTOCOL)
|
||||
authproto: str = config[CONF_AUTH_PROTOCOL]
|
||||
privkey = config.get(CONF_PRIV_KEY)
|
||||
privproto = config.get(CONF_PRIV_PROTOCOL)
|
||||
privproto: str = config[CONF_PRIV_PROTOCOL]
|
||||
payload_on = config.get(CONF_PAYLOAD_ON)
|
||||
payload_off = config.get(CONF_PAYLOAD_OFF)
|
||||
vartype = config.get(CONF_VARTYPE)
|
||||
|
||||
if version == "3":
|
||||
if not authkey:
|
||||
authproto = "none"
|
||||
if not privkey:
|
||||
privproto = "none"
|
||||
|
||||
auth_data = UsmUserData(
|
||||
username,
|
||||
authKey=authkey or None,
|
||||
privKey=privkey or None,
|
||||
authProtocol=getattr(hlapi, MAP_AUTH_PROTOCOLS[authproto]),
|
||||
privProtocol=getattr(hlapi, MAP_PRIV_PROTOCOLS[privproto]),
|
||||
)
|
||||
else:
|
||||
auth_data = CommunityData(community, mpModel=SNMP_VERSIONS[version])
|
||||
|
||||
request_args = await async_create_request_cmd_args(
|
||||
hass, auth_data, UdpTransportTarget((host, port)), baseoid
|
||||
)
|
||||
|
||||
async_add_entities(
|
||||
[
|
||||
SnmpSwitch(
|
||||
name,
|
||||
host,
|
||||
port,
|
||||
community,
|
||||
baseoid,
|
||||
command_oid,
|
||||
version,
|
||||
username,
|
||||
authkey,
|
||||
authproto,
|
||||
privkey,
|
||||
privproto,
|
||||
payload_on,
|
||||
payload_off,
|
||||
command_payload_on,
|
||||
command_payload_off,
|
||||
vartype,
|
||||
request_args,
|
||||
)
|
||||
],
|
||||
True,
|
||||
|
@ -180,21 +191,15 @@ class SnmpSwitch(SwitchEntity):
|
|||
name,
|
||||
host,
|
||||
port,
|
||||
community,
|
||||
baseoid,
|
||||
commandoid,
|
||||
version,
|
||||
username,
|
||||
authkey,
|
||||
authproto,
|
||||
privkey,
|
||||
privproto,
|
||||
payload_on,
|
||||
payload_off,
|
||||
command_payload_on,
|
||||
command_payload_off,
|
||||
vartype,
|
||||
):
|
||||
request_args,
|
||||
) -> None:
|
||||
"""Initialize the switch."""
|
||||
|
||||
self._name = name
|
||||
|
@ -206,35 +211,11 @@ class SnmpSwitch(SwitchEntity):
|
|||
self._command_payload_on = command_payload_on or payload_on
|
||||
self._command_payload_off = command_payload_off or payload_off
|
||||
|
||||
self._state = None
|
||||
self._state: bool | None = None
|
||||
self._payload_on = payload_on
|
||||
self._payload_off = payload_off
|
||||
|
||||
if version == "3":
|
||||
if not authkey:
|
||||
authproto = "none"
|
||||
if not privkey:
|
||||
privproto = "none"
|
||||
|
||||
self._request_args = [
|
||||
SnmpEngine(),
|
||||
UsmUserData(
|
||||
username,
|
||||
authKey=authkey or None,
|
||||
privKey=privkey or None,
|
||||
authProtocol=getattr(hlapi, MAP_AUTH_PROTOCOLS[authproto]),
|
||||
privProtocol=getattr(hlapi, MAP_PRIV_PROTOCOLS[privproto]),
|
||||
),
|
||||
UdpTransportTarget((host, port)),
|
||||
ContextData(),
|
||||
]
|
||||
else:
|
||||
self._request_args = [
|
||||
SnmpEngine(),
|
||||
CommunityData(community, mpModel=SNMP_VERSIONS[version]),
|
||||
UdpTransportTarget((host, port)),
|
||||
ContextData(),
|
||||
]
|
||||
self._target = UdpTransportTarget((host, port))
|
||||
self._request_args: RequestArgsType = request_args
|
||||
|
||||
async def async_turn_on(self, **kwargs: Any) -> None:
|
||||
"""Turn on the switch."""
|
||||
|
@ -259,9 +240,7 @@ class SnmpSwitch(SwitchEntity):
|
|||
|
||||
async def async_update(self) -> None:
|
||||
"""Update the state."""
|
||||
get_result = await getCmd(
|
||||
*self._request_args, ObjectType(ObjectIdentity(self._baseoid))
|
||||
)
|
||||
get_result = await getCmd(*self._request_args)
|
||||
errindication, errstatus, errindex, restable = get_result
|
||||
|
||||
if errindication:
|
||||
|
@ -296,6 +275,4 @@ class SnmpSwitch(SwitchEntity):
|
|||
return self._state
|
||||
|
||||
async def _set(self, value):
|
||||
await setCmd(
|
||||
*self._request_args, ObjectType(ObjectIdentity(self._commandoid), value)
|
||||
)
|
||||
await setCmd(*self._request_args, value)
|
||||
|
|
76
homeassistant/components/snmp/util.py
Normal file
76
homeassistant/components/snmp/util.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
"""Support for displaying collected data over SNMP."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from pysnmp.hlapi.asyncio import (
|
||||
CommunityData,
|
||||
ContextData,
|
||||
ObjectIdentity,
|
||||
ObjectType,
|
||||
SnmpEngine,
|
||||
Udp6TransportTarget,
|
||||
UdpTransportTarget,
|
||||
UsmUserData,
|
||||
)
|
||||
from pysnmp.hlapi.asyncio.cmdgen import lcd, vbProcessor
|
||||
from pysnmp.smi.builder import MibBuilder
|
||||
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import Event, HomeAssistant, callback
|
||||
from homeassistant.helpers.singleton import singleton
|
||||
|
||||
DATA_SNMP_ENGINE = "snmp_engine"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
type RequestArgsType = tuple[
|
||||
SnmpEngine,
|
||||
UsmUserData | CommunityData,
|
||||
UdpTransportTarget | Udp6TransportTarget,
|
||||
ContextData,
|
||||
ObjectType,
|
||||
]
|
||||
|
||||
|
||||
async def async_create_request_cmd_args(
|
||||
hass: HomeAssistant,
|
||||
auth_data: UsmUserData | CommunityData,
|
||||
target: UdpTransportTarget | Udp6TransportTarget,
|
||||
object_id: str,
|
||||
) -> RequestArgsType:
|
||||
"""Create request arguments."""
|
||||
return (
|
||||
await async_get_snmp_engine(hass),
|
||||
auth_data,
|
||||
target,
|
||||
ContextData(),
|
||||
ObjectType(ObjectIdentity(object_id)),
|
||||
)
|
||||
|
||||
|
||||
@singleton(DATA_SNMP_ENGINE)
|
||||
async def async_get_snmp_engine(hass: HomeAssistant) -> SnmpEngine:
|
||||
"""Get the SNMP engine."""
|
||||
engine = await hass.async_add_executor_job(_get_snmp_engine)
|
||||
|
||||
@callback
|
||||
def _async_shutdown_listener(ev: Event) -> None:
|
||||
_LOGGER.debug("Unconfiguring SNMP engine")
|
||||
lcd.unconfigure(engine, None)
|
||||
|
||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_shutdown_listener)
|
||||
return engine
|
||||
|
||||
|
||||
def _get_snmp_engine() -> SnmpEngine:
|
||||
"""Return a cached instance of SnmpEngine."""
|
||||
engine = SnmpEngine()
|
||||
mib_controller = vbProcessor.getMibViewController(engine)
|
||||
# Actually load the MIBs from disk so we do
|
||||
# not do it in the event loop
|
||||
builder: MibBuilder = mib_controller.mibBuilder
|
||||
if "PYSNMP-MIB" not in builder.mibSymbols:
|
||||
builder.loadModules()
|
||||
return engine
|
22
tests/components/snmp/test_init.py
Normal file
22
tests/components/snmp/test_init.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
"""SNMP tests."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from pysnmp.hlapi.asyncio import SnmpEngine
|
||||
from pysnmp.hlapi.asyncio.cmdgen import lcd
|
||||
|
||||
from homeassistant.components import snmp
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
async def test_async_get_snmp_engine(hass: HomeAssistant) -> None:
|
||||
"""Test async_get_snmp_engine."""
|
||||
engine = await snmp.async_get_snmp_engine(hass)
|
||||
assert isinstance(engine, SnmpEngine)
|
||||
engine2 = await snmp.async_get_snmp_engine(hass)
|
||||
assert engine is engine2
|
||||
with patch.object(lcd, "unconfigure") as mock_unconfigure:
|
||||
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
|
||||
await hass.async_block_till_done()
|
||||
assert mock_unconfigure.called
|
Loading…
Add table
Reference in a new issue