More data entry flow and HTTP related type hints (#34430)
This commit is contained in:
parent
bc1dac80b6
commit
f8416484f8
10 changed files with 79 additions and 56 deletions
|
@ -172,7 +172,7 @@ class CalendarEventView(http.HomeAssistantView):
|
|||
url = "/api/calendars/{entity_id}"
|
||||
name = "api:calendars:calendar"
|
||||
|
||||
def __init__(self, component):
|
||||
def __init__(self, component: EntityComponent) -> None:
|
||||
"""Initialize calendar view."""
|
||||
self.component = component
|
||||
|
||||
|
@ -200,11 +200,11 @@ class CalendarListView(http.HomeAssistantView):
|
|||
url = "/api/calendars"
|
||||
name = "api:calendars"
|
||||
|
||||
def __init__(self, component):
|
||||
def __init__(self, component: EntityComponent) -> None:
|
||||
"""Initialize calendar view."""
|
||||
self.component = component
|
||||
|
||||
async def get(self, request):
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Retrieve calendar list."""
|
||||
hass = request.app["hass"]
|
||||
calendar_list = []
|
||||
|
|
|
@ -473,11 +473,11 @@ class CameraView(HomeAssistantView):
|
|||
|
||||
requires_auth = False
|
||||
|
||||
def __init__(self, component):
|
||||
def __init__(self, component: EntityComponent) -> None:
|
||||
"""Initialize a basic camera view."""
|
||||
self.component = component
|
||||
|
||||
async def get(self, request, entity_id):
|
||||
async def get(self, request: web.Request, entity_id: str) -> web.Response:
|
||||
"""Start a GET request."""
|
||||
camera = self.component.get_entity(entity_id)
|
||||
|
||||
|
@ -509,7 +509,7 @@ class CameraImageView(CameraView):
|
|||
url = "/api/camera_proxy/{entity_id}"
|
||||
name = "api:camera:image"
|
||||
|
||||
async def handle(self, request, camera):
|
||||
async def handle(self, request: web.Request, camera: Camera) -> web.Response:
|
||||
"""Serve camera image."""
|
||||
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
|
||||
async with async_timeout.timeout(10):
|
||||
|
@ -527,7 +527,7 @@ class CameraMjpegStream(CameraView):
|
|||
url = "/api/camera_proxy_stream/{entity_id}"
|
||||
name = "api:camera:stream"
|
||||
|
||||
async def handle(self, request, camera):
|
||||
async def handle(self, request: web.Request, camera: Camera) -> web.Response:
|
||||
"""Serve camera stream, possibly with interval."""
|
||||
interval = request.query.get("interval")
|
||||
if interval is None:
|
||||
|
|
|
@ -4,7 +4,9 @@ from datetime import timedelta
|
|||
from itertools import groupby
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, cast
|
||||
|
||||
from aiohttp import web
|
||||
from sqlalchemy import and_, func
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -337,20 +339,22 @@ class HistoryPeriodView(HomeAssistantView):
|
|||
self.filters = filters
|
||||
self.use_include_order = use_include_order
|
||||
|
||||
async def get(self, request, datetime=None):
|
||||
async def get(
|
||||
self, request: web.Request, datetime: Optional[str] = None
|
||||
) -> web.Response:
|
||||
"""Return history over a period of time."""
|
||||
|
||||
if datetime:
|
||||
datetime = dt_util.parse_datetime(datetime)
|
||||
datetime_ = dt_util.parse_datetime(datetime)
|
||||
|
||||
if datetime is None:
|
||||
if datetime_ is None:
|
||||
return self.json_message("Invalid datetime", HTTP_BAD_REQUEST)
|
||||
|
||||
now = dt_util.utcnow()
|
||||
|
||||
one_day = timedelta(days=1)
|
||||
if datetime:
|
||||
start_time = dt_util.as_utc(datetime)
|
||||
if datetime_:
|
||||
start_time = dt_util.as_utc(datetime_)
|
||||
else:
|
||||
start_time = now - one_day
|
||||
|
||||
|
@ -376,14 +380,17 @@ class HistoryPeriodView(HomeAssistantView):
|
|||
|
||||
hass = request.app["hass"]
|
||||
|
||||
return await hass.async_add_executor_job(
|
||||
self._sorted_significant_states_json,
|
||||
hass,
|
||||
start_time,
|
||||
end_time,
|
||||
entity_ids,
|
||||
include_start_time_state,
|
||||
significant_changes_only,
|
||||
return cast(
|
||||
web.Response,
|
||||
await hass.async_add_executor_job(
|
||||
self._sorted_significant_states_json,
|
||||
hass,
|
||||
start_time,
|
||||
end_time,
|
||||
entity_ids,
|
||||
include_start_time_state,
|
||||
significant_changes_only,
|
||||
),
|
||||
)
|
||||
|
||||
def _sorted_significant_states_json(
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
"""Decorator for view methods to help with data validation."""
|
||||
from functools import wraps
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from aiohttp import web
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import HTTP_BAD_REQUEST
|
||||
|
||||
# mypy: allow-untyped-defs
|
||||
from .view import HomeAssistantView
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -20,7 +22,7 @@ class RequestDataValidator:
|
|||
Will return a 400 if no JSON provided or doesn't match schema.
|
||||
"""
|
||||
|
||||
def __init__(self, schema, allow_empty=False):
|
||||
def __init__(self, schema: vol.Schema, allow_empty: bool = False) -> None:
|
||||
"""Initialize the decorator."""
|
||||
if isinstance(schema, dict):
|
||||
schema = vol.Schema(schema)
|
||||
|
@ -28,11 +30,15 @@ class RequestDataValidator:
|
|||
self._schema = schema
|
||||
self._allow_empty = allow_empty
|
||||
|
||||
def __call__(self, method):
|
||||
def __call__(
|
||||
self, method: Callable[..., Awaitable[web.StreamResponse]]
|
||||
) -> Callable:
|
||||
"""Decorate a function."""
|
||||
|
||||
@wraps(method)
|
||||
async def wrapper(view, request, *args, **kwargs):
|
||||
async def wrapper(
|
||||
view: HomeAssistantView, request: web.Request, *args: Any, **kwargs: Any
|
||||
) -> web.StreamResponse:
|
||||
"""Wrap a request handler with data validation."""
|
||||
data = None
|
||||
try:
|
||||
|
|
|
@ -2,9 +2,10 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.typedefs import LooseHeaders
|
||||
from aiohttp.web_exceptions import (
|
||||
HTTPBadRequest,
|
||||
HTTPInternalServerError,
|
||||
|
@ -22,9 +23,6 @@ from .const import KEY_AUTHENTICATED, KEY_HASS, KEY_REAL_IP
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
|
||||
class HomeAssistantView:
|
||||
"""Base view for all views."""
|
||||
|
||||
|
@ -35,7 +33,7 @@ class HomeAssistantView:
|
|||
cors_allowed = False
|
||||
|
||||
@staticmethod
|
||||
def context(request):
|
||||
def context(request: web.Request) -> Context:
|
||||
"""Generate a context from a request."""
|
||||
user = request.get("hass_user")
|
||||
if user is None:
|
||||
|
@ -44,7 +42,9 @@ class HomeAssistantView:
|
|||
return Context(user_id=user.id)
|
||||
|
||||
@staticmethod
|
||||
def json(result, status_code=HTTP_OK, headers=None):
|
||||
def json(
|
||||
result: Any, status_code: int = HTTP_OK, headers: Optional[LooseHeaders] = None,
|
||||
) -> web.Response:
|
||||
"""Return a JSON response."""
|
||||
try:
|
||||
msg = json.dumps(
|
||||
|
@ -63,15 +63,19 @@ class HomeAssistantView:
|
|||
return response
|
||||
|
||||
def json_message(
|
||||
self, message, status_code=HTTP_OK, message_code=None, headers=None
|
||||
):
|
||||
self,
|
||||
message: str,
|
||||
status_code: int = HTTP_OK,
|
||||
message_code: Optional[str] = None,
|
||||
headers: Optional[LooseHeaders] = None,
|
||||
) -> web.Response:
|
||||
"""Return a JSON message response."""
|
||||
data = {"message": message}
|
||||
if message_code is not None:
|
||||
data["code"] = message_code
|
||||
return self.json(data, status_code, headers=headers)
|
||||
|
||||
def register(self, app, router):
|
||||
def register(self, app: web.Application, router: web.UrlDispatcher) -> None:
|
||||
"""Register the view with a router."""
|
||||
assert self.url is not None, "No url set for view"
|
||||
urls = [self.url] + self.extra_urls
|
||||
|
@ -95,13 +99,13 @@ class HomeAssistantView:
|
|||
app["allow_cors"](route)
|
||||
|
||||
|
||||
def request_handler_factory(view, handler):
|
||||
def request_handler_factory(view: HomeAssistantView, handler: Callable) -> Callable:
|
||||
"""Wrap the handler classes."""
|
||||
assert asyncio.iscoroutinefunction(handler) or is_callback(
|
||||
handler
|
||||
), "Handler should be a coroutine or a callback."
|
||||
|
||||
async def handle(request):
|
||||
async def handle(request: web.Request) -> web.StreamResponse:
|
||||
"""Handle incoming request."""
|
||||
if not request.app[KEY_HASS].is_running:
|
||||
return web.Response(status=503)
|
||||
|
@ -139,15 +143,17 @@ def request_handler_factory(view, handler):
|
|||
if isinstance(result, tuple):
|
||||
result, status_code = result
|
||||
|
||||
if isinstance(result, str):
|
||||
result = result.encode("utf-8")
|
||||
if isinstance(result, bytes):
|
||||
bresult = result
|
||||
elif isinstance(result, str):
|
||||
bresult = result.encode("utf-8")
|
||||
elif result is None:
|
||||
result = b""
|
||||
elif not isinstance(result, bytes):
|
||||
bresult = b""
|
||||
else:
|
||||
assert (
|
||||
False
|
||||
), f"Result should be None, string, bytes or Response. Got: {result}"
|
||||
|
||||
return web.Response(body=result, status=status_code)
|
||||
return web.Response(body=bresult, status=status_code)
|
||||
|
||||
return handle
|
||||
|
|
|
@ -200,7 +200,7 @@ class MailboxPlatformsView(MailboxView):
|
|||
url = "/api/mailbox/platforms"
|
||||
name = "api:mailbox:platforms"
|
||||
|
||||
async def get(self, request):
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Retrieve list of platforms."""
|
||||
platforms = []
|
||||
for mailbox in self.mailboxes:
|
||||
|
|
|
@ -12,6 +12,7 @@ from urllib.parse import urlparse
|
|||
|
||||
from aiohttp import web
|
||||
from aiohttp.hdrs import CACHE_CONTROL, CONTENT_TYPE
|
||||
from aiohttp.typedefs import LooseHeaders
|
||||
import async_timeout
|
||||
import voluptuous as vol
|
||||
|
||||
|
@ -863,7 +864,7 @@ class MediaPlayerImageView(HomeAssistantView):
|
|||
"""Initialize a media player view."""
|
||||
self.component = component
|
||||
|
||||
async def get(self, request, entity_id):
|
||||
async def get(self, request: web.Request, entity_id: str) -> web.Response:
|
||||
"""Start a get request."""
|
||||
player = self.component.get_entity(entity_id)
|
||||
if player is None:
|
||||
|
@ -883,7 +884,7 @@ class MediaPlayerImageView(HomeAssistantView):
|
|||
if data is None:
|
||||
return web.Response(status=HTTP_INTERNAL_SERVER_ERROR)
|
||||
|
||||
headers = {CACHE_CONTROL: "max-age=3600"}
|
||||
headers: LooseHeaders = {CACHE_CONTROL: "max-age=3600"}
|
||||
return web.Response(body=data, content_type=content_type, headers=headers)
|
||||
|
||||
|
||||
|
|
|
@ -530,7 +530,7 @@ class TextToSpeechUrlView(HomeAssistantView):
|
|||
"""Initialize a tts view."""
|
||||
self.tts = tts
|
||||
|
||||
async def post(self, request):
|
||||
async def post(self, request: web.Request) -> web.Response:
|
||||
"""Generate speech and provide url."""
|
||||
try:
|
||||
data = await request.json()
|
||||
|
@ -570,7 +570,7 @@ class TextToSpeechView(HomeAssistantView):
|
|||
"""Initialize a tts view."""
|
||||
self.tts = tts
|
||||
|
||||
async def get(self, request, filename):
|
||||
async def get(self, request: web.Request, filename: str) -> web.Response:
|
||||
"""Start a get request."""
|
||||
try:
|
||||
content, data = await self.tts.async_read_tts(filename)
|
||||
|
|
|
@ -42,7 +42,7 @@ class WebsocketAPIView(HomeAssistantView):
|
|||
url = URL
|
||||
requires_auth = False
|
||||
|
||||
async def get(self, request):
|
||||
async def get(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle an incoming websocket connection."""
|
||||
return await WebSocketHandler(request.app["hass"], request).async_handle()
|
||||
|
||||
|
@ -148,7 +148,7 @@ class WebSocketHandler:
|
|||
self._handle_task.cancel()
|
||||
self._writer_task.cancel()
|
||||
|
||||
async def async_handle(self):
|
||||
async def async_handle(self) -> web.WebSocketResponse:
|
||||
"""Handle a websocket response."""
|
||||
request = self.request
|
||||
wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
"""Helpers for the data entry flow."""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from aiohttp import web
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries, data_entry_flow
|
||||
|
@ -8,18 +11,16 @@ from homeassistant.components.http.data_validator import RequestDataValidator
|
|||
from homeassistant.const import HTTP_NOT_FOUND
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs
|
||||
|
||||
|
||||
class _BaseFlowManagerView(HomeAssistantView):
|
||||
"""Foundation for flow manager views."""
|
||||
|
||||
def __init__(self, flow_mgr):
|
||||
def __init__(self, flow_mgr: data_entry_flow.FlowManager) -> None:
|
||||
"""Initialize the flow manager index view."""
|
||||
self._flow_mgr = flow_mgr
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def _prepare_result_json(self, result):
|
||||
def _prepare_result_json(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert result to JSON."""
|
||||
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||
data = result.copy()
|
||||
|
@ -57,7 +58,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
|
|||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
)
|
||||
async def post(self, request, data):
|
||||
async def post(self, request: web.Request, data: Dict[str, Any]) -> web.Response:
|
||||
"""Handle a POST request."""
|
||||
if isinstance(data["handler"], list):
|
||||
handler = tuple(data["handler"])
|
||||
|
@ -66,7 +67,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
|
|||
|
||||
try:
|
||||
result = await self._flow_mgr.async_init(
|
||||
handler,
|
||||
handler, # type: ignore
|
||||
context={
|
||||
"source": config_entries.SOURCE_USER,
|
||||
"show_advanced_options": data["show_advanced_options"],
|
||||
|
@ -85,7 +86,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
|
|||
class FlowManagerResourceView(_BaseFlowManagerView):
|
||||
"""View to interact with the flow manager."""
|
||||
|
||||
async def get(self, request, flow_id):
|
||||
async def get(self, request: web.Request, flow_id: str) -> web.Response:
|
||||
"""Get the current state of a data_entry_flow."""
|
||||
try:
|
||||
result = await self._flow_mgr.async_configure(flow_id)
|
||||
|
@ -97,7 +98,9 @@ class FlowManagerResourceView(_BaseFlowManagerView):
|
|||
return self.json(result)
|
||||
|
||||
@RequestDataValidator(vol.Schema(dict), allow_empty=True)
|
||||
async def post(self, request, flow_id, data):
|
||||
async def post(
|
||||
self, request: web.Request, flow_id: str, data: Dict[str, Any]
|
||||
) -> web.Response:
|
||||
"""Handle a POST request."""
|
||||
try:
|
||||
result = await self._flow_mgr.async_configure(flow_id, data)
|
||||
|
@ -110,7 +113,7 @@ class FlowManagerResourceView(_BaseFlowManagerView):
|
|||
|
||||
return self.json(result)
|
||||
|
||||
async def delete(self, request, flow_id):
|
||||
async def delete(self, request: web.Request, flow_id: str) -> web.Response:
|
||||
"""Cancel a flow in progress."""
|
||||
try:
|
||||
self._flow_mgr.async_abort(flow_id)
|
||||
|
|
Loading…
Add table
Reference in a new issue