#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright 2011 OpenStack Foundation.
#
# (c) Copyright 2015-2016 Hewlett Packard Enterprise Development LP
# (c) Copyright 2017-2018 SUSE LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
DOCUMENTATION = '''
---
module: iptables
short_description: A module for adding iptables rules.

'''

import contextlib
import errno
import fcntl
import hashlib
import logging
import logging.handlers
import os
import signal
import socket
import time


LOG = logging.getLogger(__name__)

ROOT_PREFIX = 'ardana-INPUT'
MAX_CHAIN_NAME_LEN = 28
MAX_COMMENT_LEN = 256
FILTER_TABLE_NAME = 'filter'
FILTER_LINE = '*filter'
COMMIT_LINE = 'COMMIT'


def main():
    module = AnsibleModule(
        argument_spec=dict(
            enable=dict(required=False,
                        choices=BOOLEANS+['True', True,
                                          'False', False],
                        default=True),
            logging=dict(required=False,
                         choices=BOOLEANS+['True', True,
                                           'False', False],
                         default=True),
            rules=dict(type='dict', required=False, default=None),
            ardana_chains=dict(type='list', default=[], required=False),
            command=dict(required=False, default=None),
            lock_path=dict(required=True),
            lock_name=dict(required=True),
            lock_timeout=dict(type='int', required=False, default=120),
            synchronized_prefix=dict(required=False, default=None),
            os_family=dict(required=True),
        ),
        supports_check_mode=False
    )
    init_logging()

    enable = module.boolean(module.params['enable'])
    logging = module.boolean(module.params['logging'])
    rules = module.params['rules']
    chains = module.params['ardana_chains']
    command = module.params['command']
    lock_path = module.params['lock_path']
    lock_name = module.params['lock_name']
    lock_timeout = module.params['lock_timeout']
    synchronized_prefix = module.params['synchronized_prefix']

    if module.params['os_family'] == 'Debian':
        persist_file_ip4 = '/etc/iptables/rules.v4'
        persist_file_ip6 = '/etc/iptables/rules.v6'
    else:
        persist_file_ip4 = '/etc/sysconfig/iptables'
        persist_file_ip6 = '/etc/sysconfig/ip6tables'

    root_prefix = ROOT_PREFIX

    new_cmds = []
    new_cmds_ip6 = []

    if rules and command:
        msg = 'Cannot specify both rules and command'
        module.fail_json(msg)
    if rules:
        LOG.info("Installing iptables/ip6tables filter rules")
        # Pre-build the ip4 and ip6 rules we want to install
        if enable:
            create_root_chains(chains, root_prefix, new_cmds, new_cmds_ip6)
            append_rules(module, rules, root_prefix, new_cmds, new_cmds_ip6)
            append_default_rules(chains, root_prefix, logging, new_cmds,
                                 new_cmds_ip6)

        try:
            with lock(lock_name, synchronized_prefix, lock_path, lock_timeout):
                update_iptables_filters(module, root_prefix, new_cmds,
                                        'iptables')
                update_iptables_filters(module, root_prefix, new_cmds_ip6,
                                        'ip6tables')
            persist_iptables_filters(module, persist_file_ip4, new_cmds,
                                     root_prefix)
            persist_iptables_filters(module, persist_file_ip6, new_cmds_ip6,
                                     root_prefix)
        except Exception, e:
            LOG.error("Installing iptables/ip6tables filter rules: Failed")
            module.fail_json(msg='Exception: %s' % e)
        else:
            LOG.info("Installing iptables/ip6tables filter rules: Done")
            module.exit_json(**dict(changed=True, enabled=enable,
                                    firewall_rules=rules,
                                    generated_rules=new_cmds))

    elif command:
        # given a command string such as, '-A INPUT...'
        LOG.info("Executing iptables command '%s'", command)
        try:
            with lock(lock_name, synchronized_prefix, lock_path, lock_timeout):
                rc, stdout, stderr = run_iptables_cmd(module, command)
        except Exception, e:
            LOG.error("Failed to execute iptables command")
            module.fail_json(msg='Exception: %s' % e)
        else:
            changed = (rc == 0)
            module.exit_json(command=command,
                             changed=changed,
                             stderr=stderr.rstrip("\r\n"),
                             stdout=stdout.rstrip("\r\n"),
                             rc=rc)


def get_chain_name_hashed(prefix, chain_name):
    chain_name_hash = hashlib.md5(chain_name).hexdigest()
    postfix_length = MAX_CHAIN_NAME_LEN - len(prefix + '-')
    postfix = chain_name_hash[:postfix_length]
    return prefix + '-' + postfix


def update_iptables_filters(module, root_prefix, cmds, cmd_prefix):
    '''Read the current set of iptables, remove our previous rules,
       insert the new set and push back to the system
    '''
    LOG.info("Updating active '%s' %s rules", root_prefix, cmd_prefix)
    filter_rules = get_iptables_filter_rules(module, cmd_prefix)
    filter_rules = purge_all_rules_using_prefix(module,
                                                filter_rules,
                                                root_prefix)

    commit_idx = filter_rules.index(COMMIT_LINE)
    filter_rules[commit_idx:commit_idx] = cmds

    push_iptables_rules(module, filter_rules, cmd_prefix)


def get_iptables_filter_rules(module, cmd_prefix):
    '''Return the active ip(6)tables filter rules as list of
       strings, or default if none
    '''
    cmd = '%s-save -c -t %s' % (cmd_prefix, FILTER_TABLE_NAME)
    rc, stdout, stderr = module.run_command(cmd)
    if rc == 0:
        return stdout.split('\n')
    else:
        return [FILTER_LINE,
                COMMIT_LINE]


def push_iptables_rules(module, rules_list, cmd_prefix):
    '''Push the ip(6)tables using ip(6)tables-restore, the update is atomic
    '''
    rules = '\n'.join(rules_list)

    cmd = '%s-restore -c' % cmd_prefix
    rc, stdout, stderr = module.run_command(cmd, data=rules)
    if rc != 0:
        LOG.error("cmd failed: %r, data: %r", (cmd, rules))
        LOG.debug("rc: %r", rc)
        LOG.debug("stdout: %r", stdout)
        LOG.debug("stderr: %r", stderr)
        module.fail_json(msg="cmd failed: %s\nstderr: %s\ndata: %s" %
                             (cmd, stderr, rules))


def purge_all_rules_using_prefix(module, rules_list, prefix):
    '''Given a list of strings, return the subset of strings
       that don't contain the supplied prefix
    '''
    return [rule for rule in rules_list if prefix not in rule]


def create_root_chains(chains, prefix, cmds, cmds_ip6):
    ''' Create the chain structure and plumb it into the
        system INPUT chain.
        The root chain has the name 'prefix'
        Note: this takes advantage of the fact the specification
        for the the ip6 chains is exactly the same as for ipv4
    '''
    cmds.append(':%s -' % prefix)
    cmds.append('-A INPUT -j %s' % prefix)

    # create each chain,
    #    link from the prefix-root chain,
    for chain in chains:
        chain_name = get_chain_name_hashed(prefix, chain['name'])
        cmds.append(':%s -' % chain_name)
        cmds.append('-A %s -i %s -j %s' %
                    (prefix, chain['interface'], chain_name))

    # clone for ip6
    cmds_ip6.extend(cmds)


def append_rules(module, rules, prefix, cmds, ip6_cmds):
    ''' Create the commands to populate the chain with the requested rules
    '''
    for address, firewall_rules in rules.iteritems():
        try:
            # Because the netaddr module is not available in ansible, use
            # socket instead to determine the address protocol
            if socket.inet_pton(socket.AF_INET6, address):
                ip6_cmds.extend(process_rules(module, address,
                                              firewall_rules, prefix, 'ipv6'))
        except socket.error:
            cmds.extend(process_rules(module, address, firewall_rules, prefix))


def append_default_rules(chains, prefix, logging, cmds, cmds_ip6):
    ''' Append the 'default' rules to the existing chain structures
        - multicast accept - IP4 only
               (todo: Needs to be service specified)
        - related,established accept
        - rate limited logging - optional
        - drop
    '''
    for chain in chains:
        chain_name = get_chain_name_hashed(prefix, chain['name'])
        cmd = '-A %s -m pkttype --pkt-type multicast -j ACCEPT' % \
              chain_name
        cmds.append(cmd)
        # Open ports needed for icmpv6 functionality
        for port in [130, 131, 132, 134, 135, 136]:
            cmd = '-A %s -p icmpv6 --icmpv6-type %s -j ACCEPT' % (chain_name,
                                                                  port)
            cmds_ip6.append(cmd)
        cmd = '-A %s -m state --state RELATED,ESTABLISHED -j ACCEPT' % \
              chain_name
        cmds.append(cmd)
        cmds_ip6.append(cmd)
        if logging:
            cmd = '-A %s -m limit --limit 2/min -j LOG ' % chain_name
            cmd += '--log-prefix "IPTables-Dropped: " --log-level 4'
            cmds.append(cmd)
            cmds_ip6.append(cmd)
        cmd = '-A %s -j DROP' % chain_name
        cmds.append(cmd)
        cmds_ip6.append(cmd)


