#!/usr/bin/python3
# -*- coding: utf-8 -*-
'''
    saltbroker: A ZeroMQ Proxy (broker) for Salt Minions

    The main process spawns a process for each channel of Salt ZMQ transport:

    - PubChannelProxy process provides the PUB channel for the minions
    - RetChannelProxy process provides the RET channel for the minions

    Also acts like a supervisor for the child process, respawning them if they die.

    :depends:   python-PyYAML
    :depends:   python-pyzmq

    Copyright (c) 2016 SUSE LINUX Products GmbH, Nuernberg, Germany.

    All modifications and additions to the file contributed by third parties
    remain the property of their copyright owners, unless otherwise agreed
    upon. The license for this file, and modifications and additions to the
    file, is the same license as for the pristine package itself (unless the
    license for the pristine package is not an Open Source License, in which
    case the license is the MIT License). An "Open Source License" is a
    license that conforms to the Open Source Definition (Version 1.9)
    published by the Open Source Initiative.

    Please submit bugfixes or comments via http://bugs.opensuse.org/
'''

# Import python libs
import logging
import logging.handlers
import multiprocessing
import os
import signal
import socket
import sys
import threading
import time
import traceback
import yaml

# Import RHN libs
from spacewalk.common.rhnConfig import RHNOptions

# Import pyzmq lib
import zmq

from zmq.utils.monitor import recv_monitor_message


RHN_CONF_FILE = "/etc/rhn/rhn.conf"
SALT_BROKER_CONF_FILE = "/etc/salt/broker"
SALT_BROKER_LOGFILE = "/var/log/salt/broker"
SUPERVISOR_TIMEOUT = 5

log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)


class AbstractChannelProxy(multiprocessing.Process):
    """
    Abstract class for ChannelProxy objects
    """

    class ChannelException(Exception):
        """
        Custom Exception definition
        """

        pass

    _BACKEND_SOCKOPTS = (
        (zmq.TCP_KEEPALIVE, "tcp_keepalive"),
        (zmq.TCP_KEEPALIVE_IDLE, "tcp_keepalive_idle"),
        (zmq.TCP_KEEPALIVE_CNT, "tcp_keepalive_cnt"),
        (zmq.TCP_KEEPALIVE_INTVL, "tcp_keepalive_intvl"),
        (zmq.CONNECT_TIMEOUT, "connect_timeout"),
        (zmq.RECONNECT_IVL, "reconnect_ivl"),
        (zmq.HEARTBEAT_IVL, "heartbeat_ivl"),
        (zmq.HEARTBEAT_TIMEOUT, "heartbeat_timeout"),
    )
    _FRONTEND_SOCKOPTS = ()

    def __init__(self, opts):
        self.opts = opts
        self.backend_connected = False
        self.backend_ever_connected = False
        if "master" not in self.opts:
            raise self.ChannelException(
                '[{}] No "master" opts is provided'.format(
                    self.__class__.__name__
                )
            )
        try:
            self.opts["master_ip"] = socket.gethostbyname(self.opts["master"])
        except socket.gaierror as exc:
            raise self.ChannelException(
                "[{}] Error trying to resolve '{}': {}".format(
                    self.__class__.__name__, self.opts["master"], exc
                )
            )
        super().__init__()

    def run(self):
        try:
            context = zmq.Context()

            log.debug(
                "Setting up a {} sock on {}".format(
                    self.backend_type, self._backend_uri
                )
            )
            self.backend = context.socket(self._backend_sock_type)
            self.set_sockopts(self.backend, self._BACKEND_SOCKOPTS)

            self.reconnect_retries = self.opts["drop_after_retries"]
            if self.reconnect_retries != -1:
                self.monitor_socket = self.backend.get_monitor_socket()
                self.monitor_thread = threading.Thread(
                    target=self.backend_socket_monitor, args=(self.monitor_socket,)
                )
                self.monitor_thread.start()

            self.backend.connect(self._backend_uri)

            if self.opts["wait_for_backend"]:
                while not self.backend_connected:
                    time.sleep(0.5)

            log.debug(
                "Setting up a {} sock on {}".format(
                    self.frontend_type, self._frontend_uri
                )
            )

            self.frontend = context.socket(self._frontend_sock_type)
            self.set_sockopts(self.frontend, self._FRONTEND_SOCKOPTS)

            self.frontend.bind(self._frontend_uri)

            # Forward all messages
            zmq.proxy(self.frontend, self.backend)

        except zmq.ZMQError as zmq_error:
            if self.reconnect_retries == 0:
                # Do not raise error if drop_after_retries was used
                return
            msg = "ZMQ Error: {}".format(zmq_error)
            log.error(msg)
            raise self.ChannelException(msg)

        except Exception as exc:
            log.error("Exception: %s", exc)
            log.debug("Traceback: %s", traceback.format_exc())

    def set_sockopts(self, socket, sockopts):
        for opt, opt_src in sockopts:
            if opt_src in self.opts:
                socket.setsockopt(
                    opt,
                    self.opts[opt_src],
                )

    def backend_socket_monitor(self, monitor_socket):
        while monitor_socket.poll():
            mon_evt = recv_monitor_message(monitor_socket)
            if self.reconnect_retries != -1:
                if mon_evt["event"] == zmq.EVENT_DISCONNECTED:
                    log.warning("{} socket disconnected".format(self.backend_type))
                    self.backend_connected = False
                elif mon_evt["event"] == zmq.EVENT_CONNECTED:
                    log.info("{} socket connected".format(self.backend_type))
                    self.backend_connected = True
                    self.backend_ever_connected = True
                    self.reconnect_retries = self.opts["drop_after_retries"]
                elif mon_evt["event"] == zmq.EVENT_CONNECT_RETRIED:
                    if (
                        not self.backend_ever_connected
                        and self.opts["wait_for_backend"]
                    ):
                        continue
                    if self.reconnect_retries == 0:
                        log.warning(
                            "Closing {} socket due to retry attempts reached!".format(
                                self.backend_type
                            )
                        )
                        self.backend.close()
                        break
                    else:
                        self.reconnect_retries -= 1
            if mon_evt["event"] == zmq.EVENT_MONITOR_STOPPED:
                break
        monitor_socket.close()

    def terminate(self):
        """
        custom terminate function for the child process
        """
        log.info("Terminate called. Exiting")
        super().terminate()


