'''
Antivir Gateway Library for Sagator

(c) 2003-2024 Jan ONDREJ (SAL) <ondrejj(at)salstar.sk>

This program is a email antivir gate for an SMTP server.

 This program is free software; you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
 the Free Software Foundation; either version 2 of the License, or
 (at your option) any later version.

'''

from __future__ import absolute_import
from __future__ import print_function

import os, sys, re, time, getopt, socket, select
if sys.version_info[0]>2:
  from socketserver import ForkingTCPServer, StreamRequestHandler
else:
  from SocketServer import ForkingTCPServer, StreamRequestHandler

try:
  # email.utils is preferred for python3
  from email.utils import parseaddr
except ImportError:
  # compatibility with python<=2.4
  from rfc822 import parseaddr

from avlib import *

TEST_MSG = b'''\
From: Test User <test@tester.salstar.sk>\r
To: Test User2 <test2@tester2.salstar.sk>\r
Date: Fri, 16 Apr 2004 15:38:51 +0200 (CEST)\r
Subject: testing mail\r
\r
simple mail\r
'''

TEST_REQUEST = {
	'request': b'smtpd_access_policy',
	'protocol_state': b'RCPT',
	'protocol_name': b'ESMTP',
	'client_address': b'127.0.0.1',
	'client_name': b'localhost.localdomain',
	'helo_name': b'localhost.localdomain',
	'sender': b'root@localhost.localdomain',
	'recipient': b'root@localhost',
	'queue_id': b'',
	'instance': b'123.12345678.0',
	'size': b'0',
	'sasl_method': b'',
	'sasl_username': b'',
	'sasl_sender': b'',
	'ccert_subject': b'',
	'ccert_issuer': b'',
	'ccert_fingerprint': b''
}

#########################################################################
### service class template

