#!/usr/bin/python3

from datetime import datetime
import argparse
import logging
import subprocess
import uvicorn
from dbus import SystemBus, Interface, DBusException
from mcp.server.fastmcp import FastMCP
from pydantic import BaseModel


logging.basicConfig(filename = "mcp-server-snapper.log", level = logging.DEBUG)


class Snapshot(BaseModel):
    type: str
    number: int
    pre_number: int | None = None
    date: str | None = None
    description: str
    cleanup: str
    userdata: dict[str, str]


mcp = FastMCP("SnapperServer")


# TODO esp. when using http protocol the dbus connection to snapperd
# stays up forever


def get_snapper():
    """
    Create a snapper object
    """
    # TODO reuse - seems to work already
    bus = SystemBus()
    snapper = Interface(bus.get_object('org.opensuse.Snapper', '/org/opensuse/Snapper'),
                        dbus_interface = 'org.opensuse.Snapper')
    return snapper


@mcp.tool()
def list_configs() -> dict[str, str]:
    """
    Return the available snapper configs.
    :returns: Available snapper configs as a dictionary of key-value pairs with the config
              name as the key and the subvolume path as the value.
    :rtype: dict[str, str]
    """

    try:
        snapper = get_snapper()

        configs = snapper.ListConfigs()

        result: dict[str, str] = {}

        for config in configs:
            result[config[0]] = config[1]

        logging.info(f"list of snapper configs: {result}")

        return result

    except DBusException as e:
        logging.error(f"Snapper error: {e}")
        raise Exception("snapper error")


@mcp.tool()
def get_config(config: str) -> dict[str, str]:
    """
    Return the config values of a snapper config.
    :param config: Snapper config to use. Often 'root'. Use the list_configs tool to
           query all values.
    :returns: Config values of a snapper config as a dictionary of key-value pairs.
    :rtype: dict[str, str]
    """

    try:
        snapper = get_snapper()

        values = snapper.GetConfig(config)

        result: dict[str, str] = {}

        for k, v in values[2].items():
            result[str(k)] = str(v)

        logging.info(f"snapper config: {result}")

        return result

    except DBusException as e:
        logging.error(f"Snapper error: {e}")
        raise Exception("snapper error")


@mcp.tool()
def set_config(config: str, values: dict[str, str]) -> None:
    """
    List the configuration values of a snapper config.
    :param config: Snapper config to use. Often 'root'. Use the list_configs tool to
           query all values.
    :param values: List of key-value-pairs to set.
    """

    try:
        snapper = get_snapper()

        snapper.SetConfig(config, values)

    except DBusException as e:
        logging.error(f"Snapper error: {e}")
        raise Exception("snapper error")


@mcp.tool()
def list_snapshots(config: str) -> list[Snapshot]:
    """
    List file system snapshots using snapper.
    :param config: Snapper config to use. Often 'root'. Use the list_configs tool to
           query all values.
    :returns: Snapshots.
    :rtype: list[Snapshot]
    """

    try:
        snapper = get_snapper()

        snapshots = snapper.ListSnapshots(config)

        result: list[Snapshot] = []

        for snapshot in snapshots:

            tmp1 = tmp2 = tmp3 = None

            if snapshot[1] == 0:
                tmp1 = "single"
            elif snapshot[1] == 1:
                tmp1 = "pre"
            elif snapshot[1] == 2:
                tmp1 = "post"
                tmp2 = snapshot[2]

            if snapshot[3] != -1:
                tmp3 = datetime.fromtimestamp(snapshot[3]).strftime("%Y-%m-%d %H:%M:%S")

            result.append(Snapshot(type = tmp1, number = snapshot[0], pre_number = tmp2,
                                   date = tmp3, description = snapshot[5], cleanup = snapshot[6],
                                   userdata = snapshot[7]))

        logging.info(f"list of snapper snapshots: {result}")

        return result

    except DBusException as e:
        logging.error(f"Snapper error: {e}")
        raise Exception("snapper error")


