Refactor NextBus integration to use new API (#121133)

* Refactor NextBus integration to use new API

This removes the `messages`, `directions`, and `attribution` attributes
from the sensor. Those may be added back in the future with additional
refactoring.

Some existing sensors may be broken today because of deprecated Agency
names. This patch will not migrate them as the migration path is
ambiguous. Setting up again should work though.

* Move result indexing outside of try/except
This commit is contained in:
Ian 2024-07-24 09:18:21 -07:00 committed by GitHub
parent 3c4f2c2dcf
commit 3e8d3083ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 191 additions and 264 deletions

View file

@ -37,52 +37,33 @@ def _dict_to_select_selector(options: dict[str, str]) -> SelectSelector:
def _get_agency_tags(client: NextBusClient) -> dict[str, str]:
return {a["tag"]: a["title"] for a in client.get_agency_list()["agency"]}
return {a["id"]: a["name"] for a in client.agencies()}
def _get_route_tags(client: NextBusClient, agency_tag: str) -> dict[str, str]:
return {a["tag"]: a["title"] for a in client.get_route_list(agency_tag)["route"]}
return {a["id"]: a["title"] for a in client.routes(agency_tag)}
def _get_stop_tags(
client: NextBusClient, agency_tag: str, route_tag: str
) -> dict[str, str]:
route_config = client.get_route_config(route_tag, agency_tag)
tags = {a["tag"]: a["title"] for a in route_config["route"]["stop"]}
title_counts = Counter(tags.values())
route_config = client.route_details(route_tag, agency_tag)
stop_ids = {a["id"]: a["name"] for a in route_config["stops"]}
title_counts = Counter(stop_ids.values())
stop_directions: dict[str, str] = {}
for direction in listify(route_config["route"]["direction"]):
for stop in direction["stop"]:
stop_directions[stop["tag"]] = direction["name"]
for direction in listify(route_config["directions"]):
if not direction["useForUi"]:
continue
for stop in direction["stops"]:
stop_directions[stop] = direction["name"]
# Append directions for stops with shared titles
for tag, title in tags.items():
for stop_id, title in stop_ids.items():
if title_counts[title] > 1:
tags[tag] = f"{title} ({stop_directions.get(tag, tag)})"
stop_ids[stop_id] = f"{title} ({stop_directions.get(stop_id, stop_id)})"
return tags
def _validate_import(
client: NextBusClient, agency_tag: str, route_tag: str, stop_tag: str
) -> str | tuple[str, str, str]:
agency_tags = _get_agency_tags(client)
agency = agency_tags.get(agency_tag)
if not agency:
return "invalid_agency"
route_tags = _get_route_tags(client, agency_tag)
route = route_tags.get(route_tag)
if not route:
return "invalid_route"
stop_tags = _get_stop_tags(client, agency_tag, route_tag)
stop = stop_tags.get(stop_tag)
if not stop:
return "invalid_stop"
return agency, route, stop
return stop_ids
def _unique_id_from_data(data: dict[str, str]) -> str:
@ -101,7 +82,7 @@ class NextBusFlowHandler(ConfigFlow, domain=DOMAIN):
def __init__(self):
"""Initialize NextBus config flow."""
self.data: dict[str, str] = {}
self._client = NextBusClient(output_format="json")
self._client = NextBusClient()
async def async_step_user(
self,

View file

@ -2,16 +2,16 @@
from datetime import timedelta
import logging
from typing import Any, cast
from typing import Any
from py_nextbus import NextBusClient
from py_nextbus.client import NextBusFormatError, NextBusHTTPError, RouteStop
from py_nextbus.client import NextBusFormatError, NextBusHTTPError
from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from .const import DOMAIN
from .util import listify
from .util import RouteStop
_LOGGER = logging.getLogger(__name__)
@ -27,53 +27,48 @@ class NextBusDataUpdateCoordinator(DataUpdateCoordinator):
name=DOMAIN,
update_interval=timedelta(seconds=30),
)
self.client = NextBusClient(output_format="json", agency=agency)
self.client = NextBusClient(agency_id=agency)
self._agency = agency
self._stop_routes: set[RouteStop] = set()
self._route_stops: set[RouteStop] = set()
self._predictions: dict[RouteStop, dict[str, Any]] = {}
def add_stop_route(self, stop_tag: str, route_tag: str) -> None:
def add_stop_route(self, stop_id: str, route_id: str) -> None:
"""Tell coordinator to start tracking a given stop and route."""
self._stop_routes.add(RouteStop(route_tag, stop_tag))
self._route_stops.add(RouteStop(route_id, stop_id))
def remove_stop_route(self, stop_tag: str, route_tag: str) -> None:
def remove_stop_route(self, stop_id: str, route_id: str) -> None:
"""Tell coordinator to stop tracking a given stop and route."""
self._stop_routes.remove(RouteStop(route_tag, stop_tag))
self._route_stops.remove(RouteStop(route_id, stop_id))
def get_prediction_data(
self, stop_tag: str, route_tag: str
) -> dict[str, Any] | None:
def get_prediction_data(self, stop_id: str, route_id: str) -> dict[str, Any] | None:
"""Get prediction result for a given stop and route."""
return self._predictions.get(RouteStop(route_tag, stop_tag))
def _calc_predictions(self, data: dict[str, Any]) -> None:
self._predictions = {
RouteStop(prediction["routeTag"], prediction["stopTag"]): prediction
for prediction in listify(data.get("predictions", []))
}
def get_attribution(self) -> str | None:
"""Get attribution from api results."""
return self.data.get("copyright")
return self._predictions.get(RouteStop(route_id, stop_id))
def has_routes(self) -> bool:
"""Check if this coordinator is tracking any routes."""
return len(self._stop_routes) > 0
return len(self._route_stops) > 0
async def _async_update_data(self) -> dict[str, Any]:
"""Fetch data from NextBus."""
self.logger.debug("Updating data from API. Routes: %s", str(self._stop_routes))
self.logger.debug("Updating data from API. Routes: %s", str(self._route_stops))
def _update_data() -> dict:
"""Fetch data from NextBus."""
self.logger.debug("Updating data from API (executor)")
try:
data = self.client.get_predictions_for_multi_stops(self._stop_routes)
# Casting here because we expect dict and not a str due to the input format selected being JSON
data = cast(dict[str, Any], data)
self._calc_predictions(data)
except (NextBusHTTPError, NextBusFormatError) as ex:
raise UpdateFailed("Failed updating nextbus data", ex) from ex
return data
predictions: dict[RouteStop, dict[str, Any]] = {}
for route_stop in self._route_stops:
prediction_results: list[dict[str, Any]] = []
try:
prediction_results = self.client.predictions_for_stop(
route_stop.stop_id, route_stop.route_id
)
except (NextBusHTTPError, NextBusFormatError) as ex:
raise UpdateFailed("Failed updating nextbus data", ex) from ex
if prediction_results:
predictions[route_stop] = prediction_results[0]
self._predictions = predictions
return predictions
return await self.hass.async_add_executor_job(_update_data)

View file

@ -6,5 +6,5 @@
"documentation": "https://www.home-assistant.io/integrations/nextbus",
"iot_class": "cloud_polling",
"loggers": ["py_nextbus"],
"requirements": ["py-nextbusnext==1.0.2"]
"requirements": ["py-nextbusnext==2.0.3"]
}

