Skip to content
Snippets Groups Projects
Verified Commit e4980cb5 authored by Miles Vella's avatar Miles Vella
Browse files

Fix a lot of typechecking bugs with user definitions and ints vs reals

parent 2e7e1854
Branches
Tags
No related merge requests found
Pipeline #157587 passed
Showing
with 418 additions and 240 deletions
package de.tla2b.analysis; package de.tla2b.analysis;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import de.tla2b.config.ConfigfileEvaluator; import de.tla2b.config.ConfigfileEvaluator;
import de.tla2b.config.TLCValueNode; import de.tla2b.config.TLCValueNode;
import de.tla2b.exceptions.*; import de.tla2b.exceptions.NotImplementedException;
import de.tla2b.exceptions.TLA2BException;
import de.tla2b.exceptions.TLA2BFrontEndException;
import de.tla2b.exceptions.TypeErrorException;
import de.tla2b.exceptions.UnificationException;
import de.tla2b.global.BBuildIns; import de.tla2b.global.BBuildIns;
import de.tla2b.global.BBuiltInOPs; import de.tla2b.global.BBuiltInOPs;
import de.tla2b.global.TranslationGlobals; import de.tla2b.global.TranslationGlobals;
import de.tla2b.types.*; import de.tla2b.types.AbstractHasFollowers;
import de.tla2b.types.BoolType;
import de.tla2b.types.FunctionType;
import de.tla2b.types.IDefaultableType;
import de.tla2b.types.IntType;
import de.tla2b.types.IntegerOrRealType;
import de.tla2b.types.RealType;
import de.tla2b.types.SetType;
import de.tla2b.types.StringType;
import de.tla2b.types.StructOrFunctionType;
import de.tla2b.types.StructType;
import de.tla2b.types.TLAType;
import de.tla2b.types.TupleOrFunction;
import de.tla2b.types.TupleType;
import de.tla2b.types.UntypedType;
import de.tla2b.util.DebugUtils; import de.tla2b.util.DebugUtils;
import tla2sany.semantic.*;
import tlc2.tool.BuiltInOPs;
import java.util.*; import tla2sany.semantic.AssumeNode;
import java.util.Map.Entry; import tla2sany.semantic.AtNode;
import tla2sany.semantic.ExprNode;
import tla2sany.semantic.ExprOrOpArgNode;
import tla2sany.semantic.FormalParamNode;
import tla2sany.semantic.LetInNode;
import tla2sany.semantic.ModuleNode;
import tla2sany.semantic.NumeralNode;
import tla2sany.semantic.OpApplNode;
import tla2sany.semantic.OpDeclNode;
import tla2sany.semantic.OpDefNode;
import tla2sany.semantic.SemanticNode;
import tla2sany.semantic.StringNode;
import tla2sany.semantic.SymbolNode;
import tlc2.tool.BuiltInOPs;
public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlobals { public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlobals {
...@@ -25,7 +63,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -25,7 +63,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
private final Set<OpDefNode> bDefinitions; private final Set<OpDefNode> bDefinitions;
private final List<SymbolNode> symbolNodeList = new ArrayList<>(); private final List<SymbolNode> symbolNodeList = new ArrayList<>();
private final List<SemanticNode> tupleNodeList = new ArrayList<>(); private final List<IDefaultableType> possibleUnfinishedTypes = new ArrayList<>();
private final ModuleNode moduleNode; private final ModuleNode moduleNode;
private List<OpDeclNode> bConstList; private List<OpDeclNode> bConstList;
...@@ -67,18 +105,18 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -67,18 +105,18 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
public void start() throws TLA2BException { public void start() throws TLA2BException {
for (OpDeclNode con : moduleNode.getConstantDecls()) { for (OpDeclNode con : moduleNode.getConstantDecls()) {
if (constantAssignments != null && constantAssignments.containsKey(con)) { if (constantAssignments != null && constantAssignments.containsKey(con)) {
setTypeAndFollowers(con, constantAssignments.get(con).getType()); setLocalTypeAndFollowers(con, constantAssignments.get(con).getType());
} else { } else {
// if constant already has a type: keep type; otherwise add an untyped type // if constant already has a type: keep type; otherwise add an untyped type
if (getType(con) == null) if (getLocalType(con) == null)
setTypeAndFollowers(con, new UntypedType()); setLocalTypeAndFollowers(con, new UntypedType());
} }
} }
for (OpDeclNode var : moduleNode.getVariableDecls()) { for (OpDeclNode var : moduleNode.getVariableDecls()) {
// if variable already has a type: keep type; otherwise add an untyped type // if variable already has a type: keep type; otherwise add an untyped type
if (getType(var) == null) if (getLocalType(var) == null)
setTypeAndFollowers(var, new UntypedType()); setLocalTypeAndFollowers(var, new UntypedType());
} }
evalDefinitions(moduleNode.getOpDefs()); evalDefinitions(moduleNode.getOpDefs());
...@@ -89,10 +127,10 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -89,10 +127,10 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
if (!bConstList.contains(con)) if (!bConstList.contains(con))
continue; continue;
TLAType defType = getType(entry.getValue()); TLAType defType = getLocalType(entry.getValue());
TLAType conType = getType(con); TLAType conType = getLocalType(con);
try { try {
setType(con, defType.unify(conType)); setLocalType(con, defType.unify(conType));
} catch (UnificationException e) { } catch (UnificationException e) {
throw new TypeErrorException( throw new TypeErrorException(
String.format("Expected %s, found %s at constant '%s'.", defType, conType, con.getName())); String.format("Expected %s, found %s at constant '%s'.", defType, conType, con.getName()));
...@@ -116,10 +154,14 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -116,10 +154,14 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
} }
private void checkIfAllIdentifiersHaveAType() throws TypeErrorException { private void checkIfAllIdentifiersHaveAType() throws TypeErrorException {
for (IDefaultableType type : possibleUnfinishedTypes) {
type.setToDefault();
}
// check if a variable has no type // check if a variable has no type
for (OpDeclNode var : moduleNode.getVariableDecls()) { for (OpDeclNode var : moduleNode.getVariableDecls()) {
TLAType varType = getType(var); TLAType varType = getLocalType(var);
if (varType.isUntyped()) { if (varType == null || varType.isUntyped()) {
throw new TypeErrorException( throw new TypeErrorException(
"The type of the variable '" + var.getName() + "' can not be inferred: " + varType); "The type of the variable '" + var.getName() + "' can not be inferred: " + varType);
} }
...@@ -129,28 +171,20 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -129,28 +171,20 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
// the resulting B Machine are considered // the resulting B Machine are considered
for (OpDeclNode con : moduleNode.getConstantDecls()) { for (OpDeclNode con : moduleNode.getConstantDecls()) {
if (bConstList == null || bConstList.contains(con)) { if (bConstList == null || bConstList.contains(con)) {
TLAType conType = getType(con); TLAType conType = getLocalType(con);
if (conType.isUntyped()) { if (conType == null || conType.isUntyped()) {
throw new TypeErrorException( throw new TypeErrorException(
"The type of constant " + con.getName() + " is still untyped: " + conType); "The type of constant " + con.getName() + " is still untyped: " + conType);
} }
} }
} }
for (SymbolNode symbol : symbolNodeList) { /* TODO: for (SymbolNode symbol : symbolNodeList) {
TLAType type = getType(symbol); TLAType type = getLocalType(symbol);
if (type.isUntyped()) { if (type == null || type.isUntyped()) {
throw new TypeErrorException("Symbol '" + symbol.getName() + "' has no type.\n" + symbol.getLocation()); throw new TypeErrorException("Symbol '" + symbol.getName() + "' has no type.\n" + symbol.getLocation());
} }
} }*/
for (SemanticNode tuple : tupleNodeList) {
TLAType type = getType(tuple);
if (type instanceof TupleOrFunction) {
((TupleOrFunction) type).getFinalType();
}
// TODO: does this do anything?
}
} }
private void evalDefinitions(OpDefNode[] opDefs) throws TLA2BException { private void evalDefinitions(OpDefNode[] opDefs) throws TLA2BException {
...@@ -167,16 +201,13 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -167,16 +201,13 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
public void visitOpDefNode(OpDefNode def) throws TLA2BException { public void visitOpDefNode(OpDefNode def) throws TLA2BException {
for (FormalParamNode p : def.getParams()) { for (FormalParamNode p : def.getParams()) {
if (p.getArity() > 0) { if (p.getArity() > 0) {
throw new TLA2BFrontEndException(String.format("TLA2B do not support 2nd-order operators: '%s'%n %s ", throw new TLA2BFrontEndException(String.format("TLA2B does not support 2nd-order operators: '%s'%n %s ",
def.getName(), def.getLocation())); def.getName(), def.getLocation()));
} }
setTypeAndFollowers(p, new UntypedType(), paramId); setLocalTypeAndFollowers(p, new UntypedType());
} }
UntypedType u = new UntypedType(); TLAType found = visitExprNode(def.getBody(), new UntypedType());
// TODO: check this setLocalTypeAndFollowers(def, found);
// def.setToolObject(TYPE_ID, u);
// u.addFollower(def);
setTypeAndFollowers(def, visitExprNode(def.getBody(), u));
} }
private void evalAssumptions(AssumeNode[] assumptions) throws TLA2BException { private void evalAssumptions(AssumeNode[] assumptions) throws TLA2BException {
...@@ -209,8 +240,8 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -209,8 +240,8 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
case StringKind: case StringKind:
return unify(StringType.getInstance(), expected, ((StringNode) exprNode).getRep().toString(), exprNode); return unify(StringType.getInstance(), expected, ((StringNode) exprNode).getRep().toString(), exprNode);
case AtNodeKind: { // @ case AtNodeKind: { // @
TLAType type = getType((((AtNode) exprNode).getExceptComponentRef()).getArgs()[1]); // right side TLAType type = getLocalType((((AtNode) exprNode).getExceptComponentRef()).getArgs()[1]); // right side
return unifyAndSetTypeWithFollowers(type, expected, "@", exprNode); return unifyAndSetLocalTypeWithFollowers(type, expected, "@", exprNode);
} }
case LetInKind: { case LetInKind: {
LetInNode l = (LetInNode) exprNode; LetInNode l = (LetInNode) exprNode;
...@@ -229,28 +260,29 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -229,28 +260,29 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
private TLAType visitOpApplNode(OpApplNode n, TLAType expected) throws TLA2BException { private TLAType visitOpApplNode(OpApplNode n, TLAType expected) throws TLA2BException {
switch (n.getOperator().getKind()) { switch (n.getOperator().getKind()) {
case ConstantDeclKind: { case ConstantDeclKind: {
OpDeclNode con = (OpDeclNode) n.getOperator(); SymbolNode symbolNode = n.getOperator();
TLAType c = getType(con); String vName = symbolNode.getName().toString();
TLAType c = getLocalType(symbolNode);
if (c == null) { if (c == null) {
throw new TypeErrorException(con.getName() + " has no type yet!"); throw new TypeErrorException(vName + " has no type yet!");
} }
return unifyAndSetType(c, expected, con.getName().toString(), con); return unifyAndSetLocalTypeWithFollowers(c, expected, vName, n);
} }
case VariableDeclKind: { case VariableDeclKind: {
SymbolNode symbolNode = n.getOperator(); SymbolNode symbolNode = n.getOperator();
String vName = symbolNode.getName().toString(); String vName = symbolNode.getName().toString();
TLAType v = getType(symbolNode); TLAType v = getLocalType(symbolNode);
if (v == null) { if (v == null) {
SymbolNode var = this.specAnalyser.getSymbolNodeByName(vName); SymbolNode var = this.specAnalyser.getSymbolNodeByName(vName);
if (var != null) { if (var != null) {
// symbolNode is variable of an expression, e.g. v + 1 // symbolNode is variable of an expression, e.g. v + 1
v = getType(var); v = getLocalType(var);
} else { } else {
throw new TypeErrorException(vName + " has no type yet!"); throw new TypeErrorException(vName + " has no type yet!");
} }
} }
return unifyAndSetType(v, expected, vName, n); return unifyAndSetLocalTypeWithFollowers(v, expected, vName, n);
} }
case BuiltInKind: case BuiltInKind:
...@@ -258,22 +290,13 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -258,22 +290,13 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
case FormalParamKind: { case FormalParamKind: {
SymbolNode symbolNode = n.getOperator(); SymbolNode symbolNode = n.getOperator();
TLAType t = getType(symbolNode, paramId); String vName = symbolNode.getName().toString();
if (t == null) { // no temp type TLAType t = getLocalType(symbolNode);
t = getType(symbolNode); if (t == null) {
if (t == null) { // no type at all
t = new UntypedType(); // TODO is this correct? t = new UntypedType(); // TODO is this correct?
// throw new RuntimeException(); // throw new RuntimeException();
} }
} return unifyAndSetLocalType(t, expected, vName, n);
try {
TLAType result = expected.unify(t);
setType(symbolNode, result, paramId);
return result;
} catch (UnificationException e) {
throw new TypeErrorException(String.format("Expected %s, found %s at parameter '%s',%n%s", expected, t,
symbolNode.getName(), n.getLocation()));
}
} }
case UserDefinedOpKind: { case UserDefinedOpKind: {
...@@ -287,22 +310,33 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -287,22 +310,33 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
} }
// the definition might be generic, so we have to re-evaluate // the definition might be generic, so we have to re-evaluate
// the definition body with the concrete types we have here as args // its body with the concrete types we have here as args
// set param types // set param types
assert params.length == args.length; assert params.length == args.length;
for (int i = 0; i < args.length; i++) { for (int i = 0; i < args.length; i++) {
TLAType argType = visitExprOrOpArgNode(args[i], new UntypedType()); TLAType argType = visitExprOrOpArgNode(args[i], new UntypedType());
setType(params[i], argType.cloneTLAType(), TEMP_TYPE_ID);
int prevParamId = paramId;
paramId = TEMP_TYPE_ID;
try {
setLocalType(params[i], argType);
} finally {
paramId = prevParamId;
}
} }
// re-evaluate definition body // re-evaluate definition body
int prevParamId = paramId;
paramId = TEMP_TYPE_ID; paramId = TEMP_TYPE_ID;
TLAType found = visitExprNode(def.getBody(), expected); TLAType found;
paramId = TYPE_ID; try {
found = visitExprNode(def.getBody(), expected);
} finally {
paramId = prevParamId;
}
setType(n, found); return unify(found, expected, n);
return found;
} }
default: default:
...@@ -447,27 +481,24 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -447,27 +481,24 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
found = new FunctionType(IntType.getInstance(), list.get(0)); found = new FunctionType(IntType.getInstance(), list.get(0));
} else { } else {
found = TupleOrFunction.createTupleOrFunctionType(list); found = TupleOrFunction.createTupleOrFunctionType(list);
// found = new TupleType(list);
} }
tupleNodeList.add(n); return unifyAndSetLocalTypeWithFollowers(found, expected, "tuple", n);
return unifyAndSetTypeWithFollowers(found, expected, "tuple", n);
} }
/* /*
* Function constructors * Function constructors
*/ */
case OPCODE_rfs: { // recursive function ( f[x\in Nat] == IF x = 0 THEN 1 case OPCODE_rfs: { // recursive function ( f[x\in Nat] == IF x = 0 THEN 1 ELSE f[n-1]
// ELSE f[n-1]
FormalParamNode recFunc = n.getUnbdedQuantSymbols()[0]; FormalParamNode recFunc = n.getUnbdedQuantSymbols()[0];
symbolNodeList.add(recFunc); symbolNodeList.add(recFunc);
setTypeAndFollowers(recFunc, new FunctionType()); setLocalTypeAndFollowers(recFunc, new FunctionType());
TLAType domainType = evalBoundedVariables(n); TLAType domainType = evalBoundedVariables(n);
FunctionType found = new FunctionType(domainType, new UntypedType()); FunctionType found = new FunctionType(domainType, new UntypedType());
visitExprOrOpArgNode(n.getArgs()[0], found.getRange()); visitExprOrOpArgNode(n.getArgs()[0], found.getRange());
found = (FunctionType) unify(found, expected, n); found = (FunctionType) unify(found, expected, n);
return unify(found, getType(recFunc), n); return unify(found, getLocalType(recFunc), n);
} }
case OPCODE_nrfs: // succ[n \in Nat] == n + 1 case OPCODE_nrfs: // succ[n \in Nat] == n + 1
...@@ -495,12 +526,11 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -495,12 +526,11 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
} else if (dom instanceof NumeralNode) { } else if (dom instanceof NumeralNode) {
NumeralNode num = (NumeralNode) dom; NumeralNode num = (NumeralNode) dom;
UntypedType u = new UntypedType(); UntypedType u = new UntypedType();
setTypeAndFollowers(n, u); setLocalTypeAndFollowers(n, u);
TLAType res = visitExprOrOpArgNode(n.getArgs()[0], new TupleOrFunction(num.val(), u)); TLAType res = visitExprOrOpArgNode(n.getArgs()[0], new TupleOrFunction(num.val(), u));
setTypeAndFollowers(n.getArgs()[0], res); setLocalTypeAndFollowers(n.getArgs()[0], res);
tupleNodeList.add(n.getArgs()[0]); return unify(getLocalType(n), expected, n.getArgs()[0].toString(), n);
return unify(getType(n), expected, n.getArgs()[0].toString(), n);
} else { } else {
domType = visitExprOrOpArgNode(dom, new UntypedType()); domType = visitExprOrOpArgNode(dom, new UntypedType());
} }
...@@ -561,7 +591,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -561,7 +591,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
SetType fieldType = (SetType) visitExprOrOpArgNode(pair.getArgs()[1], new SetType(new UntypedType())); SetType fieldType = (SetType) visitExprOrOpArgNode(pair.getArgs()[1], new SetType(new UntypedType()));
struct.add(field.getRep().toString(), fieldType.getSubType()); struct.add(field.getRep().toString(), fieldType.getSubType());
} }
return unifyAndSetTypeWithFollowers(new SetType(struct), expected, "set of records", n); return unifyAndSetLocalTypeWithFollowers(new SetType(struct), expected, "set of records", n);
} }
case OPCODE_rc: { // [h_1 |-> 1, h_2 |-> 2] case OPCODE_rc: { // [h_1 |-> 1, h_2 |-> 2]
...@@ -572,7 +602,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -572,7 +602,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
TLAType fieldType = visitExprOrOpArgNode(pair.getArgs()[1], new UntypedType()); TLAType fieldType = visitExprOrOpArgNode(pair.getArgs()[1], new UntypedType());
found.add(field.getRep().toString(), fieldType); found.add(field.getRep().toString(), fieldType);
} }
return unifyAndSetTypeWithFollowers(found, expected, "record constructor", n); return unifyAndSetLocalTypeWithFollowers(found, expected, "record constructor", n);
} }
case OPCODE_rs: { // $RcdSelect r.c case OPCODE_rs: { // $RcdSelect r.c
...@@ -588,7 +618,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -588,7 +618,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
throw new TypeErrorException(String.format("Struct has no field %s with type %s: %s%n%s", fieldName, throw new TypeErrorException(String.format("Struct has no field %s with type %s: %s%n%s", fieldName,
r.getType(fieldName), r, n.getLocation())); r.getType(fieldName), r, n.getLocation()));
} }
setTypeAndFollowers(n.getArgs()[0], r); setLocalTypeAndFollowers(n.getArgs()[0], r);
return r.getType(fieldName); return r.getType(fieldName);
} }
...@@ -599,7 +629,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -599,7 +629,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
visitExprOrOpArgNode(n.getArgs()[0], BoolType.getInstance()); visitExprOrOpArgNode(n.getArgs()[0], BoolType.getInstance());
TLAType then = visitExprOrOpArgNode(n.getArgs()[1], expected); TLAType then = visitExprOrOpArgNode(n.getArgs()[1], expected);
TLAType eelse = visitExprOrOpArgNode(n.getArgs()[2], then); TLAType eelse = visitExprOrOpArgNode(n.getArgs()[2], then);
setTypeAndFollowers(n, eelse); setLocalTypeAndFollowers(n, eelse);
return eelse; return eelse;
} }
...@@ -623,7 +653,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -623,7 +653,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
for (FormalParamNode param : n.getUnbdedQuantSymbols()) { for (FormalParamNode param : n.getUnbdedQuantSymbols()) {
TLAType paramType = new UntypedType(); TLAType paramType = new UntypedType();
symbolNodeList.add(param); symbolNodeList.add(param);
setTypeAndFollowers(param, paramType); setLocalTypeAndFollowers(param, paramType);
list.add(paramType); list.add(paramType);
} }
TLAType found; TLAType found;
...@@ -632,7 +662,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -632,7 +662,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
} else { } else {
found = new TupleType(list); found = new TupleType(list);
} }
found = unifyAndSetTypeWithFollowers(found, expected, n.getOperator().getName().toString(), n); found = unifyAndSetLocalTypeWithFollowers(found, expected, n.getOperator().getName().toString(), n);
visitExprOrOpArgNode(n.getArgs()[0], BoolType.getInstance()); visitExprOrOpArgNode(n.getArgs()[0], BoolType.getInstance());
return found; return found;
} }
...@@ -674,7 +704,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -674,7 +704,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
expected = (FunctionType) unify(expected, subType, "parameter " + p.getName(), bounds[i]); expected = (FunctionType) unify(expected, subType, "parameter " + p.getName(), bounds[i]);
domList.add(expected); domList.add(expected);
symbolNodeList.add(p); symbolNodeList.add(p);
setTypeAndFollowers(p, expected.getRange()); setLocalTypeAndFollowers(p, expected.getRange());
} else { } else {
TupleType tuple = new TupleType(params[i].length); TupleType tuple = new TupleType(params[i].length);
tuple = (TupleType) unify(tuple, subType, "tuple", bounds[i]); tuple = (TupleType) unify(tuple, subType, "tuple", bounds[i]);
...@@ -682,7 +712,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -682,7 +712,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
for (int j = 0; j < params[i].length; j++) { for (int j = 0; j < params[i].length; j++) {
FormalParamNode p = params[i][j]; FormalParamNode p = params[i][j];
symbolNodeList.add(p); symbolNodeList.add(p);
setTypeAndFollowers(p, tuple.getTypes().get(j)); setLocalTypeAndFollowers(p, tuple.getTypes().get(j));
} }
} }
} else { // is not a tuple: all parameter have the same type } else { // is not a tuple: all parameter have the same type
...@@ -690,7 +720,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -690,7 +720,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
domList.add(subType); domList.add(subType);
FormalParamNode p = params[i][j]; FormalParamNode p = params[i][j];
symbolNodeList.add(p); symbolNodeList.add(p);
setTypeAndFollowers(p, subType); setLocalTypeAndFollowers(p, subType);
} }
} }
} }
...@@ -699,13 +729,13 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -699,13 +729,13 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
private TLAType evalExcept(OpApplNode n, TLAType expected) throws TLA2BException { private TLAType evalExcept(OpApplNode n, TLAType expected) throws TLA2BException {
TLAType t = visitExprOrOpArgNode(n.getArgs()[0], expected); TLAType t = visitExprOrOpArgNode(n.getArgs()[0], expected);
setTypeAndFollowers(n, t); setLocalTypeAndFollowers(n, t);
for (int i = 1; i < n.getArgs().length; i++) { // start at 1 for (int i = 1; i < n.getArgs().length; i++) { // start at 1
OpApplNode pair = (OpApplNode) n.getArgs()[i]; OpApplNode pair = (OpApplNode) n.getArgs()[i];
// stored for @ node // stored for @ node
UntypedType untyped = new UntypedType(); UntypedType untyped = new UntypedType();
setTypeAndFollowers(pair.getArgs()[1], untyped); setLocalTypeAndFollowers(pair.getArgs()[1], untyped);
TLAType valueType = visitExprOrOpArgNode(pair.getArgs()[1], untyped); // right side TLAType valueType = visitExprOrOpArgNode(pair.getArgs()[1], untyped); // right side
OpApplNode seq = (OpApplNode) pair.getArgs()[0]; // left side OpApplNode seq = (OpApplNode) pair.getArgs()[0]; // left side
...@@ -726,7 +756,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -726,7 +756,7 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
domList.add(visitExprOrOpArgNode(arg, new UntypedType())); domList.add(visitExprOrOpArgNode(arg, new UntypedType()));
} }
domType = new TupleType(domList); domType = new TupleType(domList);
setType(domExpr, domType); // store type setLocalType(domExpr, domType); // store type
} else { } else {
domType = visitExprOrOpArgNode(domExpr, new UntypedType()); domType = visitExprOrOpArgNode(domExpr, new UntypedType());
} }
...@@ -764,14 +794,10 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -764,14 +794,10 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
case B_OPCODE_leq: // <= case B_OPCODE_leq: // <=
case B_OPCODE_geq: { // >= case B_OPCODE_geq: { // >=
TLAType boolType = unify(BoolType.getInstance(), expected, n); TLAType boolType = unify(BoolType.getInstance(), expected, n);
try { TLAType numberType = new IntegerOrRealType();
for (ExprOrOpArgNode arg : n.getArgs()) {
visitExprOrOpArgNode(arg, IntType.getInstance());
}
} catch (TypeErrorException e) {
for (ExprOrOpArgNode arg : n.getArgs()) { for (ExprOrOpArgNode arg : n.getArgs()) {
visitExprOrOpArgNode(arg, RealType.getInstance()); numberType = visitExprOrOpArgNode(arg, numberType);
} setLocalTypeAndFollowers(arg, numberType);
} }
return boolType; return boolType;
} }
...@@ -780,24 +806,15 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -780,24 +806,15 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
// for UntypedTypes the default is integer; if this leads to a TypeErrorException real is tried instead // for UntypedTypes the default is integer; if this leads to a TypeErrorException real is tried instead
case B_OPCODE_plus: // + case B_OPCODE_plus: // +
case B_OPCODE_minus: // - case B_OPCODE_minus: // -
case B_OPCODE_uminus: // -x
case B_OPCODE_times: // * case B_OPCODE_times: // *
case B_OPCODE_div: // / case B_OPCODE_div: // /
case B_OPCODE_mod: // % modulo case B_OPCODE_mod: // % modulo
case B_OPCODE_exp: { // x hoch y, x^y case B_OPCODE_exp: { // x to the power of y, x^y
TLAType type; TLAType type = unify(new IntegerOrRealType(), expected, n);
try {
IntType.getInstance().unify(expected); // throws UnificationException
type = IntType.getInstance();
for (ExprOrOpArgNode arg : n.getArgs()) { for (ExprOrOpArgNode arg : n.getArgs()) {
// throws TypeErrorException; check whether IntType is OK, else try the same with RealType type = visitExprOrOpArgNode(arg, type);
visitExprOrOpArgNode(arg, type); setLocalTypeAndFollowers(arg, type);
}
} catch (UnificationException | TypeErrorException e) {
type = unify(RealType.getInstance(), expected, n);
for (ExprOrOpArgNode arg : n.getArgs()) {
// if TypeErrorException is thrown here, the type is incompatible and it is a real type error!
visitExprOrOpArgNode(arg, type);
}
} }
return type; return type;
} }
...@@ -830,21 +847,6 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -830,21 +847,6 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
case B_OPCODE_real: // Real case B_OPCODE_real: // Real
return unify(new SetType(RealType.getInstance()), expected, n); return unify(new SetType(RealType.getInstance()), expected, n);
case B_OPCODE_uminus: { // -x
TLAType type;
try {
IntType.getInstance().unify(expected); // throws UnificationException
type = IntType.getInstance();
// throws TypeErrorException; check whether IntType is OK, else try the same with RealType
visitExprOrOpArgNode(n.getArgs()[0], type);
} catch (UnificationException | TypeErrorException e) {
type = unify(RealType.getInstance(), expected, n);
// if TypeErrorException is thrown here, the type is incompatible and it is a real type error!
visitExprOrOpArgNode(n.getArgs()[0], type);
}
return type;
}
/* /*
* Standard Module FiniteSets * Standard Module FiniteSets
*/ */
...@@ -975,34 +977,74 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -975,34 +977,74 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
/* /*
* Utility methods * Utility methods
*/ */
private static void setTypeAndFollowers(SemanticNode node, TLAType type, int paramId) { private static boolean hasGlobalTyping(SemanticNode node) {
setType(node, type, paramId); SymbolNode symbol = null;
if (node instanceof SymbolNode) {
symbol = (SymbolNode) node;
} else if (node instanceof OpApplNode) {
symbol = ((OpApplNode) node).getOperator();
}
return symbol != null && (symbol.getKind() == ConstantDeclKind || symbol.getKind() == VariableDeclKind);
}
private void setLocalTypeAndFollowers(SemanticNode node, TLAType type) {
setLocalType(node, type);
if (type instanceof AbstractHasFollowers) { if (type instanceof AbstractHasFollowers) {
((AbstractHasFollowers) type).addFollower(node); ((AbstractHasFollowers) type).addFollower(node);
} }
} }
private static void setTypeAndFollowers(SemanticNode node, TLAType type) { public static void updateTypeAndFollowers(SemanticNode node, TLAType oldType, TLAType newType) {
setTypeAndFollowers(node, type, TYPE_ID); if (getType(node, TYPE_ID) == oldType) {
setType(node, newType, TYPE_ID);
if (newType instanceof AbstractHasFollowers) {
((AbstractHasFollowers) newType).addFollower(node);
}
}
if (getType(node, TEMP_TYPE_ID) == oldType) {
setType(node, newType, TEMP_TYPE_ID);
if (newType instanceof AbstractHasFollowers) {
((AbstractHasFollowers) newType).addFollower(node);
}
}
} }
private static void setType(SemanticNode node, TLAType type, int paramId) { private static void setType(SemanticNode node, TLAType type, int paramId) {
node.setToolObject(paramId, type); node.setToolObject(paramId, type);
} }
private void setLocalType(SemanticNode node, TLAType type) {
if (type instanceof IDefaultableType) {
this.possibleUnfinishedTypes.add((IDefaultableType) type);
}
if (paramId != TYPE_ID && hasGlobalTyping(node)) {
setType(node, type, TYPE_ID);
} else {
setType(node, type, paramId);
}
}
public static void setType(SemanticNode node, TLAType type) { public static void setType(SemanticNode node, TLAType type) {
setType(node, type, TYPE_ID); setType(node, type, TYPE_ID);
} }
private static TLAType getType(SemanticNode n, int paramId) { private static TLAType getType(SemanticNode node, int paramId) {
return (TLAType) n.getToolObject(paramId); return (TLAType) node.getToolObject(paramId);
}
private TLAType getLocalType(SemanticNode node) {
if (paramId != TYPE_ID && hasGlobalTyping(node)) {
return getType(node, TYPE_ID);
}
return getType(node, paramId);
} }
public static TLAType getType(SemanticNode n) { public static TLAType getType(SemanticNode node) {
return getType(n, TYPE_ID); return getType(node, TYPE_ID);
} }
private TLAType unify(TLAType toUnify, TLAType expected, String opMsg, SemanticNode n) throws TypeErrorException { private static TLAType unify(TLAType toUnify, TLAType expected, String opMsg, SemanticNode n) throws TypeErrorException {
TLAType found = toUnify; TLAType found = toUnify;
DebugUtils.printDebugMsg("Unify " + found + " and " + expected + " at '" + opMsg + "' (" + n.getLocation() + ")"); DebugUtils.printDebugMsg("Unify " + found + " and " + expected + " at '" + opMsg + "' (" + n.getLocation() + ")");
try { try {
...@@ -1014,19 +1056,19 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo ...@@ -1014,19 +1056,19 @@ public class TypeChecker extends BuiltInOPs implements BBuildIns, TranslationGlo
return found; return found;
} }
private TLAType unify(TLAType toUnify, TLAType expected, OpApplNode n) throws TypeErrorException { private static TLAType unify(TLAType toUnify, TLAType expected, OpApplNode n) throws TypeErrorException {
return unify(toUnify, expected, n.getOperator().getName().toString(), n); return unify(toUnify, expected, n.getOperator().getName().toString(), n);
} }
private TLAType unifyAndSetTypeWithFollowers(TLAType toUnify, TLAType expected, String opMsg, SemanticNode n) throws TypeErrorException { private TLAType unifyAndSetLocalTypeWithFollowers(TLAType toUnify, TLAType expected, String opMsg, SemanticNode n) throws TypeErrorException {
TLAType found = unify(toUnify, expected, opMsg, n); TLAType found = unify(toUnify, expected, opMsg, n);
setTypeAndFollowers(n, found); setLocalTypeAndFollowers(n, found);
return found; return found;
} }
private TLAType unifyAndSetType(TLAType toUnify, TLAType expected, String opMsg, SemanticNode n) throws TypeErrorException { private TLAType unifyAndSetLocalType(TLAType toUnify, TLAType expected, String opMsg, SemanticNode n) throws TypeErrorException {
TLAType found = unify(toUnify, expected, opMsg, n); TLAType found = unify(toUnify, expected, opMsg, n);
setType(n, found); setLocalType(n, found);
return found; return found;
} }
} }
...@@ -78,6 +78,11 @@ public class TlaTypePrinter extends ClassicalPositionPrinter implements TypeVisi ...@@ -78,6 +78,11 @@ public class TlaTypePrinter extends ClassicalPositionPrinter implements TypeVisi
pout.printAtom("real"); pout.printAtom("real");
} }
@Override
public void caseIntegerOrRealType(IntegerOrRealType type) {
throw new NotImplementedException("should not happen");
}
public void caseSetType(SetType type) { public void caseSetType(SetType type) {
pout.openTerm("set"); pout.openTerm("set");
type.getSubType().apply(this); type.getSubType().apply(this);
......
...@@ -18,6 +18,8 @@ public interface TypeVisitorInterface { ...@@ -18,6 +18,8 @@ public interface TypeVisitorInterface {
void caseRealType(RealType type); void caseRealType(RealType type);
void caseIntegerOrRealType(IntegerOrRealType type);
void caseSetType(SetType type); void caseSetType(SetType type);
void caseStringType(StringType type); void caseStringType(StringType type);
......
package de.tla2b.types; package de.tla2b.types;
import de.tla2b.analysis.TypeChecker;
import tla2sany.semantic.SemanticNode;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import de.tla2b.analysis.TypeChecker;
import tla2sany.semantic.SemanticNode;
public abstract class AbstractHasFollowers extends TLAType { public abstract class AbstractHasFollowers extends TLAType {
private List<Object> followers = new ArrayList<>(); private List<Object> followers = null;
public AbstractHasFollowers(int t) { public AbstractHasFollowers(int t) {
super(t); super(t);
} }
public List<Object> getFollowers() {
return followers;
}
public void addFollower(Object o) { public void addFollower(Object o) {
if (followers == null) {
followers = new ArrayList<>();
}
// only (partial) untyped types need follower // only (partial) untyped types need follower
if (followers != null && !followers.contains(o)) if (!followers.contains(o)) {
followers.add(o); followers.add(o);
} }
}
public void deleteFollowers() { public void deleteFollowers() {
followers = null; followers = null;
} }
public void removeFollower(Object o) { public void removeFollower(Object o) {
if (this.hasFollowers()) {
followers.remove(o); followers.remove(o);
} }
}
protected void setFollowersTo(TLAType newType) { protected void setFollowersTo(TLAType newType) {
if (this.followers == null) if (!this.hasFollowers()) {
return; return;
}
// avoid concurrent modification: // avoid concurrent modification:
new ArrayList<>(followers).forEach(follower -> { new ArrayList<>(followers).forEach(follower -> {
//this.removeFollower(follower);
if (follower instanceof SemanticNode) { if (follower instanceof SemanticNode) {
TypeChecker.setType((SemanticNode) follower, newType); TypeChecker.updateTypeAndFollowers((SemanticNode) follower, this, newType);
if (newType instanceof AbstractHasFollowers) {
((AbstractHasFollowers) newType).addFollower(follower);
}
} else if (follower instanceof SetType) { } else if (follower instanceof SetType) {
((SetType) follower).setSubType(newType); ((SetType) follower).update(this, newType);
} else if (follower instanceof TupleType) { } else if (follower instanceof TupleType) {
((TupleType) follower).update(this, newType); ((TupleType) follower).update(this, newType);
} else if (follower instanceof PairType) { } else if (follower instanceof PairType) {
PairType pair = ((PairType) follower); ((PairType) follower).update(this, newType);
if (pair.getFirst() == this) {
pair.setFirst(newType);
}
if (pair.getSecond() == this) {
pair.setSecond(newType);
}
} else if (follower instanceof FunctionType) { } else if (follower instanceof FunctionType) {
((FunctionType) follower).update(this, newType); ((FunctionType) follower).update(this, newType);
} else if (follower instanceof StructType) { } else if (follower instanceof StructType) {
...@@ -68,7 +64,7 @@ public abstract class AbstractHasFollowers extends TLAType { ...@@ -68,7 +64,7 @@ public abstract class AbstractHasFollowers extends TLAType {
}); });
} }
public boolean hasFollower() { public boolean hasFollowers() {
return !followers.isEmpty(); return followers != null && !followers.isEmpty();
} }
} }
...@@ -26,11 +26,13 @@ public class FunctionType extends AbstractHasFollowers { ...@@ -26,11 +26,13 @@ public class FunctionType extends AbstractHasFollowers {
} }
public void update(TLAType oldType, TLAType newType) { public void update(TLAType oldType, TLAType newType) {
if (domain == oldType) if (domain == oldType) {
setDomain(newType); setDomain(newType);
if (range == oldType) }
if (range == oldType) {
setRange(newType); setRange(newType);
} }
}
@Override @Override
public boolean compare(TLAType other) { public boolean compare(TLAType other) {
......
package de.tla2b.types;
public interface IDefaultableType {
TLAType setToDefault();
}
...@@ -29,15 +29,15 @@ public class IntType extends TLAType { ...@@ -29,15 +29,15 @@ public class IntType extends TLAType {
@Override @Override
public boolean compare(TLAType o) { public boolean compare(TLAType o) {
return o.getKind() == UNTYPED || o.getKind() == INTEGER; return o.getKind() == UNTYPED || o.getKind() == INTEGER || o.getKind() == INTEGER_OR_REAL;
} }
@Override @Override
public IntType unify(TLAType o) throws UnificationException { public IntType unify(TLAType o) throws UnificationException {
if (o.getKind() == INTEGER) { if (o.getKind() == INTEGER) {
return this; return this;
} else if (o instanceof UntypedType) { } else if (o.getKind() == INTEGER_OR_REAL || o instanceof UntypedType) {
((UntypedType) o).setFollowersTo(this); ((AbstractHasFollowers) o).setFollowersTo(this);
return this; return this;
} else } else
throw new UnificationException(); throw new UnificationException();
......
package de.tla2b.types;
import de.be4.classicalb.core.parser.node.PExpression;
import de.tla2b.exceptions.UnificationException;
import de.tla2b.output.TypeVisitorInterface;
public final class IntegerOrRealType extends AbstractHasFollowers implements IDefaultableType {
public IntegerOrRealType() {
super(INTEGER_OR_REAL);
}
@Override
public String toString() {
return "INTEGER_OR_REAL";
}
@Override
public boolean isUntyped() {
return true;
}
@Override
public boolean compare(TLAType o) {
return o.getKind() == UNTYPED || o.getKind() == INTEGER || o.getKind() == REAL || o.getKind() == INTEGER_OR_REAL;
}
@Override
public TLAType unify(TLAType o) throws UnificationException {
if (o.getKind() == REAL || o.getKind() == INTEGER) {
return o;
} else if (o.getKind() == INTEGER_OR_REAL || o instanceof UntypedType) {
((AbstractHasFollowers) o).setFollowersTo(this);
return this;
} else {
throw new UnificationException();
}
}
@Override
public IntegerOrRealType cloneTLAType() {
return new IntegerOrRealType();
}
@Override
public boolean contains(TLAType o) {
return false;
}
@Override
public TLAType setToDefault() {
TLAType type = IntType.getInstance();
this.setFollowersTo(type);
return type;
}
@Override
public PExpression getBNode() {
throw new UnsupportedOperationException("IntegerOrRealType has no corresponding B node.");
}
public void apply(TypeVisitorInterface visitor) {
visitor.caseIntegerOrRealType(this);
}
}
...@@ -20,13 +20,11 @@ public class PairType extends AbstractHasFollowers { ...@@ -20,13 +20,11 @@ public class PairType extends AbstractHasFollowers {
super(PAIR); super(PAIR);
this.first = f; this.first = f;
if (first instanceof AbstractHasFollowers) { if (first instanceof AbstractHasFollowers) {
AbstractHasFollowers firstHasFollowers = (AbstractHasFollowers) first; ((AbstractHasFollowers) first).addFollower(this);
firstHasFollowers.addFollower(this);
} }
this.second = s; this.second = s;
if (second instanceof AbstractHasFollowers) { if (second instanceof AbstractHasFollowers) {
AbstractHasFollowers secondHasFollowers = (AbstractHasFollowers) second; ((AbstractHasFollowers) second).addFollower(this);
secondHasFollowers.addFollower(this);
} }
} }
...@@ -68,6 +66,15 @@ public class PairType extends AbstractHasFollowers { ...@@ -68,6 +66,15 @@ public class PairType extends AbstractHasFollowers {
} }
} }
public void update(TLAType oldType, TLAType newType) {
if (this.first == oldType) {
this.setFirst(newType);
}
if (this.second == oldType) {
this.setSecond(newType);
}
}
@Override @Override
public boolean isUntyped() { public boolean isUntyped() {
return first.isUntyped() || second.isUntyped(); return first.isUntyped() || second.isUntyped();
......
...@@ -29,15 +29,15 @@ public class RealType extends TLAType { ...@@ -29,15 +29,15 @@ public class RealType extends TLAType {
@Override @Override
public boolean compare(TLAType o) { public boolean compare(TLAType o) {
return o.getKind() == UNTYPED || o.getKind() == REAL; return o.getKind() == UNTYPED || o.getKind() == REAL || o.getKind() == INTEGER_OR_REAL;
} }
@Override @Override
public RealType unify(TLAType o) throws UnificationException { public RealType unify(TLAType o) throws UnificationException {
if (o.getKind() == REAL) { if (o.getKind() == REAL) {
return this; return this;
} else if (o instanceof UntypedType) { } else if (o.getKind() == INTEGER_OR_REAL || o instanceof UntypedType) {
((UntypedType) o).setFollowersTo(this); ((AbstractHasFollowers) o).setFollowersTo(this);
return this; return this;
} else } else
throw new UnificationException(); throw new UnificationException();
......
...@@ -18,16 +18,11 @@ public class SetType extends AbstractHasFollowers { ...@@ -18,16 +18,11 @@ public class SetType extends AbstractHasFollowers {
} }
public void setSubType(TLAType t) { public void setSubType(TLAType t) {
// if (subType instanceof AbstractHasFollowers) { this.subType = t;
// // delete old reference
// ((AbstractHasFollowers) subType).removeFollower(this);
// }
if (t instanceof AbstractHasFollowers) { if (t instanceof AbstractHasFollowers) {
// set new reference // set new reference
((AbstractHasFollowers) t).addFollower(this); ((AbstractHasFollowers) t).addFollower(this);
} }
this.subType = t;
// setting subType can lead to a completely typed type // setting subType can lead to a completely typed type
if (!this.isUntyped()) { if (!this.isUntyped()) {
...@@ -36,6 +31,12 @@ public class SetType extends AbstractHasFollowers { ...@@ -36,6 +31,12 @@ public class SetType extends AbstractHasFollowers {
} }
} }
public void update(TLAType oldType, TLAType newType) {
if (this.subType == oldType) {
this.setSubType(newType);
}
}
public SetType unify(TLAType o) throws UnificationException { public SetType unify(TLAType o) throws UnificationException {
if (!this.compare(o) || this.contains(o)) { if (!this.compare(o) || this.contains(o)) {
throw new UnificationException(); throw new UnificationException();
......
...@@ -7,8 +7,6 @@ import de.tla2b.output.TypeVisitorInterface; ...@@ -7,8 +7,6 @@ import de.tla2b.output.TypeVisitorInterface;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
public class StructOrFunctionType extends AbstractHasFollowers { public class StructOrFunctionType extends AbstractHasFollowers {
private final Map<String, TLAType> types = new LinkedHashMap<>(); private final Map<String, TLAType> types = new LinkedHashMap<>();
...@@ -22,13 +20,13 @@ public class StructOrFunctionType extends AbstractHasFollowers { ...@@ -22,13 +20,13 @@ public class StructOrFunctionType extends AbstractHasFollowers {
super(STRUCT_OR_FUNCTION); super(STRUCT_OR_FUNCTION);
} }
public void setNewType(TLAType old, TLAType New) { public void setNewType(TLAType oldType, TLAType newType) {
types.forEach((name, type) -> { types.forEach((name, type) -> {
if (type == old) { if (type == oldType) {
if (New instanceof AbstractHasFollowers) { // set new reference types.put(name, newType);
((AbstractHasFollowers) New).addFollower(this); if (newType instanceof AbstractHasFollowers) { // set new reference
((AbstractHasFollowers) newType).addFollower(this);
} }
types.put(name, New);
} }
}); });
testRecord(); testRecord();
......
...@@ -33,16 +33,17 @@ public class StructType extends AbstractHasFollowers { ...@@ -33,16 +33,17 @@ public class StructType extends AbstractHasFollowers {
} }
public void add(String name, TLAType type) { public void add(String name, TLAType type) {
types.put(name, type);
if (type instanceof AbstractHasFollowers) { // set new reference if (type instanceof AbstractHasFollowers) { // set new reference
((AbstractHasFollowers) type).addFollower(this); ((AbstractHasFollowers) type).addFollower(this);
} }
types.put(name, type);
} }
public void setNewType(TLAType old, TLAType New) { public void setNewType(TLAType old, TLAType newType) {
types.forEach((name, type) -> { types.forEach((name, type) -> {
if (type == old) if (type == old) {
add(name, New); add(name, newType);
}
}); });
} }
......
...@@ -20,6 +20,7 @@ public abstract class TLAType { ...@@ -20,6 +20,7 @@ public abstract class TLAType {
static int TUPLE = 11; static int TUPLE = 11;
static int TUPLE_OR_FUNCTION = 12; static int TUPLE_OR_FUNCTION = 12;
static int REAL = 13; static int REAL = 13;
static int INTEGER_OR_REAL = 14;
private final int kind; private final int kind;
......
...@@ -7,7 +7,7 @@ import de.tla2b.output.TypeVisitorInterface; ...@@ -7,7 +7,7 @@ import de.tla2b.output.TypeVisitorInterface;
import java.util.*; import java.util.*;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public class TupleOrFunction extends AbstractHasFollowers { public class TupleOrFunction extends AbstractHasFollowers implements IDefaultableType {
private final Map<Integer, TLAType> types = new LinkedHashMap<>(); private final Map<Integer, TLAType> types = new LinkedHashMap<>();
public TupleOrFunction(Integer index, TLAType type) { public TupleOrFunction(Integer index, TLAType type) {
...@@ -249,7 +249,8 @@ public class TupleOrFunction extends AbstractHasFollowers { ...@@ -249,7 +249,8 @@ public class TupleOrFunction extends AbstractHasFollowers {
update(); update();
} }
public TLAType getFinalType() { @Override
public TLAType setToDefault() {
List<TLAType> list = new ArrayList<>(this.types.values()); List<TLAType> list = new ArrayList<>(this.types.values());
if (comparable(list)) { if (comparable(list)) {
FunctionType func = new FunctionType(IntType.getInstance(), new UntypedType()); FunctionType func = new FunctionType(IntType.getInstance(), new UntypedType());
......
...@@ -4,7 +4,6 @@ import de.be4.classicalb.core.parser.node.PExpression; ...@@ -4,7 +4,6 @@ import de.be4.classicalb.core.parser.node.PExpression;
import de.be4.classicalb.core.parser.util.ASTBuilder; import de.be4.classicalb.core.parser.util.ASTBuilder;
import de.tla2b.exceptions.UnificationException; import de.tla2b.exceptions.UnificationException;
import de.tla2b.output.TypeVisitorInterface; import de.tla2b.output.TypeVisitorInterface;
import de.tla2bAst.BAstCreator;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
...@@ -46,10 +45,11 @@ public class TupleType extends AbstractHasFollowers { ...@@ -46,10 +45,11 @@ public class TupleType extends AbstractHasFollowers {
TLAType t = types.get(i); TLAType t = types.get(i);
if (oldType == t) { if (oldType == t) {
types.set(i, newType); types.set(i, newType);
if (newType instanceof AbstractHasFollowers) {
((AbstractHasFollowers) newType).addFollower(this);
}
} }
} }
if (oldType instanceof AbstractHasFollowers)
((AbstractHasFollowers) oldType).addFollower(this);
} }
@Override @Override
......
...@@ -4,7 +4,7 @@ import de.be4.classicalb.core.parser.node.PExpression; ...@@ -4,7 +4,7 @@ import de.be4.classicalb.core.parser.node.PExpression;
import de.tla2b.exceptions.UnificationException; import de.tla2b.exceptions.UnificationException;
import de.tla2b.output.TypeVisitorInterface; import de.tla2b.output.TypeVisitorInterface;
public class UntypedType extends AbstractHasFollowers { public final class UntypedType extends AbstractHasFollowers {
public UntypedType() { public UntypedType() {
super(UNTYPED); super(UNTYPED);
...@@ -14,6 +14,12 @@ public class UntypedType extends AbstractHasFollowers { ...@@ -14,6 +14,12 @@ public class UntypedType extends AbstractHasFollowers {
if (!this.compare(o)) { if (!this.compare(o)) {
throw new UnificationException(); throw new UnificationException();
} }
// if the other type is just an empty untyped one we can return this
if (o instanceof UntypedType && !((UntypedType) o).hasFollowers()) {
return this;
}
// u2 contains more or equal type information than untyped (this) // u2 contains more or equal type information than untyped (this)
this.setFollowersTo(o); this.setFollowersTo(o);
//this.deleteFollowers(); //this.deleteFollowers();
......
...@@ -8,7 +8,6 @@ import org.junit.Test; ...@@ -8,7 +8,6 @@ import org.junit.Test;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
public class DefinitionsTest { public class DefinitionsTest {
/* /*
...@@ -48,7 +47,8 @@ public class DefinitionsTest { ...@@ -48,7 +47,8 @@ public class DefinitionsTest {
+ "CONSTANTS k \n" + "CONSTANTS k \n"
+ "foo == k \n" + "foo == k \n"
+ "bar == foo = 1 \n" + "bar == foo = 1 \n"
+ "ASSUME bar \n" + "================================="; + "ASSUME bar \n"
+ "=================================";
TestTypeChecker t = TestUtil.typeCheckString(module); TestTypeChecker t = TestUtil.typeCheckString(module);
assertEquals("INTEGER", t.getDefinitionType("foo")); assertEquals("INTEGER", t.getDefinitionType("foo"));
assertEquals("BOOL", t.getDefinitionType("bar")); assertEquals("BOOL", t.getDefinitionType("bar"));
...@@ -150,30 +150,72 @@ public class DefinitionsTest { ...@@ -150,30 +150,72 @@ public class DefinitionsTest {
+ "CONSTANTS k, k2 \n" + "CONSTANTS k, k2 \n"
+ "foo(a,b) == a = b \n" + "foo(a,b) == a = b \n"
+ "bar == foo(k, k2) /\\ k2 = 1 \n" + "bar == foo(k, k2) /\\ k2 = 1 \n"
+ "ASSUME bar \n" + "ASSUME bar /\\ foo(TRUE,TRUE) \n"
+ "================================="; + "=================================";
TestTypeChecker t = TestUtil.typeCheckString(module); TestTypeChecker t = TestUtil.typeCheckString(module);
assertEquals("BOOL", t.getDefinitionType("foo")); assertEquals("BOOL", t.getDefinitionType("foo"));
assertTrue(t.getDefinitionParamType("foo", "a").startsWith("UNTYPED")); assertTrue(t.getDefinitionParamType("foo", "a").startsWith("UNTYPED"));
assertTrue(t.getDefinitionParamType("foo", "b").startsWith("UNTYPED")); assertEquals(t.getDefinitionParamType("foo", "a"), t.getDefinitionParamType("foo", "b"));
assertEquals("BOOL", t.getDefinitionType("bar")); assertEquals("BOOL", t.getDefinitionType("bar"));
assertEquals("INTEGER", t.getConstantType("k")); assertEquals("INTEGER", t.getConstantType("k"));
assertEquals("INTEGER", t.getConstantType("k2")); assertEquals("INTEGER", t.getConstantType("k2"));
} }
@Test
public void testDefinitionCall7Simpler1() throws TLA2BException {
final String module = "-------------- MODULE Testing ----------------\n"
+ "CONSTANTS k \n"
+ "foo(x) == x = {1} \n"
+ "ASSUME foo(k) \n"
+ "=================================";
TestTypeChecker t = TestUtil.typeCheckString(module);
assertEquals("BOOL", t.getDefinitionType("foo"));
assertEquals("POW(INTEGER)", t.getDefinitionParamType("foo", "x"));
assertEquals("POW(INTEGER)", t.getConstantType("k"));
}
@Test
public void testDefinitionCall7Simpler2() throws TLA2BException {
final String module = "-------------- MODULE Testing ----------------\n"
+ "CONSTANTS k \n"
+ "foo(a,b) == a \\cup b \n"
+ "ASSUME foo(k,{}) = {1} \n"
+ "=================================";
TestTypeChecker t = TestUtil.typeCheckString(module);
assertTrue(t.getDefinitionType("foo").startsWith("POW(UNTYPED"));
assertTrue(t.getDefinitionParamType("foo", "a").startsWith("POW(UNTYPED"));
assertEquals(t.getDefinitionParamType("foo", "a"), t.getDefinitionParamType("foo", "b"));
assertEquals("POW(INTEGER)", t.getConstantType("k"));
}
@Test
public void testDefinitionCall7Simpler3() throws TLA2BException {
final String module = "-------------- MODULE Testing ----------------\n"
+ "CONSTANTS k, k2, k3 \n"
+ "foo(a,b) == a \\cup b \n"
+ "ASSUME k2 = foo(k3, k) /\\ k3 = {1} \n"
+ "=================================";
TestTypeChecker t = TestUtil.typeCheckString(module);
assertTrue(t.getDefinitionType("foo").startsWith("POW(UNTYPED"));
assertTrue(t.getDefinitionParamType("foo", "a").startsWith("POW(UNTYPED"));
assertEquals(t.getDefinitionParamType("foo", "a"), t.getDefinitionParamType("foo", "b"));
assertEquals("POW(INTEGER)", t.getConstantType("k"));
assertEquals("POW(INTEGER)", t.getConstantType("k2"));
assertEquals("POW(INTEGER)", t.getConstantType("k3"));
}
@Test @Test
public void testDefinitionCall7() throws TLA2BException { public void testDefinitionCall7() throws TLA2BException {
final String module = "-------------- MODULE Testing ----------------\n" final String module = "-------------- MODULE Testing ----------------\n"
+ "CONSTANTS k, k2, k3 \n" + "CONSTANTS k, k2, k3 \n"
+ "foo(a,b) == a \\cup b \n" + "foo(a,b) == a \\cup b \n"
+ "bar(x,y) == x = foo(y, k) /\\ y = {1} \n" + "bar(x,y) == x = foo(y, k) /\\ y = {1} \n"
+ "ASSUME bar(k2,k3) \n" + "================================="; + "ASSUME bar(k2,k3) \n"
+ "=================================";
TestTypeChecker t = TestUtil.typeCheckString(module); TestTypeChecker t = TestUtil.typeCheckString(module);
assertTrue(t.getDefinitionType("foo").startsWith("POW(UNTYPED")); assertTrue(t.getDefinitionType("foo").startsWith("POW(UNTYPED"));
assertTrue(t.getDefinitionParamType("foo", "a").startsWith( assertTrue(t.getDefinitionParamType("foo", "a").startsWith("POW(UNTYPED"));
"POW(UNTYPED")); assertEquals(t.getDefinitionParamType("foo", "a"), t.getDefinitionParamType("foo", "b"));
assertTrue(t.getDefinitionParamType("foo", "b").startsWith(
"POW(UNTYPED"));
assertEquals("BOOL", t.getDefinitionType("bar")); assertEquals("BOOL", t.getDefinitionType("bar"));
assertEquals("POW(INTEGER)", t.getDefinitionParamType("bar", "x")); assertEquals("POW(INTEGER)", t.getDefinitionParamType("bar", "x"));
assertEquals("POW(INTEGER)", t.getDefinitionParamType("bar", "y")); assertEquals("POW(INTEGER)", t.getDefinitionParamType("bar", "y"));
...@@ -187,7 +229,8 @@ public class DefinitionsTest { ...@@ -187,7 +229,8 @@ public class DefinitionsTest {
+ "foo(a) == k = a \n" + "foo(a) == k = a \n"
+ "bar == foo(k2) \n" + "bar == foo(k2) \n"
+ "baz == k2 = 1 \n" + "baz == k2 = 1 \n"
+ "ASSUME bar /\\ baz \n" + "================================="; + "ASSUME bar /\\ baz \n"
+ "=================================";
TestTypeChecker t = TestUtil.typeCheckString(module); TestTypeChecker t = TestUtil.typeCheckString(module);
assertEquals("INTEGER", t.getConstantType("k")); assertEquals("INTEGER", t.getConstantType("k"));
assertEquals("INTEGER", t.getConstantType("k2")); assertEquals("INTEGER", t.getConstantType("k2"));
......
...@@ -119,14 +119,15 @@ public class TupleTest { ...@@ -119,14 +119,15 @@ public class TupleTest {
assertEquals("POW(INTEGER*BOOL)", t.getConstantType("k")); assertEquals("POW(INTEGER*BOOL)", t.getConstantType("k"));
} }
@Test(expected = TypeErrorException.class) @Test
public void testTuple2Elements() throws TLA2BException { public void testTuple2Elements() throws TLA2BException {
final String module = "-------------- MODULE Testing ----------------\n" final String module = "-------------- MODULE Testing ----------------\n"
+ "CONSTANTS k, k2, k3 \n" + "CONSTANTS k, k2, k3 \n"
+ "ASSUME k = <<k2, k3>> /\\ k3 = TRUE \n" + "ASSUME k = <<k2, k3>> /\\ k3 = TRUE \n"
+ "================================="; + "=================================";
TestUtil.typeCheckString(module); TestTypeChecker t = TestUtil.typeCheckString(module);
assertEquals("POW(INTEGER*BOOL)", t.getConstantType("k"));
} }
@Test @Test
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment