-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Specific function for calculting batches logPDF #48
Labels
enhancement
New feature or request
Comments
Checkout the attached PR, it allows you to run the following code, which seems much your suggestion using BayesBase
using Distributions
using ExponentialFamilyProjection
using ExponentialFamily
using LinearAlgebra
using Random
using StableRNGs
using Test
using BenchmarkTools
import BayesBase: InplaceLogpdf
struct BatchLogpdf{F,N}
batch_logpdf::F
function BatchLogpdf{N}(batch_logpdf::F) where {F,N}
return new{F,N}(batch_logpdf)
end
end
# Constructor with batch size
function BatchLogpdf(batch_logpdf::F; batch_size::Int=100) where F
return BatchLogpdf{batch_size}(batch_logpdf)
end
# Handle batch operations
function (b::BatchLogpdf{F,N})(out::AbstractVector, samples) where {F,N}
n_samples = length(samples)
n_batches = ceil(Int, n_samples / N)
# Process samples in batches
for i in 1:n_batches
start_idx = (i-1) * N + 1
end_idx = min(i * N, n_samples)
batch_slice = start_idx:end_idx
# Process current batch
view(out, batch_slice).= b.batch_logpdf(view(samples, batch_slice))
end
return out
end
# Handle single sample operations (vector input)
function (b::BatchLogpdf{F,N})(x::AbstractVector) where {F,N}
out = zeros(length(x))
b(out, x) # Use the batch operation method
return out
end
# Convert regular logpdf function to BatchLogpdf
function Base.convert(::Type{BatchLogpdf}, logpdf_fn)
return BatchLogpdf{100}(logpdf_fn) # Default batch size
end
function Base.convert(::Type{BatchLogpdf{N}}, logpdf_fn) where N
return BatchLogpdf{N}(logpdf_fn)
end
function Base.convert(::Type{BatchLogpdf}, logpdf_fn::BatchLogpdf{F,N}) where {F,N}
return logpdf_fn
end
# Create a delayed normal distribution
function create_delayed_normal(delay_seconds=0.1)
dist = Normal(0.0, 1.0)
return function delayed_logpdf(x)
sleep(delay_seconds) # expansive operation (for example moving data to GPU)
return logpdf(dist, x)
end
end
# Test setup
nsamples = 100
delay = 0.001 # 1ms delay
batch_size = 20 # Process 20 samples at a time
# Test with ControlVariateStrategy
println("\nTesting ControlVariateStrategy performance:")
# Increase number of samples to make difference more noticeable
nsamples = 100
# Add counter to track number of logpdf calls
ncalls = 0
target_logpdf = function(x)
global ncalls += 1
sleep(delay)
return logpdf(Normal(0.0, 1.0), x)
end
# Create strategies with different base_logpdf_type
strategy_batch = ExponentialFamilyProjection.ControlVariateStrategy(
nsamples=nsamples,
base_logpdf_type=BatchLogpdf{batch_size}, # Specify batch size in type
)
strategy_inplace = ExponentialFamilyProjection.ControlVariateStrategy(
nsamples=nsamples,
base_logpdf_type=InplaceLogpdf,
)
# Create projection with minimal iterations to focus on logpdf performance
projection_batch = ProjectedTo(
NormalMeanVariance,
parameters=ProjectionParameters(
niterations=3,
tolerance=1e-1,
strategy=strategy_batch
)
)
projection_inplace = ProjectedTo(
NormalMeanVariance,
parameters=ProjectionParameters(
niterations=3,
tolerance=1e-1,
strategy=strategy_inplace
)
)
# Reset counter and time batch strategy
global ncalls = 0
time_proj_batch = @elapsed begin
result_batch = project_to(projection_batch, target_logpdf)
end
ncalls_batch = ncalls
# Reset counter and time inplace strategy
global ncalls = 0
time_proj_inplace = @elapsed begin
result_inplace = project_to(projection_inplace, target_logpdf)
end
ncalls_inplace = ncalls
println("Time with BatchLogpdf strategy: ", time_proj_batch)
println("Time with InplaceLogpdf strategy: ", time_proj_inplace)
println("Number of calls (batch): ", ncalls_batch)
println("Number of calls (inplace): ", ncalls_inplace) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
When using a neural network as a black box within a factor graph, it would be computationally more efficient to calculate the loss function using batches instead of processing samples one by one. This improvement would provide the following benefits:
The text was updated successfully, but these errors were encountered: