Wrapper for psycopg2 that handles connection pooling, transactions, cursors, and makes the API easier to deal with

Home   »   Wrapper for psycopg2 that handles connection pooling, transactions, cursors, and makes the API easier to deal with

"""
PostgreSQL database utilities.

Wrapper for psycopg2 that handles connection pooling, transactions, cursors, and makes the API
easier to deal with.
"""
from collections import defaultdict
import contextlib
import logging
import re
import threading
from typing import Any, Dict, Generator, Type, List, Optional, Tuple, Union  # Generator[yields, emits, returns]

import psycopg2
import psycopg2.extensions
import psycopg2.pool
import psycopg2.extras
from tqdm.auto import tqdm
from deprecated import deprecated

LOG = logging.getLogger(__name__)
FORMAT_SQL_QUERY_PAT = re.compile(r'(?:\s|\t|\n){2,}')

_default_pool: Optional[psycopg2.pool.ThreadedConnectionPool] = None
_default_pool_args: Optional[Tuple] = None


# TODO: rename to configure_default_pool
def configure_pool(*args, **kwargs) -> None:
    """
    Configure default connection pool options. These options will be set lazily,
    then used when get_pool() is called for the first time.

    If your application needs to use the default connection pool (i.e. calls get_pool())
    then you MUST call this function during app initialization, before get_pool().

    For available pool args/kwargs, see:
    https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS

    :param args: pool args.
    :param kwargs: pool keyword args.
    """
    global _default_pool_args

    if _default_pool_args is not None:
        LOG.warning('configure_pool() should only be called once')

    _default_pool_args = (args, kwargs)


# TODO: rename to get_default_pool
def get_pool() -> psycopg2.pool.ThreadedConnectionPool:
    """
    Get the default database connection pool.

    This pool is the default connection pool used by all functions in this module,
    where a connection isn't explicitly passed as an argument.

    You generally don't need to call this directly to get connections, and should
    almost always use the connection() context manager instead (so that your connection
    is automatically returned to the pool when you're finished).

    :return: default connection pool.
    """
    global _default_pool

    if _default_pool is None:
        if _default_pool_args is None:
            raise RuntimeError('You must call configure_pool() before get_pool()')

        _default_pool = psycopg2.pool.ThreadedConnectionPool(*_default_pool_args[0],
                                                             **_default_pool_args[1])

    return _default_pool


def get_conn(pool: Optional[psycopg2.pool.ThreadedConnectionPool] = None,
             retry_limit: Optional[int] = 5) -> \
        Tuple[psycopg2.extensions.connection, psycopg2.pool.ThreadedConnectionPool]:
    """
    Check out a database connection from a pool and tests it to ensure it's working.
    If the connection is bad, it will be discarded and a new connection will be attempted
    up to retry_limit times.

    You generally don't need to call this directly to get connections, and should
    almost always use the connection() context manager instead (so that your connection
    is automatically returned to the pool when you're finished).

    :param pool: optional connection pool to use (if unspecified, the default pool will be used).
    :param retry_limit: max number of times to get a connection from the pool if it is bad.
                        if None, the limit is infinite.
    :return: tuple of:
             - connection that was fetched from the pool.
             - pool the connection came from.
    """
    # TODO: Handle PoolError('connection pool exhausted') with optional blocking
    if pool is None:
        pool = get_pool()

    @retry((psycopg2.Error,), retry_limit)
    def try_get_conn() -> psycopg2.extensions.connection:
        conn = pool.getconn()

        # Test connection to ensure it's alive
        cur = conn.cursor()
        try:
            cur.execute('select 1')
        except psycopg2.Error:
            # Connection is bad: Return it to the pool to be discarded.
            if conn is not None:
                pool.putconn(conn, close=True)

        if not cur.closed:
            cur.close()

        return conn

    return try_get_conn(), pool


@contextlib.contextmanager
def connection(pool: Optional[psycopg2.pool.ThreadedConnectionPool] = None) -> \
        Generator[psycopg2.extensions.connection, None, None]:
    """
    Context manager that returns a connection, then cleans it up and returns it to the pool
    when finished. Use this instead of get_conn().

    :param pool: optional connection pool to use (will fallback to default pool if unspecified).
    :yield: connection.
    """
    conn, pool = get_conn(pool)
    try:
        yield conn
    finally:
        if conn is not None:
            pool.putconn(conn)


