Commit 6ef538d5 authored by Jay's avatar Jay
Browse files

More connection updates

parent 2aedc6d5
Loading
Loading
Loading
Loading
+17 −11
Original line number Diff line number Diff line
import json
import socket
import sys

from sqlalchemy import create_engine
from sqlalchemy.pool import NullPool

from autocnet.graph.node import NetworkNode
from autocnet.graph.edge import NetworkEdge
from autocnet.utils.utils import import_func
@@ -31,7 +27,14 @@ apply_iterable_options = {
                5: Images
            }


def set_srids(spatial):
    latitudinal_srid = spatial['latitudinal_srid']
    rectangular_srid = spatial['rectangular_srid']
    for cls in [Points, Overlay, Images]:
        setattr(cls, 'latitudinal_srid', latitudinal_srid)
        setattr(cls, 'rectangular_srid', rectangular_srid)
    Points.semimajor_rad = spatial['semimajor_rad']
    Points.semiminor_rad = spatial['semiminor_rad']

@retry(max_retries=5)
def _instantiate_obj(msg):
@@ -82,8 +85,13 @@ def process(msg):
    # Deserialize the message
    msg = json.loads(msg, object_hook=object_hook)

    _, engine = new_connection(msg['config']['database'])
    # Get the database connection
    engine = new_connection(msg['config']['database'])
    
    # Set the SRIDs on the table objects based on the passed config
    set_srids(msg['config']['spatial'])

    # Instantiate the objects to be used
    if msg['along'] in ['node', 'edge']:
        obj = _instantiate_obj(msg)
    elif msg['along'] in ['points', 'measures', 'overlaps', 'images']:
@@ -115,13 +123,11 @@ def process(msg):
    # Now run the function.
    res = execute_func(func,*msg['args'], **msg['kwargs'])

    del Session
    del engine

    # Update the message with the True/False
    msg['results'] = res
    # Update the message with the correct callback function
    
    engine.dispose()
    del engine
    return msg

def main():
+1 −1
Original line number Diff line number Diff line
@@ -1620,7 +1620,7 @@ class NetworkCandidateGraph(CandidateGraph):
        self._Session = Session

    def _setup_database(self):
        self.Session, self.engine = new_connection(self.config['database'])
        self.Session, self.engine = new_connection(self.config['database'], with_session=True)
        try_db_creation(self.engine, self.config)

    # def _setup_nodes(self):
+36 −33
Original line number Diff line number Diff line
from contextlib import nullcontext
import logging
import warnings

from sqlalchemy.sql.expression import bindparam

@@ -8,7 +7,11 @@ from autocnet.io.db.connection import retry
from autocnet.io.db.model import Images, Overlay, Points, Measures
from autocnet.graph.node import NetworkNode

@retry
# set up the logger file
log = logging.getLogger(__name__)


@retry()
def update_measures(ncg, session, measures_iterable_to_update):
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        stmt = Measures.__table__.update().\
@@ -23,7 +26,7 @@ def update_measures(ncg, session, measures_iterable_to_update):
        session.execute(stmt, measures_iterable_to_update)
    return

@retry
@retry()
def ignore_measures(ncg, session, measures_iterable_to_ignore, chooser):
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        measures_to_set_false = [{'_id':i} for i in measures_to_set_false]
@@ -33,49 +36,48 @@ def ignore_measures(ncg, session, measures_iterable_to_ignore, chooser):
                                values({'measureIgnore':True,
                                        'ChooserName':chooser})
        session.execute(stmt, measures_to_set_false)
    return 

@retry
#@retry(wait_time=30)
def get_nodes_for_overlap(ncg, session, overlap):
    # If an NCG is passed, instantiate a session off the NCG, else just pass the session through
    ids = tuple([i for i in overlap.intersections])
    nodes = []
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        for id in overlap.intersections:
            try:
                res = session.query(Images).filter(Images.id == id).one()
            except Exception as e:
                warnings.warn(f'Unable to instantiate image with id: {id} with error: {e}')
                continue
            nn = NetworkNode(node_id=id, 
        results = session.query(Images).filter(Images.id.in_(ids)).all()
    
        for res in results:
            nn = NetworkNode(node_id=res.id, 
                            image_path=res.path, 
                            cam_type=res.cam_type,
                            dem=res.dem,
                            dem_type=res.dem_type)
            nodes.append(nn)
    return nodes

@retry
@retry(wait_time=30)
def get_nodes_for_measures(ncg, session, measures):
    nodes = {}
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
            for measure in measures:
                res = session.query(Images).filter(Images.id == measure.imageid).one()
                logging.debug(f'Node instantiation image query result: {res.path, res.cam_type, res.dem, res.dem_type}')
                nn = NetworkNode(node_id=measure.imageid, 
        imageids = tuple([measure.imageid for measure in measures])
        results = session.query(Images).filter(Images.id.in_(imageids)).all()
        for res in results:
            nn = NetworkNode(node_id=res.imageid, 
                            image_path=res.path,
                            cam_type=res.cam_type,
                            dem=res.dem,
                            dem_type=res.dem_type)
                nodes[measure.imageid] = nn
            session.expunge_all()  
            nodes[res.imageid] = nn
    return nodes  

@retry
@retry(wait_time=30)
def get_overlap(ncg, session, overlapid):
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        overlap = session.query(Overlay).filter(Overlay.id == overlapid).one()
        session.expunge_all()
    return overlap

@retry
@retry(wait_time=30)
def get_point(ncg, session, pointid):
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        point = session.query(Points).filter(Points.id == pointid).one()
@@ -84,8 +86,9 @@ def get_point(ncg, session, pointid):



@retry

def bulk_commit(ncg, session, iterable_of_objs_to_commit):
    with ncg.session_scope() if ncg else nullcontext(session) as session:
        session.add_all(iterable_of_objs_to_commit)
        session.commit()
    return