From e1230900b60b83d627fae433f1255e7274668c32 Mon Sep 17 00:00:00 2001
From: Daniel Plagge <plagge@cs.uni-duesseldorf.de>
Date: Mon, 2 May 2011 12:28:45 +0000
Subject: [PATCH] Added workaround for Kodkod's limitation when coping with
 quantification over n-ary (n>1) tuples

git-svn-id: https://cobra.cs.uni-duesseldorf.de/prob/trunk/experimental/plagge/probkodkod@7784 7aec93f6-bc54-0410-ac70-7d7c9efa889a
---
 src/de/stups/probkodkod/KodkodAnalysis.java | 247 ++++++++++++--------
 test/de/stups/probkodkod/KodkodTest.java    |  16 ++
 test/de/stups/probkodkod/relquant.kodkod    |   9 +
 3 files changed, 180 insertions(+), 92 deletions(-)
 create mode 100644 test/de/stups/probkodkod/relquant.kodkod

diff --git a/src/de/stups/probkodkod/KodkodAnalysis.java b/src/de/stups/probkodkod/KodkodAnalysis.java
index 018c2a4..cae49e6 100644
--- a/src/de/stups/probkodkod/KodkodAnalysis.java
+++ b/src/de/stups/probkodkod/KodkodAnalysis.java
@@ -139,6 +139,7 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 	private static final Map<String, ExprCompOperator> COMPOPS = new HashMap<String, ExprCompOperator>();
 	private static final Map<String, FormulaOperator> BINFORMOPS = new HashMap<String, FormulaOperator>();
 	private static final Map<String, Quantifier> QUANTIFIERS = new HashMap<String, Quantifier>();
+	private static final Map<String, FormulaOperator> QUANTIFIER_FOP = new HashMap<String, FormulaOperator>();
 
 	private static final Map<String, ExprOperator> BINEXPROPS = new HashMap<String, ExprOperator>();
 	private static final Map<String, ExprOperator> MULTIEXPROPS = new HashMap<String, ExprOperator>();
@@ -161,6 +162,10 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 		BINFORMOPS.put(AIffLogopBinary.class.getName(), FormulaOperator.IFF);
 		QUANTIFIERS.put(AAllQuantifier.class.getName(), Quantifier.ALL);
 		QUANTIFIERS.put(AExistsQuantifier.class.getName(), Quantifier.SOME);
+		QUANTIFIER_FOP.put(AAllQuantifier.class.getName(),
+				FormulaOperator.IMPLIES);
+		QUANTIFIER_FOP.put(AExistsQuantifier.class.getName(),
+				FormulaOperator.AND);
 
 		MULTIEXPROPS.put(AProductExprMultop.class.getName(),
 				ExprOperator.PRODUCT);
@@ -211,7 +216,7 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 	private final Stack<Formula> formulaStack = new Stack<Formula>();
 	private final Stack<Expression> expressionStack = new Stack<Expression>();
 	private final Stack<IntExpression> intExpressionStack = new Stack<IntExpression>();
-	private Map<String, Variable> variables = new HashMap<String, Variable>();
+	private Map<String, Expression> variables = new HashMap<String, Expression>();
 
 	public KodkodAnalysis(final KodkodSession session,
 			final IPrologTermOutput pto) {
@@ -263,15 +268,6 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 		}
 	}
 
-	// private long extractZNumber(final PZnumber znumber) {
-	// if (znumber instanceof APosZnumber)
-	// return extractInt(((APosZnumber) znumber).getNumber());
-	// else if (znumber instanceof ANegZnumber)
-	// return -extractInt(((ANegZnumber) znumber).getNumber());
-	// else
-	// throw new IllegalStateException("Unexpected ZNumber case");
-	// }
-
 	/**
 	 * A request has been entered. It will be added to the {@link KodkodSession}
 	 * and the first solutions will be send directly.
@@ -297,7 +293,7 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 	 */
 	@Override
 	public void outAList(final AList node) {
-		String id = extractIdentifier(node.getProblem());
+		final String id = extractIdentifier(node.getProblem());
 		final ImmutableProblem problem = session.getProblem(id);
 		if (problem != null) {
 			final int size = extractInt(node.getSize());
@@ -330,8 +326,8 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 
 	@Override
 	public void outAConstInnerformula(final AConstInnerformula node) {
-		String name = node.getLogConst().getClass().getName();
-		Formula formula = CONSTFORM.get(name);
+		final String name = node.getLogConst().getClass().getName();
+		final Formula formula = CONSTFORM.get(name);
 		if (formula == null)
 			throw new IllegalStateException("Unexpected constant " + name);
 		formulaStack.push(formula);
@@ -339,8 +335,8 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 
 	@Override
 	public void outAMultInnerformula(final AMultInnerformula node) {
-		String name = node.getMultiplicity().getClass().getName();
-		Multiplicity multiplicity = MULTIPLICITIES.get(name);
+		final String name = node.getMultiplicity().getClass().getName();
+		final Multiplicity multiplicity = MULTIPLICITIES.get(name);
 		if (multiplicity == null)
 			throw new IllegalStateException("Unexpected multiplicity " + name);
 		formulaStack.push(expressionStack.pop().apply(multiplicity));
@@ -348,13 +344,13 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 
 	@Override
 	public void outARelInnerformula(final ARelInnerformula node) {
-		String name = node.getLogopRel().getClass().getName();
-		ExprCompOperator op = COMPOPS.get(name);
+		final String name = node.getLogopRel().getClass().getName();
+		final ExprCompOperator op = COMPOPS.get(name);
 		if (op == null)
 			throw new IllegalStateException("Unexpected relation operator "
 					+ name);
-		Expression b = expressionStack.pop();
-		Expression a = expressionStack.pop();
+		final Expression b = expressionStack.pop();
+		final Expression a = expressionStack.pop();
 		formulaStack.push(a.compare(op, b));
 	}
 
@@ -366,7 +362,7 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 	@Override
 	public void caseAAndInnerformula(final AAndInnerformula node) {
 		final Collection<PFormula> nodes = node.getFormula();
-		int size = nodes == null ? 0 : nodes.size();
+		final int size = nodes == null ? 0 : nodes.size();
 		if (size == 0) {
 			formulaStack.push(Formula.TRUE);
 		} else if (size == 1) {
@@ -386,50 +382,55 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 
 	@Override
 	public void outABinaryInnerformula(final ABinaryInnerformula node) {
-		String name = node.getLogopBinary().getClass().getName();
-		FormulaOperator op = BINFORMOPS.get(name);
+		final String name = node.getLogopBinary().getClass().getName();
+		final FormulaOperator op = BINFORMOPS.get(name);
 		if (op == null)
 			throw new IllegalStateException("Unexpected operator " + name);
-		Formula b = formulaStack.pop();
-		Formula a = formulaStack.pop();
+		final Formula b = formulaStack.pop();
+		final Formula a = formulaStack.pop();
 		formulaStack.push(a.compose(op, b));
 	}
 
 	@Override
 	public void caseAQuantInnerformula(final AQuantInnerformula node) {
-		String name = node.getQuantifier().getClass().getName();
-		Quantifier quantifier = QUANTIFIERS.get(name);
-		if (quantifier == null)
+		final String name = node.getQuantifier().getClass().getName();
+		final Quantifier quantifier = QUANTIFIERS.get(name);
+		final FormulaOperator formulaOp = QUANTIFIER_FOP.get(name);
+		if (quantifier == null || formulaOp == null)
 			throw new IllegalStateException("Unexpected quantifier " + name);
-		Map<String, Variable> oldVars = variables;
-		variables = new HashMap<String, Variable>(variables);
-		Decls decls = extractDecls(node.getDecls());
+		final Map<String, Expression> oldVars = variables;
+		variables = new HashMap<String, Expression>(variables);
+		final Declaration decls = extractDecls(node.getDecls());
 		node.getFormula().apply(this);
-		formulaStack.push(formulaStack.pop().quantify(quantifier, decls));
+		final Formula formula = decls.applyFormula(formulaStack.pop(),
+				formulaOp);
+		final Formula quantify = formula.quantify(quantifier,
+				decls.getDeclarations());
+		formulaStack.push(quantify);
 		variables = oldVars;
 	}
 
 	@Override
 	public void outAIntInnerformula(final AIntInnerformula node) {
-		String name = node.getIntCompOp().getClass().getName();
-		IntCompOperator op = BININTCOMPS.get(name);
+		final String name = node.getIntCompOp().getClass().getName();
+		final IntCompOperator op = BININTCOMPS.get(name);
 		if (op == null)
 			throw new IllegalStateException(
 					"Unexpected integer comparision operator " + name);
-		IntExpression b = intExpressionStack.pop();
-		IntExpression a = intExpressionStack.pop();
+		final IntExpression b = intExpressionStack.pop();
+		final IntExpression a = intExpressionStack.pop();
 		formulaStack.push(a.compare(op, b));
 	}
 
 	@Override
 	public void outAFuncInnerformula(final AFuncInnerformula node) {
-		Expression range = expressionStack.pop();
-		Expression domain = expressionStack.pop();
-		Expression obj = expressionStack.pop();
-		PLogopFunction op = node.getLogopFunction();
-		Formula formula;
+		final Expression range = expressionStack.pop();
+		final Expression domain = expressionStack.pop();
+		final Expression obj = expressionStack.pop();
+		final PLogopFunction op = node.getLogopFunction();
+		final Formula formula;
 		if (obj instanceof Relation) {
-			Relation rel = (Relation) obj;
+			final Relation rel = (Relation) obj;
 			if (op instanceof ATotalLogopFunction) {
 				formula = rel.function(domain, range);
 			} else if (op instanceof APartialLogopFunction) {
@@ -438,7 +439,7 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 				throw new IllegalStateException("unexpected function operator "
 						+ op.getClass().getName());
 		} else {
-			Multiplicity mult;
+			final Multiplicity mult;
 			if (op instanceof ATotalLogopFunction) {
 				mult = Multiplicity.ONE;
 			} else if (op instanceof APartialLogopFunction) {
@@ -446,10 +447,10 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 			} else
 				throw new IllegalStateException("unexpected function operator "
 						+ op.getClass().getName());
-			Variable v = Variable.nary("__func", domain.arity());
-			Decl decl = v.declare(Multiplicity.ONE, domain);
-			Formula subset = obj.in(domain.product(range));
-			Formula unique = v.join(obj).apply(mult).forAll(decl);
+			final Variable v = Variable.nary("__func", domain.arity());
+			final Decl decl = v.declare(Multiplicity.ONE, domain);
+			final Formula subset = obj.in(domain.product(range));
+			final Formula unique = v.join(obj).apply(mult).forAll(decl);
 			formula = subset.and(unique);
 		}
 		formulaStack.push(formula);
@@ -477,19 +478,19 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 
 	@Override
 	public void outABinaryInnerexpression(final ABinaryInnerexpression node) {
-		String name = node.getExprBinop().getClass().getName();
-		ExprOperator op = BINEXPROPS.get(name);
+		final String name = node.getExprBinop().getClass().getName();
+		final ExprOperator op = BINEXPROPS.get(name);
 		if (op == null)
 			throw new IllegalStateException("Unexpected operator " + name);
-		Expression b = expressionStack.pop();
-		Expression a = expressionStack.pop();
+		final Expression b = expressionStack.pop();
+		final Expression a = expressionStack.pop();
 		expressionStack.push(a.compose(op, b));
 	}
 
 	@Override
 	public void outAUnaryInnerexpression(final AUnaryInnerexpression node) {
-		String name = node.getExprUnop().getClass().getName();
-		ExprOperator op = UNEXPROPS.get(name);
+		final String name = node.getExprUnop().getClass().getName();
+		final ExprOperator op = UNEXPROPS.get(name);
 		if (op == null)
 			throw new IllegalStateException("Unexpected operator " + name);
 		expressionStack.push(expressionStack.pop().apply(op));
@@ -497,15 +498,15 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 
 	@Override
 	public void outARelrefInnerexpression(final ARelrefInnerexpression node) {
-		String id = extractIdentifier(node.getIdentifier());
-		Relation relation = problem.lookupRelation(id);
+		final String id = extractIdentifier(node.getIdentifier());
+		final Relation relation = problem.lookupRelation(id);
 		expressionStack.push(relation);
 	}
 
 	@Override
 	public void outAVarrefInnerexpression(final AVarrefInnerexpression node) {
-		String id = extractIdentifier(node.getIdentifier());
-		Variable var = variables.get(id);
+		final String id = extractIdentifier(node.getIdentifier());
+		final Expression var = variables.get(id);
 		if (var == null)
 			throw new IllegalStateException("unknown variable " + id);
 		expressionStack.push(var);
@@ -513,8 +514,8 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 
 	@Override
 	public void outAConstInnerexpression(final AConstInnerexpression node) {
-		String name = node.getExprConst().getClass().getName();
-		Expression expression = CONSTEXPR.get(name);
+		final String name = node.getExprConst().getClass().getName();
+		final Expression expression = CONSTEXPR.get(name);
 		if (expression == null)
 			throw new IllegalStateException("Unexpected constant " + name);
 		expressionStack.push(expression);
@@ -522,14 +523,14 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 
 	@Override
 	public void outAConstInnerintexpression(final AConstInnerintexpression node) {
-		int value = extractInt(node.getZnumber());
+		final int value = extractInt(node.getZnumber());
 		intExpressionStack.push(IntConstant.constant(value));
 	}
 
 	@Override
 	public void outAPrjInnerexpression(final APrjInnerexpression node) {
-		int[] numbers = extractNumbers(node.getNumbers());
-		IntExpression[] prjs = new IntExpression[numbers.length];
+		final int[] numbers = extractNumbers(node.getNumbers());
+		final IntExpression[] prjs = new IntExpression[numbers.length];
 		for (int i = 0; i < numbers.length; i++) {
 			prjs[i] = IntConstant.constant(numbers[i]);
 		}
@@ -538,9 +539,9 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 
 	@Override
 	public void outACastInnerexpression(final ACastInnerexpression node) {
-		IntExpression integer = intExpressionStack.pop();
+		final IntExpression integer = intExpressionStack.pop();
 		final String castName = node.getExprCast().getClass().getName();
-		IntCastOperator op = INTCASTS.get(castName);
+		final IntCastOperator op = INTCASTS.get(castName);
 		if (op == null)
 			throw new IllegalStateException("Unexpected integer cast operator "
 					+ castName);
@@ -557,20 +558,20 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 
 	@Override
 	public void outACastInnerintexpression(final ACastInnerintexpression node) {
-		Expression expr = expressionStack.pop();
+		final Expression expr = expressionStack.pop();
 		intExpressionStack.push(expr.sum());
 	}
 
 	@Override
 	public void outABinaryInnerintexpression(
 			final ABinaryInnerintexpression node) {
-		String name = node.getIntexprBinop().getClass().getName();
-		IntOperator op = BININTEXPROPS.get(name);
+		final String name = node.getIntexprBinop().getClass().getName();
+		final IntOperator op = BININTEXPROPS.get(name);
 		if (op == null)
 			throw new IllegalStateException("Unexpected integer operator "
 					+ name);
-		IntExpression b = intExpressionStack.pop();
-		IntExpression a = intExpressionStack.pop();
+		final IntExpression b = intExpressionStack.pop();
+		final IntExpression a = intExpressionStack.pop();
 		intExpressionStack.push(a.compose(op, b));
 	}
 
@@ -581,11 +582,13 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 
 	@Override
 	public void caseACompInnerexpression(final ACompInnerexpression node) {
-		Map<String, Variable> oldVars = variables;
-		variables = new HashMap<String, Variable>(variables);
-		Decls decls = extractDecls(node.getDecls());
+		final Map<String, Expression> oldVars = variables;
+		variables = new HashMap<String, Expression>(variables);
+		final Declaration decls = extractDecls(node.getDecls());
 		node.getFormula().apply(this);
-		expressionStack.push(formulaStack.pop().comprehension(decls));
+		final Formula formula = decls.applyFormula(formulaStack.pop(),
+				FormulaOperator.AND);
+		expressionStack.push(formula.comprehension(decls.getDeclarations()));
 		variables = oldVars;
 	}
 
@@ -593,34 +596,61 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 		return reqtype instanceof APosReqtype;
 	}
 
-	private Decls extractDecls(PDecls node) {
-		Decls decls = null;
+	private Declaration extractDecls(PDecls node) {
+		final Declaration declaration = new Declaration();
 		while (node instanceof AConsDecls) {
-			AConsDecls cons = (AConsDecls) node;
-			String id = extractIdentifier(cons.getId());
-			int arity = extractInt(cons.getArity());
-			String mname = cons.getMultiplicity().getClass().getName();
-			Multiplicity multiplicity = MULTIPLICITIES.get(mname);
+			final AConsDecls cons = (AConsDecls) node;
+			final String id = extractIdentifier(cons.getId());
+			final int arity = extractInt(cons.getArity());
+			final String mname = cons.getMultiplicity().getClass().getName();
+			cons.getExpression().apply(this);
+			final Expression expression = expressionStack.pop();
+
+			final Multiplicity multiplicity = MULTIPLICITIES.get(mname);
 			if (multiplicity == null)
 				throw new IllegalStateException("Unexpected multiplicity "
 						+ mname);
-			cons.getExpression().apply(this);
-			Expression expression = expressionStack.pop();
-			Variable variable = Variable.nary(id, arity);
-			variables.put(id, variable);
-			Decl decl = variable.declare(multiplicity, expression);
-			if (decls == null) {
-				decls = decl;
+			if (arity > 1 && Multiplicity.ONE.equals(multiplicity)) {
+				createRelationDecl(id, arity, expression, declaration);
 			} else {
-				decls = decls.and(decl);
+				declare(id, arity, multiplicity, expression, declaration);
 			}
 
 			node = cons.getDecls();
 		}
-		if (decls == null)
+		if (declaration.isEmpty())
 			throw new IllegalStateException(
 					"no declarations in quantified formula");
-		return decls;
+		return declaration;
+	}
+
+	private void createRelationDecl(final String id, final int arity,
+			Expression expression, final Declaration declaration) {
+		Expression substitution = null;
+		for (int p = 0; p < arity; p++) {
+			final String tmpId = id + "_#_" + p;
+			final Expression prj = expression.project(IntConstant.constant(p));
+			final Decl tmpDecl = declare(tmpId, 1, Multiplicity.ONE, prj,
+					declaration);
+			final Variable tmpVar = tmpDecl.variable();
+			if (substitution == null) {
+				substitution = tmpVar;
+			} else {
+				substitution = substitution.product(tmpVar);
+			}
+		}
+		variables.put(id, substitution);
+		declaration.addFormula(substitution.in(expression));
+	}
+
+	private Decl declare(final String id, final int arity,
+			final Multiplicity multiplicity, final Expression expression,
+			final Declaration declaration) {
+		final Variable variable = Variable.nary(id, arity);
+		variables.put(id, variable);
+		final Decl decl = variable.declare(multiplicity, expression);
+		declaration.addDeclaration(decl);
+		return decl;
 	}
 
 	private void addRelations(final Collection<PRelation> nodes) {
@@ -718,7 +748,7 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 								+ ptuples.size());
 		}
 		final int arity = tupleType.getArity();
-		Collection<int[]> tuples = new ArrayList<int[]>();
+		final Collection<int[]> tuples = new ArrayList<int[]>();
 		for (final PTuple pTuple : ptuples) {
 			final ATuple aTuple = (ATuple) pTuple;
 			final int[] numbers = extractNumbers(aTuple.getNumbers());
@@ -731,7 +761,7 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 	}
 
 	private int[] extractNumbers(final Collection<TNumber> nodes) {
-		int[] result = new int[nodes.size()];
+		final int[] result = new int[nodes.size()];
 		int i = 0;
 		for (final TNumber node : nodes) {
 			result[i] = extractInt(node);
@@ -760,7 +790,7 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 	}
 
 	private Type[] extractTypes(final Collection<TIdentifier> nodes) {
-		Type[] result = new Type[nodes.size()];
+		final Type[] result = new Type[nodes.size()];
 		int i = 0;
 		for (final TIdentifier node : nodes) {
 			result[i] = lookupTypeInterval(node);
@@ -772,4 +802,37 @@ public class KodkodAnalysis extends DepthFirstAdapter {
 	private static String extractIdentifier(final TIdentifier node) {
 		return node == null ? null : node.getText();
 	}
+
+	private static final class Declaration {
+		private Decls declarations;
+		private Formula formula;
+
+		public void addDeclaration(final Decls decls) {
+			if (declarations == null) {
+				declarations = decls;
+			} else {
+				declarations = declarations.and(decls);
+			}
+		}
+
+		public boolean isEmpty() {
+			return declarations == null;
+		}
+
+		public void addFormula(final Formula form) {
+			if (formula == null) {
+				formula = form;
+			} else {
+				formula = formula.and(form);
+			}
+		}
+
+		public Decls getDeclarations() {
+			return declarations;
+		}
+
+		public Formula applyFormula(Formula orig, FormulaOperator op) {
+			return formula == null ? orig : formula.compose(op, orig);
+		}
+	}
 }
diff --git a/test/de/stups/probkodkod/KodkodTest.java b/test/de/stups/probkodkod/KodkodTest.java
index 1fa5a35..95d2521 100644
--- a/test/de/stups/probkodkod/KodkodTest.java
+++ b/test/de/stups/probkodkod/KodkodTest.java
@@ -15,6 +15,21 @@ import de.stups.probkodkod.test.Result;
 import de.stups.probkodkod.test.ResultSetBuilder;
 
 public class KodkodTest extends InteractionTestBase {
+	@Test
+	public void testQuantificationOnRelations() throws ParserException,
+			LexerException, IOException {
+		final String problem = load("relquant.kodkod");
+		sendMessage(problem + ".");
+
+		sendMessage("request relquant 10 pos ().");
+		List<SortedMap<String, Result>> sol = new LinkedList<SortedMap<String, Result>>();
+		getSolutions(false, sol);
+
+		ResultSetBuilder b = new ResultSetBuilder();
+		b.set("f", t(0, 0), t(0, 1), t(1, 0), t(1, 1)).store();
+		checkSolutions(b.toCollection(), sol);
+	}
+
 	@Test
 	public void testLoop() throws ParserException, LexerException, IOException,
 			InterruptedException {
@@ -159,4 +174,5 @@ public class KodkodTest extends InteractionTestBase {
 		b.single("x", t(0)).store();
 		checkSolutions(b.toCollection(), sol);
 	}
+
 }
diff --git a/test/de/stups/probkodkod/relquant.kodkod b/test/de/stups/probkodkod/relquant.kodkod
new file mode 100644
index 0000000..a2e3140
--- /dev/null
+++ b/test/de/stups/probkodkod/relquant.kodkod
@@ -0,0 +1,9 @@
+problem relquant
+
+ ((d 2) (r 2))
+ 
+ ((f subset d r))
+ 
+ (all 
+   ((e 2 one (product (relref d) (relref r))))
+   (in (varref e) (relref f)))
-- 
GitLab