1 module lumars.state;
2 
3 import bindbc.lua, std, taggedalgebraic, lumars;
4 import taggedalgebraic : visit;
5 
6 /// Used to represent LUA's `nil`.
7 struct LuaNil {}
8 
9 /// See `LuaValue`
10 union LuaValueUnion
11 {
12     /// LUA `nil`
13     LuaNil nil;
14 
15     /// A lua number
16     lua_Number number;
17 
18     /// A weak reference to some text. This text is managed by LUA and not D's GC so is unsafe to escape
19     const(char)[] textWeak;
20 
21     /// GC-managed text
22     string text;
23 
24     /// A bool
25     bool boolean;
26 
27     /// A weak reference to a table currently on the LUA stack.
28     LuaTableWeak tableWeak;
29 
30     /// A strong reference to a table which is in the LUA registry.
31     LuaTable table;
32 
33     /// A weak reference to a function currently on the LUA stack.
34     LuaFuncWeak funcWeak;
35 
36     /// A strong reference to a function which is in the LUA registry.
37     LuaFunc func;
38 
39     void* userData;
40 }
41 
42 /// An enumeration of various status codes LUA may return.
43 enum LuaStatus
44 {
45     ok = 0,
46     yield = LUA_YIELD,
47     errRun = LUA_ERRRUN,
48     errSyntax = LUA_ERRSYNTAX,
49     errMem = LUA_ERRMEM,
50     errErr = LUA_ERRERR,
51 }
52 
53 /// A `TaggedUnion` of `LuaValueUnion` which is used to bridge the gap between D and Lua values.
54 alias LuaValue = TaggedUnion!LuaValueUnion;
55 alias LuaNumber = lua_Number;
56 alias LuaCFunc = lua_CFunction;
57 
58 /++
59  + A light wrapper around `lua_State` with some higher level functions for quality of life purposes.
60  +
61  + This struct cannot be copied, so put it on the heap or store it as a global.
62  + ++/
63 struct LuaState
64 {
65     @disable this(this){}
66 
67     private
68     {
69         lua_State*      _handle;
70         bool            _isWrapper;
71         LuaTablePseudo  _G;
72     }
73 
74     /// Creates a wrapper around the given `lua_state`, or creates a new state if the given value is null.
75     @trusted @nogc
76     this(lua_State* wrapAround) nothrow
77     {
78         if(wrapAround)
79         {
80             this._handle = wrapAround;
81             this._isWrapper = true;
82         }
83         else
84         {
85             this._handle = luaL_newstate();
86             luaL_openlibs(this.handle);
87         }
88 
89         this._G = LuaTablePseudo(&this, LUA_GLOBALSINDEX);
90     }
91 
92     /// For non-wrappers, destroy the lua state.
93     @trusted @nogc
94     ~this() nothrow
95     {
96         if(this._handle && !this._isWrapper)
97             lua_close(this._handle);
98     }
99 
100     @nogc
101     LuaTablePseudo globalTable() nothrow
102     {
103         return this._G;
104     }
105     
106     @nogc
107     lua_CFunction atPanic(lua_CFunction func) nothrow
108     {
109         return lua_atpanic(this.handle, func);
110     }
111 
112     @nogc
113     void call(int nargs, int nresults) nothrow
114     {
115         lua_call(this.handle, nargs, nresults);
116     }
117 
118     @nogc
119     bool checkStack(int amount) nothrow
120     {
121         return lua_checkstack(this.handle, amount) != 0;
122     }
123 
124     @nogc
125     void concat(int nargs) nothrow
126     {
127         lua_concat(this.handle, nargs);
128     }
129 
130     @nogc
131     bool equal(int index1, int index2) nothrow
132     {
133         return lua_equal(this.handle, index1, index2) != 0;
134     }
135 
136     @nogc
137     void error() nothrow
138     {
139         lua_error(this.handle);
140     }
141 
142     void error(const char[] msg) nothrow
143     {
144         luaL_error(this.handle, "%s", msg.toStringz);
145     }
146 
147     LuaTableWeak pushMetatable(int ofIndex)
148     {
149         lua_getmetatable(this.handle, ofIndex);
150         return LuaTableWeak(&this, -1);
151     }
152 
153     LuaTable getMetatable(int ofIndex)
154     {
155         lua_getmetatable(this.handle, ofIndex);
156         return LuaTable.makeRef(&this);
157     }
158 
159     @nogc
160     bool lessThan(int index1, int index2) nothrow
161     {
162         return lua_lessthan(this.handle, index1, index2) != 0;
163     }
164     
165     @nogc
166     bool rawEqual(int index1, int index2) nothrow
167     {
168         return lua_rawequal(this.handle, index1, index2) != 0;
169     }
170 
171     @nogc
172     void pushTable(int tableIndex) nothrow
173     {
174         return lua_gettable(this.handle, tableIndex);
175     }
176 
177     @nogc
178     void insert(int index) nothrow
179     {
180         return lua_insert(this.handle, index);
181     }
182 
183     @nogc
184     size_t len(int index) nothrow
185     {
186         return lua_objlen(this.handle, index);
187     }
188 
189     @nogc
190     LuaStatus pcall(int nargs, int nresults, int errFuncIndex) nothrow
191     {
192         return cast(LuaStatus)lua_pcall(this.handle, nargs, nresults, errFuncIndex);
193     }
194 
195     @nogc
196     void copy(int index) nothrow
197     {
198         lua_pushvalue(this.handle, index);
199     }
200 
201     @nogc
202     void rawGet(int tableIndex) nothrow
203     {
204         lua_rawget(this.handle, tableIndex);
205     }
206 
207     @nogc
208     void rawGet(int tableIndex, int indexIntoTable) nothrow
209     {
210         lua_rawgeti(this.handle, tableIndex, indexIntoTable);
211     }
212 
213     @nogc
214     void rawSet(int tableIndex) nothrow
215     {
216         lua_rawset(this.handle, tableIndex);
217     }
218 
219     @nogc
220     void rawSet(int tableIndex, int indexIntoTable) nothrow
221     {
222         lua_rawseti(this.handle, tableIndex, indexIntoTable);
223     }
224 
225     void getGlobal(const char[] name)
226     {
227         lua_getglobal(this.handle, name.toStringz);
228     }
229 
230     void setGlobal(const char[] name)
231     {
232         lua_setglobal(this.handle, name.toStringz);
233     }
234 
235     void register(const char[] name, LuaCFunc func) nothrow
236     {
237         lua_register(this.handle, name.toStringz, func);
238     }
239 
240     void register(alias Func)(const char[] name)
241     {
242         this.register(name, &luaCWrapperSmart!Func);
243     }
244 
245     void register(Args...)(const char[] libname)
246     if(Args.length % 2 == 0)
247     {
248         luaL_Reg[(Args.length / 2) + 1] reg;
249 
250         static foreach(i; 0..Args.length/2)
251             reg[i] = luaL_Reg(Args[i*2].ptr, &luaCWrapperSmart!(Args[i*2+1]));
252 
253         luaL_register(this.handle, libname.toStringz, reg.ptr);
254     }
255 
256     @nogc
257     void remove(int index) nothrow
258     {
259         lua_remove(this.handle, index);
260     }
261 
262     @nogc
263     void replace(int index) nothrow
264     {
265         lua_replace(this.handle, index);
266     }
267 
268     @nogc
269     void setMetatable(int ofIndex) nothrow
270     {
271         lua_setmetatable(this.handle, ofIndex);
272     }
273 
274     void checkArg(bool condition, int argNum, const(char)[] extraMsg = null) nothrow
275     {
276         luaL_argcheck(this.handle, condition ? 1 : 0, argNum, extraMsg ? extraMsg.toStringz : null);
277     }
278 
279     bool callMetamethod(int index, const char[] method) nothrow
280     {
281         return luaL_callmeta(this.handle, index, method.toStringz) != 0;
282     }
283 
284     @nogc
285     void checkAny(int arg) nothrow
286     {
287         luaL_checkany(this.handle, arg);
288     }
289 
290     @nogc
291     ptrdiff_t checkInt(int arg) nothrow
292     {
293         return luaL_checkinteger(this.handle, arg);
294     }
295 
296     @nogc
297     const(char)[] checkStringWeak(int arg) nothrow
298     {
299         size_t len;
300         const ptr = luaL_checklstring(this.handle, arg, &len);
301         return ptr[0..len];
302     }
303 
304     string checkString(int arg) nothrow
305     {
306         return checkStringWeak(arg).idup;
307     }
308 
309     @nogc
310     LuaNumber checkNumber(int arg) nothrow
311     {
312         return luaL_checknumber(this.handle, arg);
313     }
314 
315     @nogc
316     void checkType(LuaValue.Kind type, int arg) nothrow
317     {
318         int t;
319 
320         final switch(type) with(LuaValue.Kind)
321         {
322             case nil: t = LUA_TNIL; break;
323             case number: t = LUA_TNUMBER; break;
324             case textWeak:
325             case text: t = LUA_TSTRING; break;
326             case boolean: t = LUA_TBOOLEAN; break;
327             case tableWeak:
328             case table: t = LUA_TTABLE; break;
329             case funcWeak:
330             case func: t = LUA_TFUNCTION; break;
331             case userData: t = LUA_TLIGHTUSERDATA; break;
332         }
333 
334         luaL_checktype(this.handle, arg, t);
335     }
336 
337     void doFile(const char[] file)
338     {
339         const status = luaL_dofile(this.handle, file.toStringz);
340         if(status != LuaStatus.ok)
341         {
342             const error = this.get!string(-1);
343             this.pop(1);
344             throw new Exception(error);
345         }
346     }
347 
348     void doString(const char[] str)
349     {
350         const status = luaL_dostring(this.handle, str.toStringz);
351         if(status != LuaStatus.ok)
352         {
353             const error = this.get!string(-1);
354             this.pop(1);
355             throw new Exception(error);
356         }
357     }
358 
359     void loadFile(const char[] file)
360     {
361         const status = luaL_loadfile(this.handle, file.toStringz);
362         if(status != LuaStatus.ok)
363         {
364             const error = this.get!string(-1);
365             this.pop(1);
366             throw new Exception(error);
367         }
368     }
369 
370     void loadString(const char[] str)
371     {
372         const status = luaL_loadstring(this.handle, str.toStringz);
373         if(status != LuaStatus.ok)
374         {
375             const error = this.get!string(-1);
376             this.pop(1);
377             throw new Exception(error);
378         }
379     }
380 
381     @nogc
382     ptrdiff_t optInt(int arg, ptrdiff_t default_) nothrow
383     {
384         return luaL_optinteger(this.handle, arg, default_);
385     }
386 
387     @nogc
388     LuaNumber optNumber(int arg, LuaNumber default_) nothrow
389     {
390         return luaL_optnumber(this.handle, arg, default_);
391     }
392 
393     void printStack()
394     {
395         writeln("[LUA STACK]");
396         foreach(i; 0..this.top)
397         {
398             const type = lua_type(this.handle, i+1);
399             writef("\t[%s] \t", i+1);
400 
401             switch(type)
402             {
403                 case LUA_TBOOLEAN: writefln("%s\t%s", "BOOL", this.get!bool(i+1)); break;
404                 case LUA_TFUNCTION: writefln("%s\t%s", "FUNC", lua_tocfunction(this.handle, i+1)); break;
405                 case LUA_TLIGHTUSERDATA: writefln("%s\t%s", "LIGHT", lua_touserdata(this.handle, i+1)); break;
406                 case LUA_TNIL: writefln("%s", "NIL"); break;
407                 case LUA_TNUMBER: writefln("%s\t%s", "NUM", this.get!lua_Number(i+1)); break;
408                 case LUA_TSTRING: writefln("%s\t%s", "STR", this.get!(const(char)[])(i+1)); break;
409                 case LUA_TTABLE: writefln("%s", "TABL"); break;
410                 case LUA_TTHREAD: writefln("%s\t%s", "THRD", lua_tothread(this.handle, i+1)); break;
411                 case LUA_TUSERDATA: writefln("%s\t%s", "USER", lua_touserdata(this.handle, i+1)); break;
412                 default: writefln("%s\t%s", "UNKN", type); break;
413             }
414         }
415     }
416 
417     void push(T)(T value)
418     {
419         static if(is(T == typeof(null)) || is(T == LuaNil))
420             lua_pushnil(this.handle);
421         else static if(is(T : const(char)[]))
422             lua_pushlstring(this.handle, value.ptr, value.length);
423         else static if(isNumeric!T)
424             lua_pushnumber(this.handle, value.to!lua_Number);
425         else static if(is(T : const(bool)))
426             lua_pushboolean(this.handle, value ? 1 : 0);
427         else static if(isDynamicArray!T)
428         {
429             alias ValueT = typeof(value[0]);
430 
431             lua_createtable(this.handle, 0, value.length.to!int);
432             foreach(i, v; value)
433             {
434                 this.push(v);
435                 lua_rawseti(this.handle, -2, cast(int)i+1);
436             }
437         }
438         else static if(isAssociativeArray!T)
439         {
440             alias KeyT = KeyType!T;
441             alias ValueT = ValueType!T;
442 
443             lua_createtable(this.handle, 0, value.length.to!int);
444             foreach(k, v; value)
445             {
446                 this.push(k);
447                 this.push(v);
448                 lua_rawset(this.handle, -3);
449             }
450         }
451         else static if(is(T == LuaTable) || is(T == LuaFunc))
452             value.push();
453         else static if(is(T == LuaTableWeak) || is(T == LuaFuncWeak))
454             this.copy(value.push());
455         else static if(is(T : lua_CFunction))
456             lua_pushcfunction(this.handle, value);
457         else static if(isDelegate!T)
458         {
459             lua_pushlightuserdata(this.handle, value.ptr);
460             lua_pushlightuserdata(this.handle, value.funcptr);
461             lua_pushcclosure(this.handle, &luaCWrapperSmart!(T, LuaFuncWrapperType.isDelegate), 2);
462         }
463         else static if(isPointer!T && isFunction!(PointerTarget!T))
464         {
465             lua_pushlightuserdata(this.handle, value);
466             lua_pushcclosure(this.handle, &luaCWrapperSmart!(T, LuaFuncWrapperType.isFunction), 1);
467         }
468         else static if(isPointer!T)
469             lua_pushlightuserdata(this.handle, value);
470         else static if(is(T == class))
471             lua_pushlightuserdata(this.handle, cast(void*)value);
472         else static if(is(T == struct))
473         {
474             lua_newtable(this.handle);
475 
476             static foreach(member; __traits(allMembers, T))
477             {
478                 this.push(member);
479                 this.push(mixin("value."~member));
480                 lua_settable(this.handle, -3);
481             }
482         }
483         else static assert(false, "Don't know how to push type: "~T.stringof);
484     }
485 
486     void push(LuaValue value)
487     {
488         value.visit!(
489             (_){ this.push(_); }
490         );
491     }
492 
493     @nogc
494     int top() nothrow
495     {
496         return lua_gettop(this.handle);
497     }
498 
499     @nogc
500     void pop(int amount) nothrow
501     {
502         lua_pop(this.handle, amount);
503     }
504 
505     T get(T)(int index)
506     {
507         static if(is(T == string))
508         {
509             this.enforceType(LuaValue.Kind.text, index);
510             size_t len;
511             auto ptr = lua_tolstring(this.handle, index, &len);
512             return ptr[0..len].idup;
513         }
514         else static if(is(T == const(char)[]))
515         {
516             this.enforceType(LuaValue.Kind.text, index);
517             size_t len;
518             auto ptr = lua_tolstring(this.handle, index, &len);
519             return ptr[0..len];
520         }
521         else static if(is(T : const(bool)))
522         {
523             this.enforceType(LuaValue.Kind.boolean, index);
524             return lua_toboolean(this.handle, index) != 0;
525         }
526         else static if(isNumeric!T)
527         {
528             this.enforceType(LuaValue.Kind.number, index);
529             return lua_tonumber(this.handle, index).to!T;
530         }
531         else static if(is(T == typeof(null)) || is(T == LuaNil))
532         {
533             this.enforceType(LuaValue.Kind.nil, index);
534             return LuaNil();
535         }
536         else static if(is(T == LuaTableWeak))
537         {
538             this.enforceType(LuaValue.Kind.table, index);
539             return T(&this, index);
540         }
541         else static if(is(T == LuaTable))
542         {
543             this.enforceType(LuaValue.Kind.table, index);
544             this.copy(index);
545             return T.makeRef(&this);
546         }
547         else static if(isDynamicArray!T)
548         {
549             this.enforceType(LuaValue.Kind.table, index);
550             T ret;
551             ret.length = lua_objlen(this.handle, index);
552 
553             this.push(null);
554             const tableIndex = index < 0 ? index - 1 : index;
555             while(this.next(tableIndex))
556             {
557                 ret[this.get!size_t(-2) - 1] = this.get!(typeof(ret[0]))(-1);
558                 this.pop(1);
559             }
560 
561             return ret;
562         }
563         else static if(isAssociativeArray!T)
564         {
565             this.enforceType(LuaValue.Kind.table, index);
566             T ret;
567 
568             this.push(null);
569             const tableIndex = index < 0 ? index - 1 : index;
570             while(this.next(tableIndex))
571             {
572                 ret[this.get!(KeyType!T)(-2)] = this.get!(ValueType!T)(-1);
573                 this.pop(1);
574             }
575 
576             return ret;
577         }
578         else static if(is(T == LuaCFunc))
579         {
580             this.enforceType(LuaValue.Kind.func, index);
581             return lua_tocfunction(this.handle, index);
582         }
583         else static if(is(T == LuaFuncWeak))
584         {
585             this.enforceType(LuaValue.Kind.func, index);
586             return LuaFuncWeak(&this, index);
587         }
588         else static if(is(T == LuaFunc))
589         {
590             this.enforceType(LuaValue.Kind.func, index);
591             this.copy(index);
592             return T.makeRef(&this);
593         }
594         else static if(isPointer!T || is(T == class))
595         {
596             this.enforceType(LuaValue.Kind.userData, index);
597             return cast(T)lua_touserdata(this.handle, index);
598         }
599         else static if(is(T == LuaValue))
600         {
601             switch(this.type(index))
602             {
603                 case LuaValue.Kind.text: return LuaValue(this.get!string(index));
604                 case LuaValue.Kind.number: return LuaValue(this.get!lua_Number(index));
605                 case LuaValue.Kind.boolean: return LuaValue(this.get!bool(index));
606                 case LuaValue.Kind.nil: return LuaValue(this.get!LuaNil(index));
607                 case LuaValue.Kind.table: return LuaValue(this.get!LuaTable(index));
608                 case LuaValue.Kind.func: return LuaValue(this.get!LuaFunc(index));
609                 case LuaValue.Kind.userData: return LuaValue(this.get!(void*)(index));
610                 default: throw new Exception("Don't know how to convert type into a LuaValue: "~this.type(index).to!string);
611             }
612         }
613         else static if(is(T == struct))
614         {
615             this.enforceType(LuaValue.Kind.table, index);
616             T ret;
617 
618             this.push(null);
619             const tableIndex = index < 0 ? index - 1 : index;
620             While: while(this.next(tableIndex))
621             {
622                 const field = this.get!(const(char)[])(-2);
623 
624                 static foreach(member; __traits(allMembers, T))
625                 {
626                     if(field == member)
627                     {
628                         mixin("ret."~member~"= this.get!(typeof(ret."~member~"))(-1);");
629                         this.pop(1);
630                         continue While;
631                     }
632                 }
633 
634                 this.pop(1);
635             }
636             return ret;
637         }
638         else static assert(false, "Don't know how to convert any LUA values into type: "~T.stringof);
639     }
640 
641     @nogc
642     bool next(int index) nothrow
643     {
644         this.assertIndex(index);
645         return lua_next(this.handle, index) != 0;
646     }
647 
648     void enforceType(LuaValue.Kind expected, int index)
649     {
650         const type = this.type(index);
651         enforce(type == expected, "Expected value at stack index %s to be of type %s but it is %s".format(
652             index, expected, type
653         ));
654     }
655 
656     @nogc
657     LuaValue.Kind type(int index) nothrow
658     {
659         assert(this.top > 0, "Stack is empty.");
660         this.assertIndex(index);
661         const type = lua_type(this.handle, index);
662 
663         switch(type)
664         {
665             case LUA_TBOOLEAN: return LuaValue.Kind.boolean;
666             case LUA_TNIL: return LuaValue.Kind.nil;
667             case LUA_TNUMBER: return LuaValue.Kind.number;
668             case LUA_TSTRING: return LuaValue.Kind.text;
669             case LUA_TTABLE: return LuaValue.Kind.table;
670             case LUA_TFUNCTION: return LuaValue.Kind.func;
671             case LUA_TLIGHTUSERDATA: return LuaValue.Kind.userData;
672 
673             default: 
674                 return LuaValue.Kind.nil;
675         }
676     }
677 
678     @property @safe @nogc
679     inout(lua_State*) handle() nothrow pure inout
680     {
681         return this._handle;
682     }
683 
684     @nogc
685     private void assertIndex(int index) nothrow
686     {
687         if(index > 0)
688             assert(this.top >= index, "Index out of bounds");
689         else
690             assert(this.top + index >= 0, "Index out of bounds");
691     }
692 }
693 
694 unittest
695 {
696     auto l = LuaState(null);
697     l.push(null);
698     assert(l.type(-1) == LuaValue.Kind.nil);
699     assert(l.get!LuaValue(-1).kind == LuaValue.Kind.nil);
700     l.pop(1);
701 
702     l.push(LuaNil());
703     assert(l.type(-1) == LuaValue.Kind.nil);
704     assert(l.get!LuaValue(-1).kind == LuaValue.Kind.nil);
705     l.pop(1);
706 
707     l.push(false);
708     assert(l.get!LuaValue(-1).kind == LuaValue.Kind.boolean);
709     assert(!l.get!bool(-1));
710     l.pop(1);
711 
712     l.push(20);
713     assert(l.get!LuaValue(-1).kind == LuaValue.Kind.number);
714     assert(l.get!int(-1) == 20);
715     l.pop(1);
716 
717     l.push("abc");
718     assert(l.get!LuaValue(-1).kind == LuaValue.Kind.text);
719     assert(l.get!string(-1) == "abc");
720     assert(l.get!(const(char)[])(-1) == "abc");
721     l.pop(1);
722 
723     l.push(["abc", "one"]);
724     assert(l.get!(string[])(-1) == ["abc", "one"]);
725     l.pop(1);
726 
727     l.push([LuaValue(200), LuaValue("abc")]);
728     assert(l.get!(LuaValue[])(-1) == [LuaValue(200), LuaValue("abc")]);
729     l.pop(1);
730 }
731 
732 unittest
733 {
734     auto l = LuaState(null);
735     l.register!(() => 123)("abc");
736     l.doString("assert(abc() == 123)");
737 }
738 
739 unittest
740 {
741     auto l = LuaState(null);
742     l.register!(
743         "funcA", () => "a",
744         "funcB", () => "b"
745     )("lib");
746     l.doString("assert(lib.funcA() == 'a') assert(lib.funcB() == 'b')");
747 }
748 
749 unittest
750 {
751     auto l = LuaState(null);
752     l.doString("abba = 'chicken tikka'");
753     assert(l.globalTable.get!string("abba") == "chicken tikka");
754     l.globalTable["baab"] = "tikka chicken";
755     assert(l.globalTable.get!string("baab") == "tikka chicken");
756 }
757 
758 unittest
759 {
760     static struct B
761     {
762         string a;
763     }
764 
765     static struct C
766     {
767         string a;
768     }
769  
770     static struct A
771     {
772         string a;
773         B[] b;
774         C[string] c;
775     }
776 
777     auto a = A(
778         "bc",
779         [B("c")],
780         ["c": C("123")]
781     );
782 
783     auto l = LuaState(null);
784     l.push(a);
785 
786     auto luaa = l.get!A(-1);
787     assert(luaa.a == "bc");
788     assert(luaa.b.length == 1);
789     assert(luaa.b == [B("c")]);
790     assert(luaa.c.length == 1);
791     assert(luaa.c["c"] == C("123"));
792 }