From 6cf8f0c75670f66ea827fbaeae4916ebddd1ff22 Mon Sep 17 00:00:00 2001
From: Miles Vella <673-vella@users.noreply.gitlab.cs.uni-duesseldorf.de>
Date: Wed, 12 Mar 2025 17:14:59 +0100
Subject: [PATCH] Implement basic support for LET expression & predicate

---
 .../de/tlc4b/analysis/MachineContext.java     |  26 +++++
 src/main/java/de/tlc4b/analysis/Renamer.java  |  30 +++++
 .../java/de/tlc4b/analysis/Typechecker.java   |  32 ++++++
 .../analysis/UnsupportedConstructsFinder.java |   3 -
 .../typerestriction/TypeRestrictor.java       |  18 +++
 .../java/de/tlc4b/prettyprint/TLAPrinter.java | 104 ++++++++++++++++--
 .../analysis/UnsupportedConstructsTest.java   |  12 --
 .../de/tlc4b/prettyprint/OperationsTest.java  |  43 +++++++-
 8 files changed, 241 insertions(+), 27 deletions(-)

diff --git a/src/main/java/de/tlc4b/analysis/MachineContext.java b/src/main/java/de/tlc4b/analysis/MachineContext.java
index ff19458..ae6f85e 100644
--- a/src/main/java/de/tlc4b/analysis/MachineContext.java
+++ b/src/main/java/de/tlc4b/analysis/MachineContext.java
@@ -635,6 +635,31 @@ public class MachineContext extends DepthFirstAdapter {
 		}
 		node.getPredicate().apply(this);
 		node.getSubstitution().apply(this);
+		contextTable.remove(contextTable.size() - 1);
+	}
+
+	@Override
+	public void caseALetExpressionExpression(ALetExpressionExpression node) {
+		contextTable.add(new LinkedHashMap<>());
+		List<PExpression> copy = new ArrayList<>(node.getIdentifiers());
+		for (PExpression e : copy) {
+			putLocalVariableIntoCurrentScope((AIdentifierExpression) e);
+		}
+		node.getAssignment().apply(this);
+		node.getExpr().apply(this);
+		contextTable.remove(contextTable.size() - 1);
+	}
+
+	@Override
+	public void caseALetPredicatePredicate(ALetPredicatePredicate node) {
+		contextTable.add(new LinkedHashMap<>());
+		List<PExpression> copy = new ArrayList<>(node.getIdentifiers());
+		for (PExpression e : copy) {
+			putLocalVariableIntoCurrentScope((AIdentifierExpression) e);
+		}
+		node.getAssignment().apply(this);
+		node.getPred().apply(this);
+		contextTable.remove(contextTable.size() - 1);
 	}
 
 	@Override
@@ -646,6 +671,7 @@ public class MachineContext extends DepthFirstAdapter {
 		}
 		node.getWhere().apply(this);
 		node.getThen().apply(this);
+		contextTable.remove(contextTable.size() - 1);
 	}
 
 	@Override
diff --git a/src/main/java/de/tlc4b/analysis/Renamer.java b/src/main/java/de/tlc4b/analysis/Renamer.java
index d98a9ff..dd9b2d4 100644
--- a/src/main/java/de/tlc4b/analysis/Renamer.java
+++ b/src/main/java/de/tlc4b/analysis/Renamer.java
@@ -19,6 +19,8 @@ import de.be4.classicalb.core.parser.node.AGeneralProductExpression;
 import de.be4.classicalb.core.parser.node.AGeneralSumExpression;
 import de.be4.classicalb.core.parser.node.AIdentifierExpression;
 import de.be4.classicalb.core.parser.node.ALambdaExpression;
+import de.be4.classicalb.core.parser.node.ALetExpressionExpression;
+import de.be4.classicalb.core.parser.node.ALetPredicatePredicate;
 import de.be4.classicalb.core.parser.node.ALetSubstitution;
 import de.be4.classicalb.core.parser.node.AOperation;
 import de.be4.classicalb.core.parser.node.APredicateDefinitionDefinition;
