Skip to content

Commit

Permalink
made proxima faster
Browse files Browse the repository at this point in the history
  • Loading branch information
djrakita committed Dec 22, 2023
1 parent d0d007f commit 3281b21
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ pub enum ParryPairSelector {
HalfPairsSubcomponents,
PairsByIdxs(Vec<ParryPairIdxs>)
}
impl ParryPairSelector {
pub fn len(&self) -> usize {
match self {
ParryPairSelector::PairsByIdxs(v) => { v.len() }
_ => { usize::MAX }
}
}
}

pub struct ParryPairGroupOutputWrapper<O> {
data: O,
Expand Down Expand Up @@ -889,10 +897,11 @@ impl OPairGroupQryTrait for ParryDistanceGroupFilter {
let f = | output: &Box<ParryDistanceGroupOutput<T>> | -> Vec<ParryPairIdxs> {
let mut a = vec![];
output.outputs.iter().for_each(|x| {
if x.data.raw_distance < args.distance_threshold {
if x.data.distance() < args.distance_threshold {
a.push(x.pair_idxs.clone());
}
});
// println!("{:?}, {:?}", output.outputs.len(), a.len());
a
};

Expand Down
3 changes: 3 additions & 0 deletions optima_refactor/crates/optima_proximity/src/pair_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ impl<T: AD, P: O3DPose<T>> OPairQryTrait<T, P> for ParryProximaDistanceUpperBoun
type Args<'a> = ParryProximaDistanceUpperBoundArgs<'a, T, P, <P::RotationType as O3DRotation<T>>::Native3DVecType>;
type Output = ParryProximaDistanceUpperBoundOutput<T>;

#[inline(always)]
fn query<'a>(_shape_a: &Self::ShapeTypeA, _shape_b: &Self::ShapeTypeB, pose_a: &P, pose_b: &P, args: &Self::Args<'a>) -> Self::Output {
let start = Instant::now();
// let shapes = get_shapes_from_parry_qry_shape_type_and_parry_shape_rep(shape_a, shape_b, &args.0, &args.1);
Expand Down Expand Up @@ -291,6 +292,7 @@ impl<T: AD, P: O3DPose<T>> OPairQryTrait<T, P> for ParryProximaDistanceLowerBoun
type Args<'a> = ParryProximaDistanceLowerBoundArgs<'a, T, P>;
type Output = ParryProximaDistanceLowerBoundOutput<T, P>;

#[inline(always)]
fn query<'a>(shape_a: &Self::ShapeTypeA, shape_b: &Self::ShapeTypeB, pose_a: &P, pose_b: &P, args: &Self::Args<'a>) -> Self::Output {
let start = Instant::now();
let shapes = get_shapes_from_parry_qry_shape_type_and_parry_shape_rep(shape_a, shape_b, &args.parry_qry_shape_type, &args.parry_shape_rep);
Expand Down Expand Up @@ -337,6 +339,7 @@ impl<T: AD, P: O3DPose<T>> OPairQryTrait<T, P> for ParryProximaDistanceBoundsQry
type Args<'a> = ParryProximaDistanceBoundsArgs<'a, T, P, <P::RotationType as O3DRotation<T>>::Native3DVecType>;
type Output = ParryProximaDistanceBoundsOutputOption<T, P>;

#[inline(always)]
fn query<'a>(shape_a: &Self::ShapeTypeA, shape_b: &Self::ShapeTypeB, pose_a: &P, pose_b: &P, args: &Self::Args<'a>) -> Self::Output {
let start = Instant::now();

Expand Down
62 changes: 34 additions & 28 deletions optima_refactor/crates/optima_proximity/src/proxima.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,19 @@ impl OPairGroupQryTrait for ParryProximaQry {
proximity_lower_bound_sum += lower_bound_through_loss_and_powf;
proximity_upper_bound_sum += upper_bound_through_loss_and_powf;
// &(T, T, T, &P, (u64, u64), &ParryPairIdxs)
let idx = v.binary_search_by(|x: &(T, T, T, &P, (u64, u64), &ParryPairIdxs)| diff.partial_cmp(&x.0).unwrap());
let idx = match idx {
Ok(idx) => { idx }
Err(idx) => { idx }
};
v.insert(idx, (diff, lower_bound_through_loss_and_powf, upper_bound_through_loss_and_powf, &data.displacement_between_a_and_b_k, x.pair_ids(), x.pair_idxs()));
// let idx = v.binary_search_by(|x: &(T, T, T, &P, (u64, u64), &ParryPairIdxs)| diff.partial_cmp(&x.0).expect(&format!("{:?}, {:?}", diff, x.0)));
// let idx = match idx {
// Ok(idx) => { idx }
// Err(idx) => { idx }
// };
// v.insert(idx, (diff, lower_bound_through_loss_and_powf, upper_bound_through_loss_and_powf, &data.displacement_between_a_and_b_k, x.pair_ids(), x.pair_idxs()));
v.push((diff, lower_bound_through_loss_and_powf, upper_bound_through_loss_and_powf, &data.displacement_between_a_and_b_k, x.pair_ids(), x.pair_idxs()));
}
}
});

