Sync to upstream/release/536 (#592)

This commit is contained in:
Arseny Kapoulkine 2022-07-14 15:52:26 -07:00 committed by GitHub
parent e87009f5b2
commit 5b2e39c922
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
53 changed files with 3875 additions and 2168 deletions

View File

@ -78,16 +78,6 @@ struct AutocompleteResult
using ModuleName = std::string; using ModuleName = std::string;
using StringCompletionCallback = std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassTypeVar*> ctx)>; using StringCompletionCallback = std::function<std::optional<AutocompleteEntryMap>(std::string tag, std::optional<const ClassTypeVar*> ctx)>;
struct OwningAutocompleteResult
{
AutocompleteResult result;
ModulePtr module;
std::unique_ptr<SourceModule> sourceModule;
};
AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback);
// Deprecated, do not use in new work.
OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback);
} // namespace Luau } // namespace Luau

View File

@ -99,6 +99,7 @@ struct ConstraintGraphBuilder
void visit(NotNull<Scope2> scope, AstStat* stat); void visit(NotNull<Scope2> scope, AstStat* stat);
void visit(NotNull<Scope2> scope, AstStatBlock* block); void visit(NotNull<Scope2> scope, AstStatBlock* block);
void visit(NotNull<Scope2> scope, AstStatLocal* local); void visit(NotNull<Scope2> scope, AstStatLocal* local);
void visit(NotNull<Scope2> scope, AstStatFor* for_);
void visit(NotNull<Scope2> scope, AstStatLocalFunction* function); void visit(NotNull<Scope2> scope, AstStatLocalFunction* function);
void visit(NotNull<Scope2> scope, AstStatFunction* function); void visit(NotNull<Scope2> scope, AstStatFunction* function);
void visit(NotNull<Scope2> scope, AstStatReturn* ret); void visit(NotNull<Scope2> scope, AstStatReturn* ret);

View File

@ -127,13 +127,6 @@ struct Frontend
CheckResult check(const ModuleName& name, std::optional<FrontendOptions> optionOverride = {}); // new shininess CheckResult check(const ModuleName& name, std::optional<FrontendOptions> optionOverride = {}); // new shininess
LintResult lint(const ModuleName& name, std::optional<LintOptions> enabledLintWarnings = {}); LintResult lint(const ModuleName& name, std::optional<LintOptions> enabledLintWarnings = {});
/** Lint some code that has no associated DataModel object
*
* Since this source fragment has no name, we cannot cache its AST. Instead,
* we return it to the caller to use as they wish.
*/
std::pair<SourceModule, LintResult> lintFragment(std::string_view source, std::optional<LintOptions> enabledLintWarnings = {});
LintResult lint(const SourceModule& module, std::optional<LintOptions> enabledLintWarnings = {}); LintResult lint(const SourceModule& module, std::optional<LintOptions> enabledLintWarnings = {});
bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; bool isDirty(const ModuleName& name, bool forAutocomplete = false) const;

View File

@ -79,6 +79,7 @@ private:
void tryUnifySingletons(TypeId subTy, TypeId superTy); void tryUnifySingletons(TypeId subTy, TypeId superTy);
void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false);
void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false);
void tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed);
void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed);

View File

@ -7,7 +7,6 @@
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
#include "Luau/TypePack.h" #include "Luau/TypePack.h"
#include "Luau/Parser.h" // TODO: only needed for autocompleteSource which is deprecated
#include <algorithm> #include <algorithm>
#include <unordered_set> #include <unordered_set>
@ -1407,8 +1406,8 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point;
if (!FFlag::LuauSelfCallAutocompleteFix2 && isString(ty)) if (!FFlag::LuauSelfCallAutocompleteFix2 && isString(ty))
return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, ancestry), return {
ancestry}; autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, ancestry), ancestry};
else else
return {autocompleteProps(*module, typeArena, ty, indexType, ancestry), ancestry}; return {autocompleteProps(*module, typeArena, ty, indexType, ancestry), ancestry};
} }
@ -1507,8 +1506,8 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M
else if (AstStatIf* statIf = node->as<AstStatIf>(); statIf && !statIf->elseLocation.has_value()) else if (AstStatIf* statIf = node->as<AstStatIf>(); statIf && !statIf->elseLocation.has_value())
{ {
return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, return {
ancestry}; {{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry};
} }
else if (AstStatIf* statIf = parent->as<AstStatIf>(); statIf && node->is<AstStatBlock>()) else if (AstStatIf* statIf = parent->as<AstStatIf>(); statIf && node->is<AstStatBlock>())
{ {
@ -1628,32 +1627,4 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName
return autocompleteResult; return autocompleteResult;
} }
OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback)
{
// TODO: Remove #include "Luau/Parser.h" with this function
auto sourceModule = std::make_unique<SourceModule>();
ParseOptions parseOptions;
parseOptions.captureComments = true;
ParseResult result = Parser::parse(source.data(), source.size(), *sourceModule->names, *sourceModule->allocator, parseOptions);
if (!result.root)
return {AutocompleteResult{}, {}, nullptr};
sourceModule->name = "FRAGMENT_SCRIPT";
sourceModule->root = result.root;
sourceModule->mode = Mode::Strict;
sourceModule->commentLocations = std::move(result.commentLocations);
TypeChecker& typeChecker = frontend.typeCheckerForAutocomplete;
ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict);
OwningAutocompleteResult autocompleteResult = {
autocomplete(*sourceModule, module, typeChecker, &frontend.arenaForAutocomplete, position, callback), std::move(module),
std::move(sourceModule)};
frontend.arenaForAutocomplete.clear();
return autocompleteResult;
}
} // namespace Luau } // namespace Luau

View File

@ -103,6 +103,8 @@ void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStat* stat)
visit(scope, s); visit(scope, s);
else if (auto s = stat->as<AstStatLocal>()) else if (auto s = stat->as<AstStatLocal>())
visit(scope, s); visit(scope, s);
else if (auto s = stat->as<AstStatFor>())
visit(scope, s);
else if (auto f = stat->as<AstStatFunction>()) else if (auto f = stat->as<AstStatFunction>())
visit(scope, f); visit(scope, f);
else if (auto f = stat->as<AstStatLocalFunction>()) else if (auto f = stat->as<AstStatLocalFunction>())
@ -167,6 +169,27 @@ void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatLocal* local)
} }
} }
void ConstraintGraphBuilder::visit(NotNull<Scope2> scope, AstStatFor* for_)
{
auto checkNumber = [&](AstExpr* expr)
{
if (!expr)
return;
TypeId t = check(scope, expr);
addConstraint(scope, SubtypeConstraint{t, singletonTypes.numberType});
};
checkNumber(for_->from);
checkNumber(for_->to);
checkNumber(for_->step);
NotNull<Scope2> forScope = childScope(for_->location, scope);
forScope->bindings[for_->var] = singletonTypes.numberType;
visit(forScope, for_->body);
}
void addConstraints(Constraint* constraint, NotNull<Scope2> scope) void addConstraints(Constraint* constraint, NotNull<Scope2> scope)
{ {
scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); scope->constraints.reserve(scope->constraints.size() + scope->constraints.size());

View File

@ -662,29 +662,6 @@ LintResult Frontend::lint(const ModuleName& name, std::optional<Luau::LintOption
return lint(*sourceModule, enabledLintWarnings); return lint(*sourceModule, enabledLintWarnings);
} }
std::pair<SourceModule, LintResult> Frontend::lintFragment(std::string_view source, std::optional<Luau::LintOptions> enabledLintWarnings)
{
LUAU_TIMETRACE_SCOPE("Frontend::lintFragment", "Frontend");
const Config& config = configResolver->getConfig("");
SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions);
uint64_t ignoreLints = LintWarning::parseMask(sourceModule.hotcomments);
Luau::LintOptions lintOptions = enabledLintWarnings.value_or(config.enabledLint);
lintOptions.warningMask &= ~ignoreLints;
double timestamp = getTimestamp();
std::vector<LintWarning> warnings = Luau::lint(sourceModule.root, *sourceModule.names.get(), typeChecker.globalScope, nullptr,
sourceModule.hotcomments, enabledLintWarnings.value_or(config.enabledLint));
stats.timeLint += getTimestamp() - timestamp;
return {std::move(sourceModule), classifyLints(warnings, config)};
}
LintResult Frontend::lint(const SourceModule& module, std::optional<Luau::LintOptions> enabledLintWarnings) LintResult Frontend::lint(const SourceModule& module, std::optional<Luau::LintOptions> enabledLintWarnings)
{ {
LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend");
@ -958,7 +935,7 @@ std::optional<ModuleInfo> FrontendModuleResolver::resolveModuleInfo(const Module
{ {
// CLI-43699 // CLI-43699
// If we can't find the current module name, that's because we bypassed the frontend's initializer // If we can't find the current module name, that's because we bypassed the frontend's initializer
// and called typeChecker.check directly. (This is done by autocompleteSource, for example). // and called typeChecker.check directly.
// In that case, requires will always fail. // In that case, requires will always fail.
return std::nullopt; return std::nullopt;
} }

View File

@ -2688,6 +2688,21 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
else else
seenMode = true; seenMode = true;
} }
else if (first == "optimize")
{
size_t notspace = hc.content.find_first_not_of(" \t", space);
if (space == std::string::npos || notspace == std::string::npos)
emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "optimize directive requires an optimization level");
else
{
const char* level = hc.content.c_str() + notspace;
if (strcmp(level, "0") && strcmp(level, "1") && strcmp(level, "2"))
emitWarning(context, LintWarning::Code_CommentDirective, hc.location,
"optimize directive uses unknown optimization level '%s', 0..2 expected", level);
}
}
else else
{ {
static const char* kHotComments[] = { static const char* kHotComments[] = {
@ -2695,6 +2710,7 @@ static void lintComments(LintContext& context, const std::vector<HotComment>& ho
"nocheck", "nocheck",
"nonstrict", "nonstrict",
"strict", "strict",
"optimize",
}; };
if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments))) if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments)))

View File

@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCopyBeforeNormalizing, false)
LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false);
LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeFlagIsConservative, false);
LUAU_FASTFLAGVARIABLE(LuauFixNormalizationOfCyclicUnions, false);
LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAG(LuauQuantifyConstrained) LUAU_FASTFLAG(LuauQuantifyConstrained)
@ -340,13 +341,19 @@ struct Normalize final : TypeVarVisitor
return false; return false;
UnionTypeVar* utv = &const_cast<UnionTypeVar&>(utvRef); UnionTypeVar* utv = &const_cast<UnionTypeVar&>(utvRef);
std::vector<TypeId> options = std::move(utv->options);
// TODO: Clip tempOptions and optionsRef when clipping FFlag::LuauFixNormalizationOfCyclicUnions
std::vector<TypeId> tempOptions;
if (!FFlag::LuauFixNormalizationOfCyclicUnions)
tempOptions = std::move(utv->options);
std::vector<TypeId>& optionsRef = FFlag::LuauFixNormalizationOfCyclicUnions ? utv->options : tempOptions;
// We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar // We might transmute, so it's not safe to rely on the builtin traversal logic of visitTypeVar
for (TypeId option : options) for (TypeId option : optionsRef)
traverse(option); traverse(option);
std::vector<TypeId> newOptions = normalizeUnion(options); std::vector<TypeId> newOptions = normalizeUnion(optionsRef);
const bool normal = areNormal(newOptions, seen, ice); const bool normal = areNormal(newOptions, seen, ice);
@ -371,51 +378,106 @@ struct Normalize final : TypeVarVisitor
IntersectionTypeVar* itv = &const_cast<IntersectionTypeVar&>(itvRef); IntersectionTypeVar* itv = &const_cast<IntersectionTypeVar&>(itvRef);
std::vector<TypeId> oldParts = std::move(itv->parts); if (FFlag::LuauFixNormalizationOfCyclicUnions)
for (TypeId part : oldParts)
traverse(part);
std::vector<TypeId> tables;
for (TypeId part : oldParts)
{ {
part = follow(part); std::vector<TypeId> oldParts = itv->parts;
if (get<TableTypeVar>(part)) IntersectionTypeVar newIntersection;
tables.push_back(part);
else
{
Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD
combineIntoIntersection(replacer, itv, part);
}
}
// Don't allocate a new table if there's just one in the intersection. for (TypeId part : oldParts)
if (tables.size() == 1) traverse(part);
itv->parts.push_back(tables[0]);
else if (!tables.empty())
{
const TableTypeVar* first = get<TableTypeVar>(tables[0]);
LUAU_ASSERT(first);
TypeId newTable = arena.addType(TableTypeVar{first->state, first->level}); std::vector<TypeId> tables;
TableTypeVar* ttv = getMutable<TableTypeVar>(newTable); for (TypeId part : oldParts)
for (TypeId part : tables)
{ {
// Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need part = follow(part);
// to be rewritten to point at 'newTable' in the clone. if (get<TableTypeVar>(part))
Replacer replacer{&arena, part, newTable}; tables.push_back(part);
combineIntoTable(replacer, ttv, part); else
{
Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD
combineIntoIntersection(replacer, &newIntersection, part);
}
} }
itv->parts.push_back(newTable); // Don't allocate a new table if there's just one in the intersection.
if (tables.size() == 1)
newIntersection.parts.push_back(tables[0]);
else if (!tables.empty())
{
const TableTypeVar* first = get<TableTypeVar>(tables[0]);
LUAU_ASSERT(first);
TypeId newTable = arena.addType(TableTypeVar{first->state, first->level});
TableTypeVar* ttv = getMutable<TableTypeVar>(newTable);
for (TypeId part : tables)
{
// Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need
// to be rewritten to point at 'newTable' in the clone.
Replacer replacer{&arena, part, newTable};
combineIntoTable(replacer, ttv, part);
}
newIntersection.parts.push_back(newTable);
}
itv->parts = std::move(newIntersection.parts);
asMutable(ty)->normal = areNormal(itv->parts, seen, ice);
if (itv->parts.size() == 1)
{
TypeId part = itv->parts[0];
*asMutable(ty) = BoundTypeVar{part};
}
} }
else
asMutable(ty)->normal = areNormal(itv->parts, seen, ice);
if (itv->parts.size() == 1)
{ {
TypeId part = itv->parts[0]; std::vector<TypeId> oldParts = std::move(itv->parts);
*asMutable(ty) = BoundTypeVar{part};
for (TypeId part : oldParts)
traverse(part);
std::vector<TypeId> tables;
for (TypeId part : oldParts)
{
part = follow(part);
if (get<TableTypeVar>(part))
tables.push_back(part);
else
{
Replacer replacer{&arena, nullptr, nullptr}; // FIXME this is super super WEIRD
combineIntoIntersection(replacer, itv, part);
}
}
// Don't allocate a new table if there's just one in the intersection.
if (tables.size() == 1)
itv->parts.push_back(tables[0]);
else if (!tables.empty())
{
const TableTypeVar* first = get<TableTypeVar>(tables[0]);
LUAU_ASSERT(first);
TypeId newTable = arena.addType(TableTypeVar{first->state, first->level});
TableTypeVar* ttv = getMutable<TableTypeVar>(newTable);
for (TypeId part : tables)
{
// Intuition: If combineIntoTable() needs to clone a table, any references to 'part' are cyclic and need
// to be rewritten to point at 'newTable' in the clone.
Replacer replacer{&arena, part, newTable};
combineIntoTable(replacer, ttv, part);
}
itv->parts.push_back(newTable);
}
asMutable(ty)->normal = areNormal(itv->parts, seen, ice);
if (itv->parts.size() == 1)
{
TypeId part = itv->parts[0];
*asMutable(ty) = BoundTypeVar{part};
}
} }
return false; return false;
@ -590,6 +652,24 @@ struct Normalize final : TypeVarVisitor
table->props.insert({propName, prop}); table->props.insert({propName, prop});
} }
if (FFlag::LuauFixNormalizationOfCyclicUnions)
{
if (tyTable->indexer)
{
if (table->indexer)
{
table->indexer->indexType = combine(replacer, replacer.smartClone(tyTable->indexer->indexType), table->indexer->indexType);
table->indexer->indexResultType =
combine(replacer, replacer.smartClone(tyTable->indexer->indexResultType), table->indexer->indexResultType);
}
else
{
table->indexer =
TableIndexer{replacer.smartClone(tyTable->indexer->indexType), replacer.smartClone(tyTable->indexer->indexResultType)};
}
}
}
table->state = combineTableStates(table->state, tyTable->state); table->state = combineTableStates(table->state, tyTable->state);
table->level = max(table->level, tyTable->level); table->level = max(table->level, tyTable->level);
} }

View File

@ -19,7 +19,6 @@ LUAU_FASTFLAG(LuauUnknownAndNeverType)
* Fair warning: Setting this will break a lot of Luau unit tests. * Fair warning: Setting this will break a lot of Luau unit tests.
*/ */
LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false)
LUAU_FASTFLAGVARIABLE(LuauToStringTableBracesNewlines, false)
namespace Luau namespace Luau
{ {
@ -572,54 +571,22 @@ struct TypeVarStringifier
{ {
case TableState::Sealed: case TableState::Sealed:
state.result.invalid = true; state.result.invalid = true;
if (FFlag::LuauToStringTableBracesNewlines) openbrace = "{|";
{ closedbrace = "|}";
openbrace = "{|";
closedbrace = "|}";
}
else
{
openbrace = "{| ";
closedbrace = " |}";
}
break; break;
case TableState::Unsealed: case TableState::Unsealed:
if (FFlag::LuauToStringTableBracesNewlines) openbrace = "{";
{ closedbrace = "}";
openbrace = "{";
closedbrace = "}";
}
else
{
openbrace = "{ ";
closedbrace = " }";
}
break; break;
case TableState::Free: case TableState::Free:
state.result.invalid = true; state.result.invalid = true;
if (FFlag::LuauToStringTableBracesNewlines) openbrace = "{-";
{ closedbrace = "-}";
openbrace = "{-";
closedbrace = "-}";
}
else
{
openbrace = "{- ";
closedbrace = " -}";
}
break; break;
case TableState::Generic: case TableState::Generic:
state.result.invalid = true; state.result.invalid = true;
if (FFlag::LuauToStringTableBracesNewlines) openbrace = "{+";
{ closedbrace = "+}";
openbrace = "{+";
closedbrace = "+}";
}
else
{
openbrace = "{+ ";
closedbrace = " +}";
}
break; break;
} }
@ -638,8 +605,7 @@ struct TypeVarStringifier
bool comma = false; bool comma = false;
if (ttv.indexer) if (ttv.indexer)
{ {
if (FFlag::LuauToStringTableBracesNewlines) state.newline();
state.newline();
state.emit("["); state.emit("[");
stringify(ttv.indexer->indexType); stringify(ttv.indexer->indexType);
state.emit("]: "); state.emit("]: ");
@ -656,10 +622,8 @@ struct TypeVarStringifier
state.emit(","); state.emit(",");
state.newline(); state.newline();
} }
else if (FFlag::LuauToStringTableBracesNewlines) else
{
state.newline(); state.newline();
}
size_t length = state.result.name.length() - oldLength; size_t length = state.result.name.length() - oldLength;
@ -686,13 +650,10 @@ struct TypeVarStringifier
} }
state.dedent(); state.dedent();
if (FFlag::LuauToStringTableBracesNewlines) if (comma)
{ state.newline();
if (comma) else
state.newline(); state.emit(" ");
else
state.emit(" ");
}
state.emit(closedbrace); state.emit(closedbrace);
state.unsee(&ttv); state.unsee(&ttv);
@ -860,7 +821,6 @@ struct TypeVarStringifier
{ {
state.emit("never"); state.emit("never");
} }
}; };
struct TypePackStringifier struct TypePackStringifier

