Skip to content

Commit

Permalink
Add an iterator on the sampling attached data
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris00 committed Apr 2, 2024
1 parent b70050a commit 0d3d2ca
Showing 1 changed file with 53 additions and 7 deletions.
60 changes: 53 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,18 @@ impl<D> Sampling<D> {
/// by ... `[f64::NAN; 2]`, `p`, `None`,...
pub fn iter(&self) -> SamplingIter<'_, D> {
SamplingIter {
iter: SamplingIterData {
path: self.path.iter(),
prev_is_cut: true,
guess_len: self.guess_len.get(),
}
}
}

/// Same as [`Self::iter`] but also provides access to the data of
/// each point.
pub fn iter_data(&self) -> SamplingIterData<'_, D> {
SamplingIterData {
path: self.path.iter(),
prev_is_cut: true,
guess_len: self.guess_len.get(),
Expand Down Expand Up @@ -298,17 +310,18 @@ impl<D> Sampling<D> {
}
}

/// Iterator on the curve points (and cuts).

/// Iterator on the curve points (and cuts) together with the attached data.
///
/// Created by [`Sampling::iter`].
pub struct SamplingIter<'a, D> {
pub struct SamplingIterData<'a, D> {
path: list::Iter<'a, Point<D>>,
prev_is_cut: bool,
guess_len: usize,
}

impl<'a, D> Iterator for SamplingIter<'a, D> {
type Item = [f64; 2];
impl<'a, D> Iterator for SamplingIterData<'a, D> {
type Item = ([f64; 2], &'a D);

fn next(&mut self) -> Option<Self::Item> {
match self.path.next() {
Expand All @@ -317,7 +330,7 @@ impl<'a, D> Iterator for SamplingIter<'a, D> {
self.guess_len -= 1;
if p.is_valid() {
self.prev_is_cut = false;
Some(p.xy)
Some((p.xy, &p.data))
} else if self.prev_is_cut {
// Find the next valid point.
let r = self.path.try_fold(0, |n, p| {
Expand All @@ -336,12 +349,12 @@ impl<'a, D> Iterator for SamplingIter<'a, D> {
ControlFlow::Break((n, p)) => {
self.guess_len -= n;
self.prev_is_cut = false;
Some(p.xy)
Some((p.xy, &p.data))
}
}
} else {
self.prev_is_cut = true;
Some([f64::NAN; 2])
Some(([f64::NAN; 2], &p.data))
}
}
}
Expand All @@ -352,6 +365,27 @@ impl<'a, D> Iterator for SamplingIter<'a, D> {
}
}

/// Iterator on the curve points (and cuts).
///
/// Created by [`Sampling::iter`].
pub struct SamplingIter<'a, D> {
iter: SamplingIterData<'a, D>,
}

impl<'a, D> Iterator for SamplingIter<'a, D> {
type Item = [f64; 2];

#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|(xy, _)| xy)
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}


/// Mutable iterator on the curve points (and cuts).
///
/// Created by [`Sampling::iter_mut`].
Expand Down Expand Up @@ -1771,6 +1805,18 @@ mod tests {
Some((4., 0.)), None]);
}

#[test]
fn iter_data() {
let s = Sampling::uniform(|x| (x, x as i32), 0., 4.).n(3)
.init(&[1.]).init_pt([(3., (0., -1))]).build();
let expected = vec![
([0.,0.], &0), ([1.,1.], &1), ([2.,2.], &2), ([3., 0.], &-1),
([4.,4.], &4)];
for (i, d) in s.iter_data().enumerate() {
assert_eq!(d, expected[i]);
}
}

/// In order the judge the quality of the sampling, we save it
/// with the internal cost data.
fn write_with_point_costs<D>(
Expand Down

0 comments on commit 0d3d2ca

Please sign in to comment.