@contextlib.contextmanager
def _ensure_connection(conn: Optional[psycopg2.extensions.connection] = None) -> \
        Generator[psycopg2.extensions.connection, None, None]:
    """
    Internal context manager that will ensure the block is supplied with a user-defined or default
    connection.

    :param conn: optional user-defined connection (if unspecified, default connection will be used).
    :yield: connection.
    """
    if conn:
        yield conn
    else:
        with connection() as conn:
            yield conn


class cursor:
    """
    Context manager for database cursor that has the following behavior:

    - Yields a DictCursor that returns dict-like rows.
    - Handles transaction behavior:
        - Commits transaction upon exit (or rolls back if there was an exception).
        - When nested (and when a user-defined connection is passed), supports nested
          transaction behavior using savepoints.
    - Closes the cursor when finished.
    """
    use_savepoints: bool = True

    class ConnOpts:
        """
        Per-connection options.
        """
        def __init__(self):
            self.nest_level = 0  # Nested transaction level
            self.thread_lock = threading.Lock()  # Thread-safe access to vars

    _conn_opts: Dict[psycopg2.extensions.connection, ConnOpts] = defaultdict(ConnOpts)

    def __init__(self,
                 conn: Optional[psycopg2.extensions.connection] = None,
                 pool: Optional[psycopg2.pool.ThreadedConnectionPool] = None,
                 **cursor_args):
        """
        :param conn: optional connection (if unspecified, default connection will be used).
        :param pool: optional connection pool (if conn is None and you need to pass a custom pool).
        """
        if conn is not None and pool is not None:
            raise ValueError('conn and pool are mutually exclusive')

        if conn is None:
            conn, pool = get_conn(pool)
            self._is_managed_connection = True
        else:
            self._is_managed_connection = False

        self._conn = conn
        self._pool = pool
        self._opts = self._conn_opts[conn]  # TODO: convert keys to weakref.ref()

        self._cursor_args = cursor_args
        if 'cursor_factory' not in self._cursor_args:
            self._cursor_args['cursor_factory'] = psycopg2.extras.DictCursor

    def __enter__(self) -> psycopg2.extensions.cursor:
        """
        :return: cursor.
        """
        with self._opts.thread_lock:
            self._opts.nest_level += 1

        self._cur = self._conn.cursor(**self._cursor_args)
        if self._cur.closed:
            raise RuntimeError('Connection returned a closed cursor')

        return self._cur

    def __exit__(self,
                 exc_type: Optional[Type[BaseException]],
                 exc_val: Exception,
                 exc_tb):
        with self._opts.thread_lock:
            nest_level = self._opts.nest_level

        try:
            if exc_type is None:
                # Success
                if not self._is_managed_connection and nest_level > 1:
                    # Nested transaction (create new savepoint)
                    if self.use_savepoints:
                        with _ensure_cursor(self._cur) as cur:
                            cur.execute(f'savepoint level_{nest_level}')
                else:
                    # Topmost transaction (commit transaction)
                    self._conn.commit()
            else:
                # Failure
                LOG.error(f'Rolling back {"transaction" if nest_level > 2 else "to previous savepoint"} because of {exc_type} error: {exc_val!r}', exc_info=exc_val)
                if not self.use_savepoints or nest_level <= 2:
                    # Level 1 or 2 transaction (rollback transaction; no previous savepoint to rollback to)
                    try:
                        self._conn.rollback()
                    except psycopg2.Error as ex:
                        LOG.error(f'Rollback failed because of error: {ex!r}', exc_info=True)
                elif self.use_savepoints and not self._is_managed_connection:
                    # Nested transaction (roll back to previous savepoint)
                    try:
                        with cursor(conn=self._conn) as cur:  # Create new cursor (current may now be invalid)
                            cur.execute(f'rollback to savepoint level_{nest_level - 1}')
                    except psycopg2.Error as ex:
                        # Compound failure (rollback entire transaction to be safe)
                        LOG.error(f'Falling back to transaction rollback because savepoint rollback failed: {ex!r}', exc_info=True)
                        try:
                            self._conn.rollback()
                        except psycopg2.Error as ex:
                            LOG.error(f'Fallback rollback failed because of error: {ex!r}', exc_info=True)
        finally:
            # Cleanup
            if not self._cur.closed:
                self._cur.close()

            with self._opts.thread_lock:
                self._opts.nest_level -= 1

            if self._is_managed_connection and self._conn is not None:
                self._pool.putconn(self._conn)

        return False  # Re-raise exceptions


