Skip to content

Commit

Permalink
feat: support merge fragment with dataset (#3256)
Browse files Browse the repository at this point in the history
this PR allows merge dataset concurrently.
  • Loading branch information
chenkovsky authored Dec 23, 2024
1 parent c40164b commit ae70478
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 3 deletions.
7 changes: 7 additions & 0 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,13 @@ def data_storage_version(self) -> str:
"""
return self._ds.data_storage_version

@property
def max_field_id(self) -> int:
"""
The max_field_id in manifest
"""
return self._ds.max_field_id

def to_table(
self,
columns: Optional[Union[List[str], Dict[str, str]]] = None,
Expand Down
75 changes: 75 additions & 0 deletions python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_write_fragments,
)
from .progress import FragmentWriteProgress, NoopFragmentWriteProgress
from .types import _coerce_reader
from .udf import BatchUDF, normalize_transform

if TYPE_CHECKING:
Expand Down Expand Up @@ -406,6 +407,7 @@ def scanner(
limit: Optional[int] = None,
offset: Optional[int] = None,
with_row_id: bool = False,
with_row_address: bool = False,
batch_readahead: int = 16,
) -> "LanceScanner":
"""See Dataset::scanner for details"""
Expand All @@ -424,6 +426,7 @@ def scanner(
limit=limit,
offset=offset,
with_row_id=with_row_id,
with_row_address=with_row_address,
batch_readahead=batch_readahead,
**columns_arg,
)
Expand Down Expand Up @@ -475,6 +478,78 @@ def to_table(
with_row_id=with_row_id,
).to_table()

def merge(
self,
data_obj: ReaderLike,
left_on: str,
right_on: Optional[str] = None,
schema=None,
) -> Tuple[FragmentMetadata, LanceSchema]:
"""
Merge another dataset into this fragment.
Performs a left join, where the fragment is the left side and data_obj
is the right side. Rows existing in the dataset but not on the left will
be filled with null values, unless Lance doesn't support null values for
some types, in which case an error will be raised.
Parameters
----------
data_obj: Reader-like
The data to be merged. Acceptable types are:
- Pandas DataFrame, Pyarrow Table, Dataset, Scanner,
Iterator[RecordBatch], or RecordBatchReader
left_on: str
The name of the column in the dataset to join on.
right_on: str or None
The name of the column in data_obj to join on. If None, defaults to
left_on.
Examples
--------
>>> import lance
>>> import pyarrow as pa
>>> df = pa.table({'x': [1, 2, 3], 'y': ['a', 'b', 'c']})
>>> dataset = lance.write_dataset(df, "dataset")
>>> dataset.to_table().to_pandas()
x y
0 1 a
1 2 b
2 3 c
>>> fragments = dataset.get_fragments()
>>> new_df = pa.table({'x': [1, 2, 3], 'z': ['d', 'e', 'f']})
>>> merged = []
>>> schema = None
>>> for f in fragments:
... f, schema = f.merge(new_df, 'x')
... merged.append(f)
>>> merge = lance.LanceOperation.Merge(merged, schema)
>>> dataset = lance.LanceDataset.commit("dataset", merge, read_version=1)
>>> dataset.to_table().to_pandas()
x y z
0 1 a d
1 2 b e
2 3 c f
See Also
--------
LanceDataset.merge_columns :
Add columns to this Fragment.
Returns
-------
Tuple[FragmentMetadata, LanceSchema]
A new fragment with the merged column(s) and the final schema.
"""
if right_on is None:
right_on = left_on

reader = _coerce_reader(data_obj, schema)
max_field_id = self._ds.max_field_id
metadata, schema = self._fragment.merge(reader, left_on, right_on, max_field_id)
return metadata, schema

def merge_columns(
self,
value_func: Dict[str, str]
Expand Down
61 changes: 61 additions & 0 deletions python/python/tests/test_fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,64 @@ def test_create_from_file(tmp_path):
assert dataset.count_rows() == 1600
assert len(dataset.get_fragments()) == 1
assert dataset.get_fragments()[0].fragment_id == 2


def test_fragment_merge(tmp_path):
schema = pa.schema([pa.field("a", pa.string())])
batches = pa.RecordBatchReader.from_batches(
schema,
[
pa.record_batch([pa.array(["0" * 1024] * 1024 * 8)], names=["a"]),
pa.record_batch([pa.array(["0" * 1024] * 1024 * 8)], names=["a"]),
],
)

progress = ProgressForTest()
fragments = write_fragments(
batches,
tmp_path,
max_rows_per_group=512,
max_bytes_per_file=1024,
progress=progress,
)

operation = lance.LanceOperation.Overwrite(schema, fragments)
dataset = lance.LanceDataset.commit(tmp_path, operation)
merged = []
schema = None
for fragment in dataset.get_fragments():
table = fragment.scanner(with_row_id=True, columns=[]).to_table()
table = table.add_column(0, "b", [[i for i in range(len(table))]])
fragment, schema = fragment.merge(table, "_rowid")
merged.append(fragment)

merge = lance.LanceOperation.Merge(merged, schema)
dataset = lance.LanceDataset.commit(
tmp_path, merge, read_version=dataset.latest_version
)

merged = []
schema = None
for fragment in dataset.get_fragments():
table = fragment.scanner(with_row_address=True, columns=[]).to_table()
table = table.add_column(0, "c", [[i + 1 for i in range(len(table))]])
fragment, schema = fragment.merge(table, "_rowaddr")
merged.append(fragment)

merge = lance.LanceOperation.Merge(merged, schema)
dataset = lance.LanceDataset.commit(
tmp_path, merge, read_version=dataset.latest_version
)

merged = []
for fragment in dataset.get_fragments():
table = fragment.scanner(columns=["b"]).to_table()
table = table.add_column(0, "d", [[i + 2 for i in range(len(table))]])
fragment, schema = fragment.merge(table, "b")
merged.append(fragment)

merge = lance.LanceOperation.Merge(merged, schema)
dataset = lance.LanceDataset.commit(
tmp_path, merge, read_version=dataset.latest_version
)
assert [f.name for f in dataset.schema] == ["a", "b", "c", "d"]
5 changes: 5 additions & 0 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ impl Dataset {
self.clone()
}

#[getter(max_field_id)]
fn max_field_id(self_: PyRef<'_, Self>) -> PyResult<i32> {
Ok(self_.ds.manifest().max_field_id())
}

#[getter(schema)]
fn schema(self_: PyRef<'_, Self>) -> PyResult<PyObject> {
let arrow_schema = ArrowSchema::from(self_.ds.schema());
Expand Down
27 changes: 25 additions & 2 deletions python/src/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::fmt::Write as _;
use std::sync::Arc;

use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
use arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow};
use arrow_array::RecordBatchReader;
use arrow_schema::Schema as ArrowSchema;
use futures::TryFutureExt;
Expand Down Expand Up @@ -163,7 +163,7 @@ impl FileFragment {
}

#[allow(clippy::too_many_arguments)]
#[pyo3(signature=(columns=None, columns_with_transform=None, batch_size=None, filter=None, limit=None, offset=None, with_row_id=None, batch_readahead=None))]
#[pyo3(signature=(columns=None, columns_with_transform=None, batch_size=None, filter=None, limit=None, offset=None, with_row_id=None, with_row_address=None, batch_readahead=None))]
fn scanner(
self_: PyRef<'_, Self>,
columns: Option<Vec<String>>,
Expand All @@ -173,6 +173,7 @@ impl FileFragment {
limit: Option<i64>,
offset: Option<i64>,
with_row_id: Option<bool>,
with_row_address: Option<bool>,
batch_readahead: Option<usize>,
) -> PyResult<Scanner> {
let mut scanner = self_.fragment.scan();
Expand Down Expand Up @@ -212,6 +213,9 @@ impl FileFragment {
if with_row_id.unwrap_or(false) {
scanner.with_row_id();
}
if with_row_address.unwrap_or(false) {
scanner.with_row_address();
}
if let Some(batch_readahead) = batch_readahead {
scanner.batch_readahead(batch_readahead);
}
Expand Down Expand Up @@ -261,6 +265,25 @@ impl FileFragment {
Ok((PyLance(fragment), LanceSchema(schema)))
}

fn merge(
&mut self,
reader: PyArrowType<ArrowArrayStreamReader>,
left_on: String,
right_on: String,
max_field_id: i32,
) -> PyResult<(PyLance<Fragment>, LanceSchema)> {
let mut fragment = self.fragment.clone();
let (fragment, schema) = RT
.spawn(None, async move {
fragment
.merge_columns(reader.0, &left_on, &right_on, max_field_id)
.await
})?
.infer_error()?;

Ok((PyLance(fragment), LanceSchema(schema)))
}

fn delete(&self, predicate: &str) -> PyResult<Option<Self>> {
let old_fragment = self.fragment.clone();
let updated_fragment = RT
Expand Down
64 changes: 63 additions & 1 deletion rust/lance/src/dataset/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ use std::sync::Arc;

use arrow::compute::concat_batches;
use arrow_array::cast::as_primitive_array;
use arrow_array::{new_null_array, RecordBatch, StructArray, UInt32Array, UInt64Array};
use arrow_array::{
new_null_array, RecordBatch, RecordBatchReader, StructArray, UInt32Array, UInt64Array,
};
use arrow_schema::Schema as ArrowSchema;
use datafusion::logical_expr::Expr;
use datafusion::scalar::ScalarValue;
Expand Down Expand Up @@ -1331,6 +1333,66 @@ impl FileFragment {
Updater::try_new(self.clone(), reader, deletion_vector, schemas, batch_size)
}

pub async fn merge_columns(
&mut self,
stream: impl RecordBatchReader + Send + 'static,
left_on: &str,
right_on: &str,
max_field_id: i32,
) -> Result<(Fragment, Schema)> {
let stream = Box::new(stream);
if self.schema().field(left_on).is_none() && left_on != ROW_ID && left_on != ROW_ADDR {
return Err(Error::invalid_input(
format!(
"Column {} does not exist in the left side fragment",
left_on
),
location!(),
));
};
let right_schema = stream.schema();
if right_schema.field_with_name(right_on).is_err() {
return Err(Error::invalid_input(
format!(
"Column {} does not exist in the right side fragment",
right_on
),
location!(),
));
};

for field in right_schema.fields() {
if field.name() == right_on {
// right_on is allowed to exist in the dataset, since it may be
// the same as left_on.
continue;
}
if self.schema().field(field.name()).is_some() {
return Err(Error::invalid_input(
format!(
"Column {} exists in left side fragment and right side dataset",
field.name()
),
location!(),
));
}
}
// Hash join
let joiner = Arc::new(HashJoiner::try_new(stream, right_on).await?);
// Final schema is union of current schema, plus the RHS schema without
// the right_on key.
let mut new_schema: Schema = self.schema().merge(joiner.out_schema().as_ref())?;
new_schema.set_field_id(Some(max_field_id));

let new_fragment = self
.clone()
.merge(left_on, &joiner)
.await
.map(|f| f.metadata)?;

Ok((new_fragment, new_schema))
}

pub(crate) async fn merge(mut self, join_column: &str, joiner: &HashJoiner) -> Result<Self> {
let mut updater = self.updater(Some(&[join_column]), None, None).await?;

Expand Down

0 comments on commit ae70478

Please sign in to comment.