from __future__ import annotations

import asyncio

from abc import ABC, abstractmethod
from ssl import SSLContext
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Optional,
    Union,
    cast,
)

from sanic.compat import Header
from sanic.constants import LocalCertCreator
from sanic.exceptions import (
    BadRequest,
    PayloadTooLarge,
    SanicException,
    ServerError,
)
from sanic.helpers import has_message_body
from sanic.http.constants import Stage
from sanic.http.stream import Stream
from sanic.http.tls.context import CertSelector, SanicSSLContext
from sanic.log import Colors, logger
from sanic.models.protocol_types import TransportProtocol
from sanic.models.server_types import ConnInfo


try:
    from aioquic.h0.connection import H0_ALPN, H0Connection
    from aioquic.h3.connection import H3_ALPN, H3Connection
    from aioquic.h3.events import (
        DatagramReceived,
        DataReceived,
        H3Event,
        HeadersReceived,
        WebTransportStreamDataReceived,
    )
    from aioquic.quic.configuration import QuicConfiguration
    from aioquic.tls import SessionTicket

    HTTP3_AVAILABLE = True
except ModuleNotFoundError:  # no cov
    HTTP3_AVAILABLE = False

if TYPE_CHECKING:
    from sanic import Sanic
    from sanic.request import Request
    from sanic.response import BaseHTTPResponse
    from sanic.server.protocols.http_protocol import Http3Protocol

    HttpConnection = Union[H0Connection, H3Connection]


class HTTP3Transport(TransportProtocol):
    """HTTP/3 transport implementation."""

    __slots__ = ("_protocol",)

    def __init__(self, protocol: Http3Protocol):
        self._protocol = protocol

    def get_protocol(self) -> Http3Protocol:
        return self._protocol

    def get_extra_info(self, info: str, default: Any = None) -> Any:
        if (
            info in ("socket", "sockname", "peername")
            and self._protocol._transport
        ):
            return self._protocol._transport.get_extra_info(info, default)
        elif info == "network_paths":
            return self._protocol._quic._network_paths
        elif info == "ssl_context":
            return self._protocol.app.state.ssl
        return default


class Receiver(ABC):
    """HTTP/3 receiver base class."""

    future: asyncio.Future

    def __init__(self, transmit, protocol, request: Request) -> None:
        self.transmit = transmit
        self.protocol = protocol
        self.request = request

    @abstractmethod
    async def run(self):  # no cov
        ...


