#!/usr/bin/env python3
import os
import stat
import logging
from sys import argv
from http.client import HTTPConnection
from urllib.request import AbstractHTTPHandler, HTTPHandler, HTTPSHandler, OpenerDirector
import argparse
import json
import socket
from functools import lru_cache
import re

logger = logging.getLogger()
__author__ = 'Tim Laurence'
__copyright__ = "Copyright 2017"
__credits__ = ['Tim Laurence']
__license__ = "GPL"
__version__ = "2.0.0"

'''
nrpe compatible check for docker swarm

Requires Python 3

Note: I really would have preferred to have used requests for all the network connections but that would have added a
dependency.
'''

DEFAULT_SOCKET = '/var/run/docker.sock'
DEFAULT_TIMEOUT = 10.0
DEFAULT_PORT = 2375
DEFAULT_HEADERS = [('Accept', 'application/vnd.docker.distribution.manifest.v2+json')]
OK_RC = 0
WARNING_RC = 1
CRITICAL_RC = 2
UNKNOWN_RC = 3

HTTP_GOOD_CODES = range(200, 299)

# These hold the final results
rc = -1
messages = []


# Hacked up urllib to handle sockets
#############################################################################################
# Docker runs a http connection over a socket. http.client is knows how to deal with these
# but lacks some niceties. Urllib wraps that and makes up for some of the deficiencies but
# cannot fix the fact http.client can't read from socket files. In order to take advantage of
# urllib and http.client's  capabilities the class below tweaks HttpConnection and passes it
# to urllib registering for socket:// connections


class SocketFileHandler(AbstractHTTPHandler):
    class SocketFileToHttpConnectionAdaptor(HTTPConnection):
        def __init__(self, socket_file, timeout=DEFAULT_TIMEOUT):
            super().__init__(host='', port=0, timeout=timeout)
            self.socket_file = socket_file

        def connect(self):
            self.sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0, fileno=None)
            self.sock.settimeout(self.timeout)
            self.sock.connect(self.socket_file)

    def socket_open(self, req):
        socket_file, path = req.selector.split(':', 1)
        req.host = socket_file
        req.selector = path
        return self.do_open(self.SocketFileToHttpConnectionAdaptor, req)


better_urllib_get = OpenerDirector()
better_urllib_get.addheaders = DEFAULT_HEADERS.copy()
better_urllib_get.add_handler(HTTPHandler())
better_urllib_get.add_handler(HTTPSHandler())
better_urllib_get.add_handler(SocketFileHandler())

better_urllib_head = OpenerDirector()
better_urllib_head.method = 'HEAD'
better_urllib_head.addheaders = DEFAULT_HEADERS.copy()
better_urllib_head.add_handler(HTTPHandler())
better_urllib_head.add_handler(HTTPSHandler())
better_urllib_head.add_handler(SocketFileHandler())


# Util functions
#############################################################################################


@lru_cache()
def get_url(url):
    response = better_urllib_get.open(url, timeout=timeout)
    return process_urllib_response(response), response.status


def process_urllib_response(response):
    response_bytes = response.read()
    body = response_bytes.decode('utf-8')
    logger.debug(body)
    return json.loads(body)


def get_swarm_status():
    content, status = get_url(daemon + '/swarm')
    return status


def get_service_info(name):
    return get_url(daemon + '/services/{service}'.format(service=name))


def get_services(names):
    services_list, status = get_url(daemon + '/services')
    if status == 406:
        critical("Error checking service status, node is not in swarm mode")
        return []
    elif status not in HTTP_GOOD_CODES:
        unknown("Could not retrieve service info")
        return []

    all_services_names = set(x['Spec']['Name'] for x in services_list)
    if 'all' in names:
        return all_services_names

    filtered = set()
    for matcher in names:
        found = False
        for candidate in all_services_names:
            if re.match("^{}$".format(matcher), candidate):
                filtered.add(candidate)
                found = True
        # If we don't find a service that matches out regex
        if not found:
            critical("No services match {}".format(matcher))

    return filtered


def set_rc(new_rc):
    global rc
    rc = new_rc if new_rc > rc else rc


def ok(message):
    set_rc(OK_RC)
    messages.append('OK: ' + message)


