Add and fix type hints (#36501)

* Fix exceptions.Unauthorized.permission type

* Use auth.permission consts more

* Auth typing improvements

* Helpers typing improvements

* Calculate self.state only once
This commit is contained in:
Ville Skyttä 2020-06-06 21:34:56 +03:00 committed by GitHub
parent 49747684a0
commit 0c5ca3084e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 53 additions and 43 deletions

View file

@ -117,7 +117,8 @@ class TotpAuthModule(MultiFactorAuthModule):
Mfa module should extend SetupFlow Mfa module should extend SetupFlow
""" """
user = await self.hass.auth.async_get_user(user_id) # type: ignore user = await self.hass.auth.async_get_user(user_id)
assert user is not None
return TotpSetupFlow(self, self.input_schema, user) return TotpSetupFlow(self, self.input_schema, user)
async def async_setup_user(self, user_id: str, setup_data: Any) -> str: async def async_setup_user(self, user_id: str, setup_data: Any) -> str:

View file

@ -175,7 +175,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
"""Initialize the login flow.""" """Initialize the login flow."""
self._auth_provider = auth_provider self._auth_provider = auth_provider
self._auth_module_id: Optional[str] = None self._auth_module_id: Optional[str] = None
self._auth_manager = auth_provider.hass.auth # type: ignore self._auth_manager = auth_provider.hass.auth
self.available_mfa_modules: Dict[str, str] = {} self.available_mfa_modules: Dict[str, str] = {}
self.created_at = dt_util.utcnow() self.created_at = dt_util.utcnow()
self.invalid_mfa_times = 0 self.invalid_mfa_times = 0
@ -224,6 +224,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
errors = {} errors = {}
assert self._auth_module_id is not None
auth_module = self._auth_manager.get_auth_mfa_module(self._auth_module_id) auth_module = self._auth_manager.get_auth_mfa_module(self._auth_module_id)
if auth_module is None: if auth_module is None:
# Given an invalid input to async_step_select_mfa_module # Given an invalid input to async_step_select_mfa_module
@ -234,7 +235,9 @@ class LoginFlow(data_entry_flow.FlowHandler):
auth_module, "async_initialize_login_mfa_step" auth_module, "async_initialize_login_mfa_step"
): ):
try: try:
await auth_module.async_initialize_login_mfa_step(self.user.id) await auth_module.async_initialize_login_mfa_step( # type: ignore
self.user.id
)
except HomeAssistantError: except HomeAssistantError:
_LOGGER.exception("Error initializing MFA step") _LOGGER.exception("Error initializing MFA step")
return self.async_abort(reason="unknown_error") return self.async_abort(reason="unknown_error")

View file

@ -4,7 +4,7 @@ import voluptuous as vol
import voluptuous_serialize import voluptuous_serialize
from homeassistant import config_entries, data_entry_flow from homeassistant import config_entries, data_entry_flow
from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES, POLICY_EDIT
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.const import HTTP_NOT_FOUND from homeassistant.const import HTTP_NOT_FOUND
@ -180,7 +180,7 @@ class OptionManagerFlowIndexView(FlowManagerIndexView):
handler in request is entry_id. handler in request is entry_id.
""" """
if not request["hass_user"].is_admin: if not request["hass_user"].is_admin:
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="edit") raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
# pylint: disable=no-value-for-parameter # pylint: disable=no-value-for-parameter
return await super().post(request) return await super().post(request)
@ -195,7 +195,7 @@ class OptionManagerFlowResourceView(FlowManagerResourceView):
async def get(self, request, flow_id): async def get(self, request, flow_id):
"""Get the current state of a data_entry_flow.""" """Get the current state of a data_entry_flow."""
if not request["hass_user"].is_admin: if not request["hass_user"].is_admin:
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="edit") raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
return await super().get(request, flow_id) return await super().get(request, flow_id)
@ -203,7 +203,7 @@ class OptionManagerFlowResourceView(FlowManagerResourceView):
async def post(self, request, flow_id): async def post(self, request, flow_id):
"""Handle a POST request.""" """Handle a POST request."""
if not request["hass_user"].is_admin: if not request["hass_user"].is_admin:
raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission="edit") raise Unauthorized(perm_category=CAT_CONFIG_ENTRIES, permission=POLICY_EDIT)
# pylint: disable=no-value-for-parameter # pylint: disable=no-value-for-parameter
return await super().post(request, flow_id) return await super().post(request, flow_id)

View file

@ -85,18 +85,12 @@ async def _validate_edit_permission(
"""Use for validating user control permissions.""" """Use for validating user control permissions."""
splited = split_entity_id(entity_id) splited = split_entity_id(entity_id)
if splited[0] != SWITCH_DOMAIN or not splited[1].startswith(DOMAIN): if splited[0] != SWITCH_DOMAIN or not splited[1].startswith(DOMAIN):
raise Unauthorized( raise Unauthorized(context=context, entity_id=entity_id, permission=POLICY_EDIT)
context=context, entity_id=entity_id, permission=(POLICY_EDIT,)
)
user = await hass.auth.async_get_user(context.user_id) user = await hass.auth.async_get_user(context.user_id)
if user is None: if user is None:
raise UnknownUser( raise UnknownUser(context=context, entity_id=entity_id, permission=POLICY_EDIT)
context=context, entity_id=entity_id, permission=(POLICY_EDIT,)
)
if not user.permissions.check_entity(entity_id, POLICY_EDIT): if not user.permissions.check_entity(entity_id, POLICY_EDIT):
raise Unauthorized( raise Unauthorized(context=context, entity_id=entity_id, permission=POLICY_EDIT)
context=context, entity_id=entity_id, permission=(POLICY_EDIT,)
)
async def async_setup(hass: HomeAssistantType, config: Dict) -> bool: async def async_setup(hass: HomeAssistantType, config: Dict) -> bool:

View file

@ -79,6 +79,7 @@ from homeassistant.util.unit_system import IMPERIAL_SYSTEM, METRIC_SYSTEM, UnitS
# Typing imports that create a circular dependency # Typing imports that create a circular dependency
if TYPE_CHECKING: if TYPE_CHECKING:
from homeassistant.auth import AuthManager
from homeassistant.config_entries import ConfigEntries from homeassistant.config_entries import ConfigEntries
from homeassistant.components.http import HomeAssistantHTTP from homeassistant.components.http import HomeAssistantHTTP
@ -174,6 +175,7 @@ class CoreState(enum.Enum):
class HomeAssistant: class HomeAssistant:
"""Root object of the Home Assistant home automation.""" """Root object of the Home Assistant home automation."""
auth: "AuthManager"
http: "HomeAssistantHTTP" = None # type: ignore http: "HomeAssistantHTTP" = None # type: ignore
config_entries: "ConfigEntries" = None # type: ignore config_entries: "ConfigEntries" = None # type: ignore

View file

@ -1,5 +1,5 @@
"""The exceptions used by Home Assistant.""" """The exceptions used by Home Assistant."""
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Optional
import jinja2 import jinja2
@ -49,7 +49,7 @@ class Unauthorized(HomeAssistantError):
entity_id: Optional[str] = None, entity_id: Optional[str] = None,
config_entry_id: Optional[str] = None, config_entry_id: Optional[str] = None,
perm_category: Optional[str] = None, perm_category: Optional[str] = None,
permission: Optional[Tuple[str]] = None, permission: Optional[str] = None,
) -> None: ) -> None:
"""Unauthorized error.""" """Unauthorized error."""
super().__init__(self.__class__.__name__) super().__init__(self.__class__.__name__)

View file