@contextlib.contextmanager
def _ensure_cursor(cur: Optional[psycopg2.extensions.cursor] = None, **cursor_args) -> \
        Generator[psycopg2.extensions.cursor, None, None]:
    """
    Internal context manager that will ensure the block is supplied with a user-defined or default
    cursor.

    :param cursor: optional user-defined cursor (if unspecified, default cursor/connection will be used).
    :yield: cursor.
    """
    if cur:
        yield cur
    else:
        with cursor(**cursor_args) as cur:
            yield cur


def fetchmany(statement: str,
              params: Optional[Tuple] = None,
              use_tqdm: Union[bool, dict] = False,
              cur: Optional[psycopg2.extensions.cursor] = None) -> \
        Generator[psycopg2.extras.DictRow, None, int]:
    """
    Executes SQL and returns multi-row results.

    Example:

        for row in fetchmany('select id, name from foo where bar = %s', (some_var,)):
            print(row['id'])

    :param statement: SQL statement.
    :param params: SQL statement parameters.
    :param use_tqdm: whether ot not to use tqdm progress bar (can also be an options dict for tqdm).
    :param cur: optional user-defined cursor.
    :yield: row.
    :return: row count.
    """
    with _ensure_cursor(cur) as cur:
        query = cur.mogrify(statement, params)
        LOG.debug(f'SQL query: {_format_sql_query(query)}')
        cur.execute(query)

        if (use_tqdm is True or isinstance(use_tqdm, dict)) and cur.rowcount > 0:
            # Show progress bar
            tqdm_opts = {}
            if isinstance(use_tqdm, dict):
                tqdm_opts = use_tqdm

            progress = tqdm(total=cur.rowcount, **tqdm_opts)
        else:
            # Hide progress bar
            progress = None

        while True:
            rows = cur.fetchmany(cur.arraysize)
            if len(rows) == 0:
                break

            for row in rows:
                yield row

                if progress is not None:
                    progress.update()

        if progress is not None:
            progress.close()

        return cur.rowcount


def fetchone(statement: str,
             params: Optional[Tuple] = None,
             cur: Optional[psycopg2.extensions.cursor] = None) -> psycopg2.extras.DictRow:
    """
    Execute SQL and returns single row result.

    :param statement: SQL statement.
    :param params: SQL statement parameters.
    :param cur: optional user-defined cursor.
    :return: row.
    """
    with _ensure_cursor(cur) as cur:
        query = cur.mogrify(statement, params)
        LOG.debug(f'SQL query: {_format_sql_query(query)}')
        cur.execute(query)

        return cur.fetchone()


def execute(statement: str,
            params: Optional[Union[Tuple, Dict]] = None,
            cur: Optional[psycopg2.extensions.cursor] = None) -> None:
    """
    Execute a database operation (query or command).

    Parameters may be provided as sequence or mapping and will be bound to variables in the operation.
    Variables are specified either with positional (%s) or named (%(name)s) placeholders.

    The method returns None. If a query was executed, the returned values can be retrieved using
    fetch*() methods.

    :param statement: SQL statement.
    :param params: SQL statement parameters.
    :param cur: optional user-defined cursor.
    """
    with _ensure_cursor(cur) as cur:
        query = cur.mogrify(statement, params)
        LOG.debug(f'SQL query: {_format_sql_query(query)}')
        cur.execute(query)


