]> git.mdlowis.com Git - proto/obnc.git/commitdiff
fixup ssa variable versioning logic
authorMichael D. Lowis <mike.lowis@gentex.com>
Wed, 26 Jan 2022 21:46:30 +0000 (16:46 -0500)
committerMichael D. Lowis <mike.lowis@gentex.com>
Wed, 26 Jan 2022 21:46:30 +0000 (16:46 -0500)
cerise/inc/cerise.h
cerise/src/grammar.c
cerise/src/ssa.c
cerise/tests/Module.m

index ad687020ba9535dd9329f4557ebd7e1911bd1e72..c66d6209762b7af8647c4df2b25957c3e8e837d1 100644 (file)
@@ -138,6 +138,8 @@ typedef struct SsaPhiVar {
 typedef struct SsaPhi {
     struct SsaPhi* next;
     size_t symid;
+    size_t backup_ver;
+    size_t latest_ver;
     SsaPhiVar* vars;
 } SsaPhi;
 
@@ -168,7 +170,7 @@ typedef struct Symbol {
     struct Symbol* desc;
     SsaNode* value;
     long nargs;
-    long version;
+    size_t version;
     size_t module;
     int export : 1;
     int global : 1;
@@ -256,6 +258,9 @@ bool ssa_asbool(SsaNode* node);
 long long ssa_asint(SsaNode* node);
 double ssa_asreal(SsaNode* node);
 
+void ssa_join(Parser* p);
+void ssa_reset_vars(Parser* p);
+
 SsaNode* ssa_ident(Parser* p, long long index);
 SsaNode* ssa_bool(Parser* p, bool val);
 SsaNode* ssa_int(Parser* p, long long val);
index c46d0ff269ec82ca1ac37c2a7b989904102d42d0..a7c884f67609c4ae5ca967642ef75c195d81c279 100644 (file)
@@ -417,8 +417,8 @@ static SsaBlock* statement_seq(Parser* p)
 
             SsaBlock* block = p->curr_block;
             block->links[0] = statement_seq(p);
-            /* reset vars to backup values */
             block->links[0]->links[0] = p->curr_join;
+            ssa_reset_vars(p);
 
             if (accept(p, ELSE))
             {
@@ -432,10 +432,8 @@ static SsaBlock* statement_seq(Parser* p)
 
             if_node->left.block = block->links[0];
             if_node->right.block = block->links[1];
+            ssa_join(p);
 
-            /* pop the join node */
-            p->curr_block = p->curr_join;
-            p->curr_join = p->curr_join->next;
             expect(p, END);
         }
         else if (matches(p, END) || matches(p,RETURN))
@@ -649,7 +647,6 @@ static void module(Parser* p)
 
     if (accept(p, BEGIN))
     {
-        //proc_start();
         SsaBlock* block = ssa_block(p);
         p->curr_join = ssa_block(p);
         block->links[1] = p->curr_join;
@@ -660,14 +657,12 @@ static void module(Parser* p)
             block->links[0] = seqblock;
         }
         expect(p, END);
+        ssa_join(p);
 
-//        ssa_print_block(p, block);
+        /* debug dump the result */
         ssa_print_graph(p, block);
-
         extern void ssa_print_asm(Parser* p, SsaBlock* block);
         ssa_print_asm(p, block);
-
-        //proc_end();
     }
 
     if (!matches(p, END_FILE))
index 50bf064315d526b5dead740ed444f37cf60b2f46..8c83492cfbba46d81a70062cae8ae1262f129df1 100644 (file)
@@ -25,6 +25,108 @@ double ssa_asreal(SsaNode* node)
     return (node->left.val.f);
 }
 
+/*
+    * store backup values for var versions, restore at the end of block
+    * popping a join node processes phis as stores propagating phis to outer join nodes
+    * multiple assignments in blocks update a single var entry.
+*/
+
+static size_t phi_add(Parser* p, SsaVar var)
+{
+    Symbol* sym = symbol_getbyid(p, var.symid);
+    printf("phi_add()\n");
+
+    printf("    %s.%lu\n", sym->name, sym->version);
+
+    /* first, append the phi function to the list */
+    SsaPhi** phis = &(p->curr_join->phis);
+    for (; *phis; phis = &((*phis)->next))
+    {
+        if ((*phis)->symid == var.symid)
+        {
+            break;
+        }
+    }
+    if (!*phis)
+    {
+        puts("new phi()");
+        *phis = calloc(1, sizeof(SsaPhi));
+        (*phis)->symid = var.symid;
+        (*phis)->backup_ver = sym->version;
+    }
+    else if ((*phis)->latest_ver > sym->version)
+    {
+        sym->version = (*phis)->latest_ver;
+        printf("    setting to latest: %s.%lu\n", sym->name, sym->version);
+    }
+
+    /* now add or update the variable entry */
+    SsaPhiVar** vars = &((*phis)->vars);
+    for (; *vars; vars = &((*vars)->next))
+    {
+        if ((*vars)->block == p->curr_block->id)
+        {
+            break;
+        }
+    }
+    if (!*vars)
+    {
+        *vars = calloc(1, sizeof(SsaPhiVar));
+        (*vars)->block = p->curr_block->id;
+        printf("    new var %s from block %lu\n", sym->name, p->curr_block->id);
+    }
+
+    sym->version++;
+    (*vars)->version = sym->version;
+    (*phis)->latest_ver = (*vars)->version;
+
+    printf("    %s.%lu\n", sym->name, sym->version);
+
+    return sym->version;
+}
+
+
+void ssa_reset_vars(Parser* p)
+{
+    puts("ssa_reset_vars");
+    for (SsaPhi* phi = p->curr_join->phis; phi; phi = phi->next)
+    {
+        Symbol* s = symbol_getbyid(p, phi->symid);
+        s->version = phi->backup_ver;
+    }
+}
+
+void ssa_join(Parser* p)
+{
+    puts("ssa_join");
+    ssa_reset_vars(p);
+
+    /* pop the join node off the list  since we're done with it */
+    SsaBlock* block = p->curr_join;
+    p->curr_block = p->curr_join;
+    p->curr_join = p->curr_join->next;
+
+    /* update variable versions based on phi functions */
+    for (SsaPhi* phi = block->phis; phi; phi = phi->next)
+    {
+        Symbol* s = symbol_getbyid(p, phi->symid);
+        s->version = phi->latest_ver;
+        phi->latest_ver++;
+        printf("block %lu: %s.%lu\n", block->id, s->name, s->version);
+
+        if (p->curr_join)
+        {
+            s->version = phi_add(p, (SsaVar){
+                .symid = phi->symid,
+                .symver = s->version
+            });
+        }
+
+        printf("block %lu: %s.%lu\n", block->id, s->name, s->version);
+    }
+}
+
+
 static SsaNode* ssa_node(int code, int mode)
 {
     SsaNode* node = calloc(1, sizeof(SsaNode));
@@ -92,59 +194,14 @@ SsaNode* ssa_op(Parser* p, int op, SsaNode* left, SsaNode* right)
         : unop(p, op, left);
 }
 
