"""
Websocket based API for Home Assistant.

For more details about this component, please refer to the documentation at
https://home-assistant.io/developers/websocket_api/
"""
import asyncio
from concurrent import futures
from contextlib import suppress
from functools import partial
import json
import logging

from aiohttp import web
import voluptuous as vol
from voluptuous.humanize import humanize_error

from homeassistant.const import (
    MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
    __version__)
from homeassistant.core import callback
from homeassistant.loader import bind_hass
from homeassistant.remote import JSONEncoder
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.auth import validate_password
from homeassistant.components.http.const import KEY_AUTHENTICATED
from homeassistant.components.http.ban import process_wrong_login

DOMAIN = 'websocket_api'

URL = '/api/websocket'
DEPENDENCIES = ('http',)

MAX_PENDING_MSG = 512

ERR_ID_REUSE = 1
ERR_INVALID_FORMAT = 2
ERR_NOT_FOUND = 3
ERR_UNKNOWN_COMMAND = 4

TYPE_AUTH = 'auth'
TYPE_AUTH_INVALID = 'auth_invalid'
TYPE_AUTH_OK = 'auth_ok'
TYPE_AUTH_REQUIRED = 'auth_required'
TYPE_CALL_SERVICE = 'call_service'
TYPE_EVENT = 'event'
TYPE_GET_CONFIG = 'get_config'
TYPE_GET_SERVICES = 'get_services'
TYPE_GET_STATES = 'get_states'
TYPE_PING = 'ping'
TYPE_PONG = 'pong'
TYPE_RESULT = 'result'
TYPE_SUBSCRIBE_EVENTS = 'subscribe_events'
TYPE_UNSUBSCRIBE_EVENTS = 'unsubscribe_events'

_LOGGER = logging.getLogger(__name__)

JSON_DUMP = partial(json.dumps, cls=JSONEncoder)

AUTH_MESSAGE_SCHEMA = vol.Schema({
    vol.Required('type'): TYPE_AUTH,
    vol.Exclusive('api_password', 'auth'): str,
    vol.Exclusive('access_token', 'auth'): str,
})

# Minimal requirements of a message
MINIMAL_MESSAGE_SCHEMA = vol.Schema({
    vol.Required('id'): cv.positive_int,
    vol.Required('type'): cv.string,
}, extra=vol.ALLOW_EXTRA)
# Base schema to extend by message handlers
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({
    vol.Required('id'): cv.positive_int,
})


SCHEMA_SUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({
    vol.Required('type'): TYPE_SUBSCRIBE_EVENTS,
    vol.Optional('event_type', default=MATCH_ALL): str,
})


SCHEMA_UNSUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({
    vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS,
    vol.Required('subscription'): cv.positive_int,
})


SCHEMA_CALL_SERVICE = BASE_COMMAND_MESSAGE_SCHEMA.extend({
    vol.Required('type'): TYPE_CALL_SERVICE,
    vol.Required('domain'): str,
    vol.Required('service'): str,
    vol.Optional('service_data'): dict
})


SCHEMA_GET_STATES = BASE_COMMAND_MESSAGE_SCHEMA.extend({
    vol.Required('type'): TYPE_GET_STATES,
})


SCHEMA_GET_SERVICES = BASE_COMMAND_MESSAGE_SCHEMA.extend({
    vol.Required('type'): TYPE_GET_SERVICES,
})


SCHEMA_GET_CONFIG = BASE_COMMAND_MESSAGE_SCHEMA.extend({
    vol.Required('type'): TYPE_GET_CONFIG,
})


SCHEMA_PING = BASE_COMMAND_MESSAGE_SCHEMA.extend({
    vol.Required('type'): TYPE_PING,
})


# Define the possible errors that occur when connections are cancelled.
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
# that futures.CancelledErrors can also occur in some situations.
CANCELLATION_ERRORS = (asyncio.CancelledError, futures.CancelledError)


def auth_ok_message():
    """Return an auth_ok message."""
    return {
        'type': TYPE_AUTH_OK,
        'ha_version': __version__,
    }


def auth_required_message():
    """Return an auth_required message."""
    return {
        'type': TYPE_AUTH_REQUIRED,
        'ha_version': __version__,
    }


def auth_invalid_message(message):
    """Return an auth_invalid message."""
    return {
        'type': TYPE_AUTH_INVALID,
        'message': message,
    }


def event_message(iden, event):
    """Return an event message."""
    return {
        'id': iden,
        'type': TYPE_EVENT,
        'event': event.as_dict(),
    }


def error_message(iden, code, message):
    """Return an error result message."""
    return {
        'id': iden,
        'type': TYPE_RESULT,
        'success': False,
        'error': {
            'code': code,
            'message': message,
        },
    }


def pong_message(iden):
    """Return a pong message."""
    return {
        'id': iden,
        'type': TYPE_PONG,
    }


