// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/Config.h" #include "Luau/FileResolver.h" #include "Luau/Frontend.h" #include "Luau/IostreamHelpers.h" #include "Luau/Linter.h" #include "Luau/Location.h" #include "Luau/ModuleResolver.h" #include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "IostreamOptional.h" #include "ScopedFlags.h" #include #include #include #include namespace Luau { struct TestFileResolver : FileResolver , ModuleResolver { std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override { if (auto name = pathExprToModuleName(currentModuleName, pathExpr)) return {{*name, false}}; return std::nullopt; } const ModulePtr getModule(const ModuleName& moduleName) const override { LUAU_ASSERT(false); return nullptr; } bool moduleExists(const ModuleName& moduleName) const override { auto it = source.find(moduleName); return (it != source.end()); } std::optional readSource(const ModuleName& name) override { auto it = source.find(name); if (it == source.end()) return std::nullopt; SourceCode::Type sourceType = SourceCode::Module; auto it2 = sourceTypes.find(name); if (it2 != sourceTypes.end()) sourceType = it2->second; return SourceCode{it->second, sourceType}; } std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override; std::string getHumanReadableModuleName(const ModuleName& name) const override; std::optional getEnvironmentForModule(const ModuleName& name) const override; std::unordered_map source; std::unordered_map sourceTypes; std::unordered_map environments; }; struct TestConfigResolver : ConfigResolver { Config defaultConfig; std::unordered_map configFiles; const Config& getConfig(const ModuleName& name) const override { auto it = configFiles.find(name); if (it != configFiles.end()) return it->second; return defaultConfig; } }; struct Fixture { explicit Fixture(bool freeze = true, bool prepareAutocomplete = false); ~Fixture(); // Throws Luau::ParseErrors if the parse fails. AstStatBlock* parse(const std::string& source, const ParseOptions& parseOptions = {}); CheckResult check(Mode mode, std::string source); CheckResult check(const std::string& source); LintResult lint(const std::string& source, const std::optional& lintOptions = {}); LintResult lintTyped(const std::string& source, const std::optional& lintOptions = {}); /// Parse with all language extensions enabled ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {}); ParseResult tryParse(const std::string& source, const ParseOptions& parseOptions = {}); ParseResult matchParseError(const std::string& source, const std::string& message, std::optional location = std::nullopt); // Verify a parse error occurs and the parse error message has the specified prefix ParseResult matchParseErrorPrefix(const std::string& source, const std::string& prefix); ModulePtr getMainModule(); SourceModule* getMainSourceModule(); std::optional getPrimitiveType(TypeId ty); std::optional getType(const std::string& name); TypeId requireType(const std::string& name); TypeId requireType(const ModuleName& moduleName, const std::string& name); TypeId requireType(const ModulePtr& module, const std::string& name); TypeId requireType(const ScopePtr& scope, const std::string& name); std::optional findTypeAtPosition(Position position); TypeId requireTypeAtPosition(Position position); std::optional findExpectedTypeAtPosition(Position position); std::optional lookupType(const std::string& name); std::optional lookupImportedType(const std::string& moduleAlias, const std::string& name); ScopedFastFlag sff_DebugLuauFreezeArena; ScopedFastFlag sff_UnknownNever{"LuauUnknownAndNeverType", true}; TestFileResolver fileResolver; TestConfigResolver configResolver; std::unique_ptr sourceModule; Frontend frontend; InternalErrorReporter ice; TypeChecker& typeChecker; std::string decorateWithTypes(const std::string& code); void dumpErrors(std::ostream& os, const std::vector& errors); void dumpErrors(const CheckResult& cr); void dumpErrors(const ModulePtr& module); void dumpErrors(const Module& module); void validateErrors(const std::vector& errors); std::string getErrors(const CheckResult& cr); void registerTestTypes(); LoadDefinitionFileResult loadDefinition(const std::string& source); }; struct BuiltinsFixture : Fixture { BuiltinsFixture(bool freeze = true, bool prepareAutocomplete = false); }; struct ConstraintGraphBuilderFixture : Fixture { TypeArena arena; ModulePtr mainModule; ConstraintGraphBuilder cgb; ScopedFastFlag forceTheFlag; ConstraintGraphBuilderFixture(); }; ModuleName fromString(std::string_view name); template std::optional get(const std::map& map, const Name& name) { auto it = map.find(name); if (it != map.end()) return std::optional(it->second); else return std::nullopt; } std::string rep(const std::string& s, size_t n); bool isInArena(TypeId t, const TypeArena& arena); void dumpErrors(const ModulePtr& module); void dumpErrors(const Module& module); void dump(const std::string& name, TypeId ty); void dump(const std::vector& constraints); std::optional lookupName(ScopePtr scope, const std::string& name); // Warning: This function runs in O(n**2) std::optional linearSearchForBinding(Scope* scope, const char* name); struct Nth { int classIndex; int nth; }; template Nth nth(int nth = 1) { static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); LUAU_ASSERT(nth > 0); // Did you mean to use `nth(1)`? return Nth{T::ClassIndex(), nth}; } struct FindNthOccurenceOf : public AstVisitor { Nth requestedNth; size_t currentOccurrence = 0; AstNode* theNode = nullptr; FindNthOccurenceOf(Nth nth); bool checkIt(AstNode* n); bool visit(AstNode* n) override; bool visit(AstType* n) override; bool visit(AstTypePack* n) override; }; /** DSL querying of the AST. * * Given an AST, one can query for a particular node directly without having to manually unwrap the tree, for example: * * ``` * if a and b then * print(a + b) * end * * function f(x, y) * return x + y * end * ``` * * There are numerous ways to access the second AstExprBinary. * 1. Luau::query(block, {nth(), nth()}) * 2. Luau::query(Luau::query(block)) * 3. Luau::query(block, {nth(2)}) */ template T* query(AstNode* node, const std::vector& nths = {nth(N)}) { static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); // If a nested query call fails to find the node in question, subsequent calls can propagate rather than trying to do more. // This supports `query(query(...))` for (Nth nth : nths) { if (!node) return nullptr; FindNthOccurenceOf finder{nth}; node->visit(&finder); node = finder.theNode; } return node ? node->as() : nullptr; } } // namespace Luau #define LUAU_REQUIRE_ERRORS(result) \ do \ { \ auto&& r = (result); \ validateErrors(r.errors); \ REQUIRE(!r.errors.empty()); \ } while (false) #define LUAU_REQUIRE_ERROR_COUNT(count, result) \ do \ { \ auto&& r = (result); \ validateErrors(r.errors); \ REQUIRE_MESSAGE(count == r.errors.size(), getErrors(r)); \ } while (false) #define LUAU_REQUIRE_NO_ERRORS(result) LUAU_REQUIRE_ERROR_COUNT(0, result)