diff --git a/cli/main.c b/cli/main.c index 827f368..636d8ef 100644 --- a/cli/main.c +++ b/cli/main.c @@ -202,7 +202,7 @@ int main(int argc, const char** argv) { if (cmd != NULL) { // pocket -c "print('foo')" PkStringPtr source = { cmd, NULL, NULL, 0, 0 }; - PkStringPtr path = { "$(Source)", NULL, NULL, 0, 0 }; + PkStringPtr path = { "@(Source)", NULL, NULL, 0, 0 }; PkResult result = pkInterpretSource(vm, source, path, NULL); exitcode = (int)result; diff --git a/src/pk_compiler.c b/src/pk_compiler.c index b2aa1a7..cded684 100644 --- a/src/pk_compiler.c +++ b/src/pk_compiler.c @@ -21,6 +21,10 @@ // limited by it's opcode which is using a short value to identify. #define MAX_CONSTANTS (1 << 16) +// The maximum number of upvaues a literal function can capture from it's +// enclosing function. +#define MAX_UPVALUES 256 + // The maximum number of names that were used before defined. Its just the size // of the Forward buffer of the compiler. Feel free to increase it if it // require more. @@ -244,6 +248,7 @@ typedef struct { const char* name; //< Directly points into the source string. uint32_t length; //< Length of the name. int depth; //< The depth the local is defined in. + bool is_upvalue; //< Is this an upvalue for a nested function. int line; //< The line variable declared for debugging. } Local; @@ -293,6 +298,23 @@ typedef struct sForwardName { } ForwardName; +// This struct is used to keep track about the information of the upvaues for +// the current function to generate opcodes to capture them. +typedef struct sUpvalueInfo { + + // If it's true the extrenal local belongs to the immediate enclosing + // function and the bellow [index] refering at the locals of that function. + // If it's false the external local of the upvalue doesn't belongs to the + // immediate enclosing function and the [index] will refering to the upvalues + // array of the enclosing function. + bool is_immediate; + + // Index of the upvalue's external local variable, in the local or upvalues + // array of the enclosing function. + int index; + +} UpvalueInfo; + typedef struct sFunc { // Scope of the function. -2 for module body function, -1 for top level @@ -302,6 +324,8 @@ typedef struct sFunc { Local locals[MAX_VARIABLES]; //< Variables in the current context. int local_count; //< Number of locals in [locals]. + UpvalueInfo upvalues[MAX_UPVALUES]; //< Upvalues in the current context. + int stack_size; //< Current size including locals ind temps. // The actual function pointer which is being compiled. @@ -1218,9 +1242,7 @@ static bool matchAssignment(Compiler* compiler) { // if not found returns -1. static int findBuiltinFunction(const PKVM* vm, const char* name, uint32_t length) { - for (int i = 0; i < vm->builtins_count; i++) { - uint32_t bfn_length = (uint32_t)strlen(vm->builtins[i]->fn->name); if (bfn_length != length) continue; if (strncmp(name, vm->builtins[i]->fn->name, length) == 0) { @@ -1233,6 +1255,7 @@ static int findBuiltinFunction(const PKVM* vm, // Find the local with the [name] in the given function [func] and return // it's index, if not found returns -1. static int findLocal(Func* func, const char* name, uint32_t length) { + ASSERT(func != NULL, OOPS); for (int i = 0; i < func->local_count; i++) { if (func->locals[i].length != length) continue; if (strncmp(func->locals[i].name, name, length) == 0) { @@ -1242,12 +1265,74 @@ static int findLocal(Func* func, const char* name, uint32_t length) { return -1; } +// Add the upvalue to the given function and return it's index, if the upvalue +// already present in the function's upvalue array it'll return it. +static int addUpvalue(Compiler* compiler, Func* func, + int index, bool is_immediate) { + + // Search the upvalue in the existsing upvalues array. + for (int i = 0; i < func->ptr->upvalue_count; i++) { + UpvalueInfo info = func->upvalues[i]; + if (info.index == index && info.is_immediate == is_immediate) { + return i; + } + } + + if (func->ptr->upvalue_count == MAX_UPVALUES) { + parseError(compiler, "A function cannot capture more thatn %d upvalues.", + MAX_UPVALUES); + return -1; + } + + func->upvalues[func->ptr->upvalue_count].index = index; + func->upvalues[func->ptr->upvalue_count].is_immediate = is_immediate; + return func->ptr->upvalue_count++; +} + +// Search for an upvalue with the given [name] for the current function [func]. +// If an upvalue found, it'll add the upvalue info to the upvalue infor array +// of the [func] and return the index of the upvalue in the current function's +// upvalues array. +static int findUpvalue(Compiler* compiler, Func* func, const char* name, + uint32_t length) { + // TODO: + // check if the function is a method of a class and return -1 for them as + // well (once methods implemented). + // + // Toplevel functions cannot have upvalues. + if (func->depth <= DEPTH_GLOBAL) return -1; + + // Search in the immediate enclosing function's locals. + int index = findLocal(func->outer_func, name, length); + if (index != -1) { + + // Mark the locals as an upvalue to close it when it goes out of the scope. + func->outer_func->locals[index].is_upvalue = true; + + // Add upvalue to the function and return it's index. + return addUpvalue(compiler, func, index, true); + } + + // Recursively search for the upvalue in the outer function. If we found one + // all the outer function in the chain would have captured the upvalue for + // the local, we can add it to the current function as non-immediate upvalue. + index = findUpvalue(compiler, func->outer_func, name, length); + + if (index != -1) { + return addUpvalue(compiler, func, index, false); + } + + // If we reached here, the upvalue doesn't exists. + return -1; +} + // Result type for an identifier definition. typedef enum { NAME_NOT_DEFINED, NAME_LOCAL_VAR, //< Including parameter. + NAME_UPVALUE, //< Local to an enclosing function. NAME_GLOBAL_VAR, - NAME_BUILTIN, //< Native builtin function. + NAME_BUILTIN_FN, //< Native builtin function. } NameDefnType; // Identifier search result. @@ -1280,6 +1365,14 @@ static NameSearchResult compilerSearchName(Compiler* compiler, return result; } + // Search through upvalues. + index = findUpvalue(compiler, compiler->func, name, length); + if (index != -1) { + result.type = NAME_UPVALUE; + result.index = index; + return result; + } + // Search through globals. index = moduleGetGlobalIndex(compiler->module, name, length); if (index != -1) { @@ -1291,7 +1384,7 @@ static NameSearchResult compilerSearchName(Compiler* compiler, // Search through builtin functions. index = findBuiltinFunction(compiler->parser.vm, name, length); if (index != -1) { - result.type = NAME_BUILTIN; + result.type = NAME_BUILTIN_FN; result.index = index; return result; } @@ -1325,7 +1418,7 @@ static void compilerChangeStack(Compiler* compiler, int num); // Forward declaration of grammar functions. static void parsePrecedence(Compiler* compiler, Precedence precedence); -static int compileFunction(Compiler* compiler, bool is_literal); +static void compileFunction(Compiler* compiler, bool is_literal); static void compileExpression(Compiler* compiler); static void exprLiteral(Compiler* compiler); @@ -1447,12 +1540,13 @@ static void emitStoreGlobal(Compiler* compiler, int index) { // Emit opcode to push the named value at the [index] in it's array. static void emitPushName(Compiler* compiler, NameDefnType type, int index) { + ASSERT(index >= 0, OOPS); + switch (type) { case NAME_NOT_DEFINED: UNREACHABLE(); case NAME_LOCAL_VAR: - ASSERT(index >= 0, OOPS); if (index < 9) { //< 0..8 locals have single opcode. emitOpcode(compiler, (Opcode)(OP_PUSH_LOCAL_0 + index)); } else { @@ -1460,11 +1554,18 @@ static void emitPushName(Compiler* compiler, NameDefnType type, int index) { emitByte(compiler, index); } return; + + case NAME_UPVALUE: + emitOpcode(compiler, OP_PUSH_UPVALUE); + emitByte(compiler, index); + return; + case NAME_GLOBAL_VAR: emitOpcode(compiler, OP_PUSH_GLOBAL); emitByte(compiler, index); return; - case NAME_BUILTIN: + + case NAME_BUILTIN_FN: emitOpcode(compiler, OP_PUSH_BUILTIN_FN); emitByte(compiler, index); return; @@ -1474,13 +1575,14 @@ static void emitPushName(Compiler* compiler, NameDefnType type, int index) { // Emit opcode to store the stack top value to the named value at the [index] // in it's array. static void emitStoreName(Compiler* compiler, NameDefnType type, int index) { + ASSERT(index >= 0, OOPS); + switch (type) { case NAME_NOT_DEFINED: - case NAME_BUILTIN: + case NAME_BUILTIN_FN: UNREACHABLE(); case NAME_LOCAL_VAR: - ASSERT(index >= 0, OOPS); if (index < 9) { //< 0..8 locals have single opcode. emitOpcode(compiler, (Opcode)(OP_STORE_LOCAL_0 + index)); } else { @@ -1488,6 +1590,12 @@ static void emitStoreName(Compiler* compiler, NameDefnType type, int index) { emitByte(compiler, index); } return; + + case NAME_UPVALUE: + emitOpcode(compiler, OP_STORE_UPVALUE); + emitByte(compiler, index); + return; + case NAME_GLOBAL_VAR: emitStoreGlobal(compiler, index); return; @@ -1559,9 +1667,7 @@ static void exprInterpolation(Compiler* compiler) { } static void exprFunc(Compiler* compiler) { - int fn_index = compileFunction(compiler, true); - emitOpcode(compiler, OP_PUSH_CLOSURE); - emitShort(compiler, fn_index); + compileFunction(compiler, true); } static void exprName(Compiler* compiler) { @@ -1591,7 +1697,7 @@ static void exprName(Compiler* compiler) { // like python does) and it's recommented to define all the globals // before entering a local scope. - if (result.type == NAME_NOT_DEFINED || result.type == NAME_BUILTIN) { + if (result.type == NAME_NOT_DEFINED || result.type == NAME_BUILTIN_FN) { name_type = (compiler->scope_depth == DEPTH_GLOBAL) ? NAME_GLOBAL_VAR : NAME_LOCAL_VAR; @@ -1948,6 +2054,7 @@ static int compilerAddVariable(Compiler* compiler, const char* name, local->name = name; local->length = length; local->depth = compiler->scope_depth; + local->is_upvalue = false; local->line = line; return compiler->func->local_count++; } @@ -2021,7 +2128,12 @@ static int compilerPopLocals(Compiler* compiler, int depth) { // continue). So we need the pop instruction here but we still need the // locals to continue parsing the next statements in the scope. They'll be // popped once the scope is ended. - emitByte(compiler, OP_POP); + + if (compiler->func->locals[local].is_upvalue) { + emitByte(compiler, OP_CLOSE_UPVALUE); + } else { + emitByte(compiler, OP_POP); + } local--; } @@ -2154,7 +2266,7 @@ static void compileStatement(Compiler* compiler); static void compileBlockBody(Compiler* compiler, BlockType type); // Compile a class and return it's index in the module's types buffer. -static int compileClass(Compiler* compiler) { +static void compileClass(Compiler* compiler) { // Consume the name of the type. consume(compiler, TK_NAME, "Expected a type name."); @@ -2238,12 +2350,10 @@ static int compileClass(Compiler* compiler) { compilerExitBlock(compiler); emitFunctionEnd(compiler); compilerPopFunc(compiler); - - return -1; // TODO; } // Compile a function and return it's index in the module's function buffer. -static int compileFunction(Compiler* compiler, bool is_literal) { +static void compileFunction(Compiler* compiler, bool is_literal) { const char* name; int name_length; @@ -2334,7 +2444,20 @@ static int compileFunction(Compiler* compiler, bool is_literal) { compilerPopFunc(compiler); - return fn_index; + // Note: After the above compilerPopFunc() call, now we're at the outer + // function of this function, and the bellow emit calls will write to the + // outer function. If it's a literal function, we need to push a closure + // of it on the stack. + if (is_literal) { + emitOpcode(compiler, OP_PUSH_CLOSURE); + emitShort(compiler, fn_index); + + // Capture the upvalues when the closure is created. + for (int i = 0; i < curr_fn.ptr->upvalue_count; i++) { + emitByte(compiler, (curr_fn.upvalues[i].is_immediate) ? 1 : 0); + emitByte(compiler, curr_fn.upvalues[i].index); + } + } } // Finish a block body. @@ -2524,7 +2647,7 @@ static int compilerImportName(Compiler* compiler, int line, // Make it possible to override any name (ie. the syntax `print = 1` // should pass) and allow imported entries to have the same name of // builtin functions. - case NAME_BUILTIN: + case NAME_BUILTIN_FN: parseError(compiler, "Name '%.*s' already exists.", length, name); return -1; } diff --git a/src/pk_debug.c b/src/pk_debug.c index 454aed7..a0feb51 100644 --- a/src/pk_debug.c +++ b/src/pk_debug.c @@ -241,6 +241,15 @@ void dumpFunctionCode(PKVM* vm, Function* func) { break; } + case OP_PUSH_UPVALUE: + case OP_STORE_UPVALUE: + { + int index = READ_BYTE(); + PRINT_INT(index); + NEWLINE(); + break; + } + case OP_PUSH_CLOSURE: { int index = READ_SHORT(); @@ -256,7 +265,11 @@ void dumpFunctionCode(PKVM* vm, Function* func) { break; } - case OP_POP: NO_ARGS(); break; + case OP_CLOSE_UPVALUE: + case OP_POP: + NO_ARGS(); + break; + case OP_IMPORT: { int index = READ_SHORT(); diff --git a/src/pk_internal.h b/src/pk_internal.h index b224527..69fcbbe 100644 --- a/src/pk_internal.h +++ b/src/pk_internal.h @@ -78,7 +78,7 @@ // Name of a literal function. All literal function will have the same name but // they're uniquely identified by their index in the script's function buffer. -#define LITERAL_FN_NAME "@literalFn" +#define LITERAL_FN_NAME "@func" /*****************************************************************************/ /* ALLOCATION MACROS */ diff --git a/src/pk_opcodes.h b/src/pk_opcodes.h index 4706f72..c6eaf38 100644 --- a/src/pk_opcodes.h +++ b/src/pk_opcodes.h @@ -97,11 +97,24 @@ OPCODE(STORE_GLOBAL, 1, 0) // params: 1 bytes index. OPCODE(PUSH_BUILTIN_FN, 1, 1) +// Push an upvalue of the current closure at the index which is the first one +// byte argument. +// params: 1 byte index. +OPCODE(PUSH_UPVALUE, 1, 1) + +// Store the stack top value to the upvalues of the current function's upvalues +// array and don't pop it, since it's the result of the assignment. +// params: 1 byte index. +OPCODE(STORE_UPVALUE, 1, 0) + // Push a closure for the function at the constant pool with index of the // first 2 bytes arguments. // params: 2 byte index. OPCODE(PUSH_CLOSURE, 2, 1) +// Close the upvalue for the local at the stack top and pop it. +OPCODE(CLOSE_UPVALUE, 0, -1) + // Pop the stack top. OPCODE(POP, 0, -1) diff --git a/src/pk_value.c b/src/pk_value.c index 2e2c244..4695f70 100644 --- a/src/pk_value.c +++ b/src/pk_value.c @@ -206,7 +206,7 @@ static void popMarkedObjectsInternal(Object* obj, PKVM* vm) { markVarBuffer(vm, &module->globals); vm->bytes_allocated += sizeof(Var) * module->globals.capacity; - // Integer buffer has no gray call. + // Integer buffer has no mark call. vm->bytes_allocated += sizeof(uint32_t) * module->global_names.capacity; markVarBuffer(vm, &module->constants); @@ -510,7 +510,7 @@ Fiber* newFiber(PKVM* vm, Closure* closure) { fiber->sp = fiber->stack + 1; } else { - // Allocate stack. + // Calculate the stack size. int stack_size = utilPowerOf2Ceil(closure->fn->fn->stack_size + 1); if (stack_size < MIN_STACK_SIZE) stack_size = MIN_STACK_SIZE; fiber->stack = ALLOCATE_ARRAY(vm, Var, stack_size); @@ -529,6 +529,8 @@ Fiber* newFiber(PKVM* vm, Closure* closure) { fiber->frames[0].rbp = fiber->ret; } + fiber->open_upvalues = NULL; + // Initialize the return value to null (doesn't really have to do that here // but if we're trying to debut it may crash when dumping the return value). *fiber->ret = VAR_NULL; diff --git a/src/pk_value.h b/src/pk_value.h index 1387161..ae310dc 100644 --- a/src/pk_value.h +++ b/src/pk_value.h @@ -465,6 +465,11 @@ struct Fiber { // The stack pointer (%rsp) pointing to the stack top. Var* sp; + // All the open upvalues will form a linked list in the fiber and the + // upvalues are sorted in the same order their locals in the stack. The + // bellow pointer is the head of those upvalues near the stack top. + Upvalue* open_upvalues; + // The stack base pointer of the current frame. It'll be updated before // calling a native function. (`fiber->ret` === `curr_call_frame->rbp`). And // also updated if the stack is reallocated (that's when it's about to get diff --git a/src/pk_vm.c b/src/pk_vm.c index cf4e283..d5601ff 100644 --- a/src/pk_vm.c +++ b/src/pk_vm.c @@ -568,6 +568,93 @@ static inline void reuseCallFrame(PKVM* vm, const Closure* closure) { if (vm->fiber->stack_size <= needed) growStack(vm, needed); } +// Capture the [local] into an upvalue and return it. If the upvalue already +// exists on the fiber, it'll return it. +static Upvalue* captureUpvalue(PKVM* vm, Fiber* fiber, Var* local) { + + // If the fiber doesn't have any upvalues yet, create new one and add it. + if (fiber->open_upvalues == NULL) { + Upvalue* upvalue = newUpvalue(vm, local); + fiber->open_upvalues = upvalue; + return upvalue; + } + + // In the bellow diagram 'u0' is the head of the open upvalues of the fiber. + // We'll walk through the upvalues to see if any of it's value is similar + // to the [local] we want to capture. + // + // This can be optimized with binary search since the upvalues are sorted + // but it's not a frequent task neither the number of upvalues would be very + // few and the local mostly located at the stack top. + // + // 1. If say 'l3' is what we want to capture, that local already has an + // upavlue 'u1' return it. + // 2. If say 'l4' is what we want to capture, It doesn't have an upvalue yet. + // Create a new upvalue and insert to the link list (ie. u1.next = u3, + // u3.next = u2) and return it. + // + // | | + // | l1 | <-- u0 (u1.value = l3) + // | l2 | | + // | l3 | <-- u1 (u1.value = l3) + // | l4 | | + // | l5 | <-- u2 (u2.value = l5) + // '------' | + // stack NULL + + // Edge case: if the local is located higher than all the open upvalues, we + // cannot walk the chain, it's going to be the new head of the open upvalues. + if (fiber->open_upvalues->ptr < local) { + Upvalue* head = newUpvalue(vm, local); + head->next = fiber->open_upvalues; + fiber->open_upvalues = head; + return head; + } + + // Now we walk the chain of open upvalues and if we find an upvalue for the + // local return it, otherwise insert it in the chain. + Upvalue* last = NULL; + Upvalue* current = fiber->open_upvalues; + + while (current->ptr > local) { + last = current; + current = current->next; + + // If the current is NULL, we've walked all the way to the end of the open + // upvalues, and there isn't one upvalue for the local. + if (current == NULL) { + last->next = newUpvalue(vm, local); + return last->next; + } + } + + // If [current] is the upvalue that captured [local] then return it. + if (current->ptr == local) return current; + + ASSERT(last != NULL, OOPS); + + // If we've reached here, the upvalue isn't found, create a new one and + // insert it to the chain. + Upvalue* upvalue = newUpvalue(vm, local); + last->next = upvalue; + upvalue->next = current; + return upvalue; +} + +// Close all the upvalues for the locals including [top] and higher in the +// stack. +static void closeUpvalues(Fiber* fiber, Var* top) { + + while (fiber->open_upvalues != NULL && fiber->open_upvalues->ptr >= top) { + Upvalue* upvalue = fiber->open_upvalues; + upvalue->closed = *upvalue->ptr; + upvalue->ptr = &upvalue->closed; + + fiber->open_upvalues = upvalue->next; + } + +} + static void reportError(PKVM* vm) { ASSERT(VM_HAS_ERROR(vm), "runtimeError() should be called after an error."); // TODO: pass the error to the caller of the fiber. @@ -863,13 +950,51 @@ L_vm_main_loop: DISPATCH(); } + OPCODE(PUSH_UPVALUE): + { + uint8_t index = READ_BYTE(); + PUSH(*(frame->closure->upvalues[index]->ptr)); + DISPATCH(); + } + + OPCODE(STORE_UPVALUE): + { + uint8_t index = READ_BYTE(); + *(frame->closure->upvalues[index]->ptr) = PEEK(-1); + DISPATCH(); + } + OPCODE(PUSH_CLOSURE): { uint16_t index = READ_SHORT(); ASSERT_INDEX(index, module->constants.count); ASSERT(IS_OBJ_TYPE(module->constants.data[index], OBJ_FUNC), OOPS); Function* fn = (Function*)AS_OBJ(module->constants.data[index]); - PUSH(VAR_OBJ(newClosure(vm, fn))); + Closure* closure = newClosure(vm, fn); + + // Capture the vaupes. + for (int i = 0; i < fn->upvalue_count; i++) { + uint8_t is_immediate = READ_BYTE(); + uint8_t index = READ_BYTE(); + + if (is_immediate) { + // rbp[0] is the return value, rbp + 1 is the first local and so on. + closure->upvalues[i] = captureUpvalue(vm, vm->fiber, + (rbp + 1 + index)); + } else { + // The upvalue is already captured by the current function, reuse it. + closure->upvalues[i] = frame->closure->upvalues[index]; + } + } + + PUSH(VAR_OBJ(closure)); + DISPATCH(); + } + + OPCODE(CLOSE_UPVALUE): + { + closeUpvalues(vm->fiber, vm->fiber->sp - 1); + DROP(); DISPATCH(); } @@ -1169,6 +1294,9 @@ L_vm_main_loop: OPCODE(RETURN): { + // Close all the locals of the current frame. + closeUpvalues(vm->fiber, rbp + 1); + // Set the return value. Var ret_value = POP(); diff --git a/tests/lang/closure.pk b/tests/lang/closure.pk new file mode 100644 index 0000000..40331e0 --- /dev/null +++ b/tests/lang/closure.pk @@ -0,0 +1,87 @@ + +## Simple upvalue. +def f1 + local = "foo" + return func + return local + end +end +assert(f1()() == "foo") + +def add3(x) + return func(y) + return func(z) + return x + y + z + end + end +end +assert(add3(1)(2)(3) == 6); +assert(add3(7)(6)(4) == 17); + +## Upvalue external to the inner function. +def f2 + local = "bar" + return func + fn = func + return local + end + return fn + end +end +assert(f2()()() == "bar") + +## Check if upvalues are shared between closures. +def f3 + local = "baz" + _fn1 = func(x) + local = x + end + _fn2 = func + return local + end + return [_fn1, _fn2] +end +fns = f3() +fns[0]("qux") +assert(fns[1]() == "qux") + +def f4 + a = [] + x = 10 + for i in 0..2 + j = i ## 'i' is shared, but 'j' doesn't + list_append( + a, + func + return x + j + end + ) + end + x = 20 + return a +end +a = f4() +assert(a[0]() == 20) +assert(a[1]() == 21) + +def f5 + l1 = 12 + return func ## c1 + l2 = 34 + return func ## c2 + l3 = 56 + return func ## c3 + return func ## c4 + return func ## c5 + return l1 + l2 + l3 + end + end + end + end + end +end + +print(f5()()()()()() == 102) + +print('All TESTS PASSED') + diff --git a/tests/tests.py b/tests/tests.py index bb339e4..0777118 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -18,6 +18,7 @@ TEST_SUITE = { "lang/basics.pk", "lang/builtin_fn.pk", "lang/class.pk", + "lang/closure.pk", "lang/core.pk", "lang/controlflow.pk", "lang/fibers.pk",