This commit is contained in:
Paulus Schoutsen 2019-07-31 12:25:30 -07:00
parent da05dfe708
commit 4de97abc3a
2676 changed files with 163166 additions and 140084 deletions

View file

@ -7,9 +7,7 @@ import platform
import subprocess import subprocess
import sys import sys
import threading import threading
from typing import ( # noqa pylint: disable=unused-import from typing import List, Dict, Any, TYPE_CHECKING # noqa pylint: disable=unused-import
List, Dict, Any, TYPE_CHECKING
)
from homeassistant import monkey_patch from homeassistant import monkey_patch
from homeassistant.const import ( from homeassistant.const import (
@ -30,11 +28,12 @@ def set_loop() -> None:
policy = None policy = None
if sys.platform == 'win32': if sys.platform == "win32":
if hasattr(asyncio, 'WindowsProactorEventLoopPolicy'): if hasattr(asyncio, "WindowsProactorEventLoopPolicy"):
# pylint: disable=no-member # pylint: disable=no-member
policy = asyncio.WindowsProactorEventLoopPolicy() policy = asyncio.WindowsProactorEventLoopPolicy()
else: else:
class ProactorPolicy(BaseDefaultEventLoopPolicy): class ProactorPolicy(BaseDefaultEventLoopPolicy):
"""Event loop policy to create proactor loops.""" """Event loop policy to create proactor loops."""
@ -56,28 +55,40 @@ def set_loop() -> None:
def validate_python() -> None: def validate_python() -> None:
"""Validate that the right Python version is running.""" """Validate that the right Python version is running."""
if sys.version_info[:3] < REQUIRED_PYTHON_VER: if sys.version_info[:3] < REQUIRED_PYTHON_VER:
print("Home Assistant requires at least Python {}.{}.{}".format( print(
*REQUIRED_PYTHON_VER)) "Home Assistant requires at least Python {}.{}.{}".format(
*REQUIRED_PYTHON_VER
)
)
sys.exit(1) sys.exit(1)
def ensure_config_path(config_dir: str) -> None: def ensure_config_path(config_dir: str) -> None:
"""Validate the configuration directory.""" """Validate the configuration directory."""
import homeassistant.config as config_util import homeassistant.config as config_util
lib_dir = os.path.join(config_dir, 'deps')
lib_dir = os.path.join(config_dir, "deps")
# Test if configuration directory exists # Test if configuration directory exists
if not os.path.isdir(config_dir): if not os.path.isdir(config_dir):
if config_dir != config_util.get_default_config_dir(): if config_dir != config_util.get_default_config_dir():
print(('Fatal Error: Specified configuration directory does ' print(
'not exist {} ').format(config_dir)) (
"Fatal Error: Specified configuration directory does "
"not exist {} "
).format(config_dir)
)
sys.exit(1) sys.exit(1)
try: try:
os.mkdir(config_dir) os.mkdir(config_dir)
except OSError: except OSError:
print(('Fatal Error: Unable to create default configuration ' print(
'directory {} ').format(config_dir)) (
"Fatal Error: Unable to create default configuration "
"directory {} "
).format(config_dir)
)
sys.exit(1) sys.exit(1)
# Test if library directory exists # Test if library directory exists
@ -85,20 +96,22 @@ def ensure_config_path(config_dir: str) -> None:
try: try:
os.mkdir(lib_dir) os.mkdir(lib_dir)
except OSError: except OSError:
print(('Fatal Error: Unable to create library ' print(
'directory {} ').format(lib_dir)) ("Fatal Error: Unable to create library " "directory {} ").format(
lib_dir
)
)
sys.exit(1) sys.exit(1)
async def ensure_config_file(hass: 'core.HomeAssistant', config_dir: str) \ async def ensure_config_file(hass: "core.HomeAssistant", config_dir: str) -> str:
-> str:
"""Ensure configuration file exists.""" """Ensure configuration file exists."""
import homeassistant.config as config_util import homeassistant.config as config_util
config_path = await config_util.async_ensure_config_exists(
hass, config_dir) config_path = await config_util.async_ensure_config_exists(hass, config_dir)
if config_path is None: if config_path is None:
print('Error getting configuration path') print("Error getting configuration path")
sys.exit(1) sys.exit(1)
return config_path return config_path
@ -107,71 +120,72 @@ async def ensure_config_file(hass: 'core.HomeAssistant', config_dir: str) \
def get_arguments() -> argparse.Namespace: def get_arguments() -> argparse.Namespace:
"""Get parsed passed in arguments.""" """Get parsed passed in arguments."""
import homeassistant.config as config_util import homeassistant.config as config_util
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Home Assistant: Observe, Control, Automate.") description="Home Assistant: Observe, Control, Automate."
parser.add_argument('--version', action='version', version=__version__) )
parser.add_argument("--version", action="version", version=__version__)
parser.add_argument( parser.add_argument(
'-c', '--config', "-c",
metavar='path_to_config_dir', "--config",
metavar="path_to_config_dir",
default=config_util.get_default_config_dir(), default=config_util.get_default_config_dir(),
help="Directory that contains the Home Assistant configuration") help="Directory that contains the Home Assistant configuration",
)
parser.add_argument( parser.add_argument(
'--demo-mode', "--demo-mode", action="store_true", help="Start Home Assistant in demo mode"
action='store_true', )
help='Start Home Assistant in demo mode')
parser.add_argument( parser.add_argument(
'--debug', "--debug", action="store_true", help="Start Home Assistant in debug mode"
action='store_true', )
help='Start Home Assistant in debug mode')
parser.add_argument( parser.add_argument(
'--open-ui', "--open-ui", action="store_true", help="Open the webinterface in a browser"
action='store_true', )
help='Open the webinterface in a browser')
parser.add_argument( parser.add_argument(
'--skip-pip', "--skip-pip",
action='store_true', action="store_true",
help='Skips pip install of required packages on startup') help="Skips pip install of required packages on startup",
)
parser.add_argument( parser.add_argument(
'-v', '--verbose', "-v", "--verbose", action="store_true", help="Enable verbose logging to file."
action='store_true', )
help="Enable verbose logging to file.")
parser.add_argument( parser.add_argument(
'--pid-file', "--pid-file",
metavar='path_to_pid_file', metavar="path_to_pid_file",
default=None, default=None,
help='Path to PID file useful for running as daemon') help="Path to PID file useful for running as daemon",
)
parser.add_argument( parser.add_argument(
'--log-rotate-days', "--log-rotate-days",
type=int, type=int,
default=None, default=None,
help='Enables daily log rotation and keeps up to the specified days') help="Enables daily log rotation and keeps up to the specified days",
)
parser.add_argument( parser.add_argument(
'--log-file', "--log-file",
type=str, type=str,
default=None, default=None,
help='Log file to write to. If not set, CONFIG/home-assistant.log ' help="Log file to write to. If not set, CONFIG/home-assistant.log " "is used",
'is used') )
parser.add_argument( parser.add_argument(
'--log-no-color', "--log-no-color", action="store_true", help="Disable color logs"
action='store_true', )
help="Disable color logs")
parser.add_argument( parser.add_argument(
'--runner', "--runner",
action='store_true', action="store_true",
help='On restart exit with code {}'.format(RESTART_EXIT_CODE)) help="On restart exit with code {}".format(RESTART_EXIT_CODE),
)
parser.add_argument( parser.add_argument(
'--script', "--script", nargs=argparse.REMAINDER, help="Run one of the embedded scripts"
nargs=argparse.REMAINDER, )
help='Run one of the embedded scripts')
if os.name == "posix": if os.name == "posix":
parser.add_argument( parser.add_argument(
'--daemon', "--daemon", action="store_true", help="Run Home Assistant as daemon"
action='store_true', )
help='Run Home Assistant as daemon')
arguments = parser.parse_args() arguments = parser.parse_args()
if os.name != "posix" or arguments.debug or arguments.runner: if os.name != "posix" or arguments.debug or arguments.runner:
setattr(arguments, 'daemon', False) setattr(arguments, "daemon", False)
return arguments return arguments
@ -192,8 +206,8 @@ def daemonize() -> None:
sys.exit(0) sys.exit(0)
# redirect standard file descriptors to devnull # redirect standard file descriptors to devnull
infd = open(os.devnull, 'r') infd = open(os.devnull, "r")
outfd = open(os.devnull, 'a+') outfd = open(os.devnull, "a+")
sys.stdout.flush() sys.stdout.flush()
sys.stderr.flush() sys.stderr.flush()
os.dup2(infd.fileno(), sys.stdin.fileno()) os.dup2(infd.fileno(), sys.stdin.fileno())
@ -205,7 +219,7 @@ def check_pid(pid_file: str) -> None:
"""Check that Home Assistant is not already running.""" """Check that Home Assistant is not already running."""
# Check pid file # Check pid file
try: try:
with open(pid_file, 'r') as file: with open(pid_file, "r") as file:
pid = int(file.readline()) pid = int(file.readline())
except IOError: except IOError:
# PID File does not exist # PID File does not exist
@ -220,7 +234,7 @@ def check_pid(pid_file: str) -> None:
except OSError: except OSError:
# PID does not exist # PID does not exist
return return
print('Fatal Error: HomeAssistant is already running.') print("Fatal Error: HomeAssistant is already running.")
sys.exit(1) sys.exit(1)
@ -228,10 +242,10 @@ def write_pid(pid_file: str) -> None:
"""Create a PID File.""" """Create a PID File."""
pid = os.getpid() pid = os.getpid()
try: try:
with open(pid_file, 'w') as file: with open(pid_file, "w") as file:
file.write(str(pid)) file.write(str(pid))
except IOError: except IOError:
print('Fatal Error: Unable to write pid file {}'.format(pid_file)) print("Fatal Error: Unable to write pid file {}".format(pid_file))
sys.exit(1) sys.exit(1)
@ -255,17 +269,15 @@ def closefds_osx(min_fd: int, max_fd: int) -> None:
def cmdline() -> List[str]: def cmdline() -> List[str]:
"""Collect path and arguments to re-execute the current hass instance.""" """Collect path and arguments to re-execute the current hass instance."""
if os.path.basename(sys.argv[0]) == '__main__.py': if os.path.basename(sys.argv[0]) == "__main__.py":
modulepath = os.path.dirname(sys.argv[0]) modulepath = os.path.dirname(sys.argv[0])
os.environ['PYTHONPATH'] = os.path.dirname(modulepath) os.environ["PYTHONPATH"] = os.path.dirname(modulepath)
return [sys.executable] + [arg for arg in sys.argv if return [sys.executable] + [arg for arg in sys.argv if arg != "--daemon"]
arg != '--daemon']
return [arg for arg in sys.argv if arg != '--daemon'] return [arg for arg in sys.argv if arg != "--daemon"]
async def setup_and_run_hass(config_dir: str, async def setup_and_run_hass(config_dir: str, args: argparse.Namespace) -> int:
args: argparse.Namespace) -> int:
"""Set up HASS and run.""" """Set up HASS and run."""
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
from homeassistant import bootstrap, core from homeassistant import bootstrap, core
@ -273,21 +285,29 @@ async def setup_and_run_hass(config_dir: str,
hass = core.HomeAssistant() hass = core.HomeAssistant()
if args.demo_mode: if args.demo_mode:
config = { config = {"frontend": {}, "demo": {}} # type: Dict[str, Any]
'frontend': {},
'demo': {}
} # type: Dict[str, Any]
bootstrap.async_from_config_dict( bootstrap.async_from_config_dict(
config, hass, config_dir=config_dir, verbose=args.verbose, config,
skip_pip=args.skip_pip, log_rotate_days=args.log_rotate_days, hass,
log_file=args.log_file, log_no_color=args.log_no_color) config_dir=config_dir,
verbose=args.verbose,
skip_pip=args.skip_pip,
log_rotate_days=args.log_rotate_days,
log_file=args.log_file,
log_no_color=args.log_no_color,
)
else: else:
config_file = await ensure_config_file(hass, config_dir) config_file = await ensure_config_file(hass, config_dir)
print('Config directory:', config_dir) print("Config directory:", config_dir)
await bootstrap.async_from_config_file( await bootstrap.async_from_config_file(
config_file, hass, verbose=args.verbose, skip_pip=args.skip_pip, config_file,
log_rotate_days=args.log_rotate_days, log_file=args.log_file, hass,
log_no_color=args.log_no_color) verbose=args.verbose,
skip_pip=args.skip_pip,
log_rotate_days=args.log_rotate_days,
log_file=args.log_file,
log_no_color=args.log_no_color,
)
if args.open_ui: if args.open_ui:
# Imported here to avoid importing asyncio before monkey patch # Imported here to avoid importing asyncio before monkey patch
@ -297,12 +317,14 @@ async def setup_and_run_hass(config_dir: str,
"""Open the web interface in a browser.""" """Open the web interface in a browser."""
if hass.config.api is not None: if hass.config.api is not None:
import webbrowser import webbrowser
webbrowser.open(hass.config.api.base_url) webbrowser.open(hass.config.api.base_url)
run_callback_threadsafe( run_callback_threadsafe(
hass.loop, hass.loop,
hass.bus.async_listen_once, hass.bus.async_listen_once,
EVENT_HOMEASSISTANT_START, open_browser EVENT_HOMEASSISTANT_START,
open_browser,
) )
return await hass.async_run() return await hass.async_run()
@ -312,17 +334,17 @@ def try_to_restart() -> None:
"""Attempt to clean up state and start a new Home Assistant instance.""" """Attempt to clean up state and start a new Home Assistant instance."""
# Things should be mostly shut down already at this point, now just try # Things should be mostly shut down already at this point, now just try
# to clean up things that may have been left behind. # to clean up things that may have been left behind.
sys.stderr.write('Home Assistant attempting to restart.\n') sys.stderr.write("Home Assistant attempting to restart.\n")
# Count remaining threads, ideally there should only be one non-daemonized # Count remaining threads, ideally there should only be one non-daemonized
# thread left (which is us). Nothing we really do with it, but it might be # thread left (which is us). Nothing we really do with it, but it might be
# useful when debugging shutdown/restart issues. # useful when debugging shutdown/restart issues.
try: try:
nthreads = sum(thread.is_alive() and not thread.daemon nthreads = sum(
for thread in threading.enumerate()) thread.is_alive() and not thread.daemon for thread in threading.enumerate()
)
if nthreads > 1: if nthreads > 1:
sys.stderr.write( sys.stderr.write("Found {} non-daemonic threads.\n".format(nthreads))
"Found {} non-daemonic threads.\n".format(nthreads))
# Somehow we sometimes seem to trigger an assertion in the python threading # Somehow we sometimes seem to trigger an assertion in the python threading
# module. It seems we find threads that have no associated OS level thread # module. It seems we find threads that have no associated OS level thread
@ -336,7 +358,7 @@ def try_to_restart() -> None:
except ValueError: except ValueError:
max_fd = 256 max_fd = 256
if platform.system() == 'Darwin': if platform.system() == "Darwin":
closefds_osx(3, max_fd) closefds_osx(3, max_fd)
else: else:
os.closerange(3, max_fd) os.closerange(3, max_fd)
@ -355,15 +377,15 @@ def main() -> int:
validate_python() validate_python()
monkey_patch_needed = sys.version_info[:3] < (3, 6, 3) monkey_patch_needed = sys.version_info[:3] < (3, 6, 3)
if monkey_patch_needed and os.environ.get('HASS_NO_MONKEY') != '1': if monkey_patch_needed and os.environ.get("HASS_NO_MONKEY") != "1":
monkey_patch.disable_c_asyncio() monkey_patch.disable_c_asyncio()
monkey_patch.patch_weakref_tasks() monkey_patch.patch_weakref_tasks()
set_loop() set_loop()
# Run a simple daemon runner process on Windows to handle restarts # Run a simple daemon runner process on Windows to handle restarts
if os.name == 'nt' and '--runner' not in sys.argv: if os.name == "nt" and "--runner" not in sys.argv:
nt_args = cmdline() + ['--runner'] nt_args = cmdline() + ["--runner"]
while True: while True:
try: try:
subprocess.check_call(nt_args) subprocess.check_call(nt_args)
@ -378,6 +400,7 @@ def main() -> int:
if args.script is not None: if args.script is not None:
from homeassistant import scripts from homeassistant import scripts
return scripts.run(args.script) return scripts.run(args.script)
config_dir = os.path.join(os.getcwd(), args.config) config_dir = os.path.join(os.getcwd(), args.config)
@ -392,6 +415,7 @@ def main() -> int:
write_pid(args.pid_file) write_pid(args.pid_file)
from homeassistant.util.async_ import asyncio_run from homeassistant.util.async_ import asyncio_run
exit_code = asyncio_run(setup_and_run_hass(config_dir, args)) exit_code = asyncio_run(setup_and_run_hass(config_dir, args))
if exit_code == RESTART_EXIT_CODE and not args.runner: if exit_code == RESTART_EXIT_CODE and not args.runner:
try_to_restart() try_to_restart()

View file

@ -17,8 +17,8 @@ from .const import GROUP_ID_ADMIN
from .mfa_modules import auth_mfa_module_from_config, MultiFactorAuthModule from .mfa_modules import auth_mfa_module_from_config, MultiFactorAuthModule
from .providers import auth_provider_from_config, AuthProvider, LoginFlow from .providers import auth_provider_from_config, AuthProvider, LoginFlow
EVENT_USER_ADDED = 'user_added' EVENT_USER_ADDED = "user_added"
EVENT_USER_REMOVED = 'user_removed' EVENT_USER_REMOVED = "user_removed"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_MfaModuleDict = Dict[str, MultiFactorAuthModule] _MfaModuleDict = Dict[str, MultiFactorAuthModule]
@ -27,9 +27,10 @@ _ProviderDict = Dict[_ProviderKey, AuthProvider]
async def auth_manager_from_config( async def auth_manager_from_config(
hass: HomeAssistant, hass: HomeAssistant,
provider_configs: List[Dict[str, Any]], provider_configs: List[Dict[str, Any]],
module_configs: List[Dict[str, Any]]) -> 'AuthManager': module_configs: List[Dict[str, Any]],
) -> "AuthManager":
"""Initialize an auth manager from config. """Initialize an auth manager from config.
CORE_CONFIG_SCHEMA will make sure do duplicated auth providers or CORE_CONFIG_SCHEMA will make sure do duplicated auth providers or
@ -38,8 +39,11 @@ async def auth_manager_from_config(
store = auth_store.AuthStore(hass) store = auth_store.AuthStore(hass)
if provider_configs: if provider_configs:
providers = await asyncio.gather( providers = await asyncio.gather(
*(auth_provider_from_config(hass, store, config) *(
for config in provider_configs)) auth_provider_from_config(hass, store, config)
for config in provider_configs
)
)
else: else:
providers = () providers = ()
# So returned auth providers are in same order as config # So returned auth providers are in same order as config
@ -50,8 +54,8 @@ async def auth_manager_from_config(
if module_configs: if module_configs:
modules = await asyncio.gather( modules = await asyncio.gather(
*(auth_mfa_module_from_config(hass, config) *(auth_mfa_module_from_config(hass, config) for config in module_configs)
for config in module_configs)) )
else: else:
modules = () modules = ()
# So returned auth modules are in same order as config # So returned auth modules are in same order as config
@ -66,17 +70,21 @@ async def auth_manager_from_config(
class AuthManager: class AuthManager:
"""Manage the authentication for Home Assistant.""" """Manage the authentication for Home Assistant."""
def __init__(self, hass: HomeAssistant, store: auth_store.AuthStore, def __init__(
providers: _ProviderDict, mfa_modules: _MfaModuleDict) \ self,
-> None: hass: HomeAssistant,
store: auth_store.AuthStore,
providers: _ProviderDict,
mfa_modules: _MfaModuleDict,
) -> None:
"""Initialize the auth manager.""" """Initialize the auth manager."""
self.hass = hass self.hass = hass
self._store = store self._store = store
self._providers = providers self._providers = providers
self._mfa_modules = mfa_modules self._mfa_modules = mfa_modules
self.login_flow = data_entry_flow.FlowManager( self.login_flow = data_entry_flow.FlowManager(
hass, self._async_create_login_flow, hass, self._async_create_login_flow, self._async_finish_login_flow
self._async_finish_login_flow) )
@property @property
def support_legacy(self) -> bool: def support_legacy(self) -> bool:
@ -86,7 +94,7 @@ class AuthManager:
Should be removed when we removed legacy_api_password auth providers. Should be removed when we removed legacy_api_password auth providers.
""" """
for provider_type, _ in self._providers: for provider_type, _ in self._providers:
if provider_type == 'legacy_api_password': if provider_type == "legacy_api_password":
return True return True
return False return False
@ -100,20 +108,21 @@ class AuthManager:
"""Return a list of available auth modules.""" """Return a list of available auth modules."""
return list(self._mfa_modules.values()) return list(self._mfa_modules.values())
def get_auth_provider(self, provider_type: str, provider_id: str) \ def get_auth_provider(
-> Optional[AuthProvider]: self, provider_type: str, provider_id: str
) -> Optional[AuthProvider]:
"""Return an auth provider, None if not found.""" """Return an auth provider, None if not found."""
return self._providers.get((provider_type, provider_id)) return self._providers.get((provider_type, provider_id))
def get_auth_providers(self, provider_type: str) \ def get_auth_providers(self, provider_type: str) -> List[AuthProvider]:
-> List[AuthProvider]:
"""Return a List of auth provider of one type, Empty if not found.""" """Return a List of auth provider of one type, Empty if not found."""
return [provider return [
for (p_type, _), provider in self._providers.items() provider
if p_type == provider_type] for (p_type, _), provider in self._providers.items()
if p_type == provider_type
]
def get_auth_mfa_module(self, module_id: str) \ def get_auth_mfa_module(self, module_id: str) -> Optional[MultiFactorAuthModule]:
-> Optional[MultiFactorAuthModule]:
"""Return a multi-factor auth module, None if not found.""" """Return a multi-factor auth module, None if not found."""
return self._mfa_modules.get(module_id) return self._mfa_modules.get(module_id)
@ -135,7 +144,8 @@ class AuthManager:
return await self._store.async_get_group(group_id) return await self._store.async_get_group(group_id)
async def async_get_user_by_credentials( async def async_get_user_by_credentials(
self, credentials: models.Credentials) -> Optional[models.User]: self, credentials: models.Credentials
) -> Optional[models.User]:
"""Get a user by credential, return None if not found.""" """Get a user by credential, return None if not found."""
for user in await self.async_get_users(): for user in await self.async_get_users():
for creds in user.credentials: for creds in user.credentials:
@ -145,57 +155,50 @@ class AuthManager:
return None return None
async def async_create_system_user( async def async_create_system_user(
self, name: str, self, name: str, group_ids: Optional[List[str]] = None
group_ids: Optional[List[str]] = None) -> models.User: ) -> models.User:
"""Create a system user.""" """Create a system user."""
user = await self._store.async_create_user( user = await self._store.async_create_user(
name=name, name=name, system_generated=True, is_active=True, group_ids=group_ids or []
system_generated=True,
is_active=True,
group_ids=group_ids or [],
) )
self.hass.bus.async_fire(EVENT_USER_ADDED, { self.hass.bus.async_fire(EVENT_USER_ADDED, {"user_id": user.id})
'user_id': user.id
})
return user return user
async def async_create_user(self, name: str) -> models.User: async def async_create_user(self, name: str) -> models.User:
"""Create a user.""" """Create a user."""
kwargs = { kwargs = {
'name': name, "name": name,
'is_active': True, "is_active": True,
'group_ids': [GROUP_ID_ADMIN] "group_ids": [GROUP_ID_ADMIN],
} # type: Dict[str, Any] } # type: Dict[str, Any]
if await self._user_should_be_owner(): if await self._user_should_be_owner():
kwargs['is_owner'] = True kwargs["is_owner"] = True
user = await self._store.async_create_user(**kwargs) user = await self._store.async_create_user(**kwargs)
self.hass.bus.async_fire(EVENT_USER_ADDED, { self.hass.bus.async_fire(EVENT_USER_ADDED, {"user_id": user.id})
'user_id': user.id
})
return user return user
async def async_get_or_create_user(self, credentials: models.Credentials) \ async def async_get_or_create_user(
-> models.User: self, credentials: models.Credentials
) -> models.User:
"""Get or create a user.""" """Get or create a user."""
if not credentials.is_new: if not credentials.is_new:
user = await self.async_get_user_by_credentials(credentials) user = await self.async_get_user_by_credentials(credentials)
if user is None: if user is None:
raise ValueError('Unable to find the user.') raise ValueError("Unable to find the user.")
return user return user
auth_provider = self._async_get_auth_provider(credentials) auth_provider = self._async_get_auth_provider(credentials)
if auth_provider is None: if auth_provider is None:
raise RuntimeError('Credential with unknown provider encountered') raise RuntimeError("Credential with unknown provider encountered")
info = await auth_provider.async_user_meta_for_credentials( info = await auth_provider.async_user_meta_for_credentials(credentials)
credentials)
user = await self._store.async_create_user( user = await self._store.async_create_user(
credentials=credentials, credentials=credentials,
@ -204,14 +207,13 @@ class AuthManager:
group_ids=[GROUP_ID_ADMIN], group_ids=[GROUP_ID_ADMIN],
) )
self.hass.bus.async_fire(EVENT_USER_ADDED, { self.hass.bus.async_fire(EVENT_USER_ADDED, {"user_id": user.id})
'user_id': user.id
})
return user return user
async def async_link_user(self, user: models.User, async def async_link_user(
credentials: models.Credentials) -> None: self, user: models.User, credentials: models.Credentials
) -> None:
"""Link credentials to an existing user.""" """Link credentials to an existing user."""
await self._store.async_link_user(user, credentials) await self._store.async_link_user(user, credentials)
@ -227,19 +229,20 @@ class AuthManager:
await self._store.async_remove_user(user) await self._store.async_remove_user(user)
self.hass.bus.async_fire(EVENT_USER_REMOVED, { self.hass.bus.async_fire(EVENT_USER_REMOVED, {"user_id": user.id})
'user_id': user.id
})
async def async_update_user(self, user: models.User, async def async_update_user(
name: Optional[str] = None, self,
group_ids: Optional[List[str]] = None) -> None: user: models.User,
name: Optional[str] = None,
group_ids: Optional[List[str]] = None,
) -> None:
"""Update a user.""" """Update a user."""
kwargs = {} # type: Dict[str,Any] kwargs = {} # type: Dict[str,Any]
if name is not None: if name is not None:
kwargs['name'] = name kwargs["name"] = name
if group_ids is not None: if group_ids is not None:
kwargs['group_ids'] = group_ids kwargs["group_ids"] = group_ids
await self._store.async_update_user(user, **kwargs) await self._store.async_update_user(user, **kwargs)
async def async_activate_user(self, user: models.User) -> None: async def async_activate_user(self, user: models.User) -> None:
@ -249,47 +252,52 @@ class AuthManager:
async def async_deactivate_user(self, user: models.User) -> None: async def async_deactivate_user(self, user: models.User) -> None:
"""Deactivate a user.""" """Deactivate a user."""
if user.is_owner: if user.is_owner:
raise ValueError('Unable to deactive the owner') raise ValueError("Unable to deactive the owner")
await self._store.async_deactivate_user(user) await self._store.async_deactivate_user(user)
async def async_remove_credentials( async def async_remove_credentials(self, credentials: models.Credentials) -> None:
self, credentials: models.Credentials) -> None:
"""Remove credentials.""" """Remove credentials."""
provider = self._async_get_auth_provider(credentials) provider = self._async_get_auth_provider(credentials)
if (provider is not None and if provider is not None and hasattr(provider, "async_will_remove_credentials"):
hasattr(provider, 'async_will_remove_credentials')):
# https://github.com/python/mypy/issues/1424 # https://github.com/python/mypy/issues/1424
await provider.async_will_remove_credentials( # type: ignore await provider.async_will_remove_credentials( # type: ignore
credentials) credentials
)
await self._store.async_remove_credentials(credentials) await self._store.async_remove_credentials(credentials)
async def async_enable_user_mfa(self, user: models.User, async def async_enable_user_mfa(
mfa_module_id: str, data: Any) -> None: self, user: models.User, mfa_module_id: str, data: Any
) -> None:
"""Enable a multi-factor auth module for user.""" """Enable a multi-factor auth module for user."""
if user.system_generated: if user.system_generated:
raise ValueError('System generated users cannot enable ' raise ValueError(
'multi-factor auth module.') "System generated users cannot enable " "multi-factor auth module."
)
module = self.get_auth_mfa_module(mfa_module_id) module = self.get_auth_mfa_module(mfa_module_id)
if module is None: if module is None:
raise ValueError('Unable find multi-factor auth module: {}' raise ValueError(
.format(mfa_module_id)) "Unable find multi-factor auth module: {}".format(mfa_module_id)
)
await module.async_setup_user(user.id, data) await module.async_setup_user(user.id, data)
async def async_disable_user_mfa(self, user: models.User, async def async_disable_user_mfa(
mfa_module_id: str) -> None: self, user: models.User, mfa_module_id: str
) -> None:
"""Disable a multi-factor auth module for user.""" """Disable a multi-factor auth module for user."""
if user.system_generated: if user.system_generated:
raise ValueError('System generated users cannot disable ' raise ValueError(
'multi-factor auth module.') "System generated users cannot disable " "multi-factor auth module."
)
module = self.get_auth_mfa_module(mfa_module_id) module = self.get_auth_mfa_module(mfa_module_id)
if module is None: if module is None:
raise ValueError('Unable find multi-factor auth module: {}' raise ValueError(
.format(mfa_module_id)) "Unable find multi-factor auth module: {}".format(mfa_module_id)
)
await module.async_depose_user(user.id) await module.async_depose_user(user.id)
@ -302,20 +310,23 @@ class AuthManager:
return modules return modules
async def async_create_refresh_token( async def async_create_refresh_token(
self, user: models.User, client_id: Optional[str] = None, self,
client_name: Optional[str] = None, user: models.User,
client_icon: Optional[str] = None, client_id: Optional[str] = None,
token_type: Optional[str] = None, client_name: Optional[str] = None,
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION) \ client_icon: Optional[str] = None,
-> models.RefreshToken: token_type: Optional[str] = None,
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
) -> models.RefreshToken:
"""Create a new refresh token for a user.""" """Create a new refresh token for a user."""
if not user.is_active: if not user.is_active:
raise ValueError('User is not active') raise ValueError("User is not active")
if user.system_generated and client_id is not None: if user.system_generated and client_id is not None:
raise ValueError( raise ValueError(
'System generated users cannot have refresh tokens connected ' "System generated users cannot have refresh tokens connected "
'to a client.') "to a client."
)
if token_type is None: if token_type is None:
if user.system_generated: if user.system_generated:
@ -325,61 +336,76 @@ class AuthManager:
if user.system_generated != (token_type == models.TOKEN_TYPE_SYSTEM): if user.system_generated != (token_type == models.TOKEN_TYPE_SYSTEM):
raise ValueError( raise ValueError(
'System generated users can only have system type ' "System generated users can only have system type " "refresh tokens"
'refresh tokens') )
if token_type == models.TOKEN_TYPE_NORMAL and client_id is None: if token_type == models.TOKEN_TYPE_NORMAL and client_id is None:
raise ValueError('Client is required to generate a refresh token.') raise ValueError("Client is required to generate a refresh token.")
if (token_type == models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN and if (
client_name is None): token_type == models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
raise ValueError('Client_name is required for long-lived access ' and client_name is None
'token') ):
raise ValueError("Client_name is required for long-lived access " "token")
if token_type == models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN: if token_type == models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN:
for token in user.refresh_tokens.values(): for token in user.refresh_tokens.values():
if (token.client_name == client_name and token.token_type == if (
models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN): token.client_name == client_name
and token.token_type == models.TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN
):
# Each client_name can only have one # Each client_name can only have one
# long_lived_access_token type of refresh token # long_lived_access_token type of refresh token
raise ValueError('{} already exists'.format(client_name)) raise ValueError("{} already exists".format(client_name))
return await self._store.async_create_refresh_token( return await self._store.async_create_refresh_token(
user, client_id, client_name, client_icon, user,
token_type, access_token_expiration) client_id,
client_name,
client_icon,
token_type,
access_token_expiration,
)
async def async_get_refresh_token( async def async_get_refresh_token(
self, token_id: str) -> Optional[models.RefreshToken]: self, token_id: str
) -> Optional[models.RefreshToken]:
"""Get refresh token by id.""" """Get refresh token by id."""
return await self._store.async_get_refresh_token(token_id) return await self._store.async_get_refresh_token(token_id)
async def async_get_refresh_token_by_token( async def async_get_refresh_token_by_token(
self, token: str) -> Optional[models.RefreshToken]: self, token: str
) -> Optional[models.RefreshToken]:
"""Get refresh token by token.""" """Get refresh token by token."""
return await self._store.async_get_refresh_token_by_token(token) return await self._store.async_get_refresh_token_by_token(token)
async def async_remove_refresh_token(self, async def async_remove_refresh_token(
refresh_token: models.RefreshToken) \ self, refresh_token: models.RefreshToken
-> None: ) -> None:
"""Delete a refresh token.""" """Delete a refresh token."""
await self._store.async_remove_refresh_token(refresh_token) await self._store.async_remove_refresh_token(refresh_token)
@callback @callback
def async_create_access_token(self, def async_create_access_token(
refresh_token: models.RefreshToken, self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
remote_ip: Optional[str] = None) -> str: ) -> str:
"""Create a new access token.""" """Create a new access token."""
self._store.async_log_refresh_token_usage(refresh_token, remote_ip) self._store.async_log_refresh_token_usage(refresh_token, remote_ip)
now = dt_util.utcnow() now = dt_util.utcnow()
return jwt.encode({ return jwt.encode(
'iss': refresh_token.id, {
'iat': now, "iss": refresh_token.id,
'exp': now + refresh_token.access_token_expiration, "iat": now,
}, refresh_token.jwt_key, algorithm='HS256').decode() "exp": now + refresh_token.access_token_expiration,
},
refresh_token.jwt_key,
algorithm="HS256",
).decode()
async def async_validate_access_token( async def async_validate_access_token(
self, token: str) -> Optional[models.RefreshToken]: self, token: str
) -> Optional[models.RefreshToken]:
"""Return refresh token if an access token is valid.""" """Return refresh token if an access token is valid."""
try: try:
unverif_claims = jwt.decode(token, verify=False) unverif_claims = jwt.decode(token, verify=False)
@ -387,23 +413,18 @@ class AuthManager:
return None return None
refresh_token = await self.async_get_refresh_token( refresh_token = await self.async_get_refresh_token(
cast(str, unverif_claims.get('iss'))) cast(str, unverif_claims.get("iss"))
)
if refresh_token is None: if refresh_token is None:
jwt_key = '' jwt_key = ""
issuer = '' issuer = ""
else: else:
jwt_key = refresh_token.jwt_key jwt_key = refresh_token.jwt_key
issuer = refresh_token.id issuer = refresh_token.id
try: try:
jwt.decode( jwt.decode(token, jwt_key, leeway=10, issuer=issuer, algorithms=["HS256"])
token,
jwt_key,
leeway=10,
issuer=issuer,
algorithms=['HS256']
)
except jwt.InvalidTokenError: except jwt.InvalidTokenError:
return None return None
@ -413,31 +434,32 @@ class AuthManager:
return refresh_token return refresh_token
async def _async_create_login_flow( async def _async_create_login_flow(
self, handler: _ProviderKey, *, context: Optional[Dict], self, handler: _ProviderKey, *, context: Optional[Dict], data: Optional[Any]
data: Optional[Any]) -> data_entry_flow.FlowHandler: ) -> data_entry_flow.FlowHandler:
"""Create a login flow.""" """Create a login flow."""
auth_provider = self._providers[handler] auth_provider = self._providers[handler]
return await auth_provider.async_login_flow(context) return await auth_provider.async_login_flow(context)
async def _async_finish_login_flow( async def _async_finish_login_flow(
self, flow: LoginFlow, result: Dict[str, Any]) \ self, flow: LoginFlow, result: Dict[str, Any]
-> Dict[str, Any]: ) -> Dict[str, Any]:
"""Return a user as result of login flow.""" """Return a user as result of login flow."""
if result['type'] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY: if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
return result return result
# we got final result # we got final result
if isinstance(result['data'], models.User): if isinstance(result["data"], models.User):
result['result'] = result['data'] result["result"] = result["data"]
return result return result
auth_provider = self._providers[result['handler']] auth_provider = self._providers[result["handler"]]
credentials = await auth_provider.async_get_or_create_credentials( credentials = await auth_provider.async_get_or_create_credentials(
result['data']) result["data"]
)
if flow.context is not None and flow.context.get('credential_only'): if flow.context is not None and flow.context.get("credential_only"):
result['result'] = credentials result["result"] = credentials
return result return result
# multi-factor module cannot enabled for new credential # multi-factor module cannot enabled for new credential
@ -452,15 +474,18 @@ class AuthManager:
flow.available_mfa_modules = modules flow.available_mfa_modules = modules
return await flow.async_step_select_mfa_module() return await flow.async_step_select_mfa_module()
result['result'] = await self.async_get_or_create_user(credentials) result["result"] = await self.async_get_or_create_user(credentials)
return result return result
@callback @callback
def _async_get_auth_provider( def _async_get_auth_provider(
self, credentials: models.Credentials) -> Optional[AuthProvider]: self, credentials: models.Credentials
) -> Optional[AuthProvider]:
"""Get auth provider from a set of credentials.""" """Get auth provider from a set of credentials."""
auth_provider_key = (credentials.auth_provider_type, auth_provider_key = (
credentials.auth_provider_id) credentials.auth_provider_type,
credentials.auth_provider_id,
)
return self._providers.get(auth_provider_key) return self._providers.get(auth_provider_key)
async def _user_should_be_owner(self) -> bool: async def _user_should_be_owner(self) -> bool:

View file

@ -16,10 +16,10 @@ from .permissions import PermissionLookup, system_policies
from .permissions.types import PolicyType # noqa: F401 from .permissions.types import PolicyType # noqa: F401
STORAGE_VERSION = 1 STORAGE_VERSION = 1
STORAGE_KEY = 'auth' STORAGE_KEY = "auth"
GROUP_NAME_ADMIN = 'Administrators' GROUP_NAME_ADMIN = "Administrators"
GROUP_NAME_USER = "Users" GROUP_NAME_USER = "Users"
GROUP_NAME_READ_ONLY = 'Read Only' GROUP_NAME_READ_ONLY = "Read Only"
class AuthStore: class AuthStore:
@ -37,8 +37,9 @@ class AuthStore:
self._users = None # type: Optional[Dict[str, models.User]] self._users = None # type: Optional[Dict[str, models.User]]
self._groups = None # type: Optional[Dict[str, models.Group]] self._groups = None # type: Optional[Dict[str, models.Group]]
self._perm_lookup = None # type: Optional[PermissionLookup] self._perm_lookup = None # type: Optional[PermissionLookup]
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY, self._store = hass.helpers.storage.Store(
private=True) STORAGE_VERSION, STORAGE_KEY, private=True
)
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
async def async_get_groups(self) -> List[models.Group]: async def async_get_groups(self) -> List[models.Group]:
@ -74,11 +75,14 @@ class AuthStore:
return self._users.get(user_id) return self._users.get(user_id)
async def async_create_user( async def async_create_user(
self, name: Optional[str], is_owner: Optional[bool] = None, self,
is_active: Optional[bool] = None, name: Optional[str],
system_generated: Optional[bool] = None, is_owner: Optional[bool] = None,
credentials: Optional[models.Credentials] = None, is_active: Optional[bool] = None,
group_ids: Optional[List[str]] = None) -> models.User: system_generated: Optional[bool] = None,
credentials: Optional[models.Credentials] = None,
group_ids: Optional[List[str]] = None,
) -> models.User:
"""Create a new user.""" """Create a new user."""
if self._users is None: if self._users is None:
await self._async_load() await self._async_load()
@ -87,28 +91,28 @@ class AuthStore:
assert self._groups is not None assert self._groups is not None
groups = [] groups = []
for group_id in (group_ids or []): for group_id in group_ids or []:
group = self._groups.get(group_id) group = self._groups.get(group_id)
if group is None: if group is None:
raise ValueError('Invalid group specified {}'.format(group_id)) raise ValueError("Invalid group specified {}".format(group_id))
groups.append(group) groups.append(group)
kwargs = { kwargs = {
'name': name, "name": name,
# Until we get group management, we just put everyone in the # Until we get group management, we just put everyone in the
# same group. # same group.
'groups': groups, "groups": groups,
'perm_lookup': self._perm_lookup, "perm_lookup": self._perm_lookup,
} # type: Dict[str, Any] } # type: Dict[str, Any]
if is_owner is not None: if is_owner is not None:
kwargs['is_owner'] = is_owner kwargs["is_owner"] = is_owner
if is_active is not None: if is_active is not None:
kwargs['is_active'] = is_active kwargs["is_active"] = is_active
if system_generated is not None: if system_generated is not None:
kwargs['system_generated'] = system_generated kwargs["system_generated"] = system_generated
new_user = models.User(**kwargs) new_user = models.User(**kwargs)
@ -122,8 +126,9 @@ class AuthStore:
await self.async_link_user(new_user, credentials) await self.async_link_user(new_user, credentials)
return new_user return new_user
async def async_link_user(self, user: models.User, async def async_link_user(
credentials: models.Credentials) -> None: self, user: models.User, credentials: models.Credentials
) -> None:
"""Add credentials to an existing user.""" """Add credentials to an existing user."""
user.credentials.append(credentials) user.credentials.append(credentials)
self._async_schedule_save() self._async_schedule_save()
@ -139,9 +144,12 @@ class AuthStore:
self._async_schedule_save() self._async_schedule_save()
async def async_update_user( async def async_update_user(
self, user: models.User, name: Optional[str] = None, self,
is_active: Optional[bool] = None, user: models.User,
group_ids: Optional[List[str]] = None) -> None: name: Optional[str] = None,
is_active: Optional[bool] = None,
group_ids: Optional[List[str]] = None,
) -> None:
"""Update a user.""" """Update a user."""
assert self._groups is not None assert self._groups is not None
@ -156,10 +164,7 @@ class AuthStore:
user.groups = groups user.groups = groups
user.invalidate_permission_cache() user.invalidate_permission_cache()
for attr_name, value in ( for attr_name, value in (("name", name), ("is_active", is_active)):
('name', name),
('is_active', is_active),
):
if value is not None: if value is not None:
setattr(user, attr_name, value) setattr(user, attr_name, value)
@ -175,8 +180,7 @@ class AuthStore:
user.is_active = False user.is_active = False
self._async_schedule_save() self._async_schedule_save()
async def async_remove_credentials( async def async_remove_credentials(self, credentials: models.Credentials) -> None:
self, credentials: models.Credentials) -> None:
"""Remove credentials.""" """Remove credentials."""
if self._users is None: if self._users is None:
await self._async_load() await self._async_load()
@ -197,23 +201,25 @@ class AuthStore:
self._async_schedule_save() self._async_schedule_save()
async def async_create_refresh_token( async def async_create_refresh_token(
self, user: models.User, client_id: Optional[str] = None, self,
client_name: Optional[str] = None, user: models.User,
client_icon: Optional[str] = None, client_id: Optional[str] = None,
token_type: str = models.TOKEN_TYPE_NORMAL, client_name: Optional[str] = None,
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION) \ client_icon: Optional[str] = None,
-> models.RefreshToken: token_type: str = models.TOKEN_TYPE_NORMAL,
access_token_expiration: timedelta = ACCESS_TOKEN_EXPIRATION,
) -> models.RefreshToken:
"""Create a new token for a user.""" """Create a new token for a user."""
kwargs = { kwargs = {
'user': user, "user": user,
'client_id': client_id, "client_id": client_id,
'token_type': token_type, "token_type": token_type,
'access_token_expiration': access_token_expiration "access_token_expiration": access_token_expiration,
} # type: Dict[str, Any] } # type: Dict[str, Any]
if client_name: if client_name:
kwargs['client_name'] = client_name kwargs["client_name"] = client_name
if client_icon: if client_icon:
kwargs['client_icon'] = client_icon kwargs["client_icon"] = client_icon
refresh_token = models.RefreshToken(**kwargs) refresh_token = models.RefreshToken(**kwargs)
user.refresh_tokens[refresh_token.id] = refresh_token user.refresh_tokens[refresh_token.id] = refresh_token
@ -222,7 +228,8 @@ class AuthStore:
return refresh_token return refresh_token
async def async_remove_refresh_token( async def async_remove_refresh_token(
self, refresh_token: models.RefreshToken) -> None: self, refresh_token: models.RefreshToken
) -> None:
"""Remove a refresh token.""" """Remove a refresh token."""
if self._users is None: if self._users is None:
await self._async_load() await self._async_load()
@ -234,7 +241,8 @@ class AuthStore:
break break
async def async_get_refresh_token( async def async_get_refresh_token(
self, token_id: str) -> Optional[models.RefreshToken]: self, token_id: str
) -> Optional[models.RefreshToken]:
"""Get refresh token by id.""" """Get refresh token by id."""
if self._users is None: if self._users is None:
await self._async_load() await self._async_load()
@ -248,7 +256,8 @@ class AuthStore:
return None return None
async def async_get_refresh_token_by_token( async def async_get_refresh_token_by_token(
self, token: str) -> Optional[models.RefreshToken]: self, token: str
) -> Optional[models.RefreshToken]:
"""Get refresh token by token.""" """Get refresh token by token."""
if self._users is None: if self._users is None:
await self._async_load() await self._async_load()
@ -265,8 +274,8 @@ class AuthStore:
@callback @callback
def async_log_refresh_token_usage( def async_log_refresh_token_usage(
self, refresh_token: models.RefreshToken, self, refresh_token: models.RefreshToken, remote_ip: Optional[str] = None
remote_ip: Optional[str] = None) -> None: ) -> None:
"""Update refresh token last used information.""" """Update refresh token last used information."""
refresh_token.last_used_at = dt_util.utcnow() refresh_token.last_used_at = dt_util.utcnow()
refresh_token.last_used_ip = remote_ip refresh_token.last_used_ip = remote_ip
@ -292,9 +301,7 @@ class AuthStore:
if self._users is not None: if self._users is not None:
return return
self._perm_lookup = perm_lookup = PermissionLookup( self._perm_lookup = perm_lookup = PermissionLookup(ent_reg, dev_reg)
ent_reg, dev_reg
)
if data is None: if data is None:
self._set_defaults() self._set_defaults()
@ -317,24 +324,24 @@ class AuthStore:
# prevents crashing if user rolls back HA version after a new property # prevents crashing if user rolls back HA version after a new property
# was added. # was added.
for group_dict in data.get('groups', []): for group_dict in data.get("groups", []):
policy = None # type: Optional[PolicyType] policy = None # type: Optional[PolicyType]
if group_dict['id'] == GROUP_ID_ADMIN: if group_dict["id"] == GROUP_ID_ADMIN:
has_admin_group = True has_admin_group = True
name = GROUP_NAME_ADMIN name = GROUP_NAME_ADMIN
policy = system_policies.ADMIN_POLICY policy = system_policies.ADMIN_POLICY
system_generated = True system_generated = True
elif group_dict['id'] == GROUP_ID_USER: elif group_dict["id"] == GROUP_ID_USER:
has_user_group = True has_user_group = True
name = GROUP_NAME_USER name = GROUP_NAME_USER
policy = system_policies.USER_POLICY policy = system_policies.USER_POLICY
system_generated = True system_generated = True
elif group_dict['id'] == GROUP_ID_READ_ONLY: elif group_dict["id"] == GROUP_ID_READ_ONLY:
has_read_only_group = True has_read_only_group = True
name = GROUP_NAME_READ_ONLY name = GROUP_NAME_READ_ONLY
@ -342,18 +349,18 @@ class AuthStore:
system_generated = True system_generated = True
else: else:
name = group_dict['name'] name = group_dict["name"]
policy = group_dict.get('policy') policy = group_dict.get("policy")
system_generated = False system_generated = False
# We don't want groups without a policy that are not system groups # We don't want groups without a policy that are not system groups
# This is part of migrating from state 1 # This is part of migrating from state 1
if policy is None: if policy is None:
group_without_policy = group_dict['id'] group_without_policy = group_dict["id"]
continue continue
groups[group_dict['id']] = models.Group( groups[group_dict["id"]] = models.Group(
id=group_dict['id'], id=group_dict["id"],
name=name, name=name,
policy=policy, policy=policy,
system_generated=system_generated, system_generated=system_generated,
@ -361,8 +368,7 @@ class AuthStore:
# If there are no groups, add all existing users to the admin group. # If there are no groups, add all existing users to the admin group.
# This is part of migrating from state 2 # This is part of migrating from state 2
migrate_users_to_admin_group = (not groups and migrate_users_to_admin_group = not groups and group_without_policy is None
group_without_policy is None)
# If we find a no_policy_group, we need to migrate all users to the # If we find a no_policy_group, we need to migrate all users to the
# admin group. We only do this if there are no other groups, as is # admin group. We only do this if there are no other groups, as is
@ -385,82 +391,86 @@ class AuthStore:
user_group = _system_user_group() user_group = _system_user_group()
groups[user_group.id] = user_group groups[user_group.id] = user_group
for user_dict in data['users']: for user_dict in data["users"]:
# Collect the users group. # Collect the users group.
user_groups = [] user_groups = []
for group_id in user_dict.get('group_ids', []): for group_id in user_dict.get("group_ids", []):
# This is part of migrating from state 1 # This is part of migrating from state 1
if group_id == group_without_policy: if group_id == group_without_policy:
group_id = GROUP_ID_ADMIN group_id = GROUP_ID_ADMIN
user_groups.append(groups[group_id]) user_groups.append(groups[group_id])
# This is part of migrating from state 2 # This is part of migrating from state 2
if (not user_dict['system_generated'] and if not user_dict["system_generated"] and migrate_users_to_admin_group:
migrate_users_to_admin_group):
user_groups.append(groups[GROUP_ID_ADMIN]) user_groups.append(groups[GROUP_ID_ADMIN])
users[user_dict['id']] = models.User( users[user_dict["id"]] = models.User(
name=user_dict['name'], name=user_dict["name"],
groups=user_groups, groups=user_groups,
id=user_dict['id'], id=user_dict["id"],
is_owner=user_dict['is_owner'], is_owner=user_dict["is_owner"],
is_active=user_dict['is_active'], is_active=user_dict["is_active"],
system_generated=user_dict['system_generated'], system_generated=user_dict["system_generated"],
perm_lookup=perm_lookup, perm_lookup=perm_lookup,
) )
for cred_dict in data['credentials']: for cred_dict in data["credentials"]:
users[cred_dict['user_id']].credentials.append(models.Credentials( users[cred_dict["user_id"]].credentials.append(
id=cred_dict['id'], models.Credentials(
is_new=False, id=cred_dict["id"],
auth_provider_type=cred_dict['auth_provider_type'], is_new=False,
auth_provider_id=cred_dict['auth_provider_id'], auth_provider_type=cred_dict["auth_provider_type"],
data=cred_dict['data'], auth_provider_id=cred_dict["auth_provider_id"],
)) data=cred_dict["data"],
)
)
for rt_dict in data['refresh_tokens']: for rt_dict in data["refresh_tokens"]:
# Filter out the old keys that don't have jwt_key (pre-0.76) # Filter out the old keys that don't have jwt_key (pre-0.76)
if 'jwt_key' not in rt_dict: if "jwt_key" not in rt_dict:
continue continue
created_at = dt_util.parse_datetime(rt_dict['created_at']) created_at = dt_util.parse_datetime(rt_dict["created_at"])
if created_at is None: if created_at is None:
getLogger(__name__).error( getLogger(__name__).error(
'Ignoring refresh token %(id)s with invalid created_at ' "Ignoring refresh token %(id)s with invalid created_at "
'%(created_at)s for user_id %(user_id)s', rt_dict) "%(created_at)s for user_id %(user_id)s",
rt_dict,
)
continue continue
token_type = rt_dict.get('token_type') token_type = rt_dict.get("token_type")
if token_type is None: if token_type is None:
if rt_dict['client_id'] is None: if rt_dict["client_id"] is None:
token_type = models.TOKEN_TYPE_SYSTEM token_type = models.TOKEN_TYPE_SYSTEM
else: else:
token_type = models.TOKEN_TYPE_NORMAL token_type = models.TOKEN_TYPE_NORMAL
# old refresh_token don't have last_used_at (pre-0.78) # old refresh_token don't have last_used_at (pre-0.78)
last_used_at_str = rt_dict.get('last_used_at') last_used_at_str = rt_dict.get("last_used_at")
if last_used_at_str: if last_used_at_str:
last_used_at = dt_util.parse_datetime(last_used_at_str) last_used_at = dt_util.parse_datetime(last_used_at_str)
else: else:
last_used_at = None last_used_at = None
token = models.RefreshToken( token = models.RefreshToken(
id=rt_dict['id'], id=rt_dict["id"],
user=users[rt_dict['user_id']], user=users[rt_dict["user_id"]],
client_id=rt_dict['client_id'], client_id=rt_dict["client_id"],
# use dict.get to keep backward compatibility # use dict.get to keep backward compatibility
client_name=rt_dict.get('client_name'), client_name=rt_dict.get("client_name"),
client_icon=rt_dict.get('client_icon'), client_icon=rt_dict.get("client_icon"),
token_type=token_type, token_type=token_type,
created_at=created_at, created_at=created_at,
access_token_expiration=timedelta( access_token_expiration=timedelta(
seconds=rt_dict['access_token_expiration']), seconds=rt_dict["access_token_expiration"]
token=rt_dict['token'], ),
jwt_key=rt_dict['jwt_key'], token=rt_dict["token"],
jwt_key=rt_dict["jwt_key"],
last_used_at=last_used_at, last_used_at=last_used_at,
last_used_ip=rt_dict.get('last_used_ip'), last_used_ip=rt_dict.get("last_used_ip"),
) )
users[rt_dict['user_id']].refresh_tokens[token.id] = token users[rt_dict["user_id"]].refresh_tokens[token.id] = token
self._groups = groups self._groups = groups
self._users = users self._users = users
@ -481,12 +491,12 @@ class AuthStore:
users = [ users = [
{ {
'id': user.id, "id": user.id,
'group_ids': [group.id for group in user.groups], "group_ids": [group.id for group in user.groups],
'is_owner': user.is_owner, "is_owner": user.is_owner,
'is_active': user.is_active, "is_active": user.is_active,
'name': user.name, "name": user.name,
'system_generated': user.system_generated, "system_generated": user.system_generated,
} }
for user in self._users.values() for user in self._users.values()
] ]
@ -494,23 +504,23 @@ class AuthStore:
groups = [] groups = []
for group in self._groups.values(): for group in self._groups.values():
g_dict = { g_dict = {
'id': group.id, "id": group.id,
# Name not read for sys groups. Kept here for backwards compat # Name not read for sys groups. Kept here for backwards compat
'name': group.name "name": group.name,
} # type: Dict[str, Any] } # type: Dict[str, Any]
if not group.system_generated: if not group.system_generated:
g_dict['policy'] = group.policy g_dict["policy"] = group.policy
groups.append(g_dict) groups.append(g_dict)
credentials = [ credentials = [
{ {
'id': credential.id, "id": credential.id,
'user_id': user.id, "user_id": user.id,
'auth_provider_type': credential.auth_provider_type, "auth_provider_type": credential.auth_provider_type,
'auth_provider_id': credential.auth_provider_id, "auth_provider_id": credential.auth_provider_id,
'data': credential.data, "data": credential.data,
} }
for user in self._users.values() for user in self._users.values()
for credential in user.credentials for credential in user.credentials
@ -518,31 +528,30 @@ class AuthStore:
refresh_tokens = [ refresh_tokens = [
{ {
'id': refresh_token.id, "id": refresh_token.id,
'user_id': user.id, "user_id": user.id,
'client_id': refresh_token.client_id, "client_id": refresh_token.client_id,
'client_name': refresh_token.client_name, "client_name": refresh_token.client_name,
'client_icon': refresh_token.client_icon, "client_icon": refresh_token.client_icon,
'token_type': refresh_token.token_type, "token_type": refresh_token.token_type,
'created_at': refresh_token.created_at.isoformat(), "created_at": refresh_token.created_at.isoformat(),
'access_token_expiration': "access_token_expiration": refresh_token.access_token_expiration.total_seconds(),
refresh_token.access_token_expiration.total_seconds(), "token": refresh_token.token,
'token': refresh_token.token, "jwt_key": refresh_token.jwt_key,
'jwt_key': refresh_token.jwt_key, "last_used_at": refresh_token.last_used_at.isoformat()
'last_used_at': if refresh_token.last_used_at
refresh_token.last_used_at.isoformat() else None,
if refresh_token.last_used_at else None, "last_used_ip": refresh_token.last_used_ip,
'last_used_ip': refresh_token.last_used_ip,
} }
for user in self._users.values() for user in self._users.values()
for refresh_token in user.refresh_tokens.values() for refresh_token in user.refresh_tokens.values()
] ]
return { return {
'users': users, "users": users,
'groups': groups, "groups": groups,
'credentials': credentials, "credentials": credentials,
'refresh_tokens': refresh_tokens, "refresh_tokens": refresh_tokens,
} }
def _set_defaults(self) -> None: def _set_defaults(self) -> None:

View file

@ -4,6 +4,6 @@ from datetime import timedelta
ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30) ACCESS_TOKEN_EXPIRATION = timedelta(minutes=30)
MFA_SESSION_EXPIRATION = timedelta(minutes=5) MFA_SESSION_EXPIRATION = timedelta(minutes=5)
GROUP_ID_ADMIN = 'system-admin' GROUP_ID_ADMIN = "system-admin"
GROUP_ID_USER = 'system-users' GROUP_ID_USER = "system-users"
GROUP_ID_READ_ONLY = 'system-read-only' GROUP_ID_READ_ONLY = "system-read-only"

View file

@ -15,14 +15,17 @@ from homeassistant.util.decorator import Registry
MULTI_FACTOR_AUTH_MODULES = Registry() MULTI_FACTOR_AUTH_MODULES = Registry()
MULTI_FACTOR_AUTH_MODULE_SCHEMA = vol.Schema({ MULTI_FACTOR_AUTH_MODULE_SCHEMA = vol.Schema(
vol.Required(CONF_TYPE): str, {
vol.Optional(CONF_NAME): str, vol.Required(CONF_TYPE): str,
# Specify ID if you have two mfa auth module for same type. vol.Optional(CONF_NAME): str,
vol.Optional(CONF_ID): str, # Specify ID if you have two mfa auth module for same type.
}, extra=vol.ALLOW_EXTRA) vol.Optional(CONF_ID): str,
},
extra=vol.ALLOW_EXTRA,
)
DATA_REQS = 'mfa_auth_module_reqs_processed' DATA_REQS = "mfa_auth_module_reqs_processed"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -30,7 +33,7 @@ _LOGGER = logging.getLogger(__name__)
class MultiFactorAuthModule: class MultiFactorAuthModule:
"""Multi-factor Auth Module of validation function.""" """Multi-factor Auth Module of validation function."""
DEFAULT_TITLE = 'Unnamed auth module' DEFAULT_TITLE = "Unnamed auth module"
MAX_RETRY_TIME = 3 MAX_RETRY_TIME = 3
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None: def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
@ -63,7 +66,7 @@ class MultiFactorAuthModule:
"""Return a voluptuous schema to define mfa auth module's input.""" """Return a voluptuous schema to define mfa auth module's input."""
raise NotImplementedError raise NotImplementedError
async def async_setup_flow(self, user_id: str) -> 'SetupFlow': async def async_setup_flow(self, user_id: str) -> "SetupFlow":
"""Return a data entry flow handler for setup module. """Return a data entry flow handler for setup module.
Mfa module should extend SetupFlow Mfa module should extend SetupFlow
@ -82,8 +85,7 @@ class MultiFactorAuthModule:
"""Return whether user is setup.""" """Return whether user is setup."""
raise NotImplementedError raise NotImplementedError
async def async_validate( async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
self, user_id: str, user_input: Dict[str, Any]) -> bool:
"""Return True if validation passed.""" """Return True if validation passed."""
raise NotImplementedError raise NotImplementedError
@ -91,17 +93,17 @@ class MultiFactorAuthModule:
class SetupFlow(data_entry_flow.FlowHandler): class SetupFlow(data_entry_flow.FlowHandler):
"""Handler for the setup flow.""" """Handler for the setup flow."""
def __init__(self, auth_module: MultiFactorAuthModule, def __init__(
setup_schema: vol.Schema, self, auth_module: MultiFactorAuthModule, setup_schema: vol.Schema, user_id: str
user_id: str) -> None: ) -> None:
"""Initialize the setup flow.""" """Initialize the setup flow."""
self._auth_module = auth_module self._auth_module = auth_module
self._setup_schema = setup_schema self._setup_schema = setup_schema
self._user_id = user_id self._user_id = user_id
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None
-> Dict[str, Any]: ) -> Dict[str, Any]:
"""Handle the first step of setup flow. """Handle the first step of setup flow.
Return self.async_show_form(step_id='init') if user_input is None. Return self.async_show_form(step_id='init') if user_input is None.
@ -110,23 +112,19 @@ class SetupFlow(data_entry_flow.FlowHandler):
errors = {} # type: Dict[str, str] errors = {} # type: Dict[str, str]
if user_input: if user_input:
result = await self._auth_module.async_setup_user( result = await self._auth_module.async_setup_user(self._user_id, user_input)
self._user_id, user_input)
return self.async_create_entry( return self.async_create_entry(
title=self._auth_module.name, title=self._auth_module.name, data={"result": result}
data={'result': result}
) )
return self.async_show_form( return self.async_show_form(
step_id='init', step_id="init", data_schema=self._setup_schema, errors=errors
data_schema=self._setup_schema,
errors=errors
) )
async def auth_mfa_module_from_config( async def auth_mfa_module_from_config(
hass: HomeAssistant, config: Dict[str, Any]) \ hass: HomeAssistant, config: Dict[str, Any]
-> MultiFactorAuthModule: ) -> MultiFactorAuthModule:
"""Initialize an auth module from a config.""" """Initialize an auth module from a config."""
module_name = config[CONF_TYPE] module_name = config[CONF_TYPE]
module = await _load_mfa_module(hass, module_name) module = await _load_mfa_module(hass, module_name)
@ -134,26 +132,29 @@ async def auth_mfa_module_from_config(
try: try:
config = module.CONFIG_SCHEMA(config) # type: ignore config = module.CONFIG_SCHEMA(config) # type: ignore
except vol.Invalid as err: except vol.Invalid as err:
_LOGGER.error('Invalid configuration for multi-factor module %s: %s', _LOGGER.error(
module_name, humanize_error(config, err)) "Invalid configuration for multi-factor module %s: %s",
module_name,
humanize_error(config, err),
)
raise raise
return MULTI_FACTOR_AUTH_MODULES[module_name](hass, config) # type: ignore return MULTI_FACTOR_AUTH_MODULES[module_name](hass, config) # type: ignore
async def _load_mfa_module(hass: HomeAssistant, module_name: str) \ async def _load_mfa_module(hass: HomeAssistant, module_name: str) -> types.ModuleType:
-> types.ModuleType:
"""Load an mfa auth module.""" """Load an mfa auth module."""
module_path = 'homeassistant.auth.mfa_modules.{}'.format(module_name) module_path = "homeassistant.auth.mfa_modules.{}".format(module_name)
try: try:
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
except ImportError as err: except ImportError as err:
_LOGGER.error('Unable to load mfa module %s: %s', module_name, err) _LOGGER.error("Unable to load mfa module %s: %s", module_name, err)
raise HomeAssistantError('Unable to load mfa module {}: {}'.format( raise HomeAssistantError(
module_name, err)) "Unable to load mfa module {}: {}".format(module_name, err)
)
if hass.config.skip_pip or not hasattr(module, 'REQUIREMENTS'): if hass.config.skip_pip or not hasattr(module, "REQUIREMENTS"):
return module return module
processed = hass.data.get(DATA_REQS) processed = hass.data.get(DATA_REQS)
@ -164,12 +165,13 @@ async def _load_mfa_module(hass: HomeAssistant, module_name: str) \
# https://github.com/python/mypy/issues/1424 # https://github.com/python/mypy/issues/1424
req_success = await requirements.async_process_requirements( req_success = await requirements.async_process_requirements(
hass, module_path, module.REQUIREMENTS) # type: ignore hass, module_path, module.REQUIREMENTS
) # type: ignore
if not req_success: if not req_success:
raise HomeAssistantError( raise HomeAssistantError(
'Unable to process requirements of mfa module {}'.format( "Unable to process requirements of mfa module {}".format(module_name)
module_name)) )
processed.add(module_name) processed.add(module_name)
return module return module

View file

@ -6,39 +6,45 @@ import voluptuous as vol
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from . import MultiFactorAuthModule, MULTI_FACTOR_AUTH_MODULES, \ from . import (
MULTI_FACTOR_AUTH_MODULE_SCHEMA, SetupFlow MultiFactorAuthModule,
MULTI_FACTOR_AUTH_MODULES,
MULTI_FACTOR_AUTH_MODULE_SCHEMA,
SetupFlow,
)
CONFIG_SCHEMA = MULTI_FACTOR_AUTH_MODULE_SCHEMA.extend({ CONFIG_SCHEMA = MULTI_FACTOR_AUTH_MODULE_SCHEMA.extend(
vol.Required('data'): [vol.Schema({ {
vol.Required('user_id'): str, vol.Required("data"): [
vol.Required('pin'): str, vol.Schema({vol.Required("user_id"): str, vol.Required("pin"): str})
})] ]
}, extra=vol.PREVENT_EXTRA) },
extra=vol.PREVENT_EXTRA,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@MULTI_FACTOR_AUTH_MODULES.register('insecure_example') @MULTI_FACTOR_AUTH_MODULES.register("insecure_example")
class InsecureExampleModule(MultiFactorAuthModule): class InsecureExampleModule(MultiFactorAuthModule):
"""Example auth module validate pin.""" """Example auth module validate pin."""
DEFAULT_TITLE = 'Insecure Personal Identify Number' DEFAULT_TITLE = "Insecure Personal Identify Number"
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None: def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
"""Initialize the user data store.""" """Initialize the user data store."""
super().__init__(hass, config) super().__init__(hass, config)
self._data = config['data'] self._data = config["data"]
@property @property
def input_schema(self) -> vol.Schema: def input_schema(self) -> vol.Schema:
"""Validate login flow input data.""" """Validate login flow input data."""
return vol.Schema({'pin': str}) return vol.Schema({"pin": str})
@property @property
def setup_schema(self) -> vol.Schema: def setup_schema(self) -> vol.Schema:
"""Validate async_setup_user input data.""" """Validate async_setup_user input data."""
return vol.Schema({'pin': str}) return vol.Schema({"pin": str})
async def async_setup_flow(self, user_id: str) -> SetupFlow: async def async_setup_flow(self, user_id: str) -> SetupFlow:
"""Return a data entry flow handler for setup module. """Return a data entry flow handler for setup module.
@ -50,21 +56,21 @@ class InsecureExampleModule(MultiFactorAuthModule):
async def async_setup_user(self, user_id: str, setup_data: Any) -> Any: async def async_setup_user(self, user_id: str, setup_data: Any) -> Any:
"""Set up user to use mfa module.""" """Set up user to use mfa module."""
# data shall has been validate in caller # data shall has been validate in caller
pin = setup_data['pin'] pin = setup_data["pin"]
for data in self._data: for data in self._data:
if data['user_id'] == user_id: if data["user_id"] == user_id:
# already setup, override # already setup, override
data['pin'] = pin data["pin"] = pin
return return
self._data.append({'user_id': user_id, 'pin': pin}) self._data.append({"user_id": user_id, "pin": pin})
async def async_depose_user(self, user_id: str) -> None: async def async_depose_user(self, user_id: str) -> None:
"""Remove user from mfa module.""" """Remove user from mfa module."""
found = None found = None
for data in self._data: for data in self._data:
if data['user_id'] == user_id: if data["user_id"] == user_id:
found = data found = data
break break
if found: if found:
@ -73,17 +79,16 @@ class InsecureExampleModule(MultiFactorAuthModule):
async def async_is_user_setup(self, user_id: str) -> bool: async def async_is_user_setup(self, user_id: str) -> bool:
"""Return whether user is setup.""" """Return whether user is setup."""
for data in self._data: for data in self._data:
if data['user_id'] == user_id: if data["user_id"] == user_id:
return True return True
return False return False
async def async_validate( async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
self, user_id: str, user_input: Dict[str, Any]) -> bool:
"""Return True if validation passed.""" """Return True if validation passed."""
for data in self._data: for data in self._data:
if data['user_id'] == user_id: if data["user_id"] == user_id:
# user_input has been validate in caller # user_input has been validate in caller
if data['pin'] == user_input['pin']: if data["pin"] == user_input["pin"]:
return True return True
return False return False

View file

@ -15,26 +15,32 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ServiceNotFound from homeassistant.exceptions import ServiceNotFound
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from . import MultiFactorAuthModule, MULTI_FACTOR_AUTH_MODULES, \ from . import (
MULTI_FACTOR_AUTH_MODULE_SCHEMA, SetupFlow MultiFactorAuthModule,
MULTI_FACTOR_AUTH_MODULES,
MULTI_FACTOR_AUTH_MODULE_SCHEMA,
SetupFlow,
)
REQUIREMENTS = ['pyotp==2.2.7'] REQUIREMENTS = ["pyotp==2.2.7"]
CONF_MESSAGE = 'message' CONF_MESSAGE = "message"
CONFIG_SCHEMA = MULTI_FACTOR_AUTH_MODULE_SCHEMA.extend({ CONFIG_SCHEMA = MULTI_FACTOR_AUTH_MODULE_SCHEMA.extend(
vol.Optional(CONF_INCLUDE): vol.All(cv.ensure_list, [cv.string]), {
vol.Optional(CONF_EXCLUDE): vol.All(cv.ensure_list, [cv.string]), vol.Optional(CONF_INCLUDE): vol.All(cv.ensure_list, [cv.string]),
vol.Optional(CONF_MESSAGE, vol.Optional(CONF_EXCLUDE): vol.All(cv.ensure_list, [cv.string]),
default='{} is your Home Assistant login code'): str vol.Optional(CONF_MESSAGE, default="{} is your Home Assistant login code"): str,
}, extra=vol.PREVENT_EXTRA) },
extra=vol.PREVENT_EXTRA,
)
STORAGE_VERSION = 1 STORAGE_VERSION = 1
STORAGE_KEY = 'auth_module.notify' STORAGE_KEY = "auth_module.notify"
STORAGE_USERS = 'users' STORAGE_USERS = "users"
STORAGE_USER_ID = 'user_id' STORAGE_USER_ID = "user_id"
INPUT_FIELD_CODE = 'code' INPUT_FIELD_CODE = "code"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -42,24 +48,28 @@ _LOGGER = logging.getLogger(__name__)
def _generate_secret() -> str: def _generate_secret() -> str:
"""Generate a secret.""" """Generate a secret."""
import pyotp import pyotp
return str(pyotp.random_base32()) return str(pyotp.random_base32())
def _generate_random() -> int: def _generate_random() -> int:
"""Generate a 8 digit number.""" """Generate a 8 digit number."""
import pyotp import pyotp
return int(pyotp.random_base32(length=8, chars=list('1234567890')))
return int(pyotp.random_base32(length=8, chars=list("1234567890")))
def _generate_otp(secret: str, count: int) -> str: def _generate_otp(secret: str, count: int) -> str:
"""Generate one time password.""" """Generate one time password."""
import pyotp import pyotp
return str(pyotp.HOTP(secret).at(count)) return str(pyotp.HOTP(secret).at(count))
def _verify_otp(secret: str, otp: str, count: int) -> bool: def _verify_otp(secret: str, otp: str, count: int) -> bool:
"""Verify one time password.""" """Verify one time password."""
import pyotp import pyotp
return bool(pyotp.HOTP(secret).verify(otp, count)) return bool(pyotp.HOTP(secret).verify(otp, count))
@ -67,7 +77,7 @@ def _verify_otp(secret: str, otp: str, count: int) -> bool:
class NotifySetting: class NotifySetting:
"""Store notify setting for one user.""" """Store notify setting for one user."""
secret = attr.ib(type=str, factory=_generate_secret) # not persistent secret = attr.ib(type=str, factory=_generate_secret) # not persistent
counter = attr.ib(type=int, factory=_generate_random) # not persistent counter = attr.ib(type=int, factory=_generate_random) # not persistent
notify_service = attr.ib(type=Optional[str], default=None) notify_service = attr.ib(type=Optional[str], default=None)
target = attr.ib(type=Optional[str], default=None) target = attr.ib(type=Optional[str], default=None)
@ -76,18 +86,19 @@ class NotifySetting:
_UsersDict = Dict[str, NotifySetting] _UsersDict = Dict[str, NotifySetting]
@MULTI_FACTOR_AUTH_MODULES.register('notify') @MULTI_FACTOR_AUTH_MODULES.register("notify")
class NotifyAuthModule(MultiFactorAuthModule): class NotifyAuthModule(MultiFactorAuthModule):
"""Auth module send hmac-based one time password by notify service.""" """Auth module send hmac-based one time password by notify service."""
DEFAULT_TITLE = 'Notify One-Time Password' DEFAULT_TITLE = "Notify One-Time Password"
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None: def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
"""Initialize the user data store.""" """Initialize the user data store."""
super().__init__(hass, config) super().__init__(hass, config)
self._user_settings = None # type: Optional[_UsersDict] self._user_settings = None # type: Optional[_UsersDict]
self._user_store = hass.helpers.storage.Store( self._user_store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True) STORAGE_VERSION, STORAGE_KEY, private=True
)
self._include = config.get(CONF_INCLUDE, []) self._include = config.get(CONF_INCLUDE, [])
self._exclude = config.get(CONF_EXCLUDE, []) self._exclude = config.get(CONF_EXCLUDE, [])
self._message_template = config[CONF_MESSAGE] self._message_template = config[CONF_MESSAGE]
@ -119,22 +130,27 @@ class NotifyAuthModule(MultiFactorAuthModule):
if self._user_settings is None: if self._user_settings is None:
return return
await self._user_store.async_save({STORAGE_USERS: { await self._user_store.async_save(
user_id: attr.asdict( {
notify_setting, filter=attr.filters.exclude( STORAGE_USERS: {
attr.fields(NotifySetting).secret, user_id: attr.asdict(
attr.fields(NotifySetting).counter, notify_setting,
)) filter=attr.filters.exclude(
for user_id, notify_setting attr.fields(NotifySetting).secret,
in self._user_settings.items() attr.fields(NotifySetting).counter,
}}) ),
)
for user_id, notify_setting in self._user_settings.items()
}
}
)
@callback @callback
def aync_get_available_notify_services(self) -> List[str]: def aync_get_available_notify_services(self) -> List[str]:
"""Return list of notify services.""" """Return list of notify services."""
unordered_services = set() unordered_services = set()
for service in self.hass.services.async_services().get('notify', {}): for service in self.hass.services.async_services().get("notify", {}):
if service not in self._exclude: if service not in self._exclude:
unordered_services.add(service) unordered_services.add(service)
@ -149,8 +165,8 @@ class NotifyAuthModule(MultiFactorAuthModule):
Mfa module should extend SetupFlow Mfa module should extend SetupFlow
""" """
return NotifySetupFlow( return NotifySetupFlow(
self, self.input_schema, user_id, self, self.input_schema, user_id, self.aync_get_available_notify_services()
self.aync_get_available_notify_services()) )
async def async_setup_user(self, user_id: str, setup_data: Any) -> Any: async def async_setup_user(self, user_id: str, setup_data: Any) -> Any:
"""Set up auth module for user.""" """Set up auth module for user."""
@ -159,8 +175,8 @@ class NotifyAuthModule(MultiFactorAuthModule):
assert self._user_settings is not None assert self._user_settings is not None
self._user_settings[user_id] = NotifySetting( self._user_settings[user_id] = NotifySetting(
notify_service=setup_data.get('notify_service'), notify_service=setup_data.get("notify_service"),
target=setup_data.get('target'), target=setup_data.get("target"),
) )
await self._async_save() await self._async_save()
@ -182,8 +198,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
return user_id in self._user_settings return user_id in self._user_settings
async def async_validate( async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
self, user_id: str, user_input: Dict[str, Any]) -> bool:
"""Return True if validation passed.""" """Return True if validation passed."""
if self._user_settings is None: if self._user_settings is None:
await self._async_load() await self._async_load()
@ -195,9 +210,11 @@ class NotifyAuthModule(MultiFactorAuthModule):
# user_input has been validate in caller # user_input has been validate in caller
return await self.hass.async_add_executor_job( return await self.hass.async_add_executor_job(
_verify_otp, notify_setting.secret, _verify_otp,
user_input.get(INPUT_FIELD_CODE, ''), notify_setting.secret,
notify_setting.counter) user_input.get(INPUT_FIELD_CODE, ""),
notify_setting.counter,
)
async def async_initialize_login_mfa_step(self, user_id: str) -> None: async def async_initialize_login_mfa_step(self, user_id: str) -> None:
"""Generate code and notify user.""" """Generate code and notify user."""
@ -207,7 +224,7 @@ class NotifyAuthModule(MultiFactorAuthModule):
notify_setting = self._user_settings.get(user_id, None) notify_setting = self._user_settings.get(user_id, None)
if notify_setting is None: if notify_setting is None:
raise ValueError('Cannot find user_id') raise ValueError("Cannot find user_id")
def generate_secret_and_one_time_password() -> str: def generate_secret_and_one_time_password() -> str:
"""Generate and send one time password.""" """Generate and send one time password."""
@ -215,11 +232,11 @@ class NotifyAuthModule(MultiFactorAuthModule):
# secret and counter are not persistent # secret and counter are not persistent
notify_setting.secret = _generate_secret() notify_setting.secret = _generate_secret()
notify_setting.counter = _generate_random() notify_setting.counter = _generate_random()
return _generate_otp( return _generate_otp(notify_setting.secret, notify_setting.counter)
notify_setting.secret, notify_setting.counter)
code = await self.hass.async_add_executor_job( code = await self.hass.async_add_executor_job(
generate_secret_and_one_time_password) generate_secret_and_one_time_password
)
await self.async_notify_user(user_id, code) await self.async_notify_user(user_id, code)
@ -231,105 +248,107 @@ class NotifyAuthModule(MultiFactorAuthModule):
notify_setting = self._user_settings.get(user_id, None) notify_setting = self._user_settings.get(user_id, None)
if notify_setting is None: if notify_setting is None:
_LOGGER.error('Cannot find user %s', user_id) _LOGGER.error("Cannot find user %s", user_id)
return return
await self.async_notify( # type: ignore await self.async_notify( # type: ignore
code, notify_setting.notify_service, notify_setting.target) code, notify_setting.notify_service, notify_setting.target
)
async def async_notify(self, code: str, notify_service: str, async def async_notify(
target: Optional[str] = None) -> None: self, code: str, notify_service: str, target: Optional[str] = None
) -> None:
"""Send code by notify service.""" """Send code by notify service."""
data = {'message': self._message_template.format(code)} data = {"message": self._message_template.format(code)}
if target: if target:
data['target'] = [target] data["target"] = [target]
await self.hass.services.async_call('notify', notify_service, data) await self.hass.services.async_call("notify", notify_service, data)
class NotifySetupFlow(SetupFlow): class NotifySetupFlow(SetupFlow):
"""Handler for the setup flow.""" """Handler for the setup flow."""
def __init__(self, auth_module: NotifyAuthModule, def __init__(
setup_schema: vol.Schema, self,
user_id: str, auth_module: NotifyAuthModule,
available_notify_services: List[str]) -> None: setup_schema: vol.Schema,
user_id: str,
available_notify_services: List[str],
) -> None:
"""Initialize the setup flow.""" """Initialize the setup flow."""
super().__init__(auth_module, setup_schema, user_id) super().__init__(auth_module, setup_schema, user_id)
# to fix typing complaint # to fix typing complaint
self._auth_module = auth_module # type: NotifyAuthModule self._auth_module = auth_module # type: NotifyAuthModule
self._available_notify_services = available_notify_services self._available_notify_services = available_notify_services
self._secret = None # type: Optional[str] self._secret = None # type: Optional[str]
self._count = None # type: Optional[int] self._count = None # type: Optional[int]
self._notify_service = None # type: Optional[str] self._notify_service = None # type: Optional[str]
self._target = None # type: Optional[str] self._target = None # type: Optional[str]
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None
-> Dict[str, Any]: ) -> Dict[str, Any]:
"""Let user select available notify services.""" """Let user select available notify services."""
errors = {} # type: Dict[str, str] errors = {} # type: Dict[str, str]
hass = self._auth_module.hass hass = self._auth_module.hass
if user_input: if user_input:
self._notify_service = user_input['notify_service'] self._notify_service = user_input["notify_service"]
self._target = user_input.get('target') self._target = user_input.get("target")
self._secret = await hass.async_add_executor_job(_generate_secret) self._secret = await hass.async_add_executor_job(_generate_secret)
self._count = await hass.async_add_executor_job(_generate_random) self._count = await hass.async_add_executor_job(_generate_random)
return await self.async_step_setup() return await self.async_step_setup()
if not self._available_notify_services: if not self._available_notify_services:
return self.async_abort(reason='no_available_service') return self.async_abort(reason="no_available_service")
schema = OrderedDict() # type: Dict[str, Any] schema = OrderedDict() # type: Dict[str, Any]
schema['notify_service'] = vol.In(self._available_notify_services) schema["notify_service"] = vol.In(self._available_notify_services)
schema['target'] = vol.Optional(str) schema["target"] = vol.Optional(str)
return self.async_show_form( return self.async_show_form(
step_id='init', step_id="init", data_schema=vol.Schema(schema), errors=errors
data_schema=vol.Schema(schema),
errors=errors
) )
async def async_step_setup( async def async_step_setup(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None
-> Dict[str, Any]: ) -> Dict[str, Any]:
"""Verify user can recevie one-time password.""" """Verify user can recevie one-time password."""
errors = {} # type: Dict[str, str] errors = {} # type: Dict[str, str]
hass = self._auth_module.hass hass = self._auth_module.hass
if user_input: if user_input:
verified = await hass.async_add_executor_job( verified = await hass.async_add_executor_job(
_verify_otp, self._secret, user_input['code'], self._count) _verify_otp, self._secret, user_input["code"], self._count
)
if verified: if verified:
await self._auth_module.async_setup_user( await self._auth_module.async_setup_user(
self._user_id, { self._user_id,
'notify_service': self._notify_service, {"notify_service": self._notify_service, "target": self._target},
'target': self._target,
})
return self.async_create_entry(
title=self._auth_module.name,
data={}
) )
return self.async_create_entry(title=self._auth_module.name, data={})
errors['base'] = 'invalid_code' errors["base"] = "invalid_code"
# generate code every time, no retry logic # generate code every time, no retry logic
assert self._secret and self._count assert self._secret and self._count
code = await hass.async_add_executor_job( code = await hass.async_add_executor_job(
_generate_otp, self._secret, self._count) _generate_otp, self._secret, self._count
)
assert self._notify_service assert self._notify_service
try: try:
await self._auth_module.async_notify( await self._auth_module.async_notify(
code, self._notify_service, self._target) code, self._notify_service, self._target
)
except ServiceNotFound: except ServiceNotFound:
return self.async_abort(reason='notify_service_not_exist') return self.async_abort(reason="notify_service_not_exist")
return self.async_show_form( return self.async_show_form(
step_id='setup', step_id="setup",
data_schema=self._setup_schema, data_schema=self._setup_schema,
description_placeholders={'notify_service': self._notify_service}, description_placeholders={"notify_service": self._notify_service},
errors=errors, errors=errors,
) )

View file

@ -9,23 +9,26 @@ import voluptuous as vol
from homeassistant.auth.models import User from homeassistant.auth.models import User
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from . import MultiFactorAuthModule, MULTI_FACTOR_AUTH_MODULES, \ from . import (
MULTI_FACTOR_AUTH_MODULE_SCHEMA, SetupFlow MultiFactorAuthModule,
MULTI_FACTOR_AUTH_MODULES,
MULTI_FACTOR_AUTH_MODULE_SCHEMA,
SetupFlow,
)
REQUIREMENTS = ['pyotp==2.2.7', 'PyQRCode==1.2.1'] REQUIREMENTS = ["pyotp==2.2.7", "PyQRCode==1.2.1"]
CONFIG_SCHEMA = MULTI_FACTOR_AUTH_MODULE_SCHEMA.extend({ CONFIG_SCHEMA = MULTI_FACTOR_AUTH_MODULE_SCHEMA.extend({}, extra=vol.PREVENT_EXTRA)
}, extra=vol.PREVENT_EXTRA)
STORAGE_VERSION = 1 STORAGE_VERSION = 1
STORAGE_KEY = 'auth_module.totp' STORAGE_KEY = "auth_module.totp"
STORAGE_USERS = 'users' STORAGE_USERS = "users"
STORAGE_USER_ID = 'user_id' STORAGE_USER_ID = "user_id"
STORAGE_OTA_SECRET = 'ota_secret' STORAGE_OTA_SECRET = "ota_secret"
INPUT_FIELD_CODE = 'code' INPUT_FIELD_CODE = "code"
DUMMY_SECRET = 'FPPTH34D4E3MI2HG' DUMMY_SECRET = "FPPTH34D4E3MI2HG"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -38,10 +41,15 @@ def _generate_qr_code(data: str) -> str:
with BytesIO() as buffer: with BytesIO() as buffer:
qr_code.svg(file=buffer, scale=4) qr_code.svg(file=buffer, scale=4)
return '{}'.format( return "{}".format(
buffer.getvalue().decode("ascii").replace('\n', '') buffer.getvalue()
.replace('<?xml version="1.0" encoding="UTF-8"?>' .decode("ascii")
'<svg xmlns="http://www.w3.org/2000/svg"', '<svg') .replace("\n", "")
.replace(
'<?xml version="1.0" encoding="UTF-8"?>'
'<svg xmlns="http://www.w3.org/2000/svg"',
"<svg",
)
) )
@ -51,16 +59,17 @@ def _generate_secret_and_qr_code(username: str) -> Tuple[str, str, str]:
ota_secret = pyotp.random_base32() ota_secret = pyotp.random_base32()
url = pyotp.totp.TOTP(ota_secret).provisioning_uri( url = pyotp.totp.TOTP(ota_secret).provisioning_uri(
username, issuer_name="Home Assistant") username, issuer_name="Home Assistant"
)
image = _generate_qr_code(url) image = _generate_qr_code(url)
return ota_secret, url, image return ota_secret, url, image
@MULTI_FACTOR_AUTH_MODULES.register('totp') @MULTI_FACTOR_AUTH_MODULES.register("totp")
class TotpAuthModule(MultiFactorAuthModule): class TotpAuthModule(MultiFactorAuthModule):
"""Auth module validate time-based one time password.""" """Auth module validate time-based one time password."""
DEFAULT_TITLE = 'Time-based One Time Password' DEFAULT_TITLE = "Time-based One Time Password"
MAX_RETRY_TIME = 5 MAX_RETRY_TIME = 5
def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None: def __init__(self, hass: HomeAssistant, config: Dict[str, Any]) -> None:
@ -68,7 +77,8 @@ class TotpAuthModule(MultiFactorAuthModule):
super().__init__(hass, config) super().__init__(hass, config)
self._users = None # type: Optional[Dict[str, str]] self._users = None # type: Optional[Dict[str, str]]
self._user_store = hass.helpers.storage.Store( self._user_store = hass.helpers.storage.Store(
STORAGE_VERSION, STORAGE_KEY, private=True) STORAGE_VERSION, STORAGE_KEY, private=True
)
self._init_lock = asyncio.Lock() self._init_lock = asyncio.Lock()
@property @property
@ -93,14 +103,13 @@ class TotpAuthModule(MultiFactorAuthModule):
"""Save data.""" """Save data."""
await self._user_store.async_save({STORAGE_USERS: self._users}) await self._user_store.async_save({STORAGE_USERS: self._users})
def _add_ota_secret(self, user_id: str, def _add_ota_secret(self, user_id: str, secret: Optional[str] = None) -> str:
secret: Optional[str] = None) -> str:
"""Create a ota_secret for user.""" """Create a ota_secret for user."""
import pyotp import pyotp
ota_secret = secret or pyotp.random_base32() # type: str ota_secret = secret or pyotp.random_base32() # type: str
self._users[user_id] = ota_secret # type: ignore self._users[user_id] = ota_secret # type: ignore
return ota_secret return ota_secret
async def async_setup_flow(self, user_id: str) -> SetupFlow: async def async_setup_flow(self, user_id: str) -> SetupFlow:
@ -108,7 +117,7 @@ class TotpAuthModule(MultiFactorAuthModule):
Mfa module should extend SetupFlow Mfa module should extend SetupFlow
""" """
user = await self.hass.auth.async_get_user(user_id) # type: ignore user = await self.hass.auth.async_get_user(user_id) # type: ignore
return TotpSetupFlow(self, self.input_schema, user) return TotpSetupFlow(self, self.input_schema, user)
async def async_setup_user(self, user_id: str, setup_data: Any) -> str: async def async_setup_user(self, user_id: str, setup_data: Any) -> str:
@ -117,7 +126,8 @@ class TotpAuthModule(MultiFactorAuthModule):
await self._async_load() await self._async_load()
result = await self.hass.async_add_executor_job( result = await self.hass.async_add_executor_job(
self._add_ota_secret, user_id, setup_data.get('secret')) self._add_ota_secret, user_id, setup_data.get("secret")
)
await self._async_save() await self._async_save()
return result return result
@ -127,7 +137,7 @@ class TotpAuthModule(MultiFactorAuthModule):
if self._users is None: if self._users is None:
await self._async_load() await self._async_load()
if self._users.pop(user_id, None): # type: ignore if self._users.pop(user_id, None): # type: ignore
await self._async_save() await self._async_save()
async def async_is_user_setup(self, user_id: str) -> bool: async def async_is_user_setup(self, user_id: str) -> bool:
@ -135,10 +145,9 @@ class TotpAuthModule(MultiFactorAuthModule):
if self._users is None: if self._users is None:
await self._async_load() await self._async_load()
return user_id in self._users # type: ignore return user_id in self._users # type: ignore
async def async_validate( async def async_validate(self, user_id: str, user_input: Dict[str, Any]) -> bool:
self, user_id: str, user_input: Dict[str, Any]) -> bool:
"""Return True if validation passed.""" """Return True if validation passed."""
if self._users is None: if self._users is None:
await self._async_load() await self._async_load()
@ -146,7 +155,8 @@ class TotpAuthModule(MultiFactorAuthModule):
# user_input has been validate in caller # user_input has been validate in caller
# set INPUT_FIELD_CODE as vol.Required is not user friendly # set INPUT_FIELD_CODE as vol.Required is not user friendly
return await self.hass.async_add_executor_job( return await self.hass.async_add_executor_job(
self._validate_2fa, user_id, user_input.get(INPUT_FIELD_CODE, '')) self._validate_2fa, user_id, user_input.get(INPUT_FIELD_CODE, "")
)
def _validate_2fa(self, user_id: str, code: str) -> bool: def _validate_2fa(self, user_id: str, code: str) -> bool:
"""Validate two factor authentication code.""" """Validate two factor authentication code."""
@ -165,9 +175,9 @@ class TotpAuthModule(MultiFactorAuthModule):
class TotpSetupFlow(SetupFlow): class TotpSetupFlow(SetupFlow):
"""Handler for the setup flow.""" """Handler for the setup flow."""
def __init__(self, auth_module: TotpAuthModule, def __init__(
setup_schema: vol.Schema, self, auth_module: TotpAuthModule, setup_schema: vol.Schema, user: User
user: User) -> None: ) -> None:
"""Initialize the setup flow.""" """Initialize the setup flow."""
super().__init__(auth_module, setup_schema, user.id) super().__init__(auth_module, setup_schema, user.id)
# to fix typing complaint # to fix typing complaint
@ -178,8 +188,8 @@ class TotpSetupFlow(SetupFlow):
self._image = None # type Optional[str] self._image = None # type Optional[str]
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None
-> Dict[str, Any]: ) -> Dict[str, Any]:
"""Handle the first step of setup flow. """Handle the first step of setup flow.
Return self.async_show_form(step_id='init') if user_input is None. Return self.async_show_form(step_id='init') if user_input is None.
@ -191,30 +201,31 @@ class TotpSetupFlow(SetupFlow):
if user_input: if user_input:
verified = await self.hass.async_add_executor_job( # type: ignore verified = await self.hass.async_add_executor_job( # type: ignore
pyotp.TOTP(self._ota_secret).verify, user_input['code']) pyotp.TOTP(self._ota_secret).verify, user_input["code"]
)
if verified: if verified:
result = await self._auth_module.async_setup_user( result = await self._auth_module.async_setup_user(
self._user_id, {'secret': self._ota_secret}) self._user_id, {"secret": self._ota_secret}
)
return self.async_create_entry( return self.async_create_entry(
title=self._auth_module.name, title=self._auth_module.name, data={"result": result}
data={'result': result}
) )
errors['base'] = 'invalid_code' errors["base"] = "invalid_code"
else: else:
hass = self._auth_module.hass hass = self._auth_module.hass
self._ota_secret, self._url, self._image = \ self._ota_secret, self._url, self._image = await hass.async_add_executor_job( # type: ignore
await hass.async_add_executor_job( # type: ignore _generate_secret_and_qr_code, str(self._user.name)
_generate_secret_and_qr_code, str(self._user.name)) )
return self.async_show_form( return self.async_show_form(
step_id='init', step_id="init",
data_schema=self._setup_schema, data_schema=self._setup_schema,
description_placeholders={ description_placeholders={
'code': self._ota_secret, "code": self._ota_secret,
'url': self._url, "url": self._url,
'qr_code': self._image "qr_code": self._image,
}, },
errors=errors errors=errors,
) )

View file

@ -11,9 +11,9 @@ from . import permissions as perm_mdl
from .const import GROUP_ID_ADMIN from .const import GROUP_ID_ADMIN
from .util import generate_secret from .util import generate_secret
TOKEN_TYPE_NORMAL = 'normal' TOKEN_TYPE_NORMAL = "normal"
TOKEN_TYPE_SYSTEM = 'system' TOKEN_TYPE_SYSTEM = "system"
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = 'long_lived_access_token' TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
@attr.s(slots=True) @attr.s(slots=True)
@ -32,7 +32,7 @@ class User:
name = attr.ib(type=str) # type: Optional[str] name = attr.ib(type=str) # type: Optional[str]
perm_lookup = attr.ib( perm_lookup = attr.ib(
type=perm_mdl.PermissionLookup, cmp=False, type=perm_mdl.PermissionLookup, cmp=False
) # type: perm_mdl.PermissionLookup ) # type: perm_mdl.PermissionLookup
id = attr.ib(type=str, factory=lambda: uuid.uuid4().hex) id = attr.ib(type=str, factory=lambda: uuid.uuid4().hex)
is_owner = attr.ib(type=bool, default=False) is_owner = attr.ib(type=bool, default=False)
@ -42,9 +42,7 @@ class User:
groups = attr.ib(type=List, factory=list, cmp=False) # type: List[Group] groups = attr.ib(type=List, factory=list, cmp=False) # type: List[Group]
# List of credentials of a user. # List of credentials of a user.
credentials = attr.ib( credentials = attr.ib(type=list, factory=list, cmp=False) # type: List[Credentials]
type=list, factory=list, cmp=False
) # type: List[Credentials]
# Tokens associated with a user. # Tokens associated with a user.
refresh_tokens = attr.ib( refresh_tokens = attr.ib(
@ -52,10 +50,7 @@ class User:
) # type: Dict[str, RefreshToken] ) # type: Dict[str, RefreshToken]
_permissions = attr.ib( _permissions = attr.ib(
type=Optional[perm_mdl.PolicyPermissions], type=Optional[perm_mdl.PolicyPermissions], init=False, cmp=False, default=None
init=False,
cmp=False,
default=None,
) )
@property @property
@ -68,9 +63,9 @@ class User:
return self._permissions return self._permissions
self._permissions = perm_mdl.PolicyPermissions( self._permissions = perm_mdl.PolicyPermissions(
perm_mdl.merge_policies([ perm_mdl.merge_policies([group.policy for group in self.groups]),
group.policy for group in self.groups]), self.perm_lookup,
self.perm_lookup) )
return self._permissions return self._permissions
@ -80,8 +75,7 @@ class User:
if self.is_owner: if self.is_owner:
return True return True
return self.is_active and any( return self.is_active and any(gr.id == GROUP_ID_ADMIN for gr in self.groups)
gr.id == GROUP_ID_ADMIN for gr in self.groups)
def invalidate_permission_cache(self) -> None: def invalidate_permission_cache(self) -> None:
"""Invalidate permission cache.""" """Invalidate permission cache."""
@ -97,10 +91,13 @@ class RefreshToken:
access_token_expiration = attr.ib(type=timedelta) access_token_expiration = attr.ib(type=timedelta)
client_name = attr.ib(type=Optional[str], default=None) client_name = attr.ib(type=Optional[str], default=None)
client_icon = attr.ib(type=Optional[str], default=None) client_icon = attr.ib(type=Optional[str], default=None)
token_type = attr.ib(type=str, default=TOKEN_TYPE_NORMAL, token_type = attr.ib(
validator=attr.validators.in_(( type=str,
TOKEN_TYPE_NORMAL, TOKEN_TYPE_SYSTEM, default=TOKEN_TYPE_NORMAL,
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN))) validator=attr.validators.in_(
(TOKEN_TYPE_NORMAL, TOKEN_TYPE_SYSTEM, TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN)
),
)
id = attr.ib(type=str, factory=lambda: uuid.uuid4().hex) id = attr.ib(type=str, factory=lambda: uuid.uuid4().hex)
created_at = attr.ib(type=datetime, factory=dt_util.utcnow) created_at = attr.ib(type=datetime, factory=dt_util.utcnow)
token = attr.ib(type=str, factory=lambda: generate_secret(64)) token = attr.ib(type=str, factory=lambda: generate_secret(64))
@ -124,5 +121,4 @@ class Credentials:
is_new = attr.ib(type=bool, default=True) is_new = attr.ib(type=bool, default=True)
UserMeta = NamedTuple("UserMeta", UserMeta = NamedTuple("UserMeta", [("name", Optional[str]), ("is_active", bool)])
[('name', Optional[str]), ('is_active', bool)])

View file

@ -1,8 +1,17 @@
"""Permissions for Home Assistant.""" """Permissions for Home Assistant."""
import logging import logging
from typing import ( # noqa: F401 from typing import ( # noqa: F401
cast, Any, Callable, Dict, List, Mapping, Set, Tuple, Union, cast,
TYPE_CHECKING) Any,
Callable,
Dict,
List,
Mapping,
Set,
Tuple,
Union,
TYPE_CHECKING,
)
import voluptuous as vol import voluptuous as vol
@ -14,9 +23,7 @@ from .merge import merge_policies # noqa
from .util import test_all from .util import test_all
POLICY_SCHEMA = vol.Schema({ POLICY_SCHEMA = vol.Schema({vol.Optional(CAT_ENTITIES): ENTITY_POLICY_SCHEMA})
vol.Optional(CAT_ENTITIES): ENTITY_POLICY_SCHEMA
})
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -47,8 +54,7 @@ class AbstractPermissions:
class PolicyPermissions(AbstractPermissions): class PolicyPermissions(AbstractPermissions):
"""Handle permissions.""" """Handle permissions."""
def __init__(self, policy: PolicyType, def __init__(self, policy: PolicyType, perm_lookup: PermissionLookup) -> None:
perm_lookup: PermissionLookup) -> None:
"""Initialize the permission class.""" """Initialize the permission class."""
self._policy = policy self._policy = policy
self._perm_lookup = perm_lookup self._perm_lookup = perm_lookup
@ -59,14 +65,12 @@ class PolicyPermissions(AbstractPermissions):
def _entity_func(self) -> Callable[[str, str], bool]: def _entity_func(self) -> Callable[[str, str], bool]:
"""Return a function that can test entity access.""" """Return a function that can test entity access."""
return compile_entities(self._policy.get(CAT_ENTITIES), return compile_entities(self._policy.get(CAT_ENTITIES), self._perm_lookup)
self._perm_lookup)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
"""Equals check.""" """Equals check."""
# pylint: disable=protected-access # pylint: disable=protected-access
return (isinstance(other, PolicyPermissions) and return isinstance(other, PolicyPermissions) and other._policy == self._policy
other._policy == self._policy)
class _OwnerPermissions(AbstractPermissions): class _OwnerPermissions(AbstractPermissions):

View file

@ -1,8 +1,8 @@
"""Permission constants.""" """Permission constants."""
CAT_ENTITIES = 'entities' CAT_ENTITIES = "entities"
CAT_CONFIG_ENTRIES = 'config_entries' CAT_CONFIG_ENTRIES = "config_entries"
SUBCAT_ALL = 'all' SUBCAT_ALL = "all"
POLICY_READ = 'read' POLICY_READ = "read"
POLICY_CONTROL = 'control' POLICY_CONTROL = "control"
POLICY_EDIT = 'edit' POLICY_EDIT = "edit"

View file

@ -7,51 +7,59 @@ import voluptuous as vol
from .const import SUBCAT_ALL, POLICY_READ, POLICY_CONTROL, POLICY_EDIT from .const import SUBCAT_ALL, POLICY_READ, POLICY_CONTROL, POLICY_EDIT
from .models import PermissionLookup from .models import PermissionLookup
from .types import CategoryType, SubCategoryDict, ValueType from .types import CategoryType, SubCategoryDict, ValueType
# pylint: disable=unused-import # pylint: disable=unused-import
from .util import SubCatLookupType, lookup_all, compile_policy # noqa from .util import SubCatLookupType, lookup_all, compile_policy # noqa
SINGLE_ENTITY_SCHEMA = vol.Any(True, vol.Schema({ SINGLE_ENTITY_SCHEMA = vol.Any(
vol.Optional(POLICY_READ): True, True,
vol.Optional(POLICY_CONTROL): True, vol.Schema(
vol.Optional(POLICY_EDIT): True, {
})) vol.Optional(POLICY_READ): True,
vol.Optional(POLICY_CONTROL): True,
vol.Optional(POLICY_EDIT): True,
}
),
)
ENTITY_DOMAINS = 'domains' ENTITY_DOMAINS = "domains"
ENTITY_AREAS = 'area_ids' ENTITY_AREAS = "area_ids"
ENTITY_DEVICE_IDS = 'device_ids' ENTITY_DEVICE_IDS = "device_ids"
ENTITY_ENTITY_IDS = 'entity_ids' ENTITY_ENTITY_IDS = "entity_ids"
ENTITY_VALUES_SCHEMA = vol.Any(True, vol.Schema({ ENTITY_VALUES_SCHEMA = vol.Any(True, vol.Schema({str: SINGLE_ENTITY_SCHEMA}))
str: SINGLE_ENTITY_SCHEMA
}))
ENTITY_POLICY_SCHEMA = vol.Any(True, vol.Schema({ ENTITY_POLICY_SCHEMA = vol.Any(
vol.Optional(SUBCAT_ALL): SINGLE_ENTITY_SCHEMA, True,
vol.Optional(ENTITY_AREAS): ENTITY_VALUES_SCHEMA, vol.Schema(
vol.Optional(ENTITY_DEVICE_IDS): ENTITY_VALUES_SCHEMA, {
vol.Optional(ENTITY_DOMAINS): ENTITY_VALUES_SCHEMA, vol.Optional(SUBCAT_ALL): SINGLE_ENTITY_SCHEMA,
vol.Optional(ENTITY_ENTITY_IDS): ENTITY_VALUES_SCHEMA, vol.Optional(ENTITY_AREAS): ENTITY_VALUES_SCHEMA,
})) vol.Optional(ENTITY_DEVICE_IDS): ENTITY_VALUES_SCHEMA,
vol.Optional(ENTITY_DOMAINS): ENTITY_VALUES_SCHEMA,
vol.Optional(ENTITY_ENTITY_IDS): ENTITY_VALUES_SCHEMA,
}
),
)
def _lookup_domain(perm_lookup: PermissionLookup, def _lookup_domain(
domains_dict: SubCategoryDict, perm_lookup: PermissionLookup, domains_dict: SubCategoryDict, entity_id: str
entity_id: str) -> Optional[ValueType]: ) -> Optional[ValueType]:
"""Look up entity permissions by domain.""" """Look up entity permissions by domain."""
return domains_dict.get(entity_id.split(".", 1)[0]) return domains_dict.get(entity_id.split(".", 1)[0])
def _lookup_area(perm_lookup: PermissionLookup, area_dict: SubCategoryDict, def _lookup_area(
entity_id: str) -> Optional[ValueType]: perm_lookup: PermissionLookup, area_dict: SubCategoryDict, entity_id: str
) -> Optional[ValueType]:
"""Look up entity permissions by area.""" """Look up entity permissions by area."""
entity_entry = perm_lookup.entity_registry.async_get(entity_id) entity_entry = perm_lookup.entity_registry.async_get(entity_id)
if entity_entry is None or entity_entry.device_id is None: if entity_entry is None or entity_entry.device_id is None:
return None return None
device_entry = perm_lookup.device_registry.async_get( device_entry = perm_lookup.device_registry.async_get(entity_entry.device_id)
entity_entry.device_id
)
if device_entry is None or device_entry.area_id is None: if device_entry is None or device_entry.area_id is None:
return None return None
@ -59,9 +67,9 @@ def _lookup_area(perm_lookup: PermissionLookup, area_dict: SubCategoryDict,
return area_dict.get(device_entry.area_id) return area_dict.get(device_entry.area_id)
def _lookup_device(perm_lookup: PermissionLookup, def _lookup_device(
devices_dict: SubCategoryDict, perm_lookup: PermissionLookup, devices_dict: SubCategoryDict, entity_id: str
entity_id: str) -> Optional[ValueType]: ) -> Optional[ValueType]:
"""Look up entity permissions by device.""" """Look up entity permissions by device."""
entity_entry = perm_lookup.entity_registry.async_get(entity_id) entity_entry = perm_lookup.entity_registry.async_get(entity_id)
@ -71,15 +79,16 @@ def _lookup_device(perm_lookup: PermissionLookup,
return devices_dict.get(entity_entry.device_id) return devices_dict.get(entity_entry.device_id)
def _lookup_entity_id(perm_lookup: PermissionLookup, def _lookup_entity_id(
entities_dict: SubCategoryDict, perm_lookup: PermissionLookup, entities_dict: SubCategoryDict, entity_id: str
entity_id: str) -> Optional[ValueType]: ) -> Optional[ValueType]:
"""Look up entity permission by entity id.""" """Look up entity permission by entity id."""
return entities_dict.get(entity_id) return entities_dict.get(entity_id)
def compile_entities(policy: CategoryType, perm_lookup: PermissionLookup) \ def compile_entities(
-> Callable[[str, str], bool]: policy: CategoryType, perm_lookup: PermissionLookup
) -> Callable[[str, str], bool]:
"""Compile policy into a function that tests policy.""" """Compile policy into a function that tests policy."""
subcategories = OrderedDict() # type: SubCatLookupType subcategories = OrderedDict() # type: SubCatLookupType
subcategories[ENTITY_ENTITY_IDS] = _lookup_entity_id subcategories[ENTITY_ENTITY_IDS] = _lookup_entity_id

View file

@ -1,6 +1,5 @@
"""Merging of policies.""" """Merging of policies."""
from typing import ( # noqa: F401 from typing import cast, Dict, List, Set # noqa: F401
cast, Dict, List, Set)
from .types import PolicyType, CategoryType from .types import PolicyType, CategoryType
@ -14,8 +13,9 @@ def merge_policies(policies: List[PolicyType]) -> PolicyType:
if category in seen: if category in seen:
continue continue
seen.add(category) seen.add(category)
new_policy[category] = _merge_policies([ new_policy[category] = _merge_policies(
policy.get(category) for policy in policies]) [policy.get(category) for policy in policies]
)
cast(PolicyType, new_policy) cast(PolicyType, new_policy)
return new_policy return new_policy

View file

@ -5,17 +5,13 @@ import attr
if TYPE_CHECKING: if TYPE_CHECKING:
# pylint: disable=unused-import # pylint: disable=unused-import
from homeassistant.helpers import ( # noqa from homeassistant.helpers import entity_registry as ent_reg # noqa
entity_registry as ent_reg, from homeassistant.helpers import device_registry as dev_reg # noqa
)
from homeassistant.helpers import ( # noqa
device_registry as dev_reg,
)
@attr.s(slots=True) @attr.s(slots=True)
class PermissionLookup: class PermissionLookup:
"""Class to hold data for permission lookups.""" """Class to hold data for permission lookups."""
entity_registry = attr.ib(type='ent_reg.EntityRegistry') entity_registry = attr.ib(type="ent_reg.EntityRegistry")
device_registry = attr.ib(type='dev_reg.DeviceRegistry') device_registry = attr.ib(type="dev_reg.DeviceRegistry")

View file

@ -1,18 +1,8 @@
"""System policies.""" """System policies."""
from .const import CAT_ENTITIES, SUBCAT_ALL, POLICY_READ from .const import CAT_ENTITIES, SUBCAT_ALL, POLICY_READ
ADMIN_POLICY = { ADMIN_POLICY = {CAT_ENTITIES: True}
CAT_ENTITIES: True,
}
USER_POLICY = { USER_POLICY = {CAT_ENTITIES: True}
CAT_ENTITIES: True,
}
READ_ONLY_POLICY = { READ_ONLY_POLICY = {CAT_ENTITIES: {SUBCAT_ALL: {POLICY_READ: True}}}
CAT_ENTITIES: {
SUBCAT_ALL: {
POLICY_READ: True
}
}
}

View file

@ -7,17 +7,13 @@ ValueType = Union[
# Example: entities.all = { read: true, control: true } # Example: entities.all = { read: true, control: true }
Mapping[str, bool], Mapping[str, bool],
bool, bool,
None None,
] ]
# Example: entities.domains = { light: … } # Example: entities.domains = { light: … }
SubCategoryDict = Mapping[str, ValueType] SubCategoryDict = Mapping[str, ValueType]
SubCategoryType = Union[ SubCategoryType = Union[SubCategoryDict, bool, None]
SubCategoryDict,
bool,
None
]
CategoryType = Union[ CategoryType = Union[
# Example: entities.domains # Example: entities.domains
@ -25,7 +21,7 @@ CategoryType = Union[
# Example: entities.all # Example: entities.all
Mapping[str, ValueType], Mapping[str, ValueType],
bool, bool,
None None,
] ]
# Example: { entities: … } # Example: { entities: … }

View file

@ -7,28 +7,28 @@ from .const import SUBCAT_ALL
from .models import PermissionLookup from .models import PermissionLookup
from .types import CategoryType, SubCategoryDict, ValueType from .types import CategoryType, SubCategoryDict, ValueType
LookupFunc = Callable[[PermissionLookup, SubCategoryDict, str], LookupFunc = Callable[[PermissionLookup, SubCategoryDict, str], Optional[ValueType]]
Optional[ValueType]]
SubCatLookupType = Dict[str, LookupFunc] SubCatLookupType = Dict[str, LookupFunc]
def lookup_all(perm_lookup: PermissionLookup, lookup_dict: SubCategoryDict, def lookup_all(
object_id: str) -> ValueType: perm_lookup: PermissionLookup, lookup_dict: SubCategoryDict, object_id: str
) -> ValueType:
"""Look up permission for all.""" """Look up permission for all."""
# In case of ALL category, lookup_dict IS the schema. # In case of ALL category, lookup_dict IS the schema.
return cast(ValueType, lookup_dict) return cast(ValueType, lookup_dict)
def compile_policy( def compile_policy(
policy: CategoryType, subcategories: SubCatLookupType, policy: CategoryType, subcategories: SubCatLookupType, perm_lookup: PermissionLookup
perm_lookup: PermissionLookup ) -> Callable[[str, str], bool]: # noqa
) -> Callable[[str, str], bool]: # noqa
"""Compile policy into a function that tests policy. """Compile policy into a function that tests policy.
Subcategories are mapping key -> lookup function, ordered by highest Subcategories are mapping key -> lookup function, ordered by highest
priority first. priority first.
""" """
# None, False, empty dict # None, False, empty dict
if not policy: if not policy:
def apply_policy_deny_all(entity_id: str, key: str) -> bool: def apply_policy_deny_all(entity_id: str, key: str) -> bool:
"""Decline all.""" """Decline all."""
return False return False
@ -36,6 +36,7 @@ def compile_policy(
return apply_policy_deny_all return apply_policy_deny_all
if policy is True: if policy is True:
def apply_policy_allow_all(entity_id: str, key: str) -> bool: def apply_policy_allow_all(entity_id: str, key: str) -> bool:
"""Approve all.""" """Approve all."""
return True return True
@ -54,8 +55,7 @@ def compile_policy(
return lambda object_id, key: True return lambda object_id, key: True
if lookup_value is not None: if lookup_value is not None:
funcs.append(_gen_dict_test_func( funcs.append(_gen_dict_test_func(perm_lookup, lookup_func, lookup_value))
perm_lookup, lookup_func, lookup_value))
if len(funcs) == 1: if len(funcs) == 1:
func = funcs[0] func = funcs[0]
@ -79,15 +79,13 @@ def compile_policy(
def _gen_dict_test_func( def _gen_dict_test_func(
perm_lookup: PermissionLookup, perm_lookup: PermissionLookup, lookup_func: LookupFunc, lookup_dict: SubCategoryDict
lookup_func: LookupFunc, ) -> Callable[[str, str], Optional[bool]]: # noqa
lookup_dict: SubCategoryDict
) -> Callable[[str, str], Optional[bool]]: # noqa
"""Generate a lookup function.""" """Generate a lookup function."""
def test_value(object_id: str, key: str) -> Optional[bool]: def test_value(object_id: str, key: str) -> Optional[bool]:
"""Test if permission is allowed based on the keys.""" """Test if permission is allowed based on the keys."""
schema = lookup_func( schema = lookup_func(perm_lookup, lookup_dict, object_id) # type: ValueType
perm_lookup, lookup_dict, object_id) # type: ValueType
if schema is None or isinstance(schema, bool): if schema is None or isinstance(schema, bool):
return schema return schema

View file

@ -19,25 +19,29 @@ from ..const import MFA_SESSION_EXPIRATION
from ..models import Credentials, User, UserMeta # noqa: F401 from ..models import Credentials, User, UserMeta # noqa: F401
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DATA_REQS = 'auth_prov_reqs_processed' DATA_REQS = "auth_prov_reqs_processed"
AUTH_PROVIDERS = Registry() AUTH_PROVIDERS = Registry()
AUTH_PROVIDER_SCHEMA = vol.Schema({ AUTH_PROVIDER_SCHEMA = vol.Schema(
vol.Required(CONF_TYPE): str, {
vol.Optional(CONF_NAME): str, vol.Required(CONF_TYPE): str,
# Specify ID if you have two auth providers for same type. vol.Optional(CONF_NAME): str,
vol.Optional(CONF_ID): str, # Specify ID if you have two auth providers for same type.
}, extra=vol.ALLOW_EXTRA) vol.Optional(CONF_ID): str,
},
extra=vol.ALLOW_EXTRA,
)
class AuthProvider: class AuthProvider:
"""Provider of user authentication.""" """Provider of user authentication."""
DEFAULT_TITLE = 'Unnamed auth provider' DEFAULT_TITLE = "Unnamed auth provider"
def __init__(self, hass: HomeAssistant, store: AuthStore, def __init__(
config: Dict[str, Any]) -> None: self, hass: HomeAssistant, store: AuthStore, config: Dict[str, Any]
) -> None:
"""Initialize an auth provider.""" """Initialize an auth provider."""
self.hass = hass self.hass = hass
self.store = store self.store = store
@ -73,22 +77,22 @@ class AuthProvider:
credentials credentials
for user in users for user in users
for credentials in user.credentials for credentials in user.credentials
if (credentials.auth_provider_type == self.type and if (
credentials.auth_provider_id == self.id) credentials.auth_provider_type == self.type
and credentials.auth_provider_id == self.id
)
] ]
@callback @callback
def async_create_credentials(self, data: Dict[str, str]) -> Credentials: def async_create_credentials(self, data: Dict[str, str]) -> Credentials:
"""Create credentials.""" """Create credentials."""
return Credentials( return Credentials(
auth_provider_type=self.type, auth_provider_type=self.type, auth_provider_id=self.id, data=data
auth_provider_id=self.id,
data=data,
) )
# Implement by extending class # Implement by extending class
async def async_login_flow(self, context: Optional[Dict]) -> 'LoginFlow': async def async_login_flow(self, context: Optional[Dict]) -> "LoginFlow":
"""Return the data flow for logging in with auth provider. """Return the data flow for logging in with auth provider.
Auth provider should extend LoginFlow and return an instance. Auth provider should extend LoginFlow and return an instance.
@ -96,12 +100,14 @@ class AuthProvider:
raise NotImplementedError raise NotImplementedError
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]) -> Credentials: self, flow_result: Dict[str, str]
) -> Credentials:
"""Get credentials based on the flow result.""" """Get credentials based on the flow result."""
raise NotImplementedError raise NotImplementedError
async def async_user_meta_for_credentials( async def async_user_meta_for_credentials(
self, credentials: Credentials) -> UserMeta: self, credentials: Credentials
) -> UserMeta:
"""Return extra user metadata for credentials. """Return extra user metadata for credentials.
Will be used to populate info when creating a new user. Will be used to populate info when creating a new user.
@ -114,8 +120,8 @@ class AuthProvider:
async def auth_provider_from_config( async def auth_provider_from_config(
hass: HomeAssistant, store: AuthStore, hass: HomeAssistant, store: AuthStore, config: Dict[str, Any]
config: Dict[str, Any]) -> AuthProvider: ) -> AuthProvider:
"""Initialize an auth provider from a config.""" """Initialize an auth provider from a config."""
provider_name = config[CONF_TYPE] provider_name = config[CONF_TYPE]
module = await load_auth_provider_module(hass, provider_name) module = await load_auth_provider_module(hass, provider_name)
@ -123,25 +129,31 @@ async def auth_provider_from_config(
try: try:
config = module.CONFIG_SCHEMA(config) # type: ignore config = module.CONFIG_SCHEMA(config) # type: ignore
except vol.Invalid as err: except vol.Invalid as err:
_LOGGER.error('Invalid configuration for auth provider %s: %s', _LOGGER.error(
provider_name, humanize_error(config, err)) "Invalid configuration for auth provider %s: %s",
provider_name,
humanize_error(config, err),
)
raise raise
return AUTH_PROVIDERS[provider_name](hass, store, config) # type: ignore return AUTH_PROVIDERS[provider_name](hass, store, config) # type: ignore
async def load_auth_provider_module( async def load_auth_provider_module(
hass: HomeAssistant, provider: str) -> types.ModuleType: hass: HomeAssistant, provider: str
) -> types.ModuleType:
"""Load an auth provider.""" """Load an auth provider."""
try: try:
module = importlib.import_module( module = importlib.import_module(
'homeassistant.auth.providers.{}'.format(provider)) "homeassistant.auth.providers.{}".format(provider)
)
except ImportError as err: except ImportError as err:
_LOGGER.error('Unable to load auth provider %s: %s', provider, err) _LOGGER.error("Unable to load auth provider %s: %s", provider, err)
raise HomeAssistantError('Unable to load auth provider {}: {}'.format( raise HomeAssistantError(
provider, err)) "Unable to load auth provider {}: {}".format(provider, err)
)
if hass.config.skip_pip or not hasattr(module, 'REQUIREMENTS'): if hass.config.skip_pip or not hasattr(module, "REQUIREMENTS"):
return module return module
processed = hass.data.get(DATA_REQS) processed = hass.data.get(DATA_REQS)
@ -154,12 +166,13 @@ async def load_auth_provider_module(
# https://github.com/python/mypy/issues/1424 # https://github.com/python/mypy/issues/1424
reqs = module.REQUIREMENTS # type: ignore reqs = module.REQUIREMENTS # type: ignore
req_success = await requirements.async_process_requirements( req_success = await requirements.async_process_requirements(
hass, 'auth provider {}'.format(provider), reqs) hass, "auth provider {}".format(provider), reqs
)
if not req_success: if not req_success:
raise HomeAssistantError( raise HomeAssistantError(
'Unable to process requirements of auth provider {}'.format( "Unable to process requirements of auth provider {}".format(provider)
provider)) )
processed.add(provider) processed.add(provider)
return module return module
@ -179,8 +192,8 @@ class LoginFlow(data_entry_flow.FlowHandler):
self.user = None # type: Optional[User] self.user = None # type: Optional[User]
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None
-> Dict[str, Any]: ) -> Dict[str, Any]:
"""Handle the first step of login flow. """Handle the first step of login flow.
Return self.async_show_form(step_id='init') if user_input is None. Return self.async_show_form(step_id='init') if user_input is None.
@ -189,80 +202,75 @@ class LoginFlow(data_entry_flow.FlowHandler):
raise NotImplementedError raise NotImplementedError
async def async_step_select_mfa_module( async def async_step_select_mfa_module(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None
-> Dict[str, Any]: ) -> Dict[str, Any]:
"""Handle the step of select mfa module.""" """Handle the step of select mfa module."""
errors = {} errors = {}
if user_input is not None: if user_input is not None:
auth_module = user_input.get('multi_factor_auth_module') auth_module = user_input.get("multi_factor_auth_module")
if auth_module in self.available_mfa_modules: if auth_module in self.available_mfa_modules:
self._auth_module_id = auth_module self._auth_module_id = auth_module
return await self.async_step_mfa() return await self.async_step_mfa()
errors['base'] = 'invalid_auth_module' errors["base"] = "invalid_auth_module"
if len(self.available_mfa_modules) == 1: if len(self.available_mfa_modules) == 1:
self._auth_module_id = list(self.available_mfa_modules.keys())[0] self._auth_module_id = list(self.available_mfa_modules.keys())[0]
return await self.async_step_mfa() return await self.async_step_mfa()
return self.async_show_form( return self.async_show_form(
step_id='select_mfa_module', step_id="select_mfa_module",
data_schema=vol.Schema({ data_schema=vol.Schema(
'multi_factor_auth_module': vol.In(self.available_mfa_modules) {"multi_factor_auth_module": vol.In(self.available_mfa_modules)}
}), ),
errors=errors, errors=errors,
) )
async def async_step_mfa( async def async_step_mfa(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None
-> Dict[str, Any]: ) -> Dict[str, Any]:
"""Handle the step of mfa validation.""" """Handle the step of mfa validation."""
assert self.user assert self.user
errors = {} errors = {}
auth_module = self._auth_manager.get_auth_mfa_module( auth_module = self._auth_manager.get_auth_mfa_module(self._auth_module_id)
self._auth_module_id)
if auth_module is None: if auth_module is None:
# Given an invalid input to async_step_select_mfa_module # Given an invalid input to async_step_select_mfa_module
# will show invalid_auth_module error # will show invalid_auth_module error
return await self.async_step_select_mfa_module(user_input={}) return await self.async_step_select_mfa_module(user_input={})
if user_input is None and hasattr(auth_module, if user_input is None and hasattr(
'async_initialize_login_mfa_step'): auth_module, "async_initialize_login_mfa_step"
):
try: try:
await auth_module.async_initialize_login_mfa_step(self.user.id) await auth_module.async_initialize_login_mfa_step(self.user.id)
except HomeAssistantError: except HomeAssistantError:
_LOGGER.exception('Error initializing MFA step') _LOGGER.exception("Error initializing MFA step")
return self.async_abort(reason='unknown_error') return self.async_abort(reason="unknown_error")
if user_input is not None: if user_input is not None:
expires = self.created_at + MFA_SESSION_EXPIRATION expires = self.created_at + MFA_SESSION_EXPIRATION
if dt_util.utcnow() > expires: if dt_util.utcnow() > expires:
return self.async_abort( return self.async_abort(reason="login_expired")
reason='login_expired'
)
result = await auth_module.async_validate( result = await auth_module.async_validate(self.user.id, user_input)
self.user.id, user_input)
if not result: if not result:
errors['base'] = 'invalid_code' errors["base"] = "invalid_code"
self.invalid_mfa_times += 1 self.invalid_mfa_times += 1
if self.invalid_mfa_times >= auth_module.MAX_RETRY_TIME > 0: if self.invalid_mfa_times >= auth_module.MAX_RETRY_TIME > 0:
return self.async_abort( return self.async_abort(reason="too_many_retry")
reason='too_many_retry'
)
if not errors: if not errors:
return await self.async_finish(self.user) return await self.async_finish(self.user)
description_placeholders = { description_placeholders = {
'mfa_module_name': auth_module.name, "mfa_module_name": auth_module.name,
'mfa_module_id': auth_module.id, "mfa_module_id": auth_module.id,
} # type: Dict[str, Optional[str]] } # type: Dict[str, Optional[str]]
return self.async_show_form( return self.async_show_form(
step_id='mfa', step_id="mfa",
data_schema=auth_module.input_schema, data_schema=auth_module.input_schema,
description_placeholders=description_placeholders, description_placeholders=description_placeholders,
errors=errors, errors=errors,
@ -270,7 +278,4 @@ class LoginFlow(data_entry_flow.FlowHandler):
async def async_finish(self, flow_result: Any) -> Dict: async def async_finish(self, flow_result: Any) -> Dict:
"""Handle the pass of login flow.""" """Handle the pass of login flow."""
return self.async_create_entry( return self.async_create_entry(title=self._auth_provider.name, data=flow_result)
title=self._auth_provider.name,
data=flow_result
)

View file

@ -19,15 +19,16 @@ CONF_COMMAND = "command"
CONF_ARGS = "args" CONF_ARGS = "args"
CONF_META = "meta" CONF_META = "meta"
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({ CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend(
vol.Required(CONF_COMMAND): vol.All( {
str, vol.Required(CONF_COMMAND): vol.All(
os.path.normpath, str, os.path.normpath, msg="must be an absolute path"
msg="must be an absolute path" ),
), vol.Optional(CONF_ARGS, default=None): vol.Any(vol.DefaultTo(list), [str]),
vol.Optional(CONF_ARGS, default=None): vol.Any(vol.DefaultTo(list), [str]), vol.Optional(CONF_META, default=False): bool,
vol.Optional(CONF_META, default=False): bool, },
}, extra=vol.PREVENT_EXTRA) extra=vol.PREVENT_EXTRA,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -60,29 +61,27 @@ class CommandLineAuthProvider(AuthProvider):
async def async_validate_login(self, username: str, password: str) -> None: async def async_validate_login(self, username: str, password: str) -> None:
"""Validate a username and password.""" """Validate a username and password."""
env = { env = {"username": username, "password": password}
"username": username,
"password": password,
}
try: try:
# pylint: disable=no-member # pylint: disable=no-member
process = await asyncio.subprocess.create_subprocess_exec( process = await asyncio.subprocess.create_subprocess_exec(
self.config[CONF_COMMAND], *self.config[CONF_ARGS], self.config[CONF_COMMAND],
*self.config[CONF_ARGS],
env=env, env=env,
stdout=asyncio.subprocess.PIPE stdout=asyncio.subprocess.PIPE if self.config[CONF_META] else None,
if self.config[CONF_META] else None,
) )
stdout, _ = (await process.communicate()) stdout, _ = await process.communicate()
except OSError as err: except OSError as err:
# happens when command doesn't exist or permission is denied # happens when command doesn't exist or permission is denied
_LOGGER.error("Error while authenticating %r: %s", _LOGGER.error("Error while authenticating %r: %s", username, err)
username, err)
raise InvalidAuthError raise InvalidAuthError
if process.returncode != 0: if process.returncode != 0:
_LOGGER.error("User %r failed to authenticate, command exited " _LOGGER.error(
"with code %d.", "User %r failed to authenticate, command exited " "with code %d.",
username, process.returncode) username,
process.returncode,
)
raise InvalidAuthError raise InvalidAuthError
if self.config[CONF_META]: if self.config[CONF_META]:
@ -103,7 +102,7 @@ class CommandLineAuthProvider(AuthProvider):
self._user_meta[username] = meta self._user_meta[username] = meta
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
self, flow_result: Dict[str, str] self, flow_result: Dict[str, str]
) -> Credentials: ) -> Credentials:
"""Get credentials based on the flow result.""" """Get credentials based on the flow result."""
username = flow_result["username"] username = flow_result["username"]
@ -112,29 +111,24 @@ class CommandLineAuthProvider(AuthProvider):
return credential return credential
# Create new credentials. # Create new credentials.
return self.async_create_credentials({ return self.async_create_credentials({"username": username})
"username": username,
})
async def async_user_meta_for_credentials( async def async_user_meta_for_credentials(
self, credentials: Credentials self, credentials: Credentials
) -> UserMeta: ) -> UserMeta:
"""Return extra user metadata for credentials. """Return extra user metadata for credentials.
Currently, only name is supported. Currently, only name is supported.
""" """
meta = self._user_meta.get(credentials.data["username"], {}) meta = self._user_meta.get(credentials.data["username"], {})
return UserMeta( return UserMeta(name=meta.get("name"), is_active=True)
name=meta.get("name"),
is_active=True,
)
class CommandLineLoginFlow(LoginFlow): class CommandLineLoginFlow(LoginFlow):
"""Handler for the login flow.""" """Handler for the login flow."""
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None self, user_input: Optional[Dict[str, str]] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Handle the step of the form.""" """Handle the step of the form."""
errors = {} errors = {}
@ -142,10 +136,9 @@ class CommandLineLoginFlow(LoginFlow):
if user_input is not None: if user_input is not None:
user_input["username"] = user_input["username"].strip() user_input["username"] = user_input["username"].strip()
try: try:
await cast(CommandLineAuthProvider, self._auth_provider) \ await cast(
.async_validate_login( CommandLineAuthProvider, self._auth_provider
user_input["username"], user_input["password"] ).async_validate_login(user_input["username"], user_input["password"])
)
except InvalidAuthError: except InvalidAuthError:
errors["base"] = "invalid_auth" errors["base"] = "invalid_auth"
@ -158,7 +151,5 @@ class CommandLineLoginFlow(LoginFlow):
schema["password"] = str schema["password"] = str
return self.async_show_form( return self.async_show_form(
step_id="init", step_id="init", data_schema=vol.Schema(schema), errors=errors
data_schema=vol.Schema(schema),
errors=errors,
) )

View file

@ -19,14 +19,13 @@ from ..models import Credentials, UserMeta
STORAGE_VERSION = 1 STORAGE_VERSION = 1
STORAGE_KEY = 'auth_provider.homeassistant' STORAGE_KEY = "auth_provider.homeassistant"
def _disallow_id(conf: Dict[str, Any]) -> Dict[str, Any]: def _disallow_id(conf: Dict[str, Any]) -> Dict[str, Any]:
"""Disallow ID in config.""" """Disallow ID in config."""
if CONF_ID in conf: if CONF_ID in conf:
raise vol.Invalid( raise vol.Invalid("ID is not allowed for the homeassistant auth provider.")
'ID is not allowed for the homeassistant auth provider.')
return conf return conf
@ -51,8 +50,9 @@ class Data:
def __init__(self, hass: HomeAssistant) -> None: def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the user data store.""" """Initialize the user data store."""
self.hass = hass self.hass = hass
self._store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY, self._store = hass.helpers.storage.Store(
private=True) STORAGE_VERSION, STORAGE_KEY, private=True
)
self._data = None # type: Optional[Dict[str, Any]] self._data = None # type: Optional[Dict[str, Any]]
# Legacy mode will allow usernames to start/end with whitespace # Legacy mode will allow usernames to start/end with whitespace
# and will compare usernames case-insensitive. # and will compare usernames case-insensitive.
@ -72,14 +72,12 @@ class Data:
data = await self._store.async_load() data = await self._store.async_load()
if data is None: if data is None:
data = { data = {"users": []}
'users': []
}
seen = set() # type: Set[str] seen = set() # type: Set[str]
for user in data['users']: for user in data["users"]:
username = user['username'] username = user["username"]
# check if we have duplicates # check if we have duplicates
folded = username.casefold() folded = username.casefold()
@ -90,7 +88,9 @@ class Data:
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
"Home Assistant auth provider is running in legacy mode " "Home Assistant auth provider is running in legacy mode "
"because we detected usernames that are case-insensitive" "because we detected usernames that are case-insensitive"
"equivalent. Please change the username: '%s'.", username) "equivalent. Please change the username: '%s'.",
username,
)
break break
@ -103,7 +103,9 @@ class Data:
logging.getLogger(__name__).warning( logging.getLogger(__name__).warning(
"Home Assistant auth provider is running in legacy mode " "Home Assistant auth provider is running in legacy mode "
"because we detected usernames that start or end in a " "because we detected usernames that start or end in a "
"space. Please change the username: '%s'.", username) "space. Please change the username: '%s'.",
username,
)
break break
@ -112,7 +114,7 @@ class Data:
@property @property
def users(self) -> List[Dict[str, str]]: def users(self) -> List[Dict[str, str]]:
"""Return users.""" """Return users."""
return self._data['users'] # type: ignore return self._data["users"] # type: ignore
def validate_login(self, username: str, password: str) -> None: def validate_login(self, username: str, password: str) -> None:
"""Validate a username and password. """Validate a username and password.
@ -120,32 +122,30 @@ class Data:
Raises InvalidAuth if auth invalid. Raises InvalidAuth if auth invalid.
""" """
username = self.normalize_username(username) username = self.normalize_username(username)
dummy = b'$2b$12$CiuFGszHx9eNHxPuQcwBWez4CwDTOcLTX5CbOpV6gef2nYuXkY7BO' dummy = b"$2b$12$CiuFGszHx9eNHxPuQcwBWez4CwDTOcLTX5CbOpV6gef2nYuXkY7BO"
found = None found = None
# Compare all users to avoid timing attacks. # Compare all users to avoid timing attacks.
for user in self.users: for user in self.users:
if self.normalize_username(user['username']) == username: if self.normalize_username(user["username"]) == username:
found = user found = user
if found is None: if found is None:
# check a hash to make timing the same as if user was found # check a hash to make timing the same as if user was found
bcrypt.checkpw(b'foo', bcrypt.checkpw(b"foo", dummy)
dummy)
raise InvalidAuth raise InvalidAuth
user_hash = base64.b64decode(found['password']) user_hash = base64.b64decode(found["password"])
# bcrypt.checkpw is timing-safe # bcrypt.checkpw is timing-safe
if not bcrypt.checkpw(password.encode(), if not bcrypt.checkpw(password.encode(), user_hash):
user_hash):
raise InvalidAuth raise InvalidAuth
# pylint: disable=no-self-use # pylint: disable=no-self-use
def hash_password(self, password: str, for_storage: bool = False) -> bytes: def hash_password(self, password: str, for_storage: bool = False) -> bytes:
"""Encode a password.""" """Encode a password."""
hashed = bcrypt.hashpw(password.encode(), bcrypt.gensalt(rounds=12)) \ hashed = bcrypt.hashpw(password.encode(), bcrypt.gensalt(rounds=12))
# type: bytes # type: bytes
if for_storage: if for_storage:
hashed = base64.b64encode(hashed) hashed = base64.b64encode(hashed)
return hashed return hashed
@ -154,14 +154,17 @@ class Data:
"""Add a new authenticated user/pass.""" """Add a new authenticated user/pass."""
username = self.normalize_username(username) username = self.normalize_username(username)
if any(self.normalize_username(user['username']) == username if any(
for user in self.users): self.normalize_username(user["username"]) == username for user in self.users
):
raise InvalidUser raise InvalidUser
self.users.append({ self.users.append(
'username': username, {
'password': self.hash_password(password, True).decode(), "username": username,
}) "password": self.hash_password(password, True).decode(),
}
)
@callback @callback
def async_remove_auth(self, username: str) -> None: def async_remove_auth(self, username: str) -> None:
@ -170,7 +173,7 @@ class Data:
index = None index = None
for i, user in enumerate(self.users): for i, user in enumerate(self.users):
if self.normalize_username(user['username']) == username: if self.normalize_username(user["username"]) == username:
index = i index = i
break break
@ -187,9 +190,8 @@ class Data:
username = self.normalize_username(username) username = self.normalize_username(username)
for user in self.users: for user in self.users:
if self.normalize_username(user['username']) == username: if self.normalize_username(user["username"]) == username:
user['password'] = self.hash_password( user["password"] = self.hash_password(new_password, True).decode()
new_password, True).decode()
break break
else: else:
raise InvalidUser raise InvalidUser
@ -199,11 +201,11 @@ class Data:
await self._store.async_save(self._data) await self._store.async_save(self._data)
@AUTH_PROVIDERS.register('homeassistant') @AUTH_PROVIDERS.register("homeassistant")
class HassAuthProvider(AuthProvider): class HassAuthProvider(AuthProvider):
"""Auth provider based on a local storage of users in HASS config dir.""" """Auth provider based on a local storage of users in HASS config dir."""
DEFAULT_TITLE = 'Home Assistant Local' DEFAULT_TITLE = "Home Assistant Local"
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize an Home Assistant auth provider.""" """Initialize an Home Assistant auth provider."""
@ -221,8 +223,7 @@ class HassAuthProvider(AuthProvider):
await data.async_load() await data.async_load()
self.data = data self.data = data
async def async_login_flow( async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
self, context: Optional[Dict]) -> LoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
return HassLoginFlow(self) return HassLoginFlow(self)
@ -233,41 +234,41 @@ class HassAuthProvider(AuthProvider):
assert self.data is not None assert self.data is not None
await self.hass.async_add_executor_job( await self.hass.async_add_executor_job(
self.data.validate_login, username, password) self.data.validate_login, username, password
)
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]) -> Credentials: self, flow_result: Dict[str, str]
) -> Credentials:
"""Get credentials based on the flow result.""" """Get credentials based on the flow result."""
if self.data is None: if self.data is None:
await self.async_initialize() await self.async_initialize()
assert self.data is not None assert self.data is not None
norm_username = self.data.normalize_username norm_username = self.data.normalize_username
username = norm_username(flow_result['username']) username = norm_username(flow_result["username"])
for credential in await self.async_credentials(): for credential in await self.async_credentials():
if norm_username(credential.data['username']) == username: if norm_username(credential.data["username"]) == username:
return credential return credential
# Create new credentials. # Create new credentials.
return self.async_create_credentials({ return self.async_create_credentials({"username": username})
'username': username
})
async def async_user_meta_for_credentials( async def async_user_meta_for_credentials(
self, credentials: Credentials) -> UserMeta: self, credentials: Credentials
) -> UserMeta:
"""Get extra info for this credential.""" """Get extra info for this credential."""
return UserMeta(name=credentials.data['username'], is_active=True) return UserMeta(name=credentials.data["username"], is_active=True)
async def async_will_remove_credentials( async def async_will_remove_credentials(self, credentials: Credentials) -> None:
self, credentials: Credentials) -> None:
"""When credentials get removed, also remove the auth.""" """When credentials get removed, also remove the auth."""
if self.data is None: if self.data is None:
await self.async_initialize() await self.async_initialize()
assert self.data is not None assert self.data is not None
try: try:
self.data.async_remove_auth(credentials.data['username']) self.data.async_remove_auth(credentials.data["username"])
await self.data.async_save() await self.data.async_save()
except InvalidUser: except InvalidUser:
# Can happen if somehow we didn't clean up a credential # Can happen if somehow we didn't clean up a credential
@ -278,29 +279,27 @@ class HassLoginFlow(LoginFlow):
"""Handler for the login flow.""" """Handler for the login flow."""
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None
-> Dict[str, Any]: ) -> Dict[str, Any]:
"""Handle the step of the form.""" """Handle the step of the form."""
errors = {} errors = {}
if user_input is not None: if user_input is not None:
try: try:
await cast(HassAuthProvider, self._auth_provider)\ await cast(HassAuthProvider, self._auth_provider).async_validate_login(
.async_validate_login(user_input['username'], user_input["username"], user_input["password"]
user_input['password']) )
except InvalidAuth: except InvalidAuth:
errors['base'] = 'invalid_auth' errors["base"] = "invalid_auth"
if not errors: if not errors:
user_input.pop('password') user_input.pop("password")
return await self.async_finish(user_input) return await self.async_finish(user_input)
schema = OrderedDict() # type: Dict[str, type] schema = OrderedDict() # type: Dict[str, type]
schema['username'] = str schema["username"] = str
schema['password'] = str schema["password"] = str
return self.async_show_form( return self.async_show_form(
step_id='init', step_id="init", data_schema=vol.Schema(schema), errors=errors
data_schema=vol.Schema(schema),
errors=errors,
) )

View file

@ -12,23 +12,25 @@ from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
from ..models import Credentials, UserMeta from ..models import Credentials, UserMeta
USER_SCHEMA = vol.Schema({ USER_SCHEMA = vol.Schema(
vol.Required('username'): str, {
vol.Required('password'): str, vol.Required("username"): str,
vol.Optional('name'): str, vol.Required("password"): str,
}) vol.Optional("name"): str,
}
)
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({ CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend(
vol.Required('users'): [USER_SCHEMA] {vol.Required("users"): [USER_SCHEMA]}, extra=vol.PREVENT_EXTRA
}, extra=vol.PREVENT_EXTRA) )
class InvalidAuthError(HomeAssistantError): class InvalidAuthError(HomeAssistantError):
"""Raised when submitting invalid authentication.""" """Raised when submitting invalid authentication."""
@AUTH_PROVIDERS.register('insecure_example') @AUTH_PROVIDERS.register("insecure_example")
class ExampleAuthProvider(AuthProvider): class ExampleAuthProvider(AuthProvider):
"""Example auth provider based on hardcoded usernames and passwords.""" """Example auth provider based on hardcoded usernames and passwords."""
@ -42,47 +44,48 @@ class ExampleAuthProvider(AuthProvider):
user = None user = None
# Compare all users to avoid timing attacks. # Compare all users to avoid timing attacks.
for usr in self.config['users']: for usr in self.config["users"]:
if hmac.compare_digest(username.encode('utf-8'), if hmac.compare_digest(
usr['username'].encode('utf-8')): username.encode("utf-8"), usr["username"].encode("utf-8")
):
user = usr user = usr
if user is None: if user is None:
# Do one more compare to make timing the same as if user was found. # Do one more compare to make timing the same as if user was found.
hmac.compare_digest(password.encode('utf-8'), hmac.compare_digest(password.encode("utf-8"), password.encode("utf-8"))
password.encode('utf-8'))
raise InvalidAuthError raise InvalidAuthError
if not hmac.compare_digest(user['password'].encode('utf-8'), if not hmac.compare_digest(
password.encode('utf-8')): user["password"].encode("utf-8"), password.encode("utf-8")
):
raise InvalidAuthError raise InvalidAuthError
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]) -> Credentials: self, flow_result: Dict[str, str]
) -> Credentials:
"""Get credentials based on the flow result.""" """Get credentials based on the flow result."""
username = flow_result['username'] username = flow_result["username"]
for credential in await self.async_credentials(): for credential in await self.async_credentials():
if credential.data['username'] == username: if credential.data["username"] == username:
return credential return credential
# Create new credentials. # Create new credentials.
return self.async_create_credentials({ return self.async_create_credentials({"username": username})
'username': username
})
async def async_user_meta_for_credentials( async def async_user_meta_for_credentials(
self, credentials: Credentials) -> UserMeta: self, credentials: Credentials
) -> UserMeta:
"""Return extra user metadata for credentials. """Return extra user metadata for credentials.
Will be used to populate info when creating a new user. Will be used to populate info when creating a new user.
""" """
username = credentials.data['username'] username = credentials.data["username"]
name = None name = None
for user in self.config['users']: for user in self.config["users"]:
if user['username'] == username: if user["username"] == username:
name = user.get('name') name = user.get("name")
break break
return UserMeta(name=name, is_active=True) return UserMeta(name=name, is_active=True)
@ -92,29 +95,27 @@ class ExampleLoginFlow(LoginFlow):
"""Handler for the login flow.""" """Handler for the login flow."""
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None
-> Dict[str, Any]: ) -> Dict[str, Any]:
"""Handle the step of the form.""" """Handle the step of the form."""
errors = {} errors = {}
if user_input is not None: if user_input is not None:
try: try:
cast(ExampleAuthProvider, self._auth_provider)\ cast(ExampleAuthProvider, self._auth_provider).async_validate_login(
.async_validate_login(user_input['username'], user_input["username"], user_input["password"]
user_input['password']) )
except InvalidAuthError: except InvalidAuthError:
errors['base'] = 'invalid_auth' errors["base"] = "invalid_auth"
if not errors: if not errors:
user_input.pop('password') user_input.pop("password")
return await self.async_finish(user_input) return await self.async_finish(user_input)
schema = OrderedDict() # type: Dict[str, type] schema = OrderedDict() # type: Dict[str, type]
schema['username'] = str schema["username"] = str
schema['password'] = str schema["password"] = str
return self.async_show_form( return self.async_show_form(
step_id='init', step_id="init", data_schema=vol.Schema(schema), errors=errors
data_schema=vol.Schema(schema),
errors=errors,
) )

View file

@ -16,27 +16,26 @@ from . import AuthProvider, AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, LoginFlow
from .. import AuthManager from .. import AuthManager
from ..models import Credentials, UserMeta, User from ..models import Credentials, UserMeta, User
AUTH_PROVIDER_TYPE = 'legacy_api_password' AUTH_PROVIDER_TYPE = "legacy_api_password"
CONF_API_PASSWORD = 'api_password' CONF_API_PASSWORD = "api_password"
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({ CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend(
vol.Required(CONF_API_PASSWORD): cv.string, {vol.Required(CONF_API_PASSWORD): cv.string}, extra=vol.PREVENT_EXTRA
}, extra=vol.PREVENT_EXTRA) )
LEGACY_USER_NAME = 'Legacy API password user' LEGACY_USER_NAME = "Legacy API password user"
class InvalidAuthError(HomeAssistantError): class InvalidAuthError(HomeAssistantError):
"""Raised when submitting invalid authentication.""" """Raised when submitting invalid authentication."""
async def async_validate_password(hass: HomeAssistant, password: str)\ async def async_validate_password(hass: HomeAssistant, password: str) -> Optional[User]:
-> Optional[User]:
"""Return a user if password is valid. None if not.""" """Return a user if password is valid. None if not."""
auth = cast(AuthManager, hass.auth) # type: ignore auth = cast(AuthManager, hass.auth) # type: ignore
providers = auth.get_auth_providers(AUTH_PROVIDER_TYPE) providers = auth.get_auth_providers(AUTH_PROVIDER_TYPE)
if not providers: if not providers:
raise ValueError('Legacy API password provider not found') raise ValueError("Legacy API password provider not found")
try: try:
provider = cast(LegacyApiPasswordAuthProvider, providers[0]) provider = cast(LegacyApiPasswordAuthProvider, providers[0])
@ -52,7 +51,7 @@ async def async_validate_password(hass: HomeAssistant, password: str)\
class LegacyApiPasswordAuthProvider(AuthProvider): class LegacyApiPasswordAuthProvider(AuthProvider):
"""An auth provider support legacy api_password.""" """An auth provider support legacy api_password."""
DEFAULT_TITLE = 'Legacy API Password' DEFAULT_TITLE = "Legacy API Password"
@property @property
def api_password(self) -> str: def api_password(self) -> str:
@ -68,12 +67,14 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
"""Validate password.""" """Validate password."""
api_password = str(self.config[CONF_API_PASSWORD]) api_password = str(self.config[CONF_API_PASSWORD])
if not hmac.compare_digest(api_password.encode('utf-8'), if not hmac.compare_digest(
password.encode('utf-8')): api_password.encode("utf-8"), password.encode("utf-8")
):
raise InvalidAuthError raise InvalidAuthError
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]) -> Credentials: self, flow_result: Dict[str, str]
) -> Credentials:
"""Return credentials for this login.""" """Return credentials for this login."""
credentials = await self.async_credentials() credentials = await self.async_credentials()
if credentials: if credentials:
@ -82,7 +83,8 @@ class LegacyApiPasswordAuthProvider(AuthProvider):
return self.async_create_credentials({}) return self.async_create_credentials({})
async def async_user_meta_for_credentials( async def async_user_meta_for_credentials(
self, credentials: Credentials) -> UserMeta: self, credentials: Credentials
) -> UserMeta:
""" """
Return info for the user. Return info for the user.
@ -95,23 +97,22 @@ class LegacyLoginFlow(LoginFlow):
"""Handler for the login flow.""" """Handler for the login flow."""
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None
-> Dict[str, Any]: ) -> Dict[str, Any]:
"""Handle the step of the form.""" """Handle the step of the form."""
errors = {} errors = {}
if user_input is not None: if user_input is not None:
try: try:
cast(LegacyApiPasswordAuthProvider, self._auth_provider)\ cast(
.async_validate_login(user_input['password']) LegacyApiPasswordAuthProvider, self._auth_provider
).async_validate_login(user_input["password"])
except InvalidAuthError: except InvalidAuthError:
errors['base'] = 'invalid_auth' errors["base"] = "invalid_auth"
if not errors: if not errors:
return await self.async_finish({}) return await self.async_finish({})
return self.async_show_form( return self.async_show_form(
step_id='init', step_id="init", data_schema=vol.Schema({"password": str}), errors=errors
data_schema=vol.Schema({'password': str}),
errors=errors,
) )

View file

@ -3,8 +3,7 @@
It shows list of users if access from trusted network. It shows list of users if access from trusted network.
Abort login flow if not access from trusted network. Abort login flow if not access from trusted network.
""" """
from ipaddress import ip_network, IPv4Address, IPv6Address, IPv4Network,\ from ipaddress import ip_network, IPv4Address, IPv6Address, IPv4Network, IPv6Network
IPv6Network
from typing import Any, Dict, List, Optional, Union, cast from typing import Any, Dict, List, Optional, Union, cast
import voluptuous as vol import voluptuous as vol
@ -18,27 +17,32 @@ from ..models import Credentials, UserMeta
IPAddress = Union[IPv4Address, IPv6Address] IPAddress = Union[IPv4Address, IPv6Address]
IPNetwork = Union[IPv4Network, IPv6Network] IPNetwork = Union[IPv4Network, IPv6Network]
CONF_TRUSTED_NETWORKS = 'trusted_networks' CONF_TRUSTED_NETWORKS = "trusted_networks"
CONF_TRUSTED_USERS = 'trusted_users' CONF_TRUSTED_USERS = "trusted_users"
CONF_GROUP = 'group' CONF_GROUP = "group"
CONF_ALLOW_BYPASS_LOGIN = 'allow_bypass_login' CONF_ALLOW_BYPASS_LOGIN = "allow_bypass_login"
CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend({ CONFIG_SCHEMA = AUTH_PROVIDER_SCHEMA.extend(
vol.Required(CONF_TRUSTED_NETWORKS): vol.All( {
cv.ensure_list, [ip_network] vol.Required(CONF_TRUSTED_NETWORKS): vol.All(cv.ensure_list, [ip_network]),
), vol.Optional(CONF_TRUSTED_USERS, default={}): vol.Schema(
vol.Optional(CONF_TRUSTED_USERS, default={}): vol.Schema( # we only validate the format of user_id or group_id
# we only validate the format of user_id or group_id {
{ip_network: vol.All( ip_network: vol.All(
cv.ensure_list, cv.ensure_list,
[vol.Or( [
cv.uuid4_hex, vol.Or(
vol.Schema({vol.Required(CONF_GROUP): cv.uuid4_hex}), cv.uuid4_hex,
)], vol.Schema({vol.Required(CONF_GROUP): cv.uuid4_hex}),
)} )
), ],
vol.Optional(CONF_ALLOW_BYPASS_LOGIN, default=False): cv.boolean, )
}, extra=vol.PREVENT_EXTRA) }
),
vol.Optional(CONF_ALLOW_BYPASS_LOGIN, default=False): cv.boolean,
},
extra=vol.PREVENT_EXTRA,
)
class InvalidAuthError(HomeAssistantError): class InvalidAuthError(HomeAssistantError):
@ -49,14 +53,14 @@ class InvalidUserError(HomeAssistantError):
"""Raised when try to login as invalid user.""" """Raised when try to login as invalid user."""
@AUTH_PROVIDERS.register('trusted_networks') @AUTH_PROVIDERS.register("trusted_networks")
class TrustedNetworksAuthProvider(AuthProvider): class TrustedNetworksAuthProvider(AuthProvider):
"""Trusted Networks auth provider. """Trusted Networks auth provider.
Allow passwordless access from trusted network. Allow passwordless access from trusted network.
""" """
DEFAULT_TITLE = 'Trusted Networks' DEFAULT_TITLE = "Trusted Networks"
@property @property
def trusted_networks(self) -> List[IPNetwork]: def trusted_networks(self) -> List[IPNetwork]:
@ -76,49 +80,58 @@ class TrustedNetworksAuthProvider(AuthProvider):
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow: async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
"""Return a flow to login.""" """Return a flow to login."""
assert context is not None assert context is not None
ip_addr = cast(IPAddress, context.get('ip_address')) ip_addr = cast(IPAddress, context.get("ip_address"))
users = await self.store.async_get_users() users = await self.store.async_get_users()
available_users = [user for user in users available_users = [
if not user.system_generated and user.is_active] user for user in users if not user.system_generated and user.is_active
]
for ip_net, user_or_group_list in self.trusted_users.items(): for ip_net, user_or_group_list in self.trusted_users.items():
if ip_addr in ip_net: if ip_addr in ip_net:
user_list = [user_id for user_id in user_or_group_list user_list = [
if isinstance(user_id, str)] user_id
group_list = [group[CONF_GROUP] for group in user_or_group_list for user_id in user_or_group_list
if isinstance(group, dict)] if isinstance(user_id, str)
flattened_group_list = [group for sublist in group_list ]
for group in sublist] group_list = [
group[CONF_GROUP]
for group in user_or_group_list
if isinstance(group, dict)
]
flattened_group_list = [
group for sublist in group_list for group in sublist
]
available_users = [ available_users = [
user for user in available_users user
if (user.id in user_list or for user in available_users
any([group.id in flattened_group_list if (
for group in user.groups])) user.id in user_list
or any(
[group.id in flattened_group_list for group in user.groups]
)
)
] ]
break break
return TrustedNetworksLoginFlow( return TrustedNetworksLoginFlow(
self, self,
ip_addr, ip_addr,
{ {user.id: user.name for user in available_users},
user.id: user.name for user in available_users
},
self.config[CONF_ALLOW_BYPASS_LOGIN], self.config[CONF_ALLOW_BYPASS_LOGIN],
) )
async def async_get_or_create_credentials( async def async_get_or_create_credentials(
self, flow_result: Dict[str, str]) -> Credentials: self, flow_result: Dict[str, str]
) -> Credentials:
"""Get credentials based on the flow result.""" """Get credentials based on the flow result."""
user_id = flow_result['user'] user_id = flow_result["user"]
users = await self.store.async_get_users() users = await self.store.async_get_users()
for user in users: for user in users:
if (not user.system_generated and if not user.system_generated and user.is_active and user.id == user_id:
user.is_active and
user.id == user_id):
for credential in await self.async_credentials(): for credential in await self.async_credentials():
if credential.data['user_id'] == user_id: if credential.data["user_id"] == user_id:
return credential return credential
cred = self.async_create_credentials({'user_id': user_id}) cred = self.async_create_credentials({"user_id": user_id})
await self.store.async_link_user(user, cred) await self.store.async_link_user(user, cred)
return cred return cred
@ -126,7 +139,8 @@ class TrustedNetworksAuthProvider(AuthProvider):
raise InvalidUserError raise InvalidUserError
async def async_user_meta_for_credentials( async def async_user_meta_for_credentials(
self, credentials: Credentials) -> UserMeta: self, credentials: Credentials
) -> UserMeta:
"""Return extra user metadata for credentials. """Return extra user metadata for credentials.
Trusted network auth provider should never create new user. Trusted network auth provider should never create new user.
@ -141,20 +155,24 @@ class TrustedNetworksAuthProvider(AuthProvider):
Raise InvalidAuthError if trusted_networks is not configured. Raise InvalidAuthError if trusted_networks is not configured.
""" """
if not self.trusted_networks: if not self.trusted_networks:
raise InvalidAuthError('trusted_networks is not configured') raise InvalidAuthError("trusted_networks is not configured")
if not any(ip_addr in trusted_network for trusted_network if not any(
in self.trusted_networks): ip_addr in trusted_network for trusted_network in self.trusted_networks
raise InvalidAuthError('Not in trusted_networks') ):
raise InvalidAuthError("Not in trusted_networks")
class TrustedNetworksLoginFlow(LoginFlow): class TrustedNetworksLoginFlow(LoginFlow):
"""Handler for the login flow.""" """Handler for the login flow."""
def __init__(self, auth_provider: TrustedNetworksAuthProvider, def __init__(
ip_addr: IPAddress, self,
available_users: Dict[str, Optional[str]], auth_provider: TrustedNetworksAuthProvider,
allow_bypass_login: bool) -> None: ip_addr: IPAddress,
available_users: Dict[str, Optional[str]],
allow_bypass_login: bool,
) -> None:
"""Initialize the login flow.""" """Initialize the login flow."""
super().__init__(auth_provider) super().__init__(auth_provider)
self._available_users = available_users self._available_users = available_users
@ -162,27 +180,26 @@ class TrustedNetworksLoginFlow(LoginFlow):
self._allow_bypass_login = allow_bypass_login self._allow_bypass_login = allow_bypass_login
async def async_step_init( async def async_step_init(
self, user_input: Optional[Dict[str, str]] = None) \ self, user_input: Optional[Dict[str, str]] = None
-> Dict[str, Any]: ) -> Dict[str, Any]:
"""Handle the step of the form.""" """Handle the step of the form."""
try: try:
cast(TrustedNetworksAuthProvider, self._auth_provider)\ cast(
.async_validate_access(self._ip_address) TrustedNetworksAuthProvider, self._auth_provider
).async_validate_access(self._ip_address)
except InvalidAuthError: except InvalidAuthError:
return self.async_abort( return self.async_abort(reason="not_whitelisted")
reason='not_whitelisted'
)
if user_input is not None: if user_input is not None:
return await self.async_finish(user_input) return await self.async_finish(user_input)
if self._allow_bypass_login and len(self._available_users) == 1: if self._allow_bypass_login and len(self._available_users) == 1:
return await self.async_finish({ return await self.async_finish(
'user': next(iter(self._available_users.keys())) {"user": next(iter(self._available_users.keys()))}
}) )
return self.async_show_form( return self.async_show_form(
step_id='init', step_id="init",
data_schema=vol.Schema({'user': vol.In(self._available_users)}), data_schema=vol.Schema({"user": vol.In(self._available_users)}),
) )

View file

@ -10,4 +10,4 @@ def generate_secret(entropy: int = 32) -> str:
Event loop friendly. Event loop friendly.
""" """
return binascii.hexlify(os.urandom(entropy)).decode('ascii') return binascii.hexlify(os.urandom(entropy)).decode("ascii")

View file

@ -20,32 +20,33 @@ from homeassistant.exceptions import HomeAssistantError
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ERROR_LOG_FILENAME = 'home-assistant.log' ERROR_LOG_FILENAME = "home-assistant.log"
# hass.data key for logging information. # hass.data key for logging information.
DATA_LOGGING = 'logging' DATA_LOGGING = "logging"
DEBUGGER_INTEGRATIONS = {'ptvsd', } DEBUGGER_INTEGRATIONS = {"ptvsd"}
CORE_INTEGRATIONS = ('homeassistant', 'persistent_notification') CORE_INTEGRATIONS = ("homeassistant", "persistent_notification")
LOGGING_INTEGRATIONS = {'logger', 'system_log'} LOGGING_INTEGRATIONS = {"logger", "system_log"}
STAGE_1_INTEGRATIONS = { STAGE_1_INTEGRATIONS = {
# To record data # To record data
'recorder', "recorder",
# To make sure we forward data to other instances # To make sure we forward data to other instances
'mqtt_eventstream', "mqtt_eventstream",
} }
async def async_from_config_dict(config: Dict[str, Any], async def async_from_config_dict(
hass: core.HomeAssistant, config: Dict[str, Any],
config_dir: Optional[str] = None, hass: core.HomeAssistant,
enable_log: bool = True, config_dir: Optional[str] = None,
verbose: bool = False, enable_log: bool = True,
skip_pip: bool = False, verbose: bool = False,
log_rotate_days: Any = None, skip_pip: bool = False,
log_file: Any = None, log_rotate_days: Any = None,
log_no_color: bool = False) \ log_file: Any = None,
-> Optional[core.HomeAssistant]: log_no_color: bool = False,
) -> Optional[core.HomeAssistant]:
"""Try to configure Home Assistant from a configuration dictionary. """Try to configure Home Assistant from a configuration dictionary.
Dynamically loads required components and its dependencies. Dynamically loads required components and its dependencies.
@ -54,28 +55,30 @@ async def async_from_config_dict(config: Dict[str, Any],
start = time() start = time()
if enable_log: if enable_log:
async_enable_logging(hass, verbose, log_rotate_days, log_file, async_enable_logging(hass, verbose, log_rotate_days, log_file, log_no_color)
log_no_color)
hass.config.skip_pip = skip_pip hass.config.skip_pip = skip_pip
if skip_pip: if skip_pip:
_LOGGER.warning("Skipping pip installation of required modules. " _LOGGER.warning(
"This may cause issues") "Skipping pip installation of required modules. " "This may cause issues"
)
core_config = config.get(core.DOMAIN, {}) core_config = config.get(core.DOMAIN, {})
api_password = config.get('http', {}).get('api_password') api_password = config.get("http", {}).get("api_password")
trusted_networks = config.get('http', {}).get('trusted_networks') trusted_networks = config.get("http", {}).get("trusted_networks")
try: try:
await conf_util.async_process_ha_core_config( await conf_util.async_process_ha_core_config(
hass, core_config, api_password, trusted_networks) hass, core_config, api_password, trusted_networks
)
except vol.Invalid as config_err: except vol.Invalid as config_err:
conf_util.async_log_exception( conf_util.async_log_exception(config_err, "homeassistant", core_config, hass)
config_err, 'homeassistant', core_config, hass)
return None return None
except HomeAssistantError: except HomeAssistantError:
_LOGGER.error("Home Assistant core failed to initialize. " _LOGGER.error(
"Further initialization aborted") "Home Assistant core failed to initialize. "
"Further initialization aborted"
)
return None return None
# Make a copy because we are mutating it. # Make a copy because we are mutating it.
@ -83,7 +86,8 @@ async def async_from_config_dict(config: Dict[str, Any],
# Merge packages # Merge packages
await conf_util.merge_packages_config( await conf_util.merge_packages_config(
hass, config, core_config.get(conf_util.CONF_PACKAGES, {})) hass, config, core_config.get(conf_util.CONF_PACKAGES, {})
)
hass.config_entries = config_entries.ConfigEntries(hass, config) hass.config_entries = config_entries.ConfigEntries(hass, config)
await hass.config_entries.async_initialize() await hass.config_entries.async_initialize()
@ -91,19 +95,20 @@ async def async_from_config_dict(config: Dict[str, Any],
await _async_set_up_integrations(hass, config) await _async_set_up_integrations(hass, config)
stop = time() stop = time()
_LOGGER.info("Home Assistant initialized in %.2fs", stop-start) _LOGGER.info("Home Assistant initialized in %.2fs", stop - start)
return hass return hass
async def async_from_config_file(config_path: str, async def async_from_config_file(
hass: core.HomeAssistant, config_path: str,
verbose: bool = False, hass: core.HomeAssistant,
skip_pip: bool = True, verbose: bool = False,
log_rotate_days: Any = None, skip_pip: bool = True,
log_file: Any = None, log_rotate_days: Any = None,
log_no_color: bool = False)\ log_file: Any = None,
-> Optional[core.HomeAssistant]: log_no_color: bool = False,
) -> Optional[core.HomeAssistant]:
"""Read the configuration file and try to start all the functionality. """Read the configuration file and try to start all the functionality.
Will add functionality to 'hass' parameter. Will add functionality to 'hass' parameter.
@ -116,15 +121,14 @@ async def async_from_config_file(config_path: str,
if not is_virtual_env(): if not is_virtual_env():
await async_mount_local_lib_path(config_dir) await async_mount_local_lib_path(config_dir)
async_enable_logging(hass, verbose, log_rotate_days, log_file, async_enable_logging(hass, verbose, log_rotate_days, log_file, log_no_color)
log_no_color)
await hass.async_add_executor_job( await hass.async_add_executor_job(conf_util.process_ha_config_upgrade, hass)
conf_util.process_ha_config_upgrade, hass)
try: try:
config_dict = await hass.async_add_executor_job( config_dict = await hass.async_add_executor_job(
conf_util.load_yaml_config_file, config_path) conf_util.load_yaml_config_file, config_path
)
except HomeAssistantError as err: except HomeAssistantError as err:
_LOGGER.error("Error loading %s: %s", config_path, err) _LOGGER.error("Error loading %s: %s", config_path, err)
return None return None
@ -132,43 +136,48 @@ async def async_from_config_file(config_path: str,
clear_secret_cache() clear_secret_cache()
return await async_from_config_dict( return await async_from_config_dict(
config_dict, hass, enable_log=False, skip_pip=skip_pip) config_dict, hass, enable_log=False, skip_pip=skip_pip
)
@core.callback @core.callback
def async_enable_logging(hass: core.HomeAssistant, def async_enable_logging(
verbose: bool = False, hass: core.HomeAssistant,
log_rotate_days: Optional[int] = None, verbose: bool = False,
log_file: Optional[str] = None, log_rotate_days: Optional[int] = None,
log_no_color: bool = False) -> None: log_file: Optional[str] = None,
log_no_color: bool = False,
) -> None:
"""Set up the logging. """Set up the logging.
This method must be run in the event loop. This method must be run in the event loop.
""" """
fmt = ("%(asctime)s %(levelname)s (%(threadName)s) " fmt = "%(asctime)s %(levelname)s (%(threadName)s) " "[%(name)s] %(message)s"
"[%(name)s] %(message)s") datefmt = "%Y-%m-%d %H:%M:%S"
datefmt = '%Y-%m-%d %H:%M:%S'
if not log_no_color: if not log_no_color:
try: try:
from colorlog import ColoredFormatter from colorlog import ColoredFormatter
# basicConfig must be called after importing colorlog in order to # basicConfig must be called after importing colorlog in order to
# ensure that the handlers it sets up wraps the correct streams. # ensure that the handlers it sets up wraps the correct streams.
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
colorfmt = "%(log_color)s{}%(reset)s".format(fmt) colorfmt = "%(log_color)s{}%(reset)s".format(fmt)
logging.getLogger().handlers[0].setFormatter(ColoredFormatter( logging.getLogger().handlers[0].setFormatter(
colorfmt, ColoredFormatter(
datefmt=datefmt, colorfmt,
reset=True, datefmt=datefmt,
log_colors={ reset=True,
'DEBUG': 'cyan', log_colors={
'INFO': 'green', "DEBUG": "cyan",
'WARNING': 'yellow', "INFO": "green",
'ERROR': 'red', "WARNING": "yellow",
'CRITICAL': 'red', "ERROR": "red",
} "CRITICAL": "red",
)) },
)
)
except ImportError: except ImportError:
pass pass
@ -177,9 +186,9 @@ def async_enable_logging(hass: core.HomeAssistant,
logging.basicConfig(format=fmt, datefmt=datefmt, level=logging.INFO) logging.basicConfig(format=fmt, datefmt=datefmt, level=logging.INFO)
# Suppress overly verbose logs from libraries that aren't helpful # Suppress overly verbose logs from libraries that aren't helpful
logging.getLogger('requests').setLevel(logging.WARNING) logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger('urllib3').setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger('aiohttp.access').setLevel(logging.WARNING) logging.getLogger("aiohttp.access").setLevel(logging.WARNING)
# Log errors to a file if we have write access to file or config dir # Log errors to a file if we have write access to file or config dir
if log_file is None: if log_file is None:
@ -192,16 +201,16 @@ def async_enable_logging(hass: core.HomeAssistant,
# Check if we can write to the error log if it exists or that # Check if we can write to the error log if it exists or that
# we can create files in the containing directory if not. # we can create files in the containing directory if not.
if (err_path_exists and os.access(err_log_path, os.W_OK)) or \ if (err_path_exists and os.access(err_log_path, os.W_OK)) or (
(not err_path_exists and os.access(err_dir, os.W_OK)): not err_path_exists and os.access(err_dir, os.W_OK)
):
if log_rotate_days: if log_rotate_days:
err_handler = logging.handlers.TimedRotatingFileHandler( err_handler = logging.handlers.TimedRotatingFileHandler(
err_log_path, when='midnight', err_log_path, when="midnight", backupCount=log_rotate_days
backupCount=log_rotate_days) # type: logging.FileHandler ) # type: logging.FileHandler
else: else:
err_handler = logging.FileHandler( err_handler = logging.FileHandler(err_log_path, mode="w", delay=True)
err_log_path, mode='w', delay=True)
err_handler.setLevel(logging.INFO if verbose else logging.WARNING) err_handler.setLevel(logging.INFO if verbose else logging.WARNING)
err_handler.setFormatter(logging.Formatter(fmt, datefmt=datefmt)) err_handler.setFormatter(logging.Formatter(fmt, datefmt=datefmt))
@ -210,21 +219,19 @@ def async_enable_logging(hass: core.HomeAssistant,
async def async_stop_async_handler(_: Any) -> None: async def async_stop_async_handler(_: Any) -> None:
"""Cleanup async handler.""" """Cleanup async handler."""
logging.getLogger('').removeHandler(async_handler) # type: ignore logging.getLogger("").removeHandler(async_handler) # type: ignore
await async_handler.async_close(blocking=True) await async_handler.async_close(blocking=True)
hass.bus.async_listen_once( hass.bus.async_listen_once(EVENT_HOMEASSISTANT_CLOSE, async_stop_async_handler)
EVENT_HOMEASSISTANT_CLOSE, async_stop_async_handler)
logger = logging.getLogger('') logger = logging.getLogger("")
logger.addHandler(async_handler) # type: ignore logger.addHandler(async_handler) # type: ignore
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
# Save the log file location for access by other components. # Save the log file location for access by other components.
hass.data[DATA_LOGGING] = err_log_path hass.data[DATA_LOGGING] = err_log_path
else: else:
_LOGGER.error( _LOGGER.error("Unable to set up error log %s (access denied)", err_log_path)
"Unable to set up error log %s (access denied)", err_log_path)
async def async_mount_local_lib_path(config_dir: str) -> str: async def async_mount_local_lib_path(config_dir: str) -> str:
@ -232,7 +239,7 @@ async def async_mount_local_lib_path(config_dir: str) -> str:
This function is a coroutine. This function is a coroutine.
""" """
deps_dir = os.path.join(config_dir, 'deps') deps_dir = os.path.join(config_dir, "deps")
lib_dir = await async_get_user_site(deps_dir) lib_dir = await async_get_user_site(deps_dir)
if lib_dir not in sys.path: if lib_dir not in sys.path:
sys.path.insert(0, lib_dir) sys.path.insert(0, lib_dir)
@ -243,21 +250,21 @@ async def async_mount_local_lib_path(config_dir: str) -> str:
def _get_domains(hass: core.HomeAssistant, config: Dict[str, Any]) -> Set[str]: def _get_domains(hass: core.HomeAssistant, config: Dict[str, Any]) -> Set[str]:
"""Get domains of components to set up.""" """Get domains of components to set up."""
# Filter out the repeating and common config section [homeassistant] # Filter out the repeating and common config section [homeassistant]
domains = set(key.split(' ')[0] for key in config.keys() domains = set(key.split(" ")[0] for key in config.keys() if key != core.DOMAIN)
if key != core.DOMAIN)
# Add config entry domains # Add config entry domains
domains.update(hass.config_entries.async_domains()) # type: ignore domains.update(hass.config_entries.async_domains()) # type: ignore
# Make sure the Hass.io component is loaded # Make sure the Hass.io component is loaded
if 'HASSIO' in os.environ: if "HASSIO" in os.environ:
domains.add('hassio') domains.add("hassio")
return domains return domains
async def _async_set_up_integrations( async def _async_set_up_integrations(
hass: core.HomeAssistant, config: Dict[str, Any]) -> None: hass: core.HomeAssistant, config: Dict[str, Any]
) -> None:
"""Set up all the integrations.""" """Set up all the integrations."""
domains = _get_domains(hass, config) domains = _get_domains(hass, config)
@ -265,27 +272,33 @@ async def _async_set_up_integrations(
debuggers = domains & DEBUGGER_INTEGRATIONS debuggers = domains & DEBUGGER_INTEGRATIONS
if debuggers: if debuggers:
_LOGGER.debug("Starting up debuggers %s", debuggers) _LOGGER.debug("Starting up debuggers %s", debuggers)
await asyncio.gather(*( await asyncio.gather(
async_setup_component(hass, domain, config) *(async_setup_component(hass, domain, config) for domain in debuggers)
for domain in debuggers)) )
domains -= DEBUGGER_INTEGRATIONS domains -= DEBUGGER_INTEGRATIONS
# Resolve all dependencies of all components so we can find the logging # Resolve all dependencies of all components so we can find the logging
# and integrations that need faster initialization. # and integrations that need faster initialization.
resolved_domains_task = asyncio.gather(*( resolved_domains_task = asyncio.gather(
loader.async_component_dependencies(hass, domain) *(loader.async_component_dependencies(hass, domain) for domain in domains),
for domain in domains return_exceptions=True,
), return_exceptions=True) )
# Set up core. # Set up core.
_LOGGER.debug("Setting up %s", CORE_INTEGRATIONS) _LOGGER.debug("Setting up %s", CORE_INTEGRATIONS)
if not all(await asyncio.gather(*( if not all(
async_setup_component(hass, domain, config) await asyncio.gather(
for domain in CORE_INTEGRATIONS *(
))): async_setup_component(hass, domain, config)
_LOGGER.error("Home Assistant core failed to initialize. " for domain in CORE_INTEGRATIONS
"Further initialization aborted") )
)
):
_LOGGER.error(
"Home Assistant core failed to initialize. "
"Further initialization aborted"
)
return return
_LOGGER.debug("Home Assistant core initialized") _LOGGER.debug("Home Assistant core initialized")
@ -305,36 +318,32 @@ async def _async_set_up_integrations(
if logging_domains: if logging_domains:
_LOGGER.info("Setting up %s", logging_domains) _LOGGER.info("Setting up %s", logging_domains)
await asyncio.gather(*( await asyncio.gather(
async_setup_component(hass, domain, config) *(async_setup_component(hass, domain, config) for domain in logging_domains)
for domain in logging_domains )
))
# Kick off loading the registries. They don't need to be awaited. # Kick off loading the registries. They don't need to be awaited.
asyncio.gather( asyncio.gather(
hass.helpers.device_registry.async_get_registry(), hass.helpers.device_registry.async_get_registry(),
hass.helpers.entity_registry.async_get_registry(), hass.helpers.entity_registry.async_get_registry(),
hass.helpers.area_registry.async_get_registry()) hass.helpers.area_registry.async_get_registry(),
)
if stage_1_domains: if stage_1_domains:
await asyncio.gather(*( await asyncio.gather(
async_setup_component(hass, domain, config) *(async_setup_component(hass, domain, config) for domain in stage_1_domains)
for domain in stage_1_domains )
))
# Load all integrations # Load all integrations
after_dependencies = {} # type: Dict[str, Set[str]] after_dependencies = {} # type: Dict[str, Set[str]]
for int_or_exc in await asyncio.gather(*( for int_or_exc in await asyncio.gather(
loader.async_get_integration(hass, domain) *(loader.async_get_integration(hass, domain) for domain in stage_2_domains),
for domain in stage_2_domains return_exceptions=True,
), return_exceptions=True): ):
# Exceptions are handled in async_setup_component. # Exceptions are handled in async_setup_component.
if (isinstance(int_or_exc, loader.Integration) and if isinstance(int_or_exc, loader.Integration) and int_or_exc.after_dependencies:
int_or_exc.after_dependencies): after_dependencies[int_or_exc.domain] = set(int_or_exc.after_dependencies)
after_dependencies[int_or_exc.domain] = set(
int_or_exc.after_dependencies
)
last_load = None last_load = None
while stage_2_domains: while stage_2_domains:
@ -344,8 +353,7 @@ async def _async_set_up_integrations(
after_deps = after_dependencies.get(domain) after_deps = after_dependencies.get(domain)
# Load if integration has no after_dependencies or they are # Load if integration has no after_dependencies or they are
# all loaded # all loaded
if (not after_deps or if not after_deps or not after_deps - hass.config.components:
not after_deps-hass.config.components):
domains_to_load.add(domain) domains_to_load.add(domain)
if not domains_to_load or domains_to_load == last_load: if not domains_to_load or domains_to_load == last_load:
@ -353,10 +361,9 @@ async def _async_set_up_integrations(
_LOGGER.debug("Setting up %s", domains_to_load) _LOGGER.debug("Setting up %s", domains_to_load)
await asyncio.gather(*( await asyncio.gather(
async_setup_component(hass, domain, config) *(async_setup_component(hass, domain, config) for domain in domains_to_load)
for domain in domains_to_load )
))
last_load = domains_to_load last_load = domains_to_load
stage_2_domains -= domains_to_load stage_2_domains -= domains_to_load
@ -366,10 +373,9 @@ async def _async_set_up_integrations(
if stage_2_domains: if stage_2_domains:
_LOGGER.debug("Final set up: %s", stage_2_domains) _LOGGER.debug("Final set up: %s", stage_2_domains)
await asyncio.gather(*( await asyncio.gather(
async_setup_component(hass, domain, config) *(async_setup_component(hass, domain, config) for domain in stage_2_domains)
for domain in stage_2_domains )
))
# Wrap up startup # Wrap up startup
await hass.async_block_till_done() await hass.async_block_till_done()

View file

@ -31,11 +31,10 @@ def is_on(hass, entity_id=None):
component = getattr(hass.components, domain) component = getattr(hass.components, domain)
except ImportError: except ImportError:
_LOGGER.error('Failed to call %s.is_on: component not found', _LOGGER.error("Failed to call %s.is_on: component not found", domain)
domain)
continue continue
if not hasattr(component, 'is_on'): if not hasattr(component, "is_on"):
_LOGGER.warning("Integration %s has no is_on method.", domain) _LOGGER.warning("Integration %s has no is_on method.", domain)
continue continue

View file

@ -6,9 +6,18 @@ from requests.exceptions import HTTPError, ConnectTimeout
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
ATTR_ATTRIBUTION, ATTR_DATE, ATTR_TIME, ATTR_ENTITY_ID, CONF_USERNAME, ATTR_ATTRIBUTION,
CONF_PASSWORD, CONF_EXCLUDE, CONF_NAME, CONF_LIGHTS, ATTR_DATE,
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START) ATTR_TIME,
ATTR_ENTITY_ID,
CONF_USERNAME,
CONF_PASSWORD,
CONF_EXCLUDE,
CONF_NAME,
CONF_LIGHTS,
EVENT_HOMEASSISTANT_STOP,
EVENT_HOMEASSISTANT_START,
)
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers import discovery from homeassistant.helpers import discovery
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
@ -17,77 +26,88 @@ _LOGGER = logging.getLogger(__name__)
ATTRIBUTION = "Data provided by goabode.com" ATTRIBUTION = "Data provided by goabode.com"
CONF_POLLING = 'polling' CONF_POLLING = "polling"
DOMAIN = 'abode' DOMAIN = "abode"
DEFAULT_CACHEDB = './abodepy_cache.pickle' DEFAULT_CACHEDB = "./abodepy_cache.pickle"
NOTIFICATION_ID = 'abode_notification' NOTIFICATION_ID = "abode_notification"
NOTIFICATION_TITLE = 'Abode Security Setup' NOTIFICATION_TITLE = "Abode Security Setup"
EVENT_ABODE_ALARM = 'abode_alarm' EVENT_ABODE_ALARM = "abode_alarm"
EVENT_ABODE_ALARM_END = 'abode_alarm_end' EVENT_ABODE_ALARM_END = "abode_alarm_end"
EVENT_ABODE_AUTOMATION = 'abode_automation' EVENT_ABODE_AUTOMATION = "abode_automation"
EVENT_ABODE_FAULT = 'abode_panel_fault' EVENT_ABODE_FAULT = "abode_panel_fault"
EVENT_ABODE_RESTORE = 'abode_panel_restore' EVENT_ABODE_RESTORE = "abode_panel_restore"
SERVICE_SETTINGS = 'change_setting' SERVICE_SETTINGS = "change_setting"
SERVICE_CAPTURE_IMAGE = 'capture_image' SERVICE_CAPTURE_IMAGE = "capture_image"
SERVICE_TRIGGER = 'trigger_quick_action' SERVICE_TRIGGER = "trigger_quick_action"
ATTR_DEVICE_ID = 'device_id' ATTR_DEVICE_ID = "device_id"
ATTR_DEVICE_NAME = 'device_name' ATTR_DEVICE_NAME = "device_name"
ATTR_DEVICE_TYPE = 'device_type' ATTR_DEVICE_TYPE = "device_type"
ATTR_EVENT_CODE = 'event_code' ATTR_EVENT_CODE = "event_code"
ATTR_EVENT_NAME = 'event_name' ATTR_EVENT_NAME = "event_name"
ATTR_EVENT_TYPE = 'event_type' ATTR_EVENT_TYPE = "event_type"
ATTR_EVENT_UTC = 'event_utc' ATTR_EVENT_UTC = "event_utc"
ATTR_SETTING = 'setting' ATTR_SETTING = "setting"
ATTR_USER_NAME = 'user_name' ATTR_USER_NAME = "user_name"
ATTR_VALUE = 'value' ATTR_VALUE = "value"
ABODE_DEVICE_ID_LIST_SCHEMA = vol.Schema([str]) ABODE_DEVICE_ID_LIST_SCHEMA = vol.Schema([str])
CONFIG_SCHEMA = vol.Schema({ CONFIG_SCHEMA = vol.Schema(
DOMAIN: vol.Schema({ {
vol.Required(CONF_USERNAME): cv.string, DOMAIN: vol.Schema(
vol.Required(CONF_PASSWORD): cv.string, {
vol.Optional(CONF_NAME): cv.string, vol.Required(CONF_USERNAME): cv.string,
vol.Optional(CONF_POLLING, default=False): cv.boolean, vol.Required(CONF_PASSWORD): cv.string,
vol.Optional(CONF_EXCLUDE, default=[]): ABODE_DEVICE_ID_LIST_SCHEMA, vol.Optional(CONF_NAME): cv.string,
vol.Optional(CONF_LIGHTS, default=[]): ABODE_DEVICE_ID_LIST_SCHEMA vol.Optional(CONF_POLLING, default=False): cv.boolean,
}), vol.Optional(CONF_EXCLUDE, default=[]): ABODE_DEVICE_ID_LIST_SCHEMA,
}, extra=vol.ALLOW_EXTRA) vol.Optional(CONF_LIGHTS, default=[]): ABODE_DEVICE_ID_LIST_SCHEMA,
}
)
},
extra=vol.ALLOW_EXTRA,
)
CHANGE_SETTING_SCHEMA = vol.Schema({ CHANGE_SETTING_SCHEMA = vol.Schema(
vol.Required(ATTR_SETTING): cv.string, {vol.Required(ATTR_SETTING): cv.string, vol.Required(ATTR_VALUE): cv.string}
vol.Required(ATTR_VALUE): cv.string )
})
CAPTURE_IMAGE_SCHEMA = vol.Schema({ CAPTURE_IMAGE_SCHEMA = vol.Schema({ATTR_ENTITY_ID: cv.entity_ids})
ATTR_ENTITY_ID: cv.entity_ids,
})
TRIGGER_SCHEMA = vol.Schema({ TRIGGER_SCHEMA = vol.Schema({ATTR_ENTITY_ID: cv.entity_ids})
ATTR_ENTITY_ID: cv.entity_ids,
})
ABODE_PLATFORMS = [ ABODE_PLATFORMS = [
'alarm_control_panel', 'binary_sensor', 'lock', 'switch', 'cover', "alarm_control_panel",
'camera', 'light', 'sensor' "binary_sensor",
"lock",
"switch",
"cover",
"camera",
"light",
"sensor",
] ]
class AbodeSystem: class AbodeSystem:
"""Abode System class.""" """Abode System class."""
def __init__(self, username, password, cache, def __init__(self, username, password, cache, name, polling, exclude, lights):
name, polling, exclude, lights):
"""Initialize the system.""" """Initialize the system."""
import abodepy import abodepy
self.abode = abodepy.Abode( self.abode = abodepy.Abode(
username, password, auto_login=True, get_devices=True, username,
get_automations=True, cache_path=cache) password,
auto_login=True,
get_devices=True,
get_automations=True,
cache_path=cache,
)
self.name = name self.name = name
self.polling = polling self.polling = polling
self.exclude = exclude self.exclude = exclude
@ -106,9 +126,9 @@ class AbodeSystem:
"""Check if a switch device is configured as a light.""" """Check if a switch device is configured as a light."""
import abodepy.helpers.constants as CONST import abodepy.helpers.constants as CONST
return (device.generic_type == CONST.TYPE_LIGHT or return device.generic_type == CONST.TYPE_LIGHT or (
(device.generic_type == CONST.TYPE_SWITCH and device.generic_type == CONST.TYPE_SWITCH and device.device_id in self.lights
device.device_id in self.lights)) )
def setup(hass, config): def setup(hass, config):
@ -126,16 +146,18 @@ def setup(hass, config):
try: try:
cache = hass.config.path(DEFAULT_CACHEDB) cache = hass.config.path(DEFAULT_CACHEDB)
hass.data[DOMAIN] = AbodeSystem( hass.data[DOMAIN] = AbodeSystem(
username, password, cache, name, polling, exclude, lights) username, password, cache, name, polling, exclude, lights
)
except (AbodeException, ConnectTimeout, HTTPError) as ex: except (AbodeException, ConnectTimeout, HTTPError) as ex:
_LOGGER.error("Unable to connect to Abode: %s", str(ex)) _LOGGER.error("Unable to connect to Abode: %s", str(ex))
hass.components.persistent_notification.create( hass.components.persistent_notification.create(
'Error: {}<br />' "Error: {}<br />"
'You will need to restart hass after fixing.' "You will need to restart hass after fixing."
''.format(ex), "".format(ex),
title=NOTIFICATION_TITLE, title=NOTIFICATION_TITLE,
notification_id=NOTIFICATION_ID) notification_id=NOTIFICATION_ID,
)
return False return False
setup_hass_services(hass) setup_hass_services(hass)
@ -166,8 +188,11 @@ def setup_hass_services(hass):
"""Capture a new image.""" """Capture a new image."""
entity_ids = call.data.get(ATTR_ENTITY_ID) entity_ids = call.data.get(ATTR_ENTITY_ID)
target_devices = [device for device in hass.data[DOMAIN].devices target_devices = [
if device.entity_id in entity_ids] device
for device in hass.data[DOMAIN].devices
if device.entity_id in entity_ids
]
for device in target_devices: for device in target_devices:
device.capture() device.capture()
@ -176,27 +201,31 @@ def setup_hass_services(hass):
"""Trigger a quick action.""" """Trigger a quick action."""
entity_ids = call.data.get(ATTR_ENTITY_ID, None) entity_ids = call.data.get(ATTR_ENTITY_ID, None)
target_devices = [device for device in hass.data[DOMAIN].devices target_devices = [
if device.entity_id in entity_ids] device
for device in hass.data[DOMAIN].devices
if device.entity_id in entity_ids
]
for device in target_devices: for device in target_devices:
device.trigger() device.trigger()
hass.services.register( hass.services.register(
DOMAIN, SERVICE_SETTINGS, change_setting, DOMAIN, SERVICE_SETTINGS, change_setting, schema=CHANGE_SETTING_SCHEMA
schema=CHANGE_SETTING_SCHEMA) )
hass.services.register( hass.services.register(
DOMAIN, SERVICE_CAPTURE_IMAGE, capture_image, DOMAIN, SERVICE_CAPTURE_IMAGE, capture_image, schema=CAPTURE_IMAGE_SCHEMA
schema=CAPTURE_IMAGE_SCHEMA) )
hass.services.register( hass.services.register(
DOMAIN, SERVICE_TRIGGER, trigger_quick_action, DOMAIN, SERVICE_TRIGGER, trigger_quick_action, schema=TRIGGER_SCHEMA
schema=TRIGGER_SCHEMA) )
def setup_hass_events(hass): def setup_hass_events(hass):
"""Home Assistant start and stop callbacks.""" """Home Assistant start and stop callbacks."""
def startup(event): def startup(event):
"""Listen for push events.""" """Listen for push events."""
hass.data[DOMAIN].abode.events.start() hass.data[DOMAIN].abode.events.start()
@ -222,28 +251,32 @@ def setup_abode_events(hass):
def event_callback(event, event_json): def event_callback(event, event_json):
"""Handle an event callback from Abode.""" """Handle an event callback from Abode."""
data = { data = {
ATTR_DEVICE_ID: event_json.get(ATTR_DEVICE_ID, ''), ATTR_DEVICE_ID: event_json.get(ATTR_DEVICE_ID, ""),
ATTR_DEVICE_NAME: event_json.get(ATTR_DEVICE_NAME, ''), ATTR_DEVICE_NAME: event_json.get(ATTR_DEVICE_NAME, ""),
ATTR_DEVICE_TYPE: event_json.get(ATTR_DEVICE_TYPE, ''), ATTR_DEVICE_TYPE: event_json.get(ATTR_DEVICE_TYPE, ""),
ATTR_EVENT_CODE: event_json.get(ATTR_EVENT_CODE, ''), ATTR_EVENT_CODE: event_json.get(ATTR_EVENT_CODE, ""),
ATTR_EVENT_NAME: event_json.get(ATTR_EVENT_NAME, ''), ATTR_EVENT_NAME: event_json.get(ATTR_EVENT_NAME, ""),
ATTR_EVENT_TYPE: event_json.get(ATTR_EVENT_TYPE, ''), ATTR_EVENT_TYPE: event_json.get(ATTR_EVENT_TYPE, ""),
ATTR_EVENT_UTC: event_json.get(ATTR_EVENT_UTC, ''), ATTR_EVENT_UTC: event_json.get(ATTR_EVENT_UTC, ""),
ATTR_USER_NAME: event_json.get(ATTR_USER_NAME, ''), ATTR_USER_NAME: event_json.get(ATTR_USER_NAME, ""),
ATTR_DATE: event_json.get(ATTR_DATE, ''), ATTR_DATE: event_json.get(ATTR_DATE, ""),
ATTR_TIME: event_json.get(ATTR_TIME, ''), ATTR_TIME: event_json.get(ATTR_TIME, ""),
} }
hass.bus.fire(event, data) hass.bus.fire(event, data)
events = [TIMELINE.ALARM_GROUP, TIMELINE.ALARM_END_GROUP, events = [
TIMELINE.PANEL_FAULT_GROUP, TIMELINE.PANEL_RESTORE_GROUP, TIMELINE.ALARM_GROUP,
TIMELINE.AUTOMATION_GROUP] TIMELINE.ALARM_END_GROUP,
TIMELINE.PANEL_FAULT_GROUP,
TIMELINE.PANEL_RESTORE_GROUP,
TIMELINE.AUTOMATION_GROUP,
]
for event in events: for event in events:
hass.data[DOMAIN].abode.events.add_event_callback( hass.data[DOMAIN].abode.events.add_event_callback(
event, event, partial(event_callback, event)
partial(event_callback, event)) )
class AbodeDevice(Entity): class AbodeDevice(Entity):
@ -258,7 +291,8 @@ class AbodeDevice(Entity):
"""Subscribe Abode events.""" """Subscribe Abode events."""
self.hass.async_add_job( self.hass.async_add_job(
self._data.abode.events.add_device_callback, self._data.abode.events.add_device_callback,
self._device.device_id, self._update_callback self._device.device_id,
self._update_callback,
) )
@property @property
@ -280,10 +314,10 @@ class AbodeDevice(Entity):
"""Return the state attributes.""" """Return the state attributes."""
return { return {
ATTR_ATTRIBUTION: ATTRIBUTION, ATTR_ATTRIBUTION: ATTRIBUTION,
'device_id': self._device.device_id, "device_id": self._device.device_id,
'battery_low': self._device.battery_low, "battery_low": self._device.battery_low,
'no_response': self._device.no_response, "no_response": self._device.no_response,
'device_type': self._device.type "device_type": self._device.type,
} }
def _update_callback(self, device): def _update_callback(self, device):
@ -305,7 +339,8 @@ class AbodeAutomation(Entity):
if self._event: if self._event:
self.hass.async_add_job( self.hass.async_add_job(
self._data.abode.events.add_event_callback, self._data.abode.events.add_event_callback,
self._event, self._update_callback self._event,
self._update_callback,
) )
@property @property
@ -327,9 +362,9 @@ class AbodeAutomation(Entity):
"""Return the state attributes.""" """Return the state attributes."""
return { return {
ATTR_ATTRIBUTION: ATTRIBUTION, ATTR_ATTRIBUTION: ATTRIBUTION,
'automation_id': self._automation.automation_id, "automation_id": self._automation.automation_id,
'type': self._automation.type, "type": self._automation.type,
'sub_type': self._automation.sub_type "sub_type": self._automation.sub_type,
} }
def _update_callback(self, device): def _update_callback(self, device):

View file

@ -3,14 +3,17 @@ import logging
import homeassistant.components.alarm_control_panel as alarm import homeassistant.components.alarm_control_panel as alarm
from homeassistant.const import ( from homeassistant.const import (
ATTR_ATTRIBUTION, STATE_ALARM_ARMED_AWAY, STATE_ALARM_ARMED_HOME, ATTR_ATTRIBUTION,
STATE_ALARM_DISARMED) STATE_ALARM_ARMED_AWAY,
STATE_ALARM_ARMED_HOME,
STATE_ALARM_DISARMED,
)
from . import ATTRIBUTION, DOMAIN as ABODE_DOMAIN, AbodeDevice from . import ATTRIBUTION, DOMAIN as ABODE_DOMAIN, AbodeDevice
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ICON = 'mdi:security' ICON = "mdi:security"
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
@ -72,7 +75,7 @@ class AbodeAlarm(AbodeDevice, alarm.AlarmControlPanel):
"""Return the state attributes.""" """Return the state attributes."""
return { return {
ATTR_ATTRIBUTION: ATTRIBUTION, ATTR_ATTRIBUTION: ATTRIBUTION,
'device_id': self._device.device_id, "device_id": self._device.device_id,
'battery_backup': self._device.battery, "battery_backup": self._device.battery,
'cellular_backup': self._device.is_cellular, "cellular_backup": self._device.is_cellular,
} }

View file

@ -15,9 +15,13 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
data = hass.data[ABODE_DOMAIN] data = hass.data[ABODE_DOMAIN]
device_types = [CONST.TYPE_CONNECTIVITY, CONST.TYPE_MOISTURE, device_types = [
CONST.TYPE_MOTION, CONST.TYPE_OCCUPANCY, CONST.TYPE_CONNECTIVITY,
CONST.TYPE_OPENING] CONST.TYPE_MOISTURE,
CONST.TYPE_MOTION,
CONST.TYPE_OCCUPANCY,
CONST.TYPE_OPENING,
]
devices = [] devices = []
for device in data.abode.get_devices(generic_type=device_types): for device in data.abode.get_devices(generic_type=device_types):
@ -26,13 +30,15 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
devices.append(AbodeBinarySensor(data, device)) devices.append(AbodeBinarySensor(data, device))
for automation in data.abode.get_automations( for automation in data.abode.get_automations(generic_type=CONST.TYPE_QUICK_ACTION):
generic_type=CONST.TYPE_QUICK_ACTION):
if data.is_automation_excluded(automation): if data.is_automation_excluded(automation):
continue continue
devices.append(AbodeQuickActionBinarySensor( devices.append(
data, automation, TIMELINE.AUTOMATION_EDIT_GROUP)) AbodeQuickActionBinarySensor(
data, automation, TIMELINE.AUTOMATION_EDIT_GROUP
)
)
data.devices.extend(devices) data.devices.extend(devices)

View file

@ -49,7 +49,8 @@ class AbodeCamera(AbodeDevice, Camera):
self.hass.async_add_job( self.hass.async_add_job(
self._data.abode.events.add_timeline_callback, self._data.abode.events.add_timeline_callback,
self._event, self._capture_callback self._event,
self._capture_callback,
) )
def capture(self): def capture(self):
@ -66,8 +67,7 @@ class AbodeCamera(AbodeDevice, Camera):
"""Attempt to download the most recent capture.""" """Attempt to download the most recent capture."""
if self._device.image_url: if self._device.image_url:
try: try:
self._response = requests.get( self._response = requests.get(self._device.image_url, stream=True)
self._device.image_url, stream=True)
self._response.raise_for_status() self._response.raise_for_status()
except requests.HTTPError as err: except requests.HTTPError as err:

View file

@ -3,10 +3,18 @@ import logging
from math import ceil from math import ceil
from homeassistant.components.light import ( from homeassistant.components.light import (
ATTR_BRIGHTNESS, ATTR_COLOR_TEMP, ATTR_HS_COLOR, SUPPORT_BRIGHTNESS, ATTR_BRIGHTNESS,
SUPPORT_COLOR, SUPPORT_COLOR_TEMP, Light) ATTR_COLOR_TEMP,
ATTR_HS_COLOR,
SUPPORT_BRIGHTNESS,
SUPPORT_COLOR,
SUPPORT_COLOR_TEMP,
Light,
)
from homeassistant.util.color import ( from homeassistant.util.color import (
color_temperature_kelvin_to_mired, color_temperature_mired_to_kelvin) color_temperature_kelvin_to_mired,
color_temperature_mired_to_kelvin,
)
from . import DOMAIN as ABODE_DOMAIN, AbodeDevice from . import DOMAIN as ABODE_DOMAIN, AbodeDevice
@ -42,8 +50,8 @@ class AbodeLight(AbodeDevice, Light):
"""Turn on the light.""" """Turn on the light."""
if ATTR_COLOR_TEMP in kwargs and self._device.is_color_capable: if ATTR_COLOR_TEMP in kwargs and self._device.is_color_capable:
self._device.set_color_temp( self._device.set_color_temp(
int(color_temperature_mired_to_kelvin( int(color_temperature_mired_to_kelvin(kwargs[ATTR_COLOR_TEMP]))
kwargs[ATTR_COLOR_TEMP]))) )
if ATTR_HS_COLOR in kwargs and self._device.is_color_capable: if ATTR_HS_COLOR in kwargs and self._device.is_color_capable:
self._device.set_color(kwargs[ATTR_HS_COLOR]) self._device.set_color(kwargs[ATTR_HS_COLOR])

View file

@ -2,7 +2,10 @@
import logging import logging
from homeassistant.const import ( from homeassistant.const import (
DEVICE_CLASS_HUMIDITY, DEVICE_CLASS_ILLUMINANCE, DEVICE_CLASS_TEMPERATURE) DEVICE_CLASS_HUMIDITY,
DEVICE_CLASS_ILLUMINANCE,
DEVICE_CLASS_TEMPERATURE,
)
from . import DOMAIN as ABODE_DOMAIN, AbodeDevice from . import DOMAIN as ABODE_DOMAIN, AbodeDevice
@ -10,9 +13,9 @@ _LOGGER = logging.getLogger(__name__)
# Sensor types: Name, icon # Sensor types: Name, icon
SENSOR_TYPES = { SENSOR_TYPES = {
'temp': ['Temperature', DEVICE_CLASS_TEMPERATURE], "temp": ["Temperature", DEVICE_CLASS_TEMPERATURE],
'humidity': ['Humidity', DEVICE_CLASS_HUMIDITY], "humidity": ["Humidity", DEVICE_CLASS_HUMIDITY],
'lux': ['Lux', DEVICE_CLASS_ILLUMINANCE], "lux": ["Lux", DEVICE_CLASS_ILLUMINANCE],
} }
@ -42,8 +45,9 @@ class AbodeSensor(AbodeDevice):
"""Initialize a sensor for an Abode device.""" """Initialize a sensor for an Abode device."""
super().__init__(data, device) super().__init__(data, device)
self._sensor_type = sensor_type self._sensor_type = sensor_type
self._name = '{0} {1}'.format( self._name = "{0} {1}".format(
self._device.name, SENSOR_TYPES[self._sensor_type][0]) self._device.name, SENSOR_TYPES[self._sensor_type][0]
)
self._device_class = SENSOR_TYPES[self._sensor_type][1] self._device_class = SENSOR_TYPES[self._sensor_type][1]
@property @property
@ -59,19 +63,19 @@ class AbodeSensor(AbodeDevice):
@property @property
def state(self): def state(self):
"""Return the state of the sensor.""" """Return the state of the sensor."""
if self._sensor_type == 'temp': if self._sensor_type == "temp":
return self._device.temp return self._device.temp
if self._sensor_type == 'humidity': if self._sensor_type == "humidity":
return self._device.humidity return self._device.humidity
if self._sensor_type == 'lux': if self._sensor_type == "lux":
return self._device.lux return self._device.lux
@property @property
def unit_of_measurement(self): def unit_of_measurement(self):
"""Return the units of measurement.""" """Return the units of measurement."""
if self._sensor_type == 'temp': if self._sensor_type == "temp":
return self._device.temp_unit return self._device.temp_unit
if self._sensor_type == 'humidity': if self._sensor_type == "humidity":
return self._device.humidity_unit return self._device.humidity_unit
if self._sensor_type == 'lux': if self._sensor_type == "lux":
return self._device.lux_unit return self._device.lux_unit

View file

@ -25,13 +25,13 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
devices.append(AbodeSwitch(data, device)) devices.append(AbodeSwitch(data, device))
# Get all Abode automations that can be enabled/disabled # Get all Abode automations that can be enabled/disabled
for automation in data.abode.get_automations( for automation in data.abode.get_automations(generic_type=CONST.TYPE_AUTOMATION):
generic_type=CONST.TYPE_AUTOMATION):
if data.is_automation_excluded(automation): if data.is_automation_excluded(automation):
continue continue
devices.append(AbodeAutomationSwitch( devices.append(
data, automation, TIMELINE.AUTOMATION_EDIT_GROUP)) AbodeAutomationSwitch(data, automation, TIMELINE.AUTOMATION_EDIT_GROUP)
)
data.devices.extend(devices) data.devices.extend(devices)

View file

@ -4,50 +4,58 @@ import re
import voluptuous as vol import voluptuous as vol
from homeassistant.components.switch import (SwitchDevice, PLATFORM_SCHEMA) from homeassistant.components.switch import SwitchDevice, PLATFORM_SCHEMA
from homeassistant.const import ( from homeassistant.const import (
STATE_ON, STATE_OFF, STATE_UNKNOWN, CONF_NAME, CONF_FILENAME) STATE_ON,
STATE_OFF,
STATE_UNKNOWN,
CONF_NAME,
CONF_FILENAME,
)
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
CONF_TIMEOUT = 'timeout' CONF_TIMEOUT = "timeout"
CONF_WRITE_TIMEOUT = 'write_timeout' CONF_WRITE_TIMEOUT = "write_timeout"
DEFAULT_NAME = 'Acer Projector' DEFAULT_NAME = "Acer Projector"
DEFAULT_TIMEOUT = 1 DEFAULT_TIMEOUT = 1
DEFAULT_WRITE_TIMEOUT = 1 DEFAULT_WRITE_TIMEOUT = 1
ECO_MODE = 'ECO Mode' ECO_MODE = "ECO Mode"
ICON = 'mdi:projector' ICON = "mdi:projector"
INPUT_SOURCE = 'Input Source' INPUT_SOURCE = "Input Source"
LAMP = 'Lamp' LAMP = "Lamp"
LAMP_HOURS = 'Lamp Hours' LAMP_HOURS = "Lamp Hours"
MODEL = 'Model' MODEL = "Model"
# Commands known to the projector # Commands known to the projector
CMD_DICT = { CMD_DICT = {
LAMP: '* 0 Lamp ?\r', LAMP: "* 0 Lamp ?\r",
LAMP_HOURS: '* 0 Lamp\r', LAMP_HOURS: "* 0 Lamp\r",
INPUT_SOURCE: '* 0 Src ?\r', INPUT_SOURCE: "* 0 Src ?\r",
ECO_MODE: '* 0 IR 052\r', ECO_MODE: "* 0 IR 052\r",
MODEL: '* 0 IR 035\r', MODEL: "* 0 IR 035\r",
STATE_ON: '* 0 IR 001\r', STATE_ON: "* 0 IR 001\r",
STATE_OFF: '* 0 IR 002\r', STATE_OFF: "* 0 IR 002\r",
} }
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_FILENAME): cv.isdevice, {
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Required(CONF_FILENAME): cv.isdevice,
vol.Optional(CONF_TIMEOUT, default=DEFAULT_TIMEOUT): cv.positive_int, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Optional(CONF_WRITE_TIMEOUT, default=DEFAULT_WRITE_TIMEOUT): vol.Optional(CONF_TIMEOUT, default=DEFAULT_TIMEOUT): cv.positive_int,
cv.positive_int, vol.Optional(
}) CONF_WRITE_TIMEOUT, default=DEFAULT_WRITE_TIMEOUT
): cv.positive_int,
}
)
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
@ -66,9 +74,10 @@ class AcerSwitch(SwitchDevice):
def __init__(self, serial_port, name, timeout, write_timeout, **kwargs): def __init__(self, serial_port, name, timeout, write_timeout, **kwargs):
"""Init of the Acer projector.""" """Init of the Acer projector."""
import serial import serial
self.ser = serial.Serial( self.ser = serial.Serial(
port=serial_port, timeout=timeout, write_timeout=write_timeout, port=serial_port, timeout=timeout, write_timeout=write_timeout, **kwargs
**kwargs) )
self._serial_port = serial_port self._serial_port = serial_port
self._name = name self._name = name
self._state = False self._state = False
@ -82,6 +91,7 @@ class AcerSwitch(SwitchDevice):
def _write_read(self, msg): def _write_read(self, msg):
"""Write to the projector and read the return.""" """Write to the projector and read the return."""
import serial import serial
ret = "" ret = ""
# Sometimes the projector won't answer for no reason or the projector # Sometimes the projector won't answer for no reason or the projector
# was disconnected during runtime. # was disconnected during runtime.
@ -89,14 +99,14 @@ class AcerSwitch(SwitchDevice):
try: try:
if not self.ser.is_open: if not self.ser.is_open:
self.ser.open() self.ser.open()
msg = msg.encode('utf-8') msg = msg.encode("utf-8")
self.ser.write(msg) self.ser.write(msg)
# Size is an experience value there is no real limit. # Size is an experience value there is no real limit.
# AFAIK there is no limit and no end character so we will usually # AFAIK there is no limit and no end character so we will usually
# need to wait for timeout # need to wait for timeout
ret = self.ser.read_until(size=20).decode('utf-8') ret = self.ser.read_until(size=20).decode("utf-8")
except serial.SerialException: except serial.SerialException:
_LOGGER.error('Problem communicating with %s', self._serial_port) _LOGGER.error("Problem communicating with %s", self._serial_port)
self.ser.close() self.ser.close()
return ret return ret
@ -104,7 +114,7 @@ class AcerSwitch(SwitchDevice):
"""Write msg, obtain answer and format output.""" """Write msg, obtain answer and format output."""
# answers are formatted as ***\answer\r*** # answers are formatted as ***\answer\r***
awns = self._write_read(msg) awns = self._write_read(msg)
match = re.search(r'\r(.+)\r', awns) match = re.search(r"\r(.+)\r", awns)
if match: if match:
return match.group(1) return match.group(1)
return STATE_UNKNOWN return STATE_UNKNOWN
@ -133,10 +143,10 @@ class AcerSwitch(SwitchDevice):
"""Get the latest state from the projector.""" """Get the latest state from the projector."""
msg = CMD_DICT[LAMP] msg = CMD_DICT[LAMP]
awns = self._write_read_format(msg) awns = self._write_read_format(msg)
if awns == 'Lamp 1': if awns == "Lamp 1":
self._state = True self._state = True
self._available = True self._available = True
elif awns == 'Lamp 0': elif awns == "Lamp 0":
self._state = False self._state = False
self._available = True self._available = True
else: else:

View file

@ -8,22 +8,28 @@ import voluptuous as vol
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from homeassistant.components.device_tracker import ( from homeassistant.components.device_tracker import (
DOMAIN, PLATFORM_SCHEMA, DeviceScanner) DOMAIN,
PLATFORM_SCHEMA,
DeviceScanner,
)
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_LEASES_REGEX = re.compile( _LEASES_REGEX = re.compile(
r'(?P<ip>([0-9]{1,3}[\.]){3}[0-9]{1,3})' + r"(?P<ip>([0-9]{1,3}[\.]){3}[0-9]{1,3})"
r'\smac:\s(?P<mac>([0-9a-f]{2}[:-]){5}([0-9a-f]{2}))' + + r"\smac:\s(?P<mac>([0-9a-f]{2}[:-]){5}([0-9a-f]{2}))"
r'\svalid\sfor:\s(?P<timevalid>(-?\d+))' + + r"\svalid\sfor:\s(?P<timevalid>(-?\d+))"
r'\ssec') + r"\ssec"
)
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_HOST): cv.string, {
vol.Required(CONF_PASSWORD): cv.string, vol.Required(CONF_HOST): cv.string,
vol.Required(CONF_USERNAME): cv.string vol.Required(CONF_PASSWORD): cv.string,
}) vol.Required(CONF_USERNAME): cv.string,
}
)
def get_scanner(hass, config): def get_scanner(hass, config):
@ -32,7 +38,7 @@ def get_scanner(hass, config):
return scanner if scanner.success_init else None return scanner if scanner.success_init else None
Device = namedtuple('Device', ['mac', 'ip', 'last_update']) Device = namedtuple("Device", ["mac", "ip", "last_update"])
class ActiontecDeviceScanner(DeviceScanner): class ActiontecDeviceScanner(DeviceScanner):
@ -75,9 +81,11 @@ class ActiontecDeviceScanner(DeviceScanner):
actiontec_data = self.get_actiontec_data() actiontec_data = self.get_actiontec_data()
if not actiontec_data: if not actiontec_data:
return False return False
self.last_results = [Device(data['mac'], name, now) self.last_results = [
for name, data in actiontec_data.items() Device(data["mac"], name, now)
if data['timevalid'] > -60] for name, data in actiontec_data.items()
if data["timevalid"] > -60
]
_LOGGER.info("Scan successful") _LOGGER.info("Scan successful")
return True return True
@ -85,17 +93,16 @@ class ActiontecDeviceScanner(DeviceScanner):
"""Retrieve data from Actiontec MI424WR and return parsed result.""" """Retrieve data from Actiontec MI424WR and return parsed result."""
try: try:
telnet = telnetlib.Telnet(self.host) telnet = telnetlib.Telnet(self.host)
telnet.read_until(b'Username: ') telnet.read_until(b"Username: ")
telnet.write((self.username + '\n').encode('ascii')) telnet.write((self.username + "\n").encode("ascii"))
telnet.read_until(b'Password: ') telnet.read_until(b"Password: ")
telnet.write((self.password + '\n').encode('ascii')) telnet.write((self.password + "\n").encode("ascii"))
prompt = telnet.read_until( prompt = telnet.read_until(b"Wireless Broadband Router> ").split(b"\n")[-1]
b'Wireless Broadband Router> ').split(b'\n')[-1] telnet.write("firewall mac_cache_dump\n".encode("ascii"))
telnet.write('firewall mac_cache_dump\n'.encode('ascii')) telnet.write("\n".encode("ascii"))
telnet.write('\n'.encode('ascii'))
telnet.read_until(prompt) telnet.read_until(prompt)
leases_result = telnet.read_until(prompt).split(b'\n')[1:-1] leases_result = telnet.read_until(prompt).split(b"\n")[1:-1]
telnet.write('exit\n'.encode('ascii')) telnet.write("exit\n".encode("ascii"))
except EOFError: except EOFError:
_LOGGER.exception("Unexpected response from router") _LOGGER.exception("Unexpected response from router")
return return
@ -105,11 +112,11 @@ class ActiontecDeviceScanner(DeviceScanner):
devices = {} devices = {}
for lease in leases_result: for lease in leases_result:
match = _LEASES_REGEX.search(lease.decode('utf-8')) match = _LEASES_REGEX.search(lease.decode("utf-8"))
if match is not None: if match is not None:
devices[match.group('ip')] = { devices[match.group("ip")] = {
'ip': match.group('ip'), "ip": match.group("ip"),
'mac': match.group('mac').upper(), "mac": match.group("mac").upper(),
'timevalid': int(match.group('timevalid')) "timevalid": int(match.group("timevalid")),
} }
return devices return devices

View file

@ -6,13 +6,27 @@ from adguardhome import AdGuardHome, AdGuardHomeError
import voluptuous as vol import voluptuous as vol
from homeassistant.components.adguard.const import ( from homeassistant.components.adguard.const import (
CONF_FORCE, DATA_ADGUARD_CLIENT, DATA_ADGUARD_VERION, DOMAIN, CONF_FORCE,
SERVICE_ADD_URL, SERVICE_DISABLE_URL, SERVICE_ENABLE_URL, SERVICE_REFRESH, DATA_ADGUARD_CLIENT,
SERVICE_REMOVE_URL) DATA_ADGUARD_VERION,
DOMAIN,
SERVICE_ADD_URL,
SERVICE_DISABLE_URL,
SERVICE_ENABLE_URL,
SERVICE_REFRESH,
SERVICE_REMOVE_URL,
)
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
CONF_HOST, CONF_NAME, CONF_PASSWORD, CONF_PORT, CONF_SSL, CONF_URL, CONF_HOST,
CONF_USERNAME, CONF_VERIFY_SSL) CONF_NAME,
CONF_PASSWORD,
CONF_PORT,
CONF_SSL,
CONF_URL,
CONF_USERNAME,
CONF_VERIFY_SSL,
)
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
@ -34,9 +48,7 @@ async def async_setup(hass: HomeAssistantType, config: ConfigType) -> bool:
return True return True
async def async_setup_entry( async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry) -> bool:
hass: HomeAssistantType, entry: ConfigEntry
) -> bool:
"""Set up AdGuard Home from a config entry.""" """Set up AdGuard Home from a config entry."""
session = async_get_clientsession(hass, entry.data[CONF_VERIFY_SSL]) session = async_get_clientsession(hass, entry.data[CONF_VERIFY_SSL])
adguard = AdGuardHome( adguard = AdGuardHome(
@ -52,7 +64,7 @@ async def async_setup_entry(
hass.data.setdefault(DOMAIN, {})[DATA_ADGUARD_CLIENT] = adguard hass.data.setdefault(DOMAIN, {})[DATA_ADGUARD_CLIENT] = adguard
for component in 'sensor', 'switch': for component in "sensor", "switch":
hass.async_create_task( hass.async_create_task(
hass.config_entries.async_forward_entry_setup(entry, component) hass.config_entries.async_forward_entry_setup(entry, component)
) )
@ -98,9 +110,7 @@ async def async_setup_entry(
return True return True
async def async_unload_entry( async def async_unload_entry(hass: HomeAssistantType, entry: ConfigType) -> bool:
hass: HomeAssistantType, entry: ConfigType
) -> bool:
"""Unload AdGuard Home config entry.""" """Unload AdGuard Home config entry."""
hass.services.async_remove(DOMAIN, SERVICE_ADD_URL) hass.services.async_remove(DOMAIN, SERVICE_ADD_URL)
hass.services.async_remove(DOMAIN, SERVICE_REMOVE_URL) hass.services.async_remove(DOMAIN, SERVICE_REMOVE_URL)
@ -108,7 +118,7 @@ async def async_unload_entry(
hass.services.async_remove(DOMAIN, SERVICE_DISABLE_URL) hass.services.async_remove(DOMAIN, SERVICE_DISABLE_URL)
hass.services.async_remove(DOMAIN, SERVICE_REFRESH) hass.services.async_remove(DOMAIN, SERVICE_REFRESH)
for component in 'sensor', 'switch': for component in "sensor", "switch":
await hass.config_entries.async_forward_entry_unload(entry, component) await hass.config_entries.async_forward_entry_unload(entry, component)
del hass.data[DOMAIN] del hass.data[DOMAIN]
@ -166,15 +176,10 @@ class AdGuardHomeDeviceEntity(AdGuardHomeEntity):
def device_info(self) -> Dict[str, Any]: def device_info(self) -> Dict[str, Any]:
"""Return device information about this AdGuard Home instance.""" """Return device information about this AdGuard Home instance."""
return { return {
'identifiers': { "identifiers": {
( (DOMAIN, self.adguard.host, self.adguard.port, self.adguard.base_path)
DOMAIN,
self.adguard.host,
self.adguard.port,
self.adguard.base_path,
)
}, },
'name': 'AdGuard Home', "name": "AdGuard Home",
'manufacturer': 'AdGuard Team', "manufacturer": "AdGuard Team",
'sw_version': self.hass.data[DOMAIN].get(DATA_ADGUARD_VERION), "sw_version": self.hass.data[DOMAIN].get(DATA_ADGUARD_VERION),
} }

View file

@ -8,8 +8,13 @@ from homeassistant import config_entries
from homeassistant.components.adguard.const import DOMAIN from homeassistant.components.adguard.const import DOMAIN
from homeassistant.config_entries import ConfigFlow from homeassistant.config_entries import ConfigFlow
from homeassistant.const import ( from homeassistant.const import (
CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_SSL, CONF_USERNAME, CONF_HOST,
CONF_VERIFY_SSL) CONF_PASSWORD,
CONF_PORT,
CONF_SSL,
CONF_USERNAME,
CONF_VERIFY_SSL,
)
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -31,7 +36,7 @@ class AdGuardHomeFlowHandler(ConfigFlow):
async def _show_setup_form(self, errors=None): async def _show_setup_form(self, errors=None):
"""Show the setup form to the user.""" """Show the setup form to the user."""
return self.async_show_form( return self.async_show_form(
step_id='user', step_id="user",
data_schema=vol.Schema( data_schema=vol.Schema(
{ {
vol.Required(CONF_HOST): str, vol.Required(CONF_HOST): str,
@ -48,10 +53,8 @@ class AdGuardHomeFlowHandler(ConfigFlow):
async def _show_hassio_form(self, errors=None): async def _show_hassio_form(self, errors=None):
"""Show the Hass.io confirmation form to the user.""" """Show the Hass.io confirmation form to the user."""
return self.async_show_form( return self.async_show_form(
step_id='hassio_confirm', step_id="hassio_confirm",
description_placeholders={ description_placeholders={"addon": self._hassio_discovery["addon"]},
'addon': self._hassio_discovery['addon']
},
data_schema=vol.Schema({}), data_schema=vol.Schema({}),
errors=errors or {}, errors=errors or {},
) )
@ -59,16 +62,14 @@ class AdGuardHomeFlowHandler(ConfigFlow):
async def async_step_user(self, user_input=None): async def async_step_user(self, user_input=None):
"""Handle a flow initiated by the user.""" """Handle a flow initiated by the user."""
if self._async_current_entries(): if self._async_current_entries():
return self.async_abort(reason='single_instance_allowed') return self.async_abort(reason="single_instance_allowed")
if user_input is None: if user_input is None:
return await self._show_setup_form(user_input) return await self._show_setup_form(user_input)
errors = {} errors = {}
session = async_get_clientsession( session = async_get_clientsession(self.hass, user_input[CONF_VERIFY_SSL])
self.hass, user_input[CONF_VERIFY_SSL]
)
adguard = AdGuardHome( adguard = AdGuardHome(
user_input[CONF_HOST], user_input[CONF_HOST],
@ -84,7 +85,7 @@ class AdGuardHomeFlowHandler(ConfigFlow):
try: try:
await adguard.version() await adguard.version()
except AdGuardHomeConnectionError: except AdGuardHomeConnectionError:
errors['base'] = 'connection_error' errors["base"] = "connection_error"
return await self._show_setup_form(errors) return await self._show_setup_form(errors)
return self.async_create_entry( return self.async_create_entry(
@ -112,25 +113,30 @@ class AdGuardHomeFlowHandler(ConfigFlow):
cur_entry = entries[0] cur_entry = entries[0]
if (cur_entry.data[CONF_HOST] == user_input[CONF_HOST] and if (
cur_entry.data[CONF_PORT] == user_input[CONF_PORT]): cur_entry.data[CONF_HOST] == user_input[CONF_HOST]
return self.async_abort(reason='single_instance_allowed') and cur_entry.data[CONF_PORT] == user_input[CONF_PORT]
):
return self.async_abort(reason="single_instance_allowed")
is_loaded = cur_entry.state == config_entries.ENTRY_STATE_LOADED is_loaded = cur_entry.state == config_entries.ENTRY_STATE_LOADED
if is_loaded: if is_loaded:
await self.hass.config_entries.async_unload(cur_entry.entry_id) await self.hass.config_entries.async_unload(cur_entry.entry_id)
self.hass.config_entries.async_update_entry(cur_entry, data={ self.hass.config_entries.async_update_entry(
**cur_entry.data, cur_entry,
CONF_HOST: user_input[CONF_HOST], data={
CONF_PORT: user_input[CONF_PORT], **cur_entry.data,
}) CONF_HOST: user_input[CONF_HOST],
CONF_PORT: user_input[CONF_PORT],
},
)
if is_loaded: if is_loaded:
await self.hass.config_entries.async_setup(cur_entry.entry_id) await self.hass.config_entries.async_setup(cur_entry.entry_id)
return self.async_abort(reason='existing_instance_updated') return self.async_abort(reason="existing_instance_updated")
async def async_step_hassio_confirm(self, user_input=None): async def async_step_hassio_confirm(self, user_input=None):
"""Confirm Hass.io discovery.""" """Confirm Hass.io discovery."""
@ -152,11 +158,11 @@ class AdGuardHomeFlowHandler(ConfigFlow):
try: try:
await adguard.version() await adguard.version()
except AdGuardHomeConnectionError: except AdGuardHomeConnectionError:
errors['base'] = 'connection_error' errors["base"] = "connection_error"
return await self._show_hassio_form(errors) return await self._show_hassio_form(errors)
return self.async_create_entry( return self.async_create_entry(
title=self._hassio_discovery['addon'], title=self._hassio_discovery["addon"],
data={ data={
CONF_HOST: self._hassio_discovery[CONF_HOST], CONF_HOST: self._hassio_discovery[CONF_HOST],
CONF_PORT: self._hassio_discovery[CONF_PORT], CONF_PORT: self._hassio_discovery[CONF_PORT],

View file

@ -1,14 +1,14 @@
"""Constants for the AdGuard Home integration.""" """Constants for the AdGuard Home integration."""
DOMAIN = 'adguard' DOMAIN = "adguard"
DATA_ADGUARD_CLIENT = 'adguard_client' DATA_ADGUARD_CLIENT = "adguard_client"
DATA_ADGUARD_VERION = 'adguard_version' DATA_ADGUARD_VERION = "adguard_version"
CONF_FORCE = 'force' CONF_FORCE = "force"
SERVICE_ADD_URL = 'add_url' SERVICE_ADD_URL = "add_url"
SERVICE_DISABLE_URL = 'disable_url' SERVICE_DISABLE_URL = "disable_url"
SERVICE_ENABLE_URL = 'enable_url' SERVICE_ENABLE_URL = "enable_url"
SERVICE_REFRESH = 'refresh' SERVICE_REFRESH = "refresh"
SERVICE_REMOVE_URL = 'remove_url' SERVICE_REMOVE_URL = "remove_url"

View file

@ -6,7 +6,10 @@ from adguardhome import AdGuardHomeConnectionError
from homeassistant.components.adguard import AdGuardHomeDeviceEntity from homeassistant.components.adguard import AdGuardHomeDeviceEntity
from homeassistant.components.adguard.const import ( from homeassistant.components.adguard.const import (
DATA_ADGUARD_CLIENT, DATA_ADGUARD_VERION, DOMAIN) DATA_ADGUARD_CLIENT,
DATA_ADGUARD_VERION,
DOMAIN,
)
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.exceptions import PlatformNotReady from homeassistant.exceptions import PlatformNotReady
from homeassistant.helpers.typing import HomeAssistantType from homeassistant.helpers.typing import HomeAssistantType
@ -18,7 +21,7 @@ PARALLEL_UPDATES = 4
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistantType, entry: ConfigEntry, async_add_entities hass: HomeAssistantType, entry: ConfigEntry, async_add_entities
) -> None: ) -> None:
"""Set up AdGuard Home sensor based on a config entry.""" """Set up AdGuard Home sensor based on a config entry."""
adguard = hass.data[DOMAIN][DATA_ADGUARD_CLIENT] adguard = hass.data[DOMAIN][DATA_ADGUARD_CLIENT]
@ -48,12 +51,7 @@ class AdGuardHomeSensor(AdGuardHomeDeviceEntity):
"""Defines a AdGuard Home sensor.""" """Defines a AdGuard Home sensor."""
def __init__( def __init__(
self, self, adguard, name: str, icon: str, measurement: str, unit_of_measurement: str
adguard,
name: str,
icon: str,
measurement: str,
unit_of_measurement: str,
) -> None: ) -> None:
"""Initialize AdGuard Home sensor.""" """Initialize AdGuard Home sensor."""
self._state = None self._state = None
@ -65,12 +63,12 @@ class AdGuardHomeSensor(AdGuardHomeDeviceEntity):
@property @property
def unique_id(self) -> str: def unique_id(self) -> str:
"""Return the unique ID for this sensor.""" """Return the unique ID for this sensor."""
return '_'.join( return "_".join(
[ [
DOMAIN, DOMAIN,
self.adguard.host, self.adguard.host,
str(self.adguard.port), str(self.adguard.port),
'sensor', "sensor",
self.measurement, self.measurement,
] ]
) )
@ -92,11 +90,7 @@ class AdGuardHomeDNSQueriesSensor(AdGuardHomeSensor):
def __init__(self, adguard): def __init__(self, adguard):
"""Initialize AdGuard Home sensor.""" """Initialize AdGuard Home sensor."""
super().__init__( super().__init__(
adguard, adguard, "AdGuard DNS Queries", "mdi:magnify", "dns_queries", "queries"
'AdGuard DNS Queries',
'mdi:magnify',
'dns_queries',
'queries',
) )
async def _adguard_update(self) -> None: async def _adguard_update(self) -> None:
@ -111,10 +105,10 @@ class AdGuardHomeBlockedFilteringSensor(AdGuardHomeSensor):
"""Initialize AdGuard Home sensor.""" """Initialize AdGuard Home sensor."""
super().__init__( super().__init__(
adguard, adguard,
'AdGuard DNS Queries Blocked', "AdGuard DNS Queries Blocked",
'mdi:magnify-close', "mdi:magnify-close",
'blocked_filtering', "blocked_filtering",
'queries', "queries",
) )
async def _adguard_update(self) -> None: async def _adguard_update(self) -> None:
@ -129,10 +123,10 @@ class AdGuardHomePercentageBlockedSensor(AdGuardHomeSensor):
"""Initialize AdGuard Home sensor.""" """Initialize AdGuard Home sensor."""
super().__init__( super().__init__(
adguard, adguard,
'AdGuard DNS Queries Blocked Ratio', "AdGuard DNS Queries Blocked Ratio",
'mdi:magnify-close', "mdi:magnify-close",
'blocked_percentage', "blocked_percentage",
'%', "%",
) )
async def _adguard_update(self) -> None: async def _adguard_update(self) -> None:
@ -148,10 +142,10 @@ class AdGuardHomeReplacedParentalSensor(AdGuardHomeSensor):
"""Initialize AdGuard Home sensor.""" """Initialize AdGuard Home sensor."""
super().__init__( super().__init__(
adguard, adguard,
'AdGuard Parental Control Blocked', "AdGuard Parental Control Blocked",
'mdi:human-male-girl', "mdi:human-male-girl",
'blocked_parental', "blocked_parental",
'requests', "requests",
) )
async def _adguard_update(self) -> None: async def _adguard_update(self) -> None:
@ -166,10 +160,10 @@ class AdGuardHomeReplacedSafeBrowsingSensor(AdGuardHomeSensor):
"""Initialize AdGuard Home sensor.""" """Initialize AdGuard Home sensor."""
super().__init__( super().__init__(
adguard, adguard,
'AdGuard Safe Browsing Blocked', "AdGuard Safe Browsing Blocked",
'mdi:shield-half-full', "mdi:shield-half-full",
'blocked_safebrowsing', "blocked_safebrowsing",
'requests', "requests",
) )
async def _adguard_update(self) -> None: async def _adguard_update(self) -> None:
@ -184,10 +178,10 @@ class AdGuardHomeReplacedSafeSearchSensor(AdGuardHomeSensor):
"""Initialize AdGuard Home sensor.""" """Initialize AdGuard Home sensor."""
super().__init__( super().__init__(
adguard, adguard,
'Searches Safe Search Enforced', "Searches Safe Search Enforced",
'mdi:shield-search', "mdi:shield-search",
'enforced_safesearch', "enforced_safesearch",
'requests', "requests",
) )
async def _adguard_update(self) -> None: async def _adguard_update(self) -> None:
@ -202,10 +196,10 @@ class AdGuardHomeAverageProcessingTimeSensor(AdGuardHomeSensor):
"""Initialize AdGuard Home sensor.""" """Initialize AdGuard Home sensor."""
super().__init__( super().__init__(
adguard, adguard,
'AdGuard Average Processing Speed', "AdGuard Average Processing Speed",
'mdi:speedometer', "mdi:speedometer",
'average_speed', "average_speed",
'ms', "ms",
) )
async def _adguard_update(self) -> None: async def _adguard_update(self) -> None:
@ -220,11 +214,7 @@ class AdGuardHomeRulesCountSensor(AdGuardHomeSensor):
def __init__(self, adguard): def __init__(self, adguard):
"""Initialize AdGuard Home sensor.""" """Initialize AdGuard Home sensor."""
super().__init__( super().__init__(
adguard, adguard, "AdGuard Rules Count", "mdi:counter", "rules_count", "rules"
'AdGuard Rules Count',
'mdi:counter',
'rules_count',
'rules',
) )
async def _adguard_update(self) -> None: async def _adguard_update(self) -> None:

View file

@ -6,7 +6,10 @@ from adguardhome import AdGuardHomeConnectionError, AdGuardHomeError
from homeassistant.components.adguard import AdGuardHomeDeviceEntity from homeassistant.components.adguard import AdGuardHomeDeviceEntity
from homeassistant.components.adguard.const import ( from homeassistant.components.adguard.const import (
DATA_ADGUARD_CLIENT, DATA_ADGUARD_VERION, DOMAIN) DATA_ADGUARD_CLIENT,
DATA_ADGUARD_VERION,
DOMAIN,
)
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.exceptions import PlatformNotReady from homeassistant.exceptions import PlatformNotReady
from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity import ToggleEntity
@ -19,7 +22,7 @@ PARALLEL_UPDATES = 1
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistantType, entry: ConfigEntry, async_add_entities hass: HomeAssistantType, entry: ConfigEntry, async_add_entities
) -> None: ) -> None:
"""Set up AdGuard Home switch based on a config entry.""" """Set up AdGuard Home switch based on a config entry."""
adguard = hass.data[DOMAIN][DATA_ADGUARD_CLIENT] adguard = hass.data[DOMAIN][DATA_ADGUARD_CLIENT]
@ -54,14 +57,8 @@ class AdGuardHomeSwitch(ToggleEntity, AdGuardHomeDeviceEntity):
@property @property
def unique_id(self) -> str: def unique_id(self) -> str:
"""Return the unique ID for this sensor.""" """Return the unique ID for this sensor."""
return '_'.join( return "_".join(
[ [DOMAIN, self.adguard.host, str(self.adguard.port), "switch", self._key]
DOMAIN,
self.adguard.host,
str(self.adguard.port),
'switch',
self._key,
]
) )
@property @property
@ -74,9 +71,7 @@ class AdGuardHomeSwitch(ToggleEntity, AdGuardHomeDeviceEntity):
try: try:
await self._adguard_turn_off() await self._adguard_turn_off()
except AdGuardHomeError: except AdGuardHomeError:
_LOGGER.error( _LOGGER.error("An error occurred while turning off AdGuard Home switch.")
"An error occurred while turning off AdGuard Home switch."
)
self._available = False self._available = False
async def _adguard_turn_off(self) -> None: async def _adguard_turn_off(self) -> None:
@ -88,9 +83,7 @@ class AdGuardHomeSwitch(ToggleEntity, AdGuardHomeDeviceEntity):
try: try:
await self._adguard_turn_on() await self._adguard_turn_on()
except AdGuardHomeError: except AdGuardHomeError:
_LOGGER.error( _LOGGER.error("An error occurred while turning on AdGuard Home switch.")
"An error occurred while turning on AdGuard Home switch."
)
self._available = False self._available = False
async def _adguard_turn_on(self) -> None: async def _adguard_turn_on(self) -> None:
@ -104,7 +97,7 @@ class AdGuardHomeProtectionSwitch(AdGuardHomeSwitch):
def __init__(self, adguard) -> None: def __init__(self, adguard) -> None:
"""Initialize AdGuard Home switch.""" """Initialize AdGuard Home switch."""
super().__init__( super().__init__(
adguard, "AdGuard Protection", 'mdi:shield-check', 'protection' adguard, "AdGuard Protection", "mdi:shield-check", "protection"
) )
async def _adguard_turn_off(self) -> None: async def _adguard_turn_off(self) -> None:
@ -126,7 +119,7 @@ class AdGuardHomeParentalSwitch(AdGuardHomeSwitch):
def __init__(self, adguard) -> None: def __init__(self, adguard) -> None:
"""Initialize AdGuard Home switch.""" """Initialize AdGuard Home switch."""
super().__init__( super().__init__(
adguard, "AdGuard Parental Control", 'mdi:shield-check', 'parental' adguard, "AdGuard Parental Control", "mdi:shield-check", "parental"
) )
async def _adguard_turn_off(self) -> None: async def _adguard_turn_off(self) -> None:
@ -148,7 +141,7 @@ class AdGuardHomeSafeSearchSwitch(AdGuardHomeSwitch):
def __init__(self, adguard) -> None: def __init__(self, adguard) -> None:
"""Initialize AdGuard Home switch.""" """Initialize AdGuard Home switch."""
super().__init__( super().__init__(
adguard, "AdGuard Safe Search", 'mdi:shield-check', 'safesearch' adguard, "AdGuard Safe Search", "mdi:shield-check", "safesearch"
) )
async def _adguard_turn_off(self) -> None: async def _adguard_turn_off(self) -> None:
@ -170,10 +163,7 @@ class AdGuardHomeSafeBrowsingSwitch(AdGuardHomeSwitch):
def __init__(self, adguard) -> None: def __init__(self, adguard) -> None:
"""Initialize AdGuard Home switch.""" """Initialize AdGuard Home switch."""
super().__init__( super().__init__(
adguard, adguard, "AdGuard Safe Browsing", "mdi:shield-check", "safebrowsing"
"AdGuard Safe Browsing",
'mdi:shield-check',
'safebrowsing',
) )
async def _adguard_turn_off(self) -> None: async def _adguard_turn_off(self) -> None:
@ -194,9 +184,7 @@ class AdGuardHomeFilteringSwitch(AdGuardHomeSwitch):
def __init__(self, adguard) -> None: def __init__(self, adguard) -> None:
"""Initialize AdGuard Home switch.""" """Initialize AdGuard Home switch."""
super().__init__( super().__init__(adguard, "AdGuard Filtering", "mdi:shield-check", "filtering")
adguard, "AdGuard Filtering", 'mdi:shield-check', 'filtering'
)
async def _adguard_turn_off(self) -> None: async def _adguard_turn_off(self) -> None:
"""Turn off the switch.""" """Turn off the switch."""
@ -216,9 +204,7 @@ class AdGuardHomeQueryLogSwitch(AdGuardHomeSwitch):
def __init__(self, adguard) -> None: def __init__(self, adguard) -> None:
"""Initialize AdGuard Home switch.""" """Initialize AdGuard Home switch."""
super().__init__( super().__init__(adguard, "AdGuard Query Log", "mdi:shield-check", "querylog")
adguard, "AdGuard Query Log", 'mdi:shield-check', 'querylog'
)
async def _adguard_turn_off(self) -> None: async def _adguard_turn_off(self) -> None:
"""Turn off the switch.""" """Turn off the switch."""

View file

@ -10,57 +10,76 @@ import async_timeout
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
CONF_DEVICE, CONF_IP_ADDRESS, CONF_PORT, EVENT_HOMEASSISTANT_STOP) CONF_DEVICE,
CONF_IP_ADDRESS,
CONF_PORT,
EVENT_HOMEASSISTANT_STOP,
)
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DATA_ADS = 'data_ads' DATA_ADS = "data_ads"
# Supported Types # Supported Types
ADSTYPE_BOOL = 'bool' ADSTYPE_BOOL = "bool"
ADSTYPE_BYTE = 'byte' ADSTYPE_BYTE = "byte"
ADSTYPE_DINT = 'dint' ADSTYPE_DINT = "dint"
ADSTYPE_INT = 'int' ADSTYPE_INT = "int"
ADSTYPE_UDINT = 'udint' ADSTYPE_UDINT = "udint"
ADSTYPE_UINT = 'uint' ADSTYPE_UINT = "uint"
CONF_ADS_FACTOR = 'factor' CONF_ADS_FACTOR = "factor"
CONF_ADS_TYPE = 'adstype' CONF_ADS_TYPE = "adstype"
CONF_ADS_VALUE = 'value' CONF_ADS_VALUE = "value"
CONF_ADS_VAR = 'adsvar' CONF_ADS_VAR = "adsvar"
CONF_ADS_VAR_BRIGHTNESS = 'adsvar_brightness' CONF_ADS_VAR_BRIGHTNESS = "adsvar_brightness"
CONF_ADS_VAR_POSITION = 'adsvar_position' CONF_ADS_VAR_POSITION = "adsvar_position"
STATE_KEY_STATE = 'state' STATE_KEY_STATE = "state"
STATE_KEY_BRIGHTNESS = 'brightness' STATE_KEY_BRIGHTNESS = "brightness"
STATE_KEY_POSITION = 'position' STATE_KEY_POSITION = "position"
DOMAIN = 'ads' DOMAIN = "ads"
SERVICE_WRITE_DATA_BY_NAME = 'write_data_by_name' SERVICE_WRITE_DATA_BY_NAME = "write_data_by_name"
CONFIG_SCHEMA = vol.Schema({ CONFIG_SCHEMA = vol.Schema(
DOMAIN: vol.Schema({ {
vol.Required(CONF_DEVICE): cv.string, DOMAIN: vol.Schema(
vol.Required(CONF_PORT): cv.port, {
vol.Optional(CONF_IP_ADDRESS): cv.string, vol.Required(CONF_DEVICE): cv.string,
}) vol.Required(CONF_PORT): cv.port,
}, extra=vol.ALLOW_EXTRA) vol.Optional(CONF_IP_ADDRESS): cv.string,
}
)
},
extra=vol.ALLOW_EXTRA,
)
SCHEMA_SERVICE_WRITE_DATA_BY_NAME = vol.Schema({ SCHEMA_SERVICE_WRITE_DATA_BY_NAME = vol.Schema(
vol.Required(CONF_ADS_TYPE): {
vol.In([ADSTYPE_INT, ADSTYPE_UINT, ADSTYPE_BYTE, ADSTYPE_BOOL, vol.Required(CONF_ADS_TYPE): vol.In(
ADSTYPE_DINT, ADSTYPE_UDINT]), [
vol.Required(CONF_ADS_VALUE): vol.Coerce(int), ADSTYPE_INT,
vol.Required(CONF_ADS_VAR): cv.string, ADSTYPE_UINT,
}) ADSTYPE_BYTE,
ADSTYPE_BOOL,
ADSTYPE_DINT,
ADSTYPE_UDINT,
]
),
vol.Required(CONF_ADS_VALUE): vol.Coerce(int),
vol.Required(CONF_ADS_VAR): cv.string,
}
)
def setup(hass, config): def setup(hass, config):
"""Set up the ADS component.""" """Set up the ADS component."""
import pyads import pyads
conf = config[DOMAIN] conf = config[DOMAIN]
net_id = conf.get(CONF_DEVICE) net_id = conf.get(CONF_DEVICE)
@ -91,7 +110,10 @@ def setup(hass, config):
except pyads.ADSError: except pyads.ADSError:
_LOGGER.error( _LOGGER.error(
"Could not connect to ADS host (netid=%s, ip=%s, port=%s)", "Could not connect to ADS host (netid=%s, ip=%s, port=%s)",
net_id, ip_address, port) net_id,
ip_address,
port,
)
return False return False
hass.data[DATA_ADS] = ads hass.data[DATA_ADS] = ads
@ -109,15 +131,18 @@ def setup(hass, config):
_LOGGER.error(err) _LOGGER.error(err)
hass.services.register( hass.services.register(
DOMAIN, SERVICE_WRITE_DATA_BY_NAME, handle_write_data_by_name, DOMAIN,
schema=SCHEMA_SERVICE_WRITE_DATA_BY_NAME) SERVICE_WRITE_DATA_BY_NAME,
handle_write_data_by_name,
schema=SCHEMA_SERVICE_WRITE_DATA_BY_NAME,
)
return True return True
# Tuple to hold data needed for notification # Tuple to hold data needed for notification
NotificationItem = namedtuple( NotificationItem = namedtuple(
'NotificationItem', 'hnotify huser name plc_datatype callback' "NotificationItem", "hnotify huser name plc_datatype callback"
) )
@ -137,15 +162,17 @@ class AdsHub:
def shutdown(self, *args, **kwargs): def shutdown(self, *args, **kwargs):
"""Shutdown ADS connection.""" """Shutdown ADS connection."""
import pyads import pyads
_LOGGER.debug("Shutting down ADS") _LOGGER.debug("Shutting down ADS")
for notification_item in self._notification_items.values(): for notification_item in self._notification_items.values():
_LOGGER.debug( _LOGGER.debug(
"Deleting device notification %d, %d", "Deleting device notification %d, %d",
notification_item.hnotify, notification_item.huser) notification_item.hnotify,
notification_item.huser,
)
try: try:
self._client.del_device_notification( self._client.del_device_notification(
notification_item.hnotify, notification_item.hnotify, notification_item.huser
notification_item.huser
) )
except pyads.ADSError as err: except pyads.ADSError as err:
_LOGGER.error(err) _LOGGER.error(err)
@ -161,6 +188,7 @@ class AdsHub:
def write_by_name(self, name, value, plc_datatype): def write_by_name(self, name, value, plc_datatype):
"""Write a value to the device.""" """Write a value to the device."""
import pyads import pyads
with self._lock: with self._lock:
try: try:
return self._client.write_by_name(name, value, plc_datatype) return self._client.write_by_name(name, value, plc_datatype)
@ -170,6 +198,7 @@ class AdsHub:
def read_by_name(self, name, plc_datatype): def read_by_name(self, name, plc_datatype):
"""Read a value from the device.""" """Read a value from the device."""
import pyads import pyads
with self._lock: with self._lock:
try: try:
return self._client.read_by_name(name, plc_datatype) return self._client.read_by_name(name, plc_datatype)
@ -179,22 +208,25 @@ class AdsHub:
def add_device_notification(self, name, plc_datatype, callback): def add_device_notification(self, name, plc_datatype, callback):
"""Add a notification to the ADS devices.""" """Add a notification to the ADS devices."""
import pyads import pyads
attr = pyads.NotificationAttrib(ctypes.sizeof(plc_datatype)) attr = pyads.NotificationAttrib(ctypes.sizeof(plc_datatype))
with self._lock: with self._lock:
try: try:
hnotify, huser = self._client.add_device_notification( hnotify, huser = self._client.add_device_notification(
name, attr, self._device_notification_callback) name, attr, self._device_notification_callback
)
except pyads.ADSError as err: except pyads.ADSError as err:
_LOGGER.error("Error subscribing to %s: %s", name, err) _LOGGER.error("Error subscribing to %s: %s", name, err)
else: else:
hnotify = int(hnotify) hnotify = int(hnotify)
self._notification_items[hnotify] = NotificationItem( self._notification_items[hnotify] = NotificationItem(
hnotify, huser, name, plc_datatype, callback) hnotify, huser, name, plc_datatype, callback
)
_LOGGER.debug( _LOGGER.debug(
"Added device notification %d for variable %s", "Added device notification %d for variable %s", hnotify, name
hnotify, name) )
def _device_notification_callback(self, notification, name): def _device_notification_callback(self, notification, name):
"""Handle device notifications.""" """Handle device notifications."""
@ -213,17 +245,17 @@ class AdsHub:
# Parse data to desired datatype # Parse data to desired datatype
if notification_item.plc_datatype == self.PLCTYPE_BOOL: if notification_item.plc_datatype == self.PLCTYPE_BOOL:
value = bool(struct.unpack('<?', bytearray(data)[:1])[0]) value = bool(struct.unpack("<?", bytearray(data)[:1])[0])
elif notification_item.plc_datatype == self.PLCTYPE_INT: elif notification_item.plc_datatype == self.PLCTYPE_INT:
value = struct.unpack('<h', bytearray(data)[:2])[0] value = struct.unpack("<h", bytearray(data)[:2])[0]
elif notification_item.plc_datatype == self.PLCTYPE_BYTE: elif notification_item.plc_datatype == self.PLCTYPE_BYTE:
value = struct.unpack('<B', bytearray(data)[:1])[0] value = struct.unpack("<B", bytearray(data)[:1])[0]
elif notification_item.plc_datatype == self.PLCTYPE_UINT: elif notification_item.plc_datatype == self.PLCTYPE_UINT:
value = struct.unpack('<H', bytearray(data)[:2])[0] value = struct.unpack("<H", bytearray(data)[:2])[0]
elif notification_item.plc_datatype == self.PLCTYPE_DINT: elif notification_item.plc_datatype == self.PLCTYPE_DINT:
value = struct.unpack('<i', bytearray(data)[:4])[0] value = struct.unpack("<i", bytearray(data)[:4])[0]
elif notification_item.plc_datatype == self.PLCTYPE_UDINT: elif notification_item.plc_datatype == self.PLCTYPE_UDINT:
value = struct.unpack('<I', bytearray(data)[:4])[0] value = struct.unpack("<I", bytearray(data)[:4])[0]
else: else:
value = bytearray(data) value = bytearray(data)
_LOGGER.warning("No callback available for this datatype") _LOGGER.warning("No callback available for this datatype")
@ -245,11 +277,13 @@ class AdsEntity(Entity):
self._event = None self._event = None
async def async_initialize_device( async def async_initialize_device(
self, ads_var, plctype, state_key=STATE_KEY_STATE, factor=None): self, ads_var, plctype, state_key=STATE_KEY_STATE, factor=None
):
"""Register device notification.""" """Register device notification."""
def update(name, value): def update(name, value):
"""Handle device notifications.""" """Handle device notifications."""
_LOGGER.debug('Variable %s changed its value to %d', name, value) _LOGGER.debug("Variable %s changed its value to %d", name, value)
if factor is None: if factor is None:
self._state_dict[state_key] = value self._state_dict[state_key] = value
@ -266,14 +300,13 @@ class AdsEntity(Entity):
self._event = asyncio.Event() self._event = asyncio.Event()
await self.hass.async_add_executor_job( await self.hass.async_add_executor_job(
self._ads_hub.add_device_notification, self._ads_hub.add_device_notification, ads_var, plctype, update
ads_var, plctype, update) )
try: try:
with async_timeout.timeout(10): with async_timeout.timeout(10):
await self._event.wait() await self._event.wait()
except asyncio.TimeoutError: except asyncio.TimeoutError:
_LOGGER.debug('Variable %s: Timeout during first update', _LOGGER.debug("Variable %s: Timeout during first update", ads_var)
ads_var)
@property @property
def name(self): def name(self):

View file

@ -4,7 +4,10 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.components.binary_sensor import ( from homeassistant.components.binary_sensor import (
DEVICE_CLASSES_SCHEMA, PLATFORM_SCHEMA, BinarySensorDevice) DEVICE_CLASSES_SCHEMA,
PLATFORM_SCHEMA,
BinarySensorDevice,
)
from homeassistant.const import CONF_DEVICE_CLASS, CONF_NAME from homeassistant.const import CONF_DEVICE_CLASS, CONF_NAME
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -12,12 +15,14 @@ from . import CONF_ADS_VAR, DATA_ADS, AdsEntity, STATE_KEY_STATE
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DEFAULT_NAME = 'ADS binary sensor' DEFAULT_NAME = "ADS binary sensor"
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_ADS_VAR): cv.string, {
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Required(CONF_ADS_VAR): cv.string,
vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
}) vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA,
}
)
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
@ -38,12 +43,11 @@ class AdsBinarySensor(AdsEntity, BinarySensorDevice):
def __init__(self, ads_hub, name, ads_var, device_class): def __init__(self, ads_hub, name, ads_var, device_class):
"""Initialize ADS binary sensor.""" """Initialize ADS binary sensor."""
super().__init__(ads_hub, name, ads_var) super().__init__(ads_hub, name, ads_var)
self._device_class = device_class or 'moving' self._device_class = device_class or "moving"
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Register device notification.""" """Register device notification."""
await self.async_initialize_device(self._ads_var, await self.async_initialize_device(self._ads_var, self._ads_hub.PLCTYPE_BOOL)
self._ads_hub.PLCTYPE_BOOL)
@property @property
def is_on(self): def is_on(self):

View file

@ -4,35 +4,48 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.components.cover import ( from homeassistant.components.cover import (
PLATFORM_SCHEMA, SUPPORT_OPEN, SUPPORT_CLOSE, SUPPORT_STOP, PLATFORM_SCHEMA,
SUPPORT_SET_POSITION, ATTR_POSITION, DEVICE_CLASSES_SCHEMA, SUPPORT_OPEN,
CoverDevice) SUPPORT_CLOSE,
from homeassistant.const import ( SUPPORT_STOP,
CONF_NAME, CONF_DEVICE_CLASS) SUPPORT_SET_POSITION,
ATTR_POSITION,
DEVICE_CLASSES_SCHEMA,
CoverDevice,
)
from homeassistant.const import CONF_NAME, CONF_DEVICE_CLASS
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from . import CONF_ADS_VAR, CONF_ADS_VAR_POSITION, DATA_ADS, \ from . import (
AdsEntity, STATE_KEY_STATE, STATE_KEY_POSITION CONF_ADS_VAR,
CONF_ADS_VAR_POSITION,
DATA_ADS,
AdsEntity,
STATE_KEY_STATE,
STATE_KEY_POSITION,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DEFAULT_NAME = 'ADS Cover' DEFAULT_NAME = "ADS Cover"
CONF_ADS_VAR_SET_POS = 'adsvar_set_position' CONF_ADS_VAR_SET_POS = "adsvar_set_position"
CONF_ADS_VAR_OPEN = 'adsvar_open' CONF_ADS_VAR_OPEN = "adsvar_open"
CONF_ADS_VAR_CLOSE = 'adsvar_close' CONF_ADS_VAR_CLOSE = "adsvar_close"
CONF_ADS_VAR_STOP = 'adsvar_stop' CONF_ADS_VAR_STOP = "adsvar_stop"
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Optional(CONF_ADS_VAR): cv.string, {
vol.Optional(CONF_ADS_VAR_POSITION): cv.string, vol.Optional(CONF_ADS_VAR): cv.string,
vol.Optional(CONF_ADS_VAR_SET_POS): cv.string, vol.Optional(CONF_ADS_VAR_POSITION): cv.string,
vol.Optional(CONF_ADS_VAR_CLOSE): cv.string, vol.Optional(CONF_ADS_VAR_SET_POS): cv.string,
vol.Optional(CONF_ADS_VAR_OPEN): cv.string, vol.Optional(CONF_ADS_VAR_CLOSE): cv.string,
vol.Optional(CONF_ADS_VAR_STOP): cv.string, vol.Optional(CONF_ADS_VAR_OPEN): cv.string,
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_ADS_VAR_STOP): cv.string,
vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
}) vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA,
}
)
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
@ -48,24 +61,38 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
name = config[CONF_NAME] name = config[CONF_NAME]
device_class = config.get(CONF_DEVICE_CLASS) device_class = config.get(CONF_DEVICE_CLASS)
add_entities([AdsCover(ads_hub, add_entities(
ads_var_is_closed, [
ads_var_position, AdsCover(
ads_var_pos_set, ads_hub,
ads_var_open, ads_var_is_closed,
ads_var_close, ads_var_position,
ads_var_stop, ads_var_pos_set,
name, ads_var_open,
device_class)]) ads_var_close,
ads_var_stop,
name,
device_class,
)
]
)
class AdsCover(AdsEntity, CoverDevice): class AdsCover(AdsEntity, CoverDevice):
"""Representation of ADS cover.""" """Representation of ADS cover."""
def __init__(self, ads_hub, def __init__(
ads_var_is_closed, ads_var_position, self,
ads_var_pos_set, ads_var_open, ads_hub,
ads_var_close, ads_var_stop, name, device_class): ads_var_is_closed,
ads_var_position,
ads_var_pos_set,
ads_var_open,
ads_var_close,
ads_var_stop,
name,
device_class,
):
"""Initialize AdsCover entity.""" """Initialize AdsCover entity."""
super().__init__(ads_hub, name, ads_var_is_closed) super().__init__(ads_hub, name, ads_var_is_closed)
if self._ads_var is None: if self._ads_var is None:
@ -87,13 +114,14 @@ class AdsCover(AdsEntity, CoverDevice):
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Register device notification.""" """Register device notification."""
if self._ads_var is not None: if self._ads_var is not None:
await self.async_initialize_device(self._ads_var, await self.async_initialize_device(
self._ads_hub.PLCTYPE_BOOL) self._ads_var, self._ads_hub.PLCTYPE_BOOL
)
if self._ads_var_position is not None: if self._ads_var_position is not None:
await self.async_initialize_device(self._ads_var_position, await self.async_initialize_device(
self._ads_hub.PLCTYPE_BYTE, self._ads_var_position, self._ads_hub.PLCTYPE_BYTE, STATE_KEY_POSITION
STATE_KEY_POSITION) )
@property @property
def device_class(self): def device_class(self):
@ -130,29 +158,33 @@ class AdsCover(AdsEntity, CoverDevice):
def stop_cover(self, **kwargs): def stop_cover(self, **kwargs):
"""Fire the stop action.""" """Fire the stop action."""
if self._ads_var_stop: if self._ads_var_stop:
self._ads_hub.write_by_name(self._ads_var_stop, True, self._ads_hub.write_by_name(
self._ads_hub.PLCTYPE_BOOL) self._ads_var_stop, True, self._ads_hub.PLCTYPE_BOOL
)
def set_cover_position(self, **kwargs): def set_cover_position(self, **kwargs):
"""Set cover position.""" """Set cover position."""
position = kwargs[ATTR_POSITION] position = kwargs[ATTR_POSITION]
if self._ads_var_pos_set is not None: if self._ads_var_pos_set is not None:
self._ads_hub.write_by_name(self._ads_var_pos_set, position, self._ads_hub.write_by_name(
self._ads_hub.PLCTYPE_BYTE) self._ads_var_pos_set, position, self._ads_hub.PLCTYPE_BYTE
)
def open_cover(self, **kwargs): def open_cover(self, **kwargs):
"""Move the cover up.""" """Move the cover up."""
if self._ads_var_open is not None: if self._ads_var_open is not None:
self._ads_hub.write_by_name(self._ads_var_open, True, self._ads_hub.write_by_name(
self._ads_hub.PLCTYPE_BOOL) self._ads_var_open, True, self._ads_hub.PLCTYPE_BOOL
)
elif self._ads_var_pos_set is not None: elif self._ads_var_pos_set is not None:
self.set_cover_position(position=100) self.set_cover_position(position=100)
def close_cover(self, **kwargs): def close_cover(self, **kwargs):
"""Move the cover down.""" """Move the cover down."""
if self._ads_var_close is not None: if self._ads_var_close is not None:
self._ads_hub.write_by_name(self._ads_var_close, True, self._ads_hub.write_by_name(
self._ads_hub.PLCTYPE_BOOL) self._ads_var_close, True, self._ads_hub.PLCTYPE_BOOL
)
elif self._ads_var_pos_set is not None: elif self._ads_var_pos_set is not None:
self.set_cover_position(position=0) self.set_cover_position(position=0)
@ -160,6 +192,8 @@ class AdsCover(AdsEntity, CoverDevice):
def available(self): def available(self):
"""Return False if state has not been updated yet.""" """Return False if state has not been updated yet."""
if self._ads_var is not None or self._ads_var_position is not None: if self._ads_var is not None or self._ads_var_position is not None:
return self._state_dict[STATE_KEY_STATE] is not None or \ return (
self._state_dict[STATE_KEY_POSITION] is not None self._state_dict[STATE_KEY_STATE] is not None
or self._state_dict[STATE_KEY_POSITION] is not None
)
return True return True

View file

@ -4,20 +4,32 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.components.light import ( from homeassistant.components.light import (
ATTR_BRIGHTNESS, PLATFORM_SCHEMA, SUPPORT_BRIGHTNESS, Light) ATTR_BRIGHTNESS,
PLATFORM_SCHEMA,
SUPPORT_BRIGHTNESS,
Light,
)
from homeassistant.const import CONF_NAME from homeassistant.const import CONF_NAME
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from . import CONF_ADS_VAR, CONF_ADS_VAR_BRIGHTNESS, DATA_ADS, \ from . import (
AdsEntity, STATE_KEY_BRIGHTNESS, STATE_KEY_STATE CONF_ADS_VAR,
CONF_ADS_VAR_BRIGHTNESS,
DATA_ADS,
AdsEntity,
STATE_KEY_BRIGHTNESS,
STATE_KEY_STATE,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DEFAULT_NAME = 'ADS Light' DEFAULT_NAME = "ADS Light"
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_ADS_VAR): cv.string, {
vol.Optional(CONF_ADS_VAR_BRIGHTNESS): cv.string, vol.Required(CONF_ADS_VAR): cv.string,
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string vol.Optional(CONF_ADS_VAR_BRIGHTNESS): cv.string,
}) vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
}
)
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
@ -28,8 +40,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
ads_var_brightness = config.get(CONF_ADS_VAR_BRIGHTNESS) ads_var_brightness = config.get(CONF_ADS_VAR_BRIGHTNESS)
name = config.get(CONF_NAME) name = config.get(CONF_NAME)
add_entities([AdsLight(ads_hub, ads_var_enable, ads_var_brightness, add_entities([AdsLight(ads_hub, ads_var_enable, ads_var_brightness, name)])
name)])
class AdsLight(AdsEntity, Light): class AdsLight(AdsEntity, Light):
@ -43,13 +54,14 @@ class AdsLight(AdsEntity, Light):
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Register device notification.""" """Register device notification."""
await self.async_initialize_device(self._ads_var, await self.async_initialize_device(self._ads_var, self._ads_hub.PLCTYPE_BOOL)
self._ads_hub.PLCTYPE_BOOL)
if self._ads_var_brightness is not None: if self._ads_var_brightness is not None:
await self.async_initialize_device(self._ads_var_brightness, await self.async_initialize_device(
self._ads_hub.PLCTYPE_UINT, self._ads_var_brightness,
STATE_KEY_BRIGHTNESS) self._ads_hub.PLCTYPE_UINT,
STATE_KEY_BRIGHTNESS,
)
@property @property
def brightness(self): def brightness(self):
@ -72,14 +84,13 @@ class AdsLight(AdsEntity, Light):
def turn_on(self, **kwargs): def turn_on(self, **kwargs):
"""Turn the light on or set a specific dimmer value.""" """Turn the light on or set a specific dimmer value."""
brightness = kwargs.get(ATTR_BRIGHTNESS) brightness = kwargs.get(ATTR_BRIGHTNESS)
self._ads_hub.write_by_name(self._ads_var, True, self._ads_hub.write_by_name(self._ads_var, True, self._ads_hub.PLCTYPE_BOOL)
self._ads_hub.PLCTYPE_BOOL)
if self._ads_var_brightness is not None and brightness is not None: if self._ads_var_brightness is not None and brightness is not None:
self._ads_hub.write_by_name(self._ads_var_brightness, brightness, self._ads_hub.write_by_name(
self._ads_hub.PLCTYPE_UINT) self._ads_var_brightness, brightness, self._ads_hub.PLCTYPE_UINT
)
def turn_off(self, **kwargs): def turn_off(self, **kwargs):
"""Turn the light off.""" """Turn the light off."""
self._ads_hub.write_by_name(self._ads_var, False, self._ads_hub.write_by_name(self._ads_var, False, self._ads_hub.PLCTYPE_BOOL)
self._ads_hub.PLCTYPE_BOOL)

View file

@ -8,21 +8,28 @@ from homeassistant.components.sensor import PLATFORM_SCHEMA
from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT from homeassistant.const import CONF_NAME, CONF_UNIT_OF_MEASUREMENT
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from . import CONF_ADS_FACTOR, CONF_ADS_TYPE, CONF_ADS_VAR, \ from . import CONF_ADS_FACTOR, CONF_ADS_TYPE, CONF_ADS_VAR, AdsEntity, STATE_KEY_STATE
AdsEntity, STATE_KEY_STATE
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DEFAULT_NAME = "ADS sensor" DEFAULT_NAME = "ADS sensor"
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_ADS_VAR): cv.string, {
vol.Optional(CONF_ADS_FACTOR): cv.positive_int, vol.Required(CONF_ADS_VAR): cv.string,
vol.Optional(CONF_ADS_TYPE, default=ads.ADSTYPE_INT): vol.Optional(CONF_ADS_FACTOR): cv.positive_int,
vol.In([ads.ADSTYPE_INT, ads.ADSTYPE_UINT, ads.ADSTYPE_BYTE, vol.Optional(CONF_ADS_TYPE, default=ads.ADSTYPE_INT): vol.In(
ads.ADSTYPE_DINT, ads.ADSTYPE_UDINT]), [
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, ads.ADSTYPE_INT,
vol.Optional(CONF_UNIT_OF_MEASUREMENT, default=''): cv.string, ads.ADSTYPE_UINT,
}) ads.ADSTYPE_BYTE,
ads.ADSTYPE_DINT,
ads.ADSTYPE_UDINT,
]
),
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Optional(CONF_UNIT_OF_MEASUREMENT, default=""): cv.string,
}
)
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
@ -35,8 +42,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT) unit_of_measurement = config.get(CONF_UNIT_OF_MEASUREMENT)
factor = config.get(CONF_ADS_FACTOR) factor = config.get(CONF_ADS_FACTOR)
entity = AdsSensor( entity = AdsSensor(ads_hub, ads_var, ads_type, name, unit_of_measurement, factor)
ads_hub, ads_var, ads_type, name, unit_of_measurement, factor)
add_entities([entity]) add_entities([entity])
@ -44,8 +50,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
class AdsSensor(AdsEntity): class AdsSensor(AdsEntity):
"""Representation of an ADS sensor entity.""" """Representation of an ADS sensor entity."""
def __init__(self, ads_hub, ads_var, ads_type, name, unit_of_measurement, def __init__(self, ads_hub, ads_var, ads_type, name, unit_of_measurement, factor):
factor):
"""Initialize AdsSensor entity.""" """Initialize AdsSensor entity."""
super().__init__(ads_hub, name, ads_var) super().__init__(ads_hub, name, ads_var)
self._unit_of_measurement = unit_of_measurement self._unit_of_measurement = unit_of_measurement
@ -58,7 +63,8 @@ class AdsSensor(AdsEntity):
self._ads_var, self._ads_var,
self._ads_hub.ADS_TYPEMAP[self._ads_type], self._ads_hub.ADS_TYPEMAP[self._ads_type],
STATE_KEY_STATE, STATE_KEY_STATE,
self._factor) self._factor,
)
@property @property
def state(self): def state(self):

View file

@ -11,12 +11,11 @@ from . import CONF_ADS_VAR, DATA_ADS, AdsEntity, STATE_KEY_STATE
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DEFAULT_NAME = 'ADS Switch' DEFAULT_NAME = "ADS Switch"
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_ADS_VAR): cv.string, {vol.Required(CONF_ADS_VAR): cv.string, vol.Optional(CONF_NAME): cv.string}
vol.Optional(CONF_NAME): cv.string, )
})
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
@ -34,8 +33,7 @@ class AdsSwitch(AdsEntity, SwitchDevice):
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Register device notification.""" """Register device notification."""
await self.async_initialize_device(self._ads_var, await self.async_initialize_device(self._ads_var, self._ads_hub.PLCTYPE_BOOL)
self._ads_hub.PLCTYPE_BOOL)
@property @property
def is_on(self): def is_on(self):
@ -44,10 +42,8 @@ class AdsSwitch(AdsEntity, SwitchDevice):
def turn_on(self, **kwargs): def turn_on(self, **kwargs):
"""Turn the switch on.""" """Turn the switch on."""
self._ads_hub.write_by_name( self._ads_hub.write_by_name(self._ads_var, True, self._ads_hub.PLCTYPE_BOOL)
self._ads_var, True, self._ads_hub.PLCTYPE_BOOL)
def turn_off(self, **kwargs): def turn_off(self, **kwargs):
"""Turn the switch off.""" """Turn the switch off."""
self._ads_hub.write_by_name( self._ads_hub.write_by_name(self._ads_var, False, self._ads_hub.PLCTYPE_BOOL)
self._ads_var, False, self._ads_hub.PLCTYPE_BOOL)

View file

@ -1,2 +1,2 @@
"""Constants for the Aftership integration.""" """Constants for the Aftership integration."""
DOMAIN = 'aftership' DOMAIN = "aftership"

View file

@ -15,24 +15,24 @@ from .const import DOMAIN
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ATTRIBUTION = 'Information provided by AfterShip' ATTRIBUTION = "Information provided by AfterShip"
ATTR_TRACKINGS = 'trackings' ATTR_TRACKINGS = "trackings"
BASE = 'https://track.aftership.com/' BASE = "https://track.aftership.com/"
CONF_SLUG = 'slug' CONF_SLUG = "slug"
CONF_TITLE = 'title' CONF_TITLE = "title"
CONF_TRACKING_NUMBER = 'tracking_number' CONF_TRACKING_NUMBER = "tracking_number"
DEFAULT_NAME = 'aftership' DEFAULT_NAME = "aftership"
UPDATE_TOPIC = DOMAIN + '_update' UPDATE_TOPIC = DOMAIN + "_update"
ICON = 'mdi:package-variant-closed' ICON = "mdi:package-variant-closed"
MIN_TIME_BETWEEN_UPDATES = timedelta(minutes=5) MIN_TIME_BETWEEN_UPDATES = timedelta(minutes=5)
SERVICE_ADD_TRACKING = 'add_tracking' SERVICE_ADD_TRACKING = "add_tracking"
SERVICE_REMOVE_TRACKING = 'remove_tracking' SERVICE_REMOVE_TRACKING = "remove_tracking"
ADD_TRACKING_SERVICE_SCHEMA = vol.Schema( ADD_TRACKING_SERVICE_SCHEMA = vol.Schema(
{ {
@ -43,18 +43,18 @@ ADD_TRACKING_SERVICE_SCHEMA = vol.Schema(
) )
REMOVE_TRACKING_SERVICE_SCHEMA = vol.Schema( REMOVE_TRACKING_SERVICE_SCHEMA = vol.Schema(
{vol.Required(CONF_SLUG): cv.string, {vol.Required(CONF_SLUG): cv.string, vol.Required(CONF_TRACKING_NUMBER): cv.string}
vol.Required(CONF_TRACKING_NUMBER): cv.string}
) )
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_API_KEY): cv.string, {
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Required(CONF_API_KEY): cv.string,
}) vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
}
)
async def async_setup_platform( async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
hass, config, async_add_entities, discovery_info=None):
"""Set up the AfterShip sensor platform.""" """Set up the AfterShip sensor platform."""
from pyaftership.tracker import Tracking from pyaftership.tracker import Tracking
@ -66,9 +66,10 @@ async def async_setup_platform(
await aftership.get_trackings() await aftership.get_trackings()
if not aftership.meta or aftership.meta['code'] != 200: if not aftership.meta or aftership.meta["code"] != 200:
_LOGGER.error("No tracking data found. Check API key is correct: %s", _LOGGER.error(
aftership.meta) "No tracking data found. Check API key is correct: %s", aftership.meta
)
return return
instance = AfterShipSensor(aftership, name) instance = AfterShipSensor(aftership, name)
@ -130,7 +131,7 @@ class AfterShipSensor(Entity):
@property @property
def unit_of_measurement(self): def unit_of_measurement(self):
"""Return the unit of measurement of this entity, if any.""" """Return the unit of measurement of this entity, if any."""
return 'packages' return "packages"
@property @property
def device_state_attributes(self): def device_state_attributes(self):
@ -145,7 +146,8 @@ class AfterShipSensor(Entity):
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Register callbacks.""" """Register callbacks."""
self.hass.helpers.dispatcher.async_dispatcher_connect( self.hass.helpers.dispatcher.async_dispatcher_connect(
UPDATE_TOPIC, self.force_update) UPDATE_TOPIC, self.force_update
)
async def force_update(self): async def force_update(self):
"""Force update of data.""" """Force update of data."""
@ -160,40 +162,40 @@ class AfterShipSensor(Entity):
if not self.aftership.meta: if not self.aftership.meta:
_LOGGER.error("Unknown errors when querying") _LOGGER.error("Unknown errors when querying")
return return
if self.aftership.meta['code'] != 200: if self.aftership.meta["code"] != 200:
_LOGGER.error( _LOGGER.error(
"Errors when querying AfterShip. %s", str(self.aftership.meta)) "Errors when querying AfterShip. %s", str(self.aftership.meta)
)
return return
status_to_ignore = {'delivered'} status_to_ignore = {"delivered"}
status_counts = {} status_counts = {}
trackings = [] trackings = []
not_delivered_count = 0 not_delivered_count = 0
for track in self.aftership.trackings['trackings']: for track in self.aftership.trackings["trackings"]:
status = track['tag'].lower() status = track["tag"].lower()
name = ( name = (
track['tracking_number'] track["tracking_number"] if track["title"] is None else track["title"]
if track['title'] is None
else track['title']
) )
last_checkpoint = ( last_checkpoint = (
"Shipment pending" "Shipment pending"
if track['tag'] == "Pending" if track["tag"] == "Pending"
else track['checkpoints'][-1] else track["checkpoints"][-1]
) )
status_counts[status] = status_counts.get(status, 0) + 1 status_counts[status] = status_counts.get(status, 0) + 1
trackings.append({ trackings.append(
'name': name, {
'tracking_number': track['tracking_number'], "name": name,
'slug': track['slug'], "tracking_number": track["tracking_number"],
'link': '%s%s/%s' % "slug": track["slug"],
(BASE, track['slug'], track['tracking_number']), "link": "%s%s/%s" % (BASE, track["slug"], track["tracking_number"]),
'last_update': track['updated_at'], "last_update": track["updated_at"],
'expected_delivery': track['expected_delivery'], "expected_delivery": track["expected_delivery"],
'status': track['tag'], "status": track["tag"],
'last_checkpoint': last_checkpoint "last_checkpoint": last_checkpoint,
}) }
)
if status not in status_to_ignore: if status not in status_to_ignore:
not_delivered_count += 1 not_delivered_count += 1

View file

@ -4,50 +4,53 @@ import logging
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.config_validation import ( # noqa from homeassistant.helpers.config_validation import ( # noqa
PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE) PLATFORM_SCHEMA,
PLATFORM_SCHEMA_BASE,
)
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ATTR_AQI = 'air_quality_index' ATTR_AQI = "air_quality_index"
ATTR_ATTRIBUTION = 'attribution' ATTR_ATTRIBUTION = "attribution"
ATTR_CO2 = 'carbon_dioxide' ATTR_CO2 = "carbon_dioxide"
ATTR_CO = 'carbon_monoxide' ATTR_CO = "carbon_monoxide"
ATTR_N2O = 'nitrogen_oxide' ATTR_N2O = "nitrogen_oxide"
ATTR_NO = 'nitrogen_monoxide' ATTR_NO = "nitrogen_monoxide"
ATTR_NO2 = 'nitrogen_dioxide' ATTR_NO2 = "nitrogen_dioxide"
ATTR_OZONE = 'ozone' ATTR_OZONE = "ozone"
ATTR_PM_0_1 = 'particulate_matter_0_1' ATTR_PM_0_1 = "particulate_matter_0_1"
ATTR_PM_10 = 'particulate_matter_10' ATTR_PM_10 = "particulate_matter_10"
ATTR_PM_2_5 = 'particulate_matter_2_5' ATTR_PM_2_5 = "particulate_matter_2_5"
ATTR_SO2 = 'sulphur_dioxide' ATTR_SO2 = "sulphur_dioxide"
DOMAIN = 'air_quality' DOMAIN = "air_quality"
ENTITY_ID_FORMAT = DOMAIN + '.{}' ENTITY_ID_FORMAT = DOMAIN + ".{}"
SCAN_INTERVAL = timedelta(seconds=30) SCAN_INTERVAL = timedelta(seconds=30)
PROP_TO_ATTR = { PROP_TO_ATTR = {
'air_quality_index': ATTR_AQI, "air_quality_index": ATTR_AQI,
'attribution': ATTR_ATTRIBUTION, "attribution": ATTR_ATTRIBUTION,
'carbon_dioxide': ATTR_CO2, "carbon_dioxide": ATTR_CO2,
'carbon_monoxide': ATTR_CO, "carbon_monoxide": ATTR_CO,
'nitrogen_oxide': ATTR_N2O, "nitrogen_oxide": ATTR_N2O,
'nitrogen_monoxide': ATTR_NO, "nitrogen_monoxide": ATTR_NO,
'nitrogen_dioxide': ATTR_NO2, "nitrogen_dioxide": ATTR_NO2,
'ozone': ATTR_OZONE, "ozone": ATTR_OZONE,
'particulate_matter_0_1': ATTR_PM_0_1, "particulate_matter_0_1": ATTR_PM_0_1,
'particulate_matter_10': ATTR_PM_10, "particulate_matter_10": ATTR_PM_10,
'particulate_matter_2_5': ATTR_PM_2_5, "particulate_matter_2_5": ATTR_PM_2_5,
'sulphur_dioxide': ATTR_SO2, "sulphur_dioxide": ATTR_SO2,
} }
async def async_setup(hass, config): async def async_setup(hass, config):
"""Set up the air quality component.""" """Set up the air quality component."""
component = hass.data[DOMAIN] = EntityComponent( component = hass.data[DOMAIN] = EntityComponent(
_LOGGER, DOMAIN, hass, SCAN_INTERVAL) _LOGGER, DOMAIN, hass, SCAN_INTERVAL
)
await component.async_setup(config) await component.async_setup(config)
return True return True

View file

@ -6,118 +6,96 @@ import voluptuous as vol
from homeassistant.components.sensor import PLATFORM_SCHEMA from homeassistant.components.sensor import PLATFORM_SCHEMA
from homeassistant.const import ( from homeassistant.const import (
ATTR_ATTRIBUTION, ATTR_LATITUDE, ATTR_LONGITUDE, CONF_API_KEY, ATTR_ATTRIBUTION,
CONF_LATITUDE, CONF_LONGITUDE, CONF_MONITORED_CONDITIONS, ATTR_LATITUDE,
CONF_SCAN_INTERVAL, CONF_STATE, CONF_SHOW_ON_MAP) ATTR_LONGITUDE,
CONF_API_KEY,
CONF_LATITUDE,
CONF_LONGITUDE,
CONF_MONITORED_CONDITIONS,
CONF_SCAN_INTERVAL,
CONF_STATE,
CONF_SHOW_ON_MAP,
)
from homeassistant.helpers import aiohttp_client, config_validation as cv from homeassistant.helpers import aiohttp_client, config_validation as cv
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.util import Throttle from homeassistant.util import Throttle
_LOGGER = getLogger(__name__) _LOGGER = getLogger(__name__)
ATTR_CITY = 'city' ATTR_CITY = "city"
ATTR_COUNTRY = 'country' ATTR_COUNTRY = "country"
ATTR_POLLUTANT_SYMBOL = 'pollutant_symbol' ATTR_POLLUTANT_SYMBOL = "pollutant_symbol"
ATTR_POLLUTANT_UNIT = 'pollutant_unit' ATTR_POLLUTANT_UNIT = "pollutant_unit"
ATTR_REGION = 'region' ATTR_REGION = "region"
CONF_CITY = 'city' CONF_CITY = "city"
CONF_COUNTRY = 'country' CONF_COUNTRY = "country"
DEFAULT_ATTRIBUTION = "Data provided by AirVisual" DEFAULT_ATTRIBUTION = "Data provided by AirVisual"
DEFAULT_SCAN_INTERVAL = timedelta(minutes=10) DEFAULT_SCAN_INTERVAL = timedelta(minutes=10)
MASS_PARTS_PER_MILLION = 'ppm' MASS_PARTS_PER_MILLION = "ppm"
MASS_PARTS_PER_BILLION = 'ppb' MASS_PARTS_PER_BILLION = "ppb"
VOLUME_MICROGRAMS_PER_CUBIC_METER = 'µg/m3' VOLUME_MICROGRAMS_PER_CUBIC_METER = "µg/m3"
SENSOR_TYPE_LEVEL = 'air_pollution_level' SENSOR_TYPE_LEVEL = "air_pollution_level"
SENSOR_TYPE_AQI = 'air_quality_index' SENSOR_TYPE_AQI = "air_quality_index"
SENSOR_TYPE_POLLUTANT = 'main_pollutant' SENSOR_TYPE_POLLUTANT = "main_pollutant"
SENSORS = [ SENSORS = [
(SENSOR_TYPE_LEVEL, 'Air Pollution Level', 'mdi:gauge', None), (SENSOR_TYPE_LEVEL, "Air Pollution Level", "mdi:gauge", None),
(SENSOR_TYPE_AQI, 'Air Quality Index', 'mdi:chart-line', 'AQI'), (SENSOR_TYPE_AQI, "Air Quality Index", "mdi:chart-line", "AQI"),
(SENSOR_TYPE_POLLUTANT, 'Main Pollutant', 'mdi:chemical-weapon', None), (SENSOR_TYPE_POLLUTANT, "Main Pollutant", "mdi:chemical-weapon", None),
] ]
POLLUTANT_LEVEL_MAPPING = [{ POLLUTANT_LEVEL_MAPPING = [
'label': 'Good', {"label": "Good", "icon": "mdi:emoticon-excited", "minimum": 0, "maximum": 50},
'icon': 'mdi:emoticon-excited', {"label": "Moderate", "icon": "mdi:emoticon-happy", "minimum": 51, "maximum": 100},
'minimum': 0, {
'maximum': 50 "label": "Unhealthy for sensitive groups",
}, { "icon": "mdi:emoticon-neutral",
'label': 'Moderate', "minimum": 101,
'icon': 'mdi:emoticon-happy', "maximum": 150,
'minimum': 51, },
'maximum': 100 {"label": "Unhealthy", "icon": "mdi:emoticon-sad", "minimum": 151, "maximum": 200},
}, { {
'label': 'Unhealthy for sensitive groups', "label": "Very Unhealthy",
'icon': 'mdi:emoticon-neutral', "icon": "mdi:emoticon-dead",
'minimum': 101, "minimum": 201,
'maximum': 150 "maximum": 300,
}, { },
'label': 'Unhealthy', {"label": "Hazardous", "icon": "mdi:biohazard", "minimum": 301, "maximum": 10000},
'icon': 'mdi:emoticon-sad', ]
'minimum': 151,
'maximum': 200
}, {
'label': 'Very Unhealthy',
'icon': 'mdi:emoticon-dead',
'minimum': 201,
'maximum': 300
}, {
'label': 'Hazardous',
'icon': 'mdi:biohazard',
'minimum': 301,
'maximum': 10000
}]
POLLUTANT_MAPPING = { POLLUTANT_MAPPING = {
'co': { "co": {"label": "Carbon Monoxide", "unit": MASS_PARTS_PER_MILLION},
'label': 'Carbon Monoxide', "n2": {"label": "Nitrogen Dioxide", "unit": MASS_PARTS_PER_BILLION},
'unit': MASS_PARTS_PER_MILLION "o3": {"label": "Ozone", "unit": MASS_PARTS_PER_BILLION},
}, "p1": {"label": "PM10", "unit": VOLUME_MICROGRAMS_PER_CUBIC_METER},
'n2': { "p2": {"label": "PM2.5", "unit": VOLUME_MICROGRAMS_PER_CUBIC_METER},
'label': 'Nitrogen Dioxide', "s2": {"label": "Sulfur Dioxide", "unit": MASS_PARTS_PER_BILLION},
'unit': MASS_PARTS_PER_BILLION
},
'o3': {
'label': 'Ozone',
'unit': MASS_PARTS_PER_BILLION
},
'p1': {
'label': 'PM10',
'unit': VOLUME_MICROGRAMS_PER_CUBIC_METER
},
'p2': {
'label': 'PM2.5',
'unit': VOLUME_MICROGRAMS_PER_CUBIC_METER
},
's2': {
'label': 'Sulfur Dioxide',
'unit': MASS_PARTS_PER_BILLION
},
} }
SENSOR_LOCALES = {'cn': 'Chinese', 'us': 'U.S.'} SENSOR_LOCALES = {"cn": "Chinese", "us": "U.S."}
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_API_KEY): cv.string, {
vol.Required(CONF_MONITORED_CONDITIONS, default=list(SENSOR_LOCALES)): vol.Required(CONF_API_KEY): cv.string,
vol.All(cv.ensure_list, [vol.In(SENSOR_LOCALES)]), vol.Required(CONF_MONITORED_CONDITIONS, default=list(SENSOR_LOCALES)): vol.All(
vol.Inclusive(CONF_CITY, 'city'): cv.string, cv.ensure_list, [vol.In(SENSOR_LOCALES)]
vol.Inclusive(CONF_COUNTRY, 'city'): cv.string, ),
vol.Inclusive(CONF_LATITUDE, 'coords'): cv.latitude, vol.Inclusive(CONF_CITY, "city"): cv.string,
vol.Inclusive(CONF_LONGITUDE, 'coords'): cv.longitude, vol.Inclusive(CONF_COUNTRY, "city"): cv.string,
vol.Optional(CONF_SHOW_ON_MAP, default=True): cv.boolean, vol.Inclusive(CONF_LATITUDE, "coords"): cv.latitude,
vol.Inclusive(CONF_STATE, 'city'): cv.string, vol.Inclusive(CONF_LONGITUDE, "coords"): cv.longitude,
vol.Optional(CONF_SCAN_INTERVAL, default=DEFAULT_SCAN_INTERVAL): vol.Optional(CONF_SHOW_ON_MAP, default=True): cv.boolean,
cv.time_period vol.Inclusive(CONF_STATE, "city"): cv.string,
}) vol.Optional(CONF_SCAN_INTERVAL, default=DEFAULT_SCAN_INTERVAL): cv.time_period,
}
)
async def async_setup_platform( async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
hass, config, async_add_entities, discovery_info=None):
"""Configure the platform and add the sensors.""" """Configure the platform and add the sensors."""
from pyairvisual import Client from pyairvisual import Client
@ -132,25 +110,27 @@ async def async_setup_platform(
if city and state and country: if city and state and country:
_LOGGER.debug( _LOGGER.debug(
"Using city, state, and country: %s, %s, %s", city, state, country) "Using city, state, and country: %s, %s, %s", city, state, country
location_id = ','.join((city, state, country)) )
location_id = ",".join((city, state, country))
data = AirVisualData( data = AirVisualData(
Client(websession, api_key=config[CONF_API_KEY]), Client(websession, api_key=config[CONF_API_KEY]),
city=city, city=city,
state=state, state=state,
country=country, country=country,
show_on_map=config[CONF_SHOW_ON_MAP], show_on_map=config[CONF_SHOW_ON_MAP],
scan_interval=config[CONF_SCAN_INTERVAL]) scan_interval=config[CONF_SCAN_INTERVAL],
)
else: else:
_LOGGER.debug( _LOGGER.debug("Using latitude and longitude: %s, %s", latitude, longitude)
"Using latitude and longitude: %s, %s", latitude, longitude) location_id = ",".join((str(latitude), str(longitude)))
location_id = ','.join((str(latitude), str(longitude)))
data = AirVisualData( data = AirVisualData(
Client(websession, api_key=config[CONF_API_KEY]), Client(websession, api_key=config[CONF_API_KEY]),
latitude=latitude, latitude=latitude,
longitude=longitude, longitude=longitude,
show_on_map=config[CONF_SHOW_ON_MAP], show_on_map=config[CONF_SHOW_ON_MAP],
scan_interval=config[CONF_SCAN_INTERVAL]) scan_interval=config[CONF_SCAN_INTERVAL],
)
await data.async_update() await data.async_update()
@ -158,8 +138,8 @@ async def async_setup_platform(
for locale in config[CONF_MONITORED_CONDITIONS]: for locale in config[CONF_MONITORED_CONDITIONS]:
for kind, name, icon, unit in SENSORS: for kind, name, icon, unit in SENSORS:
sensors.append( sensors.append(
AirVisualSensor( AirVisualSensor(data, kind, name, icon, unit, locale, location_id)
data, kind, name, icon, unit, locale, location_id)) )
async_add_entities(sensors, True) async_add_entities(sensors, True)
@ -186,8 +166,8 @@ class AirVisualSensor(Entity):
self._attrs[ATTR_LATITUDE] = self.airvisual.latitude self._attrs[ATTR_LATITUDE] = self.airvisual.latitude
self._attrs[ATTR_LONGITUDE] = self.airvisual.longitude self._attrs[ATTR_LONGITUDE] = self.airvisual.longitude
else: else:
self._attrs['lati'] = self.airvisual.latitude self._attrs["lati"] = self.airvisual.latitude
self._attrs['long'] = self.airvisual.longitude self._attrs["long"] = self.airvisual.longitude
return self._attrs return self._attrs
@ -204,7 +184,7 @@ class AirVisualSensor(Entity):
@property @property
def name(self): def name(self):
"""Return the name.""" """Return the name."""
return '{0} {1}'.format(SENSOR_LOCALES[self._locale], self._name) return "{0} {1}".format(SENSOR_LOCALES[self._locale], self._name)
@property @property
def state(self): def state(self):
@ -214,8 +194,7 @@ class AirVisualSensor(Entity):
@property @property
def unique_id(self): def unique_id(self):
"""Return a unique, HASS-friendly identifier for this entity.""" """Return a unique, HASS-friendly identifier for this entity."""
return '{0}_{1}_{2}'.format( return "{0}_{1}_{2}".format(self._location_id, self._locale, self._type)
self._location_id, self._locale, self._type)
@property @property
def unit_of_measurement(self): def unit_of_measurement(self):
@ -231,22 +210,25 @@ class AirVisualSensor(Entity):
return return
if self._type == SENSOR_TYPE_LEVEL: if self._type == SENSOR_TYPE_LEVEL:
aqi = data['aqi{0}'.format(self._locale)] aqi = data["aqi{0}".format(self._locale)]
[level] = [ [level] = [
i for i in POLLUTANT_LEVEL_MAPPING i
if i['minimum'] <= aqi <= i['maximum'] for i in POLLUTANT_LEVEL_MAPPING
if i["minimum"] <= aqi <= i["maximum"]
] ]
self._state = level['label'] self._state = level["label"]
self._icon = level['icon'] self._icon = level["icon"]
elif self._type == SENSOR_TYPE_AQI: elif self._type == SENSOR_TYPE_AQI:
self._state = data['aqi{0}'.format(self._locale)] self._state = data["aqi{0}".format(self._locale)]
elif self._type == SENSOR_TYPE_POLLUTANT: elif self._type == SENSOR_TYPE_POLLUTANT:
symbol = data['main{0}'.format(self._locale)] symbol = data["main{0}".format(self._locale)]
self._state = POLLUTANT_MAPPING[symbol]['label'] self._state = POLLUTANT_MAPPING[symbol]["label"]
self._attrs.update({ self._attrs.update(
ATTR_POLLUTANT_SYMBOL: symbol, {
ATTR_POLLUTANT_UNIT: POLLUTANT_MAPPING[symbol]['unit'] ATTR_POLLUTANT_SYMBOL: symbol,
}) ATTR_POLLUTANT_UNIT: POLLUTANT_MAPPING[symbol]["unit"],
}
)
class AirVisualData: class AirVisualData:
@ -263,8 +245,7 @@ class AirVisualData:
self.show_on_map = kwargs.get(CONF_SHOW_ON_MAP) self.show_on_map = kwargs.get(CONF_SHOW_ON_MAP)
self.state = kwargs.get(CONF_STATE) self.state = kwargs.get(CONF_STATE)
self.async_update = Throttle( self.async_update = Throttle(kwargs[CONF_SCAN_INTERVAL])(self._async_update)
kwargs[CONF_SCAN_INTERVAL])(self._async_update)
async def _async_update(self): async def _async_update(self):
"""Update AirVisual data.""" """Update AirVisual data."""
@ -272,23 +253,21 @@ class AirVisualData:
try: try:
if self.city and self.state and self.country: if self.city and self.state and self.country:
resp = await self._client.api.city( resp = await self._client.api.city(self.city, self.state, self.country)
self.city, self.state, self.country) self.longitude, self.latitude = resp["location"]["coordinates"]
self.longitude, self.latitude = resp['location']['coordinates']
else: else:
resp = await self._client.api.nearest_city( resp = await self._client.api.nearest_city(
self.latitude, self.longitude) self.latitude, self.longitude
)
_LOGGER.debug("New data retrieved: %s", resp) _LOGGER.debug("New data retrieved: %s", resp)
self.pollution_info = resp['current']['pollution'] self.pollution_info = resp["current"]["pollution"]
except (KeyError, AirVisualError) as err: except (KeyError, AirVisualError) as err:
if self.city and self.state and self.country: if self.city and self.state and self.country:
location = (self.city, self.state, self.country) location = (self.city, self.state, self.country)
else: else:
location = (self.latitude, self.longitude) location = (self.latitude, self.longitude)
_LOGGER.error( _LOGGER.error("Can't retrieve data for location: %s (%s)", location, err)
"Can't retrieve data for location: %s (%s)", location,
err)
self.pollution_info = {} self.pollution_info = {}

View file

@ -3,30 +3,39 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.components.cover import (CoverDevice, PLATFORM_SCHEMA, from homeassistant.components.cover import (
SUPPORT_OPEN, SUPPORT_CLOSE) CoverDevice,
from homeassistant.const import (CONF_USERNAME, CONF_PASSWORD, STATE_CLOSED, PLATFORM_SCHEMA,
STATE_OPENING, STATE_CLOSING, STATE_OPEN) SUPPORT_OPEN,
SUPPORT_CLOSE,
)
from homeassistant.const import (
CONF_USERNAME,
CONF_PASSWORD,
STATE_CLOSED,
STATE_OPENING,
STATE_CLOSING,
STATE_OPEN,
)
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
NOTIFICATION_ID = 'aladdin_notification' NOTIFICATION_ID = "aladdin_notification"
NOTIFICATION_TITLE = 'Aladdin Connect Cover Setup' NOTIFICATION_TITLE = "Aladdin Connect Cover Setup"
STATES_MAP = { STATES_MAP = {
'open': STATE_OPEN, "open": STATE_OPEN,
'opening': STATE_OPENING, "opening": STATE_OPENING,
'closed': STATE_CLOSED, "closed": STATE_CLOSED,
'closing': STATE_CLOSING "closing": STATE_CLOSING,
} }
SUPPORTED_FEATURES = SUPPORT_OPEN | SUPPORT_CLOSE SUPPORTED_FEATURES = SUPPORT_OPEN | SUPPORT_CLOSE
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_USERNAME): cv.string, {vol.Required(CONF_USERNAME): cv.string, vol.Required(CONF_PASSWORD): cv.string}
vol.Required(CONF_PASSWORD): cv.string )
})
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
@ -44,11 +53,12 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
except (TypeError, KeyError, NameError, ValueError) as ex: except (TypeError, KeyError, NameError, ValueError) as ex:
_LOGGER.error("%s", ex) _LOGGER.error("%s", ex)
hass.components.persistent_notification.create( hass.components.persistent_notification.create(
'Error: {}<br />' "Error: {}<br />"
'You will need to restart hass after fixing.' "You will need to restart hass after fixing."
''.format(ex), "".format(ex),
title=NOTIFICATION_TITLE, title=NOTIFICATION_TITLE,
notification_id=NOTIFICATION_ID) notification_id=NOTIFICATION_ID,
)
class AladdinDevice(CoverDevice): class AladdinDevice(CoverDevice):
@ -57,15 +67,15 @@ class AladdinDevice(CoverDevice):
def __init__(self, acc, device): def __init__(self, acc, device):
"""Initialize the cover.""" """Initialize the cover."""
self._acc = acc self._acc = acc
self._device_id = device['device_id'] self._device_id = device["device_id"]
self._number = device['door_number'] self._number = device["door_number"]
self._name = device['name'] self._name = device["name"]
self._status = STATES_MAP.get(device['status']) self._status = STATES_MAP.get(device["status"])
@property @property
def device_class(self): def device_class(self):
"""Define this cover as a garage door.""" """Define this cover as a garage door."""
return 'garage' return "garage"
@property @property
def supported_features(self): def supported_features(self):
@ -75,7 +85,7 @@ class AladdinDevice(CoverDevice):
@property @property
def unique_id(self): def unique_id(self):
"""Return a unique ID.""" """Return a unique ID."""
return '{}-{}'.format(self._device_id, self._number) return "{}-{}".format(self._device_id, self._number)
@property @property
def name(self): def name(self):

View file

@ -5,59 +5,65 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
ATTR_CODE, ATTR_CODE_FORMAT, SERVICE_ALARM_TRIGGER, SERVICE_ALARM_DISARM, ATTR_CODE,
SERVICE_ALARM_ARM_HOME, SERVICE_ALARM_ARM_AWAY, SERVICE_ALARM_ARM_NIGHT, ATTR_CODE_FORMAT,
SERVICE_ALARM_ARM_CUSTOM_BYPASS) SERVICE_ALARM_TRIGGER,
SERVICE_ALARM_DISARM,
SERVICE_ALARM_ARM_HOME,
SERVICE_ALARM_ARM_AWAY,
SERVICE_ALARM_ARM_NIGHT,
SERVICE_ALARM_ARM_CUSTOM_BYPASS,
)
from homeassistant.helpers.config_validation import ( # noqa from homeassistant.helpers.config_validation import ( # noqa
ENTITY_SERVICE_SCHEMA, PLATFORM_SCHEMA, PLATFORM_SCHEMA_BASE) ENTITY_SERVICE_SCHEMA,
PLATFORM_SCHEMA,
PLATFORM_SCHEMA_BASE,
)
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.entity_component import EntityComponent
DOMAIN = 'alarm_control_panel' DOMAIN = "alarm_control_panel"
SCAN_INTERVAL = timedelta(seconds=30) SCAN_INTERVAL = timedelta(seconds=30)
ATTR_CHANGED_BY = 'changed_by' ATTR_CHANGED_BY = "changed_by"
FORMAT_TEXT = 'text' FORMAT_TEXT = "text"
FORMAT_NUMBER = 'number' FORMAT_NUMBER = "number"
ATTR_CODE_ARM_REQUIRED = 'code_arm_required' ATTR_CODE_ARM_REQUIRED = "code_arm_required"
ENTITY_ID_FORMAT = DOMAIN + '.{}' ENTITY_ID_FORMAT = DOMAIN + ".{}"
ALARM_SERVICE_SCHEMA = ENTITY_SERVICE_SCHEMA.extend({ ALARM_SERVICE_SCHEMA = ENTITY_SERVICE_SCHEMA.extend(
vol.Optional(ATTR_CODE): cv.string, {vol.Optional(ATTR_CODE): cv.string}
}) )
async def async_setup(hass, config): async def async_setup(hass, config):
"""Track states and offer events for sensors.""" """Track states and offer events for sensors."""
component = hass.data[DOMAIN] = EntityComponent( component = hass.data[DOMAIN] = EntityComponent(
logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL) logging.getLogger(__name__), DOMAIN, hass, SCAN_INTERVAL
)
await component.async_setup(config) await component.async_setup(config)
component.async_register_entity_service( component.async_register_entity_service(
SERVICE_ALARM_DISARM, ALARM_SERVICE_SCHEMA, SERVICE_ALARM_DISARM, ALARM_SERVICE_SCHEMA, "async_alarm_disarm"
'async_alarm_disarm'
) )
component.async_register_entity_service( component.async_register_entity_service(
SERVICE_ALARM_ARM_HOME, ALARM_SERVICE_SCHEMA, SERVICE_ALARM_ARM_HOME, ALARM_SERVICE_SCHEMA, "async_alarm_arm_home"
'async_alarm_arm_home'
) )
component.async_register_entity_service( component.async_register_entity_service(
SERVICE_ALARM_ARM_AWAY, ALARM_SERVICE_SCHEMA, SERVICE_ALARM_ARM_AWAY, ALARM_SERVICE_SCHEMA, "async_alarm_arm_away"
'async_alarm_arm_away'
) )
component.async_register_entity_service( component.async_register_entity_service(
SERVICE_ALARM_ARM_NIGHT, ALARM_SERVICE_SCHEMA, SERVICE_ALARM_ARM_NIGHT, ALARM_SERVICE_SCHEMA, "async_alarm_arm_night"
'async_alarm_arm_night'
) )
component.async_register_entity_service( component.async_register_entity_service(
SERVICE_ALARM_ARM_CUSTOM_BYPASS, ALARM_SERVICE_SCHEMA, SERVICE_ALARM_ARM_CUSTOM_BYPASS,
'async_alarm_arm_custom_bypass' ALARM_SERVICE_SCHEMA,
"async_alarm_arm_custom_bypass",
) )
component.async_register_entity_service( component.async_register_entity_service(
SERVICE_ALARM_TRIGGER, ALARM_SERVICE_SCHEMA, SERVICE_ALARM_TRIGGER, ALARM_SERVICE_SCHEMA, "async_alarm_trigger"
'async_alarm_trigger'
) )
return True return True
@ -156,8 +162,7 @@ class AlarmControlPanel(Entity):
This method must be run in the event loop and returns a coroutine. This method must be run in the event loop and returns a coroutine.
""" """
return self.hass.async_add_executor_job( return self.hass.async_add_executor_job(self.alarm_arm_custom_bypass, code)
self.alarm_arm_custom_bypass, code)
@property @property
def state_attributes(self): def state_attributes(self):
@ -165,6 +170,6 @@ class AlarmControlPanel(Entity):
state_attr = { state_attr = {
ATTR_CODE_FORMAT: self.code_format, ATTR_CODE_FORMAT: self.code_format,
ATTR_CHANGED_BY: self.changed_by, ATTR_CHANGED_BY: self.changed_by,
ATTR_CODE_ARM_REQUIRED: self.code_arm_required ATTR_CODE_ARM_REQUIRED: self.code_arm_required,
} }
return state_attr return state_attr

View file

@ -12,85 +12,105 @@ from homeassistant.components.binary_sensor import DEVICE_CLASSES_SCHEMA
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DOMAIN = 'alarmdecoder' DOMAIN = "alarmdecoder"
DATA_AD = 'alarmdecoder' DATA_AD = "alarmdecoder"
CONF_DEVICE = 'device' CONF_DEVICE = "device"
CONF_DEVICE_BAUD = 'baudrate' CONF_DEVICE_BAUD = "baudrate"
CONF_DEVICE_PATH = 'path' CONF_DEVICE_PATH = "path"
CONF_DEVICE_PORT = 'port' CONF_DEVICE_PORT = "port"
CONF_DEVICE_TYPE = 'type' CONF_DEVICE_TYPE = "type"
CONF_PANEL_DISPLAY = 'panel_display' CONF_PANEL_DISPLAY = "panel_display"
CONF_ZONE_NAME = 'name' CONF_ZONE_NAME = "name"
CONF_ZONE_TYPE = 'type' CONF_ZONE_TYPE = "type"
CONF_ZONE_LOOP = 'loop' CONF_ZONE_LOOP = "loop"
CONF_ZONE_RFID = 'rfid' CONF_ZONE_RFID = "rfid"
CONF_ZONES = 'zones' CONF_ZONES = "zones"
CONF_RELAY_ADDR = 'relayaddr' CONF_RELAY_ADDR = "relayaddr"
CONF_RELAY_CHAN = 'relaychan' CONF_RELAY_CHAN = "relaychan"
DEFAULT_DEVICE_TYPE = 'socket' DEFAULT_DEVICE_TYPE = "socket"
DEFAULT_DEVICE_HOST = 'localhost' DEFAULT_DEVICE_HOST = "localhost"
DEFAULT_DEVICE_PORT = 10000 DEFAULT_DEVICE_PORT = 10000
DEFAULT_DEVICE_PATH = '/dev/ttyUSB0' DEFAULT_DEVICE_PATH = "/dev/ttyUSB0"
DEFAULT_DEVICE_BAUD = 115200 DEFAULT_DEVICE_BAUD = 115200
DEFAULT_PANEL_DISPLAY = False DEFAULT_PANEL_DISPLAY = False
DEFAULT_ZONE_TYPE = 'opening' DEFAULT_ZONE_TYPE = "opening"
SIGNAL_PANEL_MESSAGE = 'alarmdecoder.panel_message' SIGNAL_PANEL_MESSAGE = "alarmdecoder.panel_message"
SIGNAL_PANEL_ARM_AWAY = 'alarmdecoder.panel_arm_away' SIGNAL_PANEL_ARM_AWAY = "alarmdecoder.panel_arm_away"
SIGNAL_PANEL_ARM_HOME = 'alarmdecoder.panel_arm_home' SIGNAL_PANEL_ARM_HOME = "alarmdecoder.panel_arm_home"
SIGNAL_PANEL_DISARM = 'alarmdecoder.panel_disarm' SIGNAL_PANEL_DISARM = "alarmdecoder.panel_disarm"
SIGNAL_ZONE_FAULT = 'alarmdecoder.zone_fault' SIGNAL_ZONE_FAULT = "alarmdecoder.zone_fault"
SIGNAL_ZONE_RESTORE = 'alarmdecoder.zone_restore' SIGNAL_ZONE_RESTORE = "alarmdecoder.zone_restore"
SIGNAL_RFX_MESSAGE = 'alarmdecoder.rfx_message' SIGNAL_RFX_MESSAGE = "alarmdecoder.rfx_message"
SIGNAL_REL_MESSAGE = 'alarmdecoder.rel_message' SIGNAL_REL_MESSAGE = "alarmdecoder.rel_message"
DEVICE_SOCKET_SCHEMA = vol.Schema({ DEVICE_SOCKET_SCHEMA = vol.Schema(
vol.Required(CONF_DEVICE_TYPE): 'socket', {
vol.Optional(CONF_HOST, default=DEFAULT_DEVICE_HOST): cv.string, vol.Required(CONF_DEVICE_TYPE): "socket",
vol.Optional(CONF_DEVICE_PORT, default=DEFAULT_DEVICE_PORT): cv.port}) vol.Optional(CONF_HOST, default=DEFAULT_DEVICE_HOST): cv.string,
vol.Optional(CONF_DEVICE_PORT, default=DEFAULT_DEVICE_PORT): cv.port,
}
)
DEVICE_SERIAL_SCHEMA = vol.Schema({ DEVICE_SERIAL_SCHEMA = vol.Schema(
vol.Required(CONF_DEVICE_TYPE): 'serial', {
vol.Optional(CONF_DEVICE_PATH, default=DEFAULT_DEVICE_PATH): cv.string, vol.Required(CONF_DEVICE_TYPE): "serial",
vol.Optional(CONF_DEVICE_BAUD, default=DEFAULT_DEVICE_BAUD): cv.string}) vol.Optional(CONF_DEVICE_PATH, default=DEFAULT_DEVICE_PATH): cv.string,
vol.Optional(CONF_DEVICE_BAUD, default=DEFAULT_DEVICE_BAUD): cv.string,
}
)
DEVICE_USB_SCHEMA = vol.Schema({ DEVICE_USB_SCHEMA = vol.Schema({vol.Required(CONF_DEVICE_TYPE): "usb"})
vol.Required(CONF_DEVICE_TYPE): 'usb'})
ZONE_SCHEMA = vol.Schema({ ZONE_SCHEMA = vol.Schema(
vol.Required(CONF_ZONE_NAME): cv.string, {
vol.Optional(CONF_ZONE_TYPE, vol.Required(CONF_ZONE_NAME): cv.string,
default=DEFAULT_ZONE_TYPE): vol.Any(DEVICE_CLASSES_SCHEMA), vol.Optional(CONF_ZONE_TYPE, default=DEFAULT_ZONE_TYPE): vol.Any(
vol.Optional(CONF_ZONE_RFID): cv.string, DEVICE_CLASSES_SCHEMA
vol.Optional(CONF_ZONE_LOOP): ),
vol.All(vol.Coerce(int), vol.Range(min=1, max=4)), vol.Optional(CONF_ZONE_RFID): cv.string,
vol.Inclusive(CONF_RELAY_ADDR, 'relaylocation', vol.Optional(CONF_ZONE_LOOP): vol.All(vol.Coerce(int), vol.Range(min=1, max=4)),
'Relay address and channel must exist together'): cv.byte, vol.Inclusive(
vol.Inclusive(CONF_RELAY_CHAN, 'relaylocation', CONF_RELAY_ADDR,
'Relay address and channel must exist together'): cv.byte}) "relaylocation",
"Relay address and channel must exist together",
): cv.byte,
vol.Inclusive(
CONF_RELAY_CHAN,
"relaylocation",
"Relay address and channel must exist together",
): cv.byte,
}
)
CONFIG_SCHEMA = vol.Schema({ CONFIG_SCHEMA = vol.Schema(
DOMAIN: vol.Schema({ {
vol.Required(CONF_DEVICE): vol.Any( DOMAIN: vol.Schema(
DEVICE_SOCKET_SCHEMA, DEVICE_SERIAL_SCHEMA, {
DEVICE_USB_SCHEMA), vol.Required(CONF_DEVICE): vol.Any(
vol.Optional(CONF_PANEL_DISPLAY, DEVICE_SOCKET_SCHEMA, DEVICE_SERIAL_SCHEMA, DEVICE_USB_SCHEMA
default=DEFAULT_PANEL_DISPLAY): cv.boolean, ),
vol.Optional(CONF_ZONES): {vol.Coerce(int): ZONE_SCHEMA}, vol.Optional(
}), CONF_PANEL_DISPLAY, default=DEFAULT_PANEL_DISPLAY
}, extra=vol.ALLOW_EXTRA) ): cv.boolean,
vol.Optional(CONF_ZONES): {vol.Coerce(int): ZONE_SCHEMA},
}
)
},
extra=vol.ALLOW_EXTRA,
)
def setup(hass, config): def setup(hass, config):
"""Set up for the AlarmDecoder devices.""" """Set up for the AlarmDecoder devices."""
from alarmdecoder import AlarmDecoder from alarmdecoder import AlarmDecoder
from alarmdecoder.devices import (SocketDevice, SerialDevice, USBDevice) from alarmdecoder.devices import SocketDevice, SerialDevice, USBDevice
conf = config.get(DOMAIN) conf = config.get(DOMAIN)
@ -115,13 +135,15 @@ def setup(hass, config):
def open_connection(now=None): def open_connection(now=None):
"""Open a connection to AlarmDecoder.""" """Open a connection to AlarmDecoder."""
from alarmdecoder.util import NoDeviceError from alarmdecoder.util import NoDeviceError
nonlocal restart nonlocal restart
try: try:
controller.open(baud) controller.open(baud)
except NoDeviceError: except NoDeviceError:
_LOGGER.debug("Failed to connect. Retrying in 5 seconds") _LOGGER.debug("Failed to connect. Retrying in 5 seconds")
hass.helpers.event.track_point_in_time( hass.helpers.event.track_point_in_time(
open_connection, dt_util.utcnow() + timedelta(seconds=5)) open_connection, dt_util.utcnow() + timedelta(seconds=5)
)
return return
_LOGGER.debug("Established a connection with the alarmdecoder") _LOGGER.debug("Established a connection with the alarmdecoder")
restart = True restart = True
@ -137,39 +159,34 @@ def setup(hass, config):
def handle_message(sender, message): def handle_message(sender, message):
"""Handle message from AlarmDecoder.""" """Handle message from AlarmDecoder."""
hass.helpers.dispatcher.dispatcher_send( hass.helpers.dispatcher.dispatcher_send(SIGNAL_PANEL_MESSAGE, message)
SIGNAL_PANEL_MESSAGE, message)
def handle_rfx_message(sender, message): def handle_rfx_message(sender, message):
"""Handle RFX message from AlarmDecoder.""" """Handle RFX message from AlarmDecoder."""
hass.helpers.dispatcher.dispatcher_send( hass.helpers.dispatcher.dispatcher_send(SIGNAL_RFX_MESSAGE, message)
SIGNAL_RFX_MESSAGE, message)
def zone_fault_callback(sender, zone): def zone_fault_callback(sender, zone):
"""Handle zone fault from AlarmDecoder.""" """Handle zone fault from AlarmDecoder."""
hass.helpers.dispatcher.dispatcher_send( hass.helpers.dispatcher.dispatcher_send(SIGNAL_ZONE_FAULT, zone)
SIGNAL_ZONE_FAULT, zone)
def zone_restore_callback(sender, zone): def zone_restore_callback(sender, zone):
"""Handle zone restore from AlarmDecoder.""" """Handle zone restore from AlarmDecoder."""
hass.helpers.dispatcher.dispatcher_send( hass.helpers.dispatcher.dispatcher_send(SIGNAL_ZONE_RESTORE, zone)
SIGNAL_ZONE_RESTORE, zone)
def handle_rel_message(sender, message): def handle_rel_message(sender, message):
"""Handle relay message from AlarmDecoder.""" """Handle relay message from AlarmDecoder."""
hass.helpers.dispatcher.dispatcher_send( hass.helpers.dispatcher.dispatcher_send(SIGNAL_REL_MESSAGE, message)
SIGNAL_REL_MESSAGE, message)
controller = False controller = False
if device_type == 'socket': if device_type == "socket":
host = device.get(CONF_HOST) host = device.get(CONF_HOST)
port = device.get(CONF_DEVICE_PORT) port = device.get(CONF_DEVICE_PORT)
controller = AlarmDecoder(SocketDevice(interface=(host, port))) controller = AlarmDecoder(SocketDevice(interface=(host, port)))
elif device_type == 'serial': elif device_type == "serial":
path = device.get(CONF_DEVICE_PATH) path = device.get(CONF_DEVICE_PATH)
baud = device.get(CONF_DEVICE_BAUD) baud = device.get(CONF_DEVICE_BAUD)
controller = AlarmDecoder(SerialDevice(interface=path)) controller = AlarmDecoder(SerialDevice(interface=path))
elif device_type == 'usb': elif device_type == "usb":
AlarmDecoder(USBDevice.find()) AlarmDecoder(USBDevice.find())
return False return False
@ -186,13 +203,12 @@ def setup(hass, config):
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_alarmdecoder) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_alarmdecoder)
load_platform(hass, 'alarm_control_panel', DOMAIN, conf, config) load_platform(hass, "alarm_control_panel", DOMAIN, conf, config)
if zones: if zones:
load_platform( load_platform(hass, "binary_sensor", DOMAIN, {CONF_ZONES: zones}, config)
hass, 'binary_sensor', DOMAIN, {CONF_ZONES: zones}, config)
if display: if display:
load_platform(hass, 'sensor', DOMAIN, conf, config) load_platform(hass, "sensor", DOMAIN, conf, config)
return True return True

View file

@ -5,18 +5,20 @@ import voluptuous as vol
import homeassistant.components.alarm_control_panel as alarm import homeassistant.components.alarm_control_panel as alarm
from homeassistant.const import ( from homeassistant.const import (
ATTR_CODE, STATE_ALARM_ARMED_AWAY, STATE_ALARM_ARMED_HOME, ATTR_CODE,
STATE_ALARM_DISARMED, STATE_ALARM_TRIGGERED) STATE_ALARM_ARMED_AWAY,
STATE_ALARM_ARMED_HOME,
STATE_ALARM_DISARMED,
STATE_ALARM_TRIGGERED,
)
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from . import DATA_AD, SIGNAL_PANEL_MESSAGE from . import DATA_AD, SIGNAL_PANEL_MESSAGE
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SERVICE_ALARM_TOGGLE_CHIME = 'alarmdecoder_alarm_toggle_chime' SERVICE_ALARM_TOGGLE_CHIME = "alarmdecoder_alarm_toggle_chime"
ALARM_TOGGLE_CHIME_SCHEMA = vol.Schema({ ALARM_TOGGLE_CHIME_SCHEMA = vol.Schema({vol.Required(ATTR_CODE): cv.string})
vol.Required(ATTR_CODE): cv.string,
})
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
@ -30,8 +32,11 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
device.alarm_toggle_chime(code) device.alarm_toggle_chime(code)
hass.services.register( hass.services.register(
alarm.DOMAIN, SERVICE_ALARM_TOGGLE_CHIME, alarm_toggle_chime_handler, alarm.DOMAIN,
schema=ALARM_TOGGLE_CHIME_SCHEMA) SERVICE_ALARM_TOGGLE_CHIME,
alarm_toggle_chime_handler,
schema=ALARM_TOGGLE_CHIME_SCHEMA,
)
class AlarmDecoderAlarmPanel(alarm.AlarmControlPanel): class AlarmDecoderAlarmPanel(alarm.AlarmControlPanel):
@ -55,7 +60,8 @@ class AlarmDecoderAlarmPanel(alarm.AlarmControlPanel):
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Register callbacks.""" """Register callbacks."""
self.hass.helpers.dispatcher.async_dispatcher_connect( self.hass.helpers.dispatcher.async_dispatcher_connect(
SIGNAL_PANEL_MESSAGE, self._message_callback) SIGNAL_PANEL_MESSAGE, self._message_callback
)
def _message_callback(self, message): def _message_callback(self, message):
"""Handle received messages.""" """Handle received messages."""
@ -104,15 +110,15 @@ class AlarmDecoderAlarmPanel(alarm.AlarmControlPanel):
def device_state_attributes(self): def device_state_attributes(self):
"""Return the state attributes.""" """Return the state attributes."""
return { return {
'ac_power': self._ac_power, "ac_power": self._ac_power,
'backlight_on': self._backlight_on, "backlight_on": self._backlight_on,
'battery_low': self._battery_low, "battery_low": self._battery_low,
'check_zone': self._check_zone, "check_zone": self._check_zone,
'chime': self._chime, "chime": self._chime,
'entry_delay_off': self._entry_delay_off, "entry_delay_off": self._entry_delay_off,
'programming_mode': self._programming_mode, "programming_mode": self._programming_mode,
'ready': self._ready, "ready": self._ready,
'zone_bypassed': self._zone_bypassed, "zone_bypassed": self._zone_bypassed,
} }
def alarm_disarm(self, code=None): def alarm_disarm(self, code=None):

View file

@ -4,20 +4,30 @@ import logging
from homeassistant.components.binary_sensor import BinarySensorDevice from homeassistant.components.binary_sensor import BinarySensorDevice
from . import ( from . import (
CONF_RELAY_ADDR, CONF_RELAY_CHAN, CONF_ZONE_LOOP, CONF_ZONE_NAME, CONF_RELAY_ADDR,
CONF_ZONE_RFID, CONF_ZONE_TYPE, CONF_ZONES, SIGNAL_REL_MESSAGE, CONF_RELAY_CHAN,
SIGNAL_RFX_MESSAGE, SIGNAL_ZONE_FAULT, SIGNAL_ZONE_RESTORE, ZONE_SCHEMA) CONF_ZONE_LOOP,
CONF_ZONE_NAME,
CONF_ZONE_RFID,
CONF_ZONE_TYPE,
CONF_ZONES,
SIGNAL_REL_MESSAGE,
SIGNAL_RFX_MESSAGE,
SIGNAL_ZONE_FAULT,
SIGNAL_ZONE_RESTORE,
ZONE_SCHEMA,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ATTR_RF_BIT0 = 'rf_bit0' ATTR_RF_BIT0 = "rf_bit0"
ATTR_RF_LOW_BAT = 'rf_low_battery' ATTR_RF_LOW_BAT = "rf_low_battery"
ATTR_RF_SUPERVISED = 'rf_supervised' ATTR_RF_SUPERVISED = "rf_supervised"
ATTR_RF_BIT3 = 'rf_bit3' ATTR_RF_BIT3 = "rf_bit3"
ATTR_RF_LOOP3 = 'rf_loop3' ATTR_RF_LOOP3 = "rf_loop3"
ATTR_RF_LOOP2 = 'rf_loop2' ATTR_RF_LOOP2 = "rf_loop2"
ATTR_RF_LOOP4 = 'rf_loop4' ATTR_RF_LOOP4 = "rf_loop4"
ATTR_RF_LOOP1 = 'rf_loop1' ATTR_RF_LOOP1 = "rf_loop1"
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
@ -34,8 +44,8 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
relay_addr = device_config_data.get(CONF_RELAY_ADDR) relay_addr = device_config_data.get(CONF_RELAY_ADDR)
relay_chan = device_config_data.get(CONF_RELAY_CHAN) relay_chan = device_config_data.get(CONF_RELAY_CHAN)
device = AlarmDecoderBinarySensor( device = AlarmDecoderBinarySensor(
zone_num, zone_name, zone_type, zone_rfid, zone_loop, relay_addr, zone_num, zone_name, zone_type, zone_rfid, zone_loop, relay_addr, relay_chan
relay_chan) )
devices.append(device) devices.append(device)
add_entities(devices) add_entities(devices)
@ -46,8 +56,16 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
class AlarmDecoderBinarySensor(BinarySensorDevice): class AlarmDecoderBinarySensor(BinarySensorDevice):
"""Representation of an AlarmDecoder binary sensor.""" """Representation of an AlarmDecoder binary sensor."""
def __init__(self, zone_number, zone_name, zone_type, zone_rfid, zone_loop, def __init__(
relay_addr, relay_chan): self,
zone_number,
zone_name,
zone_type,
zone_rfid,
zone_loop,
relay_addr,
relay_chan,
):
"""Initialize the binary_sensor.""" """Initialize the binary_sensor."""
self._zone_number = zone_number self._zone_number = zone_number
self._zone_type = zone_type self._zone_type = zone_type
@ -62,16 +80,20 @@ class AlarmDecoderBinarySensor(BinarySensorDevice):
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Register callbacks.""" """Register callbacks."""
self.hass.helpers.dispatcher.async_dispatcher_connect( self.hass.helpers.dispatcher.async_dispatcher_connect(
SIGNAL_ZONE_FAULT, self._fault_callback) SIGNAL_ZONE_FAULT, self._fault_callback
)
self.hass.helpers.dispatcher.async_dispatcher_connect( self.hass.helpers.dispatcher.async_dispatcher_connect(
SIGNAL_ZONE_RESTORE, self._restore_callback) SIGNAL_ZONE_RESTORE, self._restore_callback
)
self.hass.helpers.dispatcher.async_dispatcher_connect( self.hass.helpers.dispatcher.async_dispatcher_connect(
SIGNAL_RFX_MESSAGE, self._rfx_message_callback) SIGNAL_RFX_MESSAGE, self._rfx_message_callback
)
self.hass.helpers.dispatcher.async_dispatcher_connect( self.hass.helpers.dispatcher.async_dispatcher_connect(
SIGNAL_REL_MESSAGE, self._rel_message_callback) SIGNAL_REL_MESSAGE, self._rel_message_callback
)
@property @property
def name(self): def name(self):
@ -130,9 +152,9 @@ class AlarmDecoderBinarySensor(BinarySensorDevice):
def _rel_message_callback(self, message): def _rel_message_callback(self, message):
"""Update relay state.""" """Update relay state."""
if (self._relay_addr == message.address and if self._relay_addr == message.address and self._relay_chan == message.channel:
self._relay_chan == message.channel): _LOGGER.debug(
_LOGGER.debug("Relay %d:%d value:%d", message.address, "Relay %d:%d value:%d", message.address, message.channel, message.value
message.channel, message.value) )
self._state = message.value self._state = message.value
self.schedule_update_ha_state() self.schedule_update_ha_state()

View file

@ -24,13 +24,14 @@ class AlarmDecoderSensor(Entity):
"""Initialize the alarm panel.""" """Initialize the alarm panel."""
self._display = "" self._display = ""
self._state = None self._state = None
self._icon = 'mdi:alarm-check' self._icon = "mdi:alarm-check"
self._name = 'Alarm Panel Display' self._name = "Alarm Panel Display"
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Register callbacks.""" """Register callbacks."""
self.hass.helpers.dispatcher.async_dispatcher_connect( self.hass.helpers.dispatcher.async_dispatcher_connect(
SIGNAL_PANEL_MESSAGE, self._message_callback) SIGNAL_PANEL_MESSAGE, self._message_callback
)
def _message_callback(self, message): def _message_callback(self, message):
if self._display != message.text: if self._display != message.text:

View file

@ -7,25 +7,32 @@ import voluptuous as vol
import homeassistant.components.alarm_control_panel as alarm import homeassistant.components.alarm_control_panel as alarm
from homeassistant.components.alarm_control_panel import PLATFORM_SCHEMA from homeassistant.components.alarm_control_panel import PLATFORM_SCHEMA
from homeassistant.const import ( from homeassistant.const import (
CONF_CODE, CONF_NAME, CONF_PASSWORD, CONF_USERNAME, STATE_ALARM_ARMED_AWAY, CONF_CODE,
STATE_ALARM_ARMED_HOME, STATE_ALARM_DISARMED) CONF_NAME,
CONF_PASSWORD,
CONF_USERNAME,
STATE_ALARM_ARMED_AWAY,
STATE_ALARM_ARMED_HOME,
STATE_ALARM_DISARMED,
)
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DEFAULT_NAME = 'Alarm.com' DEFAULT_NAME = "Alarm.com"
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_PASSWORD): cv.string, {
vol.Required(CONF_USERNAME): cv.string, vol.Required(CONF_PASSWORD): cv.string,
vol.Optional(CONF_CODE): cv.positive_int, vol.Required(CONF_USERNAME): cv.string,
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_CODE): cv.positive_int,
}) vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
}
)
async def async_setup_platform(hass, config, async_add_entities, async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
discovery_info=None):
"""Set up a Alarm.com control panel.""" """Set up a Alarm.com control panel."""
name = config.get(CONF_NAME) name = config.get(CONF_NAME)
code = config.get(CONF_CODE) code = config.get(CONF_CODE)
@ -43,7 +50,8 @@ class AlarmDotCom(alarm.AlarmControlPanel):
def __init__(self, hass, name, code, username, password): def __init__(self, hass, name, code, username, password):
"""Initialize the Alarm.com status.""" """Initialize the Alarm.com status."""
from pyalarmdotcom import Alarmdotcom from pyalarmdotcom import Alarmdotcom
_LOGGER.debug('Setting up Alarm.com...')
_LOGGER.debug("Setting up Alarm.com...")
self._hass = hass self._hass = hass
self._name = name self._name = name
self._code = str(code) if code else None self._code = str(code) if code else None
@ -51,8 +59,7 @@ class AlarmDotCom(alarm.AlarmControlPanel):
self._password = password self._password = password
self._websession = async_get_clientsession(self._hass) self._websession = async_get_clientsession(self._hass)
self._state = None self._state = None
self._alarm = Alarmdotcom( self._alarm = Alarmdotcom(username, password, self._websession, hass.loop)
username, password, self._websession, hass.loop)
async def async_login(self): async def async_login(self):
"""Login to Alarm.com.""" """Login to Alarm.com."""
@ -73,27 +80,25 @@ class AlarmDotCom(alarm.AlarmControlPanel):
"""Return one or more digits/characters.""" """Return one or more digits/characters."""
if self._code is None: if self._code is None:
return None return None
if isinstance(self._code, str) and re.search('^\\d+$', self._code): if isinstance(self._code, str) and re.search("^\\d+$", self._code):
return alarm.FORMAT_NUMBER return alarm.FORMAT_NUMBER
return alarm.FORMAT_TEXT return alarm.FORMAT_TEXT
@property @property
def state(self): def state(self):
"""Return the state of the device.""" """Return the state of the device."""
if self._alarm.state.lower() == 'disarmed': if self._alarm.state.lower() == "disarmed":
return STATE_ALARM_DISARMED return STATE_ALARM_DISARMED
if self._alarm.state.lower() == 'armed stay': if self._alarm.state.lower() == "armed stay":
return STATE_ALARM_ARMED_HOME return STATE_ALARM_ARMED_HOME
if self._alarm.state.lower() == 'armed away': if self._alarm.state.lower() == "armed away":
return STATE_ALARM_ARMED_AWAY return STATE_ALARM_ARMED_AWAY
return None return None
@property @property
def device_state_attributes(self): def device_state_attributes(self):
"""Return the state attributes.""" """Return the state attributes."""
return { return {"sensor_status": self._alarm.sensor_status}
'sensor_status': self._alarm.sensor_status
}
async def async_alarm_disarm(self, code=None): async def async_alarm_disarm(self, code=None):
"""Send disarm command.""" """Send disarm command."""

View file

@ -7,51 +7,65 @@ import voluptuous as vol
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.components.notify import ( from homeassistant.components.notify import (
ATTR_MESSAGE, ATTR_TITLE, ATTR_DATA, DOMAIN as DOMAIN_NOTIFY) ATTR_MESSAGE,
ATTR_TITLE,
ATTR_DATA,
DOMAIN as DOMAIN_NOTIFY,
)
from homeassistant.const import ( from homeassistant.const import (
CONF_ENTITY_ID, STATE_IDLE, CONF_NAME, CONF_STATE, STATE_ON, STATE_OFF, CONF_ENTITY_ID,
SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE, ATTR_ENTITY_ID) STATE_IDLE,
CONF_NAME,
CONF_STATE,
STATE_ON,
STATE_OFF,
SERVICE_TURN_ON,
SERVICE_TURN_OFF,
SERVICE_TOGGLE,
ATTR_ENTITY_ID,
)
from homeassistant.helpers import service, event from homeassistant.helpers import service, event
from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity import ToggleEntity
from homeassistant.util.dt import now from homeassistant.util.dt import now
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DOMAIN = 'alert' DOMAIN = "alert"
ENTITY_ID_FORMAT = DOMAIN + '.{}' ENTITY_ID_FORMAT = DOMAIN + ".{}"
CONF_CAN_ACK = 'can_acknowledge' CONF_CAN_ACK = "can_acknowledge"
CONF_NOTIFIERS = 'notifiers' CONF_NOTIFIERS = "notifiers"
CONF_REPEAT = 'repeat' CONF_REPEAT = "repeat"
CONF_SKIP_FIRST = 'skip_first' CONF_SKIP_FIRST = "skip_first"
CONF_ALERT_MESSAGE = 'message' CONF_ALERT_MESSAGE = "message"
CONF_DONE_MESSAGE = 'done_message' CONF_DONE_MESSAGE = "done_message"
CONF_TITLE = 'title' CONF_TITLE = "title"
CONF_DATA = 'data' CONF_DATA = "data"
DEFAULT_CAN_ACK = True DEFAULT_CAN_ACK = True
DEFAULT_SKIP_FIRST = False DEFAULT_SKIP_FIRST = False
ALERT_SCHEMA = vol.Schema({ ALERT_SCHEMA = vol.Schema(
vol.Required(CONF_NAME): cv.string, {
vol.Required(CONF_ENTITY_ID): cv.entity_id, vol.Required(CONF_NAME): cv.string,
vol.Required(CONF_STATE, default=STATE_ON): cv.string, vol.Required(CONF_ENTITY_ID): cv.entity_id,
vol.Required(CONF_REPEAT): vol.All(cv.ensure_list, [vol.Coerce(float)]), vol.Required(CONF_STATE, default=STATE_ON): cv.string,
vol.Required(CONF_CAN_ACK, default=DEFAULT_CAN_ACK): cv.boolean, vol.Required(CONF_REPEAT): vol.All(cv.ensure_list, [vol.Coerce(float)]),
vol.Required(CONF_SKIP_FIRST, default=DEFAULT_SKIP_FIRST): cv.boolean, vol.Required(CONF_CAN_ACK, default=DEFAULT_CAN_ACK): cv.boolean,
vol.Optional(CONF_ALERT_MESSAGE): cv.template, vol.Required(CONF_SKIP_FIRST, default=DEFAULT_SKIP_FIRST): cv.boolean,
vol.Optional(CONF_DONE_MESSAGE): cv.template, vol.Optional(CONF_ALERT_MESSAGE): cv.template,
vol.Optional(CONF_TITLE): cv.template, vol.Optional(CONF_DONE_MESSAGE): cv.template,
vol.Optional(CONF_DATA): dict, vol.Optional(CONF_TITLE): cv.template,
vol.Required(CONF_NOTIFIERS): cv.ensure_list}) vol.Optional(CONF_DATA): dict,
vol.Required(CONF_NOTIFIERS): cv.ensure_list,
}
)
CONFIG_SCHEMA = vol.Schema({ CONFIG_SCHEMA = vol.Schema(
DOMAIN: cv.schema_with_slug_keys(ALERT_SCHEMA), {DOMAIN: cv.schema_with_slug_keys(ALERT_SCHEMA)}, extra=vol.ALLOW_EXTRA
}, extra=vol.ALLOW_EXTRA) )
ALERT_SERVICE_SCHEMA = vol.Schema({ ALERT_SERVICE_SCHEMA = vol.Schema({vol.Required(ATTR_ENTITY_ID): cv.entity_ids})
vol.Required(ATTR_ENTITY_ID): cv.entity_ids,
})
def is_on(hass, entity_id): def is_on(hass, entity_id):
@ -79,11 +93,23 @@ async def async_setup(hass, config):
title_template = cfg.get(CONF_TITLE) title_template = cfg.get(CONF_TITLE)
data = cfg.get(CONF_DATA) data = cfg.get(CONF_DATA)
entities.append(Alert(hass, object_id, name, entities.append(
watched_entity_id, alert_state, repeat, Alert(
skip_first, message_template, hass,
done_message_template, notifiers, object_id,
can_ack, title_template, data)) name,
watched_entity_id,
alert_state,
repeat,
skip_first,
message_template,
done_message_template,
notifiers,
can_ack,
title_template,
data,
)
)
if not entities: if not entities:
return False return False
@ -107,14 +133,17 @@ async def async_setup(hass, config):
# Setup service calls # Setup service calls
hass.services.async_register( hass.services.async_register(
DOMAIN, SERVICE_TURN_OFF, async_handle_alert_service, DOMAIN,
schema=ALERT_SERVICE_SCHEMA) SERVICE_TURN_OFF,
async_handle_alert_service,
schema=ALERT_SERVICE_SCHEMA,
)
hass.services.async_register( hass.services.async_register(
DOMAIN, SERVICE_TURN_ON, async_handle_alert_service, DOMAIN, SERVICE_TURN_ON, async_handle_alert_service, schema=ALERT_SERVICE_SCHEMA
schema=ALERT_SERVICE_SCHEMA) )
hass.services.async_register( hass.services.async_register(
DOMAIN, SERVICE_TOGGLE, async_handle_alert_service, DOMAIN, SERVICE_TOGGLE, async_handle_alert_service, schema=ALERT_SERVICE_SCHEMA
schema=ALERT_SERVICE_SCHEMA) )
tasks = [alert.async_update_ha_state() for alert in entities] tasks = [alert.async_update_ha_state() for alert in entities]
if tasks: if tasks:
@ -126,10 +155,22 @@ async def async_setup(hass, config):
class Alert(ToggleEntity): class Alert(ToggleEntity):
"""Representation of an alert.""" """Representation of an alert."""
def __init__(self, hass, entity_id, name, watched_entity_id, def __init__(
state, repeat, skip_first, message_template, self,
done_message_template, notifiers, can_ack, title_template, hass,
data): entity_id,
name,
watched_entity_id,
state,
repeat,
skip_first,
message_template,
done_message_template,
notifiers,
can_ack,
title_template,
data,
):
"""Initialize the alert.""" """Initialize the alert."""
self.hass = hass self.hass = hass
self._name = name self._name = name
@ -162,7 +203,8 @@ class Alert(ToggleEntity):
self.entity_id = ENTITY_ID_FORMAT.format(entity_id) self.entity_id = ENTITY_ID_FORMAT.format(entity_id)
event.async_track_state_change( event.async_track_state_change(
hass, watched_entity_id, self.watched_entity_change) hass, watched_entity_id, self.watched_entity_change
)
@property @property
def name(self): def name(self):
@ -224,8 +266,9 @@ class Alert(ToggleEntity):
"""Schedule a notification.""" """Schedule a notification."""
delay = self._delay[self._next_delay] delay = self._delay[self._next_delay]
next_msg = now() + delay next_msg = now() + delay
self._cancel = \ self._cancel = event.async_track_point_in_time(
event.async_track_point_in_time(self.hass, self._notify, next_msg) self.hass, self._notify, next_msg
)
self._next_delay = min(self._next_delay + 1, len(self._delay) - 1) self._next_delay = min(self._next_delay + 1, len(self._delay) - 1)
async def _notify(self, *args): async def _notify(self, *args):
@ -270,8 +313,7 @@ class Alert(ToggleEntity):
_LOGGER.debug(msg_payload) _LOGGER.debug(msg_payload)
for target in self._notifiers: for target in self._notifiers:
await self.hass.services.async_call( await self.hass.services.async_call(DOMAIN_NOTIFY, target, msg_payload)
DOMAIN_NOTIFY, target, msg_payload)
async def async_turn_on(self, **kwargs): async def async_turn_on(self, **kwargs):
"""Async Unacknowledge alert.""" """Async Unacknowledge alert."""

View file

@ -9,45 +9,68 @@ from homeassistant.const import CONF_NAME
from . import flash_briefings, intent, smart_home_http from . import flash_briefings, intent, smart_home_http
from .const import ( from .const import (
CONF_AUDIO, CONF_CLIENT_ID, CONF_CLIENT_SECRET, CONF_DISPLAY_URL, CONF_AUDIO,
CONF_ENDPOINT, CONF_TEXT, CONF_TITLE, CONF_UID, DOMAIN, CONF_FILTER, CONF_CLIENT_ID,
CONF_ENTITY_CONFIG, CONF_DESCRIPTION, CONF_DISPLAY_CATEGORIES) CONF_CLIENT_SECRET,
CONF_DISPLAY_URL,
CONF_ENDPOINT,
CONF_TEXT,
CONF_TITLE,
CONF_UID,
DOMAIN,
CONF_FILTER,
CONF_ENTITY_CONFIG,
CONF_DESCRIPTION,
CONF_DISPLAY_CATEGORIES,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
CONF_FLASH_BRIEFINGS = 'flash_briefings' CONF_FLASH_BRIEFINGS = "flash_briefings"
CONF_SMART_HOME = 'smart_home' CONF_SMART_HOME = "smart_home"
ALEXA_ENTITY_SCHEMA = vol.Schema({ ALEXA_ENTITY_SCHEMA = vol.Schema(
vol.Optional(CONF_DESCRIPTION): cv.string, {
vol.Optional(CONF_DISPLAY_CATEGORIES): cv.string, vol.Optional(CONF_DESCRIPTION): cv.string,
vol.Optional(CONF_NAME): cv.string, vol.Optional(CONF_DISPLAY_CATEGORIES): cv.string,
}) vol.Optional(CONF_NAME): cv.string,
SMART_HOME_SCHEMA = vol.Schema({
vol.Optional(CONF_ENDPOINT): cv.string,
vol.Optional(CONF_CLIENT_ID): cv.string,
vol.Optional(CONF_CLIENT_SECRET): cv.string,
vol.Optional(CONF_FILTER, default={}): entityfilter.FILTER_SCHEMA,
vol.Optional(CONF_ENTITY_CONFIG): {cv.entity_id: ALEXA_ENTITY_SCHEMA}
})
CONFIG_SCHEMA = vol.Schema({
DOMAIN: {
CONF_FLASH_BRIEFINGS: {
cv.string: vol.All(cv.ensure_list, [{
vol.Optional(CONF_UID): cv.string,
vol.Required(CONF_TITLE): cv.template,
vol.Optional(CONF_AUDIO): cv.template,
vol.Required(CONF_TEXT, default=""): cv.template,
vol.Optional(CONF_DISPLAY_URL): cv.template,
}]),
},
# vol.Optional here would mean we couldn't distinguish between an empty
# smart_home: and none at all.
CONF_SMART_HOME: vol.Any(SMART_HOME_SCHEMA, None),
} }
}, extra=vol.ALLOW_EXTRA) )
SMART_HOME_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ENDPOINT): cv.string,
vol.Optional(CONF_CLIENT_ID): cv.string,
vol.Optional(CONF_CLIENT_SECRET): cv.string,
vol.Optional(CONF_FILTER, default={}): entityfilter.FILTER_SCHEMA,
vol.Optional(CONF_ENTITY_CONFIG): {cv.entity_id: ALEXA_ENTITY_SCHEMA},
}
)
CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: {
CONF_FLASH_BRIEFINGS: {
cv.string: vol.All(
cv.ensure_list,
[
{
vol.Optional(CONF_UID): cv.string,
vol.Required(CONF_TITLE): cv.template,
vol.Optional(CONF_AUDIO): cv.template,
vol.Required(CONF_TEXT, default=""): cv.template,
vol.Optional(CONF_DISPLAY_URL): cv.template,
}
],
)
},
# vol.Optional here would mean we couldn't distinguish between an empty
# smart_home: and none at all.
CONF_SMART_HOME: vol.Any(SMART_HOME_SCHEMA, None),
}
},
extra=vol.ALLOW_EXTRA,
)
async def async_setup(hass, config): async def async_setup(hass, config):

View file

@ -13,12 +13,10 @@ from homeassistant.util import dt
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
LWA_TOKEN_URI = "https://api.amazon.com/auth/o2/token" LWA_TOKEN_URI = "https://api.amazon.com/auth/o2/token"
LWA_HEADERS = { LWA_HEADERS = {"Content-Type": "application/x-www-form-urlencoded;charset=UTF-8"}
"Content-Type": "application/x-www-form-urlencoded;charset=UTF-8"
}
PREEMPTIVE_REFRESH_TTL_IN_SECONDS = 300 PREEMPTIVE_REFRESH_TTL_IN_SECONDS = 300
STORAGE_KEY = 'alexa_auth' STORAGE_KEY = "alexa_auth"
STORAGE_VERSION = 1 STORAGE_VERSION = 1
STORAGE_EXPIRE_TIME = "expire_time" STORAGE_EXPIRE_TIME = "expire_time"
STORAGE_ACCESS_TOKEN = "access_token" STORAGE_ACCESS_TOKEN = "access_token"
@ -49,10 +47,12 @@ class Auth:
"grant_type": "authorization_code", "grant_type": "authorization_code",
"code": accept_grant_code, "code": accept_grant_code,
"client_id": self.client_id, "client_id": self.client_id,
"client_secret": self.client_secret "client_secret": self.client_secret,
} }
_LOGGER.debug("Calling LWA to get the access token (first time), " _LOGGER.debug(
"with: %s", json.dumps(lwa_params)) "Calling LWA to get the access token (first time), " "with: %s",
json.dumps(lwa_params),
)
return await self._async_request_new_token(lwa_params) return await self._async_request_new_token(lwa_params)
@ -74,7 +74,7 @@ class Auth:
"grant_type": "refresh_token", "grant_type": "refresh_token",
"refresh_token": self._prefs[STORAGE_REFRESH_TOKEN], "refresh_token": self._prefs[STORAGE_REFRESH_TOKEN],
"client_id": self.client_id, "client_id": self.client_id,
"client_secret": self.client_secret "client_secret": self.client_secret,
} }
_LOGGER.debug("Calling LWA to refresh the access token.") _LOGGER.debug("Calling LWA to refresh the access token.")
@ -88,7 +88,8 @@ class Auth:
expire_time = dt.parse_datetime(self._prefs[STORAGE_EXPIRE_TIME]) expire_time = dt.parse_datetime(self._prefs[STORAGE_EXPIRE_TIME])
preemptive_expire_time = expire_time - timedelta( preemptive_expire_time = expire_time - timedelta(
seconds=PREEMPTIVE_REFRESH_TTL_IN_SECONDS) seconds=PREEMPTIVE_REFRESH_TTL_IN_SECONDS
)
return dt.utcnow() < preemptive_expire_time return dt.utcnow() < preemptive_expire_time
@ -97,10 +98,12 @@ class Auth:
try: try:
session = aiohttp_client.async_get_clientsession(self.hass) session = aiohttp_client.async_get_clientsession(self.hass)
with async_timeout.timeout(10): with async_timeout.timeout(10):
response = await session.post(LWA_TOKEN_URI, response = await session.post(
headers=LWA_HEADERS, LWA_TOKEN_URI,
data=lwa_params, headers=LWA_HEADERS,
allow_redirects=True) data=lwa_params,
allow_redirects=True,
)
except (asyncio.TimeoutError, aiohttp.ClientError): except (asyncio.TimeoutError, aiohttp.ClientError):
_LOGGER.error("Timeout calling LWA to get auth token.") _LOGGER.error("Timeout calling LWA to get auth token.")
@ -121,8 +124,9 @@ class Auth:
expires_in = response_json["expires_in"] expires_in = response_json["expires_in"]
expire_time = dt.utcnow() + timedelta(seconds=expires_in) expire_time = dt.utcnow() + timedelta(seconds=expires_in)
await self._async_update_preferences(access_token, refresh_token, await self._async_update_preferences(
expire_time.isoformat()) access_token, refresh_token, expire_time.isoformat()
)
return access_token return access_token
@ -134,11 +138,10 @@ class Auth:
self._prefs = { self._prefs = {
STORAGE_ACCESS_TOKEN: None, STORAGE_ACCESS_TOKEN: None,
STORAGE_REFRESH_TOKEN: None, STORAGE_REFRESH_TOKEN: None,
STORAGE_EXPIRE_TIME: None STORAGE_EXPIRE_TIME: None,
} }
async def _async_update_preferences(self, access_token, refresh_token, async def _async_update_preferences(self, access_token, refresh_token, expire_time):
expire_time):
"""Update user preferences.""" """Update user preferences."""
if self._prefs is None: if self._prefs is None:
await self.async_load_preferences() await self.async_load_preferences()

View file

@ -13,11 +13,7 @@ from homeassistant.const import (
STATE_UNLOCKED, STATE_UNLOCKED,
) )
import homeassistant.components.climate.const as climate import homeassistant.components.climate.const as climate
from homeassistant.components import ( from homeassistant.components import light, fan, cover
light,
fan,
cover,
)
import homeassistant.util.color as color_util import homeassistant.util.color as color_util
from .const import ( from .const import (
@ -85,35 +81,35 @@ class AlexaCapibility:
def serialize_discovery(self): def serialize_discovery(self):
"""Serialize according to the Discovery API.""" """Serialize according to the Discovery API."""
result = { result = {
'type': 'AlexaInterface', "type": "AlexaInterface",
'interface': self.name(), "interface": self.name(),
'version': '3', "version": "3",
'properties': { "properties": {
'supported': self.properties_supported(), "supported": self.properties_supported(),
'proactivelyReported': self.properties_proactively_reported(), "proactivelyReported": self.properties_proactively_reported(),
'retrievable': self.properties_retrievable(), "retrievable": self.properties_retrievable(),
}, },
} }
# pylint: disable=assignment-from-none # pylint: disable=assignment-from-none
supports_deactivation = self.supports_deactivation() supports_deactivation = self.supports_deactivation()
if supports_deactivation is not None: if supports_deactivation is not None:
result['supportsDeactivation'] = supports_deactivation result["supportsDeactivation"] = supports_deactivation
return result return result
def serialize_properties(self): def serialize_properties(self):
"""Return properties serialized for an API response.""" """Return properties serialized for an API response."""
for prop in self.properties_supported(): for prop in self.properties_supported():
prop_name = prop['name'] prop_name = prop["name"]
# pylint: disable=assignment-from-no-return # pylint: disable=assignment-from-no-return
prop_value = self.get_property(prop_name) prop_value = self.get_property(prop_name)
if prop_value is not None: if prop_value is not None:
yield { yield {
'name': prop_name, "name": prop_name,
'namespace': self.name(), "namespace": self.name(),
'value': prop_value, "value": prop_value,
'timeOfSample': datetime.now().strftime(DATE_FORMAT), "timeOfSample": datetime.now().strftime(DATE_FORMAT),
'uncertaintyInMilliseconds': 0 "uncertaintyInMilliseconds": 0,
} }
@ -130,11 +126,11 @@ class AlexaEndpointHealth(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.EndpointHealth' return "Alexa.EndpointHealth"
def properties_supported(self): def properties_supported(self):
"""Return what properties this entity supports.""" """Return what properties this entity supports."""
return [{'name': 'connectivity'}] return [{"name": "connectivity"}]
def properties_proactively_reported(self): def properties_proactively_reported(self):
"""Return True if properties asynchronously reported.""" """Return True if properties asynchronously reported."""
@ -146,12 +142,12 @@ class AlexaEndpointHealth(AlexaCapibility):
def get_property(self, name): def get_property(self, name):
"""Read and return a property.""" """Read and return a property."""
if name != 'connectivity': if name != "connectivity":
raise UnsupportedProperty(name) raise UnsupportedProperty(name)
if self.entity.state == STATE_UNAVAILABLE: if self.entity.state == STATE_UNAVAILABLE:
return {'value': 'UNREACHABLE'} return {"value": "UNREACHABLE"}
return {'value': 'OK'} return {"value": "OK"}
class AlexaPowerController(AlexaCapibility): class AlexaPowerController(AlexaCapibility):
@ -162,11 +158,11 @@ class AlexaPowerController(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.PowerController' return "Alexa.PowerController"
def properties_supported(self): def properties_supported(self):
"""Return what properties this entity supports.""" """Return what properties this entity supports."""
return [{'name': 'powerState'}] return [{"name": "powerState"}]
def properties_proactively_reported(self): def properties_proactively_reported(self):
"""Return True if properties asynchronously reported.""" """Return True if properties asynchronously reported."""
@ -178,7 +174,7 @@ class AlexaPowerController(AlexaCapibility):
def get_property(self, name): def get_property(self, name):
"""Read and return a property.""" """Read and return a property."""
if name != 'powerState': if name != "powerState":
raise UnsupportedProperty(name) raise UnsupportedProperty(name)
if self.entity.domain == climate.DOMAIN: if self.entity.domain == climate.DOMAIN:
@ -187,7 +183,7 @@ class AlexaPowerController(AlexaCapibility):
else: else:
is_on = self.entity.state != STATE_OFF is_on = self.entity.state != STATE_OFF
return 'ON' if is_on else 'OFF' return "ON" if is_on else "OFF"
class AlexaLockController(AlexaCapibility): class AlexaLockController(AlexaCapibility):
@ -198,11 +194,11 @@ class AlexaLockController(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.LockController' return "Alexa.LockController"
def properties_supported(self): def properties_supported(self):
"""Return what properties this entity supports.""" """Return what properties this entity supports."""
return [{'name': 'lockState'}] return [{"name": "lockState"}]
def properties_retrievable(self): def properties_retrievable(self):
"""Return True if properties can be retrieved.""" """Return True if properties can be retrieved."""
@ -214,14 +210,14 @@ class AlexaLockController(AlexaCapibility):
def get_property(self, name): def get_property(self, name):
"""Read and return a property.""" """Read and return a property."""
if name != 'lockState': if name != "lockState":
raise UnsupportedProperty(name) raise UnsupportedProperty(name)
if self.entity.state == STATE_LOCKED: if self.entity.state == STATE_LOCKED:
return 'LOCKED' return "LOCKED"
if self.entity.state == STATE_UNLOCKED: if self.entity.state == STATE_UNLOCKED:
return 'UNLOCKED' return "UNLOCKED"
return 'JAMMED' return "JAMMED"
class AlexaSceneController(AlexaCapibility): class AlexaSceneController(AlexaCapibility):
@ -237,7 +233,7 @@ class AlexaSceneController(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.SceneController' return "Alexa.SceneController"
class AlexaBrightnessController(AlexaCapibility): class AlexaBrightnessController(AlexaCapibility):
@ -248,11 +244,11 @@ class AlexaBrightnessController(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.BrightnessController' return "Alexa.BrightnessController"
def properties_supported(self): def properties_supported(self):
"""Return what properties this entity supports.""" """Return what properties this entity supports."""
return [{'name': 'brightness'}] return [{"name": "brightness"}]
def properties_proactively_reported(self): def properties_proactively_reported(self):
"""Return True if properties asynchronously reported.""" """Return True if properties asynchronously reported."""
@ -264,10 +260,10 @@ class AlexaBrightnessController(AlexaCapibility):
def get_property(self, name): def get_property(self, name):
"""Read and return a property.""" """Read and return a property."""
if name != 'brightness': if name != "brightness":
raise UnsupportedProperty(name) raise UnsupportedProperty(name)
if 'brightness' in self.entity.attributes: if "brightness" in self.entity.attributes:
return round(self.entity.attributes['brightness'] / 255.0 * 100) return round(self.entity.attributes["brightness"] / 255.0 * 100)
return 0 return 0
@ -279,11 +275,11 @@ class AlexaColorController(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.ColorController' return "Alexa.ColorController"
def properties_supported(self): def properties_supported(self):
"""Return what properties this entity supports.""" """Return what properties this entity supports."""
return [{'name': 'color'}] return [{"name": "color"}]
def properties_retrievable(self): def properties_retrievable(self):
"""Return True if properties can be retrieved.""" """Return True if properties can be retrieved."""
@ -291,17 +287,15 @@ class AlexaColorController(AlexaCapibility):
def get_property(self, name): def get_property(self, name):
"""Read and return a property.""" """Read and return a property."""
if name != 'color': if name != "color":
raise UnsupportedProperty(name) raise UnsupportedProperty(name)
hue, saturation = self.entity.attributes.get( hue, saturation = self.entity.attributes.get(light.ATTR_HS_COLOR, (0, 0))
light.ATTR_HS_COLOR, (0, 0))
return { return {
'hue': hue, "hue": hue,
'saturation': saturation / 100.0, "saturation": saturation / 100.0,
'brightness': self.entity.attributes.get( "brightness": self.entity.attributes.get(light.ATTR_BRIGHTNESS, 0) / 255.0,
light.ATTR_BRIGHTNESS, 0) / 255.0,
} }
@ -313,11 +307,11 @@ class AlexaColorTemperatureController(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.ColorTemperatureController' return "Alexa.ColorTemperatureController"
def properties_supported(self): def properties_supported(self):
"""Return what properties this entity supports.""" """Return what properties this entity supports."""
return [{'name': 'colorTemperatureInKelvin'}] return [{"name": "colorTemperatureInKelvin"}]
def properties_retrievable(self): def properties_retrievable(self):
"""Return True if properties can be retrieved.""" """Return True if properties can be retrieved."""
@ -325,11 +319,12 @@ class AlexaColorTemperatureController(AlexaCapibility):
def get_property(self, name): def get_property(self, name):
"""Read and return a property.""" """Read and return a property."""
if name != 'colorTemperatureInKelvin': if name != "colorTemperatureInKelvin":
raise UnsupportedProperty(name) raise UnsupportedProperty(name)
if 'color_temp' in self.entity.attributes: if "color_temp" in self.entity.attributes:
return color_util.color_temperature_mired_to_kelvin( return color_util.color_temperature_mired_to_kelvin(
self.entity.attributes['color_temp']) self.entity.attributes["color_temp"]
)
return 0 return 0
@ -341,11 +336,11 @@ class AlexaPercentageController(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.PercentageController' return "Alexa.PercentageController"
def properties_supported(self): def properties_supported(self):
"""Return what properties this entity supports.""" """Return what properties this entity supports."""
return [{'name': 'percentage'}] return [{"name": "percentage"}]
def properties_retrievable(self): def properties_retrievable(self):
"""Return True if properties can be retrieved.""" """Return True if properties can be retrieved."""
@ -353,7 +348,7 @@ class AlexaPercentageController(AlexaCapibility):
def get_property(self, name): def get_property(self, name):
"""Read and return a property.""" """Read and return a property."""
if name != 'percentage': if name != "percentage":
raise UnsupportedProperty(name) raise UnsupportedProperty(name)
if self.entity.domain == fan.DOMAIN: if self.entity.domain == fan.DOMAIN:
@ -375,7 +370,7 @@ class AlexaSpeaker(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.Speaker' return "Alexa.Speaker"
class AlexaStepSpeaker(AlexaCapibility): class AlexaStepSpeaker(AlexaCapibility):
@ -386,7 +381,7 @@ class AlexaStepSpeaker(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.StepSpeaker' return "Alexa.StepSpeaker"
class AlexaPlaybackController(AlexaCapibility): class AlexaPlaybackController(AlexaCapibility):
@ -397,7 +392,7 @@ class AlexaPlaybackController(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.PlaybackController' return "Alexa.PlaybackController"
class AlexaInputController(AlexaCapibility): class AlexaInputController(AlexaCapibility):
@ -408,7 +403,7 @@ class AlexaInputController(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.InputController' return "Alexa.InputController"
class AlexaTemperatureSensor(AlexaCapibility): class AlexaTemperatureSensor(AlexaCapibility):
@ -424,11 +419,11 @@ class AlexaTemperatureSensor(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.TemperatureSensor' return "Alexa.TemperatureSensor"
def properties_supported(self): def properties_supported(self):
"""Return what properties this entity supports.""" """Return what properties this entity supports."""
return [{'name': 'temperature'}] return [{"name": "temperature"}]
def properties_proactively_reported(self): def properties_proactively_reported(self):
"""Return True if properties asynchronously reported.""" """Return True if properties asynchronously reported."""
@ -440,19 +435,15 @@ class AlexaTemperatureSensor(AlexaCapibility):
def get_property(self, name): def get_property(self, name):
"""Read and return a property.""" """Read and return a property."""
if name != 'temperature': if name != "temperature":
raise UnsupportedProperty(name) raise UnsupportedProperty(name)
unit = self.entity.attributes.get(ATTR_UNIT_OF_MEASUREMENT) unit = self.entity.attributes.get(ATTR_UNIT_OF_MEASUREMENT)
temp = self.entity.state temp = self.entity.state
if self.entity.domain == climate.DOMAIN: if self.entity.domain == climate.DOMAIN:
unit = self.hass.config.units.temperature_unit unit = self.hass.config.units.temperature_unit
temp = self.entity.attributes.get( temp = self.entity.attributes.get(climate.ATTR_CURRENT_TEMPERATURE)
climate.ATTR_CURRENT_TEMPERATURE) return {"value": float(temp), "scale": API_TEMP_UNITS[unit]}
return {
'value': float(temp),
'scale': API_TEMP_UNITS[unit],
}
class AlexaContactSensor(AlexaCapibility): class AlexaContactSensor(AlexaCapibility):
@ -473,11 +464,11 @@ class AlexaContactSensor(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.ContactSensor' return "Alexa.ContactSensor"
def properties_supported(self): def properties_supported(self):
"""Return what properties this entity supports.""" """Return what properties this entity supports."""
return [{'name': 'detectionState'}] return [{"name": "detectionState"}]
def properties_proactively_reported(self): def properties_proactively_reported(self):
"""Return True if properties asynchronously reported.""" """Return True if properties asynchronously reported."""
@ -489,12 +480,12 @@ class AlexaContactSensor(AlexaCapibility):
def get_property(self, name): def get_property(self, name):
"""Read and return a property.""" """Read and return a property."""
if name != 'detectionState': if name != "detectionState":
raise UnsupportedProperty(name) raise UnsupportedProperty(name)
if self.entity.state == STATE_ON: if self.entity.state == STATE_ON:
return 'DETECTED' return "DETECTED"
return 'NOT_DETECTED' return "NOT_DETECTED"
class AlexaMotionSensor(AlexaCapibility): class AlexaMotionSensor(AlexaCapibility):
@ -510,11 +501,11 @@ class AlexaMotionSensor(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.MotionSensor' return "Alexa.MotionSensor"
def properties_supported(self): def properties_supported(self):
"""Return what properties this entity supports.""" """Return what properties this entity supports."""
return [{'name': 'detectionState'}] return [{"name": "detectionState"}]
def properties_proactively_reported(self): def properties_proactively_reported(self):
"""Return True if properties asynchronously reported.""" """Return True if properties asynchronously reported."""
@ -526,12 +517,12 @@ class AlexaMotionSensor(AlexaCapibility):
def get_property(self, name): def get_property(self, name):
"""Read and return a property.""" """Read and return a property."""
if name != 'detectionState': if name != "detectionState":
raise UnsupportedProperty(name) raise UnsupportedProperty(name)
if self.entity.state == STATE_ON: if self.entity.state == STATE_ON:
return 'DETECTED' return "DETECTED"
return 'NOT_DETECTED' return "NOT_DETECTED"
class AlexaThermostatController(AlexaCapibility): class AlexaThermostatController(AlexaCapibility):
@ -547,17 +538,17 @@ class AlexaThermostatController(AlexaCapibility):
def name(self): def name(self):
"""Return the Alexa API name of this interface.""" """Return the Alexa API name of this interface."""
return 'Alexa.ThermostatController' return "Alexa.ThermostatController"
def properties_supported(self): def properties_supported(self):
"""Return what properties this entity supports.""" """Return what properties this entity supports."""
properties = [{'name': 'thermostatMode'}] properties = [{"name": "thermostatMode"}]
supported = self.entity.attributes.get(ATTR_SUPPORTED_FEATURES, 0) supported = self.entity.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
if supported & climate.SUPPORT_TARGET_TEMPERATURE: if supported & climate.SUPPORT_TARGET_TEMPERATURE:
properties.append({'name': 'targetSetpoint'}) properties.append({"name": "targetSetpoint"})
if supported & climate.SUPPORT_TARGET_TEMPERATURE_RANGE: if supported & climate.SUPPORT_TARGET_TEMPERATURE_RANGE:
properties.append({'name': 'lowerSetpoint'}) properties.append({"name": "lowerSetpoint"})
properties.append({'name': 'upperSetpoint'}) properties.append({"name": "upperSetpoint"})
return properties return properties
def properties_proactively_reported(self): def properties_proactively_reported(self):
@ -570,7 +561,7 @@ class AlexaThermostatController(AlexaCapibility):
def get_property(self, name): def get_property(self, name):
"""Read and return a property.""" """Read and return a property."""
if name == 'thermostatMode': if name == "thermostatMode":
preset = self.entity.attributes.get(climate.ATTR_PRESET_MODE) preset = self.entity.attributes.get(climate.ATTR_PRESET_MODE)
if preset in API_THERMOSTAT_PRESETS: if preset in API_THERMOSTAT_PRESETS:
@ -580,17 +571,19 @@ class AlexaThermostatController(AlexaCapibility):
if mode is None: if mode is None:
_LOGGER.error( _LOGGER.error(
"%s (%s) has unsupported state value '%s'", "%s (%s) has unsupported state value '%s'",
self.entity.entity_id, type(self.entity), self.entity.entity_id,
self.entity.state) type(self.entity),
self.entity.state,
)
raise UnsupportedProperty(name) raise UnsupportedProperty(name)
return mode return mode
unit = self.hass.config.units.temperature_unit unit = self.hass.config.units.temperature_unit
if name == 'targetSetpoint': if name == "targetSetpoint":
temp = self.entity.attributes.get(ATTR_TEMPERATURE) temp = self.entity.attributes.get(ATTR_TEMPERATURE)
elif name == 'lowerSetpoint': elif name == "lowerSetpoint":
temp = self.entity.attributes.get(climate.ATTR_TARGET_TEMP_LOW) temp = self.entity.attributes.get(climate.ATTR_TARGET_TEMP_LOW)
elif name == 'upperSetpoint': elif name == "upperSetpoint":
temp = self.entity.attributes.get(climate.ATTR_TARGET_TEMP_HIGH) temp = self.entity.attributes.get(climate.ATTR_TARGET_TEMP_HIGH)
else: else:
raise UnsupportedProperty(name) raise UnsupportedProperty(name)
@ -598,7 +591,4 @@ class AlexaThermostatController(AlexaCapibility):
if temp is None: if temp is None:
return None return None
return { return {"value": float(temp), "scale": API_TEMP_UNITS[unit]}
'value': float(temp),
'scale': API_TEMP_UNITS[unit],
}

View file

@ -1,78 +1,68 @@
"""Constants for the Alexa integration.""" """Constants for the Alexa integration."""
from collections import OrderedDict from collections import OrderedDict
from homeassistant.const import ( from homeassistant.const import TEMP_CELSIUS, TEMP_FAHRENHEIT
TEMP_CELSIUS,
TEMP_FAHRENHEIT,
)
from homeassistant.components.climate import const as climate from homeassistant.components.climate import const as climate
from homeassistant.components import fan from homeassistant.components import fan
DOMAIN = 'alexa' DOMAIN = "alexa"
# Flash briefing constants # Flash briefing constants
CONF_UID = 'uid' CONF_UID = "uid"
CONF_TITLE = 'title' CONF_TITLE = "title"
CONF_AUDIO = 'audio' CONF_AUDIO = "audio"
CONF_TEXT = 'text' CONF_TEXT = "text"
CONF_DISPLAY_URL = 'display_url' CONF_DISPLAY_URL = "display_url"
CONF_FILTER = 'filter' CONF_FILTER = "filter"
CONF_ENTITY_CONFIG = 'entity_config' CONF_ENTITY_CONFIG = "entity_config"
CONF_ENDPOINT = 'endpoint' CONF_ENDPOINT = "endpoint"
CONF_CLIENT_ID = 'client_id' CONF_CLIENT_ID = "client_id"
CONF_CLIENT_SECRET = 'client_secret' CONF_CLIENT_SECRET = "client_secret"
ATTR_UID = 'uid' ATTR_UID = "uid"
ATTR_UPDATE_DATE = 'updateDate' ATTR_UPDATE_DATE = "updateDate"
ATTR_TITLE_TEXT = 'titleText' ATTR_TITLE_TEXT = "titleText"
ATTR_STREAM_URL = 'streamUrl' ATTR_STREAM_URL = "streamUrl"
ATTR_MAIN_TEXT = 'mainText' ATTR_MAIN_TEXT = "mainText"
ATTR_REDIRECTION_URL = 'redirectionURL' ATTR_REDIRECTION_URL = "redirectionURL"
SYN_RESOLUTION_MATCH = 'ER_SUCCESS_MATCH' SYN_RESOLUTION_MATCH = "ER_SUCCESS_MATCH"
DATE_FORMAT = '%Y-%m-%dT%H:%M:%S.0Z' DATE_FORMAT = "%Y-%m-%dT%H:%M:%S.0Z"
API_DIRECTIVE = 'directive' API_DIRECTIVE = "directive"
API_ENDPOINT = 'endpoint' API_ENDPOINT = "endpoint"
API_EVENT = 'event' API_EVENT = "event"
API_CONTEXT = 'context' API_CONTEXT = "context"
API_HEADER = 'header' API_HEADER = "header"
API_PAYLOAD = 'payload' API_PAYLOAD = "payload"
API_SCOPE = 'scope' API_SCOPE = "scope"
API_CHANGE = 'change' API_CHANGE = "change"
CONF_DESCRIPTION = 'description' CONF_DESCRIPTION = "description"
CONF_DISPLAY_CATEGORIES = 'display_categories' CONF_DISPLAY_CATEGORIES = "display_categories"
API_TEMP_UNITS = { API_TEMP_UNITS = {TEMP_FAHRENHEIT: "FAHRENHEIT", TEMP_CELSIUS: "CELSIUS"}
TEMP_FAHRENHEIT: 'FAHRENHEIT',
TEMP_CELSIUS: 'CELSIUS',
}
# Needs to be ordered dict for `async_api_set_thermostat_mode` which does a # Needs to be ordered dict for `async_api_set_thermostat_mode` which does a
# reverse mapping of this dict and we want to map the first occurrance of OFF # reverse mapping of this dict and we want to map the first occurrance of OFF
# back to HA state. # back to HA state.
API_THERMOSTAT_MODES = OrderedDict([ API_THERMOSTAT_MODES = OrderedDict(
(climate.HVAC_MODE_HEAT, 'HEAT'), [
(climate.HVAC_MODE_COOL, 'COOL'), (climate.HVAC_MODE_HEAT, "HEAT"),
(climate.HVAC_MODE_HEAT_COOL, 'AUTO'), (climate.HVAC_MODE_COOL, "COOL"),
(climate.HVAC_MODE_AUTO, 'AUTO'), (climate.HVAC_MODE_HEAT_COOL, "AUTO"),
(climate.HVAC_MODE_OFF, 'OFF'), (climate.HVAC_MODE_AUTO, "AUTO"),
(climate.HVAC_MODE_FAN_ONLY, 'OFF'), (climate.HVAC_MODE_OFF, "OFF"),
(climate.HVAC_MODE_DRY, 'OFF'), (climate.HVAC_MODE_FAN_ONLY, "OFF"),
]) (climate.HVAC_MODE_DRY, "OFF"),
API_THERMOSTAT_PRESETS = { ]
climate.PRESET_ECO: 'ECO' )
} API_THERMOSTAT_PRESETS = {climate.PRESET_ECO: "ECO"}
PERCENTAGE_FAN_MAP = { PERCENTAGE_FAN_MAP = {fan.SPEED_LOW: 33, fan.SPEED_MEDIUM: 66, fan.SPEED_HIGH: 100}
fan.SPEED_LOW: 33,
fan.SPEED_MEDIUM: 66,
fan.SPEED_HIGH: 100,
}
class Cause: class Cause:
@ -84,25 +74,25 @@ class Cause:
# Indicates that the event was caused by a customer interaction with an # Indicates that the event was caused by a customer interaction with an
# application. For example, a customer switches on a light, or locks a door # application. For example, a customer switches on a light, or locks a door
# using the Alexa app or an app provided by a device vendor. # using the Alexa app or an app provided by a device vendor.
APP_INTERACTION = 'APP_INTERACTION' APP_INTERACTION = "APP_INTERACTION"
# Indicates that the event was caused by a physical interaction with an # Indicates that the event was caused by a physical interaction with an
# endpoint. For example manually switching on a light or manually locking a # endpoint. For example manually switching on a light or manually locking a
# door lock # door lock
PHYSICAL_INTERACTION = 'PHYSICAL_INTERACTION' PHYSICAL_INTERACTION = "PHYSICAL_INTERACTION"
# Indicates that the event was caused by the periodic poll of an appliance, # Indicates that the event was caused by the periodic poll of an appliance,
# which found a change in value. For example, you might poll a temperature # which found a change in value. For example, you might poll a temperature
# sensor every hour, and send the updated temperature to Alexa. # sensor every hour, and send the updated temperature to Alexa.
PERIODIC_POLL = 'PERIODIC_POLL' PERIODIC_POLL = "PERIODIC_POLL"
# Indicates that the event was caused by the application of a device rule. # Indicates that the event was caused by the application of a device rule.
# For example, a customer configures a rule to switch on a light if a # For example, a customer configures a rule to switch on a light if a
# motion sensor detects motion. In this case, Alexa receives an event from # motion sensor detects motion. In this case, Alexa receives an event from
# the motion sensor, and another event from the light to indicate that its # the motion sensor, and another event from the light to indicate that its
# state change was caused by the rule. # state change was caused by the rule.
RULE_TRIGGER = 'RULE_TRIGGER' RULE_TRIGGER = "RULE_TRIGGER"
# Indicates that the event was caused by a voice interaction with Alexa. # Indicates that the event was caused by a voice interaction with Alexa.
# For example a user speaking to their Echo device. # For example a user speaking to their Echo device.
VOICE_INTERACTION = 'VOICE_INTERACTION' VOICE_INTERACTION = "VOICE_INTERACTION"

View file

@ -14,8 +14,21 @@ from homeassistant.const import (
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
from homeassistant.components.climate import const as climate from homeassistant.components.climate import const as climate
from homeassistant.components import ( from homeassistant.components import (
alert, automation, binary_sensor, cover, fan, group, alert,
input_boolean, light, lock, media_player, scene, script, sensor, switch) automation,
binary_sensor,
cover,
fan,
group,
input_boolean,
light,
lock,
media_player,
scene,
script,
sensor,
switch,
)
from .const import CONF_DESCRIPTION, CONF_DISPLAY_CATEGORIES from .const import CONF_DESCRIPTION, CONF_DISPLAY_CATEGORIES
from .capabilities import ( from .capabilities import (
@ -129,7 +142,7 @@ class AlexaEntity:
def alexa_id(self): def alexa_id(self):
"""Return the Alexa API entity id.""" """Return the Alexa API entity id."""
return self.entity.entity_id.replace('.', '#') return self.entity.entity_id.replace(".", "#")
def display_categories(self): def display_categories(self):
"""Return a list of display categories.""" """Return a list of display categories."""
@ -171,15 +184,13 @@ class AlexaEntity:
def serialize_discovery(self): def serialize_discovery(self):
"""Serialize the entity for discovery.""" """Serialize the entity for discovery."""
return { return {
'displayCategories': self.display_categories(), "displayCategories": self.display_categories(),
'cookie': {}, "cookie": {},
'endpointId': self.alexa_id(), "endpointId": self.alexa_id(),
'friendlyName': self.friendly_name(), "friendlyName": self.friendly_name(),
'description': self.description(), "description": self.description(),
'manufacturerName': 'Home Assistant', "manufacturerName": "Home Assistant",
'capabilities': [ "capabilities": [i.serialize_discovery() for i in self.interfaces()],
i.serialize_discovery() for i in self.interfaces()
]
} }
@ -220,8 +231,10 @@ class GenericCapabilities(AlexaEntity):
def interfaces(self): def interfaces(self):
"""Yield the supported interfaces.""" """Yield the supported interfaces."""
return [AlexaPowerController(self.entity), return [
AlexaEndpointHealth(self.hass, self.entity)] AlexaPowerController(self.entity),
AlexaEndpointHealth(self.hass, self.entity),
]
@ENTITY_ADAPTERS.register(switch.DOMAIN) @ENTITY_ADAPTERS.register(switch.DOMAIN)
@ -234,8 +247,10 @@ class SwitchCapabilities(AlexaEntity):
def interfaces(self): def interfaces(self):
"""Yield the supported interfaces.""" """Yield the supported interfaces."""
return [AlexaPowerController(self.entity), return [
AlexaEndpointHealth(self.hass, self.entity)] AlexaPowerController(self.entity),
AlexaEndpointHealth(self.hass, self.entity),
]
@ENTITY_ADAPTERS.register(climate.DOMAIN) @ENTITY_ADAPTERS.register(climate.DOMAIN)
@ -249,8 +264,7 @@ class ClimateCapabilities(AlexaEntity):
def interfaces(self): def interfaces(self):
"""Yield the supported interfaces.""" """Yield the supported interfaces."""
# If we support two modes, one being off, we allow turning on too. # If we support two modes, one being off, we allow turning on too.
if (climate.HVAC_MODE_OFF in if climate.HVAC_MODE_OFF in self.entity.attributes[climate.ATTR_HVAC_MODES]:
self.entity.attributes[climate.ATTR_HVAC_MODES]):
yield AlexaPowerController(self.entity) yield AlexaPowerController(self.entity)
yield AlexaThermostatController(self.hass, self.entity) yield AlexaThermostatController(self.hass, self.entity)
@ -324,8 +338,10 @@ class LockCapabilities(AlexaEntity):
def interfaces(self): def interfaces(self):
"""Yield the supported interfaces.""" """Yield the supported interfaces."""
return [AlexaLockController(self.entity), return [
AlexaEndpointHealth(self.hass, self.entity)] AlexaLockController(self.entity),
AlexaEndpointHealth(self.hass, self.entity),
]
@ENTITY_ADAPTERS.register(media_player.const.DOMAIN) @ENTITY_ADAPTERS.register(media_player.const.DOMAIN)
@ -345,16 +361,20 @@ class MediaPlayerCapabilities(AlexaEntity):
if supported & media_player.const.SUPPORT_VOLUME_SET: if supported & media_player.const.SUPPORT_VOLUME_SET:
yield AlexaSpeaker(self.entity) yield AlexaSpeaker(self.entity)
step_volume_features = (media_player.const.SUPPORT_VOLUME_MUTE | step_volume_features = (
media_player.const.SUPPORT_VOLUME_STEP) media_player.const.SUPPORT_VOLUME_MUTE
| media_player.const.SUPPORT_VOLUME_STEP
)
if supported & step_volume_features: if supported & step_volume_features:
yield AlexaStepSpeaker(self.entity) yield AlexaStepSpeaker(self.entity)
playback_features = (media_player.const.SUPPORT_PLAY | playback_features = (
media_player.const.SUPPORT_PAUSE | media_player.const.SUPPORT_PLAY
media_player.const.SUPPORT_STOP | | media_player.const.SUPPORT_PAUSE
media_player.const.SUPPORT_NEXT_TRACK | | media_player.const.SUPPORT_STOP
media_player.const.SUPPORT_PREVIOUS_TRACK) | media_player.const.SUPPORT_NEXT_TRACK
| media_player.const.SUPPORT_PREVIOUS_TRACK
)
if supported & playback_features: if supported & playback_features:
yield AlexaPlaybackController(self.entity) yield AlexaPlaybackController(self.entity)
@ -369,7 +389,7 @@ class SceneCapabilities(AlexaEntity):
def description(self): def description(self):
"""Return the description of the entity.""" """Return the description of the entity."""
# Required description as per Amazon Scene docs # Required description as per Amazon Scene docs
scene_fmt = '{} (Scene connected via Home Assistant)' scene_fmt = "{} (Scene connected via Home Assistant)"
return scene_fmt.format(AlexaEntity.description(self)) return scene_fmt.format(AlexaEntity.description(self))
def default_display_categories(self): def default_display_categories(self):
@ -378,8 +398,7 @@ class SceneCapabilities(AlexaEntity):
def interfaces(self): def interfaces(self):
"""Yield the supported interfaces.""" """Yield the supported interfaces."""
return [AlexaSceneController(self.entity, return [AlexaSceneController(self.entity, supports_deactivation=False)]
supports_deactivation=False)]
@ENTITY_ADAPTERS.register(script.DOMAIN) @ENTITY_ADAPTERS.register(script.DOMAIN)
@ -392,9 +411,8 @@ class ScriptCapabilities(AlexaEntity):
def interfaces(self): def interfaces(self):
"""Yield the supported interfaces.""" """Yield the supported interfaces."""
can_cancel = bool(self.entity.attributes.get('can_cancel')) can_cancel = bool(self.entity.attributes.get("can_cancel"))
return [AlexaSceneController(self.entity, return [AlexaSceneController(self.entity, supports_deactivation=can_cancel)]
supports_deactivation=can_cancel)]
@ENTITY_ADAPTERS.register(sensor.DOMAIN) @ENTITY_ADAPTERS.register(sensor.DOMAIN)
@ -410,10 +428,7 @@ class SensorCapabilities(AlexaEntity):
def interfaces(self): def interfaces(self):
"""Yield the supported interfaces.""" """Yield the supported interfaces."""
attrs = self.entity.attributes attrs = self.entity.attributes
if attrs.get(ATTR_UNIT_OF_MEASUREMENT) in ( if attrs.get(ATTR_UNIT_OF_MEASUREMENT) in (TEMP_FAHRENHEIT, TEMP_CELSIUS):
TEMP_FAHRENHEIT,
TEMP_CELSIUS,
):
yield AlexaTemperatureSensor(self.hass, self.entity) yield AlexaTemperatureSensor(self.hass, self.entity)
yield AlexaEndpointHealth(self.hass, self.entity) yield AlexaEndpointHealth(self.hass, self.entity)
@ -422,8 +437,8 @@ class SensorCapabilities(AlexaEntity):
class BinarySensorCapabilities(AlexaEntity): class BinarySensorCapabilities(AlexaEntity):
"""Class to represent BinarySensor capabilities.""" """Class to represent BinarySensor capabilities."""
TYPE_CONTACT = 'contact' TYPE_CONTACT = "contact"
TYPE_MOTION = 'motion' TYPE_MOTION = "motion"
def default_display_categories(self): def default_display_categories(self):
"""Return the display categories for this entity.""" """Return the display categories for this entity."""
@ -446,12 +461,7 @@ class BinarySensorCapabilities(AlexaEntity):
def get_type(self): def get_type(self):
"""Return the type of binary sensor.""" """Return the type of binary sensor."""
attrs = self.entity.attributes attrs = self.entity.attributes
if attrs.get(ATTR_DEVICE_CLASS) in ( if attrs.get(ATTR_DEVICE_CLASS) in ("door", "garage_door", "opening", "window"):
'door',
'garage_door',
'opening',
'window',
):
return self.TYPE_CONTACT return self.TYPE_CONTACT
if attrs.get(ATTR_DEVICE_CLASS) == 'motion': if attrs.get(ATTR_DEVICE_CLASS) == "motion":
return self.TYPE_MOTION return self.TYPE_MOTION

View file

@ -35,12 +35,12 @@ class AlexaError(Exception):
class AlexaInvalidEndpointError(AlexaError): class AlexaInvalidEndpointError(AlexaError):
"""The endpoint in the request does not exist.""" """The endpoint in the request does not exist."""
namespace = 'Alexa' namespace = "Alexa"
error_type = 'NO_SUCH_ENDPOINT' error_type = "NO_SUCH_ENDPOINT"
def __init__(self, endpoint_id): def __init__(self, endpoint_id):
"""Initialize invalid endpoint error.""" """Initialize invalid endpoint error."""
msg = 'The endpoint {} does not exist'.format(endpoint_id) msg = "The endpoint {} does not exist".format(endpoint_id)
AlexaError.__init__(self, msg) AlexaError.__init__(self, msg)
self.endpoint_id = endpoint_id self.endpoint_id = endpoint_id
@ -48,38 +48,32 @@ class AlexaInvalidEndpointError(AlexaError):
class AlexaInvalidValueError(AlexaError): class AlexaInvalidValueError(AlexaError):
"""Class to represent InvalidValue errors.""" """Class to represent InvalidValue errors."""
namespace = 'Alexa' namespace = "Alexa"
error_type = 'INVALID_VALUE' error_type = "INVALID_VALUE"
class AlexaUnsupportedThermostatModeError(AlexaError): class AlexaUnsupportedThermostatModeError(AlexaError):
"""Class to represent UnsupportedThermostatMode errors.""" """Class to represent UnsupportedThermostatMode errors."""
namespace = 'Alexa.ThermostatController' namespace = "Alexa.ThermostatController"
error_type = 'UNSUPPORTED_THERMOSTAT_MODE' error_type = "UNSUPPORTED_THERMOSTAT_MODE"
class AlexaTempRangeError(AlexaError): class AlexaTempRangeError(AlexaError):
"""Class to represent TempRange errors.""" """Class to represent TempRange errors."""
namespace = 'Alexa' namespace = "Alexa"
error_type = 'TEMPERATURE_VALUE_OUT_OF_RANGE' error_type = "TEMPERATURE_VALUE_OUT_OF_RANGE"
def __init__(self, hass, temp, min_temp, max_temp): def __init__(self, hass, temp, min_temp, max_temp):
"""Initialize TempRange error.""" """Initialize TempRange error."""
unit = hass.config.units.temperature_unit unit = hass.config.units.temperature_unit
temp_range = { temp_range = {
'minimumValue': { "minimumValue": {"value": min_temp, "scale": API_TEMP_UNITS[unit]},
'value': min_temp, "maximumValue": {"value": max_temp, "scale": API_TEMP_UNITS[unit]},
'scale': API_TEMP_UNITS[unit],
},
'maximumValue': {
'value': max_temp,
'scale': API_TEMP_UNITS[unit],
},
} }
payload = {'validRange': temp_range} payload = {"validRange": temp_range}
msg = 'The requested temperature {} is out of range'.format(temp) msg = "The requested temperature {} is out of range".format(temp)
AlexaError.__init__(self, msg, payload) AlexaError.__init__(self, msg, payload)
@ -87,5 +81,5 @@ class AlexaTempRangeError(AlexaError):
class AlexaBridgeUnreachableError(AlexaError): class AlexaBridgeUnreachableError(AlexaError):
"""Class to represent BridgeUnreachable errors.""" """Class to represent BridgeUnreachable errors."""
namespace = 'Alexa' namespace = "Alexa"
error_type = 'BRIDGE_UNREACHABLE' error_type = "BRIDGE_UNREACHABLE"

View file

@ -9,27 +9,36 @@ from homeassistant.core import callback
from homeassistant.helpers import template from homeassistant.helpers import template
from .const import ( from .const import (
ATTR_MAIN_TEXT, ATTR_REDIRECTION_URL, ATTR_STREAM_URL, ATTR_TITLE_TEXT, ATTR_MAIN_TEXT,
ATTR_UID, ATTR_UPDATE_DATE, CONF_AUDIO, CONF_DISPLAY_URL, CONF_TEXT, ATTR_REDIRECTION_URL,
CONF_TITLE, CONF_UID, DATE_FORMAT) ATTR_STREAM_URL,
ATTR_TITLE_TEXT,
ATTR_UID,
ATTR_UPDATE_DATE,
CONF_AUDIO,
CONF_DISPLAY_URL,
CONF_TEXT,
CONF_TITLE,
CONF_UID,
DATE_FORMAT,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
FLASH_BRIEFINGS_API_ENDPOINT = '/api/alexa/flash_briefings/{briefing_id}' FLASH_BRIEFINGS_API_ENDPOINT = "/api/alexa/flash_briefings/{briefing_id}"
@callback @callback
def async_setup(hass, flash_briefing_config): def async_setup(hass, flash_briefing_config):
"""Activate Alexa component.""" """Activate Alexa component."""
hass.http.register_view( hass.http.register_view(AlexaFlashBriefingView(hass, flash_briefing_config))
AlexaFlashBriefingView(hass, flash_briefing_config))
class AlexaFlashBriefingView(http.HomeAssistantView): class AlexaFlashBriefingView(http.HomeAssistantView):
"""Handle Alexa Flash Briefing skill requests.""" """Handle Alexa Flash Briefing skill requests."""
url = FLASH_BRIEFINGS_API_ENDPOINT url = FLASH_BRIEFINGS_API_ENDPOINT
name = 'api:alexa:flash_briefings' name = "api:alexa:flash_briefings"
def __init__(self, hass, flash_briefings): def __init__(self, hass, flash_briefings):
"""Initialize Alexa view.""" """Initialize Alexa view."""
@ -40,13 +49,12 @@ class AlexaFlashBriefingView(http.HomeAssistantView):
@callback @callback
def get(self, request, briefing_id): def get(self, request, briefing_id):
"""Handle Alexa Flash Briefing request.""" """Handle Alexa Flash Briefing request."""
_LOGGER.debug("Received Alexa flash briefing request for: %s", _LOGGER.debug("Received Alexa flash briefing request for: %s", briefing_id)
briefing_id)
if self.flash_briefings.get(briefing_id) is None: if self.flash_briefings.get(briefing_id) is None:
err = "No configured Alexa flash briefing was found for: %s" err = "No configured Alexa flash briefing was found for: %s"
_LOGGER.error(err, briefing_id) _LOGGER.error(err, briefing_id)
return b'', 404 return b"", 404
briefing = [] briefing = []
@ -76,10 +84,8 @@ class AlexaFlashBriefingView(http.HomeAssistantView):
output[ATTR_STREAM_URL] = item.get(CONF_AUDIO) output[ATTR_STREAM_URL] = item.get(CONF_AUDIO)
if item.get(CONF_DISPLAY_URL) is not None: if item.get(CONF_DISPLAY_URL) is not None:
if isinstance(item.get(CONF_DISPLAY_URL), if isinstance(item.get(CONF_DISPLAY_URL), template.Template):
template.Template): output[ATTR_REDIRECTION_URL] = item[CONF_DISPLAY_URL].async_render()
output[ATTR_REDIRECTION_URL] = \
item[CONF_DISPLAY_URL].async_render()
else: else:
output[ATTR_REDIRECTION_URL] = item.get(CONF_DISPLAY_URL) output[ATTR_REDIRECTION_URL] = item.get(CONF_DISPLAY_URL)

View file

@ -7,29 +7,44 @@ from homeassistant import core as ha
from homeassistant.components import cover, fan, group, light, media_player from homeassistant.components import cover, fan, group, light, media_player
from homeassistant.components.climate import const as climate from homeassistant.components.climate import const as climate
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, ATTR_SUPPORTED_FEATURES, ATTR_TEMPERATURE, SERVICE_LOCK, ATTR_ENTITY_ID,
SERVICE_MEDIA_NEXT_TRACK, SERVICE_MEDIA_PAUSE, SERVICE_MEDIA_PLAY, ATTR_SUPPORTED_FEATURES,
SERVICE_MEDIA_PREVIOUS_TRACK, SERVICE_MEDIA_STOP, ATTR_TEMPERATURE,
SERVICE_SET_COVER_POSITION, SERVICE_TURN_OFF, SERVICE_TURN_ON, SERVICE_LOCK,
SERVICE_UNLOCK, SERVICE_VOLUME_DOWN, SERVICE_VOLUME_MUTE, SERVICE_MEDIA_NEXT_TRACK,
SERVICE_VOLUME_SET, SERVICE_VOLUME_UP, TEMP_CELSIUS, TEMP_FAHRENHEIT) SERVICE_MEDIA_PAUSE,
SERVICE_MEDIA_PLAY,
SERVICE_MEDIA_PREVIOUS_TRACK,
SERVICE_MEDIA_STOP,
SERVICE_SET_COVER_POSITION,
SERVICE_TURN_OFF,
SERVICE_TURN_ON,
SERVICE_UNLOCK,
SERVICE_VOLUME_DOWN,
SERVICE_VOLUME_MUTE,
SERVICE_VOLUME_SET,
SERVICE_VOLUME_UP,
TEMP_CELSIUS,
TEMP_FAHRENHEIT,
)
import homeassistant.util.color as color_util import homeassistant.util.color as color_util
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
from homeassistant.util.temperature import convert as convert_temperature from homeassistant.util.temperature import convert as convert_temperature
from .const import ( from .const import API_TEMP_UNITS, API_THERMOSTAT_MODES, API_THERMOSTAT_PRESETS, Cause
API_TEMP_UNITS, API_THERMOSTAT_MODES, API_THERMOSTAT_PRESETS, Cause)
from .entities import async_get_entities from .entities import async_get_entities
from .errors import ( from .errors import (
AlexaInvalidValueError, AlexaTempRangeError, AlexaInvalidValueError,
AlexaUnsupportedThermostatModeError) AlexaTempRangeError,
AlexaUnsupportedThermostatModeError,
)
from .state_report import async_enable_proactive_mode from .state_report import async_enable_proactive_mode
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
HANDLERS = Registry() HANDLERS = Registry()
@HANDLERS.register(('Alexa.Discovery', 'Discover')) @HANDLERS.register(("Alexa.Discovery", "Discover"))
async def async_api_discovery(hass, config, directive, context): async def async_api_discovery(hass, config, directive, context):
"""Create a API formatted discovery response. """Create a API formatted discovery response.
@ -42,19 +57,19 @@ async def async_api_discovery(hass, config, directive, context):
] ]
return directive.response( return directive.response(
name='Discover.Response', name="Discover.Response",
namespace='Alexa.Discovery', namespace="Alexa.Discovery",
payload={'endpoints': discovery_endpoints}, payload={"endpoints": discovery_endpoints},
) )
@HANDLERS.register(('Alexa.Authorization', 'AcceptGrant')) @HANDLERS.register(("Alexa.Authorization", "AcceptGrant"))
async def async_api_accept_grant(hass, config, directive, context): async def async_api_accept_grant(hass, config, directive, context):
"""Create a API formatted AcceptGrant response. """Create a API formatted AcceptGrant response.
Async friendly. Async friendly.
""" """
auth_code = directive.payload['grant']['code'] auth_code = directive.payload["grant"]["code"]
_LOGGER.debug("AcceptGrant code: %s", auth_code) _LOGGER.debug("AcceptGrant code: %s", auth_code)
if config.supports_auth: if config.supports_auth:
@ -64,12 +79,11 @@ async def async_api_accept_grant(hass, config, directive, context):
await async_enable_proactive_mode(hass, config) await async_enable_proactive_mode(hass, config)
return directive.response( return directive.response(
name='AcceptGrant.Response', name="AcceptGrant.Response", namespace="Alexa.Authorization", payload={}
namespace='Alexa.Authorization', )
payload={})
@HANDLERS.register(('Alexa.PowerController', 'TurnOn')) @HANDLERS.register(("Alexa.PowerController", "TurnOn"))
async def async_api_turn_on(hass, config, directive, context): async def async_api_turn_on(hass, config, directive, context):
"""Process a turn on request.""" """Process a turn on request."""
entity = directive.entity entity = directive.entity
@ -82,19 +96,22 @@ async def async_api_turn_on(hass, config, directive, context):
service = cover.SERVICE_OPEN_COVER service = cover.SERVICE_OPEN_COVER
elif domain == media_player.DOMAIN: elif domain == media_player.DOMAIN:
supported = entity.attributes.get(ATTR_SUPPORTED_FEATURES, 0) supported = entity.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
power_features = (media_player.SUPPORT_TURN_ON | power_features = media_player.SUPPORT_TURN_ON | media_player.SUPPORT_TURN_OFF
media_player.SUPPORT_TURN_OFF)
if not supported & power_features: if not supported & power_features:
service = media_player.SERVICE_MEDIA_PLAY service = media_player.SERVICE_MEDIA_PLAY
await hass.services.async_call(domain, service, { await hass.services.async_call(
ATTR_ENTITY_ID: entity.entity_id domain,
}, blocking=False, context=context) service,
{ATTR_ENTITY_ID: entity.entity_id},
blocking=False,
context=context,
)
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.PowerController', 'TurnOff')) @HANDLERS.register(("Alexa.PowerController", "TurnOff"))
async def async_api_turn_off(hass, config, directive, context): async def async_api_turn_off(hass, config, directive, context):
"""Process a turn off request.""" """Process a turn off request."""
entity = directive.entity entity = directive.entity
@ -107,89 +124,104 @@ async def async_api_turn_off(hass, config, directive, context):
service = cover.SERVICE_CLOSE_COVER service = cover.SERVICE_CLOSE_COVER
elif domain == media_player.DOMAIN: elif domain == media_player.DOMAIN:
supported = entity.attributes.get(ATTR_SUPPORTED_FEATURES, 0) supported = entity.attributes.get(ATTR_SUPPORTED_FEATURES, 0)
power_features = (media_player.SUPPORT_TURN_ON | power_features = media_player.SUPPORT_TURN_ON | media_player.SUPPORT_TURN_OFF
media_player.SUPPORT_TURN_OFF)
if not supported & power_features: if not supported & power_features:
service = media_player.SERVICE_MEDIA_STOP service = media_player.SERVICE_MEDIA_STOP
await hass.services.async_call(domain, service, { await hass.services.async_call(
ATTR_ENTITY_ID: entity.entity_id domain,
}, blocking=False, context=context) service,
{ATTR_ENTITY_ID: entity.entity_id},
blocking=False,
context=context,
)
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.BrightnessController', 'SetBrightness')) @HANDLERS.register(("Alexa.BrightnessController", "SetBrightness"))
async def async_api_set_brightness(hass, config, directive, context): async def async_api_set_brightness(hass, config, directive, context):
"""Process a set brightness request.""" """Process a set brightness request."""
entity = directive.entity entity = directive.entity
brightness = int(directive.payload['brightness']) brightness = int(directive.payload["brightness"])
await hass.services.async_call(entity.domain, SERVICE_TURN_ON, { await hass.services.async_call(
ATTR_ENTITY_ID: entity.entity_id, entity.domain,
light.ATTR_BRIGHTNESS_PCT: brightness, SERVICE_TURN_ON,
}, blocking=False, context=context) {ATTR_ENTITY_ID: entity.entity_id, light.ATTR_BRIGHTNESS_PCT: brightness},
blocking=False,
context=context,
)
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.BrightnessController', 'AdjustBrightness')) @HANDLERS.register(("Alexa.BrightnessController", "AdjustBrightness"))
async def async_api_adjust_brightness(hass, config, directive, context): async def async_api_adjust_brightness(hass, config, directive, context):
"""Process an adjust brightness request.""" """Process an adjust brightness request."""
entity = directive.entity entity = directive.entity
brightness_delta = int(directive.payload['brightnessDelta']) brightness_delta = int(directive.payload["brightnessDelta"])
# read current state # read current state
try: try:
current = math.floor( current = math.floor(
int(entity.attributes.get(light.ATTR_BRIGHTNESS)) / 255 * 100) int(entity.attributes.get(light.ATTR_BRIGHTNESS)) / 255 * 100
)
except ZeroDivisionError: except ZeroDivisionError:
current = 0 current = 0
# set brightness # set brightness
brightness = max(0, brightness_delta + current) brightness = max(0, brightness_delta + current)
await hass.services.async_call(entity.domain, SERVICE_TURN_ON, { await hass.services.async_call(
ATTR_ENTITY_ID: entity.entity_id, entity.domain,
light.ATTR_BRIGHTNESS_PCT: brightness, SERVICE_TURN_ON,
}, blocking=False, context=context) {ATTR_ENTITY_ID: entity.entity_id, light.ATTR_BRIGHTNESS_PCT: brightness},
blocking=False,
context=context,
)
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.ColorController', 'SetColor')) @HANDLERS.register(("Alexa.ColorController", "SetColor"))
async def async_api_set_color(hass, config, directive, context): async def async_api_set_color(hass, config, directive, context):
"""Process a set color request.""" """Process a set color request."""
entity = directive.entity entity = directive.entity
rgb = color_util.color_hsb_to_RGB( rgb = color_util.color_hsb_to_RGB(
float(directive.payload['color']['hue']), float(directive.payload["color"]["hue"]),
float(directive.payload['color']['saturation']), float(directive.payload["color"]["saturation"]),
float(directive.payload['color']['brightness']) float(directive.payload["color"]["brightness"]),
) )
await hass.services.async_call(entity.domain, SERVICE_TURN_ON, { await hass.services.async_call(
ATTR_ENTITY_ID: entity.entity_id, entity.domain,
light.ATTR_RGB_COLOR: rgb, SERVICE_TURN_ON,
}, blocking=False, context=context) {ATTR_ENTITY_ID: entity.entity_id, light.ATTR_RGB_COLOR: rgb},
blocking=False,
context=context,
)
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.ColorTemperatureController', 'SetColorTemperature')) @HANDLERS.register(("Alexa.ColorTemperatureController", "SetColorTemperature"))
async def async_api_set_color_temperature(hass, config, directive, context): async def async_api_set_color_temperature(hass, config, directive, context):
"""Process a set color temperature request.""" """Process a set color temperature request."""
entity = directive.entity entity = directive.entity
kelvin = int(directive.payload['colorTemperatureInKelvin']) kelvin = int(directive.payload["colorTemperatureInKelvin"])
await hass.services.async_call(entity.domain, SERVICE_TURN_ON, { await hass.services.async_call(
ATTR_ENTITY_ID: entity.entity_id, entity.domain,
light.ATTR_KELVIN: kelvin, SERVICE_TURN_ON,
}, blocking=False, context=context) {ATTR_ENTITY_ID: entity.entity_id, light.ATTR_KELVIN: kelvin},
blocking=False,
context=context,
)
return directive.response() return directive.response()
@HANDLERS.register( @HANDLERS.register(("Alexa.ColorTemperatureController", "DecreaseColorTemperature"))
('Alexa.ColorTemperatureController', 'DecreaseColorTemperature'))
async def async_api_decrease_color_temp(hass, config, directive, context): async def async_api_decrease_color_temp(hass, config, directive, context):
"""Process a decrease color temperature request.""" """Process a decrease color temperature request."""
entity = directive.entity entity = directive.entity
@ -197,16 +229,18 @@ async def async_api_decrease_color_temp(hass, config, directive, context):
max_mireds = int(entity.attributes.get(light.ATTR_MAX_MIREDS)) max_mireds = int(entity.attributes.get(light.ATTR_MAX_MIREDS))
value = min(max_mireds, current + 50) value = min(max_mireds, current + 50)
await hass.services.async_call(entity.domain, SERVICE_TURN_ON, { await hass.services.async_call(
ATTR_ENTITY_ID: entity.entity_id, entity.domain,
light.ATTR_COLOR_TEMP: value, SERVICE_TURN_ON,
}, blocking=False, context=context) {ATTR_ENTITY_ID: entity.entity_id, light.ATTR_COLOR_TEMP: value},
blocking=False,
context=context,
)
return directive.response() return directive.response()
@HANDLERS.register( @HANDLERS.register(("Alexa.ColorTemperatureController", "IncreaseColorTemperature"))
('Alexa.ColorTemperatureController', 'IncreaseColorTemperature'))
async def async_api_increase_color_temp(hass, config, directive, context): async def async_api_increase_color_temp(hass, config, directive, context):
"""Process an increase color temperature request.""" """Process an increase color temperature request."""
entity = directive.entity entity = directive.entity
@ -214,63 +248,70 @@ async def async_api_increase_color_temp(hass, config, directive, context):
min_mireds = int(entity.attributes.get(light.ATTR_MIN_MIREDS)) min_mireds = int(entity.attributes.get(light.ATTR_MIN_MIREDS))
value = max(min_mireds, current - 50) value = max(min_mireds, current - 50)
await hass.services.async_call(entity.domain, SERVICE_TURN_ON, { await hass.services.async_call(
ATTR_ENTITY_ID: entity.entity_id, entity.domain,
light.ATTR_COLOR_TEMP: value, SERVICE_TURN_ON,
}, blocking=False, context=context) {ATTR_ENTITY_ID: entity.entity_id, light.ATTR_COLOR_TEMP: value},
blocking=False,
context=context,
)
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.SceneController', 'Activate')) @HANDLERS.register(("Alexa.SceneController", "Activate"))
async def async_api_activate(hass, config, directive, context): async def async_api_activate(hass, config, directive, context):
"""Process an activate request.""" """Process an activate request."""
entity = directive.entity entity = directive.entity
domain = entity.domain domain = entity.domain
await hass.services.async_call(domain, SERVICE_TURN_ON, { await hass.services.async_call(
ATTR_ENTITY_ID: entity.entity_id domain,
}, blocking=False, context=context) SERVICE_TURN_ON,
{ATTR_ENTITY_ID: entity.entity_id},
blocking=False,
context=context,
)
payload = { payload = {
'cause': {'type': Cause.VOICE_INTERACTION}, "cause": {"type": Cause.VOICE_INTERACTION},
'timestamp': '%sZ' % (datetime.utcnow().isoformat(),) "timestamp": "%sZ" % (datetime.utcnow().isoformat(),),
} }
return directive.response( return directive.response(
name='ActivationStarted', name="ActivationStarted", namespace="Alexa.SceneController", payload=payload
namespace='Alexa.SceneController',
payload=payload,
) )
@HANDLERS.register(('Alexa.SceneController', 'Deactivate')) @HANDLERS.register(("Alexa.SceneController", "Deactivate"))
async def async_api_deactivate(hass, config, directive, context): async def async_api_deactivate(hass, config, directive, context):
"""Process a deactivate request.""" """Process a deactivate request."""
entity = directive.entity entity = directive.entity
domain = entity.domain domain = entity.domain
await hass.services.async_call(domain, SERVICE_TURN_OFF, { await hass.services.async_call(
ATTR_ENTITY_ID: entity.entity_id domain,
}, blocking=False, context=context) SERVICE_TURN_OFF,
{ATTR_ENTITY_ID: entity.entity_id},
blocking=False,
context=context,
)
payload = { payload = {
'cause': {'type': Cause.VOICE_INTERACTION}, "cause": {"type": Cause.VOICE_INTERACTION},
'timestamp': '%sZ' % (datetime.utcnow().isoformat(),) "timestamp": "%sZ" % (datetime.utcnow().isoformat(),),
} }
return directive.response( return directive.response(
name='DeactivationStarted', name="DeactivationStarted", namespace="Alexa.SceneController", payload=payload
namespace='Alexa.SceneController',
payload=payload,
) )
@HANDLERS.register(('Alexa.PercentageController', 'SetPercentage')) @HANDLERS.register(("Alexa.PercentageController", "SetPercentage"))
async def async_api_set_percentage(hass, config, directive, context): async def async_api_set_percentage(hass, config, directive, context):
"""Process a set percentage request.""" """Process a set percentage request."""
entity = directive.entity entity = directive.entity
percentage = int(directive.payload['percentage']) percentage = int(directive.payload["percentage"])
service = None service = None
data = {ATTR_ENTITY_ID: entity.entity_id} data = {ATTR_ENTITY_ID: entity.entity_id}
@ -291,16 +332,17 @@ async def async_api_set_percentage(hass, config, directive, context):
data[cover.ATTR_POSITION] = percentage data[cover.ATTR_POSITION] = percentage
await hass.services.async_call( await hass.services.async_call(
entity.domain, service, data, blocking=False, context=context) entity.domain, service, data, blocking=False, context=context
)
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.PercentageController', 'AdjustPercentage')) @HANDLERS.register(("Alexa.PercentageController", "AdjustPercentage"))
async def async_api_adjust_percentage(hass, config, directive, context): async def async_api_adjust_percentage(hass, config, directive, context):
"""Process an adjust percentage request.""" """Process an adjust percentage request."""
entity = directive.entity entity = directive.entity
percentage_delta = int(directive.payload['percentageDelta']) percentage_delta = int(directive.payload["percentageDelta"])
service = None service = None
data = {ATTR_ENTITY_ID: entity.entity_id} data = {ATTR_ENTITY_ID: entity.entity_id}
@ -338,44 +380,51 @@ async def async_api_adjust_percentage(hass, config, directive, context):
data[cover.ATTR_POSITION] = max(0, percentage_delta + current) data[cover.ATTR_POSITION] = max(0, percentage_delta + current)
await hass.services.async_call( await hass.services.async_call(
entity.domain, service, data, blocking=False, context=context) entity.domain, service, data, blocking=False, context=context
)
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.LockController', 'Lock')) @HANDLERS.register(("Alexa.LockController", "Lock"))
async def async_api_lock(hass, config, directive, context): async def async_api_lock(hass, config, directive, context):
"""Process a lock request.""" """Process a lock request."""
entity = directive.entity entity = directive.entity
await hass.services.async_call(entity.domain, SERVICE_LOCK, { await hass.services.async_call(
ATTR_ENTITY_ID: entity.entity_id entity.domain,
}, blocking=False, context=context) SERVICE_LOCK,
{ATTR_ENTITY_ID: entity.entity_id},
blocking=False,
context=context,
)
response = directive.response() response = directive.response()
response.add_context_property({ response.add_context_property(
'name': 'lockState', {"name": "lockState", "namespace": "Alexa.LockController", "value": "LOCKED"}
'namespace': 'Alexa.LockController', )
'value': 'LOCKED'
})
return response return response
# Not supported by Alexa yet # Not supported by Alexa yet
@HANDLERS.register(('Alexa.LockController', 'Unlock')) @HANDLERS.register(("Alexa.LockController", "Unlock"))
async def async_api_unlock(hass, config, directive, context): async def async_api_unlock(hass, config, directive, context):
"""Process an unlock request.""" """Process an unlock request."""
entity = directive.entity entity = directive.entity
await hass.services.async_call(entity.domain, SERVICE_UNLOCK, { await hass.services.async_call(
ATTR_ENTITY_ID: entity.entity_id entity.domain,
}, blocking=False, context=context) SERVICE_UNLOCK,
{ATTR_ENTITY_ID: entity.entity_id},
blocking=False,
context=context,
)
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.Speaker', 'SetVolume')) @HANDLERS.register(("Alexa.Speaker", "SetVolume"))
async def async_api_set_volume(hass, config, directive, context): async def async_api_set_volume(hass, config, directive, context):
"""Process a set volume request.""" """Process a set volume request."""
volume = round(float(directive.payload['volume'] / 100), 2) volume = round(float(directive.payload["volume"] / 100), 2)
entity = directive.entity entity = directive.entity
data = { data = {
@ -384,31 +433,31 @@ async def async_api_set_volume(hass, config, directive, context):
} }
await hass.services.async_call( await hass.services.async_call(
entity.domain, SERVICE_VOLUME_SET, entity.domain, SERVICE_VOLUME_SET, data, blocking=False, context=context
data, blocking=False, context=context) )
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.InputController', 'SelectInput')) @HANDLERS.register(("Alexa.InputController", "SelectInput"))
async def async_api_select_input(hass, config, directive, context): async def async_api_select_input(hass, config, directive, context):
"""Process a set input request.""" """Process a set input request."""
media_input = directive.payload['input'] media_input = directive.payload["input"]
entity = directive.entity entity = directive.entity
# attempt to map the ALL UPPERCASE payload name to a source # attempt to map the ALL UPPERCASE payload name to a source
source_list = entity.attributes[ source_list = entity.attributes[media_player.const.ATTR_INPUT_SOURCE_LIST] or []
media_player.const.ATTR_INPUT_SOURCE_LIST] or []
for source in source_list: for source in source_list:
# response will always be space separated, so format the source in the # response will always be space separated, so format the source in the
# most likely way to find a match # most likely way to find a match
formatted_source = source.lower().replace('-', ' ').replace('_', ' ') formatted_source = source.lower().replace("-", " ").replace("_", " ")
if formatted_source in media_input.lower(): if formatted_source in media_input.lower():
media_input = source media_input = source
break break
else: else:
msg = 'failed to map input {} to a media source on {}'.format( msg = "failed to map input {} to a media source on {}".format(
media_input, entity.entity_id) media_input, entity.entity_id
)
raise AlexaInvalidValueError(msg) raise AlexaInvalidValueError(msg)
data = { data = {
@ -417,20 +466,23 @@ async def async_api_select_input(hass, config, directive, context):
} }
await hass.services.async_call( await hass.services.async_call(
entity.domain, media_player.SERVICE_SELECT_SOURCE, entity.domain,
data, blocking=False, context=context) media_player.SERVICE_SELECT_SOURCE,
data,
blocking=False,
context=context,
)
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.Speaker', 'AdjustVolume')) @HANDLERS.register(("Alexa.Speaker", "AdjustVolume"))
async def async_api_adjust_volume(hass, config, directive, context): async def async_api_adjust_volume(hass, config, directive, context):
"""Process an adjust volume request.""" """Process an adjust volume request."""
volume_delta = int(directive.payload['volume']) volume_delta = int(directive.payload["volume"])
entity = directive.entity entity = directive.entity
current_level = entity.attributes.get( current_level = entity.attributes.get(media_player.const.ATTR_MEDIA_VOLUME_LEVEL)
media_player.const.ATTR_MEDIA_VOLUME_LEVEL)
# read current state # read current state
try: try:
@ -446,43 +498,41 @@ async def async_api_adjust_volume(hass, config, directive, context):
} }
await hass.services.async_call( await hass.services.async_call(
entity.domain, SERVICE_VOLUME_SET, entity.domain, SERVICE_VOLUME_SET, data, blocking=False, context=context
data, blocking=False, context=context) )
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.StepSpeaker', 'AdjustVolume')) @HANDLERS.register(("Alexa.StepSpeaker", "AdjustVolume"))
async def async_api_adjust_volume_step(hass, config, directive, context): async def async_api_adjust_volume_step(hass, config, directive, context):
"""Process an adjust volume step request.""" """Process an adjust volume step request."""
# media_player volume up/down service does not support specifying steps # media_player volume up/down service does not support specifying steps
# each component handles it differently e.g. via config. # each component handles it differently e.g. via config.
# For now we use the volumeSteps returned to figure out if we # For now we use the volumeSteps returned to figure out if we
# should step up/down # should step up/down
volume_step = directive.payload['volumeSteps'] volume_step = directive.payload["volumeSteps"]
entity = directive.entity entity = directive.entity
data = { data = {ATTR_ENTITY_ID: entity.entity_id}
ATTR_ENTITY_ID: entity.entity_id,
}
if volume_step > 0: if volume_step > 0:
await hass.services.async_call( await hass.services.async_call(
entity.domain, SERVICE_VOLUME_UP, entity.domain, SERVICE_VOLUME_UP, data, blocking=False, context=context
data, blocking=False, context=context) )
elif volume_step < 0: elif volume_step < 0:
await hass.services.async_call( await hass.services.async_call(
entity.domain, SERVICE_VOLUME_DOWN, entity.domain, SERVICE_VOLUME_DOWN, data, blocking=False, context=context
data, blocking=False, context=context) )
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.StepSpeaker', 'SetMute')) @HANDLERS.register(("Alexa.StepSpeaker", "SetMute"))
@HANDLERS.register(('Alexa.Speaker', 'SetMute')) @HANDLERS.register(("Alexa.Speaker", "SetMute"))
async def async_api_set_mute(hass, config, directive, context): async def async_api_set_mute(hass, config, directive, context):
"""Process a set mute request.""" """Process a set mute request."""
mute = bool(directive.payload['mute']) mute = bool(directive.payload["mute"])
entity = directive.entity entity = directive.entity
data = { data = {
@ -491,83 +541,77 @@ async def async_api_set_mute(hass, config, directive, context):
} }
await hass.services.async_call( await hass.services.async_call(
entity.domain, SERVICE_VOLUME_MUTE, entity.domain, SERVICE_VOLUME_MUTE, data, blocking=False, context=context
data, blocking=False, context=context) )
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.PlaybackController', 'Play')) @HANDLERS.register(("Alexa.PlaybackController", "Play"))
async def async_api_play(hass, config, directive, context): async def async_api_play(hass, config, directive, context):
"""Process a play request.""" """Process a play request."""
entity = directive.entity entity = directive.entity
data = { data = {ATTR_ENTITY_ID: entity.entity_id}
ATTR_ENTITY_ID: entity.entity_id
}
await hass.services.async_call( await hass.services.async_call(
entity.domain, SERVICE_MEDIA_PLAY, entity.domain, SERVICE_MEDIA_PLAY, data, blocking=False, context=context
data, blocking=False, context=context) )
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.PlaybackController', 'Pause')) @HANDLERS.register(("Alexa.PlaybackController", "Pause"))
async def async_api_pause(hass, config, directive, context): async def async_api_pause(hass, config, directive, context):
"""Process a pause request.""" """Process a pause request."""
entity = directive.entity entity = directive.entity
data = { data = {ATTR_ENTITY_ID: entity.entity_id}
ATTR_ENTITY_ID: entity.entity_id
}
await hass.services.async_call( await hass.services.async_call(
entity.domain, SERVICE_MEDIA_PAUSE, entity.domain, SERVICE_MEDIA_PAUSE, data, blocking=False, context=context
data, blocking=False, context=context) )
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.PlaybackController', 'Stop')) @HANDLERS.register(("Alexa.PlaybackController", "Stop"))
async def async_api_stop(hass, config, directive, context): async def async_api_stop(hass, config, directive, context):
"""Process a stop request.""" """Process a stop request."""
entity = directive.entity entity = directive.entity
data = { data = {ATTR_ENTITY_ID: entity.entity_id}
ATTR_ENTITY_ID: entity.entity_id
}
await hass.services.async_call( await hass.services.async_call(
entity.domain, SERVICE_MEDIA_STOP, entity.domain, SERVICE_MEDIA_STOP, data, blocking=False, context=context
data, blocking=False, context=context) )
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.PlaybackController', 'Next')) @HANDLERS.register(("Alexa.PlaybackController", "Next"))
async def async_api_next(hass, config, directive, context): async def async_api_next(hass, config, directive, context):
"""Process a next request.""" """Process a next request."""
entity = directive.entity entity = directive.entity
data = { data = {ATTR_ENTITY_ID: entity.entity_id}
ATTR_ENTITY_ID: entity.entity_id
}
await hass.services.async_call( await hass.services.async_call(
entity.domain, SERVICE_MEDIA_NEXT_TRACK, entity.domain, SERVICE_MEDIA_NEXT_TRACK, data, blocking=False, context=context
data, blocking=False, context=context) )
return directive.response() return directive.response()
@HANDLERS.register(('Alexa.PlaybackController', 'Previous')) @HANDLERS.register(("Alexa.PlaybackController", "Previous"))
async def async_api_previous(hass, config, directive, context): async def async_api_previous(hass, config, directive, context):
"""Process a previous request.""" """Process a previous request."""
entity = directive.entity entity = directive.entity
data = { data = {ATTR_ENTITY_ID: entity.entity_id}
ATTR_ENTITY_ID: entity.entity_id
}
await hass.services.async_call( await hass.services.async_call(
entity.domain, SERVICE_MEDIA_PREVIOUS_TRACK, entity.domain,
data, blocking=False, context=context) SERVICE_MEDIA_PREVIOUS_TRACK,
data,
blocking=False,
context=context,
)
return directive.response() return directive.response()
@ -576,11 +620,11 @@ def temperature_from_object(hass, temp_obj, interval=False):
"""Get temperature from Temperature object in requested unit.""" """Get temperature from Temperature object in requested unit."""
to_unit = hass.config.units.temperature_unit to_unit = hass.config.units.temperature_unit
from_unit = TEMP_CELSIUS from_unit = TEMP_CELSIUS
temp = float(temp_obj['value']) temp = float(temp_obj["value"])
if temp_obj['scale'] == 'FAHRENHEIT': if temp_obj["scale"] == "FAHRENHEIT":
from_unit = TEMP_FAHRENHEIT from_unit = TEMP_FAHRENHEIT
elif temp_obj['scale'] == 'KELVIN': elif temp_obj["scale"] == "KELVIN":
# convert to Celsius if absolute temperature # convert to Celsius if absolute temperature
if not interval: if not interval:
temp -= 273.15 temp -= 273.15
@ -588,7 +632,7 @@ def temperature_from_object(hass, temp_obj, interval=False):
return convert_temperature(temp, from_unit, to_unit, interval) return convert_temperature(temp, from_unit, to_unit, interval)
@HANDLERS.register(('Alexa.ThermostatController', 'SetTargetTemperature')) @HANDLERS.register(("Alexa.ThermostatController", "SetTargetTemperature"))
async def async_api_set_target_temp(hass, config, directive, context): async def async_api_set_target_temp(hass, config, directive, context):
"""Process a set target temperature request.""" """Process a set target temperature request."""
entity = directive.entity entity = directive.entity
@ -596,51 +640,59 @@ async def async_api_set_target_temp(hass, config, directive, context):
max_temp = entity.attributes.get(climate.ATTR_MAX_TEMP) max_temp = entity.attributes.get(climate.ATTR_MAX_TEMP)
unit = hass.config.units.temperature_unit unit = hass.config.units.temperature_unit
data = { data = {ATTR_ENTITY_ID: entity.entity_id}
ATTR_ENTITY_ID: entity.entity_id
}
payload = directive.payload payload = directive.payload
response = directive.response() response = directive.response()
if 'targetSetpoint' in payload: if "targetSetpoint" in payload:
temp = temperature_from_object(hass, payload['targetSetpoint']) temp = temperature_from_object(hass, payload["targetSetpoint"])
if temp < min_temp or temp > max_temp: if temp < min_temp or temp > max_temp:
raise AlexaTempRangeError(hass, temp, min_temp, max_temp) raise AlexaTempRangeError(hass, temp, min_temp, max_temp)
data[ATTR_TEMPERATURE] = temp data[ATTR_TEMPERATURE] = temp
response.add_context_property({ response.add_context_property(
'name': 'targetSetpoint', {
'namespace': 'Alexa.ThermostatController', "name": "targetSetpoint",
'value': {'value': temp, 'scale': API_TEMP_UNITS[unit]}, "namespace": "Alexa.ThermostatController",
}) "value": {"value": temp, "scale": API_TEMP_UNITS[unit]},
if 'lowerSetpoint' in payload: }
temp_low = temperature_from_object(hass, payload['lowerSetpoint']) )
if "lowerSetpoint" in payload:
temp_low = temperature_from_object(hass, payload["lowerSetpoint"])
if temp_low < min_temp or temp_low > max_temp: if temp_low < min_temp or temp_low > max_temp:
raise AlexaTempRangeError(hass, temp_low, min_temp, max_temp) raise AlexaTempRangeError(hass, temp_low, min_temp, max_temp)
data[climate.ATTR_TARGET_TEMP_LOW] = temp_low data[climate.ATTR_TARGET_TEMP_LOW] = temp_low
response.add_context_property({ response.add_context_property(
'name': 'lowerSetpoint', {
'namespace': 'Alexa.ThermostatController', "name": "lowerSetpoint",
'value': {'value': temp_low, 'scale': API_TEMP_UNITS[unit]}, "namespace": "Alexa.ThermostatController",
}) "value": {"value": temp_low, "scale": API_TEMP_UNITS[unit]},
if 'upperSetpoint' in payload: }
temp_high = temperature_from_object(hass, payload['upperSetpoint']) )
if "upperSetpoint" in payload:
temp_high = temperature_from_object(hass, payload["upperSetpoint"])
if temp_high < min_temp or temp_high > max_temp: if temp_high < min_temp or temp_high > max_temp:
raise AlexaTempRangeError(hass, temp_high, min_temp, max_temp) raise AlexaTempRangeError(hass, temp_high, min_temp, max_temp)
data[climate.ATTR_TARGET_TEMP_HIGH] = temp_high data[climate.ATTR_TARGET_TEMP_HIGH] = temp_high
response.add_context_property({ response.add_context_property(
'name': 'upperSetpoint', {
'namespace': 'Alexa.ThermostatController', "name": "upperSetpoint",
'value': {'value': temp_high, 'scale': API_TEMP_UNITS[unit]}, "namespace": "Alexa.ThermostatController",
}) "value": {"value": temp_high, "scale": API_TEMP_UNITS[unit]},
}
)
await hass.services.async_call( await hass.services.async_call(
entity.domain, climate.SERVICE_SET_TEMPERATURE, data, blocking=False, entity.domain,
context=context) climate.SERVICE_SET_TEMPERATURE,
data,
blocking=False,
context=context,
)
return response return response
@HANDLERS.register(('Alexa.ThermostatController', 'AdjustTargetTemperature')) @HANDLERS.register(("Alexa.ThermostatController", "AdjustTargetTemperature"))
async def async_api_adjust_target_temp(hass, config, directive, context): async def async_api_adjust_target_temp(hass, config, directive, context):
"""Process an adjust target temperature request.""" """Process an adjust target temperature request."""
entity = directive.entity entity = directive.entity
@ -649,53 +701,50 @@ async def async_api_adjust_target_temp(hass, config, directive, context):
unit = hass.config.units.temperature_unit unit = hass.config.units.temperature_unit
temp_delta = temperature_from_object( temp_delta = temperature_from_object(
hass, directive.payload['targetSetpointDelta'], interval=True) hass, directive.payload["targetSetpointDelta"], interval=True
)
target_temp = float(entity.attributes.get(ATTR_TEMPERATURE)) + temp_delta target_temp = float(entity.attributes.get(ATTR_TEMPERATURE)) + temp_delta
if target_temp < min_temp or target_temp > max_temp: if target_temp < min_temp or target_temp > max_temp:
raise AlexaTempRangeError(hass, target_temp, min_temp, max_temp) raise AlexaTempRangeError(hass, target_temp, min_temp, max_temp)
data = { data = {ATTR_ENTITY_ID: entity.entity_id, ATTR_TEMPERATURE: target_temp}
ATTR_ENTITY_ID: entity.entity_id,
ATTR_TEMPERATURE: target_temp,
}
response = directive.response() response = directive.response()
await hass.services.async_call( await hass.services.async_call(
entity.domain, climate.SERVICE_SET_TEMPERATURE, data, blocking=False, entity.domain,
context=context) climate.SERVICE_SET_TEMPERATURE,
response.add_context_property({ data,
'name': 'targetSetpoint', blocking=False,
'namespace': 'Alexa.ThermostatController', context=context,
'value': {'value': target_temp, 'scale': API_TEMP_UNITS[unit]}, )
}) response.add_context_property(
{
"name": "targetSetpoint",
"namespace": "Alexa.ThermostatController",
"value": {"value": target_temp, "scale": API_TEMP_UNITS[unit]},
}
)
return response return response
@HANDLERS.register(('Alexa.ThermostatController', 'SetThermostatMode')) @HANDLERS.register(("Alexa.ThermostatController", "SetThermostatMode"))
async def async_api_set_thermostat_mode(hass, config, directive, context): async def async_api_set_thermostat_mode(hass, config, directive, context):
"""Process a set thermostat mode request.""" """Process a set thermostat mode request."""
entity = directive.entity entity = directive.entity
mode = directive.payload['thermostatMode'] mode = directive.payload["thermostatMode"]
mode = mode if isinstance(mode, str) else mode['value'] mode = mode if isinstance(mode, str) else mode["value"]
data = { data = {ATTR_ENTITY_ID: entity.entity_id}
ATTR_ENTITY_ID: entity.entity_id,
}
ha_preset = next( ha_preset = next((k for k, v in API_THERMOSTAT_PRESETS.items() if v == mode), None)
(k for k, v in API_THERMOSTAT_PRESETS.items() if v == mode),
None
)
if ha_preset: if ha_preset:
presets = entity.attributes.get(climate.ATTR_PRESET_MODES, []) presets = entity.attributes.get(climate.ATTR_PRESET_MODES, [])
if ha_preset not in presets: if ha_preset not in presets:
msg = 'The requested thermostat mode {} is not supported'.format( msg = "The requested thermostat mode {} is not supported".format(ha_preset)
ha_preset
)
raise AlexaUnsupportedThermostatModeError(msg) raise AlexaUnsupportedThermostatModeError(msg)
service = climate.SERVICE_SET_PRESET_MODE service = climate.SERVICE_SET_PRESET_MODE
@ -703,14 +752,9 @@ async def async_api_set_thermostat_mode(hass, config, directive, context):
else: else:
operation_list = entity.attributes.get(climate.ATTR_HVAC_MODES) operation_list = entity.attributes.get(climate.ATTR_HVAC_MODES)
ha_mode = next( ha_mode = next((k for k, v in API_THERMOSTAT_MODES.items() if v == mode), None)
(k for k, v in API_THERMOSTAT_MODES.items() if v == mode),
None
)
if ha_mode not in operation_list: if ha_mode not in operation_list:
msg = 'The requested thermostat mode {} is not supported'.format( msg = "The requested thermostat mode {} is not supported".format(mode)
mode
)
raise AlexaUnsupportedThermostatModeError(msg) raise AlexaUnsupportedThermostatModeError(msg)
service = climate.SERVICE_SET_HVAC_MODE service = climate.SERVICE_SET_HVAC_MODE
@ -718,18 +762,20 @@ async def async_api_set_thermostat_mode(hass, config, directive, context):
response = directive.response() response = directive.response()
await hass.services.async_call( await hass.services.async_call(
climate.DOMAIN, service, data, climate.DOMAIN, service, data, blocking=False, context=context
blocking=False, context=context) )
response.add_context_property({ response.add_context_property(
'name': 'thermostatMode', {
'namespace': 'Alexa.ThermostatController', "name": "thermostatMode",
'value': mode, "namespace": "Alexa.ThermostatController",
}) "value": mode,
}
)
return response return response
@HANDLERS.register(('Alexa', 'ReportState')) @HANDLERS.register(("Alexa", "ReportState"))
async def async_api_reportstate(hass, config, directive, context): async def async_api_reportstate(hass, config, directive, context):
"""Process a ReportState request.""" """Process a ReportState request."""
return directive.response(name='StateReport') return directive.response(name="StateReport")

View file

@ -14,27 +14,24 @@ _LOGGER = logging.getLogger(__name__)
HANDLERS = Registry() HANDLERS = Registry()
INTENTS_API_ENDPOINT = '/api/alexa' INTENTS_API_ENDPOINT = "/api/alexa"
class SpeechType(enum.Enum): class SpeechType(enum.Enum):
"""The Alexa speech types.""" """The Alexa speech types."""
plaintext = 'PlainText' plaintext = "PlainText"
ssml = 'SSML' ssml = "SSML"
SPEECH_MAPPINGS = { SPEECH_MAPPINGS = {"plain": SpeechType.plaintext, "ssml": SpeechType.ssml}
'plain': SpeechType.plaintext,
'ssml': SpeechType.ssml,
}
class CardType(enum.Enum): class CardType(enum.Enum):
"""The Alexa card types.""" """The Alexa card types."""
simple = 'Simple' simple = "Simple"
link_account = 'LinkAccount' link_account = "LinkAccount"
@callback @callback
@ -51,44 +48,50 @@ class AlexaIntentsView(http.HomeAssistantView):
"""Handle Alexa requests.""" """Handle Alexa requests."""
url = INTENTS_API_ENDPOINT url = INTENTS_API_ENDPOINT
name = 'api:alexa' name = "api:alexa"
async def post(self, request): async def post(self, request):
"""Handle Alexa.""" """Handle Alexa."""
hass = request.app['hass'] hass = request.app["hass"]
message = await request.json() message = await request.json()
_LOGGER.debug("Received Alexa request: %s", message) _LOGGER.debug("Received Alexa request: %s", message)
try: try:
response = await async_handle_message(hass, message) response = await async_handle_message(hass, message)
return b'' if response is None else self.json(response) return b"" if response is None else self.json(response)
except UnknownRequest as err: except UnknownRequest as err:
_LOGGER.warning(str(err)) _LOGGER.warning(str(err))
return self.json(intent_error_response( return self.json(intent_error_response(hass, message, str(err)))
hass, message, str(err)))
except intent.UnknownIntent as err: except intent.UnknownIntent as err:
_LOGGER.warning(str(err)) _LOGGER.warning(str(err))
return self.json(intent_error_response( return self.json(
hass, message, intent_error_response(
"This intent is not yet configured within Home Assistant.")) hass,
message,
"This intent is not yet configured within Home Assistant.",
)
)
except intent.InvalidSlotInfo as err: except intent.InvalidSlotInfo as err:
_LOGGER.error("Received invalid slot data from Alexa: %s", err) _LOGGER.error("Received invalid slot data from Alexa: %s", err)
return self.json(intent_error_response( return self.json(
hass, message, intent_error_response(
"Invalid slot information received for this intent.")) hass, message, "Invalid slot information received for this intent."
)
)
except intent.IntentError as err: except intent.IntentError as err:
_LOGGER.exception(str(err)) _LOGGER.exception(str(err))
return self.json(intent_error_response( return self.json(
hass, message, "Error handling intent.")) intent_error_response(hass, message, "Error handling intent.")
)
def intent_error_response(hass, message, error): def intent_error_response(hass, message, error):
"""Return an Alexa response that will speak the error message.""" """Return an Alexa response that will speak the error message."""
alexa_intent_info = message.get('request').get('intent') alexa_intent_info = message.get("request").get("intent")
alexa_response = AlexaResponse(hass, alexa_intent_info) alexa_response = AlexaResponse(hass, alexa_intent_info)
alexa_response.add_speech(SpeechType.plaintext, error) alexa_response.add_speech(SpeechType.plaintext, error)
return alexa_response.as_dict() return alexa_response.as_dict()
@ -104,25 +107,25 @@ async def async_handle_message(hass, message):
- intent.IntentError - intent.IntentError
""" """
req = message.get('request') req = message.get("request")
req_type = req['type'] req_type = req["type"]
handler = HANDLERS.get(req_type) handler = HANDLERS.get(req_type)
if not handler: if not handler:
raise UnknownRequest('Received unknown request {}'.format(req_type)) raise UnknownRequest("Received unknown request {}".format(req_type))
return await handler(hass, message) return await handler(hass, message)
@HANDLERS.register('SessionEndedRequest') @HANDLERS.register("SessionEndedRequest")
async def async_handle_session_end(hass, message): async def async_handle_session_end(hass, message):
"""Handle a session end request.""" """Handle a session end request."""
return None return None
@HANDLERS.register('IntentRequest') @HANDLERS.register("IntentRequest")
@HANDLERS.register('LaunchRequest') @HANDLERS.register("LaunchRequest")
async def async_handle_intent(hass, message): async def async_handle_intent(hass, message):
"""Handle an intent request. """Handle an intent request.
@ -132,33 +135,37 @@ async def async_handle_intent(hass, message):
- intent.IntentError - intent.IntentError
""" """
req = message.get('request') req = message.get("request")
alexa_intent_info = req.get('intent') alexa_intent_info = req.get("intent")
alexa_response = AlexaResponse(hass, alexa_intent_info) alexa_response = AlexaResponse(hass, alexa_intent_info)
if req['type'] == 'LaunchRequest': if req["type"] == "LaunchRequest":
intent_name = message.get('session', {}) \ intent_name = (
.get('application', {}) \ message.get("session", {}).get("application", {}).get("applicationId")
.get('applicationId') )
else: else:
intent_name = alexa_intent_info['name'] intent_name = alexa_intent_info["name"]
intent_response = await intent.async_handle( intent_response = await intent.async_handle(
hass, DOMAIN, intent_name, hass,
{key: {'value': value} for key, value DOMAIN,
in alexa_response.variables.items()}) intent_name,
{key: {"value": value} for key, value in alexa_response.variables.items()},
)
for intent_speech, alexa_speech in SPEECH_MAPPINGS.items(): for intent_speech, alexa_speech in SPEECH_MAPPINGS.items():
if intent_speech in intent_response.speech: if intent_speech in intent_response.speech:
alexa_response.add_speech( alexa_response.add_speech(
alexa_speech, alexa_speech, intent_response.speech[intent_speech]["speech"]
intent_response.speech[intent_speech]['speech']) )
break break
if 'simple' in intent_response.card: if "simple" in intent_response.card:
alexa_response.add_card( alexa_response.add_card(
CardType.simple, intent_response.card['simple']['title'], CardType.simple,
intent_response.card['simple']['content']) intent_response.card["simple"]["title"],
intent_response.card["simple"]["content"],
)
return alexa_response.as_dict() return alexa_response.as_dict()
@ -168,23 +175,23 @@ def resolve_slot_synonyms(key, request):
# Default to the spoken slot value if more than one or none are found. For # Default to the spoken slot value if more than one or none are found. For
# reference to the request object structure, see the Alexa docs: # reference to the request object structure, see the Alexa docs:
# https://tinyurl.com/ybvm7jhs # https://tinyurl.com/ybvm7jhs
resolved_value = request['value'] resolved_value = request["value"]
if ('resolutions' in request and if (
'resolutionsPerAuthority' in request['resolutions'] and "resolutions" in request
len(request['resolutions']['resolutionsPerAuthority']) >= 1): and "resolutionsPerAuthority" in request["resolutions"]
and len(request["resolutions"]["resolutionsPerAuthority"]) >= 1
):
# Extract all of the possible values from each authority with a # Extract all of the possible values from each authority with a
# successful match # successful match
possible_values = [] possible_values = []
for entry in request['resolutions']['resolutionsPerAuthority']: for entry in request["resolutions"]["resolutionsPerAuthority"]:
if entry['status']['code'] != SYN_RESOLUTION_MATCH: if entry["status"]["code"] != SYN_RESOLUTION_MATCH:
continue continue
possible_values.extend([item['value']['name'] possible_values.extend([item["value"]["name"] for item in entry["values"]])
for item
in entry['values']])
# If there is only one match use the resolved value, otherwise the # If there is only one match use the resolved value, otherwise the
# resolution cannot be determined, so use the spoken slot value # resolution cannot be determined, so use the spoken slot value
@ -192,9 +199,9 @@ def resolve_slot_synonyms(key, request):
resolved_value = possible_values[0] resolved_value = possible_values[0]
else: else:
_LOGGER.debug( _LOGGER.debug(
'Found multiple synonym resolutions for slot value: {%s: %s}', "Found multiple synonym resolutions for slot value: {%s: %s}",
key, key,
request['value'] request["value"],
) )
return resolved_value return resolved_value
@ -215,12 +222,12 @@ class AlexaResponse:
# Intent is None if request was a LaunchRequest or SessionEndedRequest # Intent is None if request was a LaunchRequest or SessionEndedRequest
if intent_info is not None: if intent_info is not None:
for key, value in intent_info.get('slots', {}).items(): for key, value in intent_info.get("slots", {}).items():
# Only include slots with values # Only include slots with values
if 'value' not in value: if "value" not in value:
continue continue
_key = key.replace('.', '_') _key = key.replace(".", "_")
self.variables[_key] = resolve_slot_synonyms(key, value) self.variables[_key] = resolve_slot_synonyms(key, value)
@ -228,9 +235,7 @@ class AlexaResponse:
"""Add a card to the response.""" """Add a card to the response."""
assert self.card is None assert self.card is None
card = { card = {"type": card_type.value}
"type": card_type.value
}
if card_type == CardType.link_account: if card_type == CardType.link_account:
self.card = card self.card = card
@ -244,43 +249,36 @@ class AlexaResponse:
"""Add speech to the response.""" """Add speech to the response."""
assert self.speech is None assert self.speech is None
key = 'ssml' if speech_type == SpeechType.ssml else 'text' key = "ssml" if speech_type == SpeechType.ssml else "text"
self.speech = { self.speech = {"type": speech_type.value, key: text}
'type': speech_type.value,
key: text
}
def add_reprompt(self, speech_type, text): def add_reprompt(self, speech_type, text):
"""Add reprompt if user does not answer.""" """Add reprompt if user does not answer."""
assert self.reprompt is None assert self.reprompt is None
key = 'ssml' if speech_type == SpeechType.ssml else 'text' key = "ssml" if speech_type == SpeechType.ssml else "text"
self.reprompt = { self.reprompt = {
'type': speech_type.value, "type": speech_type.value,
key: text.async_render(self.variables) key: text.async_render(self.variables),
} }
def as_dict(self): def as_dict(self):
"""Return response in an Alexa valid dict.""" """Return response in an Alexa valid dict."""
response = { response = {"shouldEndSession": self.should_end_session}
'shouldEndSession': self.should_end_session
}
if self.card is not None: if self.card is not None:
response['card'] = self.card response["card"] = self.card
if self.speech is not None: if self.speech is not None:
response['outputSpeech'] = self.speech response["outputSpeech"] = self.speech
if self.reprompt is not None: if self.reprompt is not None:
response['reprompt'] = { response["reprompt"] = {"outputSpeech": self.reprompt}
'outputSpeech': self.reprompt
}
return { return {
'version': '1.0', "version": "1.0",
'sessionAttributes': self.session_attributes, "sessionAttributes": self.session_attributes,
'response': response, "response": response,
} }

View file

@ -23,8 +23,8 @@ class AlexaDirective:
def __init__(self, request): def __init__(self, request):
"""Initialize a directive.""" """Initialize a directive."""
self._directive = request[API_DIRECTIVE] self._directive = request[API_DIRECTIVE]
self.namespace = self._directive[API_HEADER]['namespace'] self.namespace = self._directive[API_HEADER]["namespace"]
self.name = self._directive[API_HEADER]['name'] self.name = self._directive[API_HEADER]["name"]
self.payload = self._directive[API_PAYLOAD] self.payload = self._directive[API_PAYLOAD]
self.has_endpoint = API_ENDPOINT in self._directive self.has_endpoint = API_ENDPOINT in self._directive
@ -44,27 +44,23 @@ class AlexaDirective:
Will raise AlexaInvalidEndpointError if the endpoint in the request is Will raise AlexaInvalidEndpointError if the endpoint in the request is
malformed or nonexistant. malformed or nonexistant.
""" """
_endpoint_id = self._directive[API_ENDPOINT]['endpointId'] _endpoint_id = self._directive[API_ENDPOINT]["endpointId"]
self.entity_id = _endpoint_id.replace('#', '.') self.entity_id = _endpoint_id.replace("#", ".")
self.entity = hass.states.get(self.entity_id) self.entity = hass.states.get(self.entity_id)
if not self.entity or not config.should_expose(self.entity_id): if not self.entity or not config.should_expose(self.entity_id):
raise AlexaInvalidEndpointError(_endpoint_id) raise AlexaInvalidEndpointError(_endpoint_id)
self.endpoint = ENTITY_ADAPTERS[self.entity.domain]( self.endpoint = ENTITY_ADAPTERS[self.entity.domain](hass, config, self.entity)
hass, config, self.entity)
def response(self, def response(self, name="Response", namespace="Alexa", payload=None):
name='Response',
namespace='Alexa',
payload=None):
"""Create an API formatted response. """Create an API formatted response.
Async friendly. Async friendly.
""" """
response = AlexaResponse(name, namespace, payload) response = AlexaResponse(name, namespace, payload)
token = self._directive[API_HEADER].get('correlationToken') token = self._directive[API_HEADER].get("correlationToken")
if token: if token:
response.set_correlation_token(token) response.set_correlation_token(token)
@ -74,31 +70,30 @@ class AlexaDirective:
return response return response
def error( def error(
self, self,
namespace='Alexa', namespace="Alexa",
error_type='INTERNAL_ERROR', error_type="INTERNAL_ERROR",
error_message="", error_message="",
payload=None payload=None,
): ):
"""Create a API formatted error response. """Create a API formatted error response.
Async friendly. Async friendly.
""" """
payload = payload or {} payload = payload or {}
payload['type'] = error_type payload["type"] = error_type
payload['message'] = error_message payload["message"] = error_message
_LOGGER.info("Request %s/%s error %s: %s", _LOGGER.info(
self._directive[API_HEADER]['namespace'], "Request %s/%s error %s: %s",
self._directive[API_HEADER]['name'], self._directive[API_HEADER]["namespace"],
error_type, error_message) self._directive[API_HEADER]["name"],
error_type,
return self.response( error_message,
name='ErrorResponse',
namespace=namespace,
payload=payload
) )
return self.response(name="ErrorResponse", namespace=namespace, payload=payload)
class AlexaResponse: class AlexaResponse:
"""Class to hold a response.""" """Class to hold a response."""
@ -109,10 +104,10 @@ class AlexaResponse:
self._response = { self._response = {
API_EVENT: { API_EVENT: {
API_HEADER: { API_HEADER: {
'namespace': namespace, "namespace": namespace,
'name': name, "name": name,
'messageId': str(uuid4()), "messageId": str(uuid4()),
'payloadVersion': '3', "payloadVersion": "3",
}, },
API_PAYLOAD: payload, API_PAYLOAD: payload,
} }
@ -121,12 +116,12 @@ class AlexaResponse:
@property @property
def name(self): def name(self):
"""Return the name of this response.""" """Return the name of this response."""
return self._response[API_EVENT][API_HEADER]['name'] return self._response[API_EVENT][API_HEADER]["name"]
@property @property
def namespace(self): def namespace(self):
"""Return the namespace of this response.""" """Return the namespace of this response."""
return self._response[API_EVENT][API_HEADER]['namespace'] return self._response[API_EVENT][API_HEADER]["namespace"]
def set_correlation_token(self, token): def set_correlation_token(self, token):
"""Set the correlationToken. """Set the correlationToken.
@ -134,7 +129,7 @@ class AlexaResponse:
This should normally mirror the value from a request, and is set by This should normally mirror the value from a request, and is set by
AlexaDirective.response() usually. AlexaDirective.response() usually.
""" """
self._response[API_EVENT][API_HEADER]['correlationToken'] = token self._response[API_EVENT][API_HEADER]["correlationToken"] = token
def set_endpoint_full(self, bearer_token, endpoint_id, cookie=None): def set_endpoint_full(self, bearer_token, endpoint_id, cookie=None):
"""Set the endpoint dictionary. """Set the endpoint dictionary.
@ -142,17 +137,14 @@ class AlexaResponse:
This is used to send proactive messages to Alexa. This is used to send proactive messages to Alexa.
""" """
self._response[API_EVENT][API_ENDPOINT] = { self._response[API_EVENT][API_ENDPOINT] = {
API_SCOPE: { API_SCOPE: {"type": "BearerToken", "token": bearer_token}
'type': 'BearerToken',
'token': bearer_token
}
} }
if endpoint_id is not None: if endpoint_id is not None:
self._response[API_EVENT][API_ENDPOINT]['endpointId'] = endpoint_id self._response[API_EVENT][API_ENDPOINT]["endpointId"] = endpoint_id
if cookie is not None: if cookie is not None:
self._response[API_EVENT][API_ENDPOINT]['cookie'] = cookie self._response[API_EVENT][API_ENDPOINT]["cookie"] = cookie
def set_endpoint(self, endpoint): def set_endpoint(self, endpoint):
"""Set the endpoint. """Set the endpoint.
@ -164,7 +156,7 @@ class AlexaResponse:
def _properties(self): def _properties(self):
context = self._response.setdefault(API_CONTEXT, {}) context = self._response.setdefault(API_CONTEXT, {})
return context.setdefault('properties', []) return context.setdefault("properties", [])
def add_context_property(self, prop): def add_context_property(self, prop):
"""Add a property to the response context. """Add a property to the response context.
@ -189,10 +181,10 @@ class AlexaResponse:
Handlers should be using .add_context_property(). Handlers should be using .add_context_property().
""" """
properties = self._properties() properties = self._properties()
already_set = {(p['namespace'], p['name']) for p in properties} already_set = {(p["namespace"], p["name"]) for p in properties}
for prop in endpoint.serialize_properties(): for prop in endpoint.serialize_properties():
if (prop['namespace'], prop['name']) not in already_set: if (prop["namespace"], prop["name"]) not in already_set:
self.add_context_property(prop) self.add_context_property(prop)
def serialize(self): def serialize(self):

View file

@ -4,32 +4,23 @@ import logging
import homeassistant.core as ha import homeassistant.core as ha
from .const import API_DIRECTIVE, API_HEADER from .const import API_DIRECTIVE, API_HEADER
from .errors import ( from .errors import AlexaError, AlexaBridgeUnreachableError
AlexaError,
AlexaBridgeUnreachableError,
)
from .handlers import HANDLERS from .handlers import HANDLERS
from .messages import AlexaDirective from .messages import AlexaDirective
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
EVENT_ALEXA_SMART_HOME = 'alexa_smart_home' EVENT_ALEXA_SMART_HOME = "alexa_smart_home"
async def async_handle_message( async def async_handle_message(hass, config, request, context=None, enabled=True):
hass,
config,
request,
context=None,
enabled=True,
):
"""Handle incoming API messages. """Handle incoming API messages.
If enabled is False, the response to all messagess will be a If enabled is False, the response to all messagess will be a
BRIDGE_UNREACHABLE error. This can be used if the API has been disabled in BRIDGE_UNREACHABLE error. This can be used if the API has been disabled in
configuration. configuration.
""" """
assert request[API_DIRECTIVE][API_HEADER]['payloadVersion'] == '3' assert request[API_DIRECTIVE][API_HEADER]["payloadVersion"] == "3"
if context is None: if context is None:
context = ha.Context() context = ha.Context()
@ -39,7 +30,8 @@ async def async_handle_message(
try: try:
if not enabled: if not enabled:
raise AlexaBridgeUnreachableError( raise AlexaBridgeUnreachableError(
'Alexa API not enabled in Home Assistant configuration') "Alexa API not enabled in Home Assistant configuration"
)
if directive.has_endpoint: if directive.has_endpoint:
directive.load_entity(hass, config) directive.load_entity(hass, config)
@ -51,30 +43,26 @@ async def async_handle_message(
response.merge_context_properties(directive.endpoint) response.merge_context_properties(directive.endpoint)
else: else:
_LOGGER.warning( _LOGGER.warning(
"Unsupported API request %s/%s", "Unsupported API request %s/%s", directive.namespace, directive.name
directive.namespace,
directive.name,
) )
response = directive.error() response = directive.error()
except AlexaError as err: except AlexaError as err:
response = directive.error( response = directive.error(
error_type=err.error_type, error_type=err.error_type, error_message=err.error_message
error_message=err.error_message) )
request_info = { request_info = {"namespace": directive.namespace, "name": directive.name}
'namespace': directive.namespace,
'name': directive.name,
}
if directive.has_endpoint: if directive.has_endpoint:
request_info['entity_id'] = directive.entity_id request_info["entity_id"] = directive.entity_id
hass.bus.async_fire(EVENT_ALEXA_SMART_HOME, { hass.bus.async_fire(
'request': request_info, EVENT_ALEXA_SMART_HOME,
'response': { {
'namespace': response.namespace, "request": request_info,
'name': response.name, "response": {"namespace": response.namespace, "name": response.name},
} },
}, context=context) context=context,
)
return response.serialize() return response.serialize()

View file

@ -11,13 +11,13 @@ from .const import (
CONF_CLIENT_SECRET, CONF_CLIENT_SECRET,
CONF_ENDPOINT, CONF_ENDPOINT,
CONF_ENTITY_CONFIG, CONF_ENTITY_CONFIG,
CONF_FILTER CONF_FILTER,
) )
from .state_report import async_enable_proactive_mode from .state_report import async_enable_proactive_mode
from .smart_home import async_handle_message from .smart_home import async_handle_message
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SMART_HOME_HTTP_ENDPOINT = '/api/alexa/smart_home' SMART_HOME_HTTP_ENDPOINT = "/api/alexa/smart_home"
class AlexaConfig(AbstractConfig): class AlexaConfig(AbstractConfig):
@ -29,8 +29,7 @@ class AlexaConfig(AbstractConfig):
self._config = config self._config = config
if config.get(CONF_CLIENT_ID) and config.get(CONF_CLIENT_SECRET): if config.get(CONF_CLIENT_ID) and config.get(CONF_CLIENT_SECRET):
self._auth = Auth(hass, config[CONF_CLIENT_ID], self._auth = Auth(hass, config[CONF_CLIENT_ID], config[CONF_CLIENT_SECRET])
config[CONF_CLIENT_SECRET])
else: else:
self._auth = None self._auth = None
@ -87,7 +86,7 @@ class SmartHomeView(HomeAssistantView):
"""Expose Smart Home v3 payload interface via HTTP POST.""" """Expose Smart Home v3 payload interface via HTTP POST."""
url = SMART_HOME_HTTP_ENDPOINT url = SMART_HOME_HTTP_ENDPOINT
name = 'api:alexa:smart_home' name = "api:alexa:smart_home"
def __init__(self, smart_home_config): def __init__(self, smart_home_config):
"""Initialize.""" """Initialize."""
@ -100,15 +99,14 @@ class SmartHomeView(HomeAssistantView):
Lambda, which will need to forward the requests to here and pass back Lambda, which will need to forward the requests to here and pass back
the response. the response.
""" """
hass = request.app['hass'] hass = request.app["hass"]
user = request['hass_user'] user = request["hass_user"]
message = await request.json() message = await request.json()
_LOGGER.debug("Received Alexa Smart Home request: %s", message) _LOGGER.debug("Received Alexa Smart Home request: %s", message)
response = await async_handle_message( response = await async_handle_message(
hass, self.smart_home_config, message, hass, self.smart_home_config, message, context=core.Context(user_id=user.id)
context=core.Context(user_id=user.id)
) )
_LOGGER.debug("Sending Alexa Smart Home response: %s", response) _LOGGER.debug("Sending Alexa Smart Home response: %s", response)
return b'' if response is None else self.json(response) return b"" if response is None else self.json(response)

View file

@ -24,8 +24,7 @@ async def async_enable_proactive_mode(hass, smart_home_config):
# Validate we can get access token. # Validate we can get access token.
await smart_home_config.async_get_access_token() await smart_home_config.async_get_access_token()
async def async_entity_state_listener(changed_entity, old_state, async def async_entity_state_listener(changed_entity, old_state, new_state):
new_state):
if not new_state: if not new_state:
return return
@ -33,18 +32,18 @@ async def async_enable_proactive_mode(hass, smart_home_config):
return return
if not smart_home_config.should_expose(changed_entity): if not smart_home_config.should_expose(changed_entity):
_LOGGER.debug("Not exposing %s because filtered by config", _LOGGER.debug("Not exposing %s because filtered by config", changed_entity)
changed_entity)
return return
alexa_changed_entity = \ alexa_changed_entity = ENTITY_ADAPTERS[new_state.domain](
ENTITY_ADAPTERS[new_state.domain](hass, smart_home_config, hass, smart_home_config, new_state
new_state) )
for interface in alexa_changed_entity.interfaces(): for interface in alexa_changed_entity.interfaces():
if interface.properties_proactively_reported(): if interface.properties_proactively_reported():
await async_send_changereport_message(hass, smart_home_config, await async_send_changereport_message(
alexa_changed_entity) hass, smart_home_config, alexa_changed_entity
)
return return
return hass.helpers.event.async_track_state_change( return hass.helpers.event.async_track_state_change(
@ -59,9 +58,7 @@ async def async_send_changereport_message(hass, config, alexa_entity):
""" """
token = await config.async_get_access_token() token = await config.async_get_access_token()
headers = { headers = {"Authorization": "Bearer {}".format(token)}
"Authorization": "Bearer {}".format(token)
}
endpoint = alexa_entity.alexa_id() endpoint = alexa_entity.alexa_id()
@ -71,14 +68,10 @@ async def async_send_changereport_message(hass, config, alexa_entity):
properties = list(alexa_entity.serialize_properties()) properties = list(alexa_entity.serialize_properties())
payload = { payload = {
API_CHANGE: { API_CHANGE: {"cause": {"type": Cause.APP_INTERACTION}, "properties": properties}
'cause': {'type': Cause.APP_INTERACTION},
'properties': properties
}
} }
message = AlexaResponse(name='ChangeReport', namespace='Alexa', message = AlexaResponse(name="ChangeReport", namespace="Alexa", payload=payload)
payload=payload)
message.set_endpoint_full(token, endpoint) message.set_endpoint_full(token, endpoint)
message_serialized = message.serialize() message_serialized = message.serialize()
@ -86,10 +79,12 @@ async def async_send_changereport_message(hass, config, alexa_entity):
try: try:
with async_timeout.timeout(DEFAULT_TIMEOUT): with async_timeout.timeout(DEFAULT_TIMEOUT):
response = await session.post(config.endpoint, response = await session.post(
headers=headers, config.endpoint,
json=message_serialized, headers=headers,
allow_redirects=True) json=message_serialized,
allow_redirects=True,
)
except (asyncio.TimeoutError, aiohttp.ClientError): except (asyncio.TimeoutError, aiohttp.ClientError):
_LOGGER.error("Timeout sending report to Alexa.") _LOGGER.error("Timeout sending report to Alexa.")
@ -102,9 +97,11 @@ async def async_send_changereport_message(hass, config, alexa_entity):
if response.status != 202: if response.status != 202:
response_json = json.loads(response_text) response_json = json.loads(response_text)
_LOGGER.error("Error when sending ChangeReport to Alexa: %s: %s", _LOGGER.error(
response_json["payload"]["code"], "Error when sending ChangeReport to Alexa: %s: %s",
response_json["payload"]["description"]) response_json["payload"]["code"],
response_json["payload"]["description"],
)
async def async_send_add_or_update_message(hass, config, entity_ids): async def async_send_add_or_update_message(hass, config, entity_ids):
@ -114,35 +111,27 @@ async def async_send_add_or_update_message(hass, config, entity_ids):
""" """
token = await config.async_get_access_token() token = await config.async_get_access_token()
headers = { headers = {"Authorization": "Bearer {}".format(token)}
"Authorization": "Bearer {}".format(token)
}
endpoints = [] endpoints = []
for entity_id in entity_ids: for entity_id in entity_ids:
domain = entity_id.split('.', 1)[0] domain = entity_id.split(".", 1)[0]
alexa_entity = ENTITY_ADAPTERS[domain]( alexa_entity = ENTITY_ADAPTERS[domain](hass, config, hass.states.get(entity_id))
hass, config, hass.states.get(entity_id)
)
endpoints.append(alexa_entity.serialize_discovery()) endpoints.append(alexa_entity.serialize_discovery())
payload = { payload = {"endpoints": endpoints, "scope": {"type": "BearerToken", "token": token}}
'endpoints': endpoints,
'scope': {
'type': 'BearerToken',
'token': token,
}
}
message = AlexaResponse( message = AlexaResponse(
name='AddOrUpdateReport', namespace='Alexa.Discovery', payload=payload) name="AddOrUpdateReport", namespace="Alexa.Discovery", payload=payload
)
message_serialized = message.serialize() message_serialized = message.serialize()
session = hass.helpers.aiohttp_client.async_get_clientsession() session = hass.helpers.aiohttp_client.async_get_clientsession()
return await session.post(config.endpoint, headers=headers, return await session.post(
json=message_serialized, allow_redirects=True) config.endpoint, headers=headers, json=message_serialized, allow_redirects=True
)
async def async_send_delete_message(hass, config, entity_ids): async def async_send_delete_message(hass, config, entity_ids):
@ -152,34 +141,24 @@ async def async_send_delete_message(hass, config, entity_ids):
""" """
token = await config.async_get_access_token() token = await config.async_get_access_token()
headers = { headers = {"Authorization": "Bearer {}".format(token)}
"Authorization": "Bearer {}".format(token)
}
endpoints = [] endpoints = []
for entity_id in entity_ids: for entity_id in entity_ids:
domain = entity_id.split('.', 1)[0] domain = entity_id.split(".", 1)[0]
alexa_entity = ENTITY_ADAPTERS[domain]( alexa_entity = ENTITY_ADAPTERS[domain](hass, config, hass.states.get(entity_id))
hass, config, hass.states.get(entity_id) endpoints.append({"endpointId": alexa_entity.alexa_id()})
)
endpoints.append({
'endpointId': alexa_entity.alexa_id()
})
payload = { payload = {"endpoints": endpoints, "scope": {"type": "BearerToken", "token": token}}
'endpoints': endpoints,
'scope': {
'type': 'BearerToken',
'token': token,
}
}
message = AlexaResponse(name='DeleteReport', namespace='Alexa.Discovery', message = AlexaResponse(
payload=payload) name="DeleteReport", namespace="Alexa.Discovery", payload=payload
)
message_serialized = message.serialize() message_serialized = message.serialize()
session = hass.helpers.aiohttp_client.async_get_clientsession() session = hass.helpers.aiohttp_client.async_get_clientsession()
return await session.post(config.endpoint, headers=headers, return await session.post(
json=message_serialized, allow_redirects=True) config.endpoint, headers=headers, json=message_serialized, allow_redirects=True
)

View file

@ -5,56 +5,59 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.components.sensor import PLATFORM_SCHEMA from homeassistant.components.sensor import PLATFORM_SCHEMA
from homeassistant.const import ( from homeassistant.const import ATTR_ATTRIBUTION, CONF_API_KEY, CONF_CURRENCY, CONF_NAME
ATTR_ATTRIBUTION, CONF_API_KEY, CONF_CURRENCY, CONF_NAME)
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ATTR_CLOSE = 'close' ATTR_CLOSE = "close"
ATTR_HIGH = 'high' ATTR_HIGH = "high"
ATTR_LOW = 'low' ATTR_LOW = "low"
ATTRIBUTION = "Stock market information provided by Alpha Vantage" ATTRIBUTION = "Stock market information provided by Alpha Vantage"
CONF_FOREIGN_EXCHANGE = 'foreign_exchange' CONF_FOREIGN_EXCHANGE = "foreign_exchange"
CONF_FROM = 'from' CONF_FROM = "from"
CONF_SYMBOL = 'symbol' CONF_SYMBOL = "symbol"
CONF_SYMBOLS = 'symbols' CONF_SYMBOLS = "symbols"
CONF_TO = 'to' CONF_TO = "to"
ICONS = { ICONS = {
'BTC': 'mdi:currency-btc', "BTC": "mdi:currency-btc",
'EUR': 'mdi:currency-eur', "EUR": "mdi:currency-eur",
'GBP': 'mdi:currency-gbp', "GBP": "mdi:currency-gbp",
'INR': 'mdi:currency-inr', "INR": "mdi:currency-inr",
'RUB': 'mdi:currency-rub', "RUB": "mdi:currency-rub",
'TRY': 'mdi:currency-try', "TRY": "mdi:currency-try",
'USD': 'mdi:currency-usd', "USD": "mdi:currency-usd",
} }
SCAN_INTERVAL = timedelta(minutes=5) SCAN_INTERVAL = timedelta(minutes=5)
SYMBOL_SCHEMA = vol.Schema({ SYMBOL_SCHEMA = vol.Schema(
vol.Required(CONF_SYMBOL): cv.string, {
vol.Optional(CONF_CURRENCY): cv.string, vol.Required(CONF_SYMBOL): cv.string,
vol.Optional(CONF_NAME): cv.string, vol.Optional(CONF_CURRENCY): cv.string,
}) vol.Optional(CONF_NAME): cv.string,
}
)
CURRENCY_SCHEMA = vol.Schema({ CURRENCY_SCHEMA = vol.Schema(
vol.Required(CONF_FROM): cv.string, {
vol.Required(CONF_TO): cv.string, vol.Required(CONF_FROM): cv.string,
vol.Optional(CONF_NAME): cv.string, vol.Required(CONF_TO): cv.string,
}) vol.Optional(CONF_NAME): cv.string,
}
)
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_API_KEY): cv.string, {
vol.Optional(CONF_FOREIGN_EXCHANGE): vol.Required(CONF_API_KEY): cv.string,
vol.All(cv.ensure_list, [CURRENCY_SCHEMA]), vol.Optional(CONF_FOREIGN_EXCHANGE): vol.All(cv.ensure_list, [CURRENCY_SCHEMA]),
vol.Optional(CONF_SYMBOLS): vol.Optional(CONF_SYMBOLS): vol.All(cv.ensure_list, [SYMBOL_SCHEMA]),
vol.All(cv.ensure_list, [SYMBOL_SCHEMA]), }
}) )
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
@ -67,9 +70,8 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
conversions = config.get(CONF_FOREIGN_EXCHANGE, []) conversions = config.get(CONF_FOREIGN_EXCHANGE, [])
if not symbols and not conversions: if not symbols and not conversions:
msg = 'Warning: No symbols or currencies configured.' msg = "Warning: No symbols or currencies configured."
hass.components.persistent_notification.create( hass.components.persistent_notification.create(msg, "Sensor alpha_vantage")
msg, 'Sensor alpha_vantage')
_LOGGER.warning(msg) _LOGGER.warning(msg)
return return
@ -78,12 +80,10 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
dev = [] dev = []
for symbol in symbols: for symbol in symbols:
try: try:
_LOGGER.debug("Configuring timeseries for symbols: %s", _LOGGER.debug("Configuring timeseries for symbols: %s", symbol[CONF_SYMBOL])
symbol[CONF_SYMBOL])
timeseries.get_intraday(symbol[CONF_SYMBOL]) timeseries.get_intraday(symbol[CONF_SYMBOL])
except ValueError: except ValueError:
_LOGGER.error( _LOGGER.error("API Key is not valid or symbol '%s' not known", symbol)
"API Key is not valid or symbol '%s' not known", symbol)
dev.append(AlphaVantageSensor(timeseries, symbol)) dev.append(AlphaVantageSensor(timeseries, symbol))
forex = ForeignExchange(key=api_key) forex = ForeignExchange(key=api_key)
@ -92,12 +92,13 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
to_cur = conversion.get(CONF_TO) to_cur = conversion.get(CONF_TO)
try: try:
_LOGGER.debug("Configuring forex %s - %s", from_cur, to_cur) _LOGGER.debug("Configuring forex %s - %s", from_cur, to_cur)
forex.get_currency_exchange_rate( forex.get_currency_exchange_rate(from_currency=from_cur, to_currency=to_cur)
from_currency=from_cur, to_currency=to_cur)
except ValueError as error: except ValueError as error:
_LOGGER.error( _LOGGER.error(
"API Key is not valid or currencies '%s'/'%s' not known", "API Key is not valid or currencies '%s'/'%s' not known",
from_cur, to_cur) from_cur,
to_cur,
)
_LOGGER.debug(str(error)) _LOGGER.debug(str(error))
dev.append(AlphaVantageForeignExchange(forex, conversion)) dev.append(AlphaVantageForeignExchange(forex, conversion))
@ -115,7 +116,7 @@ class AlphaVantageSensor(Entity):
self._timeseries = timeseries self._timeseries = timeseries
self.values = None self.values = None
self._unit_of_measurement = symbol.get(CONF_CURRENCY, self._symbol) self._unit_of_measurement = symbol.get(CONF_CURRENCY, self._symbol)
self._icon = ICONS.get(symbol.get(CONF_CURRENCY, 'USD')) self._icon = ICONS.get(symbol.get(CONF_CURRENCY, "USD"))
@property @property
def name(self): def name(self):
@ -130,7 +131,7 @@ class AlphaVantageSensor(Entity):
@property @property
def state(self): def state(self):
"""Return the state of the sensor.""" """Return the state of the sensor."""
return self.values['1. open'] return self.values["1. open"]
@property @property
def device_state_attributes(self): def device_state_attributes(self):
@ -138,9 +139,9 @@ class AlphaVantageSensor(Entity):
if self.values is not None: if self.values is not None:
return { return {
ATTR_ATTRIBUTION: ATTRIBUTION, ATTR_ATTRIBUTION: ATTRIBUTION,
ATTR_CLOSE: self.values['4. close'], ATTR_CLOSE: self.values["4. close"],
ATTR_HIGH: self.values['2. high'], ATTR_HIGH: self.values["2. high"],
ATTR_LOW: self.values['3. low'], ATTR_LOW: self.values["3. low"],
} }
@property @property
@ -167,9 +168,9 @@ class AlphaVantageForeignExchange(Entity):
if CONF_NAME in config: if CONF_NAME in config:
self._name = config.get(CONF_NAME) self._name = config.get(CONF_NAME)
else: else:
self._name = '{}/{}'.format(self._to_currency, self._from_currency) self._name = "{}/{}".format(self._to_currency, self._from_currency)
self._unit_of_measurement = self._to_currency self._unit_of_measurement = self._to_currency
self._icon = ICONS.get(self._from_currency, 'USD') self._icon = ICONS.get(self._from_currency, "USD")
self.values = None self.values = None
@property @property
@ -185,7 +186,7 @@ class AlphaVantageForeignExchange(Entity):
@property @property
def state(self): def state(self):
"""Return the state of the sensor.""" """Return the state of the sensor."""
return round(float(self.values['5. Exchange Rate']), 4) return round(float(self.values["5. Exchange Rate"]), 4)
@property @property
def icon(self): def icon(self):
@ -204,9 +205,16 @@ class AlphaVantageForeignExchange(Entity):
def update(self): def update(self):
"""Get the latest data and updates the states.""" """Get the latest data and updates the states."""
_LOGGER.debug("Requesting new data for forex %s - %s", _LOGGER.debug(
self._from_currency, self._to_currency) "Requesting new data for forex %s - %s",
self._from_currency,
self._to_currency,
)
self.values, _ = self._foreign_exchange.get_currency_exchange_rate( self.values, _ = self._foreign_exchange.get_currency_exchange_rate(
from_currency=self._from_currency, to_currency=self._to_currency) from_currency=self._from_currency, to_currency=self._to_currency
_LOGGER.debug("Received new data for forex %s - %s", )
self._from_currency, self._to_currency) _LOGGER.debug(
"Received new data for forex %s - %s",
self._from_currency,
self._to_currency,
)

View file

@ -8,108 +8,145 @@ import homeassistant.helpers.config_validation as cv
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
CONF_REGION = 'region_name' CONF_REGION = "region_name"
CONF_ACCESS_KEY_ID = 'aws_access_key_id' CONF_ACCESS_KEY_ID = "aws_access_key_id"
CONF_SECRET_ACCESS_KEY = 'aws_secret_access_key' CONF_SECRET_ACCESS_KEY = "aws_secret_access_key"
CONF_PROFILE_NAME = 'profile_name' CONF_PROFILE_NAME = "profile_name"
ATTR_CREDENTIALS = 'credentials' ATTR_CREDENTIALS = "credentials"
DEFAULT_REGION = 'us-east-1' DEFAULT_REGION = "us-east-1"
SUPPORTED_REGIONS = ['us-east-1', 'us-east-2', 'us-west-1', 'us-west-2', SUPPORTED_REGIONS = [
'ca-central-1', 'eu-west-1', 'eu-central-1', 'eu-west-2', "us-east-1",
'eu-west-3', 'ap-southeast-1', 'ap-southeast-2', "us-east-2",
'ap-northeast-2', 'ap-northeast-1', 'ap-south-1', "us-west-1",
'sa-east-1'] "us-west-2",
"ca-central-1",
CONF_VOICE = 'voice' "eu-west-1",
CONF_OUTPUT_FORMAT = 'output_format' "eu-central-1",
CONF_SAMPLE_RATE = 'sample_rate' "eu-west-2",
CONF_TEXT_TYPE = 'text_type' "eu-west-3",
"ap-southeast-1",
SUPPORTED_VOICES = [ "ap-southeast-2",
'Zhiyu', # Chinese "ap-northeast-2",
'Mads', 'Naja', # Danish "ap-northeast-1",
'Ruben', 'Lotte', # Dutch "ap-south-1",
'Russell', 'Nicole', # English Austrailian "sa-east-1",
'Brian', 'Amy', 'Emma', # English
'Aditi', 'Raveena', # English, Indian
'Joey', 'Justin', 'Matthew', 'Ivy', 'Joanna', 'Kendra', 'Kimberly',
'Salli', # English
'Geraint', # English Welsh
'Mathieu', 'Celine', 'Lea', # French
'Chantal', # French Canadian
'Hans', 'Marlene', 'Vicki', # German
'Aditi', # Hindi
'Karl', 'Dora', # Icelandic
'Giorgio', 'Carla', 'Bianca', # Italian
'Takumi', 'Mizuki', # Japanese
'Seoyeon', # Korean
'Liv', # Norwegian
'Jacek', 'Jan', 'Ewa', 'Maja', # Polish
'Ricardo', 'Vitoria', # Portuguese, Brazilian
'Cristiano', 'Ines', # Portuguese, European
'Carmen', # Romanian
'Maxim', 'Tatyana', # Russian
'Enrique', 'Conchita', 'Lucia', # Spanish European
'Mia', # Spanish Mexican
'Miguel', 'Penelope', # Spanish US
'Astrid', # Swedish
'Filiz', # Turkish
'Gwyneth', # Welsh
] ]
SUPPORTED_OUTPUT_FORMATS = ['mp3', 'ogg_vorbis', 'pcm'] CONF_VOICE = "voice"
CONF_OUTPUT_FORMAT = "output_format"
CONF_SAMPLE_RATE = "sample_rate"
CONF_TEXT_TYPE = "text_type"
SUPPORTED_SAMPLE_RATES = ['8000', '16000', '22050'] SUPPORTED_VOICES = [
"Zhiyu", # Chinese
"Mads",
"Naja", # Danish
"Ruben",
"Lotte", # Dutch
"Russell",
"Nicole", # English Austrailian
"Brian",
"Amy",
"Emma", # English
"Aditi",
"Raveena", # English, Indian
"Joey",
"Justin",
"Matthew",
"Ivy",
"Joanna",
"Kendra",
"Kimberly",
"Salli", # English
"Geraint", # English Welsh
"Mathieu",
"Celine",
"Lea", # French
"Chantal", # French Canadian
"Hans",
"Marlene",
"Vicki", # German
"Aditi", # Hindi
"Karl",
"Dora", # Icelandic
"Giorgio",
"Carla",
"Bianca", # Italian
"Takumi",
"Mizuki", # Japanese
"Seoyeon", # Korean
"Liv", # Norwegian
"Jacek",
"Jan",
"Ewa",
"Maja", # Polish
"Ricardo",
"Vitoria", # Portuguese, Brazilian
"Cristiano",
"Ines", # Portuguese, European
"Carmen", # Romanian
"Maxim",
"Tatyana", # Russian
"Enrique",
"Conchita",
"Lucia", # Spanish European
"Mia", # Spanish Mexican
"Miguel",
"Penelope", # Spanish US
"Astrid", # Swedish
"Filiz", # Turkish
"Gwyneth", # Welsh
]
SUPPORTED_OUTPUT_FORMATS = ["mp3", "ogg_vorbis", "pcm"]
SUPPORTED_SAMPLE_RATES = ["8000", "16000", "22050"]
SUPPORTED_SAMPLE_RATES_MAP = { SUPPORTED_SAMPLE_RATES_MAP = {
'mp3': ['8000', '16000', '22050'], "mp3": ["8000", "16000", "22050"],
'ogg_vorbis': ['8000', '16000', '22050'], "ogg_vorbis": ["8000", "16000", "22050"],
'pcm': ['8000', '16000'], "pcm": ["8000", "16000"],
} }
SUPPORTED_TEXT_TYPES = ['text', 'ssml'] SUPPORTED_TEXT_TYPES = ["text", "ssml"]
CONTENT_TYPE_EXTENSIONS = { CONTENT_TYPE_EXTENSIONS = {"audio/mpeg": "mp3", "audio/ogg": "ogg", "audio/pcm": "pcm"}
'audio/mpeg': 'mp3',
'audio/ogg': 'ogg',
'audio/pcm': 'pcm',
}
DEFAULT_VOICE = 'Joanna' DEFAULT_VOICE = "Joanna"
DEFAULT_OUTPUT_FORMAT = 'mp3' DEFAULT_OUTPUT_FORMAT = "mp3"
DEFAULT_TEXT_TYPE = 'text' DEFAULT_TEXT_TYPE = "text"
DEFAULT_SAMPLE_RATES = { DEFAULT_SAMPLE_RATES = {"mp3": "22050", "ogg_vorbis": "22050", "pcm": "16000"}
'mp3': '22050',
'ogg_vorbis': '22050',
'pcm': '16000',
}
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Optional(CONF_REGION, default=DEFAULT_REGION): {
vol.In(SUPPORTED_REGIONS), vol.Optional(CONF_REGION, default=DEFAULT_REGION): vol.In(SUPPORTED_REGIONS),
vol.Inclusive(CONF_ACCESS_KEY_ID, ATTR_CREDENTIALS): cv.string, vol.Inclusive(CONF_ACCESS_KEY_ID, ATTR_CREDENTIALS): cv.string,
vol.Inclusive(CONF_SECRET_ACCESS_KEY, ATTR_CREDENTIALS): cv.string, vol.Inclusive(CONF_SECRET_ACCESS_KEY, ATTR_CREDENTIALS): cv.string,
vol.Exclusive(CONF_PROFILE_NAME, ATTR_CREDENTIALS): cv.string, vol.Exclusive(CONF_PROFILE_NAME, ATTR_CREDENTIALS): cv.string,
vol.Optional(CONF_VOICE, default=DEFAULT_VOICE): vol.In(SUPPORTED_VOICES), vol.Optional(CONF_VOICE, default=DEFAULT_VOICE): vol.In(SUPPORTED_VOICES),
vol.Optional(CONF_OUTPUT_FORMAT, default=DEFAULT_OUTPUT_FORMAT): vol.Optional(CONF_OUTPUT_FORMAT, default=DEFAULT_OUTPUT_FORMAT): vol.In(
vol.In(SUPPORTED_OUTPUT_FORMATS), SUPPORTED_OUTPUT_FORMATS
vol.Optional(CONF_SAMPLE_RATE): ),
vol.All(cv.string, vol.In(SUPPORTED_SAMPLE_RATES)), vol.Optional(CONF_SAMPLE_RATE): vol.All(
vol.Optional(CONF_TEXT_TYPE, default=DEFAULT_TEXT_TYPE): cv.string, vol.In(SUPPORTED_SAMPLE_RATES)
vol.In(SUPPORTED_TEXT_TYPES), ),
}) vol.Optional(CONF_TEXT_TYPE, default=DEFAULT_TEXT_TYPE): vol.In(
SUPPORTED_TEXT_TYPES
),
}
)
def get_engine(hass, config): def get_engine(hass, config):
"""Set up Amazon Polly speech component.""" """Set up Amazon Polly speech component."""
output_format = config.get(CONF_OUTPUT_FORMAT) output_format = config.get(CONF_OUTPUT_FORMAT)
sample_rate = config.get( sample_rate = config.get(CONF_SAMPLE_RATE, DEFAULT_SAMPLE_RATES[output_format])
CONF_SAMPLE_RATE, DEFAULT_SAMPLE_RATES[output_format])
if sample_rate not in SUPPORTED_SAMPLE_RATES_MAP.get(output_format): if sample_rate not in SUPPORTED_SAMPLE_RATES_MAP.get(output_format):
_LOGGER.error("%s is not a valid sample rate for %s", _LOGGER.error(
sample_rate, output_format) "%s is not a valid sample rate for %s", sample_rate, output_format
)
return None return None
config[CONF_SAMPLE_RATE] = sample_rate config[CONF_SAMPLE_RATE] = sample_rate
@ -131,7 +168,7 @@ def get_engine(hass, config):
del config[CONF_ACCESS_KEY_ID] del config[CONF_ACCESS_KEY_ID]
del config[CONF_SECRET_ACCESS_KEY] del config[CONF_SECRET_ACCESS_KEY]
polly_client = boto3.client('polly', **aws_config) polly_client = boto3.client("polly", **aws_config)
supported_languages = [] supported_languages = []
@ -139,27 +176,25 @@ def get_engine(hass, config):
all_voices_req = polly_client.describe_voices() all_voices_req = polly_client.describe_voices()
for voice in all_voices_req.get('Voices'): for voice in all_voices_req.get("Voices"):
all_voices[voice.get('Id')] = voice all_voices[voice.get("Id")] = voice
if voice.get('LanguageCode') not in supported_languages: if voice.get("LanguageCode") not in supported_languages:
supported_languages.append(voice.get('LanguageCode')) supported_languages.append(voice.get("LanguageCode"))
return AmazonPollyProvider( return AmazonPollyProvider(polly_client, config, supported_languages, all_voices)
polly_client, config, supported_languages, all_voices)
class AmazonPollyProvider(Provider): class AmazonPollyProvider(Provider):
"""Amazon Polly speech api provider.""" """Amazon Polly speech api provider."""
def __init__(self, polly_client, config, supported_languages, def __init__(self, polly_client, config, supported_languages, all_voices):
all_voices):
"""Initialize Amazon Polly provider for TTS.""" """Initialize Amazon Polly provider for TTS."""
self.client = polly_client self.client = polly_client
self.config = config self.config = config
self.supported_langs = supported_languages self.supported_langs = supported_languages
self.all_voices = all_voices self.all_voices = all_voices
self.default_voice = self.config.get(CONF_VOICE) self.default_voice = self.config.get(CONF_VOICE)
self.name = 'Amazon Polly' self.name = "Amazon Polly"
@property @property
def supported_languages(self): def supported_languages(self):
@ -169,7 +204,7 @@ class AmazonPollyProvider(Provider):
@property @property
def default_language(self): def default_language(self):
"""Return the default language.""" """Return the default language."""
return self.all_voices.get(self.default_voice).get('LanguageCode') return self.all_voices.get(self.default_voice).get("LanguageCode")
@property @property
def default_options(self): def default_options(self):
@ -185,9 +220,8 @@ class AmazonPollyProvider(Provider):
"""Request TTS file from Polly.""" """Request TTS file from Polly."""
voice_id = options.get(CONF_VOICE, self.default_voice) voice_id = options.get(CONF_VOICE, self.default_voice)
voice_in_dict = self.all_voices.get(voice_id) voice_in_dict = self.all_voices.get(voice_id)
if language != voice_in_dict.get('LanguageCode'): if language != voice_in_dict.get("LanguageCode"):
_LOGGER.error("%s does not support the %s language", _LOGGER.error("%s does not support the %s language", voice_id, language)
voice_id, language)
return None, None return None, None
resp = self.client.synthesize_speech( resp = self.client.synthesize_speech(
@ -195,8 +229,10 @@ class AmazonPollyProvider(Provider):
SampleRate=self.config[CONF_SAMPLE_RATE], SampleRate=self.config[CONF_SAMPLE_RATE],
Text=message, Text=message,
TextType=self.config[CONF_TEXT_TYPE], TextType=self.config[CONF_TEXT_TYPE],
VoiceId=voice_id VoiceId=voice_id,
) )
return (CONTENT_TYPE_EXTENSIONS[resp.get('ContentType')], return (
resp.get('AudioStream').read()) CONTENT_TYPE_EXTENSIONS[resp.get("ContentType")],
resp.get("AudioStream").read(),
)

View file

@ -12,11 +12,12 @@ _LOGGER = logging.getLogger(__name__)
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(
{ {
DOMAIN: DOMAIN: vol.Schema(
vol.Schema({ {
vol.Required(CONF_CLIENT_ID): cv.string, vol.Required(CONF_CLIENT_ID): cv.string,
vol.Required(CONF_CLIENT_SECRET): cv.string, vol.Required(CONF_CLIENT_SECRET): cv.string,
}) }
)
}, },
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
) )
@ -30,15 +31,16 @@ async def async_setup(hass, config):
conf = config[DOMAIN] conf = config[DOMAIN]
config_flow.register_flow_implementation( config_flow.register_flow_implementation(
hass, conf[CONF_CLIENT_ID], hass, conf[CONF_CLIENT_ID], conf[CONF_CLIENT_SECRET]
conf[CONF_CLIENT_SECRET]) )
return True return True
async def async_setup_entry(hass, entry): async def async_setup_entry(hass, entry):
"""Set up Ambiclimate from a config entry.""" """Set up Ambiclimate from a config entry."""
hass.async_create_task(hass.config_entries.async_forward_entry_setup( hass.async_create_task(
entry, 'climate')) hass.config_entries.async_forward_entry_setup(entry, "climate")
)
return True return True

View file

@ -7,35 +7,41 @@ import voluptuous as vol
from homeassistant.components.climate import ClimateDevice from homeassistant.components.climate import ClimateDevice
from homeassistant.components.climate.const import ( from homeassistant.components.climate.const import (
SUPPORT_TARGET_TEMPERATURE, HVAC_MODE_OFF, HVAC_MODE_HEAT) SUPPORT_TARGET_TEMPERATURE,
HVAC_MODE_OFF,
HVAC_MODE_HEAT,
)
from homeassistant.const import ATTR_NAME, ATTR_TEMPERATURE, TEMP_CELSIUS from homeassistant.const import ATTR_NAME, ATTR_TEMPERATURE, TEMP_CELSIUS
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import (ATTR_VALUE, CONF_CLIENT_ID, CONF_CLIENT_SECRET, from .const import (
DOMAIN, SERVICE_COMFORT_FEEDBACK, SERVICE_COMFORT_MODE, ATTR_VALUE,
SERVICE_TEMPERATURE_MODE, STORAGE_KEY, STORAGE_VERSION) CONF_CLIENT_ID,
CONF_CLIENT_SECRET,
DOMAIN,
SERVICE_COMFORT_FEEDBACK,
SERVICE_COMFORT_MODE,
SERVICE_TEMPERATURE_MODE,
STORAGE_KEY,
STORAGE_VERSION,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SUPPORT_FLAGS = SUPPORT_TARGET_TEMPERATURE SUPPORT_FLAGS = SUPPORT_TARGET_TEMPERATURE
SEND_COMFORT_FEEDBACK_SCHEMA = vol.Schema({ SEND_COMFORT_FEEDBACK_SCHEMA = vol.Schema(
vol.Required(ATTR_NAME): cv.string, {vol.Required(ATTR_NAME): cv.string, vol.Required(ATTR_VALUE): cv.string}
vol.Required(ATTR_VALUE): cv.string, )
})
SET_COMFORT_MODE_SCHEMA = vol.Schema({ SET_COMFORT_MODE_SCHEMA = vol.Schema({vol.Required(ATTR_NAME): cv.string})
vol.Required(ATTR_NAME): cv.string,
})
SET_TEMPERATURE_MODE_SCHEMA = vol.Schema({ SET_TEMPERATURE_MODE_SCHEMA = vol.Schema(
vol.Required(ATTR_NAME): cv.string, {vol.Required(ATTR_NAME): cv.string, vol.Required(ATTR_VALUE): cv.string}
vol.Required(ATTR_VALUE): cv.string, )
})
async def async_setup_platform(hass, config, async_add_entities, async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
discovery_info=None):
"""Set up the Ambicliamte device.""" """Set up the Ambicliamte device."""
@ -46,10 +52,12 @@ async def async_setup_entry(hass, entry, async_add_entities):
store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY) store = hass.helpers.storage.Store(STORAGE_VERSION, STORAGE_KEY)
token_info = await store.async_load() token_info = await store.async_load()
oauth = ambiclimate.AmbiclimateOAuth(config[CONF_CLIENT_ID], oauth = ambiclimate.AmbiclimateOAuth(
config[CONF_CLIENT_SECRET], config[CONF_CLIENT_ID],
config['callback_url'], config[CONF_CLIENT_SECRET],
websession) config["callback_url"],
websession,
)
try: try:
token_info = await oauth.refresh_access_token(token_info) token_info = await oauth.refresh_access_token(token_info)
@ -62,9 +70,9 @@ async def async_setup_entry(hass, entry, async_add_entities):
await store.async_save(token_info) await store.async_save(token_info)
data_connection = ambiclimate.AmbiclimateConnection(oauth, data_connection = ambiclimate.AmbiclimateConnection(
token_info=token_info, oauth, token_info=token_info, websession=websession
websession=websession) )
if not await data_connection.find_devices(): if not await data_connection.find_devices():
_LOGGER.error("No devices found") _LOGGER.error("No devices found")
@ -88,10 +96,12 @@ async def async_setup_entry(hass, entry, async_add_entities):
if device: if device:
await device.set_comfort_feedback(service.data[ATTR_VALUE]) await device.set_comfort_feedback(service.data[ATTR_VALUE])
hass.services.async_register(DOMAIN, hass.services.async_register(
SERVICE_COMFORT_FEEDBACK, DOMAIN,
send_comfort_feedback, SERVICE_COMFORT_FEEDBACK,
schema=SEND_COMFORT_FEEDBACK_SCHEMA) send_comfort_feedback,
schema=SEND_COMFORT_FEEDBACK_SCHEMA,
)
async def set_comfort_mode(service): async def set_comfort_mode(service):
"""Set comfort mode.""" """Set comfort mode."""
@ -100,10 +110,9 @@ async def async_setup_entry(hass, entry, async_add_entities):
if device: if device:
await device.set_comfort_mode() await device.set_comfort_mode()
hass.services.async_register(DOMAIN, hass.services.async_register(
SERVICE_COMFORT_MODE, DOMAIN, SERVICE_COMFORT_MODE, set_comfort_mode, schema=SET_COMFORT_MODE_SCHEMA
set_comfort_mode, )
schema=SET_COMFORT_MODE_SCHEMA)
async def set_temperature_mode(service): async def set_temperature_mode(service):
"""Set temperature mode.""" """Set temperature mode."""
@ -112,10 +121,12 @@ async def async_setup_entry(hass, entry, async_add_entities):
if device: if device:
await device.set_temperature_mode(service.data[ATTR_VALUE]) await device.set_temperature_mode(service.data[ATTR_VALUE])
hass.services.async_register(DOMAIN, hass.services.async_register(
SERVICE_TEMPERATURE_MODE, DOMAIN,
set_temperature_mode, SERVICE_TEMPERATURE_MODE,
schema=SET_TEMPERATURE_MODE_SCHEMA) set_temperature_mode,
schema=SET_TEMPERATURE_MODE_SCHEMA,
)
class AmbiclimateEntity(ClimateDevice): class AmbiclimateEntity(ClimateDevice):
@ -141,11 +152,9 @@ class AmbiclimateEntity(ClimateDevice):
def device_info(self): def device_info(self):
"""Return the device info.""" """Return the device info."""
return { return {
'identifiers': { "identifiers": {(DOMAIN, self.unique_id)},
(DOMAIN, self.unique_id) "name": self.name,
}, "manufacturer": "Ambiclimate",
'name': self.name,
'manufacturer': 'Ambiclimate',
} }
@property @property
@ -156,7 +165,7 @@ class AmbiclimateEntity(ClimateDevice):
@property @property
def target_temperature(self): def target_temperature(self):
"""Return the target temperature.""" """Return the target temperature."""
return self._data.get('target_temperature') return self._data.get("target_temperature")
@property @property
def target_temperature_step(self): def target_temperature_step(self):
@ -166,12 +175,12 @@ class AmbiclimateEntity(ClimateDevice):
@property @property
def current_temperature(self): def current_temperature(self):
"""Return the current temperature.""" """Return the current temperature."""
return self._data.get('temperature') return self._data.get("temperature")
@property @property
def current_humidity(self): def current_humidity(self):
"""Return the current humidity.""" """Return the current humidity."""
return self._data.get('humidity') return self._data.get("humidity")
@property @property
def min_temp(self): def min_temp(self):
@ -196,7 +205,7 @@ class AmbiclimateEntity(ClimateDevice):
@property @property
def hvac_mode(self): def hvac_mode(self):
"""Return current operation.""" """Return current operation."""
if self._data.get('power', '').lower() == 'on': if self._data.get("power", "").lower() == "on":
return HVAC_MODE_HEAT return HVAC_MODE_HEAT
return HVAC_MODE_OFF return HVAC_MODE_OFF

View file

@ -7,10 +7,17 @@ from homeassistant import config_entries
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import (AUTH_CALLBACK_NAME, AUTH_CALLBACK_PATH, CONF_CLIENT_ID, from .const import (
CONF_CLIENT_SECRET, DOMAIN, STORAGE_VERSION, STORAGE_KEY) AUTH_CALLBACK_NAME,
AUTH_CALLBACK_PATH,
CONF_CLIENT_ID,
CONF_CLIENT_SECRET,
DOMAIN,
STORAGE_VERSION,
STORAGE_KEY,
)
DATA_AMBICLIMATE_IMPL = 'ambiclimate_flow_implementation' DATA_AMBICLIMATE_IMPL = "ambiclimate_flow_implementation"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -30,7 +37,7 @@ def register_flow_implementation(hass, client_id, client_secret):
} }
@config_entries.HANDLERS.register('ambiclimate') @config_entries.HANDLERS.register("ambiclimate")
class AmbiclimateFlowHandler(config_entries.ConfigFlow): class AmbiclimateFlowHandler(config_entries.ConfigFlow):
"""Handle a config flow.""" """Handle a config flow."""
@ -45,54 +52,52 @@ class AmbiclimateFlowHandler(config_entries.ConfigFlow):
async def async_step_user(self, user_input=None): async def async_step_user(self, user_input=None):
"""Handle external yaml configuration.""" """Handle external yaml configuration."""
if self.hass.config_entries.async_entries(DOMAIN): if self.hass.config_entries.async_entries(DOMAIN):
return self.async_abort(reason='already_setup') return self.async_abort(reason="already_setup")
config = self.hass.data.get(DATA_AMBICLIMATE_IMPL, {}) config = self.hass.data.get(DATA_AMBICLIMATE_IMPL, {})
if not config: if not config:
_LOGGER.debug("No config") _LOGGER.debug("No config")
return self.async_abort(reason='no_config') return self.async_abort(reason="no_config")
return await self.async_step_auth() return await self.async_step_auth()
async def async_step_auth(self, user_input=None): async def async_step_auth(self, user_input=None):
"""Handle a flow start.""" """Handle a flow start."""
if self.hass.config_entries.async_entries(DOMAIN): if self.hass.config_entries.async_entries(DOMAIN):
return self.async_abort(reason='already_setup') return self.async_abort(reason="already_setup")
errors = {} errors = {}
if user_input is not None: if user_input is not None:
errors['base'] = 'follow_link' errors["base"] = "follow_link"
if not self._registered_view: if not self._registered_view:
self._generate_view() self._generate_view()
return self.async_show_form( return self.async_show_form(
step_id='auth', step_id="auth",
description_placeholders={'authorization_url': description_placeholders={
await self._get_authorize_url(), "authorization_url": await self._get_authorize_url(),
'cb_url': self._cb_url()}, "cb_url": self._cb_url(),
},
errors=errors, errors=errors,
) )
async def async_step_code(self, code=None): async def async_step_code(self, code=None):
"""Received code for authentication.""" """Received code for authentication."""
if self.hass.config_entries.async_entries(DOMAIN): if self.hass.config_entries.async_entries(DOMAIN):
return self.async_abort(reason='already_setup') return self.async_abort(reason="already_setup")
token_info = await self._get_token_info(code) token_info = await self._get_token_info(code)
if token_info is None: if token_info is None:
return self.async_abort(reason='access_token') return self.async_abort(reason="access_token")
config = self.hass.data[DATA_AMBICLIMATE_IMPL].copy() config = self.hass.data[DATA_AMBICLIMATE_IMPL].copy()
config['callback_url'] = self._cb_url() config["callback_url"] = self._cb_url()
return self.async_create_entry( return self.async_create_entry(title="Ambiclimate", data=config)
title="Ambiclimate",
data=config,
)
async def _get_token_info(self, code): async def _get_token_info(self, code):
oauth = self._generate_oauth() oauth = self._generate_oauth()
@ -116,15 +121,16 @@ class AmbiclimateFlowHandler(config_entries.ConfigFlow):
clientsession = async_get_clientsession(self.hass) clientsession = async_get_clientsession(self.hass)
callback_url = self._cb_url() callback_url = self._cb_url()
oauth = ambiclimate.AmbiclimateOAuth(config.get(CONF_CLIENT_ID), oauth = ambiclimate.AmbiclimateOAuth(
config.get(CONF_CLIENT_SECRET), config.get(CONF_CLIENT_ID),
callback_url, config.get(CONF_CLIENT_SECRET),
clientsession) callback_url,
clientsession,
)
return oauth return oauth
def _cb_url(self): def _cb_url(self):
return '{}{}'.format(self.hass.config.api.base_url, return "{}{}".format(self.hass.config.api.base_url, AUTH_CALLBACK_PATH)
AUTH_CALLBACK_PATH)
async def _get_authorize_url(self): async def _get_authorize_url(self):
oauth = self._generate_oauth() oauth = self._generate_oauth()
@ -140,14 +146,13 @@ class AmbiclimateAuthCallbackView(HomeAssistantView):
async def get(self, request): async def get(self, request):
"""Receive authorization token.""" """Receive authorization token."""
code = request.query.get('code') code = request.query.get("code")
if code is None: if code is None:
return "No code" return "No code"
hass = request.app['hass'] hass = request.app["hass"]
hass.async_create_task( hass.async_create_task(
hass.config_entries.flow.async_init( hass.config_entries.flow.async_init(
DOMAIN, DOMAIN, context={"source": "code"}, data=code
context={'source': 'code'}, )
data=code, )
))
return "OK!" return "OK!"

View file

@ -1,14 +1,14 @@
"""Constants used by the Ambiclimate component.""" """Constants used by the Ambiclimate component."""
ATTR_VALUE = 'value' ATTR_VALUE = "value"
CONF_CLIENT_ID = 'client_id' CONF_CLIENT_ID = "client_id"
CONF_CLIENT_SECRET = 'client_secret' CONF_CLIENT_SECRET = "client_secret"
DOMAIN = 'ambiclimate' DOMAIN = "ambiclimate"
SERVICE_COMFORT_FEEDBACK = 'send_comfort_feedback' SERVICE_COMFORT_FEEDBACK = "send_comfort_feedback"
SERVICE_COMFORT_MODE = 'set_comfort_mode' SERVICE_COMFORT_MODE = "set_comfort_mode"
SERVICE_TEMPERATURE_MODE = 'set_temperature_mode' SERVICE_TEMPERATURE_MODE = "set_temperature_mode"
STORAGE_KEY = 'ambiclimate_auth' STORAGE_KEY = "ambiclimate_auth"
STORAGE_VERSION = 1 STORAGE_VERSION = 1
AUTH_CALLBACK_NAME = 'api:ambiclimate' AUTH_CALLBACK_NAME = "api:ambiclimate"
AUTH_CALLBACK_PATH = '/api/ambiclimate' AUTH_CALLBACK_PATH = "/api/ambiclimate"

View file

@ -7,220 +7,235 @@ import voluptuous as vol
from homeassistant.config_entries import SOURCE_IMPORT from homeassistant.config_entries import SOURCE_IMPORT
from homeassistant.const import ( from homeassistant.const import (
ATTR_NAME, ATTR_LOCATION, CONF_API_KEY, EVENT_HOMEASSISTANT_STOP) ATTR_NAME,
ATTR_LOCATION,
CONF_API_KEY,
EVENT_HOMEASSISTANT_STOP,
)
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import aiohttp_client, config_validation as cv from homeassistant.helpers import aiohttp_client, config_validation as cv
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_send) async_dispatcher_connect,
async_dispatcher_send,
)
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_call_later from homeassistant.helpers.event import async_call_later
from .config_flow import configured_instances from .config_flow import configured_instances
from .const import ( from .const import (
ATTR_LAST_DATA, CONF_APP_KEY, DATA_CLIENT, DOMAIN, TOPIC_UPDATE, ATTR_LAST_DATA,
TYPE_BINARY_SENSOR, TYPE_SENSOR) CONF_APP_KEY,
DATA_CLIENT,
DOMAIN,
TOPIC_UPDATE,
TYPE_BINARY_SENSOR,
TYPE_SENSOR,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DATA_CONFIG = 'config' DATA_CONFIG = "config"
DEFAULT_SOCKET_MIN_RETRY = 15 DEFAULT_SOCKET_MIN_RETRY = 15
DEFAULT_WATCHDOG_SECONDS = 5 * 60 DEFAULT_WATCHDOG_SECONDS = 5 * 60
TYPE_24HOURRAININ = '24hourrainin' TYPE_24HOURRAININ = "24hourrainin"
TYPE_BAROMABSIN = 'baromabsin' TYPE_BAROMABSIN = "baromabsin"
TYPE_BAROMRELIN = 'baromrelin' TYPE_BAROMRELIN = "baromrelin"
TYPE_BATT1 = 'batt1' TYPE_BATT1 = "batt1"
TYPE_BATT10 = 'batt10' TYPE_BATT10 = "batt10"
TYPE_BATT2 = 'batt2' TYPE_BATT2 = "batt2"
TYPE_BATT3 = 'batt3' TYPE_BATT3 = "batt3"
TYPE_BATT4 = 'batt4' TYPE_BATT4 = "batt4"
TYPE_BATT5 = 'batt5' TYPE_BATT5 = "batt5"
TYPE_BATT6 = 'batt6' TYPE_BATT6 = "batt6"
TYPE_BATT7 = 'batt7' TYPE_BATT7 = "batt7"
TYPE_BATT8 = 'batt8' TYPE_BATT8 = "batt8"
TYPE_BATT9 = 'batt9' TYPE_BATT9 = "batt9"
TYPE_BATTOUT = 'battout' TYPE_BATTOUT = "battout"
TYPE_CO2 = 'co2' TYPE_CO2 = "co2"
TYPE_DAILYRAININ = 'dailyrainin' TYPE_DAILYRAININ = "dailyrainin"
TYPE_DEWPOINT = 'dewPoint' TYPE_DEWPOINT = "dewPoint"
TYPE_EVENTRAININ = 'eventrainin' TYPE_EVENTRAININ = "eventrainin"
TYPE_FEELSLIKE = 'feelsLike' TYPE_FEELSLIKE = "feelsLike"
TYPE_HOURLYRAININ = 'hourlyrainin' TYPE_HOURLYRAININ = "hourlyrainin"
TYPE_HUMIDITY = 'humidity' TYPE_HUMIDITY = "humidity"
TYPE_HUMIDITY1 = 'humidity1' TYPE_HUMIDITY1 = "humidity1"
TYPE_HUMIDITY10 = 'humidity10' TYPE_HUMIDITY10 = "humidity10"
TYPE_HUMIDITY2 = 'humidity2' TYPE_HUMIDITY2 = "humidity2"
TYPE_HUMIDITY3 = 'humidity3' TYPE_HUMIDITY3 = "humidity3"
TYPE_HUMIDITY4 = 'humidity4' TYPE_HUMIDITY4 = "humidity4"
TYPE_HUMIDITY5 = 'humidity5' TYPE_HUMIDITY5 = "humidity5"
TYPE_HUMIDITY6 = 'humidity6' TYPE_HUMIDITY6 = "humidity6"
TYPE_HUMIDITY7 = 'humidity7' TYPE_HUMIDITY7 = "humidity7"
TYPE_HUMIDITY8 = 'humidity8' TYPE_HUMIDITY8 = "humidity8"
TYPE_HUMIDITY9 = 'humidity9' TYPE_HUMIDITY9 = "humidity9"
TYPE_HUMIDITYIN = 'humidityin' TYPE_HUMIDITYIN = "humidityin"
TYPE_LASTRAIN = 'lastRain' TYPE_LASTRAIN = "lastRain"
TYPE_MAXDAILYGUST = 'maxdailygust' TYPE_MAXDAILYGUST = "maxdailygust"
TYPE_MONTHLYRAININ = 'monthlyrainin' TYPE_MONTHLYRAININ = "monthlyrainin"
TYPE_RELAY1 = 'relay1' TYPE_RELAY1 = "relay1"
TYPE_RELAY10 = 'relay10' TYPE_RELAY10 = "relay10"
TYPE_RELAY2 = 'relay2' TYPE_RELAY2 = "relay2"
TYPE_RELAY3 = 'relay3' TYPE_RELAY3 = "relay3"
TYPE_RELAY4 = 'relay4' TYPE_RELAY4 = "relay4"
TYPE_RELAY5 = 'relay5' TYPE_RELAY5 = "relay5"
TYPE_RELAY6 = 'relay6' TYPE_RELAY6 = "relay6"
TYPE_RELAY7 = 'relay7' TYPE_RELAY7 = "relay7"
TYPE_RELAY8 = 'relay8' TYPE_RELAY8 = "relay8"
TYPE_RELAY9 = 'relay9' TYPE_RELAY9 = "relay9"
TYPE_SOILHUM1 = 'soilhum1' TYPE_SOILHUM1 = "soilhum1"
TYPE_SOILHUM10 = 'soilhum10' TYPE_SOILHUM10 = "soilhum10"
TYPE_SOILHUM2 = 'soilhum2' TYPE_SOILHUM2 = "soilhum2"
TYPE_SOILHUM3 = 'soilhum3' TYPE_SOILHUM3 = "soilhum3"
TYPE_SOILHUM4 = 'soilhum4' TYPE_SOILHUM4 = "soilhum4"
TYPE_SOILHUM5 = 'soilhum5' TYPE_SOILHUM5 = "soilhum5"
TYPE_SOILHUM6 = 'soilhum6' TYPE_SOILHUM6 = "soilhum6"
TYPE_SOILHUM7 = 'soilhum7' TYPE_SOILHUM7 = "soilhum7"
TYPE_SOILHUM8 = 'soilhum8' TYPE_SOILHUM8 = "soilhum8"
TYPE_SOILHUM9 = 'soilhum9' TYPE_SOILHUM9 = "soilhum9"
TYPE_SOILTEMP1F = 'soiltemp1f' TYPE_SOILTEMP1F = "soiltemp1f"
TYPE_SOILTEMP10F = 'soiltemp10f' TYPE_SOILTEMP10F = "soiltemp10f"
TYPE_SOILTEMP2F = 'soiltemp2f' TYPE_SOILTEMP2F = "soiltemp2f"
TYPE_SOILTEMP3F = 'soiltemp3f' TYPE_SOILTEMP3F = "soiltemp3f"
TYPE_SOILTEMP4F = 'soiltemp4f' TYPE_SOILTEMP4F = "soiltemp4f"
TYPE_SOILTEMP5F = 'soiltemp5f' TYPE_SOILTEMP5F = "soiltemp5f"
TYPE_SOILTEMP6F = 'soiltemp6f' TYPE_SOILTEMP6F = "soiltemp6f"
TYPE_SOILTEMP7F = 'soiltemp7f' TYPE_SOILTEMP7F = "soiltemp7f"
TYPE_SOILTEMP8F = 'soiltemp8f' TYPE_SOILTEMP8F = "soiltemp8f"
TYPE_SOILTEMP9F = 'soiltemp9f' TYPE_SOILTEMP9F = "soiltemp9f"
TYPE_SOLARRADIATION = 'solarradiation' TYPE_SOLARRADIATION = "solarradiation"
TYPE_SOLARRADIATION_LX = 'solarradiation_lx' TYPE_SOLARRADIATION_LX = "solarradiation_lx"
TYPE_TEMP10F = 'temp10f' TYPE_TEMP10F = "temp10f"
TYPE_TEMP1F = 'temp1f' TYPE_TEMP1F = "temp1f"
TYPE_TEMP2F = 'temp2f' TYPE_TEMP2F = "temp2f"
TYPE_TEMP3F = 'temp3f' TYPE_TEMP3F = "temp3f"
TYPE_TEMP4F = 'temp4f' TYPE_TEMP4F = "temp4f"
TYPE_TEMP5F = 'temp5f' TYPE_TEMP5F = "temp5f"
TYPE_TEMP6F = 'temp6f' TYPE_TEMP6F = "temp6f"
TYPE_TEMP7F = 'temp7f' TYPE_TEMP7F = "temp7f"
TYPE_TEMP8F = 'temp8f' TYPE_TEMP8F = "temp8f"
TYPE_TEMP9F = 'temp9f' TYPE_TEMP9F = "temp9f"
TYPE_TEMPF = 'tempf' TYPE_TEMPF = "tempf"
TYPE_TEMPINF = 'tempinf' TYPE_TEMPINF = "tempinf"
TYPE_TOTALRAININ = 'totalrainin' TYPE_TOTALRAININ = "totalrainin"
TYPE_UV = 'uv' TYPE_UV = "uv"
TYPE_WEEKLYRAININ = 'weeklyrainin' TYPE_WEEKLYRAININ = "weeklyrainin"
TYPE_WINDDIR = 'winddir' TYPE_WINDDIR = "winddir"
TYPE_WINDDIR_AVG10M = 'winddir_avg10m' TYPE_WINDDIR_AVG10M = "winddir_avg10m"
TYPE_WINDDIR_AVG2M = 'winddir_avg2m' TYPE_WINDDIR_AVG2M = "winddir_avg2m"
TYPE_WINDGUSTDIR = 'windgustdir' TYPE_WINDGUSTDIR = "windgustdir"
TYPE_WINDGUSTMPH = 'windgustmph' TYPE_WINDGUSTMPH = "windgustmph"
TYPE_WINDSPDMPH_AVG10M = 'windspdmph_avg10m' TYPE_WINDSPDMPH_AVG10M = "windspdmph_avg10m"
TYPE_WINDSPDMPH_AVG2M = 'windspdmph_avg2m' TYPE_WINDSPDMPH_AVG2M = "windspdmph_avg2m"
TYPE_WINDSPEEDMPH = 'windspeedmph' TYPE_WINDSPEEDMPH = "windspeedmph"
TYPE_YEARLYRAININ = 'yearlyrainin' TYPE_YEARLYRAININ = "yearlyrainin"
SENSOR_TYPES = { SENSOR_TYPES = {
TYPE_24HOURRAININ: ('24 Hr Rain', 'in', TYPE_SENSOR, None), TYPE_24HOURRAININ: ("24 Hr Rain", "in", TYPE_SENSOR, None),
TYPE_BAROMABSIN: ('Abs Pressure', 'inHg', TYPE_SENSOR, 'pressure'), TYPE_BAROMABSIN: ("Abs Pressure", "inHg", TYPE_SENSOR, "pressure"),
TYPE_BAROMRELIN: ('Rel Pressure', 'inHg', TYPE_SENSOR, 'pressure'), TYPE_BAROMRELIN: ("Rel Pressure", "inHg", TYPE_SENSOR, "pressure"),
TYPE_BATT10: ('Battery 10', None, TYPE_BINARY_SENSOR, 'battery'), TYPE_BATT10: ("Battery 10", None, TYPE_BINARY_SENSOR, "battery"),
TYPE_BATT1: ('Battery 1', None, TYPE_BINARY_SENSOR, 'battery'), TYPE_BATT1: ("Battery 1", None, TYPE_BINARY_SENSOR, "battery"),
TYPE_BATT2: ('Battery 2', None, TYPE_BINARY_SENSOR, 'battery'), TYPE_BATT2: ("Battery 2", None, TYPE_BINARY_SENSOR, "battery"),
TYPE_BATT3: ('Battery 3', None, TYPE_BINARY_SENSOR, 'battery'), TYPE_BATT3: ("Battery 3", None, TYPE_BINARY_SENSOR, "battery"),
TYPE_BATT4: ('Battery 4', None, TYPE_BINARY_SENSOR, 'battery'), TYPE_BATT4: ("Battery 4", None, TYPE_BINARY_SENSOR, "battery"),
TYPE_BATT5: ('Battery 5', None, TYPE_BINARY_SENSOR, 'battery'), TYPE_BATT5: ("Battery 5", None, TYPE_BINARY_SENSOR, "battery"),
TYPE_BATT6: ('Battery 6', None, TYPE_BINARY_SENSOR, 'battery'), TYPE_BATT6: ("Battery 6", None, TYPE_BINARY_SENSOR, "battery"),
TYPE_BATT7: ('Battery 7', None, TYPE_BINARY_SENSOR, 'battery'), TYPE_BATT7: ("Battery 7", None, TYPE_BINARY_SENSOR, "battery"),
TYPE_BATT8: ('Battery 8', None, TYPE_BINARY_SENSOR, 'battery'), TYPE_BATT8: ("Battery 8", None, TYPE_BINARY_SENSOR, "battery"),
TYPE_BATT9: ('Battery 9', None, TYPE_BINARY_SENSOR, 'battery'), TYPE_BATT9: ("Battery 9", None, TYPE_BINARY_SENSOR, "battery"),
TYPE_BATTOUT: ('Battery', None, TYPE_BINARY_SENSOR, 'battery'), TYPE_BATTOUT: ("Battery", None, TYPE_BINARY_SENSOR, "battery"),
TYPE_CO2: ('co2', 'ppm', TYPE_SENSOR, None), TYPE_CO2: ("co2", "ppm", TYPE_SENSOR, None),
TYPE_DAILYRAININ: ('Daily Rain', 'in', TYPE_SENSOR, None), TYPE_DAILYRAININ: ("Daily Rain", "in", TYPE_SENSOR, None),
TYPE_DEWPOINT: ('Dew Point', '°F', TYPE_SENSOR, 'temperature'), TYPE_DEWPOINT: ("Dew Point", "°F", TYPE_SENSOR, "temperature"),
TYPE_EVENTRAININ: ('Event Rain', 'in', TYPE_SENSOR, None), TYPE_EVENTRAININ: ("Event Rain", "in", TYPE_SENSOR, None),
TYPE_FEELSLIKE: ('Feels Like', '°F', TYPE_SENSOR, 'temperature'), TYPE_FEELSLIKE: ("Feels Like", "°F", TYPE_SENSOR, "temperature"),
TYPE_HOURLYRAININ: ('Hourly Rain Rate', 'in/hr', TYPE_SENSOR, None), TYPE_HOURLYRAININ: ("Hourly Rain Rate", "in/hr", TYPE_SENSOR, None),
TYPE_HUMIDITY10: ('Humidity 10', '%', TYPE_SENSOR, 'humidity'), TYPE_HUMIDITY10: ("Humidity 10", "%", TYPE_SENSOR, "humidity"),
TYPE_HUMIDITY1: ('Humidity 1', '%', TYPE_SENSOR, 'humidity'), TYPE_HUMIDITY1: ("Humidity 1", "%", TYPE_SENSOR, "humidity"),
TYPE_HUMIDITY2: ('Humidity 2', '%', TYPE_SENSOR, 'humidity'), TYPE_HUMIDITY2: ("Humidity 2", "%", TYPE_SENSOR, "humidity"),
TYPE_HUMIDITY3: ('Humidity 3', '%', TYPE_SENSOR, 'humidity'), TYPE_HUMIDITY3: ("Humidity 3", "%", TYPE_SENSOR, "humidity"),
TYPE_HUMIDITY4: ('Humidity 4', '%', TYPE_SENSOR, 'humidity'), TYPE_HUMIDITY4: ("Humidity 4", "%", TYPE_SENSOR, "humidity"),
TYPE_HUMIDITY5: ('Humidity 5', '%', TYPE_SENSOR, 'humidity'), TYPE_HUMIDITY5: ("Humidity 5", "%", TYPE_SENSOR, "humidity"),
TYPE_HUMIDITY6: ('Humidity 6', '%', TYPE_SENSOR, 'humidity'), TYPE_HUMIDITY6: ("Humidity 6", "%", TYPE_SENSOR, "humidity"),
TYPE_HUMIDITY7: ('Humidity 7', '%', TYPE_SENSOR, 'humidity'), TYPE_HUMIDITY7: ("Humidity 7", "%", TYPE_SENSOR, "humidity"),
TYPE_HUMIDITY8: ('Humidity 8', '%', TYPE_SENSOR, 'humidity'), TYPE_HUMIDITY8: ("Humidity 8", "%", TYPE_SENSOR, "humidity"),
TYPE_HUMIDITY9: ('Humidity 9', '%', TYPE_SENSOR, 'humidity'), TYPE_HUMIDITY9: ("Humidity 9", "%", TYPE_SENSOR, "humidity"),
TYPE_HUMIDITY: ('Humidity', '%', TYPE_SENSOR, 'humidity'), TYPE_HUMIDITY: ("Humidity", "%", TYPE_SENSOR, "humidity"),
TYPE_HUMIDITYIN: ('Humidity In', '%', TYPE_SENSOR, 'humidity'), TYPE_HUMIDITYIN: ("Humidity In", "%", TYPE_SENSOR, "humidity"),
TYPE_LASTRAIN: ('Last Rain', None, TYPE_SENSOR, 'timestamp'), TYPE_LASTRAIN: ("Last Rain", None, TYPE_SENSOR, "timestamp"),
TYPE_MAXDAILYGUST: ('Max Gust', 'mph', TYPE_SENSOR, None), TYPE_MAXDAILYGUST: ("Max Gust", "mph", TYPE_SENSOR, None),
TYPE_MONTHLYRAININ: ('Monthly Rain', 'in', TYPE_SENSOR, None), TYPE_MONTHLYRAININ: ("Monthly Rain", "in", TYPE_SENSOR, None),
TYPE_RELAY10: ('Relay 10', None, TYPE_BINARY_SENSOR, 'connectivity'), TYPE_RELAY10: ("Relay 10", None, TYPE_BINARY_SENSOR, "connectivity"),
TYPE_RELAY1: ('Relay 1', None, TYPE_BINARY_SENSOR, 'connectivity'), TYPE_RELAY1: ("Relay 1", None, TYPE_BINARY_SENSOR, "connectivity"),
TYPE_RELAY2: ('Relay 2', None, TYPE_BINARY_SENSOR, 'connectivity'), TYPE_RELAY2: ("Relay 2", None, TYPE_BINARY_SENSOR, "connectivity"),
TYPE_RELAY3: ('Relay 3', None, TYPE_BINARY_SENSOR, 'connectivity'), TYPE_RELAY3: ("Relay 3", None, TYPE_BINARY_SENSOR, "connectivity"),
TYPE_RELAY4: ('Relay 4', None, TYPE_BINARY_SENSOR, 'connectivity'), TYPE_RELAY4: ("Relay 4", None, TYPE_BINARY_SENSOR, "connectivity"),
TYPE_RELAY5: ('Relay 5', None, TYPE_BINARY_SENSOR, 'connectivity'), TYPE_RELAY5: ("Relay 5", None, TYPE_BINARY_SENSOR, "connectivity"),
TYPE_RELAY6: ('Relay 6', None, TYPE_BINARY_SENSOR, 'connectivity'), TYPE_RELAY6: ("Relay 6", None, TYPE_BINARY_SENSOR, "connectivity"),
TYPE_RELAY7: ('Relay 7', None, TYPE_BINARY_SENSOR, 'connectivity'), TYPE_RELAY7: ("Relay 7", None, TYPE_BINARY_SENSOR, "connectivity"),
TYPE_RELAY8: ('Relay 8', None, TYPE_BINARY_SENSOR, 'connectivity'), TYPE_RELAY8: ("Relay 8", None, TYPE_BINARY_SENSOR, "connectivity"),
TYPE_RELAY9: ('Relay 9', None, TYPE_BINARY_SENSOR, 'connectivity'), TYPE_RELAY9: ("Relay 9", None, TYPE_BINARY_SENSOR, "connectivity"),
TYPE_SOILHUM10: ('Soil Humidity 10', '%', TYPE_SENSOR, 'humidity'), TYPE_SOILHUM10: ("Soil Humidity 10", "%", TYPE_SENSOR, "humidity"),
TYPE_SOILHUM1: ('Soil Humidity 1', '%', TYPE_SENSOR, 'humidity'), TYPE_SOILHUM1: ("Soil Humidity 1", "%", TYPE_SENSOR, "humidity"),
TYPE_SOILHUM2: ('Soil Humidity 2', '%', TYPE_SENSOR, 'humidity'), TYPE_SOILHUM2: ("Soil Humidity 2", "%", TYPE_SENSOR, "humidity"),
TYPE_SOILHUM3: ('Soil Humidity 3', '%', TYPE_SENSOR, 'humidity'), TYPE_SOILHUM3: ("Soil Humidity 3", "%", TYPE_SENSOR, "humidity"),
TYPE_SOILHUM4: ('Soil Humidity 4', '%', TYPE_SENSOR, 'humidity'), TYPE_SOILHUM4: ("Soil Humidity 4", "%", TYPE_SENSOR, "humidity"),
TYPE_SOILHUM5: ('Soil Humidity 5', '%', TYPE_SENSOR, 'humidity'), TYPE_SOILHUM5: ("Soil Humidity 5", "%", TYPE_SENSOR, "humidity"),
TYPE_SOILHUM6: ('Soil Humidity 6', '%', TYPE_SENSOR, 'humidity'), TYPE_SOILHUM6: ("Soil Humidity 6", "%", TYPE_SENSOR, "humidity"),
TYPE_SOILHUM7: ('Soil Humidity 7', '%', TYPE_SENSOR, 'humidity'), TYPE_SOILHUM7: ("Soil Humidity 7", "%", TYPE_SENSOR, "humidity"),
TYPE_SOILHUM8: ('Soil Humidity 8', '%', TYPE_SENSOR, 'humidity'), TYPE_SOILHUM8: ("Soil Humidity 8", "%", TYPE_SENSOR, "humidity"),
TYPE_SOILHUM9: ('Soil Humidity 9', '%', TYPE_SENSOR, 'humidity'), TYPE_SOILHUM9: ("Soil Humidity 9", "%", TYPE_SENSOR, "humidity"),
TYPE_SOILTEMP10F: ('Soil Temp 10', '°F', TYPE_SENSOR, 'temperature'), TYPE_SOILTEMP10F: ("Soil Temp 10", "°F", TYPE_SENSOR, "temperature"),
TYPE_SOILTEMP1F: ('Soil Temp 1', '°F', TYPE_SENSOR, 'temperature'), TYPE_SOILTEMP1F: ("Soil Temp 1", "°F", TYPE_SENSOR, "temperature"),
TYPE_SOILTEMP2F: ('Soil Temp 2', '°F', TYPE_SENSOR, 'temperature'), TYPE_SOILTEMP2F: ("Soil Temp 2", "°F", TYPE_SENSOR, "temperature"),
TYPE_SOILTEMP3F: ('Soil Temp 3', '°F', TYPE_SENSOR, 'temperature'), TYPE_SOILTEMP3F: ("Soil Temp 3", "°F", TYPE_SENSOR, "temperature"),
TYPE_SOILTEMP4F: ('Soil Temp 4', '°F', TYPE_SENSOR, 'temperature'), TYPE_SOILTEMP4F: ("Soil Temp 4", "°F", TYPE_SENSOR, "temperature"),
TYPE_SOILTEMP5F: ('Soil Temp 5', '°F', TYPE_SENSOR, 'temperature'), TYPE_SOILTEMP5F: ("Soil Temp 5", "°F", TYPE_SENSOR, "temperature"),
TYPE_SOILTEMP6F: ('Soil Temp 6', '°F', TYPE_SENSOR, 'temperature'), TYPE_SOILTEMP6F: ("Soil Temp 6", "°F", TYPE_SENSOR, "temperature"),
TYPE_SOILTEMP7F: ('Soil Temp 7', '°F', TYPE_SENSOR, 'temperature'), TYPE_SOILTEMP7F: ("Soil Temp 7", "°F", TYPE_SENSOR, "temperature"),
TYPE_SOILTEMP8F: ('Soil Temp 8', '°F', TYPE_SENSOR, 'temperature'), TYPE_SOILTEMP8F: ("Soil Temp 8", "°F", TYPE_SENSOR, "temperature"),
TYPE_SOILTEMP9F: ('Soil Temp 9', '°F', TYPE_SENSOR, 'temperature'), TYPE_SOILTEMP9F: ("Soil Temp 9", "°F", TYPE_SENSOR, "temperature"),
TYPE_SOLARRADIATION: ('Solar Rad', 'W/m^2', TYPE_SENSOR, None), TYPE_SOLARRADIATION: ("Solar Rad", "W/m^2", TYPE_SENSOR, None),
TYPE_SOLARRADIATION_LX: ( TYPE_SOLARRADIATION_LX: ("Solar Rad (lx)", "lx", TYPE_SENSOR, "illuminance"),
'Solar Rad (lx)', 'lx', TYPE_SENSOR, 'illuminance'), TYPE_TEMP10F: ("Temp 10", "°F", TYPE_SENSOR, "temperature"),
TYPE_TEMP10F: ('Temp 10', '°F', TYPE_SENSOR, 'temperature'), TYPE_TEMP1F: ("Temp 1", "°F", TYPE_SENSOR, "temperature"),
TYPE_TEMP1F: ('Temp 1', '°F', TYPE_SENSOR, 'temperature'), TYPE_TEMP2F: ("Temp 2", "°F", TYPE_SENSOR, "temperature"),
TYPE_TEMP2F: ('Temp 2', '°F', TYPE_SENSOR, 'temperature'), TYPE_TEMP3F: ("Temp 3", "°F", TYPE_SENSOR, "temperature"),
TYPE_TEMP3F: ('Temp 3', '°F', TYPE_SENSOR, 'temperature'), TYPE_TEMP4F: ("Temp 4", "°F", TYPE_SENSOR, "temperature"),
TYPE_TEMP4F: ('Temp 4', '°F', TYPE_SENSOR, 'temperature'), TYPE_TEMP5F: ("Temp 5", "°F", TYPE_SENSOR, "temperature"),
TYPE_TEMP5F: ('Temp 5', '°F', TYPE_SENSOR, 'temperature'), TYPE_TEMP6F: ("Temp 6", "°F", TYPE_SENSOR, "temperature"),
TYPE_TEMP6F: ('Temp 6', '°F', TYPE_SENSOR, 'temperature'), TYPE_TEMP7F: ("Temp 7", "°F", TYPE_SENSOR, "temperature"),
TYPE_TEMP7F: ('Temp 7', '°F', TYPE_SENSOR, 'temperature'), TYPE_TEMP8F: ("Temp 8", "°F", TYPE_SENSOR, "temperature"),
TYPE_TEMP8F: ('Temp 8', '°F', TYPE_SENSOR, 'temperature'), TYPE_TEMP9F: ("Temp 9", "°F", TYPE_SENSOR, "temperature"),
TYPE_TEMP9F: ('Temp 9', '°F', TYPE_SENSOR, 'temperature'), TYPE_TEMPF: ("Temp", "°F", TYPE_SENSOR, "temperature"),
TYPE_TEMPF: ('Temp', '°F', TYPE_SENSOR, 'temperature'), TYPE_TEMPINF: ("Inside Temp", "°F", TYPE_SENSOR, "temperature"),
TYPE_TEMPINF: ('Inside Temp', '°F', TYPE_SENSOR, 'temperature'), TYPE_TOTALRAININ: ("Lifetime Rain", "in", TYPE_SENSOR, None),
TYPE_TOTALRAININ: ('Lifetime Rain', 'in', TYPE_SENSOR, None), TYPE_UV: ("uv", "Index", TYPE_SENSOR, None),
TYPE_UV: ('uv', 'Index', TYPE_SENSOR, None), TYPE_WEEKLYRAININ: ("Weekly Rain", "in", TYPE_SENSOR, None),
TYPE_WEEKLYRAININ: ('Weekly Rain', 'in', TYPE_SENSOR, None), TYPE_WINDDIR: ("Wind Dir", "°", TYPE_SENSOR, None),
TYPE_WINDDIR: ('Wind Dir', '°', TYPE_SENSOR, None), TYPE_WINDDIR_AVG10M: ("Wind Dir Avg 10m", "°", TYPE_SENSOR, None),
TYPE_WINDDIR_AVG10M: ('Wind Dir Avg 10m', '°', TYPE_SENSOR, None), TYPE_WINDDIR_AVG2M: ("Wind Dir Avg 2m", "mph", TYPE_SENSOR, None),
TYPE_WINDDIR_AVG2M: ('Wind Dir Avg 2m', 'mph', TYPE_SENSOR, None), TYPE_WINDGUSTDIR: ("Gust Dir", "°", TYPE_SENSOR, None),
TYPE_WINDGUSTDIR: ('Gust Dir', '°', TYPE_SENSOR, None), TYPE_WINDGUSTMPH: ("Wind Gust", "mph", TYPE_SENSOR, None),
TYPE_WINDGUSTMPH: ('Wind Gust', 'mph', TYPE_SENSOR, None), TYPE_WINDSPDMPH_AVG10M: ("Wind Avg 10m", "mph", TYPE_SENSOR, None),
TYPE_WINDSPDMPH_AVG10M: ('Wind Avg 10m', 'mph', TYPE_SENSOR, None), TYPE_WINDSPDMPH_AVG2M: ("Wind Avg 2m", "mph", TYPE_SENSOR, None),
TYPE_WINDSPDMPH_AVG2M: ('Wind Avg 2m', 'mph', TYPE_SENSOR, None), TYPE_WINDSPEEDMPH: ("Wind Speed", "mph", TYPE_SENSOR, None),
TYPE_WINDSPEEDMPH: ('Wind Speed', 'mph', TYPE_SENSOR, None), TYPE_YEARLYRAININ: ("Yearly Rain", "in", TYPE_SENSOR, None),
TYPE_YEARLYRAININ: ('Yearly Rain', 'in', TYPE_SENSOR, None),
} }
CONFIG_SCHEMA = vol.Schema({ CONFIG_SCHEMA = vol.Schema(
DOMAIN: {
vol.Schema({ DOMAIN: vol.Schema(
vol.Required(CONF_APP_KEY): cv.string, {
vol.Required(CONF_API_KEY): cv.string, vol.Required(CONF_APP_KEY): cv.string,
}) vol.Required(CONF_API_KEY): cv.string,
}, extra=vol.ALLOW_EXTRA) }
)
},
extra=vol.ALLOW_EXTRA,
)
async def async_setup(hass, config): async def async_setup(hass, config):
@ -242,11 +257,10 @@ async def async_setup(hass, config):
hass.async_create_task( hass.async_create_task(
hass.config_entries.flow.async_init( hass.config_entries.flow.async_init(
DOMAIN, DOMAIN,
context={'source': SOURCE_IMPORT}, context={"source": SOURCE_IMPORT},
data={ data={CONF_API_KEY: conf[CONF_API_KEY], CONF_APP_KEY: conf[CONF_APP_KEY]},
CONF_API_KEY: conf[CONF_API_KEY], )
CONF_APP_KEY: conf[CONF_APP_KEY] )
}))
return True return True
@ -257,18 +271,23 @@ async def async_setup_entry(hass, config_entry):
try: try:
ambient = AmbientStation( ambient = AmbientStation(
hass, config_entry, hass,
config_entry,
Client( Client(
config_entry.data[CONF_API_KEY], config_entry.data[CONF_API_KEY],
config_entry.data[CONF_APP_KEY], session)) config_entry.data[CONF_APP_KEY],
session,
),
)
hass.loop.create_task(ambient.ws_connect()) hass.loop.create_task(ambient.ws_connect())
hass.data[DOMAIN][DATA_CLIENT][config_entry.entry_id] = ambient hass.data[DOMAIN][DATA_CLIENT][config_entry.entry_id] = ambient
except WebsocketError as err: except WebsocketError as err:
_LOGGER.error('Config entry failed: %s', err) _LOGGER.error("Config entry failed: %s", err)
raise ConfigEntryNotReady raise ConfigEntryNotReady
hass.bus.async_listen_once( hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, ambient.client.websocket.disconnect()) EVENT_HOMEASSISTANT_STOP, ambient.client.websocket.disconnect()
)
return True return True
@ -278,9 +297,8 @@ async def async_unload_entry(hass, config_entry):
ambient = hass.data[DOMAIN][DATA_CLIENT].pop(config_entry.entry_id) ambient = hass.data[DOMAIN][DATA_CLIENT].pop(config_entry.entry_id)
hass.async_create_task(ambient.ws_disconnect()) hass.async_create_task(ambient.ws_disconnect())
for component in ('binary_sensor', 'sensor'): for component in ("binary_sensor", "sensor"):
await hass.config_entries.async_forward_entry_unload( await hass.config_entries.async_forward_entry_unload(config_entry, component)
config_entry, component)
return True return True
@ -289,7 +307,7 @@ async def async_migrate_entry(hass, config_entry):
"""Migrate old entry.""" """Migrate old entry."""
version = config_entry.version version = config_entry.version
_LOGGER.debug('Migrating from version %s', version) _LOGGER.debug("Migrating from version %s", version)
# 1 -> 2: Unique ID format changed, so delete and re-import: # 1 -> 2: Unique ID format changed, so delete and re-import:
if version == 1: if version == 1:
@ -302,7 +320,7 @@ async def async_migrate_entry(hass, config_entry):
version = config_entry.version = 2 version = config_entry.version = 2
hass.config_entries.async_update_entry(config_entry) hass.config_entries.async_update_entry(config_entry)
_LOGGER.info('Migration to version %s successful', version) _LOGGER.info("Migration to version %s successful", version)
return True return True
@ -327,71 +345,70 @@ class AmbientStation:
await self.client.websocket.connect() await self.client.websocket.connect()
except WebsocketError as err: except WebsocketError as err:
_LOGGER.error("Error with the websocket connection: %s", err) _LOGGER.error("Error with the websocket connection: %s", err)
self._ws_reconnect_delay = min( self._ws_reconnect_delay = min(2 * self._ws_reconnect_delay, 480)
2 * self._ws_reconnect_delay, 480) async_call_later(self._hass, self._ws_reconnect_delay, self.ws_connect)
async_call_later(
self._hass, self._ws_reconnect_delay, self.ws_connect)
async def ws_connect(self): async def ws_connect(self):
"""Register handlers and connect to the websocket.""" """Register handlers and connect to the websocket."""
async def _ws_reconnect(event_time): async def _ws_reconnect(event_time):
"""Forcibly disconnect from and reconnect to the websocket.""" """Forcibly disconnect from and reconnect to the websocket."""
_LOGGER.debug('Watchdog expired; forcing socket reconnection') _LOGGER.debug("Watchdog expired; forcing socket reconnection")
await self.client.websocket.disconnect() await self.client.websocket.disconnect()
await self._attempt_connect() await self._attempt_connect()
def on_connect(): def on_connect():
"""Define a handler to fire when the websocket is connected.""" """Define a handler to fire when the websocket is connected."""
_LOGGER.info('Connected to websocket') _LOGGER.info("Connected to websocket")
_LOGGER.debug('Watchdog starting') _LOGGER.debug("Watchdog starting")
if self._watchdog_listener is not None: if self._watchdog_listener is not None:
self._watchdog_listener() self._watchdog_listener()
self._watchdog_listener = async_call_later( self._watchdog_listener = async_call_later(
self._hass, DEFAULT_WATCHDOG_SECONDS, _ws_reconnect) self._hass, DEFAULT_WATCHDOG_SECONDS, _ws_reconnect
)
def on_data(data): def on_data(data):
"""Define a handler to fire when the data is received.""" """Define a handler to fire when the data is received."""
mac_address = data['macAddress'] mac_address = data["macAddress"]
if data != self.stations[mac_address][ATTR_LAST_DATA]: if data != self.stations[mac_address][ATTR_LAST_DATA]:
_LOGGER.debug('New data received: %s', data) _LOGGER.debug("New data received: %s", data)
self.stations[mac_address][ATTR_LAST_DATA] = data self.stations[mac_address][ATTR_LAST_DATA] = data
async_dispatcher_send(self._hass, TOPIC_UPDATE) async_dispatcher_send(self._hass, TOPIC_UPDATE)
_LOGGER.debug('Resetting watchdog') _LOGGER.debug("Resetting watchdog")
self._watchdog_listener() self._watchdog_listener()
self._watchdog_listener = async_call_later( self._watchdog_listener = async_call_later(
self._hass, DEFAULT_WATCHDOG_SECONDS, _ws_reconnect) self._hass, DEFAULT_WATCHDOG_SECONDS, _ws_reconnect
)
def on_disconnect(): def on_disconnect():
"""Define a handler to fire when the websocket is disconnected.""" """Define a handler to fire when the websocket is disconnected."""
_LOGGER.info('Disconnected from websocket') _LOGGER.info("Disconnected from websocket")
def on_subscribed(data): def on_subscribed(data):
"""Define a handler to fire when the subscription is set.""" """Define a handler to fire when the subscription is set."""
for station in data['devices']: for station in data["devices"]:
if station['macAddress'] in self.stations: if station["macAddress"] in self.stations:
continue continue
_LOGGER.debug('New station subscription: %s', data) _LOGGER.debug("New station subscription: %s", data)
self.monitored_conditions = [ self.monitored_conditions = [
k for k in station['lastData'] k for k in station["lastData"] if k in SENSOR_TYPES
if k in SENSOR_TYPES
] ]
# If the user is monitoring brightness (in W/m^2), # If the user is monitoring brightness (in W/m^2),
# make sure we also add a calculated sensor for the # make sure we also add a calculated sensor for the
# same data measured in lx: # same data measured in lx:
if TYPE_SOLARRADIATION in self.monitored_conditions: if TYPE_SOLARRADIATION in self.monitored_conditions:
self.monitored_conditions.append( self.monitored_conditions.append(TYPE_SOLARRADIATION_LX)
TYPE_SOLARRADIATION_LX)
self.stations[station['macAddress']] = { self.stations[station["macAddress"]] = {
ATTR_LAST_DATA: station['lastData'], ATTR_LAST_DATA: station["lastData"],
ATTR_LOCATION: station.get('info', {}).get('location'), ATTR_LOCATION: station.get("info", {}).get("location"),
ATTR_NAME: ATTR_NAME: station.get("info", {}).get(
station.get('info', {}).get( "name", station["macAddress"]
'name', station['macAddress']), ),
} }
# If the websocket disconnects and reconnects, the on_subscribed # If the websocket disconnects and reconnects, the on_subscribed
@ -399,10 +416,12 @@ class AmbientStation:
# attempt forward setup of the config entry (because it will have # attempt forward setup of the config entry (because it will have
# already been done): # already been done):
if not self._entry_setup_complete: if not self._entry_setup_complete:
for component in ('binary_sensor', 'sensor'): for component in ("binary_sensor", "sensor"):
self._hass.async_create_task( self._hass.async_create_task(
self._hass.config_entries.async_forward_entry_setup( self._hass.config_entries.async_forward_entry_setup(
self._config_entry, component)) self._config_entry, component
)
)
self._entry_setup_complete = True self._entry_setup_complete = True
self._ws_reconnect_delay = DEFAULT_SOCKET_MIN_RETRY self._ws_reconnect_delay = DEFAULT_SOCKET_MIN_RETRY
@ -423,8 +442,8 @@ class AmbientWeatherEntity(Entity):
"""Define a base Ambient PWS entity.""" """Define a base Ambient PWS entity."""
def __init__( def __init__(
self, ambient, mac_address, station_name, sensor_type, self, ambient, mac_address, station_name, sensor_type, sensor_name, device_class
sensor_name, device_class): ):
"""Initialize the sensor.""" """Initialize the sensor."""
self._ambient = ambient self._ambient = ambient
self._device_class = device_class self._device_class = device_class
@ -443,10 +462,18 @@ class AmbientWeatherEntity(Entity):
# solarradiation_lx sensor shows as available if the solarradiation # solarradiation_lx sensor shows as available if the solarradiation
# sensor is available: # sensor is available:
if self._sensor_type == TYPE_SOLARRADIATION_LX: if self._sensor_type == TYPE_SOLARRADIATION_LX:
return self._ambient.stations[self._mac_address][ return (
ATTR_LAST_DATA].get(TYPE_SOLARRADIATION) is not None self._ambient.stations[self._mac_address][ATTR_LAST_DATA].get(
return self._ambient.stations[self._mac_address][ATTR_LAST_DATA].get( TYPE_SOLARRADIATION
self._sensor_type) is not None )
is not None
)
return (
self._ambient.stations[self._mac_address][ATTR_LAST_DATA].get(
self._sensor_type
)
is not None
)
@property @property
def device_class(self): def device_class(self):
@ -457,17 +484,15 @@ class AmbientWeatherEntity(Entity):
def device_info(self): def device_info(self):
"""Return device registry information for this entity.""" """Return device registry information for this entity."""
return { return {
'identifiers': { "identifiers": {(DOMAIN, self._mac_address)},
(DOMAIN, self._mac_address) "name": self._station_name,
}, "manufacturer": "Ambient Weather",
'name': self._station_name,
'manufacturer': 'Ambient Weather',
} }
@property @property
def name(self): def name(self):
"""Return the name of the sensor.""" """Return the name of the sensor."""
return '{0}_{1}'.format(self._station_name, self._sensor_name) return "{0}_{1}".format(self._station_name, self._sensor_name)
@property @property
def should_poll(self): def should_poll(self):
@ -477,17 +502,19 @@ class AmbientWeatherEntity(Entity):
@property @property
def unique_id(self): def unique_id(self):
"""Return a unique, unchanging string that represents this sensor.""" """Return a unique, unchanging string that represents this sensor."""
return '{0}_{1}'.format(self._mac_address, self._sensor_type) return "{0}_{1}".format(self._mac_address, self._sensor_type)
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Register callbacks.""" """Register callbacks."""
@callback @callback
def update(): def update():
"""Update the state.""" """Update the state."""
self.async_schedule_update_ha_state(True) self.async_schedule_update_ha_state(True)
self._async_unsub_dispatcher_connect = async_dispatcher_connect( self._async_unsub_dispatcher_connect = async_dispatcher_connect(
self.hass, TOPIC_UPDATE, update) self.hass, TOPIC_UPDATE, update
)
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
"""Disconnect dispatcher listener when removed.""" """Disconnect dispatcher listener when removed."""

View file

@ -5,16 +5,26 @@ from homeassistant.components.binary_sensor import BinarySensorDevice
from homeassistant.const import ATTR_NAME from homeassistant.const import ATTR_NAME
from . import ( from . import (
SENSOR_TYPES, TYPE_BATT1, TYPE_BATT2, TYPE_BATT3, TYPE_BATT4, TYPE_BATT5, SENSOR_TYPES,
TYPE_BATT6, TYPE_BATT7, TYPE_BATT8, TYPE_BATT9, TYPE_BATT10, TYPE_BATTOUT, TYPE_BATT1,
AmbientWeatherEntity) TYPE_BATT2,
TYPE_BATT3,
TYPE_BATT4,
TYPE_BATT5,
TYPE_BATT6,
TYPE_BATT7,
TYPE_BATT8,
TYPE_BATT9,
TYPE_BATT10,
TYPE_BATTOUT,
AmbientWeatherEntity,
)
from .const import ATTR_LAST_DATA, DATA_CLIENT, DOMAIN, TYPE_BINARY_SENSOR from .const import ATTR_LAST_DATA, DATA_CLIENT, DOMAIN, TYPE_BINARY_SENSOR
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def async_setup_platform( async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
hass, config, async_add_entities, discovery_info=None):
"""Set up Ambient PWS binary sensors based on the old way.""" """Set up Ambient PWS binary sensors based on the old way."""
pass pass
@ -30,8 +40,14 @@ async def async_setup_entry(hass, entry, async_add_entities):
if kind == TYPE_BINARY_SENSOR: if kind == TYPE_BINARY_SENSOR:
binary_sensor_list.append( binary_sensor_list.append(
AmbientWeatherBinarySensor( AmbientWeatherBinarySensor(
ambient, mac_address, station[ATTR_NAME], condition, ambient,
name, device_class)) mac_address,
station[ATTR_NAME],
condition,
name,
device_class,
)
)
async_add_entities(binary_sensor_list, True) async_add_entities(binary_sensor_list, True)
@ -42,15 +58,25 @@ class AmbientWeatherBinarySensor(AmbientWeatherEntity, BinarySensorDevice):
@property @property
def is_on(self): def is_on(self):
"""Return the status of the sensor.""" """Return the status of the sensor."""
if self._sensor_type in (TYPE_BATT1, TYPE_BATT10, TYPE_BATT2, if self._sensor_type in (
TYPE_BATT3, TYPE_BATT4, TYPE_BATT5, TYPE_BATT1,
TYPE_BATT6, TYPE_BATT7, TYPE_BATT8, TYPE_BATT10,
TYPE_BATT9, TYPE_BATTOUT): TYPE_BATT2,
TYPE_BATT3,
TYPE_BATT4,
TYPE_BATT5,
TYPE_BATT6,
TYPE_BATT7,
TYPE_BATT8,
TYPE_BATT9,
TYPE_BATTOUT,
):
return self._state == 0 return self._state == 0
return self._state == 1 return self._state == 1
async def async_update(self): async def async_update(self):
"""Fetch new state data for the entity.""" """Fetch new state data for the entity."""
self._state = self._ambient.stations[ self._state = self._ambient.stations[self._mac_address][ATTR_LAST_DATA].get(
self._mac_address][ATTR_LAST_DATA].get(self._sensor_type) self._sensor_type
)

View file

@ -13,8 +13,8 @@ from .const import CONF_APP_KEY, DOMAIN
def configured_instances(hass): def configured_instances(hass):
"""Return a set of configured Ambient PWS instances.""" """Return a set of configured Ambient PWS instances."""
return set( return set(
entry.data[CONF_APP_KEY] entry.data[CONF_APP_KEY] for entry in hass.config_entries.async_entries(DOMAIN)
for entry in hass.config_entries.async_entries(DOMAIN)) )
@config_entries.HANDLERS.register(DOMAIN) @config_entries.HANDLERS.register(DOMAIN)
@ -26,15 +26,12 @@ class AmbientStationFlowHandler(config_entries.ConfigFlow):
async def _show_form(self, errors=None): async def _show_form(self, errors=None):
"""Show the form to the user.""" """Show the form to the user."""
data_schema = vol.Schema({ data_schema = vol.Schema(
vol.Required(CONF_API_KEY): str, {vol.Required(CONF_API_KEY): str, vol.Required(CONF_APP_KEY): str}
vol.Required(CONF_APP_KEY): str, )
})
return self.async_show_form( return self.async_show_form(
step_id='user', step_id="user", data_schema=data_schema, errors=errors if errors else {}
data_schema=data_schema,
errors=errors if errors else {},
) )
async def async_step_import(self, import_config): async def async_step_import(self, import_config):
@ -50,22 +47,22 @@ class AmbientStationFlowHandler(config_entries.ConfigFlow):
return await self._show_form() return await self._show_form()
if user_input[CONF_APP_KEY] in configured_instances(self.hass): if user_input[CONF_APP_KEY] in configured_instances(self.hass):
return await self._show_form({CONF_APP_KEY: 'identifier_exists'}) return await self._show_form({CONF_APP_KEY: "identifier_exists"})
session = aiohttp_client.async_get_clientsession(self.hass) session = aiohttp_client.async_get_clientsession(self.hass)
client = Client( client = Client(user_input[CONF_API_KEY], user_input[CONF_APP_KEY], session)
user_input[CONF_API_KEY], user_input[CONF_APP_KEY], session)
try: try:
devices = await client.api.get_devices() devices = await client.api.get_devices()
except AmbientError: except AmbientError:
return await self._show_form({'base': 'invalid_key'}) return await self._show_form({"base": "invalid_key"})
if not devices: if not devices:
return await self._show_form({'base': 'no_devices'}) return await self._show_form({"base": "no_devices"})
# The Application Key (which identifies each config entry) is too long # The Application Key (which identifies each config entry) is too long
# to show nicely in the UI, so we take the first 12 characters (similar # to show nicely in the UI, so we take the first 12 characters (similar
# to how GitHub does it): # to how GitHub does it):
return self.async_create_entry( return self.async_create_entry(
title=user_input[CONF_APP_KEY][:12], data=user_input) title=user_input[CONF_APP_KEY][:12], data=user_input
)

View file

@ -1,13 +1,13 @@
"""Define constants for the Ambient PWS component.""" """Define constants for the Ambient PWS component."""
DOMAIN = 'ambient_station' DOMAIN = "ambient_station"
ATTR_LAST_DATA = 'last_data' ATTR_LAST_DATA = "last_data"
CONF_APP_KEY = 'app_key' CONF_APP_KEY = "app_key"
DATA_CLIENT = 'data_client' DATA_CLIENT = "data_client"
TOPIC_UPDATE = 'update' TOPIC_UPDATE = "update"
TYPE_BINARY_SENSOR = 'binary_sensor' TYPE_BINARY_SENSOR = "binary_sensor"
TYPE_SENSOR = 'sensor' TYPE_SENSOR = "sensor"

View file

@ -4,15 +4,17 @@ import logging
from homeassistant.const import ATTR_NAME from homeassistant.const import ATTR_NAME
from . import ( from . import (
SENSOR_TYPES, TYPE_SOLARRADIATION, TYPE_SOLARRADIATION_LX, SENSOR_TYPES,
AmbientWeatherEntity) TYPE_SOLARRADIATION,
TYPE_SOLARRADIATION_LX,
AmbientWeatherEntity,
)
from .const import ATTR_LAST_DATA, DATA_CLIENT, DOMAIN, TYPE_SENSOR from .const import ATTR_LAST_DATA, DATA_CLIENT, DOMAIN, TYPE_SENSOR
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
async def async_setup_platform( async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
hass, config, async_add_entities, discovery_info=None):
"""Set up Ambient PWS sensors based on existing config.""" """Set up Ambient PWS sensors based on existing config."""
pass pass
@ -28,8 +30,15 @@ async def async_setup_entry(hass, entry, async_add_entities):
if kind == TYPE_SENSOR: if kind == TYPE_SENSOR:
sensor_list.append( sensor_list.append(
AmbientWeatherSensor( AmbientWeatherSensor(
ambient, mac_address, station[ATTR_NAME], condition, ambient,
name, device_class, unit)) mac_address,
station[ATTR_NAME],
condition,
name,
device_class,
unit,
)
)
async_add_entities(sensor_list, True) async_add_entities(sensor_list, True)
@ -38,16 +47,19 @@ class AmbientWeatherSensor(AmbientWeatherEntity):
"""Define an Ambient sensor.""" """Define an Ambient sensor."""
def __init__( def __init__(
self, ambient, mac_address, station_name, sensor_type, sensor_name, self,
device_class, unit): ambient,
mac_address,
station_name,
sensor_type,
sensor_name,
device_class,
unit,
):
"""Initialize the sensor.""" """Initialize the sensor."""
super().__init__( super().__init__(
ambient, ambient, mac_address, station_name, sensor_type, sensor_name, device_class
mac_address, )
station_name,
sensor_type,
sensor_name,
device_class)
self._unit = unit self._unit = unit
@ -67,9 +79,11 @@ class AmbientWeatherSensor(AmbientWeatherEntity):
# If the user requests the solarradiation_lx sensor, use the # If the user requests the solarradiation_lx sensor, use the
# value of the solarradiation sensor and apply a very accurate # value of the solarradiation sensor and apply a very accurate
# approximation of converting sunlight W/m^2 to lx: # approximation of converting sunlight W/m^2 to lx:
w_m2_brightness_val = self._ambient.stations[ w_m2_brightness_val = self._ambient.stations[self._mac_address][
self._mac_address][ATTR_LAST_DATA].get(TYPE_SOLARRADIATION) ATTR_LAST_DATA
self._state = round(float(w_m2_brightness_val)/0.0079) ].get(TYPE_SOLARRADIATION)
self._state = round(float(w_m2_brightness_val) / 0.0079)
else: else:
self._state = self._ambient.stations[ self._state = self._ambient.stations[self._mac_address][ATTR_LAST_DATA].get(
self._mac_address][ATTR_LAST_DATA].get(self._sensor_type) self._sensor_type
)

View file

@ -13,14 +13,24 @@ from homeassistant.components.camera import DOMAIN as CAMERA
from homeassistant.components.sensor import DOMAIN as SENSOR from homeassistant.components.sensor import DOMAIN as SENSOR
from homeassistant.components.switch import DOMAIN as SWITCH from homeassistant.components.switch import DOMAIN as SWITCH
from homeassistant.const import ( from homeassistant.const import (
ATTR_ENTITY_ID, CONF_AUTHENTICATION, CONF_BINARY_SENSORS, CONF_HOST, ATTR_ENTITY_ID,
CONF_NAME, CONF_PASSWORD, CONF_PORT, CONF_SCAN_INTERVAL, CONF_SENSORS, CONF_AUTHENTICATION,
CONF_SWITCHES, CONF_USERNAME, ENTITY_MATCH_ALL, HTTP_BASIC_AUTHENTICATION) CONF_BINARY_SENSORS,
CONF_HOST,
CONF_NAME,
CONF_PASSWORD,
CONF_PORT,
CONF_SCAN_INTERVAL,
CONF_SENSORS,
CONF_SWITCHES,
CONF_USERNAME,
ENTITY_MATCH_ALL,
HTTP_BASIC_AUTHENTICATION,
)
from homeassistant.exceptions import Unauthorized, UnknownUser from homeassistant.exceptions import Unauthorized, UnknownUser
from homeassistant.helpers import discovery from homeassistant.helpers import discovery
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import async_dispatcher_send, dispatcher_send
async_dispatcher_send, dispatcher_send)
from homeassistant.helpers.event import track_time_interval from homeassistant.helpers.event import track_time_interval
from homeassistant.helpers.service import async_extract_entity_ids from homeassistant.helpers.service import async_extract_entity_ids
@ -33,31 +43,26 @@ from .switch import SWITCHES
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
CONF_RESOLUTION = 'resolution' CONF_RESOLUTION = "resolution"
CONF_STREAM_SOURCE = 'stream_source' CONF_STREAM_SOURCE = "stream_source"
CONF_FFMPEG_ARGUMENTS = 'ffmpeg_arguments' CONF_FFMPEG_ARGUMENTS = "ffmpeg_arguments"
CONF_CONTROL_LIGHT = 'control_light' CONF_CONTROL_LIGHT = "control_light"
DEFAULT_NAME = 'Amcrest Camera' DEFAULT_NAME = "Amcrest Camera"
DEFAULT_PORT = 80 DEFAULT_PORT = 80
DEFAULT_RESOLUTION = 'high' DEFAULT_RESOLUTION = "high"
DEFAULT_ARGUMENTS = '-pred 1' DEFAULT_ARGUMENTS = "-pred 1"
MAX_ERRORS = 5 MAX_ERRORS = 5
RECHECK_INTERVAL = timedelta(minutes=1) RECHECK_INTERVAL = timedelta(minutes=1)
NOTIFICATION_ID = 'amcrest_notification' NOTIFICATION_ID = "amcrest_notification"
NOTIFICATION_TITLE = 'Amcrest Camera Setup' NOTIFICATION_TITLE = "Amcrest Camera Setup"
RESOLUTION_LIST = { RESOLUTION_LIST = {"high": 0, "low": 1}
'high': 0,
'low': 1,
}
SCAN_INTERVAL = timedelta(seconds=10) SCAN_INTERVAL = timedelta(seconds=10)
AUTHENTICATION_LIST = { AUTHENTICATION_LIST = {"basic": "basic"}
'basic': 'basic'
}
def _deprecated_sensor_values(sensors): def _deprecated_sensor_values(sensors):
@ -66,8 +71,11 @@ def _deprecated_sensor_values(sensors):
"The '%s' option value '%s' is deprecated, " "The '%s' option value '%s' is deprecated, "
"please remove it from your configuration and use " "please remove it from your configuration and use "
"the '%s' option with value '%s' instead", "the '%s' option with value '%s' instead",
CONF_SENSORS, SENSOR_MOTION_DETECTOR, CONF_BINARY_SENSORS, CONF_SENSORS,
BINARY_SENSOR_MOTION_DETECTED) SENSOR_MOTION_DETECTOR,
CONF_BINARY_SENSORS,
BINARY_SENSOR_MOTION_DETECTED,
)
return sensors return sensors
@ -77,7 +85,9 @@ def _deprecated_switches(config):
"The '%s' option (with value %s) is deprecated, " "The '%s' option (with value %s) is deprecated, "
"please remove it from your configuration and use " "please remove it from your configuration and use "
"services and attributes instead", "services and attributes instead",
CONF_SWITCHES, config[CONF_SWITCHES]) CONF_SWITCHES,
config[CONF_SWITCHES],
)
return config return config
@ -88,37 +98,41 @@ def _has_unique_names(devices):
AMCREST_SCHEMA = vol.All( AMCREST_SCHEMA = vol.All(
vol.Schema({ vol.Schema(
vol.Required(CONF_HOST): cv.string, {
vol.Required(CONF_USERNAME): cv.string, vol.Required(CONF_HOST): cv.string,
vol.Required(CONF_PASSWORD): cv.string, vol.Required(CONF_USERNAME): cv.string,
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Required(CONF_PASSWORD): cv.string,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Optional(CONF_AUTHENTICATION, default=HTTP_BASIC_AUTHENTICATION): vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
vol.All(vol.In(AUTHENTICATION_LIST)), vol.Optional(
vol.Optional(CONF_RESOLUTION, default=DEFAULT_RESOLUTION): CONF_AUTHENTICATION, default=HTTP_BASIC_AUTHENTICATION
vol.All(vol.In(RESOLUTION_LIST)), ): vol.All(vol.In(AUTHENTICATION_LIST)),
vol.Optional(CONF_STREAM_SOURCE, default=STREAM_SOURCE_LIST[0]): vol.Optional(CONF_RESOLUTION, default=DEFAULT_RESOLUTION): vol.All(
vol.All(vol.In(STREAM_SOURCE_LIST)), vol.In(RESOLUTION_LIST)
vol.Optional(CONF_FFMPEG_ARGUMENTS, default=DEFAULT_ARGUMENTS): ),
cv.string, vol.Optional(CONF_STREAM_SOURCE, default=STREAM_SOURCE_LIST[0]): vol.All(
vol.Optional(CONF_SCAN_INTERVAL, default=SCAN_INTERVAL): vol.In(STREAM_SOURCE_LIST)
cv.time_period, ),
vol.Optional(CONF_BINARY_SENSORS): vol.Optional(CONF_FFMPEG_ARGUMENTS, default=DEFAULT_ARGUMENTS): cv.string,
vol.All(cv.ensure_list, [vol.In(BINARY_SENSORS)]), vol.Optional(CONF_SCAN_INTERVAL, default=SCAN_INTERVAL): cv.time_period,
vol.Optional(CONF_SENSORS): vol.Optional(CONF_BINARY_SENSORS): vol.All(
vol.All(cv.ensure_list, [vol.In(SENSORS)], cv.ensure_list, [vol.In(BINARY_SENSORS)]
_deprecated_sensor_values), ),
vol.Optional(CONF_SWITCHES): vol.Optional(CONF_SENSORS): vol.All(
vol.All(cv.ensure_list, [vol.In(SWITCHES)]), cv.ensure_list, [vol.In(SENSORS)], _deprecated_sensor_values
vol.Optional(CONF_CONTROL_LIGHT, default=True): cv.boolean, ),
}), vol.Optional(CONF_SWITCHES): vol.All(cv.ensure_list, [vol.In(SWITCHES)]),
_deprecated_switches vol.Optional(CONF_CONTROL_LIGHT, default=True): cv.boolean,
}
),
_deprecated_switches,
) )
CONFIG_SCHEMA = vol.Schema({ CONFIG_SCHEMA = vol.Schema(
DOMAIN: vol.All(cv.ensure_list, [AMCREST_SCHEMA], _has_unique_names) {DOMAIN: vol.All(cv.ensure_list, [AMCREST_SCHEMA], _has_unique_names)},
}, extra=vol.ALLOW_EXTRA) extra=vol.ALLOW_EXTRA,
)
# pylint: disable=too-many-ancestors # pylint: disable=too-many-ancestors
@ -132,8 +146,9 @@ class AmcrestChecker(Http):
self._wrap_errors = 0 self._wrap_errors = 0
self._wrap_lock = threading.Lock() self._wrap_lock = threading.Lock()
self._unsub_recheck = None self._unsub_recheck = None
super().__init__(host, port, user, password, retries_connection=1, super().__init__(
timeout_protocol=3.05) host, port, user, password, retries_connection=1, timeout_protocol=3.05
)
@property @property
def available(self): def available(self):
@ -148,17 +163,16 @@ class AmcrestChecker(Http):
with self._wrap_lock: with self._wrap_lock:
was_online = self.available was_online = self.available
self._wrap_errors += 1 self._wrap_errors += 1
_LOGGER.debug('%s camera errs: %i', self._wrap_name, _LOGGER.debug("%s camera errs: %i", self._wrap_name, self._wrap_errors)
self._wrap_errors)
offline = not self.available offline = not self.available
if offline and was_online: if offline and was_online:
_LOGGER.error( _LOGGER.error("%s camera offline: Too many errors", self._wrap_name)
'%s camera offline: Too many errors', self._wrap_name)
dispatcher_send( dispatcher_send(
self._hass, self._hass, service_signal(SERVICE_UPDATE, self._wrap_name)
service_signal(SERVICE_UPDATE, self._wrap_name)) )
self._unsub_recheck = track_time_interval( self._unsub_recheck = track_time_interval(
self._hass, self._wrap_test_online, RECHECK_INTERVAL) self._hass, self._wrap_test_online, RECHECK_INTERVAL
)
raise raise
with self._wrap_lock: with self._wrap_lock:
was_offline = not self.available was_offline = not self.available
@ -166,9 +180,8 @@ class AmcrestChecker(Http):
if was_offline: if was_offline:
self._unsub_recheck() self._unsub_recheck()
self._unsub_recheck = None self._unsub_recheck = None
_LOGGER.error('%s camera back online', self._wrap_name) _LOGGER.error("%s camera back online", self._wrap_name)
dispatcher_send( dispatcher_send(self._hass, service_signal(SERVICE_UPDATE, self._wrap_name))
self._hass, service_signal(SERVICE_UPDATE, self._wrap_name))
return ret return ret
def _wrap_test_online(self, now): def _wrap_test_online(self, now):
@ -190,9 +203,8 @@ def setup(hass, config):
try: try:
api = AmcrestChecker( api = AmcrestChecker(
hass, name, hass, name, device[CONF_HOST], device[CONF_PORT], username, password
device[CONF_HOST], device[CONF_PORT], )
username, password)
except LoginError as ex: except LoginError as ex:
_LOGGER.error("Login error for %s camera: %s", name, ex) _LOGGER.error("Login error for %s camera: %s", name, ex)
@ -214,41 +226,40 @@ def setup(hass, config):
authentication = None authentication = None
hass.data[DATA_AMCREST][DEVICES][name] = AmcrestDevice( hass.data[DATA_AMCREST][DEVICES][name] = AmcrestDevice(
api, authentication, ffmpeg_arguments, stream_source, api,
resolution, control_light) authentication,
ffmpeg_arguments,
stream_source,
resolution,
control_light,
)
discovery.load_platform( discovery.load_platform(hass, CAMERA, DOMAIN, {CONF_NAME: name}, config)
hass, CAMERA, DOMAIN, {
CONF_NAME: name,
}, config)
if binary_sensors: if binary_sensors:
discovery.load_platform( discovery.load_platform(
hass, BINARY_SENSOR, DOMAIN, { hass,
CONF_NAME: name, BINARY_SENSOR,
CONF_BINARY_SENSORS: binary_sensors DOMAIN,
}, config) {CONF_NAME: name, CONF_BINARY_SENSORS: binary_sensors},
config,
)
if sensors: if sensors:
discovery.load_platform( discovery.load_platform(
hass, SENSOR, DOMAIN, { hass, SENSOR, DOMAIN, {CONF_NAME: name, CONF_SENSORS: sensors}, config
CONF_NAME: name, )
CONF_SENSORS: sensors,
}, config)
if switches: if switches:
discovery.load_platform( discovery.load_platform(
hass, SWITCH, DOMAIN, { hass, SWITCH, DOMAIN, {CONF_NAME: name, CONF_SWITCHES: switches}, config
CONF_NAME: name, )
CONF_SWITCHES: switches
}, config)
if not hass.data[DATA_AMCREST][DEVICES]: if not hass.data[DATA_AMCREST][DEVICES]:
return False return False
def have_permission(user, entity_id): def have_permission(user, entity_id):
return not user or user.permissions.check_entity( return not user or user.permissions.check_entity(entity_id, POLICY_CONTROL)
entity_id, POLICY_CONTROL)
async def async_extract_from_service(call): async def async_extract_from_service(call):
if call.context.user_id: if call.context.user_id:
@ -261,7 +272,8 @@ def setup(hass, config):
if call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL: if call.data.get(ATTR_ENTITY_ID) == ENTITY_MATCH_ALL:
# Return all entity_ids user has permission to control. # Return all entity_ids user has permission to control.
return [ return [
entity_id for entity_id in hass.data[DATA_AMCREST][CAMERAS] entity_id
for entity_id in hass.data[DATA_AMCREST][CAMERAS]
if have_permission(user, entity_id) if have_permission(user, entity_id)
] ]
@ -272,9 +284,7 @@ def setup(hass, config):
continue continue
if not have_permission(user, entity_id): if not have_permission(user, entity_id):
raise Unauthorized( raise Unauthorized(
context=call.context, context=call.context, entity_id=entity_id, permission=POLICY_CONTROL
entity_id=entity_id,
permission=POLICY_CONTROL
) )
entity_ids.append(entity_id) entity_ids.append(entity_id)
return entity_ids return entity_ids
@ -284,15 +294,10 @@ def setup(hass, config):
for arg in CAMERA_SERVICES[call.service][2]: for arg in CAMERA_SERVICES[call.service][2]:
args.append(call.data[arg]) args.append(call.data[arg])
for entity_id in await async_extract_from_service(call): for entity_id in await async_extract_from_service(call):
async_dispatcher_send( async_dispatcher_send(hass, service_signal(call.service, entity_id), *args)
hass,
service_signal(call.service, entity_id),
*args
)
for service, params in CAMERA_SERVICES.items(): for service, params in CAMERA_SERVICES.items():
hass.services.async_register( hass.services.async_register(DOMAIN, service, async_service_handler, params[0])
DOMAIN, service, async_service_handler, params[0])
return True return True
@ -300,8 +305,15 @@ def setup(hass, config):
class AmcrestDevice: class AmcrestDevice:
"""Representation of a base Amcrest discovery device.""" """Representation of a base Amcrest discovery device."""
def __init__(self, api, authentication, ffmpeg_arguments, def __init__(
stream_source, resolution, control_light): self,
api,
authentication,
ffmpeg_arguments,
stream_source,
resolution,
control_light,
):
"""Initialize the entity.""" """Initialize the entity."""
self.api = api self.api = api
self.authentication = authentication self.authentication = authentication

View file

@ -5,29 +5,35 @@ import logging
from amcrest import AmcrestError from amcrest import AmcrestError
from homeassistant.components.binary_sensor import ( from homeassistant.components.binary_sensor import (
BinarySensorDevice, DEVICE_CLASS_CONNECTIVITY, DEVICE_CLASS_MOTION) BinarySensorDevice,
DEVICE_CLASS_CONNECTIVITY,
DEVICE_CLASS_MOTION,
)
from homeassistant.const import CONF_NAME, CONF_BINARY_SENSORS from homeassistant.const import CONF_NAME, CONF_BINARY_SENSORS
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from .const import ( from .const import (
BINARY_SENSOR_SCAN_INTERVAL_SECS, DATA_AMCREST, DEVICES, SERVICE_UPDATE) BINARY_SENSOR_SCAN_INTERVAL_SECS,
DATA_AMCREST,
DEVICES,
SERVICE_UPDATE,
)
from .helpers import log_update_error, service_signal from .helpers import log_update_error, service_signal
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SCAN_INTERVAL = timedelta(seconds=BINARY_SENSOR_SCAN_INTERVAL_SECS) SCAN_INTERVAL = timedelta(seconds=BINARY_SENSOR_SCAN_INTERVAL_SECS)
BINARY_SENSOR_MOTION_DETECTED = 'motion_detected' BINARY_SENSOR_MOTION_DETECTED = "motion_detected"
BINARY_SENSOR_ONLINE = 'online' BINARY_SENSOR_ONLINE = "online"
# Binary sensor types are defined like: Name, device class # Binary sensor types are defined like: Name, device class
BINARY_SENSORS = { BINARY_SENSORS = {
BINARY_SENSOR_MOTION_DETECTED: ('Motion Detected', DEVICE_CLASS_MOTION), BINARY_SENSOR_MOTION_DETECTED: ("Motion Detected", DEVICE_CLASS_MOTION),
BINARY_SENSOR_ONLINE: ('Online', DEVICE_CLASS_CONNECTIVITY), BINARY_SENSOR_ONLINE: ("Online", DEVICE_CLASS_CONNECTIVITY),
} }
async def async_setup_platform(hass, config, async_add_entities, async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
discovery_info=None):
"""Set up a binary sensor for an Amcrest IP Camera.""" """Set up a binary sensor for an Amcrest IP Camera."""
if discovery_info is None: if discovery_info is None:
return return
@ -35,9 +41,12 @@ async def async_setup_platform(hass, config, async_add_entities,
name = discovery_info[CONF_NAME] name = discovery_info[CONF_NAME]
device = hass.data[DATA_AMCREST][DEVICES][name] device = hass.data[DATA_AMCREST][DEVICES][name]
async_add_entities( async_add_entities(
[AmcrestBinarySensor(name, device, sensor_type) [
for sensor_type in discovery_info[CONF_BINARY_SENSORS]], AmcrestBinarySensor(name, device, sensor_type)
True) for sensor_type in discovery_info[CONF_BINARY_SENSORS]
],
True,
)
class AmcrestBinarySensor(BinarySensorDevice): class AmcrestBinarySensor(BinarySensorDevice):
@ -45,7 +54,7 @@ class AmcrestBinarySensor(BinarySensorDevice):
def __init__(self, name, device, sensor_type): def __init__(self, name, device, sensor_type):
"""Initialize entity.""" """Initialize entity."""
self._name = '{} {}'.format(name, BINARY_SENSORS[sensor_type][0]) self._name = "{} {}".format(name, BINARY_SENSORS[sensor_type][0])
self._signal_name = name self._signal_name = name
self._api = device.api self._api = device.api
self._sensor_type = sensor_type self._sensor_type = sensor_type
@ -82,7 +91,7 @@ class AmcrestBinarySensor(BinarySensorDevice):
"""Update entity.""" """Update entity."""
if not self.available: if not self.available:
return return
_LOGGER.debug('Updating %s binary sensor', self._name) _LOGGER.debug("Updating %s binary sensor", self._name)
try: try:
if self._sensor_type == BINARY_SENSOR_MOTION_DETECTED: if self._sensor_type == BINARY_SENSOR_MOTION_DETECTED:
@ -91,8 +100,7 @@ class AmcrestBinarySensor(BinarySensorDevice):
elif self._sensor_type == BINARY_SENSOR_ONLINE: elif self._sensor_type == BINARY_SENSOR_ONLINE:
self._state = self._api.available self._state = self._api.available
except AmcrestError as error: except AmcrestError as error:
log_update_error( log_update_error(_LOGGER, "update", self.name, "binary sensor", error)
_LOGGER, 'update', self.name, 'binary sensor', error)
async def async_on_demand_update(self): async def async_on_demand_update(self):
"""Update state.""" """Update state."""
@ -101,8 +109,10 @@ class AmcrestBinarySensor(BinarySensorDevice):
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Subscribe to update signal.""" """Subscribe to update signal."""
self._unsub_dispatcher = async_dispatcher_connect( self._unsub_dispatcher = async_dispatcher_connect(
self.hass, service_signal(SERVICE_UPDATE, self._signal_name), self.hass,
self.async_on_demand_update) service_signal(SERVICE_UPDATE, self._signal_name),
self.async_on_demand_update,
)
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
"""Disconnect from update signal.""" """Disconnect from update signal."""

View file

@ -8,83 +8,85 @@ from amcrest import AmcrestError
import voluptuous as vol import voluptuous as vol
from homeassistant.components.camera import ( from homeassistant.components.camera import (
Camera, CAMERA_SERVICE_SCHEMA, SUPPORT_ON_OFF, SUPPORT_STREAM) Camera,
CAMERA_SERVICE_SCHEMA,
SUPPORT_ON_OFF,
SUPPORT_STREAM,
)
from homeassistant.components.ffmpeg import DATA_FFMPEG from homeassistant.components.ffmpeg import DATA_FFMPEG
from homeassistant.const import ( from homeassistant.const import CONF_NAME, STATE_ON, STATE_OFF
CONF_NAME, STATE_ON, STATE_OFF)
from homeassistant.helpers.aiohttp_client import ( from homeassistant.helpers.aiohttp_client import (
async_aiohttp_proxy_stream, async_aiohttp_proxy_web, async_aiohttp_proxy_stream,
async_get_clientsession) async_aiohttp_proxy_web,
async_get_clientsession,
)
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from .const import ( from .const import (
CAMERA_WEB_SESSION_TIMEOUT, CAMERAS, DATA_AMCREST, DEVICES, SERVICE_UPDATE) CAMERA_WEB_SESSION_TIMEOUT,
CAMERAS,
DATA_AMCREST,
DEVICES,
SERVICE_UPDATE,
)
from .helpers import log_update_error, service_signal from .helpers import log_update_error, service_signal
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SCAN_INTERVAL = timedelta(seconds=15) SCAN_INTERVAL = timedelta(seconds=15)
STREAM_SOURCE_LIST = [ STREAM_SOURCE_LIST = ["snapshot", "mjpeg", "rtsp"]
'snapshot',
'mjpeg',
'rtsp',
]
_SRV_EN_REC = 'enable_recording' _SRV_EN_REC = "enable_recording"
_SRV_DS_REC = 'disable_recording' _SRV_DS_REC = "disable_recording"
_SRV_EN_AUD = 'enable_audio' _SRV_EN_AUD = "enable_audio"
_SRV_DS_AUD = 'disable_audio' _SRV_DS_AUD = "disable_audio"
_SRV_EN_MOT_REC = 'enable_motion_recording' _SRV_EN_MOT_REC = "enable_motion_recording"
_SRV_DS_MOT_REC = 'disable_motion_recording' _SRV_DS_MOT_REC = "disable_motion_recording"
_SRV_GOTO = 'goto_preset' _SRV_GOTO = "goto_preset"
_SRV_CBW = 'set_color_bw' _SRV_CBW = "set_color_bw"
_SRV_TOUR_ON = 'start_tour' _SRV_TOUR_ON = "start_tour"
_SRV_TOUR_OFF = 'stop_tour' _SRV_TOUR_OFF = "stop_tour"
_ATTR_PRESET = 'preset' _ATTR_PRESET = "preset"
_ATTR_COLOR_BW = 'color_bw' _ATTR_COLOR_BW = "color_bw"
_CBW_COLOR = 'color' _CBW_COLOR = "color"
_CBW_AUTO = 'auto' _CBW_AUTO = "auto"
_CBW_BW = 'bw' _CBW_BW = "bw"
_CBW = [_CBW_COLOR, _CBW_AUTO, _CBW_BW] _CBW = [_CBW_COLOR, _CBW_AUTO, _CBW_BW]
_SRV_GOTO_SCHEMA = CAMERA_SERVICE_SCHEMA.extend({ _SRV_GOTO_SCHEMA = CAMERA_SERVICE_SCHEMA.extend(
vol.Required(_ATTR_PRESET): vol.All(vol.Coerce(int), vol.Range(min=1)), {vol.Required(_ATTR_PRESET): vol.All(vol.Coerce(int), vol.Range(min=1))}
}) )
_SRV_CBW_SCHEMA = CAMERA_SERVICE_SCHEMA.extend({ _SRV_CBW_SCHEMA = CAMERA_SERVICE_SCHEMA.extend(
vol.Required(_ATTR_COLOR_BW): vol.In(_CBW), {vol.Required(_ATTR_COLOR_BW): vol.In(_CBW)}
}) )
CAMERA_SERVICES = { CAMERA_SERVICES = {
_SRV_EN_REC: (CAMERA_SERVICE_SCHEMA, 'async_enable_recording', ()), _SRV_EN_REC: (CAMERA_SERVICE_SCHEMA, "async_enable_recording", ()),
_SRV_DS_REC: (CAMERA_SERVICE_SCHEMA, 'async_disable_recording', ()), _SRV_DS_REC: (CAMERA_SERVICE_SCHEMA, "async_disable_recording", ()),
_SRV_EN_AUD: (CAMERA_SERVICE_SCHEMA, 'async_enable_audio', ()), _SRV_EN_AUD: (CAMERA_SERVICE_SCHEMA, "async_enable_audio", ()),
_SRV_DS_AUD: (CAMERA_SERVICE_SCHEMA, 'async_disable_audio', ()), _SRV_DS_AUD: (CAMERA_SERVICE_SCHEMA, "async_disable_audio", ()),
_SRV_EN_MOT_REC: ( _SRV_EN_MOT_REC: (CAMERA_SERVICE_SCHEMA, "async_enable_motion_recording", ()),
CAMERA_SERVICE_SCHEMA, 'async_enable_motion_recording', ()), _SRV_DS_MOT_REC: (CAMERA_SERVICE_SCHEMA, "async_disable_motion_recording", ()),
_SRV_DS_MOT_REC: ( _SRV_GOTO: (_SRV_GOTO_SCHEMA, "async_goto_preset", (_ATTR_PRESET,)),
CAMERA_SERVICE_SCHEMA, 'async_disable_motion_recording', ()), _SRV_CBW: (_SRV_CBW_SCHEMA, "async_set_color_bw", (_ATTR_COLOR_BW,)),
_SRV_GOTO: (_SRV_GOTO_SCHEMA, 'async_goto_preset', (_ATTR_PRESET,)), _SRV_TOUR_ON: (CAMERA_SERVICE_SCHEMA, "async_start_tour", ()),
_SRV_CBW: (_SRV_CBW_SCHEMA, 'async_set_color_bw', (_ATTR_COLOR_BW,)), _SRV_TOUR_OFF: (CAMERA_SERVICE_SCHEMA, "async_stop_tour", ()),
_SRV_TOUR_ON: (CAMERA_SERVICE_SCHEMA, 'async_start_tour', ()),
_SRV_TOUR_OFF: (CAMERA_SERVICE_SCHEMA, 'async_stop_tour', ()),
} }
_BOOL_TO_STATE = {True: STATE_ON, False: STATE_OFF} _BOOL_TO_STATE = {True: STATE_ON, False: STATE_OFF}
async def async_setup_platform(hass, config, async_add_entities, async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
discovery_info=None):
"""Set up an Amcrest IP Camera.""" """Set up an Amcrest IP Camera."""
if discovery_info is None: if discovery_info is None:
return return
name = discovery_info[CONF_NAME] name = discovery_info[CONF_NAME]
device = hass.data[DATA_AMCREST][DEVICES][name] device = hass.data[DATA_AMCREST][DEVICES][name]
async_add_entities([ async_add_entities([AmcrestCam(name, device, hass.data[DATA_FFMPEG])], True)
AmcrestCam(name, device, hass.data[DATA_FFMPEG])], True)
class AmcrestCam(Camera): class AmcrestCam(Camera):
@ -118,56 +120,59 @@ class AmcrestCam(Camera):
available = self.available available = self.available
if not available or not self.is_on: if not available or not self.is_on:
_LOGGER.warning( _LOGGER.warning(
'Attempt to take snaphot when %s camera is %s', self.name, "Attempt to take snaphot when %s camera is %s",
'offline' if not available else 'off') self.name,
"offline" if not available else "off",
)
return None return None
async with self._snapshot_lock: async with self._snapshot_lock:
try: try:
# Send the request to snap a picture and return raw jpg data # Send the request to snap a picture and return raw jpg data
response = await self.hass.async_add_executor_job( response = await self.hass.async_add_executor_job(self._api.snapshot)
self._api.snapshot)
return response.data return response.data
except (AmcrestError, HTTPError) as error: except (AmcrestError, HTTPError) as error:
log_update_error( log_update_error(_LOGGER, "get image from", self.name, "camera", error)
_LOGGER, 'get image from', self.name, 'camera', error)
return None return None
async def handle_async_mjpeg_stream(self, request): async def handle_async_mjpeg_stream(self, request):
"""Return an MJPEG stream.""" """Return an MJPEG stream."""
# The snapshot implementation is handled by the parent class # The snapshot implementation is handled by the parent class
if self._stream_source == 'snapshot': if self._stream_source == "snapshot":
return await super().handle_async_mjpeg_stream(request) return await super().handle_async_mjpeg_stream(request)
if not self.available: if not self.available:
_LOGGER.warning( _LOGGER.warning(
'Attempt to stream %s when %s camera is offline', "Attempt to stream %s when %s camera is offline",
self._stream_source, self.name) self._stream_source,
self.name,
)
return None return None
if self._stream_source == 'mjpeg': if self._stream_source == "mjpeg":
# stream an MJPEG image stream directly from the camera # stream an MJPEG image stream directly from the camera
websession = async_get_clientsession(self.hass) websession = async_get_clientsession(self.hass)
streaming_url = self._api.mjpeg_url(typeno=self._resolution) streaming_url = self._api.mjpeg_url(typeno=self._resolution)
stream_coro = websession.get( stream_coro = websession.get(
streaming_url, auth=self._token, streaming_url, auth=self._token, timeout=CAMERA_WEB_SESSION_TIMEOUT
timeout=CAMERA_WEB_SESSION_TIMEOUT) )
return await async_aiohttp_proxy_web( return await async_aiohttp_proxy_web(self.hass, request, stream_coro)
self.hass, request, stream_coro)
# streaming via ffmpeg # streaming via ffmpeg
from haffmpeg.camera import CameraMjpeg from haffmpeg.camera import CameraMjpeg
streaming_url = self._rtsp_url streaming_url = self._rtsp_url
stream = CameraMjpeg(self._ffmpeg.binary, loop=self.hass.loop) stream = CameraMjpeg(self._ffmpeg.binary, loop=self.hass.loop)
await stream.open_camera( await stream.open_camera(streaming_url, extra_cmd=self._ffmpeg_arguments)
streaming_url, extra_cmd=self._ffmpeg_arguments)
try: try:
stream_reader = await stream.get_reader() stream_reader = await stream.get_reader()
return await async_aiohttp_proxy_stream( return await async_aiohttp_proxy_stream(
self.hass, request, stream_reader, self.hass,
self._ffmpeg.ffmpeg_stream_content_type) request,
stream_reader,
self._ffmpeg.ffmpeg_stream_content_type,
)
finally: finally:
await stream.close() await stream.close()
@ -191,10 +196,11 @@ class AmcrestCam(Camera):
"""Return the Amcrest-specific camera state attributes.""" """Return the Amcrest-specific camera state attributes."""
attr = {} attr = {}
if self._audio_enabled is not None: if self._audio_enabled is not None:
attr['audio'] = _BOOL_TO_STATE.get(self._audio_enabled) attr["audio"] = _BOOL_TO_STATE.get(self._audio_enabled)
if self._motion_recording_enabled is not None: if self._motion_recording_enabled is not None:
attr['motion_recording'] = _BOOL_TO_STATE.get( attr["motion_recording"] = _BOOL_TO_STATE.get(
self._motion_recording_enabled) self._motion_recording_enabled
)
if self._color_bw is not None: if self._color_bw is not None:
attr[_ATTR_COLOR_BW] = self._color_bw attr[_ATTR_COLOR_BW] = self._color_bw
return attr return attr
@ -249,13 +255,20 @@ class AmcrestCam(Camera):
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Subscribe to signals and add camera to list.""" """Subscribe to signals and add camera to list."""
for service, params in CAMERA_SERVICES.items(): for service, params in CAMERA_SERVICES.items():
self._unsub_dispatcher.append(async_dispatcher_connect( self._unsub_dispatcher.append(
async_dispatcher_connect(
self.hass,
service_signal(service, self.entity_id),
getattr(self, params[1]),
)
)
self._unsub_dispatcher.append(
async_dispatcher_connect(
self.hass, self.hass,
service_signal(service, self.entity_id), service_signal(SERVICE_UPDATE, self._name),
getattr(self, params[1]))) self.async_on_demand_update,
self._unsub_dispatcher.append(async_dispatcher_connect( )
self.hass, service_signal(SERVICE_UPDATE, self._name), )
self.async_on_demand_update))
self.hass.data[DATA_AMCREST][CAMERAS].append(self.entity_id) self.hass.data[DATA_AMCREST][CAMERAS].append(self.entity_id)
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
@ -270,32 +283,29 @@ class AmcrestCam(Camera):
if not self.available: if not self.available:
self._update_succeeded = False self._update_succeeded = False
return return
_LOGGER.debug('Updating %s camera', self.name) _LOGGER.debug("Updating %s camera", self.name)
try: try:
if self._brand is None: if self._brand is None:
resp = self._api.vendor_information.strip() resp = self._api.vendor_information.strip()
if resp.startswith('vendor='): if resp.startswith("vendor="):
self._brand = resp.split('=')[-1] self._brand = resp.split("=")[-1]
else: else:
self._brand = 'unknown' self._brand = "unknown"
if self._model is None: if self._model is None:
resp = self._api.device_type.strip() resp = self._api.device_type.strip()
if resp.startswith('type='): if resp.startswith("type="):
self._model = resp.split('=')[-1] self._model = resp.split("=")[-1]
else: else:
self._model = 'unknown' self._model = "unknown"
self.is_streaming = self._api.video_enabled self.is_streaming = self._api.video_enabled
self._is_recording = self._api.record_mode == 'Manual' self._is_recording = self._api.record_mode == "Manual"
self._motion_detection_enabled = ( self._motion_detection_enabled = self._api.is_motion_detector_on()
self._api.is_motion_detector_on())
self._audio_enabled = self._api.audio_enabled self._audio_enabled = self._api.audio_enabled
self._motion_recording_enabled = ( self._motion_recording_enabled = self._api.is_record_on_motion_detection()
self._api.is_record_on_motion_detection())
self._color_bw = _CBW[self._api.day_night_color] self._color_bw = _CBW[self._api.day_night_color]
self._rtsp_url = self._api.rtsp_url(typeno=self._resolution) self._rtsp_url = self._api.rtsp_url(typeno=self._resolution)
except AmcrestError as error: except AmcrestError as error:
log_update_error( log_update_error(_LOGGER, "get", self.name, "camera attributes", error)
_LOGGER, 'get', self.name, 'camera attributes', error)
self._update_succeeded = False self._update_succeeded = False
else: else:
self._update_succeeded = True self._update_succeeded = True
@ -338,13 +348,11 @@ class AmcrestCam(Camera):
async def async_enable_motion_recording(self): async def async_enable_motion_recording(self):
"""Call the job and enable motion recording.""" """Call the job and enable motion recording."""
await self.hass.async_add_executor_job(self._enable_motion_recording, await self.hass.async_add_executor_job(self._enable_motion_recording, True)
True)
async def async_disable_motion_recording(self): async def async_disable_motion_recording(self):
"""Call the job and disable motion recording.""" """Call the job and disable motion recording."""
await self.hass.async_add_executor_job(self._enable_motion_recording, await self.hass.async_add_executor_job(self._enable_motion_recording, False)
False)
async def async_goto_preset(self, preset): async def async_goto_preset(self, preset):
"""Call the job and move camera to preset position.""" """Call the job and move camera to preset position."""
@ -375,8 +383,12 @@ class AmcrestCam(Camera):
self._api.video_enabled = enable self._api.video_enabled = enable
except AmcrestError as error: except AmcrestError as error:
log_update_error( log_update_error(
_LOGGER, 'enable' if enable else 'disable', self.name, _LOGGER,
'camera video stream', error) "enable" if enable else "disable",
self.name,
"camera video stream",
error,
)
else: else:
self.is_streaming = enable self.is_streaming = enable
self.schedule_update_ha_state() self.schedule_update_ha_state()
@ -390,14 +402,17 @@ class AmcrestCam(Camera):
# video stream off if recording is being turned on. # video stream off if recording is being turned on.
if not self.is_streaming and enable: if not self.is_streaming and enable:
self._enable_video_stream(True) self._enable_video_stream(True)
rec_mode = {'Automatic': 0, 'Manual': 1} rec_mode = {"Automatic": 0, "Manual": 1}
try: try:
self._api.record_mode = rec_mode[ self._api.record_mode = rec_mode["Manual" if enable else "Automatic"]
'Manual' if enable else 'Automatic']
except AmcrestError as error: except AmcrestError as error:
log_update_error( log_update_error(
_LOGGER, 'enable' if enable else 'disable', self.name, _LOGGER,
'camera recording', error) "enable" if enable else "disable",
self.name,
"camera recording",
error,
)
else: else:
self._is_recording = enable self._is_recording = enable
self.schedule_update_ha_state() self.schedule_update_ha_state()
@ -408,8 +423,12 @@ class AmcrestCam(Camera):
self._api.motion_detection = str(enable).lower() self._api.motion_detection = str(enable).lower()
except AmcrestError as error: except AmcrestError as error:
log_update_error( log_update_error(
_LOGGER, 'enable' if enable else 'disable', self.name, _LOGGER,
'camera motion detection', error) "enable" if enable else "disable",
self.name,
"camera motion detection",
error,
)
else: else:
self._motion_detection_enabled = enable self._motion_detection_enabled = enable
self.schedule_update_ha_state() self.schedule_update_ha_state()
@ -420,8 +439,12 @@ class AmcrestCam(Camera):
self._api.audio_enabled = enable self._api.audio_enabled = enable
except AmcrestError as error: except AmcrestError as error:
log_update_error( log_update_error(
_LOGGER, 'enable' if enable else 'disable', self.name, _LOGGER,
'camera audio stream', error) "enable" if enable else "disable",
self.name,
"camera audio stream",
error,
)
else: else:
self._audio_enabled = enable self._audio_enabled = enable
self.schedule_update_ha_state() self.schedule_update_ha_state()
@ -432,12 +455,18 @@ class AmcrestCam(Camera):
"""Enable or disable indicator light.""" """Enable or disable indicator light."""
try: try:
self._api.command( self._api.command(
'configManager.cgi?action=setConfig&LightGlobal[0].Enable={}' "configManager.cgi?action=setConfig&LightGlobal[0].Enable={}".format(
.format(str(enable).lower())) str(enable).lower()
)
)
except AmcrestError as error: except AmcrestError as error:
log_update_error( log_update_error(
_LOGGER, 'enable' if enable else 'disable', self.name, _LOGGER,
'indicator light', error) "enable" if enable else "disable",
self.name,
"indicator light",
error,
)
def _enable_motion_recording(self, enable): def _enable_motion_recording(self, enable):
"""Enable or disable motion recording.""" """Enable or disable motion recording."""
@ -445,8 +474,12 @@ class AmcrestCam(Camera):
self._api.motion_recording = str(enable).lower() self._api.motion_recording = str(enable).lower()
except AmcrestError as error: except AmcrestError as error:
log_update_error( log_update_error(
_LOGGER, 'enable' if enable else 'disable', self.name, _LOGGER,
'camera motion recording', error) "enable" if enable else "disable",
self.name,
"camera motion recording",
error,
)
else: else:
self._motion_recording_enabled = enable self._motion_recording_enabled = enable
self.schedule_update_ha_state() self.schedule_update_ha_state()
@ -454,12 +487,11 @@ class AmcrestCam(Camera):
def _goto_preset(self, preset): def _goto_preset(self, preset):
"""Move camera position and zoom to preset.""" """Move camera position and zoom to preset."""
try: try:
self._api.go_to_preset( self._api.go_to_preset(action="start", preset_point_number=preset)
action='start', preset_point_number=preset)
except AmcrestError as error: except AmcrestError as error:
log_update_error( log_update_error(
_LOGGER, 'move', self.name, _LOGGER, "move", self.name, "camera to preset {}".format(preset), error
'camera to preset {}'.format(preset), error) )
def _set_color_bw(self, cbw): def _set_color_bw(self, cbw):
"""Set camera color mode.""" """Set camera color mode."""
@ -467,8 +499,8 @@ class AmcrestCam(Camera):
self._api.day_night_color = _CBW.index(cbw) self._api.day_night_color = _CBW.index(cbw)
except AmcrestError as error: except AmcrestError as error:
log_update_error( log_update_error(
_LOGGER, 'set', self.name, _LOGGER, "set", self.name, "camera color mode to {}".format(cbw), error
'camera color mode to {}'.format(cbw), error) )
else: else:
self._color_bw = cbw self._color_bw = cbw
self.schedule_update_ha_state() self.schedule_update_ha_state()
@ -479,5 +511,5 @@ class AmcrestCam(Camera):
self._api.tour(start=start) self._api.tour(start=start)
except AmcrestError as error: except AmcrestError as error:
log_update_error( log_update_error(
_LOGGER, 'start' if start else 'stop', self.name, _LOGGER, "start" if start else "stop", self.name, "camera tour", error
'camera tour', error) )

View file

@ -1,11 +1,11 @@
"""Constants for amcrest component.""" """Constants for amcrest component."""
DOMAIN = 'amcrest' DOMAIN = "amcrest"
DATA_AMCREST = DOMAIN DATA_AMCREST = DOMAIN
CAMERAS = 'cameras' CAMERAS = "cameras"
DEVICES = 'devices' DEVICES = "devices"
BINARY_SENSOR_SCAN_INTERVAL_SECS = 5 BINARY_SENSOR_SCAN_INTERVAL_SECS = 5
CAMERA_WEB_SESSION_TIMEOUT = 10 CAMERA_WEB_SESSION_TIMEOUT = 10
SENSOR_SCAN_INTERVAL_SECS = 10 SENSOR_SCAN_INTERVAL_SECS = 10
SERVICE_UPDATE = 'update' SERVICE_UPDATE = "update"

View file

@ -4,14 +4,18 @@ from .const import DOMAIN
def service_signal(service, ident=None): def service_signal(service, ident=None):
"""Encode service and identifier into signal.""" """Encode service and identifier into signal."""
signal = '{}_{}'.format(DOMAIN, service) signal = "{}_{}".format(DOMAIN, service)
if ident: if ident:
signal += '_{}'.format(ident.replace('.', '_')) signal += "_{}".format(ident.replace(".", "_"))
return signal return signal
def log_update_error(logger, action, name, entity_type, error): def log_update_error(logger, action, name, entity_type, error):
"""Log an update error.""" """Log an update error."""
logger.error( logger.error(
'Could not %s %s %s due to error: %s', "Could not %s %s %s due to error: %s",
action, name, entity_type, error.__class__.__name__) action,
name,
entity_type,
error.__class__.__name__,
)

View file

@ -8,27 +8,25 @@ from homeassistant.const import CONF_NAME, CONF_SENSORS
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from .const import ( from .const import DATA_AMCREST, DEVICES, SENSOR_SCAN_INTERVAL_SECS, SERVICE_UPDATE
DATA_AMCREST, DEVICES, SENSOR_SCAN_INTERVAL_SECS, SERVICE_UPDATE)
from .helpers import log_update_error, service_signal from .helpers import log_update_error, service_signal
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SCAN_INTERVAL = timedelta(seconds=SENSOR_SCAN_INTERVAL_SECS) SCAN_INTERVAL = timedelta(seconds=SENSOR_SCAN_INTERVAL_SECS)
SENSOR_MOTION_DETECTOR = 'motion_detector' SENSOR_MOTION_DETECTOR = "motion_detector"
SENSOR_PTZ_PRESET = 'ptz_preset' SENSOR_PTZ_PRESET = "ptz_preset"
SENSOR_SDCARD = 'sdcard' SENSOR_SDCARD = "sdcard"
# Sensor types are defined like: Name, units, icon # Sensor types are defined like: Name, units, icon
SENSORS = { SENSORS = {
SENSOR_MOTION_DETECTOR: ['Motion Detected', None, 'mdi:run'], SENSOR_MOTION_DETECTOR: ["Motion Detected", None, "mdi:run"],
SENSOR_PTZ_PRESET: ['PTZ Preset', None, 'mdi:camera-iris'], SENSOR_PTZ_PRESET: ["PTZ Preset", None, "mdi:camera-iris"],
SENSOR_SDCARD: ['SD Used', '%', 'mdi:sd'], SENSOR_SDCARD: ["SD Used", "%", "mdi:sd"],
} }
async def async_setup_platform( async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
hass, config, async_add_entities, discovery_info=None):
"""Set up a sensor for an Amcrest IP Camera.""" """Set up a sensor for an Amcrest IP Camera."""
if discovery_info is None: if discovery_info is None:
return return
@ -36,9 +34,12 @@ async def async_setup_platform(
name = discovery_info[CONF_NAME] name = discovery_info[CONF_NAME]
device = hass.data[DATA_AMCREST][DEVICES][name] device = hass.data[DATA_AMCREST][DEVICES][name]
async_add_entities( async_add_entities(
[AmcrestSensor(name, device, sensor_type) [
for sensor_type in discovery_info[CONF_SENSORS]], AmcrestSensor(name, device, sensor_type)
True) for sensor_type in discovery_info[CONF_SENSORS]
],
True,
)
class AmcrestSensor(Entity): class AmcrestSensor(Entity):
@ -46,7 +47,7 @@ class AmcrestSensor(Entity):
def __init__(self, name, device, sensor_type): def __init__(self, name, device, sensor_type):
"""Initialize a sensor for Amcrest camera.""" """Initialize a sensor for Amcrest camera."""
self._name = '{} {}'.format(name, SENSORS[sensor_type][0]) self._name = "{} {}".format(name, SENSORS[sensor_type][0])
self._signal_name = name self._signal_name = name
self._api = device.api self._api = device.api
self._sensor_type = sensor_type self._sensor_type = sensor_type
@ -95,7 +96,7 @@ class AmcrestSensor(Entity):
try: try:
if self._sensor_type == SENSOR_MOTION_DETECTOR: if self._sensor_type == SENSOR_MOTION_DETECTOR:
self._state = self._api.is_motion_detected self._state = self._api.is_motion_detected
self._attrs['Record Mode'] = self._api.record_mode self._attrs["Record Mode"] = self._api.record_mode
elif self._sensor_type == SENSOR_PTZ_PRESET: elif self._sensor_type == SENSOR_PTZ_PRESET:
self._state = self._api.ptz_presets_count self._state = self._api.ptz_presets_count
@ -103,20 +104,19 @@ class AmcrestSensor(Entity):
elif self._sensor_type == SENSOR_SDCARD: elif self._sensor_type == SENSOR_SDCARD:
storage = self._api.storage_all storage = self._api.storage_all
try: try:
self._attrs['Total'] = '{:.2f} {}'.format( self._attrs["Total"] = "{:.2f} {}".format(*storage["total"])
*storage['total'])
except ValueError: except ValueError:
self._attrs['Total'] = '{} {}'.format(*storage['total']) self._attrs["Total"] = "{} {}".format(*storage["total"])
try: try:
self._attrs['Used'] = '{:.2f} {}'.format(*storage['used']) self._attrs["Used"] = "{:.2f} {}".format(*storage["used"])
except ValueError: except ValueError:
self._attrs['Used'] = '{} {}'.format(*storage['used']) self._attrs["Used"] = "{} {}".format(*storage["used"])
try: try:
self._state = '{:.2f}'.format(storage['used_percent']) self._state = "{:.2f}".format(storage["used_percent"])
except ValueError: except ValueError:
self._state = storage['used_percent'] self._state = storage["used_percent"]
except AmcrestError as error: except AmcrestError as error:
log_update_error(_LOGGER, 'update', self.name, 'sensor', error) log_update_error(_LOGGER, "update", self.name, "sensor", error)
async def async_on_demand_update(self): async def async_on_demand_update(self):
"""Update state.""" """Update state."""
@ -125,8 +125,10 @@ class AmcrestSensor(Entity):
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Subscribe to update signal.""" """Subscribe to update signal."""
self._unsub_dispatcher = async_dispatcher_connect( self._unsub_dispatcher = async_dispatcher_connect(
self.hass, service_signal(SERVICE_UPDATE, self._signal_name), self.hass,
self.async_on_demand_update) service_signal(SERVICE_UPDATE, self._signal_name),
self.async_on_demand_update,
)
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
"""Disconnect from update signal.""" """Disconnect from update signal."""

View file

@ -12,17 +12,16 @@ from .helpers import log_update_error, service_signal
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
MOTION_DETECTION = 'motion_detection' MOTION_DETECTION = "motion_detection"
MOTION_RECORDING = 'motion_recording' MOTION_RECORDING = "motion_recording"
# Switch types are defined like: Name, icon # Switch types are defined like: Name, icon
SWITCHES = { SWITCHES = {
MOTION_DETECTION: ['Motion Detection', 'mdi:run-fast'], MOTION_DETECTION: ["Motion Detection", "mdi:run-fast"],
MOTION_RECORDING: ['Motion Recording', 'mdi:record-rec'] MOTION_RECORDING: ["Motion Recording", "mdi:record-rec"],
} }
async def async_setup_platform( async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
hass, config, async_add_entities, discovery_info=None):
"""Set up the IP Amcrest camera switch platform.""" """Set up the IP Amcrest camera switch platform."""
if discovery_info is None: if discovery_info is None:
return return
@ -30,9 +29,12 @@ async def async_setup_platform(
name = discovery_info[CONF_NAME] name = discovery_info[CONF_NAME]
device = hass.data[DATA_AMCREST][DEVICES][name] device = hass.data[DATA_AMCREST][DEVICES][name]
async_add_entities( async_add_entities(
[AmcrestSwitch(name, device, setting) [
for setting in discovery_info[CONF_SWITCHES]], AmcrestSwitch(name, device, setting)
True) for setting in discovery_info[CONF_SWITCHES]
],
True,
)
class AmcrestSwitch(ToggleEntity): class AmcrestSwitch(ToggleEntity):
@ -40,7 +42,7 @@ class AmcrestSwitch(ToggleEntity):
def __init__(self, name, device, setting): def __init__(self, name, device, setting):
"""Initialize the Amcrest switch.""" """Initialize the Amcrest switch."""
self._name = '{} {}'.format(name, SWITCHES[setting][0]) self._name = "{} {}".format(name, SWITCHES[setting][0])
self._signal_name = name self._signal_name = name
self._api = device.api self._api = device.api
self._setting = setting self._setting = setting
@ -64,11 +66,11 @@ class AmcrestSwitch(ToggleEntity):
return return
try: try:
if self._setting == MOTION_DETECTION: if self._setting == MOTION_DETECTION:
self._api.motion_detection = 'true' self._api.motion_detection = "true"
elif self._setting == MOTION_RECORDING: elif self._setting == MOTION_RECORDING:
self._api.motion_recording = 'true' self._api.motion_recording = "true"
except AmcrestError as error: except AmcrestError as error:
log_update_error(_LOGGER, 'turn on', self.name, 'switch', error) log_update_error(_LOGGER, "turn on", self.name, "switch", error)
def turn_off(self, **kwargs): def turn_off(self, **kwargs):
"""Turn setting off.""" """Turn setting off."""
@ -76,11 +78,11 @@ class AmcrestSwitch(ToggleEntity):
return return
try: try:
if self._setting == MOTION_DETECTION: if self._setting == MOTION_DETECTION:
self._api.motion_detection = 'false' self._api.motion_detection = "false"
elif self._setting == MOTION_RECORDING: elif self._setting == MOTION_RECORDING:
self._api.motion_recording = 'false' self._api.motion_recording = "false"
except AmcrestError as error: except AmcrestError as error:
log_update_error(_LOGGER, 'turn off', self.name, 'switch', error) log_update_error(_LOGGER, "turn off", self.name, "switch", error)
@property @property
def available(self): def available(self):
@ -100,7 +102,7 @@ class AmcrestSwitch(ToggleEntity):
detection = self._api.is_record_on_motion_detection() detection = self._api.is_record_on_motion_detection()
self._state = detection self._state = detection
except AmcrestError as error: except AmcrestError as error:
log_update_error(_LOGGER, 'update', self.name, 'switch', error) log_update_error(_LOGGER, "update", self.name, "switch", error)
@property @property
def icon(self): def icon(self):
@ -114,8 +116,10 @@ class AmcrestSwitch(ToggleEntity):
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Subscribe to update signal.""" """Subscribe to update signal."""
self._unsub_dispatcher = async_dispatcher_connect( self._unsub_dispatcher = async_dispatcher_connect(
self.hass, service_signal(SERVICE_UPDATE, self._signal_name), self.hass,
self.async_on_demand_update) service_signal(SERVICE_UPDATE, self._signal_name),
self.async_on_demand_update,
)
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self):
"""Disconnect from update signal.""" """Disconnect from update signal."""

View file

@ -4,8 +4,7 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.components.air_quality import ( from homeassistant.components.air_quality import PLATFORM_SCHEMA, AirQualityEntity
PLATFORM_SCHEMA, AirQualityEntity)
from homeassistant.const import CONF_NAME from homeassistant.const import CONF_NAME
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -13,18 +12,16 @@ from homeassistant.util import Throttle
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ATTRIBUTION = 'Data provided by Ampio' ATTRIBUTION = "Data provided by Ampio"
CONF_STATION_ID = 'station_id' CONF_STATION_ID = "station_id"
SCAN_INTERVAL = timedelta(minutes=10) SCAN_INTERVAL = timedelta(minutes=10)
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_STATION_ID): cv.string, {vol.Required(CONF_STATION_ID): cv.string, vol.Optional(CONF_NAME): cv.string}
vol.Optional(CONF_NAME): cv.string, )
})
async def async_setup_platform( async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
hass, config, async_add_entities, discovery_info=None):
"""Set up the Ampio Smog air quality platform.""" """Set up the Ampio Smog air quality platform."""
from asmog import AmpioSmog from asmog import AmpioSmog

View file

@ -7,140 +7,182 @@ import voluptuous as vol
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.const import ( from homeassistant.const import (
CONF_NAME, CONF_HOST, CONF_PORT, CONF_USERNAME, CONF_PASSWORD, CONF_NAME,
CONF_SENSORS, CONF_SWITCHES, CONF_TIMEOUT, CONF_SCAN_INTERVAL, CONF_HOST,
CONF_PLATFORM) CONF_PORT,
CONF_USERNAME,
CONF_PASSWORD,
CONF_SENSORS,
CONF_SWITCHES,
CONF_TIMEOUT,
CONF_SCAN_INTERVAL,
CONF_PLATFORM,
)
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers import discovery from homeassistant.helpers import discovery
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_send, async_dispatcher_connect) async_dispatcher_send,
async_dispatcher_connect,
)
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.event import async_track_point_in_utc_time from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from homeassistant.components.mjpeg.camera import ( from homeassistant.components.mjpeg.camera import CONF_MJPEG_URL, CONF_STILL_IMAGE_URL
CONF_MJPEG_URL, CONF_STILL_IMAGE_URL)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
ATTR_AUD_CONNS = 'Audio Connections' ATTR_AUD_CONNS = "Audio Connections"
ATTR_HOST = 'host' ATTR_HOST = "host"
ATTR_VID_CONNS = 'Video Connections' ATTR_VID_CONNS = "Video Connections"
CONF_MOTION_SENSOR = 'motion_sensor' CONF_MOTION_SENSOR = "motion_sensor"
DATA_IP_WEBCAM = 'android_ip_webcam' DATA_IP_WEBCAM = "android_ip_webcam"
DEFAULT_NAME = 'IP Webcam' DEFAULT_NAME = "IP Webcam"
DEFAULT_PORT = 8080 DEFAULT_PORT = 8080
DEFAULT_TIMEOUT = 10 DEFAULT_TIMEOUT = 10
DOMAIN = 'android_ip_webcam' DOMAIN = "android_ip_webcam"
SCAN_INTERVAL = timedelta(seconds=10) SCAN_INTERVAL = timedelta(seconds=10)
SIGNAL_UPDATE_DATA = 'android_ip_webcam_update' SIGNAL_UPDATE_DATA = "android_ip_webcam_update"
KEY_MAP = { KEY_MAP = {
'audio_connections': 'Audio Connections', "audio_connections": "Audio Connections",
'adet_limit': 'Audio Trigger Limit', "adet_limit": "Audio Trigger Limit",
'antibanding': 'Anti-banding', "antibanding": "Anti-banding",
'audio_only': 'Audio Only', "audio_only": "Audio Only",
'battery_level': 'Battery Level', "battery_level": "Battery Level",
'battery_temp': 'Battery Temperature', "battery_temp": "Battery Temperature",
'battery_voltage': 'Battery Voltage', "battery_voltage": "Battery Voltage",
'coloreffect': 'Color Effect', "coloreffect": "Color Effect",
'exposure': 'Exposure Level', "exposure": "Exposure Level",
'exposure_lock': 'Exposure Lock', "exposure_lock": "Exposure Lock",
'ffc': 'Front-facing Camera', "ffc": "Front-facing Camera",
'flashmode': 'Flash Mode', "flashmode": "Flash Mode",
'focus': 'Focus', "focus": "Focus",
'focus_homing': 'Focus Homing', "focus_homing": "Focus Homing",
'focus_region': 'Focus Region', "focus_region": "Focus Region",
'focusmode': 'Focus Mode', "focusmode": "Focus Mode",
'gps_active': 'GPS Active', "gps_active": "GPS Active",
'idle': 'Idle', "idle": "Idle",
'ip_address': 'IPv4 Address', "ip_address": "IPv4 Address",
'ipv6_address': 'IPv6 Address', "ipv6_address": "IPv6 Address",
'ivideon_streaming': 'Ivideon Streaming', "ivideon_streaming": "Ivideon Streaming",
'light': 'Light Level', "light": "Light Level",
'mirror_flip': 'Mirror Flip', "mirror_flip": "Mirror Flip",
'motion': 'Motion', "motion": "Motion",
'motion_active': 'Motion Active', "motion_active": "Motion Active",
'motion_detect': 'Motion Detection', "motion_detect": "Motion Detection",
'motion_event': 'Motion Event', "motion_event": "Motion Event",
'motion_limit': 'Motion Limit', "motion_limit": "Motion Limit",
'night_vision': 'Night Vision', "night_vision": "Night Vision",
'night_vision_average': 'Night Vision Average', "night_vision_average": "Night Vision Average",
'night_vision_gain': 'Night Vision Gain', "night_vision_gain": "Night Vision Gain",
'orientation': 'Orientation', "orientation": "Orientation",
'overlay': 'Overlay', "overlay": "Overlay",
'photo_size': 'Photo Size', "photo_size": "Photo Size",
'pressure': 'Pressure', "pressure": "Pressure",
'proximity': 'Proximity', "proximity": "Proximity",
'quality': 'Quality', "quality": "Quality",
'scenemode': 'Scene Mode', "scenemode": "Scene Mode",
'sound': 'Sound', "sound": "Sound",
'sound_event': 'Sound Event', "sound_event": "Sound Event",
'sound_timeout': 'Sound Timeout', "sound_timeout": "Sound Timeout",
'torch': 'Torch', "torch": "Torch",
'video_connections': 'Video Connections', "video_connections": "Video Connections",
'video_chunk_len': 'Video Chunk Length', "video_chunk_len": "Video Chunk Length",
'video_recording': 'Video Recording', "video_recording": "Video Recording",
'video_size': 'Video Size', "video_size": "Video Size",
'whitebalance': 'White Balance', "whitebalance": "White Balance",
'whitebalance_lock': 'White Balance Lock', "whitebalance_lock": "White Balance Lock",
'zoom': 'Zoom' "zoom": "Zoom",
} }
ICON_MAP = { ICON_MAP = {
'audio_connections': 'mdi:speaker', "audio_connections": "mdi:speaker",
'battery_level': 'mdi:battery', "battery_level": "mdi:battery",
'battery_temp': 'mdi:thermometer', "battery_temp": "mdi:thermometer",
'battery_voltage': 'mdi:battery-charging-100', "battery_voltage": "mdi:battery-charging-100",
'exposure_lock': 'mdi:camera', "exposure_lock": "mdi:camera",
'ffc': 'mdi:camera-front-variant', "ffc": "mdi:camera-front-variant",
'focus': 'mdi:image-filter-center-focus', "focus": "mdi:image-filter-center-focus",
'gps_active': 'mdi:crosshairs-gps', "gps_active": "mdi:crosshairs-gps",
'light': 'mdi:flashlight', "light": "mdi:flashlight",
'motion': 'mdi:run', "motion": "mdi:run",
'night_vision': 'mdi:weather-night', "night_vision": "mdi:weather-night",
'overlay': 'mdi:monitor', "overlay": "mdi:monitor",
'pressure': 'mdi:gauge', "pressure": "mdi:gauge",
'proximity': 'mdi:map-marker-radius', "proximity": "mdi:map-marker-radius",
'quality': 'mdi:quality-high', "quality": "mdi:quality-high",
'sound': 'mdi:speaker', "sound": "mdi:speaker",
'sound_event': 'mdi:speaker', "sound_event": "mdi:speaker",
'sound_timeout': 'mdi:speaker', "sound_timeout": "mdi:speaker",
'torch': 'mdi:white-balance-sunny', "torch": "mdi:white-balance-sunny",
'video_chunk_len': 'mdi:video', "video_chunk_len": "mdi:video",
'video_connections': 'mdi:eye', "video_connections": "mdi:eye",
'video_recording': 'mdi:record-rec', "video_recording": "mdi:record-rec",
'whitebalance_lock': 'mdi:white-balance-auto' "whitebalance_lock": "mdi:white-balance-auto",
} }
SWITCHES = ['exposure_lock', 'ffc', 'focus', 'gps_active', SWITCHES = [
'motion_detect', 'night_vision', 'overlay', "exposure_lock",
'torch', 'whitebalance_lock', 'video_recording'] "ffc",
"focus",
"gps_active",
"motion_detect",
"night_vision",
"overlay",
"torch",
"whitebalance_lock",
"video_recording",
]
SENSORS = ['audio_connections', 'battery_level', 'battery_temp', SENSORS = [
'battery_voltage', 'light', 'motion', 'pressure', 'proximity', "audio_connections",
'sound', 'video_connections'] "battery_level",
"battery_temp",
"battery_voltage",
"light",
"motion",
"pressure",
"proximity",
"sound",
"video_connections",
]
CONFIG_SCHEMA = vol.Schema({ CONFIG_SCHEMA = vol.Schema(
DOMAIN: vol.All(cv.ensure_list, [vol.Schema({ {
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, DOMAIN: vol.All(
vol.Required(CONF_HOST): cv.string, cv.ensure_list,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port, [
vol.Optional(CONF_TIMEOUT, default=DEFAULT_TIMEOUT): cv.positive_int, vol.Schema(
vol.Optional(CONF_SCAN_INTERVAL, default=SCAN_INTERVAL): {
cv.time_period, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Inclusive(CONF_USERNAME, 'authentication'): cv.string, vol.Required(CONF_HOST): cv.string,
vol.Inclusive(CONF_PASSWORD, 'authentication'): cv.string, vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
vol.Optional(CONF_SWITCHES): vol.Optional(
vol.All(cv.ensure_list, [vol.In(SWITCHES)]), CONF_TIMEOUT, default=DEFAULT_TIMEOUT
vol.Optional(CONF_SENSORS): ): cv.positive_int,
vol.All(cv.ensure_list, [vol.In(SENSORS)]), vol.Optional(
vol.Optional(CONF_MOTION_SENSOR): cv.boolean, CONF_SCAN_INTERVAL, default=SCAN_INTERVAL
})]) ): cv.time_period,
}, extra=vol.ALLOW_EXTRA) vol.Inclusive(CONF_USERNAME, "authentication"): cv.string,
vol.Inclusive(CONF_PASSWORD, "authentication"): cv.string,
vol.Optional(CONF_SWITCHES): vol.All(
cv.ensure_list, [vol.In(SWITCHES)]
),
vol.Optional(CONF_SENSORS): vol.All(
cv.ensure_list, [vol.In(SENSORS)]
),
vol.Optional(CONF_MOTION_SENSOR): cv.boolean,
}
)
],
)
},
extra=vol.ALLOW_EXTRA,
)
async def async_setup(hass, config): async def async_setup(hass, config):
@ -163,30 +205,33 @@ async def async_setup(hass, config):
# Init ip webcam # Init ip webcam
cam = PyDroidIPCam( cam = PyDroidIPCam(
hass.loop, websession, host, cam_config[CONF_PORT], hass.loop,
username=username, password=password, websession,
timeout=cam_config[CONF_TIMEOUT] host,
cam_config[CONF_PORT],
username=username,
password=password,
timeout=cam_config[CONF_TIMEOUT],
) )
if switches is None: if switches is None:
switches = [setting for setting in cam.enabled_settings switches = [
if setting in SWITCHES] setting for setting in cam.enabled_settings if setting in SWITCHES
]
if sensors is None: if sensors is None:
sensors = [sensor for sensor in cam.enabled_sensors sensors = [sensor for sensor in cam.enabled_sensors if sensor in SENSORS]
if sensor in SENSORS] sensors.extend(["audio_connections", "video_connections"])
sensors.extend(['audio_connections', 'video_connections'])
if motion is None: if motion is None:
motion = 'motion_active' in cam.enabled_sensors motion = "motion_active" in cam.enabled_sensors
async def async_update_data(now): async def async_update_data(now):
"""Update data from IP camera in SCAN_INTERVAL.""" """Update data from IP camera in SCAN_INTERVAL."""
await cam.update() await cam.update()
async_dispatcher_send(hass, SIGNAL_UPDATE_DATA, host) async_dispatcher_send(hass, SIGNAL_UPDATE_DATA, host)
async_track_point_in_utc_time( async_track_point_in_utc_time(hass, async_update_data, utcnow() + interval)
hass, async_update_data, utcnow() + interval)
await async_update_data(None) await async_update_data(None)
@ -194,42 +239,50 @@ async def async_setup(hass, config):
webcams[host] = cam webcams[host] = cam
mjpeg_camera = { mjpeg_camera = {
CONF_PLATFORM: 'mjpeg', CONF_PLATFORM: "mjpeg",
CONF_MJPEG_URL: cam.mjpeg_url, CONF_MJPEG_URL: cam.mjpeg_url,
CONF_STILL_IMAGE_URL: cam.image_url, CONF_STILL_IMAGE_URL: cam.image_url,
CONF_NAME: name, CONF_NAME: name,
} }
if username and password: if username and password:
mjpeg_camera.update({ mjpeg_camera.update({CONF_USERNAME: username, CONF_PASSWORD: password})
CONF_USERNAME: username,
CONF_PASSWORD: password
})
hass.async_create_task(discovery.async_load_platform( hass.async_create_task(
hass, 'camera', 'mjpeg', mjpeg_camera, config)) discovery.async_load_platform(hass, "camera", "mjpeg", mjpeg_camera, config)
)
if sensors: if sensors:
hass.async_create_task(discovery.async_load_platform( hass.async_create_task(
hass, 'sensor', DOMAIN, { discovery.async_load_platform(
CONF_NAME: name, hass,
CONF_HOST: host, "sensor",
CONF_SENSORS: sensors, DOMAIN,
}, config)) {CONF_NAME: name, CONF_HOST: host, CONF_SENSORS: sensors},
config,
)
)
if switches: if switches:
hass.async_create_task(discovery.async_load_platform( hass.async_create_task(
hass, 'switch', DOMAIN, { discovery.async_load_platform(
CONF_NAME: name, hass,
CONF_HOST: host, "switch",
CONF_SWITCHES: switches, DOMAIN,
}, config)) {CONF_NAME: name, CONF_HOST: host, CONF_SWITCHES: switches},
config,
)
)
if motion: if motion:
hass.async_create_task(discovery.async_load_platform( hass.async_create_task(
hass, 'binary_sensor', DOMAIN, { discovery.async_load_platform(
CONF_HOST: host, hass,
CONF_NAME: name, "binary_sensor",
}, config)) DOMAIN,
{CONF_HOST: host, CONF_NAME: name},
config,
)
)
tasks = [async_setup_ipcamera(conf) for conf in config[DOMAIN]] tasks = [async_setup_ipcamera(conf) for conf in config[DOMAIN]]
if tasks: if tasks:
@ -248,6 +301,7 @@ class AndroidIPCamEntity(Entity):
async def async_added_to_hass(self): async def async_added_to_hass(self):
"""Register update dispatcher.""" """Register update dispatcher."""
@callback @callback
def async_ipcam_update(host): def async_ipcam_update(host):
"""Update callback.""" """Update callback."""
@ -255,8 +309,7 @@ class AndroidIPCamEntity(Entity):
return return
self.async_schedule_update_ha_state(True) self.async_schedule_update_ha_state(True)
async_dispatcher_connect( async_dispatcher_connect(self.hass, SIGNAL_UPDATE_DATA, async_ipcam_update)
self.hass, SIGNAL_UPDATE_DATA, async_ipcam_update)
@property @property
def should_poll(self): def should_poll(self):
@ -275,9 +328,7 @@ class AndroidIPCamEntity(Entity):
if self._ipcam.status_data is None: if self._ipcam.status_data is None:
return state_attr return state_attr
state_attr[ATTR_VID_CONNS] = \ state_attr[ATTR_VID_CONNS] = self._ipcam.status_data.get("video_connections")
self._ipcam.status_data.get('video_connections') state_attr[ATTR_AUD_CONNS] = self._ipcam.status_data.get("audio_connections")
state_attr[ATTR_AUD_CONNS] = \
self._ipcam.status_data.get('audio_connections')
return state_attr return state_attr

View file

@ -4,8 +4,7 @@ from homeassistant.components.binary_sensor import BinarySensorDevice
from . import CONF_HOST, CONF_NAME, DATA_IP_WEBCAM, KEY_MAP, AndroidIPCamEntity from . import CONF_HOST, CONF_NAME, DATA_IP_WEBCAM, KEY_MAP, AndroidIPCamEntity
async def async_setup_platform( async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
hass, config, async_add_entities, discovery_info=None):
"""Set up the IP Webcam binary sensors.""" """Set up the IP Webcam binary sensors."""
if discovery_info is None: if discovery_info is None:
return return
@ -14,8 +13,7 @@ async def async_setup_platform(
name = discovery_info[CONF_NAME] name = discovery_info[CONF_NAME]
ipcam = hass.data[DATA_IP_WEBCAM][host] ipcam = hass.data[DATA_IP_WEBCAM][host]
async_add_entities( async_add_entities([IPWebcamBinarySensor(name, host, ipcam, "motion_active")], True)
[IPWebcamBinarySensor(name, host, ipcam, 'motion_active')], True)
class IPWebcamBinarySensor(AndroidIPCamEntity, BinarySensorDevice): class IPWebcamBinarySensor(AndroidIPCamEntity, BinarySensorDevice):
@ -27,7 +25,7 @@ class IPWebcamBinarySensor(AndroidIPCamEntity, BinarySensorDevice):
self._sensor = sensor self._sensor = sensor
self._mapped_name = KEY_MAP.get(self._sensor, self._sensor) self._mapped_name = KEY_MAP.get(self._sensor, self._sensor)
self._name = '{} {}'.format(name, self._mapped_name) self._name = "{} {}".format(name, self._mapped_name)
self._state = None self._state = None
self._unit = None self._unit = None
@ -49,4 +47,4 @@ class IPWebcamBinarySensor(AndroidIPCamEntity, BinarySensorDevice):
@property @property
def device_class(self): def device_class(self):
"""Return the class of this device, from component DEVICE_CLASSES.""" """Return the class of this device, from component DEVICE_CLASSES."""
return 'motion' return "motion"

View file

@ -2,12 +2,17 @@
from homeassistant.helpers.icon import icon_for_battery_level from homeassistant.helpers.icon import icon_for_battery_level
from . import ( from . import (
CONF_HOST, CONF_NAME, CONF_SENSORS, DATA_IP_WEBCAM, ICON_MAP, KEY_MAP, CONF_HOST,
AndroidIPCamEntity) CONF_NAME,
CONF_SENSORS,
DATA_IP_WEBCAM,
ICON_MAP,
KEY_MAP,
AndroidIPCamEntity,
)
async def async_setup_platform( async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
hass, config, async_add_entities, discovery_info=None):
"""Set up the IP Webcam Sensor.""" """Set up the IP Webcam Sensor."""
if discovery_info is None: if discovery_info is None:
return return
@ -34,7 +39,7 @@ class IPWebcamSensor(AndroidIPCamEntity):
self._sensor = sensor self._sensor = sensor
self._mapped_name = KEY_MAP.get(self._sensor, self._sensor) self._mapped_name = KEY_MAP.get(self._sensor, self._sensor)
self._name = '{} {}'.format(name, self._mapped_name) self._name = "{} {}".format(name, self._mapped_name)
self._state = None self._state = None
self._unit = None self._unit = None
@ -55,17 +60,17 @@ class IPWebcamSensor(AndroidIPCamEntity):
async def async_update(self): async def async_update(self):
"""Retrieve latest state.""" """Retrieve latest state."""
if self._sensor in ('audio_connections', 'video_connections'): if self._sensor in ("audio_connections", "video_connections"):
if not self._ipcam.status_data: if not self._ipcam.status_data:
return return
self._state = self._ipcam.status_data.get(self._sensor) self._state = self._ipcam.status_data.get(self._sensor)
self._unit = 'Connections' self._unit = "Connections"
else: else:
self._state, self._unit = self._ipcam.export_sensor(self._sensor) self._state, self._unit = self._ipcam.export_sensor(self._sensor)
@property @property
def icon(self): def icon(self):
"""Return the icon for the sensor.""" """Return the icon for the sensor."""
if self._sensor == 'battery_level' and self._state is not None: if self._sensor == "battery_level" and self._state is not None:
return icon_for_battery_level(int(self._state)) return icon_for_battery_level(int(self._state))
return ICON_MAP.get(self._sensor, 'mdi:eye') return ICON_MAP.get(self._sensor, "mdi:eye")

View file

@ -2,12 +2,17 @@
from homeassistant.components.switch import SwitchDevice from homeassistant.components.switch import SwitchDevice
from . import ( from . import (
CONF_HOST, CONF_NAME, CONF_SWITCHES, DATA_IP_WEBCAM, ICON_MAP, KEY_MAP, CONF_HOST,
AndroidIPCamEntity) CONF_NAME,
CONF_SWITCHES,
DATA_IP_WEBCAM,
ICON_MAP,
KEY_MAP,
AndroidIPCamEntity,
)
async def async_setup_platform( async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
hass, config, async_add_entities, discovery_info=None):
"""Set up the IP Webcam switch platform.""" """Set up the IP Webcam switch platform."""
if discovery_info is None: if discovery_info is None:
return return
@ -34,7 +39,7 @@ class IPWebcamSettingsSwitch(AndroidIPCamEntity, SwitchDevice):
self._setting = setting self._setting = setting
self._mapped_name = KEY_MAP.get(self._setting, self._setting) self._mapped_name = KEY_MAP.get(self._setting, self._setting)
self._name = '{} {}'.format(name, self._mapped_name) self._name = "{} {}".format(name, self._mapped_name)
self._state = False self._state = False
@property @property
@ -53,11 +58,11 @@ class IPWebcamSettingsSwitch(AndroidIPCamEntity, SwitchDevice):
async def async_turn_on(self, **kwargs): async def async_turn_on(self, **kwargs):
"""Turn device on.""" """Turn device on."""
if self._setting == 'torch': if self._setting == "torch":
await self._ipcam.torch(activate=True) await self._ipcam.torch(activate=True)
elif self._setting == 'focus': elif self._setting == "focus":
await self._ipcam.focus(activate=True) await self._ipcam.focus(activate=True)
elif self._setting == 'video_recording': elif self._setting == "video_recording":
await self._ipcam.record(record=True) await self._ipcam.record(record=True)
else: else:
await self._ipcam.change_setting(self._setting, True) await self._ipcam.change_setting(self._setting, True)
@ -66,11 +71,11 @@ class IPWebcamSettingsSwitch(AndroidIPCamEntity, SwitchDevice):
async def async_turn_off(self, **kwargs): async def async_turn_off(self, **kwargs):
"""Turn device off.""" """Turn device off."""
if self._setting == 'torch': if self._setting == "torch":
await self._ipcam.torch(activate=False) await self._ipcam.torch(activate=False)
elif self._setting == 'focus': elif self._setting == "focus":
await self._ipcam.focus(activate=False) await self._ipcam.focus(activate=False)
elif self._setting == 'video_recording': elif self._setting == "video_recording":
await self._ipcam.record(record=False) await self._ipcam.record(record=False)
else: else:
await self._ipcam.change_setting(self._setting, False) await self._ipcam.change_setting(self._setting, False)
@ -80,4 +85,4 @@ class IPWebcamSettingsSwitch(AndroidIPCamEntity, SwitchDevice):
@property @property
def icon(self): def icon(self):
"""Return the icon for the switch.""" """Return the icon for the switch."""
return ICON_MAP.get(self._setting, 'mdi:flash') return ICON_MAP.get(self._setting, "mdi:flash")

View file

@ -3,81 +3,113 @@ import functools
import logging import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.components.media_player import ( from homeassistant.components.media_player import MediaPlayerDevice, PLATFORM_SCHEMA
MediaPlayerDevice, PLATFORM_SCHEMA)
from homeassistant.components.media_player.const import ( from homeassistant.components.media_player.const import (
SUPPORT_NEXT_TRACK, SUPPORT_PAUSE, SUPPORT_PLAY, SUPPORT_PREVIOUS_TRACK, SUPPORT_NEXT_TRACK,
SUPPORT_SELECT_SOURCE, SUPPORT_STOP, SUPPORT_TURN_OFF, SUPPORT_TURN_ON, SUPPORT_PAUSE,
SUPPORT_VOLUME_MUTE, SUPPORT_VOLUME_STEP) SUPPORT_PLAY,
SUPPORT_PREVIOUS_TRACK,
SUPPORT_SELECT_SOURCE,
SUPPORT_STOP,
SUPPORT_TURN_OFF,
SUPPORT_TURN_ON,
SUPPORT_VOLUME_MUTE,
SUPPORT_VOLUME_STEP,
)
from homeassistant.const import ( from homeassistant.const import (
ATTR_COMMAND, ATTR_ENTITY_ID, CONF_DEVICE_CLASS, CONF_HOST, CONF_NAME, ATTR_COMMAND,
CONF_PORT, STATE_IDLE, STATE_OFF, STATE_PAUSED, STATE_PLAYING, ATTR_ENTITY_ID,
STATE_STANDBY) CONF_DEVICE_CLASS,
CONF_HOST,
CONF_NAME,
CONF_PORT,
STATE_IDLE,
STATE_OFF,
STATE_PAUSED,
STATE_PLAYING,
STATE_STANDBY,
)
from homeassistant.exceptions import PlatformNotReady from homeassistant.exceptions import PlatformNotReady
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
ANDROIDTV_DOMAIN = 'androidtv' ANDROIDTV_DOMAIN = "androidtv"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SUPPORT_ANDROIDTV = SUPPORT_PAUSE | SUPPORT_PLAY | \ SUPPORT_ANDROIDTV = (
SUPPORT_TURN_ON | SUPPORT_TURN_OFF | SUPPORT_PREVIOUS_TRACK | \ SUPPORT_PAUSE
SUPPORT_NEXT_TRACK | SUPPORT_STOP | SUPPORT_VOLUME_MUTE | \ | SUPPORT_PLAY
SUPPORT_VOLUME_STEP | SUPPORT_TURN_ON
| SUPPORT_TURN_OFF
| SUPPORT_PREVIOUS_TRACK
| SUPPORT_NEXT_TRACK
| SUPPORT_STOP
| SUPPORT_VOLUME_MUTE
| SUPPORT_VOLUME_STEP
)
SUPPORT_FIRETV = SUPPORT_PAUSE | SUPPORT_PLAY | \ SUPPORT_FIRETV = (
SUPPORT_TURN_ON | SUPPORT_TURN_OFF | SUPPORT_PREVIOUS_TRACK | \ SUPPORT_PAUSE
SUPPORT_NEXT_TRACK | SUPPORT_SELECT_SOURCE | SUPPORT_STOP | SUPPORT_PLAY
| SUPPORT_TURN_ON
| SUPPORT_TURN_OFF
| SUPPORT_PREVIOUS_TRACK
| SUPPORT_NEXT_TRACK
| SUPPORT_SELECT_SOURCE
| SUPPORT_STOP
)
CONF_ADBKEY = 'adbkey' CONF_ADBKEY = "adbkey"
CONF_ADB_SERVER_IP = 'adb_server_ip' CONF_ADB_SERVER_IP = "adb_server_ip"
CONF_ADB_SERVER_PORT = 'adb_server_port' CONF_ADB_SERVER_PORT = "adb_server_port"
CONF_APPS = 'apps' CONF_APPS = "apps"
CONF_GET_SOURCES = 'get_sources' CONF_GET_SOURCES = "get_sources"
CONF_TURN_ON_COMMAND = 'turn_on_command' CONF_TURN_ON_COMMAND = "turn_on_command"
CONF_TURN_OFF_COMMAND = 'turn_off_command' CONF_TURN_OFF_COMMAND = "turn_off_command"
DEFAULT_NAME = 'Android TV' DEFAULT_NAME = "Android TV"
DEFAULT_PORT = 5555 DEFAULT_PORT = 5555
DEFAULT_ADB_SERVER_PORT = 5037 DEFAULT_ADB_SERVER_PORT = 5037
DEFAULT_GET_SOURCES = True DEFAULT_GET_SOURCES = True
DEFAULT_DEVICE_CLASS = 'auto' DEFAULT_DEVICE_CLASS = "auto"
DEVICE_ANDROIDTV = 'androidtv' DEVICE_ANDROIDTV = "androidtv"
DEVICE_FIRETV = 'firetv' DEVICE_FIRETV = "firetv"
DEVICE_CLASSES = [DEFAULT_DEVICE_CLASS, DEVICE_ANDROIDTV, DEVICE_FIRETV] DEVICE_CLASSES = [DEFAULT_DEVICE_CLASS, DEVICE_ANDROIDTV, DEVICE_FIRETV]
SERVICE_ADB_COMMAND = 'adb_command' SERVICE_ADB_COMMAND = "adb_command"
SERVICE_ADB_COMMAND_SCHEMA = vol.Schema({ SERVICE_ADB_COMMAND_SCHEMA = vol.Schema(
vol.Required(ATTR_ENTITY_ID): cv.entity_ids, {vol.Required(ATTR_ENTITY_ID): cv.entity_ids, vol.Required(ATTR_COMMAND): cv.string}
vol.Required(ATTR_COMMAND): cv.string, )
})
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_HOST): cv.string, {
vol.Optional(CONF_DEVICE_CLASS, default=DEFAULT_DEVICE_CLASS): vol.Required(CONF_HOST): cv.string,
vol.In(DEVICE_CLASSES), vol.Optional(CONF_DEVICE_CLASS, default=DEFAULT_DEVICE_CLASS): vol.In(
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, DEVICE_CLASSES
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port, ),
vol.Optional(CONF_ADBKEY): cv.isfile, vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
vol.Optional(CONF_ADB_SERVER_IP): cv.string, vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
vol.Optional(CONF_ADB_SERVER_PORT, default=DEFAULT_ADB_SERVER_PORT): vol.Optional(CONF_ADBKEY): cv.isfile,
cv.port, vol.Optional(CONF_ADB_SERVER_IP): cv.string,
vol.Optional(CONF_GET_SOURCES, default=DEFAULT_GET_SOURCES): cv.boolean, vol.Optional(CONF_ADB_SERVER_PORT, default=DEFAULT_ADB_SERVER_PORT): cv.port,
vol.Optional(CONF_APPS, default=dict()): vol.Optional(CONF_GET_SOURCES, default=DEFAULT_GET_SOURCES): cv.boolean,
vol.Schema({cv.string: cv.string}), vol.Optional(CONF_APPS, default=dict()): vol.Schema({cv.string: cv.string}),
vol.Optional(CONF_TURN_ON_COMMAND): cv.string, vol.Optional(CONF_TURN_ON_COMMAND): cv.string,
vol.Optional(CONF_TURN_OFF_COMMAND): cv.string vol.Optional(CONF_TURN_OFF_COMMAND): cv.string,
}) }
)
# Translate from `AndroidTV` / `FireTV` reported state to HA state. # Translate from `AndroidTV` / `FireTV` reported state to HA state.
ANDROIDTV_STATES = {'off': STATE_OFF, ANDROIDTV_STATES = {
'idle': STATE_IDLE, "off": STATE_OFF,
'standby': STATE_STANDBY, "idle": STATE_IDLE,
'playing': STATE_PLAYING, "standby": STATE_STANDBY,
'paused': STATE_PAUSED} "playing": STATE_PLAYING,
"paused": STATE_PAUSED,
}
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
@ -86,14 +118,15 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
hass.data.setdefault(ANDROIDTV_DOMAIN, {}) hass.data.setdefault(ANDROIDTV_DOMAIN, {})
host = '{0}:{1}'.format(config[CONF_HOST], config[CONF_PORT]) host = "{0}:{1}".format(config[CONF_HOST], config[CONF_PORT])
if CONF_ADB_SERVER_IP not in config: if CONF_ADB_SERVER_IP not in config:
# Use "python-adb" (Python ADB implementation) # Use "python-adb" (Python ADB implementation)
adb_log = "using Python ADB implementation " adb_log = "using Python ADB implementation "
if CONF_ADBKEY in config: if CONF_ADBKEY in config:
aftv = setup(host, config[CONF_ADBKEY], aftv = setup(
device_class=config[CONF_DEVICE_CLASS]) host, config[CONF_ADBKEY], device_class=config[CONF_DEVICE_CLASS]
)
adb_log += "with adbkey='{0}'".format(config[CONF_ADBKEY]) adb_log += "with adbkey='{0}'".format(config[CONF_ADBKEY])
else: else:
@ -101,44 +134,52 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
adb_log += "without adbkey authentication" adb_log += "without adbkey authentication"
else: else:
# Use "pure-python-adb" (communicate with ADB server) # Use "pure-python-adb" (communicate with ADB server)
aftv = setup(host, adb_server_ip=config[CONF_ADB_SERVER_IP], aftv = setup(
adb_server_port=config[CONF_ADB_SERVER_PORT], host,
device_class=config[CONF_DEVICE_CLASS]) adb_server_ip=config[CONF_ADB_SERVER_IP],
adb_server_port=config[CONF_ADB_SERVER_PORT],
device_class=config[CONF_DEVICE_CLASS],
)
adb_log = "using ADB server at {0}:{1}".format( adb_log = "using ADB server at {0}:{1}".format(
config[CONF_ADB_SERVER_IP], config[CONF_ADB_SERVER_PORT]) config[CONF_ADB_SERVER_IP], config[CONF_ADB_SERVER_PORT]
)
if not aftv.available: if not aftv.available:
# Determine the name that will be used for the device in the log # Determine the name that will be used for the device in the log
if CONF_NAME in config: if CONF_NAME in config:
device_name = config[CONF_NAME] device_name = config[CONF_NAME]
elif config[CONF_DEVICE_CLASS] == DEVICE_ANDROIDTV: elif config[CONF_DEVICE_CLASS] == DEVICE_ANDROIDTV:
device_name = 'Android TV device' device_name = "Android TV device"
elif config[CONF_DEVICE_CLASS] == DEVICE_FIRETV: elif config[CONF_DEVICE_CLASS] == DEVICE_FIRETV:
device_name = 'Fire TV device' device_name = "Fire TV device"
else: else:
device_name = 'Android TV / Fire TV device' device_name = "Android TV / Fire TV device"
_LOGGER.warning("Could not connect to %s at %s %s", _LOGGER.warning("Could not connect to %s at %s %s", device_name, host, adb_log)
device_name, host, adb_log)
raise PlatformNotReady raise PlatformNotReady
if host in hass.data[ANDROIDTV_DOMAIN]: if host in hass.data[ANDROIDTV_DOMAIN]:
_LOGGER.warning("Platform already setup on %s, skipping", host) _LOGGER.warning("Platform already setup on %s, skipping", host)
else: else:
if aftv.DEVICE_CLASS == DEVICE_ANDROIDTV: if aftv.DEVICE_CLASS == DEVICE_ANDROIDTV:
device = AndroidTVDevice(aftv, config[CONF_NAME], device = AndroidTVDevice(
config[CONF_APPS], aftv,
config.get(CONF_TURN_ON_COMMAND), config[CONF_NAME],
config.get(CONF_TURN_OFF_COMMAND)) config[CONF_APPS],
device_name = config[CONF_NAME] if CONF_NAME in config \ config.get(CONF_TURN_ON_COMMAND),
else 'Android TV' config.get(CONF_TURN_OFF_COMMAND),
)
device_name = config[CONF_NAME] if CONF_NAME in config else "Android TV"
else: else:
device = FireTVDevice(aftv, config[CONF_NAME], config[CONF_APPS], device = FireTVDevice(
config[CONF_GET_SOURCES], aftv,
config.get(CONF_TURN_ON_COMMAND), config[CONF_NAME],
config.get(CONF_TURN_OFF_COMMAND)) config[CONF_APPS],
device_name = config[CONF_NAME] if CONF_NAME in config \ config[CONF_GET_SOURCES],
else 'Fire TV' config.get(CONF_TURN_ON_COMMAND),
config.get(CONF_TURN_OFF_COMMAND),
)
device_name = config[CONF_NAME] if CONF_NAME in config else "Fire TV"
add_entities([device]) add_entities([device])
_LOGGER.debug("Setup %s at %s%s", device_name, host, adb_log) _LOGGER.debug("Setup %s at %s%s", device_name, host, adb_log)
@ -151,26 +192,38 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
"""Dispatch service calls to target entities.""" """Dispatch service calls to target entities."""
cmd = service.data.get(ATTR_COMMAND) cmd = service.data.get(ATTR_COMMAND)
entity_id = service.data.get(ATTR_ENTITY_ID) entity_id = service.data.get(ATTR_ENTITY_ID)
target_devices = [dev for dev in hass.data[ANDROIDTV_DOMAIN].values() target_devices = [
if dev.entity_id in entity_id] dev
for dev in hass.data[ANDROIDTV_DOMAIN].values()
if dev.entity_id in entity_id
]
for target_device in target_devices: for target_device in target_devices:
output = target_device.adb_command(cmd) output = target_device.adb_command(cmd)
# log the output, if there is any # log the output, if there is any
if output: if output:
_LOGGER.info("Output of command '%s' from '%s': %s", _LOGGER.info(
cmd, target_device.entity_id, output) "Output of command '%s' from '%s': %s",
cmd,
target_device.entity_id,
output,
)
hass.services.register(ANDROIDTV_DOMAIN, SERVICE_ADB_COMMAND, hass.services.register(
service_adb_command, ANDROIDTV_DOMAIN,
schema=SERVICE_ADB_COMMAND_SCHEMA) SERVICE_ADB_COMMAND,
service_adb_command,
schema=SERVICE_ADB_COMMAND_SCHEMA,
)
def adb_decorator(override_available=False): def adb_decorator(override_available=False):
"""Send an ADB command if the device is available and catch exceptions.""" """Send an ADB command if the device is available and catch exceptions."""
def _adb_decorator(func): def _adb_decorator(func):
"""Wait if previous ADB commands haven't finished.""" """Wait if previous ADB commands haven't finished."""
@functools.wraps(func) @functools.wraps(func)
def _adb_exception_catcher(self, *args, **kwargs): def _adb_exception_catcher(self, *args, **kwargs):
# If the device is unavailable, don't do anything # If the device is unavailable, don't do anything
@ -182,7 +235,9 @@ def adb_decorator(override_available=False):
except self.exceptions as err: except self.exceptions as err:
_LOGGER.error( _LOGGER.error(
"Failed to execute an ADB command. ADB connection re-" "Failed to execute an ADB command. ADB connection re-"
"establishing attempt in the next update. Error: %s", err) "establishing attempt in the next update. Error: %s",
err,
)
self._available = False # pylint: disable=protected-access self._available = False # pylint: disable=protected-access
return None return None
@ -194,8 +249,7 @@ def adb_decorator(override_available=False):
class ADBDevice(MediaPlayerDevice): class ADBDevice(MediaPlayerDevice):
"""Representation of an Android TV or Fire TV device.""" """Representation of an Android TV or Fire TV device."""
def __init__(self, aftv, name, apps, turn_on_command, def __init__(self, aftv, name, apps, turn_on_command, turn_off_command):
turn_off_command):
"""Initialize the Android TV / Fire TV device.""" """Initialize the Android TV / Fire TV device."""
from androidtv.constants import APPS, KEYS from androidtv.constants import APPS, KEYS
@ -211,15 +265,23 @@ class ADBDevice(MediaPlayerDevice):
# ADB exceptions to catch # ADB exceptions to catch
if not self.aftv.adb_server_ip: if not self.aftv.adb_server_ip:
# Using "python-adb" (Python ADB implementation) # Using "python-adb" (Python ADB implementation)
from adb.adb_protocol import (InvalidChecksumError, from adb.adb_protocol import (
InvalidCommandError, InvalidChecksumError,
InvalidResponseError) InvalidCommandError,
InvalidResponseError,
)
from adb.usb_exceptions import TcpTimeoutException from adb.usb_exceptions import TcpTimeoutException
self.exceptions = (AttributeError, BrokenPipeError, TypeError, self.exceptions = (
ValueError, InvalidChecksumError, AttributeError,
InvalidCommandError, InvalidResponseError, BrokenPipeError,
TcpTimeoutException) TypeError,
ValueError,
InvalidChecksumError,
InvalidCommandError,
InvalidResponseError,
TcpTimeoutException,
)
else: else:
# Using "pure-python-adb" (communicate with ADB server) # Using "pure-python-adb" (communicate with ADB server)
self.exceptions = (ConnectionResetError, RuntimeError) self.exceptions = (ConnectionResetError, RuntimeError)
@ -248,7 +310,7 @@ class ADBDevice(MediaPlayerDevice):
@property @property
def device_state_attributes(self): def device_state_attributes(self):
"""Provide the last ADB command's response as an attribute.""" """Provide the last ADB command's response as an attribute."""
return {'adb_response': self._adb_response} return {"adb_response": self._adb_response}
@property @property
def name(self): def name(self):
@ -311,12 +373,12 @@ class ADBDevice(MediaPlayerDevice):
"""Send an ADB command to an Android TV / Fire TV device.""" """Send an ADB command to an Android TV / Fire TV device."""
key = self._keys.get(cmd) key = self._keys.get(cmd)
if key: if key:
self.aftv.adb_shell('input keyevent {}'.format(key)) self.aftv.adb_shell("input keyevent {}".format(key))
self._adb_response = None self._adb_response = None
self.schedule_update_ha_state() self.schedule_update_ha_state()
return return
if cmd == 'GET_PROPERTIES': if cmd == "GET_PROPERTIES":
self._adb_response = str(self.aftv.get_properties_dict()) self._adb_response = str(self.aftv.get_properties_dict())
self.schedule_update_ha_state() self.schedule_update_ha_state()
return self._adb_response return self._adb_response
@ -334,16 +396,14 @@ class ADBDevice(MediaPlayerDevice):
class AndroidTVDevice(ADBDevice): class AndroidTVDevice(ADBDevice):
"""Representation of an Android TV device.""" """Representation of an Android TV device."""
def __init__(self, aftv, name, apps, turn_on_command, def __init__(self, aftv, name, apps, turn_on_command, turn_off_command):
turn_off_command):
"""Initialize the Android TV device.""" """Initialize the Android TV device."""
super().__init__(aftv, name, apps, turn_on_command, super().__init__(aftv, name, apps, turn_on_command, turn_off_command)
turn_off_command)
self._device = None self._device = None
self._device_properties = self.aftv.device_properties self._device_properties = self.aftv.device_properties
self._is_volume_muted = None self._is_volume_muted = None
self._unique_id = self._device_properties.get('serialno') self._unique_id = self._device_properties.get("serialno")
self._volume_level = None self._volume_level = None
@adb_decorator(override_available=True) @adb_decorator(override_available=True)
@ -362,8 +422,9 @@ class AndroidTVDevice(ADBDevice):
return return
# Get the updated state and attributes. # Get the updated state and attributes.
state, self._current_app, self._device, self._is_volume_muted, \ state, self._current_app, self._device, self._is_volume_muted, self._volume_level = (
self._volume_level = self.aftv.update() self.aftv.update()
)
self._state = ANDROIDTV_STATES[state] self._state = ANDROIDTV_STATES[state]
@ -416,11 +477,11 @@ class AndroidTVDevice(ADBDevice):
class FireTVDevice(ADBDevice): class FireTVDevice(ADBDevice):
"""Representation of a Fire TV device.""" """Representation of a Fire TV device."""
def __init__(self, aftv, name, apps, get_sources, def __init__(
turn_on_command, turn_off_command): self, aftv, name, apps, get_sources, turn_on_command, turn_off_command
):
"""Initialize the Fire TV device.""" """Initialize the Fire TV device."""
super().__init__(aftv, name, apps, turn_on_command, super().__init__(aftv, name, apps, turn_on_command, turn_off_command)
turn_off_command)
self._get_sources = get_sources self._get_sources = get_sources
self._running_apps = None self._running_apps = None
@ -441,8 +502,9 @@ class FireTVDevice(ADBDevice):
return return
# Get the `state`, `current_app`, and `running_apps`. # Get the `state`, `current_app`, and `running_apps`.
state, self._current_app, self._running_apps = \ state, self._current_app, self._running_apps = self.aftv.update(
self.aftv.update(self._get_sources) self._get_sources
)
self._state = ANDROIDTV_STATES[state] self._state = ANDROIDTV_STATES[state]
@ -474,7 +536,7 @@ class FireTVDevice(ADBDevice):
opening it. opening it.
""" """
if isinstance(source, str): if isinstance(source, str):
if not source.startswith('!'): if not source.startswith("!"):
self.aftv.launch_app(source) self.aftv.launch_app(source)
else: else:
self.aftv.stop_app(source[1:].lstrip()) self.aftv.stop_app(source[1:].lstrip())

View file

@ -6,24 +6,26 @@ from datetime import timedelta
import voluptuous as vol import voluptuous as vol
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.components.switch import (SwitchDevice, PLATFORM_SCHEMA) from homeassistant.components.switch import SwitchDevice, PLATFORM_SCHEMA
from homeassistant.const import (CONF_HOST, CONF_PASSWORD, CONF_USERNAME) from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME
from homeassistant.util import Throttle from homeassistant.util import Throttle
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
CONF_PORT_RECV = 'port_recv' CONF_PORT_RECV = "port_recv"
CONF_PORT_SEND = 'port_send' CONF_PORT_SEND = "port_send"
MIN_TIME_BETWEEN_UPDATES = timedelta(seconds=5) MIN_TIME_BETWEEN_UPDATES = timedelta(seconds=5)
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_PORT_RECV): cv.port, {
vol.Required(CONF_PORT_SEND): cv.port, vol.Required(CONF_PORT_RECV): cv.port,
vol.Required(CONF_USERNAME): cv.string, vol.Required(CONF_PORT_SEND): cv.port,
vol.Required(CONF_PASSWORD): cv.string, vol.Required(CONF_USERNAME): cv.string,
vol.Optional(CONF_HOST): cv.string, vol.Required(CONF_PASSWORD): cv.string,
}) vol.Optional(CONF_HOST): cv.string,
}
)
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
@ -38,8 +40,11 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
try: try:
master = DeviceMaster( master = DeviceMaster(
username=username, password=password, read_port=port_send, username=username,
write_port=port_recv) password=password,
read_port=port_send,
write_port=port_recv,
)
master.query(ip_addr=host) master.query(ip_addr=host)
except socket.error as ex: except socket.error as ex:
_LOGGER.error("Unable to discover PwrCtrl device: %s", str(ex)) _LOGGER.error("Unable to discover PwrCtrl device: %s", str(ex))
@ -49,8 +54,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
for device in master.devices.values(): for device in master.devices.values():
parent_device = PwrCtrlDevice(device) parent_device = PwrCtrlDevice(device)
devices.extend( devices.extend(
PwrCtrlSwitch(switch, parent_device) PwrCtrlSwitch(switch, parent_device) for switch in device.switches.values()
for switch in device.switches.values()
) )
add_entities(devices) add_entities(devices)
@ -72,9 +76,8 @@ class PwrCtrlSwitch(SwitchDevice):
@property @property
def unique_id(self): def unique_id(self):
"""Return the unique ID of the device.""" """Return the unique ID of the device."""
return '{device}-{switch_idx}'.format( return "{device}-{switch_idx}".format(
device=self._port.device.host, device=self._port.device.host, switch_idx=self._port.get_index()
switch_idx=self._port.get_index()
) )
@property @property

View file

@ -3,34 +3,48 @@ import logging
import voluptuous as vol import voluptuous as vol
from homeassistant.components.media_player import ( from homeassistant.components.media_player import MediaPlayerDevice, PLATFORM_SCHEMA
MediaPlayerDevice, PLATFORM_SCHEMA)
from homeassistant.components.media_player.const import ( from homeassistant.components.media_player.const import (
SUPPORT_SELECT_SOURCE, SUPPORT_TURN_OFF, SUPPORT_TURN_ON, SUPPORT_SELECT_SOURCE,
SUPPORT_VOLUME_MUTE, SUPPORT_VOLUME_SET) SUPPORT_TURN_OFF,
SUPPORT_TURN_ON,
SUPPORT_VOLUME_MUTE,
SUPPORT_VOLUME_SET,
)
from homeassistant.const import ( from homeassistant.const import (
CONF_HOST, CONF_NAME, CONF_PORT, EVENT_HOMEASSISTANT_STOP, STATE_OFF, CONF_HOST,
STATE_ON) CONF_NAME,
CONF_PORT,
EVENT_HOMEASSISTANT_STOP,
STATE_OFF,
STATE_ON,
)
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DOMAIN = 'anthemav' DOMAIN = "anthemav"
DEFAULT_PORT = 14999 DEFAULT_PORT = 14999
SUPPORT_ANTHEMAV = SUPPORT_VOLUME_SET | SUPPORT_VOLUME_MUTE | \ SUPPORT_ANTHEMAV = (
SUPPORT_TURN_ON | SUPPORT_TURN_OFF | SUPPORT_SELECT_SOURCE SUPPORT_VOLUME_SET
| SUPPORT_VOLUME_MUTE
| SUPPORT_TURN_ON
| SUPPORT_TURN_OFF
| SUPPORT_SELECT_SOURCE
)
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Required(CONF_HOST): cv.string, {
vol.Optional(CONF_NAME): cv.string, vol.Required(CONF_HOST): cv.string,
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port, vol.Optional(CONF_NAME): cv.string,
}) vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
}
)
async def async_setup_platform(hass, config, async_add_entities, async def async_setup_platform(hass, config, async_add_entities, discovery_info=None):
discovery_info=None):
"""Set up our socket to the AVR.""" """Set up our socket to the AVR."""
import anthemav import anthemav
@ -47,8 +61,8 @@ async def async_setup_platform(hass, config, async_add_entities,
hass.async_create_task(device.async_update_ha_state()) hass.async_create_task(device.async_update_ha_state())
avr = await anthemav.Connection.create( avr = await anthemav.Connection.create(
host=host, port=port, host=host, port=port, update_callback=async_anthemav_update_callback
update_callback=async_anthemav_update_callback) )
device = AnthemAVR(avr, name) device = AnthemAVR(avr, name)
@ -84,12 +98,12 @@ class AnthemAVR(MediaPlayerDevice):
@property @property
def name(self): def name(self):
"""Return name of device.""" """Return name of device."""
return self._name or self._lookup('model') return self._name or self._lookup("model")
@property @property
def state(self): def state(self):
"""Return state of power on/off.""" """Return state of power on/off."""
pwrstate = self._lookup('power') pwrstate = self._lookup("power")
if pwrstate is True: if pwrstate is True:
return STATE_ON return STATE_ON
@ -100,64 +114,64 @@ class AnthemAVR(MediaPlayerDevice):
@property @property
def is_volume_muted(self): def is_volume_muted(self):
"""Return boolean reflecting mute state on device.""" """Return boolean reflecting mute state on device."""
return self._lookup('mute', False) return self._lookup("mute", False)
@property @property
def volume_level(self): def volume_level(self):
"""Return volume level from 0 to 1.""" """Return volume level from 0 to 1."""
return self._lookup('volume_as_percentage', 0.0) return self._lookup("volume_as_percentage", 0.0)
@property @property
def media_title(self): def media_title(self):
"""Return current input name (closest we have to media title).""" """Return current input name (closest we have to media title)."""
return self._lookup('input_name', 'No Source') return self._lookup("input_name", "No Source")
@property @property
def app_name(self): def app_name(self):
"""Return details about current video and audio stream.""" """Return details about current video and audio stream."""
return self._lookup('video_input_resolution_text', '') + ' ' \ return (
+ self._lookup('audio_input_name', '') self._lookup("video_input_resolution_text", "")
+ " "
+ self._lookup("audio_input_name", "")
)
@property @property
def source(self): def source(self):
"""Return currently selected input.""" """Return currently selected input."""
return self._lookup('input_name', "Unknown") return self._lookup("input_name", "Unknown")
@property @property
def source_list(self): def source_list(self):
"""Return all active, configured inputs.""" """Return all active, configured inputs."""
return self._lookup('input_list', ["Unknown"]) return self._lookup("input_list", ["Unknown"])
async def async_select_source(self, source): async def async_select_source(self, source):
"""Change AVR to the designated source (by name).""" """Change AVR to the designated source (by name)."""
self._update_avr('input_name', source) self._update_avr("input_name", source)
async def async_turn_off(self): async def async_turn_off(self):
"""Turn AVR power off.""" """Turn AVR power off."""
self._update_avr('power', False) self._update_avr("power", False)
async def async_turn_on(self): async def async_turn_on(self):
"""Turn AVR power on.""" """Turn AVR power on."""
self._update_avr('power', True) self._update_avr("power", True)
async def async_set_volume_level(self, volume): async def async_set_volume_level(self, volume):
"""Set AVR volume (0 to 1).""" """Set AVR volume (0 to 1)."""
self._update_avr('volume_as_percentage', volume) self._update_avr("volume_as_percentage", volume)
async def async_mute_volume(self, mute): async def async_mute_volume(self, mute):
"""Engage AVR mute.""" """Engage AVR mute."""
self._update_avr('mute', mute) self._update_avr("mute", mute)
def _update_avr(self, propname, value): def _update_avr(self, propname, value):
"""Update a property in the AVR.""" """Update a property in the AVR."""
_LOGGER.info( _LOGGER.info("Sending command to AVR: set %s to %s", propname, str(value))
"Sending command to AVR: set %s to %s", propname, str(value))
setattr(self.avr.protocol, propname, value) setattr(self.avr.protocol, propname, value)
@property @property
def dump_avrdata(self): def dump_avrdata(self):
"""Return state of avr object for debugging forensics.""" """Return state of avr object for debugging forensics."""
attrs = vars(self) attrs = vars(self)
return( return "dump_avrdata: " + ", ".join("%s: %s" % item for item in attrs.items())
'dump_avrdata: '
+ ', '.join('%s: %s' % item for item in attrs.items()))

View file

@ -7,26 +7,36 @@ from aiokafka import AIOKafkaProducer
import voluptuous as vol import voluptuous as vol
from homeassistant.const import ( from homeassistant.const import (
CONF_IP_ADDRESS, CONF_PORT, EVENT_HOMEASSISTANT_STOP, EVENT_STATE_CHANGED, CONF_IP_ADDRESS,
STATE_UNAVAILABLE, STATE_UNKNOWN) CONF_PORT,
EVENT_HOMEASSISTANT_STOP,
EVENT_STATE_CHANGED,
STATE_UNAVAILABLE,
STATE_UNKNOWN,
)
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entityfilter import FILTER_SCHEMA from homeassistant.helpers.entityfilter import FILTER_SCHEMA
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DOMAIN = 'apache_kafka' DOMAIN = "apache_kafka"
CONF_FILTER = 'filter' CONF_FILTER = "filter"
CONF_TOPIC = 'topic' CONF_TOPIC = "topic"
CONFIG_SCHEMA = vol.Schema({ CONFIG_SCHEMA = vol.Schema(
DOMAIN: vol.Schema({ {
vol.Required(CONF_IP_ADDRESS): cv.string, DOMAIN: vol.Schema(
vol.Required(CONF_PORT): cv.port, {
vol.Required(CONF_TOPIC): cv.string, vol.Required(CONF_IP_ADDRESS): cv.string,
vol.Optional(CONF_FILTER, default={}): FILTER_SCHEMA, vol.Required(CONF_PORT): cv.port,
}), vol.Required(CONF_TOPIC): cv.string,
}, extra=vol.ALLOW_EXTRA) vol.Optional(CONF_FILTER, default={}): FILTER_SCHEMA,
}
)
},
extra=vol.ALLOW_EXTRA,
)
async def async_setup(hass, config): async def async_setup(hass, config):
@ -38,7 +48,8 @@ async def async_setup(hass, config):
conf[CONF_IP_ADDRESS], conf[CONF_IP_ADDRESS],
conf[CONF_PORT], conf[CONF_PORT],
conf[CONF_TOPIC], conf[CONF_TOPIC],
conf[CONF_FILTER]) conf[CONF_FILTER],
)
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, kafka.shutdown()) hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, kafka.shutdown())
@ -63,13 +74,7 @@ class DateTimeJSONEncoder(json.JSONEncoder):
class KafkaManager: class KafkaManager:
"""Define a manager to buffer events to Kafka.""" """Define a manager to buffer events to Kafka."""
def __init__( def __init__(self, hass, ip_address, port, topic, entities_filter):
self,
hass,
ip_address,
port,
topic,
entities_filter):
"""Initialize.""" """Initialize."""
self._encoder = DateTimeJSONEncoder() self._encoder = DateTimeJSONEncoder()
self._entities_filter = entities_filter self._entities_filter = entities_filter
@ -83,16 +88,17 @@ class KafkaManager:
def _encode_event(self, event): def _encode_event(self, event):
"""Translate events into a binary JSON payload.""" """Translate events into a binary JSON payload."""
state = event.data.get('new_state') state = event.data.get("new_state")
if (state is None if (
or state.state in (STATE_UNKNOWN, '', STATE_UNAVAILABLE) state is None
or not self._entities_filter(state.entity_id)): or state.state in (STATE_UNKNOWN, "", STATE_UNAVAILABLE)
or not self._entities_filter(state.entity_id)
):
return return
return json.dumps( return json.dumps(obj=state.as_dict(), default=self._encoder.encode).encode(
obj=state.as_dict(), "utf-8"
default=self._encoder.encode )
).encode('utf-8')
async def start(self): async def start(self):
"""Start the Kafka manager.""" """Start the Kafka manager."""

View file

@ -4,31 +4,36 @@ from datetime import timedelta
import voluptuous as vol import voluptuous as vol
from homeassistant.const import (CONF_HOST, CONF_PORT) from homeassistant.const import CONF_HOST, CONF_PORT
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.util import Throttle from homeassistant.util import Throttle
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
CONF_TYPE = 'type' CONF_TYPE = "type"
DATA = None DATA = None
DEFAULT_HOST = 'localhost' DEFAULT_HOST = "localhost"
DEFAULT_PORT = 3551 DEFAULT_PORT = 3551
DOMAIN = 'apcupsd' DOMAIN = "apcupsd"
KEY_STATUS = 'STATUS' KEY_STATUS = "STATUS"
MIN_TIME_BETWEEN_UPDATES = timedelta(seconds=60) MIN_TIME_BETWEEN_UPDATES = timedelta(seconds=60)
VALUE_ONLINE = 'ONLINE' VALUE_ONLINE = "ONLINE"
CONFIG_SCHEMA = vol.Schema({ CONFIG_SCHEMA = vol.Schema(
DOMAIN: vol.Schema({ {
vol.Optional(CONF_HOST, default=DEFAULT_HOST): cv.string, DOMAIN: vol.Schema(
vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port, {
}), vol.Optional(CONF_HOST, default=DEFAULT_HOST): cv.string,
}, extra=vol.ALLOW_EXTRA) vol.Optional(CONF_PORT, default=DEFAULT_PORT): cv.port,
}
)
},
extra=vol.ALLOW_EXTRA,
)
def setup(hass, config): def setup(hass, config):
@ -60,6 +65,7 @@ class APCUPSdData:
def __init__(self, host, port): def __init__(self, host, port):
"""Initialize the data object.""" """Initialize the data object."""
from apcaccess import status from apcaccess import status
self._host = host self._host = host
self._port = port self._port = port
self._status = None self._status = None

View file

@ -1,16 +1,15 @@
"""Support for tracking the online status of a UPS.""" """Support for tracking the online status of a UPS."""
import voluptuous as vol import voluptuous as vol
from homeassistant.components.binary_sensor import ( from homeassistant.components.binary_sensor import BinarySensorDevice, PLATFORM_SCHEMA
BinarySensorDevice, PLATFORM_SCHEMA)
from homeassistant.const import CONF_NAME from homeassistant.const import CONF_NAME
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.components import apcupsd from homeassistant.components import apcupsd
DEFAULT_NAME = 'UPS Online Status' DEFAULT_NAME = "UPS Online Status"
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, {vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string}
}) )
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):

Some files were not shown because too many files have changed in this diff Show more