More data entry flow and HTTP related type hints (#34430)

This commit is contained in:
Ville Skyttä 2020-05-26 17:28:22 +03:00 committed by GitHub
parent bc1dac80b6
commit f8416484f8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 79 additions and 56 deletions

View file

@ -172,7 +172,7 @@ class CalendarEventView(http.HomeAssistantView):
url = "/api/calendars/{entity_id}"
name = "api:calendars:calendar"
def __init__(self, component):
def __init__(self, component: EntityComponent) -> None:
"""Initialize calendar view."""
self.component = component
@ -200,11 +200,11 @@ class CalendarListView(http.HomeAssistantView):
url = "/api/calendars"
name = "api:calendars"
def __init__(self, component):
def __init__(self, component: EntityComponent) -> None:
"""Initialize calendar view."""
self.component = component
async def get(self, request):
async def get(self, request: web.Request) -> web.Response:
"""Retrieve calendar list."""
hass = request.app["hass"]
calendar_list = []

View file

@ -473,11 +473,11 @@ class CameraView(HomeAssistantView):
requires_auth = False
def __init__(self, component):
def __init__(self, component: EntityComponent) -> None:
"""Initialize a basic camera view."""
self.component = component
async def get(self, request, entity_id):
async def get(self, request: web.Request, entity_id: str) -> web.Response:
"""Start a GET request."""
camera = self.component.get_entity(entity_id)
@ -509,7 +509,7 @@ class CameraImageView(CameraView):
url = "/api/camera_proxy/{entity_id}"
name = "api:camera:image"
async def handle(self, request, camera):
async def handle(self, request: web.Request, camera: Camera) -> web.Response:
"""Serve camera image."""
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
async with async_timeout.timeout(10):
@ -527,7 +527,7 @@ class CameraMjpegStream(CameraView):
url = "/api/camera_proxy_stream/{entity_id}"
name = "api:camera:stream"
async def handle(self, request, camera):
async def handle(self, request: web.Request, camera: Camera) -> web.Response:
"""Serve camera stream, possibly with interval."""
interval = request.query.get("interval")
if interval is None:

View file

@ -4,7 +4,9 @@ from datetime import timedelta
from itertools import groupby
import logging
import time
from typing import Optional, cast
from aiohttp import web
from sqlalchemy import and_, func
import voluptuous as vol
@ -337,20 +339,22 @@ class HistoryPeriodView(HomeAssistantView):
self.filters = filters
self.use_include_order = use_include_order
async def get(self, request, datetime=None):
async def get(
self, request: web.Request, datetime: Optional[str] = None
) -> web.Response:
"""Return history over a period of time."""
if datetime:
datetime = dt_util.parse_datetime(datetime)
datetime_ = dt_util.parse_datetime(datetime)
if datetime is None:
if datetime_ is None:
return self.json_message("Invalid datetime", HTTP_BAD_REQUEST)
now = dt_util.utcnow()
one_day = timedelta(days=1)
if datetime:
start_time = dt_util.as_utc(datetime)
if datetime_:
start_time = dt_util.as_utc(datetime_)
else:
start_time = now - one_day
@ -376,14 +380,17 @@ class HistoryPeriodView(HomeAssistantView):
hass = request.app["hass"]
return await hass.async_add_executor_job(
self._sorted_significant_states_json,
hass,
start_time,
end_time,
entity_ids,
include_start_time_state,
significant_changes_only,
return cast(
web.Response,
await hass.async_add_executor_job(
self._sorted_significant_states_json,
hass,
start_time,
end_time,
entity_ids,
include_start_time_state,
significant_changes_only,
),
)
def _sorted_significant_states_json(

View file

@ -1,12 +1,14 @@
"""Decorator for view methods to help with data validation."""
from functools import wraps
import logging
from typing import Any, Awaitable, Callable
from aiohttp import web
import voluptuous as vol
from homeassistant.const import HTTP_BAD_REQUEST
# mypy: allow-untyped-defs
from .view import HomeAssistantView
_LOGGER = logging.getLogger(__name__)
@ -20,7 +22,7 @@ class RequestDataValidator:
Will return a 400 if no JSON provided or doesn't match schema.
"""
def __init__(self, schema, allow_empty=False):
def __init__(self, schema: vol.Schema, allow_empty: bool = False) -> None:
"""Initialize the decorator."""
if isinstance(schema, dict):
schema = vol.Schema(schema)
@ -28,11 +30,15 @@ class RequestDataValidator:
self._schema = schema
self._allow_empty = allow_empty
def __call__(self, method):
def __call__(
self, method: Callable[..., Awaitable[web.StreamResponse]]
) -> Callable:
"""Decorate a function."""
@wraps(method)
async def wrapper(view, request, *args, **kwargs):
async def wrapper(
view: HomeAssistantView, request: web.Request, *args: Any, **kwargs: Any
) -> web.StreamResponse:
"""Wrap a request handler with data validation."""
data = None
try:

View file

@ -2,9 +2,10 @@
import asyncio
import json
import logging
from typing import List, Optional
from typing import Any, Callable, List, Optional
from aiohttp import web
from aiohttp.typedefs import LooseHeaders
from aiohttp.web_exceptions import (
HTTPBadRequest,
HTTPInternalServerError,
@ -22,9 +23,6 @@ from .const import KEY_AUTHENTICATED, KEY_HASS, KEY_REAL_IP
_LOGGER = logging.getLogger(__name__)
# mypy: allow-untyped-defs, no-check-untyped-defs
class HomeAssistantView:
"""Base view for all views."""
@ -35,7 +33,7 @@ class HomeAssistantView:
cors_allowed = False
@staticmethod
def context(request):
def context(request: web.Request) -> Context:
"""Generate a context from a request."""
user = request.get("hass_user")
if user is None:
@ -44,7 +42,9 @@ class HomeAssistantView:
return Context(user_id=user.id)
@staticmethod
def json(result, status_code=HTTP_OK, headers=None):
def json(
result: Any, status_code: int = HTTP_OK, headers: Optional[LooseHeaders] = None,
) -> web.Response:
"""Return a JSON response."""
try:
msg = json.dumps(
@ -63,15 +63,19 @@ class HomeAssistantView:
return response
def json_message(
self, message, status_code=HTTP_OK, message_code=None, headers=None
):
self,
message: str,
status_code: int = HTTP_OK,
message_code: Optional[str] = None,
headers: Optional[LooseHeaders] = None,
) -> web.Response:
"""Return a JSON message response."""
data = {"message": message}
if message_code is not None:
data["code"] = message_code
return self.json(data, status_code, headers=headers)
def register(self, app, router):
def register(self, app: web.Application, router: web.UrlDispatcher) -> None:
"""Register the view with a router."""
assert self.url is not None, "No url set for view"
urls = [self.url] + self.extra_urls
@ -95,13 +99,13 @@ class HomeAssistantView:
app["allow_cors"](route)
def request_handler_factory(view, handler):
def request_handler_factory(view: HomeAssistantView, handler: Callable) -> Callable:
"""Wrap the handler classes."""
assert asyncio.iscoroutinefunction(handler) or is_callback(
handler
), "Handler should be a coroutine or a callback."
async def handle(request):
async def handle(request: web.Request) -> web.StreamResponse:
"""Handle incoming request."""
if not request.app[KEY_HASS].is_running:
return web.Response(status=503)
@ -139,15 +143,17 @@ def request_handler_factory(view, handler):
if isinstance(result, tuple):
result, status_code = result
if isinstance(result, str):
result = result.encode("utf-8")
if isinstance(result, bytes):
bresult = result
elif isinstance(result, str):
bresult = result.encode("utf-8")
elif result is None:
result = b""
elif not isinstance(result, bytes):
bresult = b""
else:
assert (
False
), f"Result should be None, string, bytes or Response. Got: {result}"
return web.Response(body=result, status=status_code)
return web.Response(body=bresult, status=status_code)
return handle

View file

@ -200,7 +200,7 @@ class MailboxPlatformsView(MailboxView):
url = "/api/mailbox/platforms"
name = "api:mailbox:platforms"
async def get(self, request):
async def get(self, request: web.Request) -> web.Response:
"""Retrieve list of platforms."""
platforms = []
for mailbox in self.mailboxes:

View file

@ -12,6 +12,7 @@ from urllib.parse import urlparse
from aiohttp import web
from aiohttp.hdrs import CACHE_CONTROL, CONTENT_TYPE
from aiohttp.typedefs import LooseHeaders
import async_timeout
import voluptuous as vol
@ -863,7 +864,7 @@ class MediaPlayerImageView(HomeAssistantView):
"""Initialize a media player view."""
self.component = component
async def get(self, request, entity_id):
async def get(self, request: web.Request, entity_id: str) -> web.Response:
"""Start a get request."""
player = self.component.get_entity(entity_id)
if player is None:
@ -883,7 +884,7 @@ class MediaPlayerImageView(HomeAssistantView):
if data is None:
return web.Response(status=HTTP_INTERNAL_SERVER_ERROR)
headers = {CACHE_CONTROL: "max-age=3600"}
headers: LooseHeaders = {CACHE_CONTROL: "max-age=3600"}
return web.Response(body=data, content_type=content_type, headers=headers)

View file

@ -530,7 +530,7 @@ class TextToSpeechUrlView(HomeAssistantView):
"""Initialize a tts view."""
self.tts = tts
async def post(self, request):
async def post(self, request: web.Request) -> web.Response:
"""Generate speech and provide url."""
try:
data = await request.json()
@ -570,7 +570,7 @@ class TextToSpeechView(HomeAssistantView):
"""Initialize a tts view."""
self.tts = tts
async def get(self, request, filename):
async def get(self, request: web.Request, filename: str) -> web.Response:
"""Start a get request."""
try:
content, data = await self.tts.async_read_tts(filename)

View file

@ -42,7 +42,7 @@ class WebsocketAPIView(HomeAssistantView):
url = URL
requires_auth = False
async def get(self, request):
async def get(self, request: web.Request) -> web.WebSocketResponse:
"""Handle an incoming websocket connection."""
return await WebSocketHandler(request.app["hass"], request).async_handle()
@ -148,7 +148,7 @@ class WebSocketHandler:
self._handle_task.cancel()
self._writer_task.cancel()
async def async_handle(self):
async def async_handle(self) -> web.WebSocketResponse:
"""Handle a websocket response."""
request = self.request
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)

View file

@ -1,5 +1,8 @@
"""Helpers for the data entry flow."""
from typing import Any, Dict
from aiohttp import web
import voluptuous as vol
from homeassistant import config_entries, data_entry_flow
@ -8,18 +11,16 @@ from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.const import HTTP_NOT_FOUND
import homeassistant.helpers.config_validation as cv
# mypy: allow-untyped-calls, allow-untyped-defs
class _BaseFlowManagerView(HomeAssistantView):
"""Foundation for flow manager views."""
def __init__(self, flow_mgr):
def __init__(self, flow_mgr: data_entry_flow.FlowManager) -> None:
"""Initialize the flow manager index view."""
self._flow_mgr = flow_mgr
# pylint: disable=no-self-use
def _prepare_result_json(self, result):
def _prepare_result_json(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""Convert result to JSON."""
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
data = result.copy()
@ -57,7 +58,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
extra=vol.ALLOW_EXTRA,
)
)
async def post(self, request, data):
async def post(self, request: web.Request, data: Dict[str, Any]) -> web.Response:
"""Handle a POST request."""
if isinstance(data["handler"], list):
handler = tuple(data["handler"])
@ -66,7 +67,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
try:
result = await self._flow_mgr.async_init(
handler,
handler, # type: ignore
context={
"source": config_entries.SOURCE_USER,
"show_advanced_options": data["show_advanced_options"],
@ -85,7 +86,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
class FlowManagerResourceView(_BaseFlowManagerView):
"""View to interact with the flow manager."""
async def get(self, request, flow_id):
async def get(self, request: web.Request, flow_id: str) -> web.Response:
"""Get the current state of a data_entry_flow."""
try:
result = await self._flow_mgr.async_configure(flow_id)
@ -97,7 +98,9 @@ class FlowManagerResourceView(_BaseFlowManagerView):
return self.json(result)
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
async def post(self, request, flow_id, data):
async def post(
self, request: web.Request, flow_id: str, data: Dict[str, Any]
) -> web.Response:
"""Handle a POST request."""
try:
result = await self._flow_mgr.async_configure(flow_id, data)
@ -110,7 +113,7 @@ class FlowManagerResourceView(_BaseFlowManagerView):
return self.json(result)
async def delete(self, request, flow_id):
async def delete(self, request: web.Request, flow_id: str) -> web.Response:
"""Cancel a flow in progress."""
try:
self._flow_mgr.async_abort(flow_id)