'''
db.py - Database connection module.

(c) 2005-2016,2019 Jan ONDREJ (SAL) <ondrejj(at)salstar.sk>

 This program is free software; you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
 the Free Software Foundation; either version 2 of the License, or
 (at your option) any later version.

'''

from avlib import debug, safe, tostr_list
import sys, time, re

classes = ['MySQLdb', 'pymysql', 'sqlite', 'pg', 'pgdb', 'psycopg']

class dbconn:
  name = 'connection_template'
  params = ('host', 'port', 'dbname', 'dbuser', 'dbpasswd')
  next_refresh = 0
  refresh_time = 60 # 60 seconds
  def __init__(self, db=None, params=[]):
      self.DB = db
      self.PARAMS = params
      self.IntegrityError = db.IntegrityError
  def __del__(self):
      self.close()
  def refresh(self):
      t0 = time.time()
      if t0 >= self.next_refresh:
        debug.echo(6, "%s: Database connection refresh" % self.name)
        self.close()
        self.connect()
        self.next_refresh = t0+self.refresh_time
  def execute(self, sql_cmd, sql_args=[], commit=True):
      if self.DB.paramstyle == 'qmark':
        sql_cmd = sql_cmd.replace("%s", "?")
        # try to replace pyformat
        if type(sql_args) == type({}):
          reg_subst = re.compile(r'%\((.+?)\)s')
          r = reg_subst.search(sql_cmd)
          if r:
            sql_args = tuple([
              sql_args.get(x)
              for x in re.compile(r'%\((.+?)\)s').findall(sql_cmd)
            ])
            sql_cmd = reg_subst.sub('?', sql_cmd)
      debug.echo(9, sql_cmd, " ", [tostr_list(sql_args)])
      try:
        cur = self.DBC.cursor()
        if sql_args == []:
          cur.execute(sql_cmd)
        else:
          cur.execute(sql_cmd, sql_args)
      except:
        self.connect()
        cur = self.DBC.cursor()
        if sql_args == []:
          cur.execute(sql_cmd)
        else:
          cur.execute(sql_cmd, sql_args)
      self.rowcount = cur.rowcount
      if commit:
        self.DBC.commit()
      return cur
  def execute_cycle(self, sql_cmd, sql_args=[], commit=True, times=5):
      for counter in range(times):
        try:
          return self.execute(sql_cmd, sql_args, commit)
        except self.IntegrityError:
          raise
        except Exception as es:
          debug.traceback(9, "db connection:")
          time.sleep(0.1*counter)
          last_exc = es
      raise last_exc
  def query(self, sql_cmd, sql_args=[]):
      r = self.execute(sql_cmd, sql_args, False).fetchall()
      # simulate rowcount, if it's not implemented in DB API
      if self.rowcount < 0:
        self.rowcount = len(r)
      return r
  def commit(self):
      try:
        self.DBC.commit()
      except AttributeError:
        pass # self.DBC is not defined
  def quote(self, s):
      return s.replace("'", "''").replace('\\', '\\\\')
  def close(self):
      try:
        self.DBC.commit()
      except:
        pass
      try:
        self.DBC.close()
      except:
        pass

class MySQLdb(dbconn):
  '''
  MySQL database connection.
  
  Requires: MySQLdb python module

  Usage: db.MySQLdb(host='127.0.0.1',port=3306,dbname='sagator',
                    dbuser='sagator',dbpasswd='xxxxxx')

  Parameters:
	host		- database server hostname or unix socket path
	port		- tcp socket port for database server
	dbname		- database name
	dbuser		- authorization username
	dbpasswd	- authorization password

  Database creation script:	scripts/db/mysql.sh
  '''
  name = 'MySQLdb'
  params = ('host', 'port', 'dbname', 'dbuser', 'dbpasswd')
  def __init__(self, *args, **kwargs):
      sqlmod = self.module_import()
      if args:
        dbconn.__init__(self, sqlmod, args[0])
      else:
        dbconn.__init__(self, sqlmod, kwargs)
      # load required modules only
      try: self.connect()
      except: pass
  def module_import(self):
      try:
        import MySQLdb, encodings.latin_1, encodings.utf_8
        return MySQLdb
      except ImportError as es:
        sys.exit("Error importing MySQLdb module for python: "+str(es))
  def connect(self):
      if self.PARAMS.get('host')[0] == '/':
        self.DBC = self.DB.connect(
          unix_socket=self.PARAMS.get('host'),
          db=self.PARAMS.get('dbname'),
          user=self.PARAMS.get('dbuser'),
          passwd=self.PARAMS.get('dbpasswd'),
          connect_timeout=5,
          charset=self.PARAMS.get('charset', 'utf8')
        )
      else:
        self.DBC = self.DB.connect(
          host=self.PARAMS.get('host'),
          port=self.PARAMS.get('port'),
          db=self.PARAMS.get('dbname'),
          user=self.PARAMS.get('dbuser'),
          passwd=self.PARAMS.get('dbpasswd'),
          connect_timeout=5,
          charset=self.PARAMS.get('charset', 'utf8')
        )
  def quote(self, s):
      return self.DB.escape_string(s)

