Black
This commit is contained in:
parent
da05dfe708
commit
4de97abc3a
2676 changed files with 163166 additions and 140084 deletions
374
tests/common.py
374
tests/common.py
|
@ -19,31 +19,48 @@ import homeassistant.util.yaml.loader as yaml_loader
|
|||
|
||||
from homeassistant import auth, config_entries, core as ha, loader
|
||||
from homeassistant.auth import (
|
||||
models as auth_models, auth_store, providers as auth_providers,
|
||||
permissions as auth_permissions)
|
||||
models as auth_models,
|
||||
auth_store,
|
||||
providers as auth_providers,
|
||||
permissions as auth_permissions,
|
||||
)
|
||||
from homeassistant.auth.permissions import system_policies
|
||||
from homeassistant.components import mqtt, recorder
|
||||
from homeassistant.config import async_process_component_config
|
||||
from homeassistant.const import (
|
||||
ATTR_DISCOVERED, ATTR_SERVICE, DEVICE_DEFAULT_NAME,
|
||||
EVENT_HOMEASSISTANT_CLOSE, EVENT_PLATFORM_DISCOVERED, EVENT_STATE_CHANGED,
|
||||
EVENT_TIME_CHANGED, SERVER_PORT, STATE_ON, STATE_OFF)
|
||||
ATTR_DISCOVERED,
|
||||
ATTR_SERVICE,
|
||||
DEVICE_DEFAULT_NAME,
|
||||
EVENT_HOMEASSISTANT_CLOSE,
|
||||
EVENT_PLATFORM_DISCOVERED,
|
||||
EVENT_STATE_CHANGED,
|
||||
EVENT_TIME_CHANGED,
|
||||
SERVER_PORT,
|
||||
STATE_ON,
|
||||
STATE_OFF,
|
||||
)
|
||||
from homeassistant.core import State
|
||||
from homeassistant.helpers import (
|
||||
area_registry, device_registry, entity, entity_platform, entity_registry,
|
||||
intent, restore_state, storage)
|
||||
area_registry,
|
||||
device_registry,
|
||||
entity,
|
||||
entity_platform,
|
||||
entity_registry,
|
||||
intent,
|
||||
restore_state,
|
||||
storage,
|
||||
)
|
||||
from homeassistant.helpers.json import JSONEncoder
|
||||
from homeassistant.setup import async_setup_component, setup_component
|
||||
from homeassistant.util.unit_system import METRIC_SYSTEM
|
||||
from homeassistant.util.async_ import (
|
||||
run_callback_threadsafe, run_coroutine_threadsafe)
|
||||
from homeassistant.util.async_ import run_callback_threadsafe, run_coroutine_threadsafe
|
||||
|
||||
|
||||
_TEST_INSTANCE_PORT = SERVER_PORT
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
INSTANCES = []
|
||||
CLIENT_ID = 'https://example.com/app'
|
||||
CLIENT_REDIRECT_URI = 'https://example.com/app/callback'
|
||||
CLIENT_ID = "https://example.com/app"
|
||||
CLIENT_REDIRECT_URI = "https://example.com/app/callback"
|
||||
|
||||
|
||||
def threadsafe_callback_factory(func):
|
||||
|
@ -51,12 +68,14 @@ def threadsafe_callback_factory(func):
|
|||
|
||||
Callback needs to have `hass` as first argument.
|
||||
"""
|
||||
|
||||
@ft.wraps(func)
|
||||
def threadsafe(*args, **kwargs):
|
||||
"""Call func threadsafe."""
|
||||
hass = args[0]
|
||||
return run_callback_threadsafe(
|
||||
hass.loop, ft.partial(func, *args, **kwargs)).result()
|
||||
hass.loop, ft.partial(func, *args, **kwargs)
|
||||
).result()
|
||||
|
||||
return threadsafe
|
||||
|
||||
|
@ -66,19 +85,19 @@ def threadsafe_coroutine_factory(func):
|
|||
|
||||
Callback needs to have `hass` as first argument.
|
||||
"""
|
||||
|
||||
@ft.wraps(func)
|
||||
def threadsafe(*args, **kwargs):
|
||||
"""Call func threadsafe."""
|
||||
hass = args[0]
|
||||
return run_coroutine_threadsafe(
|
||||
func(*args, **kwargs), hass.loop).result()
|
||||
return run_coroutine_threadsafe(func(*args, **kwargs), hass.loop).result()
|
||||
|
||||
return threadsafe
|
||||
|
||||
|
||||
def get_test_config_dir(*add_path):
|
||||
"""Return a path to a test config dir."""
|
||||
return os.path.join(os.path.dirname(__file__), 'testing_config', *add_path)
|
||||
return os.path.join(os.path.dirname(__file__), "testing_config", *add_path)
|
||||
|
||||
|
||||
def get_test_home_assistant():
|
||||
|
@ -155,12 +174,12 @@ async def async_test_home_assistant(loop):
|
|||
hass.async_add_executor_job = async_add_executor_job
|
||||
hass.async_create_task = async_create_task
|
||||
|
||||
hass.config.location_name = 'test home'
|
||||
hass.config.location_name = "test home"
|
||||
hass.config.config_dir = get_test_config_dir()
|
||||
hass.config.latitude = 32.87336
|
||||
hass.config.longitude = -117.22743
|
||||
hass.config.elevation = 0
|
||||
hass.config.time_zone = date_util.get_time_zone('US/Pacific')
|
||||
hass.config.time_zone = date_util.get_time_zone("US/Pacific")
|
||||
hass.config.units = METRIC_SYSTEM
|
||||
hass.config.skip_pip = True
|
||||
|
||||
|
@ -176,8 +195,9 @@ async def async_test_home_assistant(loop):
|
|||
async def mock_async_start():
|
||||
"""Start the mocking."""
|
||||
# We only mock time during tests and we want to track tasks
|
||||
with patch('homeassistant.core._async_create_timer'), \
|
||||
patch.object(hass, 'async_stop_track_tasks'):
|
||||
with patch("homeassistant.core._async_create_timer"), patch.object(
|
||||
hass, "async_stop_track_tasks"
|
||||
):
|
||||
await orig_start()
|
||||
|
||||
hass.async_start = mock_async_start
|
||||
|
@ -214,8 +234,7 @@ def async_mock_service(hass, domain, service, schema=None):
|
|||
"""Mock service call."""
|
||||
calls.append(call)
|
||||
|
||||
hass.services.async_register(
|
||||
domain, service, mock_service_log, schema=schema)
|
||||
hass.services.async_register(domain, service, mock_service_log, schema=schema)
|
||||
|
||||
return calls
|
||||
|
||||
|
@ -246,9 +265,9 @@ def async_mock_intent(hass, intent_typ):
|
|||
def async_fire_mqtt_message(hass, topic, payload, qos=0, retain=False):
|
||||
"""Fire the MQTT message."""
|
||||
if isinstance(payload, str):
|
||||
payload = payload.encode('utf-8')
|
||||
payload = payload.encode("utf-8")
|
||||
msg = mqtt.Message(topic, payload, qos, retain)
|
||||
hass.data['mqtt']._mqtt_handle_message(msg)
|
||||
hass.data["mqtt"]._mqtt_handle_message(msg)
|
||||
|
||||
|
||||
fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message)
|
||||
|
@ -257,7 +276,7 @@ fire_mqtt_message = threadsafe_callback_factory(async_fire_mqtt_message)
|
|||
@ha.callback
|
||||
def async_fire_time_changed(hass, time):
|
||||
"""Fire a time changes event."""
|
||||
hass.bus.async_fire(EVENT_TIME_CHANGED, {'now': date_util.as_utc(time)})
|
||||
hass.bus.async_fire(EVENT_TIME_CHANGED, {"now": date_util.as_utc(time)})
|
||||
|
||||
|
||||
fire_time_changed = threadsafe_callback_factory(async_fire_time_changed)
|
||||
|
@ -265,37 +284,32 @@ fire_time_changed = threadsafe_callback_factory(async_fire_time_changed)
|
|||
|
||||
def fire_service_discovered(hass, service, info):
|
||||
"""Fire the MQTT message."""
|
||||
hass.bus.fire(EVENT_PLATFORM_DISCOVERED, {
|
||||
ATTR_SERVICE: service,
|
||||
ATTR_DISCOVERED: info
|
||||
})
|
||||
hass.bus.fire(
|
||||
EVENT_PLATFORM_DISCOVERED, {ATTR_SERVICE: service, ATTR_DISCOVERED: info}
|
||||
)
|
||||
|
||||
|
||||
@ha.callback
|
||||
def async_fire_service_discovered(hass, service, info):
|
||||
"""Fire the MQTT message."""
|
||||
hass.bus.async_fire(EVENT_PLATFORM_DISCOVERED, {
|
||||
ATTR_SERVICE: service,
|
||||
ATTR_DISCOVERED: info
|
||||
})
|
||||
hass.bus.async_fire(
|
||||
EVENT_PLATFORM_DISCOVERED, {ATTR_SERVICE: service, ATTR_DISCOVERED: info}
|
||||
)
|
||||
|
||||
|
||||
def load_fixture(filename):
|
||||
"""Load a fixture."""
|
||||
path = os.path.join(os.path.dirname(__file__), 'fixtures', filename)
|
||||
with open(path, encoding='utf-8') as fptr:
|
||||
path = os.path.join(os.path.dirname(__file__), "fixtures", filename)
|
||||
with open(path, encoding="utf-8") as fptr:
|
||||
return fptr.read()
|
||||
|
||||
|
||||
def mock_state_change_event(hass, new_state, old_state=None):
|
||||
"""Mock state change envent."""
|
||||
event_data = {
|
||||
'entity_id': new_state.entity_id,
|
||||
'new_state': new_state,
|
||||
}
|
||||
event_data = {"entity_id": new_state.entity_id, "new_state": new_state}
|
||||
|
||||
if old_state:
|
||||
event_data['old_state'] = old_state
|
||||
event_data["old_state"] = old_state
|
||||
|
||||
hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context)
|
||||
|
||||
|
@ -303,24 +317,23 @@ def mock_state_change_event(hass, new_state, old_state=None):
|
|||
async def async_mock_mqtt_component(hass, config=None):
|
||||
"""Mock the MQTT component."""
|
||||
if config is None:
|
||||
config = {mqtt.CONF_BROKER: 'mock-broker'}
|
||||
config = {mqtt.CONF_BROKER: "mock-broker"}
|
||||
|
||||
with patch('paho.mqtt.client.Client') as mock_client:
|
||||
with patch("paho.mqtt.client.Client") as mock_client:
|
||||
mock_client().connect.return_value = 0
|
||||
mock_client().subscribe.return_value = (0, 0)
|
||||
mock_client().unsubscribe.return_value = (0, 0)
|
||||
mock_client().publish.return_value = (0, 0)
|
||||
|
||||
result = await async_setup_component(hass, mqtt.DOMAIN, {
|
||||
mqtt.DOMAIN: config
|
||||
})
|
||||
result = await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: config})
|
||||
assert result
|
||||
await hass.async_block_till_done()
|
||||
|
||||
hass.data['mqtt'] = MagicMock(spec_set=hass.data['mqtt'],
|
||||
wraps=hass.data['mqtt'])
|
||||
hass.data["mqtt"] = MagicMock(
|
||||
spec_set=hass.data["mqtt"], wraps=hass.data["mqtt"]
|
||||
)
|
||||
|
||||
return hass.data['mqtt']
|
||||
return hass.data["mqtt"]
|
||||
|
||||
|
||||
mock_mqtt_component = threadsafe_coroutine_factory(async_mock_mqtt_component)
|
||||
|
@ -365,15 +378,11 @@ def mock_device_registry(hass, mock_entries=None):
|
|||
class MockGroup(auth_models.Group):
|
||||
"""Mock a group in Home Assistant."""
|
||||
|
||||
def __init__(self, id=None, name='Mock Group',
|
||||
policy=system_policies.ADMIN_POLICY):
|
||||
def __init__(self, id=None, name="Mock Group", policy=system_policies.ADMIN_POLICY):
|
||||
"""Mock a group."""
|
||||
kwargs = {
|
||||
'name': name,
|
||||
'policy': policy,
|
||||
}
|
||||
kwargs = {"name": name, "policy": policy}
|
||||
if id is not None:
|
||||
kwargs['id'] = id
|
||||
kwargs["id"] = id
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
@ -391,19 +400,26 @@ class MockGroup(auth_models.Group):
|
|||
class MockUser(auth_models.User):
|
||||
"""Mock a user in Home Assistant."""
|
||||
|
||||
def __init__(self, id=None, is_owner=False, is_active=True,
|
||||
name='Mock User', system_generated=False, groups=None):
|
||||
def __init__(
|
||||
self,
|
||||
id=None,
|
||||
is_owner=False,
|
||||
is_active=True,
|
||||
name="Mock User",
|
||||
system_generated=False,
|
||||
groups=None,
|
||||
):
|
||||
"""Initialize mock user."""
|
||||
kwargs = {
|
||||
'is_owner': is_owner,
|
||||
'is_active': is_active,
|
||||
'name': name,
|
||||
'system_generated': system_generated,
|
||||
'groups': groups or [],
|
||||
'perm_lookup': None,
|
||||
"is_owner": is_owner,
|
||||
"is_active": is_active,
|
||||
"name": name,
|
||||
"system_generated": system_generated,
|
||||
"groups": groups or [],
|
||||
"perm_lookup": None,
|
||||
}
|
||||
if id is not None:
|
||||
kwargs['id'] = id
|
||||
kwargs["id"] = id
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def add_to_hass(self, hass):
|
||||
|
@ -418,20 +434,20 @@ class MockUser(auth_models.User):
|
|||
|
||||
def mock_policy(self, policy):
|
||||
"""Mock a policy for a user."""
|
||||
self._permissions = auth_permissions.PolicyPermissions(
|
||||
policy, self.perm_lookup)
|
||||
self._permissions = auth_permissions.PolicyPermissions(policy, self.perm_lookup)
|
||||
|
||||
|
||||
async def register_auth_provider(hass, config):
|
||||
"""Register an auth provider."""
|
||||
provider = await auth_providers.auth_provider_from_config(
|
||||
hass, hass.auth._store, config)
|
||||
assert provider is not None, 'Invalid config specified'
|
||||
hass, hass.auth._store, config
|
||||
)
|
||||
assert provider is not None, "Invalid config specified"
|
||||
key = (provider.type, provider.id)
|
||||
providers = hass.auth._providers
|
||||
|
||||
if key in providers:
|
||||
raise ValueError('Provider already registered')
|
||||
raise ValueError("Provider already registered")
|
||||
|
||||
providers[key] = provider
|
||||
return provider
|
||||
|
@ -449,15 +465,25 @@ class MockModule:
|
|||
"""Representation of a fake module."""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def __init__(self, domain=None, dependencies=None, setup=None,
|
||||
requirements=None, config_schema=None, platform_schema=None,
|
||||
platform_schema_base=None, async_setup=None,
|
||||
async_setup_entry=None, async_unload_entry=None,
|
||||
async_migrate_entry=None, async_remove_entry=None,
|
||||
partial_manifest=None):
|
||||
def __init__(
|
||||
self,
|
||||
domain=None,
|
||||
dependencies=None,
|
||||
setup=None,
|
||||
requirements=None,
|
||||
config_schema=None,
|
||||
platform_schema=None,
|
||||
platform_schema_base=None,
|
||||
async_setup=None,
|
||||
async_setup_entry=None,
|
||||
async_unload_entry=None,
|
||||
async_migrate_entry=None,
|
||||
async_remove_entry=None,
|
||||
partial_manifest=None,
|
||||
):
|
||||
"""Initialize the mock module."""
|
||||
self.__name__ = 'homeassistant.components.{}'.format(domain)
|
||||
self.__file__ = 'homeassistant/components/{}'.format(domain)
|
||||
self.__name__ = "homeassistant.components.{}".format(domain)
|
||||
self.__file__ = "homeassistant/components/{}".format(domain)
|
||||
self.DOMAIN = domain
|
||||
self.DEPENDENCIES = dependencies or []
|
||||
self.REQUIREMENTS = requirements or []
|
||||
|
@ -499,20 +525,26 @@ class MockModule:
|
|||
"""Generate a mock manifest to represent this module."""
|
||||
return {
|
||||
**loader.manifest_from_legacy_module(self.DOMAIN, self),
|
||||
**(self._partial_manifest or {})
|
||||
**(self._partial_manifest or {}),
|
||||
}
|
||||
|
||||
|
||||
class MockPlatform:
|
||||
"""Provide a fake platform."""
|
||||
|
||||
__name__ = 'homeassistant.components.light.bla'
|
||||
__file__ = 'homeassistant/components/blah/light'
|
||||
__name__ = "homeassistant.components.light.bla"
|
||||
__file__ = "homeassistant/components/blah/light"
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def __init__(self, setup_platform=None, dependencies=None,
|
||||
platform_schema=None, async_setup_platform=None,
|
||||
async_setup_entry=None, scan_interval=None):
|
||||
def __init__(
|
||||
self,
|
||||
setup_platform=None,
|
||||
dependencies=None,
|
||||
platform_schema=None,
|
||||
async_setup_platform=None,
|
||||
async_setup_entry=None,
|
||||
scan_interval=None,
|
||||
):
|
||||
"""Initialize the platform."""
|
||||
self.DEPENDENCIES = dependencies or []
|
||||
|
||||
|
@ -540,22 +572,22 @@ class MockEntityPlatform(entity_platform.EntityPlatform):
|
|||
"""Mock class with some mock defaults."""
|
||||
|
||||
def __init__(
|
||||
self, hass,
|
||||
self,
|
||||
hass,
|
||||
logger=None,
|
||||
domain='test_domain',
|
||||
platform_name='test_platform',
|
||||
domain="test_domain",
|
||||
platform_name="test_platform",
|
||||
platform=None,
|
||||
scan_interval=timedelta(seconds=15),
|
||||
entity_namespace=None,
|
||||
async_entities_added_callback=lambda: None
|
||||
async_entities_added_callback=lambda: None,
|
||||
):
|
||||
"""Initialize a mock entity platform."""
|
||||
if logger is None:
|
||||
logger = logging.getLogger('homeassistant.helpers.entity_platform')
|
||||
logger = logging.getLogger("homeassistant.helpers.entity_platform")
|
||||
|
||||
# Otherwise the constructor will blow up.
|
||||
if (isinstance(platform, Mock) and
|
||||
isinstance(platform.PARALLEL_UPDATES, Mock)):
|
||||
if isinstance(platform, Mock) and isinstance(platform.PARALLEL_UPDATES, Mock):
|
||||
platform.PARALLEL_UPDATES = 0
|
||||
|
||||
super().__init__(
|
||||
|
@ -582,29 +614,29 @@ class MockToggleDevice(entity.ToggleEntity):
|
|||
@property
|
||||
def name(self):
|
||||
"""Return the name of the device if any."""
|
||||
self.calls.append(('name', {}))
|
||||
self.calls.append(("name", {}))
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
"""Return the name of the device if any."""
|
||||
self.calls.append(('state', {}))
|
||||
self.calls.append(("state", {}))
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def is_on(self):
|
||||
"""Return true if device is on."""
|
||||
self.calls.append(('is_on', {}))
|
||||
self.calls.append(("is_on", {}))
|
||||
return self._state == STATE_ON
|
||||
|
||||
def turn_on(self, **kwargs):
|
||||
"""Turn the device on."""
|
||||
self.calls.append(('turn_on', kwargs))
|
||||
self.calls.append(("turn_on", kwargs))
|
||||
self._state = STATE_ON
|
||||
|
||||
def turn_off(self, **kwargs):
|
||||
"""Turn the device off."""
|
||||
self.calls.append(('turn_off', kwargs))
|
||||
self.calls.append(("turn_off", kwargs))
|
||||
self._state = STATE_OFF
|
||||
|
||||
def last_call(self, method=None):
|
||||
|
@ -614,8 +646,7 @@ class MockToggleDevice(entity.ToggleEntity):
|
|||
if method is None:
|
||||
return self.calls[-1]
|
||||
try:
|
||||
return next(call for call in reversed(self.calls)
|
||||
if call[0] == method)
|
||||
return next(call for call in reversed(self.calls) if call[0] == method)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
|
@ -623,24 +654,33 @@ class MockToggleDevice(entity.ToggleEntity):
|
|||
class MockConfigEntry(config_entries.ConfigEntry):
|
||||
"""Helper for creating config entries that adds some defaults."""
|
||||
|
||||
def __init__(self, *, domain='test', data=None, version=1, entry_id=None,
|
||||
source=config_entries.SOURCE_USER, title='Mock Title',
|
||||
state=None, options={},
|
||||
connection_class=config_entries.CONN_CLASS_UNKNOWN):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
domain="test",
|
||||
data=None,
|
||||
version=1,
|
||||
entry_id=None,
|
||||
source=config_entries.SOURCE_USER,
|
||||
title="Mock Title",
|
||||
state=None,
|
||||
options={},
|
||||
connection_class=config_entries.CONN_CLASS_UNKNOWN,
|
||||
):
|
||||
"""Initialize a mock config entry."""
|
||||
kwargs = {
|
||||
'entry_id': entry_id or uuid.uuid4().hex,
|
||||
'domain': domain,
|
||||
'data': data or {},
|
||||
'options': options,
|
||||
'version': version,
|
||||
'title': title,
|
||||
'connection_class': connection_class,
|
||||
"entry_id": entry_id or uuid.uuid4().hex,
|
||||
"domain": domain,
|
||||
"data": data or {},
|
||||
"options": options,
|
||||
"version": version,
|
||||
"title": title,
|
||||
"connection_class": connection_class,
|
||||
}
|
||||
if source is not None:
|
||||
kwargs['source'] = source
|
||||
kwargs["source"] = source
|
||||
if state is not None:
|
||||
kwargs['state'] = state
|
||||
kwargs["state"] = state
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def add_to_hass(self, hass):
|
||||
|
@ -663,7 +703,7 @@ def patch_yaml_files(files_dict, endswith=True):
|
|||
if fname in files_dict:
|
||||
_LOGGER.debug("patch_yaml_files match %s", fname)
|
||||
res = StringIO(files_dict[fname])
|
||||
setattr(res, 'name', fname)
|
||||
setattr(res, "name", fname)
|
||||
return res
|
||||
|
||||
# Match using endswith
|
||||
|
@ -671,18 +711,18 @@ def patch_yaml_files(files_dict, endswith=True):
|
|||
if fname.endswith(ends):
|
||||
_LOGGER.debug("patch_yaml_files end match %s: %s", ends, fname)
|
||||
res = StringIO(files_dict[ends])
|
||||
setattr(res, 'name', fname)
|
||||
setattr(res, "name", fname)
|
||||
return res
|
||||
|
||||
# Fallback for hass.components (i.e. services.yaml)
|
||||
if 'homeassistant/components' in fname:
|
||||
if "homeassistant/components" in fname:
|
||||
_LOGGER.debug("patch_yaml_files using real file: %s", fname)
|
||||
return open(fname, encoding='utf-8')
|
||||
return open(fname, encoding="utf-8")
|
||||
|
||||
# Not found
|
||||
raise FileNotFoundError("File not found: {}".format(fname))
|
||||
|
||||
return patch.object(yaml_loader, 'open', mock_open_f, create=True)
|
||||
return patch.object(yaml_loader, "open", mock_open_f, create=True)
|
||||
|
||||
|
||||
def mock_coro(return_value=None, exception=None):
|
||||
|
@ -692,6 +732,7 @@ def mock_coro(return_value=None, exception=None):
|
|||
|
||||
def mock_coro_func(return_value=None, exception=None):
|
||||
"""Return a method to create a coro function that returns a value."""
|
||||
|
||||
@asyncio.coroutine
|
||||
def coro(*args, **kwargs):
|
||||
"""Fake coroutine."""
|
||||
|
@ -720,39 +761,40 @@ def assert_setup_component(count, domain=None):
|
|||
async def mock_psc(hass, config_input, integration):
|
||||
"""Mock the prepare_setup_component to capture config."""
|
||||
domain_input = integration.domain
|
||||
res = await async_process_component_config(
|
||||
hass, config_input, integration)
|
||||
res = await async_process_component_config(hass, config_input, integration)
|
||||
config[domain_input] = None if res is None else res.get(domain_input)
|
||||
_LOGGER.debug("Configuration for %s, Validated: %s, Original %s",
|
||||
domain_input,
|
||||
config[domain_input],
|
||||
config_input.get(domain_input))
|
||||
_LOGGER.debug(
|
||||
"Configuration for %s, Validated: %s, Original %s",
|
||||
domain_input,
|
||||
config[domain_input],
|
||||
config_input.get(domain_input),
|
||||
)
|
||||
return res
|
||||
|
||||
assert isinstance(config, dict)
|
||||
with patch('homeassistant.config.async_process_component_config',
|
||||
mock_psc):
|
||||
with patch("homeassistant.config.async_process_component_config", mock_psc):
|
||||
yield config
|
||||
|
||||
if domain is None:
|
||||
assert len(config) == 1, ('assert_setup_component requires DOMAIN: {}'
|
||||
.format(list(config.keys())))
|
||||
assert len(config) == 1, "assert_setup_component requires DOMAIN: {}".format(
|
||||
list(config.keys())
|
||||
)
|
||||
domain = list(config.keys())[0]
|
||||
|
||||
res = config.get(domain)
|
||||
res_len = 0 if res is None else len(res)
|
||||
assert res_len == count, 'setup_component failed, expected {} got {}: {}' \
|
||||
.format(count, res_len, res)
|
||||
assert res_len == count, "setup_component failed, expected {} got {}: {}".format(
|
||||
count, res_len, res
|
||||
)
|
||||
|
||||
|
||||
def init_recorder_component(hass, add_config=None):
|
||||
"""Initialize the recorder."""
|
||||
config = dict(add_config) if add_config else {}
|
||||
config[recorder.CONF_DB_URL] = 'sqlite://' # In memory DB
|
||||
config[recorder.CONF_DB_URL] = "sqlite://" # In memory DB
|
||||
|
||||
with patch('homeassistant.components.recorder.migration.migrate_schema'):
|
||||
assert setup_component(hass, recorder.DOMAIN,
|
||||
{recorder.DOMAIN: config})
|
||||
with patch("homeassistant.components.recorder.migration.migrate_schema"):
|
||||
assert setup_component(hass, recorder.DOMAIN, {recorder.DOMAIN: config})
|
||||
assert recorder.DOMAIN in hass.config.components
|
||||
_LOGGER.info("In-memory recorder successfully started")
|
||||
|
||||
|
@ -766,14 +808,17 @@ def mock_restore_cache(hass, states):
|
|||
last_states = {}
|
||||
for state in states:
|
||||
restored_state = state.as_dict()
|
||||
restored_state['attributes'] = json.loads(json.dumps(
|
||||
restored_state['attributes'], cls=JSONEncoder))
|
||||
restored_state["attributes"] = json.loads(
|
||||
json.dumps(restored_state["attributes"], cls=JSONEncoder)
|
||||
)
|
||||
last_states[state.entity_id] = restore_state.StoredState(
|
||||
State.from_dict(restored_state), now)
|
||||
State.from_dict(restored_state), now
|
||||
)
|
||||
data.last_states = last_states
|
||||
_LOGGER.debug('Restore cache: %s', data.last_states)
|
||||
assert len(data.last_states) == len(states), \
|
||||
"Duplicate entity_id? {}".format(states)
|
||||
_LOGGER.debug("Restore cache: %s", data.last_states)
|
||||
assert len(data.last_states) == len(states), "Duplicate entity_id? {}".format(
|
||||
states
|
||||
)
|
||||
|
||||
async def get_restore_state_data() -> restore_state.RestoreStateData:
|
||||
return data
|
||||
|
@ -792,6 +837,7 @@ class MockDependency:
|
|||
|
||||
def __enter__(self):
|
||||
"""Start mocking."""
|
||||
|
||||
def resolve(mock, path):
|
||||
"""Resolve a mock."""
|
||||
if not path:
|
||||
|
@ -801,12 +847,12 @@ class MockDependency:
|
|||
|
||||
base = MagicMock()
|
||||
to_mock = {
|
||||
"{}.{}".format(self.root, tom): resolve(base, tom.split('.'))
|
||||
"{}.{}".format(self.root, tom): resolve(base, tom.split("."))
|
||||
for tom in self.submodules
|
||||
}
|
||||
to_mock[self.root] = base
|
||||
|
||||
self.patcher = patch.dict('sys.modules', to_mock)
|
||||
self.patcher = patch.dict("sys.modules", to_mock)
|
||||
self.patcher.start()
|
||||
return base
|
||||
|
||||
|
@ -817,6 +863,7 @@ class MockDependency:
|
|||
|
||||
def __call__(self, func):
|
||||
"""Apply decorator."""
|
||||
|
||||
def run_mocked(*args, **kwargs):
|
||||
"""Run with mocked dependencies."""
|
||||
with self as base:
|
||||
|
@ -833,33 +880,33 @@ class MockEntity(entity.Entity):
|
|||
"""Initialize an entity."""
|
||||
self._values = values
|
||||
|
||||
if 'entity_id' in values:
|
||||
self.entity_id = values['entity_id']
|
||||
if "entity_id" in values:
|
||||
self.entity_id = values["entity_id"]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Return the name of the entity."""
|
||||
return self._handle('name')
|
||||
return self._handle("name")
|
||||
|
||||
@property
|
||||
def should_poll(self):
|
||||
"""Return the ste of the polling."""
|
||||
return self._handle('should_poll')
|
||||
return self._handle("should_poll")
|
||||
|
||||
@property
|
||||
def unique_id(self):
|
||||
"""Return the unique ID of the entity."""
|
||||
return self._handle('unique_id')
|
||||
return self._handle("unique_id")
|
||||
|
||||
@property
|
||||
def available(self):
|
||||
"""Return True if entity is available."""
|
||||
return self._handle('available')
|
||||
return self._handle("available")
|
||||
|
||||
@property
|
||||
def device_info(self):
|
||||
"""Info how it links to a device."""
|
||||
return self._handle('device_info')
|
||||
return self._handle("device_info")
|
||||
|
||||
def _handle(self, attr):
|
||||
"""Return attribute value."""
|
||||
|
@ -890,7 +937,7 @@ def mock_storage(data=None):
|
|||
|
||||
mock_data = data.get(store.key)
|
||||
|
||||
if 'data' not in mock_data or 'version' not in mock_data:
|
||||
if "data" not in mock_data or "version" not in mock_data:
|
||||
_LOGGER.error('Mock data needs "version" and "data"')
|
||||
raise ValueError('Mock data needs "version" and "data"')
|
||||
|
||||
|
@ -898,20 +945,24 @@ def mock_storage(data=None):
|
|||
|
||||
# Route through original load so that we trigger migration
|
||||
loaded = await orig_load(store)
|
||||
_LOGGER.info('Loading data for %s: %s', store.key, loaded)
|
||||
_LOGGER.info("Loading data for %s: %s", store.key, loaded)
|
||||
return loaded
|
||||
|
||||
def mock_write_data(store, path, data_to_write):
|
||||
"""Mock version of write data."""
|
||||
_LOGGER.info('Writing data to %s: %s', store.key, data_to_write)
|
||||
_LOGGER.info("Writing data to %s: %s", store.key, data_to_write)
|
||||
# To ensure that the data can be serialized
|
||||
data[store.key] = json.loads(json.dumps(
|
||||
data_to_write, cls=store._encoder))
|
||||
data[store.key] = json.loads(json.dumps(data_to_write, cls=store._encoder))
|
||||
|
||||
with patch('homeassistant.helpers.storage.Store._async_load',
|
||||
side_effect=mock_async_load, autospec=True), \
|
||||
patch('homeassistant.helpers.storage.Store._write_data',
|
||||
side_effect=mock_write_data, autospec=True):
|
||||
with patch(
|
||||
"homeassistant.helpers.storage.Store._async_load",
|
||||
side_effect=mock_async_load,
|
||||
autospec=True,
|
||||
), patch(
|
||||
"homeassistant.helpers.storage.Store._write_data",
|
||||
side_effect=mock_write_data,
|
||||
autospec=True,
|
||||
):
|
||||
yield data
|
||||
|
||||
|
||||
|
@ -925,19 +976,20 @@ async def flush_store(store):
|
|||
|
||||
async def get_system_health_info(hass, domain):
|
||||
"""Get system health info."""
|
||||
return await hass.data['system_health']['info'][domain](hass)
|
||||
return await hass.data["system_health"]["info"][domain](hass)
|
||||
|
||||
|
||||
def mock_integration(hass, module):
|
||||
"""Mock an integration."""
|
||||
integration = loader.Integration(
|
||||
hass, 'homeassistant.components.{}'.format(module.DOMAIN), None,
|
||||
module.mock_manifest())
|
||||
hass,
|
||||
"homeassistant.components.{}".format(module.DOMAIN),
|
||||
None,
|
||||
module.mock_manifest(),
|
||||
)
|
||||
|
||||
_LOGGER.info("Adding mock integration: %s", module.DOMAIN)
|
||||
hass.data.setdefault(
|
||||
loader.DATA_INTEGRATIONS, {}
|
||||
)[module.DOMAIN] = integration
|
||||
hass.data.setdefault(loader.DATA_INTEGRATIONS, {})[module.DOMAIN] = integration
|
||||
hass.data.setdefault(loader.DATA_COMPONENTS, {})[module.DOMAIN] = module
|
||||
|
||||
|
||||
|
@ -947,7 +999,7 @@ def mock_entity_platform(hass, platform_path, module):
|
|||
platform_path is in form light.hue. Will create platform
|
||||
hue.light.
|
||||
"""
|
||||
domain, platform_name = platform_path.split('.')
|
||||
domain, platform_name = platform_path.split(".")
|
||||
integration_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
|
||||
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue