Skip to content

Commit

Permalink
new interpolating motion generation wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
djrakita committed Dec 31, 2023
1 parent fcc8142 commit 34e83d7
Show file tree
Hide file tree
Showing 8 changed files with 172 additions and 225 deletions.
48 changes: 46 additions & 2 deletions optima_refactor/crates/optima_interpolation/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@ pub trait InterpolatorTraitLite<T: AD, V: OVec<T>> {
let ratio = self.max_t() * u;
return self.interpolate(ratio);
}
fn interpolate_points_by_num_points(&self, num_points: usize) -> Vec<V> {
let mut out = vec![];

let ts = get_interpolation_range_num_steps(T::zero(), T::one(), num_points);
for t in ts { out.push(self.interpolate_normalized(t)); }

out
}
fn interpolate_points_by_normalized_stride(&self, stride_length: T) -> Vec<V> {
let mut out = vec![];

println!("{:?}", stride_length);
let ts = get_interpolation_range(T::zero(), T::one(), stride_length);
for t in ts { out.push(self.interpolate_normalized(t)); }

out
}
}
impl<T: AD, V: OVec<T>, U: InterpolatorTraitLite<T, V>> InterpolatorTraitLite<T, V> for Box<U> {
fn interpolate(&self, t: T) -> V {
Expand Down Expand Up @@ -79,6 +96,11 @@ impl<T: AD, V: OVec<T>, I: InterpolatorTrait<T, V>> ArclengthParameterizedInterp

Self { interpolator, arclength_markers, total_arclength: accumulated_distance, phantom_data: PhantomData::default() }
}

pub fn interpolate_points_by_arclength_absolute_stride(&self, arclength_stride: T) -> Vec<V> {
let normalized_stride = arclength_stride / self.total_arclength;
self.interpolate_points_by_normalized_stride(normalized_stride)
}
}
impl<T: AD, V: OVec<T>, I: InterpolatorTrait<T, V>> InterpolatorTraitLite<T, V> for ArclengthParameterizedInterpolator<T, V, I> {
fn interpolate(&self, s: T) -> V {
Expand Down Expand Up @@ -132,6 +154,15 @@ impl<T: AD, V: OVec<T>, I: InterpolatorTrait<T, V>> TimedInterpolator<T, V, I> {
pub fn new(interpolator: I, max_time: T) -> Self {
Self { interpolator, max_time, phantom_data: PhantomData::default() }
}

pub fn interpolate_points_by_time_stride(&self, time_stride: T) -> Vec<V> {
let mut out = vec![];

let ts = get_interpolation_range(T::zero(), self.max_time, time_stride);
for t in ts { out.push(self.interpolate_normalized(t)); }

out
}
}
impl<T: AD, V: OVec<T>, I: InterpolatorTrait<T, V>> InterpolatorTraitLite<T, V> for TimedInterpolator<T, V, I> {
fn interpolate(&self, t: T) -> V {
Expand Down Expand Up @@ -175,11 +206,14 @@ impl<T: AD, V: OVec<T>, SI: InterpolatorTrait<T, V>, TI: InterpolatorTrait<T, V>
*/

pub fn get_interpolation_range<T: AD>(range_start: T, range_stop: T, step_size: T) -> Vec<T> {
assert!(range_stop >= range_start);

let mut out_range = Vec::new();
out_range.push(range_start);
let mut curr_val = range_start;
/*
let mut last_added_val = range_start;
while !( (range_stop - last_added_val).abs() < step_size ) {
while !( (range_stop - last_added_val).abs() <= step_size ) {
if range_stop > range_start {
last_added_val = last_added_val + step_size;
} else {
Expand All @@ -189,6 +223,16 @@ pub fn get_interpolation_range<T: AD>(range_start: T, range_stop: T, step_size:
}
// out_range.push(range_stop);
*/

'l: loop {
out_range.push(curr_val);
curr_val += step_size;
if curr_val > range_stop { break 'l; }
}

let diff = (range_stop - *out_range.last().unwrap()).abs();
if diff > T::constant(0.001) { out_range.push(range_stop) }

out_range
}
Expand Down
1 change: 0 additions & 1 deletion optima_refactor/crates/optima_proximity/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,5 @@ pub mod shapes;
pub mod pair_group_queries;
pub mod shape_scene;
pub mod proxima;
pub mod proxima2;

pub extern crate parry_ad;
47 changes: 43 additions & 4 deletions optima_refactor/crates/optima_proximity/src/pair_group_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1252,7 +1252,7 @@ impl ADConvertableTrait for PairGroupQryArgsCategoryParryDistanceSequenceFilterC

////////////////////////////////////////////////////////////////////////////////////////////////////

// DISTANCE SEQUENCE FILTER //
// EMPTY FILTER //

pub struct EmptyParryFilter;
impl OPairGroupQryTrait for EmptyParryFilter {
Expand All @@ -1275,6 +1275,45 @@ pub type OwnedEmptyParryFilter<'a, T> = OwnedPairGroupQry<'a, T, EmptyParryFilte

////////////////////////////////////////////////////////////////////////////////////////////////////

/*
// ROBOT STATE DEPENDENT FILTER
pub struct ParryRobotStateDependentFilter<FQ>(PhantomData<FQ>) where FQ: OPairGroupQryTrait<ShapeCategory=ShapeCategoryOParryShape, SelectorType=ParryPairSelector, OutputCategory=PairGroupQryOutputCategoryParryFilter>;
impl<FQ> OPairGroupQryTrait for ParryRobotStateDependentFilter<FQ>
where FQ: OPairGroupQryTrait<ShapeCategory=ShapeCategoryOParryShape, SelectorType=ParryPairSelector, OutputCategory=PairGroupQryOutputCategoryParryFilter>
{
type ShapeCategory = ShapeCategoryOParryShape;
type SelectorType = ParryPairSelector;
type ArgsCategory = PairGroupQryArgsCategoryParryRobotStateDependentFilter<FQ>;
type OutputCategory = ParryFilterOutput;
fn query<'a, T: AD, P: O3DPose<T>, S: PairSkipsTrait, A: PairAverageDistanceTrait<T>>(shape_group_a: &Vec<<Self::ShapeCategory as ShapeCategoryTrait>::ShapeType<T, P>>, shape_group_b: &Vec<<Self::ShapeCategory as ShapeCategoryTrait>::ShapeType<T, P>>, poses_a: &Vec<P>, poses_b: &Vec<P>, pair_selector: &Self::SelectorType, pair_skips: &S, pair_average_distances: &A, freeze: bool, args: &<Self::ArgsCategory as PairGroupQryArgsCategory>::Args<'a, T>) -> <Self::OutputCategory as PairGroupQryOutputCategory>::Output<T, P> {
todo!()
}
}
pub struct ParryRobotStateDependentFilterArgs<'a, T: AD, FQ>
where FQ: OPairGroupQryTrait<ShapeCategory=ShapeCategoryOParryShape, SelectorType=ParryPairSelector, OutputCategory=PairGroupQryOutputCategoryParryFilter>
{
curr_state: RwLock<Vec<f64>>,
last_checked_state: RwLock<Vec<f64>>,
linf_cutoff_for_check: f64,
curr_selector: RwLock<ParryPairSelector>,
filter_query: OwnedPairGroupQry<'a, T, FQ>
}
pub struct PairGroupQryArgsCategoryParryRobotStateDependentFilter<FQ>(PhantomData<FQ>)
where FQ: OPairGroupQryTrait<ShapeCategory=ShapeCategoryOParryShape, SelectorType=ParryPairSelector, OutputCategory=PairGroupQryOutputCategoryParryFilter>;
impl<FQ> PairGroupQryArgsCategory for PairGroupQryArgsCategoryParryRobotStateDependentFilter<FQ>
where FQ: OPairGroupQryTrait<ShapeCategory=ShapeCategoryOParryShape, SelectorType=ParryPairSelector, OutputCategory=PairGroupQryOutputCategoryParryFilter>
{
type Args<'a, T: AD> = ParryRobotStateDependentFilterArgs<'a, T, FQ>;
type QueryType = ParryRobotStateDependentFilter<FQ>;
}
*/

////////////////////////////////////////////////////////////////////////////////////////////////////

// PROXIMITY LOSS FUNCTIONS //

/*
Expand Down Expand Up @@ -1464,7 +1503,7 @@ pub(crate) fn parry_generic_pair_group_query<T: AD, P: O3DPose<T>, S: PairSkipsT
let idx0 = *idx_pair0;
let idx1 = *idx_pair1;
let shape_a = &shape_group_a[idx0];
let shape_b = &shape_group_a[idx1];
let shape_b = &shape_group_b[idx1];
let pose_a = &poses_a[idx0];
let pose_b = &poses_b[idx1];

Expand All @@ -1487,12 +1526,12 @@ pub(crate) fn parry_generic_pair_group_query<T: AD, P: O3DPose<T>, S: PairSkipsT
let shape_b_subcomponent_idx = idxs1.1;

let shape_a = &shape_group_a[shape_a_idx];
let shape_b = &shape_group_a[shape_b_idx];
let shape_b = &shape_group_b[shape_b_idx];
let pose_a = &poses_a[shape_a_idx];
let pose_b = &poses_b[shape_b_idx];

let id_a = shape_a.convex_subcomponents[shape_a_subcomponent_idx].id_from_shape_rep(parry_shape_rep1);
let id_b = shape_b.convex_subcomponents[shape_b_subcomponent_idx].id_from_shape_rep(parry_shape_rep1);
let id_b = shape_b.convex_subcomponents[shape_b_subcomponent_idx].id_from_shape_rep(parry_shape_rep2);
if decide_skip_generic(id_a, id_b, pair_skips, for_filter) { continue 'l; }

count += 1;
Expand Down
216 changes: 0 additions & 216 deletions optima_refactor/crates/optima_proximity/src/proxima2.rs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use ad_trait::AD;
use ad_trait::differentiable_block::DifferentiableBlock;
use ad_trait::differentiable_function::{DifferentiableFunctionClass, DifferentiableFunctionTrait};
use optima_3d_spatial::optima_3d_pose::O3DPoseCategory;
use optima_3d_spatial::optima_3d_vec::O3DVec;
use optima_linalg::{OLinalgCategory, OVec};
use optima_optimization::loss_functions::{GrooveLossGaussianDirection, OptimizationLossFunctionTrait, OptimizationLossGroove};
use optima_proximity::pair_group_queries::{OPairGroupQryTrait, OwnedPairGroupQry, ParryPairSelector, ProximityLossFunction, ToParryProximityOutputCategory};
Expand Down
Loading

0 comments on commit 34e83d7

Please sign in to comment.