diff --git a/homeassistant/components/logbook/const.py b/homeassistant/components/logbook/const.py index 3f0c6599724..d20acb553cc 100644 --- a/homeassistant/components/logbook/const.py +++ b/homeassistant/components/logbook/const.py @@ -30,13 +30,11 @@ LOGBOOK_ENTRY_NAME = "name" LOGBOOK_ENTRY_STATE = "state" LOGBOOK_ENTRY_WHEN = "when" -ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED = {EVENT_LOGBOOK_ENTRY, EVENT_CALL_SERVICE} -ENTITY_EVENTS_WITHOUT_CONFIG_ENTRY = { - EVENT_LOGBOOK_ENTRY, - EVENT_AUTOMATION_TRIGGERED, - EVENT_SCRIPT_STARTED, -} +# Automation events that can affect an entity_id or device_id +AUTOMATION_EVENTS = {EVENT_AUTOMATION_TRIGGERED, EVENT_SCRIPT_STARTED} +# Events that are built-in to the logbook or core +BUILT_IN_EVENTS = {EVENT_LOGBOOK_ENTRY, EVENT_CALL_SERVICE} LOGBOOK_FILTERS = "logbook_filters" LOGBOOK_ENTITIES_FILTER = "entities_filter" diff --git a/homeassistant/components/logbook/helpers.py b/homeassistant/components/logbook/helpers.py index de021994b8d..eec60ebe740 100644 --- a/homeassistant/components/logbook/helpers.py +++ b/homeassistant/components/logbook/helpers.py @@ -7,6 +7,7 @@ from typing import Any from homeassistant.components.sensor import ATTR_STATE_CLASS from homeassistant.const import ( ATTR_DEVICE_ID, + ATTR_DOMAIN, ATTR_ENTITY_ID, ATTR_UNIT_OF_MEASUREMENT, EVENT_LOGBOOK_ENTRY, @@ -21,13 +22,10 @@ from homeassistant.core import ( is_callback, ) from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.helpers.entityfilter import EntityFilter from homeassistant.helpers.event import async_track_state_change_event -from .const import ( - ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED, - DOMAIN, - ENTITY_EVENTS_WITHOUT_CONFIG_ENTRY, -) +from .const import AUTOMATION_EVENTS, BUILT_IN_EVENTS, DOMAIN from .models import LazyEventPartialState @@ -41,6 +39,25 @@ def async_filter_entities(hass: HomeAssistant, entity_ids: list[str]) -> list[st ] +@callback +def _async_config_entries_for_ids( + hass: HomeAssistant, entity_ids: list[str] | None, device_ids: list[str] | None +) -> set[str]: + """Find the config entry ids for a set of entities or devices.""" + config_entry_ids: set[str] = set() + if entity_ids: + eng_reg = er.async_get(hass) + for entity_id in entity_ids: + if (entry := eng_reg.async_get(entity_id)) and entry.config_entry_id: + config_entry_ids.add(entry.config_entry_id) + if device_ids: + dev_reg = dr.async_get(hass) + for device_id in device_ids: + if (device := dev_reg.async_get(device_id)) and device.config_entries: + config_entry_ids |= device.config_entries + return config_entry_ids + + def async_determine_event_types( hass: HomeAssistant, entity_ids: list[str] | None, device_ids: list[str] | None ) -> tuple[str, ...]: @@ -49,42 +66,91 @@ def async_determine_event_types( str, tuple[str, Callable[[LazyEventPartialState], dict[str, Any]]] ] = hass.data.get(DOMAIN, {}) if not entity_ids and not device_ids: - return (*ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED, *external_events) - config_entry_ids: set[str] = set() - intrested_event_types: set[str] = set() + return (*BUILT_IN_EVENTS, *external_events) + interested_domains: set[str] = set() + for entry_id in _async_config_entries_for_ids(hass, entity_ids, device_ids): + if entry := hass.config_entries.async_get_entry(entry_id): + interested_domains.add(entry.domain) + + # + # automations and scripts can refer to entities or devices + # but they do not have a config entry so we need + # to add them since we have historically included + # them when matching only on entities + # + intrested_event_types: set[str] = { + external_event + for external_event, domain_call in external_events.items() + if domain_call[0] in interested_domains + } | AUTOMATION_EVENTS if entity_ids: - # - # Home Assistant doesn't allow firing events from - # entities so we have a limited list to check - # - # automations and scripts can refer to entities - # but they do not have a config entry so we need - # to add them. - # - # We also allow entity_ids to be recorded via - # manual logbook entries. - # - intrested_event_types |= ENTITY_EVENTS_WITHOUT_CONFIG_ENTRY + # We also allow entity_ids to be recorded via manual logbook entries. + intrested_event_types.add(EVENT_LOGBOOK_ENTRY) - if device_ids: - dev_reg = dr.async_get(hass) - for device_id in device_ids: - if (device := dev_reg.async_get(device_id)) and device.config_entries: - config_entry_ids |= device.config_entries - interested_domains: set[str] = set() - for entry_id in config_entry_ids: - if entry := hass.config_entries.async_get_entry(entry_id): - interested_domains.add(entry.domain) - for external_event, domain_call in external_events.items(): - if domain_call[0] in interested_domains: - intrested_event_types.add(external_event) + return tuple(intrested_event_types) - return tuple( - event_type - for event_type in (EVENT_LOGBOOK_ENTRY, *external_events) - if event_type in intrested_event_types - ) + +@callback +def extract_attr(source: dict[str, Any], attr: str) -> list[str]: + """Extract an attribute as a list or string.""" + if (value := source.get(attr)) is None: + return [] + if isinstance(value, list): + return value + return str(value).split(",") + + +@callback +def event_forwarder_filtered( + target: Callable[[Event], None], + entities_filter: EntityFilter | None, + entity_ids: list[str] | None, + device_ids: list[str] | None, +) -> Callable[[Event], None]: + """Make a callable to filter events.""" + if not entities_filter and not entity_ids and not device_ids: + # No filter + # - Script Trace (context ids) + # - Automation Trace (context ids) + return target + + if entities_filter: + # We have an entity filter: + # - Logbook panel + + @callback + def _forward_events_filtered_by_entities_filter(event: Event) -> None: + assert entities_filter is not None + event_data = event.data + entity_ids = extract_attr(event_data, ATTR_ENTITY_ID) + if entity_ids and not any( + entities_filter(entity_id) for entity_id in entity_ids + ): + return + domain = event_data.get(ATTR_DOMAIN) + if domain and not entities_filter(f"{domain}._"): + return + target(event) + + return _forward_events_filtered_by_entities_filter + + # We are filtering on entity_ids and/or device_ids: + # - Areas + # - Devices + # - Logbook Card + entity_ids_set = set(entity_ids) if entity_ids else set() + device_ids_set = set(device_ids) if device_ids else set() + + @callback + def _forward_events_filtered_by_device_entity_ids(event: Event) -> None: + event_data = event.data + if entity_ids_set.intersection( + extract_attr(event_data, ATTR_ENTITY_ID) + ) or device_ids_set.intersection(extract_attr(event_data, ATTR_DEVICE_ID)): + target(event) + + return _forward_events_filtered_by_device_entity_ids @callback @@ -93,6 +159,7 @@ def async_subscribe_events( subscriptions: list[CALLBACK_TYPE], target: Callable[[Event], None], event_types: tuple[str, ...], + entities_filter: EntityFilter | None, entity_ids: list[str] | None, device_ids: list[str] | None, ) -> None: @@ -103,41 +170,31 @@ def async_subscribe_events( """ ent_reg = er.async_get(hass) assert is_callback(target), "target must be a callback" - event_forwarder = target - - if entity_ids or device_ids: - entity_ids_set = set(entity_ids) if entity_ids else set() - device_ids_set = set(device_ids) if device_ids else set() - - @callback - def _forward_events_filtered(event: Event) -> None: - event_data = event.data - if ( - entity_ids_set and event_data.get(ATTR_ENTITY_ID) in entity_ids_set - ) or (device_ids_set and event_data.get(ATTR_DEVICE_ID) in device_ids_set): - target(event) - - event_forwarder = _forward_events_filtered - + event_forwarder = event_forwarder_filtered( + target, entities_filter, entity_ids, device_ids + ) for event_type in event_types: subscriptions.append( hass.bus.async_listen(event_type, event_forwarder, run_immediately=True) ) - @callback - def _forward_state_events_filtered(event: Event) -> None: - if event.data.get("old_state") is None or event.data.get("new_state") is None: - return - state: State = event.data["new_state"] - if not _is_state_filtered(ent_reg, state): - target(event) - if device_ids and not entity_ids: # No entities to subscribe to but we are filtering # on device ids so we do not want to get any state # changed events return + @callback + def _forward_state_events_filtered(event: Event) -> None: + if event.data.get("old_state") is None or event.data.get("new_state") is None: + return + state: State = event.data["new_state"] + if _is_state_filtered(ent_reg, state) or ( + entities_filter and not entities_filter(state.entity_id) + ): + return + target(event) + if entity_ids: subscriptions.append( async_track_state_change_event( diff --git a/homeassistant/components/logbook/processor.py b/homeassistant/components/logbook/processor.py index e5cc0f124b0..82225df8364 100644 --- a/homeassistant/components/logbook/processor.py +++ b/homeassistant/components/logbook/processor.py @@ -5,8 +5,6 @@ from collections.abc import Callable, Generator from contextlib import suppress from dataclasses import dataclass from datetime import datetime as dt -import logging -import re from typing import Any from sqlalchemy.engine.row import Row @@ -30,7 +28,6 @@ from homeassistant.const import ( ) from homeassistant.core import HomeAssistant, split_entity_id from homeassistant.helpers import entity_registry as er -from homeassistant.helpers.entityfilter import EntityFilter import homeassistant.util.dt as dt_util from .const import ( @@ -46,7 +43,6 @@ from .const import ( CONTEXT_STATE, CONTEXT_USER_ID, DOMAIN, - LOGBOOK_ENTITIES_FILTER, LOGBOOK_ENTRY_DOMAIN, LOGBOOK_ENTRY_ENTITY_ID, LOGBOOK_ENTRY_ICON, @@ -62,11 +58,6 @@ from .models import EventAsRow, LazyEventPartialState, async_event_to_row from .queries import statement_for_request from .queries.common import PSUEDO_EVENT_STATE_CHANGED -_LOGGER = logging.getLogger(__name__) - -ENTITY_ID_JSON_EXTRACT = re.compile('"entity_id": ?"([^"]+)"') -DOMAIN_JSON_EXTRACT = re.compile('"domain": ?"([^"]+)"') - @dataclass class LogbookRun: @@ -106,10 +97,6 @@ class EventProcessor: self.device_ids = device_ids self.context_id = context_id self.filters: Filters | None = hass.data[LOGBOOK_FILTERS] - if self.limited_select: - self.entities_filter: EntityFilter | Callable[[str], bool] | None = None - else: - self.entities_filter = hass.data[LOGBOOK_ENTITIES_FILTER] format_time = ( _row_time_fired_timestamp if timestamp else _row_time_fired_isoformat ) @@ -183,7 +170,6 @@ class EventProcessor: return list( _humanify( row_generator, - self.entities_filter, self.ent_reg, self.logbook_run, self.context_augmenter, @@ -193,7 +179,6 @@ class EventProcessor: def _humanify( rows: Generator[Row | EventAsRow, None, None], - entities_filter: EntityFilter | Callable[[str], bool] | None, ent_reg: er.EntityRegistry, logbook_run: LogbookRun, context_augmenter: ContextAugmenter, @@ -208,29 +193,13 @@ def _humanify( include_entity_name = logbook_run.include_entity_name format_time = logbook_run.format_time - def _keep_row(row: EventAsRow) -> bool: - """Check if the entity_filter rejects a row.""" - assert entities_filter is not None - if entity_id := row.entity_id: - return entities_filter(entity_id) - if entity_id := row.data.get(ATTR_ENTITY_ID): - return entities_filter(entity_id) - if domain := row.data.get(ATTR_DOMAIN): - return entities_filter(f"{domain}._") - return True - # Process rows for row in rows: context_id = context_lookup.memorize(row) if row.context_only: continue event_type = row.event_type - if event_type == EVENT_CALL_SERVICE or ( - entities_filter - # We literally mean is EventAsRow not a subclass of EventAsRow - and type(row) is EventAsRow # pylint: disable=unidiomatic-typecheck - and not _keep_row(row) - ): + if event_type == EVENT_CALL_SERVICE: continue if event_type is PSUEDO_EVENT_STATE_CHANGED: entity_id = row.entity_id diff --git a/homeassistant/components/logbook/websocket_api.py b/homeassistant/components/logbook/websocket_api.py index b27ae65b70c..a8f9bc50920 100644 --- a/homeassistant/components/logbook/websocket_api.py +++ b/homeassistant/components/logbook/websocket_api.py @@ -16,9 +16,11 @@ from homeassistant.components.websocket_api import messages from homeassistant.components.websocket_api.connection import ActiveConnection from homeassistant.components.websocket_api.const import JSON_DUMP from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback +from homeassistant.helpers.entityfilter import EntityFilter from homeassistant.helpers.event import async_track_point_in_utc_time import homeassistant.util.dt as dt_util +from .const import LOGBOOK_ENTITIES_FILTER from .helpers import ( async_determine_event_types, async_filter_entities, @@ -365,8 +367,18 @@ async def ws_event_stream( ) _unsub() + entities_filter: EntityFilter | None = None + if not event_processor.limited_select: + entities_filter = hass.data[LOGBOOK_ENTITIES_FILTER] + async_subscribe_events( - hass, subscriptions, _queue_or_cancel, event_types, entity_ids, device_ids + hass, + subscriptions, + _queue_or_cancel, + event_types, + entities_filter, + entity_ids, + device_ids, ) subscriptions_setup_complete_time = dt_util.utcnow() connection.subscriptions[msg_id] = _unsub diff --git a/tests/components/logbook/common.py b/tests/components/logbook/common.py index b88c3854967..a41f983bfed 100644 --- a/tests/components/logbook/common.py +++ b/tests/components/logbook/common.py @@ -68,7 +68,6 @@ def mock_humanify(hass_, rows): return list( processor._humanify( rows, - None, ent_reg, logbook_run, context_augmenter, diff --git a/tests/components/logbook/test_websocket_api.py b/tests/components/logbook/test_websocket_api.py index 2623a5b17d5..ae1f7968e3b 100644 --- a/tests/components/logbook/test_websocket_api.py +++ b/tests/components/logbook/test_websocket_api.py @@ -27,8 +27,8 @@ from homeassistant.const import ( STATE_OFF, STATE_ON, ) -from homeassistant.core import Event, HomeAssistant, State -from homeassistant.helpers import device_registry +from homeassistant.core import Event, HomeAssistant, State, callback +from homeassistant.helpers import device_registry, entity_registry from homeassistant.helpers.entityfilter import CONF_ENTITY_GLOBS from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -51,22 +51,8 @@ def set_utc(hass): hass.config.set_time_zone("UTC") -async def _async_mock_device_with_logbook_platform(hass): - """Mock an integration that provides a device that are described by the logbook.""" - entry = MockConfigEntry(domain="test", data={"first": True}, options=None) - entry.add_to_hass(hass) - dev_reg = device_registry.async_get(hass) - device = dev_reg.async_get_or_create( - config_entry_id=entry.entry_id, - connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, - identifiers={("bridgeid", "0123")}, - sw_version="sw-version", - name="device name", - manufacturer="manufacturer", - model="model", - suggested_area="Game Room", - ) - +@callback +async def _async_mock_logbook_platform(hass: HomeAssistant) -> None: class MockLogbookPlatform: """Mock a logbook platform.""" @@ -90,6 +76,40 @@ async def _async_mock_device_with_logbook_platform(hass): async_describe_event("test", "mock_event", async_describe_test_event) await logbook._process_logbook_platform(hass, "test", MockLogbookPlatform) + + +async def _async_mock_entity_with_logbook_platform(hass): + """Mock an integration that provides an entity that are described by the logbook.""" + entry = MockConfigEntry(domain="test", data={"first": True}, options=None) + entry.add_to_hass(hass) + ent_reg = entity_registry.async_get(hass) + entry = ent_reg.async_get_or_create( + platform="test", + domain="sensor", + config_entry=entry, + unique_id="1234", + suggested_object_id="test", + ) + await _async_mock_logbook_platform(hass) + return entry + + +async def _async_mock_device_with_logbook_platform(hass): + """Mock an integration that provides a device that are described by the logbook.""" + entry = MockConfigEntry(domain="test", data={"first": True}, options=None) + entry.add_to_hass(hass) + dev_reg = device_registry.async_get(hass) + device = dev_reg.async_get_or_create( + config_entry_id=entry.entry_id, + connections={(device_registry.CONNECTION_NETWORK_MAC, "12:34:56:AB:CD:EF")}, + identifiers={("bridgeid", "0123")}, + sw_version="sw-version", + name="device name", + manufacturer="manufacturer", + model="model", + suggested_area="Game Room", + ) + await _async_mock_logbook_platform(hass) return device @@ -1786,6 +1806,103 @@ async def test_event_stream_bad_start_time(hass, hass_ws_client, recorder_mock): assert response["error"]["code"] == "invalid_start_time" +@patch("homeassistant.components.logbook.websocket_api.EVENT_COALESCE_TIME", 0) +async def test_logbook_stream_match_multiple_entities( + hass, recorder_mock, hass_ws_client +): + """Test logbook stream with a described integration that uses multiple entities.""" + now = dt_util.utcnow() + await asyncio.gather( + *[ + async_setup_component(hass, comp, {}) + for comp in ("homeassistant", "logbook", "automation", "script") + ] + ) + entry = await _async_mock_entity_with_logbook_platform(hass) + entity_id = entry.entity_id + hass.states.async_set(entity_id, STATE_ON) + + await hass.async_block_till_done() + init_count = sum(hass.bus.async_listeners().values()) + + await async_wait_recording_done(hass) + websocket_client = await hass_ws_client() + await websocket_client.send_json( + { + "id": 7, + "type": "logbook/event_stream", + "start_time": now.isoformat(), + "entity_ids": [entity_id], + } + ) + + msg = await asyncio.wait_for(websocket_client.receive_json(), 2) + assert msg["id"] == 7 + assert msg["type"] == TYPE_RESULT + assert msg["success"] + + # There are no answers to our initial query + # so we get an empty reply. This is to ensure + # consumers of the api know there are no results + # and its not a failure case. This is useful + # in the frontend so we can tell the user there + # are no results vs waiting for them to appear + msg = await asyncio.wait_for(websocket_client.receive_json(), 2) + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"]["events"] == [] + await async_wait_recording_done(hass) + + hass.states.async_set("binary_sensor.should_not_appear", STATE_ON) + hass.states.async_set("binary_sensor.should_not_appear", STATE_OFF) + context = core.Context( + id="ac5bd62de45711eaaeb351041eec8dd9", + user_id="b400facee45711eaa9308bfd3d19e474", + ) + hass.bus.async_fire( + "mock_event", {"entity_id": ["sensor.any", entity_id]}, context=context + ) + hass.bus.async_fire("mock_event", {"entity_id": [f"sensor.any,{entity_id}"]}) + hass.bus.async_fire("mock_event", {"entity_id": ["sensor.no_match", "light.off"]}) + hass.states.async_set(entity_id, STATE_OFF, context=context) + await hass.async_block_till_done() + + msg = await asyncio.wait_for(websocket_client.receive_json(), 2) + assert msg["id"] == 7 + assert msg["type"] == "event" + assert msg["event"]["events"] == [ + { + "context_user_id": "b400facee45711eaa9308bfd3d19e474", + "domain": "test", + "message": "is on fire", + "name": "device name", + "when": ANY, + }, + { + "context_domain": "test", + "context_event_type": "mock_event", + "context_message": "is on fire", + "context_name": "device name", + "context_user_id": "b400facee45711eaa9308bfd3d19e474", + "entity_id": "sensor.test", + "state": "off", + "when": ANY, + }, + ] + + await websocket_client.send_json( + {"id": 8, "type": "unsubscribe_events", "subscription": 7} + ) + msg = await asyncio.wait_for(websocket_client.receive_json(), 2) + + assert msg["id"] == 8 + assert msg["type"] == TYPE_RESULT + assert msg["success"] + + # Check our listener got unsubscribed + assert sum(hass.bus.async_listeners().values()) == init_count + + async def test_event_stream_bad_end_time(hass, hass_ws_client, recorder_mock): """Test event_stream bad end time.""" await async_setup_component(hass, "logbook", {})