class PubChannelProxy(AbstractChannelProxy):
    """
    Salt PUB Channel Proxy.

    Subscribes to the zmq PUB channel in the Salt master and binds a zmq SUB
    socket that allows minion to subscribe it and receive the forwarded
    messages from the Salt master.
    """

    # Prevent stopping publishing messages on XPUB socket. (bsc#1182954)
    _FRONTEND_SOCKOPTS = (
        (zmq.XPUB_VERBOSE, (1,)),
        (zmq.XPUB_VERBOSER, (1,)),
    )

    def __init__(self, opts):
        super().__init__(opts)
        self.name = "PubChannelProxy"

        self._backend_sock_type = zmq.XSUB
        self._frontend_sock_type = zmq.XPUB

        self.backend_type = "XSUB"
        self.frontend_type = "XPUB"

        self._backend_uri = "tcp://{}:{}".format(
            self.opts["master_ip"],
            self.opts["publish_port"],
        )
        self._frontend_uri = "tcp://{}:{}".format(
            self.opts["interface"],
            self.opts["publish_port"],
        )


class RetChannelProxy(AbstractChannelProxy):
    """
    Salt RET Channel Proxy.

    Connects to the zmq RET channel in the Salt master and binds a zmq ROUTER
    socket to receive messages from minions which are then forwarded to
    the Salt master.
    """
    def __init__(self, opts):
        super().__init__(opts)
        self.name = "RetChannelProxy"

        self._backend_sock_type = zmq.DEALER
        self._frontend_sock_type = zmq.ROUTER

        self.backend_type = "DEALER"
        self.frontend_type = "ROUTER"

        self._backend_uri = "tcp://{}:{}".format(
            self.opts["master_ip"],
            self.opts["ret_port"],
        )
        self._frontend_uri = "tcp://{}:{}".format(
            self.opts["interface"],
            self.opts["ret_port"],
        )


