'''
Basic interscanners for sagator, version 0.6.2

(c) 2003-2018 Jan ONDREJ (SAL) <ondrejj(at)salstar.sk>

 This program is free software; you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
 the Free Software Foundation; either version 2 of the License, or
 (at your option) any later version.

'''

from avlib import *

__all__ = ['match_all', 'alternatives', 'retry', 'match_any',
           'recover', 'nothing']

#####################################################################
### match classes

class match_class(interscanner):
  '''parent class for all match clases'''
  def __init__(self, *scanners):
      self.scanners=self.get_only_scanners(scanners)
      self.rename(self.scanners)
  def get_only_scanners(self, scanners):
      ret=[]
      for scanner in scanners:
        if type(scanner) in [type(()),type([])]:
          ret.extend(self.get_only_scanners(scanner))
        else:
          try:
            if scanner.is_scanner:
              ret.append(scanner)
          except:
            pass
      return ret
  def destroy(self):
      for scnr in self.scanners:
        scnr.destroy()
  def param(self,key,value=None):
      for scnr in self.scanners:
        n=scnr.param(key,value)
        if n:
          return n
      return ''
  def help(self):
      h={}
      for scnr in self.scanners:
        for key,value in list(scnr.help().items()):
          h[key]=value
      return h
  def reinit(self):
      for scnr in self.scanners:
        scnr.reinit()
  def get(self,var):
      try:
        r=getattr(self, var)
        if not r:
          raise AttributeError
      except AttributeError:
        for scnr in self.scanners:
          try:
            r+=scnr.get(var)
          except NameError:
            r=scnr.get(var)
          except AttributeError:
            r=''
      return r
  def rcpt_signature(self,rcpt):
      return '%s(%s)' % (
        self.name.split('(',1)[0], # only it's name,
        ",".join([x.rcpt_signature(rcpt) for x in self.scanners])
      )

class match_all(match_class):
  '''
  Returns a virus only if all scanners returns a virus.

  Usage: match_all(scanner1(),scanner2(),...)

  Where: scanner1(),... are scanners, which must find a virus
  '''
  name='match_all()'
  def scanbuffer(self, buffer, args={}):
      level, vir, ret = 0.0, b'', []
      for self.scanner in self.scanners:
        self.scanner.prescan()
        l,v,r=self.scanner.scanbuffer(buffer,args)
        self.scanner.postscan(l,v,r)
        if v:
          if vir:
            vir += b','+v
          else:
            vir=v
          ret+=r
          level+=l
        else:
          return 0.0, b'', []
      return level, vir, ret

class alternatives(match_class):
  '''
  Alternative scanners.

  Returns a virus, if one scanner returns a virus.
  Scannig is ended if one scanner returns virus or clean.
  Next scanner from the list is used, if previous scannes fails
  (raises an error). Raises an ScannerError, if all scanners failed.

  Usage: alternatives(scanner1(),scanner2(),...)

  Where: scanner1(),... are scanners, which may find a virus
  '''
  name='alternatives()'
  def scanbuffer(self, buffer, args={}):
      level, vir, ret = 0.0, b'', []
      for self.scanner in self.scanners:
        try:
          self.scanner.prescan()
          level, vir, ret = self.scanner.scanbuffer(buffer, args)
          self.scanner.postscan(level, vir, ret)
          return level, vir, ret
        except:
          debug.echo(3,self.name,": scanner ",self.scanner.name," failed: ",
                debug.traceback_value_str().rstrip())
          debug.traceback(4, self.name)
      debug.echo(3, self.name, ": All scanners failed!")
      raise ScannerError('All alternatives failed!')

class match_any(match_class):
  '''
  Match for any sub-scanner is required.

  Returns a virus, if one scanner returns a virus.
  Next scanner from list is used, if previous scanner returns
  no virus, or failed (raised an error).

  Usage: match_any(scanner1(),scanner2(),...)

  Where: scanner1(),... are scanners, which may find a virus
  '''
  name='match_any()'
  def scanbuffer(self, buffer, args={}):
      level, vir, ret = 0.0, b'', []
      for self.scanner in self.scanners:
        try:
          self.scanner.prescan()
          level, vir, ret = self.scanner.scanbuffer(buffer, args)
          self.scanner.postscan(level, vir, ret)
          if is_infected(level):
            return level, vir, ret
        except:
          debug.echo(3, self.name,": scanner ", self.scanner.name, " failed: ",
                     debug.traceback_value_str().rstrip())
          raise
      return level, vir, ret

class retry(match_any):
  '''
  Retry scanner more times.

  This scanner try to run defined scanner more than once, while if fails.
  After an successful pass latest return value is returned.

  Usage: retry(count,scanner())

  Where: count is an integer, which defines retry count
         scanner(),... is a scanner, which may find a virus

  Example: retry(5,
             alternatives(
               spamassassind(...),
               spamassassind(...)
             )
           )
  '''
  name='retry()'
  def __init__(self, count, scanner):
      self.COUNT = count
      match_any.__init__(self, [scanner])
  def scanbuffer(self, buffer, args={}):
      for counter in range(self.COUNT):
        try:
          return match_any.scanbuffer(self, buffer, args)
        except:
          debug.echo(3,self.name,": scanner ",self.scanner.name," failed: ",
                debug.traceback_value_str().rstrip())
      debug.echo(3, self.name, ": All scanners failed")
      raise ScannerError('Retry count raised!')

class recover(match_any):
  '''
  Recover from an error.

  This scanner can be used to recover from scanner fail. If an scanner
  fails (raises an error), this scanner return no virus.

  Usage: recover(scanner1(),scanner2(),...)

  Where: scanner1... are scanners
  '''
  name='recover()'
  def scanbuffer(self, buffer, args={}):
      level, vir, ret = 0.0, b'', []
      for self.scanner in self.scanners:
        try:
          self.scanner.prescan()
          level, vir, ret = self.scanner.scanbuffer(buffer, args)
          self.scanner.postscan(level, vir, ret)
          if is_infected(level):
            return level, vir, ret
        except:
          debug.echo(3,self.name,": scanner ",self.scanner.name," failed: ",
                     debug.traceback_value_str().rstrip())
      return 0.0, b'', []

class nothing(match_any):
  '''
  This scanner does nothing. :-)

  Usage: nothing(some parameters...)

  Where: some parameters... are ignored :)
  '''
  name='N()'
  def scanbuffer(self, buffer, args={}):
      return match_any.scanbuffer(self, buffer, args)