View File

@ -322,10 +322,13 @@ struct TypeChecker2 : public AstVisitor
{ {
pack = follow(pack); pack = follow(pack);
while (auto tp = get<TypePack>(pack)) while (true)
{ {
if (tp->head.empty() && tp->tail) auto tp = get<TypePack>(pack);
if (tp && tp->head.empty() && tp->tail)
pack = *tp->tail; pack = *tp->tail;
else
break;
} }
if (auto ty = first(pack)) if (auto ty = first(pack))

View File

@ -48,6 +48,8 @@ LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false)
LUAU_FASTFLAGVARIABLE(LuauCheckLenMT, false) LUAU_FASTFLAGVARIABLE(LuauCheckLenMT, false)
LUAU_FASTFLAGVARIABLE(LuauCheckGenericHOFTypes, false) LUAU_FASTFLAGVARIABLE(LuauCheckGenericHOFTypes, false)
LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false)
LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false)
LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false)
namespace Luau namespace Luau
{ {
@ -2443,8 +2445,15 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
operandType = stripFromNilAndReport(operandType, expr.location); operandType = stripFromNilAndReport(operandType, expr.location);
if (get<ErrorTypeVar>(operandType) || get<NeverTypeVar>(operandType)) // # operator is guaranteed to return number
return {!FFlag::LuauUnknownAndNeverType ? errorRecoveryType(scope) : operandType}; if ((FFlag::LuauNeverTypesAndOperatorsInference && get<AnyTypeVar>(operandType)) || get<ErrorTypeVar>(operandType) ||
get<NeverTypeVar>(operandType))
{
if (FFlag::LuauNeverTypesAndOperatorsInference)
return {numberType};
else
return {!FFlag::LuauUnknownAndNeverType ? errorRecoveryType(scope) : operandType};
}
DenseHashSet<TypeId> seen{nullptr}; DenseHashSet<TypeId> seen{nullptr};
@ -2610,6 +2619,13 @@ TypeId TypeChecker::checkRelationalOperation(
case AstExprBinary::CompareGe: case AstExprBinary::CompareGe:
case AstExprBinary::CompareLe: case AstExprBinary::CompareLe:
{ {
if (FFlag::LuauNeverTypesAndOperatorsInference)
{
// If one of the operand is never, it doesn't make sense to unify these.
if (get<NeverTypeVar>(lhsType) || get<NeverTypeVar>(rhsType))
return booleanType;
}
/* Subtlety here: /* Subtlety here:
* We need to do this unification first, but there are situations where we don't actually want to * We need to do this unification first, but there are situations where we don't actually want to
* report any problems that might have been surfaced as a result of this step because we might already * report any problems that might have been surfaced as a result of this step because we might already
@ -2787,8 +2803,10 @@ TypeId TypeChecker::checkBinaryOperation(
// If we know nothing at all about the lhs type, we can usually say nothing about the result. // If we know nothing at all about the lhs type, we can usually say nothing about the result.
// The notable exception to this is the equality and inequality operators, which always produce a boolean. // The notable exception to this is the equality and inequality operators, which always produce a boolean.
const bool lhsIsAny = get<AnyTypeVar>(lhsType) || get<ErrorTypeVar>(lhsType); const bool lhsIsAny = get<AnyTypeVar>(lhsType) || get<ErrorTypeVar>(lhsType) ||
const bool rhsIsAny = get<AnyTypeVar>(rhsType) || get<ErrorTypeVar>(rhsType); (FFlag::LuauUnknownAndNeverType && FFlag::LuauNeverTypesAndOperatorsInference && get<NeverTypeVar>(lhsType));
const bool rhsIsAny = get<AnyTypeVar>(rhsType) || get<ErrorTypeVar>(rhsType) ||
(FFlag::LuauUnknownAndNeverType && FFlag::LuauNeverTypesAndOperatorsInference && get<NeverTypeVar>(rhsType));
if (lhsIsAny) if (lhsIsAny)
return lhsType; return lhsType;
@ -3775,7 +3793,10 @@ void TypeChecker::checkArgumentList(
} }
TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}});
state.tryUnify(varPack, tail); if (FFlag::LuauReturnsFromCallsitesAreNotWidened)
state.tryUnify(tail, varPack);
else
state.tryUnify(varPack, tail);
return; return;
} }
} }
@ -4414,7 +4435,8 @@ WithPredicate<TypePackId> TypeChecker::checkExprList(const ScopePtr& scope, cons
if (FFlag::LuauUnknownAndNeverType && containsNever(typePack)) if (FFlag::LuauUnknownAndNeverType && containsNever(typePack))
{ {
// f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never, ...never) // f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never,
// ...never)
uninhabitable = true; uninhabitable = true;
continue; continue;
} }
@ -4436,7 +4458,8 @@ WithPredicate<TypePackId> TypeChecker::checkExprList(const ScopePtr& scope, cons
if (FFlag::LuauUnknownAndNeverType && get<NeverTypeVar>(type)) if (FFlag::LuauUnknownAndNeverType && get<NeverTypeVar>(type))
{ {
// f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never, ...never) // f(), g() where f() returns (never, string) or (string, never) means this whole TypePackId is uninhabitable, so return (never,
// ...never)
uninhabitable = true; uninhabitable = true;
continue; continue;
} }

View File

@ -26,6 +26,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit)
LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAGVARIABLE(LuauDeduceGmatchReturnTypes, false) LUAU_FASTFLAGVARIABLE(LuauDeduceGmatchReturnTypes, false)
LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false)
LUAU_FASTFLAGVARIABLE(LuauDeduceFindMatchReturnTypes, false)
namespace Luau namespace Luau
{ {
@ -36,6 +37,12 @@ std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch( static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate); TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate);
TypeId follow(TypeId t) TypeId follow(TypeId t)
{ {
return follow(t, [](TypeId t) { return follow(t, [](TypeId t) {
@ -164,10 +171,12 @@ bool isNumber(TypeId ty)
// Returns true when ty is a subtype of string // Returns true when ty is a subtype of string
bool isString(TypeId ty) bool isString(TypeId ty)
{ {
if (isPrim(ty, PrimitiveTypeVar::String) || get<StringSingleton>(get<SingletonTypeVar>(follow(ty)))) ty = follow(ty);
if (isPrim(ty, PrimitiveTypeVar::String) || get<StringSingleton>(get<SingletonTypeVar>(ty)))
return true; return true;
if (auto utv = get<UnionTypeVar>(follow(ty))) if (auto utv = get<UnionTypeVar>(ty))
return std::all_of(begin(utv), end(utv), isString); return std::all_of(begin(utv), end(utv), isString);
return false; return false;
@ -178,8 +187,8 @@ bool maybeString(TypeId ty)
{ {
ty = follow(ty); ty = follow(ty);
if (isPrim(ty, PrimitiveTypeVar::String) || get<AnyTypeVar>(ty)) if (isPrim(ty, PrimitiveTypeVar::String) || get<AnyTypeVar>(ty))
return true; return true;
if (auto utv = get<UnionTypeVar>(ty)) if (auto utv = get<UnionTypeVar>(ty))
return std::any_of(begin(utv), end(utv), maybeString); return std::any_of(begin(utv), end(utv), maybeString);
@ -233,6 +242,8 @@ bool isOverloadedFunction(TypeId ty)
std::optional<TypeId> getMetatable(TypeId type) std::optional<TypeId> getMetatable(TypeId type)
{ {
type = follow(type);
if (const MetatableTypeVar* mtType = get<MetatableTypeVar>(type)) if (const MetatableTypeVar* mtType = get<MetatableTypeVar>(type))
return mtType->metatable; return mtType->metatable;
else if (const ClassTypeVar* classType = get<ClassTypeVar>(type)) else if (const ClassTypeVar* classType = get<ClassTypeVar>(type))
@ -765,18 +776,24 @@ TypeId SingletonTypes::makeStringMetatable()
makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionTypeVar{emptyPack, stringVariadicList})}); makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionTypeVar{emptyPack, stringVariadicList})});
attachMagicFunction(gmatchFunc, magicFunctionGmatch); attachMagicFunction(gmatchFunc, magicFunctionGmatch);
const TypeId matchFunc = arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}),
arena->addTypePack(TypePackVar{VariadicTypePack{FFlag::LuauDeduceFindMatchReturnTypes ? stringType : optionalString}})});
attachMagicFunction(matchFunc, magicFunctionMatch);
const TypeId findFunc = arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}),
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})});
attachMagicFunction(findFunc, magicFunctionFind);
TableTypeVar::Props stringLib = { TableTypeVar::Props stringLib = {
{"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, {"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}},
{"char", {arena->addType(FunctionTypeVar{numberVariadicList, arena->addTypePack({stringType})})}}, {"char", {arena->addType(FunctionTypeVar{numberVariadicList, arena->addTypePack({stringType})})}},
{"find", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), {"find", {findFunc}},
arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})})}},
{"format", {formatFn}}, // FIXME {"format", {formatFn}}, // FIXME
{"gmatch", {gmatchFunc}}, {"gmatch", {gmatchFunc}},
{"gsub", {gsubFunc}}, {"gsub", {gsubFunc}},
{"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}},
{"lower", {stringToStringType}}, {"lower", {stringToStringType}},
{"match", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}), {"match", {matchFunc}},
arena->addTypePack(TypePackVar{VariadicTypePack{optionalString}})})}},
{"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}},
{"reverse", {stringToStringType}}, {"reverse", {stringToStringType}},
{"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}},
@ -1213,6 +1230,102 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})}; return WithPredicate<TypePackId>{arena.addTypePack({iteratorType})};
} }
static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
if (!FFlag::LuauDeduceFindMatchReturnTypes)
return std::nullopt;
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() < 2 || params.size() > 3)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = expr.self ? 0 : 1;
if (expr.args.size > patternIndex)
pattern = expr.args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
std::vector<TypeId> returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}});
size_t initIndex = expr.self ? 1 : 2;
if (params.size() == 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, expr.args.data[initIndex]->location);
const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList};
}
static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate<TypePackId> withPredicate)
{
if (!FFlag::LuauDeduceFindMatchReturnTypes)
return std::nullopt;
auto [paramPack, _predicates] = withPredicate;
const auto& [params, tail] = flatten(paramPack);
if (params.size() < 2 || params.size() > 4)
return std::nullopt;
TypeArena& arena = typechecker.currentModule->internalTypes;
AstExprConstantString* pattern = nullptr;
size_t patternIndex = expr.self ? 0 : 1;
if (expr.args.size > patternIndex)
pattern = expr.args.data[patternIndex]->as<AstExprConstantString>();
if (!pattern)
return std::nullopt;
bool plain = false;
size_t plainIndex = expr.self ? 2 : 3;
if (expr.args.size > plainIndex)
{
AstExprConstantBool* p = expr.args.data[plainIndex]->as<AstExprConstantBool>();
plain = p && p->value;
}
std::vector<TypeId> returnTypes;
if (!plain)
{
returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size);
if (returnTypes.empty())
return std::nullopt;
}
typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}});
const TypeId optionalBoolean = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.booleanType}});
size_t initIndex = expr.self ? 1 : 2;
if (params.size() >= 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, expr.args.data[initIndex]->location);
if (params.size() == 4 && expr.args.size > plainIndex)
typechecker.unify(params[3], optionalBoolean, expr.args.data[plainIndex]->location);
returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber});
const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList};
}
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate) std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate)
{ {
type = follow(type); type = follow(type);

View File

@ -21,6 +21,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation);
LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauErrorRecoveryType);
LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAG(LuauQuantifyConstrained) LUAU_FASTFLAG(LuauQuantifyConstrained)
LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false)
namespace Luau namespace Luau
{ {
@ -432,7 +433,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
{ {
// Normally, if the subtype is free, it should not be bound to any, unknown, or error types. // Normally, if the subtype is free, it should not be bound to any, unknown, or error types.
// But for bug compatibility, we'll only apply this rule to unknown. Doing this will silence cascading type errors. // But for bug compatibility, we'll only apply this rule to unknown. Doing this will silence cascading type errors.
if (get<UnknownTypeVar>(superTy)) if (log.get<UnknownTypeVar>(superTy))
return; return;
} }
@ -473,10 +474,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
return tryUnifyWithAny(superTy, subTy); return tryUnifyWithAny(superTy, subTy);
} }
if (get<ErrorTypeVar>(subTy)) if (log.get<ErrorTypeVar>(subTy))
return tryUnifyWithAny(superTy, subTy); return tryUnifyWithAny(superTy, subTy);
if (get<NeverTypeVar>(subTy)) if (log.get<NeverTypeVar>(subTy))
return tryUnifyWithAny(superTy, subTy); return tryUnifyWithAny(superTy, subTy);
auto& cache = sharedState.cachedUnify; auto& cache = sharedState.cachedUnify;
@ -538,6 +539,16 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool
{ {
tryUnifyTables(subTy, superTy, isIntersection); tryUnifyTables(subTy, superTy, isIntersection);
} }
else if (FFlag::LuauScalarShapeSubtyping && log.get<TableTypeVar>(superTy) &&
(log.get<PrimitiveTypeVar>(subTy) || log.get<SingletonTypeVar>(subTy)))
{
tryUnifyScalarShape(subTy, superTy, /*reversed*/ false);
}
else if (FFlag::LuauScalarShapeSubtyping && log.get<TableTypeVar>(subTy) &&
(log.get<PrimitiveTypeVar>(superTy) || log.get<SingletonTypeVar>(superTy)))
{
tryUnifyScalarShape(subTy, superTy, /*reversed*/ true);
}
// tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical.
else if (log.getMutable<MetatableTypeVar>(superTy)) else if (log.getMutable<MetatableTypeVar>(superTy))
@ -1600,6 +1611,60 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection)
} }
} }
void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed)
{
LUAU_ASSERT(FFlag::LuauScalarShapeSubtyping);
TypeId osubTy = subTy;
TypeId osuperTy = superTy;
if (reversed)
std::swap(subTy, superTy);
if (auto ttv = log.get<TableTypeVar>(superTy); !ttv || ttv->state != TableState::Free)
return reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}});
auto fail = [&](std::optional<TypeError> e) {
std::string reason = "The former's metatable does not satisfy the requirements.";
if (e)
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason, *e}});
else
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy, reason}});
};
// Given t1 where t1 = { lower: (t1) -> (a, b...) }
// It should be the case that `string <: t1` iff `(subtype's metatable).__index <: t1`
if (auto metatable = getMetatable(subTy))
{
auto mttv = log.get<TableTypeVar>(*metatable);
if (!mttv)
fail(std::nullopt);
if (auto it = mttv->props.find("__index"); it != mttv->props.end())
{
TypeId ty = it->second.type;
Unifier child = makeChildUnifier();
child.tryUnify_(ty, superTy);
if (auto e = hasUnificationTooComplex(child.errors))
reportError(*e);
else if (!child.errors.empty())
fail(child.errors.front());
log.concat(std::move(child.log));
return;
}
else
{
return fail(std::nullopt);
}
}
reportError(TypeError{location, TypeMismatch{osuperTy, osubTy}});
return;
}
TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map<TypeId, TypeId> seen) TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map<TypeId, TypeId> seen)
{ {
ty = follow(ty); ty = follow(ty);
@ -1916,7 +1981,8 @@ void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy)
sharedState.tempSeenTy.clear(); sharedState.tempSeenTy.clear();
sharedState.tempSeenTp.clear(); sharedState.tempSeenTp.clear();
Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, FFlag::LuauUnknownAndNeverType ? anyTy : getSingletonTypes().anyType, anyTp); Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types,
FFlag::LuauUnknownAndNeverType ? anyTy : getSingletonTypes().anyType, anyTp);
} }
void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp)

View File

@ -24,6 +24,8 @@ bool lua_telemetry_parsed_named_non_function_type = false;
LUAU_FASTFLAGVARIABLE(LuauErrorParseIntegerIssues, false) LUAU_FASTFLAGVARIABLE(LuauErrorParseIntegerIssues, false)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false)
LUAU_FASTFLAGVARIABLE(LuauAlwaysCaptureHotComments, false)
bool lua_telemetry_parsed_out_of_range_bin_integer = false; bool lua_telemetry_parsed_out_of_range_bin_integer = false;
bool lua_telemetry_parsed_out_of_range_hex_integer = false; bool lua_telemetry_parsed_out_of_range_hex_integer = false;
bool lua_telemetry_parsed_double_prefix_hex_integer = false; bool lua_telemetry_parsed_double_prefix_hex_integer = false;
@ -2918,21 +2920,23 @@ AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const
void Parser::nextLexeme() void Parser::nextLexeme()
{ {
if (options.captureComments) if (options.captureComments || FFlag::LuauAlwaysCaptureHotComments)
{ {
Lexeme::Type type = lexer.next(/* skipComments= */ false, true).type; Lexeme::Type type = lexer.next(/* skipComments= */ false, true).type;
while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment) while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment)
{ {
const Lexeme& lexeme = lexer.current(); const Lexeme& lexeme = lexer.current();
commentLocations.push_back(Comment{lexeme.type, lexeme.location});
if (options.captureComments)
commentLocations.push_back(Comment{lexeme.type, lexeme.location});
// Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme.
// The parser will turn this into a proper syntax error. // The parser will turn this into a proper syntax error.
if (lexeme.type == Lexeme::BrokenComment) if (lexeme.type == Lexeme::BrokenComment)
return; return;
// Comments starting with ! are called "hot comments" and contain directives for type checking / linting // Comments starting with ! are called "hot comments" and contain directives for type checking / linting / compiling
if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!')
{ {
const char* text = lexeme.data; const char* text = lexeme.data;

View File

@ -175,6 +175,7 @@ endif()
if(LUAU_BUILD_TESTS) if(LUAU_BUILD_TESTS)
target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS})
target_compile_definitions(Luau.UnitTest PRIVATE DOCTEST_CONFIG_DOUBLE_STRINGIFY)
target_include_directories(Luau.UnitTest PRIVATE extern) target_include_directories(Luau.UnitTest PRIVATE extern)
target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen) target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler Luau.CodeGen)

View File

@ -8,8 +8,8 @@
namespace Luau namespace Luau
{ {
class AstStatBlock;
class AstNameTable; class AstNameTable;
struct ParseResult;
class BytecodeBuilder; class BytecodeBuilder;
class BytecodeEncoder; class BytecodeEncoder;
@ -58,7 +58,7 @@ private:
}; };
// compiles bytecode into bytecode builder using either a pre-parsed AST or parsing it from source; throws on errors // compiles bytecode into bytecode builder using either a pre-parsed AST or parsing it from source; throws on errors
void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options = {}); void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, const AstNameTable& names, const CompileOptions& options = {});
void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const CompileOptions& options = {}, const ParseOptions& parseOptions = {}); void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const CompileOptions& options = {}, const ParseOptions& parseOptions = {});
// compiles bytecode into a bytecode blob, that either contains the valid bytecode or an encoded error that luau_load can decode // compiles bytecode into a bytecode blob, that either contains the valid bytecode or an encoded error that luau_load can decode

View File

@ -0,0 +1,463 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "BuiltinFolding.h"
#include "Luau/Bytecode.h"
#include <math.h>
namespace Luau
{
namespace Compile
{
const double kRadDeg = 3.14159265358979323846 / 180.0;
static Constant cvar()
{
return Constant();
}
static Constant cbool(bool v)
{
Constant res = {Constant::Type_Boolean};
res.valueBoolean = v;
return res;
}
static Constant cnum(double v)
{
Constant res = {Constant::Type_Number};
res.valueNumber = v;
return res;
}
static Constant cstring(const char* v)
{
Constant res = {Constant::Type_String};
res.stringLength = unsigned(strlen(v));
res.valueString = v;
return res;
}
static Constant ctype(const Constant& c)
{
LUAU_ASSERT(c.type != Constant::Type_Unknown);
switch (c.type)
{
case Constant::Type_Nil:
return cstring("nil");
case Constant::Type_Boolean:
return cstring("boolean");
case Constant::Type_Number:
return cstring("number");
case Constant::Type_String:
return cstring("string");
default:
LUAU_ASSERT(!"Unsupported constant type");
return cvar();
}
}
static uint32_t bit32(double v)
{
// convert through signed 64-bit integer to match runtime behavior and gracefully truncate negative integers
return uint32_t(int64_t(v));
}
Constant foldBuiltin(int bfid, const Constant* args, size_t count)
{
switch (bfid)
{
case LBF_MATH_ABS:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(fabs(args[0].valueNumber));
break;
case LBF_MATH_ACOS:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(acos(args[0].valueNumber));
break;
case LBF_MATH_ASIN:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(asin(args[0].valueNumber));
break;
case LBF_MATH_ATAN2:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
return cnum(atan2(args[0].valueNumber, args[1].valueNumber));
break;
case LBF_MATH_ATAN:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(atan(args[0].valueNumber));
break;
case LBF_MATH_CEIL:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(ceil(args[0].valueNumber));
break;
case LBF_MATH_COSH:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(cosh(args[0].valueNumber));
break;
case LBF_MATH_COS:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(cos(args[0].valueNumber));
break;
case LBF_MATH_DEG:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(args[0].valueNumber / kRadDeg);
break;
case LBF_MATH_EXP:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(exp(args[0].valueNumber));
break;
case LBF_MATH_FLOOR:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(floor(args[0].valueNumber));
break;
case LBF_MATH_FMOD:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
return cnum(fmod(args[0].valueNumber, args[1].valueNumber));
break;
// Note: FREXP isn't folded since it returns multiple values
case LBF_MATH_LDEXP:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
return cnum(ldexp(args[0].valueNumber, int(args[1].valueNumber)));
break;
case LBF_MATH_LOG10:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(log10(args[0].valueNumber));
break;
case LBF_MATH_LOG:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(log(args[0].valueNumber));
else if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
{
if (args[1].valueNumber == 2.0)
return cnum(log2(args[0].valueNumber));
else if (args[1].valueNumber == 10.0)
return cnum(log10(args[0].valueNumber));
else
return cnum(log(args[0].valueNumber) / log(args[1].valueNumber));
}
break;
case LBF_MATH_MAX:
if (count >= 1 && args[0].type == Constant::Type_Number)
{
double r = args[0].valueNumber;
for (size_t i = 1; i < count; ++i)
{
if (args[i].type != Constant::Type_Number)
return cvar();
double a = args[i].valueNumber;
r = (a > r) ? a : r;
}
return cnum(r);
}
break;
case LBF_MATH_MIN:
if (count >= 1 && args[0].type == Constant::Type_Number)
{
double r = args[0].valueNumber;
for (size_t i = 1; i < count; ++i)
{
if (args[i].type != Constant::Type_Number)
return cvar();
double a = args[i].valueNumber;
r = (a < r) ? a : r;
}
return cnum(r);
}
break;
// Note: MODF isn't folded since it returns multiple values
case LBF_MATH_POW:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
return cnum(pow(args[0].valueNumber, args[1].valueNumber));
break;
case LBF_MATH_RAD:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(args[0].valueNumber * kRadDeg);
break;
case LBF_MATH_SINH:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(sinh(args[0].valueNumber));
break;
case LBF_MATH_SIN:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(sin(args[0].valueNumber));
break;
case LBF_MATH_SQRT:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(sqrt(args[0].valueNumber));
break;
case LBF_MATH_TANH:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(tanh(args[0].valueNumber));
break;
case LBF_MATH_TAN:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(tan(args[0].valueNumber));
break;
case LBF_BIT32_ARSHIFT:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
{
uint32_t u = bit32(args[0].valueNumber);
int s = int(args[1].valueNumber);
if (unsigned(s) < 32)
return cnum(double(uint32_t(int32_t(u) >> s)));
}
break;
case LBF_BIT32_BAND:
if (count >= 1 && args[0].type == Constant::Type_Number)
{
uint32_t r = bit32(args[0].valueNumber);
for (size_t i = 1; i < count; ++i)
{
if (args[i].type != Constant::Type_Number)
return cvar();
r &= bit32(args[i].valueNumber);
}
return cnum(double(r));
}
break;
case LBF_BIT32_BNOT:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(double(uint32_t(~bit32(args[0].valueNumber))));
break;
case LBF_BIT32_BOR:
if (count >= 1 && args[0].type == Constant::Type_Number)
{
uint32_t r = bit32(args[0].valueNumber);
for (size_t i = 1; i < count; ++i)
{
if (args[i].type != Constant::Type_Number)
return cvar();
r |= bit32(args[i].valueNumber);
}
return cnum(double(r));
}
break;
case LBF_BIT32_BXOR:
if (count >= 1 && args[0].type == Constant::Type_Number)
{
uint32_t r = bit32(args[0].valueNumber);
for (size_t i = 1; i < count; ++i)
{
if (args[i].type != Constant::Type_Number)
return cvar();
r ^= bit32(args[i].valueNumber);
}
return cnum(double(r));
}
break;
case LBF_BIT32_BTEST:
if (count >= 1 && args[0].type == Constant::Type_Number)
{
uint32_t r = bit32(args[0].valueNumber);
for (size_t i = 1; i < count; ++i)
{
if (args[i].type != Constant::Type_Number)
return cvar();
r &= bit32(args[i].valueNumber);
}
return cbool(r != 0);
}
break;
case LBF_BIT32_EXTRACT:
if (count == 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number)
{
uint32_t u = bit32(args[0].valueNumber);
int f = int(args[1].valueNumber);
int w = int(args[2].valueNumber);
if (f >= 0 && w > 0 && f + w <= 32)
{
uint32_t m = ~(0xfffffffeu << (w - 1));
return cnum(double((u >> f) & m));
}
}
break;
case LBF_BIT32_LROTATE:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
{
uint32_t u = bit32(args[0].valueNumber);
int s = int(args[1].valueNumber);
return cnum(double((u << (s & 31)) | (u >> ((32 - s) & 31))));
}
break;
case LBF_BIT32_LSHIFT:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
{
uint32_t u = bit32(args[0].valueNumber);
int s = int(args[1].valueNumber);
if (unsigned(s) < 32)
return cnum(double(u << s));
}
break;
case LBF_BIT32_REPLACE:
if (count == 4 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number &&
args[3].type == Constant::Type_Number)
{
uint32_t n = bit32(args[0].valueNumber);
uint32_t v = bit32(args[1].valueNumber);
int f = int(args[2].valueNumber);
int w = int(args[3].valueNumber);
if (f >= 0 && w > 0 && f + w <= 32)
{
uint32_t m = ~(0xfffffffeu << (w - 1));
return cnum(double((n & ~(m << f)) | ((v & m) << f)));
}
}
break;
case LBF_BIT32_RROTATE:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
{
uint32_t u = bit32(args[0].valueNumber);
int s = int(args[1].valueNumber);
return cnum(double((u >> (s & 31)) | (u << ((32 - s) & 31))));
}
break;
case LBF_BIT32_RSHIFT:
if (count == 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number)
{
uint32_t u = bit32(args[0].valueNumber);
int s = int(args[1].valueNumber);
if (unsigned(s) < 32)
return cnum(double(u >> s));
}
break;
case LBF_TYPE:
if (count == 1 && args[0].type != Constant::Type_Unknown)
return ctype(args[0]);
break;
case LBF_STRING_BYTE:
if (count == 1 && args[0].type == Constant::Type_String)
{
if (args[0].stringLength > 0)
return cnum(double(uint8_t(args[0].valueString[0])));
}
else if (count == 2 && args[0].type == Constant::Type_String && args[1].type == Constant::Type_Number)
{
int i = int(args[1].valueNumber);
if (i > 0 && unsigned(i) <= args[0].stringLength)
return cnum(double(uint8_t(args[0].valueString[i - 1])));
}
break;
case LBF_STRING_LEN:
if (count == 1 && args[0].type == Constant::Type_String)
return cnum(double(args[0].stringLength));
break;
case LBF_TYPEOF:
if (count == 1 && args[0].type != Constant::Type_Unknown)
return ctype(args[0]);
break;
case LBF_MATH_CLAMP:
if (count == 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number)
{
double min = args[1].valueNumber;
double max = args[2].valueNumber;
if (min <= max)
{
double v = args[0].valueNumber;
v = v < min ? min : v;
v = v > max ? max : v;
return cnum(v);
}
}
break;
case LBF_MATH_SIGN:
if (count == 1 && args[0].type == Constant::Type_Number)
{
double v = args[0].valueNumber;
return cnum(v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0);
}
break;
case LBF_MATH_ROUND:
if (count == 1 && args[0].type == Constant::Type_Number)
return cnum(round(args[0].valueNumber));
break;
}
return cvar();
}
} // namespace Compile
} // namespace Luau

View File

@ -0,0 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "ConstantFolding.h"
namespace Luau
{
namespace Compile
{
Constant foldBuiltin(int bfid, const Constant* args, size_t count);
} // namespace Compile
} // namespace Luau

View File

@ -40,11 +40,8 @@ Builtin getBuiltin(AstExpr* node, const DenseHashMap<AstName, Global>& globals,
} }
} }
int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options)
{ {
if (builtin.empty())
return -1;
if (builtin.isGlobal("assert")) if (builtin.isGlobal("assert"))
return LBF_ASSERT; return LBF_ASSERT;
@ -200,5 +197,49 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options)
return -1; return -1;
} }
struct BuiltinVisitor : AstVisitor
{
DenseHashMap<AstExprCall*, int>& result;
const DenseHashMap<AstName, Global>& globals;
const DenseHashMap<AstLocal*, Variable>& variables;
const CompileOptions& options;
BuiltinVisitor(DenseHashMap<AstExprCall*, int>& result, const DenseHashMap<AstName, Global>& globals,
const DenseHashMap<AstLocal*, Variable>& variables, const CompileOptions& options)
: result(result)
, globals(globals)
, variables(variables)
, options(options)
{
}
bool visit(AstExprCall* node) override
{
Builtin builtin = node->self ? Builtin() : getBuiltin(node->func, globals, variables);
if (builtin.empty())
return true;
int bfid = getBuiltinFunctionId(builtin, options);
// getBuiltinFunctionId optimistically assumes all select() calls are builtin but actually the second argument must be a vararg
if (bfid == LBF_SELECT_VARARG && !(node->args.size == 2 && node->args.data[1]->is<AstExprVarargs>()))
bfid = -1;
if (bfid >= 0)
result[node] = bfid;
return true; // propagate to nested calls
}
};
void analyzeBuiltins(DenseHashMap<AstExprCall*, int>& result, const DenseHashMap<AstName, Global>& globals,
const DenseHashMap<AstLocal*, Variable>& variables, const CompileOptions& options, AstNode* root)
{
BuiltinVisitor visitor{result, globals, variables, options};
root->visit(&visitor);
}
} // namespace Compile } // namespace Compile
} // namespace Luau } // namespace Luau

View File

@ -35,7 +35,9 @@ struct Builtin
}; };
Builtin getBuiltin(AstExpr* node, const DenseHashMap<AstName, Global>& globals, const DenseHashMap<AstLocal*, Variable>& variables); Builtin getBuiltin(AstExpr* node, const DenseHashMap<AstName, Global>& globals, const DenseHashMap<AstLocal*, Variable>& variables);
int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options);
void analyzeBuiltins(DenseHashMap<AstExprCall*, int>& result, const DenseHashMap<AstName, Global>& globals,
const DenseHashMap<AstLocal*, Variable>& variables, const CompileOptions& options, AstNode* root);
} // namespace Compile } // namespace Compile
} // namespace Luau } // namespace Luau

View File

