From e4980cb59f29c966859221cb4c0e4c627b190ded Mon Sep 17 00:00:00 2001 From: Miles Vella <673-vella@users.noreply.gitlab.cs.uni-duesseldorf.de> Date: Wed, 25 Jun 2025 21:31:33 +0200 Subject: [PATCH] Fix a lot of typechecking bugs with user definitions and ints vs reals --- .../java/de/tla2b/analysis/TypeChecker.java | 336 ++++++++++-------- .../java/de/tla2b/output/TlaTypePrinter.java | 5 + .../de/tla2b/output/TypeVisitorInterface.java | 2 + .../de/tla2b/types/AbstractHasFollowers.java | 46 ++- .../java/de/tla2b/types/FunctionType.java | 10 +- .../java/de/tla2b/types/IDefaultableType.java | 7 + src/main/java/de/tla2b/types/IntType.java | 6 +- .../de/tla2b/types/IntegerOrRealType.java | 65 ++++ src/main/java/de/tla2b/types/PairType.java | 15 +- src/main/java/de/tla2b/types/RealType.java | 6 +- src/main/java/de/tla2b/types/SetType.java | 13 +- .../de/tla2b/types/StructOrFunctionType.java | 12 +- src/main/java/de/tla2b/types/StructType.java | 9 +- src/main/java/de/tla2b/types/TLAType.java | 1 + .../java/de/tla2b/types/TupleOrFunction.java | 5 +- src/main/java/de/tla2b/types/TupleType.java | 6 +- src/main/java/de/tla2b/types/UntypedType.java | 8 +- .../tla2b/typechecking/DefinitionsTest.java | 101 ++++-- .../java/de/tla2b/typechecking/TupleTest.java | 5 +- 19 files changed, 418 insertions(+), 240 deletions(-) create mode 100644 src/main/java/de/tla2b/types/IDefaultableType.java create mode 100644 src/main/java/de/tla2b/types/IntegerOrRealType.java diff --git a/src/main/java/de/tla2b/analysis/TypeChecker.java b/src/main/java/de/tla2b/analysis/TypeChecker.java index 7a23248..e3145ba 100644 --- a/src/main/java/de/tla2b/analysis/TypeChecker.java +++ b/src/main/java/de/tla2b/analysis/TypeChecker.java @@ -1,18 +1,56 @@ package de.tla2b.analysis; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + import de.tla2b.config.ConfigfileEvaluator; import de.tla2b.config.TLCValueNode; -import de.tla2b.exceptions.*; +import de.tla2b.exceptions.NotImplementedException; +import de.tla2b.exceptions.TLA2BException; +import de.tla2b.exceptions.TLA2BFrontEndException; +import de.tla2b.exceptions.TypeErrorException; +import de.tla2b.exceptions.UnificationException; import de.tla2b.global.BBuildIns; import de.tla2b.global.BBuiltInOPs; import de.tla2b.global.TranslationGlobals; -import de.tla2b.types.*; +import de.tla2b.types.AbstractHasFollowers; +import de.tla2b.types.BoolType; +import de.tla2b.types.FunctionType; +import de.tla2b.types.IDefaultableType; +import de.tla2b.types.IntType; +import de.tla2b.types.IntegerOrRealType; +import de.tla2b.types.RealType; +import de.tla2b.types.SetType; +import de.tla2b.types.StringType; +import de.tla2b.types.StructOrFunctionType; +import de.tla2b.types.StructType; +import de.tla2b.types.TLAType; +import de.tla2b.types.TupleOrFunction; +import de.tla2b.types.TupleType; +import de.tla2b.types.UntypedType; import de.tla2b.util.DebugUtils; -import tla2sany.semantic.*; -import tlc2.tool.BuiltInOPs; -import java.util.*; -import java.util.Map.Entry; +import tla2sany.semantic.AssumeNode; +import tla2sany.semantic.AtNode; +import tla2sany.semantic.ExprNode; +import tla2sany.semantic.ExprOrOpArgNode; +import tla2sany.semantic.FormalParamNode; +import tla2sany.semantic.LetInNode; +import tla2sany.semantic.ModuleNode; +import tla2sany.semantic.NumeralNode; +import tla2sany.semantic.OpApplNode; +import tla2sany.semantic.OpDeclNode; +import tla2sany.semantic.OpDefNode; +import tla2sany.semantic.SemanticNode; +import tla2sany.semantic.StringNode; +import tla2sany.semantic.SymbolNode; +import tlc2.tool.BuiltInOPs; public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlobals { @@ -25,7 +63,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo private final Set<OpDefNode> bDefinitions; private final List<SymbolNode> symbolNodeList = new ArrayList<>(); - private final List<SemanticNode> tupleNodeList = new ArrayList<>(); + private final List<IDefaultableType> possibleUnfinishedTypes = new ArrayList<>(); private final ModuleNode moduleNode; private List<OpDeclNode> bConstList; @@ -67,18 +105,18 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo public void start() throws TLA2BException { for (OpDeclNode con : moduleNode.getConstantDecls()) { if (constantAssignments != null && constantAssignments.containsKey(con)) { - setTypeAndFollowers(con, constantAssignments.get(con).getType()); + setLocalTypeAndFollowers(con, constantAssignments.get(con).getType()); } else { // if constant already has a type: keep type; otherwise add an untyped type - if (getType(con) == null) - setTypeAndFollowers(con, new UntypedType()); + if (getLocalType(con) == null) + setLocalTypeAndFollowers(con, new UntypedType()); } } for (OpDeclNode var : moduleNode.getVariableDecls()) { // if variable already has a type: keep type; otherwise add an untyped type - if (getType(var) == null) - setTypeAndFollowers(var, new UntypedType()); + if (getLocalType(var) == null) + setLocalTypeAndFollowers(var, new UntypedType()); } evalDefinitions(moduleNode.getOpDefs()); @@ -89,10 +127,10 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo if (!bConstList.contains(con)) continue; - TLAType defType = getType(entry.getValue()); - TLAType conType = getType(con); + TLAType defType = getLocalType(entry.getValue()); + TLAType conType = getLocalType(con); try { - setType(con, defType.unify(conType)); + setLocalType(con, defType.unify(conType)); } catch (UnificationException e) { throw new TypeErrorException( String.format("Expected %s, found %s at constant '%s'.", defType, conType, con.getName())); @@ -116,10 +154,14 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo } private void checkIfAllIdentifiersHaveAType() throws TypeErrorException { + for (IDefaultableType type : possibleUnfinishedTypes) { + type.setToDefault(); + } + // check if a variable has no type for (OpDeclNode var : moduleNode.getVariableDecls()) { - TLAType varType = getType(var); - if (varType.isUntyped()) { + TLAType varType = getLocalType(var); + if (varType == null || varType.isUntyped()) { throw new TypeErrorException( "The type of the variable '" + var.getName() + "' can not be inferred: " + varType); } @@ -129,28 +171,20 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo // the resulting B Machine are considered for (OpDeclNode con : moduleNode.getConstantDecls()) { if (bConstList == null || bConstList.contains(con)) { - TLAType conType = getType(con); - if (conType.isUntyped()) { + TLAType conType = getLocalType(con); + if (conType == null || conType.isUntyped()) { throw new TypeErrorException( "The type of constant " + con.getName() + " is still untyped: " + conType); } } } - for (SymbolNode symbol : symbolNodeList) { - TLAType type = getType(symbol); - if (type.isUntyped()) { + /* TODO: for (SymbolNode symbol : symbolNodeList) { + TLAType type = getLocalType(symbol); + if (type == null || type.isUntyped()) { throw new TypeErrorException("Symbol '" + symbol.getName() + "' has no type.\n" + symbol.getLocation()); } - } - - for (SemanticNode tuple : tupleNodeList) { - TLAType type = getType(tuple); - if (type instanceof TupleOrFunction) { - ((TupleOrFunction) type).getFinalType(); - } - // TODO: does this do anything? - } + }*/ } private void evalDefinitions(OpDefNode[] opDefs) throws TLA2BException { @@ -167,16 +201,13 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo public void visitOpDefNode(OpDefNode def) throws TLA2BException { for (FormalParamNode p : def.getParams()) { if (p.getArity() > 0) { - throw new TLA2BFrontEndException(String.format("TLA2B do not support 2nd-order operators: '%s'%n %s ", + throw new TLA2BFrontEndException(String.format("TLA2B does not support 2nd-order operators: '%s'%n %s ", def.getName(), def.getLocation())); } - setTypeAndFollowers(p, new UntypedType(), paramId); + setLocalTypeAndFollowers(p, new UntypedType()); } - UntypedType u = new UntypedType(); - // TODO: check this - // def.setToolObject(TYPE_ID, u); - // u.addFollower(def); - setTypeAndFollowers(def, visitExprNode(def.getBody(), u)); + TLAType found = visitExprNode(def.getBody(), new UntypedType()); + setLocalTypeAndFollowers(def, found); } private void evalAssumptions(AssumeNode[] assumptions) throws TLA2BException { @@ -198,7 +229,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo case TLCValueKind: { TLCValueNode valueNode = (TLCValueNode) exprNode; return unify(valueNode.getType(), expected, valueNode.getValue().toString() - + " (assigned in the configuration file)", exprNode); + + " (assigned in the configuration file)", exprNode); } case OpApplKind: return visitOpApplNode((OpApplNode) exprNode, expected); @@ -209,8 +240,8 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo case StringKind: return unify(StringType.getInstance(), expected, ((StringNode) exprNode).getRep().toString(), exprNode); case AtNodeKind: { // @ - TLAType type = getType((((AtNode) exprNode).getExceptComponentRef()).getArgs()[1]); // right side - return unifyAndSetTypeWithFollowers(type, expected, "@", exprNode); + TLAType type = getLocalType((((AtNode) exprNode).getExceptComponentRef()).getArgs()[1]); // right side + return unifyAndSetLocalTypeWithFollowers(type, expected, "@", exprNode); } case LetInKind: { LetInNode l = (LetInNode) exprNode; @@ -229,28 +260,29 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo private TLAType visitOpApplNode(OpApplNode n, TLAType expected) throws TLA2BException { switch (n.getOperator().getKind()) { case ConstantDeclKind: { - OpDeclNode con = (OpDeclNode) n.getOperator(); - TLAType c = getType(con); + SymbolNode symbolNode = n.getOperator(); + String vName = symbolNode.getName().toString(); + TLAType c = getLocalType(symbolNode); if (c == null) { - throw new TypeErrorException(con.getName() + " has no type yet!"); + throw new TypeErrorException(vName + " has no type yet!"); } - return unifyAndSetType(c, expected, con.getName().toString(), con); + return unifyAndSetLocalTypeWithFollowers(c, expected, vName, n); } case VariableDeclKind: { SymbolNode symbolNode = n.getOperator(); String vName = symbolNode.getName().toString(); - TLAType v = getType(symbolNode); + TLAType v = getLocalType(symbolNode); if (v == null) { SymbolNode var = this.specAnalyser.getSymbolNodeByName(vName); if (var != null) { // symbolNode is variable of an expression, e.g. v + 1 - v = getType(var); + v = getLocalType(var); } else { throw new TypeErrorException(vName + " has no type yet!"); } } - return unifyAndSetType(v, expected, vName, n); + return unifyAndSetLocalTypeWithFollowers(v, expected, vName, n); } case BuiltInKind: @@ -258,22 +290,13 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo case FormalParamKind: { SymbolNode symbolNode = n.getOperator(); - TLAType t = getType(symbolNode, paramId); - if (t == null) { // no temp type - t = getType(symbolNode); - if (t == null) { // no type at all - t = new UntypedType(); // TODO is this correct? - // throw new RuntimeException(); - } - } - try { - TLAType result = expected.unify(t); - setType(symbolNode, result, paramId); - return result; - } catch (UnificationException e) { - throw new TypeErrorException(String.format("Expected %s, found %s at parameter '%s',%n%s", expected, t, - symbolNode.getName(), n.getLocation())); + String vName = symbolNode.getName().toString(); + TLAType t = getLocalType(symbolNode); + if (t == null) { + t = new UntypedType(); // TODO is this correct? + // throw new RuntimeException(); } + return unifyAndSetLocalType(t, expected, vName, n); } case UserDefinedOpKind: { @@ -287,22 +310,33 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo } // the definition might be generic, so we have to re-evaluate - // the definition body with the concrete types we have here as args + // its body with the concrete types we have here as args // set param types assert params.length == args.length; for (int i = 0; i < args.length; i++) { TLAType argType = visitExprOrOpArgNode(args[i], new UntypedType()); - setType(params[i], argType.cloneTLAType(), TEMP_TYPE_ID); + + int prevParamId = paramId; + paramId = TEMP_TYPE_ID; + try { + setLocalType(params[i], argType); + } finally { + paramId = prevParamId; + } } // re-evaluate definition body + int prevParamId = paramId; paramId = TEMP_TYPE_ID; - TLAType found = visitExprNode(def.getBody(), expected); - paramId = TYPE_ID; + TLAType found; + try { + found = visitExprNode(def.getBody(), expected); + } finally { + paramId = prevParamId; + } - setType(n, found); - return found; + return unify(found, expected, n); } default: @@ -447,27 +481,24 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo found = new FunctionType(IntType.getInstance(), list.get(0)); } else { found = TupleOrFunction.createTupleOrFunctionType(list); - // found = new TupleType(list); } - tupleNodeList.add(n); - return unifyAndSetTypeWithFollowers(found, expected, "tuple", n); + return unifyAndSetLocalTypeWithFollowers(found, expected, "tuple", n); } /* * Function constructors */ - case OPCODE_rfs: { // recursive function ( f[x\in Nat] == IF x = 0 THEN 1 - // ELSE f[n-1] + case OPCODE_rfs: { // recursive function ( f[x\in Nat] == IF x = 0 THEN 1 ELSE f[n-1] FormalParamNode recFunc = n.getUnbdedQuantSymbols()[0]; symbolNodeList.add(recFunc); - setTypeAndFollowers(recFunc, new FunctionType()); + setLocalTypeAndFollowers(recFunc, new FunctionType()); TLAType domainType = evalBoundedVariables(n); FunctionType found = new FunctionType(domainType, new UntypedType()); visitExprOrOpArgNode(n.getArgs()[0], found.getRange()); found = (FunctionType) unify(found, expected, n); - return unify(found, getType(recFunc), n); + return unify(found, getLocalType(recFunc), n); } case OPCODE_nrfs: // succ[n \in Nat] == n + 1 @@ -490,17 +521,16 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo domList.add(visitExprOrOpArgNode(arg, new UntypedType())); } domType = domList.size() == 1 - ? new FunctionType(IntType.getInstance(), domList.get(0)) // one-tuple - : new TupleType(domList); + ? new FunctionType(IntType.getInstance(), domList.get(0)) // one-tuple + : new TupleType(domList); } else if (dom instanceof NumeralNode) { NumeralNode num = (NumeralNode) dom; UntypedType u = new UntypedType(); - setTypeAndFollowers(n, u); + setLocalTypeAndFollowers(n, u); TLAType res = visitExprOrOpArgNode(n.getArgs()[0], new TupleOrFunction(num.val(), u)); - setTypeAndFollowers(n.getArgs()[0], res); - tupleNodeList.add(n.getArgs()[0]); - return unify(getType(n), expected, n.getArgs()[0].toString(), n); + setLocalTypeAndFollowers(n.getArgs()[0], res); + return unify(getLocalType(n), expected, n.getArgs()[0].toString(), n); } else { domType = visitExprOrOpArgNode(dom, new UntypedType()); } @@ -561,7 +591,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo SetType fieldType = (SetType) visitExprOrOpArgNode(pair.getArgs()[1], new SetType(new UntypedType())); struct.add(field.getRep().toString(), fieldType.getSubType()); } - return unifyAndSetTypeWithFollowers(new SetType(struct), expected, "set of records", n); + return unifyAndSetLocalTypeWithFollowers(new SetType(struct), expected, "set of records", n); } case OPCODE_rc: { // [h_1 |-> 1, h_2 |-> 2] @@ -572,7 +602,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo TLAType fieldType = visitExprOrOpArgNode(pair.getArgs()[1], new UntypedType()); found.add(field.getRep().toString(), fieldType); } - return unifyAndSetTypeWithFollowers(found, expected, "record constructor", n); + return unifyAndSetLocalTypeWithFollowers(found, expected, "record constructor", n); } case OPCODE_rs: { // $RcdSelect r.c @@ -588,7 +618,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo throw new TypeErrorException(String.format("Struct has no field %s with type %s: %s%n%s", fieldName, r.getType(fieldName), r, n.getLocation())); } - setTypeAndFollowers(n.getArgs()[0], r); + setLocalTypeAndFollowers(n.getArgs()[0], r); return r.getType(fieldName); } @@ -599,7 +629,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo visitExprOrOpArgNode(n.getArgs()[0], BoolType.getInstance()); TLAType then = visitExprOrOpArgNode(n.getArgs()[1], expected); TLAType eelse = visitExprOrOpArgNode(n.getArgs()[2], then); - setTypeAndFollowers(n, eelse); + setLocalTypeAndFollowers(n, eelse); return eelse; } @@ -623,7 +653,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo for (FormalParamNode param : n.getUnbdedQuantSymbols()) { TLAType paramType = new UntypedType(); symbolNodeList.add(param); - setTypeAndFollowers(param, paramType); + setLocalTypeAndFollowers(param, paramType); list.add(paramType); } TLAType found; @@ -632,7 +662,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo } else { found = new TupleType(list); } - found = unifyAndSetTypeWithFollowers(found, expected, n.getOperator().getName().toString(), n); + found = unifyAndSetLocalTypeWithFollowers(found, expected, n.getOperator().getName().toString(), n); visitExprOrOpArgNode(n.getArgs()[0], BoolType.getInstance()); return found; } @@ -655,7 +685,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo default: throw new NotImplementedException( - "Not supported Operator: " + n.getOperator().getName() + "\n" + n.getLocation()); + "Not supported Operator: " + n.getOperator().getName() + "\n" + n.getLocation()); } } @@ -674,7 +704,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo expected = (FunctionType) unify(expected, subType, "parameter " + p.getName(), bounds[i]); domList.add(expected); symbolNodeList.add(p); - setTypeAndFollowers(p, expected.getRange()); + setLocalTypeAndFollowers(p, expected.getRange()); } else { TupleType tuple = new TupleType(params[i].length); tuple = (TupleType) unify(tuple, subType, "tuple", bounds[i]); @@ -682,7 +712,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo for (int j = 0; j < params[i].length; j++) { FormalParamNode p = params[i][j]; symbolNodeList.add(p); - setTypeAndFollowers(p, tuple.getTypes().get(j)); + setLocalTypeAndFollowers(p, tuple.getTypes().get(j)); } } } else { // is not a tuple: all parameter have the same type @@ -690,7 +720,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo domList.add(subType); FormalParamNode p = params[i][j]; symbolNodeList.add(p); - setTypeAndFollowers(p, subType); + setLocalTypeAndFollowers(p, subType); } } } @@ -699,13 +729,13 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo private TLAType evalExcept(OpApplNode n, TLAType expected) throws TLA2BException { TLAType t = visitExprOrOpArgNode(n.getArgs()[0], expected); - setTypeAndFollowers(n, t); + setLocalTypeAndFollowers(n, t); for (int i = 1; i < n.getArgs().length; i++) { // start at 1 OpApplNode pair = (OpApplNode) n.getArgs()[i]; // stored for @ node UntypedType untyped = new UntypedType(); - setTypeAndFollowers(pair.getArgs()[1], untyped); + setLocalTypeAndFollowers(pair.getArgs()[1], untyped); TLAType valueType = visitExprOrOpArgNode(pair.getArgs()[1], untyped); // right side OpApplNode seq = (OpApplNode) pair.getArgs()[0]; // left side @@ -726,7 +756,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo domList.add(visitExprOrOpArgNode(arg, new UntypedType())); } domType = new TupleType(domList); - setType(domExpr, domType); // store type + setLocalType(domExpr, domType); // store type } else { domType = visitExprOrOpArgNode(domExpr, new UntypedType()); } @@ -764,14 +794,10 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo case B_OPCODE_leq: // <= case B_OPCODE_geq: { // >= TLAType boolType = unify(BoolType.getInstance(), expected, n); - try { - for (ExprOrOpArgNode arg : n.getArgs()) { - visitExprOrOpArgNode(arg, IntType.getInstance()); - } - } catch (TypeErrorException e) { - for (ExprOrOpArgNode arg : n.getArgs()) { - visitExprOrOpArgNode(arg, RealType.getInstance()); - } + TLAType numberType = new IntegerOrRealType(); + for (ExprOrOpArgNode arg : n.getArgs()) { + numberType = visitExprOrOpArgNode(arg, numberType); + setLocalTypeAndFollowers(arg, numberType); } return boolType; } @@ -780,24 +806,15 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo // for UntypedTypes the default is integer; if this leads to a TypeErrorException real is tried instead case B_OPCODE_plus: // + case B_OPCODE_minus: // - + case B_OPCODE_uminus: // -x case B_OPCODE_times: // * case B_OPCODE_div: // / case B_OPCODE_mod: // % modulo - case B_OPCODE_exp: { // x hoch y, x^y - TLAType type; - try { - IntType.getInstance().unify(expected); // throws UnificationException - type = IntType.getInstance(); - for (ExprOrOpArgNode arg : n.getArgs()) { - // throws TypeErrorException; check whether IntType is OK, else try the same with RealType - visitExprOrOpArgNode(arg, type); - } - } catch (UnificationException | TypeErrorException e) { - type = unify(RealType.getInstance(), expected, n); - for (ExprOrOpArgNode arg : n.getArgs()) { - // if TypeErrorException is thrown here, the type is incompatible and it is a real type error! - visitExprOrOpArgNode(arg, type); - } + case B_OPCODE_exp: { // x to the power of y, x^y + TLAType type = unify(new IntegerOrRealType(), expected, n); + for (ExprOrOpArgNode arg : n.getArgs()) { + type = visitExprOrOpArgNode(arg, type); + setLocalTypeAndFollowers(arg, type); } return type; } @@ -830,21 +847,6 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo case B_OPCODE_real: // Real return unify(new SetType(RealType.getInstance()), expected, n); - case B_OPCODE_uminus: { // -x - TLAType type; - try { - IntType.getInstance().unify(expected); // throws UnificationException - type = IntType.getInstance(); - // throws TypeErrorException; check whether IntType is OK, else try the same with RealType - visitExprOrOpArgNode(n.getArgs()[0], type); - } catch (UnificationException | TypeErrorException e) { - type = unify(RealType.getInstance(), expected, n); - // if TypeErrorException is thrown here, the type is incompatible and it is a real type error! - visitExprOrOpArgNode(n.getArgs()[0], type); - } - return type; - } - /* * Standard Module FiniteSets */ @@ -975,34 +977,74 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo /* * Utility methods */ - private static void setTypeAndFollowers(SemanticNode node, TLAType type, int paramId) { - setType(node, type, paramId); + private static boolean hasGlobalTyping(SemanticNode node) { + SymbolNode symbol = null; + if (node instanceof SymbolNode) { + symbol = (SymbolNode) node; + } else if (node instanceof OpApplNode) { + symbol = ((OpApplNode) node).getOperator(); + } + + return symbol != null && (symbol.getKind() == ConstantDeclKind || symbol.getKind() == VariableDeclKind); + } + + private void setLocalTypeAndFollowers(SemanticNode node, TLAType type) { + setLocalType(node, type); if (type instanceof AbstractHasFollowers) { ((AbstractHasFollowers) type).addFollower(node); } } - private static void setTypeAndFollowers(SemanticNode node, TLAType type) { - setTypeAndFollowers(node, type, TYPE_ID); + public static void updateTypeAndFollowers(SemanticNode node, TLAType oldType, TLAType newType) { + if (getType(node, TYPE_ID) == oldType) { + setType(node, newType, TYPE_ID); + if (newType instanceof AbstractHasFollowers) { + ((AbstractHasFollowers) newType).addFollower(node); + } + } + if (getType(node, TEMP_TYPE_ID) == oldType) { + setType(node, newType, TEMP_TYPE_ID); + if (newType instanceof AbstractHasFollowers) { + ((AbstractHasFollowers) newType).addFollower(node); + } + } } private static void setType(SemanticNode node, TLAType type, int paramId) { node.setToolObject(paramId, type); } + private void setLocalType(SemanticNode node, TLAType type) { + if (type instanceof IDefaultableType) { + this.possibleUnfinishedTypes.add((IDefaultableType) type); + } + if (paramId != TYPE_ID && hasGlobalTyping(node)) { + setType(node, type, TYPE_ID); + } else { + setType(node, type, paramId); + } + } + public static void setType(SemanticNode node, TLAType type) { setType(node, type, TYPE_ID); } - private static TLAType getType(SemanticNode n, int paramId) { - return (TLAType) n.getToolObject(paramId); + private static TLAType getType(SemanticNode node, int paramId) { + return (TLAType) node.getToolObject(paramId); + } + + private TLAType getLocalType(SemanticNode node) { + if (paramId != TYPE_ID && hasGlobalTyping(node)) { + return getType(node, TYPE_ID); + } + return getType(node, paramId); } - public static TLAType getType(SemanticNode n) { - return getType(n, TYPE_ID); + public static TLAType getType(SemanticNode node) { + return getType(node, TYPE_ID); } - private TLAType unify(TLAType toUnify, TLAType expected, String opMsg, SemanticNode n) throws TypeErrorException { + private static TLAType unify(TLAType toUnify, TLAType expected, String opMsg, SemanticNode n) throws TypeErrorException { TLAType found = toUnify; DebugUtils.printDebugMsg("Unify " + found + " and " + expected + " at '" + opMsg + "' (" + n.getLocation() + ")"); try { @@ -1014,19 +1056,19 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo return found; } - private TLAType unify(TLAType toUnify, TLAType expected, OpApplNode n) throws TypeErrorException { + private static TLAType unify(TLAType toUnify, TLAType expected, OpApplNode n) throws TypeErrorException { return unify(toUnify, expected, n.getOperator().getName().toString(), n); } - private TLAType unifyAndSetTypeWithFollowers(TLAType toUnify, TLAType expected, String opMsg, SemanticNode n) throws TypeErrorException { + private TLAType unifyAndSetLocalTypeWithFollowers(TLAType toUnify, TLAType expected, String opMsg, SemanticNode n) throws TypeErrorException { TLAType found = unify(toUnify, expected, opMsg, n); - setTypeAndFollowers(n, found); + setLocalTypeAndFollowers(n, found); return found; } - private TLAType unifyAndSetType(TLAType toUnify, TLAType expected, String opMsg, SemanticNode n) throws TypeErrorException { + private TLAType unifyAndSetLocalType(TLAType toUnify, TLAType expected, String opMsg, SemanticNode n) throws TypeErrorException { TLAType found = unify(toUnify, expected, opMsg, n); - setType(n, found); + setLocalType(n, found); return found; } } diff --git a/src/main/java/de/tla2b/output/TlaTypePrinter.java b/src/main/java/de/tla2b/output/TlaTypePrinter.java index 1092ed7..24ec63b 100644 --- a/src/main/java/de/tla2b/output/TlaTypePrinter.java +++ b/src/main/java/de/tla2b/output/TlaTypePrinter.java @@ -78,6 +78,11 @@ public class TlaTypePrinter extends ClassicalPositionPrinter implements TypeVisi pout.printAtom("real"); } + @Override + public void caseIntegerOrRealType(IntegerOrRealType type) { + throw new NotImplementedException("should not happen"); + } + public void caseSetType(SetType type) { pout.openTerm("set"); type.getSubType().apply(this); diff --git a/src/main/java/de/tla2b/output/TypeVisitorInterface.java b/src/main/java/de/tla2b/output/TypeVisitorInterface.java index 9c42f4b..19bcf93 100644 --- a/src/main/java/de/tla2b/output/TypeVisitorInterface.java +++ b/src/main/java/de/tla2b/output/TypeVisitorInterface.java @@ -18,6 +18,8 @@ public interface TypeVisitorInterface { void caseRealType(RealType type); + void caseIntegerOrRealType(IntegerOrRealType type); + void caseSetType(SetType type); void caseStringType(StringType type); diff --git a/src/main/java/de/tla2b/types/AbstractHasFollowers.java b/src/main/java/de/tla2b/types/AbstractHasFollowers.java index afbbc88..7ddf2c0 100644 --- a/src/main/java/de/tla2b/types/AbstractHasFollowers.java +++ b/src/main/java/de/tla2b/types/AbstractHasFollowers.java @@ -1,27 +1,28 @@ package de.tla2b.types; -import de.tla2b.analysis.TypeChecker; -import tla2sany.semantic.SemanticNode; - import java.util.ArrayList; import java.util.List; +import de.tla2b.analysis.TypeChecker; + +import tla2sany.semantic.SemanticNode; + public abstract class AbstractHasFollowers extends TLAType { - private List<Object> followers = new ArrayList<>(); + private List<Object> followers = null; public AbstractHasFollowers(int t) { super(t); } - public List<Object> getFollowers() { - return followers; - } - public void addFollower(Object o) { + if (followers == null) { + followers = new ArrayList<>(); + } // only (partial) untyped types need follower - if (followers != null && !followers.contains(o)) + if (!followers.contains(o)) { followers.add(o); + } } public void deleteFollowers() { @@ -29,31 +30,26 @@ public abstract class AbstractHasFollowers extends TLAType { } public void removeFollower(Object o) { - followers.remove(o); + if (this.hasFollowers()) { + followers.remove(o); + } } protected void setFollowersTo(TLAType newType) { - if (this.followers == null) + if (!this.hasFollowers()) { return; + } // avoid concurrent modification: new ArrayList<>(followers).forEach(follower -> { + //this.removeFollower(follower); if (follower instanceof SemanticNode) { - TypeChecker.setType((SemanticNode) follower, newType); - if (newType instanceof AbstractHasFollowers) { - ((AbstractHasFollowers) newType).addFollower(follower); - } + TypeChecker.updateTypeAndFollowers((SemanticNode) follower, this, newType); } else if (follower instanceof SetType) { - ((SetType) follower).setSubType(newType); + ((SetType) follower).update(this, newType); } else if (follower instanceof TupleType) { ((TupleType) follower).update(this, newType); } else if (follower instanceof PairType) { - PairType pair = ((PairType) follower); - if (pair.getFirst() == this) { - pair.setFirst(newType); - } - if (pair.getSecond() == this) { - pair.setSecond(newType); - } + ((PairType) follower).update(this, newType); } else if (follower instanceof FunctionType) { ((FunctionType) follower).update(this, newType); } else if (follower instanceof StructType) { @@ -68,7 +64,7 @@ public abstract class AbstractHasFollowers extends TLAType { }); } - public boolean hasFollower() { - return !followers.isEmpty(); + public boolean hasFollowers() { + return followers != null && !followers.isEmpty(); } } diff --git a/src/main/java/de/tla2b/types/FunctionType.java b/src/main/java/de/tla2b/types/FunctionType.java index e2f2f7e..7f7b507 100644 --- a/src/main/java/de/tla2b/types/FunctionType.java +++ b/src/main/java/de/tla2b/types/FunctionType.java @@ -26,10 +26,12 @@ public class FunctionType extends AbstractHasFollowers { } public void update(TLAType oldType, TLAType newType) { - if (domain == oldType) - setDomain(newType); - if (range == oldType) - setRange(newType); + if (domain == oldType) { + setDomain(newType); + } + if (range == oldType) { + setRange(newType); + } } @Override diff --git a/src/main/java/de/tla2b/types/IDefaultableType.java b/src/main/java/de/tla2b/types/IDefaultableType.java new file mode 100644 index 0000000..faf6411 --- /dev/null +++ b/src/main/java/de/tla2b/types/IDefaultableType.java @@ -0,0 +1,7 @@ +package de.tla2b.types; + +public interface IDefaultableType { + + TLAType setToDefault(); + +} diff --git a/src/main/java/de/tla2b/types/IntType.java b/src/main/java/de/tla2b/types/IntType.java index bdde67f..6560716 100644 --- a/src/main/java/de/tla2b/types/IntType.java +++ b/src/main/java/de/tla2b/types/IntType.java @@ -29,15 +29,15 @@ public class IntType extends TLAType { @Override public boolean compare(TLAType o) { - return o.getKind() == UNTYPED || o.getKind() == INTEGER; + return o.getKind() == UNTYPED || o.getKind() == INTEGER || o.getKind() == INTEGER_OR_REAL; } @Override public IntType unify(TLAType o) throws UnificationException { if (o.getKind() == INTEGER) { return this; - } else if (o instanceof UntypedType) { - ((UntypedType) o).setFollowersTo(this); + } else if (o.getKind() == INTEGER_OR_REAL || o instanceof UntypedType) { + ((AbstractHasFollowers) o).setFollowersTo(this); return this; } else throw new UnificationException(); diff --git a/src/main/java/de/tla2b/types/IntegerOrRealType.java b/src/main/java/de/tla2b/types/IntegerOrRealType.java new file mode 100644 index 0000000..b2fd72b --- /dev/null +++ b/src/main/java/de/tla2b/types/IntegerOrRealType.java @@ -0,0 +1,65 @@ +package de.tla2b.types; + +import de.be4.classicalb.core.parser.node.PExpression; +import de.tla2b.exceptions.UnificationException; +import de.tla2b.output.TypeVisitorInterface; + +public final class IntegerOrRealType extends AbstractHasFollowers implements IDefaultableType { + + public IntegerOrRealType() { + super(INTEGER_OR_REAL); + } + + @Override + public String toString() { + return "INTEGER_OR_REAL"; + } + + @Override + public boolean isUntyped() { + return true; + } + + @Override + public boolean compare(TLAType o) { + return o.getKind() == UNTYPED || o.getKind() == INTEGER || o.getKind() == REAL || o.getKind() == INTEGER_OR_REAL; + } + + @Override + public TLAType unify(TLAType o) throws UnificationException { + if (o.getKind() == REAL || o.getKind() == INTEGER) { + return o; + } else if (o.getKind() == INTEGER_OR_REAL || o instanceof UntypedType) { + ((AbstractHasFollowers) o).setFollowersTo(this); + return this; + } else { + throw new UnificationException(); + } + } + + @Override + public IntegerOrRealType cloneTLAType() { + return new IntegerOrRealType(); + } + + @Override + public boolean contains(TLAType o) { + return false; + } + + @Override + public TLAType setToDefault() { + TLAType type = IntType.getInstance(); + this.setFollowersTo(type); + return type; + } + + @Override + public PExpression getBNode() { + throw new UnsupportedOperationException("IntegerOrRealType has no corresponding B node."); + } + + public void apply(TypeVisitorInterface visitor) { + visitor.caseIntegerOrRealType(this); + } +} diff --git a/src/main/java/de/tla2b/types/PairType.java b/src/main/java/de/tla2b/types/PairType.java index 450e179..c51bc60 100644 --- a/src/main/java/de/tla2b/types/PairType.java +++ b/src/main/java/de/tla2b/types/PairType.java @@ -20,13 +20,11 @@ public class PairType extends AbstractHasFollowers { super(PAIR); this.first = f; if (first instanceof AbstractHasFollowers) { - AbstractHasFollowers firstHasFollowers = (AbstractHasFollowers) first; - firstHasFollowers.addFollower(this); + ((AbstractHasFollowers) first).addFollower(this); } this.second = s; if (second instanceof AbstractHasFollowers) { - AbstractHasFollowers secondHasFollowers = (AbstractHasFollowers) second; - secondHasFollowers.addFollower(this); + ((AbstractHasFollowers) second).addFollower(this); } } @@ -68,6 +66,15 @@ public class PairType extends AbstractHasFollowers { } } + public void update(TLAType oldType, TLAType newType) { + if (this.first == oldType) { + this.setFirst(newType); + } + if (this.second == oldType) { + this.setSecond(newType); + } + } + @Override public boolean isUntyped() { return first.isUntyped() || second.isUntyped(); diff --git a/src/main/java/de/tla2b/types/RealType.java b/src/main/java/de/tla2b/types/RealType.java index f0becfe..fdb3d91 100644 --- a/src/main/java/de/tla2b/types/RealType.java +++ b/src/main/java/de/tla2b/types/RealType.java @@ -29,15 +29,15 @@ public class RealType extends TLAType { @Override public boolean compare(TLAType o) { - return o.getKind() == UNTYPED || o.getKind() == REAL; + return o.getKind() == UNTYPED || o.getKind() == REAL || o.getKind() == INTEGER_OR_REAL; } @Override public RealType unify(TLAType o) throws UnificationException { if (o.getKind() == REAL) { return this; - } else if (o instanceof UntypedType) { - ((UntypedType) o).setFollowersTo(this); + } else if (o.getKind() == INTEGER_OR_REAL || o instanceof UntypedType) { + ((AbstractHasFollowers) o).setFollowersTo(this); return this; } else throw new UnificationException(); diff --git a/src/main/java/de/tla2b/types/SetType.java b/src/main/java/de/tla2b/types/SetType.java index c1b36c2..06ff98a 100644 --- a/src/main/java/de/tla2b/types/SetType.java +++ b/src/main/java/de/tla2b/types/SetType.java @@ -18,16 +18,11 @@ public class SetType extends AbstractHasFollowers { } public void setSubType(TLAType t) { - // if (subType instanceof AbstractHasFollowers) { - // // delete old reference - // ((AbstractHasFollowers) subType).removeFollower(this); - // } - + this.subType = t; if (t instanceof AbstractHasFollowers) { // set new reference ((AbstractHasFollowers) t).addFollower(this); } - this.subType = t; // setting subType can lead to a completely typed type if (!this.isUntyped()) { @@ -36,6 +31,12 @@ public class SetType extends AbstractHasFollowers { } } + public void update(TLAType oldType, TLAType newType) { + if (this.subType == oldType) { + this.setSubType(newType); + } + } + public SetType unify(TLAType o) throws UnificationException { if (!this.compare(o) || this.contains(o)) { throw new UnificationException(); diff --git a/src/main/java/de/tla2b/types/StructOrFunctionType.java b/src/main/java/de/tla2b/types/StructOrFunctionType.java index be4ec5b..0ed99da 100644 --- a/src/main/java/de/tla2b/types/StructOrFunctionType.java +++ b/src/main/java/de/tla2b/types/StructOrFunctionType.java @@ -7,8 +7,6 @@ import de.tla2b.output.TypeVisitorInterface; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.Map; -import java.util.Map.Entry; -import java.util.Set; public class StructOrFunctionType extends AbstractHasFollowers { private final Map<String, TLAType> types = new LinkedHashMap<>(); @@ -22,13 +20,13 @@ public class StructOrFunctionType extends AbstractHasFollowers { super(STRUCT_OR_FUNCTION); } - public void setNewType(TLAType old, TLAType New) { + public void setNewType(TLAType oldType, TLAType newType) { types.forEach((name, type) -> { - if (type == old) { - if (New instanceof AbstractHasFollowers) { // set new reference - ((AbstractHasFollowers) New).addFollower(this); + if (type == oldType) { + types.put(name, newType); + if (newType instanceof AbstractHasFollowers) { // set new reference + ((AbstractHasFollowers) newType).addFollower(this); } - types.put(name, New); } }); testRecord(); diff --git a/src/main/java/de/tla2b/types/StructType.java b/src/main/java/de/tla2b/types/StructType.java index 98e3455..0c1ff50 100644 --- a/src/main/java/de/tla2b/types/StructType.java +++ b/src/main/java/de/tla2b/types/StructType.java @@ -33,16 +33,17 @@ public class StructType extends AbstractHasFollowers { } public void add(String name, TLAType type) { + types.put(name, type); if (type instanceof AbstractHasFollowers) { // set new reference ((AbstractHasFollowers) type).addFollower(this); } - types.put(name, type); } - public void setNewType(TLAType old, TLAType New) { + public void setNewType(TLAType old, TLAType newType) { types.forEach((name, type) -> { - if (type == old) - add(name, New); + if (type == old) { + add(name, newType); + } }); } diff --git a/src/main/java/de/tla2b/types/TLAType.java b/src/main/java/de/tla2b/types/TLAType.java index 368fc4b..fbfe5f0 100644 --- a/src/main/java/de/tla2b/types/TLAType.java +++ b/src/main/java/de/tla2b/types/TLAType.java @@ -20,6 +20,7 @@ public abstract class TLAType { static int TUPLE = 11; static int TUPLE_OR_FUNCTION = 12; static int REAL = 13; + static int INTEGER_OR_REAL = 14; private final int kind; diff --git a/src/main/java/de/tla2b/types/TupleOrFunction.java b/src/main/java/de/tla2b/types/TupleOrFunction.java index 396f87a..c951408 100644 --- a/src/main/java/de/tla2b/types/TupleOrFunction.java +++ b/src/main/java/de/tla2b/types/TupleOrFunction.java @@ -7,7 +7,7 @@ import de.tla2b.output.TypeVisitorInterface; import java.util.*; import java.util.stream.Collectors; -public class TupleOrFunction extends AbstractHasFollowers { +public class TupleOrFunction extends AbstractHasFollowers implements IDefaultableType { private final Map<Integer, TLAType> types = new LinkedHashMap<>(); public TupleOrFunction(Integer index, TLAType type) { @@ -249,7 +249,8 @@ public class TupleOrFunction extends AbstractHasFollowers { update(); } - public TLAType getFinalType() { + @Override + public TLAType setToDefault() { List<TLAType> list = new ArrayList<>(this.types.values()); if (comparable(list)) { FunctionType func = new FunctionType(IntType.getInstance(), new UntypedType()); diff --git a/src/main/java/de/tla2b/types/TupleType.java b/src/main/java/de/tla2b/types/TupleType.java index 557c1b7..43c09e9 100644 --- a/src/main/java/de/tla2b/types/TupleType.java +++ b/src/main/java/de/tla2b/types/TupleType.java @@ -4,7 +4,6 @@ import de.be4.classicalb.core.parser.node.PExpression; import de.be4.classicalb.core.parser.util.ASTBuilder; import de.tla2b.exceptions.UnificationException; import de.tla2b.output.TypeVisitorInterface; -import de.tla2bAst.BAstCreator; import java.util.ArrayList; import java.util.List; @@ -46,10 +45,11 @@ public class TupleType extends AbstractHasFollowers { TLAType t = types.get(i); if (oldType == t) { types.set(i, newType); + if (newType instanceof AbstractHasFollowers) { + ((AbstractHasFollowers) newType).addFollower(this); + } } } - if (oldType instanceof AbstractHasFollowers) - ((AbstractHasFollowers) oldType).addFollower(this); } @Override diff --git a/src/main/java/de/tla2b/types/UntypedType.java b/src/main/java/de/tla2b/types/UntypedType.java index 54e9c3f..a00944e 100644 --- a/src/main/java/de/tla2b/types/UntypedType.java +++ b/src/main/java/de/tla2b/types/UntypedType.java @@ -4,7 +4,7 @@ import de.be4.classicalb.core.parser.node.PExpression; import de.tla2b.exceptions.UnificationException; import de.tla2b.output.TypeVisitorInterface; -public class UntypedType extends AbstractHasFollowers { +public final class UntypedType extends AbstractHasFollowers { public UntypedType() { super(UNTYPED); @@ -14,6 +14,12 @@ public class UntypedType extends AbstractHasFollowers { if (!this.compare(o)) { throw new UnificationException(); } + + // if the other type is just an empty untyped one we can return this + if (o instanceof UntypedType && !((UntypedType) o).hasFollowers()) { + return this; + } + // u2 contains more or equal type information than untyped (this) this.setFollowersTo(o); //this.deleteFollowers(); diff --git a/src/test/java/de/tla2b/typechecking/DefinitionsTest.java b/src/test/java/de/tla2b/typechecking/DefinitionsTest.java index 8f2dc45..82d9fb2 100644 --- a/src/test/java/de/tla2b/typechecking/DefinitionsTest.java +++ b/src/test/java/de/tla2b/typechecking/DefinitionsTest.java @@ -8,7 +8,6 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; - public class DefinitionsTest { /* @@ -17,8 +16,8 @@ public class DefinitionsTest { @Test public void testDefinition() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" - + "foo(a,b) == a = 1 /\\ b = TRUE \n" - + "Next == foo(1,TRUE) \n" + + "foo(a,b) == a = 1 /\\ b = TRUE \n" + + "Next == foo(1,TRUE) \n" + "================================="; TestTypeChecker t = TestUtil.typeCheckString(module); assertEquals("BOOL", t.getDefinitionType("foo")); @@ -30,7 +29,7 @@ public class DefinitionsTest { public void testDefinition2() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" + "CONSTANTS k, k2 \n" - + "foo(a,b) == a = k /\\ b = k2 \n" + + "foo(a,b) == a = k /\\ b = k2 \n" + "bar == k = 1 /\\ k2 = TRUE \n" + "ASSUME foo(1,FALSE) /\\ bar \n" + "================================="; @@ -46,9 +45,10 @@ public class DefinitionsTest { public void testDefinition3() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" + "CONSTANTS k \n" - + "foo == k \n" + + "foo == k \n" + "bar == foo = 1 \n" - + "ASSUME bar \n" + "================================="; + + "ASSUME bar \n" + + "================================="; TestTypeChecker t = TestUtil.typeCheckString(module); assertEquals("INTEGER", t.getDefinitionType("foo")); assertEquals("BOOL", t.getDefinitionType("bar")); @@ -58,8 +58,8 @@ public class DefinitionsTest { public void testDefinition4() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" + "CONSTANTS k, k2 \n" - + "foo(var, value) == var = value \n" - + "ASSUME foo(k,1) /\\ foo(k2,TRUE) \n" + + "foo(var, value) == var = value \n" + + "ASSUME foo(k,1) /\\ foo(k2,TRUE) \n" + "================================="; TestTypeChecker t = TestUtil.typeCheckString(module); assertEquals("BOOL", t.getDefinitionType("foo")); @@ -74,7 +74,7 @@ public class DefinitionsTest { @Test public void testDefinitionCall() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" - + "foo(a) == TRUE \n" + + "foo(a) == TRUE \n" + "bar == foo(1) \n" + "ASSUME bar \n" + "================================="; @@ -86,7 +86,7 @@ public class DefinitionsTest { @Test public void testDefinitionCall2() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" - + "foo(a) == a \n" + + "foo(a) == a \n" + "bar == foo(1) \n" + "baz == foo(TRUE) \n" + "ASSUME baz /\\ bar = bar" @@ -101,9 +101,9 @@ public class DefinitionsTest { public void testDefinitionCall3() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" + "CONSTANTS k, k2 \n" - + "foo(a) == a \n" + + "foo(a) == a \n" + "bar == foo(1) \n" - + "baz == k = foo(k2) /\\ k2 = bar \n" + + "baz == k = foo(k2) /\\ k2 = bar \n" + "ASSUME baz \n" + "================================="; TestTypeChecker t = TestUtil.typeCheckString(module); @@ -118,7 +118,7 @@ public class DefinitionsTest { public void testDefinitionCall4() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" + "CONSTANTS k, k2 \n" - + "foo(a,b) == a \\cup b \n" + + "foo(a,b) == a \\cup b \n" + "bar == foo({1}, k) \n" + "baz == foo({TRUE}, k2)\n" + "ASSUME baz = baz /\\ bar = bar" @@ -135,7 +135,7 @@ public class DefinitionsTest { public void testDefinitionCall5() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" + "CONSTANTS k \n" - + "foo(a,b) == a = b \n" + + "foo(a,b) == a = b \n" + "bar == foo(1,k) \n" + "ASSUME bar \n" + "================================="; @@ -148,32 +148,74 @@ public class DefinitionsTest { public void testDefinitionCall6() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" + "CONSTANTS k, k2 \n" - + "foo(a,b) == a = b \n" + + "foo(a,b) == a = b \n" + "bar == foo(k, k2) /\\ k2 = 1 \n" - + "ASSUME bar \n" + + "ASSUME bar /\\ foo(TRUE,TRUE) \n" + "================================="; TestTypeChecker t = TestUtil.typeCheckString(module); assertEquals("BOOL", t.getDefinitionType("foo")); assertTrue(t.getDefinitionParamType("foo", "a").startsWith("UNTYPED")); - assertTrue(t.getDefinitionParamType("foo", "b").startsWith("UNTYPED")); + assertEquals(t.getDefinitionParamType("foo", "a"), t.getDefinitionParamType("foo", "b")); assertEquals("BOOL", t.getDefinitionType("bar")); assertEquals("INTEGER", t.getConstantType("k")); assertEquals("INTEGER", t.getConstantType("k2")); } + @Test + public void testDefinitionCall7Simpler1() throws TLA2BException { + final String module = "-------------- MODULE Testing ----------------\n" + + "CONSTANTS k \n" + + "foo(x) == x = {1} \n" + + "ASSUME foo(k) \n" + + "================================="; + TestTypeChecker t = TestUtil.typeCheckString(module); + assertEquals("BOOL", t.getDefinitionType("foo")); + assertEquals("POW(INTEGER)", t.getDefinitionParamType("foo", "x")); + assertEquals("POW(INTEGER)", t.getConstantType("k")); + } + + @Test + public void testDefinitionCall7Simpler2() throws TLA2BException { + final String module = "-------------- MODULE Testing ----------------\n" + + "CONSTANTS k \n" + + "foo(a,b) == a \\cup b \n" + + "ASSUME foo(k,{}) = {1} \n" + + "================================="; + TestTypeChecker t = TestUtil.typeCheckString(module); + assertTrue(t.getDefinitionType("foo").startsWith("POW(UNTYPED")); + assertTrue(t.getDefinitionParamType("foo", "a").startsWith("POW(UNTYPED")); + assertEquals(t.getDefinitionParamType("foo", "a"), t.getDefinitionParamType("foo", "b")); + assertEquals("POW(INTEGER)", t.getConstantType("k")); + } + + @Test + public void testDefinitionCall7Simpler3() throws TLA2BException { + final String module = "-------------- MODULE Testing ----------------\n" + + "CONSTANTS k, k2, k3 \n" + + "foo(a,b) == a \\cup b \n" + + "ASSUME k2 = foo(k3, k) /\\ k3 = {1} \n" + + "================================="; + TestTypeChecker t = TestUtil.typeCheckString(module); + assertTrue(t.getDefinitionType("foo").startsWith("POW(UNTYPED")); + assertTrue(t.getDefinitionParamType("foo", "a").startsWith("POW(UNTYPED")); + assertEquals(t.getDefinitionParamType("foo", "a"), t.getDefinitionParamType("foo", "b")); + assertEquals("POW(INTEGER)", t.getConstantType("k")); + assertEquals("POW(INTEGER)", t.getConstantType("k2")); + assertEquals("POW(INTEGER)", t.getConstantType("k3")); + } + @Test public void testDefinitionCall7() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" + "CONSTANTS k, k2, k3 \n" - + "foo(a,b) == a \\cup b \n" - + "bar(x,y) == x = foo(y, k) /\\ y ={1} \n" - + "ASSUME bar(k2,k3) \n" + "================================="; + + "foo(a,b) == a \\cup b \n" + + "bar(x,y) == x = foo(y, k) /\\ y = {1} \n" + + "ASSUME bar(k2,k3) \n" + + "================================="; TestTypeChecker t = TestUtil.typeCheckString(module); assertTrue(t.getDefinitionType("foo").startsWith("POW(UNTYPED")); - assertTrue(t.getDefinitionParamType("foo", "a").startsWith( - "POW(UNTYPED")); - assertTrue(t.getDefinitionParamType("foo", "b").startsWith( - "POW(UNTYPED")); + assertTrue(t.getDefinitionParamType("foo", "a").startsWith("POW(UNTYPED")); + assertEquals(t.getDefinitionParamType("foo", "a"), t.getDefinitionParamType("foo", "b")); assertEquals("BOOL", t.getDefinitionType("bar")); assertEquals("POW(INTEGER)", t.getDefinitionParamType("bar", "x")); assertEquals("POW(INTEGER)", t.getDefinitionParamType("bar", "y")); @@ -184,10 +226,11 @@ public class DefinitionsTest { public void testDefinitionCall8() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" + "CONSTANTS k, k2 \n" - + "foo(a) == k = a \n" - + "bar == foo(k2)\n" + + "foo(a) == k = a \n" + + "bar == foo(k2) \n" + "baz == k2 = 1 \n" - + "ASSUME bar /\\ baz \n" + "================================="; + + "ASSUME bar /\\ baz \n" + + "================================="; TestTypeChecker t = TestUtil.typeCheckString(module); assertEquals("INTEGER", t.getConstantType("k")); assertEquals("INTEGER", t.getConstantType("k2")); @@ -201,7 +244,7 @@ public class DefinitionsTest { public void testDefinitionCall9() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" + "CONSTANTS k, k2 \n" - + "foo(a,b) == a = b \n" + + "foo(a,b) == a = b \n" + "ASSUME foo(k, 1) /\\ foo(k2, TRUE) \n" + "================================="; TestTypeChecker t = TestUtil.typeCheckString(module); @@ -215,7 +258,7 @@ public class DefinitionsTest { @Test public void testDefinitionCall10() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" - + "foo(a,b) == a= 1 /\\ b = TRUE \n" + + "foo(a,b) == a = 1 /\\ b = TRUE \n" + "ASSUME foo(1, TRUE) \n" + "================================="; TestTypeChecker t = TestUtil.typeCheckString(module); diff --git a/src/test/java/de/tla2b/typechecking/TupleTest.java b/src/test/java/de/tla2b/typechecking/TupleTest.java index e764548..0683bbf 100644 --- a/src/test/java/de/tla2b/typechecking/TupleTest.java +++ b/src/test/java/de/tla2b/typechecking/TupleTest.java @@ -119,14 +119,15 @@ public class TupleTest { assertEquals("POW(INTEGER*BOOL)", t.getConstantType("k")); } - @Test(expected = TypeErrorException.class) + @Test public void testTuple2Elements() throws TLA2BException { final String module = "-------------- MODULE Testing ----------------\n" + "CONSTANTS k, k2, k3 \n" + "ASSUME k = <<k2, k3>> /\\ k3 = TRUE \n" + "================================="; - TestUtil.typeCheckString(module); + TestTypeChecker t = TestUtil.typeCheckString(module); + assertEquals("POW(INTEGER*BOOL)", t.getConstantType("k")); } @Test -- GitLab