]> git.mdlowis.com Git - proto/cerise-os.git/commitdiff
non-crashing, non-exiting, multi-threaded job queue
authorMichael D. Lowis <mike.lowis@gentex.com>
Tue, 6 Sep 2022 20:38:33 +0000 (16:38 -0400)
committerMichael D. Lowis <mike.lowis@gentex.com>
Tue, 6 Sep 2022 20:38:33 +0000 (16:38 -0400)
experiments/ctxswitch.c

index 06bf48c665e7f6527a16969bd0caf9cd3c304c79..6a4ebd96ff5057b0ad78a5eda380e4e562e7aef1 100644 (file)
@@ -1,5 +1,20 @@
 #include <stdio.h>
 #include <stdlib.h>
+#include <stdbool.h>
+#include <unistd.h>
+#include <pthread.h>
+#include <assert.h>
+
+#define TRY_ELSE(cond) if ((cond) < 0)
+
+#define TRY(cond) ((cond) < 0 ? die("try:") : 0)
+
+static int die(char* str)
+{
+    perror(str);
+    abort();
+    return 0;
+}
 
 /* TODO:
 
@@ -15,12 +30,25 @@ typedef struct Task_T {
     int id;
 } Task_T;
 
-static struct {
+typedef struct {
+    pthread_t thread;
+    pthread_cond_t cond;
+    pthread_mutex_t mutex;
+    Task_T* idle;
     Task_T* curr;
     Task_T* head;
     Task_T* tail;
     Task_T* dead;
-} Queue = { 0 };
+} Thread_T;
+
+static int ThreadCount;
+static __thread int ThreadID;
+static Thread_T* Threads;
+static struct {
+    pthread_cond_t cond;
+    pthread_cond_t mutex;
+    long int value;
+} TaskCount;
 
 extern void Task_Switch(Task_T* prev, Task_T* next);
 asm (
@@ -45,74 +73,88 @@ asm (
 "    ret\n"
  );
 
-void Enqueue(Task_T* task)
+static void Enqueue(int id, Task_T* task)
 {
-    /* enqueue the currently running task */
-    if (task && task != Queue.dead)
+    TRY( pthread_mutex_lock(&Threads[id].mutex) );
+    printf("Scheduling task %d on thread %d\n", task->id, id);
+    if (task)
     {
-        if (Queue.tail)
+        if (Threads[id].tail)
         {
-            Queue.tail->next = task;
+            Threads[id].tail->next = task;
             task->next = NULL;
         }
-        Queue.tail = task;
-        if (!Queue.head)
+        Threads[id].tail = task;
+        if (!Threads[id].head)
         {
-            Queue.head = Queue.tail;
+            Threads[id].head = Threads[id].tail;
         }
+        printf("signaling thread %d to run\n", id);
+        TRY( pthread_cond_signal(&Threads[id].cond) );
     }
+
+    TRY( pthread_mutex_unlock(&Threads[id].mutex) );
 }
 
 void Task_Yield(void)
 {
-    if (Queue.head)
+    if (Threads[ThreadID].head)
     {
         /* unload the current task */
-        Task_T* curr = Queue.curr;
-        Enqueue(curr);
-        Queue.curr = NULL;
-
-        /* now pick the next task and start it */
-        Queue.curr = Queue.head;
-        Queue.head = Queue.head->next;
-        Queue.curr->next = NULL;
-        if (!Queue.head)
+        printf("thread %d unloading current task\n", ThreadID);
+        Task_T* curr = Threads[ThreadID].curr;
+        if (curr != Threads[ThreadID].idle)
         {
-            Queue.tail = NULL;
+            Enqueue(ThreadID, curr);
         }
+        Threads[ThreadID].curr = NULL;
 
-        /* run the selected task */
-        Task_Switch(curr, Queue.curr);
-
-        /* clean up the dead task if we have one */
-        if (Queue.dead)
+        /* now pick the next task and start it */
+        printf("thread %d picking next task\n", ThreadID);
+        Threads[ThreadID].curr = Threads[ThreadID].head;
+        Threads[ThreadID].head = Threads[ThreadID].head->next;
+        Threads[ThreadID].curr->next = NULL;
+        if (!Threads[ThreadID].head)
         {
-            printf("destroying task %d\n", Queue.dead->id);
-            free(Queue.dead->stack_base);
-            free(Queue.dead);
-            Queue.dead = NULL;
+            Threads[ThreadID].tail = NULL;
         }
+
+        printf("thread %d switching to task\n", ThreadID);
+        Task_Switch(curr, Threads[ThreadID].curr);
     }
 }
 
 void Task_Exit(void)
 {
-    printf("exiting task %d\n", Queue.curr->id);
-    Queue.dead = Queue.curr;
-    if (Queue.head)
+    printf("thread %d exiting task %d\n", ThreadID, Threads[ThreadID].curr->id);
+
+    TRY( pthread_mutex_lock(&TaskCount.mutex) );
+    TaskCount.value--;
+    printf("TASK COUNT %d\n", TaskCount.value);
+    TRY( pthread_cond_signal(&TaskCount.cond) );
+    TRY( pthread_mutex_unlock(&TaskCount.mutex) );
+
+
+    while(true)
     {
         Task_Yield();
     }
-    else
-    {
-        exit(0);
-    }
+
+//    Queue.dead = Queue.curr;
+//    if (Queue.head)
+//    {
+//        Task_Yield();
+//    }
+//    else
+//    {
+//        exit(0);
+//    }
 }
 
-void Task_Create(void (*task_fn)(void*), void *arg, int stacksize)
+static Task_T* CreateTask(void (*task_fn)(void*), void *arg, int stacksize)
 {
     static int i = 0;
-    if (stacksize == 0) { stacksize = 1024*1024; }
+    if (stacksize == 0) { stacksize = 32768; }
     Task_T* task = calloc(1, sizeof(Task_T));
     task->id = i++;
     task->stack_base = calloc(stacksize/sizeof(long), sizeof(long));
@@ -124,24 +166,72 @@ void Task_Create(void (*task_fn)(void*), void *arg, int stacksize)
     {
         *(--task->stack_top) = 0xdeadbeef; // initial values for saved registers
     }
+    printf("created task %d\n", task->id);
+    return task;
+}
 
-    /* enqueue the task */
-    if (!Queue.curr)
-    {
-        Queue.curr = task;
-    }
-    else
+static void* ThreadMain(void* arg)
+{
+    ThreadID = (long int)arg;
+    printf("Thread %d started\n", ThreadID);
+    Threads[ThreadID].idle = CreateTask(0, 0, 0);
+    Threads[ThreadID].curr = Threads[ThreadID].idle;
+
+    while (true)
     {
-        Enqueue(task);
+        /* let's wait for there to be work to do */
+        TRY_ELSE( pthread_cond_wait(&Threads[ThreadID].cond, &Threads[ThreadID].mutex) )
+        {
+            /* ignore error */
+        }
+
+        printf("thread %d woke up\n", ThreadID);
+        if (Threads[ThreadID].head)
+        {
+            printf("thread %d found something in the queue\n", ThreadID);
+            Task_Yield();
+        }
+        else
+        {
+            /* steal work here if possible */
+            printf("thread %d attempting to steal work\n", ThreadID);
+        }
     }
-    printf("created task %d\n", task->id);
-    Task_Yield();
+}
+
+void Task_Spawn(void (*task_fn)(void*), void *arg, int stacksize)
+{
+    TRY( pthread_mutex_lock(&TaskCount.mutex) );
+    Task_T* task = CreateTask(task_fn, arg, stacksize);
+    TaskCount.value++;
+    printf("TASK COUNT %d\n", TaskCount.value);
+    TRY( pthread_cond_signal(&TaskCount.cond) );
+    TRY( pthread_mutex_unlock(&TaskCount.mutex) );
+
+    Enqueue(rand() % ThreadCount, task);
 }
 
 void Task(void)
 {
-    /* create a task object for main */
-    Task_Create(0, 0, 0);
+    /* allocate the thread pool */
+    ThreadCount = sysconf(_SC_NPROCESSORS_ONLN);
+    assert(ThreadCount > 0);
+    printf("ThreadCount: %d\n", ThreadCount);
+    Threads = calloc(ThreadCount, sizeof(Thread_T));
+
+    TRY( pthread_cond_init(&TaskCount.cond, NULL) );
+    TRY( pthread_mutex_init(&TaskCount.mutex, NULL) );
+
+    /* allocate the worker threads */
+    for (long int i = 0; i < ThreadCount; i++)
+    {
+        TRY( pthread_create(&Threads[i].thread, NULL, ThreadMain, (void*)i) );
+        TRY( pthread_cond_init(&Threads[i].cond, NULL) );
+        TRY( pthread_mutex_init(&Threads[i].mutex, NULL) );
+    }
+
+    /* init PRNG for initial thread allocation */
+    srand(time(NULL));
 }
 
 /***************************************
@@ -153,7 +243,7 @@ void task_sub(void *arg)
     for (int i = 0; i < 5; i++)
     {
         Task_Yield();
-        printf("    Inside task %d\n", Queue.curr->id);
+        printf("    Inside task %d\n", Threads[ThreadID].curr->id);
     }
 }
 
@@ -162,23 +252,41 @@ void task_main(void *arg)
     for (int i = 0; i < 3; i++)
     {
         printf("spawning task %d\n", i);
-        Task_Create(task_sub, 0, 0);
+        Task_Spawn(task_sub, 0, 0);
     }
+    printf("hello from task_main\n");
 }
 
 int main(int argc, char** argv) {
-    if (argc != 2) { return 1; }
     Task();
+    Task_Spawn(task_main, 0, 0);
 
-    int count = atoi(argv[1]);
-    for (int i = 0; i < count; i++)
+    while (true)
     {
-        printf("spawning task %d\n", i+1);
-        int* id = calloc(1, sizeof(int));
-        *id = i;
-        Task_Create(task_sub, id, 0);
+        TRY_ELSE( pthread_cond_wait(&TaskCount.cond, &TaskCount.mutex) )
+        {
+            /* ignore error */
+        }
+        TRY( pthread_mutex_lock(&TaskCount.mutex) );
+        if(TaskCount.value == 0)
+        {
+            break;
+        }
+        TRY( pthread_mutex_unlock(&TaskCount.mutex) );
     }
-    Task_Exit();
+
+//    int count = atoi(argv[1]);
+//    int count = 5;
+//    for (int i = 0; i < count; i++)
+//    {
+//        printf("spawning task %d on thread %d\n", i+1, ThreadID);
+//        int* id = calloc(1, sizeof(int));
+//        *id = i;
+//        Task_Create(task_sub, id, 0);
+//    }
+//    Task_Exit();
+
     printf("This is unreachable!\n");
+
     return 0;
 }
\ No newline at end of file