Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specific function for calculting batches logPDF #48

Open
StanleyGreen opened this issue Feb 10, 2025 · 1 comment
Open

Specific function for calculting batches logPDF #48

StanleyGreen opened this issue Feb 10, 2025 · 1 comment
Assignees
Labels
enhancement New feature or request

Comments

@StanleyGreen
Copy link

When using a neural network as a black box within a factor graph, it would be computationally more efficient to calculate the loss function using batches instead of processing samples one by one. This improvement would provide the following benefits:

  • Efficient Handling of Out-of-Range Values: Enables disregarding multiple out-of-range values at once, reducing unnecessary computations.
  • Better GPU Utilization: Allows computing multiple samples simultaneously, leveraging parallel processing capabilities of modern GPUs.
  • Customizable Behavior: Having a dedicated function for batch-based loss calculation would allow users to modify it to fit their specific use case, such as defining custom handling for out-of-range values.
@Nimrais Nimrais self-assigned this Feb 10, 2025
@Nimrais Nimrais added the enhancement New feature or request label Feb 10, 2025
@Nimrais
Copy link
Member

Nimrais commented Feb 10, 2025

Checkout the attached PR, it allows you to run the following code, which seems much your suggestion

using BayesBase
using Distributions
using ExponentialFamilyProjection
using ExponentialFamily
using LinearAlgebra
using Random
using StableRNGs
using Test
using BenchmarkTools

import BayesBase: InplaceLogpdf

struct BatchLogpdf{F,N}
    batch_logpdf::F
    
    function BatchLogpdf{N}(batch_logpdf::F) where {F,N}
        return new{F,N}(batch_logpdf)
    end
end

# Constructor with batch size
function BatchLogpdf(batch_logpdf::F; batch_size::Int=100) where F
    return BatchLogpdf{batch_size}(batch_logpdf)
end

# Handle batch operations
function (b::BatchLogpdf{F,N})(out::AbstractVector, samples) where {F,N}
    n_samples = length(samples)
    n_batches = ceil(Int, n_samples / N)
    
    # Process samples in batches
    for i in 1:n_batches
        start_idx = (i-1) * N + 1
        end_idx = min(i * N, n_samples)
        batch_slice = start_idx:end_idx
        
        # Process current batch
        view(out, batch_slice).= b.batch_logpdf(view(samples, batch_slice))
    end
    
    return out
end

# Handle single sample operations (vector input)
function (b::BatchLogpdf{F,N})(x::AbstractVector) where {F,N}
    out = zeros(length(x))
    b(out, x)  # Use the batch operation method
    return out
end

# Convert regular logpdf function to BatchLogpdf
function Base.convert(::Type{BatchLogpdf}, logpdf_fn)
    return BatchLogpdf{100}(logpdf_fn)  # Default batch size
end

function Base.convert(::Type{BatchLogpdf{N}}, logpdf_fn) where N
    return BatchLogpdf{N}(logpdf_fn)
end

function Base.convert(::Type{BatchLogpdf}, logpdf_fn::BatchLogpdf{F,N}) where {F,N}
    return logpdf_fn
end

# Create a delayed normal distribution
function create_delayed_normal(delay_seconds=0.1)
    dist = Normal(0.0, 1.0)
    return function delayed_logpdf(x)
        sleep(delay_seconds) # expansive operation (for example moving data to GPU)
        return logpdf(dist, x)
    end
end

# Test setup
nsamples = 100
delay = 0.001  # 1ms delay
batch_size = 20  # Process 20 samples at a time

# Test with ControlVariateStrategy
println("\nTesting ControlVariateStrategy performance:")

# Increase number of samples to make difference more noticeable
nsamples = 100

# Add counter to track number of logpdf calls
ncalls = 0
target_logpdf = function(x)
    global ncalls += 1
    sleep(delay)
    return logpdf(Normal(0.0, 1.0), x)
end

# Create strategies with different base_logpdf_type
strategy_batch = ExponentialFamilyProjection.ControlVariateStrategy(
    nsamples=nsamples,
    base_logpdf_type=BatchLogpdf{batch_size},  # Specify batch size in type
)

strategy_inplace = ExponentialFamilyProjection.ControlVariateStrategy(
    nsamples=nsamples,
    base_logpdf_type=InplaceLogpdf,
)

# Create projection with minimal iterations to focus on logpdf performance
projection_batch = ProjectedTo(
    NormalMeanVariance, 
    parameters=ProjectionParameters(
        niterations=3,
        tolerance=1e-1,
        strategy=strategy_batch
    )
)

projection_inplace = ProjectedTo(
    NormalMeanVariance, 
    parameters=ProjectionParameters(
        niterations=3,
        tolerance=1e-1,
        strategy=strategy_inplace
    )
)

# Reset counter and time batch strategy
global ncalls = 0
time_proj_batch = @elapsed begin
    result_batch = project_to(projection_batch, target_logpdf)
end
ncalls_batch = ncalls

# Reset counter and time inplace strategy
global ncalls = 0
time_proj_inplace = @elapsed begin
    result_inplace = project_to(projection_inplace, target_logpdf)
end
ncalls_inplace = ncalls

println("Time with BatchLogpdf strategy: ", time_proj_batch)
println("Time with InplaceLogpdf strategy: ", time_proj_inplace)
println("Number of calls (batch): ", ncalls_batch)
println("Number of calls (inplace): ", ncalls_inplace)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants