From e78f0a03216859c6714a731b4901601880a5a9e7 Mon Sep 17 00:00:00 2001
From: Fabian Vu <Fabian.Vu@hhu.de>
Date: Sat, 23 Nov 2024 17:49:26 +0100
Subject: [PATCH] Fix identity for relation in C++

---
 btypes_big_integer/src/main/cpp/BRelation.hpp   | 16 ++++++++++++++++
 btypes_primitives/src/main/cpp/BRelation.hpp    | 17 +++++++++++++++++
 .../generators/ExpressionGenerator.java         |  5 +++++
 .../de/hhu/stups/codegenerator/CppTemplate.stg  |  6 +++++-
 4 files changed, 43 insertions(+), 1 deletion(-)

diff --git a/btypes_big_integer/src/main/cpp/BRelation.hpp b/btypes_big_integer/src/main/cpp/BRelation.hpp
index 0219e90ef..9fff4ac93 100644
--- a/btypes_big_integer/src/main/cpp/BRelation.hpp
+++ b/btypes_big_integer/src/main/cpp/BRelation.hpp
@@ -967,6 +967,22 @@ class BRelation : public BObject {
             return BRelation<T,T>(resultMap);
         }
 
+        static BRelation<BTuple<S,T>,BTuple<S,T>> identity2(const BRelation<S,T>& relation) {
+            immer::map<BTuple<S,T>,immer::set<BTuple<S,T>, typename BSet<BTuple<S,T>>::Hash, typename BSet<BTuple<S,T>>::HashEqual>,
+                                                               typename BSet<BTuple<S,T>>::Hash,
+                                                               typename BSet<BTuple<S,T>>::HashEqual> resultMap;
+            for(const std::pair<S,immer::set<T, typename BSet<T>::Hash, typename BSet<T>::HashEqual>>& pair : relation.map) {
+                T domainElement = pair.first;
+                immer::set<T, typename BSet<T>::Hash, typename BSet<T>::HashEqual> range = pair.second;
+                for(const T& rangeElement : range) {
+                    immer::set<BTuple<S,T>, typename BSet<BTuple<S,T>>::Hash, typename BSet<BTuple<S,T>>::HashEqual> range;
+                    range = range.insert(BTuple<S,T>(domainElement, rangeElement));
+                    resultMap = resultMap.set(BTuple<S,T>(domainElement, rangeElement), range);
+                }
+            }
+            return BRelation<BTuple<S,T>,BTuple<S,T>>(resultMap);
+        }
+
     	BRelation<S,S> iterate(const BInteger& n) const {
     		BRelation<S,S> thisRelation = (BRelation<S,S>) *this;
             BRelation<S,S> result = BRelation<S,S>::identity(this->domain()._union(thisRelation.range()));
diff --git a/btypes_primitives/src/main/cpp/BRelation.hpp b/btypes_primitives/src/main/cpp/BRelation.hpp
index 12097c5d9..f6f4eb08f 100644
--- a/btypes_primitives/src/main/cpp/BRelation.hpp
+++ b/btypes_primitives/src/main/cpp/BRelation.hpp
@@ -967,6 +967,23 @@ class BRelation : public BObject {
             return BRelation<T,T>(resultMap);
         }
 
+
+        static BRelation<BTuple<S,T>,BTuple<S,T>> identity2(const BRelation<S,T>& relation) {
+            immer::map<BTuple<S,T>,immer::set<BTuple<S,T>, typename BSet<BTuple<S,T>>::Hash, typename BSet<BTuple<S,T>>::HashEqual>,
+                                                               typename BSet<BTuple<S,T>>::Hash,
+                                                               typename BSet<BTuple<S,T>>::HashEqual> resultMap;
+            for(const std::pair<S,immer::set<T, typename BSet<T>::Hash, typename BSet<T>::HashEqual>>& pair : relation.map) {
+                T domainElement = pair.first;
+                immer::set<T, typename BSet<T>::Hash, typename BSet<T>::HashEqual> range = pair.second;
+                for(const T& rangeElement : range) {
+                    immer::set<BTuple<S,T>, typename BSet<BTuple<S,T>>::Hash, typename BSet<BTuple<S,T>>::HashEqual> range;
+                    range = range.insert(BTuple<S,T>(domainElement, rangeElement));
+                    resultMap = resultMap.set(BTuple<S,T>(domainElement, rangeElement), range);
+                }
+            }
+            return BRelation<BTuple<S,T>,BTuple<S,T>>(resultMap);
+        }
+
     	BRelation<S,S> iterate(const BInteger& n) const {
     		BRelation<S,S> thisRelation = (BRelation<S,S>) *this;
             BRelation<S,S> result = BRelation<S,S>::identity(this->domain()._union(thisRelation.range()));
diff --git a/src/main/java/de/hhu/stups/codegenerator/generators/ExpressionGenerator.java b/src/main/java/de/hhu/stups/codegenerator/generators/ExpressionGenerator.java
index 03bb7c3a4..b03809d24 100644
--- a/src/main/java/de/hhu/stups/codegenerator/generators/ExpressionGenerator.java
+++ b/src/main/java/de/hhu/stups/codegenerator/generators/ExpressionGenerator.java
@@ -794,8 +794,13 @@ public class ExpressionGenerator {
     */
     private String generateIdentity(List<String> expressionList, BType type) {
         ST identity = currentGroup.getInstanceOf("identity");
+        if(type instanceof CoupleType) {
+            TemplateHandler.add(identity, "leftType", typeGenerator.generate(((CoupleType) type).getLeft()));
+            TemplateHandler.add(identity, "rightType", typeGenerator.generate(((CoupleType) type).getRight()));
+        }
         TemplateHandler.add(identity, "type", typeGenerator.generate(type));
         TemplateHandler.add(identity, "arg", expressionList.get(0));
+        TemplateHandler.add(identity, "relationalArg", type instanceof CoupleType);
         return identity.render();
     }
 
diff --git a/src/main/resources/de/hhu/stups/codegenerator/CppTemplate.stg b/src/main/resources/de/hhu/stups/codegenerator/CppTemplate.stg
index 2a6377c51..5e8a39fa9 100644
--- a/src/main/resources/de/hhu/stups/codegenerator/CppTemplate.stg
+++ b/src/main/resources/de/hhu/stups/codegenerator/CppTemplate.stg
@@ -779,8 +779,12 @@ projection_tuple(arg, isProjection1) ::= <<
 (<arg>.<if(isProjection1)>projection1<else>projection2<endif>())
 >>
 
-identity(type, arg) ::= <<
+identity(leftType, rightType, type, arg, relationalArg) ::= <<
+<if(relationalArg)>
+(BRelation\<<leftType>, <rightType> >::identity2(<arg>))
+<else>
 (BRelation\<<type>, <type> >::identity(<arg>))
+<endif>
 >>
 
 cartesian_product(leftType, rightType, arg1, arg2) ::= <<
-- 
GitLab