1 module lumars.state;
2 
3 import bindbc.lua, taggedalgebraic, lumars;
4 import taggedalgebraic : visit;
5 
6 /// Used to represent LUA's `nil`.
7 struct LuaNil {}
8 
9 /// See `LuaValue`
10 union LuaValueUnion
11 {
12     /// LUA `nil`
13     LuaNil nil;
14 
15     /// A lua number
16     lua_Number number;
17 
18     /// A weak reference to some text. This text is managed by LUA and not D's GC so is unsafe to escape
19     const(char)[] textWeak;
20 
21     /// GC-managed text
22     string text;
23 
24     /// A bool
25     bool boolean;
26 
27     /// A weak reference to a table currently on the LUA stack.
28     LuaTableWeak tableWeak;
29 
30     /// A strong reference to a table which is in the LUA registry.
31     LuaTable table;
32 
33     /// A weak reference to a function currently on the LUA stack.
34     LuaFuncWeak funcWeak;
35 
36     /// A strong reference to a function which is in the LUA registry.
37     LuaFunc func;
38 
39     void* userData;
40 }
41 
42 /// An enumeration of various status codes LUA may return.
43 enum LuaStatus
44 {
45     ok = 0,
46     yield = LUA_YIELD,
47     errRun = LUA_ERRRUN,
48     errSyntax = LUA_ERRSYNTAX,
49     errMem = LUA_ERRMEM,
50     errErr = LUA_ERRERR,
51 }
52 
53 /// A `TaggedUnion` of `LuaValueUnion` which is used to bridge the gap between D and Lua values.
54 alias LuaValue = TaggedUnion!LuaValueUnion;
55 alias LuaNumber = lua_Number;
56 alias LuaCFunc = lua_CFunction;
57 
58 /++
59  + A light wrapper around `lua_State` with some higher level functions for quality of life purposes.
60  +
61  + This struct cannot be copied, so put it on the heap or store it as a global.
62  + ++/
63 struct LuaState
64 {
65     import std.string : toStringz;
66     import std.stdio : writefln, writeln, writef;
67 
68     @disable this(this){}
69 
70     private
71     {
72         lua_State*      _handle;
73         LuaTablePseudo  _G;
74     }
75     package bool        _isWrapper;
76 
77     /// Creates a wrapper around the given `lua_state`, or creates a new state if the given value is null.
78     @trusted @nogc
79     this(lua_State* wrapAround) nothrow
80     {
81         if(wrapAround)
82         {
83             this._handle = wrapAround;
84             this._isWrapper = true;
85         }
86         else
87         {
88             this._handle = luaL_newstate();
89             luaL_openlibs(this.handle);
90         }
91 
92         this._G = LuaTablePseudo(&this, LUA_GLOBALSINDEX);
93     }
94 
95     /// For non-wrappers, destroy the lua state.
96     @trusted @nogc
97     ~this() nothrow
98     {
99         if(this._handle && !this._isWrapper)
100             lua_close(this._handle);
101     }
102 
103     @nogc
104     LuaTablePseudo globalTable() nothrow
105     {
106         return this._G;
107     }
108     
109     @nogc
110     lua_CFunction atPanic(lua_CFunction func) nothrow
111     {
112         return lua_atpanic(this.handle, func);
113     }
114 
115     @nogc
116     void call(int nargs, int nresults) nothrow
117     {
118         lua_call(this.handle, nargs, nresults);
119     }
120 
121     @nogc
122     bool checkStack(int amount) nothrow
123     {
124         return lua_checkstack(this.handle, amount) != 0;
125     }
126 
127     @nogc
128     void concat(int nargs) nothrow
129     {
130         lua_concat(this.handle, nargs);
131     }
132 
133     @nogc
134     bool equal(int index1, int index2) nothrow
135     {
136         return lua_equal(this.handle, index1, index2) != 0;
137     }
138 
139     @nogc
140     void error() nothrow
141     {
142         lua_error(this.handle);
143     }
144 
145     void error(const char[] msg) nothrow
146     {
147         luaL_error(this.handle, "%s", msg.toStringz);
148     }
149 
150     LuaTableWeak pushMetatable(int ofIndex)
151     {
152         lua_getmetatable(this.handle, ofIndex);
153         return LuaTableWeak(&this, -1);
154     }
155 
156     LuaTable getMetatable(int ofIndex)
157     {
158         lua_getmetatable(this.handle, ofIndex);
159         return LuaTable.makeRef(&this);
160     }
161 
162     @nogc
163     bool lessThan(int index1, int index2) nothrow
164     {
165         return lua_lessthan(this.handle, index1, index2) != 0;
166     }
167     
168     @nogc
169     bool rawEqual(int index1, int index2) nothrow
170     {
171         return lua_rawequal(this.handle, index1, index2) != 0;
172     }
173 
174     @nogc
175     void pushTable(int tableIndex) nothrow
176     {
177         return lua_gettable(this.handle, tableIndex);
178     }
179 
180     @nogc
181     void insert(int index) nothrow
182     {
183         return lua_insert(this.handle, index);
184     }
185 
186     @nogc
187     size_t len(int index) nothrow
188     {
189         return lua_objlen(this.handle, index);
190     }
191 
192     @nogc
193     LuaStatus pcall(int nargs, int nresults, int errFuncIndex) nothrow
194     {
195         return cast(LuaStatus)lua_pcall(this.handle, nargs, nresults, errFuncIndex);
196     }
197 
198     @nogc
199     void copy(int index) nothrow
200     {
201         lua_pushvalue(this.handle, index);
202     }
203 
204     @nogc
205     void rawGet(int tableIndex) nothrow
206     {
207         lua_rawget(this.handle, tableIndex);
208     }
209 
210     @nogc
211     void rawGet(int tableIndex, int indexIntoTable) nothrow
212     {
213         lua_rawgeti(this.handle, tableIndex, indexIntoTable);
214     }
215 
216     @nogc
217     void rawSet(int tableIndex) nothrow
218     {
219         lua_rawset(this.handle, tableIndex);
220     }
221 
222     @nogc
223     void rawSet(int tableIndex, int indexIntoTable) nothrow
224     {
225         lua_rawseti(this.handle, tableIndex, indexIntoTable);
226     }
227 
228     void getGlobal(const char[] name)
229     {
230         lua_getglobal(this.handle, name.toStringz);
231     }
232 
233     void setGlobal(const char[] name)
234     {
235         lua_setglobal(this.handle, name.toStringz);
236     }
237 
238     void register(const char[] name, LuaCFunc func) nothrow
239     {
240         lua_register(this.handle, name.toStringz, func);
241     }
242 
243     void register(alias Func)(const char[] name)
244     {
245         this.register(name, &luaCWrapperSmart!Func);
246     }
247 
248     void register(Args...)(const char[] libname)
249     if(Args.length % 2 == 0)
250     {
251         luaL_Reg[(Args.length / 2) + 1] reg;
252 
253         static foreach(i; 0..Args.length/2)
254             reg[i] = luaL_Reg(Args[i*2].ptr, &luaCWrapperSmart!(Args[i*2+1]));
255 
256         luaL_register(this.handle, libname.toStringz, reg.ptr);
257     }
258 
259     @nogc
260     void remove(int index) nothrow
261     {
262         lua_remove(this.handle, index);
263     }
264 
265     @nogc
266     void replace(int index) nothrow
267     {
268         lua_replace(this.handle, index);
269     }
270 
271     @nogc
272     void setMetatable(int ofIndex) nothrow
273     {
274         lua_setmetatable(this.handle, ofIndex);
275     }
276 
277     void checkArg(bool condition, int argNum, const(char)[] extraMsg = null) nothrow
278     {
279         luaL_argcheck(this.handle, condition ? 1 : 0, argNum, extraMsg ? extraMsg.toStringz : null);
280     }
281 
282     bool callMetamethod(int index, const char[] method) nothrow
283     {
284         return luaL_callmeta(this.handle, index, method.toStringz) != 0;
285     }
286 
287     @nogc
288     void checkAny(int arg) nothrow
289     {
290         luaL_checkany(this.handle, arg);
291     }
292 
293     @nogc
294     ptrdiff_t checkInt(int arg) nothrow
295     {
296         return luaL_checkinteger(this.handle, arg);
297     }
298 
299     @nogc
300     const(char)[] checkStringWeak(int arg) nothrow
301     {
302         size_t len;
303         const ptr = luaL_checklstring(this.handle, arg, &len);
304         return ptr[0..len];
305     }
306 
307     string checkString(int arg) nothrow
308     {
309         return checkStringWeak(arg).idup;
310     }
311 
312     @nogc
313     LuaNumber checkNumber(int arg) nothrow
314     {
315         return luaL_checknumber(this.handle, arg);
316     }
317 
318     @nogc
319     void checkType(LuaValue.Kind type, int arg) nothrow
320     {
321         int t;
322 
323         final switch(type) with(LuaValue.Kind)
324         {
325             case nil: t = LUA_TNIL; break;
326             case number: t = LUA_TNUMBER; break;
327             case textWeak:
328             case text: t = LUA_TSTRING; break;
329             case boolean: t = LUA_TBOOLEAN; break;
330             case tableWeak:
331             case table: t = LUA_TTABLE; break;
332             case funcWeak:
333             case func: t = LUA_TFUNCTION; break;
334             case userData: t = LUA_TLIGHTUSERDATA; break;
335         }
336 
337         luaL_checktype(this.handle, arg, t);
338     }
339 
340     void doFile(const char[] file)
341     {
342         const status = luaL_dofile(this.handle, file.toStringz);
343         if(status != LuaStatus.ok)
344         {
345             const error = this.get!string(-1);
346             this.pop(1);
347             throw new Exception(error);
348         }
349     }
350 
351     void doString(const char[] str)
352     {
353         const status = luaL_dostring(this.handle, str.toStringz);
354         if(status != LuaStatus.ok)
355         {
356             const error = this.get!string(-1);
357             this.pop(1);
358             throw new Exception(error);
359         }
360     }
361 
362     void loadFile(const char[] file)
363     {
364         const status = luaL_loadfile(this.handle, file.toStringz);
365         if(status != LuaStatus.ok)
366         {
367             const error = this.get!string(-1);
368             this.pop(1);
369             throw new Exception(error);
370         }
371     }
372 
373     void loadString(const char[] str)
374     {
375         const status = luaL_loadstring(this.handle, str.toStringz);
376         if(status != LuaStatus.ok)
377         {
378             const error = this.get!string(-1);
379             this.pop(1);
380             throw new Exception(error);
381         }
382     }
383 
384     @nogc
385     ptrdiff_t optInt(int arg, ptrdiff_t default_) nothrow
386     {
387         return luaL_optinteger(this.handle, arg, default_);
388     }
389 
390     @nogc
391     LuaNumber optNumber(int arg, LuaNumber default_) nothrow
392     {
393         return luaL_optnumber(this.handle, arg, default_);
394     }
395 
396     void printStack()
397     {
398         writeln("[LUA STACK]");
399         foreach(i; 0..this.top)
400         {
401             const type = lua_type(this.handle, i+1);
402             writef("\t[%s] \t", i+1);
403 
404             switch(type)
405             {
406                 case LUA_TBOOLEAN: writefln("%s\t%s", "BOOL", this.get!bool(i+1)); break;
407                 case LUA_TFUNCTION: writefln("%s\t%s", "FUNC", lua_tocfunction(this.handle, i+1)); break;
408                 case LUA_TLIGHTUSERDATA: writefln("%s\t%s", "LIGHT", lua_touserdata(this.handle, i+1)); break;
409                 case LUA_TNIL: writefln("%s", "NIL"); break;
410                 case LUA_TNUMBER: writefln("%s\t%s", "NUM", this.get!lua_Number(i+1)); break;
411                 case LUA_TSTRING: writefln("%s\t%s", "STR", this.get!(const(char)[])(i+1)); break;
412                 case LUA_TTABLE: writefln("%s", "TABL"); break;
413                 case LUA_TTHREAD: writefln("%s\t%s", "THRD", lua_tothread(this.handle, i+1)); break;
414                 case LUA_TUSERDATA: writefln("%s\t%s", "USER", lua_touserdata(this.handle, i+1)); break;
415                 default: writefln("%s\t%s", "UNKN", type); break;
416             }
417         }
418     }
419 
420     void push(T)(T value)
421     {
422         import std.conv : to;
423         import std.traits : isNumeric, isDynamicArray, isAssociativeArray, isDelegate, isPointer, isFunction,
424                             PointerTarget, KeyType, ValueType;
425 
426         static if(is(T == typeof(null)) || is(T == LuaNil))
427             lua_pushnil(this.handle);
428         else static if(is(T : const(char)[]))
429             lua_pushlstring(this.handle, value.ptr, value.length);
430         else static if(isNumeric!T)
431             lua_pushnumber(this.handle, value.to!lua_Number);
432         else static if(is(T : const(bool)))
433             lua_pushboolean(this.handle, value ? 1 : 0);
434         else static if(isDynamicArray!T)
435         {
436             alias ValueT = typeof(value[0]);
437 
438             lua_createtable(this.handle, 0, value.length.to!int);
439             foreach(i, v; value)
440             {
441                 this.push(v);
442                 lua_rawseti(this.handle, -2, cast(int)i+1);
443             }
444         }
445         else static if(isAssociativeArray!T)
446         {
447             alias KeyT = KeyType!T;
448             alias ValueT = ValueType!T;
449 
450             lua_createtable(this.handle, 0, value.length.to!int);
451             foreach(k, v; value)
452             {
453                 this.push(k);
454                 this.push(v);
455                 lua_rawset(this.handle, -3);
456             }
457         }
458         else static if(is(T == LuaTable) || is(T == LuaFunc))
459             value.push();
460         else static if(is(T == LuaTableWeak) || is(T == LuaFuncWeak))
461             this.copy(value.push());
462         else static if(is(T : lua_CFunction))
463             lua_pushcfunction(this.handle, value);
464         else static if(isDelegate!T)
465         {
466             lua_pushlightuserdata(this.handle, value.ptr);
467             lua_pushlightuserdata(this.handle, value.funcptr);
468             lua_pushcclosure(this.handle, &luaCWrapperSmart!(T, LuaFuncWrapperType.isDelegate), 2);
469         }
470         else static if(isPointer!T && isFunction!(PointerTarget!T))
471         {
472             lua_pushlightuserdata(this.handle, value);
473             lua_pushcclosure(this.handle, &luaCWrapperSmart!(T, LuaFuncWrapperType.isFunction), 1);
474         }
475         else static if(isPointer!T)
476             lua_pushlightuserdata(this.handle, value);
477         else static if(is(T == class))
478             lua_pushlightuserdata(this.handle, cast(void*)value);
479         else static if(is(T == struct))
480         {
481             lua_newtable(this.handle);
482 
483             static foreach(member; __traits(allMembers, T))
484             {
485                 this.push(member);
486                 this.push(mixin("value."~member));
487                 lua_settable(this.handle, -3);
488             }
489         }
490         else static assert(false, "Don't know how to push type: "~T.stringof);
491     }
492 
493     void push(LuaValue value)
494     {
495         value.visit!(
496             (_){ this.push(_); }
497         );
498     }
499 
500     @nogc
501     int top() nothrow
502     {
503         return lua_gettop(this.handle);
504     }
505 
506     @nogc
507     void pop(int amount) nothrow
508     {
509         lua_pop(this.handle, amount);
510     }
511 
512     T get(T)(int index)
513     {
514         import std.conv : to;
515         import std.traits : isNumeric, isDynamicArray, isAssociativeArray, isPointer, KeyType, ValueType;
516 
517         static if(is(T == string))
518         {
519             this.enforceType(LuaValue.Kind.text, index);
520             size_t len;
521             auto ptr = lua_tolstring(this.handle, index, &len);
522             return ptr[0..len].idup;
523         }
524         else static if(is(T == const(char)[]))
525         {
526             this.enforceType(LuaValue.Kind.text, index);
527             size_t len;
528             auto ptr = lua_tolstring(this.handle, index, &len);
529             return ptr[0..len];
530         }
531         else static if(is(T : const(bool)))
532         {
533             this.enforceType(LuaValue.Kind.boolean, index);
534             return lua_toboolean(this.handle, index) != 0;
535         }
536         else static if(isNumeric!T)
537         {
538             this.enforceType(LuaValue.Kind.number, index);
539             return lua_tonumber(this.handle, index).to!T;
540         }
541         else static if(is(T == typeof(null)) || is(T == LuaNil))
542         {
543             this.enforceType(LuaValue.Kind.nil, index);
544             return LuaNil();
545         }
546         else static if(is(T == LuaTableWeak))
547         {
548             this.enforceType(LuaValue.Kind.table, index);
549             return T(&this, index);
550         }
551         else static if(is(T == LuaTable))
552         {
553             this.enforceType(LuaValue.Kind.table, index);
554             this.copy(index);
555             return T.makeRef(&this);
556         }
557         else static if(isDynamicArray!T)
558         {
559             this.enforceType(LuaValue.Kind.table, index);
560             T ret;
561             ret.length = lua_objlen(this.handle, index);
562 
563             this.push(null);
564             const tableIndex = index < 0 ? index - 1 : index;
565             while(this.next(tableIndex))
566             {
567                 ret[this.get!size_t(-2) - 1] = this.get!(typeof(ret[0]))(-1);
568                 this.pop(1);
569             }
570 
571             return ret;
572         }
573         else static if(isAssociativeArray!T)
574         {
575             this.enforceType(LuaValue.Kind.table, index);
576             T ret;
577 
578             this.push(null);
579             const tableIndex = index < 0 ? index - 1 : index;
580             while(this.next(tableIndex))
581             {
582                 ret[this.get!(KeyType!T)(-2)] = this.get!(ValueType!T)(-1);
583                 this.pop(1);
584             }
585 
586             return ret;
587         }
588         else static if(is(T == LuaCFunc))
589         {
590             this.enforceType(LuaValue.Kind.func, index);
591             return lua_tocfunction(this.handle, index);
592         }
593         else static if(is(T == LuaFuncWeak))
594         {
595             this.enforceType(LuaValue.Kind.func, index);
596             return LuaFuncWeak(&this, index);
597         }
598         else static if(is(T == LuaFunc))
599         {
600             this.enforceType(LuaValue.Kind.func, index);
601             this.copy(index);
602             return T.makeRef(&this);
603         }
604         else static if(isPointer!T || is(T == class))
605         {
606             this.enforceType(LuaValue.Kind.userData, index);
607             return cast(T)lua_touserdata(this.handle, index);
608         }
609         else static if(is(T == LuaValue))
610         {
611             switch(this.type(index))
612             {
613                 case LuaValue.Kind.text: return LuaValue(this.get!string(index));
614                 case LuaValue.Kind.number: return LuaValue(this.get!lua_Number(index));
615                 case LuaValue.Kind.boolean: return LuaValue(this.get!bool(index));
616                 case LuaValue.Kind.nil: return LuaValue(this.get!LuaNil(index));
617                 case LuaValue.Kind.table: return LuaValue(this.get!LuaTable(index));
618                 case LuaValue.Kind.func: return LuaValue(this.get!LuaFunc(index));
619                 case LuaValue.Kind.userData: return LuaValue(this.get!(void*)(index));
620                 default: throw new Exception("Don't know how to convert type into a LuaValue: "~this.type(index).to!string);
621             }
622         }
623         else static if(is(T == struct))
624         {
625             this.enforceType(LuaValue.Kind.table, index);
626             T ret;
627 
628             this.push(null);
629             const tableIndex = index < 0 ? index - 1 : index;
630             While: while(this.next(tableIndex))
631             {
632                 const field = this.get!(const(char)[])(-2);
633 
634                 static foreach(member; __traits(allMembers, T))
635                 {
636                     if(field == member)
637                     {
638                         mixin("ret."~member~"= this.get!(typeof(ret."~member~"))(-1);");
639                         this.pop(1);
640                         continue While;
641                     }
642                 }
643 
644                 this.pop(1);
645             }
646             return ret;
647         }
648         else static assert(false, "Don't know how to convert any LUA values into type: "~T.stringof);
649     }
650 
651     @nogc
652     bool next(int index) nothrow
653     {
654         this.assertIndex(index);
655         return lua_next(this.handle, index) != 0;
656     }
657 
658     void enforceType(LuaValue.Kind expected, int index)
659     {
660         import std.exception : enforce;
661         import std.format    : format;
662         const type = this.type(index);
663         enforce(type == expected, "Expected value at stack index %s to be of type %s but it is %s".format(
664             index, expected, type
665         ));
666     }
667 
668     @nogc
669     LuaValue.Kind type(int index) nothrow
670     {
671         assert(this.top > 0, "Stack is empty.");
672         this.assertIndex(index);
673         const type = lua_type(this.handle, index);
674 
675         switch(type)
676         {
677             case LUA_TBOOLEAN: return LuaValue.Kind.boolean;
678             case LUA_TNIL: return LuaValue.Kind.nil;
679             case LUA_TNUMBER: return LuaValue.Kind.number;
680             case LUA_TSTRING: return LuaValue.Kind.text;
681             case LUA_TTABLE: return LuaValue.Kind.table;
682             case LUA_TFUNCTION: return LuaValue.Kind.func;
683             case LUA_TLIGHTUSERDATA: return LuaValue.Kind.userData;
684 
685             default: 
686                 return LuaValue.Kind.nil;
687         }
688     }
689 
690     @property @safe @nogc
691     inout(lua_State*) handle() nothrow pure inout
692     {
693         return this._handle;
694     }
695 
696     @nogc
697     private void assertIndex(int index) nothrow
698     {
699         if(index > 0)
700             assert(this.top >= index, "Index out of bounds");
701         else
702             assert(this.top + index >= 0, "Index out of bounds");
703     }
704 }
705 
706 unittest
707 {
708     auto l = LuaState(null);
709     l.push(null);
710     assert(l.type(-1) == LuaValue.Kind.nil);
711     assert(l.get!LuaValue(-1).kind == LuaValue.Kind.nil);
712     l.pop(1);
713 
714     l.push(LuaNil());
715     assert(l.type(-1) == LuaValue.Kind.nil);
716     assert(l.get!LuaValue(-1).kind == LuaValue.Kind.nil);
717     l.pop(1);
718 
719     l.push(false);
720     assert(l.get!LuaValue(-1).kind == LuaValue.Kind.boolean);
721     assert(!l.get!bool(-1));
722     l.pop(1);
723 
724     l.push(20);
725     assert(l.get!LuaValue(-1).kind == LuaValue.Kind.number);
726     assert(l.get!int(-1) == 20);
727     l.pop(1);
728 
729     l.push("abc");
730     assert(l.get!LuaValue(-1).kind == LuaValue.Kind.text);
731     assert(l.get!string(-1) == "abc");
732     assert(l.get!(const(char)[])(-1) == "abc");
733     l.pop(1);
734 
735     l.push(["abc", "one"]);
736     assert(l.get!(string[])(-1) == ["abc", "one"]);
737     l.pop(1);
738 
739     l.push([LuaValue(200), LuaValue("abc")]);
740     assert(l.get!(LuaValue[])(-1) == [LuaValue(200), LuaValue("abc")]);
741     l.pop(1);
742 }
743 
744 unittest
745 {
746     auto l = LuaState(null);
747     l.register!(() => 123)("abc");
748     l.doString("assert(abc() == 123)");
749 }
750 
751 unittest
752 {
753     auto l = LuaState(null);
754     l.register!(
755         "funcA", () => "a",
756         "funcB", () => "b"
757     )("lib");
758     l.doString("assert(lib.funcA() == 'a') assert(lib.funcB() == 'b')");
759 }
760 
761 unittest
762 {
763     auto l = LuaState(null);
764     l.doString("abba = 'chicken tikka'");
765     assert(l.globalTable.get!string("abba") == "chicken tikka");
766     l.globalTable["baab"] = "tikka chicken";
767     assert(l.globalTable.get!string("baab") == "tikka chicken");
768 }
769 
770 unittest
771 {
772     static struct B
773     {
774         string a;
775     }
776 
777     static struct C
778     {
779         string a;
780     }
781  
782     static struct A
783     {
784         string a;
785         B[] b;
786         C[string] c;
787     }
788 
789     auto a = A(
790         "bc",
791         [B("c")],
792         ["c": C("123")]
793     );
794 
795     auto l = LuaState(null);
796     l.push(a);
797 
798     auto luaa = l.get!A(-1);
799     assert(luaa.a == "bc");
800     assert(luaa.b.length == 1);
801     assert(luaa.b == [B("c")]);
802     assert(luaa.c.length == 1);
803     assert(luaa.c["c"] == C("123"));
804 }