'''
lmtpd.py - LMTP daemon service for sagator.

(c) 2003-2019 Jan ONDREJ (SAL) <ondrejj(at)salstar.sk>

 This program is free software; you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
 the Free Software Foundation; either version 2 of the License, or
 (at your option) any later version.

'''

from aglib import *
import stats

__all__=['lmtpd']

class lmtps(smtp):
  '''
  SMTP server class
  '''
  reg_helo=re.compile(r"^LHLO +(.*)$",re.IGNORECASE)
  reg_xforward=re.compile(r"^XFORWARD ",re.IGNORECASE)
  def __init__(self,scanners,conn,addr):
      self.SCANNERS=scanners
      self.f=conn.makefile('rw',self.bufsize)
      self.write=conn.sendall
      self.readline=self.f.readline
      self.RECV=RECV_SMTP
      self.smtpcs={} # smtp clients for each recipient index
      self.mail_from=''
      self.rcpt_to_seq=[]
      self.write(b"220 sagator.localhost LMTP server for SAGATOR\r\n")
      mail.__init__()
      mail.addr = (
        addr[0].encode(), addr[1], mail.getnamebyaddr(addr[0]).encode()
      )
      while True:
        if self.recv(): break
  def send_ok(self,r="Ok",code=250):
      self.write(b"%03d %s\r\n" % (code,r))
  def shutdown_clients(self):
      ''' Try to shutdown all SMTP clients. '''
      for key,client in list(self.smtpcs.items()):
        try:
          client.shutdown(socket.SHUT_RDWR)
          client.close()
        except:
          pass
  def recv(self):
      try:
        line=self.readline()
      except IOError as err:
        (ec,es) = err.args
        debug.echo(0, "lmtps(): ", es)
        self.shutdown_clients()
        return 1
      if not line:
        debug.echo(9, "lmtps(): Empty line")
        self.shutdown_clients()
        return 1
      if self.RECV: # receiving data or header
        if self.RECV==RECV_QUIT: # waiting for quit?
          if self.reg_quit.search(line):
            self.write(b"221 Bye\r\n")
            self.f.close()
            self.shutdown_clients()
            return 1
          else:
            self.write(b"502 SMTP communication stopped by Sagator\r\n")
            return 0
        elif (line==b'.\r\n') | (line==b'.\n'):
          # end of DATA part
          mail.close()
          debug.echo(3,"LMTPS: BODY DONE, size: ", len(mail.data), " B")
          self.RECV=RECV_SMTP
          vir_status={}
          for rcpt_sig,rcpt_email in self.rcpt_to_seq:
            if rcpt_sig not in vir_status:
              client=self.smtpcs[rcpt_sig]
              self.stats=stats.statistics()
              mail.recip=client.email_recipients
              mail.store()
              v,level,virname=checkvir(self.SCANNERS)
              if (v==S_OK) | (v==S_FORCE_SEND):
                debug.echo(5,"LMTPS: Sending data for %s" % str(rcpt_sig))
                client.send('DATA',354)
                client.conn.sendall(mail.xheader)
                client.conn.sendall(mail.data)
                client.conn.sendall(line) # send .
                reply=client.readline()
              else:
                reply=''
              vir_status[rcpt_sig]=(v,virname,reply)
              if v!=S_TEMPFAIL:
                debug.echo(2,"STATS: %3.2f seconds, %d bytes, status: %s" \
                           % (self.stats.end(), len(mail.data), tostr(virname)))
                if v in (S_OK,S_FORCE_SEND):
                  self.stats.update(len(mail.data))
                elif v in (S_REJECT,S_DROP,S_CUSTOM):
                  self.stats.update(len(mail.data), 1)
              else: # S_TEMPFAIL
                self.stats.update(tempfail=1) # update fail statistics
              mail.restore()
            v,virname,reply=vir_status[rcpt_sig]
            if v in (S_OK,S_FORCE_SEND):
              debug.echo(1,"LMTPS: <%s> OK: 250 Ok" % rcpt_email)
              self.write(reply)
            elif v==S_REJECT:
              debug.echo(1,"LMTPS: <%s> REJECT: 550 Content rejected - %s"
                         % (rcpt_email, tostr(virname)))
              self.write(b"550 Content rejected - %s\r\n" % (virname))
            elif v==S_DROP:
              debug.echo(1,"LMTPS: <%s> DROP: 250 mail dropped - %s"
                         % (rcpt_email, tostr(virname)))
              self.write(b"250 mail dropped - %s\r\n" % (virname))
            elif v==S_CUSTOM:
              debug.echo(1, "LMTPS: <%s> CUSTOM: %s"
                         % (rcpt_email, tostr(globals.REPLY)))
              self.write(b"%s %s\r\n" % globals.REPLY)
            else: # S_TEMPFAIL
              debug.echo(1,"LMTPS: <%s> TEMPFAIL: 451 %s" \
                         % (rcpt_email, tostr(virname)))
              self.write(b"451 %s\r\n" % (virname))
          self.shutdown_clients()
          return 1
        else:
          mail.df.write(line)
      else:
        mail.comm=mail.comm+line
        rcptto=self.reg_rcptto.search(line)
        mailfrom=self.reg_mailfrom.search(line)
        if rcptto:
          try:
            rcpt_email=parseaddr(tostr(rcptto.group(1)))[1]
            mail.recip.append(rcpt_email)
          except:
            rcpt_email=rcptto.group(1)
          # create a new SMTPC
          rcpt_sig=",".join([x.rcpt_signature(rcpt_email) for x in self.SCANNERS])
          if rcpt_sig not in list(self.smtpcs.keys()):
            self.smtpcs[rcpt_sig]=smtpc(self.mail_from)
            self.smtpcs[rcpt_sig].email_recipients=[]
          sline=self.smtpcs[rcpt_sig].send(line.rstrip())
          self.write(sline)
          if self.reg_smtp_reply_ok.search(sline):
            debug.echo(2,'LMTPS: %s' % line.rstrip())
            debug.echo(3,'lmtpd(): signature: [', rcpt_email, ': ', rcpt_sig, ']')
            self.rcpt_to_seq.append((rcpt_sig,rcpt_email))
            self.smtpcs[rcpt_sig].email_recipients.append(rcpt_email)
          return 0
        elif mailfrom:
          self.send_ok()
          debug.echo(2,'LMTPS: '+line)
          try:
            self.mail_from=mailfrom.group(1)
            mail.sender=parseaddr(tostr(self.mail_from))[1]
          except:
            mail.sender=self.mail_from
        elif self.reg_data.search(line):
          self.send_ok("End data with <CR><LF>.<CR><LF>",354)
          debug.echo(5,'LMTPS: '+line)
          self.RECV=RECV_BODY
        elif self.reg_quit.search(line):
          self.send_ok()
          debug.echo(5,'LMTPS: '+line)
          self.shutdown_clients()
          return 1
        #elif self.reg_xforward.search(line):
        #  debug.echo(2,'LMTPS: ',line)
        #  self.send_ok()
        elif self.reg_helo.search(line):
          self.send_ok()
        else:
          debug.echo(6, 'LMTPS: UNKNOWN COMMAND: '+line)
          self.send_ok("Error: command not implemented", 502)
      return 0

