Add integration type (#68349)

This commit is contained in:
Paulus Schoutsen 2022-03-20 20:38:13 -07:00 committed by GitHub
parent 4f9df1fd0f
commit 3213091b8d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 608 additions and 499 deletions

View file

@ -1,13 +1,14 @@
"""Http views to control the config manager.""" """Http views to control the config manager."""
from __future__ import annotations from __future__ import annotations
import asyncio
from http import HTTPStatus from http import HTTPStatus
from aiohttp import web from aiohttp import web
import aiohttp.web_exceptions import aiohttp.web_exceptions
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries, data_entry_flow from homeassistant import config_entries, data_entry_flow, loader
from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES, POLICY_EDIT from homeassistant.auth.permissions.const import CAT_CONFIG_ENTRIES, POLICY_EDIT
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
@ -48,11 +49,36 @@ class ConfigManagerEntryIndexView(HomeAssistantView):
async def get(self, request): async def get(self, request):
"""List available config entries.""" """List available config entries."""
hass = request.app["hass"] hass: HomeAssistant = request.app["hass"]
return self.json( kwargs = {}
[entry_json(entry) for entry in hass.config_entries.async_entries()] if "domain" in request.query:
kwargs["domain"] = request.query["domain"]
entries = hass.config_entries.async_entries(**kwargs)
if "type" not in request.query:
return self.json([entry_json(entry) for entry in entries])
integrations = {}
type_filter = request.query["type"]
# Fetch all the integrations so we can check their type
for integration in await asyncio.gather(
*(
loader.async_get_integration(hass, domain)
for domain in {entry.domain for entry in entries}
) )
):
integrations[integration.domain] = integration
entries = [
entry
for entry in entries
if integrations[entry.domain].integration_type == type_filter
]
return self.json([entry_json(entry) for entry in entries])
class ConfigManagerEntryResourceView(HomeAssistantView): class ConfigManagerEntryResourceView(HomeAssistantView):
@ -179,7 +205,10 @@ class ConfigManagerAvailableFlowView(HomeAssistantView):
async def get(self, request): async def get(self, request):
"""List available flow handlers.""" """List available flow handlers."""
hass = request.app["hass"] hass = request.app["hass"]
return self.json(await async_get_config_flows(hass)) kwargs = {}
if "type" in request.query:
kwargs["type_filter"] = request.query["type"]
return self.json(await async_get_config_flows(hass, **kwargs))
class OptionManagerFlowIndexView(FlowManagerIndexView): class OptionManagerFlowIndexView(FlowManagerIndexView):

View file

@ -1,5 +1,6 @@
{ {
"domain": "derivative", "domain": "derivative",
"integration_type": "helper",
"name": "Derivative", "name": "Derivative",
"documentation": "https://www.home-assistant.io/integrations/derivative", "documentation": "https://www.home-assistant.io/integrations/derivative",
"codeowners": [ "codeowners": [

View file

@ -5,7 +5,8 @@ To update, run python3 -m script.hassfest
# fmt: off # fmt: off
FLOWS = [ FLOWS = {
"integration": [
"abode", "abode",
"accuweather", "accuweather",
"acmeda", "acmeda",
@ -66,7 +67,6 @@ FLOWS = [
"daikin", "daikin",
"deconz", "deconz",
"denonavr", "denonavr",
"derivative",
"devolo_home_control", "devolo_home_control",
"devolo_home_network", "devolo_home_network",
"dexcom", "dexcom",
@ -395,4 +395,8 @@ FLOWS = [
"zha", "zha",
"zwave_js", "zwave_js",
"zwave_me" "zwave_me"
],
"helper": [
"derivative"
] ]
}

View file

@ -16,7 +16,7 @@ import logging
import pathlib import pathlib
import sys import sys
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING, Any, TypedDict, TypeVar, cast from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeVar, cast
from awesomeversion import ( from awesomeversion import (
AwesomeVersion, AwesomeVersion,
@ -87,6 +87,7 @@ class Manifest(TypedDict, total=False):
name: str name: str
disabled: str disabled: str
domain: str domain: str
integration_type: Literal["integration", "helper"]
dependencies: list[str] dependencies: list[str]
after_dependencies: list[str] after_dependencies: list[str]
requirements: list[str] requirements: list[str]
@ -180,20 +181,29 @@ async def async_get_custom_components(
return cast(dict[str, "Integration"], reg_or_evt) return cast(dict[str, "Integration"], reg_or_evt)
async def async_get_config_flows(hass: HomeAssistant) -> set[str]: async def async_get_config_flows(
hass: HomeAssistant,
type_filter: Literal["helper", "integration"] | None = None,
) -> set[str]:
"""Return cached list of config flows.""" """Return cached list of config flows."""
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
from .generated.config_flows import FLOWS from .generated.config_flows import FLOWS
flows: set[str] = set()
flows.update(FLOWS)
integrations = await async_get_custom_components(hass) integrations = await async_get_custom_components(hass)
flows: set[str] = set()
if type_filter is not None:
flows.update(FLOWS[type_filter])
else:
for type_flows in FLOWS.values():
flows.update(type_flows)
flows.update( flows.update(
[ [
integration.domain integration.domain
for integration in integrations.values() for integration in integrations.values()
if integration.config_flow if integration.config_flow
and (type_filter is None or integration.integration_type == type_filter)
] ]
) )
@ -474,6 +484,11 @@ class Integration:
"""Return the integration IoT Class.""" """Return the integration IoT Class."""
return self.manifest.get("iot_class") return self.manifest.get("iot_class")
@property
def integration_type(self) -> Literal["integration", "helper"]:
"""Return the integration type."""
return self.manifest.get("integration_type", "integration")
@property @property
def mqtt(self) -> list[str] | None: def mqtt(self) -> list[str] | None:
"""Return Integration MQTT entries.""" """Return Integration MQTT entries."""

View file

@ -69,7 +69,10 @@ def validate_integration(config: Config, integration: Integration):
def generate_and_validate(integrations: dict[str, Integration], config: Config): def generate_and_validate(integrations: dict[str, Integration], config: Config):
"""Validate and generate config flow data.""" """Validate and generate config flow data."""
domains = [] domains = {
"integration": [],
"helper": [],
}
for domain in sorted(integrations): for domain in sorted(integrations):
integration = integrations[domain] integration = integrations[domain]
@ -79,7 +82,7 @@ def generate_and_validate(integrations: dict[str, Integration], config: Config):
validate_integration(config, integration) validate_integration(config, integration)
domains.append(domain) domains[integration.integration_type].append(domain)
return BASE.format(json.dumps(domains, indent=4)) return BASE.format(json.dumps(domains, indent=4))

View file

@ -152,6 +152,7 @@ MANIFEST_SCHEMA = vol.Schema(
{ {
vol.Required("domain"): str, vol.Required("domain"): str,
vol.Required("name"): str, vol.Required("name"): str,
vol.Optional("integration_type"): "helper",
vol.Optional("config_flow"): bool, vol.Optional("config_flow"): bool,
vol.Optional("mqtt"): [str], vol.Optional("mqtt"): [str],
vol.Optional("zeroconf"): [ vol.Optional("zeroconf"): [

View file

@ -112,6 +112,11 @@ class Integration:
"""List of dependencies.""" """List of dependencies."""
return self.manifest.get("dependencies", []) return self.manifest.get("dependencies", [])
@property
def integration_type(self) -> str:
"""Get integration_type."""
return self.manifest.get("integration_type", "integration")
def add_error(self, *args: Any, **kwargs: Any) -> None: def add_error(self, *args: Any, **kwargs: Any) -> None:
"""Add an error.""" """Add an error."""
self.errors.append(Error(*args, **kwargs)) self.errors.append(Error(*args, **kwargs))

View file

@ -23,6 +23,13 @@ from tests.common import (
) )
@pytest.fixture
def clear_handlers():
"""Clear config entry handlers."""
with patch.dict(HANDLERS, clear=True):
yield
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_test_component(hass): def mock_test_component(hass):
"""Ensure a component called 'test' exists.""" """Ensure a component called 'test' exists."""
@ -30,16 +37,20 @@ def mock_test_component(hass):
@pytest.fixture @pytest.fixture
def client(hass, hass_client): async def client(hass, hass_client):
"""Fixture that can interact with the config manager API.""" """Fixture that can interact with the config manager API."""
hass.loop.run_until_complete(async_setup_component(hass, "http", {})) await async_setup_component(hass, "http", {})
hass.loop.run_until_complete(config_entries.async_setup(hass)) await config_entries.async_setup(hass)
yield hass.loop.run_until_complete(hass_client()) return await hass_client()
async def test_get_entries(hass, client): async def test_get_entries(hass, client, clear_handlers):
"""Test get entries.""" """Test get entries."""
with patch.dict(HANDLERS, clear=True): mock_integration(hass, MockModule("comp1"))
mock_integration(
hass, MockModule("comp2", partial_manifest={"integration_type": "helper"})
)
mock_integration(hass, MockModule("comp3"))
@HANDLERS.register("comp1") @HANDLERS.register("comp1")
class Comp1ConfigFlow: class Comp1ConfigFlow:
@ -129,6 +140,31 @@ async def test_get_entries(hass, client):
}, },
] ]
resp = await client.get("/api/config/config_entries/entry?domain=comp3")
assert resp.status == HTTPStatus.OK
data = await resp.json()
assert len(data) == 1
assert data[0]["domain"] == "comp3"
resp = await client.get("/api/config/config_entries/entry?domain=comp3&type=helper")
assert resp.status == HTTPStatus.OK
data = await resp.json()
assert len(data) == 0
resp = await client.get(
"/api/config/config_entries/entry?domain=comp3&type=integration"
)
assert resp.status == HTTPStatus.OK
data = await resp.json()
assert len(data) == 1
resp = await client.get("/api/config/config_entries/entry?type=integration")
assert resp.status == HTTPStatus.OK
data = await resp.json()
assert len(data) == 2
assert data[0]["domain"] == "comp1"
assert data[1]["domain"] == "comp3"
async def test_remove_entry(hass, client): async def test_remove_entry(hass, client):
"""Test removing an entry via the API.""" """Test removing an entry via the API."""
@ -224,13 +260,28 @@ async def test_reload_entry_in_setup_retry(hass, client, hass_admin_user):
assert len(hass.config_entries.async_entries()) == 1 assert len(hass.config_entries.async_entries()) == 1
async def test_available_flows(hass, client): @pytest.mark.parametrize(
"type_filter,result",
(
(None, {"hello", "another", "world"}),
("integration", {"hello", "another"}),
("helper", {"world"}),
),
)
async def test_available_flows(hass, client, type_filter, result):
"""Test querying the available flows.""" """Test querying the available flows."""
with patch.object(config_flows, "FLOWS", ["hello", "world"]): with patch.object(
resp = await client.get("/api/config/config_entries/flow_handlers") config_flows,
"FLOWS",
{"integration": ["hello", "another"], "helper": ["world"]},
):
resp = await client.get(
"/api/config/config_entries/flow_handlers",
params={"type": type_filter} if type_filter else {},
)
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
data = await resp.json() data = await resp.json()
assert set(data) == {"hello", "world"} assert set(data) == result
############################ ############################

View file

@ -15,7 +15,7 @@ from homeassistant.setup import async_setup_component
@pytest.fixture @pytest.fixture
def mock_config_flows(): def mock_config_flows():
"""Mock the config flows.""" """Mock the config flows."""
flows = [] flows = {"integration": [], "helper": {}}
with patch.object(config_flows, "FLOWS", flows): with patch.object(config_flows, "FLOWS", flows):
yield flows yield flows
@ -124,7 +124,7 @@ async def test_get_translations(hass, mock_config_flows, enable_custom_integrati
async def test_get_translations_loads_config_flows(hass, mock_config_flows): async def test_get_translations_loads_config_flows(hass, mock_config_flows):
"""Test the get translations helper loads config flow translations.""" """Test the get translations helper loads config flow translations."""
mock_config_flows.append("component1") mock_config_flows["integration"].append("component1")
integration = Mock(file_path=pathlib.Path(__file__)) integration = Mock(file_path=pathlib.Path(__file__))
integration.name = "Component 1" integration.name = "Component 1"
@ -153,7 +153,7 @@ async def test_get_translations_loads_config_flows(hass, mock_config_flows):
assert "component1" not in hass.config.components assert "component1" not in hass.config.components
mock_config_flows.append("component2") mock_config_flows["integration"].append("component2")
integration = Mock(file_path=pathlib.Path(__file__)) integration = Mock(file_path=pathlib.Path(__file__))
integration.name = "Component 2" integration.name = "Component 2"