class SaltBroker(object):
    '''
    Creates a SaltBroker that forward messages and responses from
    minions to Salt Master by creating a ZeroMQ proxy that manage
    the PUB/RET channels of the Salt ZMQ transport.
    '''
    def __init__(self, opts):
        log.debug("Readed config: %s", opts)
        self.opts = opts
        self.exit = False
        self.default_sigterm = signal.getsignal(signal.SIGTERM)
        self.pub_proxy_proc = None
        self.ret_proxy_proc = None
        super().__init__()

    def _start_pub_proxy(self):
        """
        Spawn a new PubChannelProxy process
        """
        # setting up the default SIGTERM handler for the new process
        signal.signal(signal.SIGTERM, self.default_sigterm)

        # Spawn a new PubChannelProxy process
        pub_proxy = PubChannelProxy(opts=self.opts)
        pub_proxy.start()

        # setting up again the custom SIGTERM handler
        signal.signal(signal.SIGTERM, self.sigterm_clean)

        log.info("Spawning PUB channel proxy process [PID: %s]", pub_proxy.pid)

        return pub_proxy

    def _start_ret_proxy(self):
        """
        Spawn a new RetChannelProxy process
        """
        # setting up the default SIGTERM handler for the new process
        signal.signal(signal.SIGTERM, self.default_sigterm)

        # Spawn a new RetChannelProxy process
        ret_proxy = RetChannelProxy(opts=self.opts)
        ret_proxy.start()

        # setting up again the custom SIGTERM handler
        signal.signal(signal.SIGTERM, self.sigterm_clean)

        log.info("Spawning RET channel proxy process [PID: %s]", ret_proxy.pid)

        return ret_proxy

    def sigterm_clean(self, signum, frame):
        '''
        Custom SIGTERM handler
        '''
        log.info("Caught signal %s, stopping all channels", signum)

        if self.pub_proxy_proc:
            self.pub_proxy_proc.terminate()
        if self.ret_proxy_proc:
            self.ret_proxy_proc.terminate()

        self.exit = True
        log.info("Terminating main process")

    def start(self):
        '''
        Starts a SaltBroker. It spawns the PubChannelProxy and
        RetChannelProxy processes and also acts like a supervisor
        of these child process respawning them if they died.
        '''
        log.info("Starting Salt ZeroMQ Proxy [PID: %s]", os.getpid())

        # Attach a handler for SIGTERM signal
        signal.signal(signal.SIGTERM, self.sigterm_clean)

        try:
            self.pub_proxy_proc = self._start_pub_proxy()
            self.ret_proxy_proc = self._start_ret_proxy()
        except AbstractChannelProxy.ChannelException as exc:
            log.error("Exception: %s", exc)
            log.error("Exiting")
            sys.exit(exc)

        # Supervisor. Restart a channel if died
        while not self.exit:
            if not self.pub_proxy_proc.is_alive():
                log.error("PUB channel proxy has died. Respawning")
                self.pub_proxy_proc = self._start_pub_proxy()
            if not self.ret_proxy_proc.is_alive():
                log.error("RET channel proxy has died. Respawning")
                self.ret_proxy_proc = self._start_ret_proxy()
            time.sleep(SUPERVISOR_TIMEOUT)

if __name__ == "__main__":
    # Try to get config from /etc/rhn/rhn.conf
    rhn_parent = None
    rhn_proxy_conf = RHNOptions(component="proxy")
    rhn_proxy_conf.parse()
    if rhn_proxy_conf.get("rhn_parent"):
        log.debug("Using 'rhn_parent' from /etc/rhn/rhn.conf as 'master'")
        rhn_parent = rhn_proxy_conf["rhn_parent"]

    # Check for the config file
    if not os.path.isfile(SALT_BROKER_CONF_FILE):
        sys.exit("Config file not found: {0}".format(SALT_BROKER_CONF_FILE))

    # default config
    _DEFAULT_OPTS = {
        "publish_port": "4505",
        "ret_port": "4506",
        "interface": "0.0.0.0",
        "tcp_keepalive": True,
        "tcp_keepalive_idle": 300,
        "tcp_keepalive_cnt": -1,
        "tcp_keepalive_intvl": -1,
        "log_to_file": 1,
        "connect_timeout": 0,
        "reconnect_ivl": 100,
        "heartbeat_ivl": 0,
        "heartbeat_timeout": 0,
        "drop_after_retries": -1,
        "wait_for_backend": False,
    }

    try:
        config = yaml.load(open(SALT_BROKER_CONF_FILE), Loader=yaml.SafeLoader)
        if not config:
            config = {}
        if not isinstance(config, dict):
            sys.exit("Bad format in config file: {0}".format(SALT_BROKER_CONF_FILE))

        saltbroker_opts = _DEFAULT_OPTS.copy()

        if rhn_parent:
            saltbroker_opts.update({"master": rhn_parent})

        saltbroker_opts.update(config)

        formatter = logging.Formatter(
            "%(asctime)s [%(levelname)-8s][%(processName)-16s][%(process)s] %(message)s",
        )
        # log to file or to standard output and error depending on the configuration
        if saltbroker_opts.get('log_to_file'):
            fileloghandler = logging.handlers.RotatingFileHandler(
                SALT_BROKER_LOGFILE, maxBytes=200000, backupCount=5)
            fileloghandler.setFormatter(formatter)
            log.addHandler(fileloghandler)
        else:
            # prepare two log handlers, 1 for stdout and 1 for stderr
            stdout_handler = logging.StreamHandler(sys.stdout)
            stderr_handler = logging.StreamHandler(sys.stderr)
            # stdout handler filters out everything above the ERROR level included
            stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
            # stderror handler looks only for everything above the ERROR level included
            stderr_handler.setLevel(logging.ERROR)
            # same format for both handlers
            stdout_handler.setFormatter(formatter)
            stderr_handler.setFormatter(formatter)
            # add handlers to log Object
            log.addHandler(stdout_handler)
            log.addHandler(stderr_handler)

        proxy = SaltBroker(opts=saltbroker_opts)
        proxy.start()

    except yaml.scanner.ScannerError as exc:
        sys.exit("Error reading YAML config file: {0}".format(exc))
