Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clear catalog interface #4726

Merged
merged 3 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion extension/fts/src/catalog/fts_index_catalog_entry.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "catalog/fts_index_catalog_entry.h"

#include "catalog/catalog.h"
#include "catalog/catalog_entry/table_catalog_entry.h"
#include "common/serializer/buffered_reader.h"
#include "common/serializer/buffered_serializer.h"
#include "main/client_context.h"
Expand Down Expand Up @@ -36,7 +37,9 @@ std::string FTSIndexAuxInfo::toCypher(const catalog::IndexCatalogEntry& indexEnt
main::ClientContext* context) {
std::string cypher;
auto catalog = context->getCatalog();
auto tableName = catalog->getTableName(context->getTransaction(), indexEntry.getTableID());
auto tableName =
catalog->getTableCatalogEntry(context->getTransaction(), indexEntry.getTableID())
->getName();
std::string propertyStr;
for (auto i = 0u; i < properties.size(); i++) {
propertyStr +=
Expand Down
4 changes: 1 addition & 3 deletions extension/fts/src/fts_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ namespace kuzu {
namespace fts_extension {

static void initFTSEntries(const transaction::Transaction* transaction, catalog::Catalog& catalog) {
auto indexEntries = catalog.getIndexes()->getEntries(transaction);
for (auto& [_, entry] : indexEntries) {
auto indexEntry = entry->ptrCast<catalog::IndexCatalogEntry>();
for (auto& indexEntry : catalog.getIndexEntries(transaction)) {
if (indexEntry->getIndexType() == FTSIndexCatalogEntry::TYPE_NAME) {
indexEntry->setAuxInfo(FTSIndexAuxInfo::deserialize(indexEntry->getAuxBufferReader()));
}
Expand Down
5 changes: 4 additions & 1 deletion extension/fts/src/function/query_fts_bind_data.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "function/query_fts_bind_data.h"

#include "binder/expression/expression_util.h"
#include "catalog/catalog_entry/table_catalog_entry.h"
#include "catalog/fts_index_catalog_entry.h"
#include "common/string_utils.h"
#include "function/fts_utils.h"
Expand Down Expand Up @@ -57,7 +58,9 @@ struct StopWordsChecker {
StopWordsChecker::StopWordsChecker(main::ClientContext& context)
: termsVector{LogicalType::STRING(), context.getMemoryManager()}, tx{context.getTransaction()} {
termsVector.state = common::DataChunkState::getSingleValueDataChunkState();
auto tableID = context.getCatalog()->getTableID(tx, FTSUtils::getStopWordsTableName());
auto tableID = context.getCatalog()
->getTableCatalogEntry(tx, FTSUtils::getStopWordsTableName())
->getTableID();
stopWordsTable = context.getStorageManager()->getTable(tableID)->ptrCast<storage::NodeTable>();
}

Expand Down
5 changes: 3 additions & 2 deletions src/binder/bind/bind_ddl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "catalog/catalog_entry/rel_table_catalog_entry.h"
#include "catalog/catalog_entry/sequence_catalog_entry.h"
#include "common/exception/binder.h"
#include "common/exception/message.h"
#include "common/string_format.h"
Expand Down Expand Up @@ -84,7 +85,7 @@ std::vector<PropertyDefinition> Binder::bindPropertyDefinitions(
if (type.getLogicalTypeID() == LogicalTypeID::SERIAL) {
validateSerialNoDefault(*boundExpr);
expr = ParsedExpressionUtils::getSerialDefaultExpr(
Catalog::genSerialName(tableName, parsedDefinition.getName()));
SequenceCatalogEntry::genSerialName(tableName, parsedDefinition.getName()));
}
auto columnDefinition = ColumnDefinition(parsedDefinition.getName(), std::move(type));
definitions.emplace_back(std::move(columnDefinition), std::move(expr));
Expand Down Expand Up @@ -496,7 +497,7 @@ std::unique_ptr<BoundStatement> Binder::bindAddProperty(const Statement& stateme
if (dataType.getLogicalTypeID() == LogicalTypeID::SERIAL) {
validateSerialNoDefault(*boundDefault);
defaultValue = ParsedExpressionUtils::getSerialDefaultExpr(
Catalog::genSerialName(tableName, propertyName));
SequenceCatalogEntry::genSerialName(tableName, propertyName));
boundDefault = expressionBinder.implicitCastIfNecessary(
expressionBinder.bindExpression(*defaultValue), dataType);
}
Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind/bind_export_database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ static void bindExportRelTableDataQuery(const TableCatalogEntry& entry, std::str
auto relTableEntry = entry.constPtrCast<RelTableCatalogEntry>();
auto srcPrimaryKeyName = getPrimaryKeyName(relTableEntry->getSrcTableID(), catalog, tx);
auto dstPrimaryKeyName = getPrimaryKeyName(relTableEntry->getDstTableID(), catalog, tx);
auto srcName = catalog.getTableName(tx, relTableEntry->getSrcTableID());
auto dstName = catalog.getTableName(tx, relTableEntry->getDstTableID());
auto srcName = catalog.getTableCatalogEntry(tx, relTableEntry->getSrcTableID())->getName();
auto dstName = catalog.getTableCatalogEntry(tx, relTableEntry->getDstTableID())->getName();
auto relName = relTableEntry->getName();
exportQuery = stringFormat("match (a:{})-[r:{}]->(b:{}) return a.{},b.{},r.*;", srcName,
relName, dstName, srcPrimaryKeyName, dstPrimaryKeyName);
Expand Down
3 changes: 1 addition & 2 deletions src/binder/bind/bind_file_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ static TableFunction getObjectScanFunc(const std::string& dbName, const std::str
throw BinderException{stringFormat("No database named {} has been attached.", dbName)};
}
auto attachedCatalog = attachedDB->getCatalog();
auto tableID = attachedCatalog->getTableID(clientContext->getTransaction(), tableName);
auto entry = attachedCatalog->getTableCatalogEntry(clientContext->getTransaction(), tableID);
auto entry = attachedCatalog->getTableCatalogEntry(clientContext->getTransaction(), tableName);
return entry->ptrCast<TableCatalogEntry>()->getScanFunction();
}

Expand Down
3 changes: 1 addition & 2 deletions src/binder/bind/copy/bind_copy_from.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& stat
validateTableExist(tableName);
// Bind to table schema.
auto catalog = clientContext->getCatalog();
auto tableID = catalog->getTableID(clientContext->getTransaction(), tableName);
auto tableEntry = catalog->getTableCatalogEntry(clientContext->getTransaction(), tableID);
auto tableEntry = catalog->getTableCatalogEntry(clientContext->getTransaction(), tableName);
switch (tableEntry->getTableType()) {
case TableType::REL_GROUP: {
throw BinderException(stringFormat("Cannot copy into {} table with type {}.", tableName,
Expand Down
36 changes: 16 additions & 20 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,21 @@ std::shared_ptr<Expression> ExpressionBinder::bindEndNodeExpression(const Expres
return rel.getDstNode();
}

static std::vector<std::unique_ptr<Value>> populateLabelValues(std::vector<table_id_t> tableIDs,
const catalog::Catalog& catalog, transaction::Transaction* tx) {
auto tableIDsSet = std::unordered_set<table_id_t>(tableIDs.begin(), tableIDs.end());
table_id_t maxTableID = *std::max_element(tableIDsSet.begin(), tableIDsSet.end());
static std::vector<std::unique_ptr<Value>> populateLabelValues(
std::vector<TableCatalogEntry*> entries) {
std::unordered_map<table_id_t, std::string> map;
common::table_id_t maxTableID = 0;
for (auto& entry : entries) {
map.insert({entry->getTableID(), entry->getName()});
if (entry->getTableID() > maxTableID) {
maxTableID = entry->getTableID();
}
}
std::vector<std::unique_ptr<Value>> labels;
labels.resize(maxTableID + 1);
for (auto i = 0u; i < labels.size(); ++i) {
if (tableIDsSet.contains(i)) {
labels[i] = std::make_unique<Value>(LogicalType::STRING(), catalog.getTableName(tx, i));
if (map.contains(i)) {
labels[i] = std::make_unique<Value>(LogicalType::STRING(), map.at(i));
} else {
// TODO(Xiyang/Guodong): change to null literal once we support null in LIST type.
labels[i] = std::make_unique<Value>(LogicalType::STRING(), std::string(""));
Expand All @@ -257,7 +263,6 @@ static std::vector<std::unique_ptr<Value>> populateLabelValues(std::vector<table
}

std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(const Expression& expression) {
auto catalog = context->getCatalog();
auto listType = LogicalType::LIST(LogicalType::STRING());
expression_vector children;
switch (expression.getDataType().getLogicalTypeID()) {
Expand All @@ -267,16 +272,11 @@ std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(const Expression
return createLiteralExpression("");
}
if (!node.isMultiLabeled()) {
auto labelName = catalog->getTableName(context->getTransaction(),
node.getSingleEntry()->getTableID());
auto labelName = node.getSingleEntry()->getName();
return createLiteralExpression(Value(LogicalType::STRING(), labelName));
}
// Internal tables should be invisible to the label function.
auto nodeTableIDs =
catalog->getNodeTableIDs(context->getTransaction(), false /* useInternalTable */);
children.push_back(node.getInternalID());
auto labelsValue = Value(std::move(listType),
populateLabelValues(std::move(nodeTableIDs), *catalog, context->getTransaction()));
auto labelsValue = Value(std::move(listType), populateLabelValues(node.getEntries()));
children.push_back(createLiteralExpression(labelsValue));
} break;
case LogicalTypeID::REL: {
Expand All @@ -285,15 +285,11 @@ std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(const Expression
return createLiteralExpression("");
}
if (!rel.isMultiLabeled()) {
auto labelName = catalog->getTableName(context->getTransaction(),
rel.getSingleEntry()->getTableID());
auto labelName = rel.getSingleEntry()->getName();
return createLiteralExpression(Value(LogicalType::STRING(), labelName));
}
auto relTableIDs =
catalog->getRelTableIDs(context->getTransaction(), false /* useInternalTable */);
children.push_back(rel.getInternalIDProperty());
auto labelsValue = Value(std::move(listType),
populateLabelValues(std::move(relTableIDs), *catalog, context->getTransaction()));
auto labelsValue = Value(std::move(listType), populateLabelValues(rel.getEntries()));
children.push_back(createLiteralExpression(labelsValue));
} break;
default:
Expand Down
110 changes: 23 additions & 87 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,39 +73,6 @@ bool Catalog::containsTable(const Transaction* transaction, const std::string& t
return contains;
}

table_id_t Catalog::getTableID(const Transaction* transaction, const std::string& tableName,
bool useInternal) const {
return getTableCatalogEntry(transaction, tableName, useInternal)->getTableID();
}

std::vector<table_id_t> Catalog::getNodeTableIDs(const Transaction* transaction,
bool useInternal) const {
std::vector<table_id_t> tableIDs;
tables->iterateEntriesOfType(transaction, CatalogEntryType::NODE_TABLE_ENTRY,
[&](const CatalogEntry* entry) { tableIDs.push_back(entry->getOID()); });
if (useInternal) {
internalTables->iterateEntriesOfType(transaction, CatalogEntryType::NODE_TABLE_ENTRY,
[&](const CatalogEntry* entry) { tableIDs.push_back(entry->getOID()); });
}
return tableIDs;
}

std::vector<table_id_t> Catalog::getRelTableIDs(const Transaction* transaction,
bool useInternal) const {
std::vector<table_id_t> tableIDs;
tables->iterateEntriesOfType(transaction, CatalogEntryType::REL_TABLE_ENTRY,
[&](const CatalogEntry* entry) { tableIDs.push_back(entry->getOID()); });
if (useInternal) {
internalTables->iterateEntriesOfType(transaction, CatalogEntryType::REL_TABLE_ENTRY,
[&](const CatalogEntry* entry) { tableIDs.push_back(entry->getOID()); });
}
return tableIDs;
}

std::string Catalog::getTableName(const Transaction* transaction, table_id_t tableID) const {
return getTableCatalogEntry(transaction, tableID)->getName();
}

TableCatalogEntry* Catalog::getTableCatalogEntry(const Transaction* transaction,
table_id_t tableID) const {
auto result = tables->getEntryOfOID(transaction, tableID);
Expand Down Expand Up @@ -163,15 +130,6 @@ std::vector<TableCatalogEntry*> Catalog::getTableEntries(const Transaction* tran
return result;
}

std::vector<TableCatalogEntry*> Catalog::getTableEntries(const Transaction* transaction,
const table_id_vector_t& tableIDs) const {
std::vector<TableCatalogEntry*> result;
for (const auto tableID : tableIDs) {
result.push_back(getTableCatalogEntry(transaction, tableID));
}
return result;
}

bool Catalog::tableInRelGroup(Transaction* transaction, table_id_t tableID) const {
for (const auto& entry : getRelTableGroupEntries(transaction)) {
if (entry->isParent(tableID)) {
Expand All @@ -181,29 +139,7 @@ bool Catalog::tableInRelGroup(Transaction* transaction, table_id_t tableID) cons
return false;
}

table_id_set_t Catalog::getFwdRelTableIDs(Transaction* transaction, table_id_t nodeTableID) const {
KU_ASSERT(getTableCatalogEntry(transaction, nodeTableID)->getTableType() == TableType::NODE);
table_id_set_t result;
for (const auto& relEntry : getRelTableEntries(transaction)) {
if (relEntry->getSrcTableID() == nodeTableID) {
result.insert(relEntry->getTableID());
}
}
return result;
}

table_id_set_t Catalog::getBwdRelTableIDs(Transaction* transaction, table_id_t nodeTableID) const {
KU_ASSERT(getTableCatalogEntry(transaction, nodeTableID)->getTableType() == TableType::NODE);
table_id_set_t result;
for (const auto& relEntry : getRelTableEntries(transaction)) {
if (relEntry->getDstTableID() == nodeTableID) {
result.insert(relEntry->getTableID());
}
}
return result;
}

table_id_t Catalog::createTableSchema(Transaction* transaction, const BoundCreateTableInfo& info) {
table_id_t Catalog::createTableEntry(Transaction* transaction, const BoundCreateTableInfo& info) {
std::unique_ptr<CatalogEntry> entry;
switch (info.type) {
case TableType::NODE: {
Expand All @@ -221,7 +157,8 @@ table_id_t Catalog::createTableSchema(Transaction* transaction, const BoundCreat
const auto tableEntry = entry->constPtrCast<TableCatalogEntry>();
for (auto& definition : tableEntry->getProperties()) {
if (definition.getType().getLogicalTypeID() == LogicalTypeID::SERIAL) {
const auto seqName = genSerialName(tableEntry->getName(), definition.getName());
const auto seqName =
SequenceCatalogEntry::genSerialName(tableEntry->getName(), definition.getName());
auto seqInfo =
BoundCreateSequenceInfo(seqName, 0, 1, 0, std::numeric_limits<int64_t>::max(),
false, ConflictAction::ON_CONFLICT_THROW, info.isInternal);
Expand All @@ -237,7 +174,7 @@ table_id_t Catalog::createTableSchema(Transaction* transaction, const BoundCreat
}

void Catalog::dropTableEntry(Transaction* transaction, const std::string& name) {
const auto tableID = getTableID(transaction, name);
auto tableID = getTableCatalogEntry(transaction, name)->getTableID();
dropAllIndexes(transaction, tableID);
dropTableEntry(transaction, tableID);
}
Expand All @@ -257,7 +194,8 @@ void Catalog::dropTableEntry(Transaction* transaction, table_id_t tableID) {
}
for (auto& definition : tableEntry->getProperties()) {
if (definition.getType().getLogicalTypeID() == LogicalTypeID::SERIAL) {
auto seqName = genSerialName(tableEntry->getName(), definition.getName());
auto seqName =
SequenceCatalogEntry::genSerialName(tableEntry->getName(), definition.getName());
dropSequence(transaction, seqName);
}
}
Expand Down Expand Up @@ -287,12 +225,11 @@ void Catalog::alterTableEntry(Transaction* transaction, const BoundAlterInfo& in
tables->alterEntry(transaction, info);
}

bool Catalog::containsSequence(const Transaction* transaction,
const std::string& sequenceName) const {
return sequences->containsEntry(transaction, sequenceName);
bool Catalog::containsSequence(const Transaction* transaction, const std::string& name) const {
return sequences->containsEntry(transaction, name);
}

sequence_id_t Catalog::getSequenceID(const Transaction* transaction,
SequenceCatalogEntry* Catalog::getSequenceEntry(const Transaction* transaction,
const std::string& sequenceName, bool useInternalSeq) const {
CatalogEntry* entry = nullptr;
if (!sequences->containsEntry(transaction, sequenceName) && useInternalSeq) {
Expand All @@ -301,10 +238,10 @@ sequence_id_t Catalog::getSequenceID(const Transaction* transaction,
entry = sequences->getEntry(transaction, sequenceName);
}
KU_ASSERT(entry);
return entry->getOID();
return entry->ptrCast<SequenceCatalogEntry>();
}

SequenceCatalogEntry* Catalog::getSequenceCatalogEntry(const Transaction* transaction,
SequenceCatalogEntry* Catalog::getSequenceEntry(const Transaction* transaction,
sequence_id_t sequenceID) const {
auto entry = internalSequences->getEntryOfOID(transaction, sequenceID);
if (entry == nullptr) {
Expand Down Expand Up @@ -335,22 +272,18 @@ sequence_id_t Catalog::createSequence(Transaction* transaction,
}

void Catalog::dropSequence(Transaction* transaction, const std::string& name) {
const auto sequenceID = getSequenceID(transaction, name);
dropSequence(transaction, sequenceID);
const auto entry = getSequenceEntry(transaction, name);
dropSequence(transaction, entry->getOID());
}

void Catalog::dropSequence(Transaction* transaction, sequence_id_t sequenceID) {
const auto sequenceEntry = getSequenceCatalogEntry(transaction, sequenceID);
const auto sequenceEntry = getSequenceEntry(transaction, sequenceID);
CatalogSet* set = nullptr;
set = sequences->containsEntry(transaction, sequenceEntry->getName()) ? sequences.get() :
internalSequences.get();
set->dropEntry(transaction, sequenceEntry->getName(), sequenceEntry->getOID());
}

std::string Catalog::genSerialName(const std::string& tableName, const std::string& propertyName) {
return std::string(tableName).append("_").append(propertyName).append("_").append("serial");
}

void Catalog::createType(Transaction* transaction, std::string name, LogicalType type) {
if (types->containsEntry(transaction, name)) {
return;
Expand Down Expand Up @@ -381,13 +314,16 @@ void Catalog::createIndex(Transaction* transaction,

IndexCatalogEntry* Catalog::getIndex(const Transaction* transaction, table_id_t tableID,
const std::string& indexName) const {
return indexes
->getEntry(transaction, IndexCatalogEntry::getInternalIndexName(tableID, indexName))
->ptrCast<IndexCatalogEntry>();
auto internalName = IndexCatalogEntry::getInternalIndexName(tableID, indexName);
return indexes->getEntry(transaction, internalName)->ptrCast<IndexCatalogEntry>();
}

CatalogSet* Catalog::getIndexes() const {
return indexes.get();
std::vector<IndexCatalogEntry*> Catalog::getIndexEntries(const Transaction* transaction) const {
std::vector<IndexCatalogEntry*> result;
for (auto& [_, entry] : indexes->getEntries(transaction)) {
result.push_back(entry->ptrCast<IndexCatalogEntry>());
}
return result;
}

bool Catalog::containsIndex(const Transaction* transaction, table_id_t tableID,
Expand Down Expand Up @@ -593,7 +529,7 @@ std::unique_ptr<CatalogEntry> Catalog::createRelTableGroupEntry(Transaction* tra
relTableIDs.reserve(extraInfo->infos.size());
for (auto& childInfo : extraInfo->infos) {
childInfo.hasParent = true;
relTableIDs.push_back(createTableSchema(transaction, childInfo));
relTableIDs.push_back(createTableEntry(transaction, childInfo));
}
return std::make_unique<RelGroupCatalogEntry>(tables.get(), info.tableName,
std::move(relTableIDs));
Expand Down
Loading