@ -5,7 +5,7 @@ from datetime import datetime, timedelta
import functools as ft import functools as ft
import logging import logging
from timeit import default_timer as timer from timeit import default_timer as timer
from typing import Any, Dict, Iterable, List, Optional, Union from typing import Any, Awaitable, Dict, Iterable, List, Optional, Union
from homeassistant.config import DATA_CUSTOMIZE from homeassistant.config import DATA_CUSTOMIZE
from homeassistant.const import ( from homeassistant.const import (
@ -32,11 +32,10 @@ from homeassistant.helpers.entity_registry import (
EVENT_ENTITY_REGISTRY_UPDATED, EVENT_ENTITY_REGISTRY_UPDATED,
RegistryEntry, RegistryEntry,
) )
from homeassistant.helpers.event import Event
from homeassistant.util import dt as dt_util, ensure_unique_string, slugify from homeassistant.util import dt as dt_util, ensure_unique_string, slugify
from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.async_ import run_callback_threadsafe
# mypy: allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SLOW_UPDATE_WARNING = 10 SLOW_UPDATE_WARNING = 10
@ -258,7 +257,7 @@ class Entity(ABC):
self._context = context self._context = context
self._context_set = dt_util.utcnow() self._context_set = dt_util.utcnow()
async def async_update_ha_state(self, force_refresh=False): async def async_update_ha_state(self, force_refresh: bool = False) -> None:
"""Update Home Assistant with current state of entity. """Update Home Assistant with current state of entity.
If force_refresh == True will update entity before setting state. If force_refresh == True will update entity before setting state.
@ -294,14 +293,15 @@ class Entity(ABC):
f"No entity id specified for entity {self.name}" f"No entity id specified for entity {self.name}"
) )
self._async_write_ha_state() # type: ignore self._async_write_ha_state()
@callback @callback
def _async_write_ha_state(self): def _async_write_ha_state(self) -> None:
"""Write the state to the state machine.""" """Write the state to the state machine."""
if self.registry_entry and self.registry_entry.disabled_by: if self.registry_entry and self.registry_entry.disabled_by:
if not self._disabled_reported: if not self._disabled_reported:
self._disabled_reported = True self._disabled_reported = True
assert self.platform is not None
_LOGGER.warning( _LOGGER.warning(
"Entity %s is incorrectly being triggered for updates while it is disabled. This is a bug in the %s integration.", "Entity %s is incorrectly being triggered for updates while it is disabled. This is a bug in the %s integration.",
self.entity_id, self.entity_id,
@ -317,9 +317,8 @@ class Entity(ABC):
if not self.available: if not self.available:
state = STATE_UNAVAILABLE state = STATE_UNAVAILABLE
else: else:
state = self.state sstate = self.state
state = STATE_UNKNOWN if sstate is None else str(sstate)
state = STATE_UNKNOWN if state is None else str(state)
attr.update(self.state_attributes or {}) attr.update(self.state_attributes or {})
attr.update(self.device_state_attributes or {}) attr.update(self.device_state_attributes or {})
@ -383,6 +382,7 @@ class Entity(ABC):
) )
# Overwrite properties that have been set in the config file. # Overwrite properties that have been set in the config file.
assert self.hass is not None
if DATA_CUSTOMIZE in self.hass.data: if DATA_CUSTOMIZE in self.hass.data:
attr.update(self.hass.data[DATA_CUSTOMIZE].get(self.entity_id)) attr.update(self.hass.data[DATA_CUSTOMIZE].get(self.entity_id))
@ -403,7 +403,7 @@ class Entity(ABC):
pass pass
if ( if (
self._context is not None self._context_set is not None
and dt_util.utcnow() - self._context_set > self.context_recent_time and dt_util.utcnow() - self._context_set > self.context_recent_time
): ):
self._context = None self._context = None
@ -413,7 +413,7 @@ class Entity(ABC):
self.entity_id, state, attr, self.force_update, self._context self.entity_id, state, attr, self.force_update, self._context
) )
def schedule_update_ha_state(self, force_refresh=False): def schedule_update_ha_state(self, force_refresh: bool = False) -> None:
"""Schedule an update ha state change task. """Schedule an update ha state change task.
Scheduling the update avoids executor deadlocks. Scheduling the update avoids executor deadlocks.
@ -423,10 +423,11 @@ class Entity(ABC):
If state is changed more than once before the ha state change task has If state is changed more than once before the ha state change task has
been executed, the intermediate state transitions will be missed. been executed, the intermediate state transitions will be missed.
""" """
self.hass.add_job(self.async_update_ha_state(force_refresh)) assert self.hass is not None
self.hass.add_job(self.async_update_ha_state(force_refresh)) # type: ignore
@callback @callback
def async_schedule_update_ha_state(self, force_refresh=False): def async_schedule_update_ha_state(self, force_refresh: bool = False) -> None:
"""Schedule an update ha state change task. """Schedule an update ha state change task.
This method must be run in the event loop. This method must be run in the event loop.
@ -438,11 +439,12 @@ class Entity(ABC):
been executed, the intermediate state transitions will be missed. been executed, the intermediate state transitions will be missed.
""" """
if force_refresh: if force_refresh:
assert self.hass is not None
self.hass.async_create_task(self.async_update_ha_state(force_refresh)) self.hass.async_create_task(self.async_update_ha_state(force_refresh))
else: else:
self.async_write_ha_state() self.async_write_ha_state()
async def async_device_update(self, warning=True): async def async_device_update(self, warning: bool = True) -> None:
"""Process 'update' or 'async_update' from entity. """Process 'update' or 'async_update' from entity.
This method is a coroutine. This method is a coroutine.
@ -455,6 +457,7 @@ class Entity(ABC):
if self.parallel_updates: if self.parallel_updates:
await self.parallel_updates.acquire() await self.parallel_updates.acquire()
assert self.hass is not None
if warning: if warning:
update_warn = self.hass.loop.call_later( update_warn = self.hass.loop.call_later(
SLOW_UPDATE_WARNING, SLOW_UPDATE_WARNING,
@ -467,9 +470,11 @@ class Entity(ABC):
try: try:
# pylint: disable=no-member # pylint: disable=no-member
if hasattr(self, "async_update"): if hasattr(self, "async_update"):
await self.async_update() await self.async_update() # type: ignore
elif hasattr(self, "update"): elif hasattr(self, "update"):
await self.hass.async_add_executor_job(self.update) await self.hass.async_add_executor_job(
self.update # type: ignore
)
finally: finally:
self._update_staged = False self._update_staged = False
if warning: if warning:
@ -534,7 +539,7 @@ class Entity(ABC):
Not to be extended by integrations. Not to be extended by integrations.
""" """
async def _async_registry_updated(self, event): async def _async_registry_updated(self, event: Event) -> None:
"""Handle entity registry update.""" """Handle entity registry update."""
data = event.data data = event.data
if data["action"] == "remove" and data["entity_id"] == self.entity_id: if data["action"] == "remove" and data["entity_id"] == self.entity_id:
@ -547,24 +552,28 @@ class Entity(ABC):
): ):
return return
assert self.hass is not None
ent_reg = await self.hass.helpers.entity_registry.async_get_registry() ent_reg = await self.hass.helpers.entity_registry.async_get_registry()
old = self.registry_entry old = self.registry_entry
self.registry_entry = ent_reg.async_get(data["entity_id"]) self.registry_entry = ent_reg.async_get(data["entity_id"])
assert self.registry_entry is not None
if self.registry_entry.disabled_by is not None: if self.registry_entry.disabled_by is not None:
await self.async_remove() await self.async_remove()
return return
assert old is not None
if self.registry_entry.entity_id == old.entity_id: if self.registry_entry.entity_id == old.entity_id:
self.async_write_ha_state() self.async_write_ha_state()
return return
await self.async_remove() await self.async_remove()
assert self.platform is not None
self.entity_id = self.registry_entry.entity_id self.entity_id = self.registry_entry.entity_id
await self.platform.async_add_entities([self]) await self.platform.async_add_entities([self])
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
"""Return the comparison.""" """Return the comparison."""
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return False return False
@ -587,8 +596,7 @@ class Entity(ABC):
"""Return the representation.""" """Return the representation."""
return f"<Entity {self.name}: {self.state}>" return f"<Entity {self.name}: {self.state}>"
# call an requests async def async_request_call(self, coro: Awaitable) -> None:
async def async_request_call(self, coro):
"""Process request batched.""" """Process request batched."""
if self.parallel_updates: if self.parallel_updates:
await self.parallel_updates.acquire() await self.parallel_updates.acquire()
@ -617,16 +625,18 @@ class ToggleEntity(Entity):
"""Turn the entity on.""" """Turn the entity on."""
raise NotImplementedError() raise NotImplementedError()
async def async_turn_on(self, **kwargs): async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn the entity on.""" """Turn the entity on."""
assert self.hass is not None
await self.hass.async_add_executor_job(ft.partial(self.turn_on, **kwargs)) await self.hass.async_add_executor_job(ft.partial(self.turn_on, **kwargs))
def turn_off(self, **kwargs: Any) -> None: def turn_off(self, **kwargs: Any) -> None:
"""Turn the entity off.""" """Turn the entity off."""
raise NotImplementedError() raise NotImplementedError()
async def async_turn_off(self, **kwargs): async def async_turn_off(self, **kwargs: Any) -> None:
"""Turn the entity off.""" """Turn the entity off."""
assert self.hass is not None
await self.hass.async_add_executor_job(ft.partial(self.turn_off, **kwargs)) await self.hass.async_add_executor_job(ft.partial(self.turn_off, **kwargs))
def toggle(self, **kwargs: Any) -> None: def toggle(self, **kwargs: Any) -> None:
@ -636,7 +646,7 @@ class ToggleEntity(Entity):
else: else:
self.turn_on(**kwargs) self.turn_on(**kwargs)
async def async_toggle(self, **kwargs): async def async_toggle(self, **kwargs: Any) -> None:
"""Toggle the entity.""" """Toggle the entity."""
if self.is_on: if self.is_on:
await self.async_turn_off(**kwargs) await self.async_turn_off(**kwargs)

View file

@ -542,7 +542,7 @@ class EntityPlatform:
for entity in self.entities.values(): for entity in self.entities.values():
if not entity.should_poll: if not entity.should_poll:
continue continue
tasks.append(entity.async_update_ha_state(True)) # type: ignore tasks.append(entity.async_update_ha_state(True))
if tasks: if tasks:
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View file

@ -505,7 +505,7 @@ def async_register_admin_service(
"""Register a service that requires admin access.""" """Register a service that requires admin access."""
@wraps(service_func) @wraps(service_func)
async def admin_handler(call): async def admin_handler(call: ha.ServiceCall) -> None:
if call.context.user_id: if call.context.user_id:
user = await hass.auth.async_get_user(call.context.user_id) user = await hass.auth.async_get_user(call.context.user_id)
if user is None: if user is None: