1 module libd.datastructures.sumtype; 2 3 import libd.meta : assertIsPartOfUnion, assertAllSatisfy, Parameters, isSomeFunction; 4 5 @nogc nothrow 6 struct SumType(UnionT) 7 { 8 alias Union = UnionT; 9 10 // Can't static foreach inside an enum, so this is the next best thing? 11 // It really. REALLY sucks that you get -betterC limitations inside of CTFE-only funcs. 12 struct Kind 13 { 14 private int i; 15 static foreach(i2, name; __traits(allMembers, UnionT)) 16 mixin("static immutable Kind "~name~" = Kind(i2+1);"); 17 18 @safe @nogc 19 private bool isUnassigned() nothrow const 20 { 21 return this.i == 0; 22 } 23 24 @safe @nogc 25 size_t index() nothrow const 26 { 27 return this.i; 28 } 29 } 30 31 @nogc nothrow: 32 33 private Kind _kind; 34 private UnionT _value; 35 36 this(T)(auto ref T value) 37 { 38 this.set(value); 39 } 40 41 this(this) 42 { 43 if(this._kind.isUnassigned) 44 return; 45 46 switch(this._kind.i) 47 { 48 static foreach(member; __traits(allMembers, UnionT)) 49 {{ 50 alias Type = typeof(__traits(getMember, UnionT, member)); 51 case typeof(this).kindOf!Type.i: 52 static if(__traits(compiles, Type.init.__xpostblit())) 53 this.get!Type.__xpostblit(); 54 return; 55 }} 56 57 default: assert(false, "??"); 58 } 59 } 60 61 ~this() 62 { 63 this.dtorValue(); 64 } 65 66 static Kind kindOf(T)() 67 if(assertIsPartOfUnion!(UnionT, T)) 68 { 69 static foreach(member; __traits(allMembers, UnionT)) 70 { 71 static if(is(T == typeof(__traits(getMember, UnionT, member)))) 72 mixin("return Kind."~member~";"); 73 } 74 assert(false, "assertIsPartOfUnion should've triggered."); 75 } 76 77 void set(T)(auto ref T value) 78 if(assertIsPartOfUnion!(UnionT, T)) 79 { 80 enum NewKind = typeof(this).kindOf!T; 81 this.dtorValue(); 82 this._kind = NewKind; 83 84 static foreach(member; __traits(allMembers, UnionT)) 85 {{ 86 alias Type = typeof(__traits(getMember, UnionT, member)); 87 static if(is(Type == T)) 88 mixin("this._value."~member~" = value;"); 89 }} 90 } 91 92 ref T get(T)() 93 if(assertIsPartOfUnion!(UnionT, T)) 94 { 95 static foreach(member; __traits(allMembers, UnionT)) 96 {{ 97 alias Type = typeof(__traits(getMember, UnionT, member)); 98 static if(is(Type == T)) 99 { 100 assert(this._kind == typeof(this).kindOf!T, 101 "This SumType holds a value of type `"~"(TODO)"~"`, but a value of`" 102 ~ T.stringof~"` was asked for." 103 ); 104 mixin("return this._value."~member~";"); 105 } 106 }} 107 } 108 109 bool contains(T)() 110 { 111 return this._kind == typeof(this).kindOf!T; 112 } 113 114 void opAssign(T)(auto ref T value) 115 if(!is(T == typeof(this))) 116 { 117 this.set(value); 118 } 119 120 alias visit(Handlers...) = _visit!(typeof(this), Handlers); 121 122 @property @safe @nogc 123 Kind kind() nothrow const 124 { 125 return this._kind; 126 } 127 128 private void dtorValue()() 129 { 130 if(this._kind.isUnassigned) 131 return; 132 133 auto kind = this._kind; 134 this._kind = Kind(0); 135 136 switch(kind.i) 137 { 138 static foreach(member; __traits(allMembers, UnionT)) 139 {{ 140 alias Type = typeof(__traits(getMember, UnionT, member)); 141 enum TypeKind = typeof(this).kindOf!Type; 142 143 case TypeKind.i: 144 static if(__traits(compiles, Type.init.__xdtor())) 145 mixin("this._value."~member~".__xdtor();"); 146 this._value = UnionT.init; 147 return; 148 }} 149 150 default: assert(false, "Unknown Kind ID, wut?"); 151 } 152 } 153 } 154 /// 155 @("SumType") 156 unittest 157 { 158 union U 159 { 160 int a; 161 string b; 162 } 163 164 alias Sum = SumType!U; 165 166 static assert(Sum.kindOf!int == Sum.Kind.a); 167 static assert(Sum.kindOf!string == Sum.Kind.b); 168 169 auto value = Sum(20); 170 assert(value.kind == Sum.Kind.a && value.kind == Sum.kindOf!int); 171 assert(value.contains!int); 172 173 // bool threw = false; 174 // try value.get!string(); 175 // catch(Error error) 176 // threw = true; 177 // assert(threw); 178 179 assert(value.get!int == 20); 180 value = "lol"; 181 assert(value.kind == Sum.Kind.b && value.kind == Sum.kindOf!string); 182 assert(value.contains!string); 183 assert(value.get!string == "lol"); 184 185 int unhandled; 186 void visitTest(ref Sum sum) 187 { 188 sum.visit!( 189 (ref int i) { i *= 2; }, 190 (ref string b) { b = "lel"; }, 191 () { unhandled++; } 192 )(sum); 193 } 194 195 value = 20; 196 visitTest(value); 197 assert(value.get!int == 40); 198 199 value = "lol"; 200 visitTest(value); 201 assert(value.get!string == "lel"); 202 203 // value = new Object(); 204 // visitTest(value); 205 // assert(unhandled == 1); 206 } 207 208 @("SumType - PostBlit") 209 unittest 210 { 211 auto dtor = 0; 212 auto pblit = 0; 213 214 static struct S 215 { 216 int* dtor; 217 int* pblit; 218 219 @nogc nothrow: 220 221 this(this) 222 { 223 if(this.pblit) 224 (*this.pblit)++; 225 } 226 227 ~this() 228 { 229 if(this.dtor) 230 (*this.dtor)++; 231 } 232 } 233 234 static union U 235 { 236 S s; 237 int _; 238 } 239 240 alias Sum = SumType!U; 241 242 auto s = S(&dtor, &pblit); 243 assert(dtor == 0 && pblit == 0); 244 245 auto sum = Sum(s); 246 assert(pblit == 1); 247 assert(dtor == 0); 248 249 sum = 0; 250 assert(pblit == 1); 251 assert(dtor == 1); 252 253 sum = 1; 254 assert(pblit == 1); 255 assert(dtor == 1); 256 257 sum.__xdtor(); 258 assert(pblit == 1); 259 assert(dtor == 1); 260 261 sum = s; 262 assert(pblit == 2); 263 assert(dtor == 1); 264 265 auto sum2 = sum; 266 assert(pblit == 3); 267 assert(dtor == 1); 268 269 sum.__xdtor(); 270 sum.__xdtor(); 271 sum2.__xdtor(); 272 sum2.__xdtor(); 273 assert(pblit == 3); 274 assert(dtor == 3); 275 } 276 277 private void _visit(SumT, Handlers...)(ref SumT sum) 278 if(assertAllSatisfy!(isSomeFunction, Handlers)) 279 { 280 bool handled = false; 281 scope(exit) assert(handled, "No handler was provided for the type held within this SumType, and no default handler was provided."); 282 static foreach(handler; Handlers) 283 {{ 284 alias Params = Parameters!handler; 285 static if(Params.length == 0) 286 { 287 scope(exit) 288 { 289 if(!handled) 290 { 291 handled = true; 292 handler(); 293 } 294 } 295 } 296 else 297 { 298 if(sum._kind.i == SumT.kindOf!(Params[0]).i) 299 { 300 handler(sum.get!(Params[0])); 301 handled = true; 302 return; 303 } 304 } 305 }} 306 }