Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions src/interpolator/one/strategies.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use strategy::cubic::*;
use strategy::*;

impl<D> Strategy1D<D> for Linear
Expand Down Expand Up @@ -37,6 +38,143 @@ where
}
}

impl<D> Strategy1D<D> for Cubic<D::Elem>
where
D: Data + RawDataClone + Clone,
D::Elem: Float + Euclid + Debug,
{
fn init(&mut self, data: &InterpData1D<D>) -> Result<(), ValidateError> {
// Number of segments
let n = data.grid[0].len() - 1;

let zero = D::Elem::zero();
let one = D::Elem::one();
let two = <D::Elem as NumCast>::from(2.).unwrap();
let six = <D::Elem as NumCast>::from(6.).unwrap();

let h = Array1::from_shape_fn(n, |i| data.grid[0][i + 1] - data.grid[0][i]);
let v = Array1::from_shape_fn(n - 1, |i| two * (h[i + 1] + h[i]));
let b = Array1::from_shape_fn(n, |i| (data.values[i + 1] - data.values[i]) / h[i]);
let u = Array1::from_shape_fn(n - 1, |i| six * (b[i + 1] - b[i]));

let (sub, diag, sup, rhs) = match &self.boundary_condition {
CubicBC::Natural => {
let zero = array![zero];
let one = array![one];
(
&ndarray::concatenate(Axis(0), &[h.slice(s![0..n - 1]), zero.view()]).unwrap(),
&ndarray::concatenate(Axis(0), &[one.view(), v.view(), one.view()]).unwrap(),
&ndarray::concatenate(Axis(0), &[zero.view(), h.slice(s![1..n])]).unwrap(),
&ndarray::concatenate(Axis(0), &[zero.view(), u.view(), zero.view()]).unwrap(),
)
}
CubicBC::Clamped(l, r) => {
let diag_0 = array![two * h[0]];
let diag_n = array![two * h[n - 1]];
let rhs_0 = array![six * (b[0] - *l)];
let rhs_n = array![six * (*r - b[n - 1])];
(
&h,
&ndarray::concatenate(Axis(0), &[diag_0.view(), v.view(), diag_n.view()])
.unwrap(),
&h,
&ndarray::concatenate(Axis(0), &[rhs_0.view(), u.view(), rhs_n.view()])
.unwrap(),
)
}
CubicBC::NotAKnot => {
let three = two + one;
let sub_n =
array![two * h[n - 1].powi(2) + three * h[n - 1] * h[n - 2] + h[n - 2].powi(2)];
let diag_0 = array![h[0].powi(2) - h[1].powi(2)];
let diag_n = array![h[n - 1].powi(2) - h[n - 2].powi(2)];
let sup_0 = array![two * h[0].powi(2) + three * h[0] * h[1] + h[1].powi(2)];
let rhs_0 = array![h[0] * u[0]];
let rhs_n = array![h[n - 1] * u[n - 2]];

println!(
"sub {:?}",
&ndarray::concatenate(Axis(0), &[h.slice(s![0..n - 1]), sub_n.view()]).unwrap()
);
println!(
"dia {:?}",
&ndarray::concatenate(Axis(0), &[diag_0.view(), v.view(), diag_n.view()])
.unwrap()
);
println!(
"sup {:?}",
&ndarray::concatenate(Axis(0), &[sup_0.view(), h.slice(s![1..n])]).unwrap()
);
println!(
"rhs {:?}",
&ndarray::concatenate(Axis(0), &[rhs_0.view(), u.view(), rhs_n.view()])
.unwrap()
);
(
&ndarray::concatenate(Axis(0), &[h.slice(s![0..n - 1]), sub_n.view()]).unwrap(),
&ndarray::concatenate(Axis(0), &[diag_0.view(), v.view(), diag_n.view()])
.unwrap(),
&ndarray::concatenate(Axis(0), &[sup_0.view(), h.slice(s![1..n])]).unwrap(),
&ndarray::concatenate(Axis(0), &[rhs_0.view(), u.view(), rhs_n.view()])
.unwrap(),
)
}
_ => unreachable!(),
};

self.z = Self::thomas(sub.view(), diag.view(), sup.view(), rhs.view()).into_dyn();

Ok(())
}

fn interpolate(
&self,
data: &InterpData1D<D>,
point: &[<D>::Elem; 1],
) -> Result<<D>::Elem, InterpolateError> {
let last = data.grid[0].len() - 1;
let l = if point[0] < data.grid[0][0] {
match &self.extrapolate {
CubicExtrapolate::Linear => {
let h0 = data.grid[0][1] - data.grid[0][0];
let k0 = (data.values[1] - data.values[0]) / h0
- h0 * self.z[1] / <D::Elem as NumCast>::from(6.).unwrap();
return Ok(k0 * (point[0] - data.grid[0][0]) + data.values[0]);
}
CubicExtrapolate::Spline => 0,
CubicExtrapolate::Wrap => {
let point = [wrap(point[0], data.grid[0][0], data.grid[0][last])];
let l = find_nearest_index(data.grid[0].view(), &point[0]);
return self.evaluate_1d(&point, l, data);
}
}
} else if point[0] > data.grid[0][last] {
match &self.extrapolate {
CubicExtrapolate::Linear => {
let hn = data.grid[0][last] - data.grid[0][last - 1];
let kn = (data.values[last] - data.values[last - 1]) / hn
+ hn * self.z[last - 1] / <D::Elem as NumCast>::from(6.).unwrap();
return Ok(kn * (point[0] - data.grid[0][last]) + data.values[last]);
}
CubicExtrapolate::Spline => last - 1,
CubicExtrapolate::Wrap => {
let point = [wrap(point[0], data.grid[0][0], data.grid[0][last])];
let l = find_nearest_index(data.grid[0].view(), &point[0]);
return self.evaluate_1d(&point, l, data);
}
}
} else {
find_nearest_index(data.grid[0].view(), &point[0])
};
self.evaluate_1d(point, l, data)
}

/// Returns `true`
fn allow_extrapolate(&self) -> bool {
true
}
}

