diff --git a/maraboupy/Marabou.py b/maraboupy/Marabou.py index bc21c08917..8875acc746 100644 --- a/maraboupy/Marabou.py +++ b/maraboupy/Marabou.py @@ -3,6 +3,7 @@ - Christopher Lazarus - Kyle Julian - Andrew Wu + - Teruhiro Tagomori This file is part of the Marabou project. Copyright (c) 2017-2019 by the authors listed in the file AUTHORS @@ -29,6 +30,7 @@ from maraboupy.MarabouNetworkONNX import * except ImportError: warnings.warn("ONNX parser is unavailable because onnx or onnxruntime packages are not installed") +from maraboupy.MarabouNetworkComposition import MarabouNetworkComposition def read_nnet(filename, normalize=False): """Constructs a MarabouNetworkNnet object from a .nnet file @@ -74,6 +76,23 @@ def read_onnx(filename, inputNames=None, outputNames=None): """ return MarabouNetworkONNX(filename, inputNames, outputNames) +def read_onnx_with_threshold(filename, inputNames=None, outputNames=None, maxNumberOfLinearEquations=None): + """Constructs a MarabouNetworkComposition object from an ONNX file + + Args: + filename (str): Path to the ONNX file + inputNames (list of str, optional): List of node names corresponding to inputs + outputNames (list of str, optional): List of node names corresponding to outputs + maxNumberOfLinearEquations (int, optional): Threshold for the number of linear equations. + If the number of linear equations is greater than this threshold, + the network will be split into two networks. Defaults to None. + + Returns: + :class:`~maraboupy.MarabouNetworkComposition.MarabouNetworkComposition` + """ + return MarabouNetworkComposition(filename, inputNames, outputNames, + maxNumberOfLinearEquations=maxNumberOfLinearEquations) + def load_query(filename): """Load the serialized inputQuery from the given filename diff --git a/maraboupy/MarabouNetworkComposition.py b/maraboupy/MarabouNetworkComposition.py new file mode 100644 index 0000000000..4e924ab65b --- /dev/null +++ b/maraboupy/MarabouNetworkComposition.py @@ -0,0 +1,301 @@ +''' +Top contributors (to current version): + - Teruhiro Tagomori + +This file is part of the Marabou project. +Copyright (c) 2017-2019 by the authors listed in the file AUTHORS +in the top-level source directory) and their institutional affiliations. +All rights reserved. See the file COPYING in the top-level source +directory for licensing information. + +MarabouNetworkComposition represents split subnets of a neural network with piecewise linear constraints derived from the ONNX format +''' + +import numpy as np +import onnx +import sys +import pathlib +import os +sys.path.insert(0, os.path.join(str(pathlib.Path(__file__).parent.absolute()), "../")) + +from maraboupy import Marabou +from maraboupy import MarabouCore +from maraboupy import MarabouNetwork +from maraboupy import MarabouNetworkONNX + +class MarabouNetworkComposition(MarabouNetwork.MarabouNetwork): + """Constructs a MarabouNetworkComposition object from an ONNX file + This class splits into subnets every time the number of linear equations reaches maxNumberOfLinearEquations. + It provides the function to propagate bounds for each subnet. + + Args: + filename (str): Path to the ONNX file + inputNames: (list of str, optional): List of node names corresponding to inputs + outputNames: (list of str, optional): List of node names corresponding to outputs + maxNumberOfLinearEquations (int, optional): Threshold for the number of linear equations. + If the number of linear equations is greater than this threshold, + the network will be split into two networks. Defaults to None. + + Returns: + :class:`~maraboupy.Marabou.MarabouNetworkComposition.MarabouNetworkComposition` + """ + def __init__(self, filename, inputNames=None, outputNames=None, maxNumberOfLinearEquations=None): + super().__init__() + self.shapeMap = {} + self.madeGraphEquations = [] + self.ipqs = [] + self.ipqToInVars = {} + self.ipqToOutVars = {} + self.inputVars, self.outputVars = self.getInputOutputVars(filename, inputNames, outputNames) + + # Instantiate the first subnet + network = MarabouNetworkONNX.MarabouNetworkONNX(filename, maxNumberOfLinearEquations=maxNumberOfLinearEquations) + + savedInputQueryName = 'q1.ipq' + network.saveQuery(savedInputQueryName) + self.ipqs.append(savedInputQueryName) + self.ipqToInVars[savedInputQueryName] = network.inputVars + self.ipqToOutVars[savedInputQueryName] = network.outputVars + + # index of ipq file + index = 2 + + while os.path.exists('post_split.onnx'): + # delete network + del network + + # Instantiate the next subnet + network = MarabouNetworkONNX.MarabouNetworkONNX('post_split.onnx', + maxNumberOfLinearEquations=maxNumberOfLinearEquations) + # name of the input query file + savedInputQueryName = f'q{index}.ipq' + + # save input query + network.saveQuery(savedInputQueryName) + + # append input query to the lsit + self.ipqs.append(savedInputQueryName) + + # save input and output variables so that this can map them to the next input query + self.ipqToInVars[savedInputQueryName] = network.inputVars + self.ipqToOutVars[savedInputQueryName] = network.outputVars + + # increment index + index += 1 + + def solve(self, filename="", verbose=True, options=None): + """Function to solve query represented by this network + + Args: + filename (string): Path for redirecting output (Only for the last subnet) + verbose (bool): If true, print out solution after solve finishes + options (:class:`~maraboupy.MarabouCore.Options`): Object for specifying Marabou options, defaults to None + + Returns: + (tuple): tuple containing: + - exitCode (str): A string representing the exit code (unsat/TIMEOUT/ERROR/UNKNOWN/QUIT_REQUESTED). + - vals (Dict[int, float]): Empty dictionary. This is for compatibility with MarabouNetwork. + - stats (:class:`~maraboupy.MarabouCore.Statistics`): A Statistics object to how Marabou performed (Only for the last subnet) + """ + if options == None: + options = MarabouCore.Options() + + for i, ipqFile in enumerate(self.ipqs): + # load input query + ipq = Marabou.loadQuery(ipqFile) + + # If the first subnet, encode input variables with the input bounds of the original network + if i == 0: + self.encodeInput(ipq) + + # If not the first subnet, encode input variables with the output bounds of the previous subwork + if i > 0: + self.encodeCalculateInputBounds(ipq, i, bounds) + + if i == len(self.ipqs) - 1: + # If the last subnet, encode output variables with the output bounds of the original network + self.encodeOutput(ipq, i) + + # If the last subnet, propagate bounds and return the exit code, values, and statistics + exitCode, bounds, stats = MarabouCore.calculateBounds(ipq, options, str(filename)) + if exitCode == "": + exitCode = "UNKNOWN" + if verbose: + print(exitCode) + return [exitCode, {}, stats] + else: + # If not the last subnet, propagate bounds + _, bounds, _ = MarabouCore.calculateBounds(ipq, options) + + def encodeCalculateInputBounds(self, ipq, i, bounds): + """Function to encode input variables and set bounds for the current subnet + + Args: + ipq (:class:`~maraboupy.MarabouCore.InputQuery`): InputQuery object to encode input variables + i (int): Index of the previous subnet + bounds (dict): Dictionary containing bounds for variables of the previous subnet + + Returns: + None + + :meta private: + """ + # Output variables of the previous subnet + previousOutputVars = self.ipqToOutVars[f'q{i}.ipq'] + + # Input variables of the current subnet + currentInputVars = self.ipqToInVars[f'q{i+1}.ipq'] + + # Set bounds for the current subnet + for previousOutputVar, currentInputVar in zip(previousOutputVars, currentInputVars): + for previousOutputVarElement, currentInputVarElement in zip(previousOutputVar.flatten(), currentInputVar.flatten()): + ipq.setLowerBound(currentInputVarElement, bounds[previousOutputVarElement][0]) + ipq.setUpperBound(currentInputVarElement, bounds[previousOutputVarElement][1]) + + def encodeInput(self, ipq): + """Function to encode input variables + + Args: + ipq (:class:`~maraboupy.MarabouCore.InputQuery`): InputQuery object to encode input variables + Returns: + None + + :meta private: + """ + inputVars = self.ipqToInVars['q1.ipq'] + + # Set bounds for the first subnet + for array in inputVars: + for var in array.flatten(): + ipq.setLowerBound(var, self.lowerBounds[var]) + ipq.setUpperBound(var, self.upperBounds[var]) + + def encodeOutput(self, ipq, i): + """Function to encode output variables + Args: + ipq: (:class:`~maraboupy.MarabouCore.InputQuery`): InputQuery object to encode output variables + i: (int): Index of the previous subnet + + Returns: + None + + :meta private: + """ + # Output variables of the current subnet + outputVars = self.ipqToOutVars[f'q{i+1}.ipq'] + + # Set bounds for the current subnet + originalOutputVars = self.outputVars + + # Set bounds for the last subnet + for originalOutputVar, outputVar in zip(originalOutputVars, outputVars): + for originalOutputVarElement, outputVarElement in zip(originalOutputVar.flatten(), outputVar.flatten()): + if originalOutputVarElement in self.lowerBounds: + ipq.setLowerBound(outputVarElement, self.lowerBounds[originalOutputVarElement]) + if originalOutputVarElement in self.upperBounds: + ipq.setUpperBound(outputVarElement, self.upperBounds[originalOutputVarElement]) + + def getInputOutputVars(self, filename, inputNames, outputNames): + """Get input and output variables of an original network + + Args: + filename: (str): Path to the ONNX file + inputNames: (list of str): List of node names corresponding to inputs + outputNames: (list of str): List of node names corresponding to outputs + + :meta private: + """ + self.filename = filename + self.graph = onnx.load(filename).graph + + # Get default inputs/outputs if no names are provided + if not inputNames: + assert len(self.graph.input) >= 1 + initNames = [node.name for node in self.graph.initializer] + inputNames = [inp.name for inp in self.graph.input if inp.name not in initNames] + if not outputNames: + assert len(self.graph.output) >= 1 + initNames = [node.name for node in self.graph.initializer] + outputNames = [out.name for out in self.graph.output if out.name not in initNames] + elif isinstance(outputNames, str): + outputNames = [outputNames] + + # Check that input/outputs are in the graph + for name in inputNames: + if not len([nde for nde in self.graph.node if name in nde.input]): + raise RuntimeError("Input %s not found in graph!" % name) + for name in outputNames: + if not len([nde for nde in self.graph.node if name in nde.output]): + raise RuntimeError("Output %s not found in graph!" % name) + + self.inputNames = inputNames + self.outputNames = outputNames + + # Process the shapes and values of the graph while making Marabou equations and constraints + self.foundnInputFlags = 0 + + # Add shapes for the graph's inputs + inputVars = [] + for node in self.graph.input: + self.shapeMap[node.name] = list([dim.dim_value if dim.dim_value > 0 else 1 for dim in node.type.tensor_type.shape.dim]) + + # If we find one of the specified inputs, create new variables + if node.name in self.inputNames: + self.madeGraphEquations += [node.name] + self.foundnInputFlags += 1 + v = self.makeNewVariables(node.name) + inputVars += [v] + + # Add shapes for the graph's outputs + outputVars = [] + for node in self.graph.output: + self.shapeMap[node.name] = list([dim.dim_value if dim.dim_value > 0 else 1 for dim in node.type.tensor_type.shape.dim]) + + # If we find one of the specified inputs, create new variables + if node.name in self.outputNames: + self.madeGraphEquations += [node.name] + self.foundnInputFlags += 1 + v = self.makeNewVariables(node.name) + outputVars += [v] + return inputVars, outputVars + + def makeNewVariables(self, nodeName): + """Assuming the node's shape is known, return a set of new variables in the same shape + + Args: + nodeName (str): Name of node + + Returns: + (numpy array): Array of variable numbers + + :meta private: + """ + shape = self.shapeMap[nodeName] + size = np.prod(shape) + v = np.array([self.getNewVariable() for _ in range(size)]).reshape(shape) + assert all([np.equal(np.mod(i, 1), 0) for i in v.reshape(-1)]) # check if integers + return v + + def setLowerBound(self, x, v): + """Function to set lower bound for variable + + Args: + x (int): Variable number to set + v (float): Value representing lower bound + """ + if any(x in arr for arr in self.inputVars) or any(x in arr for arr in self.outputVars): + self.lowerBounds[x] = v + else: + raise RuntimeError("Can set bounds only on either input or output variables") + + def setUpperBound(self, x, v): + """Function to set upper bound for variable + + Args: + x (int): Variable number to set + v (float): Value representing upper bound + """ + if any(x in arr for arr in self.inputVars) or any(x in arr for arr in self.outputVars): + self.upperBounds[x] = v + else: + raise RuntimeError("Can set bounds only on either input or output variables") diff --git a/maraboupy/MarabouNetworkONNX.py b/maraboupy/MarabouNetworkONNX.py index 698d92da4b..6cadd71280 100644 --- a/maraboupy/MarabouNetworkONNX.py +++ b/maraboupy/MarabouNetworkONNX.py @@ -31,17 +31,34 @@ class MarabouNetworkONNX(MarabouNetwork): Returns: :class:`~maraboupy.Marabou.marabouNetworkONNX.marabouNetworkONNX` """ - def __init__(self, filename, inputNames=None, outputNames=None): + def __init__(self, filename, inputNames=None, outputNames=None, maxNumberOfLinearEquations=None): super().__init__() - self.readONNX(filename, inputNames, outputNames) + self.readONNX(filename, inputNames, outputNames, maxNumberOfLinearEquations=maxNumberOfLinearEquations) + + def readONNX(self, filename, inputNames=None, outputNames=None, preserveExistingConstraints=False, maxNumberOfLinearEquations=None): + """Read an ONNX file and create a MarabouNetworkONNX object + + Args: + filename: (str): Path to the ONNX file + inputNames: (list of str): List of node names corresponding to inputs + outputNames: (list of str): List of node names corresponding to outputs + preserveExistingConstraints (bool, optional): If True, preserve existing constraints in the network. Defaults to False. + maxNumberOfLinearEquations (int, optional): Threshold for the number of linear equations. + If the number of linear equations is greater than this threshold, + the network will be split into two networks. Defaults to None. + :meta private: + """ + - def readONNX(self, filename, inputNames=None, outputNames=None, preserveExistingConstraints=False): if not preserveExistingConstraints: self.clear() self.filename = filename self.graph = onnx.load(filename).graph + if os.path.exists('post_split.onnx'): + os.remove('post_split.onnx') + # Setup input node names if inputNames is not None: # Check that input are in the graph @@ -71,7 +88,7 @@ def readONNX(self, filename, inputNames=None, outputNames=None, preserveExisting initNames = [node.name for node in self.graph.initializer] self.outputNames = [out.name for out in self.graph.output if out.name not in initNames] - ONNXParser.parse(self, self.graph, self.inputNames, self.outputNames) + ONNXParser.parse(self, self.graph, self.inputNames, self.outputNames, maxNumberOfLinearEquations=maxNumberOfLinearEquations) def getNode(self, nodeName): """Find the node in the graph corresponding to the given name @@ -169,4 +186,4 @@ def evaluateWithoutMarabou(self, inputValues): else: raise NotImplementedError("Inputs to network expected to be of type 'float', not %s" % onnxType) input_dict[inputName] = inputValues[i].reshape(self.inputVars[i].shape).astype(inputType) - return sess.run(self.outputNames, input_dict) \ No newline at end of file + return sess.run(self.outputNames, input_dict) diff --git a/maraboupy/parsers/ONNXParser.py b/maraboupy/parsers/ONNXParser.py index f9c80adc75..f66d35c133 100644 --- a/maraboupy/parsers/ONNXParser.py +++ b/maraboupy/parsers/ONNXParser.py @@ -32,7 +32,7 @@ class ONNXParser: """ @staticmethod - def parse(query:InputQueryBuilder, graph, inputNames:List[str], outputNames:List[str]): + def parse(query:InputQueryBuilder, graph, inputNames:List[str], outputNames:List[str], maxNumberOfLinearEquations=None): """ Parses the provided ONNX graph into constraints which are stored in the query argument. @@ -45,11 +45,11 @@ def parse(query:InputQueryBuilder, graph, inputNames:List[str], outputNames:List Returns: :class:`~maraboupy.Marabou.marabouNetworkONNX.marabouNetworkONNX` """ - parser = ONNXParser(query, graph, inputNames, outputNames) + parser = ONNXParser(query, graph, inputNames, outputNames, maxNumberOfLinearEquations=maxNumberOfLinearEquations) parser.parseGraph() - def __init__(self, query:InputQueryBuilder, graph, inputNames, outputNames): + def __init__(self, query:InputQueryBuilder, graph, inputNames, outputNames, maxNumberOfLinearEquations=None): """ Should not be called directly. Use `ONNXParser.parse` instead. @@ -66,6 +66,9 @@ def __init__(self, query:InputQueryBuilder, graph, inputNames, outputNames): self.constantMap = dict() self.shapeMap = dict() + self.maxNumberOfLinearEquations = maxNumberOfLinearEquations + self.thresholdReached = False + def parseGraph(self): """Read an ONNX file and create a MarabouNetworkONNX object @@ -84,7 +87,11 @@ def parseGraph(self): for outputName in self.outputNames: if outputName in self.constantMap: raise RuntimeError("Output variable %s is a constant, not the output of equations!" % outputName) - self.query.outputVars.extend([self.varMap[outputName] for outputName in self.outputNames]) + + for outputName in self.outputNames: + # If maxNumberOfLinearEquations is reached, the network is split and the outputVars are not set + if outputName in self.varMap: + self.query.outputVars.extend([self.varMap[outputName]]) def processGraph(self): """Processes the ONNX graph to produce Marabou equations @@ -142,8 +149,15 @@ def makeGraphEquations(self, nodeName, makeEquations): raise RuntimeError(err_msg) # Compute node's shape and create Marabou equations as needed + if self.thresholdReached: + return self.makeMarabouEquations(nodeName, makeEquations) + if self.maxNumberOfLinearEquations is not None: + if not self.thresholdReached and len(self.query.equList) > self.maxNumberOfLinearEquations: + if self.query.splitNetworkAtNode(nodeName, networkNamePostSplit='post_split.onnx'): + self.thresholdReached = True + # Create new variables when we find one of the inputs if nodeName in self.inputNames: self.makeNewVariables(nodeName) diff --git a/maraboupy/test/test_network_composition.py b/maraboupy/test/test_network_composition.py new file mode 100644 index 0000000000..f7d49d3e3d --- /dev/null +++ b/maraboupy/test/test_network_composition.py @@ -0,0 +1,118 @@ +from maraboupy import Marabou +import os + +# Global settings +OPT = Marabou.createOptions(verbosity = 0) # Turn off printing +TOL = 1e-6 # Set tolerance for checking Marabou evaluations +NETWORK_FOLDER = "../../resources/nnet/" # Folder for test networks +NETWORK_ONNX_FOLDER = "../../resources/onnx/" # Folder for test networks in onnx format + +def test_zero_split_unknown(): + """ + Tests that a network with no splits is correctly read and solved as unknown + """ + filename = 'fc1.onnx' + filename = os.path.join(os.path.dirname(__file__), NETWORK_ONNX_FOLDER, filename) + network = Marabou.read_onnx_with_threshold(filename, maxNumberOfLinearEquations=100) + + # check that the network has one split + assert len(network.ipqs) == 1 + + network.setLowerBound(network.inputVars[0][0][0], 3) + network.setUpperBound(network.inputVars[0][0][0], 5) + network.setLowerBound(network.inputVars[0][0][1], 3) + network.setUpperBound(network.inputVars[0][0][1], 10) + + exitCode, _, _ = network.solve(options=OPT) + + assert exitCode == "UNKNOWN" + +def test_zero_split_unsat(): + """ + Tests that a network with no splits is correctly read and solved as unsat + """ + filename = 'fc1.onnx' + filename = os.path.join(os.path.dirname(__file__), NETWORK_ONNX_FOLDER, filename) + network = Marabou.read_onnx_with_threshold(filename, maxNumberOfLinearEquations=100) + + # check that the network has no splits + assert len(network.ipqs) == 1 + + network.setLowerBound(network.inputVars[0][0][0], 3) + network.setUpperBound(network.inputVars[0][0][0], 5) + network.setLowerBound(network.inputVars[0][0][1], 3) + network.setUpperBound(network.inputVars[0][0][1], 10) + + network.setLowerBound(network.outputVars[0][0][0], 100) + + exitCode, _, _ = network.solve(options=OPT) + + assert exitCode == "unsat" + + network = Marabou.read_onnx(filename) + network.setLowerBound(network.inputVars[0][0][0], 3) + network.setUpperBound(network.inputVars[0][0][0], 5) + network.setLowerBound(network.inputVars[0][0][1], 3) + network.setUpperBound(network.inputVars[0][0][1], 10) + + network.setLowerBound(network.outputVars[0][0][0], 100) + + exitCode2, _, _ = network.calculateBounds(options=OPT) + + # exitCode2 should be also unsat + assert exitCode == exitCode2 + +def test_one_split_unknown(): + """ + Tests that a network with one split is correctly read and solved as unknown + """ + filename = 'fc1.onnx' + filename = os.path.join(os.path.dirname(__file__), NETWORK_ONNX_FOLDER, filename) + network = Marabou.read_onnx_with_threshold(filename, maxNumberOfLinearEquations=50) + + # check that the network has one split + assert len(network.ipqs) == 2 + + network.setLowerBound(network.inputVars[0][0][0], 3) + network.setUpperBound(network.inputVars[0][0][0], 5) + network.setLowerBound(network.inputVars[0][0][1], 3) + network.setUpperBound(network.inputVars[0][0][1], 10) + + exitCode, _, _ = network.solve(options=OPT) + + assert exitCode == "UNKNOWN" + +def test_one_split_unsat(): + """ + Tests that a network with one split is correctly read and solved as unsat + """ + filename = 'fc1.onnx' + filename = os.path.join(os.path.dirname(__file__), NETWORK_ONNX_FOLDER, filename) + network = Marabou.read_onnx_with_threshold(filename, maxNumberOfLinearEquations=50) + + # check that the network has one split + assert len(network.ipqs) == 2 + + network.setLowerBound(network.inputVars[0][0][0], 3) + network.setUpperBound(network.inputVars[0][0][0], 5) + network.setLowerBound(network.inputVars[0][0][1], 3) + network.setUpperBound(network.inputVars[0][0][1], 10) + + network.setLowerBound(network.outputVars[0][0][0], 100) + + exitCode, _, _ = network.solve(options=OPT) + + assert exitCode == "unsat" + + network = Marabou.read_onnx(filename) + network.setLowerBound(network.inputVars[0][0][0], 3) + network.setUpperBound(network.inputVars[0][0][0], 5) + network.setLowerBound(network.inputVars[0][0][1], 3) + network.setUpperBound(network.inputVars[0][0][1], 10) + + network.setLowerBound(network.outputVars[0][0][0], 100) + + exitCode2, _, _ = network.calculateBounds(options=OPT) + + # exitCode2 should be also unsat + assert exitCode == exitCode2