// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" #include "ScopedFlags.h" #include "doctest.h" using namespace Luau; namespace Luau { namespace Compile { uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount, const DenseHashMap& builtins); int computeCost(uint64_t model, const bool* varsConst, size_t varCount); } // namespace Compile } // namespace Luau TEST_SUITE_BEGIN("CostModel"); static uint64_t modelFunction(const char* source) { Allocator allocator; AstNameTable names(allocator); ParseResult result = Parser::parse(source, strlen(source), names, allocator); REQUIRE(result.root != nullptr); AstStatFunction* func = result.root->body.data[0]->as(); REQUIRE(func); return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size, DenseHashMap{nullptr}); } TEST_CASE("Expression") { uint64_t model = modelFunction(R"( function test(a, b, c) return a + (b + 1) * (b + 1) - c end )"); const bool args1[] = {false, false, false}; const bool args2[] = {false, true, false}; CHECK_EQ(5, Luau::Compile::computeCost(model, args1, 3)); CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 3)); } TEST_CASE("PropagateVariable") { uint64_t model = modelFunction(R"( function test(a) local b = a * a * a return b * b end )"); const bool args1[] = {false}; const bool args2[] = {true}; CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); CHECK_EQ(0, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("LoopAssign") { uint64_t model = modelFunction(R"( function test(a) for i=1,3 do a[i] = i end end )"); const bool args1[] = {false}; const bool args2[] = {true}; // loop baseline cost is 5 CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("MutableVariable") { uint64_t model = modelFunction(R"( function test(a, b) local x = a * a x += b return x * x end )"); const bool args1[] = {false}; const bool args2[] = {true}; CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("ImportCall") { uint64_t model = modelFunction(R"( function test(a) return Instance.new(a) end )"); const bool args1[] = {false}; const bool args2[] = {true}; CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("FastCall") { uint64_t model = modelFunction(R"( function test(a) return math.abs(a + 1) end )"); const bool args1[] = {false}; const bool args2[] = {true}; // note: we currently don't treat fast calls differently from cost model perspective CHECK_EQ(6, Luau::Compile::computeCost(model, args1, 1)); CHECK_EQ(5, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("ControlFlow") { uint64_t model = modelFunction(R"( function test(a) while a < 0 do a += 1 end for i=10,1,-1 do a += 1 end for i in pairs({}) do a += 1 if a % 2 == 0 then continue end end repeat a += 1 if a % 2 == 0 then break end until a > 10 return a end )"); const bool args1[] = {false}; const bool args2[] = {true}; CHECK_EQ(76, Luau::Compile::computeCost(model, args1, 1)); CHECK_EQ(73, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("Conditional") { uint64_t model = modelFunction(R"( function test(a) return if a < 0 then -a else a end )"); const bool args1[] = {false}; const bool args2[] = {true}; CHECK_EQ(4, Luau::Compile::computeCost(model, args1, 1)); CHECK_EQ(2, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("VarArgs") { uint64_t model = modelFunction(R"( function test(...) return select('#', ...) :: number end )"); CHECK_EQ(8, Luau::Compile::computeCost(model, nullptr, 0)); } TEST_CASE("TablesFunctions") { uint64_t model = modelFunction(R"( function test() return { 42, op = function() end } end )"); CHECK_EQ(22, Luau::Compile::computeCost(model, nullptr, 0)); } TEST_CASE("CostOverflow") { uint64_t model = modelFunction(R"( function test() return {{{{{{{{{{{{{{{}}}}}}}}}}}}}}} end )"); CHECK_EQ(127, Luau::Compile::computeCost(model, nullptr, 0)); } TEST_CASE("TableAssign") { uint64_t model = modelFunction(R"( function test(a) for i=1,#a do a[i] = i end end )"); const bool args1[] = {false}; const bool args2[] = {true}; CHECK_EQ(7, Luau::Compile::computeCost(model, args1, 1)); CHECK_EQ(6, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("InterpString") { uint64_t model = modelFunction(R"( function test(a) return `hello, {a}!` end )"); const bool args1[] = {false}; const bool args2[] = {true}; CHECK_EQ(3, Luau::Compile::computeCost(model, args1, 1)); CHECK_EQ(3, Luau::Compile::computeCost(model, args2, 1)); } TEST_CASE("MultipleAssignments") { uint64_t model = modelFunction(R"( function test(a) local x = 0 x = a x = a + 1 x, x, x = a x = a, a, a end )"); const bool args1[] = {false}; const bool args2[] = {true}; CHECK_EQ(8, Luau::Compile::computeCost(model, args1, 1)); CHECK_EQ(7, Luau::Compile::computeCost(model, args2, 1)); } TEST_SUITE_END();