class service(object):
  ''' Sagator service template. '''
  name = 'ASrv()'
  is_service = True
  EXITING = False
  MIN_CHILDS = 0
  childs = []
  def __init__(self, scanners, host, port, prefork=2, listen=5):
      self.SCANNERS = scanners
      self.BINDTO = (host, port)
      self.LISTEN = listen
      self.MIN_CHILDS = prefork
  def test_scanners(self, scanners):
      if type(scanners)==type({}):
        scanners = list(scanners.values())
      for scnr in scanners:
        scnr.reinit()
      if os.fork()==0:
        dochroot()
        globals.scan_only = True
        scanarr, scannames = [], 'None'
        mail.data = TEST_MSG
        mail.policy_request = TEST_REQUEST
        for scnr in scanners:
          if scnr.name:
            debug.echo(4, "Testing "+scnr.name+"...")
          level, detected, virlist, scan_report, err = do_scan(scnr)
          if scan_report:
            debug.echo(0, "Scanner ",scnr.name,
                       " test failed! Disable it manually!")
            debug.echo(0, "  ",scan_report)
          scanarr = scanarr+[scnr.name]
        mail.data = b''
        sys.exit(0)
      os.wait()
  def bind(self):
      # bind to socket
      self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
      self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
      try:
        self.s.bind(self.BINDTO)
        self.s.listen(self.LISTEN)
      except socket.error as err:
        (ec, es) = err.args
        debug.echo(0, "%s: ERROR: BIND %s (%s)" % (self.name, es, self.BINDTO))
        sys.exit(1)
  def start(self, run_test=True):
      if run_test:
        self.test_scanners(self.SCANNERS)
      self.bind()
      # prefork
      self.PID = os.getpid()
      self.childs = []
      for i in range(self.MIN_CHILDS):
        self.fork()
      debug.echo(1, "%s: service started ... %s" \
                    % (self.name, str(self.childs)))
      return self.childs
  def stop(self):
      pass
  def cleanup(self):
      pass
  def accept(self, connects=0):
      debug.echo(0, self.name+
                    ": This method must be overwritten in a subclass!")
      sys.exit(1)
  def fork(self):
      # Reinit scanners before fork to share memory.
      # Because autoexit is not implemented yet, this code has no effect.
      #for scnr in self.SCANNERS:
      #  scnr.reinit()
      # fork process now
      try:
        pid = os.fork()
        if pid>0:
          self.childs.append(pid)
          globals.fork_id += 1
          return pid
      except OSError as err:
        (ec, es) = err.args
        debug.echo(1, "ERROR: Can't fork:", es)
        return -1
      signal.signal(signal.SIGTERM, self.sigterm)
      signal.signal(signal.SIGCHLD, signal.SIG_DFL)
      signal.signal(signal.SIGHUP, self.sighup) # reopen log
      dochroot()
      random.seed() # touch random number generator
      self.time1, self.timec = '', 1
      while True:
        if self.EXITING:
          debug.echo(6, "%s: Exit flag set, bye, bye, ..." % (self.name))
          self.s.close()
          sys.exit(0)
        try:
          self.accept()
        except SmtpcError as err:
          (ses, es) = err.args
          # error is written, accept new connection
          continue
        except socket.error as err:
          try:
            ec = err.args[0]
            es = err.args[1]
            if ec==107: # Transport endpoint is not connected
              debug.echo(1, self.name+": Connection closed by peer.")
            elif ec==104: # Connection reset by peer
              debug.echo(1, self.name+": Connection reset by peer.")
            elif ec==32: # 'Broken pipe'
              debug.echo(0, self.name+": ERROR: Broken pipe.")
              debug.traceback(3, self.name+": ")
            elif ec!=4:
              debug.echo(0, self.name+': ERROR: SocketError: ', es)
              debug.traceback(4, self.name+": ")
          except IndexError:
            pass # socket timeout
        except KeyboardInterrupt:
          sys.exit(0)
        except SystemExit:
          raise
        except MemoryError:
          debug.echo(1, self.name+": ERROR: OUT OF MEMORY! EXITING PROCESS!")
          # it's safer to exit process now
          sys.exit(1)
        except:
          debug.echo(1, self.name+": ERROR: Unknown error!")
          debug.traceback(1, self.name+": ")
      return 0 # child process
  def thread_start(self):
      def loop():
          while not self.EXITING:
            self.accept()
      random.seed()
      self.time1, self.timec = '', 1
      self.bind()
      self.thread = Thread(target=loop)
      self.thread.start()
  def thread_stop(self):
      self.EXITING = True
      self.thread.join()
      self.s.close()
  def killall(self, sig=None):
      if sig is None:
        self.killall(signal.SIGTERM)
        time.sleep(0.5)
        sig = signal.SIGKILL
      self.s.close()
      for pid in self.childs:
        debug.echo(2, "Killing pid=%d with signal %d" % (pid, sig))
        os.kill(pid, sig)
        if sig==signal.SIGKILL:
          debug.echo(2, "waitpid()=%s" % str(os.waitpid(pid, 0)))
  def sighup(self, sn, stack):
      debug.reopen()
  def sigterm(self, sn, stack):
      debug.echo(6, '%s: TERM signal received.' % self.name)
      self.EXITING = True

class ServiceTCPServer(ForkingTCPServer, service):
  ''' Preforking Sagator service template. '''
  name = 'ForkingService()'
  is_service = True
  MIN_CHILDS = 1
  childs = []
  allow_reuse_address = True
  pgrp = None
  def __init__(self, scanners, host, port, request_handler):
      self.SCANNERS = scanners
      self.BINDTO = (host, port)
      self.REQUEST_HANDLER = request_handler
  def start(self):
      self.test_scanners(self.SCANNERS)
      ForkingTCPServer.__init__(self, self.BINDTO, self.REQUEST_HANDLER)
      self.PID = os.getpid()
      self.childs = []
      self.fork()
      debug.echo(1, "%s: service started ... %s" \
                    % (self.name, str(self.childs)))
      return self.childs
  def fork(self):
      try:
        pid = os.fork()
        if pid>0:
          self.childs.append(pid)
          return pid
      except OSError as err:
        (ec, es) = err.args
        debug.echo(1, "ERROR: Can't fork:", es)
        return -1
      os.setpgrp()
      self.pgrp = os.getpgrp()
      signal.signal(signal.SIGTERM, self.sigterm)
      signal.signal(signal.SIGCHLD, signal.SIG_DFL)
      signal.signal(signal.SIGHUP, self.sighup) # reopen log
      os.close(0) # close stdin
      os.open('/dev/null', os.O_RDONLY)
      dochroot()
      random.seed() # touch random number generator
      while True:
        try:
          self.serve_forever()
        except select.error as err:
          (ec, es) = err.args
          debug.echo(3, "%s: %s %s" % (self.name, ec, es))
          if ec==4: # Interrupted system call
            # Ignore this error, an signal received, no need to exit.
            continue
        break
      return 0 # child process
  def stop(self):
      if self.pgrp is None:
        return service.stop(self)
      pgrp = self.pgrp
      self.pgrp = None # avoid recursion
      signal.signal(signal.SIGTERM, signal.SIG_DFL)
      debug.echo(4, "%s: Sending TERM signal to process group %d"
                    % (self.name, pgrp))
      try:
        os.killpg(pgrp, signal.SIGTERM)
      except:
        pass
      # wait at most 3 seconds to before sending KILL signal
      for i in range(30):
        if self.active_children:
          debug.echo(6, "Waiting for ... %s" % list(self.active_children))
          time.sleep(0.1)
      # send KILL signal, if processess are still running
      if self.active_children:
        debug.echo(4, "%s: Sending KILL signal to process group %d"
                      % (self.name, pgrp))
        try:
          os.killpg(pgrp, signal.SIGKILL)
        except:
          pass
  def cleanup(self):
      pass
  def sighup(self, sn, stack):
      debug.reopen()
  def sigterm(self, sn, stack):
      debug.echo(6, '%s: TERM signal received.' % self.name)
      self.stop()
      os._exit(0)