class HTTPReceiver(Receiver, Stream):
    """HTTP/3 receiver implementation."""

    stage: Stage
    request: Request

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.request_body = None
        self.stage = Stage.IDLE
        self.headers_sent = False
        self.response: Optional[BaseHTTPResponse] = None
        self.request_max_size = self.protocol.request_max_size
        self.request_bytes = 0

    async def run(self, exception: Optional[Exception] = None):
        """Handle the request and response cycle."""
        self.stage = Stage.HANDLER
        self.head_only = self.request.method.upper() == "HEAD"

        if exception:
            logger.info(  # no cov
                f"{Colors.BLUE}[exception]: "
                f"{Colors.RED}{exception}{Colors.END}",
                exc_info=True,
                extra={"verbosity": 1},
            )
            await self.error_response(exception)
        else:
            try:
                logger.info(  # no cov
                    f"{Colors.BLUE}[request]:{Colors.END} {self.request}",
                    extra={"verbosity": 1},
                )
                await self.protocol.request_handler(self.request)
            except Exception as e:  # no cov
                # This should largely be handled within the request handler.
                # But, just in case...
                await self.run(e)
        self.stage = Stage.IDLE

    async def error_response(self, exception: Exception) -> None:
        """Handle response when exception encountered"""
        # From request and handler states we can respond, otherwise be silent
        app = self.protocol.app

        await app.handle_exception(self.request, exception)

    def _prepare_headers(
        self, response: BaseHTTPResponse
    ) -> list[tuple[bytes, bytes]]:
        size = len(response.body) if response.body else 0
        headers = response.headers
        status = response.status

        if not has_message_body(status) and (
            size
            or "content-length" in headers
            or "transfer-encoding" in headers
        ):
            headers.pop("content-length", None)
            headers.pop("transfer-encoding", None)
            logger.warning(  # no cov
                f"Message body set in response on {self.request.path}. "
                f"A {status} response may only have headers, no body."
            )
        elif "content-length" not in headers:
            if size:
                headers["content-length"] = size
            else:
                headers["transfer-encoding"] = "chunked"

        headers = [
            (b":status", str(response.status).encode()),
            *response.processed_headers,
        ]
        return headers

    def send_headers(self) -> None:
        """Send response headers to client"""
        logger.debug(  # no cov
            f"{Colors.BLUE}[send]: {Colors.GREEN}HEADERS{Colors.END}",
            extra={"verbosity": 2},
        )
        if not self.response:
            raise RuntimeError("no response")

        response = self.response
        headers = self._prepare_headers(response)

        self.protocol.connection.send_headers(
            stream_id=self.request.stream_id,
            headers=headers,
        )
        self.headers_sent = True
        self.stage = Stage.RESPONSE

        if self.response.body and not self.head_only:
            self._send(self.response.body, False)
        elif self.head_only:
            self.future.cancel()

    def respond(self, response: BaseHTTPResponse) -> BaseHTTPResponse:
        """Prepare response to client"""
        logger.debug(  # no cov
            f"{Colors.BLUE}[respond]:{Colors.END} {response}",
            extra={"verbosity": 2},
        )

        if self.stage is not Stage.HANDLER:
            self.stage = Stage.FAILED
            raise RuntimeError("Response already started")

        # Disconnect any earlier but unused response object
        if self.response is not None:
            self.response.stream = None

        self.response, response.stream = response, self

        return response

    def receive_body(self, data: bytes) -> None:
        """Receive request body from client"""
        self.request_bytes += len(data)
        if self.request_bytes > self.request_max_size:
            raise PayloadTooLarge("Request body exceeds the size limit")

        self.request.body += data

    async def send(self, data: bytes, end_stream: bool) -> None:
        """Send data to client"""
        logger.debug(  # no cov
            f"{Colors.BLUE}[send]: {Colors.GREEN}data={data.decode()} "
            f"end_stream={end_stream}{Colors.END}",
            extra={"verbosity": 2},
        )
        self._send(data, end_stream)

    def _send(self, data: bytes, end_stream: bool) -> None:
        if not self.headers_sent:
            self.send_headers()
        if self.stage is not Stage.RESPONSE:
            raise ServerError(f"not ready to send: {self.stage}")

        # Chunked
        if (
            self.response
            and self.response.headers.get("transfer-encoding") == "chunked"
        ):
            size = len(data)
            if end_stream:
                data = (
                    b"%x\r\n%b\r\n0\r\n\r\n" % (size, data)
                    if size
                    else b"0\r\n\r\n"
                )
            elif size:
                data = b"%x\r\n%b\r\n" % (size, data)

        logger.debug(  # no cov
            f"{Colors.BLUE}[transmitting]{Colors.END}",
            extra={"verbosity": 2},
        )
        self.protocol.connection.send_data(
            stream_id=self.request.stream_id,
            data=data,
            end_stream=end_stream,
        )
        self.transmit()

        if end_stream:
            self.stage = Stage.IDLE


class WebsocketReceiver(Receiver):  # noqa
    """Websocket receiver implementation."""

    async def run(self): ...


class WebTransportReceiver(Receiver):  # noqa
    """WebTransport receiver implementation."""

    async def run(self): ...


