#!/usr/bin/env python3

# MIT License
#
# Copyright (c) 2025 Onyx and Iris
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""
sendtext.py

This script provides functionality to send text messages over a network using the VBAN protocol.
It includes classes and functions to construct and send VBAN packets with text data, read configuration
from a TOML file, and handle command-line arguments for customization.

Classes:
    RTHeader: A dataclass representing the header of a VBAN packet, with methods to convert it to bytes.

Functions:
    ratelimit(func): A decorator to enforce a rate limit on a function.
    send(sock, args, cmd, framecounter): Sends a text message using the provided socket and packet.
    parse_args(): Parses command-line arguments and returns them as an argparse.Namespace object.
    main(args): Main function to send text using VbanSendText.

Usage:
    Run the script with appropriate command-line arguments to send text messages.
    Example:
        To send a single message:
            python sendtext.py --host localhost --port 6980 --streamname Command1 --bps 256000 --channel 1 "strip[0].mute=1"
        To send multiple messages:
            python sendtext.py --host localhost --port 6980 --streamname Command1 --bps 256000 --channel 1 "strip[0].mute=1" "strip[1].mute=0"
"""

import argparse
import functools
import logging
import socket
import time
from dataclasses import dataclass, field

logger = logging.getLogger(__name__)


def ratelimit(func):
    """
    Decorator to enforce a rate limit on a function.

    This decorator extracts the rate limit value from the 'args.ratelimit' parameter
    of the decorated function and ensures that the function is not called more
    frequently than the specified rate limit.

    Args:
        func: The function to be rate limited. Must accept 'args' as a parameter
              with a 'ratelimit' attribute.

    Returns:
        The decorated function with rate limiting applied.
    """
    last_call_time = [0.0]  # Use list to make it mutable in closure

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        # Extract args.ratelimit from function parameters
        rate_limit = getattr(args[1], "ratelimit", 0.0) if len(args) > 1 else 0.0

        current_time = time.time()
        time_since_last_call = current_time - last_call_time[0]

        if time_since_last_call < rate_limit:
            sleep_time = rate_limit - time_since_last_call
            logger.debug(f"Rate limiting: sleeping for {sleep_time:.3f} seconds")
            time.sleep(sleep_time)

        last_call_time[0] = time.time()
        return func(*args, **kwargs)

    return wrapper


@dataclass
class RTHeader:
    """RTHeader represents the header of a VBAN packet for sending text messages.
    It includes fields for the stream name, bits per second, channel, and frame counter,
    as well as methods to convert these fields into the appropriate byte format for transmission.

    Attributes:
        name (str): The name of the VBAN stream (max 16 characters).
        bps (int): The bits per second for the VBAN stream, must be one of the predefined options.
        channel (int): The channel number for the VBAN stream.
        framecounter (int): A counter for the frames being sent, default is 0.
        BPS_OPTS (list[int]): A list of valid bits per second options for VBAN streams.

    Methods:
        __post_init__(): Validates the stream name length and bits per second value.
        vban(): Returns the VBAN header as bytes.
        sr(): Returns the sample rate byte based on the bits per second index.
        nbs(): Returns the number of bits per sample byte (currently set to 0).
        nbc(): Returns the number of channels byte based on the channel attribute.
        bit(): Returns the bit depth byte (currently set to 0x10).
        streamname(): Returns the stream name as bytes, padded to 16 bytes.
        to_bytes(name, bps, channel, framecounter): Class method to create a byte representation of the RTHeader with the given parameters.

    Raises:
        ValueError: If the stream name exceeds 16 characters or if the bits per second value is not in the predefined options.
    """

    name: str
    bps: int
    channel: int
    framecounter: int = 0
    # fmt: off
    BPS_OPTS: list[int] = field(default_factory=lambda: [
        0, 110, 150, 300, 600, 1200, 2400, 4800, 9600, 14400, 19200, 31250,
        38400, 57600, 115200, 128000, 230400, 250000, 256000, 460800, 921600,
        1000000, 1500000, 2000000, 3000000
    ])
    # fmt: on

    def __post_init__(self):
        if len(self.name) > 16:
            raise ValueError(
                f"Stream name got: '{self.name}', want: must be 16 characters or fewer"
            )
        try:
            self.bps_index = self.BPS_OPTS.index(self.bps)
        except ValueError as e:
            ERR_MSG = f"Invalid bps: {self.bps}, must be one of {self.BPS_OPTS}"
            e.add_note(ERR_MSG)
            raise

    @property
    def vban(self) -> bytes:
        return b"VBAN"

    @property
    def sr(self) -> bytes:
        return (0x40 + self.bps_index).to_bytes(1, "little")

    @property
    def nbs(self) -> bytes:
        return (0).to_bytes(1, "little")

    @property
    def nbc(self) -> bytes:
        return (self.channel).to_bytes(1, "little")

    @property
    def bit(self) -> bytes:
        return (0x10).to_bytes(1, "little")

    @property
    def streamname(self) -> bytes:
        return self.name.encode() + bytes(16 - len(self.name))

    @classmethod
    def to_bytes(cls, name: str, bps: int, channel: int, framecounter: int) -> bytes:
        header = cls(name=name, bps=bps, channel=channel, framecounter=framecounter)

        data = bytearray()
        data.extend(header.vban)
        data.extend(header.sr)
        data.extend(header.nbs)
        data.extend(header.nbc)
        data.extend(header.bit)
        data.extend(header.streamname)
        data.extend(header.framecounter.to_bytes(4, "little"))
        return bytes(data)


@ratelimit
def send(
    sock: socket.socket, args: argparse.Namespace, cmd: str, framecounter: int
) -> None:
    """
    Send a text message using the provided socket and packet.

    Args:
        sock (socket.socket): The socket to use for sending the message.
        args (argparse.Namespace): The command-line arguments containing the stream name, bits per second, and channel information.
        cmd (str): The text command to send.
        framecounter (int): The frame counter to include in the VBAN header.

    Returns:
        None
    """

    raw_packet = RTHeader.to_bytes(
        name=args.streamname,
        bps=args.bps,
        channel=args.channel,
        framecounter=framecounter,
    ) + cmd.encode("utf-8")
    logger.debug("Sending packet: %s", raw_packet)
    sock.sendto(raw_packet, (args.host, args.port))


def parse_args() -> argparse.Namespace:
    """
    Parse command-line arguments.
    Returns:
        argparse.Namespace: Parsed command-line arguments.
    Command-line arguments:
        --host, -H: VBAN host to send to (default: localhost)
        --port, -P: VBAN port to send to (default: 6980)
        --streamname, -s: VBAN stream name (default: Command1)
        --bps, -b: Bits per second for VBAN stream (default: 256000)
        --channel, -c: Channel number for VBAN stream (default: 1)
        text: Text to send (positional argument)
    """

    parser = argparse.ArgumentParser(description="Voicemeeter VBAN Send Text CLI")
    parser.add_argument(
        "--host",
        "-H",
        type=str,
        default="localhost",
        help="VBAN host to send to (default: localhost)",
    )
    parser.add_argument(
        "--port",
        "-P",
        type=int,
        default=6980,
        help="VBAN port to send to (default: 6980)",
    )
    parser.add_argument(
        "--streamname",
        "-s",
        type=str,
        default="Command1",
        help="VBAN stream name (default: Command1)",
    )
    parser.add_argument(
        "--bps",
        "-b",
        type=int,
        default=256000,
        help="Bits per second for VBAN stream (default: 256000)",
    )
    parser.add_argument(
        "--channel",
        "-c",
        type=int,
        default=1,
        help="Channel number for VBAN stream (default: 1)",
    )
    parser.add_argument(
        "--ratelimit",
        "-r",
        type=float,
        default=0.2,
        help="Minimum time in seconds between sending messages (default: 0.2)",
    )
    parser.add_argument(
        "--loglevel",
        "-l",
        type=str,
        default="info",
        choices=["debug", "info", "warning", "error", "critical"],
        help="Set the logging level (default: info)",
    )
    parser.add_argument(
        "cmds", nargs="+", type=str, help="Text to send (positional argument)"
    )
    return parser.parse_args()


def main(args: argparse.Namespace):
    """
    Main function to send text using VBAN.
    Args:
        args (argparse.Namespace): Parsed command-line arguments.
    Behavior:
        Creates a UDP socket and sends each command in 'args.cmds' to the specified VBAN host and port using the 'send' function, with rate limiting applied.
    """

    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
        for n, cmd in enumerate(args.cmds):
            logger.debug(f"Sending: {cmd}")
            send(sock, args, cmd, n)


if __name__ == "__main__":
    args = parse_args()

    logging.basicConfig(level=getattr(logging, args.loglevel.upper()))

    main(args)