class pymysql(MySQLdb):
  # update docs by replacing MySQLdb to pymysql
  __doc__ = MySQLdb.__doc__.replace("MySQLdb", "pymysql")
  name = "pymysql"
  def module_import(self):
      try:
        import pymysql, encodings.latin_1, encodings.utf_8
        return pymysql
      except ImportError as es:
        sys.exit("Error importing pymysql module for python: "+str(es))

class sqlite(dbconn):
  '''
  SQLite database conenction.

  Requires: sqlite python module

  Usage: db.sqlite(dbname='/var/lib/sagator/sqlitedb')

  Parameters:
	dbname		- database name

  Database creation script:	scripts/db/sqlite.sh
  '''
  name = 'sqlite'
  params = ('dbname')
  def __init__(self, *args, **kwargs):
      try:
        # pysqlite 3.x version
        import sqlite3 as sqlite
      except ImportError as es0:
        try:
          # pysqlite 2.x version
          from pysqlite2 import dbapi2 as sqlite
        except ImportError as es1:
          try:
            # pysqlite 1.x version
            import sqlite
          except ImportError as es2:
            sys.exit("Error importing SQLite module for python: %s or %s" \
                     % (str(es1), str(es2)))
      if args:
        dbconn.__init__(self, sqlite, args[0])
      elif kwargs:
        dbconn.__init__(self, sqlite, kwargs)
      else:
        dbconn.__init__(self, sqlite, {'dbname': '/var/lib/sagator/sqlitedb'})
      #self.connect()
  def connect(self):
      self.DBC = self.DB.connect(
        safe.fn(self.PARAMS.get('dbname'))
      )

class pg(dbconn):
  '''
  PostgreSQL support via pg python module.

  Requires: pg python module

  Usage: db.pg(host='127.0.0.1',port=5432,dbname='sagator',
               dbuser='sagator',dbpasswd='xxxxxx')

  Parameters:
	host		- database server hostname
	port		- tcp socket port for database server
	dbname		- database name
	dbuser		- authorization username
	dbpasswd	- authorization password

  Database creation script:	scripts/db/pgsql.sh
  '''
  name = 'pg'
  params = ('host', 'port', 'dbname', 'dbuser', 'dbpasswd')
  def __init__(self, *args, **kwargs):
      try:
        import pg
      except ImportError as es:
        sys.exit("Error importing PostgreSQL module for python: "+str(es))
      if args:
        dbconn.__init__(self, pg, args[0])
      else:
        dbconn.__init__(self, pg, kwargs)
      #self.connect()
  def connect(self):
      self.DBC = self.DB.connect(
        dbname=self.PARAMS.get('dbname'),
        host=self.PARAMS.get('host'),
        port=self.PARAMS.get('port'),
        user=self.PARAMS.get('dbuser'),
        passwd=self.PARAMS.get('dbpasswd')
      )

class pgdb(dbconn):
  '''
  PostgreSQL support via pgdb python module (psycopg).

  Requires: pgdb python module

  Usage: db.pgdb(host='127.0.0.1',port=5432,dbname='sagator',
                 dbuser='sagator',dbpasswd='xxxxxx')

  Parameters:
	host		- database server hostname
	port		- tcp socket port for database server
	dbname		- database name
	dbuser		- authorization username
	dbpasswd	- authorization password

  Database creation script:	scripts/db/pgsql.sh
  '''
  name = 'pgdb'
  params = ('host', 'port', 'dbname', 'dbuser', 'dbpasswd')
  def __init__(self, *args, **kwargs):
      try:
        import pgdb
      except ImportError as es:
        sys.exit("Error importing PostgreSQL module for python: "+str(es))
      if args:
        dbconn.__init__(self, pgdb, args[0])
      else:
        dbconn.__init__(self, pgdb, kwargs)
      #self.connect()
  def connect(self):
      self.DBC = self.DB.connect(
        user=self.PARAMS.get('dbuser'),
        password=self.PARAMS.get('dbpasswd'),
        host="%s:%s" % (self.PARAMS.get('host'), self.PARAMS.get('port')),
        database=self.PARAMS.get('dbname')
      )

class psycopg(dbconn):
  '''
  PostgreSQL support via psycopg python module.

  Requires: psycopg python module

  Usage: db.psycopg(host='127.0.0.1',port=5432,dbname='sagator',
                    dbuser='sagator',dbpasswd='xxxxxx')

  Parameters:
	host		- database server hostname
	port		- tcp socket port for database server
	dbname		- database name
	dbuser		- authorization username
	dbpasswd	- authorization password

  Database creation script:	scripts/db/pgsql.sh
  '''
  name = 'psycopg'
  params = ('host', 'port', 'dbname', 'dbuser', 'dbpasswd')
  def __init__(self, *args, **kwargs):
      try:
        import psycopg
      except ImportError as es:
        sys.exit("Error importing psycopg module for python: "+str(es))
      if args:
        dbconn.__init__(self, psycopg, args[0])
      else:
        dbconn.__init__(self, psycopg, kwargs)
      #self.connect()
  def connect(self):
      self.DBC = self.DB.connect(
        "host=%s port=%d dbname=%s user=%s password=%s" % (
          self.PARAMS.get('host'),
          self.PARAMS.get('port') or 5432,
          self.PARAMS.get('dbname'),
          self.PARAMS.get('dbuser'),
          self.PARAMS.get('dbpasswd')
        )
      )