class Http3:
    """Internal helper for managing the HTTP/3 request/response cycle"""

    if HTTP3_AVAILABLE:
        HANDLER_PROPERTY_MAPPING = {
            DataReceived: "stream_id",
            HeadersReceived: "stream_id",
            DatagramReceived: "flow_id",
            WebTransportStreamDataReceived: "session_id",
        }

    def __init__(
        self,
        protocol: Http3Protocol,
        transmit: Callable[[], None],
    ) -> None:
        self.protocol = protocol
        self.transmit = transmit
        self.receivers: dict[int, Receiver] = {}

    def http_event_received(self, event: H3Event) -> None:
        logger.debug(  # no cov
            f"{Colors.BLUE}[http_event_received]: "
            f"{Colors.YELLOW}{event}{Colors.END}",
            extra={"verbosity": 2},
        )
        receiver, created_new = self.get_or_make_receiver(event)
        receiver = cast(HTTPReceiver, receiver)

        if isinstance(event, HeadersReceived) and created_new:
            receiver.future = asyncio.ensure_future(receiver.run())
        elif isinstance(event, DataReceived):
            try:
                receiver.receive_body(event.data)
            except Exception as e:
                receiver.future.cancel()
                receiver.future = asyncio.ensure_future(receiver.run(e))
        else:
            ...  # Intentionally here to help out Touchup
            logger.debug(  # no cov
                f"{Colors.RED}DOING NOTHING{Colors.END}",
                extra={"verbosity": 2},
            )

    def get_or_make_receiver(self, event: H3Event) -> tuple[Receiver, bool]:
        if (
            isinstance(event, HeadersReceived)
            and event.stream_id not in self.receivers
        ):
            request = self._make_request(event)
            receiver = HTTPReceiver(self.transmit, self.protocol, request)
            request.stream = receiver

            self.receivers[event.stream_id] = receiver
            return receiver, True
        else:
            ident = getattr(event, self.HANDLER_PROPERTY_MAPPING[type(event)])
            return self.receivers[ident], False

    def get_receiver_by_stream_id(self, stream_id: int) -> Receiver:
        return self.receivers[stream_id]

    def _make_request(self, event: HeadersReceived) -> Request:
        try:
            headers = Header(
                (
                    (k.decode("ASCII"), v.decode(errors="surrogateescape"))
                    for k, v in event.headers
                )
            )
        except UnicodeDecodeError:
            raise BadRequest(
                "Header names may only contain US-ASCII characters."
            )
        method = headers[":method"]
        path = headers[":path"]
        scheme = headers.pop(":scheme", "")
        authority = headers.pop(":authority", "")

        if authority:
            headers["host"] = authority

        try:
            url_bytes = path.encode("ASCII")
        except UnicodeEncodeError:
            raise BadRequest("URL may only contain US-ASCII characters.")

        transport = HTTP3Transport(self.protocol)
        request = self.protocol.request_class(
            url_bytes,
            headers,
            "3",
            method,
            transport,
            self.protocol.app,
            b"",
        )
        request.conn_info = ConnInfo(transport)
        request._stream_id = event.stream_id
        request._scheme = scheme

        return request


class SessionTicketStore:
    """
    Simple in-memory store for session tickets.
    """

    def __init__(self) -> None:
        self.tickets: dict[bytes, SessionTicket] = {}

    def add(self, ticket: SessionTicket) -> None:
        self.tickets[ticket.ticket] = ticket

    def pop(self, label: bytes) -> Optional[SessionTicket]:
        return self.tickets.pop(label, None)


def get_config(
    app: Sanic, ssl: Union[SanicSSLContext, CertSelector, SSLContext]
):
    # TODO:
    # - proper selection needed if service with multiple certs insted of
    #   just taking the first
    if isinstance(ssl, CertSelector):
        ssl = cast(SanicSSLContext, ssl.sanic_select[0])
    if app.config.LOCAL_CERT_CREATOR is LocalCertCreator.TRUSTME:
        raise SanicException(
            "Sorry, you cannot currently use trustme as a local certificate "
            "generator for an HTTP/3 server. This is not yet supported. You "
            "should be able to use mkcert instead. For more information, see: "
            "https://github.com/aiortc/aioquic/issues/295."
        )
    if not isinstance(ssl, SanicSSLContext):
        raise SanicException("SSLContext is not SanicSSLContext")

    config = QuicConfiguration(
        alpn_protocols=H3_ALPN + H0_ALPN + ["siduck"],
        is_client=False,
        max_datagram_frame_size=65536,
    )
    password = app.config.TLS_CERT_PASSWORD or None

    config.load_cert_chain(
        ssl.sanic["cert"], ssl.sanic["key"], password=password
    )

    return config
