Use singleton enum for "not set" sentinels (#41990)

* Use singleton enum for "not set" sentinel

https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions

* Remove unused variable
This commit is contained in:
Ville Skyttä 2020-12-19 13:46:27 +02:00 committed by GitHub
parent de04a1ed67
commit 317ed418dd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 139 additions and 131 deletions

View file

@ -1,11 +1,12 @@
"""Preference management for camera component.""" """Preference management for camera component."""
from homeassistant.helpers.typing import UNDEFINED
from .const import DOMAIN, PREF_PRELOAD_STREAM from .const import DOMAIN, PREF_PRELOAD_STREAM
# mypy: allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-defs, no-check-untyped-defs
STORAGE_KEY = DOMAIN STORAGE_KEY = DOMAIN
STORAGE_VERSION = 1 STORAGE_VERSION = 1
_UNDEF = object()
class CameraEntityPreferences: class CameraEntityPreferences:
@ -44,14 +45,14 @@ class CameraPreferences:
self._prefs = prefs self._prefs = prefs
async def async_update( async def async_update(
self, entity_id, *, preload_stream=_UNDEF, stream_options=_UNDEF self, entity_id, *, preload_stream=UNDEFINED, stream_options=UNDEFINED
): ):
"""Update camera preferences.""" """Update camera preferences."""
if not self._prefs.get(entity_id): if not self._prefs.get(entity_id):
self._prefs[entity_id] = {} self._prefs[entity_id] = {}
for key, value in ((PREF_PRELOAD_STREAM, preload_stream),): for key, value in ((PREF_PRELOAD_STREAM, preload_stream),):
if value is not _UNDEF: if value is not UNDEFINED:
self._prefs[entity_id][key] = value self._prefs[entity_id][key] = value
await self._store.async_save(self._prefs) await self._store.async_save(self._prefs)

View file

@ -5,6 +5,7 @@ from typing import List, Optional
from homeassistant.auth.const import GROUP_ID_ADMIN from homeassistant.auth.const import GROUP_ID_ADMIN
from homeassistant.auth.models import User from homeassistant.auth.models import User
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.helpers.typing import UNDEFINED
from homeassistant.util.logging import async_create_catching_coro from homeassistant.util.logging import async_create_catching_coro
from .const import ( from .const import (
@ -36,7 +37,6 @@ from .const import (
STORAGE_KEY = DOMAIN STORAGE_KEY = DOMAIN
STORAGE_VERSION = 1 STORAGE_VERSION = 1
_UNDEF = object()
class CloudPreferences: class CloudPreferences:
@ -74,18 +74,18 @@ class CloudPreferences:
async def async_update( async def async_update(
self, self,
*, *,
google_enabled=_UNDEF, google_enabled=UNDEFINED,
alexa_enabled=_UNDEF, alexa_enabled=UNDEFINED,
remote_enabled=_UNDEF, remote_enabled=UNDEFINED,
google_secure_devices_pin=_UNDEF, google_secure_devices_pin=UNDEFINED,
cloudhooks=_UNDEF, cloudhooks=UNDEFINED,
cloud_user=_UNDEF, cloud_user=UNDEFINED,
google_entity_configs=_UNDEF, google_entity_configs=UNDEFINED,
alexa_entity_configs=_UNDEF, alexa_entity_configs=UNDEFINED,
alexa_report_state=_UNDEF, alexa_report_state=UNDEFINED,
google_report_state=_UNDEF, google_report_state=UNDEFINED,
alexa_default_expose=_UNDEF, alexa_default_expose=UNDEFINED,
google_default_expose=_UNDEF, google_default_expose=UNDEFINED,
): ):
"""Update user preferences.""" """Update user preferences."""
prefs = {**self._prefs} prefs = {**self._prefs}
@ -104,7 +104,7 @@ class CloudPreferences:
(PREF_ALEXA_DEFAULT_EXPOSE, alexa_default_expose), (PREF_ALEXA_DEFAULT_EXPOSE, alexa_default_expose),
(PREF_GOOGLE_DEFAULT_EXPOSE, google_default_expose), (PREF_GOOGLE_DEFAULT_EXPOSE, google_default_expose),
): ):
if value is not _UNDEF: if value is not UNDEFINED:
prefs[key] = value prefs[key] = value
if remote_enabled is True and self._has_local_trusted_network: if remote_enabled is True and self._has_local_trusted_network:
@ -121,10 +121,10 @@ class CloudPreferences:
self, self,
*, *,
entity_id, entity_id,
override_name=_UNDEF, override_name=UNDEFINED,
disable_2fa=_UNDEF, disable_2fa=UNDEFINED,
aliases=_UNDEF, aliases=UNDEFINED,
should_expose=_UNDEF, should_expose=UNDEFINED,
): ):
"""Update config for a Google entity.""" """Update config for a Google entity."""
entities = self.google_entity_configs entities = self.google_entity_configs
@ -137,7 +137,7 @@ class CloudPreferences:
(PREF_ALIASES, aliases), (PREF_ALIASES, aliases),
(PREF_SHOULD_EXPOSE, should_expose), (PREF_SHOULD_EXPOSE, should_expose),
): ):
if value is not _UNDEF: if value is not UNDEFINED:
changes[key] = value changes[key] = value
if not changes: if not changes:
@ -149,7 +149,7 @@ class CloudPreferences:
await self.async_update(google_entity_configs=updated_entities) await self.async_update(google_entity_configs=updated_entities)
async def async_update_alexa_entity_config( async def async_update_alexa_entity_config(
self, *, entity_id, should_expose=_UNDEF self, *, entity_id, should_expose=UNDEFINED
): ):
"""Update config for an Alexa entity.""" """Update config for an Alexa entity."""
entities = self.alexa_entity_configs entities = self.alexa_entity_configs
@ -157,7 +157,7 @@ class CloudPreferences:
changes = {} changes = {}
for key, value in ((PREF_SHOULD_EXPOSE, should_expose),): for key, value in ((PREF_SHOULD_EXPOSE, should_expose),):
if value is not _UNDEF: if value is not UNDEFINED:
changes[key] = value changes[key] = value
if not changes: if not changes:

View file

@ -1,8 +1,8 @@
"""Support for deCONZ devices.""" """Support for deCONZ devices."""
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import _UNDEF
from homeassistant.const import EVENT_HOMEASSISTANT_STOP from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.helpers.typing import UNDEFINED
from .config_flow import get_master_gateway from .config_flow import get_master_gateway
from .const import CONF_BRIDGE_ID, CONF_GROUP_ID_BASE, CONF_MASTER_GATEWAY, DOMAIN from .const import CONF_BRIDGE_ID, CONF_GROUP_ID_BASE, CONF_MASTER_GATEWAY, DOMAIN
@ -39,7 +39,7 @@ async def async_setup_entry(hass, config_entry):
# 0.104 introduced config entry unique id, this makes upgrading possible # 0.104 introduced config entry unique id, this makes upgrading possible
if config_entry.unique_id is None: if config_entry.unique_id is None:
new_data = _UNDEF new_data = UNDEFINED
if CONF_BRIDGE_ID in config_entry.data: if CONF_BRIDGE_ID in config_entry.data:
new_data = dict(config_entry.data) new_data = dict(config_entry.data)
new_data[CONF_GROUP_ID_BASE] = config_entry.data[CONF_BRIDGE_ID] new_data[CONF_GROUP_ID_BASE] = config_entry.data[CONF_BRIDGE_ID]

View file

@ -87,8 +87,6 @@ CONFIG_SCHEMA = vol.Schema(
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
) )
_UNDEF = object()
@bind_hass @bind_hass
async def async_create_person(hass, name, *, user_id=None, device_trackers=None): async def async_create_person(hass, name, *, user_id=None, device_trackers=None):

View file

