#!/usr/bin/python3
# -*- coding: UTF-8 -*-

import unittest, inspect, os, socket, time
from threading import Thread

from avlib import ascanner, tobytes, mail, debug, smtpc
from avir import *
from aspam import *
from interscan import *
from srv import *

# for testing purposes
debug.set_level(int(os.environ.get("DEBUG_LEVEL", "0")))
mail.addr = (b"127.0.0.1", 25, b"noname")
mail.recip = ["x@y.z"]
mail.sender = "wrong﻿string@x.y"
smtpc.SMTP_SERVER = ('127.0.0.1', 2525)
LOCAL_IPS = r'^(192\.168|172\.(1[6789]|2[0-9]|3[01])|10|127)\.'
CLEAN = const(0.0)
VIRUS = const(1.0, "virus")
TEST_EMAIL = b'''\
From: Test User <%s>
To: Test User2 <%s>
Date: Fri, 16 Apr 2019 15:38:51 +0200 (CEST)
Subject: testing mail

simple test mail
''' % (tobytes(mail.sender), tobytes(mail.recip[0]))

class smtpd_simulator:
  def __init__(self):
      self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
      self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
      self.s.settimeout(1)
      self.s.bind(smtpc.SMTP_SERVER)
      debug.echo(2, "SMTPD-SIM: LISTEN ON:", smtpc.SMTP_SERVER)
      self.s.listen(1)
  def loop(self):
      self.stop = False
      while not self.stop:
        self.accept()
      debug.echo(2, "SMTPD-SIM: STOP")
      self.s.close()
  def reply(self, code=250, txt=b"OK"):
      self.conn.sendall(b"%d %s\r\n" % (code, txt))
      debug.echo(2, "SMTPD-SIM: %d" % code)
      self.replies.append(code)
  def accept(self):
      try:
        self.conn, addr = self.s.accept()
        debug.echo(2, "SMTPD_SIM: ACCEPT FROM:", addr)
      except socket.timeout:
        self.stop = True
        return
      self.replies = []
      data = False
      self.f = self.conn.makefile("rwb")
      self.reply(220)
      while True:
        line = self.f.readline()
        if data:
          if line.strip()==b".":
            data = False
            self.reply()
          continue
        debug.echo(2, "SMTPD-SIM: ", line.strip().decode("utf8"))
        if line.startswith(b"DATA"):
          self.reply(354)
          data = True
          continue
        if line.startswith(b"QUIT"):
          self.reply(221)
          break
        self.reply()
      self.f.close()
      try:
        self.conn.shutdown(socket.SHUT_RDWR)
      except:
        pass
      self.conn.close()
smtpd_simulator = smtpd_simulator()

class smtpc_tester(smtpc):
  SMTP_SERVER = ('127.0.0.1', 2727)

### Testers

class ScannerTest(unittest.TestCase):
  buffer = b"MZ_executable\0_SomethingFound"
  args = {}

  def assertClean(self, scanner):
      scanner.reinit()
      self.assertEqual(
        scanner.scanbuffer(self.buffer, self.args)[0:2], (0.0, b"")
      )

  def assertInf(self, scanner, virus=VIRUS.vir, level=1.0):
      scanner.reinit()
      self.assertEqual(
        scanner.scanbuffer(self.buffer, self.args)[0:2],
        (level, tobytes(virus))
      )

class BasicVirScanners(ScannerTest):
  def test_regexp_scan(self):
      self.assertInf(regexp_scan([["VIR_NAME", self.buffer]]), "VIR_NAME")
      self.assertClean(regexp_scan([["VIR_NAME", "NOT FOUND"]]))
  def test_string_scan(self):
      self.assertInf(string_scan([["VIR_NAME", self.buffer]]), "VIR_NAME")
      self.assertClean(string_scan([["VIR_NAME", "NOT FOUND"]]))
  def test_smtp_comm(self):
      pass # TODO
  def test_const(self):
      self.assertClean(CLEAN)
      self.assertInf(const(1.0, "VIR_NAME"), "VIR_NAME")
  def test_max_file_size(self):
      self.assertInf(max_file_size(len(self.buffer)-1), "FileSizeOverrun")
      self.assertClean(max_file_size(len(self.buffer)+1))
  def test_file_type(self):
      self.assertInf(file_type({"exe": "mexec"}), "mexec")
      self.assertClean(file_type({"doc": "document"}))
  def test_file_magic(self):
      self.assertInf(file_magic({"exe": "MS-DOS executable"}, raw=True), "exe")
      self.assertClean(file_magic({"doc": "document"}))
  def test_sender_regexp(self):
      self.assertInf(
        sender_regexp([["VIR_NAME", "127.0.0.1"]]), "VIR_NAME"
      )
      self.assertClean(
        sender_regexp([["VIR_NAME", "127.0.0.2"]])
      )

class FilterScanner(ScannerTest):
  def test_filter(self):
      self.assertClean(filter("grep ^MZ"))

class ActionsScanners(ScannerTest):
  def test_quarantine(self):
      self.assertClean(quarantine('/tmp', '', CLEAN))
      #self.assertInf(quarantine('/tmp', '', VIRUS))
  def test_drop(self):
      self.assertClean(drop('.', CLEAN))
      self.assertInf(drop('.', VIRUS))
  def test_deliver(self):
      self.assertClean(deliver('.', CLEAN))
      self.assertInf(deliver('.', VIRUS))
  def test_deliver_to(self):
      self.assertClean(deliver_to(['x@y.z'], CLEAN))
      self.assertInf(deliver_to(['x@y.z'], VIRUS))
  def test_custom_action(self):
      self.assertInf(
        custom_action('.', 550, "Rejected %(VIRNAME)s", VIRUS)
      )
  def test_rename(self):
      self.assertInf(rename("new_vir_name", const(1.0)), "new_vir_name")
  def test_time_limit(self):
      self.assertInf(time_limit(1, VIRUS))
  def test_cache(self):
      self.assertClean(cache('var_cl', CLEAN))
      self.assertInf(cache('var_inf', VIRUS))
      self.assertClean(cache("var_cl"))
      #self.assertInf(cache("var_inf")) # TODO

