Add more timestamp sensors to swiss_public_transport (#107916)

* add more timestamp sensors

* more generic definition for future sensors

* add entity descriptor

* use enable property to prevent sensors from getting added

* set legacy attribute flag for first sensor

* remove departure from extra attributes

* remove breaking changes again and keep for next pr

* fix multiline statements

* outsource the multiline ifs into function

---------

Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
This commit is contained in:
Cyrill Raccaud 2024-03-26 10:17:25 +01:00 committed by GitHub
parent 0338aaf577
commit e136847b89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 115 additions and 57 deletions

View file

@ -7,6 +7,8 @@ CONF_START = "from"
DEFAULT_NAME = "Next Destination"
SENSOR_CONNECTIONS_COUNT = 3
PLACEHOLDERS = {
"stationboard_url": "http://transport.opendata.ch/examples/stationboard.html",

View file

@ -14,7 +14,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
import homeassistant.util.dt as dt_util
from .const import DOMAIN
from .const import DOMAIN, SENSOR_CONNECTIONS_COUNT
_LOGGER = logging.getLogger(__name__)
@ -23,8 +23,8 @@ class DataConnection(TypedDict):
"""A connection data class."""
departure: datetime | None
next_departure: str | None
next_on_departure: str | None
next_departure: datetime | None
next_on_departure: datetime | None
duration: str
platform: str
remaining_time: str
@ -35,7 +35,9 @@ class DataConnection(TypedDict):
delay: int
class SwissPublicTransportDataUpdateCoordinator(DataUpdateCoordinator[DataConnection]):
class SwissPublicTransportDataUpdateCoordinator(
DataUpdateCoordinator[list[DataConnection]]
):
"""A SwissPublicTransport Data Update Coordinator."""
config_entry: ConfigEntry
@ -50,7 +52,22 @@ class SwissPublicTransportDataUpdateCoordinator(DataUpdateCoordinator[DataConnec
)
self._opendata = opendata
async def _async_update_data(self) -> DataConnection:
def remaining_time(self, departure) -> timedelta | None:
"""Calculate the remaining time for the departure."""
departure_datetime = dt_util.parse_datetime(departure)
if departure_datetime:
return departure_datetime - dt_util.as_local(dt_util.utcnow())
return None
def nth_departure_time(self, i: int) -> datetime | None:
"""Get nth departure time."""
connections = self._opendata.connections
if len(connections) > i and connections[i] is not None:
return dt_util.parse_datetime(connections[i]["departure"])
return None
async def _async_update_data(self) -> list[DataConnection]:
try:
await self._opendata.async_get_data()
except OpendataTransportError as e:
@ -59,41 +76,22 @@ class SwissPublicTransportDataUpdateCoordinator(DataUpdateCoordinator[DataConnec
)
raise UpdateFailed from e
departure_time = (
dt_util.parse_datetime(self._opendata.connections[0]["departure"])
if self._opendata.connections[0] is not None
else None
)
next_departure_time = (
dt_util.parse_datetime(self._opendata.connections[1]["departure"])
if self._opendata.connections[1] is not None
else None
)
next_on_departure_time = (
dt_util.parse_datetime(self._opendata.connections[2]["departure"])
if self._opendata.connections[2] is not None
else None
)
connections = self._opendata.connections
if departure_time:
remaining_time = departure_time - dt_util.as_local(dt_util.utcnow())
else:
remaining_time = None
return DataConnection(
departure=departure_time,
next_departure=next_departure_time.isoformat()
if next_departure_time is not None
else None,
next_on_departure=next_on_departure_time.isoformat()
if next_on_departure_time is not None
else None,
train_number=self._opendata.connections[0]["number"],
platform=self._opendata.connections[0]["platform"],
transfers=self._opendata.connections[0]["transfers"],
duration=self._opendata.connections[0]["duration"],
start=self._opendata.from_name,
destination=self._opendata.to_name,
remaining_time=f"{remaining_time}",
delay=self._opendata.connections[0]["delay"],
)
return [
DataConnection(
departure=self.nth_departure_time(i),
next_departure=self.nth_departure_time(i + 1),
next_on_departure=self.nth_departure_time(i + 2),
train_number=connections[i]["number"],
platform=connections[i]["platform"],
transfers=connections[i]["transfers"],
duration=connections[i]["duration"],
start=self._opendata.from_name,
destination=self._opendata.to_name,
remaining_time=str(self.remaining_time(connections[i]["departure"])),
delay=connections[i]["delay"],
)
for i in range(SENSOR_CONNECTIONS_COUNT)
if len(connections) > i and connections[i] is not None
]

View file

@ -2,6 +2,8 @@
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from datetime import datetime, timedelta
import logging
from typing import TYPE_CHECKING
@ -13,6 +15,7 @@ from homeassistant.components.sensor import (
PLATFORM_SCHEMA,
SensorDeviceClass,
SensorEntity,
SensorEntityDescription,
)
from homeassistant.config_entries import SOURCE_IMPORT
from homeassistant.const import CONF_NAME
@ -25,8 +28,15 @@ from homeassistant.helpers.issue_registry import IssueSeverity, async_create_iss
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import CONF_DESTINATION, CONF_START, DEFAULT_NAME, DOMAIN, PLACEHOLDERS
from .coordinator import SwissPublicTransportDataUpdateCoordinator
from .const import (
CONF_DESTINATION,
CONF_START,
DEFAULT_NAME,
DOMAIN,
PLACEHOLDERS,
SENSOR_CONNECTIONS_COUNT,
)
from .coordinator import DataConnection, SwissPublicTransportDataUpdateCoordinator
_LOGGER = logging.getLogger(__name__)
@ -41,6 +51,33 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
)
@dataclass(kw_only=True, frozen=True)
class SwissPublicTransportSensorEntityDescription(SensorEntityDescription):
"""Describes swiss public transport sensor entity."""
exists_fn: Callable[[DataConnection], bool]
value_fn: Callable[[DataConnection], datetime | None]
index: int
has_legacy_attributes: bool
SENSORS: tuple[SwissPublicTransportSensorEntityDescription, ...] = (
*[
SwissPublicTransportSensorEntityDescription(
key=f"departure{i or ''}",
translation_key=f"departure{i}",
device_class=SensorDeviceClass.TIMESTAMP,
has_legacy_attributes=i == 0,
value_fn=lambda data_connection: data_connection["departure"],
exists_fn=lambda data_connection: data_connection is not None,
index=i,
)
for i in range(SENSOR_CONNECTIONS_COUNT)
],
)
async def async_setup_entry(
hass: core.HomeAssistant,
config_entry: config_entries.ConfigEntry,
@ -55,7 +92,8 @@ async def async_setup_entry(
assert unique_id
async_add_entities(
[SwissPublicTransportSensor(coordinator, unique_id)],
SwissPublicTransportSensor(coordinator, description, unique_id)
for description in SENSORS
)
@ -108,34 +146,51 @@ class SwissPublicTransportSensor(
):
"""Implementation of a Swiss public transport sensor."""
entity_description: SwissPublicTransportSensorEntityDescription
_attr_attribution = "Data provided by transport.opendata.ch"
_attr_has_entity_name = True
_attr_translation_key = "departure"
_attr_device_class = SensorDeviceClass.TIMESTAMP
def __init__(
self,
coordinator: SwissPublicTransportDataUpdateCoordinator,
entity_description: SwissPublicTransportSensorEntityDescription,
unique_id: str,
) -> None:
"""Initialize the sensor."""
super().__init__(coordinator)
self._attr_unique_id = f"{unique_id}_departure"
self.entity_description = entity_description
self._attr_unique_id = f"{unique_id}_{entity_description.key}"
self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, unique_id)},
manufacturer="Opendata.ch",
entry_type=DeviceEntryType.SERVICE,
)
@property
def enabled(self) -> bool:
"""Enable the sensor if data is available."""
return self.entity_description.exists_fn(
self.coordinator.data[self.entity_description.index]
)
@property
def native_value(self) -> datetime | None:
"""Return the state of the sensor."""
return self.entity_description.value_fn(
self.coordinator.data[self.entity_description.index]
)
async def async_added_to_hass(self) -> None:
"""Prepare the extra attributes at start."""
self._async_update_attrs()
if self.entity_description.has_legacy_attributes:
self._async_update_attrs()
await super().async_added_to_hass()
@callback
def _handle_coordinator_update(self) -> None:
"""Handle the state update and prepare the extra state attributes."""
self._async_update_attrs()
if self.entity_description.has_legacy_attributes:
self._async_update_attrs()
return super()._handle_coordinator_update()
@callback
@ -143,11 +198,8 @@ class SwissPublicTransportSensor(
"""Update the extra state attributes based on the coordinator data."""
self._attr_extra_state_attributes = {
key: value
for key, value in self.coordinator.data.items()
for key, value in self.coordinator.data[
self.entity_description.index
].items()
if key not in {"departure"}
}
@property
def native_value(self) -> datetime | None:
"""Return the state of the sensor."""
return self.coordinator.data["departure"]

View file

@ -24,8 +24,14 @@
},
"entity": {
"sensor": {
"departure": {
"departure0": {
"name": "Departure"
},
"departure1": {
"name": "Departure +1"
},
"departure2": {
"name": "Departure +2"
}
}
},