Skip to content

Commit

Permalink
Slight fix for GSL interpolation parallelization (#195)
Browse files Browse the repository at this point in the history
* fix for interpolation parallelization

* revert and fix parallel omp code

---------

Co-authored-by: Michael McCrackan <[email protected]>
Co-authored-by: Michael McCrackan <[email protected]>
  • Loading branch information
3 people authored Jan 15, 2025
1 parent e30cca2 commit a5af0e6
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/array_ops.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ void _interp1d(const bp::object & x, const bp::object & y, const bp::object & x_
gsl_interp_accel* acc = gsl_interp_accel_alloc();
gsl_spline* spline = gsl_spline_alloc(interp_type, n_x);

#pragma omp parallel for
#pragma omp for
for (int row = 0; row < n_rows; ++row) {

int y_row_start = row * y_data_stride;
Expand All @@ -910,10 +910,10 @@ void _interp1d(const bp::object & x, const bp::object & y, const bp::object & x_
T* y_row = y_data + y_row_start;
T* y_interp_row = y_interp_data + y_interp_row_start;

interp_func(x_data, y_row, x_interp_data, y_interp_row,
interp_func(x_data, y_row, x_interp_data, y_interp_row,
n_x, n_x_interp, spline, acc);
}

// Free gsl objects
gsl_spline_free(spline);
gsl_interp_accel_free(acc);
Expand All @@ -930,7 +930,7 @@ void _interp1d(const bp::object & x, const bp::object & y, const bp::object & x_
std::transform(x_data, x_data + n_x, x_dbl,
[](float value) { return static_cast<double>(value); });

std::transform(x_interp_data, x_interp_data + n_x_interp, x_interp_dbl,
std::transform(x_interp_data, x_interp_data + n_x_interp, x_interp_dbl,
[](float value) { return static_cast<double>(value); });

#pragma omp parallel
Expand All @@ -939,7 +939,7 @@ void _interp1d(const bp::object & x, const bp::object & y, const bp::object & x_
gsl_interp_accel* acc = gsl_interp_accel_alloc();
gsl_spline* spline = gsl_spline_alloc(interp_type, n_x);

#pragma omp parallel for
#pragma omp for
for (int row = 0; row < n_rows; ++row) {

int y_row_start = row * y_data_stride;
Expand All @@ -949,13 +949,13 @@ void _interp1d(const bp::object & x, const bp::object & y, const bp::object & x_
// Transform y row to double array for gsl
double y_dbl[n_x];

std::transform(y_data + y_row_start, y_data + y_row_end, y_dbl,
std::transform(y_data + y_row_start, y_data + y_row_end, y_dbl,
[](float value) { return static_cast<double>(value); });

T* y_interp_row = y_interp_data + y_interp_row_start;

// Don't copy y_interp to doubles as it is cast during assignment
interp_func(x_dbl, y_dbl, x_interp_dbl, y_interp_row,
interp_func(x_dbl, y_dbl, x_interp_dbl, y_interp_row,
n_x, n_x_interp, spline, acc);
}

Expand All @@ -977,15 +977,15 @@ void interp1d_linear(const bp::object & x, const bp::object & y,
const gsl_interp_type* interp_type = gsl_interp_linear;
// Pointer to interpolation function
_interp_func_pointer<float> interp_func = &_linear_interp<float>;

_interp1d<float>(x, y, x_interp, y_interp, interp_type, interp_func);
}
else if (dtype == NPY_DOUBLE) {
// GSL interpolation type
const gsl_interp_type* interp_type = gsl_interp_linear;
// Pointer to interpolation function
_interp_func_pointer<double> interp_func = &_linear_interp<double>;

_interp1d<double>(x, y, x_interp, y_interp, interp_type, interp_func);
}
else {
Expand Down

0 comments on commit a5af0e6

Please sign in to comment.