Skip to content

Commit

Permalink
Improve catalog function interface
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jan 18, 2025
1 parent af44675 commit 631d7be
Show file tree
Hide file tree
Showing 22 changed files with 98 additions and 148 deletions.
2 changes: 1 addition & 1 deletion extension/fts/test/test_files/fts_small.test
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ Binder exception: BM25 model requires the Term Frequency Saturation(k) value to
-RELOADDB
-STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'Alice') RETURN _node.ID, score
---- error
Catalog exception: QUERY_FTS_INDEX function does not exist.
Catalog exception: function QUERY_FTS_INDEX does not exist.
-STATEMENT load extension "${KUZU_ROOT_DIRECTORY}/extension/fts/build/libfts.kuzu_extension"
---- ok
-STATEMENT CALL QUERY_FTS_INDEX('doc', 'docIdx', 'Alice') RETURN _node.ID, score
Expand Down
5 changes: 1 addition & 4 deletions src/binder/bind/bind_standalone_call_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#include "binder/bound_standalone_call_function.h"
#include "catalog/catalog.h"
#include "common/exception/binder.h"
#include "function/built_in_function_utils.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/standalone_call_function.h"

Expand All @@ -17,9 +16,7 @@ std::unique_ptr<BoundStatement> Binder::bindStandaloneCallFunction(
auto& funcExpr =
callStatement.getFunctionExpression()->constCast<parser::ParsedFunctionExpression>();
auto funcName = funcExpr.getFunctionName();
auto catalogSet = clientContext->getCatalog()->getFunctions(clientContext->getTransaction());
auto entry = function::BuiltInFunctionsUtils::getFunctionCatalogEntry(
clientContext->getTransaction(), funcName, catalogSet);
auto entry = clientContext->getCatalog()->getFunctionEntry(clientContext->getTransaction(), funcName);
if (entry->getType() != catalog::CatalogEntryType::STANDALONE_TABLE_FUNCTION_ENTRY) {
throw common::BinderException(
"Only standalone table functions can be called without return statement.");
Expand Down
6 changes: 2 additions & 4 deletions src/binder/bind/bind_table_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ static void validateParameterType(const expression_vector& positionalParams) {

BoundTableFunction Binder::bindTableFunc(std::string tableFuncName,
const parser::ParsedExpression& expr, expression_vector& columns) {
auto functions = clientContext->getCatalog()->getFunctions(clientContext->getTransaction());
auto entry = BuiltInFunctionsUtils::getFunctionCatalogEntry(clientContext->getTransaction(),
tableFuncName, functions);
auto entry = clientContext->getCatalog()->getFunctionEntry(clientContext->getTransaction(), tableFuncName);
expression_vector positionalParams;
std::vector<LogicalType> positionalParamTypes;
optional_params_t optionalParams;
Expand All @@ -38,7 +36,7 @@ BoundTableFunction Binder::bindTableFunc(std::string tableFuncName,
optionalParams.emplace(childExpr.getAlias(), literalExpr->getValue());
}
}
auto func = BuiltInFunctionsUtils::matchFunction(tableFuncName, positionalParamTypes, entry);
auto func = BuiltInFunctionsUtils::matchFunction(tableFuncName, positionalParamTypes, entry->ptrCast<catalog::FunctionCatalogEntry>());
validateParameterType(positionalParams);
auto tableFunc = func->constPtrCast<TableFunction>();
for (auto i = 0u; i < positionalParams.size(); ++i) {
Expand Down
5 changes: 2 additions & 3 deletions src/binder/bind/copy/bind_copy_to.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ std::unique_ptr<BoundStatement> Binder::bindCopyToClause(const Statement& statem
auto parsedQuery = copyToStatement.getStatement()->constPtrCast<RegularQuery>();
auto query = bindQuery(*parsedQuery);
auto columns = query->getStatementResult()->getColumns();
auto functions = clientContext->getCatalog()->getFunctions(clientContext->getTransaction());
auto fileTypeStr = fileTypeInfo.fileTypeStr;
auto name = common::stringFormat("COPY_{}", fileTypeStr);
auto exportFunc = function::BuiltInFunctionsUtils::matchFunction(
clientContext->getTransaction(), name, functions)
auto entry = clientContext->getCatalog()->getFunctionEntry(clientContext->getTransaction(), name);
auto exportFunc = function::BuiltInFunctionsUtils::matchFunction(name, entry->ptrCast<catalog::FunctionCatalogEntry>())
->constPtrCast<function::ExportFunction>();
for (auto& column : columns) {
auto columnName = column->hasAlias() ? column->getAlias() : column->toString();
Expand Down
6 changes: 2 additions & 4 deletions src/binder/bind/read/bind_in_query_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
auto functionName = functionExpr->getFunctionName();
std::unique_ptr<BoundReadingClause> boundReadingClause;
expression_vector columns;
auto catalogSet = clientContext->getCatalog()->getFunctions(clientContext->getTransaction());
auto entry = BuiltInFunctionsUtils::getFunctionCatalogEntry(clientContext->getTransaction(),
functionName, catalogSet);
auto entry = clientContext->getCatalog()->getFunctionEntry(clientContext->getTransaction(), functionName);
switch (entry->getType()) {
case CatalogEntryType::TABLE_FUNCTION_ENTRY: {
auto boundTableFunction = bindTableFunc(functionName, *functionExpr, columns);
Expand All @@ -51,7 +49,7 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
literalExpr->getValue());
}
}
auto func = BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes, entry);
auto func = BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes, entry->ptrCast<FunctionCatalogEntry>());
auto gdsFunc = func->constPtrCast<GDSFunction>()->copy();
auto input = GDSBindInput();
input.params = children;
Expand Down
6 changes: 2 additions & 4 deletions src/binder/bind_expression/bind_comparison_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
ExpressionType expressionType, const expression_vector& children) {
auto catalog = context->getCatalog();
auto transaction = context->getTransaction();
auto functions = catalog->getFunctions(transaction);
auto functionName = ExpressionTypeUtil::toString(expressionType);
LogicalType combinedType(LogicalTypeID::ANY);
if (!ExpressionUtil::tryCombineDataType(children, combinedType)) {
Expand All @@ -42,9 +41,8 @@ std::shared_ptr<Expression> ExpressionBinder::bindComparisonExpression(
for (auto i = 0u; i < children.size(); i++) {
childrenTypes.push_back(combinedType.copy());
}
auto function =
BuiltInFunctionsUtils::matchFunction(transaction, functionName, childrenTypes, functions)
->ptrCast<ScalarFunction>();
auto entry = catalog->getFunctionEntry(transaction, functionName);
auto function = BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes, entry->ptrCast<catalog::FunctionCatalogEntry>())->ptrCast<ScalarFunction>();
expression_vector childrenAfterCast;
for (auto i = 0u; i < children.size(); ++i) {
if (children[i]->dataType != combinedType) {
Expand Down
19 changes: 9 additions & 10 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ std::shared_ptr<Expression> ExpressionBinder::bindScalarFunctionExpression(
auto catalog = context->getCatalog();
auto transaction = context->getTransaction();
auto childrenTypes = getTypes(children);
auto functions = catalog->getFunctions(transaction);

auto entry = catalog->getFunctionEntry(transaction, functionName);

auto function =
BuiltInFunctionsUtils::matchFunction(transaction, functionName, childrenTypes, functions)
->ptrCast<ScalarFunction>()
->copy();
BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes, entry->ptrCast<FunctionCatalogEntry>())->ptrCast<ScalarFunction>()->copy();
if (children.size() == 2 && children[1]->expressionType == ExpressionType::LAMBDA) {
if (!function->isListLambda) {
throw BinderException(stringFormat("{} does not support lambda input.", functionName));
Expand Down Expand Up @@ -133,9 +133,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindRewriteFunctionExpression(
children.push_back(bindExpression(*expr.getChild(i)));
}
auto childrenTypes = getTypes(children);
auto functions = context->getCatalog()->getFunctions(context->getTransaction());
auto match = BuiltInFunctionsUtils::matchFunction(context->getTransaction(),
funcExpr.getNormalizedFunctionName(), childrenTypes, functions);
auto functionName = funcExpr.getNormalizedFunctionName();
auto entry = context->getCatalog()->getFunctionEntry(context->getTransaction(), functionName);
auto match = BuiltInFunctionsUtils::matchFunction(functionName, childrenTypes, entry->ptrCast<FunctionCatalogEntry>());
auto function = match->constPtrCast<RewriteFunction>();
KU_ASSERT(function->rewriteFunc != nullptr);
return function->rewriteFunc(children, this);
Expand All @@ -150,10 +150,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindAggregateFunctionExpression(
childrenTypes.push_back(child->dataType.copy());
children.push_back(std::move(child));
}
auto functions = context->getCatalog()->getFunctions(context->getTransaction());
auto entry = context->getCatalog()->getFunctionEntry(context->getTransaction(), functionName);
auto function = BuiltInFunctionsUtils::matchAggregateFunction(functionName, childrenTypes,
isDistinct, functions)
->copy();
isDistinct, entry->ptrCast<FunctionCatalogEntry>())->copy();
if (function.paramRewriteFunc) {
function.paramRewriteFunc(children);
}
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind_expression/bind_subquery_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ std::shared_ptr<Expression> ExpressionBinder::bindSubqueryExpression(
std::move(boundGraphPattern.queryGraphCollection), uniqueName, std::move(rawName));
boundSubqueryExpr->setWhereExpression(boundGraphPattern.where);
// Bind projection
auto functions = context->getCatalog()->getFunctions(context->getTransaction());
auto entry = context->getCatalog()->getFunctionEntry(context->getTransaction(), CountStarFunction::name);
auto function = BuiltInFunctionsUtils::matchAggregateFunction(CountStarFunction::name,
std::vector<LogicalType>{}, false, functions);
std::vector<LogicalType>{}, false, entry->ptrCast<catalog::FunctionCatalogEntry>());
auto bindData = std::make_unique<FunctionBindData>(LogicalType(function->returnTypeID));
auto countStarExpr =
std::make_shared<AggregateFunctionExpression>(function->copy(), std::move(bindData),
Expand Down
29 changes: 16 additions & 13 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

using namespace kuzu::catalog;
using namespace kuzu::common;
using namespace kuzu::function;
using namespace kuzu::parser;
using namespace kuzu::processor;

Expand Down Expand Up @@ -228,31 +229,33 @@ void Binder::restoreScope(BinderScope prevScope) {
scope = std::move(prevScope);
}

function::TableFunction Binder::getScanFunction(FileTypeInfo typeInfo,
TableFunction Binder::getScanFunction(FileTypeInfo typeInfo,
const FileScanInfo& fileScanInfo) {
function::Function* func = nullptr;
Function* func = nullptr;
std::vector<LogicalType> inputTypes;
inputTypes.push_back(LogicalType::STRING());
auto functions = clientContext->getCatalog()->getFunctions(clientContext->getTransaction());
auto catalog = clientContext->getCatalog();
auto transaction = clientContext->getTransaction();
switch (typeInfo.fileType) {
case FileType::PARQUET: {
func = function::BuiltInFunctionsUtils::matchFunction(clientContext->getTransaction(),
ParquetScanFunction::name, inputTypes, functions);
auto entry = catalog->getFunctionEntry(transaction, ParquetScanFunction::name);
func = BuiltInFunctionsUtils::matchFunction(ParquetScanFunction::name, inputTypes, entry->ptrCast<FunctionCatalogEntry>());
} break;
case FileType::NPY: {
func = function::BuiltInFunctionsUtils::matchFunction(clientContext->getTransaction(),
NpyScanFunction::name, inputTypes, functions);
auto entry = catalog->getFunctionEntry(transaction, NpyScanFunction::name);
func = BuiltInFunctionsUtils::matchFunction(NpyScanFunction::name, inputTypes, entry->ptrCast<FunctionCatalogEntry>());
} break;
case FileType::CSV: {
auto csvConfig = CSVReaderConfig::construct(fileScanInfo.options);
func = function::BuiltInFunctionsUtils::matchFunction(clientContext->getTransaction(),
csvConfig.parallel ? ParallelCSVScan::name : SerialCSVScan::name, inputTypes,
functions);
auto name = csvConfig.parallel ? ParallelCSVScan::name : SerialCSVScan::name;
auto entry = catalog->getFunctionEntry(transaction, name);
func = BuiltInFunctionsUtils::matchFunction(name, inputTypes, entry->ptrCast<FunctionCatalogEntry>());
} break;
case FileType::UNKNOWN: {
try {
func = function::BuiltInFunctionsUtils::matchFunction(clientContext->getTransaction(),
common::stringFormat("{}_SCAN", typeInfo.fileTypeStr), inputTypes, functions);
auto name = common::stringFormat("{}_SCAN", typeInfo.fileTypeStr);
auto entry = catalog->getFunctionEntry(transaction, name);
func = BuiltInFunctionsUtils::matchFunction(name, inputTypes, entry->ptrCast<FunctionCatalogEntry>());
} catch (...) {
if (typeInfo.fileTypeStr == "") {
throw common::BinderException{
Expand All @@ -266,7 +269,7 @@ function::TableFunction Binder::getScanFunction(FileTypeInfo typeInfo,
default:
KU_UNREACHABLE;
}
return *func->ptrCast<function::TableFunction>();
return *func->ptrCast<TableFunction>();
}

} // namespace binder
Expand Down
37 changes: 20 additions & 17 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "storage/storage_utils.h"
#include "storage/storage_version_info.h"
#include "transaction/transaction.h"
#include "function/function_collection.h"

using namespace kuzu::binder;
using namespace kuzu::common;
Expand Down Expand Up @@ -348,6 +349,10 @@ void Catalog::dropIndex(Transaction* transaction, table_id_t tableID,
indexes->dropEntry(transaction, uniqueName, entry->getOID());
}

bool Catalog::containsFunction(const Transaction* transaction, const std::string& name) {
return functions->containsEntry(transaction, name);
}

void Catalog::addFunction(Transaction* transaction, CatalogEntryType entryType, std::string name,
function::function_set functionSet) {
if (functions->containsEntry(transaction, name)) {
Expand All @@ -358,29 +363,19 @@ void Catalog::addFunction(Transaction* transaction, CatalogEntryType entryType,
}

void Catalog::dropFunction(Transaction* transaction, const std::string& name) {
const auto entry = functions->getEntry(transaction, name);
if (entry == nullptr) {
if (!containsFunction(transaction, name)) {
throw CatalogException{stringFormat("function {} doesn't exist.", name)};
}
auto entry = getFunctionEntry(transaction, name);
functions->dropEntry(transaction, name, entry->getOID());
}

void Catalog::addBuiltInFunction(CatalogEntryType entryType, std::string name,
function::function_set functionSet) {
addFunction(&DUMMY_TRANSACTION, entryType, std::move(name), std::move(functionSet));
}

CatalogSet* Catalog::getFunctions(Transaction*) const {
return functions.get();
}

CatalogEntry* Catalog::getFunctionEntry(const Transaction* transaction,
const std::string& name) const {
const auto catalogSet = functions.get();
if (!catalogSet->containsEntry(transaction, name)) {
if (!functions->containsEntry(transaction, name)) {
throw CatalogException(stringFormat("function {} does not exist.", name));
}
return catalogSet->getEntry(transaction, name);
return functions->getEntry(transaction, name);
}

std::vector<FunctionCatalogEntry*> Catalog::getFunctionEntries(
Expand All @@ -403,11 +398,12 @@ function::ScalarMacroFunction* Catalog::getScalarMacroFunction(const Transaction
.getMacroFunction();
}

// addScalarMacroFunction
void Catalog::addScalarMacroFunction(Transaction* transaction, std::string name,
std::unique_ptr<function::ScalarMacroFunction> macro) {
auto scalarMacroCatalogEntry =
auto entry =
std::make_unique<ScalarMacroCatalogEntry>(std::move(name), std::move(macro));
functions->createEntry(transaction, std::move(scalarMacroCatalogEntry));
functions->createEntry(transaction, std::move(entry));
}

std::vector<std::string> Catalog::getMacroNames(const Transaction* transaction) const {
Expand Down Expand Up @@ -494,7 +490,14 @@ void Catalog::readFromFile(const std::string& directory, VirtualFileSystem* fs,
}

void Catalog::registerBuiltInFunctions() {
function::BuiltInFunctionsUtils::createFunctions(&DUMMY_TRANSACTION, functions.get());
auto functionCollection = function::FunctionCollection::getFunctions();
for (auto i = 0u; functionCollection[i].name != nullptr; ++i) {
auto& f = functionCollection[i];
auto functionSet = f.getFunctionSetFunc();
functions->createEntry(&DUMMY_TRANSACTION,
std::make_unique<FunctionCatalogEntry>(f.catalogEntryType, f.name,
std::move(functionSet)));
}
}

std::unique_ptr<CatalogEntry> Catalog::createNodeTableEntry(Transaction*,
Expand Down
6 changes: 2 additions & 4 deletions src/extension/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,7 @@ void ExtensionUtils::registerTableFunction(main::Database& database,
function::function_set functionSet;
functionSet.push_back(std::move(function));
auto catalog = database.getCatalog();
if (catalog->getFunctions(&transaction::DUMMY_TRANSACTION)
->containsEntry(&transaction::DUMMY_TRANSACTION, name)) {
if (catalog->containsFunction(&transaction::DUMMY_TRANSACTION, name)) {
return;
}
catalog->addFunction(&transaction::DUMMY_TRANSACTION,
Expand Down Expand Up @@ -153,8 +152,7 @@ std::string ExtensionUtils::getLocalPathForSharedLib(main::ClientContext* contex
void ExtensionUtils::registerFunctionSet(main::Database& database, std::string name,
function::function_set functionSet, catalog::CatalogEntryType functionType) {
auto catalog = database.getCatalog();
if (catalog->getFunctions(&transaction::DUMMY_TRANSACTION)
->containsEntry(&transaction::DUMMY_TRANSACTION, name)) {
if (catalog->containsFunction(&transaction::DUMMY_TRANSACTION, name)) {
return;
}
catalog->addFunction(&transaction::DUMMY_TRANSACTION, functionType, std::move(name),
Expand Down
Loading

0 comments on commit 631d7be

Please sign in to comment.