1 module libd.async.coroutine;
2 
3 // NOTE: This module doesn't go through the normal allocators for stack allocation, since this is bit of a special case in terms of memory allocation and management.
4 import libd.datastructures : LinkedList, SumType;
5 import libd.memory : g_alloc, PageAllocator, PageAllocation;
6 import libd.util.maths : alignTo;
7 
8 enum DEFAULT_COROUTINE_STACK_SIZE = 0x1000*10; // For stack tracing on windows to work, we need to be a multiple of 0x1000, a.k.a the page boundary
9                                                // Because after a 'fun' ASM debug session, I found out that an internal function
10                                                // deep inside dbghelp.dll has a hard expectation of this alignment.
11                                                // On the plus side, I learned a lot more about how to use x64dbg!
12 
13 version(X86_64)
14 {
15     version(Windows)
16     {
17         private enum Win64 = true;
18         private enum SysV  = false;
19     }
20     else version(linux)
21     {
22         private enum Win64 = false;
23         private enum SysV  = true;
24     }
25     else static assert(false, "Unsupported platform.");
26 }
27 else static assert(false, "libd only targets x86_64");
28 
29 // CONSTANTS
30 static if(Win64)
31 {
32     private enum REGISTERS : size_t
33     {
34         rsp,
35         ret,
36         r12,
37         r13,
38         r14,
39         r15,
40         rdi,
41         rsi,
42         rbx,
43         rbp,
44         gs0,
45         gs8,
46         gs16,
47 
48         COUNT
49     }
50 }
51 else static if(SysV)
52 {
53     private enum REGISTERS : size_t
54     {
55         rsp,
56         ret,
57         rbx,
58         rbp,
59         r12,
60         r13,
61         r14,
62         r15,
63 
64         COUNT
65     }
66 }
67 
68 package @nogc nothrow:
69 
70 alias CoroutineFunc = void function() @nogc nothrow;
71 
72 // TODO: Either fix TLS, or find a new mechanism to handle this.
73 __gshared Coroutine* g_currentThreadRoutine;
74 __gshared Coroutine  g_currentThreadMainRoutine;
75 
76 enum CoroutineState : ubyte
77 {
78     start,
79     running,
80     suspended,
81     end
82 }
83 
84 struct Coroutine
85 {
86     ulong[REGISTERS.COUNT] registers;
87     CoroutineState state;
88     CoroutineFunc entryPoint;
89     void* context;
90     LinkedList!(Coroutine*) callStack;
91     CoroutineStack stack;
92     CoroutineSuspendedStack suspendedStack;
93 
94     @safe @nogc nothrow pure const
95     bool isMain()
96     {
97         return this.callStack.length == 0;
98     }
99 }
100 
101 union CoroutineStackUnion
102 {
103     StandaloneStack* standalone;
104 }
105 
106 struct StandaloneStack
107 {
108     StackContext context;
109     Coroutine* owner;
110 }
111 
112 alias CoroutineStack = SumType!CoroutineStackUnion;
113 
114 struct CoroutineSuspendedStack
115 {
116     ubyte[] memory;
117 }
118 
119 extern(C) void coroutineSwap(Coroutine* from, Coroutine* to); // Implemented in NASM since D's inline ASM is a bit limited.
120 
121 CoroutineStack coroutineCreateStandaloneStack(
122     size_t minMemory = DEFAULT_COROUTINE_STACK_SIZE,
123     bool useGuardPage = true
124 )
125 {
126     auto alloc = pageAlloc(minMemory, useGuardPage);
127     auto stack = g_alloc.make!StandaloneStack(alloc);
128     if(stack is null)
129         onOutOfMemoryError(null);
130     return CoroutineStack(stack.ptr);
131 }
132 
133 void coroutineDestroyStack(ref CoroutineStack stack)
134 {
135     releaseMemoryResources(stack);
136     stack = CoroutineStack.init;
137 }
138 
139 Coroutine* coroutineCreate(
140     CoroutineFunc func,
141     CoroutineStack stack,
142     void* context,
143 )
144 {
145     auto ptr = g_alloc.make!Coroutine();
146     assert(ptr !is null);
147 
148     ptr.entryPoint = func;
149     ptr.context = context;
150     ptr.stack = stack;
151     return ptr;
152 }
153 
154 void coroutineDestroy(ref Coroutine* routine)
155 {
156     releaseMemoryResources(routine);
157     g_alloc.dispose(routine);
158 }
159 
160 void coroutineStart(Coroutine* to)
161 {
162     auto from = g_currentThreadRoutine;
163     if(from is null)
164         from = &g_currentThreadMainRoutine;
165     assert(to !is null, "To is null");
166     assert(to.state == CoroutineState.start, "Child is not in the `start` state.");
167     assert(to.entryPoint !is null, "Child has no entry point.");
168     
169     to.callStack.put(from);
170     from.state = CoroutineState.suspended;
171     to.state = CoroutineState.running;
172 
173     to.registers[REGISTERS.ret] = cast(ulong)&routineMain;
174 
175     to.stack.visit!(
176         (StandaloneStack* standalone)
177         {
178             assert(standalone.owner is null, "There is currently another coroutine making use of this standalone stack.");
179             standalone.owner = to;
180             to.registers[REGISTERS.rsp] = cast(ulong)standalone.context.alignedBot-8; // Need to enter a function with (RSP % 16) == 8
181             version(Windows)
182             version(X86_64)
183             {
184                 to.registers[REGISTERS.gs0]  = 0;
185                 to.registers[REGISTERS.gs8]  = cast(ulong)standalone.context.alignedBot;
186                 to.registers[REGISTERS.gs16] = cast(ulong)standalone.context.alignedTop;
187             }
188             *(cast(void**)standalone.context.alignedBot-8) = &coroutineExit;
189         }
190     )(to.stack);
191     g_currentThreadRoutine = to;
192     coroutineSwap(from, to);
193 }
194 
195 private void routineMain()
196 {
197     g_currentThreadRoutine.entryPoint();
198     coroutineExit();
199     assert(false);
200 }
201 
202 void coroutineReset(
203     Coroutine* routine,
204     void* newContext = null,
205     CoroutineFunc newEntryPoint = null,
206 )
207 {
208     assert(routine.state == CoroutineState.end, "Routine is not in the `end` state.");
209     assert(routine.callStack.length == 0, "Routine still has values on the call stack?");
210     routine.registers[] = 0;
211     if(newContext)
212         routine.context = newContext;
213     if(newEntryPoint)
214         routine.entryPoint = newEntryPoint;
215     releaseMemoryResources(routine, true);
216     routine.state = CoroutineState.start;
217 }
218 
219 void coroutineResume(Coroutine* routine)
220 {
221     auto from = g_currentThreadRoutine;
222     if(from is null)
223         from = &g_currentThreadMainRoutine;
224     assert(routine !is null, "Routine is null.");
225     assert(routine.state == CoroutineState.suspended, "Routine is not in the `suspended` state.");
226 
227     routine.callStack.put(from);
228     from.state = CoroutineState.suspended;
229     routine.state = CoroutineState.running;
230     g_currentThreadRoutine = routine;
231     coroutineSwap(from, routine);
232 }
233 
234 void* coroutineGetContext()
235 {
236     auto routine = g_currentThreadRoutine;
237     assert(routine !is null, "Cannot call this function when not inside a coroutine.");
238     return routine.context;
239 }
240 
241 void coroutineYield()
242 {
243     yieldImpl(CoroutineState.suspended);
244 }
245 
246 void coroutineExit()
247 {
248     yieldImpl(CoroutineState.end);
249     assert(false);
250 }
251 
252 private void yieldImpl(CoroutineState endState)
253 {
254     auto routine = g_currentThreadRoutine;
255     assert(routine !is null, "Cannot call this function when not inside a coroutine.");
256     assert(routine.callStack.length > 0, "Coroutine has no call stack?");
257 
258     routine.state = endState;
259     auto next = routine.callStack.removeAtTail(routine.callStack.length - 1);
260     assert(next.state == CoroutineState.suspended, "Call stack routine is not in suspended state?");
261     g_currentThreadRoutine = next;
262     next.state = CoroutineState.running;
263     coroutineSwap(routine, next);
264 }
265 
266 private void releaseMemoryResources(Coroutine* routine, bool isForReset = false)
267 {
268     routine.stack.visit!(
269         (StandaloneStack* standalone)
270         {
271             if(routine.state == CoroutineState.running || routine.state == CoroutineState.suspended)
272                 assert(standalone.owner is routine, "??");
273             standalone.owner = null;
274         }
275     )(routine.stack);
276 }
277 
278 private void releaseMemoryResources(CoroutineStack stack)
279 {
280     stack.visit!(
281         (StandaloneStack* standalone) 
282         {
283             pageFree(standalone.context.pages);
284             g_alloc.dispose(standalone);
285         }
286     )(stack);
287 }
288 
289 private @nogc nothrow:
290 
291 struct StackContext
292 {
293     ubyte* base;
294     ubyte* alignedTop;
295     ubyte* alignedBot;
296     shared PageAllocation pages;
297 }
298 
299 StackContext pageAlloc(size_t minSize, bool useGuardPage)
300 {
301     StackContext context;
302 
303     auto alloc = PageAllocator.allocInBytesToPages(minSize, useGuardPage);
304     
305     context.base       = cast(ubyte*)alloc.memory.ptr;
306     context.alignedBot = cast(ubyte*)(alloc.memory.ptr + alloc.memory.length);
307     context.alignedTop = cast(ubyte*)alloc.memory.ptr;
308     context.alignedBot -= 56; // Win64 ABI requires a 32 byte shadow space, and we need another 8 bytes for the default return address.
309     context.alignedBot = cast(ubyte*)((cast(ulong)context.alignedBot).alignTo!16);
310     context.alignedTop = cast(ubyte*)((cast(ulong)context.alignedTop).alignTo!16);
311     context.pages      = alloc;
312 
313     return context;
314 }
315 
316 void pageFree(shared(PageAllocation) pages)
317 {
318     PageAllocator.free(pages);
319 }
320 
321 @("coroutine - Create and Free stack")
322 unittest
323 {
324     auto stack = coroutineCreateStandaloneStack(200, true);
325     coroutineDestroyStack(stack);
326 }
327 
328 @("coroutine - Create and Free routine")
329 unittest
330 {
331     static void routine()
332     {
333     }
334 
335     auto stack = coroutineCreateStandaloneStack();
336     auto co    = coroutineCreate(&routine, stack, null);
337     coroutineDestroy(co);
338     coroutineDestroyStack(stack);
339 }
340 
341 @("coroutine - Explicit exit")
342 unittest
343 {
344     static void routine()
345     {
346         coroutineExit();
347     }
348 
349     auto stack = coroutineCreateStandaloneStack();
350     auto co    = coroutineCreate(&routine, stack, null);
351     coroutineStart(co);
352     coroutineDestroy(co);
353     coroutineDestroyStack(stack);
354 }
355 
356 @("coroutine - Implicit exit")
357 unittest
358 {
359     static void routine()
360     {
361     }
362 
363     auto stack = coroutineCreateStandaloneStack();
364     auto co    = coroutineCreate(&routine, stack, null);
365     coroutineStart(co);
366     coroutineDestroy(co);
367     coroutineDestroyStack(stack);  
368 }
369 
370 @("coroutine - Suspend")
371 unittest
372 {
373     __gshared static int num;
374 
375     static void routine()
376     {
377         num++;
378         coroutineYield();
379         num++;
380     }
381 
382     auto stack = coroutineCreateStandaloneStack();
383     auto co    = coroutineCreate(&routine, stack, null);
384 
385     coroutineStart(co);
386     assert(num == 1);
387     coroutineResume(co);
388     assert(num == 2);
389 
390     coroutineDestroy(co);
391     coroutineDestroyStack(stack);  
392 }
393 
394 @("coroutine - Context")
395 unittest
396 {
397     int num;
398 
399     static void routine()
400     {
401         auto ptr = cast(int*)coroutineGetContext();
402         assert(ptr !is null);
403         *ptr = 200;
404     }
405 
406     auto stack = coroutineCreateStandaloneStack();
407     auto co    = coroutineCreate(&routine, stack, &num);
408 
409     coroutineStart(co);
410     assert(num == 200);
411 
412     coroutineDestroy(co);
413     coroutineDestroyStack(stack);  
414 }
415 
416 @("coroutine - Reset")
417 unittest
418 {
419     int num;
420 
421     static void routine()
422     {
423         auto ptr = cast(int*)coroutineGetContext();
424         assert(ptr !is null);
425         *ptr += 1;
426     }
427     auto stack = coroutineCreateStandaloneStack();
428     auto co    = coroutineCreate(&routine, stack, &num);
429 
430     coroutineStart(co);
431     assert(num == 1);
432     coroutineReset(co);
433     coroutineStart(co);
434     assert(num == 2);
435 
436     coroutineDestroy(co);
437     coroutineDestroyStack(stack);
438 }