1 module libd.datastructures.sumtype;
2 
3 import libd.meta : assertIsPartOfUnion, assertAllSatisfy, Parameters, isSomeFunction;
4 
5 @nogc nothrow
6 struct SumType(UnionT)
7 {
8     alias Union = UnionT;
9 
10     // Can't static foreach inside an enum, so this is the next best thing?
11     // It really. REALLY sucks that you get -betterC limitations inside of CTFE-only funcs.
12     struct Kind
13     {
14         private int i;
15         static foreach(i2, name; __traits(allMembers, UnionT))
16             mixin("static immutable Kind "~name~" = Kind(i2+1);");
17 
18         @safe @nogc
19         private bool isUnassigned() nothrow const
20         {
21             return this.i == 0;
22         }
23 
24         @safe @nogc
25         size_t index() nothrow const
26         {
27             return this.i;
28         }
29     }
30 
31     @nogc nothrow:
32 
33     private Kind   _kind;
34     private UnionT _value;
35 
36     this(T)(auto ref T value)
37     {
38         this.set(value);
39     }
40 
41     this(this)
42     {
43         if(this._kind.isUnassigned)
44             return;
45 
46         switch(this._kind.i)
47         {
48             static foreach(member; __traits(allMembers, UnionT))
49             {{
50                 alias Type = typeof(__traits(getMember, UnionT, member));
51                 case typeof(this).kindOf!Type.i:
52                     static if(__traits(compiles, Type.init.__xpostblit()))
53                         this.get!Type.__xpostblit();
54                     return;
55             }}
56 
57             default: assert(false, "??");
58         }
59     }
60 
61     ~this()
62     {
63         this.dtorValue();
64     }
65 
66     static Kind kindOf(T)()
67     if(assertIsPartOfUnion!(UnionT, T))
68     {
69         static foreach(member; __traits(allMembers, UnionT))
70         {
71             static if(is(T == typeof(__traits(getMember, UnionT, member))))
72                 mixin("return Kind."~member~";");
73         }
74         assert(false, "assertIsPartOfUnion should've triggered.");
75     }
76 
77     void set(T)(auto ref T value)
78     if(assertIsPartOfUnion!(UnionT, T))
79     {
80         enum NewKind = typeof(this).kindOf!T;
81         this.dtorValue();
82         this._kind = NewKind;
83 
84         static foreach(member; __traits(allMembers, UnionT))
85         {{
86             alias Type = typeof(__traits(getMember, UnionT, member));
87             static if(is(Type == T))
88                 mixin("this._value."~member~" = value;");
89         }}
90     }
91 
92     ref T get(T)()
93     if(assertIsPartOfUnion!(UnionT, T))
94     {
95         static foreach(member; __traits(allMembers, UnionT))
96         {{
97             alias Type = typeof(__traits(getMember, UnionT, member));
98             static if(is(Type == T))
99             {
100                 assert(this._kind == typeof(this).kindOf!T,
101                     "This SumType holds a value of type `"~"(TODO)"~"`, but a value of`"
102                   ~ T.stringof~"` was asked for."
103                 );
104                 mixin("return this._value."~member~";");
105             }
106         }}
107     }
108 
109     bool contains(T)()
110     {
111         return this._kind == typeof(this).kindOf!T;
112     }
113 
114     void opAssign(T)(auto ref T value)
115     if(!is(T == typeof(this)))
116     {
117         this.set(value);
118     }
119 
120     alias visit(Handlers...) = _visit!(typeof(this), Handlers);
121 
122     @property @safe @nogc
123     Kind kind() nothrow const
124     {
125         return this._kind;
126     }
127 
128     private void dtorValue()()
129     {
130         if(this._kind.isUnassigned)
131             return;
132 
133         auto kind = this._kind;
134         this._kind = Kind(0);
135 
136         switch(kind.i)
137         {
138             static foreach(member; __traits(allMembers, UnionT))
139             {{
140                 alias Type = typeof(__traits(getMember, UnionT, member));
141                 enum TypeKind = typeof(this).kindOf!Type;
142 
143                 case TypeKind.i:
144                     static if(__traits(compiles, Type.init.__xdtor()))
145                         mixin("this._value."~member~".__xdtor();");
146                     this._value = UnionT.init;
147                     return;
148             }}
149 
150             default: assert(false, "Unknown Kind ID, wut?");
151         }
152     }
153 }
154 ///
155 @("SumType")
156 unittest
157 {
158     union U
159     {
160         int a;
161         string b;
162     }
163 
164     alias Sum = SumType!U;
165 
166     static assert(Sum.kindOf!int    == Sum.Kind.a);
167     static assert(Sum.kindOf!string == Sum.Kind.b);
168 
169     auto value = Sum(20);
170     assert(value.kind == Sum.Kind.a && value.kind == Sum.kindOf!int);
171     assert(value.contains!int);
172 
173     // bool threw = false;
174     // try value.get!string();
175     // catch(Error error)
176     //     threw = true;
177     // assert(threw);
178 
179     assert(value.get!int == 20);
180     value = "lol";
181     assert(value.kind == Sum.Kind.b && value.kind == Sum.kindOf!string);
182     assert(value.contains!string);
183     assert(value.get!string == "lol");
184 
185     int unhandled;
186     void visitTest(ref Sum sum)
187     {
188         sum.visit!(
189             (ref int i) { i *= 2; },
190             (ref string b) { b = "lel"; },
191             () { unhandled++; }
192         )(sum);
193     }
194     
195     value = 20;
196     visitTest(value);
197     assert(value.get!int == 40);
198 
199     value = "lol";
200     visitTest(value);
201     assert(value.get!string == "lel");
202 
203     // value = new Object();
204     // visitTest(value);
205     // assert(unhandled == 1);
206 }
207 
208 @("SumType - PostBlit")
209 unittest
210 {
211     auto dtor = 0;
212     auto pblit = 0;
213 
214     static struct S
215     {
216         int* dtor;
217         int* pblit;
218 
219         @nogc nothrow:
220 
221         this(this)
222         {
223             if(this.pblit)
224                 (*this.pblit)++;
225         }
226 
227         ~this()
228         {
229             if(this.dtor)
230                 (*this.dtor)++;
231         }
232     }
233 
234     static union U
235     {
236         S s;
237         int _;
238     }
239 
240     alias Sum = SumType!U;
241 
242     auto s = S(&dtor, &pblit);
243     assert(dtor == 0 && pblit == 0);
244 
245     auto sum = Sum(s);
246     assert(pblit == 1);
247     assert(dtor == 0);
248 
249     sum = 0;
250     assert(pblit == 1);
251     assert(dtor == 1);
252 
253     sum = 1;
254     assert(pblit == 1);
255     assert(dtor == 1);
256 
257     sum.__xdtor();
258     assert(pblit == 1);
259     assert(dtor == 1);
260 
261     sum = s;
262     assert(pblit == 2);
263     assert(dtor == 1);
264 
265     auto sum2 = sum;
266     assert(pblit == 3);
267     assert(dtor == 1);
268 
269     sum.__xdtor();
270     sum.__xdtor();
271     sum2.__xdtor();
272     sum2.__xdtor();
273     assert(pblit == 3);
274     assert(dtor == 3);
275 }
276 
277 private void _visit(SumT, Handlers...)(ref SumT sum)
278 if(assertAllSatisfy!(isSomeFunction, Handlers))
279 {
280     bool handled = false;
281     scope(exit) assert(handled, "No handler was provided for the type held within this SumType, and no default handler was provided.");
282     static foreach(handler; Handlers)
283     {{
284         alias Params = Parameters!handler;
285         static if(Params.length == 0)
286         {
287             scope(exit)
288             {
289                 if(!handled)
290                 {
291                     handled = true;
292                     handler();
293                 }
294             }
295         }
296         else
297         {
298             if(sum._kind.i == SumT.kindOf!(Params[0]).i)
299             {
300                 handler(sum.get!(Params[0]));
301                 handled = true;
302                 return;
303             }
304         }
305     }}
306 }