1 /**
2     SumTypes, this is mostly taken from libphobos.
3     
4     Copyright:
5         Copyright © 2005-2009, The D Language Foundation.
6         Copyright © 2023-2025, Kitsunebi Games
7         Copyright © 2023-2025, Inochi2D Project
8     
9     License:   $(LINK2 http://www.boost.org/LICENSE_1_0.txt, Boost License 1.0)
10     Authors:
11         Paul Backus,
12         Luna Nielsen
13 */
14 module nulib.data.sumtype;
15 import numem.core.lifetime : forward;
16 import numem.core.traits;
17 import numem.core.meta;
18 import numem;
19 
20 /**
21     Gets whether the given type is a sumtype.
22 */
23 enum isSumType(T) = is(T : SumType!U, U...);
24 
25 /**
26     A sum type
27 */
28 struct SumType(Types...) if (Types.length > 0) {
29 private:
30 @nogc:
31     enum bool canHoldTag(T) = Types.length <= T.max;
32     alias Tag = Filter!(canHoldTag, AliasSeq!(ubyte, ushort, uint, ulong))[0];
33 
34     // Meta-info enums.
35     enum hasDtor = anySatisfy!(hasElaborateDestructor, Types);
36     enum hasCCtor = anySatisfy!(hasElaborateCopyConstructor, Types);
37 
38     // Sumtype storage
39     union Storage {
40         static foreach(typeId, T; Types) {
41             mixin("T ", memberNameOf!typeId, ";");
42         }
43     }
44 
45     Tag tag;
46     Storage storage;
47 
48     @trusted
49     auto ref getByIndex(size_t tid)() inout 
50     if (tid < Types.length) {
51         assert(tag == tid, "Does not contain requested type!");
52         return storage.tupleof[tid];
53     }
54 
55 public:
56 
57     /**
58         The types which the sumtype may contain
59     */
60     alias AllowedTypes = Types;
61 
62     /**
63         Destructor
64     */
65     static if (hasDtor)
66     ~this() {
67         this.match!destroyIfOwner;
68     }
69 
70     static foreach(typeId, T; Types) {
71 
72         // Constructor
73         this()(auto ref T value) {
74             static if (isMovable!T) {
75                 storage.tupleof[typeId] = value.move();
76             } else {
77                 storage.tupleof[typeId] = forward!value;
78             }
79 
80             static if (Types.length > 1)
81                 this.tag = typeId;
82         }
83 
84         // Assignment
85         static if (isAssignable!T) {
86             ref SumType opAssign(T rhs) {
87                 import std.format : format;
88                 alias OtherTypes = AliasSeq!(Types[0..typeId], Types[typeId+1..$]);
89                 this.match!destroyIfOwner;
90 
91                 static if (isMovable!T) {
92                     mixin(q{
93                         Storage newStorage = {
94                             %s: forward!rhs
95                         };
96                     }.format(memberNameOf!typeId));
97                 } else {
98                     mixin(q{
99                         Storage newStorage = {
100                             %s: __ctfe ? rhs : forward!rhs
101                         };
102                     }.format(memberNameOf!typeId));
103                 }
104 
105                 storage = newStorage;
106                 static if (Types.length > 1)
107                     this.tag = typeId;
108                 
109                 return this;
110             }
111         }
112 
113     }
114 }
115 
116 /**
117     Gets whether the given handler can match any of the given types.
118     
119     See the documentation for $(D match) for a full explanation of how matches are
120     chosen.
121 */
122 template canMatch(alias handler, Types...)
123 if (Types.length > 0) {
124     enum canMatch = is(typeof((ref Types args) { return handler(args); }));
125 }
126 
127 /**
128     Calls a type-appropriate function with the value held in a [SumType].
129 
130     For each possible type the [SumType] can hold, the given handlers are
131     checked, in order, to see whether they accept a single argument of that type.
132     The first one that does is chosen as the match for that type. (Note that the
133     first match may not always be the most exact match.
134     See ["Avoiding unintentional matches"](#avoiding-unintentional-matches) for
135     one common pitfall.)
136 
137     Every type must have a matching handler, and every handler must match at
138     least one type. This is enforced at compile time.
139 
140     Handlers may be functions, delegates, or objects with `opCall` overloads. If
141     a function with more than one overload is given as a handler, all of the
142     overloads are considered as potential matches.
143 
144     Templated handlers are also accepted, and will match any type for which they
145     can be [implicitly instantiated](https://dlang.org/glossary.html#ifti).
146     (Remember that a $(DDSUBLINK spec/expression,function_literals, function literal)
147     without an explicit argument type is considered a template.)
148 
149     If multiple [SumType]s are passed to match, their values are passed to the
150     handlers as separate arguments, and matching is done for each possible
151     combination of value types. See ["Multiple dispatch"](#multiple-dispatch) for
152     an example.
153 
154     Returns:
155         The value returned from the handler that matches the currently-held type.
156 */
157 template match(handlers...) {
158     auto ref match(SumTypes...)(auto ref SumTypes args)
159     if (allSatisfy!(isSumType, SumTypes)) {
160         return matchImpl!(true, handlers)(args);
161     }
162 }
163 
164 template has(T) {
165 
166     /**
167         The actual `has` function.
168 
169         Params:
170             self = the `SumType` to check.
171 
172         Returns: true if `self` contains a `T`, otherwise false.
173     */
174     bool has(Self)(auto ref Self self)
175     if (isSumType!Self) {
176         return self.match!checkType;
177     }
178 
179     // Helper to avoid redundant template instantiations
180     private
181     bool checkType(Value)(ref Value value) {
182         return is(Value == T);
183     }
184 }
185 
186 /**
187     Accesses a `SumType`'s value.
188 
189     The value must be of the specified type. Use [has] to check.
190 
191     Params:
192         T = the type of the value being accessed.
193 */
194 template get(T) {
195     auto ref T get(Self)(auto ref Self self)
196     if (isSumType!Self) {
197         static if (__traits(isRef, self))
198             return self.match!(getLValue!(T));
199         else
200             return self.match!(getRValue!(T));
201     }
202     
203     private
204     ref T getLValue(Value)(ref Value value) {
205         static if (is(Value == T))
206             return value;
207         else
208             assert(false, "Could not get value!");
209     }
210     
211     private
212     T getRValue(Value)(ref Value value) {
213         static if (is(Value == T)) {
214             static if(is(typeof(move(value))))
215                 return __ctfe ? value : move(value);
216             else
217                 return value;
218         } else return T.init;
219     }
220 }
221 
222 private:
223 
224 template iota(size_t n) {
225     alias iota = AliasSeq!();
226 
227     static foreach(i; 0..n)
228         iota = AliasSeq!(iota, n);
229 }
230 
231 /*
232     Creates a member name from an ID, used for procedual
233     member generation.
234 */
235 template memberNameOf(size_t id) {
236     import std.conv : text;
237     enum memberNameOf = "value_"~id.text;
238 }
239 
240 /*
241     A TagTuple represents a single possible set of tags that the arguments to
242     `matchImpl` could have at runtime.
243 
244     Because D does not allow a struct to be the controlling expression
245     of a switch statement, we cannot dispatch on the TagTuple directly.
246     Instead, we must map each TagTuple to a unique integer and generate
247     a case label for each of those integers.
248 
249     This mapping is implemented in `fromCaseId` and `toCaseId`. It uses
250     the same technique that's used to map index tuples to memory offsets
251     in a multidimensional static array.
252 
253     For example, when `args` consists of two SumTypes with two member
254     types each, the TagTuples corresponding to each case label are:
255 
256     case 0:  TagTuple([0, 0])
257     case 1:  TagTuple([1, 0])
258     case 2:  TagTuple([0, 1])
259     case 3:  TagTuple([1, 1])
260 
261     When there is only one argument, the caseId is equal to that
262     argument's tag.
263 */
264 struct TagTuple(typeCounts...) {
265     size_t[typeCounts.length] tags;
266     alias tags this;
267     
268     alias stride(size_t) = .stride(i, typeCounts);
269     invariant {
270         static foreach(i; 0..tags.length)
271             assert(tags[i] < typeCounts[i], "Invalid tag");
272     }
273 
274     this(SumTypes...)(ref const SumTypes args)
275     if (allSatisfy!(isSumType, SumTypes) && args.length == typeCounts.length) {
276         static foreach (i; 0..tags.length)
277             tags[i] = args[i].tag;
278     }
279     
280     static TagTuple fromCaseId(size_t caseId) {
281         TagTuple result;
282 
283         // Most-significant to least-significant
284         static foreach_reverse (i; 0..result.length) {
285             result[i] = caseId / stride!i;
286             caseId %= stride!i;
287         }
288 
289         return result;
290     }
291 
292     size_t toCaseId() {
293         size_t result;
294         static foreach (i; 0..tags.length)
295             result += tags[i] * stride!i;
296 
297         return result;
298     }
299 }
300 
301 // Gets the count of types within a sumtype, for use with staticMap.
302 enum countSumType(T) = T.Types.length;
303 
304 size_t stride(size_t dim, lengths...)() {
305     import core.checkedint : mulu;
306 
307     size_t result = 1;
308     bool overflow = false;
309     static foreach(i; 0 .. dim) {
310         result = mulu(result, lengths[i], overflow);
311     }
312 
313     assert(!overflow, "Integer overflow");
314     return result;
315 }
316 
317 template handlerArgs(size_t dim, typeCounts...) {
318     import std.format : format;
319 
320     enum tags = TagTuple!typeCounts.fromCaseId(caseId);
321     alias handlerArgs = AliasSeq!();
322 
323     static foreach(i; 0..tags.length) {
324         handlerArgs = AliasSeq!(
325             handlerArgs,
326             "args[%s].getByIndex!(&s)()".format(i, tags[i])
327         );
328     }
329 }
330 
331 template matchImpl(bool try_, handlers...) {
332     bool canFind(T)(T[] haystack, T needle) @nogc {
333         foreach(ref item; haystack)
334             if (item == needle)
335                 return true;
336 
337         return false;
338     }
339 
340     auto ref matchImpl(SumTypes...)(auto ref SumTypes args)
341     if (allSatisfy!(isSumType, SumTypes) && args.length > 0) {
342         import std.format : format;
343 
344         // Generate dispatch.
345         static if (args.length == 1) {
346 
347             // Single dispatch
348             enum handlerArgs(size_t caseId) = "args[0].getByIndex!(%s)()".format(caseId);
349             enum numCases = SumTypes[0].AllowedTypes.length;
350             alias valueTypes(size_t caseId) = typeof(args[0].getByIndex!(caseId)());
351         } else {
352 
353             // Multi-dispatch
354             alias typeCount = staticMap!(countSumType, SumTypes);
355             alias stride(size_t i) = .stride!(i, typeCount);
356             alias TagTuple = .TagTuple!typeCount;
357             alias handlerArgs(size_t caseId) = .handlerArgs!(caseId, typeCount);
358 
359             template valueTypes(size_t caseId) {
360                 enum tags = TagTuple.fromCaseId(caseId);
361                 alias getType(size_t i) = typeof(args[i].getByIndex!(tags[i])());
362                 alias valueTypes = staticMap!(getType, iota!(tags.length));
363             }
364 
365             enum numCases = stride!(SumTypes.length);
366         }
367 
368         // No-match ID.
369         enum noMatch = size_t.max;
370 
371         // Static array that maps case IDs to handler IDs.
372         enum matches = () {
373             size_t[numCases] result;
374             foreach(ref match; result) {
375                 match = noMatch;
376             }
377 
378             static foreach(caseId; 0..numCases) {
379                 static foreach(handlerId, handler; handlers) {
380                     static if (canMatch!(handler, valueTypes!caseId)) {
381                         if (result[caseId] == noMatch)
382                             result[caseId] = handlerId;
383                     }
384                 }
385             }
386 
387             return result;
388         }();
389 
390         static foreach(handlerId, handler; handlers) {
391             static assert(canFind(matches[], handlerId), "Handler "~typeof(handler).stringof~" never matches.");
392         }
393 
394         enum handlerName(size_t hid) = "handler%s".format(hid);
395         static foreach(size_t hid, handler; handlers) {
396             mixin("alias ", handlerName!hid, " = handler;");
397         }
398 
399         static if (args.length == 1)
400             immutable argsId = args[0].tag;
401         else
402             immutable argsId = TagTuple(args).toCaseId;
403         
404         final switch(argsId) {
405             static foreach(caseId; 0..numCases) {
406                 case caseId:
407                     static if (matches[caseId] != noMatch)
408                         return mixin(handlerName!(matches[caseId]), "(", handlerArgs!caseId, ")");
409                     else {
410                         static if (!try_) {
411                             static assert(false, "No matching handler for types `"~valueTypes!(caseId).stringof~"`!");
412                         }
413                     } 
414             }
415         }
416 
417         assert(false, "unreachable");
418     }
419 }
420 
421 void destroyIfOwner(T)(ref T value) {
422     nogc_delete(value);
423 }