mirror of
https://github.com/zekexiao/pocketlang.git
synced 2025-02-05 20:26:53 +08:00
Merge pull request #196 from ThakeeNathees/upvalue
Closures implemented
This commit is contained in:
commit
9d6d37fa09
@ -202,7 +202,7 @@ int main(int argc, const char** argv) {
|
||||
if (cmd != NULL) { // pocket -c "print('foo')"
|
||||
|
||||
PkStringPtr source = { cmd, NULL, NULL, 0, 0 };
|
||||
PkStringPtr path = { "$(Source)", NULL, NULL, 0, 0 };
|
||||
PkStringPtr path = { "@(Source)", NULL, NULL, 0, 0 };
|
||||
PkResult result = pkInterpretSource(vm, source, path, NULL);
|
||||
exitcode = (int)result;
|
||||
|
||||
|
@ -21,6 +21,10 @@
|
||||
// limited by it's opcode which is using a short value to identify.
|
||||
#define MAX_CONSTANTS (1 << 16)
|
||||
|
||||
// The maximum number of upvaues a literal function can capture from it's
|
||||
// enclosing function.
|
||||
#define MAX_UPVALUES 256
|
||||
|
||||
// The maximum number of names that were used before defined. Its just the size
|
||||
// of the Forward buffer of the compiler. Feel free to increase it if it
|
||||
// require more.
|
||||
@ -244,6 +248,7 @@ typedef struct {
|
||||
const char* name; //< Directly points into the source string.
|
||||
uint32_t length; //< Length of the name.
|
||||
int depth; //< The depth the local is defined in.
|
||||
bool is_upvalue; //< Is this an upvalue for a nested function.
|
||||
int line; //< The line variable declared for debugging.
|
||||
} Local;
|
||||
|
||||
@ -293,6 +298,23 @@ typedef struct sForwardName {
|
||||
|
||||
} ForwardName;
|
||||
|
||||
// This struct is used to keep track about the information of the upvaues for
|
||||
// the current function to generate opcodes to capture them.
|
||||
typedef struct sUpvalueInfo {
|
||||
|
||||
// If it's true the extrenal local belongs to the immediate enclosing
|
||||
// function and the bellow [index] refering at the locals of that function.
|
||||
// If it's false the external local of the upvalue doesn't belongs to the
|
||||
// immediate enclosing function and the [index] will refering to the upvalues
|
||||
// array of the enclosing function.
|
||||
bool is_immediate;
|
||||
|
||||
// Index of the upvalue's external local variable, in the local or upvalues
|
||||
// array of the enclosing function.
|
||||
int index;
|
||||
|
||||
} UpvalueInfo;
|
||||
|
||||
typedef struct sFunc {
|
||||
|
||||
// Scope of the function. -2 for module body function, -1 for top level
|
||||
@ -302,6 +324,8 @@ typedef struct sFunc {
|
||||
Local locals[MAX_VARIABLES]; //< Variables in the current context.
|
||||
int local_count; //< Number of locals in [locals].
|
||||
|
||||
UpvalueInfo upvalues[MAX_UPVALUES]; //< Upvalues in the current context.
|
||||
|
||||
int stack_size; //< Current size including locals ind temps.
|
||||
|
||||
// The actual function pointer which is being compiled.
|
||||
@ -1218,9 +1242,7 @@ static bool matchAssignment(Compiler* compiler) {
|
||||
// if not found returns -1.
|
||||
static int findBuiltinFunction(const PKVM* vm,
|
||||
const char* name, uint32_t length) {
|
||||
|
||||
for (int i = 0; i < vm->builtins_count; i++) {
|
||||
|
||||
uint32_t bfn_length = (uint32_t)strlen(vm->builtins[i]->fn->name);
|
||||
if (bfn_length != length) continue;
|
||||
if (strncmp(name, vm->builtins[i]->fn->name, length) == 0) {
|
||||
@ -1233,6 +1255,7 @@ static int findBuiltinFunction(const PKVM* vm,
|
||||
// Find the local with the [name] in the given function [func] and return
|
||||
// it's index, if not found returns -1.
|
||||
static int findLocal(Func* func, const char* name, uint32_t length) {
|
||||
ASSERT(func != NULL, OOPS);
|
||||
for (int i = 0; i < func->local_count; i++) {
|
||||
if (func->locals[i].length != length) continue;
|
||||
if (strncmp(func->locals[i].name, name, length) == 0) {
|
||||
@ -1242,12 +1265,74 @@ static int findLocal(Func* func, const char* name, uint32_t length) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Add the upvalue to the given function and return it's index, if the upvalue
|
||||
// already present in the function's upvalue array it'll return it.
|
||||
static int addUpvalue(Compiler* compiler, Func* func,
|
||||
int index, bool is_immediate) {
|
||||
|
||||
// Search the upvalue in the existsing upvalues array.
|
||||
for (int i = 0; i < func->ptr->upvalue_count; i++) {
|
||||
UpvalueInfo info = func->upvalues[i];
|
||||
if (info.index == index && info.is_immediate == is_immediate) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
if (func->ptr->upvalue_count == MAX_UPVALUES) {
|
||||
parseError(compiler, "A function cannot capture more thatn %d upvalues.",
|
||||
MAX_UPVALUES);
|
||||
return -1;
|
||||
}
|
||||
|
||||
func->upvalues[func->ptr->upvalue_count].index = index;
|
||||
func->upvalues[func->ptr->upvalue_count].is_immediate = is_immediate;
|
||||
return func->ptr->upvalue_count++;
|
||||
}
|
||||
|
||||
// Search for an upvalue with the given [name] for the current function [func].
|
||||
// If an upvalue found, it'll add the upvalue info to the upvalue infor array
|
||||
// of the [func] and return the index of the upvalue in the current function's
|
||||
// upvalues array.
|
||||
static int findUpvalue(Compiler* compiler, Func* func, const char* name,
|
||||
uint32_t length) {
|
||||
// TODO:
|
||||
// check if the function is a method of a class and return -1 for them as
|
||||
// well (once methods implemented).
|
||||
//
|
||||
// Toplevel functions cannot have upvalues.
|
||||
if (func->depth <= DEPTH_GLOBAL) return -1;
|
||||
|
||||
// Search in the immediate enclosing function's locals.
|
||||
int index = findLocal(func->outer_func, name, length);
|
||||
if (index != -1) {
|
||||
|
||||
// Mark the locals as an upvalue to close it when it goes out of the scope.
|
||||
func->outer_func->locals[index].is_upvalue = true;
|
||||
|
||||
// Add upvalue to the function and return it's index.
|
||||
return addUpvalue(compiler, func, index, true);
|
||||
}
|
||||
|
||||
// Recursively search for the upvalue in the outer function. If we found one
|
||||
// all the outer function in the chain would have captured the upvalue for
|
||||
// the local, we can add it to the current function as non-immediate upvalue.
|
||||
index = findUpvalue(compiler, func->outer_func, name, length);
|
||||
|
||||
if (index != -1) {
|
||||
return addUpvalue(compiler, func, index, false);
|
||||
}
|
||||
|
||||
// If we reached here, the upvalue doesn't exists.
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Result type for an identifier definition.
|
||||
typedef enum {
|
||||
NAME_NOT_DEFINED,
|
||||
NAME_LOCAL_VAR, //< Including parameter.
|
||||
NAME_UPVALUE, //< Local to an enclosing function.
|
||||
NAME_GLOBAL_VAR,
|
||||
NAME_BUILTIN, //< Native builtin function.
|
||||
NAME_BUILTIN_FN, //< Native builtin function.
|
||||
} NameDefnType;
|
||||
|
||||
// Identifier search result.
|
||||
@ -1280,6 +1365,14 @@ static NameSearchResult compilerSearchName(Compiler* compiler,
|
||||
return result;
|
||||
}
|
||||
|
||||
// Search through upvalues.
|
||||
index = findUpvalue(compiler, compiler->func, name, length);
|
||||
if (index != -1) {
|
||||
result.type = NAME_UPVALUE;
|
||||
result.index = index;
|
||||
return result;
|
||||
}
|
||||
|
||||
// Search through globals.
|
||||
index = moduleGetGlobalIndex(compiler->module, name, length);
|
||||
if (index != -1) {
|
||||
@ -1291,7 +1384,7 @@ static NameSearchResult compilerSearchName(Compiler* compiler,
|
||||
// Search through builtin functions.
|
||||
index = findBuiltinFunction(compiler->parser.vm, name, length);
|
||||
if (index != -1) {
|
||||
result.type = NAME_BUILTIN;
|
||||
result.type = NAME_BUILTIN_FN;
|
||||
result.index = index;
|
||||
return result;
|
||||
}
|
||||
@ -1325,7 +1418,7 @@ static void compilerChangeStack(Compiler* compiler, int num);
|
||||
|
||||
// Forward declaration of grammar functions.
|
||||
static void parsePrecedence(Compiler* compiler, Precedence precedence);
|
||||
static int compileFunction(Compiler* compiler, bool is_literal);
|
||||
static void compileFunction(Compiler* compiler, bool is_literal);
|
||||
static void compileExpression(Compiler* compiler);
|
||||
|
||||
static void exprLiteral(Compiler* compiler);
|
||||
@ -1447,12 +1540,13 @@ static void emitStoreGlobal(Compiler* compiler, int index) {
|
||||
|
||||
// Emit opcode to push the named value at the [index] in it's array.
|
||||
static void emitPushName(Compiler* compiler, NameDefnType type, int index) {
|
||||
ASSERT(index >= 0, OOPS);
|
||||
|
||||
switch (type) {
|
||||
case NAME_NOT_DEFINED:
|
||||
UNREACHABLE();
|
||||
|
||||
case NAME_LOCAL_VAR:
|
||||
ASSERT(index >= 0, OOPS);
|
||||
if (index < 9) { //< 0..8 locals have single opcode.
|
||||
emitOpcode(compiler, (Opcode)(OP_PUSH_LOCAL_0 + index));
|
||||
} else {
|
||||
@ -1460,11 +1554,18 @@ static void emitPushName(Compiler* compiler, NameDefnType type, int index) {
|
||||
emitByte(compiler, index);
|
||||
}
|
||||
return;
|
||||
|
||||
case NAME_UPVALUE:
|
||||
emitOpcode(compiler, OP_PUSH_UPVALUE);
|
||||
emitByte(compiler, index);
|
||||
return;
|
||||
|
||||
case NAME_GLOBAL_VAR:
|
||||
emitOpcode(compiler, OP_PUSH_GLOBAL);
|
||||
emitByte(compiler, index);
|
||||
return;
|
||||
case NAME_BUILTIN:
|
||||
|
||||
case NAME_BUILTIN_FN:
|
||||
emitOpcode(compiler, OP_PUSH_BUILTIN_FN);
|
||||
emitByte(compiler, index);
|
||||
return;
|
||||
@ -1474,13 +1575,14 @@ static void emitPushName(Compiler* compiler, NameDefnType type, int index) {
|
||||
// Emit opcode to store the stack top value to the named value at the [index]
|
||||
// in it's array.
|
||||
static void emitStoreName(Compiler* compiler, NameDefnType type, int index) {
|
||||
ASSERT(index >= 0, OOPS);
|
||||
|
||||
switch (type) {
|
||||
case NAME_NOT_DEFINED:
|
||||
case NAME_BUILTIN:
|
||||
case NAME_BUILTIN_FN:
|
||||
UNREACHABLE();
|
||||
|
||||
case NAME_LOCAL_VAR:
|
||||
ASSERT(index >= 0, OOPS);
|
||||
if (index < 9) { //< 0..8 locals have single opcode.
|
||||
emitOpcode(compiler, (Opcode)(OP_STORE_LOCAL_0 + index));
|
||||
} else {
|
||||
@ -1488,6 +1590,12 @@ static void emitStoreName(Compiler* compiler, NameDefnType type, int index) {
|
||||
emitByte(compiler, index);
|
||||
}
|
||||
return;
|
||||
|
||||
case NAME_UPVALUE:
|
||||
emitOpcode(compiler, OP_STORE_UPVALUE);
|
||||
emitByte(compiler, index);
|
||||
return;
|
||||
|
||||
case NAME_GLOBAL_VAR:
|
||||
emitStoreGlobal(compiler, index);
|
||||
return;
|
||||
@ -1559,9 +1667,7 @@ static void exprInterpolation(Compiler* compiler) {
|
||||
}
|
||||
|
||||
static void exprFunc(Compiler* compiler) {
|
||||
int fn_index = compileFunction(compiler, true);
|
||||
emitOpcode(compiler, OP_PUSH_CLOSURE);
|
||||
emitShort(compiler, fn_index);
|
||||
compileFunction(compiler, true);
|
||||
}
|
||||
|
||||
static void exprName(Compiler* compiler) {
|
||||
@ -1591,7 +1697,7 @@ static void exprName(Compiler* compiler) {
|
||||
// like python does) and it's recommented to define all the globals
|
||||
// before entering a local scope.
|
||||
|
||||
if (result.type == NAME_NOT_DEFINED || result.type == NAME_BUILTIN) {
|
||||
if (result.type == NAME_NOT_DEFINED || result.type == NAME_BUILTIN_FN) {
|
||||
name_type = (compiler->scope_depth == DEPTH_GLOBAL)
|
||||
? NAME_GLOBAL_VAR
|
||||
: NAME_LOCAL_VAR;
|
||||
@ -1948,6 +2054,7 @@ static int compilerAddVariable(Compiler* compiler, const char* name,
|
||||
local->name = name;
|
||||
local->length = length;
|
||||
local->depth = compiler->scope_depth;
|
||||
local->is_upvalue = false;
|
||||
local->line = line;
|
||||
return compiler->func->local_count++;
|
||||
}
|
||||
@ -2021,7 +2128,12 @@ static int compilerPopLocals(Compiler* compiler, int depth) {
|
||||
// continue). So we need the pop instruction here but we still need the
|
||||
// locals to continue parsing the next statements in the scope. They'll be
|
||||
// popped once the scope is ended.
|
||||
emitByte(compiler, OP_POP);
|
||||
|
||||
if (compiler->func->locals[local].is_upvalue) {
|
||||
emitByte(compiler, OP_CLOSE_UPVALUE);
|
||||
} else {
|
||||
emitByte(compiler, OP_POP);
|
||||
}
|
||||
|
||||
local--;
|
||||
}
|
||||
@ -2154,7 +2266,7 @@ static void compileStatement(Compiler* compiler);
|
||||
static void compileBlockBody(Compiler* compiler, BlockType type);
|
||||
|
||||
// Compile a class and return it's index in the module's types buffer.
|
||||
static int compileClass(Compiler* compiler) {
|
||||
static void compileClass(Compiler* compiler) {
|
||||
|
||||
// Consume the name of the type.
|
||||
consume(compiler, TK_NAME, "Expected a type name.");
|
||||
@ -2238,12 +2350,10 @@ static int compileClass(Compiler* compiler) {
|
||||
compilerExitBlock(compiler);
|
||||
emitFunctionEnd(compiler);
|
||||
compilerPopFunc(compiler);
|
||||
|
||||
return -1; // TODO;
|
||||
}
|
||||
|
||||
// Compile a function and return it's index in the module's function buffer.
|
||||
static int compileFunction(Compiler* compiler, bool is_literal) {
|
||||
static void compileFunction(Compiler* compiler, bool is_literal) {
|
||||
|
||||
const char* name;
|
||||
int name_length;
|
||||
@ -2334,7 +2444,20 @@ static int compileFunction(Compiler* compiler, bool is_literal) {
|
||||
|
||||
compilerPopFunc(compiler);
|
||||
|
||||
return fn_index;
|
||||
// Note: After the above compilerPopFunc() call, now we're at the outer
|
||||
// function of this function, and the bellow emit calls will write to the
|
||||
// outer function. If it's a literal function, we need to push a closure
|
||||
// of it on the stack.
|
||||
if (is_literal) {
|
||||
emitOpcode(compiler, OP_PUSH_CLOSURE);
|
||||
emitShort(compiler, fn_index);
|
||||
|
||||
// Capture the upvalues when the closure is created.
|
||||
for (int i = 0; i < curr_fn.ptr->upvalue_count; i++) {
|
||||
emitByte(compiler, (curr_fn.upvalues[i].is_immediate) ? 1 : 0);
|
||||
emitByte(compiler, curr_fn.upvalues[i].index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finish a block body.
|
||||
@ -2524,7 +2647,7 @@ static int compilerImportName(Compiler* compiler, int line,
|
||||
// Make it possible to override any name (ie. the syntax `print = 1`
|
||||
// should pass) and allow imported entries to have the same name of
|
||||
// builtin functions.
|
||||
case NAME_BUILTIN:
|
||||
case NAME_BUILTIN_FN:
|
||||
parseError(compiler, "Name '%.*s' already exists.", length, name);
|
||||
return -1;
|
||||
}
|
||||
|
@ -241,6 +241,15 @@ void dumpFunctionCode(PKVM* vm, Function* func) {
|
||||
break;
|
||||
}
|
||||
|
||||
case OP_PUSH_UPVALUE:
|
||||
case OP_STORE_UPVALUE:
|
||||
{
|
||||
int index = READ_BYTE();
|
||||
PRINT_INT(index);
|
||||
NEWLINE();
|
||||
break;
|
||||
}
|
||||
|
||||
case OP_PUSH_CLOSURE:
|
||||
{
|
||||
int index = READ_SHORT();
|
||||
@ -256,7 +265,11 @@ void dumpFunctionCode(PKVM* vm, Function* func) {
|
||||
break;
|
||||
}
|
||||
|
||||
case OP_POP: NO_ARGS(); break;
|
||||
case OP_CLOSE_UPVALUE:
|
||||
case OP_POP:
|
||||
NO_ARGS();
|
||||
break;
|
||||
|
||||
case OP_IMPORT:
|
||||
{
|
||||
int index = READ_SHORT();
|
||||
|
@ -78,7 +78,7 @@
|
||||
|
||||
// Name of a literal function. All literal function will have the same name but
|
||||
// they're uniquely identified by their index in the script's function buffer.
|
||||
#define LITERAL_FN_NAME "@literalFn"
|
||||
#define LITERAL_FN_NAME "@func"
|
||||
|
||||
/*****************************************************************************/
|
||||
/* ALLOCATION MACROS */
|
||||
|
@ -97,11 +97,24 @@ OPCODE(STORE_GLOBAL, 1, 0)
|
||||
// params: 1 bytes index.
|
||||
OPCODE(PUSH_BUILTIN_FN, 1, 1)
|
||||
|
||||
// Push an upvalue of the current closure at the index which is the first one
|
||||
// byte argument.
|
||||
// params: 1 byte index.
|
||||
OPCODE(PUSH_UPVALUE, 1, 1)
|
||||
|
||||
// Store the stack top value to the upvalues of the current function's upvalues
|
||||
// array and don't pop it, since it's the result of the assignment.
|
||||
// params: 1 byte index.
|
||||
OPCODE(STORE_UPVALUE, 1, 0)
|
||||
|
||||
// Push a closure for the function at the constant pool with index of the
|
||||
// first 2 bytes arguments.
|
||||
// params: 2 byte index.
|
||||
OPCODE(PUSH_CLOSURE, 2, 1)
|
||||
|
||||
// Close the upvalue for the local at the stack top and pop it.
|
||||
OPCODE(CLOSE_UPVALUE, 0, -1)
|
||||
|
||||
// Pop the stack top.
|
||||
OPCODE(POP, 0, -1)
|
||||
|
||||
|
@ -206,7 +206,7 @@ static void popMarkedObjectsInternal(Object* obj, PKVM* vm) {
|
||||
markVarBuffer(vm, &module->globals);
|
||||
vm->bytes_allocated += sizeof(Var) * module->globals.capacity;
|
||||
|
||||
// Integer buffer has no gray call.
|
||||
// Integer buffer has no mark call.
|
||||
vm->bytes_allocated += sizeof(uint32_t) * module->global_names.capacity;
|
||||
|
||||
markVarBuffer(vm, &module->constants);
|
||||
@ -510,7 +510,7 @@ Fiber* newFiber(PKVM* vm, Closure* closure) {
|
||||
fiber->sp = fiber->stack + 1;
|
||||
|
||||
} else {
|
||||
// Allocate stack.
|
||||
// Calculate the stack size.
|
||||
int stack_size = utilPowerOf2Ceil(closure->fn->fn->stack_size + 1);
|
||||
if (stack_size < MIN_STACK_SIZE) stack_size = MIN_STACK_SIZE;
|
||||
fiber->stack = ALLOCATE_ARRAY(vm, Var, stack_size);
|
||||
@ -529,6 +529,8 @@ Fiber* newFiber(PKVM* vm, Closure* closure) {
|
||||
fiber->frames[0].rbp = fiber->ret;
|
||||
}
|
||||
|
||||
fiber->open_upvalues = NULL;
|
||||
|
||||
// Initialize the return value to null (doesn't really have to do that here
|
||||
// but if we're trying to debut it may crash when dumping the return value).
|
||||
*fiber->ret = VAR_NULL;
|
||||
|
@ -465,6 +465,11 @@ struct Fiber {
|
||||
// The stack pointer (%rsp) pointing to the stack top.
|
||||
Var* sp;
|
||||
|
||||
// All the open upvalues will form a linked list in the fiber and the
|
||||
// upvalues are sorted in the same order their locals in the stack. The
|
||||
// bellow pointer is the head of those upvalues near the stack top.
|
||||
Upvalue* open_upvalues;
|
||||
|
||||
// The stack base pointer of the current frame. It'll be updated before
|
||||
// calling a native function. (`fiber->ret` === `curr_call_frame->rbp`). And
|
||||
// also updated if the stack is reallocated (that's when it's about to get
|
||||
|
130
src/pk_vm.c
130
src/pk_vm.c
@ -568,6 +568,93 @@ static inline void reuseCallFrame(PKVM* vm, const Closure* closure) {
|
||||
if (vm->fiber->stack_size <= needed) growStack(vm, needed);
|
||||
}
|
||||
|
||||
// Capture the [local] into an upvalue and return it. If the upvalue already
|
||||
// exists on the fiber, it'll return it.
|
||||
static Upvalue* captureUpvalue(PKVM* vm, Fiber* fiber, Var* local) {
|
||||
|
||||
// If the fiber doesn't have any upvalues yet, create new one and add it.
|
||||
if (fiber->open_upvalues == NULL) {
|
||||
Upvalue* upvalue = newUpvalue(vm, local);
|
||||
fiber->open_upvalues = upvalue;
|
||||
return upvalue;
|
||||
}
|
||||
|
||||
// In the bellow diagram 'u0' is the head of the open upvalues of the fiber.
|
||||
// We'll walk through the upvalues to see if any of it's value is similar
|
||||
// to the [local] we want to capture.
|
||||
//
|
||||
// This can be optimized with binary search since the upvalues are sorted
|
||||
// but it's not a frequent task neither the number of upvalues would be very
|
||||
// few and the local mostly located at the stack top.
|
||||
//
|
||||
// 1. If say 'l3' is what we want to capture, that local already has an
|
||||
// upavlue 'u1' return it.
|
||||
// 2. If say 'l4' is what we want to capture, It doesn't have an upvalue yet.
|
||||
// Create a new upvalue and insert to the link list (ie. u1.next = u3,
|
||||
// u3.next = u2) and return it.
|
||||
//
|
||||
// | |
|
||||
// | l1 | <-- u0 (u1.value = l3)
|
||||
// | l2 | |
|
||||
// | l3 | <-- u1 (u1.value = l3)
|
||||
// | l4 | |
|
||||
// | l5 | <-- u2 (u2.value = l5)
|
||||
// '------' |
|
||||
// stack NULL
|
||||
|
||||
// Edge case: if the local is located higher than all the open upvalues, we
|
||||
// cannot walk the chain, it's going to be the new head of the open upvalues.
|
||||
if (fiber->open_upvalues->ptr < local) {
|
||||
Upvalue* head = newUpvalue(vm, local);
|
||||
head->next = fiber->open_upvalues;
|
||||
fiber->open_upvalues = head;
|
||||
return head;
|
||||
}
|
||||
|
||||
// Now we walk the chain of open upvalues and if we find an upvalue for the
|
||||
// local return it, otherwise insert it in the chain.
|
||||
Upvalue* last = NULL;
|
||||
Upvalue* current = fiber->open_upvalues;
|
||||
|
||||
while (current->ptr > local) {
|
||||
last = current;
|
||||
current = current->next;
|
||||
|
||||
// If the current is NULL, we've walked all the way to the end of the open
|
||||
// upvalues, and there isn't one upvalue for the local.
|
||||
if (current == NULL) {
|
||||
last->next = newUpvalue(vm, local);
|
||||
return last->next;
|
||||
}
|
||||
}
|
||||
|
||||
// If [current] is the upvalue that captured [local] then return it.
|
||||
if (current->ptr == local) return current;
|
||||
|
||||
ASSERT(last != NULL, OOPS);
|
||||
|
||||
// If we've reached here, the upvalue isn't found, create a new one and
|
||||
// insert it to the chain.
|
||||
Upvalue* upvalue = newUpvalue(vm, local);
|
||||
last->next = upvalue;
|
||||
upvalue->next = current;
|
||||
return upvalue;
|
||||
}
|
||||
|
||||
// Close all the upvalues for the locals including [top] and higher in the
|
||||
// stack.
|
||||
static void closeUpvalues(Fiber* fiber, Var* top) {
|
||||
|
||||
while (fiber->open_upvalues != NULL && fiber->open_upvalues->ptr >= top) {
|
||||
Upvalue* upvalue = fiber->open_upvalues;
|
||||
upvalue->closed = *upvalue->ptr;
|
||||
upvalue->ptr = &upvalue->closed;
|
||||
|
||||
fiber->open_upvalues = upvalue->next;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static void reportError(PKVM* vm) {
|
||||
ASSERT(VM_HAS_ERROR(vm), "runtimeError() should be called after an error.");
|
||||
// TODO: pass the error to the caller of the fiber.
|
||||
@ -863,13 +950,51 @@ L_vm_main_loop:
|
||||
DISPATCH();
|
||||
}
|
||||
|
||||
OPCODE(PUSH_UPVALUE):
|
||||
{
|
||||
uint8_t index = READ_BYTE();
|
||||
PUSH(*(frame->closure->upvalues[index]->ptr));
|
||||
DISPATCH();
|
||||
}
|
||||
|
||||
OPCODE(STORE_UPVALUE):
|
||||
{
|
||||
uint8_t index = READ_BYTE();
|
||||
*(frame->closure->upvalues[index]->ptr) = PEEK(-1);
|
||||
DISPATCH();
|
||||
}
|
||||
|
||||
OPCODE(PUSH_CLOSURE):
|
||||
{
|
||||
uint16_t index = READ_SHORT();
|
||||
ASSERT_INDEX(index, module->constants.count);
|
||||
ASSERT(IS_OBJ_TYPE(module->constants.data[index], OBJ_FUNC), OOPS);
|
||||
Function* fn = (Function*)AS_OBJ(module->constants.data[index]);
|
||||
PUSH(VAR_OBJ(newClosure(vm, fn)));
|
||||
Closure* closure = newClosure(vm, fn);
|
||||
|
||||
// Capture the vaupes.
|
||||
for (int i = 0; i < fn->upvalue_count; i++) {
|
||||
uint8_t is_immediate = READ_BYTE();
|
||||
uint8_t index = READ_BYTE();
|
||||
|
||||
if (is_immediate) {
|
||||
// rbp[0] is the return value, rbp + 1 is the first local and so on.
|
||||
closure->upvalues[i] = captureUpvalue(vm, vm->fiber,
|
||||
(rbp + 1 + index));
|
||||
} else {
|
||||
// The upvalue is already captured by the current function, reuse it.
|
||||
closure->upvalues[i] = frame->closure->upvalues[index];
|
||||
}
|
||||
}
|
||||
|
||||
PUSH(VAR_OBJ(closure));
|
||||
DISPATCH();
|
||||
}
|
||||
|
||||
OPCODE(CLOSE_UPVALUE):
|
||||
{
|
||||
closeUpvalues(vm->fiber, vm->fiber->sp - 1);
|
||||
DROP();
|
||||
DISPATCH();
|
||||
}
|
||||
|
||||
@ -1169,6 +1294,9 @@ L_vm_main_loop:
|
||||
OPCODE(RETURN):
|
||||
{
|
||||
|
||||
// Close all the locals of the current frame.
|
||||
closeUpvalues(vm->fiber, rbp + 1);
|
||||
|
||||
// Set the return value.
|
||||
Var ret_value = POP();
|
||||
|
||||
|
87
tests/lang/closure.pk
Normal file
87
tests/lang/closure.pk
Normal file
@ -0,0 +1,87 @@
|
||||
|
||||
## Simple upvalue.
|
||||
def f1
|
||||
local = "foo"
|
||||
return func
|
||||
return local
|
||||
end
|
||||
end
|
||||
assert(f1()() == "foo")
|
||||
|
||||
def add3(x)
|
||||
return func(y)
|
||||
return func(z)
|
||||
return x + y + z
|
||||
end
|
||||
end
|
||||
end
|
||||
assert(add3(1)(2)(3) == 6);
|
||||
assert(add3(7)(6)(4) == 17);
|
||||
|
||||
## Upvalue external to the inner function.
|
||||
def f2
|
||||
local = "bar"
|
||||
return func
|
||||
fn = func
|
||||
return local
|
||||
end
|
||||
return fn
|
||||
end
|
||||
end
|
||||
assert(f2()()() == "bar")
|
||||
|
||||
## Check if upvalues are shared between closures.
|
||||
def f3
|
||||
local = "baz"
|
||||
_fn1 = func(x)
|
||||
local = x
|
||||
end
|
||||
_fn2 = func
|
||||
return local
|
||||
end
|
||||
return [_fn1, _fn2]
|
||||
end
|
||||
fns = f3()
|
||||
fns[0]("qux")
|
||||
assert(fns[1]() == "qux")
|
||||
|
||||
def f4
|
||||
a = []
|
||||
x = 10
|
||||
for i in 0..2
|
||||
j = i ## 'i' is shared, but 'j' doesn't
|
||||
list_append(
|
||||
a,
|
||||
func
|
||||
return x + j
|
||||
end
|
||||
)
|
||||
end
|
||||
x = 20
|
||||
return a
|
||||
end
|
||||
a = f4()
|
||||
assert(a[0]() == 20)
|
||||
assert(a[1]() == 21)
|
||||
|
||||
def f5
|
||||
l1 = 12
|
||||
return func ## c1
|
||||
l2 = 34
|
||||
return func ## c2
|
||||
l3 = 56
|
||||
return func ## c3
|
||||
return func ## c4
|
||||
return func ## c5
|
||||
return l1 + l2 + l3
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
print(f5()()()()()() == 102)
|
||||
|
||||
print('All TESTS PASSED')
|
||||
|
@ -18,6 +18,7 @@ TEST_SUITE = {
|
||||
"lang/basics.pk",
|
||||
"lang/builtin_fn.pk",
|
||||
"lang/class.pk",
|
||||
"lang/closure.pk",
|
||||
"lang/core.pk",
|
||||
"lang/controlflow.pk",
|
||||
"lang/fibers.pk",
|
||||
|
Loading…
Reference in New Issue
Block a user