View file

@ -2,7 +2,6 @@
from __future__ import annotations
from itertools import chain
import logging
from typing import cast
@ -16,7 +15,7 @@ from homeassistant.util.dt import utc_from_timestamp
from .const import CONF_AGENCY, CONF_ROUTE, DOMAIN
from .coordinator import NextBusDataUpdateCoordinator
from .util import listify, maybe_first
from .util import maybe_first
_LOGGER = logging.getLogger(__name__)
@ -76,7 +75,11 @@ class NextBusDepartureSensor(
self.agency = agency
self.route = route
self.stop = stop
self._attr_extra_state_attributes: dict[str, str] = {}
self._attr_extra_state_attributes: dict[str, str] = {
"agency": agency,
"route": route,
"stop": stop,
}
self._attr_unique_id = unique_id
self._attr_name = name
@ -99,11 +102,10 @@ class NextBusDepartureSensor(
def _handle_coordinator_update(self) -> None:
"""Update sensor with new departures times."""
results = self.coordinator.get_prediction_data(self.stop, self.route)
self._attr_attribution = self.coordinator.get_attribution()
self._log_debug("Predictions results: %s", results)
if not results or "Error" in results:
if not results:
self._log_err("Error getting predictions: %s", str(results))
self._attr_native_value = None
self._attr_extra_state_attributes.pop("upcoming", None)
@ -112,31 +114,13 @@ class NextBusDepartureSensor(
# Set detailed attributes
self._attr_extra_state_attributes.update(
{
"agency": str(results.get("agencyTitle")),
"route": str(results.get("routeTitle")),
"stop": str(results.get("stopTitle")),
"route": str(results["route"]["title"]),
"stop": str(results["stop"]["name"]),
}
)
# List all messages in the attributes
messages = listify(results.get("message", []))
self._log_debug("Messages: %s", messages)
self._attr_extra_state_attributes["message"] = " -- ".join(
message.get("text", "") for message in messages
)
# List out all directions in the attributes
directions = listify(results.get("direction", []))
self._attr_extra_state_attributes["direction"] = ", ".join(
direction.get("title", "") for direction in directions
)
# Chain all predictions together
predictions = list(
chain(
*(listify(direction.get("prediction", [])) for direction in directions)
)
)
predictions = results["values"]
# Short circuit if we don't have any actual bus predictions
if not predictions:
@ -146,12 +130,12 @@ class NextBusDepartureSensor(
else:
# Generate list of upcoming times
self._attr_extra_state_attributes["upcoming"] = ", ".join(
sorted((p["minutes"] for p in predictions), key=int)
str(p["minutes"]) for p in predictions
)
latest_prediction = maybe_first(predictions)
self._attr_native_value = utc_from_timestamp(
int(latest_prediction["epochTime"]) / 1000
latest_prediction["timestamp"] / 1000
)
self.async_write_ha_state()

View file

@ -1,6 +1,6 @@
"""Utils for NextBus integration module."""
from typing import Any
from typing import Any, NamedTuple
def listify(maybe_list: Any) -> list[Any]:
@ -24,3 +24,10 @@ def maybe_first(maybe_list: list[Any] | None) -> Any:
return maybe_list[0]
return maybe_list
class RouteStop(NamedTuple):
"""NamedTuple for a route and stop combination."""
route_id: str
stop_id: str

View file

@ -1644,7 +1644,7 @@ py-madvr2==1.6.29
py-melissa-climate==2.1.4
# homeassistant.components.nextbus
py-nextbusnext==1.0.2
py-nextbusnext==2.0.3
# homeassistant.components.nightscout
py-nightscout==1.2.2

View file

@ -1333,7 +1333,7 @@ py-madvr2==1.6.29
py-melissa-climate==2.1.4
# homeassistant.components.nextbus
py-nextbusnext==1.0.2
py-nextbusnext==2.0.3
# homeassistant.components.nightscout
py-nightscout==1.2.2

View file

@ -8,15 +8,32 @@ import pytest
@pytest.fixture(
params=[
{"name": "Outbound", "stop": [{"tag": "5650"}]},
[
{
"name": "Outbound",
"stop": [{"tag": "5650"}],
"shortName": "Outbound",
"useForUi": True,
"stops": ["5184"],
},
{
"name": "Outbound - Hidden",
"shortName": "Outbound - Hidden",
"useForUi": False,
"stops": ["5651"],
},
],
[
{
"name": "Outbound",
"shortName": "Outbound",
"useForUi": True,
"stops": ["5184"],
},
{
"name": "Inbound",
"stop": [{"tag": "5651"}],
"shortName": "Inbound",
"useForUi": True,
"stops": ["5651"],
},
],
]
@ -35,22 +52,65 @@ def mock_nextbus_lists(
) -> MagicMock:
"""Mock all list functions in nextbus to test validate logic."""
instance = mock_nextbus.return_value
instance.get_agency_list.return_value = {
"agency": [{"tag": "sf-muni", "title": "San Francisco Muni"}]
}
instance.get_route_list.return_value = {
"route": [{"tag": "F", "title": "F - Market & Wharves"}]
}
instance.get_route_config.return_value = {
"route": {
"stop": [
{"tag": "5650", "title": "Market St & 7th St"},
{"tag": "5651", "title": "Market St & 7th St"},
# Error case test. Duplicate title with no unique direction
{"tag": "5652", "title": "Market St & 7th St"},
],
"direction": route_config_direction,
instance.agencies.return_value = [
{
"id": "sfmta-cis",
"name": "San Francisco Muni CIS",
"shortName": "SF Muni CIS",
"region": "",
"website": "",
"logo": "",
"nxbs2RedirectUrl": "",
}
]
instance.routes.return_value = [
{
"id": "F",
"rev": 1057,
"title": "F Market & Wharves",
"description": "7am-10pm daily",
"color": "",
"textColor": "",
"hidden": False,
"timestamp": "2024-06-23T03:06:58Z",
},
]
instance.route_details.return_value = {
"id": "F",
"rev": 1057,
"title": "F Market & Wharves",
"description": "7am-10pm daily",
"color": "",
"textColor": "",
"hidden": False,
"boundingBox": {},
"stops": [
{
"id": "5184",
"lat": 37.8071299,
"lon": -122.41732,
"name": "Jones St & Beach St",
"code": "15184",
"hidden": False,
"showDestinationSelector": True,
"directions": ["F_0_var1", "F_0_var0"],
},
{
"id": "5651",
"lat": 37.8071299,
"lon": -122.41732,
"name": "Jones St & Beach St",
"code": "15651",
"hidden": False,
"showDestinationSelector": True,
"directions": ["F_0_var1", "F_0_var0"],
},
],
"directions": route_config_direction,
"paths": [],
"timestamp": "2024-06-23T03:06:58Z",
}
return instance

View file

@ -44,7 +44,7 @@ async def test_user_config(
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_AGENCY: "sf-muni",
CONF_AGENCY: "sfmta-cis",
},
)
await hass.async_block_till_done()
@ -68,16 +68,16 @@ async def test_user_config(
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_STOP: "5650",
CONF_STOP: "5184",
},
)
await hass.async_block_till_done()
assert result.get("type") is FlowResultType.CREATE_ENTRY
assert result.get("data") == {
"agency": "sf-muni",
"agency": "sfmta-cis",
"route": "F",
"stop": "5650",
"stop": "5184",
}
assert len(mock_setup_entry.mock_calls) == 1

View file

@ -18,9 +18,9 @@ from homeassistant.helpers.update_coordinator import UpdateFailed
from tests.common import MockConfigEntry
VALID_AGENCY = "sf-muni"
VALID_AGENCY = "sfmta-cis"
VALID_ROUTE = "F"
VALID_STOP = "5650"
VALID_STOP = "5184"
VALID_AGENCY_TITLE = "San Francisco Muni"
VALID_ROUTE_TITLE = "F-Market & Wharves"
VALID_STOP_TITLE = "Market St & 7th St"
@ -44,25 +44,38 @@ CONFIG_BASIC = {
}
}
BASIC_RESULTS = {
"predictions": {
"agencyTitle": VALID_AGENCY_TITLE,
"agencyTag": VALID_AGENCY,
"routeTitle": VALID_ROUTE_TITLE,
"routeTag": VALID_ROUTE,
"stopTitle": VALID_STOP_TITLE,
"stopTag": VALID_STOP,
"direction": {
"title": "Outbound",
"prediction": [
{"minutes": "1", "epochTime": "1553807371000"},
{"minutes": "2", "epochTime": "1553807372000"},
{"minutes": "3", "epochTime": "1553807373000"},
{"minutes": "10", "epochTime": "1553807380000"},
],
BASIC_RESULTS = [
{
"route": {
"title": VALID_ROUTE_TITLE,
"id": VALID_ROUTE,
},
"stop": {
"name": VALID_STOP_TITLE,
"id": VALID_STOP,
},
"values": [
{"minutes": 1, "timestamp": 1553807371000},
{"minutes": 2, "timestamp": 1553807372000},
{"minutes": 3, "timestamp": 1553807373000},
{"minutes": 10, "timestamp": 1553807380000},
],
}
}
]
NO_UPCOMING = [
{
"route": {
"title": VALID_ROUTE_TITLE,
"id": VALID_ROUTE,
},
"stop": {
"name": VALID_STOP_TITLE,
"id": VALID_STOP,
},
"values": [],
}
]
@pytest.fixture
@ -78,9 +91,9 @@ def mock_nextbus_predictions(
) -> Generator[MagicMock]:
"""Create a mock of NextBusClient predictions."""
instance = mock_nextbus.return_value
instance.get_predictions_for_multi_stops.return_value = BASIC_RESULTS
instance.predictions_for_stop.return_value = BASIC_RESULTS
return instance.get_predictions_for_multi_stops
return instance.predictions_for_stop
async def assert_setup_sensor(
@ -105,117 +118,23 @@ async def assert_setup_sensor(
return config_entry
async def test_message_dict(
hass: HomeAssistant,
mock_nextbus: MagicMock,
mock_nextbus_lists: MagicMock,
mock_nextbus_predictions: MagicMock,
) -> None:
"""Verify that a single dict message is rendered correctly."""
mock_nextbus_predictions.return_value = {
"predictions": {
"agencyTitle": VALID_AGENCY_TITLE,
"agencyTag": VALID_AGENCY,
"routeTitle": VALID_ROUTE_TITLE,
"routeTag": VALID_ROUTE,
"stopTitle": VALID_STOP_TITLE,
"stopTag": VALID_STOP,
"message": {"text": "Message"},
"direction": {
"title": "Outbound",
"prediction": [
{"minutes": "1", "epochTime": "1553807371000"},
{"minutes": "2", "epochTime": "1553807372000"},
{"minutes": "3", "epochTime": "1553807373000"},
],
},
}
}
await assert_setup_sensor(hass, CONFIG_BASIC)
state = hass.states.get(SENSOR_ID)
assert state is not None
assert state.attributes["message"] == "Message"
async def test_message_list(
async def test_predictions(
hass: HomeAssistant,
mock_nextbus: MagicMock,
mock_nextbus_lists: MagicMock,
mock_nextbus_predictions: MagicMock,
) -> None:
"""Verify that a list of messages are rendered correctly."""
mock_nextbus_predictions.return_value = {
"predictions": {
"agencyTitle": VALID_AGENCY_TITLE,
"agencyTag": VALID_AGENCY,
"routeTitle": VALID_ROUTE_TITLE,
"routeTag": VALID_ROUTE,
"stopTitle": VALID_STOP_TITLE,
"stopTag": VALID_STOP,
"message": [{"text": "Message 1"}, {"text": "Message 2"}],
"direction": {
"title": "Outbound",
"prediction": [
{"minutes": "1", "epochTime": "1553807371000"},
{"minutes": "2", "epochTime": "1553807372000"},
{"minutes": "3", "epochTime": "1553807373000"},
],
},
}
}
await assert_setup_sensor(hass, CONFIG_BASIC)
state = hass.states.get(SENSOR_ID)
assert state is not None
assert state.attributes["message"] == "Message 1 -- Message 2"
async def test_direction_list(
hass: HomeAssistant,
mock_nextbus: MagicMock,
mock_nextbus_lists: MagicMock,
mock_nextbus_predictions: MagicMock,
) -> None:
"""Verify that a list of messages are rendered correctly."""
mock_nextbus_predictions.return_value = {
"predictions": {
"agencyTitle": VALID_AGENCY_TITLE,
"agencyTag": VALID_AGENCY,
"routeTitle": VALID_ROUTE_TITLE,
"routeTag": VALID_ROUTE,
"stopTitle": VALID_STOP_TITLE,
"stopTag": VALID_STOP,
"message": [{"text": "Message 1"}, {"text": "Message 2"}],
"direction": [
{
"title": "Outbound",
"prediction": [
{"minutes": "1", "epochTime": "1553807371000"},
{"minutes": "2", "epochTime": "1553807372000"},
{"minutes": "3", "epochTime": "1553807373000"},
],
},
{
"title": "Outbound 2",
"prediction": {"minutes": "0", "epochTime": "1553807374000"},
},
],
}
}
await assert_setup_sensor(hass, CONFIG_BASIC)
state = hass.states.get(SENSOR_ID)
assert state is not None
assert state.state == "2019-03-28T21:09:31+00:00"
assert state.attributes["agency"] == VALID_AGENCY_TITLE
assert state.attributes["agency"] == VALID_AGENCY
assert state.attributes["route"] == VALID_ROUTE_TITLE
assert state.attributes["stop"] == VALID_STOP_TITLE
assert state.attributes["direction"] == "Outbound, Outbound 2"
assert state.attributes["upcoming"] == "0, 1, 2, 3"
assert state.attributes["upcoming"] == "1, 2, 3, 10"
@pytest.mark.parametrize(
@ -256,27 +175,19 @@ async def test_custom_name(
assert state.name == "Custom Name"
@pytest.mark.parametrize(
"prediction_results",
[
{},
{"Error": "Failed"},
],
)
async def test_no_predictions(
async def test_verify_no_predictions(
hass: HomeAssistant,
mock_nextbus: MagicMock,
mock_nextbus_predictions: MagicMock,
mock_nextbus_lists: MagicMock,
prediction_results: dict[str, str],
mock_nextbus_predictions: MagicMock,
) -> None:
"""Verify there are no exceptions when no predictions are returned."""
mock_nextbus_predictions.return_value = prediction_results
"""Verify attributes are set despite no upcoming times."""
mock_nextbus_predictions.return_value = []
await assert_setup_sensor(hass, CONFIG_BASIC)
state = hass.states.get(SENSOR_ID)
assert state is not None
assert "upcoming" not in state.attributes
assert state.state == "unknown"
@ -287,21 +198,10 @@ async def test_verify_no_upcoming(
mock_nextbus_predictions: MagicMock,
) -> None:
"""Verify attributes are set despite no upcoming times."""
mock_nextbus_predictions.return_value = {
"predictions": {
"agencyTitle": VALID_AGENCY_TITLE,
"agencyTag": VALID_AGENCY,
"routeTitle": VALID_ROUTE_TITLE,
"routeTag": VALID_ROUTE,
"stopTitle": VALID_STOP_TITLE,
"stopTag": VALID_STOP,
"direction": {"title": "Outbound", "prediction": []},
}
}
mock_nextbus_predictions.return_value = NO_UPCOMING
await assert_setup_sensor(hass, CONFIG_BASIC)
state = hass.states.get(SENSOR_ID)
assert state is not None
assert state.state == "unknown"
assert state.attributes["upcoming"] == "No upcoming predictions"
assert state.state == "unknown"