Skip to content

Commit

Permalink
More refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jan 17, 2025
1 parent 71c1650 commit 6c39667
Show file tree
Hide file tree
Showing 13 changed files with 121 additions and 104 deletions.
4 changes: 1 addition & 3 deletions extension/fts/src/fts_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ namespace kuzu {
namespace fts_extension {

static void initFTSEntries(const transaction::Transaction* transaction, catalog::Catalog& catalog) {
auto indexEntries = catalog.getIndexes()->getEntries(transaction);
for (auto& [_, entry] : indexEntries) {
auto indexEntry = entry->ptrCast<catalog::IndexCatalogEntry>();
for (auto& indexEntry : catalog.getIndexEntries(transaction)) {
if (indexEntry->getIndexType() == FTSIndexCatalogEntry::TYPE_NAME) {
indexEntry->setAuxInfo(FTSIndexAuxInfo::deserialize(indexEntry->getAuxBufferReader()));
}
Expand Down
5 changes: 3 additions & 2 deletions src/binder/bind/bind_ddl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "catalog/catalog_entry/rel_group_catalog_entry.h"
#include "catalog/catalog_entry/rel_table_catalog_entry.h"
#include "catalog/catalog_entry/sequence_catalog_entry.h"
#include "common/exception/binder.h"
#include "common/exception/message.h"
#include "common/string_format.h"
Expand Down Expand Up @@ -84,7 +85,7 @@ std::vector<PropertyDefinition> Binder::bindPropertyDefinitions(
if (type.getLogicalTypeID() == LogicalTypeID::SERIAL) {
validateSerialNoDefault(*boundExpr);
expr = ParsedExpressionUtils::getSerialDefaultExpr(
Catalog::genSerialName(tableName, parsedDefinition.getName()));
SequenceCatalogEntry::genSerialName(tableName, parsedDefinition.getName()));
}
auto columnDefinition = ColumnDefinition(parsedDefinition.getName(), std::move(type));
definitions.emplace_back(std::move(columnDefinition), std::move(expr));
Expand Down Expand Up @@ -496,7 +497,7 @@ std::unique_ptr<BoundStatement> Binder::bindAddProperty(const Statement& stateme
if (dataType.getLogicalTypeID() == LogicalTypeID::SERIAL) {
validateSerialNoDefault(*boundDefault);
defaultValue = ParsedExpressionUtils::getSerialDefaultExpr(
Catalog::genSerialName(tableName, propertyName));
SequenceCatalogEntry::genSerialName(tableName, propertyName));
boundDefault = expressionBinder.implicitCastIfNecessary(
expressionBinder.bindExpression(*defaultValue), dataType);
}
Expand Down
69 changes: 19 additions & 50 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,6 @@ std::vector<TableCatalogEntry*> Catalog::getTableEntries(const Transaction* tran
return result;
}

std::vector<TableCatalogEntry*> Catalog::getTableEntries(const Transaction* transaction,
const table_id_vector_t& tableIDs) const {
std::vector<TableCatalogEntry*> result;
for (const auto tableID : tableIDs) {
result.push_back(getTableCatalogEntry(transaction, tableID));
}
return result;
}

bool Catalog::tableInRelGroup(Transaction* transaction, table_id_t tableID) const {
for (const auto& entry : getRelTableGroupEntries(transaction)) {
if (entry->isParent(tableID)) {
Expand All @@ -148,29 +139,7 @@ bool Catalog::tableInRelGroup(Transaction* transaction, table_id_t tableID) cons
return false;
}

table_id_set_t Catalog::getFwdRelTableIDs(Transaction* transaction, table_id_t nodeTableID) const {
KU_ASSERT(getTableCatalogEntry(transaction, nodeTableID)->getTableType() == TableType::NODE);
table_id_set_t result;
for (const auto& relEntry : getRelTableEntries(transaction)) {
if (relEntry->getSrcTableID() == nodeTableID) {
result.insert(relEntry->getTableID());
}
}
return result;
}

table_id_set_t Catalog::getBwdRelTableIDs(Transaction* transaction, table_id_t nodeTableID) const {
KU_ASSERT(getTableCatalogEntry(transaction, nodeTableID)->getTableType() == TableType::NODE);
table_id_set_t result;
for (const auto& relEntry : getRelTableEntries(transaction)) {
if (relEntry->getDstTableID() == nodeTableID) {
result.insert(relEntry->getTableID());
}
}
return result;
}

table_id_t Catalog::createTableSchema(Transaction* transaction, const BoundCreateTableInfo& info) {
table_id_t Catalog::createTableEntry(Transaction* transaction, const BoundCreateTableInfo& info) {
std::unique_ptr<CatalogEntry> entry;
switch (info.type) {
case TableType::NODE: {
Expand All @@ -188,7 +157,7 @@ table_id_t Catalog::createTableSchema(Transaction* transaction, const BoundCreat
const auto tableEntry = entry->constPtrCast<TableCatalogEntry>();
for (auto& definition : tableEntry->getProperties()) {
if (definition.getType().getLogicalTypeID() == LogicalTypeID::SERIAL) {
const auto seqName = genSerialName(tableEntry->getName(), definition.getName());
const auto seqName = SequenceCatalogEntry::genSerialName(tableEntry->getName(), definition.getName());
auto seqInfo =
BoundCreateSequenceInfo(seqName, 0, 1, 0, std::numeric_limits<int64_t>::max(),
false, ConflictAction::ON_CONFLICT_THROW, info.isInternal);
Expand Down Expand Up @@ -224,7 +193,7 @@ void Catalog::dropTableEntry(Transaction* transaction, table_id_t tableID) {
}
for (auto& definition : tableEntry->getProperties()) {
if (definition.getType().getLogicalTypeID() == LogicalTypeID::SERIAL) {
auto seqName = genSerialName(tableEntry->getName(), definition.getName());
auto seqName = SequenceCatalogEntry::genSerialName(tableEntry->getName(), definition.getName());
dropSequence(transaction, seqName);
}
}
Expand Down Expand Up @@ -259,7 +228,7 @@ bool Catalog::containsSequence(const Transaction* transaction,
return sequences->containsEntry(transaction, name);
}

sequence_id_t Catalog::getSequenceID(const Transaction* transaction,
SequenceCatalogEntry* Catalog::getSequenceEntry(const Transaction* transaction,
const std::string& sequenceName, bool useInternalSeq) const {
CatalogEntry* entry = nullptr;
if (!sequences->containsEntry(transaction, sequenceName) && useInternalSeq) {
Expand All @@ -268,10 +237,10 @@ sequence_id_t Catalog::getSequenceID(const Transaction* transaction,
entry = sequences->getEntry(transaction, sequenceName);
}
KU_ASSERT(entry);
return entry->getOID();
return entry->ptrCast<SequenceCatalogEntry>();
}

SequenceCatalogEntry* Catalog::getSequenceCatalogEntry(const Transaction* transaction,
SequenceCatalogEntry* Catalog::getSequenceEntry(const Transaction* transaction,
sequence_id_t sequenceID) const {
auto entry = internalSequences->getEntryOfOID(transaction, sequenceID);
if (entry == nullptr) {
Expand Down Expand Up @@ -302,22 +271,18 @@ sequence_id_t Catalog::createSequence(Transaction* transaction,
}

void Catalog::dropSequence(Transaction* transaction, const std::string& name) {
const auto sequenceID = getSequenceID(transaction, name);
dropSequence(transaction, sequenceID);
const auto entry = getSequenceEntry(transaction, name);
dropSequence(transaction, entry->getOID());
}

void Catalog::dropSequence(Transaction* transaction, sequence_id_t sequenceID) {
const auto sequenceEntry = getSequenceCatalogEntry(transaction, sequenceID);
const auto sequenceEntry = getSequenceEntry(transaction, sequenceID);
CatalogSet* set = nullptr;
set = sequences->containsEntry(transaction, sequenceEntry->getName()) ? sequences.get() :
internalSequences.get();
set->dropEntry(transaction, sequenceEntry->getName(), sequenceEntry->getOID());
}

std::string Catalog::genSerialName(const std::string& tableName, const std::string& propertyName) {
return std::string(tableName).append("_").append(propertyName).append("_").append("serial");
}

void Catalog::createType(Transaction* transaction, std::string name, LogicalType type) {
if (types->containsEntry(transaction, name)) {
return;
Expand Down Expand Up @@ -348,13 +313,17 @@ void Catalog::createIndex(Transaction* transaction,

IndexCatalogEntry* Catalog::getIndex(const Transaction* transaction, table_id_t tableID,
const std::string& indexName) const {
return indexes
->getEntry(transaction, IndexCatalogEntry::getInternalIndexName(tableID, indexName))
->ptrCast<IndexCatalogEntry>();
auto internalName = IndexCatalogEntry::getInternalIndexName(tableID, indexName);
return indexes->getEntry(transaction, internalName)->ptrCast<IndexCatalogEntry>();
}

CatalogSet* Catalog::getIndexes() const {
return indexes.get();
std::vector<IndexCatalogEntry*> Catalog::getIndexEntries(
const Transaction* transaction) const {
std::vector<IndexCatalogEntry*> result;
for (auto& [_, entry] : indexes->getEntries(transaction)) {
result.push_back(entry->ptrCast<IndexCatalogEntry>());
}
return result;
}

bool Catalog::containsIndex(const Transaction* transaction, table_id_t tableID,
Expand Down Expand Up @@ -560,7 +529,7 @@ std::unique_ptr<CatalogEntry> Catalog::createRelTableGroupEntry(Transaction* tra
relTableIDs.reserve(extraInfo->infos.size());
for (auto& childInfo : extraInfo->infos) {
childInfo.hasParent = true;
relTableIDs.push_back(createTableSchema(transaction, childInfo));
relTableIDs.push_back(createTableEntry(transaction, childInfo));
}
return std::make_unique<RelGroupCatalogEntry>(tables.get(), info.tableName,
std::move(relTableIDs));
Expand Down
22 changes: 22 additions & 0 deletions src/catalog/catalog_entry/node_table_catalog_entry.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
#include "catalog/catalog_entry/node_table_catalog_entry.h"

#include "catalog/catalog_entry/rel_table_catalog_entry.h"
#include "binder/ddl/bound_create_table_info.h"
#include "catalog/catalog_set.h"
#include "catalog/catalog.h"
#include "common/serializer/deserializer.h"

using namespace kuzu::binder;

namespace kuzu {
namespace catalog {

common::table_id_set_t NodeTableCatalogEntry::getFwdRelTableIDs(Catalog* catalog, transaction::Transaction* transaction) const {
common::table_id_set_t result;
for (const auto& relEntry : catalog->getRelTableEntries(transaction)) {
if (relEntry->getSrcTableID() == getTableID()) {
result.insert(relEntry->getTableID());
}
}
return result;
}

common::table_id_set_t NodeTableCatalogEntry::getBwdRelTableIDs(Catalog* catalog, transaction::Transaction* transaction) const {
common::table_id_set_t result;
for (const auto& relEntry : catalog->getRelTableEntries(transaction)) {
if (relEntry->getDstTableID() == getTableID()) {
result.insert(relEntry->getTableID());
}
}
return result;
}

void NodeTableCatalogEntry::serialize(common::Serializer& serializer) const {
TableCatalogEntry::serialize(serializer);
serializer.writeDebuggingInfo("primaryKeyName");
Expand Down
8 changes: 2 additions & 6 deletions src/function/sequence/sequence_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ struct CurrVal {
auto ctx = reinterpret_cast<FunctionBindData*>(dataPtr)->clientContext;
auto catalog = ctx->getCatalog();
auto sequenceName = input.getAsString();
auto sequenceID = catalog->getSequenceID(ctx->getTransaction(), sequenceName,
ctx->shouldUseInternalCatalogEntry());
auto sequenceEntry = catalog->getSequenceCatalogEntry(ctx->getTransaction(), sequenceID);
auto sequenceEntry = catalog->getSequenceEntry(ctx->getTransaction(), sequenceName, ctx->shouldUseInternalCatalogEntry());
result.setValue(0, sequenceEntry->currVal());
}
};
Expand All @@ -28,9 +26,7 @@ struct NextVal {
auto cnt = reinterpret_cast<FunctionBindData*>(dataPtr)->count;
auto catalog = ctx->getCatalog();
auto sequenceName = input.getAsString();
auto sequenceID = catalog->getSequenceID(ctx->getTransaction(), sequenceName,
ctx->shouldUseInternalCatalogEntry());
auto sequenceEntry = catalog->getSequenceCatalogEntry(ctx->getTransaction(), sequenceID);
auto sequenceEntry = catalog->getSequenceEntry(ctx->getTransaction(), sequenceName, ctx->shouldUseInternalCatalogEntry());
sequenceEntry->nextKVal(ctx->getTransaction(), cnt, result);
result.state->getSelVectorUnsafe().setSelSize(cnt);
}
Expand Down
10 changes: 6 additions & 4 deletions src/graph/on_disk_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "common/enums/rel_direction.h"
#include "common/types/types.h"
#include "common/vector/value_vector.h"
#include "catalog/catalog_entry/node_table_catalog_entry.h"
#include "expression_evaluator/expression_evaluator.h"
#include "graph/graph.h"
#include "main/client_context.h"
Expand Down Expand Up @@ -129,20 +130,21 @@ OnDiskGraph::OnDiskGraph(ClientContext* context, GraphEntry entry)
auto storage = context->getStorageManager();
auto catalog = context->getCatalog();
auto transaction = context->getTransaction();
for (auto& nodeEntry : graphEntry.nodeEntries) {
auto nodeTableID = nodeEntry->getTableID();
for (auto& entry : graphEntry.nodeEntries) {
auto& nodeEntry = entry->constCast<NodeTableCatalogEntry>();
auto nodeTableID = nodeEntry.getTableID();
nodeIDToNodeTable.insert(
{nodeTableID, storage->getTable(nodeTableID)->ptrCast<NodeTable>()});
table_id_map_t<RelTable*> fwdRelTables;
for (auto& relTableID : catalog->getFwdRelTableIDs(transaction, nodeTableID)) {
for (auto& relTableID : nodeEntry.getFwdRelTableIDs(catalog, transaction)) {
if (!graphEntry.hasRelEntry(relTableID)) {
continue;
}
fwdRelTables.insert({relTableID, storage->getTable(relTableID)->ptrCast<RelTable>()});
}
nodeTableIDToFwdRelTables.insert({nodeTableID, std::move(fwdRelTables)});
table_id_map_t<RelTable*> bwdRelTables;
for (auto& relTableID : catalog->getBwdRelTableIDs(transaction, nodeTableID)) {
for (auto& relTableID : nodeEntry.getBwdRelTableIDs(catalog, transaction)) {
if (!graphEntry.hasRelEntry(relTableID)) {
continue;
}
Expand Down
67 changes: 43 additions & 24 deletions src/include/catalog/catalog.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,62 +74,78 @@ class KUZU_API Catalog {
// Get all rel group entries.
std::vector<RelGroupCatalogEntry*> getRelTableGroupEntries(
transaction::Transaction* transaction) const;
// Get all table entries.
std::vector<TableCatalogEntry*> getTableEntries(
const transaction::Transaction* transaction) const;
std::vector<TableCatalogEntry*> getTableEntries(const transaction::Transaction* transaction,
const common::table_id_vector_t& tableIDs) const;
bool tableInRelGroup(transaction::Transaction* transaction, common::table_id_t tableID) const;

// TODO fix me
common::table_id_set_t getFwdRelTableIDs(transaction::Transaction* transaction,
common::table_id_t nodeTableID) const;
common::table_id_set_t getBwdRelTableIDs(transaction::Transaction* transaction,
common::table_id_t nodeTableID) const;

common::table_id_t createTableSchema(transaction::Transaction* transaction,

// Create table entry.
common::table_id_t createTableEntry(transaction::Transaction* transaction,
const binder::BoundCreateTableInfo& info);
// Drop table entry with name.
void dropTableEntry(transaction::Transaction* transaction, const std::string& name);
// Drop table entry with id.
void dropTableEntry(transaction::Transaction* transaction, common::table_id_t tableID);
// Alter table entry.
void alterTableEntry(transaction::Transaction* transaction, const binder::BoundAlterInfo& info);

// ----------------------------- Sequences ----------------------------

// Check if sequence entry exists.
bool containsSequence(const transaction::Transaction* transaction,
const std::string& name) const;

common::sequence_id_t getSequenceID(const transaction::Transaction* transaction,
const std::string& sequenceName, bool useInternalSeq = true) const;
SequenceCatalogEntry* getSequenceCatalogEntry(const transaction::Transaction* transaction,
// Get sequence entry with name.
SequenceCatalogEntry* getSequenceEntry(const transaction::Transaction* transaction,
const std::string& name, bool useInternalSeq = true) const;
// Get sequence entry with id.
SequenceCatalogEntry* getSequenceEntry(const transaction::Transaction* transaction,
common::sequence_id_t sequenceID) const;
// Get all sequence entries.
std::vector<SequenceCatalogEntry*> getSequenceEntries(
const transaction::Transaction* transaction) const;

// Create sequence entry.
common::sequence_id_t createSequence(transaction::Transaction* transaction,
const binder::BoundCreateSequenceInfo& info);
// Drop sequence entry with name.
void dropSequence(transaction::Transaction* transaction, const std::string& name);
// Drop sequence entry with id.
void dropSequence(transaction::Transaction* transaction, common::sequence_id_t sequenceID);

static std::string genSerialName(const std::string& tableName, const std::string& propertyName);

// ----------------------------- Types ----------------------------

// Check if type entry exists.
bool containsType(const transaction::Transaction* transaction,
const std::string& name) const;
// Get type entry with name.
common::LogicalType getType(const transaction::Transaction*, const std::string& name) const;

// Create type entry.
void createType(transaction::Transaction* transaction, std::string name,
common::LogicalType type);
common::LogicalType getType(const transaction::Transaction*, const std::string& name) const;
bool containsType(const transaction::Transaction* transaction,
const std::string& typeName) const;

// ----------------------------- Indexes ----------------------------
void createIndex(transaction::Transaction* transaction,
std::unique_ptr<IndexCatalogEntry> indexCatalogEntry);
IndexCatalogEntry* getIndex(const transaction::Transaction*, common::table_id_t tableID,
const std::string& indexName) const;
CatalogSet* getIndexes() const;

// Check if index entry exists.
bool containsIndex(const transaction::Transaction* transaction, common::table_id_t tableID,
const std::string& indexName) const;
// Get index entry with name.
IndexCatalogEntry* getIndex(const transaction::Transaction* transaction, common::table_id_t tableID,
const std::string& indexName) const;
// Get all index entries.
std::vector<IndexCatalogEntry*> getIndexEntries(const transaction::Transaction* transaction) const;

// Create index entry.
void createIndex(transaction::Transaction* transaction,
std::unique_ptr<IndexCatalogEntry> indexCatalogEntry);
// Drop all index entries within a table.
void dropAllIndexes(transaction::Transaction* transaction, common::table_id_t tableID);
// Drop index entry with name.
void dropIndex(transaction::Transaction* transaction, common::table_id_t tableID,
const std::string& indexName) const;

// ----------------------------- Functions ----------------------------

void addFunction(transaction::Transaction* transaction, CatalogEntryType entryType,
std::string name, function::function_set functionSet);
void dropFunction(transaction::Transaction* transaction, const std::string& name);
Expand All @@ -141,6 +157,9 @@ class KUZU_API Catalog {
std::vector<FunctionCatalogEntry*> getFunctionEntries(
const transaction::Transaction* transaction) const;

// ----------------------------- Macro ----------------------------

// Check if macro entry exists.
bool containsMacro(const transaction::Transaction* transaction,
const std::string& macroName) const;
void addScalarMacroFunction(transaction::Transaction* transaction, std::string name,
Expand Down
Loading

0 comments on commit 6c39667

Please sign in to comment.