Migrate rest notify to httpx (#90769)

This commit is contained in:
epenet 2023-05-11 08:26:16 +01:00 committed by GitHub
parent 26f7843800
commit 949e8f7b13
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -5,8 +5,7 @@ from http import HTTPStatus
import logging import logging
from typing import Any from typing import Any
import requests import httpx
from requests.auth import AuthBase, HTTPBasicAuth, HTTPDigestAuth
import voluptuous as vol import voluptuous as vol
from homeassistant.components.notify import ( from homeassistant.components.notify import (
@ -32,6 +31,7 @@ from homeassistant.const import (
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.httpx_client import get_async_client
from homeassistant.helpers.template import Template from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
@ -72,7 +72,7 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def get_service( async def async_get_service(
hass: HomeAssistant, hass: HomeAssistant,
config: ConfigType, config: ConfigType,
discovery_info: DiscoveryInfoType | None = None, discovery_info: DiscoveryInfoType | None = None,
@ -91,12 +91,12 @@ def get_service(
password: str | None = config.get(CONF_PASSWORD) password: str | None = config.get(CONF_PASSWORD)
verify_ssl: bool = config[CONF_VERIFY_SSL] verify_ssl: bool = config[CONF_VERIFY_SSL]
auth: AuthBase | None = None auth: httpx.Auth | None = None
if username and password: if username and password:
if config.get(CONF_AUTHENTICATION) == HTTP_DIGEST_AUTHENTICATION: if config.get(CONF_AUTHENTICATION) == HTTP_DIGEST_AUTHENTICATION:
auth = HTTPDigestAuth(username, password) auth = httpx.DigestAuth(username, password)
else: else:
auth = HTTPBasicAuth(username, password) auth = httpx.BasicAuth(username, password)
return RestNotificationService( return RestNotificationService(
hass, hass,
@ -129,7 +129,7 @@ class RestNotificationService(BaseNotificationService):
target_param_name: str | None, target_param_name: str | None,
data: dict[str, Any] | None, data: dict[str, Any] | None,
data_template: dict[str, Any] | None, data_template: dict[str, Any] | None,
auth: AuthBase | None, auth: httpx.Auth | None,
verify_ssl: bool, verify_ssl: bool,
) -> None: ) -> None:
"""Initialize the service.""" """Initialize the service."""
@ -146,7 +146,7 @@ class RestNotificationService(BaseNotificationService):
self._auth = auth self._auth = auth
self._verify_ssl = verify_ssl self._verify_ssl = verify_ssl
def send_message(self, message: str = "", **kwargs: Any) -> None: async def async_send_message(self, message: str = "", **kwargs: Any) -> None:
"""Send a message to a user.""" """Send a message to a user."""
data = {self._message_param_name: message} data = {self._message_param_name: message}
@ -179,34 +179,32 @@ class RestNotificationService(BaseNotificationService):
if self._data_template: if self._data_template:
data.update(_data_template_creator(self._data_template)) data.update(_data_template_creator(self._data_template))
websession = get_async_client(self._hass, self._verify_ssl)
if self._method == "POST": if self._method == "POST":
response = requests.post( response = await websession.post(
self._resource, self._resource,
headers=self._headers, headers=self._headers,
params=self._params, params=self._params,
data=data, data=data,
timeout=10, timeout=10,
auth=self._auth, auth=self._auth or httpx.USE_CLIENT_DEFAULT,
verify=self._verify_ssl,
) )
elif self._method == "POST_JSON": elif self._method == "POST_JSON":
response = requests.post( response = await websession.post(
self._resource, self._resource,
headers=self._headers, headers=self._headers,
params=self._params, params=self._params,
json=data, json=data,
timeout=10, timeout=10,
auth=self._auth, auth=self._auth or httpx.USE_CLIENT_DEFAULT,
verify=self._verify_ssl,
) )
else: # default GET else: # default GET
response = requests.get( response = await websession.get(
self._resource, self._resource,
headers=self._headers, headers=self._headers,
params={**self._params, **data} if self._params else data, params={**self._params, **data} if self._params else data,
timeout=10, timeout=10,
auth=self._auth, auth=self._auth,
verify=self._verify_ssl,
) )
if ( if (
@ -214,21 +212,29 @@ class RestNotificationService(BaseNotificationService):
and response.status_code < 600 and response.status_code < 600
): ):
_LOGGER.exception( _LOGGER.exception(
"Server error. Response %d: %s:", response.status_code, response.reason "Server error. Response %d: %s:",
response.status_code,
response.reason_phrase,
) )
elif ( elif (
response.status_code >= HTTPStatus.BAD_REQUEST response.status_code >= HTTPStatus.BAD_REQUEST
and response.status_code < HTTPStatus.INTERNAL_SERVER_ERROR and response.status_code < HTTPStatus.INTERNAL_SERVER_ERROR
): ):
_LOGGER.exception( _LOGGER.exception(
"Client error. Response %d: %s:", response.status_code, response.reason "Client error. Response %d: %s:",
response.status_code,
response.reason_phrase,
) )
elif ( elif (
response.status_code >= HTTPStatus.OK response.status_code >= HTTPStatus.OK
and response.status_code < HTTPStatus.MULTIPLE_CHOICES and response.status_code < HTTPStatus.MULTIPLE_CHOICES
): ):
_LOGGER.debug( _LOGGER.debug(
"Success. Response %d: %s:", response.status_code, response.reason "Success. Response %d: %s:",
response.status_code,
response.reason_phrase,
) )
else: else:
_LOGGER.debug("Response %d: %s:", response.status_code, response.reason) _LOGGER.debug(
"Response %d: %s:", response.status_code, response.reason_phrase
)