'''
Basic realscanners for sagator

(c) 2003-2024 Jan ONDREJ (SAL) <ondrejj(at)salstar.sk>

 This program is free software; you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
 the Free Software Foundation; either version 2 of the License, or
 (at your option) any later version.
    
'''

from avlib import *
import re, os, time, socket, filetype

__all__ = ['regexp_scan', 'string_scan', 'smtp_comm',
           'const', 'max_file_size',
           'file_type', 'file_magic',
           'sender_regexp']

def to_list(args, fx=None):
    ''' Convert dictionary to list if type is dict or return original. '''
    if type(args)==type({}):
      ret = []
      for vname, vals in args.items():
        vname = tobytes(vname)
        if fx!=None:
          ret.append([vname]+[fx(tobytes(x)) for x in vals])
        else:
          ret.append([vname]+vals)
      return ret
    elif fx!=None:
      ret = []
      for arg in args:
        ret.append([tobytes(arg[0])]+[fx(tobytes(x)) for x in arg[1:]])
      return ret
    else:
      return args

class regexp_scan(ascanner):
  '''
  Primitive regexp pattern scanner.
  
  There can be more patterns for one virus. All patterns in []
  must match to assign an buffer as virus (AND opeator).
  There can also be more virnames in one dictionary.

  Usage: regexp_scan([['VirName', 'RegExp_Pattern...'], ...], size=0, flags=0)

  Where: 'VirName' is a string, which identifies defined virus
         'RegExp_Pattern...' is a regexp pattern
         size is a number, which defines, how many bytes may be checked.
           If it is 0 or not defined, whole buffer is scanned.
           If it is -1, email header is scanned.
         flags is an integer, which defines regular expression flags.
           By default no flags are used.
    
  Example: regexp_scan([
             # Scan for a part of EICAR virus test file pattern
             ['EICAR', '^X5O!P%@AP[4.*EICAR-STANDARD-ANTIVIRUS-TEST-FILE'],
             # Scan for a an EXE file pattern endoded as base64.
             ['UnknownEXE', '^TVqQ']
           ])
  '''
  name = 'RegExpScanner()'
  def __init__(self, regexps, size=0, flags=0):
      self.size = size
      self.flags = re.M|flags
      # compile regexps
      self.regexps = to_list(regexps, self.reg_compile)
  def reg_compile(self, arg):
      return re.compile(tobytes(arg), self.flags)
  def scanbuffer(self, buffer, args={}):
      for regs in self.regexps:
        for reg in regs[1:]:
          if self.size>0:
            reg1 = reg.search(buffer[:self.size])
          elif self.size<0:
            reg1 = reg.search(buffer[:mail.bodypos])
          else:
            reg1 = reg.search(buffer)
          if reg1:
            if reg==regs[-1]:
              mail.addheader("X-Sagator-RegExp", regs[0])
              return iret(1.0, regs[0], [b'regexp: '+regs[0]+b' FOUND!\n'])
            else:
              break
      return 0.0, b'', []

class string_scan(ascanner):
  '''
  Primitive string pattern scanner.
  
  There can be more patterns for one virus. All patterns
  must match to assign an buffer as virus.
  There can also be more virnames in one dictionary.

  Usage: string_scan([['VirName', 'Pattern1...', ...], ...], size=0)

  Where: 'VirName' is a string, which identifies defined virus
         'Pattern...' is a string pattern
         size is a number, which defines, how many bytes may be checked.
           If it is 0 or not defined, whole buffer is scanned.
           If it is -1, email header is scanned.
    
  Example: string_scan([
             # Scan for a part of EICAR virus test file pattern
             ['EICAR', 'X5O!P%@AP[4', 'EICAR-STANDARD-ANTIVIRUS-TEST-FILE'],
             # Scan for a an EXE file pattern endoded as base64.
             ['UnknownEXE', 'TVqQ']
           ])
  '''
  name = 'StringScanner()'
  def __init__(self, strings, size=0):
      self.strings = to_list(strings, tobytes)
      self.size = size
  def scanbuffer(self, buffer, args={}):
      for regs in self.strings:
        for reg in regs[1:]:
          if self.size>0:
            str1 = buffer[:self.size].find(reg)
          elif self.size<0:
            str1 = buffer[:mail.bodypos].find(reg)
          else:
            str1 = buffer.find(reg)
          if str1>=0:
            if reg==regs[-1]:
              mail.addheader("X-Sagator-StringScan", regs[0])
              return iret(1.0, regs[0], [b'string: '+regs[0]+b' FOUND!\n'])
          else:
            break
      return 0.0, b'', []

