#! /opt/imh-python/bin/python
''' Watch live Apache domlogs gathering data about connections '''

# Author: Daniel K

from sh import tail
from sh import ErrorReturnCode
import errno
import sys
import glob
import os
import re
from collections import defaultdict
from datetime import datetime
import time
from rads.common import colors
import curses

def colorize(text, color, end='none'):
    """Apply a color to text data"""
    return '%s%s%s' % (colors()[color], text, colors()[end])

def print_hit(hit):
    pass
#    print "%s %s '%s' '%s'" % (
#       hit.ip,
#       hit.query,
#       hit.referer,
#       hit.user_agent
#    )

class Rate(object):
    ''' Count rate of occurances '''

    interval = 3

    time_1_last = 0
    time_2_last = 1
    time_3_last = 3
    start_time = None

    time_1_hits = 0
    time_2_hits = 0
    time_3_hits = 0
    hits_total = 0

    avg = 0

    epoch = None

    distance = 0

    def __init__(self, interval=3):
        ''' Initialize by getting initial time and hit counts '''

        self.interval = interval
        self.distance = int(interval / 3)
        self.epoch = datetime(1970,1,1)

        self.reset()
        self.start_time = datetime.now()

    def reset(self):
        ''' Reset rate information '''

        time_now = datetime.now()

        delta = (time_now - self.epoch).total_seconds()

        self.time_1_last = delta
        self.time_2_last = delta + self.distance
        self.time_3_last = delta + (self.distance * 2)

    def add_occurence(self):
        ''' Record a new occurance now '''

        time_now = datetime.now()

        delta = (time_now - self.epoch).total_seconds()

        if delta > (self.time_1_last + self.interval):
            self.time_1_last = delta
            self.time_1_hits = 0

        if delta > (self.time_2_last + self.interval):
            self.time_2_last = delta
            self.time_2_hits = 0

        if delta > (self.time_3_last + self.interval):
            self.time_3_last = delta
            self.time_3_hits = 0

        self.time_1_hits = self.time_1_hits + 1
        self.time_2_hits = self.time_2_hits + 1
        self.time_3_hits = self.time_3_hits + 1
        self.hits_total = self.hits_total + 1

        self.average()

    def average(self):
        ''' Calculate average hits per second '''

        hit_avg_1 = self.time_1_hits / self.interval
        hit_avg_2 = self.time_2_hits / self.interval
        hit_avg_3 = self.time_3_hits / self.interval

        self.avg = (hit_avg_1 + hit_avg_2 + hit_avg_3) / 3
        return self.avg

    def total(self):
        ''' Overal occurance average '''
        return self.hits_total / (datetime.now() - self.start_time).total_seconds()

class LogEntry(object):
    ''' Individual domlog entry '''

    hostname = ''
    ip = ''
    method = ''
    query = ''
    status = ''
    bytes = 0
    referer = ''
    user_agent = ''

    def __init__(self, hostname, regex_match):

        assert (regex_match is not None), "Log entry did not match"

        self.hostname = hostname

        self.ip = str(regex_match.group(1))
        self.method = str(regex_match.group(2))
        self.query = str(hostname + regex_match.group(3))
        self.status = str(regex_match.group(4))
        if regex_match.group(5) == '-':
            self.bytes = 0
        else:
            self.bytes = int(regex_match.group(5))
        self.referer = str(regex_match.group(6))
        self.user_agent = str(regex_match.group(7))

class IPHistory(object):
    ''' History of activity from an IP '''

    hit_list = []
    bytes_read = 0
    rate = Rate(8)

    def __init__(self):
        self.hit_list = []

    def add_hit(self, hit):
        ''' Record log entry '''

        self.bytes_read = self.bytes_read + hit.bytes
        self.hit_list.append(hit)
        self.rate.add_occurence()

    def print_activit(self, limit=5):
        ''' Print IP activity '''

        #print "Read %d bytes." % self.bytes_read

        for hit in hit_list:
            pass
            #print_hit(hit)


