Migrate rest notify to httpx (#90769)
This commit is contained in:
parent
26f7843800
commit
949e8f7b13
1 changed files with 26 additions and 20 deletions
|
@ -5,8 +5,7 @@ from http import HTTPStatus
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import requests
|
import httpx
|
||||||
from requests.auth import AuthBase, HTTPBasicAuth, HTTPDigestAuth
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.notify import (
|
from homeassistant.components.notify import (
|
||||||
|
@ -32,6 +31,7 @@ from homeassistant.const import (
|
||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
|
from homeassistant.helpers.httpx_client import get_async_client
|
||||||
from homeassistant.helpers.template import Template
|
from homeassistant.helpers.template import Template
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_service(
|
async def async_get_service(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
config: ConfigType,
|
config: ConfigType,
|
||||||
discovery_info: DiscoveryInfoType | None = None,
|
discovery_info: DiscoveryInfoType | None = None,
|
||||||
|
@ -91,12 +91,12 @@ def get_service(
|
||||||
password: str | None = config.get(CONF_PASSWORD)
|
password: str | None = config.get(CONF_PASSWORD)
|
||||||
verify_ssl: bool = config[CONF_VERIFY_SSL]
|
verify_ssl: bool = config[CONF_VERIFY_SSL]
|
||||||
|
|
||||||
auth: AuthBase | None = None
|
auth: httpx.Auth | None = None
|
||||||
if username and password:
|
if username and password:
|
||||||
if config.get(CONF_AUTHENTICATION) == HTTP_DIGEST_AUTHENTICATION:
|
if config.get(CONF_AUTHENTICATION) == HTTP_DIGEST_AUTHENTICATION:
|
||||||
auth = HTTPDigestAuth(username, password)
|
auth = httpx.DigestAuth(username, password)
|
||||||
else:
|
else:
|
||||||
auth = HTTPBasicAuth(username, password)
|
auth = httpx.BasicAuth(username, password)
|
||||||
|
|
||||||
return RestNotificationService(
|
return RestNotificationService(
|
||||||
hass,
|
hass,
|
||||||
|
@ -129,7 +129,7 @@ class RestNotificationService(BaseNotificationService):
|
||||||
target_param_name: str | None,
|
target_param_name: str | None,
|
||||||
data: dict[str, Any] | None,
|
data: dict[str, Any] | None,
|
||||||
data_template: dict[str, Any] | None,
|
data_template: dict[str, Any] | None,
|
||||||
auth: AuthBase | None,
|
auth: httpx.Auth | None,
|
||||||
verify_ssl: bool,
|
verify_ssl: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the service."""
|
"""Initialize the service."""
|
||||||
|
@ -146,7 +146,7 @@ class RestNotificationService(BaseNotificationService):
|
||||||
self._auth = auth
|
self._auth = auth
|
||||||
self._verify_ssl = verify_ssl
|
self._verify_ssl = verify_ssl
|
||||||
|
|
||||||
def send_message(self, message: str = "", **kwargs: Any) -> None:
|
async def async_send_message(self, message: str = "", **kwargs: Any) -> None:
|
||||||
"""Send a message to a user."""
|
"""Send a message to a user."""
|
||||||
data = {self._message_param_name: message}
|
data = {self._message_param_name: message}
|
||||||
|
|
||||||
|
@ -179,34 +179,32 @@ class RestNotificationService(BaseNotificationService):
|
||||||
if self._data_template:
|
if self._data_template:
|
||||||
data.update(_data_template_creator(self._data_template))
|
data.update(_data_template_creator(self._data_template))
|
||||||
|
|
||||||
|
websession = get_async_client(self._hass, self._verify_ssl)
|
||||||
if self._method == "POST":
|
if self._method == "POST":
|
||||||
response = requests.post(
|
response = await websession.post(
|
||||||
self._resource,
|
self._resource,
|
||||||
headers=self._headers,
|
headers=self._headers,
|
||||||
params=self._params,
|
params=self._params,
|
||||||
data=data,
|
data=data,
|
||||||
timeout=10,
|
timeout=10,
|
||||||
auth=self._auth,
|
auth=self._auth or httpx.USE_CLIENT_DEFAULT,
|
||||||
verify=self._verify_ssl,
|
|
||||||
)
|
)
|
||||||
elif self._method == "POST_JSON":
|
elif self._method == "POST_JSON":
|
||||||
response = requests.post(
|
response = await websession.post(
|
||||||
self._resource,
|
self._resource,
|
||||||
headers=self._headers,
|
headers=self._headers,
|
||||||
params=self._params,
|
params=self._params,
|
||||||
json=data,
|
json=data,
|
||||||
timeout=10,
|
timeout=10,
|
||||||
auth=self._auth,
|
auth=self._auth or httpx.USE_CLIENT_DEFAULT,
|
||||||
verify=self._verify_ssl,
|
|
||||||
)
|
)
|
||||||
else: # default GET
|
else: # default GET
|
||||||
response = requests.get(
|
response = await websession.get(
|
||||||
self._resource,
|
self._resource,
|
||||||
headers=self._headers,
|
headers=self._headers,
|
||||||
params={**self._params, **data} if self._params else data,
|
params={**self._params, **data} if self._params else data,
|
||||||
timeout=10,
|
timeout=10,
|
||||||
auth=self._auth,
|
auth=self._auth,
|
||||||
verify=self._verify_ssl,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
@ -214,21 +212,29 @@ class RestNotificationService(BaseNotificationService):
|
||||||
and response.status_code < 600
|
and response.status_code < 600
|
||||||
):
|
):
|
||||||
_LOGGER.exception(
|
_LOGGER.exception(
|
||||||
"Server error. Response %d: %s:", response.status_code, response.reason
|
"Server error. Response %d: %s:",
|
||||||
|
response.status_code,
|
||||||
|
response.reason_phrase,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
response.status_code >= HTTPStatus.BAD_REQUEST
|
response.status_code >= HTTPStatus.BAD_REQUEST
|
||||||
and response.status_code < HTTPStatus.INTERNAL_SERVER_ERROR
|
and response.status_code < HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
):
|
):
|
||||||
_LOGGER.exception(
|
_LOGGER.exception(
|
||||||
"Client error. Response %d: %s:", response.status_code, response.reason
|
"Client error. Response %d: %s:",
|
||||||
|
response.status_code,
|
||||||
|
response.reason_phrase,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
response.status_code >= HTTPStatus.OK
|
response.status_code >= HTTPStatus.OK
|
||||||
and response.status_code < HTTPStatus.MULTIPLE_CHOICES
|
and response.status_code < HTTPStatus.MULTIPLE_CHOICES
|
||||||
):
|
):
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Success. Response %d: %s:", response.status_code, response.reason
|
"Success. Response %d: %s:",
|
||||||
|
response.status_code,
|
||||||
|
response.reason_phrase,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_LOGGER.debug("Response %d: %s:", response.status_code, response.reason)
|
_LOGGER.debug(
|
||||||
|
"Response %d: %s:", response.status_code, response.reason_phrase
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue