#!/usr/bin/python3

'''
igrep, Internet sniffer.

(c) 2015 Jan ONDREJ (SAL) <ondrejj(at)salstar.sk>

 This program is free software; you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
 the Free Software Foundation; either version 2 of the License, or
 (at your option) any later version.

Usage: igrep.py [options] pcap_filter

  --help | -h		This help.
  -i interface		Set sniffing interface to "interface" (default: any).
  -f filename		Read pcap dump from filename.
  -e expr		Search for "expr".
  -s size		Max. packet size to display.
  -c count		Print count captured packets and exit.
  -p 			Put interface into promiscuous mode.
'''

import sys, re, time, pcapy, getopt, string, array, struct

class unknown:
  type = 'unknown'
  data_offset = 0
  def __init__(self, payload):
      self.payload = payload
  def data(self):
      return self.payload[self.data_offset:]

class tcp(unknown):
  # https://en.wikipedia.org/wiki/Transmission_Control_Protocol
  type = 'tcp'
  all_flags = 'FIN SYN RST PSH ACK URG ECE CWR NS'.split(" ")
  def __init__(self, payload):
      self.payload = payload
      self.src = struct.unpack(">H", self.payload[0:2])[0]
      self.dst = struct.unpack(">H", self.payload[2:4])[0]
      if len(self.payload)>12:
        self.data_offset = (ord(self.payload[12]) & 0xf0) >> 2
      else:
        self.data_offset = len(self.payload)
      self.len = len(self.payload)-self.data_offset 
  def get_flags(self):
      flags = ord(self.payload[13])
      for flag in self.all_flags:
        if flags&1 == 1:
          yield flag
        flags = flags >> 1
  def as_string(self, dg):
      yield "%s:%d -> %s:%d; %s\n" % (
        dg.src, self.src, dg.dst, self.dst,
        ','.join(self.get_flags())
      )
      yield printable(self.data())

class udp(tcp):
  type = 'udp'
  data_offset = 8

class icmp(tcp):
  # https://en.wikipedia.org/wiki/Internet_Control_Message_Protocol
  type = 'icmp'
  types = {
    0:	"Echo reply",
    3:	"Destination Unreachable",
    8:	"Echo Request"
  }
  codes = {
    3: {
      0:	"Destination network unreachable",
      1:	"Destination host unreachable",
      2:	"Destination protocol unreachable",
      3:	"Destination port unreachable",
      4:	"Fragmentation required, and DF flag set",
      5:	"Source route failed",
      6:	"Destination network unknown",
      7:	"Destination host unknown",
      8:	"Source host isolated",
      9:	"Network administratively prohibited",
      10:	"Host administratively prohibited",
      11:	"Network unreachable for TOS",
      12:	"Host unreachable for TOS",
      13:	"Communication administratively prohibited",
      14:	"Host Precedence Violation",
      15:	"Precedence cutoff in effect"
    }
  }
  def __init__(self, payload):
      self.payload = payload
      self.protocol_type = ord(payload[0])
      self.protocol_code = ord(payload[1])
      if self.protocol_type==3:
        self.tcp = ipv4(payload[8:])
      else:
        self.tcp = None
      self.data_offset = 8
  def as_string(self, dg):
      yield "%s -> %s\n" % (dg.src, dg.dst)
      if self.protocol_type in self.codes:
        yield "ICMP: %s: %s\n" % (
          self.types.get(self.protocol_type, self.protocol_type),
          self.codes[self.protocol_type].get(
            self.protocol_code, self.protocol_code)
        )
      else:
        yield "ICMP: %s\n" \
              % self.types.get(self.protocol_type, self.protocol_type)
      if self.tcp:
        e = self.tcp.data()
        yield "ICMP connection: TCP: %s:%d -> %s:%d\n" \
              % (proto.tcp.src, e.src, proto.tcp.dst, e.dst)

class ipv4:
  protocols = {
    1: icmp,
    6: tcp,
    17: udp
  }
  def __init__(self, payload):
      self.payload = payload
      self._src = array.array("B", payload[12:16])
      self._dst = array.array("B", payload[16:20])
      self.src = self.ip2str(self._src)
      self.dst = self.ip2str(self._dst)
      self.ihl = (ord(self.payload[0]) & 0x0f) << 2
      self.protocol = ord(self.payload[9])
  def ip2str(self, addr):
      return '.'.join([str(x) for x in addr])
  def data(self):
      protocol = self.protocols.get(self.protocol, unknown)
      return protocol(self.payload[self.ihl:])

class ipv6(ipv4):
  def __init__(self, payload):
      self.payload = payload
      self.ihl = 40
      self.protocol = ord(payload[6])
      self._src = struct.unpack(">HHHHHHHH", payload[8:24])
      self._dst = struct.unpack(">HHHHHHHH", payload[24:40])
      self.src = self.ip2str(self._src)
      self.dst = self.ip2str(self._dst)
  def ip2str(self, addr):
      return '['+':'.join(["%04x" % x for x in addr])+']'
  def data(self):
      protocol = self.protocols.get(self.protocol, unknown)
      return protocol(self.payload[self.ihl:])