def execute_values(statement: str,
                   values: List[Tuple],
                   template: Optional[str] = None,
                   page_size: Optional[int] = None,
                   fetch: bool = False,
                   cur: Optional[psycopg2.extensions.cursor] = None) -> \
        Generator[psycopg2.extras.DictRow, None, int]:
    """
    Execute a statement using VALUES with a sequence of parameters.

    :param statement: SQL statement to execute. It must contain a single %s placeholder, which will
                      be replaced by a VALUES list.
                      Example: "INSERT INTO mytable (id, f1, f2) VALUES %s".
    :param values: sequence of sequences or dictionaries with the arguments to send to the query.
                   The type and content must be consistent with template.
    :param template: the snippet to merge to every item in argslist to compose the query.
                     - If the argslist items are sequences it should contain positional placeholders
                       (e.g. "(%s, %s, %s)", or "(%s, %s, 42)” if there are constants value…).
                     - If the argslist items are mappings it should contain named placeholders
                       (e.g. "(%(id)s, %(f1)s, 42)").
                     If not specified, assume the arguments are sequence and use a simple positional template
                     (i.e. (%s, %s, ...)), with the number of placeholders sniffed by the first element in argslist.
    :param page_size: maximum number of argslist items to include in every statement.
                      If there are more items the function will execute more than one statement.
                      Defaults to the length of the values parameter.
    :param fetch: if True return the query results into a list (like in a fetchall()).
                  Useful for queries with RETURNING clause.
    :param cur: optional user-defined cursor.
    :yield: row (if fetch parameter is True).
    :return: row count (if fetch parameter is True).
    """
    # TODO: Add tqdm_opts parameter like with fetchmany()
    if page_size is None:
        page_size = len(values)

    with _ensure_cursor(cur) as cur:
        LOG.debug(f'SQL query: {_format_sql_query(statement.encode("utf-8"))} -> {values!r}')
        result = psycopg2.extras.execute_values(cur,
                                                statement,
                                                values,
                                                template=template,
                                                page_size=page_size,
                                                fetch=fetch)
        if fetch:
            for row in result:
                yield row

            return cur.rowcount


def upsert(row: Dict[str, Any],
           table_name: str,
           primary_key: List[str],
           include_keys: Optional[List[str]] = None,
           exclude_keys: Optional[List[str]] = None,
           cur: Optional[psycopg2.extensions.cursor] = None) -> Any:
    if include_keys is None:
        include_keys = row.keys()

    if exclude_keys is None:
        exclude_keys = []

    item_keys = [k for k in include_keys if k not in exclude_keys]

    with _ensure_cursor(cur) as cur:
        return fetchone(
            f'''
                insert into {table_name} ({', '.join(item_keys)}) 
                    values ({', '.join(['%s' for _ in item_keys])}) 
                    on conflict ({', '.join(primary_key)}) do update set 
                    {', '.join([f'{k} = excluded.{k}' for k in item_keys if k not in primary_key])}
                    returning *
            ''',
            tuple([row.get(key) for key in item_keys]),
            cur=cur)


def _format_sql_query(query: bytes) -> str:
    """
    Strips extra whitespace and newlines from SQL queries, so that they are easier to read in logs.

    :param query: query to format.
    :return: formatted query.
    """
    return FORMAT_SQL_QUERY_PAT.sub(' ', normalize_line_endings(query.decode('utf-8'))).strip()


def normalize_line_endings(s: str) -> str:
    """
    Converts various line ending characters/pairs into \n

    :param s: string with possibly abnormal line endings.
    :return: normalized string.
    """
    return s.replace('\r\n', '\n').replace('\r', '\n')


def retry(exc_types: Sequence[Type],
          max_attempts: Optional[int] = None,
          delay: int = 0,
          error_fn: Optional[Callable[[BaseException], None]] = None) -> Callable:
    """
    Decorator that automatically re-calls a function if it throws a set of expected exception types.

    :param exc_types: exception classes to retry on.
    :param max_attempts: max number of attempts to retry before re-throwing.
                         if None, there is no limit.
    :param delay: optional time delay between retry attempts (in seconds).
    :param error_fn: optional function to call (with exception) when an error occurs.
    """
    def retry_decorator(f: Callable) -> Callable:
        def retryable_func(*args, **kwargs):
            for attempt in range(max_attempts):
                try:
                    return f(*args, **kwargs)
                except tuple(exc_types) as ex:
                    if error_fn is not None:
                        error_fn(ex)

                    if attempt >= max_attempts:
                        raise

                    LOG.warning(f'Retrying because of {ex.__class__.__name__} error: {ex!r}')
                    if delay > 0:
                        time.sleep(delay)

        return functools.wraps(f)(retryable_func)

    return retry_decorator

Leave a Reply

Your email address will not be published. Required fields are marked *