"""Config flow for MySensors."""
from __future__ import annotations

import os
from typing import Any

from awesomeversion import (
    AwesomeVersion,
    AwesomeVersionStrategy,
    AwesomeVersionStrategyException,
)
import voluptuous as vol

from homeassistant import config_entries
from homeassistant.components.mqtt import (
    DOMAIN as MQTT_DOMAIN,
    valid_publish_topic,
    valid_subscribe_topic,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
import homeassistant.helpers.config_validation as cv

from . import DEFAULT_BAUD_RATE, DEFAULT_TCP_PORT, DEFAULT_VERSION, is_persistence_file
from .const import (
    CONF_BAUD_RATE,
    CONF_DEVICE,
    CONF_GATEWAY_TYPE,
    CONF_GATEWAY_TYPE_ALL,
    CONF_GATEWAY_TYPE_MQTT,
    CONF_GATEWAY_TYPE_SERIAL,
    CONF_GATEWAY_TYPE_TCP,
    CONF_PERSISTENCE_FILE,
    CONF_RETAIN,
    CONF_TCP_PORT,
    CONF_TOPIC_IN_PREFIX,
    CONF_TOPIC_OUT_PREFIX,
    CONF_VERSION,
    DOMAIN,
    ConfGatewayType,
)
from .gateway import MQTT_COMPONENT, is_serial_port, is_socket_address, try_connect


def _get_schema_common(user_input: dict[str, str]) -> dict:
    """Create a schema with options common to all gateway types."""
    schema = {
        vol.Required(
            CONF_VERSION,
            description={
                "suggested_value": user_input.get(CONF_VERSION, DEFAULT_VERSION)
            },
        ): str,
        vol.Optional(CONF_PERSISTENCE_FILE): str,
    }
    return schema


def _validate_version(version: str) -> dict[str, str]:
    """Validate a version string from the user."""
    version_okay = True
    try:
        AwesomeVersion(
            version,
            [AwesomeVersionStrategy.SIMPLEVER, AwesomeVersionStrategy.SEMVER],
        )
    except AwesomeVersionStrategyException:
        version_okay = False

    if version_okay:
        return {}
    return {CONF_VERSION: "invalid_version"}


def _is_same_device(
    gw_type: ConfGatewayType, user_input: dict[str, Any], entry: ConfigEntry
) -> bool:
    """Check if another ConfigDevice is actually the same as user_input.

    This function only compares addresses and tcp ports, so it is possible to fool it with tricks like port forwarding.
    """
    if entry.data[CONF_DEVICE] != user_input[CONF_DEVICE]:
        return False
    if gw_type == CONF_GATEWAY_TYPE_TCP:
        entry_tcp_port: int = entry.data[CONF_TCP_PORT]
        input_tcp_port: int = user_input[CONF_TCP_PORT]
        return entry_tcp_port == input_tcp_port
    if gw_type == CONF_GATEWAY_TYPE_MQTT:
        entry_topics = {
            entry.data[CONF_TOPIC_IN_PREFIX],
            entry.data[CONF_TOPIC_OUT_PREFIX],
        }
        return (
            user_input.get(CONF_TOPIC_IN_PREFIX) in entry_topics
            or user_input.get(CONF_TOPIC_OUT_PREFIX) in entry_topics
        )
    return True


class MySensorsConfigFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
    """Handle a config flow."""

    def __init__(self) -> None:
        """Set up config flow."""
        self._gw_type: str | None = None

    async def async_step_import(self, user_input: dict[str, Any]) -> FlowResult:
        """Import a config entry.

        This method is called by async_setup and it has already
        prepared the dict to be compatible with what a user would have
        entered from the frontend.
        Therefore we process it as though it came from the frontend.
        """
        if user_input[CONF_DEVICE] == MQTT_COMPONENT:
            user_input[CONF_GATEWAY_TYPE] = CONF_GATEWAY_TYPE_MQTT
        else:
            try:
                await self.hass.async_add_executor_job(
                    is_serial_port, user_input[CONF_DEVICE]
                )
            except vol.Invalid:
                user_input[CONF_GATEWAY_TYPE] = CONF_GATEWAY_TYPE_TCP
            else:
                user_input[CONF_GATEWAY_TYPE] = CONF_GATEWAY_TYPE_SERIAL

        result: FlowResult = await self.async_step_user(user_input=user_input)
        if errors := result.get("errors"):
            return self.async_abort(reason=next(iter(errors.values())))
        return result

    async def async_step_user(
        self, user_input: dict[str, str] | None = None
    ) -> FlowResult:
        """Create a config entry from frontend user input."""
        schema = {vol.Required(CONF_GATEWAY_TYPE): vol.In(CONF_GATEWAY_TYPE_ALL)}
        schema = vol.Schema(schema)
        errors = {}

        if user_input is not None:
            gw_type = self._gw_type = user_input[CONF_GATEWAY_TYPE]
            input_pass = user_input if CONF_DEVICE in user_input else None
            if gw_type == CONF_GATEWAY_TYPE_MQTT:
                # Naive check that doesn't consider config entry state.
                if MQTT_DOMAIN in self.hass.config.components:
                    return await self.async_step_gw_mqtt(input_pass)

                errors["base"] = "mqtt_required"
            if gw_type == CONF_GATEWAY_TYPE_TCP:
                return await self.async_step_gw_tcp(input_pass)
            if gw_type == CONF_GATEWAY_TYPE_SERIAL:
                return await self.async_step_gw_serial(input_pass)

        return self.async_show_form(step_id="user", data_schema=schema, errors=errors)

    async def async_step_gw_serial(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        """Create config entry for a serial gateway."""
        errors: dict[str, str] = {}
        if user_input is not None:
            errors.update(
                await self.validate_common(CONF_GATEWAY_TYPE_SERIAL, errors, user_input)
            )
            if not errors:
                return self._async_create_entry(user_input)

        user_input = user_input or {}
        schema = _get_schema_common(user_input)
        schema[
            vol.Required(
                CONF_BAUD_RATE,
                default=user_input.get(CONF_BAUD_RATE, DEFAULT_BAUD_RATE),
            )
        ] = cv.positive_int
        schema[
            vol.Required(
                CONF_DEVICE, default=user_input.get(CONF_DEVICE, "/dev/ttyACM0")
            )
        ] = str

        schema = vol.Schema(schema)
        return self.async_show_form(
            step_id="gw_serial", data_schema=schema, errors=errors
        )

    async def async_step_gw_tcp(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        """Create a config entry for a tcp gateway."""
        errors = {}
        if user_input is not None:
            if CONF_TCP_PORT in user_input:
                port: int = user_input[CONF_TCP_PORT]
                if not (0 < port <= 65535):
                    errors[CONF_TCP_PORT] = "port_out_of_range"

            errors.update(
                await self.validate_common(CONF_GATEWAY_TYPE_TCP, errors, user_input)
            )
            if not errors:
                return self._async_create_entry(user_input)

        user_input = user_input or {}
        schema = _get_schema_common(user_input)
        schema[
            vol.Required(CONF_DEVICE, default=user_input.get(CONF_DEVICE, "127.0.0.1"))
        ] = str
        # Don't use cv.port as that would show a slider *facepalm*
        schema[
            vol.Optional(
                CONF_TCP_PORT, default=user_input.get(CONF_TCP_PORT, DEFAULT_TCP_PORT)
            )
        ] = vol.Coerce(int)

        schema = vol.Schema(schema)
        return self.async_show_form(step_id="gw_tcp", data_schema=schema, errors=errors)

    def _check_topic_exists(self, topic: str) -> bool:
        for other_config in self._async_current_entries():
            if topic == other_config.data.get(
                CONF_TOPIC_IN_PREFIX
            ) or topic == other_config.data.get(CONF_TOPIC_OUT_PREFIX):
                return True
        return False

    async def async_step_gw_mqtt(
        self, user_input: dict[str, Any] | None = None
    ) -> FlowResult:
        """Create a config entry for a mqtt gateway."""
        errors = {}
        if user_input is not None:
            user_input[CONF_DEVICE] = MQTT_COMPONENT

            try:
                valid_subscribe_topic(user_input[CONF_TOPIC_IN_PREFIX])
            except vol.Invalid:
                errors[CONF_TOPIC_IN_PREFIX] = "invalid_subscribe_topic"
            else:
                if self._check_topic_exists(user_input[CONF_TOPIC_IN_PREFIX]):
                    errors[CONF_TOPIC_IN_PREFIX] = "duplicate_topic"

            try:
                valid_publish_topic(user_input[CONF_TOPIC_OUT_PREFIX])
            except vol.Invalid:
                errors[CONF_TOPIC_OUT_PREFIX] = "invalid_publish_topic"
            if not errors:
                if (
                    user_input[CONF_TOPIC_IN_PREFIX]
                    == user_input[CONF_TOPIC_OUT_PREFIX]
                ):
                    errors[CONF_TOPIC_OUT_PREFIX] = "same_topic"
                elif self._check_topic_exists(user_input[CONF_TOPIC_OUT_PREFIX]):
                    errors[CONF_TOPIC_OUT_PREFIX] = "duplicate_topic"

            errors.update(
                await self.validate_common(CONF_GATEWAY_TYPE_MQTT, errors, user_input)
            )
            if not errors:
                return self._async_create_entry(user_input)

        user_input = user_input or {}
        schema = _get_schema_common(user_input)
        schema[
            vol.Required(CONF_RETAIN, default=user_input.get(CONF_RETAIN, True))
        ] = bool
        schema[
            vol.Required(
                CONF_TOPIC_IN_PREFIX, default=user_input.get(CONF_TOPIC_IN_PREFIX, "")
            )
        ] = str
        schema[
            vol.Required(
                CONF_TOPIC_OUT_PREFIX, default=user_input.get(CONF_TOPIC_OUT_PREFIX, "")
            )
        ] = str

        schema = vol.Schema(schema)
        return self.async_show_form(
            step_id="gw_mqtt", data_schema=schema, errors=errors
        )

    @callback
    def _async_create_entry(self, user_input: dict[str, Any]) -> FlowResult:
        """Create the config entry."""
        return self.async_create_entry(
            title=f"{user_input[CONF_DEVICE]}",
            data={**user_input, CONF_GATEWAY_TYPE: self._gw_type},
        )

    def _normalize_persistence_file(self, path: str) -> str:
        return os.path.realpath(os.path.normcase(self.hass.config.path(path)))

    async def validate_common(
        self,
        gw_type: ConfGatewayType,
        errors: dict[str, str],
        user_input: dict[str, Any],
    ) -> dict[str, str]:
        """Validate parameters common to all gateway types."""
        errors.update(_validate_version(user_input[CONF_VERSION]))

        if gw_type != CONF_GATEWAY_TYPE_MQTT:
            if gw_type == CONF_GATEWAY_TYPE_TCP:
                verification_func = is_socket_address
            else:
                verification_func = is_serial_port

            try:
                await self.hass.async_add_executor_job(
                    verification_func, user_input.get(CONF_DEVICE)
                )
            except vol.Invalid:
                errors[CONF_DEVICE] = (
                    "invalid_ip"
                    if gw_type == CONF_GATEWAY_TYPE_TCP
                    else "invalid_serial"
                )
        if CONF_PERSISTENCE_FILE in user_input:
            try:
                is_persistence_file(user_input[CONF_PERSISTENCE_FILE])
            except vol.Invalid:
                errors[CONF_PERSISTENCE_FILE] = "invalid_persistence_file"
            else:
                real_persistence_path = user_input[
                    CONF_PERSISTENCE_FILE
                ] = self._normalize_persistence_file(user_input[CONF_PERSISTENCE_FILE])
                for other_entry in self._async_current_entries():
                    if CONF_PERSISTENCE_FILE not in other_entry.data:
                        continue
                    if real_persistence_path == self._normalize_persistence_file(
                        other_entry.data[CONF_PERSISTENCE_FILE]
                    ):
                        errors[CONF_PERSISTENCE_FILE] = "duplicate_persistence_file"
                        break

        for other_entry in self._async_current_entries():
            if _is_same_device(gw_type, user_input, other_entry):
                errors["base"] = "already_configured"
                break

        # if no errors so far, try to connect
        if not errors and not await try_connect(self.hass, gw_type, user_input):
            errors["base"] = "cannot_connect"

        return errors