def si(key, unit='B', delimeter=' '):
    fix = ['', 'k', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y']
    fix_len = len(fix)-1
    counter = 0
    while (key >= 1024) and (fix_len>counter):
      key /= 1024
      counter += 1
      if len(fix)==(counter-1):
        break
    return "%4d%s%1s%s" % (key, delimeter, fix[counter], unit)

def printable(s):
    return repr('\'"'+s)[4:-1].replace(
      "\\r\\n", "\n").replace(
      "\\n", "\n").replace(
      "\\t", "\t").replace(
      "\\'", "'")

def format_date(ts):
    return time.strftime(
      "%Y-%m-%d",
      time.localtime(ts[0])
    )

def format_time(ts):
    return time.strftime(
      "%H:%M:%S" + (".%06d" % ts[1]),
      time.localtime(ts[0])
    )

def stats(arr, max_count=5):
    return ', '.join([
      "%d:%dx" % x
      for x in sorted(arr.items(), key=lambda x: x[1], reverse=True)
    ][:max_count])

if __name__ == "__main__":
  opts, args = getopt.gnu_getopt(sys.argv[1:], 'hi:f:e:s:c:p', [
    'help', 'interface=', 'filename=', 'expression=',
    'size=', 'count=', 'promiscuous'
  ])
  opts = dict(opts)
  if "--help" in opts or "-h" in opts:
    print(__doc__)
    sys.exit(0)

  if "--interface" in opts:
    interface = opts['--interface']
  elif "-i" in opts:
    interface = opts['-i']
  else:
    #interface = pcapy.findalldevs()[0]
    interface = "any"

  if "--filename" in opts:
    filename = opts['--filename']
  elif "-f" in opts:
    filename = opts['-f']
  else:
    filename = None

  if "--promiscuous" in opts or "-p" in opts:
    promisc = 1
  else:
    promisc = 0

  if filename:
    cap = pcapy.open_offline(filename)
  else:
    cap = pcapy.open_live(interface, 1500, promisc, 100)
    print("Capture interface %s: %s/%s"
          % (interface, cap.getnet(), cap.getmask()))

  if "--expression" in opts:
    grep = re.compile(opts['--expression']).search
  elif "-e" in opts:
    grep = re.compile(opts['-e']).search
  else:
    def grep(x):
        return True

  if "--size" in opts:
    max_size = int(opts['--size'])
  elif "-s" in opts:
    max_size = int(opts['-s'])
  else:
    max_size = None

  if "--count" in opts:
    packet_count = int(opts['--count'])
  elif "-c" in opts:
    packet_count = int(opts['-c'])
  else:
    packet_count = None

  if len(args)>1:
    ip_filter = ' '.join(args)
    print("Filter: %s" % ip_filter)
    cap.setfilter(ip_filter)

  last_date = None
  packets = 0
  addrs = {}
  ports = {}
  try:
    while True:
      try:
        header, payload = cap.next()
      except pcapy.PcapError as err:
        if not filename:
          if str(err):
            print('PcapError:', str(err))
          continue
        else:
          break
      date = format_date(header.getts())
      if date!=last_date:
        print("DATE:", date)
        last_date = date
      out = [format_time(header.getts())+": "]
      ipv = ord(payload[14])>>4
      if ipv==4:
        dg = ipv4(payload[14:])
      elif ipv==6:
        dg = ipv6(payload[14:])
      elif ipv==0:
        continue # ignore
      elif ipv==8:
        continue # ignore EGP
      else:
        print("UNKNOWN PROTOCOL:", ipv)
        continue
      try:
        proto = dg.data()
      except IndexError as err:
        print("DECODE ERROR:", str(err))
        continue
      if proto.type=="tcp" or proto.type=="udp":
        data = proto.data()
        if grep(data):
          out.extend(proto.as_string(dg))
          addrs[dg.src] = 1
          addrs[dg.dst] = 2
          if proto.src not in ports:
            ports[proto.src] = 0
          ports[proto.src] += 1
          if proto.dst not in ports:
            ports[proto.dst] = 0
          ports[proto.dst] += 1
      elif proto.type=="icmp":
        out.extend(proto.as_string(dg))
      if len(out)>1:
        packets += 1
        if packet_count is not None and packets>packet_count:
          break
        if max_size is None:
          print(''.join(out))
        else:
          print(''.join(out)[:max_size])
  except KeyboardInterrupt:
    print("")
    print("Packets:", packets)
    print("IP addresses:", len(addrs))
    print("Ports: %s [%s]" % (len(ports), stats(ports)))
  except Exception:
    raise