@ -13,12 +13,12 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError
from homeassistant.helpers import entity_registry from homeassistant.helpers import entity_registry
from homeassistant.helpers.event import Event from homeassistant.helpers.event import Event
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
from homeassistant.setup import async_process_deps_reqs, async_setup_component from homeassistant.setup import async_process_deps_reqs, async_setup_component
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
import homeassistant.util.uuid as uuid_util import homeassistant.util.uuid as uuid_util
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_UNDEF: dict = {}
SOURCE_DISCOVERY = "discovery" SOURCE_DISCOVERY = "discovery"
SOURCE_HASSIO = "hassio" SOURCE_HASSIO = "hassio"
@ -760,12 +760,11 @@ class ConfigEntries:
self, self,
entry: ConfigEntry, entry: ConfigEntry,
*, *,
# pylint: disable=dangerous-default-value # _UNDEFs not modified unique_id: Union[str, dict, None, UndefinedType] = UNDEFINED,
unique_id: Union[str, dict, None] = _UNDEF, title: Union[str, dict, UndefinedType] = UNDEFINED,
title: Union[str, dict] = _UNDEF, data: Union[dict, UndefinedType] = UNDEFINED,
data: dict = _UNDEF, options: Union[dict, UndefinedType] = UNDEFINED,
options: dict = _UNDEF, system_options: Union[dict, UndefinedType] = UNDEFINED,
system_options: dict = _UNDEF,
) -> bool: ) -> bool:
"""Update a config entry. """Update a config entry.
@ -777,24 +776,24 @@ class ConfigEntries:
""" """
changed = False changed = False
if unique_id is not _UNDEF and entry.unique_id != unique_id: if unique_id is not UNDEFINED and entry.unique_id != unique_id:
changed = True changed = True
entry.unique_id = cast(Optional[str], unique_id) entry.unique_id = cast(Optional[str], unique_id)
if title is not _UNDEF and entry.title != title: if title is not UNDEFINED and entry.title != title:
changed = True changed = True
entry.title = cast(str, title) entry.title = cast(str, title)
if data is not _UNDEF and entry.data != data: # type: ignore if data is not UNDEFINED and entry.data != data: # type: ignore
changed = True changed = True
entry.data = MappingProxyType(data) entry.data = MappingProxyType(data)
if options is not _UNDEF and entry.options != options: # type: ignore if options is not UNDEFINED and entry.options != options: # type: ignore
changed = True changed = True
entry.options = MappingProxyType(options) entry.options = MappingProxyType(options)
if ( if (
system_options is not _UNDEF system_options is not UNDEFINED
and entry.system_options.as_dict() != system_options and entry.system_options.as_dict() != system_options
): ):
changed = True changed = True

View file

@ -89,7 +89,7 @@ block_async_io.enable()
fix_threading_exception_logging() fix_threading_exception_logging()
T = TypeVar("T") T = TypeVar("T")
_UNDEF: dict = {} _UNDEF: dict = {} # Internal; not helpers.typing.UNDEFINED due to circular dependency
# pylint: disable=invalid-name # pylint: disable=invalid-name
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable)
CALLBACK_TYPE = Callable[[], None] CALLBACK_TYPE = Callable[[], None]

View file

@ -11,7 +11,7 @@ import homeassistant.util.uuid as uuid_util
from .debounce import Debouncer from .debounce import Debouncer
from .singleton import singleton from .singleton import singleton
from .typing import HomeAssistantType from .typing import UNDEFINED, HomeAssistantType
if TYPE_CHECKING: if TYPE_CHECKING:
from . import entity_registry from . import entity_registry
@ -19,7 +19,6 @@ if TYPE_CHECKING:
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_UNDEF = object()
DATA_REGISTRY = "device_registry" DATA_REGISTRY = "device_registry"
EVENT_DEVICE_REGISTRY_UPDATED = "device_registry_updated" EVENT_DEVICE_REGISTRY_UPDATED = "device_registry_updated"
@ -224,17 +223,17 @@ class DeviceRegistry:
config_entry_id, config_entry_id,
connections=None, connections=None,
identifiers=None, identifiers=None,
manufacturer=_UNDEF, manufacturer=UNDEFINED,
model=_UNDEF, model=UNDEFINED,
name=_UNDEF, name=UNDEFINED,
default_manufacturer=_UNDEF, default_manufacturer=UNDEFINED,
default_model=_UNDEF, default_model=UNDEFINED,
default_name=_UNDEF, default_name=UNDEFINED,
sw_version=_UNDEF, sw_version=UNDEFINED,
entry_type=_UNDEF, entry_type=UNDEFINED,
via_device=None, via_device=None,
# To disable a device if it gets created # To disable a device if it gets created
disabled_by=_UNDEF, disabled_by=UNDEFINED,
): ):
"""Get device. Create if it doesn't exist.""" """Get device. Create if it doesn't exist."""
if not identifiers and not connections: if not identifiers and not connections:
@ -261,27 +260,27 @@ class DeviceRegistry:
) )
self._add_device(device) self._add_device(device)
if default_manufacturer is not _UNDEF and device.manufacturer is None: if default_manufacturer is not UNDEFINED and device.manufacturer is None:
manufacturer = default_manufacturer manufacturer = default_manufacturer
if default_model is not _UNDEF and device.model is None: if default_model is not UNDEFINED and device.model is None:
model = default_model model = default_model
if default_name is not _UNDEF and device.name is None: if default_name is not UNDEFINED and device.name is None:
name = default_name name = default_name
if via_device is not None: if via_device is not None:
via = self.async_get_device({via_device}, set()) via = self.async_get_device({via_device}, set())
via_device_id = via.id if via else _UNDEF via_device_id = via.id if via else UNDEFINED
else: else:
via_device_id = _UNDEF via_device_id = UNDEFINED
return self._async_update_device( return self._async_update_device(
device.id, device.id,
add_config_entry_id=config_entry_id, add_config_entry_id=config_entry_id,
via_device_id=via_device_id, via_device_id=via_device_id,
merge_connections=connections or _UNDEF, merge_connections=connections or UNDEFINED,
merge_identifiers=identifiers or _UNDEF, merge_identifiers=identifiers or UNDEFINED,
manufacturer=manufacturer, manufacturer=manufacturer,
model=model, model=model,
name=name, name=name,
@ -295,16 +294,16 @@ class DeviceRegistry:
self, self,
device_id, device_id,
*, *,
area_id=_UNDEF, area_id=UNDEFINED,
manufacturer=_UNDEF, manufacturer=UNDEFINED,
model=_UNDEF, model=UNDEFINED,
name=_UNDEF, name=UNDEFINED,
name_by_user=_UNDEF, name_by_user=UNDEFINED,
new_identifiers=_UNDEF, new_identifiers=UNDEFINED,
sw_version=_UNDEF, sw_version=UNDEFINED,
via_device_id=_UNDEF, via_device_id=UNDEFINED,
remove_config_entry_id=_UNDEF, remove_config_entry_id=UNDEFINED,
disabled_by=_UNDEF, disabled_by=UNDEFINED,
): ):
"""Update properties of a device.""" """Update properties of a device."""
return self._async_update_device( return self._async_update_device(
@ -326,20 +325,20 @@ class DeviceRegistry:
self, self,
device_id, device_id,
*, *,
add_config_entry_id=_UNDEF, add_config_entry_id=UNDEFINED,
remove_config_entry_id=_UNDEF, remove_config_entry_id=UNDEFINED,
merge_connections=_UNDEF, merge_connections=UNDEFINED,
merge_identifiers=_UNDEF, merge_identifiers=UNDEFINED,
new_identifiers=_UNDEF, new_identifiers=UNDEFINED,
manufacturer=_UNDEF, manufacturer=UNDEFINED,
model=_UNDEF, model=UNDEFINED,
name=_UNDEF, name=UNDEFINED,
sw_version=_UNDEF, sw_version=UNDEFINED,
entry_type=_UNDEF, entry_type=UNDEFINED,
via_device_id=_UNDEF, via_device_id=UNDEFINED,
area_id=_UNDEF, area_id=UNDEFINED,
name_by_user=_UNDEF, name_by_user=UNDEFINED,
disabled_by=_UNDEF, disabled_by=UNDEFINED,
): ):
"""Update device attributes.""" """Update device attributes."""
old = self.devices[device_id] old = self.devices[device_id]
@ -349,13 +348,13 @@ class DeviceRegistry:
config_entries = old.config_entries config_entries = old.config_entries
if ( if (
add_config_entry_id is not _UNDEF add_config_entry_id is not UNDEFINED
and add_config_entry_id not in old.config_entries and add_config_entry_id not in old.config_entries
): ):
config_entries = old.config_entries | {add_config_entry_id} config_entries = old.config_entries | {add_config_entry_id}
if ( if (
remove_config_entry_id is not _UNDEF remove_config_entry_id is not UNDEFINED
and remove_config_entry_id in config_entries and remove_config_entry_id in config_entries
): ):
if config_entries == {remove_config_entry_id}: if config_entries == {remove_config_entry_id}:
@ -373,10 +372,10 @@ class DeviceRegistry:
): ):
old_value = getattr(old, attr_name) old_value = getattr(old, attr_name)
# If not undefined, check if `value` contains new items. # If not undefined, check if `value` contains new items.
if value is not _UNDEF and not value.issubset(old_value): if value is not UNDEFINED and not value.issubset(old_value):
changes[attr_name] = old_value | value changes[attr_name] = old_value | value
if new_identifiers is not _UNDEF: if new_identifiers is not UNDEFINED:
changes["identifiers"] = new_identifiers changes["identifiers"] = new_identifiers
for attr_name, value in ( for attr_name, value in (
@ -388,13 +387,13 @@ class DeviceRegistry:
("via_device_id", via_device_id), ("via_device_id", via_device_id),
("disabled_by", disabled_by), ("disabled_by", disabled_by),
): ):
if value is not _UNDEF and value != getattr(old, attr_name): if value is not UNDEFINED and value != getattr(old, attr_name):
changes[attr_name] = value changes[attr_name] = value
if area_id is not _UNDEF and area_id != old.area_id: if area_id is not UNDEFINED and area_id != old.area_id:
changes["area_id"] = area_id changes["area_id"] = area_id
if name_by_user is not _UNDEF and name_by_user != old.name_by_user: if name_by_user is not UNDEFINED and name_by_user != old.name_by_user:
changes["name_by_user"] = name_by_user changes["name_by_user"] = name_by_user
if old.is_new: if old.is_new:

