#!/opt/vmware/bin/python

import argparse
import json
import logging
import os
import platform
import pprint
import subprocess
import sys
sys.path.append(os.environ['VMWARE_PYTHON_PATH'])
from cis.utils import setupLogging
from cis.defaults import get_cis_log_dir

# This script manipulates per-service configuration files.  Note that it
# performs no locking because we assume a single service is not configuring
# its configuration file in multiple threads simultaneously.

defaults = {
               "firewall":
                   {
                       "direction": "inbound",
                       "portoffset": 0,
                       "porttype": "dst",
                       "protocol": "tcp"
                    },
               "internal-ports":
                   {
                   }
           }

keys = {
           "firewall": "name",
           "internal-ports": "name"
       }

def execute(cmd, outfile=subprocess.PIPE, errfile=subprocess.PIPE):
    """
    Execute specified command.
    """
    process = subprocess.Popen(cmd, stdout=outfile, stderr=errfile)
    stdout,stderr = process.communicate()
    rc = process.wait()
    if rc != 0:
        msg = "Command failed with exitStatus = %d: %s\n" % (rc, ' '.join(cmd))
        msg += "Stdout: %s\n" % stdout
        msg += "Stderr: %s\n" % stderr
        print msg
        logging.error(msg)
    return rc,stdout,stderr

def loadServicesFile(file):
    """
    Load service configuration file into dictionary and return it.

    If file does not exist, return an empty dictionary.
    """
    svcs = {}
    try:
        with open(file, "r") as f:
            svcs = json.load(f)
    except ValueError, e:
        msg = "Malformed service configuration file (%s)" % file
        logging.error("%s: %s" % (msg, e))
        print msg
        exit(1)
    except IOError, e:
        pass

    return svcs