def find_service(services, *names):
    ''' Find and return a service named "name". '''
    for service in services:
      if service.name in names:
        return service
    raise TypeError('Required service not defined %s' % str(names))

#########################################################################
### Some helpful functions

# chroot, setgid and setuid
def dochroot():
    ''' Do chroot() and setuid(), setgid(). '''
    def func_err(fn, f):
        try:
          eval(f)
          return 0
        except OSError as err:
          debug.echo(1, "WARNING! Can't do %10s: %s" % (fn, err.args[1]))
        except AttributeError:
          debug.echo(1, "WARNING! Can't do "+f+
                        ": Your python don't have this function!")
        return 1
    # preload some libraries
    import encodings.utf_8, encodings.latin_1, encodings.ascii, encodings.idna
    # change root
    if safe.ROOT_PATH:
      if not func_err('chroot', "os.chroot('"+safe.ROOT_PATH+"')"):
        safe.ROOT_PATH = '/'
        os.chdir("/") # change directory to new chroot
    if globals.GID:
      func_err('setgroups', "os.setgroups(["+str(globals.GID)+"])")
      func_err('setgid', "os.setgid("+str(globals.GID)+")")
    if globals.UID:
      func_err('setuid', "os.setuid("+str(globals.UID)+")")
    if (os.getuid()==0) or (os.getgid()==0
        or (os.geteuid()==0) or (os.getegid()==0)):
      debug.echo(0, "ERROR! I can't setuid and i am running as root!")
      sys.exit(10)

def do_scan(scnr, fn=''):
    '''
    Scan email (stored in mail class) for viruses.
    Use scnr scanner.
    '''
    level = 0.0
    detected = b''
    virlist = []
    try:
      if fn:
        scnr.filename = fn
      scnr.prescan()
      level, detected, virlist = \
        scnr.scanbuffer(mail.data, {'dbc':globals.DBC})
      scnr.postscan(level, detected, virlist)
      scnr.destroy()
      return level, detected, virlist, '', None
    except socket.timeout as es:
      debug.traceback(4, "do_scan: ")
      scnr.destroy()
      return level, detected, virlist, \
        scnr.name+": SocketTimeout: %s" % es, 'Socket timeout: %s' % es
    except socket.error as err:
      (ec, es) = err.args
      debug.traceback(4, "do_scan: ")
      scnr.destroy()
      return level, detected, virlist, \
        scnr.name+": SocketError: "+es, 'Socket error: '+es
    except (IOError, OSError):
      debug.traceback(4, "do_scan: ")
      scnr.destroy()
      return level, detected, virlist, \
        scnr.name+": Error: "+debug.traceback_value_str(), \
        'Error: '+debug.traceback_value_str()
    except (ScannerError, ValueError) as es:
      debug.traceback(4, "do_scan: ")
      scnr.destroy()
      return level, detected, virlist, str(es), 'Error: '+str(es)
    except KeyboardInterrupt:
      raise
    except:
      debug.traceback(4, "do_scan: ")
      return level, detected, virlist, \
        scnr.name+": unknown error", 'Unknown error'

