closures were completely added.

This commit is contained in:
Thakee Nathees 2022-04-13 14:31:57 +05:30
parent 2e1d8d89dd
commit 8d062e38b8
9 changed files with 300 additions and 22 deletions

View File

@ -1332,7 +1332,7 @@ typedef enum {
NAME_LOCAL_VAR, //< Including parameter.
NAME_UPVALUE, //< Local to an enclosing function.
NAME_GLOBAL_VAR,
NAME_BUILTIN, //< Native builtin function.
NAME_BUILTIN_FN, //< Native builtin function.
} NameDefnType;
// Identifier search result.
@ -1384,7 +1384,7 @@ static NameSearchResult compilerSearchName(Compiler* compiler,
// Search through builtin functions.
index = findBuiltinFunction(compiler->parser.vm, name, length);
if (index != -1) {
result.type = NAME_BUILTIN;
result.type = NAME_BUILTIN_FN;
result.index = index;
return result;
}
@ -1418,7 +1418,7 @@ static void compilerChangeStack(Compiler* compiler, int num);
// Forward declaration of grammar functions.
static void parsePrecedence(Compiler* compiler, Precedence precedence);
static int compileFunction(Compiler* compiler, bool is_literal);
static void compileFunction(Compiler* compiler, bool is_literal);
static void compileExpression(Compiler* compiler);
static void exprLiteral(Compiler* compiler);
@ -1540,12 +1540,13 @@ static void emitStoreGlobal(Compiler* compiler, int index) {
// Emit opcode to push the named value at the [index] in it's array.
static void emitPushName(Compiler* compiler, NameDefnType type, int index) {
ASSERT(index >= 0, OOPS);
switch (type) {
case NAME_NOT_DEFINED:
UNREACHABLE();
case NAME_LOCAL_VAR:
ASSERT(index >= 0, OOPS);
if (index < 9) { //< 0..8 locals have single opcode.
emitOpcode(compiler, (Opcode)(OP_PUSH_LOCAL_0 + index));
} else {
@ -1553,11 +1554,18 @@ static void emitPushName(Compiler* compiler, NameDefnType type, int index) {
emitByte(compiler, index);
}
return;
case NAME_UPVALUE:
emitOpcode(compiler, OP_PUSH_UPVALUE);
emitByte(compiler, index);
return;
case NAME_GLOBAL_VAR:
emitOpcode(compiler, OP_PUSH_GLOBAL);
emitByte(compiler, index);
return;
case NAME_BUILTIN:
case NAME_BUILTIN_FN:
emitOpcode(compiler, OP_PUSH_BUILTIN_FN);
emitByte(compiler, index);
return;
@ -1567,13 +1575,14 @@ static void emitPushName(Compiler* compiler, NameDefnType type, int index) {
// Emit opcode to store the stack top value to the named value at the [index]
// in it's array.
static void emitStoreName(Compiler* compiler, NameDefnType type, int index) {
ASSERT(index >= 0, OOPS);
switch (type) {
case NAME_NOT_DEFINED:
case NAME_BUILTIN:
case NAME_BUILTIN_FN:
UNREACHABLE();
case NAME_LOCAL_VAR:
ASSERT(index >= 0, OOPS);
if (index < 9) { //< 0..8 locals have single opcode.
emitOpcode(compiler, (Opcode)(OP_STORE_LOCAL_0 + index));
} else {
@ -1581,6 +1590,12 @@ static void emitStoreName(Compiler* compiler, NameDefnType type, int index) {
emitByte(compiler, index);
}
return;
case NAME_UPVALUE:
emitOpcode(compiler, OP_STORE_UPVALUE);
emitByte(compiler, index);
return;
case NAME_GLOBAL_VAR:
emitStoreGlobal(compiler, index);
return;
@ -1652,9 +1667,7 @@ static void exprInterpolation(Compiler* compiler) {
}
static void exprFunc(Compiler* compiler) {
int fn_index = compileFunction(compiler, true);
emitOpcode(compiler, OP_PUSH_CLOSURE);
emitShort(compiler, fn_index);
compileFunction(compiler, true);
}
static void exprName(Compiler* compiler) {
@ -1684,7 +1697,7 @@ static void exprName(Compiler* compiler) {
// like python does) and it's recommented to define all the globals
// before entering a local scope.
if (result.type == NAME_NOT_DEFINED || result.type == NAME_BUILTIN) {
if (result.type == NAME_NOT_DEFINED || result.type == NAME_BUILTIN_FN) {
name_type = (compiler->scope_depth == DEPTH_GLOBAL)
? NAME_GLOBAL_VAR
: NAME_LOCAL_VAR;
@ -2115,7 +2128,12 @@ static int compilerPopLocals(Compiler* compiler, int depth) {
// continue). So we need the pop instruction here but we still need the
// locals to continue parsing the next statements in the scope. They'll be
// popped once the scope is ended.
emitByte(compiler, OP_POP);
if (compiler->func->locals[local].is_upvalue) {
emitByte(compiler, OP_CLOSE_UPVALUE);
} else {
emitByte(compiler, OP_POP);
}
local--;
}
@ -2248,7 +2266,7 @@ static void compileStatement(Compiler* compiler);
static void compileBlockBody(Compiler* compiler, BlockType type);
// Compile a class and return it's index in the module's types buffer.
static int compileClass(Compiler* compiler) {
static void compileClass(Compiler* compiler) {
// Consume the name of the type.
consume(compiler, TK_NAME, "Expected a type name.");
@ -2332,12 +2350,10 @@ static int compileClass(Compiler* compiler) {
compilerExitBlock(compiler);
emitFunctionEnd(compiler);
compilerPopFunc(compiler);
return -1; // TODO;
}
// Compile a function and return it's index in the module's function buffer.
static int compileFunction(Compiler* compiler, bool is_literal) {
static void compileFunction(Compiler* compiler, bool is_literal) {
const char* name;
int name_length;
@ -2428,7 +2444,20 @@ static int compileFunction(Compiler* compiler, bool is_literal) {
compilerPopFunc(compiler);
return fn_index;
// Note: After the above compilerPopFunc() call, now we're at the outer
// function of this function, and the bellow emit calls will write to the
// outer function. If it's a literal function, we need to push a closure
// of it on the stack.
if (is_literal) {
emitOpcode(compiler, OP_PUSH_CLOSURE);
emitShort(compiler, fn_index);
// Capture the upvalues when the closure is created.
for (int i = 0; i < curr_fn.ptr->upvalue_count; i++) {
emitByte(compiler, (curr_fn.upvalues[i].is_immediate) ? 1 : 0);
emitByte(compiler, curr_fn.upvalues[i].index);
}
}
}
// Finish a block body.
@ -2618,7 +2647,7 @@ static int compilerImportName(Compiler* compiler, int line,
// Make it possible to override any name (ie. the syntax `print = 1`
// should pass) and allow imported entries to have the same name of
// builtin functions.
case NAME_BUILTIN:
case NAME_BUILTIN_FN:
parseError(compiler, "Name '%.*s' already exists.", length, name);
return -1;
}

View File

@ -241,6 +241,15 @@ void dumpFunctionCode(PKVM* vm, Function* func) {
break;
}
case OP_PUSH_UPVALUE:
case OP_STORE_UPVALUE:
{
int index = READ_BYTE();
PRINT_INT(index);
NEWLINE();
break;
}
case OP_PUSH_CLOSURE:
{
int index = READ_SHORT();
@ -256,7 +265,11 @@ void dumpFunctionCode(PKVM* vm, Function* func) {
break;
}
case OP_POP: NO_ARGS(); break;
case OP_CLOSE_UPVALUE:
case OP_POP:
NO_ARGS();
break;
case OP_IMPORT:
{
int index = READ_SHORT();

View File

@ -78,7 +78,7 @@
// Name of a literal function. All literal function will have the same name but
// they're uniquely identified by their index in the script's function buffer.
#define LITERAL_FN_NAME "@literalFn"
#define LITERAL_FN_NAME "@func"
/*****************************************************************************/
/* ALLOCATION MACROS */

View File

@ -97,11 +97,24 @@ OPCODE(STORE_GLOBAL, 1, 0)
// params: 1 bytes index.
OPCODE(PUSH_BUILTIN_FN, 1, 1)
// Push an upvalue of the current closure at the index which is the first one
// byte argument.
// params: 1 byte index.
OPCODE(PUSH_UPVALUE, 1, 1)
// Store the stack top value to the upvalues of the current function's upvalues
// array and don't pop it, since it's the result of the assignment.
// params: 1 byte index.
OPCODE(STORE_UPVALUE, 1, 0)
// Push a closure for the function at the constant pool with index of the
// first 2 bytes arguments.
// params: 2 byte index.
OPCODE(PUSH_CLOSURE, 2, 1)
// Close the upvalue for the local at the stack top and pop it.
OPCODE(CLOSE_UPVALUE, 0, -1)
// Pop the stack top.
OPCODE(POP, 0, -1)

View File

@ -510,7 +510,7 @@ Fiber* newFiber(PKVM* vm, Closure* closure) {
fiber->sp = fiber->stack + 1;
} else {
// Allocate stack.
// Calculate the stack size.
int stack_size = utilPowerOf2Ceil(closure->fn->fn->stack_size + 1);
if (stack_size < MIN_STACK_SIZE) stack_size = MIN_STACK_SIZE;
fiber->stack = ALLOCATE_ARRAY(vm, Var, stack_size);
@ -529,6 +529,8 @@ Fiber* newFiber(PKVM* vm, Closure* closure) {
fiber->frames[0].rbp = fiber->ret;
}
fiber->open_upvalues = NULL;
// Initialize the return value to null (doesn't really have to do that here
// but if we're trying to debut it may crash when dumping the return value).
*fiber->ret = VAR_NULL;

View File

@ -465,6 +465,11 @@ struct Fiber {
// The stack pointer (%rsp) pointing to the stack top.
Var* sp;
// All the open upvalues will form a linked list in the fiber and the
// upvalues are sorted in the same order their locals in the stack. The
// bellow pointer is the head of those upvalues near the stack top.
Upvalue* open_upvalues;
// The stack base pointer of the current frame. It'll be updated before
// calling a native function. (`fiber->ret` === `curr_call_frame->rbp`). And
// also updated if the stack is reallocated (that's when it's about to get

View File

@ -568,6 +568,93 @@ static inline void reuseCallFrame(PKVM* vm, const Closure* closure) {
if (vm->fiber->stack_size <= needed) growStack(vm, needed);
}
// Capture the [local] into an upvalue and return it. If the upvalue already
// exists on the fiber, it'll return it.
static Upvalue* captureUpvalue(PKVM* vm, Fiber* fiber, Var* local) {
// If the fiber doesn't have any upvalues yet, create new one and add it.
if (fiber->open_upvalues == NULL) {
Upvalue* upvalue = newUpvalue(vm, local);
fiber->open_upvalues = upvalue;
return upvalue;
}
// In the bellow diagram 'u0' is the head of the open upvalues of the fiber.
// We'll walk through the upvalues to see if any of it's value is similar
// to the [local] we want to capture.
//
// This can be optimized with binary search since the upvalues are sorted
// but it's not a frequent task neither the number of upvalues would be very
// few and the local mostly located at the stack top.
//
// 1. If say 'l3' is what we want to capture, that local already has an
// upavlue 'u1' return it.
// 2. If say 'l4' is what we want to capture, It doesn't have an upvalue yet.
// Create a new upvalue and insert to the link list (ie. u1.next = u3,
// u3.next = u2) and return it.
//
// | |
// | l1 | <-- u0 (u1.value = l3)
// | l2 | |
// | l3 | <-- u1 (u1.value = l3)
// | l4 | |
// | l5 | <-- u2 (u2.value = l5)
// '------' |
// stack NULL
// Edge case: if the local is located higher than all the open upvalues, we
// cannot walk the chain, it's going to be the new head of the open upvalues.
if (fiber->open_upvalues->ptr < local) {
Upvalue* head = newUpvalue(vm, local);
head->next = fiber->open_upvalues;
fiber->open_upvalues = head;
return head;
}
// Now we walk the chain of open upvalues and if we find an upvalue for the
// local return it, otherwise insert it in the chain.
Upvalue* last = NULL;
Upvalue* current = fiber->open_upvalues;
while (current->ptr > local) {
last = current;
current = current->next;
// If the current is NULL, we've walked all the way to the end of the open
// upvalues, and there isn't one upvalue for the local.
if (current == NULL) {
last->next = newUpvalue(vm, local);
return last->next;
}
}
// If [current] is the upvalue that captured [local] then return it.
if (current->ptr == local) return current;
ASSERT(last != NULL, OOPS);
// If we've reached here, the upvalue isn't found, create a new one and
// insert it to the chain.
Upvalue* upvalue = newUpvalue(vm, local);
last->next = upvalue;
upvalue->next = current;
return upvalue;
}
// Close all the upvalues for the locals including [top] and higher in the
// stack.
static void closeUpvalues(Fiber* fiber, Var* top) {
while (fiber->open_upvalues != NULL && fiber->open_upvalues->ptr >= top) {
Upvalue* upvalue = fiber->open_upvalues;
upvalue->closed = *upvalue->ptr;
upvalue->ptr = &upvalue->closed;
fiber->open_upvalues = upvalue->next;
}
}
static void reportError(PKVM* vm) {
ASSERT(VM_HAS_ERROR(vm), "runtimeError() should be called after an error.");
// TODO: pass the error to the caller of the fiber.
@ -863,13 +950,51 @@ L_vm_main_loop:
DISPATCH();
}
OPCODE(PUSH_UPVALUE):
{
uint8_t index = READ_BYTE();
PUSH(*(frame->closure->upvalues[index]->ptr));
DISPATCH();
}
OPCODE(STORE_UPVALUE):
{
uint8_t index = READ_BYTE();
*(frame->closure->upvalues[index]->ptr) = PEEK(-1);
DISPATCH();
}
OPCODE(PUSH_CLOSURE):
{
uint16_t index = READ_SHORT();
ASSERT_INDEX(index, module->constants.count);
ASSERT(IS_OBJ_TYPE(module->constants.data[index], OBJ_FUNC), OOPS);
Function* fn = (Function*)AS_OBJ(module->constants.data[index]);
PUSH(VAR_OBJ(newClosure(vm, fn)));
Closure* closure = newClosure(vm, fn);
// Capture the vaupes.
for (int i = 0; i < fn->upvalue_count; i++) {
uint8_t is_immediate = READ_BYTE();
uint8_t index = READ_BYTE();
if (is_immediate) {
// rbp[0] is the return value, rbp + 1 is the first local and so on.
closure->upvalues[i] = captureUpvalue(vm, vm->fiber,
(rbp + 1 + index));
} else {
// The upvalue is already captured by the current function, reuse it.
closure->upvalues[i] = frame->closure->upvalues[index];
}
}
PUSH(VAR_OBJ(closure));
DISPATCH();
}
OPCODE(CLOSE_UPVALUE):
{
closeUpvalues(vm->fiber, vm->fiber->sp - 1);
DROP();
DISPATCH();
}
@ -1169,6 +1294,9 @@ L_vm_main_loop:
OPCODE(RETURN):
{
// Close all the locals of the current frame.
closeUpvalues(vm->fiber, rbp + 1);
// Set the return value.
Var ret_value = POP();

87
tests/lang/closure.pk Normal file
View File

@ -0,0 +1,87 @@
## Simple upvalue.
def f1
local = "foo"
return func
return local
end
end
assert(f1()() == "foo")
def add3(x)
return func(y)
return func(z)
return x + y + z
end
end
end
assert(add3(1)(2)(3) == 6);
assert(add3(7)(6)(4) == 17);
## Upvalue external to the inner function.
def f2
local = "bar"
return func
fn = func
return local
end
return fn
end
end
assert(f2()()() == "bar")
## Check if upvalues are shared between closures.
def f3
local = "baz"
_fn1 = func(x)
local = x
end
_fn2 = func
return local
end
return [_fn1, _fn2]
end
fns = f3()
fns[0]("qux")
assert(fns[1]() == "qux")
def f4
a = []
x = 10
for i in 0..2
j = i ## 'i' is shared, but 'j' doesn't
list_append(
a,
func
return x + j
end
)
end
x = 20
return a
end
a = f4()
assert(a[0]() == 20)
assert(a[1]() == 21)
def f5
l1 = 12
return func ## c1
l2 = 34
return func ## c2
l3 = 56
return func ## c3
return func ## c4
return func ## c5
return l1 + l2 + l3
end
end
end
end
end
end
print(f5()()()()()() == 102)
print('All TESTS PASSED')

View File

@ -18,6 +18,7 @@ TEST_SUITE = {
"lang/basics.pk",
"lang/builtin_fn.pk",
"lang/class.pk",
"lang/closure.pk",
"lang/core.pk",
"lang/controlflow.pk",
"lang/fibers.pk",