def parseArguments():
    """
    Parse command line arguments.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--service", help="The name of a service")
    parser.add_argument("--section", help="Section of config file (e.g., "
                        "'firewall' or 'internal-ports')")
    parser.add_argument("--dump",
                        help="Dump all or section config for service",
                        action="store_true")
    parser.add_argument("--enable",
                        help="Enable specified section of provided service",
                        action="store_true")
    parser.add_argument("--disable",
                        help="Disable specified section of provided service",
                        action="store_true")
    parser.add_argument("--reload-firewall",
                        help="Reload firewall after update",
                        action="store_true")
    parser.add_argument("--get-rules",
                        help="Get rules in specified section",
                        action="store_true")
    parser.add_argument("--update-rule-by-key",
                        help="Update rule with particular key (usually 'name')")
    parser.add_argument("--add-rule",
                        help="Add rule in specified section")

    args = parser.parse_args()

    # Check that --service is specified (except when using --update-rule-by-key).
    if not args.service and not args.update_rule_by_key and not args.reload_firewall:
        print "ERROR: --service must be specified (except when using --update-rule-by-key or --reload-firewall)"
        exit(1)
    if args.service and args.update_rule_by_key:
        print "ERROR: --service must not be used with --update-rule-by-key"
        exit(1)

    # Check that --section is specified when using --enable, --disable,
    # --get-rules, --update-rule-by-key, or --add-rule.
    if not args.section and (args.enable or args.disable or
                             args.get_rules or args.update_rule_by_key or
                             args.add_rule):
        print "ERROR: --section must be used with any of the following"
        print "       --enable, --disable, --get-rules, --update-rules-by-key,"
        print"        and --and-rule."
        exit(1)

    # Check that --dump is only used with --service and --section (nothing
    # else).
    if args.dump and (args.enable or args.disable or args.reload_firewall or
                      args.get_rules or args.update_rule_by_key or
                      args.add_rule):
        print "ERROR: --dump can only be used with --service and --section"
        exit(1)

    # Check that --reload-firewall is only used alone or with --service,
    # '--section firewall', --enable (optional), --disable (optional),
    # --update-rule-by-key (optional), --add-rule (optional), nothing else.
    if args.reload_firewall and (args.dump or args.get_rules):
        print "ERROR: --reload-firewall cannot be used with --dump or --get-rules"
        exit(1)
    if args.reload_firewall and args.section and args.section != "firewall":
        print "ERROR: --reload-firewall cannot be used with '--section %s'" % args.section
        exit(1)

    # Check that only one of --enable or --disable are provided at a time.
    if args.enable and args.disable:
        print "ERROR: Only one of --enable and --disable can be used"
        exit(1)

    return args

class serviceConfig(object):
    """
    """
    _globalCfgPath = os.path.join(os.environ['VMWARE_CFG_DIR'], 'appliance',
                                  'services.conf')
    _svcCfgDir = os.path.join(os.environ['VMWARE_CFG_DIR'], 'appliance',
                              'firewall')
    _changed = False

    def __init__(self, service, section=None):
        if not service:
            return
        self._service = service
        self._section = section
        self._svcCfgPath = os.path.join(self._svcCfgDir, service)
        if not os.path.exists(self._svcCfgPath) and \
           os.path.exists(os.path.join(self._svcCfgDir,
                                       "vmware-%s" % service)):
            self._svcCfgPath = os.path.join(self._svcCfgDir,
                                            "vmware-%s" % service)
        self._serviceDict = loadServicesFile(self._svcCfgPath)
        globalDict = loadServicesFile(self._globalCfgPath)
        self._globalDict = {}
        if service in globalDict:
            self._globalDict = globalDict[service]

    def _dominantDict(self):
        """
        If service-specific configuration exists, return it.  Otherwise, return global configuration.
        """
        d = dict()
        if self._serviceDict:
            d = self._serviceDict
        elif self._globalDict:
            d = self._globalDict
        return d

    def getConfig(self):
        """
        Return the service configuration (limiting to a particular section if
        one was specified).

        Return empty dictionary if not found.
        """
        result = self._dominantDict()
        if self._section:
            if self._section in result:
                result = result[self._section]
            else:
                result = {}
        return result

    def getRules(self):
        """
        Return rules in particular section of a particular service.

        Return empty array if section is undefined or does not contain rules.
        """
        rules = []
        d = self._dominantDict()
        if self._section in d and 'rules' in d[self._section]:
            rules = d[self._section]['rules']
        return rules

    def _getSectionDict(self):
        """
        Get dictionary associated with section.  If section doesn't exist,
        create it.
        """
        if self._section not in self._serviceDict:
            self._serviceDict[self._section] = dict()
        return self._serviceDict[self._section]

    def addRule(self, rule):
        """
        Add a new rule to section.  If the key (defined by the keys diction)
        in rule matches an existing rule, the existing rule is replace.
        """
        sectionDict = self._getSectionDict()

        # Add defaults to rule
        if self._section in defaults:
            defs = [(k,v) for k,v in defaults[self._section].iteritems()
                    if k not in rule]
            for d,v in defs:
                rule[d] = v

        # Create rules property if not already present
        if 'rules' not in sectionDict:
            sectionDict['rules'] = []

        # Strip out rules with matching key
        if self._section in keys:
            k = keys[self._section]
            rules = [r for r in sectionDict['rules']
                     if k in r and k in rule and r[k] != rule[k]]
            sectionDict['rules'] = rules

        # Add new rule
        sectionDict['rules'].append(rule)
        self._changed = True

    def updateRuleByKey(self, rule):
        """
        Search rules in section for service; if one is found containing a key
        matching that in the provided rule, augment the former with the
        properties of the latter.  The key is typically 'name'.  Note that
        this is not a full replacement (i.e., properties in the old rule will
        live on if not specified by the new rule).
        """
        sectionDict = self._getSectionDict()
        result = False
        if 'rules' in sectionDict and self._section in keys:
            k = keys[self._section]
            for r in sectionDict['rules']:
                if k in r and r[k] == rule[k]:
                    self._changed = True
                    result = True
                    for k,v in rule.iteritems():
                        r[k] = v
        return result

    def enable(self):
        """
        Enable the section of the service.
        """
        sectionDict = self._getSectionDict()
        sectionDict['enable'] = True
        self._changed = True

    def disable(self):
        """
        Disable the section of the service.
        """
        sectionDict = self._getSectionDict()
        sectionDict['enable'] = False
        self._changed = True

    def sync(self):
        """
        If the configuration has been updated, write it back to the
        configuration file.
        """
        if self._changed:
            tmpFile = self._svcCfgPath + '.tmp'
            with open(tmpFile, "w") as f:
                json.dump(self._serviceDict, f, indent=2)
            if platform.system() == "Windows" and os.path.isfile(self._svcCfgPath):
                os.remove(self._svcCfgPath)
            os.rename(tmpFile, self._svcCfgPath)

    def reloadFirewall(self):
        """
        Reload the firewall.
        """
        if platform.system() == "Windows":
            msg = "Windows firewall configuration is not yet implemented."
            print msg
            logging.error(msg)
            return
        cmd = ['/usr/lib/applmgmt/networking/bin/firewall-reload']
        execute(cmd)

#
# The work starts here.
#
logDir = os.path.join(get_cis_log_dir(), 'cloudvm')
setupLogging('service-config', logMechanism='file', logDir=logDir)

pp = pprint.PrettyPrinter(indent=2)

args = parseArguments()

if args.update_rule_by_key:
    svcCfgDir = os.path.join(os.environ['VMWARE_CFG_DIR'], 'appliance',
                              'firewall')
    try:
        info = json.loads(args.update_rule_by_key)
        if 'name' not in info:
            raise ValueError("Rule must include a 'name' field")
    except ValueError, e:
        print "Malformed rule '%s': %s" % (args.update_rule_by_key, e)
        exit(1)

    found = False
    if os.path.isdir(svcCfgDir):
        for service in [x for x in os.listdir(svcCfgDir)
                        if not x.startswith('.')]:
            cfg = serviceConfig(service, args.section)
            found = cfg.updateRuleByKey(info) or found
            cfg.sync()
    if not found:
        serviceName = info['name'].split('.')[0]
        cfg = serviceConfig(serviceName, args.section)
        cfg.addRule(info)
        cfg.enable()
        cfg.sync()
    exit(0)

cfg = serviceConfig(args.service, args.section)

if args.dump:
    pp.pprint(cfg.getConfig())
    exit(0)

if args.get_rules:
    pp.pprint(cfg.getRules())
    exit(0)

if args.add_rule:
    try:
        rule = json.loads(args.add_rule)
    except ValueError, e:
        print "ERROR: Malformed rule: %s" % e
        exit(1)
    cfg.addRule(rule)

if args.enable:
    cfg.enable()
elif args.disable:
    cfg.disable()

cfg.sync()

if args.reload_firewall:
    cfg.reloadFirewall()

