Commit 50baac49 authored by Jay's avatar Jay Committed by jay
Browse files

Added the ability to save / load a graph

parent fa3fcf72
Loading
Loading
Loading
Loading
+34 −0
Original line number Diff line number Diff line
import os
import pickle

import networkx as nx
import numpy as np
@@ -58,6 +59,24 @@ class CandidateGraph(nx.Graph):
        for s, d, edge in self.edges_iter(data=True):
            self.edge[s][d] = Edge(self.node[s], self.node[d])

    @classmethod
    def from_graph(cls, graph):
        """
        Return a graph object from a pickled file
        Parameters
        ----------
        graph : str
                PATH to the graph object

        Returns
        -------
        graph : object
                CandidateGraph object
        """
        with open(graph, 'rb') as f:
            graph = pickle.load(f)
        return graph

    @classmethod
    def from_adjacency(cls, input_adjacency, basepath=None):
        """
@@ -459,6 +478,21 @@ class CandidateGraph(nx.Graph):
        """
        return sorted(nx.connected_components(self), key=len, reverse=True)

    def save(self, filename):
        """
        Save the graph object to disk.
        Parameters
        ----------
        filename : str
                   The relative or absolute PATH where the network is saved
        """
        for i, node in self.nodes_iter(data=True):
            # Close the file handle because pickle doesn't handle SwigPyObjects
            node._handle = None

        with open(filename, 'wb') as f:
            pickle.dump(self, f)

    # TODO: The Edge object requires a get method in order to be plottable, probably Node as well.
    # This is a function of being a dict in NetworkX
    def plot(self, ax=None, **kwargs):
+8 −2
Original line number Diff line number Diff line
@@ -2,11 +2,9 @@ import os
import sys
sys.path.insert(0, os.path.abspath('..'))

import numpy as np
import unittest

from autocnet.examples import get_path
from autocnet.fileio.io_gdal import GeoDataset

from .. import network

@@ -43,6 +41,14 @@ class TestCandidateGraph(unittest.TestCase):
        island = self.graph.island_nodes()[0]
        self.assertTrue(island in subgraph_list[1])

    def test_save_load(self):
        self.graph.save('test_save.cg')
        loaded = self.graph.from_graph('test_save.cg')

        self.assertEqual(self.graph.node[0].nkeypoints, loaded.node[0].nkeypoints)
        self.assertEqual(self.graph.edge[0][1], self.graph.edge[0][1])
        os.remove('test_save.cg')

    def tearDown(self):
        pass