1 module libd.datastructures.hashstuff;
2 
3 import libd.memory, libd.data, libd.datastructures.array, libd.datastructures.growth, libd.algorithm.common,
4        libd.meta.traits;
5 
6 struct KeyValuePair(alias KeyT, alias ValueT)
7 {
8     KeyT key;
9     ValueT value;
10 }
11 
12 struct KeyValueRefPair(alias KeyT, alias ValueT)
13 {
14     KeyT* key;
15     ValueT* value;
16 }
17 
18 struct RobinHoodHashMap(
19     alias KeyT, 
20     alias ValueT, 
21     alias AllocT = SystemAllocator, 
22     alias Hasher = murmur3_32HashOf, 
23     double maxLoadFactor = 0.8
24 )
25 {
26     static assert(maxLoadFactor > 0, "The load factor cannot be 0 or negative.");
27 
28     // Moving things we don't need to move is *suuuper* slow, so types can explicitly say if they prefer being moved.
29     enum KeyOptimise   = BitmaskUda!(OptimisationHint, KeyT);
30     enum ValueOptimise = BitmaskUda!(OptimisationHint, ValueT);
31     enum MoveKey       = (KeyOptimise & OptimisationHint.preferMoveOverCopy) > 0   || !isCopyable!KeyT;
32     enum MoveValue     = (ValueOptimise & OptimisationHint.preferMoveOverCopy) > 0 || !isCopyable!ValueT;
33 
34     static struct Node
35     {
36         KeyT key;
37         ValueT value;
38         ubyte distance = ubyte.max;
39     }
40 
41     private Array!(Node, AllocT) _array;
42     private size_t _fakeCapacity;
43     private size_t _fakeMaxLoadCapacity;
44     private size_t _length;
45     private ubyte  _primeIndex;
46     private ubyte  _probeLimit;
47 
48     @nogc nothrow:
49 
50     this(AllocatorWrapperOf!AllocT alloc)
51     {
52         this._array = typeof(_array)(alloc);
53     }
54 
55     void put()(KeyT key, ValueT value)
56     {
57         bool alreadyExists, wasSwap, wasSwapAtAnyPoint, ___;
58         KeyT currKey, _;
59         ValueT currValue, __;
60         if(this._length >= this._fakeMaxLoadCapacity
61          || !this.putInto(this._array, key, value, alreadyExists, wasSwap, currKey, currValue) // failed insertion
62         )
63         {
64             import libd.util.maths : log2, ceilToInt;
65 
66             typeof(_array) nextArray;
67             const oldLength = this._length;
68             while(true)
69             {
70                 const nextPrime    = nextPrimeSize(this._primeIndex);
71                 const nextLimit    = cast(ubyte)log2(nextPrime);
72                 const nextRealSize = nextPrime + nextLimit;
73                 const nextMaxSize  = ceilToInt!size_t(cast(double)nextPrime * maxLoadFactor);
74                 this._probeLimit   = nextLimit;
75                 this._fakeCapacity = nextPrime;
76                 this._fakeMaxLoadCapacity = nextMaxSize;
77                 nextArray.length = nextRealSize;
78 
79                 bool reloop = false;
80                 size_t insertCount;
81                 foreach(ref node; this._array)
82                 {
83                     if(node.distance == ubyte.max)
84                         continue;
85                     if(!this.putInto(nextArray, node.key, node.value, alreadyExists, ___, _, __))
86                     {
87                         reloop = true;
88                         break;
89                     }
90                     if(++insertCount == oldLength)
91                         break;
92                 }
93                 
94                 if(reloop)
95                 {
96                     emplaceInit(nextArray);
97                     continue;
98                 }
99 
100                 const result = (wasSwap || wasSwapAtAnyPoint)
101                     ? this.putInto(nextArray, currKey, currValue, alreadyExists, wasSwap, currKey, currValue)
102                     : this.putInto(nextArray, key, value, alreadyExists, wasSwap, currKey, currValue);
103                 wasSwapAtAnyPoint = wasSwapAtAnyPoint || wasSwap;
104                 if(!result)
105                 {
106                     emplaceInit(nextArray);
107                     continue;
108                 }
109 
110                 move(nextArray, this._array);
111                 break;
112             }
113         }
114 
115         this._length += !alreadyExists;
116     }
117 
118     bool removeAt()(auto ref KeyT key)
119     {
120         ValueT v;
121         return this.removeAt(key, v);
122     }
123 
124     bool removeAt()(auto ref KeyT key, ref ValueT outValue)
125     {
126         if(this._primeIndex == 0)
127             return false;
128         const index = toHashToPrimeIndex!(Hasher, KeyT)(key, this._primeIndex - 1);
129         foreach(i; 0..this._probeLimit)
130         {
131             auto ptr = &this._array[index+i];
132             if(ptr.key == key)
133             {
134                 ValueT value;
135                 move(ptr.value, value);
136                 emplaceInit(*ptr);
137 
138                 auto shiftIndex = index+i+1;
139                 auto lastPtr    = ptr;
140                 while(shiftIndex < this._array.length)
141                 {
142                     auto currPtr = &this._array[shiftIndex++];
143                     if(currPtr.distance == 0 || currPtr.distance == 255)
144                         break;
145                     move(*currPtr, *lastPtr);
146                     lastPtr = currPtr;
147                 }
148 
149                 move(value, outValue);
150                 this._length--;
151                 return true;
152             }
153         }
154         return false;
155     }
156 
157     bool containsKey()(auto ref KeyT key) const
158     {
159         return this.getNodeAt(key) !is null;
160     }
161     
162     inout(ValueT) getAt()(auto ref KeyT key) inout
163     {
164         auto result = this.getNodeAt(key);
165         assert(result !is null, "Could not find key.");
166         return result.value;
167     }
168 
169     ref inout(ValueT) getAtByRef()(auto ref KeyT key) inout
170     {
171         auto ptr = this.getNodeAt(key);
172         assert(ptr !is null, "Could not find key.");
173         return ptr.value;
174     }
175 
176     inout(ValueT) getAtOrDefault()(auto ref KeyT key, auto ref scope return ValueT default_ = ValueT.init) inout
177     {
178         auto result = this.getNodeAt(key);
179         return (result) ? result.value : cast(inout)default_;
180     }
181 
182     inout(ValueT)* getPtrUnsafeAt()(auto ref KeyT key) inout
183     {
184         auto result = this.getNodeAt(key);
185         return (result) ? &result.value : null;
186     }
187 
188     @property @safe
189     size_t length() const
190     {
191         return this._length;
192     }
193 
194     @property
195     auto range()
196     {
197         alias HashMapT = typeof(this);
198 
199         static struct R
200         {
201             HashMapT* ptr;
202             KeyValueRefPair!(KeyT, ValueT) _front;
203             size_t lengthAtStart;
204             size_t iteratedOver;
205             size_t index;
206             bool _empty = true;
207 
208             @nogc nothrow:
209 
210             this(HashMapT* hashmap)
211             {
212                 this.ptr = hashmap;
213                 this.lengthAtStart = hashmap.length;
214                 this._empty = false;
215                 this.popFront();
216             }
217 
218             void popFront()
219             {
220                 assert(!this.empty, "Cannot pop an empty range.");
221                 if(this.iteratedOver == this.lengthAtStart)
222                 {
223                     this._empty = true;
224                     return;
225                 }
226                 foreach(i; this.index..this.ptr._array.length)
227                 {
228                     if(ptr._array[i].distance != ubyte.max)
229                     {
230                         this.iteratedOver++;
231                         this._front = typeof(_front)(
232                             &ptr._array[i].key,
233                             &ptr._array[i].value
234                         );
235                         this.index = i+1;
236                         return;
237                     }
238                 }
239                 assert(false, "?? Could not find next front?");
240             }
241 
242             bool empty()
243             {
244                 if(this.ptr is null)
245                     return true;
246                 assert(this.ptr.length == this.lengthAtStart, "Please do not modify the hashmap during iteration.");
247                 return this._empty;
248             }
249 
250             typeof(_front) front()
251             {
252                 assert(!this.empty, "Cannot access front of empty range.");
253                 return this._front;
254             }
255         }
256 
257         return R(&this);
258     }
259 
260     private inout(Node)* getNodeAt()(auto ref KeyT key) inout
261     {
262         if(this._primeIndex == 0)
263             return null;
264         const index  = toHashToPrimeIndex!(Hasher, KeyT)(key, this._primeIndex - 1);
265         auto nodePtr = &this._array[index]; // bypass bounds checking, as our overallocation should ensure this is always in bounds.
266         foreach(i; 0..this._probeLimit)
267         {
268             auto ptr = &nodePtr[i];
269             if(ptr.key == key)
270                 return ptr;
271         }
272         return null;
273     }
274     
275     // Interestingly, LDC automatically inlines this function when optimising o.o
276     private bool putInto()(
277         ref typeof(_array) array, 
278         auto ref KeyT key, 
279         auto ref ValueT value, 
280         out bool alreadyExists,
281         out bool wasSwap,
282         ref KeyT currKey,
283         ref ValueT currValue
284     )
285     {
286         static if(MoveKey)   move(key, currKey);     else currKey = key;
287         static if(MoveValue) move(value, currValue); else currValue = value;
288 
289         const index    = toHashToPrimeIndex!(Hasher, KeyT)(key, this._primeIndex-1);
290         const length   = array.length;
291         auto arrayPtr  = array[].ptr; // bypass bounds checking as this should in theory be completely @safe to access.
292         ubyte distance = 255;
293         for(size_t i = index; i < length; i++)
294         {
295             distance++;
296             if(distance >= this._probeLimit)
297                 break;
298 
299             auto nodePtr = &arrayPtr[i];
300 
301             if(nodePtr.distance == ubyte.max)
302             {
303                 static if(MoveKey)   move(currKey, nodePtr.key);     else nodePtr.key = currKey;
304                 static if(MoveValue) move(currValue, nodePtr.value); else nodePtr.value = currValue;
305                 nodePtr.distance = distance;
306                 return true;
307             }
308             else if(nodePtr.distance < distance)
309             {
310                 wasSwap = true;
311 
312                 KeyT   tempKey;
313                 ValueT tempValue;
314                 ubyte  tempDistance;
315 
316                 static if(MoveKey)   move(nodePtr.key, tempKey);     else tempKey = nodePtr.key;
317                 static if(MoveValue) move(nodePtr.value, tempValue); else tempValue = nodePtr.value;
318                 tempDistance = nodePtr.distance;
319 
320                 static if(MoveKey)   move(currKey, nodePtr.key);     else nodePtr.key = currKey;
321                 static if(MoveValue) move(currValue, nodePtr.value); else nodePtr.value = currValue;
322                 nodePtr.distance = distance;
323 
324                 static if(MoveKey)   move(tempKey, currKey);     else currKey   = tempKey;
325                 static if(MoveValue) move(tempValue, currValue); else currValue = tempValue;
326                 distance = tempDistance;
327             }
328             else if(nodePtr.key == currKey)
329             {
330                 static if(MoveKey)   move(currKey, nodePtr.key); // So things stay predictable in terms of what OnMove and such do.
331                 static if(MoveValue) move(currValue, nodePtr.value); else nodePtr.value = currValue;
332                 alreadyExists = true;
333                 return true;
334             }
335         }
336 
337         return false;
338     }
339 }
340 @("RobinHoodHashMap")
341 unittest
342 {
343     uint pblit;
344     uint dtor;
345     //@(OptimisationHint.preferMoveOverCopy)
346     static struct S
347     {
348         @nogc nothrow:
349         uint* pblit;
350         uint* dtor;
351 
352         this(this)
353         {
354             if(pblit)
355                 (*pblit)++;
356         }
357 
358         ~this()
359         {
360             if(dtor)
361                 (*dtor)++;
362         }
363     }
364 
365     S s = S(&pblit, &dtor);
366     RobinHoodHashMap!(string, S) h;
367     
368     h.put("test", s);
369     assert(h.length == 1);
370     assert(h.containsKey("test"));
371     assert(!h.containsKey("tesT"));
372     assert(pblit == dtor+1);
373     h.__xdtor();
374     assert(pblit == dtor);
375 
376     pblit = dtor = 0;
377     emplaceInit(h);
378     h.put("test", s);
379     h.put("test", s);
380     assert(h.length == 1);
381     assert(pblit == dtor+1);
382     h.put("test2", s);
383     assert(h.length == 2);
384     assert(pblit == dtor+2);
385     h.__xdtor();
386     assert(pblit == dtor);
387     
388     pblit = dtor = 0;
389     emplaceInit(h);
390     h.put("test", s);
391     h.put("test2", s);
392     assert(pblit == dtor+2);
393     h.removeAt("test");
394     assert(h.length == 1);
395     assert(pblit == dtor+1);
396     h.__xdtor();
397     assert(pblit == dtor);
398 
399     RobinHoodHashMap!(int, S) his;
400     pblit = dtor = 0;
401     foreach(i; 0..10_000)
402         his.put(i, s);
403     his.__xdtor();
404     assert(pblit == dtor);
405 
406     RobinHoodHashMap!(string, int) hi;
407 
408     hi.put("one", 0);
409     hi.put("two", 2);
410     assert(hi.getPtrUnsafeAt("one") !is null);
411     *hi.getPtrUnsafeAt("one") = 1;
412     assert(hi.getAt("one") == 1);
413     assert(hi.getAt("two") == 2);
414     assert(hi.getAtOrDefault("three", 3) == 3);
415     
416     int result;
417     assert(hi.removeAt("two", result));
418     assert(result == 2);
419     assert(!hi.removeAt("two"));
420     assert(hi.removeAt("one"));
421 
422     hi.put("one", 2);
423     hi.put("two", 4);
424     assert(hi.length == 2);
425     assert(hi.getAt("two") == 4);
426 
427     auto r = hi.range;
428     assert(!r.empty);
429     assert(*r.front.key == "one" || *r.front.key == "two");
430     assert(*r.front.value == 2 || *r.front.value == 4);
431     r.popFront();
432     assert(!r.empty);
433     assert(*r.front.key == "one" || *r.front.key == "two");
434     assert(*r.front.value == 2 || *r.front.value == 4);
435     r.popFront();
436     assert(r.empty);
437 
438     result = 0;
439     foreach(kvp; hi.range) 
440         result += *kvp.value;
441     assert(result == 6);
442 }
443 @("RobinHoodHashMap - 10_000 ints")
444 unittest
445 {
446     enum AMOUNT = 10_000;
447     RobinHoodHashMap!(int, int) h;
448     foreach(i; 0..AMOUNT)
449     {
450         if(h.length != i)
451             assert(false);
452         h.put(i, i);
453     }
454     assert(h.length == AMOUNT);
455     foreach(i; 0..AMOUNT)
456         h.getAt(i);
457 }
458 
459 @safe @nogc 
460 private size_t nextPrimeSize(ref ubyte index) nothrow pure
461 {
462     return PrimeNumberForSizeLookup[index++];
463 }
464 
465 @trusted @nogc
466 private auto toHashToPrimeIndex(alias Hasher, alias T)(auto ref T value, const int primeIndex) nothrow pure
467 {
468     static if(is(T == struct) && __traits(hasMember, T, "toHash"))
469         const hash = value.toHash();
470     else
471         const hash = Hasher(value);
472 
473     // Division (inc modulo) is super super super slow on unknown numbers.
474     // Doing things this way allows the compiler to generated optimised opcodes since it knows the divisor beforehand.
475     switch(primeIndex)
476     {
477         static foreach(i, prime; PrimeNumberForSizeLookup)
478             case i: return hash % prime;
479         default: break;
480     }
481 
482     assert(false, "Index too high?");
483 }