Source code for discord.ui.modal

from __future__ import annotations

import asyncio
import os
import sys
import time
import traceback
from functools import partial
from itertools import groupby
from typing import TYPE_CHECKING, Any, Callable

from .input_text import InputText

__all__ = (
    "Modal",
    "ModalStore",
)


if TYPE_CHECKING:
    from ..interactions import Interaction
    from ..state import ConnectionState





class _ModalWeights:
    __slots__ = ("weights",)

    def __init__(self, children: list[InputText]):
        self.weights: list[int] = [0, 0, 0, 0, 0]

        key = lambda i: sys.maxsize if i.row is None else i.row
        children = sorted(children, key=key)
        for row, group in groupby(children, key=key):
            for item in group:
                self.add_item(item)

    def find_open_space(self, item: InputText) -> int:
        for index, weight in enumerate(self.weights):
            if weight + item.width <= 5:
                return index

        raise ValueError("could not find open space for item")

    def add_item(self, item: InputText) -> None:
        if item.row is not None:
            total = self.weights[item.row] + item.width
            if total > 5:
                raise ValueError(
                    f"item would not fit at row {item.row} ({total} > 5 width)"
                )
            self.weights[item.row] = total
            item._rendered_row = item.row
        else:
            index = self.find_open_space(item)
            self.weights[index] += item.width
            item._rendered_row = index

    def remove_item(self, item: InputText) -> None:
        if item._rendered_row is not None:
            self.weights[item._rendered_row] -= item.width
            item._rendered_row = None

    def clear(self) -> None:
        self.weights = [0, 0, 0, 0, 0]


class ModalStore:
    def __init__(self, state: ConnectionState) -> None:
        # (user_id, custom_id) : Modal
        self._modals: dict[tuple[int, str], Modal] = {}
        self._state: ConnectionState = state

    def add_modal(self, modal: Modal, user_id: int):
        self._modals[(user_id, modal.custom_id)] = modal
        modal._start_listening_from_store(self)

    def remove_modal(self, modal: Modal, user_id):
        modal.stop()
        self._modals.pop((user_id, modal.custom_id))

    async def dispatch(self, user_id: int, custom_id: str, interaction: Interaction):
        key = (user_id, custom_id)
        value = self._modals.get(key)
        if value is None:
            return

        try:
            components = [
                component
                for parent_component in interaction.data["components"]
                for component in parent_component["components"]
            ]
            for component in components:
                for child in value.children:
                    if child.custom_id == component["custom_id"]:  # type: ignore
                        child.refresh_state(component)
                        break
            await value.callback(interaction)
            self.remove_modal(value, user_id)
        except Exception as e:
            return await value.on_error(e, interaction)