Skip to content

Commit

Permalink
Clear catalog interface
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Jan 17, 2025
1 parent cba13a1 commit 71c1650
Show file tree
Hide file tree
Showing 20 changed files with 95 additions and 148 deletions.
3 changes: 2 additions & 1 deletion extension/fts/src/catalog/fts_index_catalog_entry.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "catalog/fts_index_catalog_entry.h"

#include "catalog/catalog.h"
#include "catalog/catalog_entry/table_catalog_entry.h"
#include "common/serializer/buffered_reader.h"
#include "common/serializer/buffered_serializer.h"
#include "main/client_context.h"
Expand Down Expand Up @@ -36,7 +37,7 @@ std::string FTSIndexAuxInfo::toCypher(const catalog::IndexCatalogEntry& indexEnt
main::ClientContext* context) {
std::string cypher;
auto catalog = context->getCatalog();
auto tableName = catalog->getTableName(context->getTransaction(), indexEntry.getTableID());
auto tableName = catalog->getTableCatalogEntry(context->getTransaction(), indexEntry.getTableID())->getName();
std::string propertyStr;
for (auto i = 0u; i < properties.size(); i++) {
propertyStr +=
Expand Down
3 changes: 2 additions & 1 deletion extension/fts/src/function/query_fts_bind_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "binder/expression/expression_util.h"
#include "catalog/fts_index_catalog_entry.h"
#include "catalog/catalog_entry/table_catalog_entry.h"
#include "common/string_utils.h"
#include "function/fts_utils.h"
#include "function/stem.h"
Expand Down Expand Up @@ -57,7 +58,7 @@ struct StopWordsChecker {
StopWordsChecker::StopWordsChecker(main::ClientContext& context)
: termsVector{LogicalType::STRING(), context.getMemoryManager()}, tx{context.getTransaction()} {
termsVector.state = common::DataChunkState::getSingleValueDataChunkState();
auto tableID = context.getCatalog()->getTableID(tx, FTSUtils::getStopWordsTableName());
auto tableID = context.getCatalog()->getTableCatalogEntry(tx, FTSUtils::getStopWordsTableName())->getTableID();
stopWordsTable = context.getStorageManager()->getTable(tableID)->ptrCast<storage::NodeTable>();
}

Expand Down
4 changes: 2 additions & 2 deletions src/binder/bind/bind_export_database.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ static void bindExportRelTableDataQuery(const TableCatalogEntry& entry, std::str
auto relTableEntry = entry.constPtrCast<RelTableCatalogEntry>();
auto srcPrimaryKeyName = getPrimaryKeyName(relTableEntry->getSrcTableID(), catalog, tx);
auto dstPrimaryKeyName = getPrimaryKeyName(relTableEntry->getDstTableID(), catalog, tx);
auto srcName = catalog.getTableName(tx, relTableEntry->getSrcTableID());
auto dstName = catalog.getTableName(tx, relTableEntry->getDstTableID());
auto srcName = catalog.getTableCatalogEntry(tx, relTableEntry->getSrcTableID())->getName();
auto dstName = catalog.getTableCatalogEntry(tx, relTableEntry->getDstTableID())->getName();
auto relName = relTableEntry->getName();
exportQuery = stringFormat("match (a:{})-[r:{}]->(b:{}) return a.{},b.{},r.*;", srcName,
relName, dstName, srcPrimaryKeyName, dstPrimaryKeyName);
Expand Down
3 changes: 1 addition & 2 deletions src/binder/bind/bind_file_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ static TableFunction getObjectScanFunc(const std::string& dbName, const std::str
throw BinderException{stringFormat("No database named {} has been attached.", dbName)};
}
auto attachedCatalog = attachedDB->getCatalog();
auto tableID = attachedCatalog->getTableID(clientContext->getTransaction(), tableName);
auto entry = attachedCatalog->getTableCatalogEntry(clientContext->getTransaction(), tableID);
auto entry = attachedCatalog->getTableCatalogEntry(clientContext->getTransaction(), tableName);
return entry->ptrCast<TableCatalogEntry>()->getScanFunction();
}

Expand Down
3 changes: 1 addition & 2 deletions src/binder/bind/copy/bind_copy_from.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& stat
validateTableExist(tableName);
// Bind to table schema.
auto catalog = clientContext->getCatalog();
auto tableID = catalog->getTableID(clientContext->getTransaction(), tableName);
auto tableEntry = catalog->getTableCatalogEntry(clientContext->getTransaction(), tableID);
auto tableEntry = catalog->getTableCatalogEntry(clientContext->getTransaction(), tableName);
switch (tableEntry->getTableType()) {
case TableType::REL_GROUP: {
throw BinderException(stringFormat("Cannot copy into {} table with type {}.", tableName,
Expand Down
35 changes: 15 additions & 20 deletions src/binder/bind_expression/bind_function_expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,20 @@ std::shared_ptr<Expression> ExpressionBinder::bindEndNodeExpression(const Expres
return rel.getDstNode();
}

static std::vector<std::unique_ptr<Value>> populateLabelValues(std::vector<table_id_t> tableIDs,
const catalog::Catalog& catalog, transaction::Transaction* tx) {
auto tableIDsSet = std::unordered_set<table_id_t>(tableIDs.begin(), tableIDs.end());
table_id_t maxTableID = *std::max_element(tableIDsSet.begin(), tableIDsSet.end());
static std::vector<std::unique_ptr<Value>> populateLabelValues(std::vector<TableCatalogEntry*> entries) {
std::unordered_map<table_id_t, std::string> map;
common::table_id_t maxTableID = 0;
for (auto& entry : entries) {
map.insert({entry->getTableID(), entry->getName()});
if (entry->getTableID() > maxTableID) {
maxTableID = entry->getTableID();
}
}
std::vector<std::unique_ptr<Value>> labels;
labels.resize(maxTableID + 1);
for (auto i = 0u; i < labels.size(); ++i) {
if (tableIDsSet.contains(i)) {
labels[i] = std::make_unique<Value>(LogicalType::STRING(), catalog.getTableName(tx, i));
if (map.contains(i)) {
labels[i] = std::make_unique<Value>(LogicalType::STRING(), map.at(i));
} else {
// TODO(Xiyang/Guodong): change to null literal once we support null in LIST type.
labels[i] = std::make_unique<Value>(LogicalType::STRING(), std::string(""));
Expand All @@ -257,7 +262,6 @@ static std::vector<std::unique_ptr<Value>> populateLabelValues(std::vector<table
}

std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(const Expression& expression) {
auto catalog = context->getCatalog();
auto listType = LogicalType::LIST(LogicalType::STRING());
expression_vector children;
switch (expression.getDataType().getLogicalTypeID()) {
Expand All @@ -267,16 +271,11 @@ std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(const Expression
return createLiteralExpression("");
}
if (!node.isMultiLabeled()) {
auto labelName = catalog->getTableName(context->getTransaction(),
node.getSingleEntry()->getTableID());
auto labelName = node.getSingleEntry()->getName();
return createLiteralExpression(Value(LogicalType::STRING(), labelName));
}
// Internal tables should be invisible to the label function.
auto nodeTableIDs =
catalog->getNodeTableIDs(context->getTransaction(), false /* useInternalTable */);
children.push_back(node.getInternalID());
auto labelsValue = Value(std::move(listType),
populateLabelValues(std::move(nodeTableIDs), *catalog, context->getTransaction()));
auto labelsValue = Value(std::move(listType), populateLabelValues(node.getEntries()));
children.push_back(createLiteralExpression(labelsValue));
} break;
case LogicalTypeID::REL: {
Expand All @@ -285,15 +284,11 @@ std::shared_ptr<Expression> ExpressionBinder::bindLabelFunction(const Expression
return createLiteralExpression("");
}
if (!rel.isMultiLabeled()) {
auto labelName = catalog->getTableName(context->getTransaction(),
rel.getSingleEntry()->getTableID());
auto labelName = rel.getSingleEntry()->getName();
return createLiteralExpression(Value(LogicalType::STRING(), labelName));
}
auto relTableIDs =
catalog->getRelTableIDs(context->getTransaction(), false /* useInternalTable */);
children.push_back(rel.getInternalIDProperty());
auto labelsValue = Value(std::move(listType),
populateLabelValues(std::move(relTableIDs), *catalog, context->getTransaction()));
auto labelsValue = Value(std::move(listType), populateLabelValues(rel.getEntries()));
children.push_back(createLiteralExpression(labelsValue));
} break;
default:
Expand Down
39 changes: 3 additions & 36 deletions src/catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,39 +73,6 @@ bool Catalog::containsTable(const Transaction* transaction, const std::string& t
return contains;
}

table_id_t Catalog::getTableID(const Transaction* transaction, const std::string& tableName,
bool useInternal) const {
return getTableCatalogEntry(transaction, tableName, useInternal)->getTableID();
}

std::vector<table_id_t> Catalog::getNodeTableIDs(const Transaction* transaction,
bool useInternal) const {
std::vector<table_id_t> tableIDs;
tables->iterateEntriesOfType(transaction, CatalogEntryType::NODE_TABLE_ENTRY,
[&](const CatalogEntry* entry) { tableIDs.push_back(entry->getOID()); });
if (useInternal) {
internalTables->iterateEntriesOfType(transaction, CatalogEntryType::NODE_TABLE_ENTRY,
[&](const CatalogEntry* entry) { tableIDs.push_back(entry->getOID()); });
}
return tableIDs;
}

std::vector<table_id_t> Catalog::getRelTableIDs(const Transaction* transaction,
bool useInternal) const {
std::vector<table_id_t> tableIDs;
tables->iterateEntriesOfType(transaction, CatalogEntryType::REL_TABLE_ENTRY,
[&](const CatalogEntry* entry) { tableIDs.push_back(entry->getOID()); });
if (useInternal) {
internalTables->iterateEntriesOfType(transaction, CatalogEntryType::REL_TABLE_ENTRY,
[&](const CatalogEntry* entry) { tableIDs.push_back(entry->getOID()); });
}
return tableIDs;
}

std::string Catalog::getTableName(const Transaction* transaction, table_id_t tableID) const {
return getTableCatalogEntry(transaction, tableID)->getName();
}

TableCatalogEntry* Catalog::getTableCatalogEntry(const Transaction* transaction,
table_id_t tableID) const {
auto result = tables->getEntryOfOID(transaction, tableID);
Expand Down Expand Up @@ -237,7 +204,7 @@ table_id_t Catalog::createTableSchema(Transaction* transaction, const BoundCreat
}

void Catalog::dropTableEntry(Transaction* transaction, const std::string& name) {
const auto tableID = getTableID(transaction, name);
auto tableID = getTableCatalogEntry(transaction, name)->getTableID();
dropAllIndexes(transaction, tableID);
dropTableEntry(transaction, tableID);
}
Expand Down Expand Up @@ -288,8 +255,8 @@ void Catalog::alterTableEntry(Transaction* transaction, const BoundAlterInfo& in
}

bool Catalog::containsSequence(const Transaction* transaction,
const std::string& sequenceName) const {
return sequences->containsEntry(transaction, sequenceName);
const std::string& name) const {
return sequences->containsEntry(transaction, name);
}

sequence_id_t Catalog::getSequenceID(const Transaction* transaction,
Expand Down
3 changes: 2 additions & 1 deletion src/catalog/catalog_entry/hnsw_index_catalog_entry.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "catalog/catalog_entry/hnsw_index_catalog_entry.h"

#include "catalog/catalog.h"
#include "catalog/catalog_entry/table_catalog_entry.h"
#include "common/serializer/buffered_reader.h"
#include "common/serializer/buffered_serializer.h"
#include "common/serializer/deserializer.h"
Expand Down Expand Up @@ -43,7 +44,7 @@ std::string HNSWIndexAuxInfo::toCypher(const IndexCatalogEntry& indexEntry,
main::ClientContext* context) {
std::string cypher;
auto catalog = context->getCatalog();
auto tableName = catalog->getTableName(context->getTransaction(), indexEntry.getTableID());
auto tableName = catalog->getTableCatalogEntry(context->getTransaction(), indexEntry.getTableID())->getName();
auto distFuncName = storage::HNSWIndexConfig::distFuncToString(config.distFunc);
cypher += common::stringFormat("CALL CREATE_HNSW_INDEX('{}', '{}', '{}', mu := {}, ml := {}, "
"pu := {}, distFunc := '{}', alpha := {}, efc := {});\n",
Expand Down
4 changes: 2 additions & 2 deletions src/catalog/catalog_entry/rel_group_catalog_entry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ static std::string getFromToStr(common::table_id_t tableID, ClientContext* conte
auto transaction = context->getTransaction();
auto& entry =
catalog->getTableCatalogEntry(transaction, tableID)->constCast<RelTableCatalogEntry>();
auto srcTableName = catalog->getTableName(transaction, entry.getSrcTableID());
auto dstTableName = catalog->getTableName(transaction, entry.getDstTableID());
auto srcTableName = catalog->getTableCatalogEntry(transaction, entry.getSrcTableID())->getName();
auto dstTableName = catalog->getTableCatalogEntry(transaction, entry.getDstTableID())->getName();
return stringFormat("FROM {} TO {}", srcTableName, dstTableName);
}

Expand Down
5 changes: 3 additions & 2 deletions src/catalog/catalog_entry/rel_table_catalog_entry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ std::unique_ptr<TableCatalogEntry> RelTableCatalogEntry::copy() const {
std::string RelTableCatalogEntry::toCypher(main::ClientContext* clientContext) const {
std::stringstream ss;
auto catalog = clientContext->getCatalog();
auto srcTableName = catalog->getTableName(clientContext->getTransaction(), srcTableID);
auto dstTableName = catalog->getTableName(clientContext->getTransaction(), dstTableID);
auto transaction = clientContext->getTransaction();
auto srcTableName = catalog->getTableCatalogEntry(transaction, srcTableID)->getName();
auto dstTableName = catalog->getTableCatalogEntry(transaction, dstTableID)->getName();
auto srcMultiStr = srcMultiplicity == common::RelMultiplicity::MANY ? "MANY" : "ONE";
auto dstMultiStr = dstMultiplicity == common::RelMultiplicity::MANY ? "MANY" : "ONE";
std::string tableInfo =
Expand Down
31 changes: 16 additions & 15 deletions src/function/table/call/hnsw/create_hnsw_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,25 @@ static offset_t tableFunc(const TableFuncInput& input, TableFuncOutput&) {

static void finalizeFunc(const processor::ExecutionContext* context,
TableFuncSharedState* sharedState) {
auto clientContext = context->clientContext;
auto transaction = clientContext->getTransaction();
const auto hnswSharedState = sharedState->ptrCast<CreateHNSWSharedState>();
hnswSharedState->hnswIndex->shrink(context->clientContext->getTransaction());
hnswSharedState->hnswIndex->finalize(*context->clientContext->getMemoryManager(),
*hnswSharedState->partitionerSharedState);
const auto index = hnswSharedState->hnswIndex.get();
index->shrink(transaction);
index->finalize(*clientContext->getMemoryManager(), *hnswSharedState->partitionerSharedState);

const auto bindData = hnswSharedState->bindData->constPtrCast<CreateHNSWIndexBindData>();
const auto catalog = context->clientContext->getCatalog();
auto upperRelTableID = catalog->getTableID(context->clientContext->getTransaction(),
storage::HNSWIndexUtils::getUpperGraphTableName(bindData->indexName));
auto lowerRelTableID = catalog->getTableID(context->clientContext->getTransaction(),
storage::HNSWIndexUtils::getLowerGraphTableName(bindData->indexName));
catalog->createIndex(context->clientContext->getTransaction(),
std::make_unique<catalog::IndexCatalogEntry>(catalog::HNSWIndexCatalogEntry::TYPE_NAME,
bindData->tableEntry->getTableID(), bindData->indexName,
std::make_unique<catalog::HNSWIndexAuxInfo>(upperRelTableID, lowerRelTableID,
hnswSharedState->nodeTable.getColumn(bindData->columnID).getName(),
hnswSharedState->hnswIndex->getUpperEntryPoint(),
hnswSharedState->hnswIndex->getLowerEntryPoint(), bindData->config.copy())));
const auto catalog = clientContext->getCatalog();
auto upperRelTableID = catalog->getTableCatalogEntry(transaction,
storage::HNSWIndexUtils::getUpperGraphTableName(bindData->indexName))->getTableID();
auto lowerRelTableID = catalog->getTableCatalogEntry(transaction,
storage::HNSWIndexUtils::getLowerGraphTableName(bindData->indexName))->getTableID();
auto auxInfo = std::make_unique<catalog::HNSWIndexAuxInfo>(upperRelTableID, lowerRelTableID,
hnswSharedState->nodeTable.getColumn(bindData->columnID).getName(),
index->getUpperEntryPoint(), index->getLowerEntryPoint(), bindData->config.copy());
auto indexEntry = std::make_unique<catalog::IndexCatalogEntry>(catalog::HNSWIndexCatalogEntry::TYPE_NAME,
bindData->tableEntry->getTableID(), bindData->indexName, std::move(auxInfo));
catalog->createIndex(transaction,std::move(indexEntry));
}

static std::string createHNSWIndexTables(main::ClientContext& context,
Expand Down
9 changes: 3 additions & 6 deletions src/function/table/call/show_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ static void outputRelTableConnection(const DataChunk& outputDataChunk, uint64_t
// Get src and dst name
const auto srcTableID = relTableEntry->getSrcTableID();
const auto dstTableID = relTableEntry->getDstTableID();
const auto srcTableName = catalog->getTableName(context->getTransaction(), srcTableID);
const auto dstTableName = catalog->getTableName(context->getTransaction(), dstTableID);
// Get src and dst primary key
const auto srcTableEntry = catalog->getTableCatalogEntry(context->getTransaction(), srcTableID);
const auto dstTableEntry = catalog->getTableCatalogEntry(context->getTransaction(), dstTableID);
Expand All @@ -47,8 +45,8 @@ static void outputRelTableConnection(const DataChunk& outputDataChunk, uint64_t
const auto dstTablePrimaryKey =
dstTableEntry->constCast<NodeTableCatalogEntry>().getPrimaryKeyName();
// Write result to dataChunk
outputDataChunk.getValueVectorMutable(0).setValue(outputPos, srcTableName);
outputDataChunk.getValueVectorMutable(1).setValue(outputPos, dstTableName);
outputDataChunk.getValueVectorMutable(0).setValue(outputPos, srcTableEntry->getName());
outputDataChunk.getValueVectorMutable(1).setValue(outputPos, dstTableEntry->getName());
outputDataChunk.getValueVectorMutable(2).setValue(outputPos, srcTablePrimaryKey);
outputDataChunk.getValueVectorMutable(3).setValue(outputPos, dstTablePrimaryKey);
}
Expand Down Expand Up @@ -90,8 +88,7 @@ static std::unique_ptr<TableFuncBindData> bindFunc(const ClientContext* context,
std::vector<LogicalType> columnTypes;
const auto tableName = input->getLiteralVal<std::string>(0);
const auto catalog = context->getCatalog();
const auto tableID = catalog->getTableID(context->getTransaction(), tableName);
auto tableEntry = catalog->getTableCatalogEntry(context->getTransaction(), tableID);
auto tableEntry = catalog->getTableCatalogEntry(context->getTransaction(), tableName);
const auto tableType = tableEntry->getTableType();
if (tableType != TableType::REL && tableType != TableType::REL_GROUP) {
throw BinderException{"Show connection can only be called on a rel table!"};
Expand Down
5 changes: 2 additions & 3 deletions src/function/table/call/stats_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ static std::unique_ptr<TableFuncBindData> bindFunc(const ClientContext* context,
if (!catalog->containsTable(context->getTransaction(), tableName)) {
throw BinderException{"Table " + tableName + " does not exist!"};
}
const auto tableID = catalog->getTableID(context->getTransaction(), tableName);
auto tableEntry = catalog->getTableCatalogEntry(context->getTransaction(), tableID);
auto tableEntry = catalog->getTableCatalogEntry(context->getTransaction(), tableName);
if (tableEntry->getTableType() != TableType::NODE) {
throw BinderException{
"Stats from a non-node table " + tableName + " is not supported yet!"};
Expand All @@ -80,7 +79,7 @@ static std::unique_ptr<TableFuncBindData> bindFunc(const ClientContext* context,
columnTypes.push_back(LogicalType::INT64());
}
const auto storageManager = context->getStorageManager();
auto table = storageManager->getTable(tableID);
auto table = storageManager->getTable(tableEntry->getTableID());
auto columns = input->binder->createVariables(columnNames, columnTypes);
return std::make_unique<StatsInfoBindData>(columns, tableEntry, table, context);
}
Expand Down
5 changes: 2 additions & 3 deletions src/function/table/call/storage_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,9 @@ static std::unique_ptr<TableFuncBindData> bindFunc(const ClientContext* context,
if (!catalog->containsTable(context->getTransaction(), tableName)) {
throw BinderException{"Table " + tableName + " does not exist!"};
}
auto tableID = catalog->getTableID(context->getTransaction(), tableName);
auto tableEntry = catalog->getTableCatalogEntry(context->getTransaction(), tableID);
auto tableEntry = catalog->getTableCatalogEntry(context->getTransaction(), tableName);
auto storageManager = context->getStorageManager();
auto table = storageManager->getTable(tableID);
auto table = storageManager->getTable(tableEntry->getTableID());
auto columns = input->binder->createVariables(columnNames, columnTypes);
return std::make_unique<StorageInfoBindData>(columns, tableEntry, table, context);
}
Expand Down
Loading

0 comments on commit 71c1650

Please sign in to comment.