Clean mysensors gateway type selection (#51531)

* Clean mysensors gateway type selection

* Fix comment grammar
This commit is contained in:
Martin Hjelmare 2021-06-07 15:45:58 +02:00 committed by GitHub
parent 4c51299dcc
commit 564042ec67
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 35 deletions

View file

@ -347,7 +347,7 @@ class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
break
# if no errors so far, try to connect
if not errors and not await try_connect(self.hass, user_input):
if not errors and not await try_connect(self.hass, gw_type, user_input):
errors["base"] = "cannot_connect"
return errors

View file

@ -22,6 +22,9 @@ import homeassistant.helpers.config_validation as cv
from .const import (
CONF_BAUD_RATE,
CONF_DEVICE,
CONF_GATEWAY_TYPE,
CONF_GATEWAY_TYPE_MQTT,
CONF_GATEWAY_TYPE_SERIAL,
CONF_PERSISTENCE_FILE,
CONF_RETAIN,
CONF_TCP_PORT,
@ -31,6 +34,7 @@ from .const import (
DOMAIN,
MYSENSORS_GATEWAY_START_TASK,
MYSENSORS_GATEWAYS,
ConfGatewayType,
GatewayId,
)
from .handler import HANDLERS
@ -66,10 +70,12 @@ def is_socket_address(value):
raise vol.Invalid("Device is not a valid domain name or ip address") from err
async def try_connect(hass: HomeAssistant, user_input: dict[str, Any]) -> bool:
async def try_connect(
hass: HomeAssistant, gateway_type: ConfGatewayType, user_input: dict[str, Any]
) -> bool:
"""Try to connect to a gateway and report if it worked."""
if user_input[CONF_DEVICE] == MQTT_COMPONENT:
return True # dont validate mqtt. mqtt gateways dont send ready messages :(
if gateway_type == "MQTT":
return True # Do not validate MQTT, as that does not use connection made.
try:
gateway_ready = asyncio.Event()
@ -78,6 +84,7 @@ async def try_connect(hass: HomeAssistant, user_input: dict[str, Any]) -> bool:
gateway: BaseAsyncGateway | None = await _get_gateway(
hass,
gateway_type,
device=user_input[CONF_DEVICE],
version=user_input[CONF_VERSION],
event_callback=lambda _: None,
@ -128,6 +135,7 @@ async def setup_gateway(
ready_gateway = await _get_gateway(
hass,
gateway_type=entry.data[CONF_GATEWAY_TYPE],
device=entry.data[CONF_DEVICE],
version=entry.data[CONF_VERSION],
event_callback=_gw_callback_factory(hass, entry.entry_id),
@ -145,6 +153,7 @@ async def setup_gateway(
async def _get_gateway(
hass: HomeAssistant,
gateway_type: ConfGatewayType,
device: str,
version: str,
event_callback: Callable[[Message], None],
@ -154,15 +163,16 @@ async def _get_gateway(
topic_in_prefix: str | None = None,
topic_out_prefix: str | None = None,
retain: bool = False,
persistence: bool = True, # old persistence option has been deprecated. kwarg is here so we can run try_connect() without persistence
persistence: bool = True,
) -> BaseAsyncGateway | None:
"""Return gateway after setup of the gateway."""
if persistence_file is not None:
# interpret relative paths to be in hass config folder. absolute paths will be left as they are
# Interpret relative paths to be in hass config folder.
# Absolute paths will be left as they are.
persistence_file = hass.config.path(persistence_file)
if device == MQTT_COMPONENT:
if gateway_type == CONF_GATEWAY_TYPE_MQTT:
# Make sure the mqtt integration is set up.
# Naive check that doesn't consider config entry state.
if MQTT_DOMAIN not in hass.config.components:
@ -195,35 +205,26 @@ async def _get_gateway(
persistence_file=persistence_file,
protocol_version=version,
)
elif gateway_type == CONF_GATEWAY_TYPE_SERIAL:
gateway = mysensors.AsyncSerialGateway(
device,
baud=baud_rate,
loop=hass.loop,
event_callback=None,
persistence=persistence,
persistence_file=persistence_file,
protocol_version=version,
)
else:
try:
await hass.async_add_executor_job(is_serial_port, device)
gateway = mysensors.AsyncSerialGateway(
device,
baud=baud_rate,
loop=hass.loop,
event_callback=None,
persistence=persistence,
persistence_file=persistence_file,
protocol_version=version,
)
except vol.Invalid:
try:
await hass.async_add_executor_job(is_socket_address, device)
# valid ip address
gateway = mysensors.AsyncTCPGateway(
device,
port=tcp_port,
loop=hass.loop,
event_callback=None,
persistence=persistence,
persistence_file=persistence_file,
protocol_version=version,
)
except vol.Invalid:
# invalid ip address
_LOGGER.error("Connect failed: Invalid device %s", device)
return None
gateway = mysensors.AsyncTCPGateway(
device,
port=tcp_port,
loop=hass.loop,
event_callback=None,
persistence=persistence,
persistence_file=persistence_file,
protocol_version=version,
)
gateway.event_callback = event_callback
if persistence:
await gateway.start_persistence()