diff --git a/extension/fts/src/catalog/fts_index_catalog_entry.cpp b/extension/fts/src/catalog/fts_index_catalog_entry.cpp index f11046e7ab0..1c80713dc43 100644 --- a/extension/fts/src/catalog/fts_index_catalog_entry.cpp +++ b/extension/fts/src/catalog/fts_index_catalog_entry.cpp @@ -1,6 +1,7 @@ #include "catalog/fts_index_catalog_entry.h" #include "catalog/catalog.h" +#include "catalog/catalog_entry/table_catalog_entry.h" #include "common/serializer/buffered_reader.h" #include "common/serializer/buffered_serializer.h" #include "main/client_context.h" @@ -36,7 +37,9 @@ std::string FTSIndexAuxInfo::toCypher(const catalog::IndexCatalogEntry& indexEnt main::ClientContext* context) { std::string cypher; auto catalog = context->getCatalog(); - auto tableName = catalog->getTableName(context->getTransaction(), indexEntry.getTableID()); + auto tableName = + catalog->getTableCatalogEntry(context->getTransaction(), indexEntry.getTableID()) + ->getName(); std::string propertyStr; for (auto i = 0u; i < properties.size(); i++) { propertyStr += diff --git a/extension/fts/src/fts_extension.cpp b/extension/fts/src/fts_extension.cpp index 4dd6360937d..4abcdcb2b06 100644 --- a/extension/fts/src/fts_extension.cpp +++ b/extension/fts/src/fts_extension.cpp @@ -15,9 +15,7 @@ namespace kuzu { namespace fts_extension { static void initFTSEntries(const transaction::Transaction* transaction, catalog::Catalog& catalog) { - auto indexEntries = catalog.getIndexes()->getEntries(transaction); - for (auto& [_, entry] : indexEntries) { - auto indexEntry = entry->ptrCast(); + for (auto& indexEntry : catalog.getIndexEntries(transaction)) { if (indexEntry->getIndexType() == FTSIndexCatalogEntry::TYPE_NAME) { indexEntry->setAuxInfo(FTSIndexAuxInfo::deserialize(indexEntry->getAuxBufferReader())); } diff --git a/extension/fts/src/function/query_fts_bind_data.cpp b/extension/fts/src/function/query_fts_bind_data.cpp index d0d8a54170c..068dd2200a3 100644 --- a/extension/fts/src/function/query_fts_bind_data.cpp +++ b/extension/fts/src/function/query_fts_bind_data.cpp @@ -1,6 +1,7 @@ #include "function/query_fts_bind_data.h" #include "binder/expression/expression_util.h" +#include "catalog/catalog_entry/table_catalog_entry.h" #include "catalog/fts_index_catalog_entry.h" #include "common/string_utils.h" #include "function/fts_utils.h" @@ -57,7 +58,9 @@ struct StopWordsChecker { StopWordsChecker::StopWordsChecker(main::ClientContext& context) : termsVector{LogicalType::STRING(), context.getMemoryManager()}, tx{context.getTransaction()} { termsVector.state = common::DataChunkState::getSingleValueDataChunkState(); - auto tableID = context.getCatalog()->getTableID(tx, FTSUtils::getStopWordsTableName()); + auto tableID = context.getCatalog() + ->getTableCatalogEntry(tx, FTSUtils::getStopWordsTableName()) + ->getTableID(); stopWordsTable = context.getStorageManager()->getTable(tableID)->ptrCast(); } diff --git a/src/binder/bind/bind_ddl.cpp b/src/binder/bind/bind_ddl.cpp index 9cb4acc4f72..f81c7449120 100644 --- a/src/binder/bind/bind_ddl.cpp +++ b/src/binder/bind/bind_ddl.cpp @@ -9,6 +9,7 @@ #include "catalog/catalog_entry/node_table_catalog_entry.h" #include "catalog/catalog_entry/rel_group_catalog_entry.h" #include "catalog/catalog_entry/rel_table_catalog_entry.h" +#include "catalog/catalog_entry/sequence_catalog_entry.h" #include "common/exception/binder.h" #include "common/exception/message.h" #include "common/string_format.h" @@ -84,7 +85,7 @@ std::vector Binder::bindPropertyDefinitions( if (type.getLogicalTypeID() == LogicalTypeID::SERIAL) { validateSerialNoDefault(*boundExpr); expr = ParsedExpressionUtils::getSerialDefaultExpr( - Catalog::genSerialName(tableName, parsedDefinition.getName())); + SequenceCatalogEntry::genSerialName(tableName, parsedDefinition.getName())); } auto columnDefinition = ColumnDefinition(parsedDefinition.getName(), std::move(type)); definitions.emplace_back(std::move(columnDefinition), std::move(expr)); @@ -496,7 +497,7 @@ std::unique_ptr Binder::bindAddProperty(const Statement& stateme if (dataType.getLogicalTypeID() == LogicalTypeID::SERIAL) { validateSerialNoDefault(*boundDefault); defaultValue = ParsedExpressionUtils::getSerialDefaultExpr( - Catalog::genSerialName(tableName, propertyName)); + SequenceCatalogEntry::genSerialName(tableName, propertyName)); boundDefault = expressionBinder.implicitCastIfNecessary( expressionBinder.bindExpression(*defaultValue), dataType); } diff --git a/src/binder/bind/bind_export_database.cpp b/src/binder/bind/bind_export_database.cpp index f08b91c101a..93cae966761 100644 --- a/src/binder/bind/bind_export_database.cpp +++ b/src/binder/bind/bind_export_database.cpp @@ -61,8 +61,8 @@ static void bindExportRelTableDataQuery(const TableCatalogEntry& entry, std::str auto relTableEntry = entry.constPtrCast(); auto srcPrimaryKeyName = getPrimaryKeyName(relTableEntry->getSrcTableID(), catalog, tx); auto dstPrimaryKeyName = getPrimaryKeyName(relTableEntry->getDstTableID(), catalog, tx); - auto srcName = catalog.getTableName(tx, relTableEntry->getSrcTableID()); - auto dstName = catalog.getTableName(tx, relTableEntry->getDstTableID()); + auto srcName = catalog.getTableCatalogEntry(tx, relTableEntry->getSrcTableID())->getName(); + auto dstName = catalog.getTableCatalogEntry(tx, relTableEntry->getDstTableID())->getName(); auto relName = relTableEntry->getName(); exportQuery = stringFormat("match (a:{})-[r:{}]->(b:{}) return a.{},b.{},r.*;", srcName, relName, dstName, srcPrimaryKeyName, dstPrimaryKeyName); diff --git a/src/binder/bind/bind_file_scan.cpp b/src/binder/bind/bind_file_scan.cpp index ae898c34e17..32424b7f6a1 100644 --- a/src/binder/bind/bind_file_scan.cpp +++ b/src/binder/bind/bind_file_scan.cpp @@ -165,8 +165,7 @@ static TableFunction getObjectScanFunc(const std::string& dbName, const std::str throw BinderException{stringFormat("No database named {} has been attached.", dbName)}; } auto attachedCatalog = attachedDB->getCatalog(); - auto tableID = attachedCatalog->getTableID(clientContext->getTransaction(), tableName); - auto entry = attachedCatalog->getTableCatalogEntry(clientContext->getTransaction(), tableID); + auto entry = attachedCatalog->getTableCatalogEntry(clientContext->getTransaction(), tableName); return entry->ptrCast()->getScanFunction(); } diff --git a/src/binder/bind/copy/bind_copy_from.cpp b/src/binder/bind/copy/bind_copy_from.cpp index 3f71dddf57f..e9b542498fa 100644 --- a/src/binder/bind/copy/bind_copy_from.cpp +++ b/src/binder/bind/copy/bind_copy_from.cpp @@ -24,8 +24,7 @@ std::unique_ptr Binder::bindCopyFromClause(const Statement& stat validateTableExist(tableName); // Bind to table schema. auto catalog = clientContext->getCatalog(); - auto tableID = catalog->getTableID(clientContext->getTransaction(), tableName); - auto tableEntry = catalog->getTableCatalogEntry(clientContext->getTransaction(), tableID); + auto tableEntry = catalog->getTableCatalogEntry(clientContext->getTransaction(), tableName); switch (tableEntry->getTableType()) { case TableType::REL_GROUP: { throw BinderException(stringFormat("Cannot copy into {} table with type {}.", tableName, diff --git a/src/binder/bind_expression/bind_function_expression.cpp b/src/binder/bind_expression/bind_function_expression.cpp index 06b1b09df06..9f6099da37b 100644 --- a/src/binder/bind_expression/bind_function_expression.cpp +++ b/src/binder/bind_expression/bind_function_expression.cpp @@ -239,15 +239,21 @@ std::shared_ptr ExpressionBinder::bindEndNodeExpression(const Expres return rel.getDstNode(); } -static std::vector> populateLabelValues(std::vector tableIDs, - const catalog::Catalog& catalog, transaction::Transaction* tx) { - auto tableIDsSet = std::unordered_set(tableIDs.begin(), tableIDs.end()); - table_id_t maxTableID = *std::max_element(tableIDsSet.begin(), tableIDsSet.end()); +static std::vector> populateLabelValues( + std::vector entries) { + std::unordered_map map; + common::table_id_t maxTableID = 0; + for (auto& entry : entries) { + map.insert({entry->getTableID(), entry->getName()}); + if (entry->getTableID() > maxTableID) { + maxTableID = entry->getTableID(); + } + } std::vector> labels; labels.resize(maxTableID + 1); for (auto i = 0u; i < labels.size(); ++i) { - if (tableIDsSet.contains(i)) { - labels[i] = std::make_unique(LogicalType::STRING(), catalog.getTableName(tx, i)); + if (map.contains(i)) { + labels[i] = std::make_unique(LogicalType::STRING(), map.at(i)); } else { // TODO(Xiyang/Guodong): change to null literal once we support null in LIST type. labels[i] = std::make_unique(LogicalType::STRING(), std::string("")); @@ -257,7 +263,6 @@ static std::vector> populateLabelValues(std::vector ExpressionBinder::bindLabelFunction(const Expression& expression) { - auto catalog = context->getCatalog(); auto listType = LogicalType::LIST(LogicalType::STRING()); expression_vector children; switch (expression.getDataType().getLogicalTypeID()) { @@ -267,16 +272,11 @@ std::shared_ptr ExpressionBinder::bindLabelFunction(const Expression return createLiteralExpression(""); } if (!node.isMultiLabeled()) { - auto labelName = catalog->getTableName(context->getTransaction(), - node.getSingleEntry()->getTableID()); + auto labelName = node.getSingleEntry()->getName(); return createLiteralExpression(Value(LogicalType::STRING(), labelName)); } - // Internal tables should be invisible to the label function. - auto nodeTableIDs = - catalog->getNodeTableIDs(context->getTransaction(), false /* useInternalTable */); children.push_back(node.getInternalID()); - auto labelsValue = Value(std::move(listType), - populateLabelValues(std::move(nodeTableIDs), *catalog, context->getTransaction())); + auto labelsValue = Value(std::move(listType), populateLabelValues(node.getEntries())); children.push_back(createLiteralExpression(labelsValue)); } break; case LogicalTypeID::REL: { @@ -285,15 +285,11 @@ std::shared_ptr ExpressionBinder::bindLabelFunction(const Expression return createLiteralExpression(""); } if (!rel.isMultiLabeled()) { - auto labelName = catalog->getTableName(context->getTransaction(), - rel.getSingleEntry()->getTableID()); + auto labelName = rel.getSingleEntry()->getName(); return createLiteralExpression(Value(LogicalType::STRING(), labelName)); } - auto relTableIDs = - catalog->getRelTableIDs(context->getTransaction(), false /* useInternalTable */); children.push_back(rel.getInternalIDProperty()); - auto labelsValue = Value(std::move(listType), - populateLabelValues(std::move(relTableIDs), *catalog, context->getTransaction())); + auto labelsValue = Value(std::move(listType), populateLabelValues(rel.getEntries())); children.push_back(createLiteralExpression(labelsValue)); } break; default: diff --git a/src/catalog/catalog.cpp b/src/catalog/catalog.cpp index 10599ee30d3..b88ee69b784 100644 --- a/src/catalog/catalog.cpp +++ b/src/catalog/catalog.cpp @@ -73,39 +73,6 @@ bool Catalog::containsTable(const Transaction* transaction, const std::string& t return contains; } -table_id_t Catalog::getTableID(const Transaction* transaction, const std::string& tableName, - bool useInternal) const { - return getTableCatalogEntry(transaction, tableName, useInternal)->getTableID(); -} - -std::vector Catalog::getNodeTableIDs(const Transaction* transaction, - bool useInternal) const { - std::vector tableIDs; - tables->iterateEntriesOfType(transaction, CatalogEntryType::NODE_TABLE_ENTRY, - [&](const CatalogEntry* entry) { tableIDs.push_back(entry->getOID()); }); - if (useInternal) { - internalTables->iterateEntriesOfType(transaction, CatalogEntryType::NODE_TABLE_ENTRY, - [&](const CatalogEntry* entry) { tableIDs.push_back(entry->getOID()); }); - } - return tableIDs; -} - -std::vector Catalog::getRelTableIDs(const Transaction* transaction, - bool useInternal) const { - std::vector tableIDs; - tables->iterateEntriesOfType(transaction, CatalogEntryType::REL_TABLE_ENTRY, - [&](const CatalogEntry* entry) { tableIDs.push_back(entry->getOID()); }); - if (useInternal) { - internalTables->iterateEntriesOfType(transaction, CatalogEntryType::REL_TABLE_ENTRY, - [&](const CatalogEntry* entry) { tableIDs.push_back(entry->getOID()); }); - } - return tableIDs; -} - -std::string Catalog::getTableName(const Transaction* transaction, table_id_t tableID) const { - return getTableCatalogEntry(transaction, tableID)->getName(); -} - TableCatalogEntry* Catalog::getTableCatalogEntry(const Transaction* transaction, table_id_t tableID) const { auto result = tables->getEntryOfOID(transaction, tableID); @@ -163,15 +130,6 @@ std::vector Catalog::getTableEntries(const Transaction* tran return result; } -std::vector Catalog::getTableEntries(const Transaction* transaction, - const table_id_vector_t& tableIDs) const { - std::vector result; - for (const auto tableID : tableIDs) { - result.push_back(getTableCatalogEntry(transaction, tableID)); - } - return result; -} - bool Catalog::tableInRelGroup(Transaction* transaction, table_id_t tableID) const { for (const auto& entry : getRelTableGroupEntries(transaction)) { if (entry->isParent(tableID)) { @@ -181,29 +139,7 @@ bool Catalog::tableInRelGroup(Transaction* transaction, table_id_t tableID) cons return false; } -table_id_set_t Catalog::getFwdRelTableIDs(Transaction* transaction, table_id_t nodeTableID) const { - KU_ASSERT(getTableCatalogEntry(transaction, nodeTableID)->getTableType() == TableType::NODE); - table_id_set_t result; - for (const auto& relEntry : getRelTableEntries(transaction)) { - if (relEntry->getSrcTableID() == nodeTableID) { - result.insert(relEntry->getTableID()); - } - } - return result; -} - -table_id_set_t Catalog::getBwdRelTableIDs(Transaction* transaction, table_id_t nodeTableID) const { - KU_ASSERT(getTableCatalogEntry(transaction, nodeTableID)->getTableType() == TableType::NODE); - table_id_set_t result; - for (const auto& relEntry : getRelTableEntries(transaction)) { - if (relEntry->getDstTableID() == nodeTableID) { - result.insert(relEntry->getTableID()); - } - } - return result; -} - -table_id_t Catalog::createTableSchema(Transaction* transaction, const BoundCreateTableInfo& info) { +table_id_t Catalog::createTableEntry(Transaction* transaction, const BoundCreateTableInfo& info) { std::unique_ptr entry; switch (info.type) { case TableType::NODE: { @@ -221,7 +157,8 @@ table_id_t Catalog::createTableSchema(Transaction* transaction, const BoundCreat const auto tableEntry = entry->constPtrCast(); for (auto& definition : tableEntry->getProperties()) { if (definition.getType().getLogicalTypeID() == LogicalTypeID::SERIAL) { - const auto seqName = genSerialName(tableEntry->getName(), definition.getName()); + const auto seqName = + SequenceCatalogEntry::genSerialName(tableEntry->getName(), definition.getName()); auto seqInfo = BoundCreateSequenceInfo(seqName, 0, 1, 0, std::numeric_limits::max(), false, ConflictAction::ON_CONFLICT_THROW, info.isInternal); @@ -237,7 +174,7 @@ table_id_t Catalog::createTableSchema(Transaction* transaction, const BoundCreat } void Catalog::dropTableEntry(Transaction* transaction, const std::string& name) { - const auto tableID = getTableID(transaction, name); + auto tableID = getTableCatalogEntry(transaction, name)->getTableID(); dropAllIndexes(transaction, tableID); dropTableEntry(transaction, tableID); } @@ -257,7 +194,8 @@ void Catalog::dropTableEntry(Transaction* transaction, table_id_t tableID) { } for (auto& definition : tableEntry->getProperties()) { if (definition.getType().getLogicalTypeID() == LogicalTypeID::SERIAL) { - auto seqName = genSerialName(tableEntry->getName(), definition.getName()); + auto seqName = + SequenceCatalogEntry::genSerialName(tableEntry->getName(), definition.getName()); dropSequence(transaction, seqName); } } @@ -287,12 +225,11 @@ void Catalog::alterTableEntry(Transaction* transaction, const BoundAlterInfo& in tables->alterEntry(transaction, info); } -bool Catalog::containsSequence(const Transaction* transaction, - const std::string& sequenceName) const { - return sequences->containsEntry(transaction, sequenceName); +bool Catalog::containsSequence(const Transaction* transaction, const std::string& name) const { + return sequences->containsEntry(transaction, name); } -sequence_id_t Catalog::getSequenceID(const Transaction* transaction, +SequenceCatalogEntry* Catalog::getSequenceEntry(const Transaction* transaction, const std::string& sequenceName, bool useInternalSeq) const { CatalogEntry* entry = nullptr; if (!sequences->containsEntry(transaction, sequenceName) && useInternalSeq) { @@ -301,10 +238,10 @@ sequence_id_t Catalog::getSequenceID(const Transaction* transaction, entry = sequences->getEntry(transaction, sequenceName); } KU_ASSERT(entry); - return entry->getOID(); + return entry->ptrCast(); } -SequenceCatalogEntry* Catalog::getSequenceCatalogEntry(const Transaction* transaction, +SequenceCatalogEntry* Catalog::getSequenceEntry(const Transaction* transaction, sequence_id_t sequenceID) const { auto entry = internalSequences->getEntryOfOID(transaction, sequenceID); if (entry == nullptr) { @@ -335,22 +272,18 @@ sequence_id_t Catalog::createSequence(Transaction* transaction, } void Catalog::dropSequence(Transaction* transaction, const std::string& name) { - const auto sequenceID = getSequenceID(transaction, name); - dropSequence(transaction, sequenceID); + const auto entry = getSequenceEntry(transaction, name); + dropSequence(transaction, entry->getOID()); } void Catalog::dropSequence(Transaction* transaction, sequence_id_t sequenceID) { - const auto sequenceEntry = getSequenceCatalogEntry(transaction, sequenceID); + const auto sequenceEntry = getSequenceEntry(transaction, sequenceID); CatalogSet* set = nullptr; set = sequences->containsEntry(transaction, sequenceEntry->getName()) ? sequences.get() : internalSequences.get(); set->dropEntry(transaction, sequenceEntry->getName(), sequenceEntry->getOID()); } -std::string Catalog::genSerialName(const std::string& tableName, const std::string& propertyName) { - return std::string(tableName).append("_").append(propertyName).append("_").append("serial"); -} - void Catalog::createType(Transaction* transaction, std::string name, LogicalType type) { if (types->containsEntry(transaction, name)) { return; @@ -381,13 +314,16 @@ void Catalog::createIndex(Transaction* transaction, IndexCatalogEntry* Catalog::getIndex(const Transaction* transaction, table_id_t tableID, const std::string& indexName) const { - return indexes - ->getEntry(transaction, IndexCatalogEntry::getInternalIndexName(tableID, indexName)) - ->ptrCast(); + auto internalName = IndexCatalogEntry::getInternalIndexName(tableID, indexName); + return indexes->getEntry(transaction, internalName)->ptrCast(); } -CatalogSet* Catalog::getIndexes() const { - return indexes.get(); +std::vector Catalog::getIndexEntries(const Transaction* transaction) const { + std::vector result; + for (auto& [_, entry] : indexes->getEntries(transaction)) { + result.push_back(entry->ptrCast()); + } + return result; } bool Catalog::containsIndex(const Transaction* transaction, table_id_t tableID, @@ -593,7 +529,7 @@ std::unique_ptr Catalog::createRelTableGroupEntry(Transaction* tra relTableIDs.reserve(extraInfo->infos.size()); for (auto& childInfo : extraInfo->infos) { childInfo.hasParent = true; - relTableIDs.push_back(createTableSchema(transaction, childInfo)); + relTableIDs.push_back(createTableEntry(transaction, childInfo)); } return std::make_unique(tables.get(), info.tableName, std::move(relTableIDs)); diff --git a/src/catalog/catalog_entry/hnsw_index_catalog_entry.cpp b/src/catalog/catalog_entry/hnsw_index_catalog_entry.cpp index ddd256c822e..c0432d3d68b 100644 --- a/src/catalog/catalog_entry/hnsw_index_catalog_entry.cpp +++ b/src/catalog/catalog_entry/hnsw_index_catalog_entry.cpp @@ -1,6 +1,7 @@ #include "catalog/catalog_entry/hnsw_index_catalog_entry.h" #include "catalog/catalog.h" +#include "catalog/catalog_entry/table_catalog_entry.h" #include "common/serializer/buffered_reader.h" #include "common/serializer/buffered_serializer.h" #include "common/serializer/deserializer.h" @@ -43,7 +44,9 @@ std::string HNSWIndexAuxInfo::toCypher(const IndexCatalogEntry& indexEntry, main::ClientContext* context) { std::string cypher; auto catalog = context->getCatalog(); - auto tableName = catalog->getTableName(context->getTransaction(), indexEntry.getTableID()); + auto tableName = + catalog->getTableCatalogEntry(context->getTransaction(), indexEntry.getTableID()) + ->getName(); auto distFuncName = storage::HNSWIndexConfig::distFuncToString(config.distFunc); cypher += common::stringFormat("CALL CREATE_HNSW_INDEX('{}', '{}', '{}', mu := {}, ml := {}, " "pu := {}, distFunc := '{}', alpha := {}, efc := {});\n", diff --git a/src/catalog/catalog_entry/node_table_catalog_entry.cpp b/src/catalog/catalog_entry/node_table_catalog_entry.cpp index eb46b15c3f9..859c6db2ebd 100644 --- a/src/catalog/catalog_entry/node_table_catalog_entry.cpp +++ b/src/catalog/catalog_entry/node_table_catalog_entry.cpp @@ -1,6 +1,8 @@ #include "catalog/catalog_entry/node_table_catalog_entry.h" #include "binder/ddl/bound_create_table_info.h" +#include "catalog/catalog.h" +#include "catalog/catalog_entry/rel_table_catalog_entry.h" #include "catalog/catalog_set.h" #include "common/serializer/deserializer.h" @@ -9,6 +11,28 @@ using namespace kuzu::binder; namespace kuzu { namespace catalog { +common::table_id_set_t NodeTableCatalogEntry::getFwdRelTableIDs(Catalog* catalog, + transaction::Transaction* transaction) const { + common::table_id_set_t result; + for (const auto& relEntry : catalog->getRelTableEntries(transaction)) { + if (relEntry->getSrcTableID() == getTableID()) { + result.insert(relEntry->getTableID()); + } + } + return result; +} + +common::table_id_set_t NodeTableCatalogEntry::getBwdRelTableIDs(Catalog* catalog, + transaction::Transaction* transaction) const { + common::table_id_set_t result; + for (const auto& relEntry : catalog->getRelTableEntries(transaction)) { + if (relEntry->getDstTableID() == getTableID()) { + result.insert(relEntry->getTableID()); + } + } + return result; +} + void NodeTableCatalogEntry::serialize(common::Serializer& serializer) const { TableCatalogEntry::serialize(serializer); serializer.writeDebuggingInfo("primaryKeyName"); diff --git a/src/catalog/catalog_entry/rel_group_catalog_entry.cpp b/src/catalog/catalog_entry/rel_group_catalog_entry.cpp index c2417a64f1a..843468476b6 100644 --- a/src/catalog/catalog_entry/rel_group_catalog_entry.cpp +++ b/src/catalog/catalog_entry/rel_group_catalog_entry.cpp @@ -79,8 +79,10 @@ static std::string getFromToStr(common::table_id_t tableID, ClientContext* conte auto transaction = context->getTransaction(); auto& entry = catalog->getTableCatalogEntry(transaction, tableID)->constCast(); - auto srcTableName = catalog->getTableName(transaction, entry.getSrcTableID()); - auto dstTableName = catalog->getTableName(transaction, entry.getDstTableID()); + auto srcTableName = + catalog->getTableCatalogEntry(transaction, entry.getSrcTableID())->getName(); + auto dstTableName = + catalog->getTableCatalogEntry(transaction, entry.getDstTableID())->getName(); return stringFormat("FROM {} TO {}", srcTableName, dstTableName); } diff --git a/src/catalog/catalog_entry/rel_table_catalog_entry.cpp b/src/catalog/catalog_entry/rel_table_catalog_entry.cpp index 5a9e25ca3e6..9ad5ca3e0bd 100644 --- a/src/catalog/catalog_entry/rel_table_catalog_entry.cpp +++ b/src/catalog/catalog_entry/rel_table_catalog_entry.cpp @@ -102,8 +102,9 @@ std::unique_ptr RelTableCatalogEntry::copy() const { std::string RelTableCatalogEntry::toCypher(main::ClientContext* clientContext) const { std::stringstream ss; auto catalog = clientContext->getCatalog(); - auto srcTableName = catalog->getTableName(clientContext->getTransaction(), srcTableID); - auto dstTableName = catalog->getTableName(clientContext->getTransaction(), dstTableID); + auto transaction = clientContext->getTransaction(); + auto srcTableName = catalog->getTableCatalogEntry(transaction, srcTableID)->getName(); + auto dstTableName = catalog->getTableCatalogEntry(transaction, dstTableID)->getName(); auto srcMultiStr = srcMultiplicity == common::RelMultiplicity::MANY ? "MANY" : "ONE"; auto dstMultiStr = dstMultiplicity == common::RelMultiplicity::MANY ? "MANY" : "ONE"; std::string tableInfo = diff --git a/src/function/sequence/sequence_functions.cpp b/src/function/sequence/sequence_functions.cpp index d73b48084d8..64732c11e09 100644 --- a/src/function/sequence/sequence_functions.cpp +++ b/src/function/sequence/sequence_functions.cpp @@ -15,9 +15,8 @@ struct CurrVal { auto ctx = reinterpret_cast(dataPtr)->clientContext; auto catalog = ctx->getCatalog(); auto sequenceName = input.getAsString(); - auto sequenceID = catalog->getSequenceID(ctx->getTransaction(), sequenceName, + auto sequenceEntry = catalog->getSequenceEntry(ctx->getTransaction(), sequenceName, ctx->shouldUseInternalCatalogEntry()); - auto sequenceEntry = catalog->getSequenceCatalogEntry(ctx->getTransaction(), sequenceID); result.setValue(0, sequenceEntry->currVal()); } }; @@ -28,9 +27,8 @@ struct NextVal { auto cnt = reinterpret_cast(dataPtr)->count; auto catalog = ctx->getCatalog(); auto sequenceName = input.getAsString(); - auto sequenceID = catalog->getSequenceID(ctx->getTransaction(), sequenceName, + auto sequenceEntry = catalog->getSequenceEntry(ctx->getTransaction(), sequenceName, ctx->shouldUseInternalCatalogEntry()); - auto sequenceEntry = catalog->getSequenceCatalogEntry(ctx->getTransaction(), sequenceID); sequenceEntry->nextKVal(ctx->getTransaction(), cnt, result); result.state->getSelVectorUnsafe().setSelSize(cnt); } diff --git a/src/function/table/call/hnsw/create_hnsw_index.cpp b/src/function/table/call/hnsw/create_hnsw_index.cpp index 13f299142c1..e8573744a5a 100644 --- a/src/function/table/call/hnsw/create_hnsw_index.cpp +++ b/src/function/table/call/hnsw/create_hnsw_index.cpp @@ -74,24 +74,32 @@ static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput&) { static void finalizeFunc(const processor::ExecutionContext* context, TableFuncSharedState* sharedState) { + auto clientContext = context->clientContext; + auto transaction = clientContext->getTransaction(); const auto hnswSharedState = sharedState->ptrCast(); - hnswSharedState->hnswIndex->shrink(context->clientContext->getTransaction()); - hnswSharedState->hnswIndex->finalize(*context->clientContext->getMemoryManager(), - *hnswSharedState->partitionerSharedState); + const auto index = hnswSharedState->hnswIndex.get(); + index->shrink(transaction); + index->finalize(*clientContext->getMemoryManager(), *hnswSharedState->partitionerSharedState); const auto bindData = hnswSharedState->bindData->constPtrCast(); - const auto catalog = context->clientContext->getCatalog(); - auto upperRelTableID = catalog->getTableID(context->clientContext->getTransaction(), - storage::HNSWIndexUtils::getUpperGraphTableName(bindData->indexName)); - auto lowerRelTableID = catalog->getTableID(context->clientContext->getTransaction(), - storage::HNSWIndexUtils::getLowerGraphTableName(bindData->indexName)); - catalog->createIndex(context->clientContext->getTransaction(), + const auto catalog = clientContext->getCatalog(); + auto upperRelTableID = + catalog + ->getTableCatalogEntry(transaction, + storage::HNSWIndexUtils::getUpperGraphTableName(bindData->indexName)) + ->getTableID(); + auto lowerRelTableID = + catalog + ->getTableCatalogEntry(transaction, + storage::HNSWIndexUtils::getLowerGraphTableName(bindData->indexName)) + ->getTableID(); + auto auxInfo = std::make_unique(upperRelTableID, lowerRelTableID, + hnswSharedState->nodeTable.getColumn(bindData->columnID).getName(), + index->getUpperEntryPoint(), index->getLowerEntryPoint(), bindData->config.copy()); + auto indexEntry = std::make_unique(catalog::HNSWIndexCatalogEntry::TYPE_NAME, - bindData->tableEntry->getTableID(), bindData->indexName, - std::make_unique(upperRelTableID, lowerRelTableID, - hnswSharedState->nodeTable.getColumn(bindData->columnID).getName(), - hnswSharedState->hnswIndex->getUpperEntryPoint(), - hnswSharedState->hnswIndex->getLowerEntryPoint(), bindData->config.copy()))); + bindData->tableEntry->getTableID(), bindData->indexName, std::move(auxInfo)); + catalog->createIndex(transaction, std::move(indexEntry)); } static std::string createHNSWIndexTables(main::ClientContext& context, diff --git a/src/function/table/call/show_connection.cpp b/src/function/table/call/show_connection.cpp index 706adeb0a29..96b3592175e 100644 --- a/src/function/table/call/show_connection.cpp +++ b/src/function/table/call/show_connection.cpp @@ -37,8 +37,6 @@ static void outputRelTableConnection(const DataChunk& outputDataChunk, uint64_t // Get src and dst name const auto srcTableID = relTableEntry->getSrcTableID(); const auto dstTableID = relTableEntry->getDstTableID(); - const auto srcTableName = catalog->getTableName(context->getTransaction(), srcTableID); - const auto dstTableName = catalog->getTableName(context->getTransaction(), dstTableID); // Get src and dst primary key const auto srcTableEntry = catalog->getTableCatalogEntry(context->getTransaction(), srcTableID); const auto dstTableEntry = catalog->getTableCatalogEntry(context->getTransaction(), dstTableID); @@ -47,8 +45,8 @@ static void outputRelTableConnection(const DataChunk& outputDataChunk, uint64_t const auto dstTablePrimaryKey = dstTableEntry->constCast().getPrimaryKeyName(); // Write result to dataChunk - outputDataChunk.getValueVectorMutable(0).setValue(outputPos, srcTableName); - outputDataChunk.getValueVectorMutable(1).setValue(outputPos, dstTableName); + outputDataChunk.getValueVectorMutable(0).setValue(outputPos, srcTableEntry->getName()); + outputDataChunk.getValueVectorMutable(1).setValue(outputPos, dstTableEntry->getName()); outputDataChunk.getValueVectorMutable(2).setValue(outputPos, srcTablePrimaryKey); outputDataChunk.getValueVectorMutable(3).setValue(outputPos, dstTablePrimaryKey); } @@ -90,8 +88,7 @@ static std::unique_ptr bindFunc(const ClientContext* context, std::vector columnTypes; const auto tableName = input->getLiteralVal(0); const auto catalog = context->getCatalog(); - const auto tableID = catalog->getTableID(context->getTransaction(), tableName); - auto tableEntry = catalog->getTableCatalogEntry(context->getTransaction(), tableID); + auto tableEntry = catalog->getTableCatalogEntry(context->getTransaction(), tableName); const auto tableType = tableEntry->getTableType(); if (tableType != TableType::REL && tableType != TableType::REL_GROUP) { throw BinderException{"Show connection can only be called on a rel table!"}; diff --git a/src/function/table/call/stats_info.cpp b/src/function/table/call/stats_info.cpp index f5205f3535f..76ac754035f 100644 --- a/src/function/table/call/stats_info.cpp +++ b/src/function/table/call/stats_info.cpp @@ -65,8 +65,7 @@ static std::unique_ptr bindFunc(const ClientContext* context, if (!catalog->containsTable(context->getTransaction(), tableName)) { throw BinderException{"Table " + tableName + " does not exist!"}; } - const auto tableID = catalog->getTableID(context->getTransaction(), tableName); - auto tableEntry = catalog->getTableCatalogEntry(context->getTransaction(), tableID); + auto tableEntry = catalog->getTableCatalogEntry(context->getTransaction(), tableName); if (tableEntry->getTableType() != TableType::NODE) { throw BinderException{ "Stats from a non-node table " + tableName + " is not supported yet!"}; @@ -80,7 +79,7 @@ static std::unique_ptr bindFunc(const ClientContext* context, columnTypes.push_back(LogicalType::INT64()); } const auto storageManager = context->getStorageManager(); - auto table = storageManager->getTable(tableID); + auto table = storageManager->getTable(tableEntry->getTableID()); auto columns = input->binder->createVariables(columnNames, columnTypes); return std::make_unique(columns, tableEntry, table, context); } diff --git a/src/function/table/call/storage_info.cpp b/src/function/table/call/storage_info.cpp index 613a9a186db..0dd304fe899 100644 --- a/src/function/table/call/storage_info.cpp +++ b/src/function/table/call/storage_info.cpp @@ -344,10 +344,9 @@ static std::unique_ptr bindFunc(const ClientContext* context, if (!catalog->containsTable(context->getTransaction(), tableName)) { throw BinderException{"Table " + tableName + " does not exist!"}; } - auto tableID = catalog->getTableID(context->getTransaction(), tableName); - auto tableEntry = catalog->getTableCatalogEntry(context->getTransaction(), tableID); + auto tableEntry = catalog->getTableCatalogEntry(context->getTransaction(), tableName); auto storageManager = context->getStorageManager(); - auto table = storageManager->getTable(tableID); + auto table = storageManager->getTable(tableEntry->getTableID()); auto columns = input->binder->createVariables(columnNames, columnTypes); return std::make_unique(columns, tableEntry, table, context); } diff --git a/src/function/table/call/table_info.cpp b/src/function/table/call/table_info.cpp index 4cffeccbf47..3f408381f2a 100644 --- a/src/function/table/call/table_info.cpp +++ b/src/function/table/call/table_info.cpp @@ -72,8 +72,7 @@ static std::unique_ptr getTableCatalogEntry(const main::Clien auto transaction = context->getTransaction(); auto tableInfo = common::StringUtils::split(tableName, "."); if (tableInfo.size() == 1) { - auto tableID = context->getCatalog()->getTableID(transaction, tableName); - return context->getCatalog()->getTableCatalogEntry(transaction, tableID)->copy(); + return context->getCatalog()->getTableCatalogEntry(transaction, tableName)->copy(); } else { auto catalogName = tableInfo[0]; auto attachedTableName = tableInfo[1]; @@ -82,8 +81,9 @@ static std::unique_ptr getTableCatalogEntry(const main::Clien throw common::RuntimeException{ common::stringFormat("Database: {} doesn't exist.", catalogName)}; } - auto tableID = attachedDatabase->getCatalog()->getTableID(transaction, attachedTableName); - return attachedDatabase->getCatalog()->getTableCatalogEntry(transaction, tableID)->copy(); + return attachedDatabase->getCatalog() + ->getTableCatalogEntry(transaction, attachedTableName) + ->copy(); } } diff --git a/src/graph/on_disk_graph.cpp b/src/graph/on_disk_graph.cpp index b37d956124d..a49d069d70a 100644 --- a/src/graph/on_disk_graph.cpp +++ b/src/graph/on_disk_graph.cpp @@ -1,6 +1,7 @@ #include "graph/on_disk_graph.h" #include "binder/expression/property_expression.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" #include "common/assert.h" #include "common/cast.h" #include "common/constants.h" @@ -129,12 +130,13 @@ OnDiskGraph::OnDiskGraph(ClientContext* context, GraphEntry entry) auto storage = context->getStorageManager(); auto catalog = context->getCatalog(); auto transaction = context->getTransaction(); - for (auto& nodeEntry : graphEntry.nodeEntries) { - auto nodeTableID = nodeEntry->getTableID(); + for (auto& entry : graphEntry.nodeEntries) { + auto& nodeEntry = entry->constCast(); + auto nodeTableID = nodeEntry.getTableID(); nodeIDToNodeTable.insert( {nodeTableID, storage->getTable(nodeTableID)->ptrCast()}); table_id_map_t fwdRelTables; - for (auto& relTableID : catalog->getFwdRelTableIDs(transaction, nodeTableID)) { + for (auto& relTableID : nodeEntry.getFwdRelTableIDs(catalog, transaction)) { if (!graphEntry.hasRelEntry(relTableID)) { continue; } @@ -142,7 +144,7 @@ OnDiskGraph::OnDiskGraph(ClientContext* context, GraphEntry entry) } nodeTableIDToFwdRelTables.insert({nodeTableID, std::move(fwdRelTables)}); table_id_map_t bwdRelTables; - for (auto& relTableID : catalog->getBwdRelTableIDs(transaction, nodeTableID)) { + for (auto& relTableID : nodeEntry.getBwdRelTableIDs(catalog, transaction)) { if (!graphEntry.hasRelEntry(relTableID)) { continue; } diff --git a/src/include/catalog/catalog.h b/src/include/catalog/catalog.h index 68a463b8014..f992f554eca 100644 --- a/src/include/catalog/catalog.h +++ b/src/include/catalog/catalog.h @@ -55,83 +55,97 @@ class KUZU_API Catalog { virtual ~Catalog() = default; // ----------------------------- Table Schemas ---------------------------- - bool containsTable(const transaction::Transaction* transaction, const std::string& tableName, - bool useInternal = true) const; - common::table_id_t getTableID(const transaction::Transaction* transaction, - const std::string& tableName, bool useInternal = true) const; - std::vector getNodeTableIDs(const transaction::Transaction* transaction, - bool useInternal = true) const; - std::vector getRelTableIDs(const transaction::Transaction* transaction, + // Check if table entry exists. + bool containsTable(const transaction::Transaction* transaction, const std::string& tableName, bool useInternal = true) const; - - // TODO: Should remove this. - std::string getTableName(const transaction::Transaction* transaction, - common::table_id_t tableID) const; + // Get table entry with name. TableCatalogEntry* getTableCatalogEntry(const transaction::Transaction* transaction, const std::string& tableName, bool useInternal = true) const; + // Get table entry with id. TableCatalogEntry* getTableCatalogEntry(const transaction::Transaction* transaction, common::table_id_t tableID) const; + // Get all node table entries. std::vector getNodeTableEntries(transaction::Transaction* transaction, bool useInternal = true) const; + // Get all rel table entries. std::vector getRelTableEntries(transaction::Transaction* transaction, bool useInternal = true) const; + // Get all rel group entries. std::vector getRelTableGroupEntries( transaction::Transaction* transaction) const; + // Get all table entries. std::vector getTableEntries( const transaction::Transaction* transaction) const; - std::vector getTableEntries(const transaction::Transaction* transaction, - const common::table_id_vector_t& tableIDs) const; bool tableInRelGroup(transaction::Transaction* transaction, common::table_id_t tableID) const; - common::table_id_set_t getFwdRelTableIDs(transaction::Transaction* transaction, - common::table_id_t nodeTableID) const; - common::table_id_set_t getBwdRelTableIDs(transaction::Transaction* transaction, - common::table_id_t nodeTableID) const; - common::table_id_t createTableSchema(transaction::Transaction* transaction, + // Create table entry. + common::table_id_t createTableEntry(transaction::Transaction* transaction, const binder::BoundCreateTableInfo& info); + // Drop table entry with name. void dropTableEntry(transaction::Transaction* transaction, const std::string& name); + // Drop table entry with id. void dropTableEntry(transaction::Transaction* transaction, common::table_id_t tableID); + // Alter table entry. void alterTableEntry(transaction::Transaction* transaction, const binder::BoundAlterInfo& info); // ----------------------------- Sequences ---------------------------- - bool containsSequence(const transaction::Transaction* transaction, - const std::string& sequenceName) const; - common::sequence_id_t getSequenceID(const transaction::Transaction* transaction, - const std::string& sequenceName, bool useInternalSeq = true) const; - SequenceCatalogEntry* getSequenceCatalogEntry(const transaction::Transaction* transaction, + // Check if sequence entry exists. + bool containsSequence(const transaction::Transaction* transaction, + const std::string& name) const; + // Get sequence entry with name. + SequenceCatalogEntry* getSequenceEntry(const transaction::Transaction* transaction, + const std::string& name, bool useInternalSeq = true) const; + // Get sequence entry with id. + SequenceCatalogEntry* getSequenceEntry(const transaction::Transaction* transaction, common::sequence_id_t sequenceID) const; + // Get all sequence entries. std::vector getSequenceEntries( const transaction::Transaction* transaction) const; + // Create sequence entry. common::sequence_id_t createSequence(transaction::Transaction* transaction, const binder::BoundCreateSequenceInfo& info); + // Drop sequence entry with name. void dropSequence(transaction::Transaction* transaction, const std::string& name); + // Drop sequence entry with id. void dropSequence(transaction::Transaction* transaction, common::sequence_id_t sequenceID); - static std::string genSerialName(const std::string& tableName, const std::string& propertyName); - // ----------------------------- Types ---------------------------- + + // Check if type entry exists. + bool containsType(const transaction::Transaction* transaction, const std::string& name) const; + // Get type entry with name. + common::LogicalType getType(const transaction::Transaction*, const std::string& name) const; + + // Create type entry. void createType(transaction::Transaction* transaction, std::string name, common::LogicalType type); - common::LogicalType getType(const transaction::Transaction*, const std::string& name) const; - bool containsType(const transaction::Transaction* transaction, - const std::string& typeName) const; // ----------------------------- Indexes ---------------------------- - void createIndex(transaction::Transaction* transaction, - std::unique_ptr indexCatalogEntry); - IndexCatalogEntry* getIndex(const transaction::Transaction*, common::table_id_t tableID, - const std::string& indexName) const; - CatalogSet* getIndexes() const; + + // Check if index entry exists. bool containsIndex(const transaction::Transaction* transaction, common::table_id_t tableID, const std::string& indexName) const; + // Get index entry with name. + IndexCatalogEntry* getIndex(const transaction::Transaction* transaction, + common::table_id_t tableID, const std::string& indexName) const; + // Get all index entries. + std::vector getIndexEntries( + const transaction::Transaction* transaction) const; + + // Create index entry. + void createIndex(transaction::Transaction* transaction, + std::unique_ptr indexCatalogEntry); + // Drop all index entries within a table. void dropAllIndexes(transaction::Transaction* transaction, common::table_id_t tableID); + // Drop index entry with name. void dropIndex(transaction::Transaction* transaction, common::table_id_t tableID, const std::string& indexName) const; // ----------------------------- Functions ---------------------------- + void addFunction(transaction::Transaction* transaction, CatalogEntryType entryType, std::string name, function::function_set functionSet); void dropFunction(transaction::Transaction* transaction, const std::string& name); @@ -143,6 +157,9 @@ class KUZU_API Catalog { std::vector getFunctionEntries( const transaction::Transaction* transaction) const; + // ----------------------------- Macro ---------------------------- + + // Check if macro entry exists. bool containsMacro(const transaction::Transaction* transaction, const std::string& macroName) const; void addScalarMacroFunction(transaction::Transaction* transaction, std::string name, @@ -187,9 +204,6 @@ class KUZU_API Catalog { return result; } - std::vector getTableIDs(transaction::Transaction* transaction, - CatalogEntryType catalogType) const; - std::unique_ptr createNodeTableEntry(transaction::Transaction* transaction, const binder::BoundCreateTableInfo& info) const; std::unique_ptr createRelTableEntry(transaction::Transaction* transaction, @@ -197,14 +211,6 @@ class KUZU_API Catalog { std::unique_ptr createRelTableGroupEntry(transaction::Transaction* transaction, const binder::BoundCreateTableInfo& info); - // ----------------------------- Sequence entries ---------------------------- - void iterateSequenceCatalogEntries(const transaction::Transaction* transaction, - const std::function& func) const { - for (auto& [_, entry] : sequences->getEntries(transaction)) { - func(entry); - } - } - protected: std::unique_ptr tables; diff --git a/src/include/catalog/catalog_entry/node_table_catalog_entry.h b/src/include/catalog/catalog_entry/node_table_catalog_entry.h index a20db433b4d..bf9636545c5 100644 --- a/src/include/catalog/catalog_entry/node_table_catalog_entry.h +++ b/src/include/catalog/catalog_entry/node_table_catalog_entry.h @@ -9,6 +9,7 @@ class Transaction; namespace catalog { +class Catalog; class CatalogSet; class KUZU_API NodeTableCatalogEntry final : public TableCatalogEntry { static constexpr CatalogEntryType entryType_ = CatalogEntryType::NODE_TABLE_ENTRY; @@ -28,6 +29,11 @@ class KUZU_API NodeTableCatalogEntry final : public TableCatalogEntry { return getProperty(primaryKeyName); } + common::table_id_set_t getFwdRelTableIDs(Catalog* catalog, + transaction::Transaction* transaction) const; + common::table_id_set_t getBwdRelTableIDs(Catalog* catalog, + transaction::Transaction* transaction) const; + void serialize(common::Serializer& serializer) const override; static std::unique_ptr deserialize(common::Deserializer& deserializer); diff --git a/src/include/catalog/catalog_entry/sequence_catalog_entry.h b/src/include/catalog/catalog_entry/sequence_catalog_entry.h index e980f8efe0f..ffd87e4f5af 100644 --- a/src/include/catalog/catalog_entry/sequence_catalog_entry.h +++ b/src/include/catalog/catalog_entry/sequence_catalog_entry.h @@ -73,10 +73,16 @@ class KUZU_API SequenceCatalogEntry final : public CatalogEntry { //===--------------------------------------------------------------------===// void serialize(common::Serializer& serializer) const override; static std::unique_ptr deserialize(common::Deserializer& deserializer); + std::string toCypher(main::ClientContext* /*clientContext*/) const override; binder::BoundCreateSequenceInfo getBoundCreateSequenceInfo(bool isInternal) const; + static std::string genSerialName(const std::string& tableName, + const std::string& propertyName) { + return std::string(tableName).append("_").append(propertyName).append("_").append("serial"); + } + private: void nextValNoLock(); diff --git a/src/main/storage_driver.cpp b/src/main/storage_driver.cpp index fc6e61b23c5..9b1fcdea4ee 100644 --- a/src/main/storage_driver.cpp +++ b/src/main/storage_driver.cpp @@ -21,16 +21,12 @@ StorageDriver::StorageDriver(Database* database) : database{database} { StorageDriver::~StorageDriver() = default; -static Table* getTable(const ClientContext& context, const std::string& tableName) { - auto tableID = context.getCatalog()->getTableID(context.getTransaction(), tableName); - return context.getStorageManager()->getTable(tableID); +static TableCatalogEntry* getEntry(const ClientContext& context, const std::string& tableName) { + return context.getCatalog()->getTableCatalogEntry(context.getTransaction(), tableName); } -static TableCatalogEntry* getEntry(const ClientContext& context, const std::string& tableName) { - auto catalog = context.getCatalog(); - auto transaction = context.getTransaction(); - auto tableID = catalog->getTableID(transaction, tableName); - return catalog->getTableCatalogEntry(transaction, tableID); +static Table* getTable(const ClientContext& context, const std::string& tableName) { + return context.getStorageManager()->getTable(getEntry(context, tableName)->getTableID()); } static bool validateNumericalType(const LogicalType& type) { diff --git a/src/processor/map/map_delete.cpp b/src/processor/map/map_delete.cpp index dcbb422989b..ab5d0827c9d 100644 --- a/src/processor/map/map_delete.cpp +++ b/src/processor/map/map_delete.cpp @@ -1,5 +1,6 @@ #include "binder/expression/node_expression.h" #include "binder/expression/rel_expression.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" #include "planner/operator/persistent/logical_delete.h" #include "processor/operator/persistent/delete.h" #include "processor/plan_mapper.h" @@ -23,10 +24,11 @@ NodeTableDeleteInfo PlanMapper::getNodeTableDeleteInfo(const TableCatalogEntry& auto table = storageManager->getTable(tableID)->ptrCast(); std::unordered_set fwdRelTables; std::unordered_set bwdRelTables; - for (auto id : catalog->getFwdRelTableIDs(transaction, tableID)) { + auto& nodeEntry = entry.constCast(); + for (auto id : nodeEntry.getFwdRelTableIDs(catalog, transaction)) { fwdRelTables.insert(storageManager->getTable(id)->ptrCast()); } - for (auto id : catalog->getBwdRelTableIDs(transaction, tableID)) { + for (auto id : nodeEntry.getBwdRelTableIDs(catalog, transaction)) { bwdRelTables.insert(storageManager->getTable(id)->ptrCast()); } return NodeTableDeleteInfo(table, std::move(fwdRelTables), std::move(bwdRelTables), pkPos); diff --git a/src/processor/operator/ddl/create_table.cpp b/src/processor/operator/ddl/create_table.cpp index 395e78f0fda..bcd0bb816bd 100644 --- a/src/processor/operator/ddl/create_table.cpp +++ b/src/processor/operator/ddl/create_table.cpp @@ -10,20 +10,21 @@ namespace kuzu { namespace processor { void CreateTable::executeDDLInternal(ExecutionContext* context) { - auto catalog = context->clientContext->getCatalog(); + auto clientContext = context->clientContext; + auto catalog = clientContext->getCatalog(); switch (info.onConflict) { case common::ConflictAction::ON_CONFLICT_DO_NOTHING: { - if (catalog->containsTable(context->clientContext->getTransaction(), info.tableName)) { + if (catalog->containsTable(clientContext->getTransaction(), info.tableName)) { return; } } default: break; } - auto newTableID = catalog->createTableSchema(context->clientContext->getTransaction(), info); + auto newTableID = catalog->createTableEntry(clientContext->getTransaction(), info); tableCreated = true; - auto storageManager = context->clientContext->getStorageManager(); - storageManager->createTable(newTableID, catalog, context->clientContext); + auto storageManager = clientContext->getStorageManager(); + storageManager->createTable(newTableID, catalog, clientContext); } std::string CreateTable::getOutputMsg() { diff --git a/src/processor/operator/simple/export_db.cpp b/src/processor/operator/simple/export_db.cpp index 13cd053a56d..b9ed51aea78 100644 --- a/src/processor/operator/simple/export_db.cpp +++ b/src/processor/operator/simple/export_db.cpp @@ -113,10 +113,9 @@ std::string getCopyCypher(const Catalog* catalog, Transaction* transaction, std::string getIndexCypher(ClientContext* clientContext) { stringstream ss; - for (auto [_, entry] : - clientContext->getCatalog()->getIndexes()->getEntries(clientContext->getTransaction())) { - auto& index = entry->constCast(); - ss << index.toCypher(clientContext) << std::endl; + for (auto entry : + clientContext->getCatalog()->getIndexEntries(clientContext->getTransaction())) { + ss << entry->toCypher(clientContext) << std::endl; } return ss.str(); } diff --git a/src/storage/index/index_utils.cpp b/src/storage/index/index_utils.cpp index 9a20faca7c3..40d1682f588 100644 --- a/src/storage/index/index_utils.cpp +++ b/src/storage/index/index_utils.cpp @@ -8,23 +8,23 @@ namespace kuzu { namespace storage { -static void validateIndexExistence(const main::ClientContext& context, common::table_id_t tableID, - const std::string& indexName, IndexOperation indexOperation) { +static void validateIndexExistence(const main::ClientContext& context, + catalog::TableCatalogEntry* tableEntry, const std::string& indexName, + IndexOperation indexOperation) { switch (indexOperation) { case IndexOperation::CREATE: { - if (context.getCatalog()->containsIndex(context.getTransaction(), tableID, indexName)) { - throw common::BinderException{ - common::stringFormat("Index {} already exists in table {}.", indexName, - context.getCatalog()->getTableName(context.getTransaction(), tableID))}; + if (context.getCatalog()->containsIndex(context.getTransaction(), tableEntry->getTableID(), + indexName)) { + throw common::BinderException{common::stringFormat( + "Index {} already exists in table {}.", indexName, tableEntry->getName())}; } } break; case IndexOperation::DROP: case IndexOperation::QUERY: { - if (!context.getCatalog()->containsIndex(context.getTransaction(), tableID, indexName)) { - const auto tableName = - context.getCatalog()->getTableName(context.getTransaction(), tableID); + if (!context.getCatalog()->containsIndex(context.getTransaction(), tableEntry->getTableID(), + indexName)) { throw common::BinderException{common::stringFormat( - "Table {} doesn't have an index with name {}.", tableName, indexName)}; + "Table {} doesn't have an index with name {}.", tableEntry->getName(), indexName)}; } } break; default: { @@ -54,7 +54,7 @@ catalog::NodeTableCatalogEntry* IndexUtils::bindTable(const main::ClientContext& const auto tableEntry = context.getCatalog()->getTableCatalogEntry(context.getTransaction(), tableName); validateNodeTable(tableEntry); - validateIndexExistence(context, tableEntry->getTableID(), indexName, indexOperation); + validateIndexExistence(context, tableEntry, indexName, indexOperation); return tableEntry->ptrCast(); } diff --git a/src/storage/wal_replayer.cpp b/src/storage/wal_replayer.cpp index 169666ca52e..81ce2110776 100644 --- a/src/storage/wal_replayer.cpp +++ b/src/storage/wal_replayer.cpp @@ -126,11 +126,11 @@ void WALReplayer::replayWALRecord(const WALRecord& walRecord) const { void WALReplayer::replayCreateTableEntryRecord(const WALRecord& walRecord) const { auto& createTableEntryRecord = walRecord.constCast(); KU_ASSERT(clientContext.getCatalog()); - const auto tableID = clientContext.getCatalog()->createTableSchema( - clientContext.getTransaction(), createTableEntryRecord.boundCreateTableInfo); + auto catalog = clientContext.getCatalog(); + const auto tableID = catalog->createTableEntry(clientContext.getTransaction(), + createTableEntryRecord.boundCreateTableInfo); KU_ASSERT(clientContext.getStorageManager()); - clientContext.getStorageManager()->createTable(tableID, clientContext.getCatalog(), - &clientContext); + clientContext.getStorageManager()->createTable(tableID, catalog, &clientContext); } void WALReplayer::replayCreateCatalogEntryRecord(const WALRecord& walRecord) const { @@ -362,8 +362,8 @@ void WALReplayer::replayCopyTableRecord(const WALRecord&) const { void WALReplayer::replayUpdateSequenceRecord(const WALRecord& walRecord) const { auto& sequenceEntryRecord = walRecord.constCast(); const auto sequenceID = sequenceEntryRecord.sequenceID; - const auto entry = clientContext.getCatalog()->getSequenceCatalogEntry( - clientContext.getTransaction(), sequenceID); + const auto entry = + clientContext.getCatalog()->getSequenceEntry(clientContext.getTransaction(), sequenceID); entry->nextKVal(clientContext.getTransaction(), sequenceEntryRecord.kCount); } diff --git a/src/transaction/transaction.cpp b/src/transaction/transaction.cpp index 4f8d80408ba..9774f38ccba 100644 --- a/src/transaction/transaction.cpp +++ b/src/transaction/transaction.cpp @@ -1,6 +1,6 @@ #include "transaction/transaction.h" -#include "catalog/catalog_entry/table_catalog_entry.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" #include "common/exception/runtime.h" #include "main/client_context.h" #include "main/db_config.h" @@ -23,9 +23,10 @@ Transaction::Transaction(main::ClientContext& clientContext, TransactionType tra undoBuffer = std::make_unique(this); currentTS = common::Timestamp::getCurrentTimestamp().value; // Note that the use of `this` should be safe here as there is no inheritance. - for (auto tableID : clientContext.getCatalog()->getNodeTableIDs(this)) { - minUncommittedNodeOffsets[tableID] = - clientContext.getStorageManager()->getTable(tableID)->getNumTotalRows(this); + for (auto entry : clientContext.getCatalog()->getNodeTableEntries(this)) { + auto id = entry->getTableID(); + minUncommittedNodeOffsets[id] = + clientContext.getStorageManager()->getTable(id)->getNumTotalRows(this); } } diff --git a/test/storage/rel_scan_test.cpp b/test/storage/rel_scan_test.cpp index 8747a256787..6a4b4291372 100644 --- a/test/storage/rel_scan_test.cpp +++ b/test/storage/rel_scan_test.cpp @@ -2,6 +2,8 @@ #include #include "catalog/catalog.h" +#include "catalog/catalog_entry/node_table_catalog_entry.h" +#include "catalog/catalog_entry/rel_table_catalog_entry.h" #include "common/types/date_t.h" #include "common/types/ku_string.h" #include "common/types/types.h" @@ -26,10 +28,15 @@ class RelScanTest : public PrivateApiTest { context = getClientContext(*conn); catalog = context->getCatalog(); auto transaction = context->getTransaction(); - auto nodeTableIDs = catalog->getNodeTableIDs(transaction); - auto relTableIDs = catalog->getRelTableIDs(transaction); - auto entry = graph::GraphEntry(catalog->getTableEntries(transaction, nodeTableIDs), - catalog->getTableEntries(transaction, relTableIDs)); + std::vector nodeEntries; + for (auto& entry : catalog->getNodeTableEntries(transaction)) { + nodeEntries.push_back(entry); + } + std::vector relEntries; + for (auto& entry : catalog->getRelTableEntries(transaction)) { + relEntries.push_back(entry); + } + auto entry = graph::GraphEntry(nodeEntries, relEntries); graph = std::make_unique(context, std::move(entry)); fwdStorageOnly = (common::ExtendDirectionUtil::getDefaultExtendDirection() == @@ -51,7 +58,7 @@ class RelScanTestAmazon : public RelScanTest { // Test correctness of scan fwd TEST_F(RelScanTest, ScanFwd) { - auto tableID = catalog->getTableID(context->getTransaction(), "person"); + auto tableID = catalog->getTableCatalogEntry(context->getTransaction(), "person")->getTableID(); auto relEntry = catalog->getTableCatalogEntry(context->getTransaction(), "knows"); auto scanState = graph->prepareRelScan(relEntry, "date");