Source code for cfl_data_utils.database._database

#!/usr/bin/env python
"""Extendable Database class for connecting to different database types within Chetwood Financial.

TODO
    * Implement time out functionality for reducing idle connections
"""

from os import path
from warnings import warn

from pandas import read_sql_query, DataFrame
from sqlalchemy import create_engine
from sqlalchemy.engine.url import URL

__author__ = 'Will Garside'
__email__ = 'worgarside@gmail.com'
__status__ = 'Production'


[docs]class Database: """Base database class This is to be used as a super class by different database connectors Attributes: db_name (str): Name of the database to connect to ssh_host (str): Host IP of the SSH tunnel ssh_username (str): User for access to SSH tunnel pkey_path (str): Path to the .pem file containing the PKey db_bind_address (str): Binding address for the database db_user (str): Database username db_password (str): Database password db_host (str): Host IP for the database, usually 127.0.0.1 after SSH tunneling db_port (int): DB port ssh_port (int): SSH Tunnel port, defaults to 22 stubbed (bool): Flag to stub database calls max_idle_time (int): Maximum time connection can idle before being disconnected """ def __init__(self, ssh_host=None, ssh_port=22, ssh_username=None, pkey_path=None, db_name=None, db_bind_address=None, db_host='127.0.0.1', db_port=None, db_user=None, db_password=None, stubbed=False, max_idle_time=30, disable_ssh_tunnel=False): """Inits Database and validates setup config""" self.conn = None self.cur = None self.server = None self.dialect = None self.driver = None self.error_state = False self.stubbed = stubbed self.max_idle_time = max_idle_time self.db_user = db_user self.db_password = db_password self.db_host = db_host self.db_port = db_port self.db_name = db_name self.db_bind_address = db_bind_address self.ssh_host = ssh_host self.ssh_port = int(ssh_port) self.ssh_username = ssh_username self.pkey_path = pkey_path self.required_creds = {} if self.stubbed: print(f'\033[31mWarning: Connection to \033[1m{self.db_name}\033[31m is stubbed.' f' No data will be affected.\033[0m') self.setup() self.validate_setup(disable_ssh_tunnel) self.connect_to_db(disable_ssh_tunnel)
[docs] def setup(self): """Allows each database subclass to be setup differently according to the relevant dialect and driver"""
[docs] def connect_to_db(self, disable_ssh_tunnel): """Placeholder method for the actual connection to the DB Args: disable_ssh_tunnel (bool): Determines whether an SSH tunnel should be used """
[docs] def validate_setup(self, disable_ssh_tunnel): """Validate the setup of the database: args, connections etc. Args: disable_ssh_tunnel (bool): Determines whether an SSH tunnel should be used """ missing_params = [k for k, v in self.__dict__.items() if not v and k in self.required_creds] if missing_params: self.error_state = True raise TypeError(f"Database instance missing params: {missing_params}") if not disable_ssh_tunnel and self.pkey_path and not path.isfile(self.pkey_path): self.error_state = True raise IOError(f"{self.pkey_path} is not a valid file." f" Make sure you have provided the absolute path")
[docs] def query(self, sql, commit=True): """Executes a query passed in by using the DatabaseManager object Args: sql (str): The sql query to be executed by the Cursor object commit (bool): Committing on queries can be disabled for rapid writing (e.g. q_one commit at end) Returns: the Cursor object for continued usage """ if self.stubbed: return self.cur self.cur.execute(sql) if commit: self.commit() return self.cur
[docs] def df_from_query(self, stmt, index_col=None, coerce_float=True, params=None, parse_dates=None, chunksize=None): """Generates a pandas dataframe from an SQL query. All parameter defaults match those of read_sql_query. Args: stmt (str): The query to be executed by the Cursor object index_col (Union[Iterable[str], str]): Column(s) to set as index(MultiIndex). coerce_float (bool): Attempts to convert values of non-string, non-numeric objects (like decimal.Decimal) to floating point. Useful for SQL result sets. params (Union[List, Tuple, Dict]): List of parameters to pass to execute method. The syntax used to pass parameters is database driver dependent. parse_dates (Union[List, Dict]): - List of column names to parse as dates. - Dict of ``{column_name: format string}`` where format string is strftime compatible in case of parsing string times, or is one of (D, s, ns, ms, us) in case of parsing integer timestamps. - Dict of ``{column_name: arg dict}``, where the arg dict corresponds to the keyword arguments of :func:`pandas.to_datetime` Especially useful with databases without native Datetime support, such as SQLite. chunksize (int): If specified, return an iterator where `chunksize` is the number of rows to include in each chunk. Returns: DataFrame object """ if self.stubbed: return DataFrame() df: DataFrame = read_sql_query(stmt, self.conn, index_col=index_col, coerce_float=coerce_float, params=params, parse_dates=parse_dates, chunksize=chunksize) return df
[docs] def df_to_table(self, df, name, schema=None, if_exists='fail', index=True, index_label=None, chunksize=None, dtype=None, method=None): """Write records stored in a DataFrame to a SQL database. Args: df (DataFrame): DataFrame to convert name (str): Name of table to be created schema (str): Specify the schema (if database flavor supports this) if_exists (str): How to behave if the table already exists index (bool): Write DataFrame index as a column. Uses index_label as the column name in the table. index_label (Union[str, List]): Column label for index column(s). If None is given (default) and index is True, then the index names are used. A sequence should be given if the DataFrame uses MultiIndex. chunksize (int): Rows will be written in batches of this size at a time. By default, all rows will be written at once. dtype (dict): Specifying the datatype for columns. The keys should be the column names and the values should be the SQLAlchemy types or strings for the sqlite3 legacy mode. method (Union[None, str, callable]): Controls the SQL insertion clause used: None : Uses standard SQL INSERT clause (one per row). ‘multi’: Pass multiple values in a single INSERT clause. callable with signature (pd_table, conn, keys, data_iter) """ if self.stubbed: return if if_exists not in {None, 'fail', 'replace', 'append'}: warn(f"Parameter if_exists has invalid value: {if_exists}. " f"Should be one of {{None, 'fail', 'replace', 'append'}}") if_exists = None if isinstance(method, str) and not method == 'multi': warn(f"Parameter method has invalid value: {method}. Should be one of {{None, 'multi', callable}}") method = None engine = create_engine( URL( drivername=f'{self.dialect}+{self.driver}', username=self.db_user, password=self.db_password, host=self.db_host, port=self.db_port, database=self.db_name ) ) df.to_sql(name=name, con=engine, schema=schema, if_exists=if_exists, index=index, index_label=index_label, chunksize=chunksize, dtype=dtype, method=method)
[docs] def executemany(self, stmt: str, data): """ Args: stmt (str): The query to be executed by the Cursor object data (Union[List, Tuple]): The data to be processed Returns: the Cursor object for continued usage """ if self.stubbed: return self.cur self.cur.executemany(stmt, data) return self.cur
[docs] def commit(self): """Commits all changes to database Returns: the Cursor object for continued usage """ self.conn.commit() return self.cur
[docs] def disconnect(self, silent: bool = False): """Disconnects from the database and the SSH tunnel Args: silent (bool): suppresses final print statement """ self.conn.close() if self.server is not None: self.server.close() if not silent: print(f"{self.db_name if self.db_name else 'Database'} disconnected.")