diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 50299704..2904ef8d 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -427,6 +427,36 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos) return findVisitor.result; } +static std::optional checkOverloadedDocumentationSymbol( + const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional documentationSymbol) +{ + if (!documentationSymbol) + return std::nullopt; + + // This might be an overloaded function. + if (get(follow(ty))) + { + TypeId matchingOverload = nullptr; + if (parentExpr && parentExpr->is()) + { + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) + { + matchingOverload = *it; + } + } + + if (matchingOverload) + { + std::string overloadSymbol = *documentationSymbol + "/overload/"; + // Default toString options are fine for this purpose. + overloadSymbol += toString(matchingOverload); + return overloadSymbol; + } + } + + return documentationSymbol; +} + std::optional getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position) { std::vector ancestry = findAstAncestryOfPosition(source, position); @@ -436,31 +466,7 @@ std::optional getDocumentationSymbolAtPosition(const Source if (std::optional binding = findBindingAtPosition(module, source, position)) { - if (binding->documentationSymbol) - { - // This might be an overloaded function binding. - if (get(follow(binding->typeId))) - { - TypeId matchingOverload = nullptr; - if (parentExpr && parentExpr->is()) - { - if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) - { - matchingOverload = *it; - } - } - - if (matchingOverload) - { - std::string overloadSymbol = *binding->documentationSymbol + "/overload/"; - // Default toString options are fine for this purpose. - overloadSymbol += toString(matchingOverload); - return overloadSymbol; - } - } - } - - return binding->documentationSymbol; + return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol); } if (targetExpr) @@ -474,14 +480,14 @@ std::optional getDocumentationSymbolAtPosition(const Source { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) { - return propIt->second.documentationSymbol; + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); } } else if (const ClassTypeVar* ctv = get(parentTy)) { if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) { - return propIt->second.documentationSymbol; + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); } } } diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 4b21c443..a642334a 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -72,6 +72,73 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_fn") CHECK_EQ(symbol, "@test/global/foo/overload/(string) -> number"); } +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "class_method") +{ + loadDefinition(R"( + declare class Foo + function bar(self, x: string): number + end + )"); + + std::optional symbol = getDocSymbol(R"( + local x: Foo + x:bar("asdf") + )", + Position(2, 11)); + + CHECK_EQ(symbol, "@test/globaltype/Foo.bar"); +} + +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_class_method") +{ + loadDefinition(R"( + declare class Foo + function bar(self, x: string): number + function bar(self, x: number): string + end + )"); + + std::optional symbol = getDocSymbol(R"( + local x: Foo + x:bar("asdf") + )", + Position(2, 11)); + + CHECK_EQ(symbol, "@test/globaltype/Foo.bar/overload/(Foo, string) -> number"); +} + +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "table_function_prop") +{ + loadDefinition(R"( + declare Foo: { + new: (number) -> string + } + )"); + + std::optional symbol = getDocSymbol(R"( + Foo.new("asdf") + )", + Position(1, 13)); + + CHECK_EQ(symbol, "@test/global/Foo.new"); +} + +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "table_overloaded_function_prop") +{ + loadDefinition(R"( + declare Foo: { + new: ((number) -> string) & ((string) -> number) + } + )"); + + std::optional symbol = getDocSymbol(R"( + Foo.new("asdf") + )", + Position(1, 13)); + + CHECK_EQ(symbol, "@test/global/Foo.new/overload/(string) -> number"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("AstQuery");