From c8d48040b6e6b271f511d0977754c7ea36ff7cb7 Mon Sep 17 00:00:00 2001
From: dgelessus <dgelessus@users.noreply.github.com>
Date: Mon, 8 Nov 2021 13:46:58 +0100
Subject: [PATCH] Automatically insert local variables when parsing formulas

This matches what almost all code already did anyway. It will also help
with fixing the problems with local variables in Event-B mode - the fix
will require conditionally parsing multiple formulas.
---
 .../java/de/prob2/jupyter/CommandUtils.java   |  8 +++++
 .../java/de/prob2/jupyter/ProBKernel.java     | 30 +++++++++++++------
 .../prob2/jupyter/commands/AssertCommand.java |  5 ++--
 .../de/prob2/jupyter/commands/DotCommand.java | 21 ++++++-------
 .../prob2/jupyter/commands/EvalCommand.java   |  5 ++--
 .../prob2/jupyter/commands/FindCommand.java   |  7 ++---
 .../de/prob2/jupyter/commands/LetCommand.java |  5 ++--
 .../jupyter/commands/PrettyPrintCommand.java  |  5 ++--
 .../prob2/jupyter/commands/SolveCommand.java  |  3 +-
 .../prob2/jupyter/commands/TableCommand.java  |  3 +-
 .../prob2/jupyter/commands/TypeCommand.java   |  3 +-
 11 files changed, 51 insertions(+), 44 deletions(-)

diff --git a/src/main/java/de/prob2/jupyter/CommandUtils.java b/src/main/java/de/prob2/jupyter/CommandUtils.java
index 65a6c09..659e297 100644
--- a/src/main/java/de/prob2/jupyter/CommandUtils.java
+++ b/src/main/java/de/prob2/jupyter/CommandUtils.java
@@ -211,6 +211,14 @@ public final class CommandUtils {
 		});
 	}
 	
