Move thread safety check in async_register/async_remove (#116077)

This commit is contained in:
J. Nick Koston 2024-04-24 10:41:11 +02:00 committed by GitHub
parent 5bded2a52d
commit e0b58c3f45
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 63 additions and 4 deletions

View file

@ -2456,7 +2456,7 @@ class ServiceRegistry:
""" """
run_callback_threadsafe( run_callback_threadsafe(
self._hass.loop, self._hass.loop,
self.async_register, self._async_register,
domain, domain,
service, service,
service_func, service_func,
@ -2484,6 +2484,33 @@ class ServiceRegistry:
Schema is called to coerce and validate the service data. Schema is called to coerce and validate the service data.
This method must be run in the event loop.
"""
self._hass.verify_event_loop_thread("async_register")
self._async_register(
domain, service, service_func, schema, supports_response, job_type
)
@callback
def _async_register(
self,
domain: str,
service: str,
service_func: Callable[
[ServiceCall],
Coroutine[Any, Any, ServiceResponse | EntityServiceResponse]
| ServiceResponse
| EntityServiceResponse
| None,
],
schema: vol.Schema | None = None,
supports_response: SupportsResponse = SupportsResponse.NONE,
job_type: HassJobType | None = None,
) -> None:
"""Register a service.
Schema is called to coerce and validate the service data.
This method must be run in the event loop. This method must be run in the event loop.
""" """
domain = domain.lower() domain = domain.lower()
@ -2502,20 +2529,29 @@ class ServiceRegistry:
else: else:
self._services[domain] = {service: service_obj} self._services[domain] = {service: service_obj}
self._hass.bus.async_fire( self._hass.bus.async_fire_internal(
EVENT_SERVICE_REGISTERED, {ATTR_DOMAIN: domain, ATTR_SERVICE: service} EVENT_SERVICE_REGISTERED, {ATTR_DOMAIN: domain, ATTR_SERVICE: service}
) )
def remove(self, domain: str, service: str) -> None: def remove(self, domain: str, service: str) -> None:
"""Remove a registered service from service handler.""" """Remove a registered service from service handler."""
run_callback_threadsafe( run_callback_threadsafe(
self._hass.loop, self.async_remove, domain, service self._hass.loop, self._async_remove, domain, service
).result() ).result()
@callback @callback
def async_remove(self, domain: str, service: str) -> None: def async_remove(self, domain: str, service: str) -> None:
"""Remove a registered service from service handler. """Remove a registered service from service handler.
This method must be run in the event loop.
"""
self._hass.verify_event_loop_thread("async_remove")
self._async_remove(domain, service)
@callback
def _async_remove(self, domain: str, service: str) -> None:
"""Remove a registered service from service handler.
This method must be run in the event loop. This method must be run in the event loop.
""" """
domain = domain.lower() domain = domain.lower()
@ -2530,7 +2566,7 @@ class ServiceRegistry:
if not self._services[domain]: if not self._services[domain]:
self._services.pop(domain) self._services.pop(domain)
self._hass.bus.async_fire( self._hass.bus.async_fire_internal(
EVENT_SERVICE_REMOVED, {ATTR_DOMAIN: domain, ATTR_SERVICE: service} EVENT_SERVICE_REMOVED, {ATTR_DOMAIN: domain, ATTR_SERVICE: service}
) )

View file

@ -3457,3 +3457,26 @@ async def test_async_fire_thread_safety(hass: HomeAssistant) -> None:
await hass.async_add_executor_job(hass.bus.async_fire, "test_event") await hass.async_add_executor_job(hass.bus.async_fire, "test_event")
assert len(events) == 1 assert len(events) == 1
async def test_async_register_thread_safety(hass: HomeAssistant) -> None:
"""Test async_register thread safety."""
with pytest.raises(
RuntimeError, match="Detected code that calls async_register from a thread."
):
await hass.async_add_executor_job(
hass.services.async_register,
"test_domain",
"test_service",
lambda call: None,
)
async def test_async_remove_thread_safety(hass: HomeAssistant) -> None:
"""Test async_remove thread safety."""
with pytest.raises(
RuntimeError, match="Detected code that calls async_remove from a thread."
):
await hass.async_add_executor_job(
hass.services.async_remove, "test_domain", "test_service"
)