def process_rules(module, address, rules, prefix, family=None):
    cmds = []
    for rule in rules:
        chain_name = get_chain_name_hashed(prefix, rule['chain'])
        rtype = rule.get('type', 'allow')
        protocol = rule.get('protocol', 'tcp')
        if family and family == 'ipv6' and protocol == 'icmp':
            protocol = 'icmpv6'
        source = rule.get('remote-ip-prefix', None)
        source_arg = _address_arg(source, 's')
        dest_arg = _address_arg(address, 'd')
        protocol_arg = _protocol_arg(protocol)
        port_arg = _port_arg(rule, 'dport', protocol)

        args = ['-A', chain_name]
        args += source_arg + dest_arg + protocol_arg + port_arg
        args += ['-m', 'comment', '--comment',
                 '"%s"' % rule['chain'][:MAX_COMMENT_LEN]]

        if rtype.startswith('allow'):
            args += ['-j', 'ACCEPT']
        elif rtype == 'deny':
            args += ['-j', 'DROP']
        else:
            LOG.error("rule.type not supported: %r", rule)
            module.fail_json(msg="rule.type not supported: %s" % rule)

        cmds.append(' '.join(args))
    return cmds


def _address_arg(address, direction):
    if address:
        return ['-%s' % direction, address]
    else:
        return []


def _protocol_arg(protocol):
    return ['-p', str(protocol)]


def _port_arg(rule, direction, protocol):
    minPort = rule.get('port-range-min', None)
    maxPort = rule.get('port-range-max', None)

    # We don't do any input verification elsewhere, so we need to handle the
    # odd cases here.
    #
    # If minPort and maxPort are None, return nothing and caller will
    #    open all ports for the protocol
    #
    # If minPort is None and maxPort is not None, reverse them, user error.
    #    We can't set them equal as that will break the ICMP code below
    #
    if minPort is None:
        if maxPort is None:
            return []
        # Handle the case where someone has set max but not min, bad user!
        minPort = maxPort
        maxPort = None

    if protocol in ['icmp', 'icmpv6']:
        # minPort/maxPort represent icmp type/code when protocol is icmp or
        # icmpv6. Since icmp code can be 0 we cannot use "if maxPort" here
        if protocol == 'icmpv6' and minPort == '8':
            #Make sure that the intended echo-request type is specified
            minPort = 128
        if maxPort is not None:
            return ['--%s-type' % str(protocol), '%s/%s' % (minPort, maxPort)]
        return ['--%s-type' % str(protocol), '%s' % minPort]
    elif minPort == maxPort or maxPort is None:
        return ['--%s' % direction, '%s' % minPort]
    else:
        return ['-m', 'multiport', '--%ss' % direction,
                '%s:%s' % (minPort, maxPort)]


def run_iptables_cmd(module, command):
    # As the command can be any arbitrary command and the address protocol
    # is unknown, run for both iptables and ip6tables and report failure
    # only if both fail.
    checking = command.startswith("-C")
    cmd = 'iptables ' + command
    rcv4, stdoutv4, stderrv4 = module.run_command(cmd)
    cmdipv6 = 'ip6tables ' + command
    rcv6, stdoutv6, stderrv6 = module.run_command(cmdipv6)
    if rcv4 != rcv6:
        if rcv4 == 0:
            # IPv4 command version succeeded, report that
            return rcv4, stdoutv4, stderrv4
        elif rcv6 == 0:
            # IPv6 command version succeeded, report that
            return rcv6, stdoutv6, stderrv6
    if checking:
        # Only checking to see if a rule/table/chain already
        # exists.  Need to handle special.
        if (stderrv4.endswith("No chain/target/match by that name.\n")
            or stderrv4.endswith("Chain already exists.\n")):
            return rcv4, stdoutv4, stderrv4
        else:
            stderrv6 = stderrv6.replace("ip6tables","iptables")
            return rcv6, stdoutv6, stderrv6
    # if we have gotten this far, log errors before returning
    if rcv4 != 0:
        LOG.error("cmd failed: %r", cmd)
        LOG.debug("rc: %r", rcv4)
        LOG.debug("stdout: %r", stdoutv4)
        LOG.debug("stderr: %r", stderrv4)
    if rcv6 != 0:
        LOG.error("cmd failed: %r", cmdipv6)
        LOG.debug("rc: %r", rcv6)
        LOG.debug("stdout: %r", stdoutv6)
        LOG.debug("stderr: %r", stderrv6)
    # Both commands resulted in the same rc, report IPv4 version for
    # backwards compatibility in playbooks
    return rcv4, stdoutv4, stderrv4