-static void phi_add(Parser* p, SsaVar var)
-{
-    /* for all enclosing join nodes */
-    for (SsaBlock* join = p->curr_join; join; join = join->next)
-    {
-        /* first, append the phi function to the list */
-        SsaPhi** phis = &(join->phis);
-        for (; *phis; phis = &((*phis)->next))
-        {
-            if ((*phis)->symid == var.symid)
-            {
-                break;
-            }
-        }
-        if (!*phis)
-        {
-            *phis = calloc(1, sizeof(SsaPhi));
-            (*phis)->symid = var.symid;
-        }
-
-        /* now add or update the variable entry */
-        SsaPhiVar** vars = &((*phis)->vars);
-        for (; *vars; vars = &((*vars)->next))
-        {
-            if ((*vars)->block == p->curr_block->id)
-            {
-                (*vars)->version = var.symver;
-                break;
-            }
-        }
-        if (!*vars)
-        {
-            *vars = calloc(1, sizeof(SsaPhiVar));
-            (*vars)->block = p->curr_block->id;
-            (*vars)->version = var.symver;
-        }
-    }
-}
-
 SsaNode* ssa_store(Parser* p, SsaNode* dest, SsaNode* value)
 {
     load(p, value);
-//    load(p, dest);
     SsaNode* node = ssa_node('=', MODE_VAR);
     node->type = dest->type;
     node->dest = dest->left.var;
     node->left.var = value->dest;
-
-    Symbol* sym = symbol_getbyid(p, node->dest.symid);
-    sym->version++;
-    node->dest.symver = sym->version;
-    phi_add(p, node->dest);
-
+    node->dest.symver = phi_add(p, node->dest);
     ssa_block_add(p->curr_block, node);
     node->loaded = 1;
     return node;
@@ -530,7 +587,6 @@ void ssa_print_block(Parser* p, Bitset* set, SsaBlock* block)
     for (SsaPhi* phi = block->phis; phi; phi = phi->next)
     {
         Symbol* s = symbol_getbyid(p, phi->symid);
-        s->version++;
         printf("    %s.%lu = phi(", s->name, s->version);
         for (SsaPhiVar* var = phi->vars; var; var = var->next)
         {
@@ -578,57 +634,55 @@ static void topsort(Bitset* set, SsaBlock** sorted, SsaBlock* block)
 
 void ssa_print_graph(Parser* p, SsaBlock* block)
 {
-    /* perform a topological sort of the nodes */
-    SsaBlock* sorted = NULL;
-    Bitset* set = bitset_new(p->blockid);
-    topsort(set, &sorted, block);
-
-
-    /* now let's print the plantuml representation */
-    printf("@startuml\n");
-    printf("[*] --> block%lu\n", block->id);
-    for (SsaBlock* curr = sorted; curr; curr = curr->next)
-    {
-
-        /* print the phis */
-        for (SsaPhi* phi = curr->phis; phi; phi = phi->next)
-        {
-            Symbol* s = symbol_getbyid(p, phi->symid);
-            s->version++;
-            printf("block%lu: %s.%lu = phi(", curr->id, s->name, s->version);
-            for (SsaPhiVar* var = phi->vars; var; var = var->next)
-            {
-                printf("%s.%lu", s->name, var->version);
-                if (var->next)
-                {
-                    printf(", ");
-                }
-            }
-            puts(")");
-        }
-
-        /* print the instructions */
-        for (SsaNode* node = curr->head; node; node = node->next)
-        {
-            printf("block%lu : ", curr->id);
-            print_dest(p, node);
-            ssa_print(p, node);
-            puts("");
-        }
-
-        /* print the links */
-        if (curr->links[1])
-        {
-            printf("block%lu --> block%lu\n", curr->id, curr->links[1]->id);
-        }
-        if (curr->links[0])
-        {
-            printf("block%lu --> block%lu\n", curr->id, curr->links[0]->id);
-        }
-
-        puts("");
-    }
-    printf("@enduml\n\n");
+//    /* perform a topological sort of the nodes */
+//    SsaBlock* sorted = NULL;
+//    Bitset* set = bitset_new(p->blockid);
+//    topsort(set, &sorted, block);
+//
+//
+//    /* now let's print the plantuml representation */
+//    printf("@startuml\n");
+//    printf("[*] --> block%lu\n", block->id);
+//    for (SsaBlock* curr = sorted; curr; curr = curr->next)
+//    {
+//        /* print the phis */
+//        for (SsaPhi* phi = curr->phis; phi; phi = phi->next)
+//        {
+//            Symbol* s = symbol_getbyid(p, phi->symid);
+//            printf("block%lu: %s.%lu = phi(", curr->id, s->name, s->version);
+//            for (SsaPhiVar* var = phi->vars; var; var = var->next)
+//            {
+//                printf("%s.%lu", s->name, var->version);
+//                if (var->next)
+//                {
+//                    printf(", ");
+//                }
+//            }
+//            puts(")");
+//        }
+//
+//        /* print the instructions */
+//        for (SsaNode* node = curr->head; node; node = node->next)
+//        {
+//            printf("block%lu : ", curr->id);
+//            print_dest(p, node);
+//            ssa_print(p, node);
+//            puts("");
+//        }
+//
+//        /* print the links */
+//        if (curr->links[1])
+//        {
+//            printf("block%lu --> block%lu\n", curr->id, curr->links[1]->id);
+//        }
+//        if (curr->links[0])
+//        {
+//            printf("block%lu --> block%lu\n", curr->id, curr->links[0]->id);
+//        }
+//
+//        puts("");
+//    }
+//    printf("@enduml\n\n");
 }
 
 
@@ -649,8 +703,7 @@ void ssa_print_asm(Parser* p, SsaBlock* block)
         for (SsaPhi* phi = curr->phis; phi; phi = phi->next)
         {
             Symbol* s = symbol_getbyid(p, phi->symid);
-            s->version++;
-            printf("    %s.%lu = phi(", s->name, s->version);
+            printf("    %s.%lu = phi(", s->name, phi->latest_ver);
             for (SsaPhiVar* var = phi->vars; var; var = var->next)
             {
                 printf("%s.%lu", s->name, var->version);
@@ -665,7 +718,6 @@ void ssa_print_asm(Parser* p, SsaBlock* block)
         /* print the instructions */
         for (SsaNode* node = curr->head; node; node = node->next)
         {
-//            printf("block%lu : ", curr->id);
             printf("    ");
             print_dest(p, node);
             ssa_print(p, node);
index ed147c01856b9dfee950f284a3944bcaa6f04ff4..45ffc3bbec03923be478af8819eef0427e191636 100644 (file)
@@ -70,21 +70,32 @@ begin
 end
 
 begin
-  b = 42;
-  b = -b;
-  c = b + 1;
-
+  b = 1;
   if c == b then
-    c = 42;
-  else
-    c = 24;
-  end
+    b = b - 1;
+    b = b - 1;
+ else
+    b = b + 1;
+    b = b + 1;
+ end
+ b = 4;
+ b = 5;
 
-  if c == b then
-    b = 42;
-  else
-    b = 24;
-  end
+#  b = 42;
+#  b = -b;
+#  c = b + 1;
+#
+#  if c == b then
+#    c = 42;
+#  else
+#    c = 24;
+#  end
+#
+#  if c == b then
+#    b = 42;
+#  else
+#    b = 24;
+#  end
 
 
 #    h[1].i = 42;