def warning(message):
    set_rc(WARNING_RC)
    messages.append('WARNING: ' + message)


def critical(message):
    set_rc(CRITICAL_RC)
    messages.append('CRITICAL: ' + message)


def unknown(message):
    set_rc(UNKNOWN_RC)
    messages.append('UNKNOWN: ' + message)


# Checks
#############################################################################################
def check_swarm():
    status = get_swarm_status()
    process_url_status(status, ok_msg='Node is in a swarm',
                       critical_msg='Node is not in a swarm', unknown_msg='Error accessing swarm info')


def check_service(name):
    info, status = get_service_info(name)
    process_url_status(status, ok_msg='Service {service} is up and running'.format(service=name),
                       critical_msg='Service {service} was not found on the swarm'.format(service=name))


def process_url_status(status, ok_msg=None, critical_msg=None, unknown_msg=None):
    if status in HTTP_GOOD_CODES:
        return ok(ok_msg)
    elif status in [503, 404, 406]:
        return critical(critical_msg)
    else:
        return unknown(unknown_msg)


def process_args(args):
    parser = argparse.ArgumentParser(description='Check docker swarm.')

    # Connect to local socket or ip address
    connection_group = parser.add_mutually_exclusive_group()
    connection_group.add_argument('--connection',
                                  dest='connection',
                                  action='store',
                                  default=DEFAULT_SOCKET,
                                  type=str,
                                  metavar='[/<path to>/docker.socket|<ip/host address>:<port>]',
                                  help='Where to find docker daemon socket. (default: %(default)s)')

    connection_group.add_argument('--secure-connection',
                                  dest='secure_connection',
                                  action='store',
                                  type=str,
                                  metavar='[<ip/host address>:<port>]',
                                  help='Where to find TLS protected docker daemon socket.')

    # Connection timeout
    parser.add_argument('--timeout',
                        dest='timeout',
                        action='store',
                        type=float,
                        default=DEFAULT_TIMEOUT,
                        help='Connection timeout in seconds. (default: %(default)s)')

    swarm_group = parser.add_mutually_exclusive_group(required=True)

    # Swarm
    swarm_group.add_argument('--swarm',
                             dest='swarm',
                             default=None,
                             action='store_true',
                             help='Check swarm status')

    # Service
    swarm_group.add_argument('--service',
                             dest='service',
                             action='store',
                             type=str,
                             nargs='+',
                             default=[],
                             help='One or more RegEx that match the names of the services(s) to check.')

    if len(args) == 0:
        parser.print_help()

    parsed_args = parser.parse_args(args=args)

    global timeout
    timeout = parsed_args.timeout

    global daemon
    global connection_type
    if parsed_args.secure_connection:
        daemon = 'https://' + parsed_args.secure_connection
        connection_type = 'https'
    elif parsed_args.connection:
        if parsed_args.connection[0] == '/':
            daemon = 'socket://' + parsed_args.connection + ':'
            connection_type = 'socket'
        else:
            daemon = 'http://' + parsed_args.connection
            connection_type = 'http'

    return parsed_args


def socketfile_permissions_failure(parsed_args):
    if connection_type == 'socket':
        return not (os.path.exists(parsed_args.connection)
                    and stat.S_ISSOCK(os.stat(parsed_args.connection).st_mode)
                    and os.access(parsed_args.connection, os.R_OK)
                    and os.access(parsed_args.connection, os.W_OK))
    else:
        return False


def print_results():
    print('; '.join(messages))


def perform_checks(raw_args):
    args = process_args(raw_args)
    if socketfile_permissions_failure(args):
        unknown("Cannot access docker socket file. User ID={}, socket file={}".format(os.getuid(), args.connection))
    else:
        # Here is where all the work happens
        #############################################################################################
        try:
            if args.swarm:
                check_swarm()
            elif args.service:
                services = get_services(args.service)

                if len(services) > 0:  # Status is set to critical by get_services() if nothing is found for a name
                    for service in services:
                        check_service(service)

        except Exception as e:
            unknown("Exception raised during check: {}".format(repr(e)))

    print_results()


if __name__ == '__main__':
    perform_checks(argv[1:])
    exit(rc)
