Commit 6ef538d5 authored by Jay's avatar Jay
Browse files

More connection updates

parent 2aedc6d5
Loading
Loading
Loading
Loading
+17 −11
Original line number Original line Diff line number Diff line
import json
import json
import socket
import sys
import sys


from sqlalchemy import create_engine
from sqlalchemy.pool import NullPool

from autocnet.graph.node import NetworkNode
from autocnet.graph.node import NetworkNode
from autocnet.graph.edge import NetworkEdge
from autocnet.graph.edge import NetworkEdge
from autocnet.utils.utils import import_func
from autocnet.utils.utils import import_func
@@ -31,7 +27,14 @@ apply_iterable_options = {
                5: Images
                5: Images
            }
            }



def set_srids(spatial):
    latitudinal_srid = spatial['latitudinal_srid']
    rectangular_srid = spatial['rectangular_srid']
    for cls in [Points, Overlay, Images]:
        setattr(cls, 'latitudinal_srid', latitudinal_srid)
        setattr(cls, 'rectangular_srid', rectangular_srid)
    Points.semimajor_rad = spatial['semimajor_rad']
    Points.semiminor_rad = spatial['semiminor_rad']


@retry(max_retries=5)
@retry(max_retries=5)
def _instantiate_obj(msg):
def _instantiate_obj(msg):
@@ -82,8 +85,13 @@ def process(msg):
    # Deserialize the message
    # Deserialize the message
    msg = json.loads(msg, object_hook=object_hook)
    msg = json.loads(msg, object_hook=object_hook)


    _, engine = new_connection(msg['config']['database'])
    # Get the database connection
    engine = new_connection(msg['config']['database'])
    
    # Set the SRIDs on the table objects based on the passed config
    set_srids(msg['config']['spatial'])


    # Instantiate the objects to be used
    if msg['along'] in ['node', 'edge']:
    if msg['along'] in ['node', 'edge']:
        obj = _instantiate_obj(msg)
        obj = _instantiate_obj(msg)
    elif msg['along'] in ['points', 'measures', 'overlaps', 'images']:
    elif msg['along'] in ['points', 'measures', 'overlaps', 'images']:
@@ -115,13 +123,11 @@ def process(msg):
    # Now run the function.
    # Now run the function.
    res = execute_func(func,*msg['args'], **msg['kwargs'])
    res = execute_func(func,*msg['args'], **msg['kwargs'])


    del Session
    del engine

    # Update the message with the True/False
    # Update the message with the True/False
    msg['results'] = res
    msg['results'] = res
    # Update the message with the correct callback function
    
    
    engine.dispose()
    del engine
    return msg
    return msg


def main():
def main():
+1 −1
Original line number Original line Diff line number Diff line
@@ -1620,7 +1620,7 @@ class NetworkCandidateGraph(CandidateGraph):
        self._Session = Session
        self._Session = Session


    def _setup_database(self):
    def _setup_database(self):
        self.Session, self.engine = new_connection(self.config['database'])
        self.Session, self.engine = new_connection(self.config['database'], with_session=True)
        try_db_creation(self.engine, self.config)
        try_db_creation(self.engine, self.config)


    # def _setup_nodes(self):
    # def _setup_nodes(self):
+36 −33
Original line number Original line Diff line number Diff line
from contextlib import nullcontext
from contextlib import nullcontext
import logging
import logging
import warnings


from sqlalchemy.sql.expression import bindparam
from sqlalchemy.sql.expression import bindparam


@@ -8,7 +7,11 @@ from autocnet.io.db.connection import retry
from autocnet.io.db.model import Images, Overlay, Points, Measures
from autocnet.io.db.model import Images, Overlay, Points, Measures
from autocnet.graph.node import NetworkNode
from autocnet.graph.node import NetworkNode


@retry
# set up the logger file
log = logging.getLogger(__name__)


@retry()
def update_measures(ncg, session, measures_iterable_to_update):
def update_measures(ncg, session, measures_iterable_to_update):
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        stmt = Measures.__table__.update().\
        stmt = Measures.__table__.update().\
@@ -23,7 +26,7 @@ def update_measures(ncg, session, measures_iterable_to_update):
        session.execute(stmt, measures_iterable_to_update)
        session.execute(stmt, measures_iterable_to_update)
    return
    return


