Unverified Commit 3f30fd3a authored by jlaura's avatar jlaura Committed by GitHub
Browse files

Adds parallel support for parallelization using ntasks (#611)

parent e6a712ad
Loading
Loading
Loading
Loading
+15 −4
Original line number Original line Diff line number Diff line
import json
import json
import time
import time


from sqlalchemy import insert, update
from sqlalchemy import insert
from sqlalchemy.sql.expression import bindparam
from sqlalchemy.sql.expression import bindparam


from autocnet.io.db.model import Points, Measures
from autocnet.io.db.model import Points, Measures
from autocnet.utils.serializers import object_hook
from autocnet.utils.serializers import object_hook
from autocnet.transformation.spatial import reproject, og2oc


def watch_insert_queue(queue, queue_name, counter_name, engine, stop_event, sleep_time=5):
def watch_insert_queue(queue, queue_name, counter_name, engine, stop_event, sleep_time=5):
    """
    """
@@ -57,7 +58,8 @@ def watch_insert_queue(queue, queue_name, counter_name, engine, stop_event, slee
        measures = []
        measures = []
        
        
        # Pull the SRID dynamically from the model (database)
        # Pull the SRID dynamically from the model (database)
        srid = Points.rectangular_srid
        rect_srid = Points.rectangular_srid
        lat_srid = Points.latitudinal_srid


        for i in range(0, read_length):
        for i in range(0, read_length):
            msg = json.loads(queue.lpop(queue_name), object_hook=object_hook)
            msg = json.loads(queue.lpop(queue_name), object_hook=object_hook)
@@ -68,7 +70,16 @@ def watch_insert_queue(queue, queue_name, counter_name, engine, stop_event, slee


                # Since this avoids the ORM, need to map the table names manually
                # Since this avoids the ORM, need to map the table names manually
                msg['pointType'] = msg['pointtype']  
                msg['pointType'] = msg['pointtype']  
                msg['adjusted'] = f'SRID={srid};' + msg["adjusted"].wkt  # Geometries go in as EWKT
                adjusted = msg['adjusted']
            
                msg['adjusted'] = f'SRID={rect_srid};' + adjusted.wkt  # Geometries go in as EWKT
                msg['apriori'] = f'SRID={rect_srid};' + adjusted.wkt

                lon_og, lat_og, _ = reproject([adjusted.x, adjusted.y, adjusted.z],
                                    Points.semimajor_rad, Points.semiminor_rad,
                                    'geocent', 'latlon')
                lon, lat = og2oc(lon_og, lat_og, Points.semimajor_rad, Points.semiminor_rad)
                msg['geom'] = f'SRID={lat_srid};Point({lon} {lat})' 


                # Measures are removed and manually added later
                # Measures are removed and manually added later
                point_measures = msg.pop('measures', [])
                point_measures = msg.pop('measures', [])
+57 −43
Original line number Original line Diff line number Diff line
@@ -20,11 +20,13 @@ from autocnet.utils.utils import import_func
from autocnet.utils.serializers import JsonEncoder, object_hook
from autocnet.utils.serializers import JsonEncoder, object_hook
from autocnet.io.db.model import JobsHistory
from autocnet.io.db.model import JobsHistory



def parse_args():  # pragma: no cover
def parse_args():  # pragma: no cover
    parser = argparse.ArgumentParser()
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--host', help='The host URL for the redis queue to to pull messages from.')
    parser.add_argument('-r', '--host', help='The host URL for the redis queue to to pull messages from.')
    parser.add_argument('-p', '--port', help='The port for used by redis.')
    parser.add_argument('-p', '--port', help='The port for used by redis.')
    parser.add_argument('-q', '--queue', default=False, action='store_true',
                        help='If passed, run in queue mode, where this job runs until either \
                              walltime is hit or the queue that is being processed is empty.')
    parser.add_argument('processing_queue', help='The name of the processing queue to draw messages from.')
    parser.add_argument('processing_queue', help='The name of the processing queue to draw messages from.')
    parser.add_argument('working_queue', help='The name of the queue to push messages to while they process.')
    parser.add_argument('working_queue', help='The name of the queue to push messages to while they process.')


@@ -55,7 +57,6 @@ def _instantiate_row(msg, ncg):
    """
    """
    # Get the dict mapping iterable keyword types to the objects
    # Get the dict mapping iterable keyword types to the objects
    objdict = ncg.apply_iterable_options
    objdict = ncg.apply_iterable_options
    rowid = msg['id']
    obj = objdict[msg['along']]
    obj = objdict[msg['along']]
    with ncg.session_scope() as session:
    with ncg.session_scope() as session:
        res = session.query(obj).filter(getattr(obj, 'id')==msg['id']).one()
        res = session.query(obj).filter(getattr(obj, 'id')==msg['id']).one()
@@ -167,14 +168,21 @@ def manage_messages(args, queue):
            A py-Redis queue object
            A py-Redis queue object


    """
    """
    processing = True
    
    while processing:
        # Pop the message from the left queue and push to the right queue; atomic operation
        # Pop the message from the left queue and push to the right queue; atomic operation
        msg = transfer_message_to_work_queue(queue,
        msg = transfer_message_to_work_queue(queue,
                                            args['processing_queue'],
                                            args['processing_queue'],
                                            args['working_queue'])
                                            args['working_queue'])
        
        
        if msg is None:
        if msg is None:
            if args['queue'] == False:
                warnings.warn('Expected to process a cluster job, but the message queue is empty.')
                warnings.warn('Expected to process a cluster job, but the message queue is empty.')
                return
                return
            elif args['queue'] == True:
                print(f'Completed processing from queue: {queue}.')
                return


        # The key to remove from the working queue is the message. Essentially, find this element
        # The key to remove from the working queue is the message. Essentially, find this element
        # in the list where the element is the JSON representation of the message. Maybe swap to a hash?
        # in the list where the element is the JSON representation of the message. Maybe swap to a hash?
@@ -196,22 +204,28 @@ def manage_messages(args, queue):
        # print to get everything on the logs in the directory
        # print to get everything on the logs in the directory
        print(out)
        print(out)


    serializedDict = json.loads(msg)
        sys.stdout.flush()
    results  = msgdict['results'] if msgdict['results'] else [{"status" : "success"}]
        stdout.flush()
    success = True if "success" in results[0]["status"].split(" ")[0].lower() else False

        #serializedDict = json.loads(msg)
        #results  = msgdict['results'] if msgdict['results'] else [{"status" : "success"}]
        #success = True if "success" in results[0]["status"].split(" ")[0].lower() else False


    jh = JobsHistory(jobId=int(os.environ["SLURM_JOB_ID"]), functionName=msgdict["func"], args={"args" : serializedDict["args"], "kwargs": serializedDict["kwargs"]}, results=msgdict["results"], logs=out, success=success)
        #jh = JobsHistory(jobId=int(os.environ["SLURM_JOB_ID"]), functionName=msgdict["func"], args={"args" : serializedDict["args"], "kwargs": serializedDict["kwargs"]}, results=msgdict["results"], logs=out, success=success)
        
        
    with response['kwargs']['Session']() as session:
        #with response['kwargs']['Session']() as session:
        session.add(jh)
            #session.add(jh)
        session.commit()
            #session.commit()


        finalize_message_from_work_queue(queue, args['working_queue'], remove_key)
        finalize_message_from_work_queue(queue, args['working_queue'], remove_key)


        # Process only a single job, else draw the next message off the queue if available.
        if args['queue'] == False:
            processing = False
        

def main():  # pragma: no cover
def main():  # pragma: no cover
    args = vars(parse_args())
    args = vars(parse_args())
    # Get the message
    # Get the message
    queue = StrictRedis(host=args['host'], port=args['port'], db=0)
    queue = StrictRedis(host=args['host'], port=args['port'], db=0)
    manage_messages(args, queue)
    manage_messages(args, queue)
    
+21 −5
Original line number Original line Diff line number Diff line
@@ -1658,8 +1658,10 @@ class NetworkCandidateGraph(CandidateGraph):
            on='edge',
            on='edge',
            args=(),
            args=(),
            walltime='01:00:00',
            walltime='01:00:00',
            jobname='AutoCNet',
            chunksize=1000,
            chunksize=1000,
            arraychunk=25,
            arraychunk=25,
            ntasks=1,
            filters={},
            filters={},
            query_string='',
            query_string='',
            reapply=False,
            reapply=False,
@@ -1706,6 +1708,14 @@ class NetworkCandidateGraph(CandidateGraph):
                     The number of concurrent jobs to run per job array. e.g. chunksize=100 and
                     The number of concurrent jobs to run per job array. e.g. chunksize=100 and
                     arraychunk=25 gives the job array 1-100%25
                     arraychunk=25 gives the job array 1-100%25


        ntasks : int
                 The number of tasks, distributed across the cluster on some set of nodes to be run.
                 When running apply with ntasks, set ntasks to some integer greater then 1. arraychunk and
                 chunksize arguments will then be ignored. In this mode, a number of non-communicating 
                 CPUs equal to ntasks are allocated and these CPUs run jobs. Changing from arrays to ntasks
                 also likely requires increasing the walltime of the job significantly since less jobs
                 will need to run for a longer duration.

        filters : dict
        filters : dict
                  Of simple filters to apply on database rows where the key is the attribute and
                  Of simple filters to apply on database rows where the key is the attribute and
                  the value used to check equivalency (e.g., attribute == value).
                  the value used to check equivalency (e.g., attribute == value).
@@ -1810,17 +1820,23 @@ class NetworkCandidateGraph(CandidateGraph):
        isissetup = f'export ISISROOT={isisroot} && export ISISDATA={isisdata}'
        isissetup = f'export ISISROOT={isisroot} && export ISISDATA={isisdata}'
        condasetup = f'conda activate {condaenv}'
        condasetup = f'conda activate {condaenv}'
        job = f'acn_submit -r={rhost} -p={rport} {processing_queue} {self.working_queue}'
        job = f'acn_submit -r={rhost} -p={rport} {processing_queue} {self.working_queue}'
        command = f'{condasetup} && {isissetup} && {job}'
        if ntasks > 1:
            job += ' --queue'  # Use queue mode where jobs run until the queue is empty
        command = f'{condasetup} && {isissetup} && srun {job}'


        if queue == None:
        if queue == None:
            queue = self.config['cluster']['queue']
            queue = self.config['cluster']['queue']


        submitter = Slurm(command,
        submitter = Slurm(command,
                     job_name='AutoCNet',
                     job_name=jobname,
                     mem_per_cpu=self.config['cluster']['processing_memory'],
                     mem_per_cpu=self.config['cluster']['processing_memory'],
                     time=walltime,
                     time=walltime,
                     partition=queue,
                     partition=queue,
                     ntasks=ntasks,
                     output=log_dir+f'/autocnet.{function}-%j')
                     output=log_dir+f'/autocnet.{function}-%j')
        if ntasks > 1:
            job_str = submitter.submit(exclude=exclude)
        else:
            job_str = submitter.submit(array='1-{}%{}'.format(job_counter,arraychunk),
            job_str = submitter.submit(array='1-{}%{}'.format(job_counter,arraychunk),
                                    chunksize=chunksize,
                                    chunksize=chunksize,
                                    exclude=exclude)
                                    exclude=exclude)
+2 −1
Original line number Original line Diff line number Diff line
@@ -16,7 +16,8 @@ from autocnet.io.db.model import Points, JobsHistory
@pytest.fixture
@pytest.fixture
def args():
def args():
    arg_dict = {'working_queue':'working',
    arg_dict = {'working_queue':'working',
                'processing_queue':'processing'}
                'processing_queue':'processing',
                'queue':False}
    return arg_dict
    return arg_dict


@pytest.fixture
@pytest.fixture