def persist_iptables_filters(module, persist_file, cmds, prefix):
    '''Write our the new set of filter rules to the file specified.
    '''
    LOG.info("Updating persisted '%s' rules in %s", prefix, persist_file)
    file_list = read_persisted_iptables_filters(module, persist_file)

    update_persisted_iptables_filters(module, file_list, cmds, prefix)

    write_persisted_iptables_filters(module, persist_file, file_list)


def read_persisted_iptables_filters(module, filename):
    '''Return the content of the file specified as a list of strings
    '''
    file_contents = []

    try:
        with open(filename, 'r') as pfile:
            file_contents = pfile.read().splitlines()
    except:
        # failure to open is OK, the file may not exist
        pass

    return file_contents


def write_persisted_iptables_filters(module, filename, content_list):
    '''Write the content to the file specified.
       This will raise an exception if the create fails
    '''
    lines = '\n'.join(content_list)
    lines += '\n'

    with open(filename, "w") as pfile:
        pfile.write(lines)


def update_persisted_iptables_filters(module, persist_file, cmds, prefix):
    '''Update the netfilter-persistent rules file with our new content.
    '''
    # assume we don't have a persisted set of filter rules
    filter_idx = -1
    filter_table = [FILTER_LINE,
                    ':INPUT ACCEPT',
                    ':FORWARD ACCEPT',
                    ':OUTPUT ACCEPT',
                    COMMIT_LINE]

    # if there is a '*filter' table in the file, use that
    if FILTER_LINE in persist_file:
        filter_idx = persist_file.index(FILTER_LINE)
        commit_idx = (filter_idx +
                      persist_file[filter_idx:].index(COMMIT_LINE))

        # copy/remove the filter table section
        filter_table = persist_file[filter_idx:commit_idx+1]
        persist_file[filter_idx:commit_idx+1] = []

        # purge rules (and comment) matching the prefix from filter table
        filter_table = purge_all_rules_using_prefix(module,
                                                    filter_table,
                                                    prefix)

    # Insert our rules into the filter table before the COMMIT
    commit_idx = filter_table.index(COMMIT_LINE)
    filter_table[commit_idx:commit_idx] = cmds

    # insert a comment after the *filter to track the edit
    filter_table.insert(1, '# %s: filter section updated on %s' %
                           (prefix, time.ctime()))

    # put the filter table back where we found it
    persist_file[filter_idx:filter_idx] = filter_table


def init_logging():
    syslog_handler = logging.handlers.SysLogHandler(address='/dev/log')
    formatter = logging.Formatter('%(module)s: %(message)s')
    syslog_handler.setFormatter(formatter)
    LOG.addHandler(syslog_handler)
    LOG.setLevel(logging.INFO)
#    LOG.setLevel(logging.DEBUG)        # comment out if not needed