let mut indices = (0..v.len()).collect::<Vec<_>>();
indices.sort_by(|x, y| v[*x].0.partial_cmp(&v[*y].0).unwrap());
// v.sort_by(|x, y| { y.0.partial_cmp(&x.0).expect(&format!("y.1 {:?}, y.0 {:?}, x.1 {:?}, x.0 {:?}", y.1, y.0, x.1, x.0)) });

let mut proximity_lower_bound_output;
Expand All @@ -128,27 +131,30 @@ impl OPairGroupQryTrait for ParryProximaQry {

let mut num_queries = 0;

let mut idx = indices.len();
'l: loop {
proximity_lower_bound_output = proximity_lower_bound_sum.powf(args.p_norm.recip());
proximity_upper_bound_output = proximity_upper_bound_sum.powf(args.p_norm.recip());
// assert!(proximity_upper_bound_output >= proximity_lower_bound_output);
max_possible_error = (proximity_upper_bound_output - proximity_lower_bound_output).abs();

if num_queries >= v.len() { break 'l; }
if idx == 0 { break 'l; }
idx -= 1;

let mut terminate = false;
match &args.termination {
ProximaTermination::MaxTime(t) => {
if start.elapsed() > *t { terminate = true; }
}
ProximaTermination::MaxError(e) => {
if max_possible_error.to_constant() < * e { terminate = true; }
if max_possible_error.to_constant() < *e { terminate = true; }
}
}

if terminate { break 'l; }

let curr_entry = &v[num_queries];
let curr_entry = &v[idx];
let lower_bound_through_loss_and_powf = curr_entry.1;
let upper_bound_through_loss_and_powf = curr_entry.2;
let displacement_between_a_and_b_k = curr_entry.3;
Expand Down Expand Up @@ -186,27 +192,27 @@ impl OPairGroupQryTrait for ParryProximaQry {
block.displacement_between_a_and_b_j_staging = displacement_between_a_and_b_k.o3dpose_downcast_or_convert::<Isometry3<T>>().as_ref().o3dpose_to_constant_ads();

/*
match block {
None => {
binding.hashmap.insert(ids, ProximaGenericBlock {
pose_a_j: poses.0.o3dpose_downcast_or_convert::<Isometry3<T>>().as_ref().clone(),
pose_b_j: poses.1.o3dpose_downcast_or_convert::<Isometry3<T>>().as_ref().clone(),
closest_point_a_j: contact.point1.o3dvec_downcast_or_convert::<Vector3<T>>().as_ref().clone(),
closest_point_b_j: contact.point2.o3dvec_downcast_or_convert::<Vector3<T>>().as_ref().clone(),
raw_distance_j: *raw_distance,
displacement_between_a_and_b_j: displacement_between_a_and_b_k.o3dpose_downcast_or_convert::<Isometry3<T>>().as_ref().clone(),
});
}
Some(block) => {
block.displacement_between_a_and_b_j = displacement_between_a_and_b_k.o3dpose_downcast_or_convert::<Isometry3<T>>().as_ref().clone();
block.raw_distance_j = *raw_distance;
block.closest_point_a_j = contact.point1.o3dvec_downcast_or_convert::<Vector3<T>>().as_ref().clone();
block.closest_point_b_j = contact.point2.o3dvec_downcast_or_convert::<Vector3<T>>().as_ref().clone();
block.pose_a_j = poses.0.o3dpose_downcast_or_convert::<Isometry3<T>>().as_ref().clone();
block.pose_b_j = poses.1.o3dpose_downcast_or_convert::<Isometry3<T>>().as_ref().clone();
}
match block {
None => {
binding.hashmap.insert(ids, ProximaGenericBlock {
pose_a_j: poses.0.o3dpose_downcast_or_convert::<Isometry3<T>>().as_ref().clone(),
pose_b_j: poses.1.o3dpose_downcast_or_convert::<Isometry3<T>>().as_ref().clone(),
closest_point_a_j: contact.point1.o3dvec_downcast_or_convert::<Vector3<T>>().as_ref().clone(),
closest_point_b_j: contact.point2.o3dvec_downcast_or_convert::<Vector3<T>>().as_ref().clone(),
raw_distance_j: *raw_distance,
displacement_between_a_and_b_j: displacement_between_a_and_b_k.o3dpose_downcast_or_convert::<Isometry3<T>>().as_ref().clone(),
});
}
*/
Some(block) => {
block.displacement_between_a_and_b_j = displacement_between_a_and_b_k.o3dpose_downcast_or_convert::<Isometry3<T>>().as_ref().clone();
block.raw_distance_j = *raw_distance;
block.closest_point_a_j = contact.point1.o3dvec_downcast_or_convert::<Vector3<T>>().as_ref().clone();
block.closest_point_b_j = contact.point2.o3dvec_downcast_or_convert::<Vector3<T>>().as_ref().clone();
block.pose_a_j = poses.0.o3dpose_downcast_or_convert::<Isometry3<T>>().as_ref().clone();
block.pose_b_j = poses.1.o3dpose_downcast_or_convert::<Isometry3<T>>().as_ref().clone();
}
}
*/
}

Box::new(ParryProximaGroupOutput {
Expand Down
42 changes: 38 additions & 4 deletions optima_refactor/src/bin/test7.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,42 @@
use ad_trait::differentiable_function::{FiniteDifferencing2, ForwardAD2, ForwardADMulti2, ReverseAD2};
use ad_trait::forward_ad::adf::adf_f64x8;
use ad_trait::forward_ad::adfn::adfn;
use nalgebra::Isometry3;
use optima_3d_spatial::optima_3d_pose::O3DPose;
use optima_bevy::optima_bevy_utils::robotics::BevyRoboticsTrait;
use optima_robotics::robot::{ORobotDefault, SaveRobot};
use optima_interpolation::InterpolatorTrait;
use optima_interpolation::splines::{InterpolatingSpline, InterpolatingSplineType};
use optima_optimization2::{DiffBlockOptimizerTrait, OptimizerOutputTrait};
use optima_optimization2::open::SimpleOpEnOptimizer;
use optima_proximity::pair_group_queries::{OwnedEmptyParryFilter, OwnedEmptyToProximityQry, OwnedParryDistanceAsProximityGroupQry, OwnedParryDistanceGroupSequenceFilter, ParryDistanceGroupArgs, ParryDistanceGroupSequenceFilterArgs, ParryPairSelector, ProximityLossFunction};
use optima_proximity::pair_queries::{ParryDisMode, ParryShapeRep};
use optima_proximity::proxima::{OwnedParryProximaAsProximityQry, PairGroupQryArgsParryProxima, ProximaTermination};
use optima_robotics::robot::{ORobotDefault};
use optima_robotics::robotics_optimization2::robotics_optimization_ik::{DifferentiableBlockIKObjectiveTrait, IKGoalUpdateMode};

fn main() {
let mut r = ORobotDefault::load_from_saved_robot("xarm7_with_gripper_and_rail");
r.preprocess(SaveRobot::Save(None));
r.bevy_self_collision_visualization();
let mut r = ORobotDefault::load_from_saved_robot("xarm7_with_gripper_and_rail_8dof");
// r.set_joint_as_fixed(12, &[0.0]);
// r.save_robot(Some("xarm7_with_gripper_and_rail_8dof"));

let init_condition = vec![0.1; 8];
let fq = OwnedParryDistanceGroupSequenceFilter::new(ParryDistanceGroupSequenceFilterArgs::new(vec![ParryShapeRep::BoundingSphere], vec![], 0.6, true, ParryDisMode::ContactDis));
let q = OwnedParryProximaAsProximityQry::new(PairGroupQryArgsParryProxima::new(ParryShapeRep::Full, true, false, ProximaTermination::MaxError(0.2), ProximityLossFunction::Hinge, 15.0, 0.6));
// let q = OwnedParryDistanceAsProximityGroupQry::new(ParryDistanceGroupArgs::new(ParryShapeRep::Full, ParryDisMode::ContactDis, true, false, -1000.0, false));
let db = r.get_ik_differentiable_block(ForwardADMulti2::<adfn<8>>::new(), fq, q, None, &init_condition, vec![19], 0.09, 0.6, 1.0, 0.05, 1.0, 0.1, 0.1);
let o = SimpleOpEnOptimizer::new(r.get_dof_lower_bounds(), r.get_dof_upper_bounds(), 0.001);

let mut solutions = vec![];
let mut curr_solution = init_condition.clone();
for _ in 0..3000 {
solutions.push(curr_solution.clone());
let solution = o.optimize_unconstrained(&curr_solution, &db);
println!("{:?}", solution.solver_status().solve_time());
// println!("{:?}", solution.solver_status().iterations());
curr_solution = solution.x_star().to_vec();
db.update_prev_states(curr_solution.clone());
db.update_ik_pose(0, Isometry3::from_constructors(&[0.,0.,0.0001], &[0.,0.,0.]), IKGoalUpdateMode::GlobalRelative);
}
let spline = InterpolatingSpline::new(solutions, InterpolatingSplineType::Linear).to_timed_interpolator(6.0);
r.bevy_motion_playback(&spline);
}
8 changes: 8 additions & 0 deletions optima_refactor/src/bin/test8.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@


fn main() {
let a = vec![1,5,8,2,5,3,7,9,0,6,4,1];
let mut indices = (0..a.len()).collect::<Vec<_>>();
indices.sort_unstable_by_key(|i| &a[*i]);
println!("{:?}", indices);
}

0 comments on commit 3281b21

Please sign in to comment.