class LogHandler(object):
    ''' Class to handle log data '''

    # Shell process to poll for new log lines
    proc = None

    domain = None

    ips = defaultdict(IPHistory)
    queries = defaultdict(int)

    def hit_details(self, hit):
        ip = hit.ip
        ip_total = len(self.ips[hit.ip].hit_list)
        hit_per_s = self.ips[hit.ip].rate.avg
        query = hit.query
        query_total = self.queries[hit.query]
        status = hit.status
        return (ip, ip_total, hit_per_s, query, query_total, status)

    def add_hit(self, hit):
        ''' Record log entry to appropriate lists '''

        self.ips[hit.ip].add_hit(hit)
        self.queries[hit.query] = self.queries[hit.query] + 1
    
    def __init__(self):
        ''' Initialize logging '''
        file_list = []

        for file in glob.glob("/usr/local/apache/domlogs/*"):
            if os.path.isfile(file):
                if None is re.search(r"_log|^ftp|\.offset", file):
                    file_list.append(file)


        self.proc = tail(
            "-f",
            file_list,
            _iter_noblock=True
        )

    def __enter__(self):
        ''' Begin using class by passing self '''
        return self

    def __exit__(self, type, value, traceback):
        ''' Cleanup '''

        if self.proc is not None:
            self.proc.kill()


    def process_line(self, line):
        ''' Proccess line from logs '''

        hit_rx = re.compile(
            r'^([0-9.:]+)'
            r'\s+\S+\s+\S+\s+'
            r'\[[^]]+\]\s+'
            r'"([A-Z]+)\s+(\S+)\s+[^"]*"\s+'
            r'([0-9]+)\s+(\S+)\s+'
            r'"([^"]*)"\s+"(.*?)"$'
        )

        hit_match = hit_rx.search(line)
        if hit_match is not None:
            hit = LogEntry(self.domain, hit_match)
            self.add_hit(hit)
            return self.hit_details(hit)


    def poll(self):
        ''' Poll the logs for new data '''
        assert (self.proc is not None)

        line = self.proc.next()

        if line == errno.EWOULDBLOCK:
            # No new line
            return None
        if line == '\n':
            return None

        domain_match = re.search(r"(==>[^<]*?)([^/]+) <==", line)
        if domain_match is not None:
            self.domain = domain_match.group(2)
            return None

        return self.process_line(line)

    def top_ips(self, num=5):
        ''' Return activity for top hitting IP addresses '''

        ip_count = 0
        for ip in sorted(self.ips.items(), key=lambda x: len(x[1].hit_list), reverse=True):
            yield (ip[0], len(ip[1].hit_list), ip[1].bytes_read, ip[1].rate.avg)
            ip_count = ip_count + 1
            if ip_count >= num:
                return
        
    def top_queries(self, num=5):
        ''' Return activity for top requested queries '''

        query_count = 0
        for query in sorted(self.queries, key=self.queries.get, reverse=True):
            yield (query, self.queries[query])
            query_count = query_count + 1
            if query_count >= num:
                return
        
