Add Swiss public transport via stations (#115891)

* add via stations

* bump minor version due to backwards incompatibility

* better coverage of many via station options in unit tests

* fix migration unit test for new minor version 1.3

* switch version bump to major and improve migration test

* fixes

* improve error messages

* use placeholders for strings
This commit is contained in:
Cyrill Raccaud 2024-06-23 19:26:55 +02:00 committed by GitHub
parent 5fbb965624
commit 7efd1079bd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 216 additions and 65 deletions

View file

@ -14,8 +14,9 @@ from homeassistant.exceptions import ConfigEntryError, ConfigEntryNotReady
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import CONF_DESTINATION, CONF_START, DOMAIN
from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, PLACEHOLDERS
from .coordinator import SwissPublicTransportDataUpdateCoordinator
from .helper import unique_id_from_config
_LOGGER = logging.getLogger(__name__)
@ -33,19 +34,28 @@ async def async_setup_entry(
destination = config[CONF_DESTINATION]
session = async_get_clientsession(hass)
opendata = OpendataTransport(start, destination, session)
opendata = OpendataTransport(start, destination, session, via=config.get(CONF_VIA))
try:
await opendata.async_get_data()
except OpendataTransportConnectionError as e:
raise ConfigEntryNotReady(
f"Timeout while connecting for entry '{start} {destination}'"
translation_domain=DOMAIN,
translation_key="request_timeout",
translation_placeholders={
"config_title": entry.title,
"error": e,
},
) from e
except OpendataTransportError as e:
raise ConfigEntryError(
f"Setup failed for entry '{start} {destination}' with invalid data, check "
"at http://transport.opendata.ch/examples/stationboard.html if your "
"station names are valid"
translation_domain=DOMAIN,
translation_key="invalid_data",
translation_placeholders={
**PLACEHOLDERS,
"config_title": entry.title,
"error": e,
},
) from e
coordinator = SwissPublicTransportDataUpdateCoordinator(hass, opendata)
@ -72,15 +82,13 @@ async def async_migrate_entry(
"""Migrate config entry."""
_LOGGER.debug("Migrating from version %s", config_entry.version)
if config_entry.minor_version > 3:
if config_entry.version > 2:
# This means the user has downgraded from a future version
return False
if config_entry.minor_version == 1:
if config_entry.version == 1 and config_entry.minor_version == 1:
# Remove wrongly registered devices and entries
new_unique_id = (
f"{config_entry.data[CONF_START]} {config_entry.data[CONF_DESTINATION]}"
)
new_unique_id = unique_id_from_config(config_entry.data)
entity_registry = er.async_get(hass)
device_registry = dr.async_get(hass)
device_entries = dr.async_entries_for_config_entry(
@ -109,6 +117,10 @@ async def async_migrate_entry(
config_entry, unique_id=new_unique_id, minor_version=2
)
if config_entry.version < 2:
# Via stations now available, which are not backwards compatible if used, changes unique id
hass.config_entries.async_update_entry(config_entry, version=2, minor_version=1)
_LOGGER.debug(
"Migration to version %s.%s successful",
config_entry.version,

View file

@ -13,12 +13,24 @@ import voluptuous as vol
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.selector import (
TextSelector,
TextSelectorConfig,
TextSelectorType,
)
from .const import CONF_DESTINATION, CONF_START, DOMAIN, PLACEHOLDERS
from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, MAX_VIA, PLACEHOLDERS
from .helper import unique_id_from_config
DATA_SCHEMA = vol.Schema(
{
vol.Required(CONF_START): cv.string,
vol.Optional(CONF_VIA): TextSelector(
TextSelectorConfig(
type=TextSelectorType.TEXT,
multiple=True,
),
),
vol.Required(CONF_DESTINATION): cv.string,
}
)
@ -29,8 +41,8 @@ _LOGGER = logging.getLogger(__name__)
class SwissPublicTransportConfigFlow(ConfigFlow, domain=DOMAIN):
"""Swiss public transport config flow."""
VERSION = 1
MINOR_VERSION = 2
VERSION = 2
MINOR_VERSION = 1
async def async_step_user(
self, user_input: dict[str, Any] | None = None
@ -38,29 +50,34 @@ class SwissPublicTransportConfigFlow(ConfigFlow, domain=DOMAIN):
"""Async user step to set up the connection."""
errors: dict[str, str] = {}
if user_input is not None:
await self.async_set_unique_id(
f"{user_input[CONF_START]} {user_input[CONF_DESTINATION]}"
)
unique_id = unique_id_from_config(user_input)
await self.async_set_unique_id(unique_id)
self._abort_if_unique_id_configured()
session = async_get_clientsession(self.hass)
opendata = OpendataTransport(
user_input[CONF_START], user_input[CONF_DESTINATION], session
)
try:
await opendata.async_get_data()
except OpendataTransportConnectionError:
errors["base"] = "cannot_connect"
except OpendataTransportError:
errors["base"] = "bad_config"
except Exception:
_LOGGER.exception("Unknown error")
errors["base"] = "unknown"
if CONF_VIA in user_input and len(user_input[CONF_VIA]) > MAX_VIA:
errors["base"] = "too_many_via_stations"
else:
return self.async_create_entry(
title=f"{user_input[CONF_START]} {user_input[CONF_DESTINATION]}",
data=user_input,
session = async_get_clientsession(self.hass)
opendata = OpendataTransport(
user_input[CONF_START],
user_input[CONF_DESTINATION],
session,
via=user_input.get(CONF_VIA),
)
try:
await opendata.async_get_data()
except OpendataTransportConnectionError:
errors["base"] = "cannot_connect"
except OpendataTransportError:
errors["base"] = "bad_config"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unknown error")
errors["base"] = "unknown"
else:
return self.async_create_entry(
title=unique_id,
data=user_input,
)
return self.async_show_form(
step_id="user",

View file

@ -1,12 +1,16 @@
"""Constants for the swiss_public_transport integration."""
from typing import Final
DOMAIN = "swiss_public_transport"
CONF_DESTINATION = "to"
CONF_START = "from"
CONF_DESTINATION: Final = "to"
CONF_START: Final = "from"
CONF_VIA: Final = "via"
DEFAULT_NAME = "Next Destination"
MAX_VIA = 5
SENSOR_CONNECTIONS_COUNT = 3

View file

@ -0,0 +1,15 @@
"""Helper functions for swiss_public_transport."""
from types import MappingProxyType
from typing import Any
from .const import CONF_DESTINATION, CONF_START, CONF_VIA
def unique_id_from_config(config: MappingProxyType[str, Any] | dict[str, Any]) -> str:
"""Build a unique id from a config entry."""
return f"{config[CONF_START]} {config[CONF_DESTINATION]}" + (
" via " + ", ".join(config[CONF_VIA])
if CONF_VIA in config and len(config[CONF_VIA]) > 0
else ""
)

View file

@ -3,6 +3,7 @@
"error": {
"cannot_connect": "Cannot connect to server",
"bad_config": "Request failed due to bad config: Check at [stationboard]({stationboard_url}) if your station names are valid",
"too_many_via_stations": "Too many via stations, only up to 5 via stations are allowed per connection.",
"unknown": "An unknown error was raised by python-opendata-transport"
},
"abort": {
@ -15,9 +16,10 @@
"user": {
"data": {
"from": "Start station",
"to": "End station"
"to": "End station",
"via": "List of up to 5 via stations"
},
"description": "Provide start and end station for your connection\n\nCheck the [stationboard]({stationboard_url}) for valid stations.",
"description": "Provide start and end station for your connection,\nand optionally up to 5 via stations.\n\nCheck the [stationboard]({stationboard_url}) for valid stations.",
"title": "Swiss Public Transport"
}
}
@ -46,5 +48,13 @@
"name": "Delay"
}
}
},
"exceptions": {
"invalid_data": {
"message": "Setup failed for entry {config_title} with invalid data, check at the [stationboard]({stationboard_url}) if your station names are valid.\n{error}"
},
"request_timeout": {
"message": "Timeout while connecting for entry {config_title}.\n{error}"
}
}
}

View file

@ -12,7 +12,10 @@ from homeassistant.components.swiss_public_transport import config_flow
from homeassistant.components.swiss_public_transport.const import (
CONF_DESTINATION,
CONF_START,
CONF_VIA,
MAX_VIA,
)
from homeassistant.components.swiss_public_transport.helper import unique_id_from_config
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
@ -25,8 +28,36 @@ MOCK_DATA_STEP = {
CONF_DESTINATION: "test_destination",
}
MOCK_DATA_STEP_ONE_VIA = {
**MOCK_DATA_STEP,
CONF_VIA: ["via_station"],
}
async def test_flow_user_init_data_success(hass: HomeAssistant) -> None:
MOCK_DATA_STEP_MANY_VIA = {
**MOCK_DATA_STEP,
CONF_VIA: ["via_station_1", "via_station_2", "via_station_3"],
}
MOCK_DATA_STEP_TOO_MANY_STATIONS = {
**MOCK_DATA_STEP,
CONF_VIA: MOCK_DATA_STEP_ONE_VIA[CONF_VIA] * (MAX_VIA + 1),
}
@pytest.mark.parametrize(
("user_input", "config_title"),
[
(MOCK_DATA_STEP, "test_start test_destination"),
(MOCK_DATA_STEP_ONE_VIA, "test_start test_destination via via_station"),
(
MOCK_DATA_STEP_MANY_VIA,
"test_start test_destination via via_station_1, via_station_2, via_station_3",
),
],
)
async def test_flow_user_init_data_success(
hass: HomeAssistant, user_input, config_title
) -> None:
"""Test success response."""
result = await hass.config_entries.flow.async_init(
config_flow.DOMAIN, context={"source": "user"}
@ -47,25 +78,26 @@ async def test_flow_user_init_data_success(hass: HomeAssistant) -> None:
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input=MOCK_DATA_STEP,
user_input=user_input,
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["result"].title == "test_start test_destination"
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["result"].title == config_title
assert result["data"] == MOCK_DATA_STEP
assert result["data"] == user_input
@pytest.mark.parametrize(
("raise_error", "text_error"),
("raise_error", "text_error", "user_input_error"),
[
(OpendataTransportConnectionError(), "cannot_connect"),
(OpendataTransportError(), "bad_config"),
(IndexError(), "unknown"),
(OpendataTransportConnectionError(), "cannot_connect", MOCK_DATA_STEP),
(OpendataTransportError(), "bad_config", MOCK_DATA_STEP),
(None, "too_many_via_stations", MOCK_DATA_STEP_TOO_MANY_STATIONS),
(IndexError(), "unknown", MOCK_DATA_STEP),
],
)
async def test_flow_user_init_data_error_and_recover(
hass: HomeAssistant, raise_error, text_error
hass: HomeAssistant, raise_error, text_error, user_input_error
) -> None:
"""Test unknown errors."""
with patch(
@ -78,7 +110,7 @@ async def test_flow_user_init_data_error_and_recover(
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input=MOCK_DATA_STEP,
user_input=user_input_error,
)
assert result["type"] is FlowResultType.FORM
@ -92,7 +124,7 @@ async def test_flow_user_init_data_error_and_recover(
user_input=MOCK_DATA_STEP,
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["result"].title == "test_start test_destination"
assert result["data"] == MOCK_DATA_STEP
@ -104,7 +136,7 @@ async def test_flow_user_init_data_already_configured(hass: HomeAssistant) -> No
entry = MockConfigEntry(
domain=config_flow.DOMAIN,
data=MOCK_DATA_STEP,
unique_id=f"{MOCK_DATA_STEP[CONF_START]} {MOCK_DATA_STEP[CONF_DESTINATION]}",
unique_id=unique_id_from_config(MOCK_DATA_STEP),
)
entry.add_to_hass(hass)

View file

@ -2,22 +2,32 @@
from unittest.mock import AsyncMock, patch
import pytest
from homeassistant.components.swiss_public_transport.const import (
CONF_DESTINATION,
CONF_START,
CONF_VIA,
DOMAIN,
)
from homeassistant.components.swiss_public_transport.helper import unique_id_from_config
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from tests.common import MockConfigEntry
MOCK_DATA_STEP = {
MOCK_DATA_STEP_BASE = {
CONF_START: "test_start",
CONF_DESTINATION: "test_destination",
}
MOCK_DATA_STEP_VIA = {
**MOCK_DATA_STEP_BASE,
CONF_VIA: ["via_station"],
}
CONNECTIONS = [
{
"departure": "2024-01-06T18:03:00+0100",
@ -46,19 +56,38 @@ CONNECTIONS = [
]
async def test_migration_1_1_to_1_2(
hass: HomeAssistant, entity_registry: er.EntityRegistry
@pytest.mark.parametrize(
(
"from_version",
"from_minor_version",
"config_data",
"overwrite_unique_id",
),
[
(1, 1, MOCK_DATA_STEP_BASE, "None_departure"),
(1, 2, MOCK_DATA_STEP_BASE, None),
(2, 1, MOCK_DATA_STEP_VIA, None),
],
)
async def test_migration_from(
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
from_version,
from_minor_version,
config_data,
overwrite_unique_id,
) -> None:
"""Test successful setup."""
config_entry_faulty = MockConfigEntry(
config_entry = MockConfigEntry(
domain=DOMAIN,
data=MOCK_DATA_STEP,
title="MIGRATION_TEST",
version=1,
minor_version=1,
data=config_data,
title=f"MIGRATION_TEST from {from_version}.{from_minor_version}",
version=from_version,
minor_version=from_minor_version,
unique_id=overwrite_unique_id or unique_id_from_config(config_data),
)
config_entry_faulty.add_to_hass(hass)
config_entry.add_to_hass(hass)
with patch(
"homeassistant.components.swiss_public_transport.OpendataTransport",
@ -67,21 +96,53 @@ async def test_migration_1_1_to_1_2(
mock().connections = CONNECTIONS
# Setup the config entry
await hass.config_entries.async_setup(config_entry_faulty.entry_id)
unique_id = unique_id_from_config(config_entry.data)
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
assert entity_registry.async_is_registered(
entity_registry.entities.get_entity_id(
(Platform.SENSOR, DOMAIN, "test_start test_destination_departure")
(
Platform.SENSOR,
DOMAIN,
f"{unique_id}_departure",
)
)
)
# Check change in config entry
assert config_entry_faulty.minor_version == 2
assert config_entry_faulty.unique_id == "test_start test_destination"
# Check change in config entry and verify most recent version
assert config_entry.version == 2
assert config_entry.minor_version == 1
assert config_entry.unique_id == unique_id
# Check "None" is gone
# Check "None" is gone from version 1.1 to 1.2
assert not entity_registry.async_is_registered(
entity_registry.entities.get_entity_id(
(Platform.SENSOR, DOMAIN, "None_departure")
)
)
async def test_migrate_error_from_future(hass: HomeAssistant) -> None:
"""Test a future version isn't migrated."""
mock_entry = MockConfigEntry(
domain=DOMAIN,
version=3,
minor_version=1,
unique_id="some_crazy_future_unique_id",
data=MOCK_DATA_STEP_BASE,
)
mock_entry.add_to_hass(hass)
with patch(
"homeassistant.components.swiss_public_transport.OpendataTransport",
return_value=AsyncMock(),
) as mock:
mock().connections = CONNECTIONS
await hass.config_entries.async_setup(mock_entry.entry_id)
await hass.async_block_till_done()
entry = hass.config_entries.async_get_entry(mock_entry.entry_id)
assert entry.state is ConfigEntryState.MIGRATION_ERROR