#!/usr/bin/python3

from __future__ import print_function

import sys, socket, os, random, time, threading, getopt

HELP_TEXT='''\
Usage: $0 [-c count] [-t threads] [-h hostname] [-p port] [-f prefix] [-s] [-x] [-v]
  -c count      number of test
  -t threads    number of threads run in paralel
  -h hostname   hostname of SMTP server
  --random-ip	generate random IPs from range 127.* as hostname
  -p port       tcp port of SMTP or policy server
  -f prefix     IP prefix, by default 127.0
  -s filename   run in SMTP test mode.
                Send filename as DATA content (mbox or classic file format)
  -r email      email recipient
  -q            short/quick SMTP test mode
  -x            exclude tests ended with OK
  -n            nagios mode
  --warning 10	warning time for nagios mode
  --critical 60	critical time for nagios mode
  -v            verbose mode

Examples:
  Policy test:
    smtptest.py -c 1
  SMTP test vith virus file:
    smtptest.py -s ./virus -r email@domain.tld
  Nagios mode:
    smtptest.py -c 1 -n
'''

if len(sys.argv)<2:
  print(HELP_TEXT)
  sys.exit()

TIMEOUT = int(os.environ.get('DEFAULT_SOCKET_TIMEOUT', 30))
SENDER = os.environ.get('MAILFROM', 'foo@bar.tld').encode()
RECIPIENT = os.environ.get('RCPTTO', 'root@localhost').encode()

SHORT_SMTP_TEMPLATE = b'''\
HELO '''+socket.gethostname().encode()+b'''
XFORWARD addr=%s
MAIL FROM:<'''+SENDER+b'''>
RCPT TO:<'''+RECIPIENT+b'''>
QUIT
'''

LONG_SMTP_TEMPLATE = b'''\
HELO '''+socket.gethostname().encode()+b'''
XFORWARD addr=%s
MAIL FROM:<'''+SENDER+b'''>
RCPT TO:<'''+RECIPIENT+b'''>
DATA
%s
.
QUIT
'''

POLICY_TEMPLATE = b'''\
request=smtpd_access_policy
protocol_state=RCPT
protocol_name=SMTP
helo_name=some.domain.tld
queue_id=8045F2AB23
sender='''+SENDER+b'''
recipient=+'''+RECIPIENT+b'''
recipient_count=1
client_address=%s
client_name=another.domain.tld
sasl_username='''+os.environ.get("USER", "root").encode()+b'''
reverse_client_name=another.domain.tld
instance=123.456.7

'''

rng = random.Random()
rng.seed()
def rnd(max):
    return int(rng.random()*max)

def random_ip(self):
    ip = PREFIX.encode()
    while ip.count(b'.')<3:
      ip += b'.' + str(rnd(256)).encode()
    return ip

class shortsmtpthread(threading.Thread):
  def get_template(self, ip):
      if ip==b"localhost":
        ip = b"127.0.0.1"
      return SHORT_SMTP_TEMPLATE % ip
  def test(self):
      ip = self.get_ip()
      s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
      s.settimeout(TIMEOUT)
      if not quiet_mode:
        print("Connecting to: %s:%s" % (ip, port))
      s.connect((ip, port))
      if not quiet_mode:
        print("Sending...")
        #print self.get_template(ip=ip)
      s.sendall(self.get_template(ip=ip.encode()))
      f = s.makefile()
      status = 'OK'
      status_line = ''
      for i in range(20):
        line = f.readline()
        if line.startswith('221'):
          break
        if line.startswith('4'):
          #print line[:79].rstrip('\n')
          status = line[:3]
          status_line = line
      try:
        s.shutdown(socket.SHUT_RDWR)
      except socket.error:
        pass
      s.close()
      return ip, status, status_line
  def run(self):
      self.summary = {}
      self.min_time = 9999
      self.max_time = 0
      self.sum_time = 0
      t0 = time.time()
      for i in range(counts):
        t1 = time.time()
        try:
          ip, status, line = self.test()
        except socket.error as e:
          if not quiet_mode:
            print("ERROR: %s" % e)
          continue
        if status in self.summary:
          self.summary[status] += 1
        else:
          self.summary[status] = 1
        test_time = time.time()-t1
        self.min_time = min(self.min_time, test_time)
        self.max_time = max(self.max_time, test_time)
        self.sum_time += test_time
        if exclude_ok and (status=='OK' or status=='450') and test_time<5:
          sys.stdout.write("%d          \r" % (i+1))
          sys.stdout.flush()
          continue
        if not quiet_mode:
          print("%s: %15s, %3d:%-5d, %6.2f/s, %6.2fs,  status: %s" \
                % (time.strftime("%b %e %H:%M:%S"), ip,
                   self.ID, i+1, (i+1)/(time.time()-t0), test_time, status))
        if verbose_mode and line and status!='OK' and status!='450':
          if not quiet_mode:
            print(line, end=' ')

class longsmtpthread(shortsmtpthread):
  def get_template(self, ip):
      if ip==b"localhost":
        ip = b"127.0.0.1"
      return LONG_SMTP_TEMPLATE % (ip, DATA_CONTENT)

