From 515982a6926b50355e606bef9999f9af1448c1c4 Mon Sep 17 00:00:00 2001 From: Fabian Affolter Date: Sat, 16 Sep 2017 08:13:30 +0200 Subject: [PATCH] Refactor Swiss Public Transport sensor (#9129) * Refactor Swiss Public Transport sensor * Minor change --- .../sensor/swiss_public_transport.py | 141 ++++++++---------- requirements_all.txt | 3 + 2 files changed, 63 insertions(+), 81 deletions(-) diff --git a/homeassistant/components/sensor/swiss_public_transport.py b/homeassistant/components/sensor/swiss_public_transport.py index 0febd8c95bc..973eac0bdde 100644 --- a/homeassistant/components/sensor/swiss_public_transport.py +++ b/homeassistant/components/sensor/swiss_public_transport.py @@ -4,10 +4,10 @@ Support for transport.opendata.ch. For more details about this platform, please refer to the documentation at https://home-assistant.io/components/sensor.swiss_public_transport/ """ +import asyncio import logging from datetime import timedelta -import requests import voluptuous as vol import homeassistant.helpers.config_validation as cv @@ -15,15 +15,21 @@ import homeassistant.util.dt as dt_util from homeassistant.components.sensor import PLATFORM_SCHEMA from homeassistant.const import CONF_NAME, ATTR_ATTRIBUTION from homeassistant.helpers.entity import Entity +from homeassistant.helpers.aiohttp_client import async_get_clientsession + +REQUIREMENTS = ['python_opendata_transport==0.0.2'] _LOGGER = logging.getLogger(__name__) -_RESOURCE = 'http://transport.opendata.ch/v1/' ATTR_DEPARTURE_TIME1 = 'next_departure' ATTR_DEPARTURE_TIME2 = 'next_on_departure' +ATTR_DURATION = 'duration' +ATTR_PLATFORM = 'platform' ATTR_REMAINING_TIME = 'remaining_time' ATTR_START = 'start' ATTR_TARGET = 'destination' +ATTR_TRAIN_NUMBER = 'train_number' +ATTR_TRANSFERS = 'transfers' CONF_ATTRIBUTION = "Data provided by transport.opendata.ch" CONF_DESTINATION = 'to' @@ -33,9 +39,7 @@ DEFAULT_NAME = 'Next Departure' ICON = 'mdi:bus' -SCAN_INTERVAL = timedelta(minutes=1) - -TIME_STR_FORMAT = "%H:%M" +SCAN_INTERVAL = timedelta(seconds=90) PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ vol.Required(CONF_DESTINATION): cv.string, @@ -44,39 +48,39 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ }) -def setup_platform(hass, config, add_devices, discovery_info=None): +@asyncio.coroutine +def async_setup_platform(hass, config, async_add_devices, discovery_info=None): """Set up the Swiss public transport sensor.""" name = config.get(CONF_NAME) - # journal contains [0] Station ID start, [1] Station ID destination - # [2] Station name start, and [3] Station name destination - journey = [config.get(CONF_START), config.get(CONF_DESTINATION)] - try: - for location in [config.get(CONF_START), config.get(CONF_DESTINATION)]: - # transport.opendata.ch doesn't play nice with requests.Session - result = requests.get( - '{}locations?query={}'.format(_RESOURCE, location), timeout=10) - journey.append(result.json()['stations'][0]['name']) - except KeyError: - _LOGGER.exception( - "Unable to determine stations. " - "Check your settings and/or the availability of opendata.ch") + start = config.get(CONF_START) + destination = config.get(CONF_DESTINATION) + + connection = SwissPublicTransportSensor(hass, start, destination, name) + yield from connection.async_update() + + if connection.state is None: + _LOGGER.error( + "Check at http://transport.opendata.ch/examples/stationboard.html " + "if your station names are valid") return False - data = PublicTransportData(journey) - add_devices([SwissPublicTransportSensor(data, journey, name)], True) + async_add_devices([connection]) class SwissPublicTransportSensor(Entity): """Implementation of an Swiss public transport sensor.""" - def __init__(self, data, journey, name): + def __init__(self, hass, start, destination, name): """Initialize the sensor.""" - self.data = data + from opendata_transport import OpendataTransport + + self.hass = hass self._name = name - self._state = None - self._times = None - self._from = journey[2] - self._to = journey[3] + self._from = start + self._to = destination + self._websession = async_get_clientsession(self.hass) + self._opendata = OpendataTransport( + self._from, self._to, self.hass.loop, self._websession) @property def name(self): @@ -86,70 +90,45 @@ class SwissPublicTransportSensor(Entity): @property def state(self): """Return the state of the sensor.""" - return self._state + return self._opendata.connections[0]['departure'] \ + if self._opendata is not None else None @property def device_state_attributes(self): """Return the state attributes.""" - if self._times is not None: - return { - ATTR_DEPARTURE_TIME1: self._times[0], - ATTR_DEPARTURE_TIME2: self._times[1], - ATTR_START: self._from, - ATTR_TARGET: self._to, - ATTR_REMAINING_TIME: '{}'.format( - ':'.join(str(self._times[2]).split(':')[:2])), - ATTR_ATTRIBUTION: CONF_ATTRIBUTION, - } + if self._opendata is None: + return + + remaining_time = dt_util.parse_datetime( + self._opendata.connections[0]['departure']) -\ + dt_util.as_local(dt_util.utcnow()) + + attr = { + ATTR_TRAIN_NUMBER: self._opendata.connections[0]['number'], + ATTR_PLATFORM: self._opendata.connections[0]['platform'], + ATTR_TRANSFERS: self._opendata.connections[0]['transfers'], + ATTR_DURATION: self._opendata.connections[0]['duration'], + ATTR_DEPARTURE_TIME1: self._opendata.connections[1]['departure'], + ATTR_DEPARTURE_TIME2: self._opendata.connections[2]['departure'], + ATTR_START: self._opendata.from_name, + ATTR_TARGET: self._opendata.to_name, + ATTR_REMAINING_TIME: '{}'.format(remaining_time), + ATTR_ATTRIBUTION: CONF_ATTRIBUTION, + } + return attr @property def icon(self): """Icon to use in the frontend, if any.""" return ICON - def update(self): + @asyncio.coroutine + def async_update(self): """Get the latest data from opendata.ch and update the states.""" - self.data.update() - self._times = self.data.times - try: - self._state = self._times[0] - except TypeError: - pass - - -class PublicTransportData(object): - """The Class for handling the data retrieval.""" - - def __init__(self, journey): - """Initialize the data object.""" - self.start = journey[0] - self.destination = journey[1] - self.times = {} - - def update(self): - """Get the latest data from opendata.ch.""" - response = requests.get( - _RESOURCE + - 'connections?' + - 'from=' + self.start + '&' + - 'to=' + self.destination + '&' + - 'fields[]=connections/from/departureTimestamp/&' + - 'fields[]=connections/', - timeout=10) - connections = response.json()['connections'][1:3] + from opendata_transport.exceptions import OpendataTransportError try: - self.times = [ - dt_util.as_local( - dt_util.utc_from_timestamp( - item['from']['departureTimestamp'])).strftime( - TIME_STR_FORMAT) - for item in connections - ] - self.times.append( - dt_util.as_local( - dt_util.utc_from_timestamp( - connections[0]['from']['departureTimestamp'])) - - dt_util.as_local(dt_util.utcnow())) - except KeyError: - self.times = ['n/a'] + yield from self._opendata.async_get_data() + except OpendataTransportError: + _LOGGER.error("Unable to retrieve data from transport.opendata.ch") + self._opendata = None diff --git a/requirements_all.txt b/requirements_all.txt index d4177d5570c..d7907bbcf02 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -799,6 +799,9 @@ python-vlc==1.1.2 # homeassistant.components.wink python-wink==1.5.1 +# homeassistant.components.sensor.swiss_public_transport +python_opendata_transport==0.0.2 + # homeassistant.components.zwave python_openzwave==0.4.0.31