diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index 7c396a981..18aa359a6 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -390,7 +390,9 @@ def dataset_stats( expressions: tuple[_ColumnsClauseArgument[Any], ...] = ( sa.func.count(table.c.sys__id), ) - if "size" in table.columns: + if "file__size" in table.columns: + expressions = (*expressions, sa.func.sum(table.c.file__size)) + elif "size" in table.columns: expressions = (*expressions, sa.func.sum(table.c.size)) query = select(*expressions) ((nrows, *rest),) = self.db.execute(query) diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index aabdd57d4..a4356dadd 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -4,6 +4,7 @@ import pandas as pd import pytest +from datachain.dataset import DatasetStats from datachain.lib.dc import DataChain from datachain.lib.file import File @@ -205,3 +206,12 @@ def test_show_no_truncate(capsys, catalog): for i in range(3): assert client[i] in normalized_output assert details[i] in normalized_output + + +def test_from_storage_dataset_stats(tmp_dir, catalog): + for i in range(4): + (tmp_dir / f"file{i}.txt").write_text(f"file{i}") + + dc = DataChain.from_storage(tmp_dir.as_uri(), catalog=catalog).save("test-data") + stats = catalog.dataset_stats(dc.name, dc.version) + assert stats == DatasetStats(num_objects=4, size=20)