From 3281b218cccb2bba746099624344a12a689cd5f1 Mon Sep 17 00:00:00 2001 From: Danny Rakita Date: Fri, 22 Dec 2023 00:06:15 -0500 Subject: [PATCH] made proxima faster --- .../src/pair_group_queries.rs | 11 +++- .../optima_proximity/src/pair_queries.rs | 3 + .../crates/optima_proximity/src/proxima.rs | 62 ++++++++++--------- optima_refactor/src/bin/test7.rs | 42 +++++++++++-- optima_refactor/src/bin/test8.rs | 8 +++ 5 files changed, 93 insertions(+), 33 deletions(-) create mode 100644 optima_refactor/src/bin/test8.rs diff --git a/optima_refactor/crates/optima_proximity/src/pair_group_queries.rs b/optima_refactor/crates/optima_proximity/src/pair_group_queries.rs index b6dc3ee..49275b1 100644 --- a/optima_refactor/crates/optima_proximity/src/pair_group_queries.rs +++ b/optima_refactor/crates/optima_proximity/src/pair_group_queries.rs @@ -201,6 +201,14 @@ pub enum ParryPairSelector { HalfPairsSubcomponents, PairsByIdxs(Vec) } +impl ParryPairSelector { + pub fn len(&self) -> usize { + match self { + ParryPairSelector::PairsByIdxs(v) => { v.len() } + _ => { usize::MAX } + } + } +} pub struct ParryPairGroupOutputWrapper { data: O, @@ -889,10 +897,11 @@ impl OPairGroupQryTrait for ParryDistanceGroupFilter { let f = | output: &Box> | -> Vec { let mut a = vec![]; output.outputs.iter().for_each(|x| { - if x.data.raw_distance < args.distance_threshold { + if x.data.distance() < args.distance_threshold { a.push(x.pair_idxs.clone()); } }); + // println!("{:?}, {:?}", output.outputs.len(), a.len()); a }; diff --git a/optima_refactor/crates/optima_proximity/src/pair_queries.rs b/optima_refactor/crates/optima_proximity/src/pair_queries.rs index 1c33020..a88dc98 100644 --- a/optima_refactor/crates/optima_proximity/src/pair_queries.rs +++ b/optima_refactor/crates/optima_proximity/src/pair_queries.rs @@ -226,6 +226,7 @@ impl> OPairQryTrait for ParryProximaDistanceUpperBoun type Args<'a> = ParryProximaDistanceUpperBoundArgs<'a, T, P, >::Native3DVecType>; type Output = ParryProximaDistanceUpperBoundOutput; + #[inline(always)] fn query<'a>(_shape_a: &Self::ShapeTypeA, _shape_b: &Self::ShapeTypeB, pose_a: &P, pose_b: &P, args: &Self::Args<'a>) -> Self::Output { let start = Instant::now(); // let shapes = get_shapes_from_parry_qry_shape_type_and_parry_shape_rep(shape_a, shape_b, &args.0, &args.1); @@ -291,6 +292,7 @@ impl> OPairQryTrait for ParryProximaDistanceLowerBoun type Args<'a> = ParryProximaDistanceLowerBoundArgs<'a, T, P>; type Output = ParryProximaDistanceLowerBoundOutput; + #[inline(always)] fn query<'a>(shape_a: &Self::ShapeTypeA, shape_b: &Self::ShapeTypeB, pose_a: &P, pose_b: &P, args: &Self::Args<'a>) -> Self::Output { let start = Instant::now(); let shapes = get_shapes_from_parry_qry_shape_type_and_parry_shape_rep(shape_a, shape_b, &args.parry_qry_shape_type, &args.parry_shape_rep); @@ -337,6 +339,7 @@ impl> OPairQryTrait for ParryProximaDistanceBoundsQry type Args<'a> = ParryProximaDistanceBoundsArgs<'a, T, P, >::Native3DVecType>; type Output = ParryProximaDistanceBoundsOutputOption; + #[inline(always)] fn query<'a>(shape_a: &Self::ShapeTypeA, shape_b: &Self::ShapeTypeB, pose_a: &P, pose_b: &P, args: &Self::Args<'a>) -> Self::Output { let start = Instant::now(); diff --git a/optima_refactor/crates/optima_proximity/src/proxima.rs b/optima_refactor/crates/optima_proximity/src/proxima.rs index 1662230..2962d35 100644 --- a/optima_refactor/crates/optima_proximity/src/proxima.rs +++ b/optima_refactor/crates/optima_proximity/src/proxima.rs @@ -110,16 +110,19 @@ impl OPairGroupQryTrait for ParryProximaQry { proximity_lower_bound_sum += lower_bound_through_loss_and_powf; proximity_upper_bound_sum += upper_bound_through_loss_and_powf; // &(T, T, T, &P, (u64, u64), &ParryPairIdxs) - let idx = v.binary_search_by(|x: &(T, T, T, &P, (u64, u64), &ParryPairIdxs)| diff.partial_cmp(&x.0).unwrap()); - let idx = match idx { - Ok(idx) => { idx } - Err(idx) => { idx } - }; - v.insert(idx, (diff, lower_bound_through_loss_and_powf, upper_bound_through_loss_and_powf, &data.displacement_between_a_and_b_k, x.pair_ids(), x.pair_idxs())); + // let idx = v.binary_search_by(|x: &(T, T, T, &P, (u64, u64), &ParryPairIdxs)| diff.partial_cmp(&x.0).expect(&format!("{:?}, {:?}", diff, x.0))); + // let idx = match idx { + // Ok(idx) => { idx } + // Err(idx) => { idx } + // }; + // v.insert(idx, (diff, lower_bound_through_loss_and_powf, upper_bound_through_loss_and_powf, &data.displacement_between_a_and_b_k, x.pair_ids(), x.pair_idxs())); + v.push((diff, lower_bound_through_loss_and_powf, upper_bound_through_loss_and_powf, &data.displacement_between_a_and_b_k, x.pair_ids(), x.pair_idxs())); } } }); + let mut indices = (0..v.len()).collect::>(); + indices.sort_by(|x, y| v[*x].0.partial_cmp(&v[*y].0).unwrap()); // v.sort_by(|x, y| { y.0.partial_cmp(&x.0).expect(&format!("y.1 {:?}, y.0 {:?}, x.1 {:?}, x.0 {:?}", y.1, y.0, x.1, x.0)) }); let mut proximity_lower_bound_output; @@ -128,6 +131,7 @@ impl OPairGroupQryTrait for ParryProximaQry { let mut num_queries = 0; + let mut idx = indices.len(); 'l: loop { proximity_lower_bound_output = proximity_lower_bound_sum.powf(args.p_norm.recip()); proximity_upper_bound_output = proximity_upper_bound_sum.powf(args.p_norm.recip()); @@ -135,6 +139,8 @@ impl OPairGroupQryTrait for ParryProximaQry { max_possible_error = (proximity_upper_bound_output - proximity_lower_bound_output).abs(); if num_queries >= v.len() { break 'l; } + if idx == 0 { break 'l; } + idx -= 1; let mut terminate = false; match &args.termination { @@ -142,13 +148,13 @@ impl OPairGroupQryTrait for ParryProximaQry { if start.elapsed() > *t { terminate = true; } } ProximaTermination::MaxError(e) => { - if max_possible_error.to_constant() < * e { terminate = true; } + if max_possible_error.to_constant() < *e { terminate = true; } } } if terminate { break 'l; } - let curr_entry = &v[num_queries]; + let curr_entry = &v[idx]; let lower_bound_through_loss_and_powf = curr_entry.1; let upper_bound_through_loss_and_powf = curr_entry.2; let displacement_between_a_and_b_k = curr_entry.3; @@ -186,27 +192,27 @@ impl OPairGroupQryTrait for ParryProximaQry { block.displacement_between_a_and_b_j_staging = displacement_between_a_and_b_k.o3dpose_downcast_or_convert::>().as_ref().o3dpose_to_constant_ads(); /* - match block { - None => { - binding.hashmap.insert(ids, ProximaGenericBlock { - pose_a_j: poses.0.o3dpose_downcast_or_convert::>().as_ref().clone(), - pose_b_j: poses.1.o3dpose_downcast_or_convert::>().as_ref().clone(), - closest_point_a_j: contact.point1.o3dvec_downcast_or_convert::>().as_ref().clone(), - closest_point_b_j: contact.point2.o3dvec_downcast_or_convert::>().as_ref().clone(), - raw_distance_j: *raw_distance, - displacement_between_a_and_b_j: displacement_between_a_and_b_k.o3dpose_downcast_or_convert::>().as_ref().clone(), - }); - } - Some(block) => { - block.displacement_between_a_and_b_j = displacement_between_a_and_b_k.o3dpose_downcast_or_convert::>().as_ref().clone(); - block.raw_distance_j = *raw_distance; - block.closest_point_a_j = contact.point1.o3dvec_downcast_or_convert::>().as_ref().clone(); - block.closest_point_b_j = contact.point2.o3dvec_downcast_or_convert::>().as_ref().clone(); - block.pose_a_j = poses.0.o3dpose_downcast_or_convert::>().as_ref().clone(); - block.pose_b_j = poses.1.o3dpose_downcast_or_convert::>().as_ref().clone(); - } + match block { + None => { + binding.hashmap.insert(ids, ProximaGenericBlock { + pose_a_j: poses.0.o3dpose_downcast_or_convert::>().as_ref().clone(), + pose_b_j: poses.1.o3dpose_downcast_or_convert::>().as_ref().clone(), + closest_point_a_j: contact.point1.o3dvec_downcast_or_convert::>().as_ref().clone(), + closest_point_b_j: contact.point2.o3dvec_downcast_or_convert::>().as_ref().clone(), + raw_distance_j: *raw_distance, + displacement_between_a_and_b_j: displacement_between_a_and_b_k.o3dpose_downcast_or_convert::>().as_ref().clone(), + }); } - */ + Some(block) => { + block.displacement_between_a_and_b_j = displacement_between_a_and_b_k.o3dpose_downcast_or_convert::>().as_ref().clone(); + block.raw_distance_j = *raw_distance; + block.closest_point_a_j = contact.point1.o3dvec_downcast_or_convert::>().as_ref().clone(); + block.closest_point_b_j = contact.point2.o3dvec_downcast_or_convert::>().as_ref().clone(); + block.pose_a_j = poses.0.o3dpose_downcast_or_convert::>().as_ref().clone(); + block.pose_b_j = poses.1.o3dpose_downcast_or_convert::>().as_ref().clone(); + } + } + */ } Box::new(ParryProximaGroupOutput { diff --git a/optima_refactor/src/bin/test7.rs b/optima_refactor/src/bin/test7.rs index 78928e0..95cab9b 100644 --- a/optima_refactor/src/bin/test7.rs +++ b/optima_refactor/src/bin/test7.rs @@ -1,8 +1,42 @@ +use ad_trait::differentiable_function::{FiniteDifferencing2, ForwardAD2, ForwardADMulti2, ReverseAD2}; +use ad_trait::forward_ad::adf::adf_f64x8; +use ad_trait::forward_ad::adfn::adfn; +use nalgebra::Isometry3; +use optima_3d_spatial::optima_3d_pose::O3DPose; use optima_bevy::optima_bevy_utils::robotics::BevyRoboticsTrait; -use optima_robotics::robot::{ORobotDefault, SaveRobot}; +use optima_interpolation::InterpolatorTrait; +use optima_interpolation::splines::{InterpolatingSpline, InterpolatingSplineType}; +use optima_optimization2::{DiffBlockOptimizerTrait, OptimizerOutputTrait}; +use optima_optimization2::open::SimpleOpEnOptimizer; +use optima_proximity::pair_group_queries::{OwnedEmptyParryFilter, OwnedEmptyToProximityQry, OwnedParryDistanceAsProximityGroupQry, OwnedParryDistanceGroupSequenceFilter, ParryDistanceGroupArgs, ParryDistanceGroupSequenceFilterArgs, ParryPairSelector, ProximityLossFunction}; +use optima_proximity::pair_queries::{ParryDisMode, ParryShapeRep}; +use optima_proximity::proxima::{OwnedParryProximaAsProximityQry, PairGroupQryArgsParryProxima, ProximaTermination}; +use optima_robotics::robot::{ORobotDefault}; +use optima_robotics::robotics_optimization2::robotics_optimization_ik::{DifferentiableBlockIKObjectiveTrait, IKGoalUpdateMode}; fn main() { - let mut r = ORobotDefault::load_from_saved_robot("xarm7_with_gripper_and_rail"); - r.preprocess(SaveRobot::Save(None)); - r.bevy_self_collision_visualization(); + let mut r = ORobotDefault::load_from_saved_robot("xarm7_with_gripper_and_rail_8dof"); + // r.set_joint_as_fixed(12, &[0.0]); + // r.save_robot(Some("xarm7_with_gripper_and_rail_8dof")); + + let init_condition = vec![0.1; 8]; + let fq = OwnedParryDistanceGroupSequenceFilter::new(ParryDistanceGroupSequenceFilterArgs::new(vec![ParryShapeRep::BoundingSphere], vec![], 0.6, true, ParryDisMode::ContactDis)); + let q = OwnedParryProximaAsProximityQry::new(PairGroupQryArgsParryProxima::new(ParryShapeRep::Full, true, false, ProximaTermination::MaxError(0.2), ProximityLossFunction::Hinge, 15.0, 0.6)); + // let q = OwnedParryDistanceAsProximityGroupQry::new(ParryDistanceGroupArgs::new(ParryShapeRep::Full, ParryDisMode::ContactDis, true, false, -1000.0, false)); + let db = r.get_ik_differentiable_block(ForwardADMulti2::>::new(), fq, q, None, &init_condition, vec![19], 0.09, 0.6, 1.0, 0.05, 1.0, 0.1, 0.1); + let o = SimpleOpEnOptimizer::new(r.get_dof_lower_bounds(), r.get_dof_upper_bounds(), 0.001); + + let mut solutions = vec![]; + let mut curr_solution = init_condition.clone(); + for _ in 0..3000 { + solutions.push(curr_solution.clone()); + let solution = o.optimize_unconstrained(&curr_solution, &db); + println!("{:?}", solution.solver_status().solve_time()); + // println!("{:?}", solution.solver_status().iterations()); + curr_solution = solution.x_star().to_vec(); + db.update_prev_states(curr_solution.clone()); + db.update_ik_pose(0, Isometry3::from_constructors(&[0.,0.,0.0001], &[0.,0.,0.]), IKGoalUpdateMode::GlobalRelative); + } + let spline = InterpolatingSpline::new(solutions, InterpolatingSplineType::Linear).to_timed_interpolator(6.0); + r.bevy_motion_playback(&spline); } \ No newline at end of file diff --git a/optima_refactor/src/bin/test8.rs b/optima_refactor/src/bin/test8.rs new file mode 100644 index 0000000..f8202e3 --- /dev/null +++ b/optima_refactor/src/bin/test8.rs @@ -0,0 +1,8 @@ + + +fn main() { + let a = vec![1,5,8,2,5,3,7,9,0,6,4,1]; + let mut indices = (0..a.len()).collect::>(); + indices.sort_unstable_by_key(|i| &a[*i]); + println!("{:?}", indices); +} \ No newline at end of file