diff --git a/homeassistant/components/tibber/services.py b/homeassistant/components/tibber/services.py index 82353bb78d7..35facbcd545 100644 --- a/homeassistant/components/tibber/services.py +++ b/homeassistant/components/tibber/services.py @@ -3,7 +3,7 @@ from __future__ import annotations import datetime as dt -from datetime import date, datetime +from datetime import datetime from functools import partial from typing import Any, Final @@ -61,27 +61,24 @@ async def __get_prices(call: ServiceCall, *, hass: HomeAssistant) -> ServiceResp ] selected_data = [ - price - for price in price_data - if price["start_time"].replace(tzinfo=None) >= start - and price["start_time"].replace(tzinfo=None) < end + price for price in price_data if start <= price["start_time"] < end ] tibber_prices[home_nickname] = selected_data return {"prices": tibber_prices} -def __get_date(date_input: str | None, mode: str | None) -> date | datetime: +def __get_date(date_input: str | None, mode: str | None) -> datetime: """Get date.""" if not date_input: if mode == "end": increment = dt.timedelta(days=1) else: increment = dt.timedelta() - return datetime.fromisoformat(dt_util.now().date().isoformat()) + increment + return dt_util.start_of_local_day() + increment if value := dt_util.parse_datetime(date_input): - return value + return dt_util.as_local(value) raise ServiceValidationError( "Invalid datetime provided.", diff --git a/tests/components/tibber/test_services.py b/tests/components/tibber/test_services.py index e9bee3ba31f..1df91d719fe 100644 --- a/tests/components/tibber/test_services.py +++ b/tests/components/tibber/test_services.py @@ -11,8 +11,11 @@ from homeassistant.components.tibber.const import DOMAIN from homeassistant.components.tibber.services import PRICE_SERVICE_NAME, __get_prices from homeassistant.core import ServiceCall from homeassistant.exceptions import ServiceValidationError +from homeassistant.util import dt as dt_util -STARTTIME = dt.datetime.fromtimestamp(1615766400) +STARTTIME = dt.datetime.fromtimestamp(1615766400).replace( + tzinfo=dt_util.get_default_time_zone() +) def generate_mock_home_data(): @@ -246,6 +249,75 @@ async def test_get_prices_start_tomorrow( } +@pytest.mark.parametrize( + "start_time", + [ + STARTTIME.isoformat(), + STARTTIME.replace(tzinfo=None).isoformat(), + (STARTTIME + dt.timedelta(hours=4)) + .replace(tzinfo=dt.timezone(dt.timedelta(hours=4))) + .isoformat(), + ], +) +async def test_get_prices_with_timezones( + freezer: FrozenDateTimeFactory, + start_time: str, +) -> None: + """Test __get_prices with timezone and without.""" + freezer.move_to(STARTTIME) + call = ServiceCall(DOMAIN, PRICE_SERVICE_NAME, {"start": start_time}) + + result = await __get_prices(call, hass=create_mock_hass()) + + assert result == { + "prices": { + "first_home": [ + { + "start_time": STARTTIME, + "price": 0.46914, + "level": "VERY_EXPENSIVE", + }, + { + "start_time": STARTTIME + dt.timedelta(hours=1), + "price": 0.46914, + "level": "VERY_EXPENSIVE", + }, + ], + "second_home": [ + { + "start_time": STARTTIME, + "price": 0.46914, + "level": "VERY_EXPENSIVE", + }, + { + "start_time": STARTTIME + dt.timedelta(hours=1), + "price": 0.46914, + "level": "VERY_EXPENSIVE", + }, + ], + } + } + + +@pytest.mark.parametrize( + "start_time", + [ + (STARTTIME + dt.timedelta(hours=4)).isoformat(), + (STARTTIME + dt.timedelta(hours=4)).replace(tzinfo=None).isoformat(), + ], +) +async def test_get_prices_with_wrong_timezones( + freezer: FrozenDateTimeFactory, + start_time: str, +) -> None: + """Test __get_prices with timezone and without, while expecting it to fail.""" + freezer.move_to(STARTTIME) + call = ServiceCall(DOMAIN, PRICE_SERVICE_NAME, {"start": start_time}) + + result = await __get_prices(call, hass=create_mock_hass()) + assert result == {"prices": {"first_home": [], "second_home": []}} + + async def test_get_prices_invalid_input() -> None: """Test __get_prices with invalid input."""