Commit 9bb0b92b authored by Laura, Jason R.'s avatar Laura, Jason R.
Browse files

Merge branch 'dbretry' into 'main'

Centralizes retry logic into db.io.connection.py and adds retry logic to the NCG setup.

See merge request astrogeology/autocnet!702
parents 0279ec96 6a9bfeca
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -38,6 +38,10 @@ release.
### Added
- Debug logging to `place_points_in_overlap` and `distribute_points_in_geom` to make debugging issues easier.

### Changed
- cluster submission object loading to use [`joinedload('*')`](https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#wildcard-loading-strategies) in order to properly load and expunge measures on point objects. This change is at the query level (and not the mapper level).
- Acceptable affine residual bumped to 1px, better support for KaguyaTC data.

### Fixed
- Error in `find_interesting_feature` that was mis-using the ROI API. This bug was introduced in 1.2.0 when the ROI API updated.

+4 −2
Original line number Diff line number Diff line
@@ -213,9 +213,11 @@ def main(): # pragma: no cover
    args = vars(parse_args())
    # set up the logger
    logging.basicConfig(level=os.environ.get("AUTOCNET_LOGLEVEL", "INFO"))
    # Get the message
    # Get the message; 30s timeout should be long enough to handle minor
    # network issues or congestion without holding onto cluster resources
    # for an undue amount of time.
    queue = StrictRedis(host=args['host'], port=args['port'], db=0,
                        socket_timeout=30, socket_connect_timeout=300)
                        socket_timeout=30, socket_connect_timeout=30)
    manage_messages(args, queue)

if __name__ == '__main__':
+30 −45
Original line number Diff line number Diff line
import json
import socket
import logging
import sys
from time import sleep

from sqlalchemy import create_engine
from sqlalchemy.pool import NullPool
# Set the logging level
logging.basicConfig(level='INFO')
logger = logging.getLogger()

from autocnet.graph.node import NetworkNode
from autocnet.graph.edge import NetworkEdge
from autocnet.utils.utils import import_func
from autocnet.utils.serializers import object_hook
from autocnet.io.db.model import Measures, Points, Overlay, Images
from autocnet.io.db.connection import retry, new_connection

from sqlalchemy.orm import joinedload