+	public static <T> T withSourceCode(final @NotNull IEvalElement formula, final Supplier<T> action) {
+		return withSourceCode(formula.getCode(), action);
+	}
+	
+	public static void withSourceCode(final @NotNull IEvalElement formula, final Runnable action) {
+		withSourceCode(formula.getCode(), action);
+	}
+	
 	public static @NotNull String insertLetVariables(final @NotNull String code, final @NotNull Map<@NotNull String, @NotNull String> variables) {
 		if (variables.isEmpty()) {
 			return code;
diff --git a/src/main/java/de/prob2/jupyter/ProBKernel.java b/src/main/java/de/prob2/jupyter/ProBKernel.java
index 0751bc9..bbfbf75 100644
--- a/src/main/java/de/prob2/jupyter/ProBKernel.java
+++ b/src/main/java/de/prob2/jupyter/ProBKernel.java
@@ -347,12 +347,14 @@ public final class ProBKernel extends BaseKernel {
 	/**
 	 * Parse the given formula code into an {@link IEvalElement}.
 	 * The language used for parsing depends on the current formula language (see {@link #getCurrentFormulaLanguage()}.
+	 * Unlike {@link #parseFormula(String, FormulaExpand)},
+	 * this method does not automatically insert local variables into the code.
 	 * 
 	 * @param code the formula code
 	 * @param expand the expansion mode to use when evaluating the formula
 	 * @return the parsed formula
 	 */
-	public IEvalElement parseFormula(final String code, final FormulaExpand expand) {
+	public IEvalElement parseFormulaWithoutLetVariables(final String code, final FormulaExpand expand) {
 		switch (this.getCurrentFormulaLanguage()) {
 			case DEFAULT:
 				return this.animationSelector.getCurrentTrace().getModel().parseFormula(code, expand);
@@ -368,16 +370,30 @@ public final class ProBKernel extends BaseKernel {
 		}
 	}
 	
+	/**
+	 * Parse the given formula code into an {@link IEvalElement}.
+	 * The language used for parsing depends on the current formula language (see {@link #getCurrentFormulaLanguage()}.
+	 * Any currently defined local variables are automatically inserted before parsing.
+	 * This can be avoided if necessary using {@link #parseFormulaWithoutLetVariables(String, FormulaExpand)}.
+	 * 
+	 * @param code the formula code
+	 * @param expand the expansion mode to use when evaluating the formula
+	 * @return the parsed formula
+	 */
+	public IEvalElement parseFormula(final String code, final FormulaExpand expand) {
+		return this.parseFormulaWithoutLetVariables(CommandUtils.insertLetVariables(code, this.getVariables()), expand);
+	}
+	
 	public @NotNull DisplayData executeOperation(final @NotNull String name, final @Nullable String predicate) {
 		final Trace trace = this.animationSelector.getCurrentTrace();
 		final String translatedOpName = Transition.unprettifyName(name);
-		final String modifiedPredicate;
+		final IEvalElement parsedPredicate;
 		if (predicate == null) {
-			modifiedPredicate = "1=1";
+			parsedPredicate = this.parseFormulaWithoutLetVariables("1=1", FormulaExpand.EXPAND);
 		} else {
-			modifiedPredicate = this.insertLetVariables(predicate);
+			parsedPredicate = this.parseFormula(predicate, FormulaExpand.EXPAND);
 		}
-		final List<Transition> ops = trace.getStateSpace().transitionFromPredicate(trace.getCurrentState(), translatedOpName, modifiedPredicate, 1);
+		final List<Transition> ops = trace.getStateSpace().transitionFromPredicate(trace.getCurrentState(), translatedOpName, parsedPredicate, 1);
 		assert !ops.isEmpty();
 		final Transition op = ops.get(0);
 		
@@ -827,8 +843,4 @@ public final class ProBKernel extends BaseKernel {
 			throw e2;
 		}
 	}
-	
-	public @NotNull String insertLetVariables(final @NotNull String code) {
-		return CommandUtils.insertLetVariables(code, this.getVariables());
-	}
 }
diff --git a/src/main/java/de/prob2/jupyter/commands/AssertCommand.java b/src/main/java/de/prob2/jupyter/commands/AssertCommand.java
index b879046..1ba3378 100644
--- a/src/main/java/de/prob2/jupyter/commands/AssertCommand.java
+++ b/src/main/java/de/prob2/jupyter/commands/AssertCommand.java
@@ -69,9 +69,8 @@ public final class AssertCommand implements Command {
 	@Override
 	public @NotNull DisplayData run(final @NotNull ParsedArguments args) {
 		final ProBKernel kernel = this.kernelProvider.get();
-		final String code = kernel.insertLetVariables(args.get(FORMULA_PARAM));
-		final IEvalElement formula = kernel.parseFormula(code, FormulaExpand.TRUNCATE);
-		final AbstractEvalResult result = CommandUtils.withSourceCode(code, () -> this.animationSelector.getCurrentTrace().evalCurrent(formula));
+		final IEvalElement formula = kernel.parseFormula(args.get(FORMULA_PARAM), FormulaExpand.TRUNCATE);
+		final AbstractEvalResult result = CommandUtils.withSourceCode(formula, () -> this.animationSelector.getCurrentTrace().evalCurrent(formula));
 		if (result instanceof EvalResult && "TRUE".equals(((EvalResult)result).getValue())) {
 			// Use EvalResult.TRUE instead of the real result so that solution variables are not displayed.
 			return CommandUtils.displayDataForEvalResult(EvalResult.TRUE);
diff --git a/src/main/java/de/prob2/jupyter/commands/DotCommand.java b/src/main/java/de/prob2/jupyter/commands/DotCommand.java
index ef05549..ea3bd3c 100644
--- a/src/main/java/de/prob2/jupyter/commands/DotCommand.java
+++ b/src/main/java/de/prob2/jupyter/commands/DotCommand.java
@@ -88,16 +88,12 @@ public final class DotCommand implements Command {
 	@Override
 	public @NotNull DisplayData run(final @NotNull ParsedArguments args) {
 		final String command = args.get(COMMAND_PARAM);
-		final List<IEvalElement> dotCommandArgs;
-		final String code;
+		final IEvalElement formula;
 		if (args.get(FORMULA_PARAM).isPresent()) {
 			final ProBKernel kernel = this.kernelProvider.get();
-			code = kernel.insertLetVariables(args.get(FORMULA_PARAM).get());
-			final IEvalElement formula = CommandUtils.withSourceCode(code, () -> kernel.parseFormula(code, FormulaExpand.EXPAND));
-			dotCommandArgs = Collections.singletonList(formula);
+			formula = kernel.parseFormula(args.get(FORMULA_PARAM).get(), FormulaExpand.EXPAND);
 		} else {
-			code = null;
-			dotCommandArgs = Collections.emptyList();
+			formula = null;
 		}
 		
 		final Trace trace = this.animationSelector.getCurrentTrace();
@@ -108,13 +104,14 @@ public final class DotCommand implements Command {
 			throw new UserErrorException("No such dot command: " + command, e);
 		}
 		
-		// Provide source code (if any) to error highlighter
-		final Supplier<String> execute = () -> dotCommand.visualizeAsSvgToString(dotCommandArgs);
+		final List<IEvalElement> dotCommandArgs;
 		final String svg;
-		if (code != null) {
-			svg = CommandUtils.withSourceCode(code, execute);
+		if (formula != null) {
+			dotCommandArgs = Collections.singletonList(formula);
+			svg = CommandUtils.withSourceCode(formula, () -> dotCommand.visualizeAsSvgToString(dotCommandArgs));
 		} else {
-			svg = execute.get();
+			dotCommandArgs = Collections.emptyList();
+			svg = dotCommand.visualizeAsSvgToString(dotCommandArgs);
 		}
 		final DisplayData result = new DisplayData(String.format("<Dot visualization: %s %s>", command, dotCommandArgs));
 		result.putSVG(svg);
diff --git a/src/main/java/de/prob2/jupyter/commands/EvalCommand.java b/src/main/java/de/prob2/jupyter/commands/EvalCommand.java
index add281a..bc6271f 100644
--- a/src/main/java/de/prob2/jupyter/commands/EvalCommand.java
+++ b/src/main/java/de/prob2/jupyter/commands/EvalCommand.java
@@ -64,9 +64,8 @@ public final class EvalCommand implements Command {
 	@Override
 	public @NotNull DisplayData run(final @NotNull ParsedArguments args) {
 		final ProBKernel kernel = this.injector.getInstance(ProBKernel.class);
-		final String code = kernel.insertLetVariables(args.get(FORMULA_PARAM));
-		final IEvalElement formula = kernel.parseFormula(code, FormulaExpand.EXPAND);
-		return CommandUtils.displayDataForEvalResult(CommandUtils.withSourceCode(code, () -> this.animationSelector.getCurrentTrace().evalCurrent(formula)));
+		final IEvalElement formula = kernel.parseFormula(args.get(FORMULA_PARAM), FormulaExpand.EXPAND);
+		return CommandUtils.displayDataForEvalResult(CommandUtils.withSourceCode(formula, () -> this.animationSelector.getCurrentTrace().evalCurrent(formula)));
 	}
 	
 	@Override
diff --git a/src/main/java/de/prob2/jupyter/commands/FindCommand.java b/src/main/java/de/prob2/jupyter/commands/FindCommand.java
index df91c04..92ba83e 100644
--- a/src/main/java/de/prob2/jupyter/commands/FindCommand.java
+++ b/src/main/java/de/prob2/jupyter/commands/FindCommand.java
@@ -66,11 +66,8 @@ public final class FindCommand implements Command {
 	public @NotNull DisplayData run(final @NotNull ParsedArguments args) {
 		final ProBKernel kernel = this.kernelProvider.get();
 		final Trace trace = this.animationSelector.getCurrentTrace();
-		final String code = kernel.insertLetVariables(args.get(PREDICATE_PARAM));
-		final Trace newTrace = CommandUtils.withSourceCode(code, () -> {
-			final IEvalElement pred = kernel.parseFormula(code, FormulaExpand.EXPAND);
-			return trace.getStateSpace().getTraceToState(pred);
-		});
+		final IEvalElement pred = kernel.parseFormula(args.get(PREDICATE_PARAM), FormulaExpand.EXPAND);
+		final Trace newTrace = CommandUtils.withSourceCode(pred, () -> trace.getStateSpace().getTraceToState(pred));
 		this.animationSelector.changeCurrentAnimation(newTrace);
 		return new DisplayData("Found a matching state and made it current state");
 	}
diff --git a/src/main/java/de/prob2/jupyter/commands/LetCommand.java b/src/main/java/de/prob2/jupyter/commands/LetCommand.java
index 7e4a10e..814fb52 100644
--- a/src/main/java/de/prob2/jupyter/commands/LetCommand.java
+++ b/src/main/java/de/prob2/jupyter/commands/LetCommand.java
@@ -69,9 +69,8 @@ public final class LetCommand implements Command {
 	public @NotNull DisplayData run(final @NotNull ParsedArguments args) {
 		final String name = args.get(NAME_PARAM);
 		final ProBKernel kernel = this.kernelProvider.get();
-		final String expr = kernel.insertLetVariables(args.get(EXPRESSION_PARAM));
-		final IEvalElement formula = kernel.parseFormula(expr, FormulaExpand.EXPAND);
-		final AbstractEvalResult evaluated = CommandUtils.withSourceCode(expr, () -> this.animationSelector.getCurrentTrace().evalCurrent(formula));
+		final IEvalElement formula = kernel.parseFormula(args.get(EXPRESSION_PARAM), FormulaExpand.EXPAND);
+		final AbstractEvalResult evaluated = CommandUtils.withSourceCode(formula, () -> this.animationSelector.getCurrentTrace().evalCurrent(formula));
 		if (evaluated instanceof EvalResult) {
 			kernel.getVariables().put(name, evaluated.toString());
 		}
diff --git a/src/main/java/de/prob2/jupyter/commands/PrettyPrintCommand.java b/src/main/java/de/prob2/jupyter/commands/PrettyPrintCommand.java
index e8f0fa7..789559f 100644
--- a/src/main/java/de/prob2/jupyter/commands/PrettyPrintCommand.java
+++ b/src/main/java/de/prob2/jupyter/commands/PrettyPrintCommand.java
@@ -64,14 +64,13 @@ public final class PrettyPrintCommand implements Command {
 	
 	@Override
 	public @NotNull DisplayData run(final @NotNull ParsedArguments args) {
-		final String code = args.get(PREDICATE_PARAM);
-		final IEvalElement formula = this.kernelProvider.get().parseFormula(code, FormulaExpand.EXPAND);
+		final IEvalElement formula = this.kernelProvider.get().parseFormula(args.get(PREDICATE_PARAM), FormulaExpand.EXPAND);
 		
 		final PrettyPrintFormulaCommand cmdUnicode = new PrettyPrintFormulaCommand(formula, PrettyPrintFormulaCommand.Mode.UNICODE);
 		cmdUnicode.setOptimize(false);
 		final PrettyPrintFormulaCommand cmdLatex = new PrettyPrintFormulaCommand(formula, PrettyPrintFormulaCommand.Mode.LATEX);
 		cmdLatex.setOptimize(false);
-		CommandUtils.withSourceCode(code, () -> this.animationSelector.getCurrentTrace().getStateSpace().execute(cmdUnicode, cmdLatex));
+		CommandUtils.withSourceCode(formula, () -> this.animationSelector.getCurrentTrace().getStateSpace().execute(cmdUnicode, cmdLatex));
 		
 		final DisplayData ret = new DisplayData(cmdUnicode.getPrettyPrint());
 		ret.putLatex('$' + cmdLatex.getPrettyPrint() + '$');
diff --git a/src/main/java/de/prob2/jupyter/commands/SolveCommand.java b/src/main/java/de/prob2/jupyter/commands/SolveCommand.java
index 028758e..c92210b 100644
--- a/src/main/java/de/prob2/jupyter/commands/SolveCommand.java
+++ b/src/main/java/de/prob2/jupyter/commands/SolveCommand.java
@@ -86,8 +86,7 @@ public final class SolveCommand implements Command {
 		if (solver == null) {
 			throw new UserErrorException("Unknown solver: " + args.get(SOLVER_PARAM));
 		}
-		final String code = kernel.insertLetVariables(args.get(PREDICATE_PARAM));
-		final IEvalElement predicate = CommandUtils.withSourceCode(code, () -> kernel.parseFormula(code, FormulaExpand.EXPAND));
+		final IEvalElement predicate = kernel.parseFormula(args.get(PREDICATE_PARAM), FormulaExpand.EXPAND);
 		
 		final CbcSolveCommand cmd = new CbcSolveCommand(predicate, solver, this.animationSelector.getCurrentTrace().getCurrentState());
 		trace.getStateSpace().execute(cmd);
diff --git a/src/main/java/de/prob2/jupyter/commands/TableCommand.java b/src/main/java/de/prob2/jupyter/commands/TableCommand.java
index cc54b69..d302afe 100644
--- a/src/main/java/de/prob2/jupyter/commands/TableCommand.java
+++ b/src/main/java/de/prob2/jupyter/commands/TableCommand.java
@@ -70,8 +70,7 @@ public final class TableCommand implements Command {
 	public @NotNull DisplayData run(final @NotNull ParsedArguments args) {
 		final ProBKernel kernel = this.kernelProvider.get();
 		final Trace trace = this.animationSelector.getCurrentTrace();
-		final String code = kernel.insertLetVariables(args.get(EXPRESSION_PARAM));
-		final IEvalElement formula = CommandUtils.withSourceCode(code, () -> kernel.parseFormula(code, FormulaExpand.EXPAND));
+		final IEvalElement formula = kernel.parseFormula(args.get(EXPRESSION_PARAM), FormulaExpand.EXPAND);
 		
 		final TableData table = TableVisualizationCommand.getByName(TableVisualizationCommand.EXPRESSION_AS_TABLE_NAME, trace.getCurrentState())
 			.visualize(Collections.singletonList(formula));
diff --git a/src/main/java/de/prob2/jupyter/commands/TypeCommand.java b/src/main/java/de/prob2/jupyter/commands/TypeCommand.java
index dc5da6d..de56ed6 100644
--- a/src/main/java/de/prob2/jupyter/commands/TypeCommand.java
+++ b/src/main/java/de/prob2/jupyter/commands/TypeCommand.java
@@ -67,8 +67,7 @@ public final class TypeCommand implements Command {
 	public @NotNull DisplayData run(final @NotNull ParsedArguments args) {
 		final ProBKernel kernel = this.kernelProvider.get();
 		final Trace trace = this.animationSelector.getCurrentTrace();
-		final String code = kernel.insertLetVariables(args.get(FORMULA_PARAM));
-		final IEvalElement formula = CommandUtils.withSourceCode(code, () -> kernel.parseFormula(code, FormulaExpand.EXPAND));
+		final IEvalElement formula = kernel.parseFormula(args.get(FORMULA_PARAM), FormulaExpand.EXPAND);
 		final TypeCheckResult result = trace.getStateSpace().typeCheck(formula);
 		if (result.isOk()) {
 			return new DisplayData(result.getType());
-- 
GitLab