import sys
from vtk import *

class filterBase:
    """ A superclass for our simple filter helpers """

    def __init__(self, inputPort):
        """ Baseclass Constructor """
        self._filter = vtkProgrammableFilter()
        self._filter.AddInputConnection(inputPort)
        self._filter.SetExecuteMethod(self.callback)
        self.Name = None
        
    def __str__(self):
        return( "%s"%(self._filter) )

    def callback(self):
        message = "A custom callback() method was not found!"
        raise NotImplementedError(message)

    def GetOutputPort(self):
        return( self._filter.GetOutputPort() )
    
    def GetOutput(self):
        return( self._filter.GetOutput() )
    
    def Update(self):
        self._filter.Update()

    def SetName(self, newName):
        self.Name = newName



class createEdgeIndexFilter(filterBase):
    """ A helper class to add edge indices to a graph """
    def __init__(self, inputPort):
        filterBase.__init__(self, inputPort)
        self.SetName("edge_index")
        
    def callback(self):
        input  = self._filter.GetInput()
        output = self._filter.GetOutput()
        output.ShallowCopy( input )
        _array = vtkIntArray()
        _array.SetName(self.Name)
        _array.SetNumberOfTuples( output.GetNumberOfEdges() )
        for i in range( output.GetNumberOfEdges() ):
            _array.SetValue(i, i)
        output.GetEdgeData().AddArray(_array)



class createVertexIndexFilter(filterBase):
    """ A helper class to add vertex indices to a graph """
    def __init__(self, inputPort):
        filterBase.__init__(self, inputPort)
        self.SetName("vertex_index")

    def callback(self):
        input  = self._filter.GetInput()
        output = self._filter.GetOutput()
        output.ShallowCopy(input)
        _array = vtkIntArray()
        _array.SetName(self.Name)
        _array.SetNumberOfTuples(output.GetNumberOfVertices())
        for i in range(output.GetNumberOfVertices()):
            _array.SetValue(i, i)
        output.GetVertexData().AddArray(_array)



def main():
    ## Set up the theme
    theme = vtkViewTheme.CreateOceanTheme()
    theme.SetCellOpacity(.5)
    theme.SetPointOpacity(0.5)
    theme.SetEdgeLabelColor(.9, .0, .25)
    theme.FastDelete()
    theme.SetPointSize(5.0)
    theme.SetLineWidth(1.5)
    theme.SetPointOpacity(0.4)
    theme.SetCellOpacity(0.4)
    
    ## Create a random graph
    G = vtkRandomGraphSource()
    G.SetNumberOfVertices(50)
    G.SetEdgeProbability(.02)
    G.SetUseEdgeProbability(True)
    G.SetStartWithTree(True)
    G.SetDirected(True)

    ## Add edge-index attribute array, "edge_index"
    eindx_builder = createEdgeIndexFilter( G.GetOutputPort() )
    
    ## Add vertex-index attribute array, "vertex_index"
    vindx_builder = createVertexIndexFilter( eindx_builder.GetOutputPort() )

    ## Create an edge-selection using Thresholds, let's get edges 1-10.
    selection = vtkSelectionSource()
    selection.SetContentType(7)            # Thresholds
    selection.SetFieldType(4)              # Edge
    selection.SetArrayName("edge_index")
    selection.AddThreshold(1, 15)
    
    ## Extract a subgraph using our threshold selection
    subgraph = vtkExtractSelectedGraph()
    subgraph.SetRemoveIsolatedVertices(True)
    subgraph.SetInputConnection(vindx_builder.GetOutputPort())
    subgraph.SetSelectionConnection(selection.GetOutputPort())

    ## Create a view for the subgraph
    subGraphView = vtkGraphLayoutView()
    subGraphView.AddRepresentationFromInputConnection(subgraph.GetOutputPort())
    subGraphView.SetVertexLabelArrayName("vertex_index")
    subGraphView.SetEdgeLabelArrayName("edge_index")
    subGraphView.SetVertexLabelVisibility(True)
    subGraphView.SetEdgeLabelVisibility(True)
    subGraphView.ApplyViewTheme(theme)
    
    ## Set up the render window and attach sub graph view to it.
    subGraphWindow = vtkRenderWindow()
    subGraphWindow.SetSize(800, 800)
    
    subGraphView.SetupRenderWindow(subGraphWindow)
    sys.exit( subGraphWindow.GetInteractor().Start() )


if __name__ == "__main__":
    main()