From f36bad5532439b9f62a310cb4d3456be6d7f593a Mon Sep 17 00:00:00 2001 From: Matthieu Gomez Date: Wed, 19 Jun 2024 10:35:29 -0400 Subject: [PATCH] better type for derivatives --- Project.toml | 2 +- src/derivatives.jl | 28 ++++++++++++++-------------- src/jointoperator.jl | 6 ++---- 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index f832214..a5b0cdc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "InfinitesimalGenerators" uuid = "2fce0c6f-5f0b-5c85-85c9-2ffe1d5ee30d" -version = "2.1.0" +version = "2.2.0" [deps] Arpack = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" diff --git a/src/derivatives.jl b/src/derivatives.jl index 520d91d..a8143c6 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -1,19 +1,19 @@ -struct FirstDerivative{T} <: AbstractVector{T} - x::AbstractVector{<:Real} - y::AbstractVector{T} +struct FirstDerivative{T, X <: AbstractVector{<:Real}, Y <: AbstractVector{<: Real}} <: AbstractVector{T} + x::X + y::Y bc::NTuple{2}{T} direction::Symbol - function FirstDerivative{T}(x, y, bc, direction) where {T} + function FirstDerivative(x, y, bc, direction) size(x) == size(y) || throw(DimensionMismatch( "cannot match grid of length $(length(x)) with vector of length $(length(y))")) direction ∈ (:upward, :downward) || throw(ArgumentError("direction must be :upward or :downward")) - return new(x, y, bc, direction) + return new{float(eltype(y)), typeof(x), typeof(y)}(x, y, bc, direction) end end -function FirstDerivative(x::AbstractVector, y::AbstractVector; bc = (0, 0), direction = :upward) - FirstDerivative{eltype(y)}(x, y, bc, direction) +function FirstDerivative(x, y; bc = (0, 0), direction = :upward) + FirstDerivative(x, y, bc, direction) end Base.size(d::FirstDerivative) = (length(d.x), 1) @@ -40,19 +40,19 @@ function Base.getindex(d::FirstDerivative{T}, i::Int) where {T} end -struct SecondDerivative{T} <: AbstractVector{T} - x::AbstractVector{<:Real} - y::AbstractVector{T} +struct SecondDerivative{T, X <: AbstractVector{<:Real}, Y <: AbstractVector{<: Real}} <: AbstractVector{T} + x::X + y::Y bc::NTuple{2}{T} - function SecondDerivative{T}(x, y, bc) where {T} + function SecondDerivative(x, y, bc) length(x) == length(y) || throw(DimensionMismatch( "cannot match grid of length $(length(x)) with vector of length $(length(y))")) - return new(x, y, bc) + return new{float(eltype(y)), typeof(x), typeof(y)}(x, y, bc) end end -function SecondDerivative(x::AbstractVector, y::AbstractVector; bc = (0, 0)) - SecondDerivative{eltype(y)}(x, y, bc) +function SecondDerivative(x, y; bc = (0, 0)) + SecondDerivative(x, y, bc) end Base.size(d::SecondDerivative) = (length(d.x), 1) diff --git a/src/jointoperator.jl b/src/jointoperator.jl index d89383c..d9b9905 100644 --- a/src/jointoperator.jl +++ b/src/jointoperator.jl @@ -1,11 +1,9 @@ -function jointoperator(operators, Q::Array) +function jointoperator(operators::AbstractVector{<:Tridiagonal}, Q::Array) N = length(operators) wn = size(operators[1], 1) - @assert all(o isa Tridiagonal for o in operators) # check if all os have same size @assert all(size(o) == (wn, wn) for o in operators) - # check if the size of transition matrix is - # same as the number of operators + # check if the size of transition matrix is same as the number of operators @assert size(Q,1) == size(Q,2) == N J = BandedBlockBandedMatrix(Zeros(wn * N, wn * N), fill(wn, N) ,fill(wn, N), (N-1, N-1), (1, 1)) for i in 1:N