View file

@ -39,7 +39,7 @@ from homeassistant.util import slugify
from homeassistant.util.yaml import load_yaml from homeassistant.util.yaml import load_yaml
from .singleton import singleton from .singleton import singleton
from .typing import HomeAssistantType from .typing import UNDEFINED, HomeAssistantType
if TYPE_CHECKING: if TYPE_CHECKING:
from homeassistant.config_entries import ConfigEntry # noqa: F401 from homeassistant.config_entries import ConfigEntry # noqa: F401
@ -51,7 +51,6 @@ DATA_REGISTRY = "entity_registry"
EVENT_ENTITY_REGISTRY_UPDATED = "entity_registry_updated" EVENT_ENTITY_REGISTRY_UPDATED = "entity_registry_updated"
SAVE_DELAY = 10 SAVE_DELAY = 10
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_UNDEF = object()
DISABLED_CONFIG_ENTRY = "config_entry" DISABLED_CONFIG_ENTRY = "config_entry"
DISABLED_DEVICE = "device" DISABLED_DEVICE = "device"
DISABLED_HASS = "hass" DISABLED_HASS = "hass"
@ -225,15 +224,15 @@ class EntityRegistry:
if entity_id: if entity_id:
return self._async_update_entity( # type: ignore return self._async_update_entity( # type: ignore
entity_id, entity_id,
config_entry_id=config_entry_id or _UNDEF, config_entry_id=config_entry_id or UNDEFINED,
device_id=device_id or _UNDEF, device_id=device_id or UNDEFINED,
area_id=area_id or _UNDEF, area_id=area_id or UNDEFINED,
capabilities=capabilities or _UNDEF, capabilities=capabilities or UNDEFINED,
supported_features=supported_features or _UNDEF, supported_features=supported_features or UNDEFINED,
device_class=device_class or _UNDEF, device_class=device_class or UNDEFINED,
unit_of_measurement=unit_of_measurement or _UNDEF, unit_of_measurement=unit_of_measurement or UNDEFINED,
original_name=original_name or _UNDEF, original_name=original_name or UNDEFINED,
original_icon=original_icon or _UNDEF, original_icon=original_icon or UNDEFINED,
# When we changed our slugify algorithm, we invalidated some # When we changed our slugify algorithm, we invalidated some
# stored entity IDs with either a __ or ending in _. # stored entity IDs with either a __ or ending in _.
# Fix introduced in 0.86 (Jan 23, 2019). Next line can be # Fix introduced in 0.86 (Jan 23, 2019). Next line can be
@ -333,12 +332,12 @@ class EntityRegistry:
self, self,
entity_id, entity_id,
*, *,
name=_UNDEF, name=UNDEFINED,
icon=_UNDEF, icon=UNDEFINED,
area_id=_UNDEF, area_id=UNDEFINED,
new_entity_id=_UNDEF, new_entity_id=UNDEFINED,
new_unique_id=_UNDEF, new_unique_id=UNDEFINED,
disabled_by=_UNDEF, disabled_by=UNDEFINED,
): ):
"""Update properties of an entity.""" """Update properties of an entity."""
return cast( # cast until we have _async_update_entity type hinted return cast( # cast until we have _async_update_entity type hinted
@ -359,20 +358,20 @@ class EntityRegistry:
self, self,
entity_id, entity_id,
*, *,
name=_UNDEF, name=UNDEFINED,
icon=_UNDEF, icon=UNDEFINED,
config_entry_id=_UNDEF, config_entry_id=UNDEFINED,
new_entity_id=_UNDEF, new_entity_id=UNDEFINED,
device_id=_UNDEF, device_id=UNDEFINED,
area_id=_UNDEF, area_id=UNDEFINED,
new_unique_id=_UNDEF, new_unique_id=UNDEFINED,
disabled_by=_UNDEF, disabled_by=UNDEFINED,
capabilities=_UNDEF, capabilities=UNDEFINED,
supported_features=_UNDEF, supported_features=UNDEFINED,
device_class=_UNDEF, device_class=UNDEFINED,
unit_of_measurement=_UNDEF, unit_of_measurement=UNDEFINED,
original_name=_UNDEF, original_name=UNDEFINED,
original_icon=_UNDEF, original_icon=UNDEFINED,
): ):
"""Private facing update properties method.""" """Private facing update properties method."""
old = self.entities[entity_id] old = self.entities[entity_id]
@ -393,10 +392,10 @@ class EntityRegistry:
("original_name", original_name), ("original_name", original_name),
("original_icon", original_icon), ("original_icon", original_icon),
): ):
if value is not _UNDEF and value != getattr(old, attr_name): if value is not UNDEFINED and value != getattr(old, attr_name):
changes[attr_name] = value changes[attr_name] = value
if new_entity_id is not _UNDEF and new_entity_id != old.entity_id: if new_entity_id is not UNDEFINED and new_entity_id != old.entity_id:
if self.async_is_registered(new_entity_id): if self.async_is_registered(new_entity_id):
raise ValueError("Entity is already registered") raise ValueError("Entity is already registered")
@ -409,7 +408,7 @@ class EntityRegistry:
self.entities.pop(entity_id) self.entities.pop(entity_id)
entity_id = changes["entity_id"] = new_entity_id entity_id = changes["entity_id"] = new_entity_id
if new_unique_id is not _UNDEF: if new_unique_id is not UNDEFINED:
conflict_entity_id = self.async_get_entity_id( conflict_entity_id = self.async_get_entity_id(
old.domain, old.platform, new_unique_id old.domain, old.platform, new_unique_id
) )

