Use assignment expressions 03 (#57710)

This commit is contained in:
Marc Mueller 2021-10-17 20:08:11 +02:00 committed by GitHub
parent 2a8eaf0e0f
commit 238b488642
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 49 additions and 102 deletions

View file

@ -192,8 +192,7 @@ def _async_register_clientsession_shutdown(
EVENT_HOMEASSISTANT_CLOSE, _async_close_websession EVENT_HOMEASSISTANT_CLOSE, _async_close_websession
) )
config_entry = config_entries.current_entry.get() if not (config_entry := config_entries.current_entry.get()):
if not config_entry:
return return
config_entry.async_on_unload(unsub) config_entry.async_on_unload(unsub)

View file

@ -328,9 +328,8 @@ def async_numeric_state( # noqa: C901
if isinstance(entity, str): if isinstance(entity, str):
entity_id = entity entity_id = entity
entity = hass.states.get(entity)
if entity is None: if (entity := hass.states.get(entity)) is None:
raise ConditionErrorMessage("numeric_state", f"unknown entity {entity_id}") raise ConditionErrorMessage("numeric_state", f"unknown entity {entity_id}")
else: else:
entity_id = entity.entity_id entity_id = entity.entity_id
@ -371,8 +370,7 @@ def async_numeric_state( # noqa: C901
if below is not None: if below is not None:
if isinstance(below, str): if isinstance(below, str):
below_entity = hass.states.get(below) if not (below_entity := hass.states.get(below)):
if not below_entity:
raise ConditionErrorMessage( raise ConditionErrorMessage(
"numeric_state", f"unknown 'below' entity {below}" "numeric_state", f"unknown 'below' entity {below}"
) )
@ -400,8 +398,7 @@ def async_numeric_state( # noqa: C901
if above is not None: if above is not None:
if isinstance(above, str): if isinstance(above, str):
above_entity = hass.states.get(above) if not (above_entity := hass.states.get(above)):
if not above_entity:
raise ConditionErrorMessage( raise ConditionErrorMessage(
"numeric_state", f"unknown 'above' entity {above}" "numeric_state", f"unknown 'above' entity {above}"
) )
@ -497,9 +494,8 @@ def state(
if isinstance(entity, str): if isinstance(entity, str):
entity_id = entity entity_id = entity
entity = hass.states.get(entity)
if entity is None: if (entity := hass.states.get(entity)) is None:
raise ConditionErrorMessage("state", f"unknown entity {entity_id}") raise ConditionErrorMessage("state", f"unknown entity {entity_id}")
else: else:
entity_id = entity.entity_id entity_id = entity.entity_id
@ -526,8 +522,7 @@ def state(
isinstance(req_state_value, str) isinstance(req_state_value, str)
and INPUT_ENTITY_ID.match(req_state_value) is not None and INPUT_ENTITY_ID.match(req_state_value) is not None
): ):
state_entity = hass.states.get(req_state_value) if not (state_entity := hass.states.get(req_state_value)):
if not state_entity:
raise ConditionErrorMessage( raise ConditionErrorMessage(
"state", f"the 'state' entity {req_state_value} is unavailable" "state", f"the 'state' entity {req_state_value} is unavailable"
) )
@ -738,8 +733,7 @@ def time(
if after is None: if after is None:
after = dt_util.dt.time(0) after = dt_util.dt.time(0)
elif isinstance(after, str): elif isinstance(after, str):
after_entity = hass.states.get(after) if not (after_entity := hass.states.get(after)):
if not after_entity:
raise ConditionErrorMessage("time", f"unknown 'after' entity {after}") raise ConditionErrorMessage("time", f"unknown 'after' entity {after}")
if after_entity.domain == "input_datetime": if after_entity.domain == "input_datetime":
after = dt_util.dt.time( after = dt_util.dt.time(
@ -763,8 +757,7 @@ def time(
if before is None: if before is None:
before = dt_util.dt.time(23, 59, 59, 999999) before = dt_util.dt.time(23, 59, 59, 999999)
elif isinstance(before, str): elif isinstance(before, str):
before_entity = hass.states.get(before) if not (before_entity := hass.states.get(before)):
if not before_entity:
raise ConditionErrorMessage("time", f"unknown 'before' entity {before}") raise ConditionErrorMessage("time", f"unknown 'before' entity {before}")
if before_entity.domain == "input_datetime": if before_entity.domain == "input_datetime":
before = dt_util.dt.time( before = dt_util.dt.time(
@ -840,9 +833,8 @@ def zone(
if isinstance(zone_ent, str): if isinstance(zone_ent, str):
zone_ent_id = zone_ent zone_ent_id = zone_ent
zone_ent = hass.states.get(zone_ent)
if zone_ent is None: if (zone_ent := hass.states.get(zone_ent)) is None:
raise ConditionErrorMessage("zone", f"unknown zone {zone_ent_id}") raise ConditionErrorMessage("zone", f"unknown zone {zone_ent_id}")
if entity is None: if entity is None:
@ -850,9 +842,8 @@ def zone(
if isinstance(entity, str): if isinstance(entity, str):
entity_id = entity entity_id = entity
entity = hass.states.get(entity)
if entity is None: if (entity := hass.states.get(entity)) is None:
raise ConditionErrorMessage("zone", f"unknown entity {entity_id}") raise ConditionErrorMessage("zone", f"unknown entity {entity_id}")
else: else:
entity_id = entity.entity_id entity_id = entity.entity_id
@ -1029,9 +1020,7 @@ def async_extract_devices(config: ConfigType | Template) -> set[str]:
if condition != "device": if condition != "device":
continue continue
device_id = config.get(CONF_DEVICE_ID) if (device_id := config.get(CONF_DEVICE_ID)) is not None:
if device_id is not None:
referenced.add(device_id) referenced.add(device_id)
return referenced return referenced

View file

@ -129,14 +129,10 @@ class LocalOAuth2Implementation(AbstractOAuth2Implementation):
@property @property
def redirect_uri(self) -> str: def redirect_uri(self) -> str:
"""Return the redirect uri.""" """Return the redirect uri."""
req = http.current_request.get() if (req := http.current_request.get()) is None:
if req is None:
raise RuntimeError("No current request in context") raise RuntimeError("No current request in context")
ha_host = req.headers.get(HEADER_FRONTEND_BASE) if (ha_host := req.headers.get(HEADER_FRONTEND_BASE)) is None:
if ha_host is None:
raise RuntimeError("No header in request") raise RuntimeError("No header in request")
return f"{ha_host}{AUTH_CALLBACK_PATH}" return f"{ha_host}{AUTH_CALLBACK_PATH}"
@ -501,9 +497,7 @@ async def async_oauth2_request(
@callback @callback
def _encode_jwt(hass: HomeAssistant, data: dict) -> str: def _encode_jwt(hass: HomeAssistant, data: dict) -> str:
"""JWT encode data.""" """JWT encode data."""
secret = hass.data.get(DATA_JWT_SECRET) if (secret := hass.data.get(DATA_JWT_SECRET)) is None:
if secret is None:
secret = hass.data[DATA_JWT_SECRET] = secrets.token_hex() secret = hass.data[DATA_JWT_SECRET] = secrets.token_hex()
return jwt.encode(data, secret, algorithm="HS256") return jwt.encode(data, secret, algorithm="HS256")

View file

@ -38,8 +38,7 @@ class _BaseFlowManagerView(HomeAssistantView):
data = result.copy() data = result.copy()
schema = data["data_schema"] if (schema := data["data_schema"]) is None:
if schema is None:
data["data_schema"] = [] data["data_schema"] = []
else: else:
data["data_schema"] = voluptuous_serialize.convert( data["data_schema"] = voluptuous_serialize.convert(

View file

@ -111,9 +111,7 @@ def async_listen_platform(
async def discovery_platform_listener(discovered: DiscoveryDict) -> None: async def discovery_platform_listener(discovered: DiscoveryDict) -> None:
"""Listen for platform discovery events.""" """Listen for platform discovery events."""
platform = discovered["platform"] if not (platform := discovered["platform"]):
if not platform:
return return
task = hass.async_run_hass_job(job, platform, discovered.get("discovered")) task = hass.async_run_hass_job(job, platform, discovered.get("discovered"))

View file

@ -727,8 +727,7 @@ current_platform: ContextVar[EntityPlatform | None] = ContextVar(
@callback @callback
def async_get_current_platform() -> EntityPlatform: def async_get_current_platform() -> EntityPlatform:
"""Get the current platform from context.""" """Get the current platform from context."""
platform = current_platform.get() if (platform := current_platform.get()) is None:
if platform is None:
raise RuntimeError("Cannot get non-set current platform") raise RuntimeError("Cannot get non-set current platform")
return platform return platform

View file

@ -33,8 +33,7 @@ SPEECH_TYPE_SSML = "ssml"
@bind_hass @bind_hass
def async_register(hass: HomeAssistant, handler: IntentHandler) -> None: def async_register(hass: HomeAssistant, handler: IntentHandler) -> None:
"""Register an intent with Home Assistant.""" """Register an intent with Home Assistant."""
intents = hass.data.get(DATA_KEY) if (intents := hass.data.get(DATA_KEY)) is None:
if intents is None:
intents = hass.data[DATA_KEY] = {} intents = hass.data[DATA_KEY] = {}
assert handler.intent_type is not None, "intent_type cannot be None" assert handler.intent_type is not None, "intent_type cannot be None"

View file

@ -51,9 +51,7 @@ def find_coordinates(
hass: HomeAssistant, entity_id: str, recursion_history: list | None = None hass: HomeAssistant, entity_id: str, recursion_history: list | None = None
) -> str | None: ) -> str | None:
"""Find the gps coordinates of the entity in the form of '90.000,180.000'.""" """Find the gps coordinates of the entity in the form of '90.000,180.000'."""
entity_state = hass.states.get(entity_id) if (entity_state := hass.states.get(entity_id)) is None:
if entity_state is None:
_LOGGER.error("Unable to find entity %s", entity_id) _LOGGER.error("Unable to find entity %s", entity_id)
return None return None

View file

@ -118,8 +118,7 @@ def get_url(
def _get_request_host() -> str | None: def _get_request_host() -> str | None:
"""Get the host address of the current request.""" """Get the host address of the current request."""
request = http.current_request.get() if (request := http.current_request.get()) is None:
if request is None:
raise NoURLAvailableError raise NoURLAvailableError
return yarl.URL(request.url).host return yarl.URL(request.url).host

View file

@ -78,8 +78,7 @@ class KeyedRateLimit:
if rate_limit is None: if rate_limit is None:
return None return None
last_triggered = self._last_triggered.get(key) if not (last_triggered := self._last_triggered.get(key)):
if not last_triggered:
return None return None
next_call_time = last_triggered + rate_limit next_call_time = last_triggered + rate_limit

View file

@ -953,8 +953,7 @@ class Script:
variables: ScriptVariables | None = None, variables: ScriptVariables | None = None,
) -> None: ) -> None:
"""Initialize the script.""" """Initialize the script."""
all_scripts = hass.data.get(DATA_SCRIPTS) if not (all_scripts := hass.data.get(DATA_SCRIPTS)):
if not all_scripts:
all_scripts = hass.data[DATA_SCRIPTS] = [] all_scripts = hass.data[DATA_SCRIPTS] = []
hass.bus.async_listen_once( hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, partial(_async_stop_scripts_at_shutdown, hass) EVENT_HOMEASSISTANT_STOP, partial(_async_stop_scripts_at_shutdown, hass)
@ -1273,8 +1272,7 @@ class Script:
config_cache_key = config.template config_cache_key = config.template
else: else:
config_cache_key = frozenset((k, str(v)) for k, v in config.items()) config_cache_key = frozenset((k, str(v)) for k, v in config.items())
cond = self._config_cache.get(config_cache_key) if not (cond := self._config_cache.get(config_cache_key)):
if not cond:
cond = await condition.async_from_config(self._hass, config, False) cond = await condition.async_from_config(self._hass, config, False)
self._config_cache[config_cache_key] = cond self._config_cache[config_cache_key] = cond
return cond return cond
@ -1297,8 +1295,7 @@ class Script:
return sub_script return sub_script
def _get_repeat_script(self, step: int) -> Script: def _get_repeat_script(self, step: int) -> Script:
sub_script = self._repeat_script.get(step) if not (sub_script := self._repeat_script.get(step)):
if not sub_script:
sub_script = self._prep_repeat_script(step) sub_script = self._prep_repeat_script(step)
self._repeat_script[step] = sub_script self._repeat_script[step] = sub_script
return sub_script return sub_script
@ -1351,8 +1348,7 @@ class Script:
return {"choices": choices, "default": default_script} return {"choices": choices, "default": default_script}
async def _async_get_choose_data(self, step: int) -> _ChooseData: async def _async_get_choose_data(self, step: int) -> _ChooseData:
choose_data = self._choose_data.get(step) if not (choose_data := self._choose_data.get(step)):
if not choose_data:
choose_data = await self._async_prep_choose_data(step) choose_data = await self._async_prep_choose_data(step)
self._choose_data[step] = choose_data self._choose_data[step] = choose_data
return choose_data return choose_data

View file

@ -22,9 +22,7 @@ def validate_selector(config: Any) -> dict:
selector_type = list(config)[0] selector_type = list(config)[0]
selector_class = SELECTORS.get(selector_type) if (selector_class := SELECTORS.get(selector_type)) is None:
if selector_class is None:
raise vol.Invalid(f"Unknown selector type {selector_type} found") raise vol.Invalid(f"Unknown selector type {selector_type} found")
# Selectors can be empty # Selectors can be empty

View file

@ -396,10 +396,11 @@ async def async_extract_config_entry_ids(
# Some devices may have no entities # Some devices may have no entities
for device_id in referenced.referenced_devices: for device_id in referenced.referenced_devices:
if device_id in dev_reg.devices: if (
device = dev_reg.async_get(device_id) device_id in dev_reg.devices
if device is not None: and (device := dev_reg.async_get(device_id)) is not None
config_entry_ids.update(device.config_entries) ):
config_entry_ids.update(device.config_entries)
for entity_id in referenced.referenced | referenced.indirectly_referenced: for entity_id in referenced.referenced | referenced.indirectly_referenced:
entry = ent_reg.async_get(entity_id) entry = ent_reg.async_get(entity_id)

View file

@ -813,8 +813,7 @@ class TemplateState(State):
def _collect_state(hass: HomeAssistant, entity_id: str) -> None: def _collect_state(hass: HomeAssistant, entity_id: str) -> None:
entity_collect = hass.data.get(_RENDER_INFO) if (entity_collect := hass.data.get(_RENDER_INFO)) is not None:
if entity_collect is not None:
entity_collect.entities.add(entity_id) entity_collect.entities.add(entity_id)
@ -1188,8 +1187,7 @@ def state_attr(hass: HomeAssistant, entity_id: str, name: str) -> Any:
def now(hass: HomeAssistant) -> datetime: def now(hass: HomeAssistant) -> datetime:
"""Record fetching now.""" """Record fetching now."""
render_info = hass.data.get(_RENDER_INFO) if (render_info := hass.data.get(_RENDER_INFO)) is not None:
if render_info is not None:
render_info.has_time = True render_info.has_time = True
return dt_util.now() return dt_util.now()
@ -1197,8 +1195,7 @@ def now(hass: HomeAssistant) -> datetime:
def utcnow(hass: HomeAssistant) -> datetime: def utcnow(hass: HomeAssistant) -> datetime:
"""Record fetching utcnow.""" """Record fetching utcnow."""
render_info = hass.data.get(_RENDER_INFO) if (render_info := hass.data.get(_RENDER_INFO)) is not None:
if render_info is not None:
render_info.has_time = True render_info.has_time = True
return dt_util.utcnow() return dt_util.utcnow()
@ -1843,9 +1840,7 @@ class TemplateEnvironment(ImmutableSandboxedEnvironment):
# any instance of this. # any instance of this.
return super().compile(source, name, filename, raw, defer_init) return super().compile(source, name, filename, raw, defer_init)
cached = self.template_cache.get(source) if (cached := self.template_cache.get(source)) is None:
if cached is None:
cached = self.template_cache[source] = super().compile(source) cached = self.template_cache[source] = super().compile(source)
return cached return cached

View file

@ -113,8 +113,7 @@ def trace_id_get() -> tuple[tuple[str, str], str] | None:
def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None: def trace_stack_push(trace_stack_var: ContextVar, node: Any) -> None:
"""Push an element to the top of a trace stack.""" """Push an element to the top of a trace stack."""
trace_stack = trace_stack_var.get() if (trace_stack := trace_stack_var.get()) is None:
if trace_stack is None:
trace_stack = [] trace_stack = []
trace_stack_var.set(trace_stack) trace_stack_var.set(trace_stack)
trace_stack.append(node) trace_stack.append(node)
@ -149,8 +148,7 @@ def trace_path_pop(count: int) -> None:
def trace_path_get() -> str: def trace_path_get() -> str:
"""Return a string representing the current location in the config tree.""" """Return a string representing the current location in the config tree."""
path = trace_path_stack_cv.get() if not (path := trace_path_stack_cv.get()):
if not path:
return "" return ""
return "/".join(path) return "/".join(path)
@ -160,12 +158,10 @@ def trace_append_element(
maxlen: int | None = None, maxlen: int | None = None,
) -> None: ) -> None:
"""Append a TraceElement to trace[path].""" """Append a TraceElement to trace[path]."""
path = trace_element.path if (trace := trace_cv.get()) is None:
trace = trace_cv.get()
if trace is None:
trace = {} trace = {}
trace_cv.set(trace) trace_cv.set(trace)
if path not in trace: if (path := trace_element.path) not in trace:
trace[path] = deque(maxlen=maxlen) trace[path] = deque(maxlen=maxlen)
trace[path].append(trace_element) trace[path].append(trace_element)
@ -213,16 +209,14 @@ class StopReason:
def script_execution_set(reason: str) -> None: def script_execution_set(reason: str) -> None:
"""Set stop reason.""" """Set stop reason."""
data = script_execution_cv.get() if (data := script_execution_cv.get()) is None:
if data is None:
return return
data.script_execution = reason data.script_execution = reason
def script_execution_get() -> str | None: def script_execution_get() -> str | None:
"""Return the current trace.""" """Return the current trace."""
data = script_execution_cv.get() if (data := script_execution_cv.get()) is None:
if data is None:
return None return None
return data.script_execution return data.script_execution

View file

@ -146,9 +146,7 @@ async def async_get_custom_components(
hass: HomeAssistant, hass: HomeAssistant,
) -> dict[str, Integration]: ) -> dict[str, Integration]:
"""Return cached list of custom integrations.""" """Return cached list of custom integrations."""
reg_or_evt = hass.data.get(DATA_CUSTOM_COMPONENTS) if (reg_or_evt := hass.data.get(DATA_CUSTOM_COMPONENTS)) is None:
if reg_or_evt is None:
evt = hass.data[DATA_CUSTOM_COMPONENTS] = asyncio.Event() evt = hass.data[DATA_CUSTOM_COMPONENTS] = asyncio.Event()
reg = await _async_get_custom_components(hass) reg = await _async_get_custom_components(hass)
@ -543,8 +541,7 @@ class Integration:
async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration: async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration:
"""Get an integration.""" """Get an integration."""
cache = hass.data.get(DATA_INTEGRATIONS) if (cache := hass.data.get(DATA_INTEGRATIONS)) is None:
if cache is None:
if not _async_mount_config_dir(hass): if not _async_mount_config_dir(hass):
raise IntegrationNotFound(domain) raise IntegrationNotFound(domain)
cache = hass.data[DATA_INTEGRATIONS] = {} cache = hass.data[DATA_INTEGRATIONS] = {}
@ -553,12 +550,11 @@ async def async_get_integration(hass: HomeAssistant, domain: str) -> Integration
if isinstance(int_or_evt, asyncio.Event): if isinstance(int_or_evt, asyncio.Event):
await int_or_evt.wait() await int_or_evt.wait()
int_or_evt = cache.get(domain, _UNDEF)
# When we have waited and it's _UNDEF, it doesn't exist # When we have waited and it's _UNDEF, it doesn't exist
# We don't cache that it doesn't exist, or else people can't fix it # We don't cache that it doesn't exist, or else people can't fix it
# and then restart, because their config will never be valid. # and then restart, because their config will never be valid.
if int_or_evt is _UNDEF: if (int_or_evt := cache.get(domain, _UNDEF)) is _UNDEF:
raise IntegrationNotFound(domain) raise IntegrationNotFound(domain)
if int_or_evt is not _UNDEF: if int_or_evt is not _UNDEF:
@ -630,8 +626,7 @@ def _load_file(
with suppress(KeyError): with suppress(KeyError):
return hass.data[DATA_COMPONENTS][comp_or_platform] # type: ignore return hass.data[DATA_COMPONENTS][comp_or_platform] # type: ignore
cache = hass.data.get(DATA_COMPONENTS) if (cache := hass.data.get(DATA_COMPONENTS)) is None:
if cache is None:
if not _async_mount_config_dir(hass): if not _async_mount_config_dir(hass):
return None return None
cache = hass.data[DATA_COMPONENTS] = {} cache = hass.data[DATA_COMPONENTS] = {}

View file

@ -60,8 +60,7 @@ async def async_get_integration_with_requirements(
if hass.config.skip_pip: if hass.config.skip_pip:
return integration return integration
cache = hass.data.get(DATA_INTEGRATIONS_WITH_REQS) if (cache := hass.data.get(DATA_INTEGRATIONS_WITH_REQS)) is None:
if cache is None:
cache = hass.data[DATA_INTEGRATIONS_WITH_REQS] = {} cache = hass.data[DATA_INTEGRATIONS_WITH_REQS] = {}
int_or_evt: Integration | asyncio.Event | None | UndefinedType = cache.get( int_or_evt: Integration | asyncio.Event | None | UndefinedType = cache.get(
@ -71,12 +70,10 @@ async def async_get_integration_with_requirements(
if isinstance(int_or_evt, asyncio.Event): if isinstance(int_or_evt, asyncio.Event):
await int_or_evt.wait() await int_or_evt.wait()
int_or_evt = cache.get(domain, UNDEFINED)
# When we have waited and it's UNDEFINED, it doesn't exist # When we have waited and it's UNDEFINED, it doesn't exist
# We don't cache that it doesn't exist, or else people can't fix it # We don't cache that it doesn't exist, or else people can't fix it
# and then restart, because their config will never be valid. # and then restart, because their config will never be valid.
if int_or_evt is UNDEFINED: if (int_or_evt := cache.get(domain, UNDEFINED)) is UNDEFINED:
raise IntegrationNotFound(domain) raise IntegrationNotFound(domain)
if int_or_evt is not UNDEFINED: if int_or_evt is not UNDEFINED:
@ -154,8 +151,7 @@ async def async_process_requirements(
This method is a coroutine. It will raise RequirementsNotFound This method is a coroutine. It will raise RequirementsNotFound
if an requirement can't be satisfied. if an requirement can't be satisfied.
""" """
pip_lock = hass.data.get(DATA_PIP_LOCK) if (pip_lock := hass.data.get(DATA_PIP_LOCK)) is None:
if pip_lock is None:
pip_lock = hass.data[DATA_PIP_LOCK] = asyncio.Lock() pip_lock = hass.data[DATA_PIP_LOCK] = asyncio.Lock()
install_failure_history = hass.data.get(DATA_INSTALL_FAILURE_HISTORY) install_failure_history = hass.data.get(DATA_INSTALL_FAILURE_HISTORY)
if install_failure_history is None: if install_failure_history is None:

View file

@ -83,8 +83,7 @@ class HassEventLoopPolicy(asyncio.DefaultEventLoopPolicy): # type: ignore[valid
def _async_loop_exception_handler(_: Any, context: dict[str, Any]) -> None: def _async_loop_exception_handler(_: Any, context: dict[str, Any]) -> None:
"""Handle all exception inside the core loop.""" """Handle all exception inside the core loop."""
kwargs = {} kwargs = {}
exception = context.get("exception") if exception := context.get("exception"):
if exception:
kwargs["exc_info"] = (type(exception), exception, exception.__traceback__) kwargs["exc_info"] = (type(exception), exception, exception.__traceback__)
logging.getLogger(__package__).error( logging.getLogger(__package__).error(