# NOTE: This entire class can be removed if we have the fasteners
#       module, see oslo.concurrency code for more information.
class InterProcessLock(object):
    """An interprocess locking implementation.

    This is a lock implementation which allows multiple locks, working around
    issues like http://bugs.debian.org/cgi-bin/bugreport.cgi?bug=632857 and
    does not require any cleanup. Since the lock is always held on a file
    descriptor rather than outside of the process, the lock gets dropped
    automatically if the process crashes, even if ``__exit__`` is not
    executed.

    There are no guarantees regarding usage by multiple threads in a
    single process here. This lock works only between processes.

    Note these locks are released when the descriptor is closed, so it's not
    safe to close the file descriptor while another thread holds the
    lock. Just opening and closing the lock file can break synchronization,
    so lock files must be accessed only using this abstraction.
    """

    def __init__(self, path, timeout):
        self.lockfile = None
        self.needlock = True
        self.path = path
        self.timeout = timeout

    def _try_acquire(self):
        # This method will set a timeout as trylock() will do a blocking
        # lockf() call and we don't want to wait indefinitely.
        class LockError(Exception):
            pass

        def handler(signum, frame):
            raise LockError()

        # set the timeout handler
        signal.signal(signal.SIGALRM, handler)
        signal.alarm(self.timeout)

        try:
            self.trylock()
        except LockError:
            msg = 'Unable to acquire lock on `%s` due to timeout' % self.path
            raise LockError(msg)
        except IOError as e:
            msg = ('Unable to acquire lock on `%s` due to exception %s' %
                   (self.path, e))
            raise LockError(msg)
        finally:
            signal.alarm(0)

    def _do_open(self):
        self._ensure_tree()
        if not self.needlock:
            return

        # Open in append mode so we don't overwrite any potential contents of
        # the target file. This eliminates the possibility of an attacker
        # creating a symlink to an important file in our lock path.
        if self.lockfile is None or self.lockfile.closed:
            self.lockfile = open(self.path, 'a')

    def acquire(self):
        """Attempt to acquire the given lock.

        There are no re-attempts, errors will be raised!

        :returns: whether or not the acquisition succeeded
        :rtype bool
        """
        self._do_open()
        if not self.needlock:
            return

        self._try_acquire()
        LOG.debug('Acquired file lock `%s`', self.path)

    def _do_close(self):
        if self.lockfile is not None:
            self.lockfile.close()
            self.lockfile = None

    def __enter__(self):
        self.acquire()
        return self

    def release(self):
        if not self.needlock:
            return

        """Release the previously acquired lock."""
        try:
            self.unlock()
        except IOError:
            LOG.exception("Could not unlock the acquired lock opened"
                          " on `%s`", self.path)
        else:
            self.acquired = False
            try:
                self._do_close()
            except IOError:
                LOG.exception("Could not close the file handle"
                              " opened on `%s`", self.path)
            else:
                LOG.debug("Unlocked and closed file lock open on"
                          " `%s`", self.path)

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.release()

    def trylock(self):
        fcntl.lockf(self.lockfile, fcntl.LOCK_EX)

    def unlock(self):
        fcntl.lockf(self.lockfile, fcntl.LOCK_UN)

    def _ensure_tree(self):
        basedir = os.path.dirname(self.path)

        # If the directory doesn't exist, we don't need a lock
        if not os.path.exists(basedir):
            LOG.debug("Directory `%s` does not exist, file lock not needed",
                      basedir)
            self.needlock = False
            return

        """Create a directory (and any ancestor directories required)"""
        try:
            os.makedirs(basedir)
        except OSError as e:
            if e.errno == errno.EEXIST:
                if not os.path.isdir(basedir):
                    raise
            elif e.errno != errno.EISDIR:
                raise


def _get_lock_path(name, lock_file_prefix, lock_path):
    # NOTE(mikal): the lock name cannot contain directory
    # separators
    name = name.replace(os.sep, '_')
    sep = '' if lock_file_prefix.endswith('-') else '-'
    name = '%s%s%s' % (lock_file_prefix, sep, name)

    return os.path.join(lock_path, name)


def external_lock(name, lock_file_prefix, lock_path, lock_timeout):
    lock_file_path = _get_lock_path(name, lock_file_prefix, lock_path)

    return InterProcessLock(lock_file_path, lock_timeout)


@contextlib.contextmanager
def lock(name, lock_file_prefix, lock_path, lock_timeout, do_log=True):
    """Context based lock

    This function yields an InterProcessLock instance.

    :param lock_file_prefix: The lock_file_prefix argument is used to provide
      lock files on disk with a meaningful prefix.

    :param lock_path: The path in which to store external lock files.  For
      external locking to work properly, this must be the same for all
      references to the lock.

    :param do_log: Whether to log acquire/release messages.  This is primarily
      intended to reduce log message duplication when `lock` is used from the
      `synchronized` decorator.
    """
    if do_log:
        LOG.debug('Acquiring "%(lock)s"', {'lock': name})
    try:
        ext_lock = external_lock(name, lock_file_prefix, lock_path,
                                 lock_timeout)
        ext_lock.acquire()
        try:
            yield ext_lock
        finally:
            ext_lock.release()
    finally:
        if do_log:
            LOG.debug('Releasing lock "%(lock)s"', {'lock': name})


from ansible.module_utils.basic import *
if __name__ == '__main__':
    main()