apply_iterable_options = {
                'measures' : Measures,
@@ -31,21 +34,14 @@ apply_iterable_options = {
                5: Images
            }

def retry(max_retries=3, wait_time=300):
    def decorator(func):
        def wrapper(*args, **kwargs):
            retries = 0
            if retries < max_retries:
                try:
                    result = func(*args, **kwargs)
                    return result
                except:
                    retries += 1
                    sleep(wait_time)
            else:
                raise Exception(f"Maximum retires of {func} exceeded")
        return wrapper
    return decorator
def set_srids(spatial):
    latitudinal_srid = spatial['latitudinal_srid']
    rectangular_srid = spatial['rectangular_srid']
    for cls in [Points, Overlay, Images]:
        setattr(cls, 'latitudinal_srid', latitudinal_srid)
        setattr(cls, 'rectangular_srid', rectangular_srid)
    Points.semimajor_rad = spatial['semimajor_rad']
    Points.semiminor_rad = spatial['semiminor_rad']

@retry(max_retries=5)
def _instantiate_obj(msg):
@@ -73,27 +69,13 @@ def _instantiate_row(msg, session):
    """
    # Get the dict mapping iterable keyword types to the objects
    obj = apply_iterable_options[msg['along']]
    res = session.query(obj).filter(getattr(obj, 'id')==msg['id']).one()
    session.expunge_all() # Disconnect the object from the session
    res = session.query(obj). \
            filter(getattr(obj, 'id')==msg['id']). \
            options(joinedload('*')). \
            one()
    session.expunge_all()
    return res

@retry()
def get_db_connection(dbconfig):
    db_uri = 'postgresql://{}:{}@{}:{}/{}'.format(dbconfig['username'],
                                                  dbconfig['password'],
                                                  dbconfig['host'],
                                                  dbconfig['pgbouncer_port'],
                                                  dbconfig['name'])
    hostname = socket.gethostname()

    engine = create_engine(db_uri,
        poolclass=NullPool,
        connect_args={"application_name":f"AutoCNet_{hostname}"},
        isolation_level="AUTOCOMMIT",
        pool_pre_ping=True)
    return engine

@retry()
def execute_func(func, *args, **kwargs):
    return func(*args, **kwargs)

@@ -112,8 +94,13 @@ def process(msg):
    # Deserialize the message
    msg = json.loads(msg, object_hook=object_hook)

    engine = get_db_connection(msg['config']['database'])
    # Get the database connection
    engine = new_connection(msg['config']['database'])
    
    # Set the SRIDs on the table objects based on the passed config
    set_srids(msg['config']['spatial'])

    # Instantiate the objects to be used
    if msg['along'] in ['node', 'edge']:
        obj = _instantiate_obj(msg)
    elif msg['along'] in ['points', 'measures', 'overlaps', 'images']:
@@ -145,19 +132,17 @@ def process(msg):
    # Now run the function.
    res = execute_func(func,*msg['args'], **msg['kwargs'])

    del Session
    del engine

    # Update the message with the True/False
    msg['results'] = res
    # Update the message with the correct callback function
    
    engine.dispose()
    del engine
    return msg

def main():
    msg = ''.join(sys.argv[1:])
    result = process(msg)
    print(result)
    logging.info('Result: ', result)
    
if __name__ == '__main__':
    main()
+2 −17
Original line number Diff line number Diff line
@@ -1620,23 +1620,8 @@ class NetworkCandidateGraph(CandidateGraph):
        self._Session = Session

    def _setup_database(self):
        # A non-linear timeout if the DB is spinning up or loaded with many connections.
        sleeptime = 2
        retries = 0
        self.Session, self.engine = new_connection(self.config['database'])
        self.Session, self.engine = new_connection(self.config['database'], with_session=True)
        try_db_creation(self.engine, self.config)
        return
        while retries < 5:
            log.debug(f'Database connection attempt {retries}')
            try:
                self.Session, self.engine = new_connection(self.config['database'])

                # Attempt to create the database (if it does not exist)
                try_db_creation(self.engine, self.config)
                break
            except:
                retries += 1
                sleep(retries ** sleeptime)

    # def _setup_nodes(self):
    #     with self.session_scope() as session:
@@ -2433,7 +2418,7 @@ class NetworkCandidateGraph(CandidateGraph):
        >>> ncg.add_from_remote_database(source_db_config, outpath, query_string=query)
        """

        sourceSession, _ = new_connection(source_db_config)
        sourceSession, _ = new_connection(source_db_config, with_session=True)
        sourcesession = sourceSession()

        sourceimages = sourcesession.execute(query_string).fetchall()
+30 −16
Original line number Diff line number Diff line
import socket

import sqlalchemy
from sqlalchemy import create_engine, pool, orm
from sqlalchemy.orm import create_session, scoped_session, sessionmaker

import logging
import os
import socket
import warnings
import yaml
from time import sleep

from sqlalchemy import orm, create_engine, pool

# set up the logging file
log = logging.getLogger(__name__)
@@ -19,7 +13,24 @@ class Parent:
        self.session = Session()
        self.session.begin()

def new_connection(dbconfig):
def retry(max_retries=5, wait_time=300):
    def decorator(func):
        def wrapper(*args, **kwargs):
            retries = 0
            if retries < max_retries:
                try:
                    result = func(*args, **kwargs)
                    return result
                except:
                    retries += 1
                    sleep(wait_time)
            else:
                raise Exception(f"Maximum retries of {func} exceeded! Is the database accessible?")
        return wrapper
    return decorator

@retry()
def new_connection(dbconfig, with_session=False):
    """
    Using the user supplied config create a NullPool database connection.

@@ -29,6 +40,8 @@ def new_connection(dbconfig):
               Dictionary defining necessary parameters for the database
               connection

    with_session : boolean
                   If true return a SQL Alchemy session factory. Default False.
    Returns
    -------
    Session : object
@@ -43,11 +56,12 @@ def new_connection(dbconfig):
                                                  dbconfig['pgbouncer_port'],
                                                  dbconfig['name'])
    hostname = socket.gethostname()
    engine = sqlalchemy.create_engine(db_uri,
                poolclass=sqlalchemy.pool.NullPool,
    engine = create_engine(db_uri,
                poolclass=pool.NullPool,
                connect_args={"application_name":f"AutoCNet_{hostname}"},
                isolation_level="AUTOCOMMIT",
                pool_pre_ping=True)
    if with_session:
        Session = orm.sessionmaker(bind=engine, autocommit=False)
        log.debug(Session, engine)
        return Session, engine
    return engine
Loading