def checkvir(scanners):
    ''' Check email (stored in mail class) for viruses. '''
    level = 0.0
    detected = b''
    virlist = []
    scanarr, scannames = [], 'None'
    globals.reset()
    try:
      for scnr in scanners:
        l, detected, virlist, scan_reply, err = do_scan(scnr)
        level += l
        if scan_reply:
          debug.echo(0, "CheckVir: ERROR: ", scan_reply)
          return S_TEMPFAIL, level, tobytes(scan_reply)
        else:
          scanarr = scanarr+[scnr.name]
          scannames = ' '.join(scanarr)
        if is_infected(level, detected):
          debug.echo(3, "Status: ",
                        [tostr(detected), tostr_list(virlist), level])
          break
      mail.addvirusinfo(detected, scannames)
      if not is_infected(level, detected):
        return S_OK, level, b'CLEAN' # mail delivered OK
      else: # VIRUS FOUND
        rstat = globals.ACTION
        debug.echo(5, "action: ", rstat)
        return rstat, level, detected
    except socket.error as err:
      (ec, es) = err.args
      debug.echo(0, "checkvir: ERROR: socket.error: ", [[ec]], es)
      debug.traceback(4, "checkvir: ")
      return S_TEMPFAIL, level, tobytes(es)
    except IOError as err:
      (ec, es) = err.args
      debug.echo(0, "checkvir: ERROR: IO: ", es)
      debug.traceback(4, "checkvir: ")
      return S_TEMPFAIL, level, b'IOError'
    except:
      debug.traceback(3, "checkvir: ")
      return S_TEMPFAIL, level, b'AnyError, see sagator logs'

def checkpolicy(scanners, smtp_style=True):
    # reinit scanners and globals
    for scnr in scanners:
      scnr.reinit()
    globals.reset(action=b'dunno')
    # scan it
    for scanner in scanners:
      level, detected, ret = scanner.scanbuffer(mail.data, {'dbc':globals.DBC})
      if is_infected(level, detected):
        # set ACTION=reject if not set yet
        if globals.ACTION==b'dunno':
          globals.ACTION = b'reject %s' % detected
        break
    debug.echo(3, 'checkpolicy: %s [%s] %s -> %s, action=%s' % (
      time.strftime("%c"),
      tostr(mail.policy_request.get(b'client_address')),
      tostr(mail.policy_request.get(b'sender')),
      tostr(mail.policy_request.get(b'recipient')),
      tostr(globals.ACTION))
    )
    if smtp_style:
      try:
        reply_action, reply_msg = globals.ACTION.split(b' ', 1)
        reply_action = reply_action.lower()
      except ValueError:
        reply_action = globals.ACTION
        reply_msg = b'Ok'
      if reply_action[:5]==b'defer':
        return b'450 %s' % reply_msg
      elif reply_action[:6]==b'reject':
        return b'554 %s' % reply_msg
      elif reply_action in (b'dunno', b'permit', b'prepend'):
        return b'250 %s' % tobytes(globals.ACTION)
      raise ScannerError("Unknown policy reply: %s" % tostr(globals.ACTION))
    if (globals.ACTION[:5]==b'dunno') and globals.PREPEND:
      return b'prepend %s' % tobytes(globals.PREPEND)
    return tobytes(globals.ACTION)

#########################################################################
### Signal handlers

def sighup(sn, stack):
    ''' signal: HUP signal '''
    debug.echo(3, "SIGHUP: Reopening log ...")
    debug.reopen()
    # reopen logs for all childs
    for service in globals.SRV:
      for pid in service.childs:
        try:
          os.kill(pid, signal.SIGHUP)
        except:
          pass

def sigchld(sn, stack):
    ''' signal: Child process terminated '''
    # try to remove older zombies (sigchld lost?)
    debug.echo(6, "CHLD signal received ...")
    try:
      while 1:
        pid, s = os.waitpid(0, os.WNOHANG)
        if pid==0:
          break
        for srv in globals.SRV:
          if pid in srv.childs:
            debug.echo(1, "%s: Child exited ... pid=%d, status=%d" \
                          % (srv.name, pid, s))
            srv.childs.remove(pid)
            # Autorestart imediatelly after exit does not work as expected
            # for smtp() and some other services.
            # Leave this to sagator's each 60 seconds restart.
            #if len(srv.childs)<srv.MIN_CHILDS:
            #  if not srv.EXITING:
            #    # start new process
            #    pid = srv.fork()
            #    srv.childs.append(pid)
            #    debug.echo(1, "%s: Starting a new child ... [%s]"
            #                  % (srv.name, srv.fork()))
            break
    except OSError: # no child processes?
      pass

