"""Config flow for Tradfri."""
import asyncio
from collections import OrderedDict
from uuid import uuid4

import async_timeout
import voluptuous as vol

from homeassistant import config_entries

from .const import (
    CONF_IMPORT_GROUPS,
    CONF_IDENTITY,
    CONF_HOST,
    CONF_KEY,
    CONF_GATEWAY_ID,
)

KEY_SECURITY_CODE = "security_code"
KEY_IMPORT_GROUPS = "import_groups"


class AuthError(Exception):
    """Exception if authentication occurs."""

    def __init__(self, code):
        """Initialize exception."""
        super().__init__()
        self.code = code


@config_entries.HANDLERS.register("tradfri")
class FlowHandler(config_entries.ConfigFlow):
    """Handle a config flow."""

    VERSION = 1

    def __init__(self):
        """Initialize flow."""
        self._host = None
        self._import_groups = False

    async def async_step_user(self, user_input=None):
        """Handle a flow initialized by the user."""
        return await self.async_step_auth()

    async def async_step_auth(self, user_input=None):
        """Handle the authentication with a gateway."""
        errors = {}

        if user_input is not None:
            host = user_input.get(CONF_HOST, self._host)
            try:
                auth = await authenticate(
                    self.hass, host, user_input[KEY_SECURITY_CODE]
                )

                # We don't ask for import group anymore as group state
                # is not reliable, don't want to show that to the user.
                # But we still allow specifying import group via config yaml.
                auth[CONF_IMPORT_GROUPS] = self._import_groups

                return await self._entry_from_data(auth)

            except AuthError as err:
                if err.code == "invalid_security_code":
                    errors[KEY_SECURITY_CODE] = err.code
                else:
                    errors["base"] = err.code

        fields = OrderedDict()

        if self._host is None:
            fields[vol.Required(CONF_HOST)] = str

        fields[vol.Required(KEY_SECURITY_CODE)] = str

        return self.async_show_form(
            step_id="auth", data_schema=vol.Schema(fields), errors=errors
        )

    async def async_step_zeroconf(self, user_input):
        """Handle zeroconf discovery."""
        host = user_input["host"]

        # pylint: disable=unsupported-assignment-operation
        self.context["host"] = host

        if any(host == flow["context"]["host"] for flow in self._async_in_progress()):
            return self.async_abort(reason="already_in_progress")

        for entry in self._async_current_entries():
            if entry.data[CONF_HOST] == host:
                return self.async_abort(reason="already_configured")

        self._host = host
        return await self.async_step_auth()

    async_step_homekit = async_step_zeroconf

    async def async_step_import(self, user_input):
        """Import a config entry."""
        for entry in self._async_current_entries():
            if entry.data[CONF_HOST] == user_input["host"]:
                return self.async_abort(reason="already_configured")

        # Happens if user has host directly in configuration.yaml
        if "key" not in user_input:
            self._host = user_input["host"]
            self._import_groups = user_input[CONF_IMPORT_GROUPS]
            return await self.async_step_auth()

        try:
            data = await get_gateway_info(
                self.hass,
                user_input["host"],
                # Old config format had a fixed identity
                user_input.get("identity", "homeassistant"),
                user_input["key"],
            )

            data[CONF_IMPORT_GROUPS] = user_input[CONF_IMPORT_GROUPS]

            return await self._entry_from_data(data)
        except AuthError:
            # If we fail to connect, just pass it on to discovery
            self._host = user_input["host"]
            return await self.async_step_auth()

    async def _entry_from_data(self, data):
        """Create an entry from data."""
        host = data[CONF_HOST]
        gateway_id = data[CONF_GATEWAY_ID]

        same_hub_entries = [
            entry.entry_id
            for entry in self._async_current_entries()
            if entry.data[CONF_GATEWAY_ID] == gateway_id
            or entry.data[CONF_HOST] == host
        ]

        if same_hub_entries:
            await asyncio.wait(
                [
                    self.hass.config_entries.async_remove(entry_id)
                    for entry_id in same_hub_entries
                ]
            )

        return self.async_create_entry(title=host, data=data)


async def authenticate(hass, host, security_code):
    """Authenticate with a Tradfri hub."""
    from pytradfri.api.aiocoap_api import APIFactory
    from pytradfri import RequestError

    identity = uuid4().hex

    api_factory = APIFactory(host, psk_id=identity)

    try:
        with async_timeout.timeout(5):
            key = await api_factory.generate_psk(security_code)
    except RequestError:
        raise AuthError("invalid_security_code")
    except asyncio.TimeoutError:
        raise AuthError("timeout")

    return await get_gateway_info(hass, host, identity, key)


async def get_gateway_info(hass, host, identity, key):
    """Return info for the gateway."""
    from pytradfri.api.aiocoap_api import APIFactory
    from pytradfri import Gateway, RequestError

    try:
        factory = APIFactory(host, psk_id=identity, psk=key, loop=hass.loop)

        api = factory.request
        gateway = Gateway()
        gateway_info_result = await api(gateway.get_gateway_info())

        await factory.shutdown()
    except (OSError, RequestError):
        # We're also catching OSError as PyTradfri doesn't catch that one yet
        # Upstream PR: https://github.com/ggravlingen/pytradfri/pull/189
        raise AuthError("cannot_connect")

    return {
        CONF_HOST: host,
        CONF_IDENTITY: identity,
        CONF_KEY: key,
        CONF_GATEWAY_ID: gateway_info_result.id,
    }