class lmtpd(service):
  '''
  LMTP daemon service.
  
  This service can be used to start sagator as separate filtering LMTP
  daemon. Is is useful for postfix and any other LMTP client, which
  can use these filters.
  LMTP protocol is useful, if you want to set different filters for
  different users.
  
  Usage: lmtpd(scanners, host, port, prefork=2)
  
  Where: scanners is an array of scanners (see README.scanners for more info)
         host is a an ip address to bind
         port is a port to bind
         prefork is a number, which defines preforked process count.
           Set this parameter to actual processor count + 1
           or leave it's default (2).
  
  Example: lmtpd(SCANNERS, '127.0.0.1', 27)

  New in version 0.7.0.
  '''
  name='lmtpd()'
  def __init__(self,scanners,host,port,prefork=2):
      self.SCANNERS=scanners
      self.BINDTO=(host,port)
      self.LISTEN=5
      self.MIN_CHILDS=prefork
      self.EXITING=False
  def accept(self,connects=0):
      # accept
      conn,addr = self.s.accept()
      socket_settimeout(conn,120)
      # reinit scanners
      for scnr in self.SCANNERS:
        scnr.reinit()
      globals.reset()
      # generate ID
      self.time2=time.strftime("%Y%m%d-%H%M%S",time.localtime(time.time()))
      if self.time2!=self.time1:
        self.time1,self.timec=self.time2,1
      globals.gen_id(self.time2,self.timec)
      debug.echo(1, "Connection from: "+addr[0]+" at "+\
        time.strftime("%c", time.localtime()))
      debug.echo(1, "Process id: %s" % globals.id)
      lmtps(self.SCANNERS, conn, addr)
      conn.shutdown(socket.SHUT_RDWR)
      debug.echo(1, "Closing connection.")
      conn.close()