class smtp_comm(ascanner):
  '''
  Primitive regexp pattern scanner for SMTP communication.
  
  For more information about this scanner see documentation for
  regexp_scan realscanner.
  
  This scanner also can be used to apply some scnanner to defined
  email adresses. For example:
      smtp_comm([['D', '^RCPT TO:.*anydomain.dom']]])
        & buffer2mbox(libclam())
  In this example, libclam() scanner is used only if mail is
  adressed to an user on anydomain.dom.

  Usage: smtp_comm([['VirName', 'RegExp_Pattern...', ...], ...], flags=0)

  Where: 'VirName' is a string, which identifies defined virus
         'RegExp_Pattern...' is a regexp pattern
         flags is an integer, which defines regular expression flags.
           By default re.IGNORECASE|re.MULTILINE is used.
  '''
  name = 'smtp_comm()'
  def __init__(self, regexps, flags=re.I|re.M):
      self.flags = flags
      # compile regexps
      self.regexps = to_list(regexps, self.reg_compile)
  def reg_compile(self, arg):
      return re.compile(arg, self.flags)
  def scanbuffer(self, buffer, args={}):
      if globals.scan_only:
        return 0.0, b'', []
      for regs in self.regexps:
        for reg in regs[1:]:
          if reg.search(mail.comm):
            if reg==regs[-1]:
              mail.addheader("X-Sagator-SMTP-comm", regs[0])
              return iret(1.0, regs[0], [b'smtp_comm: '+regs[0]+b' FOUND!\n'])
          else:
            break
      return 0.0, b'', []
  def scanfile(self, files, dir='', args={}):
      return self.scanbuffer('', args)

#####################################################################
### primitive class

class const(ascanner):
  '''
  Realscanner to return a constant value (virus or clean).
  
  This scanner has no special functionality. It returns always
  a defined virus, or clean (if virus is not defined).
  This scanner has no error codes.
  
  Usage: const(level, VirName, return_string=[])
      or const()
  
  Where: level is an float which defines returned infection level
           If level is not defined, after scanning an error is raised.
         VirName is a returned virus.
         return_string is a array of strings returned
  
  Examples: const(0.0)          # Return clean
            const(1.0, 'Virus') # Return virus name "Virus"
            const()             # Raise an error
  '''
  name = 'const()'
  ignore_name = True
  def __init__(self, level=None, vir='', ret=[]):
      self.level = level
      self.vir = tobytes(vir)
      self.ret = ret
      self.filename = ''
  def scanbuffer(self, buffer, args={}):
      if self.level==None:
        raise ScannerError('Forced error by const() scanner')
      if self.filename:
        debug.echo(6, self.name, ": Scanning file: ", self.filename)
      return self.level, self.vir, self.ret
  def scanfile(self, files, dirname='', args={}):
      if self.level==None:
        raise ScannerError('Forced error by const() scanner')
      for fn in files:
        debug.echo(6, self.name, ": Scanning file: ", fn)
      return self.level, self.vir, self.ret

#####################################################################
### max_file_size class

class max_file_size(ascanner):
  '''
  Realscanner to test email's size.
  
  Usage: max_file_size(size, name='FileSizeOverrun')
  
  Where: size is a number of bytes, which if exceeded, virus is returned
         name is a string, which identifies the virus name.
            If this parameter is not present, 'FileSizeOverrun' is returned.

  Example: max_file_size(1024*1024*10)
           # all files at least 10MB will be reported as virus
  '''
  name = 'max_file_size()'
  def __init__(self, size, vname='FileSizeOverrun'):
      self.size = size
      self.vname = tobytes(vname)
  def scanbuffer(self, buffer, args={}):
      if len(buffer)<=self.size:
        return 0.0, b'', []
      else:
        return 1.0, self.vname, [tobytes(self.filename)+b": "+self.vname]
  def scanfile(self, files, dir='', args={}):
      for fn in files:
        if os.lstat(safe.fn(fn)).st_size>self.size:
          return 1.0, self.vname, [tobytes(fn)+b': '+self.vname]
      return 0.0, b'', []

#####################################################################
### other scanners

class file_type(ascanner):
  '''
  Realscanner, which scans type of a file.
  
  A scanner to chcek content of a file for a special type (like zip, ...).
  
  Usage: file_type({'type': 'vir', ...})
  
  Where 'type' is a name of type returned by filetype.* function
        'vir' is a vir name returned, if file type is matched
  
  Example: # scan for executables (COM, EXE, PIF and similiar types)
           parsemail(file_type({'exe': 'MS Executable',
                                'elf': 'Linux Executable'}))

  Obsolete since 1.1.1, use file_magic() instead.
  '''
  name = 'file_type()'
  def __init__(self, types):
      self.types = dict([
        (tobytes(x), y)
        for x, y in types.items()
      ])
  def scanbuffer(self, buffer, args={}):
      fta = filetype.what(buffer)
      for ft in fta:
        if ft in self.types:
          return 1.0, tobytes(self.types[ft]), [self.filename+": "+str(fta)]
      return 0.0, b'', []
  def scanfile(self, files, dirname='', args={}):
      for fname in files:
        fta = filetype.file(fname)
        for ft in fta:
          if ft in self.types:
            return 1.0, tobytes(self.types[ft]), [fname+": "+str(fta)]
      return 0.0, b'', []

class file_magic(ascanner):
  '''
  File magic test (like "file -i command").
  
  This scanner can be used to test file content for a special type.
  You need a python module "magic", which is by default in file's
  source, but in most distributions it is not compiled in package.
  
  Usage: file_magic({'VirName': 'regexp_pattern', ..}, flags=0, raw=True)
  
  Where: 'VirName' is a string, which identifies defined virus
         'regexp_pattern' is a regular expression pattern
         flags is an int, which defines regular expression flags,
           like re.I to ignore it's case.
         raw is a boolean, when True, use MAGIC_RAW instead of MAGIC_MIME
  '''
  name = 'file_magic()'
  def __init__(self, magics, flags=re.I, raw=False):
      import magic
      if raw:
        self.magic = magic.open(magic.MAGIC_RAW)
      else:
        self.magic = magic.open(magic.MAGIC_MIME)
      self.magic.load()
      # compile regexps
      self.magics = {}
      for vname, reg in magics.items():
        self.magics[tobytes(vname)] = re.compile(reg, flags)
  def __del__(self):
      self.magic.close()
  def scanbuffer(self, buffer, args={}):
      for vname, reg in self.magics.items():
        if reg.search(self.magic.buffer(buffer)):
          mail.addheader("X-Sagator-FileMagic", vname)
          return iret(1.0, vname, [b'file_magic: '+vname+b' found in '+tobytes(self.filename)+b'!\n'])
      return 0.0, b'', []
  def scanfile(self, files, dirname='', args={}):
      for fname in files:
        for vname, reg in self.magics.items():
          if reg.search(self.magic.file(fname)):
            return iret(1.0, vname, [tobytes(self.name)+b": "+vname+b' found in '+tobytes(self.filename)+b'!\n'])
      return 0.0, b'', []

class sender_regexp(regexp_scan):
  r'''
  Sender IP address regexp scanner.
  
  This scanner can be used to scan for sender's IP address. You can use
  it to perform some actions (for example send report to admin) if
  a virus is comming from your local addresses.

  Usage: sender_regexp([['VirName', 'RegExp_Pattern...'], ...])

  Where: 'VirName' is a string, which identifies defined virus
         'RegExp_Pattern...' is a regexp pattern
    
  Example: sender_regexp([
             ['LOCAL_IP', r'(192\\.168|172\\.(1[6789]|2[0-9]|3[01])|10)\\.']
           ])
  '''
  def scanbuffer(self, buffer, args={}):
      ip = mail.getsender()['ADDR']
      debug.echo(7, self.name+': IP=', ip)
      return regexp_scan.scanbuffer(self, ip, args)
