diff --git a/src/convnets/unet.jl b/src/convnets/unet.jl index 9a06a040..d78756f5 100644 --- a/src/convnets/unet.jl +++ b/src/convnets/unet.jl @@ -71,18 +71,44 @@ Backbone of any Metalhead ResNet-like model can be used as encoder - `final`: final block as described in original paper - `fdownscale`: downscale factor """ -function unet(encoder_backbone, imgdims, outplanes::Integer, - final::Any = unet_final_block, fdownscale::Integer = 0) - backbonelayers = collect(flatten_chains(encoder_backbone)) - layers = unetlayers(backbonelayers, imgdims; m_middle = unet_middle_block, - skip_upscale = fdownscale) +function unet(encoder_backbone, imgdims, inchannels::Integer,outplanes::Integer, + final::Any = unet_final_block, fdownscale::Integer = 0) +backbonelayers = collect(flatten_chains(encoder_backbone)) - outsz = Flux.outputsize(layers, imgdims) - layers = Chain(layers, final(outsz[end - 1], outplanes)) +# Adjusting input size to include channels +adjusted_imgdims = (imgdims..., inchannels, 1) - return layers +layers = unetlayers(backbonelayers, adjusted_imgdims; m_middle = unet_middle_block, + skip_upscale = fdownscale) + +outsz = Flux.outputsize(layers, adjusted_imgdims) +layers = Chain(layers, final(outsz[end - 1], outplanes)) + +return layers +end +function modify_first_conv_layer_advanced(encoder_backbone, inchannels) + layers = [layer for layer in encoder_backbone.layers] # Create a mutable array from the layers + modified = false + for index in 1:length(layers) + if isa(layers[index], Flux.Conv) && !modified + layer = layers[index] + outchannels = size(layer.weight, 1) # The number of output channels + kernel_size = (size(layer.weight, 3), size(layer.weight, 4)) # Kernel size + stride = layer.stride + pad = layer.pad + + + new_conv_layer = Flux.Conv(kernel_size, inchannels => outchannels, stride=stride, pad=pad) + layers[index] = new_conv_layer # Replace the old layer with the new one + + modified = true # Mark as modified to avoid changing any other Conv layer + end + end + return Flux.Chain(layers...) # Reconstruct the model with the modified layers end + + """ UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) @@ -110,17 +136,25 @@ See also [`Metalhead.unet`](@ref). struct UNet layers::Any end -@functor UNet +@functor UNet function UNet(imsize::Dims{2} = (256, 256), inchannels::Integer = 3, outplanes::Integer = 3, encoder_backbone = Metalhead.backbone(DenseNet(121)); pretrain::Bool = false) + + if inchannels != 3 + encoder_backbone = modify_first_conv_layer_advanced(encoder_backbone, inchannels) + + end + layers = unet(encoder_backbone, (imsize..., inchannels, 1), outplanes) model = UNet(layers) if pretrain + artifact_name = "UNet" loadpretrain!(model, artifact_name) end return model end + (m::UNet)(x::AbstractArray) = m.layers(x)