class ConditionScanners(ScannerTest):
  def test_sql_find(self):
      pass # TODO
  def test_regexp_find(self):
      pass # TODO
  def test_check_level(self):
      self.assertInf(
        check_level(const(4.0), {
          (1.0, 5.0): const(1.0, "1to5"),
          (5.0, 9.0): const(5.0, "5to9")
        }),
        "1to5"
      )
  # obsoletes
  #def test_rpct_in_sql(self):
  #    pass # TODO
  #def test_rcpt_in_txt(self):
  #    fn = "/tmp/rcpt.txt"
  #    open(fn, "wt").write("%s\n" % tostr(mail.recip[0]))
  #    self.assertInf(rcpt_in_txt(fn) & VIRUS)

class FileBufScanners(ScannerTest):
  def test_buffer2file(self):
      self.assertClean(b2f(CLEAN))
      self.assertInf(b2f(VIRUS))
  def test_buffer2mbox(self):
      self.assertClean(buffer2mbox(CLEAN))
      self.assertInf(buffer2mbox(VIRUS))
  def test_file2buffer(self):
      self.assertClean(b2f(f2b(CLEAN)))
      self.assertInf(b2f(f2b(VIRUS)))

class HeaderScanners(ScannerTest):
  def test_add_header(self):
      self.assertInf(add_header('X-Header', '%S %L %V', VIRUS))
      self.assertEqual(mail.xhdra['X-Header'], "AScanner() 1.0 virus")
  def test_modify_header(self):
      self.assertInf(modify_header('^X-Header', 'X-Header: test', VIRUS))
  def test_modify_subject(self):
      self.assertInf(modify_subject("Test subject", VIRUS))
  def test_remove_headers(self):
      self.assertClean(remove_headers("X-Header"))

class LoggerScanners(ScannerTest):
  pass

class MatchScanners(ScannerTest):
  def test_match_all(self):
      self.assertInf(match_all(VIRUS, VIRUS, VIRUS), "virus,virus,virus", 3.0)
      self.assertClean(match_all(VIRUS, CLEAN))
  def test_match_any(self):
      self.assertInf(match_any(VIRUS, CLEAN))
      self.assertClean(match_any(CLEAN, CLEAN))
  def test_alternatives(self):
      self.assertInf(alternatives(VIRUS, CLEAN))
      self.assertClean(alternatives(CLEAN, VIRUS))
      self.assertInf(alternatives(VIRUS, VIRUS))
      self.assertClean(alternatives(CLEAN, CLEAN))
  def test_retry(self):
      self.assertClean(retry(5, CLEAN))
      self.assertInf(retry(5, VIRUS))
  def test_nothing(self):
      self.assertClean(nothing())

class PolicyScanners(ScannerTest):
  def test_set_action(self):
      self.assertInf(set_action("dunno", VIRUS))
  def test_spf_check(self):
      self.assertClean(spf_check())
  @unittest.skipIf(
    not os.path.exists("/usr/share/GeoIP/GeoLite2-Country.mmdb"),
    "Missing mmdb file"
  )
  def test_geoip2_country(self):
      self.assertClean(geoip2_country(["EN"]))

class ReportScanners(ScannerTest):
  def test_report(self):
      self.assertInf(report(['test@salstar.sk'], report.MSG_TMPL, VIRUS))
      self.assertInf(
        report(
          ['test@salstar.sk'],
          report.MSG_TMPL,
          VIRUS
        ).ifscan(sender_regexp({'LOCAL_IP': [LOCAL_IPS]}))
      )


'''
class ClamavScanner(ScannerTest):
  eicar = b"X5O!P%@ AP[4\\PZX54(P^) 7CC)7}$EIC AR-STAN DARD-ANTI VIRUS-TEST-FILE!$H+H*".replace(b" ", b"")

  def setUp(self):
      self.scanner = b2f(libclam())
      self.scanner.reinit()

  def test_libclam(self):
      self.assertClean(self.scanner)
      self.buffer = self.eicar
      self.assertInf(self.scanner, "Eicar-Test-Signature")
'''

### SERVICES

'''
class SmtpdService(ScannerTest):

  def assertService(self, service, ret_code):
      service.thread_start()
      smtpct = smtpc_tester()
      smtpct.sendmail(mail.sender, mail.recip, TEST_EMAIL, ret_code)
      smtpct.close()
      service.thread_stop()
      fail_replies = [x for x in smtpd_simulator.replies if x>=400]
      if fail_replies:
        self.assertTrue(len(smtpd_simulator.replies)<7)
        self.assertTrue(len(fail_replies)>0)

  def test_smtpd_virus(self):
      self.assertService(smtpd([VIRUS], '127.0.0.1', 2727, 1), b"550")
  def test_smtpd_clean(self):
      self.assertService(smtpd([CLEAN], '127.0.0.1', 2727, 1), b"")
'''


if __name__ == "__main__":
  smtpd_thread = Thread(target=smtpd_simulator.loop)
  smtpd_thread.start()
  unittest.main()
  smtpd_simulator.stop = True
  smtpd_thread.join()