@mcp.tool()
def create_snapshot(config: str, type: str, pre_number: int, description: str, cleanup: str,
                    userdata: dict[str, str]) -> int:
    """
    Create a file system snapshot using snapper.
    :param config: Snapper config to use. Often 'root'. Use the list_configs tool to
           query all values.
    :param type: Type for the snapshot, either 'single', 'pre' or 'post'.
    :param pre_number: Number of the corresponding pre snapshot. Required if type is 'post',
           otherwise ignored.
    :param description: Description for the snapshot.
    :param cleanup: Cleanup algorithm for the snapshot like 'number' or 'timeline'.
    :param userdata: List of key-value pairs.
    :returns: Number of the created snapshot.
    :rtype: int
    """

    try:
        snapper = get_snapper()

        if type == "single":
            number = snapper.CreateSingleSnapshot(config, description, cleanup, userdata)
        elif type == "pre":
            number = snapper.CreatePreSnapshot(config, description, cleanup, userdata)
        elif type == "post":
            number = snapper.CreatePostSnapshot(config, pre_number, description, cleanup, userdata)
        else:
            logging.error(f"Invalid snapshot type: {type}")
            raise Exception("invalid snapshot type")

        logging.info(f"snapper number of created snapshot: {number}")

        return number

    except DBusException as e:
        logging.error(f"Snapper error: {e}")
        raise Exception("snapper error")


@mcp.tool()
def delete_snapshots(config: str, numbers: list[int]) -> None:
    """
    Delete one or more file system snapshot using snapper.
    :param config: Snapper config to use. Often 'root'. Use the list_configs tool to
           query all values.
    :param numbers: The snapshot numbers to delete.
    """

    try:
        snapper = get_snapper()

        snapper.DeleteSnapshots(config, numbers)

    except DBusException as e:
        logging.error(f"Snapper error: {e}")
        raise Exception("snapper error")


def run_and_log_result(cmd_args):

    try:

        logging.info(cmd_args)

        result = subprocess.run(cmd_args, capture_output = True,
            text = True, check = False)

        if result.stdout.strip():
            logging.info(f"STDOUT:\n{result.stdout.strip()}")

        if result.stderr.strip():
            logging.error(f"STDERR:\n{result.stderr.strip()}")

        return result.returncode

    except FileNotFoundError:
        logging.error(f"Command not found: {cmd_args[0]}")
        return 127


@mcp.tool()
def rollback(config: str, number: int | None, description: str, cleanup: str,
             userdata: dict[str, str]) -> None:
    """
    Rollback to a snapshot.
    :param config: Snapper config to use. Often 'root'. Use the list_configs tool to
           query all values.
    :param numbers: Optionally the number of the snapshot to rollback to.
    :param description: Description for the new snapshot.
    :param cleanup: Cleanup algorithm for the new snapshot like 'number' or 'timeline'.
    :param userdata: List of key-value pairs.
    """

    try:

        cmd_args = [ "/usr/bin/snapper", "--config", config, "rollback" ]

        if description:
            cmd_args.extend([ "--description", description ])

        if cleanup:
            cmd_args.extend([ "--cleanup", cleanup ])

        if number:
            cmd_args.append(str(number))

        exit_code = run_and_log_result(cmd_args)

        if exit_code != 0:
            logging.error(f"Snapper error: {exit_code}")
            raise Exception("snapper error")

    except DBusException as e:
        logging.error(f"Snapper error: {e}")
        raise Exception("snapper error")


if __name__ == "__main__":
    logging.info("Server started")

    parser = argparse.ArgumentParser(description = "Run the snapper MCP server.")

    parser.add_argument("--transport", choices = [ "stdio", "http", "https" ], default = "stdio",
                        help = "Transport type (default: stdio)")

    parser.add_argument("--port", type = int, default = 8000, help = "Port for HTTP")

    parser.add_argument("--key", type = str, help = "Key for HTTPS")
    parser.add_argument("--cert", type = str, help = "Cert for HTTPS")

    args = parser.parse_args()

    # logger.setLevel(logging.INFO)

    if args.transport == "https":
        if not args.key or not args.cert:
            parser.error("--transport https requires both --key and --cert")

    if args.transport == "http":
        uvicorn.run(mcp.sse_app(), host = "0.0.0.0", port = args.port)
    elif args.transport == "https":
        uvicorn.run(mcp.sse_app(), host = "0.0.0.0", port = args.port, ssl_keyfile = args.key,
                    ssl_certfile = args.cert)
    else:
        mcp.run(transport = "stdio")