class Interface(object):
    ''' Class to handle control and view '''

    refresh_interval = 0.5
    last_refresh = 0

    main_lines = []
    ip_lines = []
    queries_lines = []

    scr = None
    main_win = None
    ip_win = None
    query_win = None

    main_win_y = 5
    main_win_x = 78

    queries_width = 20

    is_clean = False
    is_init = False

    needs_refresh = False


    def __init__(self, main_main_win_y=20):
        ''' Initialize interface '''

        self.main_lines = []
        self.ip_lines = []
        self.main_win_y = main_main_win_y
        self.main_win_x = 120

        self.init()

    def __enter__ (self):
        return self

    def __exit__(self):
        self.clean()

    def __del__(self):
        self.clean()

    def init(self):
        if self.is_init:
            return

        self.scr = curses.initscr()

        (max_y, max_x) = self.scr.getmaxyx()

        self.main_win_y = max_y - 4 - 25
        self.main_win_x = max_x - 4

        curses.raw()
        curses.cbreak()
        self.scr.nodelay(1)
        curses.noecho()
        curses.curs_set(0)
        curses.nl()


        self.queries_width = max_x - 40 - 4

        self.main_win = curses.newwin(self.main_win_y + 2, self.main_win_x + 2, 24, 1)
        self.ip_win = curses.newwin(13, 35, 0, 0)
        self.query_win = curses.newwin(13, self.queries_width + 1, 0, 40)

        self.main_win.clear()
        self.scr.clear()

        #self.main_win.border()
        self.main_win.refresh()
        #self.main_win.border()


        self.scr.refresh()

        self.is_clean = False
        self.is_init = True

    def clean(self):
        if self.is_clean:
            return

        curses.nocbreak()
        self.scr.keypad(0)
        curses.echo()
        curses.endwin()
        curses.curs_set(1)

        self.is_init = False
        self.is_clean = True

    def print_ips(self):
        self.ip_win.clear()
        #self.ip_win.border()
        self.ip_win.addstr(" -- Top IPs --\n")
        self.ip_win.addstr("    IP     Bytes Hits Rate\n")
        for line in self.ip_lines:
            text = "%s\n" % line[:34]
            #self.main_win.addstr(text)
            self.ip_win.addstr(text)
            print text
        self.ip_win.refresh()

    def print_queries(self):
        self.query_win.clear()
        #self.query_win.border()
        self.query_win.addstr(" -- Queries --\n")
        for line in self.queries_lines:
            text = "%s\n" % line[:self.queries_width]
            #self.main_win.addstr(text)
            self.query_win.addstr(text)
            print text
        self.query_win.refresh()

    def print_main(self):
        self.main_win.clear()
        #self.main_win.border()
        self.main_win.addstr(" -- Current Requests --\n")
        for line in self.main_lines:
            text = "%s\n" % line[:self.main_win_x -7]
            self.main_win.addstr(text)
        self.main_win.refresh()

    def out_ips(self, lines=[]):
        self.ip_lines = lines
        self.needs_refresh = True

    def out_queries(self, lines=[]):
        self.queries_lines = lines
        self.needs_refresh = True

    def out(self, line):
        ''' Add line to output '''
        if len(self.main_lines) >= self.main_win_y:
            self.main_lines.pop(0)
        self.main_lines.append(line)
        self.needs_refresh = True

    def print_all(self):
        if time.time() < (self.last_refresh + self.refresh_interval):
            return
        if self.needs_refresh:
            self.print_ips()
            self.print_queries()
            self.print_main()
            self.scr.refresh()
        self.needs_refresh = False
        self.last_refresh = time.time()

    def poll(self):
        ch = self.scr.getch()
        if not ch == -1:
            return ch
        
        self.print_all()
        return False


def main():
    ''' Main function for watch_apache '''

    iface = Interface()

    try:
        with LogHandler() as logs:
            while True:
                result = logs.poll()
                if result is not None:
                    #print result
                    iface.out (
                        "ip: %s (%d rate: %d)"
                        "    -    "
                        "%s (%d) Status: %s"
                        % result
                    )
                    ips = []
                    queries = []
                    for (ip, hits, bytes, avg) in logs.top_ips():
                        ips.append("%s - %d %d %d" % (ip, hits, bytes, avg))
                    iface.out_ips(ips)
                    for (query, count) in logs.top_queries():
                        queries.append("%d: %s" % (count, query))
                    iface.out_queries(queries)
                ch = iface.poll()
                if False:
                    iface.out(curses.keyname(ch))
                if curses.keyname(ch).lower() == 'q':
                    break
    except (KeyboardInterrupt):
        print "\nAll done\n"
    finally:
        iface.clean()


if __name__ == "__main__":
    main()
    sys.exit(0)