@ -120,6 +120,17 @@ inline bool isSkipC(LuauOpcode op)
switch (op) switch (op)
{ {
case LOP_LOADB: case LOP_LOADB:
return true;
default:
return false;
}
}
inline bool isFastCall(LuauOpcode op)
{
switch (op)
{
case LOP_FASTCALL: case LOP_FASTCALL:
case LOP_FASTCALL1: case LOP_FASTCALL1:
case LOP_FASTCALL2: case LOP_FASTCALL2:
@ -137,6 +148,8 @@ static int getJumpTarget(uint32_t insn, uint32_t pc)
if (isJumpD(op)) if (isJumpD(op))
return int(pc + LUAU_INSN_D(insn) + 1); return int(pc + LUAU_INSN_D(insn) + 1);
else if (isFastCall(op))
return int(pc + LUAU_INSN_C(insn) + 2);
else if (isSkipC(op) && LUAU_INSN_C(insn)) else if (isSkipC(op) && LUAU_INSN_C(insn))
return int(pc + LUAU_INSN_C(insn) + 1); return int(pc + LUAU_INSN_C(insn) + 1);
else if (op == LOP_JUMPX) else if (op == LOP_JUMPX)
@ -479,7 +492,7 @@ bool BytecodeBuilder::patchSkipC(size_t jumpLabel, size_t targetLabel)
unsigned int jumpInsn = insns[jumpLabel]; unsigned int jumpInsn = insns[jumpLabel];
(void)jumpInsn; (void)jumpInsn;
LUAU_ASSERT(isSkipC(LuauOpcode(LUAU_INSN_OP(jumpInsn)))); LUAU_ASSERT(isSkipC(LuauOpcode(LUAU_INSN_OP(jumpInsn))) || isFastCall(LuauOpcode(LUAU_INSN_OP(jumpInsn))));
LUAU_ASSERT(LUAU_INSN_C(jumpInsn) == 0); LUAU_ASSERT(LUAU_INSN_C(jumpInsn) == 0);
int offset = int(targetLabel) - int(jumpLabel) - 1; int offset = int(targetLabel) - int(jumpLabel) - 1;

View File

@ -25,6 +25,9 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5)
LUAU_FASTFLAGVARIABLE(LuauCompileNoIpairs, false) LUAU_FASTFLAGVARIABLE(LuauCompileNoIpairs, false)
LUAU_FASTFLAGVARIABLE(LuauCompileFoldBuiltins, false)
LUAU_FASTFLAGVARIABLE(LuauCompileBetterMultret, false)
namespace Luau namespace Luau
{ {
@ -75,6 +78,12 @@ static BytecodeBuilder::StringRef sref(AstArray<char> data)
return {data.data, data.size}; return {data.data, data.size};
} }
static BytecodeBuilder::StringRef sref(AstArray<const char> data)
{
LUAU_ASSERT(data.data);
return {data.data, data.size};
}
struct Compiler struct Compiler
{ {
struct RegScope; struct RegScope;
@ -89,6 +98,7 @@ struct Compiler
, constants(nullptr) , constants(nullptr)
, locstants(nullptr) , locstants(nullptr)
, tableShapes(nullptr) , tableShapes(nullptr)
, builtins(nullptr)
{ {
// preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays // preallocate some buffers that are very likely to grow anyway; this works around std::vector's inefficient growth policy for small arrays
localStack.reserve(16); localStack.reserve(16);
@ -245,7 +255,7 @@ struct Compiler
{ {
f.canInline = true; f.canInline = true;
f.stackSize = stackSize; f.stackSize = stackSize;
f.costModel = modelCost(func->body, func->args.data, func->args.size); f.costModel = modelCost(func->body, func->args.data, func->args.size, builtins);
// track functions that only ever return a single value so that we can convert multret calls to fixedret calls // track functions that only ever return a single value so that we can convert multret calls to fixedret calls
if (allPathsEndWithReturn(func->body)) if (allPathsEndWithReturn(func->body))
@ -262,22 +272,63 @@ struct Compiler
return fid; return fid;
} }
// returns true if node can return multiple values; may conservatively return true even if expr is known to return just a single value
bool isExprMultRet(AstExpr* node)
{
if (!FFlag::LuauCompileBetterMultret)
return node->is<AstExprCall>() || node->is<AstExprVarargs>();
AstExprCall* expr = node->as<AstExprCall>();
if (!expr)
return node->is<AstExprVarargs>();
// conservative version, optimized for compilation throughput
if (options.optimizationLevel <= 1)
return true;
// handles builtin calls that can be constant-folded
// without this we may omit some optimizations eg compiling fast calls without use of FASTCALL2K
if (isConstant(expr))
return false;
// handles local function calls where we know only one argument is returned
AstExprFunction* func = getFunctionExpr(expr->func);
Function* fi = func ? functions.find(func) : nullptr;
if (fi && fi->returnsOne)
return false;
// unrecognized call, so we conservatively assume multret
return true;
}
// note: this doesn't just clobber target (assuming it's temp), but also clobbers *all* allocated registers >= target! // note: this doesn't just clobber target (assuming it's temp), but also clobbers *all* allocated registers >= target!
// this is important to be able to support "multret" semantics due to Lua call frame structure // this is important to be able to support "multret" semantics due to Lua call frame structure
bool compileExprTempMultRet(AstExpr* node, uint8_t target) bool compileExprTempMultRet(AstExpr* node, uint8_t target)
{ {
if (AstExprCall* expr = node->as<AstExprCall>()) if (AstExprCall* expr = node->as<AstExprCall>())
{ {
// Optimization: convert multret calls to functions that always return one value to fixedret calls; this facilitates inlining // Optimization: convert multret calls that always return one value to fixedret calls; this facilitates inlining/constant folding
if (options.optimizationLevel >= 2) if (options.optimizationLevel >= 2)
{ {
AstExprFunction* func = getFunctionExpr(expr->func); if (FFlag::LuauCompileBetterMultret)
Function* fi = func ? functions.find(func) : nullptr;
if (fi && fi->returnsOne)
{ {
compileExprTemp(node, target); if (!isExprMultRet(node))
return false; {
compileExprTemp(node, target);
return false;
}
}
else
{
AstExprFunction* func = getFunctionExpr(expr->func);
Function* fi = func ? functions.find(func) : nullptr;
if (fi && fi->returnsOne)
{
compileExprTemp(node, target);
return false;
}
} }
} }
@ -483,8 +534,7 @@ struct Compiler
varc[i] = isConstant(expr->args.data[i]); varc[i] = isConstant(expr->args.data[i]);
// if the last argument only returns a single value, all following arguments are nil // if the last argument only returns a single value, all following arguments are nil
if (expr->args.size != 0 && if (expr->args.size != 0 && !isExprMultRet(expr->args.data[expr->args.size - 1]))
!(expr->args.data[expr->args.size - 1]->is<AstExprCall>() || expr->args.data[expr->args.size - 1]->is<AstExprVarargs>()))
for (size_t i = expr->args.size; i < func->args.size && i < 8; ++i) for (size_t i = expr->args.size; i < func->args.size && i < 8; ++i)
varc[i] = true; varc[i] = true;
@ -523,7 +573,7 @@ struct Compiler
AstLocal* var = func->args.data[i]; AstLocal* var = func->args.data[i];
AstExpr* arg = i < expr->args.size ? expr->args.data[i] : nullptr; AstExpr* arg = i < expr->args.size ? expr->args.data[i] : nullptr;
if (i + 1 == expr->args.size && func->args.size > expr->args.size && (arg->is<AstExprCall>() || arg->is<AstExprVarargs>())) if (i + 1 == expr->args.size && func->args.size > expr->args.size && isExprMultRet(arg))
{ {
// if the last argument can return multiple values, we need to compute all of them into the remaining arguments // if the last argument can return multiple values, we need to compute all of them into the remaining arguments
unsigned int tail = unsigned(func->args.size - expr->args.size) + 1; unsigned int tail = unsigned(func->args.size - expr->args.size) + 1;
@ -591,7 +641,7 @@ struct Compiler
} }
// fold constant values updated above into expressions in the function body // fold constant values updated above into expressions in the function body
foldConstants(constants, variables, locstants, func->body); foldConstants(constants, variables, locstants, builtinsFold, func->body);
bool usedFallthrough = false; bool usedFallthrough = false;
@ -632,7 +682,7 @@ struct Compiler
if (Constant* var = locstants.find(func->args.data[i])) if (Constant* var = locstants.find(func->args.data[i]))
var->type = Constant::Type_Unknown; var->type = Constant::Type_Unknown;
foldConstants(constants, variables, locstants, func->body); foldConstants(constants, variables, locstants, builtinsFold, func->body);
} }
void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false) void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false)
@ -675,29 +725,23 @@ struct Compiler
int bfid = -1; int bfid = -1;
if (options.optimizationLevel >= 1) if (options.optimizationLevel >= 1 && !expr->self)
{ if (const int* id = builtins.find(expr))
Builtin builtin = getBuiltin(expr->func, globals, variables); bfid = *id;
bfid = getBuiltinFunctionId(builtin, options);
}
if (bfid == LBF_SELECT_VARARG) if (bfid == LBF_SELECT_VARARG)
{ {
// Optimization: compile select(_, ...) as FASTCALL1; the builtin will read variadic arguments directly // Optimization: compile select(_, ...) as FASTCALL1; the builtin will read variadic arguments directly
// note: for now we restrict this to single-return expressions since our runtime code doesn't deal with general cases // note: for now we restrict this to single-return expressions since our runtime code doesn't deal with general cases
if (multRet == false && targetCount == 1 && expr->args.size == 2 && expr->args.data[1]->is<AstExprVarargs>()) if (multRet == false && targetCount == 1)
return compileExprSelectVararg(expr, target, targetCount, targetTop, multRet, regs); return compileExprSelectVararg(expr, target, targetCount, targetTop, multRet, regs);
else else
bfid = -1; bfid = -1;
} }
// Optimization: for 1/2 argument fast calls use specialized opcodes // Optimization: for 1/2 argument fast calls use specialized opcodes
if (!expr->self && bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2) if (bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2 && !isExprMultRet(expr->args.data[expr->args.size - 1]))
{ return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid);
AstExpr* last = expr->args.data[expr->args.size - 1];
if (!last->is<AstExprCall>() && !last->is<AstExprVarargs>())
return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid);
}
if (expr->self) if (expr->self)
{ {
@ -2495,7 +2539,7 @@ struct Compiler
} }
AstLocal* var = stat->var; AstLocal* var = stat->var;
uint64_t costModel = modelCost(stat->body, &var, 1); uint64_t costModel = modelCost(stat->body, &var, 1, builtins);
// we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to unrolling // we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to unrolling
bool varc = true; bool varc = true;
@ -2533,7 +2577,7 @@ struct Compiler
locstants[var].type = Constant::Type_Number; locstants[var].type = Constant::Type_Number;
locstants[var].valueNumber = from + iv * step; locstants[var].valueNumber = from + iv * step;
foldConstants(constants, variables, locstants, stat); foldConstants(constants, variables, locstants, builtinsFold, stat);
size_t iterJumps = loopJumps.size(); size_t iterJumps = loopJumps.size();
@ -2561,7 +2605,7 @@ struct Compiler
// clean up fold state in case we need to recompile - normally we compile the loop body once, but due to inlining we may need to do it again // clean up fold state in case we need to recompile - normally we compile the loop body once, but due to inlining we may need to do it again
locstants[var].type = Constant::Type_Unknown; locstants[var].type = Constant::Type_Unknown;
foldConstants(constants, variables, locstants, stat); foldConstants(constants, variables, locstants, builtinsFold, stat);
} }
void compileStatFor(AstStatFor* stat) void compileStatFor(AstStatFor* stat)
@ -3368,7 +3412,11 @@ struct Compiler
bool visit(AstStatReturn* stat) override bool visit(AstStatReturn* stat) override
{ {
if (stat->list.size == 1) if (FFlag::LuauCompileBetterMultret)
{
returnsOne &= stat->list.size == 1 && !self->isExprMultRet(stat->list.data[0]);
}
else if (stat->list.size == 1)
{ {
AstExpr* value = stat->list.data[0]; AstExpr* value = stat->list.data[0];
@ -3487,6 +3535,8 @@ struct Compiler
DenseHashMap<AstExpr*, Constant> constants; DenseHashMap<AstExpr*, Constant> constants;
DenseHashMap<AstLocal*, Constant> locstants; DenseHashMap<AstLocal*, Constant> locstants;
DenseHashMap<AstExprTable*, TableShape> tableShapes; DenseHashMap<AstExprTable*, TableShape> tableShapes;
DenseHashMap<AstExprCall*, int> builtins;
const DenseHashMap<AstExprCall*, int>* builtinsFold = nullptr;
unsigned int regTop = 0; unsigned int regTop = 0;
unsigned int stackSize = 0; unsigned int stackSize = 0;
@ -3502,10 +3552,21 @@ struct Compiler
std::vector<Capture> captures; std::vector<Capture> captures;
}; };
void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options) void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, const AstNameTable& names, const CompileOptions& inputOptions)
{ {
LUAU_TIMETRACE_SCOPE("compileOrThrow", "Compiler"); LUAU_TIMETRACE_SCOPE("compileOrThrow", "Compiler");
LUAU_ASSERT(parseResult.root);
LUAU_ASSERT(parseResult.errors.empty());
CompileOptions options = inputOptions;
for (const HotComment& hc : parseResult.hotcomments)
if (hc.header && hc.content.compare(0, 9, "optimize ") == 0)
options.optimizationLevel = std::max(0, std::min(2, atoi(hc.content.c_str() + 9)));
AstStatBlock* root = parseResult.root;
Compiler compiler(bytecode, options); Compiler compiler(bytecode, options);
// since access to some global objects may result in values that change over time, we block imports from non-readonly tables // since access to some global objects may result in values that change over time, we block imports from non-readonly tables
@ -3514,10 +3575,17 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName
// this pass analyzes mutability of locals/globals and associates locals with their initial values // this pass analyzes mutability of locals/globals and associates locals with their initial values
trackValues(compiler.globals, compiler.variables, root); trackValues(compiler.globals, compiler.variables, root);
// builtin folding is enabled on optimization level 2 since we can't deoptimize folding at runtime
if (options.optimizationLevel >= 2 && FFlag::LuauCompileFoldBuiltins)
compiler.builtinsFold = &compiler.builtins;
if (options.optimizationLevel >= 1) if (options.optimizationLevel >= 1)
{ {
// this pass tracks which calls are builtins and can be compiled more efficiently
analyzeBuiltins(compiler.builtins, compiler.globals, compiler.variables, options, root);
// this pass analyzes constantness of expressions // this pass analyzes constantness of expressions
foldConstants(compiler.constants, compiler.variables, compiler.locstants, root); foldConstants(compiler.constants, compiler.variables, compiler.locstants, compiler.builtinsFold, root);
// this pass analyzes table assignments to estimate table shapes for initially empty tables // this pass analyzes table assignments to estimate table shapes for initially empty tables
predictTableShapes(compiler.tableShapes, root); predictTableShapes(compiler.tableShapes, root);
@ -3559,9 +3627,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const
if (!result.errors.empty()) if (!result.errors.empty())
throw ParseErrors(result.errors); throw ParseErrors(result.errors);
AstStatBlock* root = result.root; compileOrThrow(bytecode, result, names, options);
compileOrThrow(bytecode, root, names, options);
} }
std::string compile(const std::string& source, const CompileOptions& options, const ParseOptions& parseOptions, BytecodeEncoder* encoder) std::string compile(const std::string& source, const CompileOptions& options, const ParseOptions& parseOptions, BytecodeEncoder* encoder)
@ -3584,7 +3650,7 @@ std::string compile(const std::string& source, const CompileOptions& options, co
try try
{ {
BytecodeBuilder bcb(encoder); BytecodeBuilder bcb(encoder);
compileOrThrow(bcb, result.root, names, options); compileOrThrow(bcb, result, names, options);
return bcb.getBytecode(); return bcb.getBytecode();
} }

View File

@ -1,6 +1,8 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "ConstantFolding.h" #include "ConstantFolding.h"
#include "BuiltinFolding.h"
#include <math.h> #include <math.h>
namespace Luau namespace Luau
@ -193,13 +195,18 @@ struct ConstantVisitor : AstVisitor
DenseHashMap<AstLocal*, Variable>& variables; DenseHashMap<AstLocal*, Variable>& variables;
DenseHashMap<AstLocal*, Constant>& locals; DenseHashMap<AstLocal*, Constant>& locals;
const DenseHashMap<AstExprCall*, int>* builtins;
bool wasEmpty = false; bool wasEmpty = false;
ConstantVisitor( std::vector<Constant> builtinArgs;
DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables, DenseHashMap<AstLocal*, Constant>& locals)
ConstantVisitor(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables,
DenseHashMap<AstLocal*, Constant>& locals, const DenseHashMap<AstExprCall*, int>* builtins)
: constants(constants) : constants(constants)
, variables(variables) , variables(variables)
, locals(locals) , locals(locals)
, builtins(builtins)
{ {
// since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries // since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries
wasEmpty = constants.empty() && locals.empty(); wasEmpty = constants.empty() && locals.empty();
@ -253,8 +260,37 @@ struct ConstantVisitor : AstVisitor
{ {
analyze(expr->func); analyze(expr->func);
for (size_t i = 0; i < expr->args.size; ++i) if (const int* bfid = builtins ? builtins->find(expr) : nullptr)
analyze(expr->args.data[i]); {
// since recursive calls to analyze() may reuse the vector we need to be careful and preserve existing contents
size_t offset = builtinArgs.size();
bool canFold = true;
builtinArgs.reserve(offset + expr->args.size);
for (size_t i = 0; i < expr->args.size; ++i)
{
Constant ac = analyze(expr->args.data[i]);
if (ac.type == Constant::Type_Unknown)
canFold = false;
else
builtinArgs.push_back(ac);
}
if (canFold)
{
LUAU_ASSERT(builtinArgs.size() == offset + expr->args.size);
result = foldBuiltin(*bfid, builtinArgs.data() + offset, expr->args.size);
}
builtinArgs.resize(offset);
}
else
{
for (size_t i = 0; i < expr->args.size; ++i)
analyze(expr->args.data[i]);
}
} }
else if (AstExprIndexName* expr = node->as<AstExprIndexName>()) else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
{ {
@ -395,9 +431,9 @@ struct ConstantVisitor : AstVisitor
}; };
void foldConstants(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables, void foldConstants(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables,
DenseHashMap<AstLocal*, Constant>& locals, AstNode* root) DenseHashMap<AstLocal*, Constant>& locals, const DenseHashMap<AstExprCall*, int>* builtins, AstNode* root)
{ {
ConstantVisitor visitor{constants, variables, locals}; ConstantVisitor visitor{constants, variables, locals, builtins};
root->visit(&visitor); root->visit(&visitor);
} }

View File

@ -26,7 +26,7 @@ struct Constant
{ {
bool valueBoolean; bool valueBoolean;
double valueNumber; double valueNumber;
char* valueString = nullptr; // length stored in stringLength const char* valueString = nullptr; // length stored in stringLength
}; };
bool isTruthful() const bool isTruthful() const
@ -35,7 +35,7 @@ struct Constant
return type != Type_Nil && !(type == Type_Boolean && valueBoolean == false); return type != Type_Nil && !(type == Type_Boolean && valueBoolean == false);
} }
AstArray<char> getString() const AstArray<const char> getString() const
{ {
LUAU_ASSERT(type == Type_String); LUAU_ASSERT(type == Type_String);
return {valueString, stringLength}; return {valueString, stringLength};
@ -43,7 +43,7 @@ struct Constant
}; };
void foldConstants(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables, void foldConstants(DenseHashMap<AstExpr*, Constant>& constants, DenseHashMap<AstLocal*, Variable>& variables,
DenseHashMap<AstLocal*, Constant>& locals, AstNode* root); DenseHashMap<AstLocal*, Constant>& locals, const DenseHashMap<AstExprCall*, int>* builtins, AstNode* root);
} // namespace Compile } // namespace Compile
} // namespace Luau } // namespace Luau

View File

@ -6,6 +6,8 @@
#include <limits.h> #include <limits.h>
LUAU_FASTFLAGVARIABLE(LuauCompileModelBuiltins, false)
namespace Luau namespace Luau
{ {
namespace Compile namespace Compile
@ -113,11 +115,14 @@ struct Cost
struct CostVisitor : AstVisitor struct CostVisitor : AstVisitor
{ {
const DenseHashMap<AstExprCall*, int>& builtins;
DenseHashMap<AstLocal*, uint64_t> vars; DenseHashMap<AstLocal*, uint64_t> vars;
Cost result; Cost result;
CostVisitor() CostVisitor(const DenseHashMap<AstExprCall*, int>& builtins)
: vars(nullptr) : builtins(builtins)
, vars(nullptr)
{ {
} }
@ -148,14 +153,21 @@ struct CostVisitor : AstVisitor
} }
else if (AstExprCall* expr = node->as<AstExprCall>()) else if (AstExprCall* expr = node->as<AstExprCall>())
{ {
Cost cost = 3; // builtin cost modeling is different from regular calls because we use FASTCALL to compile these
cost += model(expr->func); // thus we use a cheaper baseline, don't account for function, and assume constant/local copy is free
bool builtin = FFlag::LuauCompileModelBuiltins && builtins.find(expr) != nullptr;
bool builtinShort = builtin && expr->args.size <= 2; // FASTCALL1/2
Cost cost = builtin ? 2 : 3;
if (!builtin)
cost += model(expr->func);
for (size_t i = 0; i < expr->args.size; ++i) for (size_t i = 0; i < expr->args.size; ++i)
{ {
Cost ac = model(expr->args.data[i]); Cost ac = model(expr->args.data[i]);
// for constants/locals we still need to copy them to the argument list // for constants/locals we still need to copy them to the argument list
cost += ac.model == 0 ? Cost(1) : ac; cost += ac.model == 0 && !builtinShort ? Cost(1) : ac;
} }
return cost; return cost;
@ -327,9 +339,9 @@ struct CostVisitor : AstVisitor
} }
}; };
uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount) uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount, const DenseHashMap<AstExprCall*, int>& builtins)
{ {
CostVisitor visitor; CostVisitor visitor{builtins};
for (size_t i = 0; i < varCount && i < 7; ++i) for (size_t i = 0; i < varCount && i < 7; ++i)
visitor.vars[vars[i]] = 0xffull << (i * 8 + 8); visitor.vars[vars[i]] = 0xffull << (i * 8 + 8);

View File

@ -2,6 +2,7 @@
#pragma once #pragma once
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/DenseHash.h"
namespace Luau namespace Luau
{ {
@ -9,7 +10,7 @@ namespace Compile
{ {
// cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant // cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant
uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount); uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount, const DenseHashMap<AstExprCall*, int>& builtins);
// cost is computed as B - sum(Di * Ci), where B is baseline cost, Di is the discount for each variable and Ci is 1 when variable #i is constant // cost is computed as B - sum(Di * Ci), where B is baseline cost, Di is the discount for each variable and Ci is 1 when variable #i is constant
int computeCost(uint64_t model, const bool* varsConst, size_t varCount); int computeCost(uint64_t model, const bool* varsConst, size_t varCount);

View File

@ -50,6 +50,9 @@ TESTS_ARGS=
ifneq ($(flags),) ifneq ($(flags),)
TESTS_ARGS+=--fflags=$(flags) TESTS_ARGS+=--fflags=$(flags)
endif endif
ifneq ($(opt),)
TESTS_ARGS+=-O$(opt)
endif
OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS) OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS)
@ -104,7 +107,7 @@ $(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnaly
$(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include $(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include
$(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include
$(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include $(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include
$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -ICodeGen/include -IVM/include -ICLI -Iextern $(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -ICodeGen/include -IVM/include -ICLI -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY
$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -Iextern -Iextern/isocline/include $(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -Iextern -Iextern/isocline/include
$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -Iextern $(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -Iextern
$(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IVM/include $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IVM/include

View File

@ -38,12 +38,14 @@ target_sources(Luau.Compiler PRIVATE
Compiler/src/BytecodeBuilder.cpp Compiler/src/BytecodeBuilder.cpp
Compiler/src/Compiler.cpp Compiler/src/Compiler.cpp
Compiler/src/Builtins.cpp Compiler/src/Builtins.cpp
Compiler/src/BuiltinFolding.cpp
Compiler/src/ConstantFolding.cpp Compiler/src/ConstantFolding.cpp
Compiler/src/CostModel.cpp Compiler/src/CostModel.cpp
Compiler/src/TableShape.cpp Compiler/src/TableShape.cpp
Compiler/src/ValueTracking.cpp Compiler/src/ValueTracking.cpp
Compiler/src/lcode.cpp Compiler/src/lcode.cpp
Compiler/src/Builtins.h Compiler/src/Builtins.h
Compiler/src/BuiltinFolding.h
Compiler/src/ConstantFolding.h Compiler/src/ConstantFolding.h
Compiler/src/CostModel.h Compiler/src/CostModel.h
Compiler/src/TableShape.h Compiler/src/TableShape.h

View File

@ -373,7 +373,7 @@ LUA_API const char* lua_getupvalue(lua_State* L, int funcindex, int n);
LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n); LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n);
LUA_API void lua_singlestep(lua_State* L, int enabled); LUA_API void lua_singlestep(lua_State* L, int enabled);
LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled); LUA_API int lua_breakpoint(lua_State* L, int funcindex, int line, int enabled);
typedef void (*lua_Coverage)(void* context, const char* function, int linedefined, int depth, const int* hits, size_t size); typedef void (*lua_Coverage)(void* context, const char* function, int linedefined, int depth, const int* hits, size_t size);

View File

@ -34,6 +34,8 @@
* therefore call luaC_checkGC before luaC_checkthreadsleep to guarantee the object is pushed to an awake thread. * therefore call luaC_checkGC before luaC_checkthreadsleep to guarantee the object is pushed to an awake thread.
*/ */
LUAU_FASTFLAG(LuauLazyAtoms)
const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n"
"$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n"
"$URL: www.lua.org $\n"; "$URL: www.lua.org $\n";
@ -51,6 +53,13 @@ const char* luau_ident = "$Luau: Copyright (C) 2019-2022 Roblox Corporation $\n"
L->top++; \ L->top++; \
} }
#define updateatom(L, ts) \
if (FFlag::LuauLazyAtoms) \
{ \
if (ts->atom == ATOM_UNDEF) \
ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, ts->len) : -1; \
}
static Table* getcurrenv(lua_State* L) static Table* getcurrenv(lua_State* L)
{ {
if (L->ci == L->base_ci) /* no enclosing function? */ if (L->ci == L->base_ci) /* no enclosing function? */
@ -441,19 +450,25 @@ const char* lua_tostringatom(lua_State* L, int idx, int* atom)
StkId o = index2addr(L, idx); StkId o = index2addr(L, idx);
if (!ttisstring(o)) if (!ttisstring(o))
return NULL; return NULL;
const TString* s = tsvalue(o); TString* s = tsvalue(o);
if (atom) if (atom)
{
updateatom(L, s);
*atom = s->atom; *atom = s->atom;
}
return getstr(s); return getstr(s);
} }
const char* lua_namecallatom(lua_State* L, int* atom) const char* lua_namecallatom(lua_State* L, int* atom)
{ {
const TString* s = L->namecall; TString* s = L->namecall;
if (!s) if (!s)
return NULL; return NULL;
if (atom) if (atom)
{
updateatom(L, s);
*atom = s->atom; *atom = s->atom;
}
return getstr(s); return getstr(s);
} }

View File

@ -12,6 +12,8 @@
#include <string.h> #include <string.h>
#include <stdio.h> #include <stdio.h>
LUAU_FASTFLAGVARIABLE(LuauDebuggerBreakpointHitOnNextBestLine, false);
static const char* getfuncname(Closure* f); static const char* getfuncname(Closure* f);
static int currentpc(lua_State* L, CallInfo* ci) static int currentpc(lua_State* L, CallInfo* ci)
@ -367,14 +369,6 @@ void lua_singlestep(lua_State* L, int enabled)
L->singlestep = bool(enabled); L->singlestep = bool(enabled);
} }
void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled)
{
const TValue* func = luaA_toobject(L, funcindex);
api_check(L, ttisfunction(func) && !clvalue(func)->isC);
luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled));
}
static int getmaxline(Proto* p) static int getmaxline(Proto* p)
{ {
int result = -1; int result = -1;
@ -394,6 +388,71 @@ static int getmaxline(Proto* p)
return result; return result;
} }
// Find the line number with instructions. If the provided line doesn't have any instruction, it should return the next line number with
// instructions.
static int getnextline(Proto* p, int line)
{
int closest = -1;
if (p->lineinfo)
{
for (int i = 0; i < p->sizecode; ++i)
{
// note: we keep prologue as is, instead opting to break at the first meaningful instruction
if (LUAU_INSN_OP(p->code[i]) == LOP_PREPVARARGS)
continue;
int current = luaG_getline(p, i);
if (current >= line)
{
closest = current;
break;
}
}
}
for (int i = 0; i < p->sizep; ++i)
{
// Find the closest line number to the intended one.
int candidate = getnextline(p->p[i], line);
if (closest == -1 || (candidate >= line && candidate < closest))
{
closest = candidate;
}
}
return closest;
}
int lua_breakpoint(lua_State* L, int funcindex, int line, int enabled)
{
int target = -1;
if (FFlag::LuauDebuggerBreakpointHitOnNextBestLine)
{
const TValue* func = luaA_toobject(L, funcindex);
api_check(L, ttisfunction(func) && !clvalue(func)->isC);
Proto* p = clvalue(func)->l.p;
// Find line number to add the breakpoint to.
target = getnextline(p, line);
if (target != -1)
{
// Add breakpoint on the exact line
luaG_breakpoint(L, p, target, bool(enabled));
}
}
else
{
const TValue* func = luaA_toobject(L, funcindex);
api_check(L, ttisfunction(func) && !clvalue(func)->isC);
luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled));
}
return target;
}
static void getcoverage(Proto* p, int depth, int* buffer, size_t size, void* context, lua_Coverage callback) static void getcoverage(Proto* p, int depth, int* buffer, size_t size, void* context, lua_Coverage callback)
{ {
memset(buffer, -1, size * sizeof(int)); memset(buffer, -1, size * sizeof(int));

View File

@ -7,6 +7,8 @@
#include <string.h> #include <string.h>
LUAU_FASTFLAGVARIABLE(LuauLazyAtoms, false)
unsigned int luaS_hash(const char* str, size_t len) unsigned int luaS_hash(const char* str, size_t len)
{ {
// Note that this hashing algorithm is replicated in BytecodeBuilder.cpp, BytecodeBuilder::getStringHash // Note that this hashing algorithm is replicated in BytecodeBuilder.cpp, BytecodeBuilder::getStringHash
@ -82,7 +84,7 @@ static TString* newlstr(lua_State* L, const char* str, size_t l, unsigned int h)
ts->memcat = L->activememcat; ts->memcat = L->activememcat;
memcpy(ts->data, str, l); memcpy(ts->data, str, l);
ts->data[l] = '\0'; /* ending 0 */ ts->data[l] = '\0'; /* ending 0 */
ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, l) : -1; ts->atom = FFlag::LuauLazyAtoms ? ATOM_UNDEF : L->global->cb.useratom ? L->global->cb.useratom(ts->data, l) : -1;
tb = &L->global->strt; tb = &L->global->strt;
h = lmod(h, tb->size); h = lmod(h, tb->size);
ts->next = tb->hash[h]; /* chain new entry */ ts->next = tb->hash[h]; /* chain new entry */
@ -165,7 +167,7 @@ TString* luaS_buffinish(lua_State* L, TString* ts)
ts->data[ts->len] = '\0'; // ending 0 ts->data[ts->len] = '\0'; // ending 0
// Complete string object // Complete string object
ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, ts->len) : -1; ts->atom = FFlag::LuauLazyAtoms ? ATOM_UNDEF : L->global->cb.useratom ? L->global->cb.useratom(ts->data, ts->len) : -1;
ts->next = tb->hash[bucket]; // chain new entry ts->next = tb->hash[bucket]; // chain new entry
tb->hash[bucket] = ts; tb->hash[bucket] = ts;

View File

@ -8,6 +8,9 @@
/* string size limit */ /* string size limit */
#define MAXSSIZE (1 << 30) #define MAXSSIZE (1 << 30)
/* string atoms are not defined by default; the storage is 16-bit integer */
#define ATOM_UNDEF -32768
#define sizestring(len) (offsetof(TString, data) + len + 1) #define sizestring(len) (offsetof(TString, data) + len + 1)
#define luaS_new(L, s) (luaS_newlstr(L, s, strlen(s))) #define luaS_new(L, s) (luaS_newlstr(L, s, strlen(s)))

View File

@ -1,436 +0,0 @@
--[[
* AES Cipher function: encrypt 'input' with Rijndael algorithm
*
* takes byte-array 'input' (16 bytes)
* 2D byte-array key schedule 'w' (Nr+1 x Nb bytes)
*
* applies Nr rounds (10/12/14) using key schedule w for 'add round key' stage
*
* returns byte-array encrypted value (16 bytes)
*/]]
local bench = script and require(script.Parent.bench_support) or require("bench_support")
function test()
-- Sbox is pre-computed multiplicative inverse in GF(2^8) used in SubBytes and KeyExpansion [§5.1.1]
local Sbox = { 0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76,
0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0,
0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15,
0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75,
0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84,
0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf,
0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8,
0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2,
0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73,
0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb,
0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79,
0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08,
0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a,
0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e,
0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf,
0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16 };
-- Rcon is Round Constant used for the Key Expansion [1st col is 2^(r-1) in GF(2^8)] [§5.2]
local Rcon = { { 0x00, 0x00, 0x00, 0x00 },
{0x01, 0x00, 0x00, 0x00},
{0x02, 0x00, 0x00, 0x00},
{0x04, 0x00, 0x00, 0x00},
{0x08, 0x00, 0x00, 0x00},
{0x10, 0x00, 0x00, 0x00},
{0x20, 0x00, 0x00, 0x00},
{0x40, 0x00, 0x00, 0x00},
{0x80, 0x00, 0x00, 0x00},
{0x1b, 0x00, 0x00, 0x00},
{0x36, 0x00, 0x00, 0x00} };
function Cipher(input, w) -- main Cipher function [§5.1]
local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES)
local Nr = #w / Nb - 1; -- no of rounds: 10/12/14 for 128/192/256-bit keys
local state = {{},{},{},{}}; -- initialise 4xNb byte-array 'state' with input [§3.4]
for i = 0,4*Nb-1 do state[(i % 4) + 1][math.floor(i/4) + 1] = input[i + 1]; end
state = AddRoundKey(state, w, 0, Nb);
for round = 1,Nr-1 do
state = SubBytes(state, Nb);
state = ShiftRows(state, Nb);
state = MixColumns(state, Nb);
state = AddRoundKey(state, w, round, Nb);
end
state = SubBytes(state, Nb);
state = ShiftRows(state, Nb);
state = AddRoundKey(state, w, Nr, Nb);
local output = {} -- convert state to 1-d array before returning [§3.4]
for i = 0,4*Nb-1 do output[i + 1] = state[(i % 4) + 1][math.floor(i / 4) + 1]; end
return output;
end
function SubBytes(s, Nb) -- apply SBox to state S [§5.1.1]
for r = 0,3 do
for c = 0,Nb-1 do s[r + 1][c + 1] = Sbox[s[r + 1][c + 1] + 1]; end
end
return s;
end
function ShiftRows(s, Nb) -- shift row r of state S left by r bytes [§5.1.2]
local t = {};
for r = 1,3 do
for c = 0,3 do t[c + 1] = s[r + 1][((c + r) % Nb) + 1] end; -- shift into temp copy
for c = 0,3 do s[r + 1][c + 1] = t[c + 1]; end -- and copy back
end -- note that this will work for Nb=4,5,6, but not 7,8 (always 4 for AES):
return s; -- see fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf
end
function MixColumns(s, Nb) -- combine bytes of each col of state S [§5.1.3]
for c = 0,3 do
local a = {}; -- 'a' is a copy of the current column from 's'
local b = {}; -- 'b' is a•{02} in GF(2^8)
for i = 0,3 do
a[i + 1] = s[i + 1][c + 1];
if bit32.band(s[i + 1][c + 1], 0x80) ~= 0 then
b[i + 1] = bit32.bxor(bit32.lshift(s[i + 1][c + 1], 1), 0x011b);
else
b[i + 1] = bit32.lshift(s[i + 1][c + 1], 1);
end
end
-- a[n] ^ b[n] is a•{03} in GF(2^8)
s[1][c + 1] = bit32.bxor(bit32.bxor(bit32.bxor(b[1], a[2]), bit32.bxor(b[2], a[3])), a[4]); -- 2*a0 + 3*a1 + a2 + a3
s[2][c + 1] = bit32.bxor(bit32.bxor(bit32.bxor(a[1], b[2]), bit32.bxor(a[3], b[3])), a[4]); -- a0 * 2*a1 + 3*a2 + a3
s[3][c + 1] = bit32.bxor(bit32.bxor(bit32.bxor(a[1], a[2]), bit32.bxor(b[3], a[4])), b[4]); -- a0 + a1 + 2*a2 + 3*a3
s[4][c + 1] = bit32.bxor(bit32.bxor(bit32.bxor(a[1], b[1]), bit32.bxor(a[2], a[3])), b[4]); -- 3*a0 + a1 + a2 + 2*a3
end
return s;
end
function AddRoundKey(state, w, rnd, Nb) -- xor Round Key into state S [§5.1.4]
for r = 0,3 do
for c = 0,Nb-1 do state[r + 1][c + 1] = bit32.bxor(state[r + 1][c + 1], w[rnd*4+c + 1][r + 1]); end
end
return state;
end
function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from Key [§5.2]
local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES)
local Nk = #key / 4 -- key length (in words): 4/6/8 for 128/192/256-bit keys
local Nr = Nk + 6; -- no of rounds: 10/12/14 for 128/192/256-bit keys
local w = {};
local temp = {};
for i = 0,Nk do
local r = { key[4*i + 1], key[4*i + 2], key[4*i + 3], key[4*i + 4] };
w[i + 1] = r;
end
for i = Nk,(Nb*(Nr+1)) - 1 do
w[i + 1] = {};
for t = 0,3 do temp[t + 1] = w[i-1 + 1][t + 1]; end
if (i % Nk == 0) then
temp = SubWord(RotWord(temp));
for t = 0,3 do temp[t + 1] = bit32.bxor(temp[t + 1], Rcon[i/Nk + 1][t + 1]); end
elseif (Nk > 6 and i % Nk == 4) then
temp = SubWord(temp);
end
for t = 0,3 do w[i + 1][t + 1] = bit32.bxor(w[i - Nk + 1][t + 1], temp[t + 1]); end
end
return w;
end
function SubWord(w) -- apply SBox to 4-byte word w
for i = 0,3 do w[i + 1] = Sbox[w[i + 1] + 1]; end
return w;
end
function RotWord(w) -- rotate 4-byte word w left by one byte
w[5] = w[1];
for i = 0,3 do w[i + 1] = w[i + 2]; end
return w;
end
--[[
* Use AES to encrypt 'plaintext' with 'password' using 'nBits' key, in 'Counter' mode of operation
* - see http://csrc.nist.gov/publications/nistpubs/800-38a/sp800-38a.pdf
* for each block
* - outputblock = cipher(counter, key)
* - cipherblock = plaintext xor outputblock
]]
function AESEncryptCtr(plaintext, password, nBits)
if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys
-- for this example script, generate the key by applying Cipher to 1st 16/24/32 chars of password;
-- for real-world applications, a higher security approach would be to hash the password e.g. with SHA-1
local nBytes = nBits/8; -- no bytes in key
local pwBytes = {};
for i = 0,nBytes-1 do pwBytes[i + 1] = bit32.band(string.byte(password, i + 1), 0xff); end
local key = Cipher(pwBytes, KeyExpansion(pwBytes));
-- key is now 16/24/32 bytes long
for i = 1,nBytes-16 do
table.insert(key, key[i])
end
-- initialise counter block (NIST SP800-38A §B.2): millisecond time-stamp for nonce in 1st 8 bytes,
-- block counter in 2nd 8 bytes
local blockSize = 16; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES
local counterBlock = {}; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES
local nonce = 12564231564 -- (new Date()).getTime(); -- milliseconds since 1-Jan-1970
-- encode nonce in two stages to cater for JavaScript 32-bit limit on bitwise ops
for i = 0,3 do counterBlock[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff); end
for i = 0,3 do counterBlock[i + 4 + 1] = bit32.band(bit32.rshift(math.floor(nonce / 0x100000000), i*8), 0xff); end
-- generate key schedule - an expansion of the key into distinct Key Rounds for each round
local keySchedule = KeyExpansion(key);
local blockCount = math.ceil(#plaintext / blockSize);
local ciphertext = {}; -- ciphertext as array of strings
for b = 0,blockCount-1 do
-- set counter (block #) in last 8 bytes of counter block (leaving nonce in 1st 8 bytes)
-- again done in two stages for 32-bit ops
for c = 0,3 do counterBlock[15-c + 1] = bit32.band(bit32.rshift(b, c*8), 0xff); end
for c = 0,3 do counterBlock[15-c-4 + 1] = bit32.rshift(math.floor(b/0x100000000), c*8) end
local cipherCntr = Cipher(counterBlock, keySchedule); -- -- encrypt counter block --
-- calculate length of final block:
local blockLength = nil
if b<blockCount-1 then
blockLength = blockSize;
else
blockLength = (#plaintext - 1) % blockSize+1;
end
local ct = '';
for i = 0,blockLength-1 do -- -- xor plaintext with ciphered counter byte-by-byte --
local plaintextByte = string.byte(plaintext, b*blockSize+i + 1);
local cipherByte = bit32.bxor(plaintextByte, cipherCntr[i + 1]);
ct = ct .. string.char(cipherByte);
end
-- ct is now ciphertext for this block
ciphertext[b + 1] = escCtrlChars(ct); -- escape troublesome characters in ciphertext
end
-- convert the nonce to a string to go on the front of the ciphertext
local ctrTxt = '';
for i = 0,7 do ctrTxt = ctrTxt .. string.char(counterBlock[i + 1]); end
ctrTxt = escCtrlChars(ctrTxt);
-- use '-' to separate blocks, use Array.join to concatenate arrays of strings for efficiency
return ctrTxt .. '-' .. table.concat(ciphertext, '-');
end
--[[
* Use AES to decrypt 'ciphertext' with 'password' using 'nBits' key, in Counter mode of operation
*
* for each block
* - outputblock = cipher(counter, key)
* - cipherblock = plaintext xor outputblock
]]
function AESDecryptCtr(ciphertext, password, nBits)
if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys
local nBytes = nBits/8; -- no bytes in key
local pwBytes = {};
for i = 0,nBytes-1 do pwBytes[i + 1] = bit32.band(string.byte(password, i + 1), 0xff); end
local pwKeySchedule = KeyExpansion(pwBytes);
local key = Cipher(pwBytes, pwKeySchedule);
-- key is now 16/24/32 bytes long
for i = 1,nBytes-16 do
table.insert(key, key[i])
end
local keySchedule = KeyExpansion(key);
-- split ciphertext into array of block-length strings
local tmp = {}
for token in string.gmatch(ciphertext, "[^-]+") do
table.insert(tmp, token)
end
ciphertext = tmp;
-- recover nonce from 1st element of ciphertext
local blockSize = 16; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES
local counterBlock = {};
local ctrTxt = unescCtrlChars(ciphertext[1]);
for i = 0,7 do counterBlock[i + 1] = string.byte(ctrTxt, i + 1); end
local plaintext = {};
for b = 1,#ciphertext-1 do
-- set counter (block #) in last 8 bytes of counter block (leaving nonce in 1st 8 bytes)
for c = 0,3 do counterBlock[15-c + 1] = bit32.band(bit32.rshift((b-1), c*8), 0xff); end
for c = 0,3 do counterBlock[15-c-4 + 1] = bit32.band(bit32.rshift(math.floor((b-1)/0x100000000), c*8), 0xff); end
local cipherCntr = Cipher(counterBlock, keySchedule); -- encrypt counter block
ciphertext[b + 1] = unescCtrlChars(ciphertext[b + 1]);
local pt = '';
for i = 0,#ciphertext[b + 1]-1 do
-- -- xor plaintext with ciphered counter byte-by-byte --
local ciphertextByte = string.byte(ciphertext[b + 1], i + 1);
local plaintextByte = bit32.bxor(ciphertextByte, cipherCntr[i + 1]);
pt = pt .. string.char(plaintextByte);
end
-- pt is now plaintext for this block
plaintext[b] = pt; -- b-1 'cos no initial nonce block in plaintext
end
return table.concat(plaintext)
end
function escCtrlChars(str) -- escape control chars which might cause problems handling ciphertext
return string.gsub(str, "[\0\t\n\v\f\r\'\"!-]", function(c) return '!' .. string.byte(c, 1) .. '!'; end);
end
function unescCtrlChars(str) -- unescape potentially problematic control characters
return string.gsub(str, "!%d%d?%d?!", function(c)
local sc = string.sub(c, 2,-2)
return string.char(tonumber(sc));
end);
end
--[[
* if escCtrlChars()/unescCtrlChars() still gives problems, use encodeBase64()/decodeBase64() instead
]]
local b64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=";
function encodeBase64(str) -- http://tools.ietf.org/html/rfc4648
local o1, o2, o3, h1, h2, h3, h4, bits
local i=0
local enc='';
str = encodeUTF8(str); -- encode multi-byte chars into UTF-8 for byte-array
repeat -- pack three octets into four hexets
o1 = string.byte(str, i + 1); i = i + 1
o2 = string.byte(str, i + 1); i = i + 1
o3 = string.byte(str, i + 1); i = i + 1
bits = bit32.bor(bit32.bor(bit32.lshift(o1, 16), bit32.lshift(o2, 8)), o3);
h1 = bit32.band(bit32.rshit(bits, 18), 0x3f);
h2 = bit32.band(bit32.rshit(bits, 12), 0x3f);
h3 = bit32.band(bit32.rshit(bits, 6), 0x3f);
h4 = bit32.band(bits, 0x3f);
-- end of string? index to '=' in b64
if (isNaN(o3)) then h4 = 64; end
if (isNaN(o2)) then h3 = 64; end
-- use hexets to index into b64, and append result to encoded string
enc = enc .. (b64.charAt(h1) + b64.charAt(h2) + b64.charAt(h3) + b64.charAt(h4));
until not (i < #str);
return enc;
end
function decodeBase64(str)
local o1, o2, o3, h1, h2, h3, h4, bits
local i=0
local enc='';
repeat -- unpack four hexets into three octets using index points in b64
h1 = b64.indexOf(str.charAt(i)); i = i + 1
h2 = b64.indexOf(str.charAt(i)); i = i + 1
h3 = b64.indexOf(str.charAt(i)); i = i + 1
h4 = b64.indexOf(str.charAt(i)); i = i + 1
bits = bit32.bor(bit32.bor(bit32.bor(bit32.lshift(h1, 18), bit32.lshift(h2, 12)), bit32.lshift(h3, 6)), h4);
o1 = bit32.band(bit32.rshift(bits, 16), 0xff);
o2 = bit32.band(bit32.rshift(bits, 8), 0xff);
o3 = bit32.band(bits, 0xff);
if (h3 == 64) then enc = enc .. string.char(o1);
elseif (h4 == 64) then enc = enc .. string.char(o1, o2);
else enc = enc .. string.char(o1, o2, o3); end
until not (i < #str);
return decodeUTF8(enc); -- decode UTF-8 byte-array back to Unicode
end
function encodeUTF8(str) -- encode multi-byte string into utf-8 multiple single-byte characters
return str;
end
function decodeUTF8(str) -- decode utf-8 encoded string back into multi-byte characters
return str;
end
local plainText = "ROMEO: But, soft! what light through yonder window breaks?\n\
It is the east, and Juliet is the sun.\n\
Arise, fair sun, and kill the envious moon,\n\
Who is already sick and pale with grief,\n\
That thou her maid art far more fair than she:\n\
Be not her maid, since she is envious;\n\
Her vestal livery is but sick and green\n\
And none but fools do wear it; cast it off.\n\
It is my lady, O, it is my love!\n\
O, that she knew she were!\n\
She speaks yet she says nothing: what of that?\n\
Her eye discourses; I will answer it.\n\
I am too bold, 'tis not to me she speaks:\n\
Two of the fairest stars in all the heaven,\n\
Having some business, do entreat her eyes\n\
To twinkle in their spheres till they return.\n\
What if her eyes were there, they in her head?\n\
The brightness of her cheek would shame those stars,\n\
As daylight doth a lamp; her eyes in heaven\n\
Would through the airy region stream so bright\n\
That birds would sing and think it were not night.\n\
See, how she leans her cheek upon her hand!\n\
O, that I were a glove upon that hand,\n\
That I might touch that cheek!\n\
JULIET: Ay me!\n\
ROMEO: She speaks:\n\
O, speak again, bright angel! for thou art\n\
As glorious to this night, being o'er my head\n\
As is a winged messenger of heaven\n\
Unto the white-upturned wondering eyes\n\
Of mortals that fall back to gaze on him\n\
When he bestrides the lazy-pacing clouds\n\
And sails upon the bosom of the air.";
local password = "O Romeo, Romeo! wherefore art thou Romeo?";
local t = ""
for i = 0,10000 do
t = t.."a"
end
local cipherText = AESEncryptCtr(plainText, password, 256);
local decryptedText = AESDecryptCtr(cipherText, password, 256);
if (decryptedText ~= plainText) then
assert(false, "ERROR: bad result: expected " .. plainText .. " but got " .. decryptedText);
end
end
bench.runCode(test, "crypto-aes")

View File

@ -185,7 +185,7 @@ local function AESEncryptCtr(plaintext, password, nBits)
-- for real-world applications, a higher security approach would be to hash the password e.g. with SHA-1 -- for real-world applications, a higher security approach would be to hash the password e.g. with SHA-1
local nBytes = nBits/8; -- no bytes in key local nBytes = nBits/8; -- no bytes in key
local pwBytes = {}; local pwBytes = {};
for i = 0,nBytes-1 do pwBytes[i + 1] = bit32.band(string.byte(password, i + 1), 0xff); end for i = 0,nBytes-1 do pwBytes[i + 1] = string.byte(password, i + 1); end
local key = Cipher(pwBytes, KeyExpansion(pwBytes)); local key = Cipher(pwBytes, KeyExpansion(pwBytes));
-- key is now 16/24/32 bytes long -- key is now 16/24/32 bytes long
@ -197,11 +197,11 @@ local function AESEncryptCtr(plaintext, password, nBits)
-- block counter in 2nd 8 bytes -- block counter in 2nd 8 bytes
local blockSize = 16; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES local blockSize = 16; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES
local counterBlock = {}; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES local counterBlock = {}; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES
local nonce = 12564231564 -- (new Date()).getTime(); -- milliseconds since 1-Jan-1970 local nonce = os.clock() * 1000 -- (new Date()).getTime(); -- milliseconds since 1-Jan-1970
-- encode nonce in two stages to cater for JavaScript 32-bit limit on bitwise ops -- encode nonce in two stages to cater for JavaScript 32-bit limit on bitwise ops
for i = 0,3 do counterBlock[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff); end for i = 0,3 do counterBlock[i + 1] = bit32.extract(nonce, i * 8, 8); end
for i = 0,3 do counterBlock[i + 4 + 1] = bit32.band(bit32.rshift(math.floor(nonce / 0x100000000), i*8), 0xff); end for i = 0,3 do counterBlock[i + 4 + 1] = bit32.extract(math.floor(nonce / 0x100000000), i*8, 8); end
-- generate key schedule - an expansion of the key into distinct Key Rounds for each round -- generate key schedule - an expansion of the key into distinct Key Rounds for each round
local keySchedule = KeyExpansion(key); local keySchedule = KeyExpansion(key);
@ -212,8 +212,8 @@ local function AESEncryptCtr(plaintext, password, nBits)
for b = 0,blockCount-1 do for b = 0,blockCount-1 do
-- set counter (block #) in last 8 bytes of counter block (leaving nonce in 1st 8 bytes) -- set counter (block #) in last 8 bytes of counter block (leaving nonce in 1st 8 bytes)
-- again done in two stages for 32-bit ops -- again done in two stages for 32-bit ops
for c = 0,3 do counterBlock[15-c + 1] = bit32.band(bit32.rshift(b, c*8), 0xff); end for c = 0,3 do counterBlock[15-c + 1] = bit32.extract(b, c*8, 8); end
for c = 0,3 do counterBlock[15-c-4 + 1] = bit32.rshift(math.floor(b/0x100000000), c*8) end for c = 0,3 do counterBlock[15-c-4 + 1] = bit32.extract(math.floor(b/0x100000000), c*8, 8); end
local cipherCntr = Cipher(counterBlock, keySchedule); -- -- encrypt counter block -- local cipherCntr = Cipher(counterBlock, keySchedule); -- -- encrypt counter block --
@ -260,7 +260,7 @@ local function AESDecryptCtr(ciphertext, password, nBits)
local nBytes = nBits/8; -- no bytes in key local nBytes = nBits/8; -- no bytes in key
local pwBytes = {}; local pwBytes = {};
for i = 0,nBytes-1 do pwBytes[i + 1] = bit32.band(string.byte(password, i + 1), 0xff); end for i = 0,nBytes-1 do pwBytes[i + 1] = string.byte(password, i + 1); end
local pwKeySchedule = KeyExpansion(pwBytes); local pwKeySchedule = KeyExpansion(pwBytes);
local key = Cipher(pwBytes, pwKeySchedule); local key = Cipher(pwBytes, pwKeySchedule);
@ -290,8 +290,8 @@ local function AESDecryptCtr(ciphertext, password, nBits)
for b = 1,#ciphertext-1 do for b = 1,#ciphertext-1 do
-- set counter (block #) in last 8 bytes of counter block (leaving nonce in 1st 8 bytes) -- set counter (block #) in last 8 bytes of counter block (leaving nonce in 1st 8 bytes)
for c = 0,3 do counterBlock[15-c + 1] = bit32.band(bit32.rshift((b-1), c*8), 0xff); end for c = 0,3 do counterBlock[15-c + 1] = bit32.extract(b-1, c*8, 8); end
for c = 0,3 do counterBlock[15-c-4 + 1] = bit32.band(bit32.rshift(math.floor((b-1)/0x100000000), c*8), 0xff); end for c = 0,3 do counterBlock[15-c-4 + 1] = bit32.extract(math.floor((b-1)/0x100000000), c*8, 8); end
local cipherCntr = Cipher(counterBlock, keySchedule); -- encrypt counter block local cipherCntr = Cipher(counterBlock, keySchedule); -- encrypt counter block

3089
extern/doctest.h vendored

File diff suppressed because it is too large Load Diff

View File

@ -2240,43 +2240,18 @@ local a: aaa.do
CHECK(ac.entryMap.count("other")); CHECK(ac.entryMap.count("other"));
} }
TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteSource")
TEST_CASE_FIXTURE(ACFixture, "comments")
{ {
std::string_view source = R"( fileResolver.source["Comments"] = "--!str";
local a = table. -- Line 1
-- | Column 23
)";
auto ac = autocompleteSource(frontend, source, Position{1, 24}, nullCallback).result; auto ac = Luau::autocomplete(frontend, "Comments", Position{0, 6}, nullCallback);
CHECK_EQ(17, ac.entryMap.size());
CHECK(ac.entryMap.count("find"));
CHECK(ac.entryMap.count("pack"));
CHECK(!ac.entryMap.count("math"));
}
TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_require")
{
std::string_view source = R"(
local a = require(w -- Line 1
-- | Column 27
)";
// CLI-43699 require shouldn't crash inside autocompleteSource
auto ac = autocompleteSource(frontend, source, Position{1, 27}, nullCallback).result;
}
TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_comments")
{
std::string_view source = "--!str";
auto ac = autocompleteSource(frontend, source, Position{0, 6}, nullCallback).result;
CHECK_EQ(0, ac.entryMap.size()); CHECK_EQ(0, ac.entryMap.size());
} }
TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteProp_index_function_metamethod_is_variadic") TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteProp_index_function_metamethod_is_variadic")
{ {
std::string_view source = R"( fileResolver.source["Module/A"] = R"(
type Foo = {x: number} type Foo = {x: number}
local t = {} local t = {}
setmetatable(t, { setmetatable(t, {
@ -2289,7 +2264,7 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocompleteProp_index_function_metamethod
-- | Column 20 -- | Column 20
)"; )";
auto ac = autocompleteSource(frontend, source, Position{9, 20}, nullCallback).result; auto ac = Luau::autocomplete(frontend, "Module/A", Position{9, 20}, nullCallback);
REQUIRE_EQ(1, ac.entryMap.size()); REQUIRE_EQ(1, ac.entryMap.size());
CHECK(ac.entryMap.count("x")); CHECK(ac.entryMap.count("x"));
} }
@ -2378,35 +2353,36 @@ end
CHECK(ac.entryMap.count("elsewhere")); CHECK(ac.entryMap.count("elsewhere"));
} }
TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_not_the_var_we_are_defining") TEST_CASE_FIXTURE(ACFixture, "not_the_var_we_are_defining")
{ {
std::string_view source = "abc,de"; fileResolver.source["Module/A"] = "abc,de";
auto ac = autocompleteSource(frontend, source, Position{0, 6}, nullCallback).result; auto ac = Luau::autocomplete(frontend, "Module/A", Position{0, 6}, nullCallback);
CHECK(!ac.entryMap.count("de")); CHECK(!ac.entryMap.count("de"));
} }
TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_recursive_function") TEST_CASE_FIXTURE(ACFixture, "recursive_function_global")
{ {
{ fileResolver.source["global"] = R"(function abc()
std::string_view global = R"(function abc()
end end
)"; )";
auto ac = autocompleteSource(frontend, global, Position{1, 0}, nullCallback).result; auto ac = Luau::autocomplete(frontend, "global", Position{1, 0}, nullCallback);
CHECK(ac.entryMap.count("abc")); CHECK(ac.entryMap.count("abc"));
} }
{
std::string_view local = R"(local function abc()
TEST_CASE_FIXTURE(ACFixture, "recursive_function_local")
{
fileResolver.source["local"] = R"(local function abc()
end end
)"; )";
auto ac = autocompleteSource(frontend, local, Position{1, 0}, nullCallback).result; auto ac = Luau::autocomplete(frontend, "local", Position{1, 0}, nullCallback);
CHECK(ac.entryMap.count("abc")); CHECK(ac.entryMap.count("abc"));
}
} }
TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys") TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys")

View File

@ -165,8 +165,8 @@ LOADN R1 1
FASTCALL2K 18 R1 K0 L0 FASTCALL2K 18 R1 K0 L0
LOADK R2 K0 LOADK R2 K0
GETIMPORT R0 3 GETIMPORT R0 3
L0: CALL R0 2 -1 CALL R0 2 -1
RETURN R0 -1 L0: RETURN R0 -1
)"); )");
} }
@ -2100,12 +2100,12 @@ FASTCALL2 18 R0 R1 L0
MOVE R5 R0 MOVE R5 R0
MOVE R6 R1 MOVE R6 R1
GETIMPORT R4 2 GETIMPORT R4 2
L0: CALL R4 2 1 CALL R4 2 1
FASTCALL2 19 R4 R2 L1 L0: FASTCALL2 19 R4 R2 L1
MOVE R5 R2 MOVE R5 R2
GETIMPORT R3 4 GETIMPORT R3 4
L1: CALL R3 2 -1 CALL R3 2 -1
RETURN R3 -1 L1: RETURN R3 -1
)"); )");
} }
@ -2511,8 +2511,8 @@ return
5: MOVE R3 R0 5: MOVE R3 R0
5: MOVE R4 R1 5: MOVE R4 R1
5: GETIMPORT R2 2 5: GETIMPORT R2 2
5: L0: CALL R2 2 -1 5: CALL R2 2 -1
5: RETURN R2 -1 5: L0: RETURN R2 -1
)"); )");
} }
@ -2828,8 +2828,8 @@ TEST_CASE("FastcallBytecode")
LOADN R1 -5 LOADN R1 -5
FASTCALL1 2 R1 L0 FASTCALL1 2 R1 L0
GETIMPORT R0 2 GETIMPORT R0 2
L0: CALL R0 1 -1 CALL R0 1 -1
RETURN R0 -1 L0: RETURN R0 -1
)"); )");
// call through a local variable // call through a local variable
@ -2838,8 +2838,8 @@ GETIMPORT R0 2
LOADN R2 -5 LOADN R2 -5
FASTCALL1 2 R2 L0 FASTCALL1 2 R2 L0
MOVE R1 R0 MOVE R1 R0
L0: CALL R1 1 -1 CALL R1 1 -1
RETURN R1 -1 L0: RETURN R1 -1
)"); )");
// call through an upvalue // call through an upvalue
@ -2847,8 +2847,8 @@ RETURN R1 -1
LOADN R1 -5 LOADN R1 -5
FASTCALL1 2 R1 L0 FASTCALL1 2 R1 L0
GETUPVAL R0 0 GETUPVAL R0 0
L0: CALL R0 1 -1 CALL R0 1 -1
RETURN R0 -1 L0: RETURN R0 -1
)"); )");
// mutating the global in the script breaks the optimization // mutating the global in the script breaks the optimization
@ -2893,8 +2893,8 @@ LOADK R1 K0
FASTCALL1 57 R1 L0 FASTCALL1 57 R1 L0
GETIMPORT R0 2 GETIMPORT R0 2
GETVARARGS R2 -1 GETVARARGS R2 -1
L0: CALL R0 -1 1 CALL R0 -1 1
RETURN R0 1 L0: RETURN R0 1
)"); )");
// more complex example: select inside a for loop bound + select from a iterator // more complex example: select inside a for loop bound + select from a iterator
@ -2912,16 +2912,16 @@ LOADK R5 K0
FASTCALL1 57 R5 L0 FASTCALL1 57 R5 L0
GETIMPORT R4 2 GETIMPORT R4 2
GETVARARGS R6 -1 GETVARARGS R6 -1
L0: CALL R4 -1 1 CALL R4 -1 1
MOVE R1 R4 L0: MOVE R1 R4
LOADN R2 1 LOADN R2 1
FORNPREP R1 L3 FORNPREP R1 L3
L1: FASTCALL1 57 R3 L2 L1: FASTCALL1 57 R3 L2
GETIMPORT R4 2 GETIMPORT R4 2
MOVE R5 R3 MOVE R5 R3
GETVARARGS R6 -1 GETVARARGS R6 -1
L2: CALL R4 -1 1 CALL R4 -1 1
ADD R0 R0 R4 L2: ADD R0 R0 R4
FORNLOOP R1 L1 FORNLOOP R1 L1
L3: RETURN R0 1 L3: RETURN R0 1
)"); )");
@ -3242,7 +3242,7 @@ LOADN R2 -1
FASTCALL1 2 R2 L0 FASTCALL1 2 R2 L0
GETGLOBAL R3 K1024 GETGLOBAL R3 K1024
GETTABLEKS R1 R3 K1025 GETTABLEKS R1 R3 K1025
L0: CALL R1 1 -1 CALL R1 1 -1
)"); )");
} }
@ -4063,8 +4063,8 @@ LOADN R2 2
LOADN R3 3 LOADN R3 3
FASTCALL 54 L0 FASTCALL 54 L0
GETIMPORT R0 2 GETIMPORT R0 2
L0: CALL R0 3 -1 CALL R0 3 -1
RETURN R0 -1 L0: RETURN R0 -1
)"); )");
} }
@ -4351,6 +4351,8 @@ TEST_CASE("LoopUnrollControlFlow")
{"LuauCompileLoopUnrollThresholdMaxBoost", 300}, {"LuauCompileLoopUnrollThresholdMaxBoost", 300},
}; };
ScopedFastFlag sff("LuauCompileFoldBuiltins", true);
// break jumps to the end // break jumps to the end
CHECK_EQ("\n" + compileFunction(R"( CHECK_EQ("\n" + compileFunction(R"(
for i=1,3 do for i=1,3 do
@ -4414,7 +4416,7 @@ L2: RETURN R0 0
// continue needs to properly close upvalues // continue needs to properly close upvalues
CHECK_EQ("\n" + compileFunction(R"( CHECK_EQ("\n" + compileFunction(R"(
for i=1,1 do for i=1,1 do
local j = math.abs(i) local j = global(i)
print(function() return j end) print(function() return j end)
if math.random() < 0.5 then if math.random() < 0.5 then
continue continue
@ -4424,21 +4426,20 @@ end
)", )",
1, 2), 1, 2),
R"( R"(
GETIMPORT R0 1
LOADN R1 1 LOADN R1 1
FASTCALL1 2 R1 L0 CALL R0 1 1
GETIMPORT R0 2 GETIMPORT R1 3
L0: CALL R0 1 1
GETIMPORT R1 4
NEWCLOSURE R2 P0 NEWCLOSURE R2 P0
CAPTURE REF R0 CAPTURE REF R0
CALL R1 1 0 CALL R1 1 0
GETIMPORT R1 6 GETIMPORT R1 6
CALL R1 0 1 CALL R1 0 1
LOADK R2 K7 LOADK R2 K7
JUMPIFNOTLT R1 R2 L1 JUMPIFNOTLT R1 R2 L0
CLOSEUPVALS R0 CLOSEUPVALS R0
RETURN R0 0 RETURN R0 0
L1: ADDK R0 R0 K8 L0: ADDK R0 R0 K8
CLOSEUPVALS R0 CLOSEUPVALS R0
RETURN R0 0 RETURN R0 0
)"); )");
@ -4625,11 +4626,11 @@ FORNPREP R1 L3
L0: FASTCALL1 24 R3 L1 L0: FASTCALL1 24 R3 L1
MOVE R6 R3 MOVE R6 R3
GETIMPORT R5 2 GETIMPORT R5 2
L1: CALL R5 1 -1 CALL R5 1 -1
FASTCALL 2 L2 L1: FASTCALL 2 L2
GETIMPORT R4 4 GETIMPORT R4 4
L2: CALL R4 -1 1 CALL R4 -1 1
SETTABLE R4 R0 R3 L2: SETTABLE R4 R0 R3
FORNLOOP R1 L0 FORNLOOP R1 L0
L3: RETURN R0 1 L3: RETURN R0 1
)"); )");
@ -4660,6 +4661,133 @@ L1: RETURN R0 0
)"); )");
} }
TEST_CASE("LoopUnrollCostBuiltins")
{
ScopedFastInt sfis[] = {
{"LuauCompileLoopUnrollThreshold", 25},
{"LuauCompileLoopUnrollThresholdMaxBoost", 300},
};
ScopedFastFlag sff("LuauCompileModelBuiltins", true);
// this loop uses builtins and is close to the cost budget so it's important that we model builtins as cheaper than regular calls
CHECK_EQ("\n" + compileFunction(R"(
function cipher(block, nonce)
for i = 0,3 do
block[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff)
end
end
)",
0, 2),
R"(
FASTCALL2K 39 R1 K0 L0
MOVE R4 R1
LOADK R5 K0
GETIMPORT R3 3
CALL R3 2 1
L0: FASTCALL2K 29 R3 K4 L1
LOADK R4 K4
GETIMPORT R2 6
CALL R2 2 1
L1: SETTABLEN R2 R0 1
FASTCALL2K 39 R1 K7 L2
MOVE R4 R1
LOADK R5 K7
GETIMPORT R3 3
CALL R3 2 1
L2: FASTCALL2K 29 R3 K4 L3
LOADK R4 K4
GETIMPORT R2 6
CALL R2 2 1
L3: SETTABLEN R2 R0 2
FASTCALL2K 39 R1 K8 L4
MOVE R4 R1
LOADK R5 K8
GETIMPORT R3 3
CALL R3 2 1
L4: FASTCALL2K 29 R3 K4 L5
LOADK R4 K4
GETIMPORT R2 6
CALL R2 2 1
L5: SETTABLEN R2 R0 3
FASTCALL2K 39 R1 K9 L6
MOVE R4 R1
LOADK R5 K9
GETIMPORT R3 3
CALL R3 2 1
L6: FASTCALL2K 29 R3 K4 L7
LOADK R4 K4
GETIMPORT R2 6
CALL R2 2 1
L7: SETTABLEN R2 R0 4
RETURN R0 0
)");
// note that if we break compiler's ability to reason about bit32 builtin the loop is no longer unrolled as it's too expensive
CHECK_EQ("\n" + compileFunction(R"(
bit32 = {}
function cipher(block, nonce)
for i = 0,3 do
block[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff)
end
end
)",
0, 2),
R"(
LOADN R4 0
LOADN R2 3
LOADN R3 1
FORNPREP R2 L1
L0: ADDK R5 R4 K0
GETGLOBAL R7 K1
GETTABLEKS R6 R7 K2
GETGLOBAL R8 K1
GETTABLEKS R7 R8 K3
MOVE R8 R1
MULK R9 R4 K4
CALL R7 2 1
LOADN R8 255
CALL R6 2 1
SETTABLE R6 R0 R5
FORNLOOP R2 L0
L1: RETURN R0 0
)");
// additionally, if we pass too many constants the builtin stops being cheap because of argument setup
CHECK_EQ("\n" + compileFunction(R"(
function cipher(block, nonce)
for i = 0,3 do
block[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff, 0xff, 0xff, 0xff, 0xff)
end
end
)",
0, 2),
R"(
LOADN R4 0
LOADN R2 3
LOADN R3 1
FORNPREP R2 L3
L0: ADDK R5 R4 K0
MULK R9 R4 K1
FASTCALL2 39 R1 R9 L1
MOVE R8 R1
GETIMPORT R7 4
CALL R7 2 1
L1: LOADN R8 255
LOADN R9 255
LOADN R10 255
LOADN R11 255
LOADN R12 255
FASTCALL 29 L2
GETIMPORT R6 6
CALL R6 6 1
L2: SETTABLE R6 R0 R5
FORNLOOP R2 L0
L3: RETURN R0 0
)");
}
TEST_CASE("InlineBasic") TEST_CASE("InlineBasic")
{ {
// inline function that returns a constant // inline function that returns a constant
@ -5216,8 +5344,8 @@ DUPCLOSURE R0 K0
LOADK R3 K1 LOADK R3 K1
FASTCALL1 20 R3 L0 FASTCALL1 20 R3 L0
GETIMPORT R2 4 GETIMPORT R2 4
L0: CALL R2 1 2 CALL R2 1 2
ADD R1 R2 R3 L0: ADD R1 R2 R3
RETURN R1 1 RETURN R1 1
)"); )");
@ -5483,14 +5611,14 @@ NEWTABLE R2 0 0
FASTCALL2K 49 R2 K1 L0 FASTCALL2K 49 R2 K1 L0
LOADK R3 K1 LOADK R3 K1
GETIMPORT R1 3 GETIMPORT R1 3
L0: CALL R1 2 0 CALL R1 2 0
NEWTABLE R1 0 0 L0: NEWTABLE R1 0 0
NEWTABLE R3 0 0 NEWTABLE R3 0 0
FASTCALL2 49 R3 R1 L1 FASTCALL2 49 R3 R1 L1
MOVE R4 R1 MOVE R4 R1
GETIMPORT R2 3 GETIMPORT R2 3
L1: CALL R2 2 0 CALL R2 2 0
RETURN R0 0 L1: RETURN R0 0
)"); )");
} }
@ -5762,4 +5890,271 @@ RETURN R0 2
)"); )");
} }
TEST_CASE("OptimizationLevel")
{
ScopedFastFlag sff("LuauAlwaysCaptureHotComments", true);
// at optimization level 1, no inlining is performed
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a)
return a
end
return foo(42)
)",
1, 1),
R"(
DUPCLOSURE R0 K0
MOVE R1 R0
LOADN R2 42
CALL R1 1 -1
RETURN R1 -1
)");
// you can override the level from 1 to 2 to force it
CHECK_EQ("\n" + compileFunction(R"(
--!optimize 2
local function foo(a)
return a
end
return foo(42)
)",
1, 1),
R"(
DUPCLOSURE R0 K0
LOADN R1 42
RETURN R1 1
)");
// you can also override it externally
CHECK_EQ("\n" + compileFunction(R"(
local function foo(a)
return a
end
return foo(42)
)",
1, 2),
R"(
DUPCLOSURE R0 K0
LOADN R1 42
RETURN R1 1
)");
// ... after which you can downgrade it back via hot comment
CHECK_EQ("\n" + compileFunction(R"(
--!optimize 1
local function foo(a)
return a
end
return foo(42)
)",
1, 2),
R"(
DUPCLOSURE R0 K0
MOVE R1 R0
LOADN R2 42
CALL R1 1 -1
RETURN R1 -1
)");
}
TEST_CASE("BuiltinFolding")
{
ScopedFastFlag sff("LuauCompileFoldBuiltins", true);
CHECK_EQ("\n" + compileFunction(R"(
return
math.abs(-42),
math.acos(1),
math.asin(0),
math.atan2(0, 1),
math.atan(0),
math.ceil(1.5),
math.cosh(0),
math.cos(0),
math.deg(3.14159265358979323846),
math.exp(0),
math.floor(-1.5),
math.fmod(7, 3),
math.ldexp(0.5, 3),
math.log10(100),
math.log(1),
math.log(4, 2),
math.log(27, 3),
math.max(1, 2, 3),
math.min(1, 2, 3),
math.pow(3, 3),
math.floor(math.rad(180)),
math.sinh(0),
math.sin(0),
math.sqrt(9),
math.tanh(0),
math.tan(0),
bit32.arshift(-10, 1),
bit32.arshift(10, 1),
bit32.band(1, 3),
bit32.bnot(-2),
bit32.bor(1, 2),
bit32.bxor(3, 7),
bit32.btest(1, 3),
bit32.extract(100, 1, 3),
bit32.lrotate(100, -1),
bit32.lshift(100, 1),
bit32.replace(100, 5, 1, 3),
bit32.rrotate(100, -1),
bit32.rshift(100, 1),
type(100),
string.byte("a"),
string.byte("abc", 2),
string.len("abc"),
typeof(true),
math.clamp(-1, 0, 1),
math.sign(77),
math.round(7.6),
(type("fin"))
)",
0, 2),
R"(
LOADN R0 42
LOADN R1 0
LOADN R2 0
LOADN R3 0
LOADN R4 0
LOADN R5 2
LOADN R6 1
LOADN R7 1
LOADN R8 180
LOADN R9 1
LOADN R10 -2
LOADN R11 1
LOADN R12 4
LOADN R13 2
LOADN R14 0
LOADN R15 2
LOADN R16 3
LOADN R17 3
LOADN R18 1
LOADN R19 27
LOADN R20 3
LOADN R21 0
LOADN R22 0
LOADN R23 3
LOADN R24 0
LOADN R25 0
LOADK R26 K0
LOADN R27 5
LOADN R28 1
LOADN R29 1
LOADN R30 3
LOADN R31 4
LOADB R32 1
LOADN R33 2
LOADN R34 50
LOADN R35 200
LOADN R36 106
LOADN R37 200
LOADN R38 50
LOADK R39 K1
LOADN R40 97
LOADN R41 98
LOADN R42 3
LOADK R43 K2
LOADN R44 0
LOADN R45 1
LOADN R46 8
LOADK R47 K3
RETURN R0 48
)");
}
TEST_CASE("BuiltinFoldingProhibited")
{
ScopedFastFlag sff("LuauCompileFoldBuiltins", true);
CHECK_EQ("\n" + compileFunction(R"(
return
math.abs(),
math.max(1, true),
string.byte("abc", 42),
bit32.rshift(10, 42)
)",
0, 2),
R"(
FASTCALL 2 L0
GETIMPORT R0 2
CALL R0 0 1
L0: LOADN R2 1
FASTCALL2K 18 R2 K3 L1
LOADK R3 K3
GETIMPORT R1 5
CALL R1 2 1
L1: LOADK R3 K6
FASTCALL2K 41 R3 K7 L2
LOADK R4 K7
GETIMPORT R2 10
CALL R2 2 1
L2: LOADN R4 10
FASTCALL2K 39 R4 K7 L3
LOADK R5 K7
GETIMPORT R3 13
CALL R3 2 -1
L3: RETURN R0 -1
)");
}
TEST_CASE("BuiltinFoldingMultret")
{
ScopedFastFlag sff1("LuauCompileFoldBuiltins", true);
ScopedFastFlag sff2("LuauCompileBetterMultret", true);
CHECK_EQ("\n" + compileFunction(R"(
local NoLanes: Lanes = --[[ ]] 0b0000000000000000000000000000000
local OffscreenLane: Lane = --[[ ]] 0b1000000000000000000000000000000
local function getLanesToRetrySynchronouslyOnError(root: FiberRoot): Lanes
local everythingButOffscreen = bit32.band(root.pendingLanes, bit32.bnot(OffscreenLane))
if everythingButOffscreen ~= NoLanes then
return everythingButOffscreen
end
if bit32.band(everythingButOffscreen, OffscreenLane) ~= 0 then
return OffscreenLane
end
return NoLanes
end
)",
0, 2),
R"(
GETTABLEKS R2 R0 K0
FASTCALL2K 29 R2 K1 L0
LOADK R3 K1
GETIMPORT R1 4
CALL R1 2 1
L0: JUMPIFEQK R1 K5 L1
RETURN R1 1
L1: FASTCALL2K 29 R1 K6 L2
MOVE R3 R1
LOADK R4 K6
GETIMPORT R2 4
CALL R2 2 1
L2: JUMPIFEQK R2 K5 L3
LOADK R2 K6
RETURN R2 1
L3: LOADN R2 0
RETURN R2 1
)");
// Note: similarly, here we should have folded the return value but haven't because it's the last call in the sequence
CHECK_EQ("\n" + compileFunction(R"(
return math.abs(-42)
)",
0, 2),
R"(
LOADN R0 42
RETURN R0 1
)");
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -17,6 +17,16 @@
#include <math.h> #include <math.h>
extern bool verbose; extern bool verbose;
extern int optimizationLevel;
static lua_CompileOptions defaultOptions()
{
lua_CompileOptions copts = {};
copts.optimizationLevel = optimizationLevel;
copts.debugLevel = 1;
return copts;
}
static int lua_collectgarbage(lua_State* L) static int lua_collectgarbage(lua_State* L)
{ {
@ -127,7 +137,7 @@ int lua_silence(lua_State* L)
using StateRef = std::unique_ptr<lua_State, void (*)(lua_State*)>; using StateRef = std::unique_ptr<lua_State, void (*)(lua_State*)>;
static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr,
lua_State* initialLuaState = nullptr, lua_CompileOptions* copts = nullptr) lua_State* initialLuaState = nullptr, lua_CompileOptions* options = nullptr)
{ {
std::string path = __FILE__; std::string path = __FILE__;
path.erase(path.find_last_of("\\/")); path.erase(path.find_last_of("\\/"));
@ -189,8 +199,11 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n
std::string chunkname = "=" + std::string(name); std::string chunkname = "=" + std::string(name);
// note: luau_compile supports nullptr options, but we need to customize our defaults to improve test coverage
lua_CompileOptions opts = options ? *options : defaultOptions();
size_t bytecodeSize = 0; size_t bytecodeSize = 0;
char* bytecode = luau_compile(source.data(), source.size(), copts, &bytecodeSize); char* bytecode = luau_compile(source.data(), source.size(), &opts, &bytecodeSize);
int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0); int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0);
free(bytecode); free(bytecode);
@ -383,9 +396,7 @@ TEST_CASE("Pack")
TEST_CASE("Vector") TEST_CASE("Vector")
{ {
lua_CompileOptions copts = {}; lua_CompileOptions copts = defaultOptions();
copts.optimizationLevel = 1;
copts.debugLevel = 1;
copts.vectorCtor = "vector"; copts.vectorCtor = "vector";
runConformance( runConformance(
@ -519,8 +530,7 @@ TEST_CASE("Debugger")
breakhits = 0; breakhits = 0;
interruptedthread = nullptr; interruptedthread = nullptr;
lua_CompileOptions copts = {}; lua_CompileOptions copts = defaultOptions();
copts.optimizationLevel = 1;
copts.debugLevel = 2; copts.debugLevel = 2;
runConformance( runConformance(
@ -850,6 +860,43 @@ TEST_CASE("ApiCalls")
} }
} }
TEST_CASE("ApiAtoms")
{
ScopedFastFlag sff("LuauLazyAtoms", true);
StateRef globalState(luaL_newstate(), lua_close);
lua_State* L = globalState.get();
lua_callbacks(L)->useratom = [](const char* s, size_t l) -> int16_t {
if (strcmp(s, "string") == 0)
return 0;
if (strcmp(s, "important") == 0)
return 1;
return -1;
};
lua_pushstring(L, "string");
lua_pushstring(L, "import");
lua_pushstring(L, "ant");
lua_concat(L, 2);
lua_pushstring(L, "unimportant");
int a1, a2, a3;
const char* s1 = lua_tostringatom(L, -3, &a1);
const char* s2 = lua_tostringatom(L, -2, &a2);
const char* s3 = lua_tostringatom(L, -1, &a3);
CHECK(strcmp(s1, "string") == 0);
CHECK(a1 == 0);
CHECK(strcmp(s2, "important") == 0);
CHECK(a2 == 1);
CHECK(strcmp(s3, "unimportant") == 0);
CHECK(a3 == -1);
}
static bool endsWith(const std::string& str, const std::string& suffix) static bool endsWith(const std::string& str, const std::string& suffix)
{ {
if (suffix.length() > str.length()) if (suffix.length() > str.length())
@ -957,9 +1004,8 @@ TEST_CASE("TagMethodError")
TEST_CASE("Coverage") TEST_CASE("Coverage")
{ {
lua_CompileOptions copts = {}; lua_CompileOptions copts = defaultOptions();
copts.optimizationLevel = 1; copts.optimizationLevel = 1; // disable inlining to get fixed expected hit results
copts.debugLevel = 1;
copts.coverageLevel = 2; copts.coverageLevel = 2;
runConformance( runConformance(
@ -1059,6 +1105,9 @@ TEST_CASE("GCDump")
TEST_CASE("Interrupt") TEST_CASE("Interrupt")
{ {
lua_CompileOptions copts = defaultOptions();
copts.optimizationLevel = 1; // disable loop unrolling to get fixed expected hit results
static const int expectedhits[] = { static const int expectedhits[] = {
2, 2,
9, 9,
@ -1109,7 +1158,8 @@ TEST_CASE("Interrupt")
}, },
[](lua_State* L) { [](lua_State* L) {
CHECK(index == 5); // a single yield point CHECK(index == 5); // a single yield point
}); },
nullptr, &copts);
CHECK(index == int(std::size(expectedhits))); CHECK(index == int(std::size(expectedhits)));
} }

View File

@ -10,7 +10,7 @@ namespace Luau
namespace Compile namespace Compile
{ {
uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount); uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount, const DenseHashMap<AstExprCall*, int>& builtins);
int computeCost(uint64_t model, const bool* varsConst, size_t varCount); int computeCost(uint64_t model, const bool* varsConst, size_t varCount);
} // namespace Compile } // namespace Compile
@ -29,7 +29,7 @@ static uint64_t modelFunction(const char* source)
AstStatFunction* func = result.root->body.data[0]->as<AstStatFunction>(); AstStatFunction* func = result.root->body.data[0]->as<AstStatFunction>();
REQUIRE(func); REQUIRE(func);
return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size); return Luau::Compile::modelCost(func->func->body, func->func->args.data, func->func->args.size, {nullptr});
} }
TEST_CASE("Expression") TEST_CASE("Expression")

View File

@ -756,26 +756,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "test_lint_uses_correct_config")
CHECK_EQ(0, result4.warnings.size()); CHECK_EQ(0, result4.warnings.size());
} }
TEST_CASE_FIXTURE(FrontendFixture, "lintFragment")
{
LintOptions lintOptions;
lintOptions.enableWarning(LintWarning::Code_ForRange);
auto [_sourceModule, result] = frontend.lintFragment(R"(
local t = {}
for i=#t,1 do
end
for i=#t,1,-1 do
end
)",
lintOptions);
CHECK_EQ(1, result.warnings.size());
CHECK_EQ(0, result.errors.size());
}
TEST_CASE_FIXTURE(FrontendFixture, "discard_type_graphs") TEST_CASE_FIXTURE(FrontendFixture, "discard_type_graphs")
{ {
Frontend fe{&fileResolver, &configResolver, {false}}; Frontend fe{&fileResolver, &configResolver, {false}};

View File

@ -1658,4 +1658,21 @@ end
CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 2"); CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 2");
} }
TEST_CASE_FIXTURE(Fixture, "WrongCommentOptimize")
{
LintResult result = lint(R"(
--!optimize
--!optimize
--!optimize me
--!optimize 100500
--!optimize 2
)");
REQUIRE_EQ(result.warnings.size(), 4);
CHECK_EQ(result.warnings[0].text, "optimize directive requires an optimization level");
CHECK_EQ(result.warnings[1].text, "optimize directive requires an optimization level");
CHECK_EQ(result.warnings[2].text, "optimize directive uses unknown optimization level 'me', 0..2 expected");
CHECK_EQ(result.warnings[3].text, "optimize directive uses unknown optimization level '100500', 0..2 expected");
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -756,6 +756,65 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_table_normalizes_sensibly")
CHECK_EQ("t1 where t1 = { get: () -> t1 }", toString(ty, {true})); CHECK_EQ("t1 where t1 = { get: () -> t1 }", toString(ty, {true}));
} }
TEST_CASE_FIXTURE(Fixture, "cyclic_union")
{
ScopedFastFlag sff[] = {
{"LuauLowerBoundsCalculation", true},
{"LuauFixNormalizationOfCyclicUnions", true},
};
CheckResult result = check(R"(
type T = {T?}?
local a: T
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK("t1? where t1 = {t1?}" == toString(requireType("a")));
}
TEST_CASE_FIXTURE(Fixture, "cyclic_intersection")
{
ScopedFastFlag sff[] = {
{"LuauLowerBoundsCalculation", true},
{"LuauFixNormalizationOfCyclicUnions", true},
};
CheckResult result = check(R"(
type T = {T & {}}
local a: T
)");
LUAU_REQUIRE_NO_ERRORS(result);
// FIXME: We are not properly normalizing this type, but we are at least not improperly discarding information
CHECK("t1 where t1 = {{t1 & {| |}}}" == toString(requireType("a"), {true}));
}
TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_indexers")
{
ScopedFastFlag sff[] = {
{"LuauLowerBoundsCalculation", true},
{"LuauFixNormalizationOfCyclicUnions", true},
};
CheckResult result = check(R"(
type A = {number}
type B = {string}
type C = A & B
local a: C
)");
LUAU_REQUIRE_NO_ERRORS(result);
// FIXME: We are not properly normalizing this type, but we are at least not improperly discarding information
CHECK("{number & string}" == toString(requireType("a"), {true}));
}
TEST_CASE_FIXTURE(BuiltinsFixture, "union_of_distinct_free_types") TEST_CASE_FIXTURE(BuiltinsFixture, "union_of_distinct_free_types")
{ {
ScopedFastFlag flags[] = { ScopedFastFlag flags[] = {

View File

@ -62,7 +62,6 @@ TEST_CASE_FIXTURE(Fixture, "named_table")
TEST_CASE_FIXTURE(Fixture, "empty_table") TEST_CASE_FIXTURE(Fixture, "empty_table")
{ {
ScopedFastFlag LuauToStringTableBracesNewlines("LuauToStringTableBracesNewlines", true);
CheckResult result = check(R"( CheckResult result = check(R"(
local a: {} local a: {}
)"); )");
@ -77,7 +76,6 @@ TEST_CASE_FIXTURE(Fixture, "empty_table")
TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break") TEST_CASE_FIXTURE(Fixture, "table_respects_use_line_break")
{ {
ScopedFastFlag LuauToStringTableBracesNewlines("LuauToStringTableBracesNewlines", true);
CheckResult result = check(R"( CheckResult result = check(R"(
local a: { prop: string, anotherProp: number, thirdProp: boolean } local a: { prop: string, anotherProp: number, thirdProp: boolean }
)"); )");

View File

@ -1143,4 +1143,114 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gmatch_capture_types_invalid_pattern_fallbac
CHECK_EQ(toString(requireType("foo")), "() -> (...string)"); CHECK_EQ(toString(requireType("foo")), "() -> (...string)");
} }
TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types")
{
ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true};
CheckResult result = check(R"END(
local a, b, c = string.match("This is a string", "(.()(%a+))")
)END");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "string");
CHECK_EQ(toString(requireType("b")), "number");
CHECK_EQ(toString(requireType("c")), "string");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "match_capture_types2")
{
ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true};
CheckResult result = check(R"END(
local a, b, c = string.match("This is a string", "(.()(%a+))", "this should be a number")
)END");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ(toString(tm->wantedType), "number?");
CHECK_EQ(toString(tm->givenType), "string");
CHECK_EQ(toString(requireType("a")), "string");
CHECK_EQ(toString(requireType("b")), "number");
CHECK_EQ(toString(requireType("c")), "string");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types")
{
ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true};
CheckResult result = check(R"END(
local d, e, a, b, c = string.find("This is a string", "(.()(%a+))")
)END");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("a")), "string");
CHECK_EQ(toString(requireType("b")), "number");
CHECK_EQ(toString(requireType("c")), "string");
CHECK_EQ(toString(requireType("d")), "number?");
CHECK_EQ(toString(requireType("e")), "number?");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types2")
{
ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true};
CheckResult result = check(R"END(
local d, e, a, b, c = string.find("This is a string", "(.()(%a+))", "this should be a number")
)END");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ(toString(tm->wantedType), "number?");
CHECK_EQ(toString(tm->givenType), "string");
CHECK_EQ(toString(requireType("a")), "string");
CHECK_EQ(toString(requireType("b")), "number");
CHECK_EQ(toString(requireType("c")), "string");
CHECK_EQ(toString(requireType("d")), "number?");
CHECK_EQ(toString(requireType("e")), "number?");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3")
{
ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true};
CheckResult result = check(R"END(
local d, e, a, b, c = string.find("This is a string", "(.()(%a+))", 1, "this should be a bool")
)END");
LUAU_REQUIRE_ERROR_COUNT(1, result);
TypeMismatch* tm = get<TypeMismatch>(result.errors[0]);
REQUIRE(tm);
CHECK_EQ(toString(tm->wantedType), "boolean?");
CHECK_EQ(toString(tm->givenType), "string");
CHECK_EQ(toString(requireType("a")), "string");
CHECK_EQ(toString(requireType("b")), "number");
CHECK_EQ(toString(requireType("c")), "string");
CHECK_EQ(toString(requireType("d")), "number?");
CHECK_EQ(toString(requireType("e")), "number?");
}
TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3")
{
ScopedFastFlag sffs{"LuauDeduceFindMatchReturnTypes", true};
CheckResult result = check(R"END(
local d, e, a, b = string.find("This is a string", "(.()(%a+))", 1, true)
)END");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CountMismatch* acm = get<CountMismatch>(result.errors[0]);
REQUIRE(acm);
CHECK_EQ(acm->context, CountMismatch::Result);
CHECK_EQ(acm->expected, 2);
CHECK_EQ(acm->actual, 4);
CHECK_EQ(toString(requireType("d")), "number?");
CHECK_EQ(toString(requireType("e")), "number?");
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -11,6 +11,8 @@
#include "doctest.h" #include "doctest.h"
LUAU_FASTFLAG(LuauDeduceFindMatchReturnTypes)
using namespace Luau; using namespace Luau;
TEST_SUITE_BEGIN("TypeInferPrimitives"); TEST_SUITE_BEGIN("TypeInferPrimitives");
@ -80,7 +82,10 @@ TEST_CASE_FIXTURE(Fixture, "string_function_other")
)"); )");
CHECK_EQ(0, result.errors.size()); CHECK_EQ(0, result.errors.size());
CHECK_EQ(toString(requireType("p")), "string?"); if (FFlag::LuauDeduceFindMatchReturnTypes)
CHECK_EQ(toString(requireType("p")), "string");
else
CHECK_EQ(toString(requireType("p")), "string?");
} }
TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber")

View File

@ -476,4 +476,21 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton")
CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 23}))); CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 23})));
} }
TEST_CASE_FIXTURE(Fixture, "no_widening_from_callsites")
{
ScopedFastFlag sff{"LuauReturnsFromCallsitesAreNotWidened", true};
CheckResult result = check(R"(
type Direction = "North" | "East" | "West" | "South"
local function direction(): Direction
return "North"
end
local d: Direction = direction()
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -3084,4 +3084,86 @@ local b = a.x
CHECK_EQ("Type 'a' does not have key 'x'", toString(result.errors[1])); CHECK_EQ("Type 'a' does not have key 'x'", toString(result.errors[1]));
} }
TEST_CASE_FIXTURE(Fixture, "scalar_is_a_subtype_of_a_compatible_polymorphic_shape_type")
{
ScopedFastFlag sff{"LuauScalarShapeSubtyping", true};
CheckResult result = check(R"(
local function f(s)
return s:lower()
end
f("foo" :: string)
f("bar" :: "bar")
f("baz" :: "bar" | "baz")
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type")
{
ScopedFastFlag sff{"LuauScalarShapeSubtyping", true};
CheckResult result = check(R"(
local function f(s)
return s:absolutely_no_scalar_has_this_method()
end
f("foo" :: string)
f("bar" :: "bar")
f("baz" :: "bar" | "baz")
)");
LUAU_REQUIRE_ERROR_COUNT(3, result);
CHECK_EQ(R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}'
caused by:
The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')",
toString(result.errors[0]));
CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}'
caused by:
The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')",
toString(result.errors[1]));
CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}'
caused by:
Not all union options are compatible. Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}'
caused by:
The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')",
toString(result.errors[2]));
}
TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compatible")
{
ScopedFastFlag sff{"LuauScalarShapeSubtyping", true};
CheckResult result = check(R"(
local function f(s): string
local foo = s:lower()
return s
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("(string) -> string", toString(requireType("f")));
}
TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible")
{
ScopedFastFlag sff{"LuauScalarShapeSubtyping", true};
CheckResult result = check(R"(
local function f(s): string
local foo = s:absolutely_no_scalar_has_this_method()
return s
end
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ(R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string'
caused by:
The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' because the former is missing field 'absolutely_no_scalar_has_this_method')",
toString(result.errors[0]));
CHECK_EQ("<a, b...>(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f")));
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -268,13 +268,49 @@ TEST_CASE_FIXTURE(Fixture, "unary_minus_of_never")
TEST_CASE_FIXTURE(Fixture, "length_of_never") TEST_CASE_FIXTURE(Fixture, "length_of_never")
{ {
ScopedFastFlag sff{"LuauNeverTypesAndOperatorsInference", true};
CheckResult result = check(R"( CheckResult result = check(R"(
local x = #({} :: never) local x = #({} :: never)
)"); )");
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("never", toString(requireType("x"))); CHECK_EQ("number", toString(requireType("x")));
}
TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_in_any_ordering_operators")
{
ScopedFastFlag sff[]{
{"LuauUnknownAndNeverType", true},
{"LuauNeverTypesAndOperatorsInference", true},
};
CheckResult result = check(R"(
local function ord(x: nil, y)
return x ~= nil and x > y
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("<a>(nil, a) -> boolean", toString(requireType("ord")));
}
TEST_CASE_FIXTURE(Fixture, "math_operators_and_never")
{
ScopedFastFlag sff[]{
{"LuauUnknownAndNeverType", true},
{"LuauNeverTypesAndOperatorsInference", true},
};
CheckResult result = check(R"(
local function mul(x: nil, y)
return x ~= nil and x * y -- infers boolean | never, which is normalized into boolean
end
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("<a>(nil, a) -> boolean", toString(requireType("mul")));
} }
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -727,16 +727,20 @@ assert((function() local abs = math.abs function foo(...) return abs(...) end re
-- NOTE: getfenv breaks fastcalls for the remainder of the source! hence why this is delayed until the end -- NOTE: getfenv breaks fastcalls for the remainder of the source! hence why this is delayed until the end
function testgetfenv() function testgetfenv()
getfenv() getfenv()
-- declare constant so that at O2 this test doesn't interfere with constant folding which we can't deoptimize
local negfive negfive = -5
-- getfenv breaks fastcalls (we assume we can't rely on knowing the semantics), but behavior shouldn't change -- getfenv breaks fastcalls (we assume we can't rely on knowing the semantics), but behavior shouldn't change
assert((function() return math.abs(-5) end)() == 5) assert((function() return math.abs(negfive) end)() == 5)
assert((function() local abs = math.abs return abs(-5) end)() == 5) assert((function() local abs = math.abs return abs(negfive) end)() == 5)
assert((function() local abs = math.abs function foo() return abs(-5) end return foo() end)() == 5) assert((function() local abs = math.abs function foo() return abs(negfive) end return foo() end)() == 5)
-- ... unless you actually reassign the function :D -- ... unless you actually reassign the function :D
getfenv().math = { abs = function(n) return n*n end } getfenv().math = { abs = function(n) return n*n end }
assert((function() return math.abs(-5) end)() == 25) assert((function() return math.abs(negfive) end)() == 25)
assert((function() local abs = math.abs return abs(-5) end)() == 25) assert((function() local abs = math.abs return abs(negfive) end)() == 25)
assert((function() local abs = math.abs function foo() return abs(-5) end return foo() end)() == 25) assert((function() local abs = math.abs function foo() return abs(negfive) end return foo() end)() == 25)
end end
-- you need to have enough arguments and arguments of the right type; if you don't, we'll fallback to the regular code. This checks coercions -- you need to have enough arguments and arguments of the right type; if you don't, we'll fallback to the regular code. This checks coercions

View File

@ -23,10 +23,13 @@
#include <optional> #include <optional>
// Indicates if verbose output is enabled. // Indicates if verbose output is enabled; can be overridden via --verbose
// Currently, this enables output from lua's 'print', but other verbose output could be enabled eventually. // Currently, this enables output from 'print', but other verbose output could be enabled eventually.
bool verbose = false; bool verbose = false;
// Default optimization level for conformance test; can be overridden via -On
int optimizationLevel = 1;
static bool skipFastFlag(const char* flagName) static bool skipFastFlag(const char* flagName)
{ {
if (strncmp(flagName, "Test", 4) == 0) if (strncmp(flagName, "Test", 4) == 0)
@ -249,6 +252,15 @@ int main(int argc, char** argv)
verbose = true; verbose = true;
} }
int level = -1;
if (doctest::parseIntOption(argc, argv, "-O", doctest::option_int, level))
{
if (level < 0 || level > 2)
std::cerr << "Optimization level must be between 0 and 2 inclusive." << std::endl;
else
optimizationLevel = level;
}
if (std::vector<doctest::String> flags; doctest::parseCommaSepArgs(argc, argv, "--fflags=", flags)) if (std::vector<doctest::String> flags; doctest::parseCommaSepArgs(argc, argv, "--fflags=", flags))
setFastFlags(flags); setFastFlags(flags);
@ -279,6 +291,7 @@ int main(int argc, char** argv)
if (doctest::parseFlag(argc, argv, "--help") || doctest::parseFlag(argc, argv, "-h")) if (doctest::parseFlag(argc, argv, "--help") || doctest::parseFlag(argc, argv, "-h"))
{ {
printf("Additional command line options:\n"); printf("Additional command line options:\n");
printf(" -O[n] Changes default optimization level (1) for conformance runs\n");
printf(" --verbose Enables verbose output (e.g. lua 'print' statements)\n"); printf(" --verbose Enables verbose output (e.g. lua 'print' statements)\n");
printf(" --fflags= Sets specified fast flags\n"); printf(" --fflags= Sets specified fast flags\n");
printf(" --list-fflags List all fast flags\n"); printf(" --list-fflags List all fast flags\n");

View File

@ -16,7 +16,11 @@ state = 0
# parse input into errors[] with the state machine; this is using doctest output and expects multi-line match failures # parse input into errors[] with the state machine; this is using doctest output and expects multi-line match failures
for line in input: for line in input:
if state == 0: if state == 0:
match = re.match("tests/[^:]+:(\d+): ERROR: CHECK_EQ", line) if sys.platform == "win32":
match = re.match("[^(]+\((\d+)\): ERROR: CHECK_EQ", line)
else:
match = re.match("tests/[^:]+:(\d+): ERROR: CHECK_EQ", line)
if match: if match:
error_line = int(match[1]) error_line = int(match[1])
state = 1 state = 1
@ -52,12 +56,16 @@ result = []
current = 0 current = 0
index = 0 index = 0
target = 0
while index < len(source): while index < len(source):
line = source[index] line = source[index]
error = errors[current] if current < len(errors) else None error = errors[current] if current < len(errors) else None
if not error or index < error[0] or line != error[1][0]: if error:
target = error[0] if sys.platform != "win32" else error[0] - len(error[1]) - 1
if not error or index < target or line != error[1][0]:
result.append(line) result.append(line)
index += 1 index += 1
else: else: