#! /opt/imh-python/bin/python
""" Scan a URL, giving details about the process. """

import re
import sys
import logging
from argparse import ArgumentParser
from rads.common import colors
from rads.common import setup_logging
from strace_tools import strace_url
import json


LOGGER = logging.getLogger(__name__)

__author__ = "Daniel K"
__email__ = "danielk@inmotionhosting.com"


def colorize(text, color, end='none'):
    """Apply a color to text data"""
    return '%s%s%s' % (colors()[color], text, colors()[end])


def set_logging(is_quiet, verbosity, output_file):
    ''' Setup the logging in a separate function '''
    if is_quiet:
        logging_level = logging.CRITICAL
    else:
        if None is verbosity:
            logging_level = logging.ERROR
        elif 1 == verbosity:
            logging_level = logging.WARNING
        elif 2 == verbosity:
            logging_level = logging.INFO
        elif 3 == verbosity:
            logging_level = logging.DEBUG
        else:
            logging_level = logging.DEBUG

    if output_file == '':
        setup_logging(loglevel=logging_level, print_out=sys.stderr)
    else:
        setup_logging(
            logfile=output_file,
            loglevel=logging_level,
            print_out=False
        )


def parse_args():
    ''' Parse command line arguments '''

    parser = ArgumentParser(description=__doc__)

    parser.add_argument(
        "-v", "--verbose", action='count',
        help="Print verbose output. May be added multiple times."
    )

    parser.add_argument(
        "-q", "--quiet", action='store_true',
        help="Do not output logging. Overrides -v."
    )

    parser.add_argument(
        "-o", "--output", action='store', type=str, default='',
        help="Output logging to the specified file."
    )

    output_parser_group = parser.add_argument_group("Output options")
    output_group = output_parser_group.add_mutually_exclusive_group()

    output_group.add_argument(
        "-r", "--raw", action='store_true',
        help="Output strace raw data."
    )

    output_group.add_argument(
        "-d", "--debug", action='store_true',
        help="Print verbose diagnostic information."
    )

    output_group.add_argument(
        "-e", "--errors", action='store_true',
        help="Output data sent to error logs."
    )

    output_group.add_argument(
        "-f", "--files", action='store_true',
        help="Show file activity."
    )

    parser.add_argument(
        'url', metavar='URL', type=str, nargs='?',
        help="URL to scan."
    )

    args = parser.parse_args()

    if args.url is None:
        print "URL not given"
        sys.exit(1)

    set_logging(args.quiet, args.verbose, args.output)

    return args.url, args.raw, args.debug


def dump_callback(command, line, strace_handler, additional_data=None):
    ''' Dump data collected from command. This is for debugging. '''
    print "Command [%s] - %s" % (command, line)
    for key in additional_data:
        print "\t%s: %s" % (key, additional_data[key])
    print "\n"


def make_error_hash(filename, data):
    ''' Create a hash to determine the uniquness of this error '''
    data = data.replace('\n', '')
    data = data.replace(' ', '_')
    return "%s_%s" % (filename, data[:15])


class ErrorHandler(object):
    ''' Class to handle all errors in the URL '''
    error_hashes = []
    errors = []
    max_errors = 5
    timeouts = {}

    def __init__(self, max_errors=5):
        self.max_errors = max_errors
        self.timeouts = {}

    def record(self, filename, error):
        ''' Record an incomming issue '''

        if len(self.errors) >= self.max_errors:
            return

        error_hash = make_error_hash(filename, error)

        if error_hash in self.error_hashes:
            return

        self.error_hashes.append(error_hash)
        self.errors.append(error)

    def list(self):
        ''' Return list of issues '''

        if len(self.errors) == 0:
            print colorize("No errors recorded\n", 'green')
        else:
            print "\nErrors were detected:\n"
            for error in self.errors:
                print (
                    " * " +
                    colorize(error, 'red') +
                    "\n"
                )

    def record_timeout(self, ip_address, error):
        ''' Record a connection timeout to check later '''
        if ip_address not in self.timeouts:
            self.timeouts[ip_address] = []

        self.timeouts[ip_address].append(error)

    def clear_timeouts(self, ip_address):
        ''' Clear a timeout. Useful if we connected later. '''
        if ip_address in self.timeouts:
            self.timeouts.pop(ip_address, None)

    def get_remaining_timeouts(self):
        ''' Get timeouts that haven't been cleared '''
        remaining_timeouts = []
        for ip_address in self.timeouts:
            remaining_timeouts.append(self.timeouts[ip_address][0])
        return remaining_timeouts

    def parse_php_error(self, filename, error):
        ''' Parse a PHP error message '''
        include_error_match = re.search(
            r'(include|require)(_once)?'
            r'('
            r'\((?P<filename>[^)]+)\): '
            r'failed to open stream: No such file or directory '
            r'|'
            r"\(\): Failed opening '(?P<filename2>[^']+)' "
            r'for inclusion \([^)]+\) '
            r')'
            r'in (?P<source>(?P<usrroot>/home[0-9]*/[^/]+/)?[^ ]+) '
            r'on line (?P<line>[0-9]+)',
            error
        )
        session_after_header_match = re.search(
            r'('
            r'session_start\(\): Cannot send session (cookie|cache limiter)'
            r'|'
            r'Cannot modify header information'
            r') - headers already sent (by )?'
            r'\(output started at (?P<sentheader>[^:]+):(?P<sentline>[0-9]+)\)'
            r' in (?P<session>[^\s]+) on line (?P<sessline>[0-9]+)',
            error
        )
        old_joomla_match = re.search(
            r'('
            r'Non-static method J[A-Z][a-zA-Z]+::'
            r'|'
            r'Redefining already defined constructor for class J[A-Z][a-zA-Z]+'
            r'|'
            r'should be compatible with( &)? J[A-Z][a-zA-Z]+'
            r')',
            error
        )
        user_access_match = re.search(
            r"Access denied for user "
            r"'(?P<username>[a-zA-Z0-9_]*)'@'(?P<hostname>[a-zA-Z0-9._-]+)' "
            r"\(using password: (?P<password>YES|NO)\)"
            r"( in (?P<source>[^ ]+) on line (?P<line>[0-9]+))?",
            error
        )
        db_access_match = re.search(
            r"Access denied for user "
            r"'(?P<username>[a-zA-Z0-9_]*)'@'(?P<hostname>[a-zA-Z0-9._-]+)' "
            r"to database '(?P<database>[a-zA-Z0-9._-]+)'",
            error
        )
        if include_error_match is not None:
            if include_error_match.group('filename') is not None:
                filename = include_error_match.group('filename')
            else:
                filename = include_error_match.group('filename2')
            error = "Missing file %s in %s" % (
                filename,
                include_error_match.group('source')
            )

            if include_error_match.group('usrroot') is not None:
                if filename.startswith('/'):
                    if not filename.startswith(
                            include_error_match.group('usrroot')
                    ):
                        error = (
                            "Hard-coded path wrong for %s in %s"
                            % (filename, include_error_match.group('source'))
                        )

            self.record(filename, error)

        elif session_after_header_match is not None:
            error = (
                "Session can't start because "
                "output already sent from %s on line %s" %
                (
                    session_after_header_match.group('sentheader'),
                    session_after_header_match.group('sentline')
                )
            )

            self.record(filename, error)
        elif old_joomla_match is not None:
            error = "Outdated Joomla site"

            self.record(filename, error)
        elif user_access_match is not None:
            if user_access_match.group('source') is not None:
                error = (
                    "Database authentication error for '%s' on %s. "
                    "Found in %s on line %s." % (
                        user_access_match.group('username'),
                        user_access_match.group('hostname'),
                        user_access_match.group('source'),
                        user_access_match.group('line')
                    )
                )
            else:
                error = "Database authentication error for '%s' on %s" % (
                    user_access_match.group('username'),
                    user_access_match.group('hostname')
                )

            self.record(filename, error)
        elif db_access_match is not None:
            error = "'%s' can't connect to database '%s'" % (
                db_access_match.group('username'),
                db_access_match.group('database'),
            )
            self.record(filename, error)
        elif 'is deprecated' in error:
            self.record(
                filename,
                "Software is outdated and expects an older PHP version."
            )
        elif 'PHP Startup:' in error:
            self.record(
                filename,
                "Startup error: %s" % error
            )
        else:
            self.record(filename, "PHP ERROR: %s" % error)

    def parse_error(self, filename, data):
        ''' Parse errors which were output '''
        if len(self.errors) > self.max_errors:
            return

        php_error_match = re.search(
            r'\[[a-zA-Z0-9-]+\s+[0-9:]+\s+[^]]+\] '
            r'PHP (Warning|[^:]+):\s+(?P<error_message>.*$)',
            data
        )

        if php_error_match is not None:
            return self.parse_php_error(
                filename,
                php_error_match.group('error_message')
            )
        elif 'Out of memory' in data:
            self.record(filename, "Out of Memory")
        else:
            self.record(
                filename,
                "Error written to %s: %s\n" % (filename, data)
            )

    def check_for_written_error(
            self,
            command,
            line,
            strace_handler,
            additional_data=None
    ):
        ''' Check for any messages written to error logs or output '''
        if 'filename' in additional_data:
            if (
                    'STDERR' in additional_data['filename'] or
                    'error' in additional_data['filename'] or
                    '/dev/null' in additional_data['filename']
            ):
                self.parse_error(
                    additional_data['filename'],
                    additional_data['data']
                )
            elif 'STDOUT' in additional_data['filename']:
                title_match = re.search(
                    r'<title>(?P<title>[^<]+)(</title>)?',
                    additional_data['data']
                )
                if title_match is not None:
                    if '</title>' in additional_data['current_data']:
                        print (
                            "Page Title: " +
                            colorize(
                                title_match.group('title'),
                                'green'
                            )
                        )
            else:
                self.clear_timeouts(additional_data['filename'])

    def check_for_read_error(
            self,
            command,
            line,
            strace_handler,
            additional_data=None
    ):
        ''' Check for error messages which were read in '''

        if 'filename' not in additional_data:
            return

        if additional_data['status'] == '-1':
            error = (
                "Error reading from %s: %s" %
                (
                    additional_data['filename'],
                    additional_data['error_message']
                )
            )
            self.record_timeout(additional_data['filename'], error)
            return

        self.clear_timeouts(additional_data['filename'])

        if command == 'recvfrom':
            fail_match = re.search(
                r"HTTP/1.[10] (?P<status>([45][0-9]+) ([^\\])*)",
                additional_data['data']
            )
            if fail_match is not None:
                self.record(
                    additional_data['filename'],
                    "Remote connection error from %s: %s" % (
                        additional_data['filename'],
                        fail_match.group('status')
                    )
                )

        user_access_match = re.search(
            r"Access denied for user "
            r"'(?P<username>[a-zA-Z0-9_]*)'@'(?P<hostname>[a-zA-Z0-9._-]+)' "
            r"\(using password: (?P<password>YES|NO)\)",
            additional_data['data']
        )
        db_access_match = re.search(
            r"Access denied for user "
            r"'(?P<username>[a-zA-Z0-9_]*)'@'(?P<hostname>[a-zA-Z0-9._-]+)' "
            r"to database '(?P<database>[a-zA-Z0-9._-]+)'",
            additional_data['data']
        )
        missing_table_match = re.search(
            r"Table '(?P<table>[^']+)' doesn't exist",
            additional_data['data']
        )
        if user_access_match is not None:
            self.record(
                additional_data['filename'],
                "Database authentication error for '%s' on %s" % (
                    user_access_match.group('username'),
                    user_access_match.group('hostname')
                )
            )
        elif db_access_match is not None:
            self.record(
                additional_data['filename'],
                "'%s' can't connect to database '%s'" % (
                    db_access_match.group('username'),
                    db_access_match.group('database'),
                )
            )
        elif missing_table_match is not None:
            self.record(
                additional_data['filename'],
                "Missing database table: '%s'" % (
                    missing_table_match.group('table'),
                )
            )
        else:
            pass

    def check_for_connect_errors(
            self,
            command,
            line,
            strace_handler,
            additional_data=None
    ):
        ''' Check for errors with the connections '''
        if 'status' not in additional_data:
            return
        if additional_data['status'] == '0':
            return

        if 'mysql.sock' in additional_data['filename']:
            self.record(
                additional_data['filename'],
                "MySQL seems to be down. Unable to connect to socket."
            )
        elif additional_data['ip'] is not None:
            if additional_data['port'] == '80':
                connection_type = "Web connection"
            elif additional_data['port'] == '443':
                connection_type = "Web connection"
            elif additional_data['port'] == '3306':
                connection_type = "Database connection"
            else:
                connection_type = "Connection"

            error = "%s timed out to %s." % (
                connection_type,
                additional_data['ip']
            )

            self.record_timeout(additional_data['ip'], error)

        else:
            self.record_timeout(
                additional_data['filename'],
                "Unable to connect to %s: %s" % (
                    additional_data['filename'],
                    additional_data['status_message']
                )
            )

    def catch_timeout(
            self,
            command,
            line,
            strace_handler,
            additional_data=None
    ):
        ''' Check for connection timeouts '''
        timeout_match = re.search(
            r'poll\(\[{fd=(?P<handle>[0-9+]+),'
            r'\s+events=([A-Z_|]+)}\],\s+[0-9]+,'
            r'\s+[0-9]+\)\s+=\s+[0-9]+ \(Timeout\)',
            line
        )
        if timeout_match is not None:
            handle = int(timeout_match.group('handle'))
            filename = strace_handler.file_handle[handle].filename
            self.record_timeout(filename, "Poll timeout on %s" % filename)

    def catch_exit(
            self,
            command,
            line,
            strace_handler,
            additional_data=None
    ):
        ''' Notify if exit is not a success '''
        if 'status' in additional_data:
            if additional_data['status'] != '0':
                self.record(
                    "exit",
                    "Exited with status %s" % additional_data['status']
                )


def main():
    ''' Main function for scan_url '''

    (url, raw, debug) = parse_args()

    callbacks = {}
    check_errors = ErrorHandler(5)

    if debug:
        callbacks['default'] = dump_callback
    else:
        callbacks['write'] = check_errors.check_for_written_error
        callbacks['read'] = check_errors.check_for_read_error
        callbacks['recvfrom'] = check_errors.check_for_read_error
        callbacks['poll'] = check_errors.catch_timeout
        callbacks['connect'] = check_errors.check_for_connect_errors
        callbacks['end'] = check_errors.catch_exit

    if raw:
        for line in strace_url(url, callbacks, raw):
            print line.rstrip()
    else:
        for result in strace_url(url, callbacks, raw):
            if result is not None:
                if isinstance(result, basestring):
                    print result
                else:
                    print "%s" % json.dumps(
                        result,
                        sort_keys=True,
                        indent=4,
                        separators=(',', ': ')
                    )

    timeouts = check_errors.get_remaining_timeouts()
    if len(timeouts) > 0:
        for message in timeouts:
            check_errors.record("timeout", message)

    check_errors.list()

if __name__ == "__main__":
    main()
