Skip to content

Commit

Permalink
Add doctest to non_overlapping_qspace_samples, made general-purpose
Browse files Browse the repository at this point in the history
  • Loading branch information
dPys committed Mar 24, 2020
1 parent f5ca8c1 commit 66f8b48
Showing 1 changed file with 61 additions and 10 deletions.
71 changes: 61 additions & 10 deletions dmriprep/utils/vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,26 +377,70 @@ def bvecs2ras(affine, bvecs, norm=True, bvec_norm_epsilon=0.2):
return rotated_bvecs


def _nonoverlapping_qspace_samples(
prediction_bval, prediction_bvec, all_bvals, all_bvecs, cutoff
):
def nonoverlapping_qspace_samples(sample_bval, sample_bvec, all_bvals,
all_bvecs, cutoff=2):
"""Ensure that none of the training samples are too close to the sample to predict.
Parameters
Parameters
----------
sample_bval : int
A single b-value sampled along the sphere.
sample_bvec : int
A single b-vector sampled along the sphere.
Should correspond to `sample_bval`.
all_bvals : ndarray
A 1D vector of all b-values from the diffusion series.
all_bvecs: ndarray
A 3 x n vector of all vectors from the diffusion series,
where n is the total number of samples.
cutoff : float
A minimal allowable q-space distance between points on
the sphere.
Returns
-------
ok_samples : boolean ndarray
True for q-vectors whose spatial distribution along
the sphere is non-overlapping, else False.
Examples
--------
>>> bvec1 = np.array([1, 0, 0])
>>> bvec2 = np.array([1, 0, 0])
>>> bvec3 = np.array([0, 1, 0])
>>> bval1 = 1000
>>> bval2 = 1000
>>> bval3 = 1000
>>> all_bvals = np.array([0, bval2, bval3])
>>> all_bvecs = np.array([np.zeros(3), bvec2, bvec3])
>>> # Case 1: overlapping
>>> nonoverlapping_qspace_samples(bval1, bvec1, all_bvals, all_bvecs, cutoff=2)
array([ True, False, True])
>>> all_bvals = np.array([0, bval1, bval2])
>>> all_bvecs = np.array([np.zeros(3), bvec1, bvec2])
>>> # Case 2: non-overlapping
>>> nonoverlapping_qspace_samples(bval3, bvec3, all_bvals, all_bvecs, cutoff=2)
array([ True, True, True])
"""
min_bval = min(min(all_bvals), prediction_bval)
min_bval = min(min(all_bvals), sample_bval)
max_bval = max(max(all_bvals), sample_bval)
if min_bval == max_bval:
raise ValueError('All b-values are identical')

all_qvals = np.sqrt(all_bvals - min_bval)
prediction_qval = np.sqrt(prediction_bval - min_bval)
sample_qval = np.sqrt(sample_bval - min_bval)

# Convert q values to percent of maximum qval
max_qval = max(max(all_qvals), prediction_qval)
max_qval = max(max(all_qvals), sample_qval)
all_qvals_scaled = all_qvals / max_qval * 100
scaled_qvecs = all_bvecs * all_qvals_scaled[:, np.newaxis]
scaled_prediction_qvec = prediction_bvec * (prediction_qval / max_qval * 100)
scaled_sample_qvec = sample_bvec * (sample_qval / max_qval * 100)

# Calculate the distance between the sampled qvecs and the prediction qvec
# Calculate the distance between all qvecs and the sample qvec
ok_samples = (
np.linalg.norm(scaled_qvecs - scaled_prediction_qvec, axis=1) > cutoff
) * (np.linalg.norm(scaled_qvecs + scaled_prediction_qvec, axis=1) > cutoff)
np.linalg.norm(scaled_qvecs - scaled_sample_qvec, axis=1) > cutoff
) * (np.linalg.norm(scaled_qvecs + scaled_sample_qvec, axis=1) > cutoff)

return ok_samples

Expand All @@ -409,6 +453,9 @@ def _rasb_to_bvec_list(in_rasb):
----------
in_rasb : str or os.pathlike
File path to a RAS-B gradient table.
Returns
-------
List of b-vectors as floats.
"""
import numpy as np

Expand All @@ -425,6 +472,10 @@ def _rasb_to_bval_floats(in_rasb):
----------
in_rasb : str or os.pathlike
File path to a RAS-B gradient table.
Returns
-------
List of b-values as floats.
"""
import numpy as np

Expand Down

0 comments on commit 66f8b48

Please sign in to comment.