diff --git a/scripts/README.md b/scripts/README.md index f909e84..c74814d 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -13,7 +13,7 @@ can download it your self at https://premake.github.io/download. python3 download_premake.py ``` -It will download and place the `premake.exe` (in windows) binary next to +It will download and place the `premake5.exe` (in windows) binary next to it, next run the following command to generate Visual studio solution files. ``` diff --git a/scripts/premake5.lua b/scripts/premake5.lua index d778f07..dc19c37 100644 --- a/scripts/premake5.lua +++ b/scripts/premake5.lua @@ -68,13 +68,3 @@ project "pocket_cli" cli_dir .. "**.h" } includedirs ({ src_dir .. "include/" }) links { "pocket_static" } - - - - - - - - - - diff --git a/src/pk_compiler.c b/src/pk_compiler.c index 65b3a6f..8acd8fb 100644 --- a/src/pk_compiler.c +++ b/src/pk_compiler.c @@ -127,7 +127,7 @@ typedef enum { TK_TRUE, // true TK_FALSE, // false TK_SELF, // self - // TODO: TK_SUPER + TK_SUPER, // super TK_DO, // do TK_THEN, // then @@ -191,6 +191,7 @@ static _Keyword _keywords[] = { { "true", 4, TK_TRUE }, { "false", 5, TK_FALSE }, { "self", 4, TK_SELF }, + { "super", 5, TK_SUPER }, { "do", 2, TK_DO }, { "then", 4, TK_THEN }, { "while", 5, TK_WHILE }, @@ -1533,6 +1534,7 @@ static void exprSubscript(Compiler* compiler); static void exprValue(Compiler* compiler); static void exprSelf(Compiler* compiler); +static void exprSuper(Compiler* compiler); #define NO_RULE { NULL, NULL, PREC_NONE } #define NO_INFIX PREC_NONE @@ -1600,7 +1602,8 @@ GrammarRule rules[] = { // Prefix Infix Infix Precedence /* TK_NOT */ { exprUnaryOp, NULL, PREC_UNARY }, /* TK_TRUE */ { exprValue, NULL, NO_INFIX }, /* TK_FALSE */ { exprValue, NULL, NO_INFIX }, - /* TK_FALSE */ { exprSelf, NULL, NO_INFIX }, + /* TK_SELF */ { exprSelf, NULL, NO_INFIX }, + /* TK_SUPER */ { exprSuper, NULL, NO_INFIX }, /* TK_DO */ NO_RULE, /* TK_THEN */ NO_RULE, /* TK_WHILE */ NO_RULE, @@ -2020,7 +2023,9 @@ static void exprMap(Compiler* compiler) { // is OP_METHOD_CALL the [method] should refer a string in the module's // constant pool, otherwise it's ignored. static void _compileCall(Compiler* compiler, Opcode call_type, int method) { - ASSERT((call_type == OP_CALL) || (call_type == OP_METHOD_CALL), OOPS); + ASSERT((call_type == OP_CALL) || + (call_type == OP_METHOD_CALL) || + (call_type == OP_SUPER_CALL), OOPS); // Compile parameters. int argc = 0; @@ -2038,7 +2043,7 @@ static void _compileCall(Compiler* compiler, Opcode call_type, int method) { emitByte(compiler, argc); - if (call_type == OP_METHOD_CALL) { + if ((call_type == OP_METHOD_CALL) || (call_type == OP_SUPER_CALL)) { ASSERT_INDEX(method, (int)compiler->module->constants.count); emitShort(compiler, method); } @@ -2146,6 +2151,42 @@ static void exprSelf(Compiler* compiler) { } } +static void exprSuper(Compiler* compiler) { + + if (compiler->func->type != FUNC_CONSTRUCTOR && + compiler->func->type != FUNC_METHOD) { + semanticError(compiler, compiler->parser.previous, + "Invalid use of 'super'."); + return; + } + + ASSERT(compiler->func->ptr != NULL, OOPS); + + int index = 0; + const char* name = compiler->func->ptr->name; + int name_length = -1; + + if (!match(compiler, TK_LPARAN)) { // super.method(). + consume(compiler, TK_DOT, "Invalid use of 'super'."); + + consume(compiler, TK_NAME, "Expected a method name after 'super'."); + name = compiler->parser.previous.start; + name_length = compiler->parser.previous.length; + + consume(compiler, TK_LPARAN, "Expected symbol '('."); + + } else { // super(). + name_length = (int)strlen(name); + } + + if (compiler->parser.has_syntax_error) return; + + emitOpcode(compiler, OP_PUSH_SELF); + moduleAddString(compiler->module, compiler->parser.vm, + name, name_length, &index); + _compileCall(compiler, OP_SUPER_CALL, index); +} + static void parsePrecedence(Compiler* compiler, Precedence precedence) { lexToken(compiler); if (compiler->parser.has_syntax_error) return; diff --git a/src/pk_core.c b/src/pk_core.c index 292d5f5..0871bb9 100644 --- a/src/pk_core.c +++ b/src/pk_core.c @@ -968,21 +968,31 @@ Class* getClass(PKVM* vm, Var instance) { return inst->cls; } -bool hasMethod(PKVM* vm, Var self, String* name, Closure** _method) { - Class* cls = getClass(vm, self); - ASSERT(cls != NULL, OOPS); - +// Returns a method on a class (it'll walk up the inheritance tree to search +// and if the method not found, it'll return NULL. +static inline Closure* clsGetMethod(Class* cls, String* name) { Class* cls_ = cls; do { for (int i = 0; i < (int)cls_->methods.count; i++) { Closure* method_ = cls_->methods.data[i]; if (IS_CSTR_EQ(name, method_->fn->name, name->length)) { - if (_method) *_method = method_; - return true; + return method_; } } cls_ = cls_->super_class; } while (cls_ != NULL); + return NULL; +} + +bool hasMethod(PKVM* vm, Var self, String* name, Closure** _method) { + Class* cls = getClass(vm, self); + ASSERT(cls != NULL, OOPS); + + Closure* method_ = clsGetMethod(cls, name); + if (method_ != NULL) { + *_method = method_; + return true; + } return false; } @@ -1000,6 +1010,22 @@ Var getMethod(PKVM* vm, Var self, String* name, bool* is_method) { return varGetAttrib(vm, self, name); } +Closure* getSuperMethod(PKVM* vm, Var self, String* name) { + Class* super = getClass(vm, self)->super_class; + if (super == NULL) { + VM_SET_ERROR(vm, stringFormat(vm, "'$' object has no parent class.", \ + varTypeName(self))); + return NULL; + }; + + Closure* method = clsGetMethod(super, name); + if (method == NULL) { + VM_SET_ERROR(vm, stringFormat(vm, "'@' class has no method named '@'.", \ + super->name, name)); + } + return method; +} + #define UNSUPPORTED_UNARY_OP(op) \ VM_SET_ERROR(vm, stringFormat(vm, "Unsupported operand ($) for " \ "unary operator " op ".", varTypeName(v))) diff --git a/src/pk_core.h b/src/pk_core.h index 78c8d6a..3c00a12 100644 --- a/src/pk_core.h +++ b/src/pk_core.h @@ -75,6 +75,10 @@ Class* getClass(PKVM* vm, Var instance); // If the method / attribute not found, it'll set a runtime error on the VM. Var getMethod(PKVM* vm, Var self, String* name, bool* is_method); +// Returns the method (closure) from the instance's super class. If the method +// doesn't exists, it'll set an error on the VM. +Closure* getSuperMethod(PKVM* vm, Var self, String* name); + // Unlike getMethod this will not set error and will not try to get attribute // with the same name. It'll return true if the method exists on [self], false // otherwise and if the [method] argument is not NULL, method will be set. diff --git a/src/pk_debug.c b/src/pk_debug.c index e25e9fb..689d315 100644 --- a/src/pk_debug.c +++ b/src/pk_debug.c @@ -512,6 +512,7 @@ void dumpFunctionCode(PKVM* vm, Function* func) { break; } + case OP_SUPER_CALL: case OP_METHOD_CALL: { int argc = READ_BYTE(); diff --git a/src/pk_opcodes.h b/src/pk_opcodes.h index a84c3a2..4e79e5d 100644 --- a/src/pk_opcodes.h +++ b/src/pk_opcodes.h @@ -133,7 +133,14 @@ OPCODE(POP, 0, -1) // params: 2 byte name index. OPCODE(IMPORT, 2, 1) -// Call a method on the variable at the stack top. See opcode CALL for detail. +// Call a super class's method on the variable at (stack_top - argc). +// See opcode CALL for detail. +// params: 2 bytes method name index in the constant pool. +// 1 byte argc. +OPCODE(SUPER_CALL, 3, -0) //< Stack size will be calculated at compile time. + +// Call a method on the variable at (stack_top - argc). See opcode CALL for +// detail. // params: 2 bytes method name index in the constant pool. // 1 byte argc. OPCODE(METHOD_CALL, 3, -0) //< Stack size will be calculated at compile time. diff --git a/src/pk_vm.c b/src/pk_vm.c index b0caea9..9acb6c1 100644 --- a/src/pk_vm.c +++ b/src/pk_vm.c @@ -889,13 +889,10 @@ L_vm_main_loop: if (strcmp(method->fn->name, CTOR_NAME) == 0) { cls->ctor = method; - } else { - // TODO: The method buffer should be ordered with it's name and - // inserted in a way to preserve the order to implement binary search - // to find a method. - pkClosureBufferWrite(&cls->methods, vm, method); } + pkClosureBufferWrite(&cls->methods, vm, method); + DROP(); DISPATCH(); } @@ -952,14 +949,28 @@ L_vm_main_loop: Var callable; const Closure* closure; + uint16_t index; //< To get the method name. + String* name; //< The method name. + + OPCODE(SUPER_CALL): + argc = READ_BYTE(); + fiber->ret = (fiber->sp - argc - 1); + fiber->self = *fiber->ret; //< Self for the next call. + index = READ_SHORT(); + name = moduleGetStringAt(module, (int)index); + Closure* super_method = getSuperMethod(vm, fiber->self, name); + CHECK_ERROR(); // Will return if super_method is NULL. + callable = VAR_OBJ(super_method); + goto L_do_call; + OPCODE(METHOD_CALL): argc = READ_BYTE(); fiber->ret = (fiber->sp - argc - 1); fiber->self = *fiber->ret; //< Self for the next call. - uint16_t index = READ_SHORT(); + index = READ_SHORT(); + name = moduleGetStringAt(module, (int)index); bool is_method; - String* name = moduleGetStringAt(module, (int)index); callable = getMethod(vm, fiber->self, name, &is_method); CHECK_ERROR(); goto L_do_call; @@ -1055,16 +1066,18 @@ L_do_call: } else { - if (instruction == OP_CALL || instruction == OP_METHOD_CALL) { + if (instruction == OP_TAIL_CALL) { + reuseCallFrame(vm, closure); + LOAD_FRAME(); //< Re-load the frame to vm's execution variables. + + } else { + ASSERT((instruction == OP_CALL) || + (instruction == OP_METHOD_CALL) || + (instruction == OP_SUPER_CALL), OOPS); + UPDATE_FRAME(); //< Update the current frame's ip. pushCallFrame(vm, closure, fiber->ret); LOAD_FRAME(); //< Load the top frame to vm's execution variables. - - } else { - ASSERT(instruction == OP_TAIL_CALL, OOPS); - - reuseCallFrame(vm, closure); - LOAD_FRAME(); //< Re-load the frame to vm's execution variables. } } diff --git a/tests/lang/class.pk b/tests/lang/class.pk index fecfa6c..badff7f 100644 --- a/tests/lang/class.pk +++ b/tests/lang/class.pk @@ -193,4 +193,44 @@ assert(!n5 == 3) assert(!n5 == 2) assert(!n5 == 1) +############################################################################### +## SUPER CLASS METHOD +############################################################################### + +class A + def _init() + print("A init") + end + def foo() + print("A foo") + ret = self.bar() + assert(ret == "B.bar") + return "A.foo" + end + def bar() + print("A bar") + return "A.bar" + end +end + +class B is A + def _init() + super() + print("B init") + end + def foo() + print("B foo") + ret = super() + assert(ret == "A.foo") + return super.bar() + end + def bar() + print("B bar") + return "B.bar" + end +end + +b = B() +assert(b.foo() == "A.bar") + print('ALL TESTS PASSED')