Use _get_reconfigure_entry in here_travel_time (#127294)

This commit is contained in:
epenet 2024-10-02 14:31:45 +02:00 committed by GitHub
parent f0f924a0a2
commit 5ed7efb01d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,7 +4,7 @@ from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
import logging import logging
from typing import TYPE_CHECKING, Any from typing import Any
from here_routing import ( from here_routing import (
HERERoutingApi, HERERoutingApi,
@ -17,6 +17,7 @@ from here_transit import HERETransitError
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ( from homeassistant.config_entries import (
SOURCE_RECONFIGURE,
ConfigEntry, ConfigEntry,
ConfigFlow, ConfigFlow,
ConfigFlowResult, ConfigFlowResult,
@ -79,7 +80,7 @@ async def async_validate_api_key(api_key: str) -> None:
) )
def get_user_step_schema(data: dict[str, Any]) -> vol.Schema: def get_user_step_schema(data: Mapping[str, Any]) -> vol.Schema:
"""Get a populated schema or default.""" """Get a populated schema or default."""
travel_mode = data.get(CONF_MODE, TRAVEL_MODE_CAR) travel_mode = data.get(CONF_MODE, TRAVEL_MODE_CAR)
if travel_mode == "publicTransportTimeTable": if travel_mode == "publicTransportTimeTable":
@ -102,11 +103,11 @@ class HERETravelTimeConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
_entry: ConfigEntry
def __init__(self) -> None: def __init__(self) -> None:
"""Init Config Flow.""" """Init Config Flow."""
self._config: dict[str, Any] = {} self._config: dict[str, Any] = {}
self._entry: ConfigEntry | None = None
self._is_reconfigure_flow: bool = False
@staticmethod @staticmethod
@callback @callback
@ -122,21 +123,19 @@ class HERETravelTimeConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle the initial step.""" """Handle the initial step."""
errors = {} errors = {}
user_input = user_input or {} user_input = user_input or {}
if not self._is_reconfigure_flow: # Always show form first for reconfiguration if user_input:
if user_input: try:
try: await async_validate_api_key(user_input[CONF_API_KEY])
await async_validate_api_key(user_input[CONF_API_KEY]) except HERERoutingUnauthorizedError:
except HERERoutingUnauthorizedError: errors["base"] = "invalid_auth"
errors["base"] = "invalid_auth" except (HERERoutingError, HERETransitError):
except (HERERoutingError, HERETransitError): _LOGGER.exception("Unexpected exception")
_LOGGER.exception("Unexpected exception") errors["base"] = "unknown"
errors["base"] = "unknown" if not errors:
if not errors: self._config[CONF_NAME] = user_input[CONF_NAME]
self._config[CONF_NAME] = user_input[CONF_NAME] self._config[CONF_API_KEY] = user_input[CONF_API_KEY]
self._config[CONF_API_KEY] = user_input[CONF_API_KEY] self._config[CONF_MODE] = user_input[CONF_MODE]
self._config[CONF_MODE] = user_input[CONF_MODE] return await self.async_step_origin_menu()
return await self.async_step_origin_menu()
self._is_reconfigure_flow = False
return self.async_show_form( return self.async_show_form(
step_id="user", data_schema=get_user_step_schema(user_input), errors=errors step_id="user", data_schema=get_user_step_schema(user_input), errors=errors
) )
@ -145,12 +144,10 @@ class HERETravelTimeConfigFlow(ConfigFlow, domain=DOMAIN):
self, entry_data: Mapping[str, Any] self, entry_data: Mapping[str, Any]
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Handle reconfiguration.""" """Handle reconfiguration."""
self._is_reconfigure_flow = True self._entry = self._get_reconfigure_entry()
self._entry = self.hass.config_entries.async_get_entry(self.context["entry_id"]) return self.async_show_form(
if TYPE_CHECKING: step_id="user", data_schema=get_user_step_schema(entry_data)
assert self._entry )
self._config = self._entry.data.copy()
return await self.async_step_user(self._config)
async def async_step_origin_menu(self, _: None = None) -> ConfigFlowResult: async def async_step_origin_menu(self, _: None = None) -> ConfigFlowResult:
"""Show the origin menu.""" """Show the origin menu."""
@ -233,7 +230,7 @@ class HERETravelTimeConfigFlow(ConfigFlow, domain=DOMAIN):
] ]
# Remove possible previous configuration using an entity_id # Remove possible previous configuration using an entity_id
self._config.pop(CONF_DESTINATION_ENTITY_ID, None) self._config.pop(CONF_DESTINATION_ENTITY_ID, None)
if self._entry: if self.source == SOURCE_RECONFIGURE:
return self.async_update_reload_and_abort( return self.async_update_reload_and_abort(
self._entry, self._entry,
title=self._config[CONF_NAME], title=self._config[CONF_NAME],
@ -278,7 +275,7 @@ class HERETravelTimeConfigFlow(ConfigFlow, domain=DOMAIN):
# Remove possible previous configuration using coordinates # Remove possible previous configuration using coordinates
self._config.pop(CONF_DESTINATION_LATITUDE, None) self._config.pop(CONF_DESTINATION_LATITUDE, None)
self._config.pop(CONF_DESTINATION_LONGITUDE, None) self._config.pop(CONF_DESTINATION_LONGITUDE, None)
if self._entry: if self.source == SOURCE_RECONFIGURE:
return self.async_update_reload_and_abort( return self.async_update_reload_and_abort(
self._entry, data=self._config, reason="reconfigure_successful" self._entry, data=self._config, reason="reconfigure_successful"
) )