Fix Tibber get_prices when called with aware datetime (#123289)

* Tibber: Add extra test to expose aware/naive datetime issue

* Tibber: Fix get_prices action not working with aware datetimes

* Tibber: Simplify comparison

* Tibber: Combine timezone tests into single parametrized one

* Tibber: Split test again to prevent if statement
This commit is contained in:
functionpointer 2024-10-02 08:43:31 +02:00 committed by GitHub
parent cd090ff000
commit 5bd2d27488
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 78 additions and 9 deletions

View file

@ -3,7 +3,7 @@
from __future__ import annotations
import datetime as dt
from datetime import date, datetime
from datetime import datetime
from functools import partial
from typing import Any, Final
@ -61,27 +61,24 @@ async def __get_prices(call: ServiceCall, *, hass: HomeAssistant) -> ServiceResp
]
selected_data = [
price
for price in price_data
if price["start_time"].replace(tzinfo=None) >= start
and price["start_time"].replace(tzinfo=None) < end
price for price in price_data if start <= price["start_time"] < end
]
tibber_prices[home_nickname] = selected_data
return {"prices": tibber_prices}
def __get_date(date_input: str | None, mode: str | None) -> date | datetime:
def __get_date(date_input: str | None, mode: str | None) -> datetime:
"""Get date."""
if not date_input:
if mode == "end":
increment = dt.timedelta(days=1)
else:
increment = dt.timedelta()
return datetime.fromisoformat(dt_util.now().date().isoformat()) + increment
return dt_util.start_of_local_day() + increment
if value := dt_util.parse_datetime(date_input):
return value
return dt_util.as_local(value)
raise ServiceValidationError(
"Invalid datetime provided.",

View file

@ -11,8 +11,11 @@ from homeassistant.components.tibber.const import DOMAIN
from homeassistant.components.tibber.services import PRICE_SERVICE_NAME, __get_prices
from homeassistant.core import ServiceCall
from homeassistant.exceptions import ServiceValidationError
from homeassistant.util import dt as dt_util
STARTTIME = dt.datetime.fromtimestamp(1615766400)
STARTTIME = dt.datetime.fromtimestamp(1615766400).replace(
tzinfo=dt_util.get_default_time_zone()
)
def generate_mock_home_data():
@ -246,6 +249,75 @@ async def test_get_prices_start_tomorrow(
}
@pytest.mark.parametrize(
"start_time",
[
STARTTIME.isoformat(),
STARTTIME.replace(tzinfo=None).isoformat(),
(STARTTIME + dt.timedelta(hours=4))
.replace(tzinfo=dt.timezone(dt.timedelta(hours=4)))
.isoformat(),
],
)
async def test_get_prices_with_timezones(
freezer: FrozenDateTimeFactory,
start_time: str,
) -> None:
"""Test __get_prices with timezone and without."""
freezer.move_to(STARTTIME)
call = ServiceCall(DOMAIN, PRICE_SERVICE_NAME, {"start": start_time})
result = await __get_prices(call, hass=create_mock_hass())
assert result == {
"prices": {
"first_home": [
{
"start_time": STARTTIME,
"price": 0.46914,
"level": "VERY_EXPENSIVE",
},
{
"start_time": STARTTIME + dt.timedelta(hours=1),
"price": 0.46914,
"level": "VERY_EXPENSIVE",
},
],
"second_home": [
{
"start_time": STARTTIME,
"price": 0.46914,
"level": "VERY_EXPENSIVE",
},
{
"start_time": STARTTIME + dt.timedelta(hours=1),
"price": 0.46914,
"level": "VERY_EXPENSIVE",
},
],
}
}
@pytest.mark.parametrize(
"start_time",
[
(STARTTIME + dt.timedelta(hours=4)).isoformat(),
(STARTTIME + dt.timedelta(hours=4)).replace(tzinfo=None).isoformat(),
],
)
async def test_get_prices_with_wrong_timezones(
freezer: FrozenDateTimeFactory,
start_time: str,
) -> None:
"""Test __get_prices with timezone and without, while expecting it to fail."""
freezer.move_to(STARTTIME)
call = ServiceCall(DOMAIN, PRICE_SERVICE_NAME, {"start": start_time})
result = await __get_prices(call, hass=create_mock_hass())
assert result == {"prices": {"first_home": [], "second_home": []}}
async def test_get_prices_invalid_input() -> None:
"""Test __get_prices with invalid input."""