@@ -350,6 +352,34 @@ public class Renamer extends DepthFirstAdapter {
 		removeLastContext();
 	}
 
+	@Override
+	public void caseALetExpressionExpression(ALetExpressionExpression node) {
+		List<PExpression> list = new ArrayList<>(node.getIdentifiers());
+		evalBoundedVariables(node, list);
+
+		List<PExpression> copy = new ArrayList<>(node.getIdentifiers());
+		for (PExpression e : copy) {
+			e.apply(this);
+		}
+		node.getAssignment().apply(this);
+		node.getExpr().apply(this);
+		removeLastContext();
+	}
+
+	@Override
+	public void caseALetPredicatePredicate(ALetPredicatePredicate node) {
+		List<PExpression> list = new ArrayList<>(node.getIdentifiers());
+		evalBoundedVariables(node, list);
+
+		List<PExpression> copy = new ArrayList<>(node.getIdentifiers());
+		for (PExpression e : copy) {
+			e.apply(this);
+		}
+		node.getAssignment().apply(this);
+		node.getPred().apply(this);
+		removeLastContext();
+	}
+
 	@Override
 	public void caseAVarSubstitution(AVarSubstitution node) {
 		List<PExpression> list = new ArrayList<>(node.getIdentifiers());
diff --git a/src/main/java/de/tlc4b/analysis/Typechecker.java b/src/main/java/de/tlc4b/analysis/Typechecker.java
index 55e34f7..4503bc3 100644
--- a/src/main/java/de/tlc4b/analysis/Typechecker.java
+++ b/src/main/java/de/tlc4b/analysis/Typechecker.java
@@ -605,6 +605,38 @@ public class Typechecker extends DepthFirstAdapter implements ITypechecker {
 		node.getSubstitution().apply(this);
 	}
 
+	@Override
+	public void caseALetExpressionExpression(ALetExpressionExpression node) {
+		List<PExpression> copy = new ArrayList<>(node.getIdentifiers());
+		for (PExpression e : copy) {
+			AIdentifierExpression v = (AIdentifierExpression) e;
+			setType(v, new UntypedType());
+		}
+
+		setType(node.getAssignment(), BoolType.getInstance());
+		node.getAssignment().apply(this);
+
+		setType(node.getExpr(), new UntypedType());
+		node.getExpr().apply(this);
+		unify(getType(node), getType(node.getExpr()), node);
+	}
+
+	@Override
+	public void caseALetPredicatePredicate(ALetPredicatePredicate node) {
+		List<PExpression> copy = new ArrayList<>(node.getIdentifiers());
+		for (PExpression e : copy) {
+			AIdentifierExpression v = (AIdentifierExpression) e;
+			setType(v, new UntypedType());
+		}
+
+		setType(node.getAssignment(), BoolType.getInstance());
+		node.getAssignment().apply(this);
+
+		setType(node.getPred(), BoolType.getInstance());
+		node.getPred().apply(this);
+		unify(getType(node), getType(node.getPred()), node);
+	}
+
 	/****************************************************************************
 	 * Arithmetic operators *
 	 ****************************************************************************/
diff --git a/src/main/java/de/tlc4b/analysis/UnsupportedConstructsFinder.java b/src/main/java/de/tlc4b/analysis/UnsupportedConstructsFinder.java
index 15cca65..55459d4 100644
--- a/src/main/java/de/tlc4b/analysis/UnsupportedConstructsFinder.java
+++ b/src/main/java/de/tlc4b/analysis/UnsupportedConstructsFinder.java
@@ -37,9 +37,6 @@ public class UnsupportedConstructsFinder extends DepthFirstAdapter {
 		// should have been rewritten in parser
 		add(AIfPredicatePredicate.class);
 		add(AIfElsifPredicatePredicate.class);
-
-		add(ALetExpressionExpression.class);
-		add(ALetPredicatePredicate.class);
 	}
 
 	private static void add(Class<? extends Node> clazz) {
diff --git a/src/main/java/de/tlc4b/analysis/typerestriction/TypeRestrictor.java b/src/main/java/de/tlc4b/analysis/typerestriction/TypeRestrictor.java
index 2d2a7d6..9cc265d 100644
--- a/src/main/java/de/tlc4b/analysis/typerestriction/TypeRestrictor.java
+++ b/src/main/java/de/tlc4b/analysis/typerestriction/TypeRestrictor.java
@@ -25,6 +25,8 @@ import de.be4.classicalb.core.parser.node.AImplicationPredicate;
 import de.be4.classicalb.core.parser.node.AInitialisationMachineClause;
 import de.be4.classicalb.core.parser.node.AIntersectionExpression;
 import de.be4.classicalb.core.parser.node.ALambdaExpression;
+import de.be4.classicalb.core.parser.node.ALetExpressionExpression;
+import de.be4.classicalb.core.parser.node.ALetPredicatePredicate;
 import de.be4.classicalb.core.parser.node.ALetSubstitution;
 import de.be4.classicalb.core.parser.node.AMemberPredicate;
 import de.be4.classicalb.core.parser.node.ANotMemberPredicate;
@@ -456,6 +458,22 @@ public class TypeRestrictor extends DepthFirstAdapter {
 		createRestrictedTypeofLocalVariables(new HashSet<>(node.getIdentifiers()), false);
 	}
 
+	@Override
+	public void inALetExpressionExpression(ALetExpressionExpression node) {
+		super.inALetExpressionExpression(node);
+		// no type restriction, will use "LET var == value IN expr" construct
+	}
+
+	@Override
+	public void inALetPredicatePredicate(ALetPredicatePredicate node) {
+		HashSet<Node> list = new HashSet<>();
+		List<PExpression> copy = new ArrayList<>(node.getIdentifiers());
+		list.addAll(copy);
+		list.addAll(getExpectedIdentifier(node));
+		analysePredicate(node.getAssignment(), list, new HashSet<>());
+		createRestrictedTypeofLocalVariables(new HashSet<>(node.getIdentifiers()), false);
+	}
+
 	private Hashtable<Node, Node> variablesHashTable;
 
 	public void inABecomesSuchSubstitution(ABecomesSuchSubstitution node) {
diff --git a/src/main/java/de/tlc4b/prettyprint/TLAPrinter.java b/src/main/java/de/tlc4b/prettyprint/TLAPrinter.java
index aa0c57c..9474ba3 100644
--- a/src/main/java/de/tlc4b/prettyprint/TLAPrinter.java
+++ b/src/main/java/de/tlc4b/prettyprint/TLAPrinter.java
@@ -2,10 +2,12 @@ package de.tlc4b.prettyprint;
 
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
+import java.util.Map;
 
 import de.be4.classicalb.core.parser.analysis.DepthFirstAdapter;
 import de.be4.classicalb.core.parser.node.*;
@@ -28,6 +30,7 @@ import de.tlc4b.btypes.PairType;
 import de.tlc4b.btypes.SetType;
 import de.tlc4b.btypes.StructType;
 import de.tlc4b.btypes.UntypedType;
+import de.tlc4b.exceptions.TranslationException;
 import de.tlc4b.ltl.LTLFormulaVisitor;
 import de.tlc4b.tla.ConfigFile;
 import de.tlc4b.tla.TLADefinition;
@@ -997,20 +1000,18 @@ public class TLAPrinter extends DepthFirstAdapter {
 	@Override
 	public void caseALetSubstitution(ALetSubstitution node) {
 		inALetSubstitution(node);
+		moduleStringAppend("\\E ");
 		List<PExpression> copy = new ArrayList<>(node.getIdentifiers());
-		if (!copy.isEmpty()) {
-			moduleStringAppend("\\E ");
-			for (int i = 0; i < copy.size(); i++) {
-				PExpression e = copy.get(i);
-				e.apply(this);
-				moduleStringAppend(" \\in ");
-				typeRestrictor.getRestrictedNode(e).apply(this);
-				if (i < copy.size() - 1) {
-					moduleStringAppend(", ");
-				}
+		for (int i = 0; i < copy.size(); i++) {
+			PExpression e = copy.get(i);
+			e.apply(this);
+			moduleStringAppend(" \\in ");
+			typeRestrictor.getRestrictedNode(e).apply(this);
+			if (i < copy.size() - 1) {
+				moduleStringAppend(", ");
 			}
-			moduleStringAppend(" : ");
 		}
+		moduleStringAppend(" : ");
 
 		if (typeRestrictor.isARemovedNode(node.getPredicate())) {
 			moduleStringAppend("TRUE");
@@ -1025,6 +1026,87 @@ public class TLAPrinter extends DepthFirstAdapter {
 		outALetSubstitution(node);
 	}
 
+	@Override
+	public void caseALetPredicatePredicate(ALetPredicatePredicate node) {
+		inALetPredicatePredicate(node);
+		moduleStringAppend("\\E ");
+		List<PExpression> copy = new ArrayList<>(node.getIdentifiers());
+		for (int i = 0; i < copy.size(); i++) {
+			PExpression e = copy.get(i);
+			e.apply(this);
+			moduleStringAppend(" \\in ");
+			typeRestrictor.getRestrictedNode(e).apply(this);
+			if (i < copy.size() - 1) {
+				moduleStringAppend(", ");
+			}
+		}
+		moduleStringAppend(" : ");
+
+		if (typeRestrictor.isARemovedNode(node.getAssignment())) {
+			moduleStringAppend("TRUE");
+		} else {
+			node.getAssignment().apply(this);
+		}
+
+		moduleStringAppend(" /\\ ");
+		node.getPred().apply(this);
+
+		outALetPredicatePredicate(node);
+	}
+
+	private static Map<String, PExpression> getValuesFromEqualPredicates(PPredicate p, Map<String, PExpression> values) {
+		if (p instanceof AEqualPredicate) {
+			AEqualPredicate eq = (AEqualPredicate) p;
+			PExpression left = eq.getLeft();
+			if (left instanceof AIdentifierExpression) {
+				String s = Utils.getAIdentifierAsString((AIdentifierExpression) left);
+				if (values.containsKey(s)) {
+					throw new TranslationException("invalid predicate in LET expr: " + p);
+				}
+				values.put(s, eq.getRight());
+			} else {
+				throw new TranslationException("invalid predicate in LET expr: " + p);
+			}
+		} else if (p instanceof AConjunctPredicate) {
+			AConjunctPredicate conj = (AConjunctPredicate) p;
+			getValuesFromEqualPredicates(conj.getLeft(), values);
+			getValuesFromEqualPredicates(conj.getRight(), values);
+		} else {
+			throw new TranslationException("invalid predicate in LET expr: " + p);
+		}
+
+		return values;
+	}
+
+	@Override
+	public void caseALetExpressionExpression(ALetExpressionExpression node) {
+		inALetExpressionExpression(node);
+		moduleStringAppend("LET ");
+		Map<String, PExpression> values = getValuesFromEqualPredicates(node.getAssignment(), new HashMap<>());
+		List<PExpression> copy = new ArrayList<>(node.getIdentifiers());
+		for (int i = 0; i < copy.size(); i++) {
+			PExpression e = copy.get(i);
+			e.apply(this);
+			moduleStringAppend(" == ");
+
+			String identifier = Utils.getAIdentifierAsString((AIdentifierExpression) e);
+			PExpression value = values.get(identifier);
+			if (value == null) {
+				throw new TranslationException("no equals predicate for identifier " + identifier + " in LET expr");
+			}
+			value.apply(this);
+
+			if (i < copy.size() - 1) {
+				moduleStringAppend(", ");
+			}
+		}
+
+		moduleStringAppend(" IN ");
+		node.getExpr().apply(this);
+
+		outALetExpressionExpression(node);
+	}
+
 	@Override
 	public void caseAOperation(AOperation node) {
 		String name = renamer.getNameOfRef(node);
diff --git a/src/test/java/de/tlc4b/analysis/UnsupportedConstructsTest.java b/src/test/java/de/tlc4b/analysis/UnsupportedConstructsTest.java
index dd99257..d1720ca 100644
--- a/src/test/java/de/tlc4b/analysis/UnsupportedConstructsTest.java
+++ b/src/test/java/de/tlc4b/analysis/UnsupportedConstructsTest.java
@@ -25,16 +25,4 @@ public class UnsupportedConstructsTest {
 		final String machine = "MACHINE M FREETYPES F = F1, F2(INTEGER) END";
 		translate(machine);
 	}
-
-	@Test(expected = NotSupportedException.class)
-	public void testLetExpr() throws Exception {
-		final String machine = "MACHINE M VARIABLES x INVARIANT x : INTEGER INITIALISATION x := (LET foo BE foo=42 IN foo END) END";
-		translate(machine);
-	}
-
-	@Test(expected = NotSupportedException.class)
-	public void testLetPred() throws Exception {
-		final String machine = "MACHINE M VARIABLES x INVARIANT x : INTEGER INITIALISATION x : (LET foo BE foo=42 IN x=foo END) END";
-		translate(machine);
-	}
 }
diff --git a/src/test/java/de/tlc4b/prettyprint/OperationsTest.java b/src/test/java/de/tlc4b/prettyprint/OperationsTest.java
index 1741ed1..bad0b01 100644
--- a/src/test/java/de/tlc4b/prettyprint/OperationsTest.java
+++ b/src/test/java/de/tlc4b/prettyprint/OperationsTest.java
@@ -426,5 +426,46 @@ public class OperationsTest {
 				+ "====";
 		compare(expected, machine);
 	}
-	
+
+	@Test
+	public void testOperationWithLetExpr() throws Exception {
+		String machine = "MACHINE test\n"
+				+ "VARIABLES x\n"
+				+ "INVARIANT x : 1..10\n"
+				+ "INITIALISATION x := 1\n"
+				+ "OPERATIONS\n"
+				+ "inc = x := LET y BE y = x+1 IN y+1 END\n"
+				+ "END";
+
+		String expected = "---- MODULE test ----\n"
+				+ "EXTENDS Naturals\n"
+				+ "VARIABLES x\n"
+				+ "Invariant1 == x \\in 1..10\n"
+				+ "Init == x = 1\n"
+				+ "inc == x' = LET y == x+1 IN y+1\n"
+				+ "Next == \\/ inc\n"
+				+ "====";
+		compare(expected, machine);
+	}
+
+	@Test
+	public void testOperationWithLetPred() throws Exception {
+		String machine = "MACHINE test\n"
+				+ "VARIABLES x\n"
+				+ "INVARIANT x : 1..10\n"
+				+ "INITIALISATION x := 1\n"
+				+ "OPERATIONS\n"
+				+ "inc = SELECT (LET y BE y = 1 IN x=y END) THEN x := x+1 END\n"
+				+ "END";
+
+		String expected = "---- MODULE test ----\n"
+				+ "EXTENDS Naturals\n"
+				+ "VARIABLES x\n"
+				+ "Invariant1 == x \\in 1..10\n"
+				+ "Init == x = 1 \n"
+				+ "inc == (\\E y \\in {1} : TRUE /\\ x = y /\\ x' = x + 1)\n"
+				+ "Next == \\/ inc\n"
+				+ "====";
+		compare(expected, machine);
+	}
 }
-- 
GitLab