impl<D> Strategy1D<D> for Nearest
where
D: Data + RawDataClone + Clone,
Expand Down
231 changes: 231 additions & 0 deletions src/interpolator/one/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,234 @@ fn test_extrapolate() {
assert_approx_eq!(interp.interpolate(&[-0.75]).unwrap(), 0.05);
assert_eq!(interp.interpolate(&[5.]).unwrap(), 1.2);
}

#[test]
fn test_cubic_natural() {
let x = array![1., 2.4, 3.1, 5., 7.6, 8.3, 10., 10.1];
let f_x = array![3., -90., 19., 99., 291., 444., 222., 250.];

let interp = Interp1D::new(
x.view(),
f_x.view(),
strategy::Cubic::natural(),
Extrapolate::Enable,
)
.unwrap();

// Interpolating at knots returns values
for i in 0..x.len() {
assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]);
}

let x0 = x.first().unwrap();
let xn = x.last().unwrap();
let y0 = f_x.first().unwrap();
let yn = f_x.last().unwrap();

let range = xn - x0;

let x_low = x0 - 0.2 * range;
let y_low = interp.interpolate(&[x_low]).unwrap();
let slope_low = (y0 - y_low) / (x0 - x_low);

let x_high = xn + 0.2 * range;
let y_high = interp.interpolate(&[x_high]).unwrap();
let slope_high = (y_high - yn) / (x_high - xn);

let xs_left = Array1::linspace(*x0, x0 + 2e-6, 50);
let xs_right = Array1::linspace(xn - 2e-6, *xn, 50);

// Left extrapolation is linear
let ys: Array1<f64> = xs_left
.iter()
.map(|&x| interp.interpolate(&[x]).unwrap())
.collect();
let slopes: Array1<f64> = xs_left
.windows(2)
.into_iter()
.zip(ys.windows(2))
.map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0]))
.collect();
assert_approx_eq!(slopes.mean().unwrap(), slope_low);

// Right extrapolation is linear
let ys: Array1<f64> = xs_right
.iter()
.map(|&x| interp.interpolate(&[x]).unwrap())
.collect();
let slopes: Array1<f64> = xs_right
.windows(2)
.into_iter()
.zip(ys.windows(2))
.map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0]))
.collect();
assert_approx_eq!(slopes.mean().unwrap(), slope_high);
}

#[test]
fn test_cubic_clamped() {
let x = array![1., 2.4, 3.1, 5., 7.6, 8.3, 10., 10.1];
let f_x = array![3., -90., 19., 99., 291., 444., 222., 250.];

let xs_left = Array1::linspace(x.first().unwrap() - 1e-6, x.first().unwrap() + 1e-6, 50);
let xs_right = Array1::linspace(x.last().unwrap() - 1e-6, x.last().unwrap() + 1e-6, 50);

for (a, b) in [(-5., 10.), (0., 0.), (2.4, -5.2)] {
let interp = Interp1D::new(
x.view(),
f_x.view(),
strategy::Cubic::clamped(a, b),
Extrapolate::Enable,
)
.unwrap();

// Interpolating at knots returns values
for i in 0..x.len() {
assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]);
}

// Left slope = a
let ys: Array1<f64> = xs_left
.iter()
.map(|&x| interp.interpolate(&[x]).unwrap())
.collect();
let slopes: Array1<f64> = xs_left
.windows(2)
.into_iter()
.zip(ys.windows(2))
.map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0]))
.collect();
assert_approx_eq!(slopes.mean().unwrap(), a);

// Right slope = b
let ys: Array1<f64> = xs_right
.iter()
.map(|&x| interp.interpolate(&[x]).unwrap())
.collect();
let slopes: Array1<f64> = xs_right
.windows(2)
.into_iter()
.zip(ys.windows(2))
.map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0]))
.collect();
assert_approx_eq!(slopes.mean().unwrap(), b);
}
}

#[test]
fn test_cubic_not_a_knot() {
let x = array![1., 2.4, 3.1, 5., 7.6, 8.3, 10., 10.1];
let f_x = array![3., -90., 19., 99., 291., 444., 222., 250.];

let x = array![1., 2., 3., 5., 7., 8., 10.];
let f_x = array![3., -90., 19., 99., 291., 444., 222.];

let interp = Interp1D::new(
x.view(),
f_x.view(),
strategy::Cubic::not_a_knot(),
Extrapolate::Enable,
)
.unwrap();

// Interpolating at knots returns values
for i in 0..x.len() {
assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]);
}

// // Left slope = a
// let ys: Array1<f64> = xs_left
// .iter()
// .map(|&x| interp.interpolate(&[x]).unwrap())
// .collect();
// let slopes: Array1<f64> = xs_left
// .windows(2)
// .into_iter()
// .zip(ys.windows(2))
// .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0]))
// .collect();
// assert_approx_eq!(slopes.mean().unwrap(), a);

// // Right slope = b
// let ys: Array1<f64> = xs_right
// .iter()
// .map(|&x| interp.interpolate(&[x]).unwrap())
// .collect();
// let slopes: Array1<f64> = xs_right
// .windows(2)
// .into_iter()
// .zip(ys.windows(2))
// .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0]))
// .collect();
// assert_approx_eq!(slopes.mean().unwrap(), b);
}

// #[test]
// fn test_cubic_periodic() {
// let x = array![1., 2., 3., 5., 7., 8.];
// let f_x = array![3., -90., 19., 99., 291., 444.];
//
// let interp_extrap_enable =
// Interp1D::new(x.view(), f_x.view(), strategy::Cubic::periodic(), Extrapolate::Enable).unwrap();
// let interp_extrap_wrap =
// Interp1D::new(x.view(), f_x.view(), strategy::Cubic::periodic(), Extrapolate::Wrap).unwrap();
//
// // Interpolating at knots returns values
// for i in 0..x.len() {
// assert_approx_eq!(interp_extrap_enable.interpolate(&[x[i]]).unwrap(), f_x[i]);
// assert_approx_eq!(interp_extrap_wrap.interpolate(&[x[i]]).unwrap(), f_x[i]);
// }
//
// // Extrapolate::Enable is equivalent to Extrapolate::Wrap for Cubic::periodic()
// let x0 = x.first().unwrap();
// let xn = x.last().unwrap();
// let range = xn - x0;
// let x_low = x0 - 0.2 * range;
// let x_high = x0 + 0.2 * range;
// let xs_left = Array1::linspace(x_low, *x0, 50);
// let xs_right = Array1::linspace(*xn, x_high, 50);
// for x in xs_left {
// assert_eq!(
// interp_extrap_enable.interpolate(&[x]).unwrap(),
// interp_extrap_wrap.interpolate(&[x]).unwrap()
// );
// }
// for x in xs_right {
// assert_eq!(
// interp_extrap_enable.interpolate(&[x]).unwrap(),
// interp_extrap_wrap.interpolate(&[x]).unwrap()
// );
// }
//
// // Slope left
// let xs_left = Array1::linspace(x_low, x_low + 2e6, 50);
// let ys_left: Array1<f64> = xs_left
// .iter()
// .map(|&x| interp_extrap_enable.interpolate(&[x]).unwrap())
// .collect();
// let slopes_left: Array1<f64> = xs_left
// .windows(2)
// .into_iter()
// .zip(ys_left.windows(2))
// .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0]))
// .collect();
// let slope_left = slopes_left.mean().unwrap();
// // Slope right
// let xs_right = Array1::linspace(x_high - 2e6, x_high, 50);
// let ys_right: Array1<f64> = xs_right
// .iter()
// .map(|&x| interp_extrap_enable.interpolate(&[x]).unwrap())
// .collect();
// let slopes_right: Array1<f64> = xs_right
// .windows(2)
// .into_iter()
// .zip(ys_right.windows(2))
// .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0]))
// .collect();
// let slope_right = slopes_right.mean().unwrap();
// // Slopes at left and right are equal
// assert_approx_eq!(slope_left, slope_right);
// // Second derivatives at left and right are equal
// let z = interp_extrap_enable.strategy.z;
// assert_approx_eq!(z.first().unwrap(), z.last().unwrap());
// }
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ pub use ndarray;
pub(crate) use ndarray::prelude::*;
pub(crate) use ndarray::{Data, Ix, RawDataClone};

pub(crate) use num_traits::{clamp, Euclid, Num, One};
pub(crate) use num_traits::{clamp, Euclid, Float, Num, NumCast, One, Zero};

pub(crate) use dyn_clone::*;

Expand Down
Loading
Loading