View file

@ -1,4 +1,5 @@
"""Typing Helpers for Home Assistant.""" """Typing Helpers for Home Assistant."""
from enum import Enum
from typing import Any, Dict, Mapping, Optional, Tuple, Union from typing import Any, Dict, Mapping, Optional, Tuple, Union
import homeassistant.core import homeassistant.core
@ -16,3 +17,12 @@ TemplateVarsType = Optional[Mapping[str, Any]]
# Custom type for recorder Queries # Custom type for recorder Queries
QueryType = Any QueryType = Any
class UndefinedType(Enum):
"""Singleton type for use with not set sentinel values."""
_singleton = 0
UNDEFINED = UndefinedType._singleton # pylint: disable=protected-access

View file

@ -48,7 +48,7 @@ CUSTOM_WARNING = (
"cause stability problems, be sure to disable it if you " "cause stability problems, be sure to disable it if you "
"experience issues with Home Assistant." "experience issues with Home Assistant."
) )
_UNDEF = object() _UNDEF = object() # Internal; not helpers.typing.UNDEFINED due to circular dependency
MAX_LOAD_CONCURRENTLY = 4 MAX_LOAD_CONCURRENTLY = 4

View file

@ -5,6 +5,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import UNDEFINED, UndefinedType
from homeassistant.loader import Integration, IntegrationNotFound, async_get_integration from homeassistant.loader import Integration, IntegrationNotFound, async_get_integration
import homeassistant.util.package as pkg_util import homeassistant.util.package as pkg_util
@ -17,7 +18,6 @@ DISCOVERY_INTEGRATIONS: Dict[str, Iterable[str]] = {
"ssdp": ("ssdp",), "ssdp": ("ssdp",),
"zeroconf": ("zeroconf", "homekit"), "zeroconf": ("zeroconf", "homekit"),
} }
_UNDEF = object()
class RequirementsNotFound(HomeAssistantError): class RequirementsNotFound(HomeAssistantError):
@ -53,19 +53,21 @@ async def async_get_integration_with_requirements(
if cache is None: if cache is None:
cache = hass.data[DATA_INTEGRATIONS_WITH_REQS] = {} cache = hass.data[DATA_INTEGRATIONS_WITH_REQS] = {}
int_or_evt: Union[Integration, asyncio.Event, None] = cache.get(domain, _UNDEF) int_or_evt: Union[Integration, asyncio.Event, None, UndefinedType] = cache.get(
domain, UNDEFINED
)
if isinstance(int_or_evt, asyncio.Event): if isinstance(int_or_evt, asyncio.Event):
await int_or_evt.wait() await int_or_evt.wait()
int_or_evt = cache.get(domain, _UNDEF) int_or_evt = cache.get(domain, UNDEFINED)
# When we have waited and it's _UNDEF, it doesn't exist # When we have waited and it's UNDEFINED, it doesn't exist
# We don't cache that it doesn't exist, or else people can't fix it # We don't cache that it doesn't exist, or else people can't fix it
# and then restart, because their config will never be valid. # and then restart, because their config will never be valid.
if int_or_evt is _UNDEF: if int_or_evt is UNDEFINED:
raise IntegrationNotFound(domain) raise IntegrationNotFound(domain)
if int_or_evt is not _UNDEF: if int_or_evt is not UNDEFINED:
return cast(Integration, int_or_evt) return cast(Integration, int_or_evt)
event = cache[domain] = asyncio.Event() event = cache[domain] = asyncio.Event()