class policythread(threading.Thread):
  def __init__(self):
      threading.Thread.__init__(self)
      self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
      self.s.settimeout(TIMEOUT)
      self.s.connect((hostname, port))
      self.f = self.s.makefile()
  def run(self):
      self.summary = {}
      self.min_time = 9999
      self.max_time = 0
      self.sum_time = 0
      t0 = time.time()
      for i in range(counts):
        t1 = time.time()
        ip = random_ip(self)
        self.s.sendall(POLICY_TEMPLATE % ip)
        line = self.f.readline().rstrip('\n')
        if self.f.readline().strip():
          if not quiet_mode:
            print("ERROR")
        status = ' '.join(line[7:].split(' ', 2)[:2]).rstrip(': ')
        if status in self.summary:
          self.summary[status] += 1
        else:
          self.summary[status] = 1
        test_time = time.time()-t1
        self.min_time = min(self.min_time, test_time)
        self.max_time = max(self.max_time, test_time)
        self.sum_time += test_time
        if exclude_ok and status=='dunno' and test_time<5:
          sys.stdout.write("%d          \r" % (i+1))
          continue
        if not quiet_mode:
          print("%15s: %3d:%-5d, %8.2f/s,  status: %s" \
                % (ip.decode(), self.ID, i+1, (i+1)/(time.time()-t0), status))
        if verbose_mode:
          if not quiet_mode:
            print(line, end=' ')

try:
  opts, files = getopt.gnu_getopt(sys.argv[1:], 'h:s:r:qxvc:t:h:p:f:n',
    ['help', 'count=', 'threads=', 'hostname=', 'port=', 'prefix=',
     'smtp=', 'recipient=', 'quick', 'exclude-ok', 'nagios', 'random-ip'])
except getopt.GetoptError as err:
  (msg, opt) = err.args
  print("Error:", msg)
  sys.exit(1)
if files:
  print("Unknown parameter(s):", ' '.join(files))
  sys.exit(1)

opts = dict(opts)
counts = int(opts.get('-c') or opts.get('--count') or 1)
threads = int(opts.get('-t') or opts.get('--threads') or 1)
hostname = opts.get('-h') or opts.get('--hostname') or 'localhost'
port = opts.get('-p') or opts.get('--port')
recipient = opts.get('-r') or opts.get('--recipient') or ('root@%s' % hostname)
PREFIX = opts.get('-f') or opts.get('--prefix') or '127.0'
exclude_ok = ('-x' in opts) or ('--exclude-ok' in opts)
verbose_mode = ('-v' in opts) or ('--verbose' in opts)
nagios_mode = ('-n' in opts) or ('--nagios' in opts)
quiet_mode = nagios_mode

if ('-s' in opts) or ('--smtp' in opts):
  thread_type = longsmtpthread
  port = int(port or 25)
  # read email DATA
  f = open(opts.get('-s') or opts.get('--smtp'), 'rb')
  if not f.readline().startswith(b'From '):
    # move to start, if not in mbox format
    f.seek(0)
  DATA_CONTENT = f.read()
  f.close()
elif ('-q' in opts) or ('--quick' in opts):
  # quick mode
  thread_type = shortsmtpthread
  port = int(port or 25)
else:
  thread_type = policythread
  port = int(port or 29)

if ('-r' in opts) or ('--recipient' in opts):
  if not ('-h' in opts) and not ('--hostname' in opts):
    try:
      import DNS
    except ImportError:
      print('You need python-pydns or py3dns to use -r without specifiing hostname (-h).')
      print('Please install pydns or add "-h hostname" parameter.')
      sys.exit()
    domain = recipient.split('@', 1)[1]
    DNS.DiscoverNameServers()
    hostname = sorted(DNS.mxlookup(domain))[0][1]

if '--random-ip' in opts:
  thread_type.get_ip = random_ip
else:
  thread_type.get_ip = lambda self: hostname

t0 = time.time()
min_time = 9999
max_time = 0
sum_time = 0
summary = {}
processes = {}
for p in range(threads):
  processes[p] = thread_type()
  processes[p].ID = p+1
  if threads>1:
    processes[p].start()
  else:
    # start only this process, no need for threading
    processes[p].run()
for p in range(threads):
  if threads>1:
    processes[p].join()
  for key,value in list(processes[p].summary.items()):
    if key in summary:
      summary[key] += value
    else:
      summary[key] = value
    min_time = min(min_time, processes[p].min_time)
    max_time = max(max_time, processes[p].max_time)
    sum_time += processes[p].sum_time
run_time = time.time()-t0

if nagios_mode:
  warn = float(opts.get('--warning', '5'))
  crit = float(opts.get('--critical', '8'))
  txt = ','.join(list(summary.keys()))
  if run_time<warn:
    print("SMTP_POLICY OK - %s in %d ms|time=%s;%s;%s;0" \
          % (txt, run_time*1000, run_time, warn, crit))
    sys.exit(0)
  elif run_time<crit:
    print("SMTP_POLICY WARNING - %s in %d ms|time=%s;%s;%s;0" \
          % (txt, run_time*1000, run_time, warn, crit))
    sys.exit(1)
  else:
    print("SMTP_POLICY CRITICAL - %s in %d ms|time=%s;%s;%s;0" \
          % (txt, run_time*1000, run_time, warn, crit))
    sys.exit(2)

print("Summary:")
for key in summary:
  print("%40s: %5d = %10.6f%%" \
        % (key, summary[key], summary[key]*100.0/threads/counts))

print("Run time: %d:%02d:%02d" % (run_time/3600, run_time/60%60, run_time%60))
print("Min./max. time: %4.2f/%4.2f" % (min_time, max_time))
print("Avg. time: %4.2f" % (sum_time/threads/counts))