@retry
@retry()
def ignore_measures(ncg, session, measures_iterable_to_ignore, chooser):
def ignore_measures(ncg, session, measures_iterable_to_ignore, chooser):
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        measures_to_set_false = [{'_id':i} for i in measures_to_set_false]
        measures_to_set_false = [{'_id':i} for i in measures_to_set_false]
@@ -33,49 +36,48 @@ def ignore_measures(ncg, session, measures_iterable_to_ignore, chooser):
                                values({'measureIgnore':True,
                                values({'measureIgnore':True,
                                        'ChooserName':chooser})
                                        'ChooserName':chooser})
        session.execute(stmt, measures_to_set_false)
        session.execute(stmt, measures_to_set_false)
    return 


@retry
#@retry(wait_time=30)
def get_nodes_for_overlap(ncg, session, overlap):
def get_nodes_for_overlap(ncg, session, overlap):
    # If an NCG is passed, instantiate a session off the NCG, else just pass the session through
    # If an NCG is passed, instantiate a session off the NCG, else just pass the session through
    ids = tuple([i for i in overlap.intersections])
    nodes = []
    nodes = []
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        for id in overlap.intersections:
        results = session.query(Images).filter(Images.id.in_(ids)).all()
            try:
    
                res = session.query(Images).filter(Images.id == id).one()
        for res in results:
            except Exception as e:
            nn = NetworkNode(node_id=res.id, 
                warnings.warn(f'Unable to instantiate image with id: {id} with error: {e}')
                continue
            nn = NetworkNode(node_id=id, 
                            image_path=res.path, 
                            image_path=res.path, 
                            cam_type=res.cam_type,
                            cam_type=res.cam_type,
                            dem=res.dem,
                            dem=res.dem,
                            dem_type=res.dem_type)
                            dem_type=res.dem_type)
            nodes.append(nn)
            nodes.append(nn)
    return nodes


@retry
@retry(wait_time=30)
def get_nodes_for_measures(ncg, session, measures):
def get_nodes_for_measures(ncg, session, measures):
    nodes = {}
    nodes = {}
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
            for measure in measures:
        imageids = tuple([measure.imageid for measure in measures])
                res = session.query(Images).filter(Images.id == measure.imageid).one()
        results = session.query(Images).filter(Images.id.in_(imageids)).all()
                logging.debug(f'Node instantiation image query result: {res.path, res.cam_type, res.dem, res.dem_type}')
        for res in results:
                nn = NetworkNode(node_id=measure.imageid, 
            nn = NetworkNode(node_id=res.imageid, 
                            image_path=res.path,
                            image_path=res.path,
                            cam_type=res.cam_type,
                            cam_type=res.cam_type,
                            dem=res.dem,
                            dem=res.dem,
                            dem_type=res.dem_type)
                            dem_type=res.dem_type)
                nodes[measure.imageid] = nn
            nodes[res.imageid] = nn
            session.expunge_all()  
    return nodes  
    return nodes  


@retry
@retry(wait_time=30)
def get_overlap(ncg, session, overlapid):
def get_overlap(ncg, session, overlapid):
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        overlap = session.query(Overlay).filter(Overlay.id == overlapid).one()
        overlap = session.query(Overlay).filter(Overlay.id == overlapid).one()
        session.expunge_all()
        session.expunge_all()
    return overlap
    return overlap


@retry
@retry(wait_time=30)
def get_point(ncg, session, pointid):
def get_point(ncg, session, pointid):
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
    with ncg.session_scope() if ncg is not None else nullcontext(session) as session:
        point = session.query(Points).filter(Points.id == pointid).one()
        point = session.query(Points).filter(Points.id == pointid).one()
@@ -84,8 +86,9 @@ def get_point(ncg, session, pointid):






@retry

def bulk_commit(ncg, session, iterable_of_objs_to_commit):
def bulk_commit(ncg, session, iterable_of_objs_to_commit):
    with ncg.session_scope() if ncg else nullcontext(session) as session:
    with ncg.session_scope() if ncg else nullcontext(session) as session:
        session.add_all(iterable_of_objs_to_commit)
        session.add_all(iterable_of_objs_to_commit)
        session.commit()
        session.commit()
    return