Refactor Swiss Public Transport sensor (#9129)

* Refactor Swiss Public Transport sensor

* Minor change
This commit is contained in:
Fabian Affolter 2017-09-16 08:13:30 +02:00 committed by Paulus Schoutsen
parent 7b0628421d
commit 515982a692
2 changed files with 63 additions and 81 deletions

View file

@ -4,10 +4,10 @@ Support for transport.opendata.ch.
For more details about this platform, please refer to the documentation at
https://home-assistant.io/components/sensor.swiss_public_transport/
"""
import asyncio
import logging
from datetime import timedelta
import requests
import voluptuous as vol
import homeassistant.helpers.config_validation as cv
@ -15,15 +15,21 @@ import homeassistant.util.dt as dt_util
from homeassistant.components.sensor import PLATFORM_SCHEMA
from homeassistant.const import CONF_NAME, ATTR_ATTRIBUTION
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.aiohttp_client import async_get_clientsession
REQUIREMENTS = ['python_opendata_transport==0.0.2']
_LOGGER = logging.getLogger(__name__)
_RESOURCE = 'http://transport.opendata.ch/v1/'
ATTR_DEPARTURE_TIME1 = 'next_departure'
ATTR_DEPARTURE_TIME2 = 'next_on_departure'
ATTR_DURATION = 'duration'
ATTR_PLATFORM = 'platform'
ATTR_REMAINING_TIME = 'remaining_time'
ATTR_START = 'start'
ATTR_TARGET = 'destination'
ATTR_TRAIN_NUMBER = 'train_number'
ATTR_TRANSFERS = 'transfers'
CONF_ATTRIBUTION = "Data provided by transport.opendata.ch"
CONF_DESTINATION = 'to'
@ -33,9 +39,7 @@ DEFAULT_NAME = 'Next Departure'
ICON = 'mdi:bus'
SCAN_INTERVAL = timedelta(minutes=1)
TIME_STR_FORMAT = "%H:%M"
SCAN_INTERVAL = timedelta(seconds=90)
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
vol.Required(CONF_DESTINATION): cv.string,
@ -44,39 +48,39 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
})
def setup_platform(hass, config, add_devices, discovery_info=None):
@asyncio.coroutine
def async_setup_platform(hass, config, async_add_devices, discovery_info=None):
"""Set up the Swiss public transport sensor."""
name = config.get(CONF_NAME)
# journal contains [0] Station ID start, [1] Station ID destination
# [2] Station name start, and [3] Station name destination
journey = [config.get(CONF_START), config.get(CONF_DESTINATION)]
try:
for location in [config.get(CONF_START), config.get(CONF_DESTINATION)]:
# transport.opendata.ch doesn't play nice with requests.Session
result = requests.get(
'{}locations?query={}'.format(_RESOURCE, location), timeout=10)
journey.append(result.json()['stations'][0]['name'])
except KeyError:
_LOGGER.exception(
"Unable to determine stations. "
"Check your settings and/or the availability of opendata.ch")
start = config.get(CONF_START)
destination = config.get(CONF_DESTINATION)
connection = SwissPublicTransportSensor(hass, start, destination, name)
yield from connection.async_update()
if connection.state is None:
_LOGGER.error(
"Check at http://transport.opendata.ch/examples/stationboard.html "
"if your station names are valid")
return False
data = PublicTransportData(journey)
add_devices([SwissPublicTransportSensor(data, journey, name)], True)
async_add_devices([connection])
class SwissPublicTransportSensor(Entity):
"""Implementation of an Swiss public transport sensor."""
def __init__(self, data, journey, name):
def __init__(self, hass, start, destination, name):
"""Initialize the sensor."""
self.data = data
from opendata_transport import OpendataTransport
self.hass = hass
self._name = name
self._state = None
self._times = None
self._from = journey[2]
self._to = journey[3]
self._from = start
self._to = destination
self._websession = async_get_clientsession(self.hass)
self._opendata = OpendataTransport(
self._from, self._to, self.hass.loop, self._websession)
@property
def name(self):
@ -86,70 +90,45 @@ class SwissPublicTransportSensor(Entity):
@property
def state(self):
"""Return the state of the sensor."""
return self._state
return self._opendata.connections[0]['departure'] \
if self._opendata is not None else None
@property
def device_state_attributes(self):
"""Return the state attributes."""
if self._times is not None:
return {
ATTR_DEPARTURE_TIME1: self._times[0],
ATTR_DEPARTURE_TIME2: self._times[1],
ATTR_START: self._from,
ATTR_TARGET: self._to,
ATTR_REMAINING_TIME: '{}'.format(
':'.join(str(self._times[2]).split(':')[:2])),
ATTR_ATTRIBUTION: CONF_ATTRIBUTION,
}
if self._opendata is None:
return
remaining_time = dt_util.parse_datetime(
self._opendata.connections[0]['departure']) -\
dt_util.as_local(dt_util.utcnow())
attr = {
ATTR_TRAIN_NUMBER: self._opendata.connections[0]['number'],
ATTR_PLATFORM: self._opendata.connections[0]['platform'],
ATTR_TRANSFERS: self._opendata.connections[0]['transfers'],
ATTR_DURATION: self._opendata.connections[0]['duration'],
ATTR_DEPARTURE_TIME1: self._opendata.connections[1]['departure'],
ATTR_DEPARTURE_TIME2: self._opendata.connections[2]['departure'],
ATTR_START: self._opendata.from_name,
ATTR_TARGET: self._opendata.to_name,
ATTR_REMAINING_TIME: '{}'.format(remaining_time),
ATTR_ATTRIBUTION: CONF_ATTRIBUTION,
}
return attr
@property
def icon(self):
"""Icon to use in the frontend, if any."""
return ICON
def update(self):
@asyncio.coroutine
def async_update(self):
"""Get the latest data from opendata.ch and update the states."""
self.data.update()
self._times = self.data.times
try:
self._state = self._times[0]
except TypeError:
pass
class PublicTransportData(object):
"""The Class for handling the data retrieval."""
def __init__(self, journey):
"""Initialize the data object."""
self.start = journey[0]
self.destination = journey[1]
self.times = {}
def update(self):
"""Get the latest data from opendata.ch."""
response = requests.get(
_RESOURCE +
'connections?' +
'from=' + self.start + '&' +
'to=' + self.destination + '&' +
'fields[]=connections/from/departureTimestamp/&' +
'fields[]=connections/',
timeout=10)
connections = response.json()['connections'][1:3]
from opendata_transport.exceptions import OpendataTransportError
try:
self.times = [
dt_util.as_local(
dt_util.utc_from_timestamp(
item['from']['departureTimestamp'])).strftime(
TIME_STR_FORMAT)
for item in connections
]
self.times.append(
dt_util.as_local(
dt_util.utc_from_timestamp(
connections[0]['from']['departureTimestamp'])) -
dt_util.as_local(dt_util.utcnow()))
except KeyError:
self.times = ['n/a']
yield from self._opendata.async_get_data()
except OpendataTransportError:
_LOGGER.error("Unable to retrieve data from transport.opendata.ch")
self._opendata = None

View file

@ -799,6 +799,9 @@ python-vlc==1.1.2
# homeassistant.components.wink
python-wink==1.5.1
# homeassistant.components.sensor.swiss_public_transport
python_opendata_transport==0.0.2
# homeassistant.components.zwave
python_openzwave==0.4.0.31