Unverified Commit 3c61d5e3 authored by jlaura's avatar jlaura Committed by GitHub
Browse files

Adds arbitrary iterable submission (#492)

* Adds arbitrary iterable submission

* Updates per comments
parent 309e352b
Loading
Loading
Loading
Loading
+23 −6
Original line number Diff line number Diff line
@@ -1450,7 +1450,7 @@ class NetworkCandidateGraph(CandidateGraph):
                                        weights=json.dumps({})))
            session.add_all(to_add)
            session.commit()
            print(len(to_add))

    def _setup_queues(self):
        """
        Setup a 2 queue redis connection for pushing and pulling work/results
@@ -1556,6 +1556,18 @@ class NetworkCandidateGraph(CandidateGraph):
            assert len(res) == self.queue_length
        return len(res)

    def _push_iterable_message(self, iterable, function, walltime, args, kwargs):
        for job_counter, item in enumerate(iterable):
            msg = {'along':item,
                    'func':function,
                    'args':args,
                    'kwargs':kwargs,
                    'walltime':walltime}
            msg['config'] = self.config
            self.redis_queue.rpush(self.processing_queue,
                                   json.dumps(msg, cls=JsonEncoder))
        return job_counter + 1

    def apply(self, function, on='edge', args=(), walltime='01:00:00', chunksize=1000, arraychunk=25, filters={}, query_string='', reapply=False, **kwargs):
        """
        A mirror of the apply function from the standard CandidateGraph object. This implementation
@@ -1641,8 +1653,10 @@ class NetworkCandidateGraph(CandidateGraph):

        if not reapply:
            # Determine which obj will be called
            if isinstance(on, str):
                onobj = self.apply_iterable_options[on]
            res = []
            elif isinstance(on, list):
                onobj = on
                
            # This method support arbitrary functions. The name needs to be a string for the log name.
            if not isinstance(function, (str, bytes)):
@@ -1653,9 +1667,12 @@ class NetworkCandidateGraph(CandidateGraph):
            # Dispatch to either the database object message generator or the autocnet object message generator
            if isinstance(onobj, DeclarativeMeta):
                job_counter = self._push_row_messages(onobj, on, function, walltime, filters, query_string, args, kwargs)
            else:
            elif isinstance(onobj, list):
                job_counter = self._push_iterable_message(onobj, function, walltime, args, kwargs)
            elif isinstance(onobj, (Node, NetworkNode, Edge, NetworkEdge)):
                job_counter = self._push_obj_messages(onobj, function, walltime, args, kwargs)

            else:
                raise TypeError('The type of the `on` argument is not understood. Must be a database model, iterable, Node or Edge.')

        # Submit the jobs
        rconf = self.config['redis']
+3 −2
Original line number Diff line number Diff line
@@ -50,14 +50,14 @@ def _instantiate_row(msg, ncg):
    return res

def main(msg):

    ncg = NetworkCandidateGraph()
    ncg.config_from_dict(msg['config'])

    if msg['along'] in ['node', 'edge']:
        obj = _instantiate_obj(msg, ncg)
    elif msg['along'] in ['points', 'measures', 'overlaps']:
        obj = _instantiate_row(msg, ncg)
    else:
        obj = msg['along']

    # Grab the function and apply. This assumes that the func is going to
    # have a True/False return value. Basically, all processing needs to
@@ -67,6 +67,7 @@ def main(msg):
    func = msg['func']
    if callable(func):  # The function is a de-serialzied function
        msg['args'] = (obj, *msg['args'])
        msg['kwargs']['ncg'] = ncg
    elif hasattr(obj, msg['func']):  # The function is a method on the object
        func = getattr(obj, msg['func'])
    else:  # The func is a function from a library to be imported