Changed setup of EnergyZero services (#106224)

* Changed setup of energyzero services

* PR review updates

* Dict access instead of get

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Added tests for unloaded state

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Robert Groot 2024-01-02 13:24:17 +01:00 committed by GitHub
parent 0d7bb2d124
commit 2df9e5e7b9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 155 additions and 15 deletions

View file

@ -5,12 +5,23 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN from .const import DOMAIN
from .coordinator import EnergyZeroDataUpdateCoordinator from .coordinator import EnergyZeroDataUpdateCoordinator
from .services import async_register_services from .services import async_setup_services
PLATFORMS = [Platform.SENSOR] PLATFORMS = [Platform.SENSOR]
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up EnergyZero services."""
async_setup_services(hass)
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
@ -27,8 +38,6 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
async_register_services(hass, coordinator)
return True return True

View file

@ -9,6 +9,7 @@ from typing import Final
from energyzero import Electricity, Gas, VatOption from energyzero import Electricity, Gas, VatOption
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.core import ( from homeassistant.core import (
HomeAssistant, HomeAssistant,
ServiceCall, ServiceCall,
@ -17,11 +18,13 @@ from homeassistant.core import (
callback, callback,
) )
from homeassistant.exceptions import ServiceValidationError from homeassistant.exceptions import ServiceValidationError
from homeassistant.helpers import selector
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from .const import DOMAIN from .const import DOMAIN
from .coordinator import EnergyZeroDataUpdateCoordinator from .coordinator import EnergyZeroDataUpdateCoordinator
ATTR_CONFIG_ENTRY: Final = "config_entry"
ATTR_START: Final = "start" ATTR_START: Final = "start"
ATTR_END: Final = "end" ATTR_END: Final = "end"
ATTR_INCL_VAT: Final = "incl_vat" ATTR_INCL_VAT: Final = "incl_vat"
@ -30,6 +33,11 @@ GAS_SERVICE_NAME: Final = "get_gas_prices"
ENERGY_SERVICE_NAME: Final = "get_energy_prices" ENERGY_SERVICE_NAME: Final = "get_energy_prices"
SERVICE_SCHEMA: Final = vol.Schema( SERVICE_SCHEMA: Final = vol.Schema(
{ {
vol.Required(ATTR_CONFIG_ENTRY): selector.ConfigEntrySelector(
{
"integration": DOMAIN,
}
),
vol.Required(ATTR_INCL_VAT): bool, vol.Required(ATTR_INCL_VAT): bool,
vol.Optional(ATTR_START): str, vol.Optional(ATTR_START): str,
vol.Optional(ATTR_END): str, vol.Optional(ATTR_END): str,
@ -75,12 +83,43 @@ def __serialize_prices(prices: Electricity | Gas) -> ServiceResponse:
} }
def __get_coordinator(
hass: HomeAssistant, call: ServiceCall
) -> EnergyZeroDataUpdateCoordinator:
"""Get the coordinator from the entry."""
entry_id: str = call.data[ATTR_CONFIG_ENTRY]
entry: ConfigEntry | None = hass.config_entries.async_get_entry(entry_id)
if not entry:
raise ServiceValidationError(
f"Invalid config entry: {entry_id}",
translation_domain=DOMAIN,
translation_key="invalid_config_entry",
translation_placeholders={
"config_entry": entry_id,
},
)
if entry.state != ConfigEntryState.LOADED:
raise ServiceValidationError(
f"{entry.title} is not loaded",
translation_domain=DOMAIN,
translation_key="unloaded_config_entry",
translation_placeholders={
"config_entry": entry.title,
},
)
return hass.data[DOMAIN][entry_id]
async def __get_prices( async def __get_prices(
call: ServiceCall, call: ServiceCall,
*, *,
coordinator: EnergyZeroDataUpdateCoordinator, hass: HomeAssistant,
price_type: PriceType, price_type: PriceType,
) -> ServiceResponse: ) -> ServiceResponse:
coordinator = __get_coordinator(hass, call)
start = __get_date(call.data.get(ATTR_START)) start = __get_date(call.data.get(ATTR_START))
end = __get_date(call.data.get(ATTR_END)) end = __get_date(call.data.get(ATTR_END))
@ -108,22 +147,20 @@ async def __get_prices(
@callback @callback
def async_register_services( def async_setup_services(hass: HomeAssistant) -> None:
hass: HomeAssistant, coordinator: EnergyZeroDataUpdateCoordinator
):
"""Set up EnergyZero services.""" """Set up EnergyZero services."""
hass.services.async_register( hass.services.async_register(
DOMAIN, DOMAIN,
GAS_SERVICE_NAME, GAS_SERVICE_NAME,
partial(__get_prices, coordinator=coordinator, price_type=PriceType.GAS), partial(__get_prices, hass=hass, price_type=PriceType.GAS),
schema=SERVICE_SCHEMA, schema=SERVICE_SCHEMA,
supports_response=SupportsResponse.ONLY, supports_response=SupportsResponse.ONLY,
) )
hass.services.async_register( hass.services.async_register(
DOMAIN, DOMAIN,
ENERGY_SERVICE_NAME, ENERGY_SERVICE_NAME,
partial(__get_prices, coordinator=coordinator, price_type=PriceType.ENERGY), partial(__get_prices, hass=hass, price_type=PriceType.ENERGY),
schema=SERVICE_SCHEMA, schema=SERVICE_SCHEMA,
supports_response=SupportsResponse.ONLY, supports_response=SupportsResponse.ONLY,
) )

View file

@ -1,5 +1,10 @@
get_gas_prices: get_gas_prices:
fields: fields:
config_entry:
required: true
selector:
config_entry:
integration: energyzero
incl_vat: incl_vat:
required: true required: true
default: true default: true
@ -17,6 +22,11 @@ get_gas_prices:
datetime: datetime:
get_energy_prices: get_energy_prices:
fields: fields:
config_entry:
required: true
selector:
config_entry:
integration: energyzero
incl_vat: incl_vat:
required: true required: true
default: true default: true

View file

@ -12,6 +12,12 @@
"exceptions": { "exceptions": {
"invalid_date": { "invalid_date": {
"message": "Invalid date provided. Got {date}" "message": "Invalid date provided. Got {date}"
},
"invalid_config_entry": {
"message": "Invalid config entry provided. Got {config_entry}"
},
"unloaded_config_entry": {
"message": "Invalid config entry provided. {config_entry} is not loaded."
} }
}, },
"entity": { "entity": {
@ -50,6 +56,10 @@
"name": "Get gas prices", "name": "Get gas prices",
"description": "Request gas prices from EnergyZero.", "description": "Request gas prices from EnergyZero.",
"fields": { "fields": {
"config_entry": {
"name": "Config Entry",
"description": "The config entry to use for this service."
},
"incl_vat": { "incl_vat": {
"name": "Including VAT", "name": "Including VAT",
"description": "Include VAT in the prices." "description": "Include VAT in the prices."
@ -68,6 +78,10 @@
"name": "Get energy prices", "name": "Get energy prices",
"description": "Request energy prices from EnergyZero.", "description": "Request energy prices from EnergyZero.",
"fields": { "fields": {
"config_entry": {
"name": "[%key:component::energyzero::services::get_gas_prices::fields::config_entry::name%]",
"description": "[%key:component::energyzero::services::get_gas_prices::fields::config_entry::description%]"
},
"incl_vat": { "incl_vat": {
"name": "[%key:component::energyzero::services::get_gas_prices::fields::incl_vat::name%]", "name": "[%key:component::energyzero::services::get_gas_prices::fields::incl_vat::name%]",
"description": "[%key:component::energyzero::services::get_gas_prices::fields::incl_vat::description%]" "description": "[%key:component::energyzero::services::get_gas_prices::fields::incl_vat::description%]"

View file

@ -6,12 +6,15 @@ import voluptuous as vol
from homeassistant.components.energyzero.const import DOMAIN from homeassistant.components.energyzero.const import DOMAIN
from homeassistant.components.energyzero.services import ( from homeassistant.components.energyzero.services import (
ATTR_CONFIG_ENTRY,
ENERGY_SERVICE_NAME, ENERGY_SERVICE_NAME,
GAS_SERVICE_NAME, GAS_SERVICE_NAME,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ServiceValidationError from homeassistant.exceptions import ServiceValidationError
from tests.common import MockConfigEntry
@pytest.mark.usefixtures("init_integration") @pytest.mark.usefixtures("init_integration")
async def test_has_services( async def test_has_services(
@ -29,6 +32,7 @@ async def test_has_services(
@pytest.mark.parametrize("end", [{"end": "2023-01-01 00:00:00"}, {}]) @pytest.mark.parametrize("end", [{"end": "2023-01-01 00:00:00"}, {}])
async def test_service( async def test_service(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
service: str, service: str,
incl_vat: dict[str, bool], incl_vat: dict[str, bool],
@ -36,8 +40,9 @@ async def test_service(
end: dict[str, str], end: dict[str, str],
) -> None: ) -> None:
"""Test the EnergyZero Service.""" """Test the EnergyZero Service."""
entry = {ATTR_CONFIG_ENTRY: mock_config_entry.entry_id}
data = incl_vat | start | end data = entry | incl_vat | start | end
assert snapshot == await hass.services.async_call( assert snapshot == await hass.services.async_call(
DOMAIN, DOMAIN,
@ -48,32 +53,72 @@ async def test_service(
) )
@pytest.fixture
def config_entry_data(
mock_config_entry: MockConfigEntry, request: pytest.FixtureRequest
) -> dict[str, str]:
"""Fixture for the config entry."""
if "config_entry" in request.param and request.param["config_entry"] is True:
return {"config_entry": mock_config_entry.entry_id}
return request.param
@pytest.mark.usefixtures("init_integration") @pytest.mark.usefixtures("init_integration")
@pytest.mark.parametrize("service", [GAS_SERVICE_NAME, ENERGY_SERVICE_NAME]) @pytest.mark.parametrize("service", [GAS_SERVICE_NAME, ENERGY_SERVICE_NAME])
@pytest.mark.parametrize( @pytest.mark.parametrize(
("service_data", "error", "error_message"), ("config_entry_data", "service_data", "error", "error_message"),
[ [
({}, vol.er.Error, "required key not provided .+"), ({}, {}, vol.er.Error, "required key not provided .+"),
( (
{"config_entry": True},
{},
vol.er.Error,
"required key not provided .+",
),
(
{},
{"incl_vat": True},
vol.er.Error,
"required key not provided .+",
),
(
{"config_entry": True},
{"incl_vat": "incorrect vat"}, {"incl_vat": "incorrect vat"},
vol.er.Error, vol.er.Error,
"expected bool for dictionary value .+", "expected bool for dictionary value .+",
), ),
( (
{"incl_vat": True, "start": "incorrect date"}, {"config_entry": "incorrect entry"},
{"incl_vat": True},
ServiceValidationError,
"Invalid config entry.+",
),
(
{"config_entry": True},
{
"incl_vat": True,
"start": "incorrect date",
},
ServiceValidationError, ServiceValidationError,
"Invalid datetime provided.", "Invalid datetime provided.",
), ),
( (
{"incl_vat": True, "end": "incorrect date"}, {"config_entry": True},
{
"incl_vat": True,
"end": "incorrect date",
},
ServiceValidationError, ServiceValidationError,
"Invalid datetime provided.", "Invalid datetime provided.",
), ),
], ],
indirect=["config_entry_data"],
) )
async def test_service_validation( async def test_service_validation(
hass: HomeAssistant, hass: HomeAssistant,
service: str, service: str,
config_entry_data: dict[str, str],
service_data: dict[str, str], service_data: dict[str, str],
error: type[Exception], error: type[Exception],
error_message: str, error_message: str,
@ -84,7 +129,32 @@ async def test_service_validation(
await hass.services.async_call( await hass.services.async_call(
DOMAIN, DOMAIN,
service, service,
service_data, config_entry_data | service_data,
blocking=True,
return_response=True,
)
@pytest.mark.usefixtures("init_integration")
@pytest.mark.parametrize("service", [GAS_SERVICE_NAME, ENERGY_SERVICE_NAME])
async def test_service_called_with_unloaded_entry(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
service: str,
) -> None:
"""Test service calls with unloaded config entry."""
await mock_config_entry.async_unload(hass)
data = {"config_entry": mock_config_entry.entry_id, "incl_vat": True}
with pytest.raises(
ServiceValidationError, match=f"{mock_config_entry.title} is not loaded"
):
await hass.services.async_call(
DOMAIN,
service,
data,
blocking=True, blocking=True,
return_response=True, return_response=True,
) )