def sigterm(sn, stack=None):
    ''' signal: Terminate process '''
    if type(sn)==type(''):
      debug.echo(1, "SERVER: Exiting - %s ..." % sn)
    else:
      debug.echo(1, "SERVER: Exiting - SIGTERM ...")
    for srv in globals.SRV:
      srv.MIN_CHILDS = 0
      for pid in srv.childs:
        debug.echo(5, '%s: killing child %d (TERM signal)' \
                      % (srv.name, pid))
        try:
          os.kill(pid, signal.SIGTERM)
          time.sleep(0.1)
        except OSError:
          pass
      srv.stop()
    sigchld(None, None) # unzombie all childs
    for srv in globals.SRV:
      for pid in srv.childs:
        debug.echo(3, '%s: killing child %d (KILL signal)' \
                   % (srv.name, pid))
        try:
          os.kill(pid, signal.SIGKILL)
        except OSError:
          pass
      srv.cleanup()
    sigchld(None, None) # unzombie all childs
    try:
      if globals.pid_file:
        os.unlink(globals.pid_file)
    except OSError:
      pass
    debug.echo(6, "All childs killed, shutting down ...")
    os._exit(0) # Exit immediatelly, don't use sys.exit()

def sigusr2(sn, stack):
    ''' signal: USR2 signal '''
    debug.echo(1, "SIGUSR2: Tracelog request received ...")
    debug.debug_level = debug.trace_level
    debug.stack(1, "SIGUSR2: ")

#########################################################################
### Sqback and queue file

def send_qfile(fname, smtp_server, rcptto=''):
    ''' Send a file from quarantine to it's recipient. '''
    conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    conn.connect(smtp_server)
    out = b"Connection to: "+tobytes(str(smtp_server))+b" established\n"
    f = open(fname, 'rb')
    r_rcptto = re.compile(br"^RCPT TO:.*[\r]*$", re.IGNORECASE|re.MULTILINE)
    r_data = re.compile(br"^DATA[\r]*$", re.IGNORECASE|re.MULTILINE)
    if rcptto:
      buf = f.read()
      buf = r_rcptto.sub(b"RCPT TO: "+tobytes(rcptto)+b"\r", buf, 1)
      while 1:
        s = r_data.search(buf).start()
        e = r_rcptto.search(buf).end()
        while buf[e] in ['\r', '\n', 13, 10]:
          e += 1
        if e<s:
          buf = buf[:e]+buf[s:]
        else:
          break
      conn.sendall(buf)
    else:
      conn.sendall(f.read())
    f.close()
    while 1:
      try:
        r = conn.recv(10240)
        if r==b'': break
        out += r.strip()+b'\n'
      except:
        out += b"Connection closed.\n"
        break
    return out

class qfile(smtp):
  '''
  A class to parse a quarantine file.
  '''
  def __init__(self):
      self.sender = ''
      self.recipients = []
      self.data = BytesIO()
      self.body_pos = 0
  def parse(self, f):
      ''' Parse quarantine file. 'f' is an file like object. '''
      while True:
        line1 = f.readline()
        if not line1:
          return
        sender = self.reg_mailfrom.search(line1)
        if sender:
          self.sender = sender.group(1)
        recipient = self.reg_rcptto.search(line1)
        if recipient:
          self.recipients.append(recipient.group(1))
        if self.reg_data.search(line1):
          break
      while True:
        line1 = f.readline()
        if not line1:
          return
        if line1==b'.\n' or line1==b'.\r\n':
          break
        if (self.body_pos==0) and (line1==b'\n' or line1==b'\r\n'):
          self.body_pos = self.data.tell()
        self.data.write(line1)
  def message(self):
      ''' Return message's header + body '''
      return self.data.getvalue()
  def header(self):
      ''' Return message header '''
      return self.data.getvalue()[:self.body_pos]
  def body(self):
      ''' Return message body '''
      return self.data.getvalue()[self.body_pos:]

# getopt python2.2 compatibility
if 'gnu_getopt' not in getopt.__all__:
  getopt.gnu_getopt = getopt.getopt