def result_message(iden, result=None):
    """Return a success result message."""
    return {
        'id': iden,
        'type': TYPE_RESULT,
        'success': True,
        'result': result,
    }


@bind_hass
@callback
def async_register_command(hass, command, handler, schema):
    """Register a websocket command."""
    handlers = hass.data.get(DOMAIN)
    if handlers is None:
        handlers = hass.data[DOMAIN] = {}
    handlers[command] = (handler, schema)


async def async_setup(hass, config):
    """Initialize the websocket API."""
    hass.http.register_view(WebsocketAPIView)

    async_register_command(hass, TYPE_SUBSCRIBE_EVENTS,
                           handle_subscribe_events, SCHEMA_SUBSCRIBE_EVENTS)
    async_register_command(hass, TYPE_UNSUBSCRIBE_EVENTS,
                           handle_unsubscribe_events,
                           SCHEMA_UNSUBSCRIBE_EVENTS)
    async_register_command(hass, TYPE_CALL_SERVICE,
                           handle_call_service, SCHEMA_CALL_SERVICE)
    async_register_command(hass, TYPE_GET_STATES,
                           handle_get_states, SCHEMA_GET_STATES)
    async_register_command(hass, TYPE_GET_SERVICES,
                           handle_get_services, SCHEMA_GET_SERVICES)
    async_register_command(hass, TYPE_GET_CONFIG,
                           handle_get_config, SCHEMA_GET_CONFIG)
    async_register_command(hass, TYPE_PING,
                           handle_ping, SCHEMA_PING)

    return True


class WebsocketAPIView(HomeAssistantView):
    """View to serve a websockets endpoint."""

    name = "websocketapi"
    url = URL
    requires_auth = False

    async def get(self, request):
        """Handle an incoming websocket connection."""
        return await ActiveConnection(request.app['hass'], request).handle()


class ActiveConnection:
    """Handle an active websocket client connection."""

    def __init__(self, hass, request):
        """Initialize an active connection."""
        self.hass = hass
        self.request = request
        self.wsock = None
        self.event_listeners = {}
        self.to_write = asyncio.Queue(maxsize=MAX_PENDING_MSG, loop=hass.loop)
        self._handle_task = None
        self._writer_task = None

    def debug(self, message1, message2=''):
        """Print a debug message."""
        _LOGGER.debug("WS %s: %s %s", id(self.wsock), message1, message2)

    def log_error(self, message1, message2=''):
        """Print an error message."""
        _LOGGER.error("WS %s: %s %s", id(self.wsock), message1, message2)

    async def _writer(self):
        """Write outgoing messages."""
        # Exceptions if Socket disconnected or cancelled by connection handler
        with suppress(RuntimeError, *CANCELLATION_ERRORS):
            while not self.wsock.closed:
                message = await self.to_write.get()
                if message is None:
                    break
                self.debug("Sending", message)
                try:
                    await self.wsock.send_json(message, dumps=JSON_DUMP)
                except TypeError as err:
                    _LOGGER.error('Unable to serialize to JSON: %s\n%s',
                                  err, message)

    @callback
    def send_message_outside(self, message):
        """Send a message to the client outside of the main task.

        Closes connection if the client is not reading the messages.

        Async friendly.
        """
        try:
            self.to_write.put_nowait(message)
        except asyncio.QueueFull:
            self.log_error("Client exceeded max pending messages [2]:",
                           MAX_PENDING_MSG)
            self.cancel()

    @callback
    def cancel(self):
        """Cancel the connection."""
        self._handle_task.cancel()
        self._writer_task.cancel()

    async def handle(self):
        """Handle the websocket connection."""
        request = self.request
        wsock = self.wsock = web.WebSocketResponse(heartbeat=55)
        await wsock.prepare(request)
        self.debug("Connected")

        # Get a reference to current task so we can cancel our connection
        self._handle_task = asyncio.Task.current_task(loop=self.hass.loop)

        @callback
        def handle_hass_stop(event):
            """Cancel this connection."""
            self.cancel()

        unsub_stop = self.hass.bus.async_listen(
            EVENT_HOMEASSISTANT_STOP, handle_hass_stop)
        self._writer_task = self.hass.async_add_job(self._writer())
        final_message = None
        msg = None
        authenticated = False

        try:
            if request[KEY_AUTHENTICATED]:
                authenticated = True

            else:
                await self.wsock.send_json(auth_required_message())
                msg = await wsock.receive_json()
                msg = AUTH_MESSAGE_SCHEMA(msg)

                if 'api_password' in msg:
                    authenticated = validate_password(
                        request, msg['api_password'])

                elif 'access_token' in msg:
                    token = self.hass.auth.async_get_access_token(
                        msg['access_token'])
                    authenticated = token is not None

            if not authenticated:
                self.debug("Invalid password")
                await self.wsock.send_json(
                    auth_invalid_message('Invalid password'))
                await process_wrong_login(request)
                return wsock

            await self.wsock.send_json(auth_ok_message())

            # ---------- AUTH PHASE OVER ----------

            msg = await wsock.receive_json()
            last_id = 0
            handlers = self.hass.data[DOMAIN]

            while msg:
                self.debug("Received", msg)
                msg = MINIMAL_MESSAGE_SCHEMA(msg)
                cur_id = msg['id']

                if cur_id <= last_id:
                    self.to_write.put_nowait(error_message(
                        cur_id, ERR_ID_REUSE,
                        'Identifier values have to increase.'))

                elif msg['type'] not in handlers:
                    self.log_error(
                        'Received invalid command: {}'.format(msg['type']))
                    self.to_write.put_nowait(error_message(
                        cur_id, ERR_UNKNOWN_COMMAND,
                        'Unknown command.'))

                else:
                    handler, schema = handlers[msg['type']]
                    handler(self.hass, self, schema(msg))

                last_id = cur_id
                msg = await wsock.receive_json()

        except vol.Invalid as err:
            error_msg = "Message incorrectly formatted: "
            if msg:
                error_msg += humanize_error(msg, err)
            else:
                error_msg += str(err)

            self.log_error(error_msg)

            if not authenticated:
                final_message = auth_invalid_message(error_msg)

            else:
                if isinstance(msg, dict):
                    iden = msg.get('id')
                else:
                    iden = None

                final_message = error_message(
                    iden, ERR_INVALID_FORMAT, error_msg)

        except TypeError as err:
            if wsock.closed:
                self.debug("Connection closed by client")
            else:
                _LOGGER.exception("Unexpected TypeError: %s", msg)

        except ValueError as err:
            msg = "Received invalid JSON"
            value = getattr(err, 'doc', None)  # Py3.5+ only
            if value:
                msg += ': {}'.format(value)
            self.log_error(msg)
            self._writer_task.cancel()

        except CANCELLATION_ERRORS:
            self.debug("Connection cancelled by server")

        except asyncio.QueueFull:
            self.log_error("Client exceeded max pending messages [1]:",
                           MAX_PENDING_MSG)
            self._writer_task.cancel()

        except Exception:  # pylint: disable=broad-except
            error = "Unexpected error inside websocket API. "
            if msg is not None:
                error += str(msg)
            _LOGGER.exception(error)

        finally:
            unsub_stop()

            for unsub in self.event_listeners.values():
                unsub()

            try:
                if final_message is not None:
                    self.to_write.put_nowait(final_message)
                self.to_write.put_nowait(None)
                # Make sure all error messages are written before closing
                await self._writer_task
            except asyncio.QueueFull:
                self._writer_task.cancel()

            await wsock.close()
            self.debug("Closed connection")

        return wsock


@callback
def handle_subscribe_events(hass, connection, msg):
    """Handle subscribe events command.

    Async friendly.
    """
    async def forward_events(event):
        """Forward events to websocket."""
        if event.event_type == EVENT_TIME_CHANGED:
            return

        connection.send_message_outside(event_message(msg['id'], event))

    connection.event_listeners[msg['id']] = hass.bus.async_listen(
        msg['event_type'], forward_events)

    connection.to_write.put_nowait(result_message(msg['id']))


@callback
def handle_unsubscribe_events(hass, connection, msg):
    """Handle unsubscribe events command.

    Async friendly.
    """
    subscription = msg['subscription']

    if subscription in connection.event_listeners:
        connection.event_listeners.pop(subscription)()
        connection.to_write.put_nowait(result_message(msg['id']))
    else:
        connection.to_write.put_nowait(error_message(
            msg['id'], ERR_NOT_FOUND, 'Subscription not found.'))


@callback
def handle_call_service(hass, connection, msg):
    """Handle call service command.

    Async friendly.
    """
    async def call_service_helper(msg):
        """Call a service and fire complete message."""
        await hass.services.async_call(
            msg['domain'], msg['service'], msg.get('service_data'), True)
        connection.send_message_outside(result_message(msg['id']))

    hass.async_add_job(call_service_helper(msg))


@callback
def handle_get_states(hass, connection, msg):
    """Handle get states command.

    Async friendly.
    """
    connection.to_write.put_nowait(result_message(
        msg['id'], hass.states.async_all()))


@callback
def handle_get_services(hass, connection, msg):
    """Handle get services command.

    Async friendly.
    """
    async def get_services_helper(msg):
        """Get available services and fire complete message."""
        descriptions = await async_get_all_descriptions(hass)
        connection.send_message_outside(
            result_message(msg['id'], descriptions))

    hass.async_add_job(get_services_helper(msg))


@callback
def handle_get_config(hass, connection, msg):
    """Handle get config command.

    Async friendly.
    """
    connection.to_write.put_nowait(result_message(
        msg['id'], hass.config.as_dict()))


@callback
def handle_ping(hass, connection, msg):
    """Handle ping command.

    Async friendly.
    """
    connection.to_write.put_nowait(pong_message(msg['id']))