Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Project #5

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open

Project #5

wants to merge 2 commits into from

Conversation

akmathur1
Copy link

Abstract
Medical image segmentation is critical in medical image analysis, enabling precise delineation of anatomical structures and pathological regions. The evaluation of segmentation algorithms necessitates robust and efficient metrics to ensure accuracy and reliability. In this paper, we present MedEval3D, a CUDA-accelerated package for calculating a comprehensive set of medical segmentation metrics. Our implementation leverages GPU acceleration to significantly reduce computation time, facilitating rapid assessment of segmentation performance on large 3D medical datasets. The metrics calculations are based on the mathematical framework proposed by Taha et al., ensuring a solid theoretical foundation. We provide a detailed explanation of the two-phase metric evaluation process, including the preparation and execution steps. Benchmarking results demonstrate the substantial performance gains achieved by our package compared to existing CPU-based solutions. Additionally, we highlight the versatility of MedEval3D in handling multiple metrics simultaneously with minimal overhead. An example implementation showcases the integration of MedEval3D with the nnunet_utilities Python module, allowing seamless calculation of Dice loss using both Julia and Python functions. Our work aims to advance the field of medical image segmentation by providing a powerful tool for researchers and practitioners, promoting accurate and efficient evaluation of segmentation algorithms.
Introduction
The evaluation of medical image segmentation algorithms is a pivotal aspect of medical image analysis, playing a crucial role in applications ranging from diagnosis to treatment planning. Traditional metrics such as Dice coefficient, Jaccard index, and Mahalanobis distance are essential for assessing the accuracy of segmentation algorithms. However, the computational demands of these metrics, particularly for large 3D datasets, pose significant challenges. The advent of GPU computing offers a promising solution to these challenges by enabling accelerated computations.
MedEval3D is a novel package designed to harness the power of CUDA for the efficient calculation of medical segmentation metrics. Built on the mathematical foundations laid out by Taha et al., MedEval3D incorporates a wide range of metrics essential for comprehensive evaluation. The package's programming model is based on a two-phase metric evaluation process. The first phase involves the preparation step, where constants and configurations are precomputed based on the image array dimensions and GPU hardware specifications. This step ensures optimal performance by leveraging the Occupancy API to maximize GPU occupancy. The second phase involves the actual metric computation, where the prepared configurations are applied to the image data to produce the final metrics.
To demonstrate the flexibility and efficiency of MedEval3D, we integrate it with the nnunet_utilities Python module, a popular tool for medical image segmentation. This integration allows users to calculate metrics such as the Dice loss using both Julia and Python functions, facilitating a seamless workflow for researchers who utilize tools from both ecosystems. For example, we define the Dice loss function in Julia and compare it with the implementation in Python, showcasing the ease of cross-language functionality and the performance benefits of CUDA acceleration.
Our benchmarking experiments, conducted on a high-performance Windows PC equipped with an Intel Core I9 processor and an NVIDIA Geforce RTX 3080 GPU, reveal that MedEval3D achieves up to 214 times faster execution times compared to traditional CPU-based methods. These results underscore the efficiency of our CUDA-accelerated approach. Notably, MedEval3D is capable of calculating multiple metrics simultaneously with negligible additional computation time, a significant advantage over existing solutions.

This paper provides a comprehensive overview of MedEval3D, including its implementation details, benchmarking results, and practical applications. By offering a powerful and efficient tool for segmentation evaluation, we aim to contribute to the advancement of medical image analysis and support the development of more accurate and reliable segmentation algorithms.

using Pkg
Pkg.add("PyCall")
using PyCall

np = pyimport("numpy")
sitk = pyimport("SimpleITK")

function dice_loss(arrGold::Array{Int64,3}, arrAlgo::Array{Int64,3})
intersection = sum(arrGold .& arrAlgo)
union = sum(arrGold) + sum(arrAlgo)
if union == 0
return 1.0
end
return 1.0 - (2.0 * intersection/union)
end

arrGold = [
[1 0 0; 0 1 0; 0 0 1],
[1 1 0; 0 1 1; 0 0 1],
[1 0 1; 0 1 0; 1 0 1]
]

arrAlgo = [
[1 1 0; 1 0 0; 0 1 1],
[1 0 1; 0 1 0; 1 1 0],
[0 1 0; 1 0 1; 1 1 0]
]

arrGold = reshape(arrGold, (3,3,3))
arrAlgo = reshape(arrAlgo, (3,3,3))

julia_loss = dice_loss(arrGold, arrAlgo)
println("Julia Dice Loss: ", julia_loss)

arrGold_np = PyObject(arrGold)
arrAlgo_np = PyObject(arrAlgo)

dice_loss_py = py"
import numpy as np
def dice_loss(arrGold, arrAlgo):
intersection = np.sum(arrGold * arrAlgo)
union = np.sum(arrGold) + np.sum(arrAlgo)
if union == 0:
return 1.0
return 1.0 - (2.0 * intersection/union)

result = dice_loss($arrGold_np, $arrAlgo_np)
"""
println("Python Dice Loss: ", dice_loss_py)

arrGold_sitk = sitk.GetImageFromArray(arrGold)
arrAlgo_sitk = sitk.GetImageFromArray(arrAlgo)
dice_filter = sitk.LabelOverlapMeasuresImageFilter()
dice_filter.Execute(arrGold_sitk, arrAlgo_sitk)
sitk_dice = dice_filter.GetDiceCoefficient()
println("SimpleITK Dice Coefficient: ", sitk_dice)

  1. Cross Entropy Tests
    function cross_entropy_loss(arrGold::Array{Int64,3}, arrAlgo::Array{Float64,3})
    ε = 1e-15 # Small epsilon to avoid log(0)
    arrAlgo = clamp.(arrAlgo, ε, 1 - ε) # Ensure predictions are within [ε, 1-ε]
    loss = -sum(arrGold .* log.(arrAlgo) + (1 .- arrGold) .* log.(1 .- arrAlgo))
    return loss / length(arrGold)
    end

  2. Top- k loss implementation

function topk_loss(arrGold::Array{Int64,3}, arrAlgo::Array{Float64,3}; k=5)
sorted_indices = sortperm(arrAlgo, rev=true)
top_k_indices = sorted_indices[1:k]
loss = -sum(arrGold[top_k_indices] .* log.(arrAlgo[top_k_indices]))
return loss / k
end

using BenchmarkTools

arrGold = [
[1 0 0; 0 1 0; 0 0 1],
[1 1 0; 0 1 1; 0 0 1],
[1 0 1; 0 1 0; 1 0 1]
]

arrAlgo = [
[0.9 0.1 0.2; 0.4 0.8 0.3; 0.3 0.4 0.9],
[0.7 0.6 0.5; 0.2 0.9 0.6; 0.8 0.4 0.5],
[0.3 0.7 0.2; 0.5 0.6 0.8; 0.9 0.8 0.7]
]

arrGold = reshape(arrGold, (3,3,3))
arrAlgo = reshape(arrAlgo, (3,3,3))

println("Julia Dice Loss: ", dice_loss(arrGold, arrAlgo))
println("Julia Cross-Entropy Loss: ", cross_entropy_loss(arrGold, arrAlgo))
println("Julia Focal Loss: ", focal_loss(arrGold, arrAlgo))
println("Julia Top-k Loss: ", topk_loss(arrGold, arrAlgo))

@benchmark dice_loss(arrGold, arrAlgo)
@benchmark cross_entropy_loss(arrGold, arrAlgo)
@benchmark focal_loss(arrGold, arrAlgo)
@benchmark topk_loss(arrGold, arrAlgo)

arrGold_np = PyObject(arrGold)
arrAlgo_np = PyObject(arrAlgo)

cross_entropy_loss_py = py"
import numpy as np
def cross_entropy_loss(arrGold, arrAlgo):
ε = 1e-15
arrAlgo = np.clip(arrAlgo, ε, 1 - ε)
return -np.mean(arrGold * np.log(arrAlgo) + (1 - arrGold) * np.log(1 - arrAlgo))

result = cross_entropy_loss($arrGold_np, $arrAlgo_np)
"""

focal_loss_py = py"
import numpy as np
def focal_loss(arrGold, arrAlgo, γ=2.0, α=0.25):
ε = 1e-15
arrAlgo = np.clip(arrAlgo, ε, 1 - ε)
return -np.mean(α * arrGold * np.power(1 - arrAlgo, γ) * np.log(arrAlgo) +
(1 - α) * (1 - arrGold) * np.power(arrAlgo, γ) * np.log(1 - arrAlgo))

result = focal_loss($arrGold_np, $arrAlgo_np)
"""

topk_loss_py = py"""
import numpy as np
def topk_loss(arrGold, arrAlgo, k=5):
sorted_indices = np.argsort(arrAlgo, axis=None)[::-1]
top_k_indices = sorted_indices[:k]
return -np.mean(arrGold.flat[top_k_indices] * np.log(arrAlgo.flat[top_k_indices]))

result = topk_loss($arrGold_np, $arrAlgo_np)
"""

println("Python Cross-Entropy Loss: ", cross_entropy_loss_py)
println("Python Focal Loss: ", focal_loss_py)
println("Python Top-k Loss: ", topk_loss_py)

User friendly Code:

using Pkg
Pkg.add("PyCall")
using PyCall, BenchmarkTools

np = pyimport("numpy")
sitk = pyimport("SimpleITK")

function dice_loss(arrGold::Array{Int64,3}, arrAlgo::Array{Int64,3})
intersection = sum(arrGold .& arrAlgo)
union = sum(arrGold) + sum(arrAlgo)
if union == 0
return 1.0
end
return 1.0 - (2.0 * intersection / union)
end

function cross_entropy_loss(arrGold::Array{Int64,3}, arrAlgo::Array{Float64,3})
ε = 1e-15 # Small epsilon to avoid log(0)
arrAlgo = clamp.(arrAlgo, ε, 1 - ε) # Ensure predictions are within [ε, 1-ε]
loss = -sum(arrGold .* log.(arrAlgo) + (1 .- arrGold) .* log.(1 .- arrAlgo))
return loss / length(arrGold)
end

function focal_loss(arrGold::Array{Int64,3}, arrAlgo::Array{Float64,3}; γ=2.0, α=0.25)
ε = 1e-15 # Small epsilon to avoid log(0)
arrAlgo = clamp.(arrAlgo, ε, 1 - ε) # Ensure predictions are within [ε, 1-ε]
loss = -sum(α .* arrGold .* (1 .- arrAlgo).^γ .* log.(arrAlgo) +
(1 .- α) .* (1 .- arrGold) .* arrAlgo.^γ .* log.(1 .- arrAlgo))
return loss / length(arrGold)
end

function topk_loss(arrGold::Array{Int64,3}, arrAlgo::Array{Float64,3}; k=5)
sorted_indices = sortperm(arrAlgo, rev=true)
top_k_indices = sorted_indices[1:k]
loss = -sum(arrGold[top_k_indices] .* log.(arrAlgo[top_k_indices]))
return loss / k
end

function run_benchmarks(arrGold, arrAlgo)
println("Julia Dice Loss: ", dice_loss(arrGold, arrAlgo))
println("Julia Cross-Entropy Loss: ", cross_entropy_loss(arrGold, arrAlgo))
println("Julia Focal Loss: ", focal_loss(arrGold, arrAlgo))
println("Julia Top-k Loss: ", topk_loss(arrGold, arrAlgo))

println("\nBenchmarking Julia implementations:")
println("Dice Loss: ", @belapsed dice_loss($arrGold, $arrAlgo))
println("Cross-Entropy Loss: ", @belapsed cross_entropy_loss($arrGold, $arrAlgo))
println("Focal Loss: ", @belapsed focal_loss($arrGold, $arrAlgo))
println("Top-k Loss: ", @belapsed topk_loss($arrGold, $arrAlgo))


arrGold_np = PyObject(arrGold)
arrAlgo_np = PyObject(arrAlgo)


dice_loss_py = py"""
import numpy as np
def dice_loss(arrGold, arrAlgo):
    intersection = np.sum(arrGold * arrAlgo)
    union = np.sum(arrGold) + np.sum(arrAlgo)
    if union == 0:
        return 1.0
    return 1.0 - (2.0 * intersection / union)

result = dice_loss($arrGold_np, $arrAlgo_np)
"""
cross_entropy_loss_py = py"""
import numpy as np
def cross_entropy_loss(arrGold, arrAlgo):
    ε = 1e-15
    arrAlgo = np.clip(arrAlgo, ε, 1 - ε)
    return -np.mean(arrGold * np.log(arrAlgo) + (1 - arrGold) * np.log(1 - arrAlgo))

result = cross_entropy_loss($arrGold_np, $arrAlgo_np)
"""
focal_loss_py = py"""
import numpy as np
def focal_loss(arrGold, arrAlgo, γ=2.0, α=0.25):
    ε = 1e-15
    arrAlgo = np.clip(arrAlgo, ε, 1 - ε)
    return -np.mean(α * arrGold * np.power(1 - arrAlgo, γ) * np.log(arrAlgo) + 
                    (1 - α) * (1 - arrGold) * np.power(arrAlgo, γ) * np.log(1 - arrAlgo))

result = focal_loss($arrGold_np, $arrAlgo_np)
"""
topk_loss_py = py"""
import numpy as np
def topk_loss(arrGold, arrAlgo, k=5):
    sorted_indices = np.argsort(arrAlgo, axis=None)[::-1]
    top_k_indices = sorted_indices[:k]
    return -np.mean(arrGold.flat[top_k_indices] * np.log(arrAlgo.flat[top_k_indices]))

result = topk_loss($arrGold_np, $arrAlgo_np)
"""

println("\nPython Dice Loss: ", dice_loss_py)
println("Python Cross-Entropy Loss: ", cross_entropy_loss_py)
println("Python Focal Loss: ", focal_loss_py)
println("Python Top-k Loss: ", topk_loss_py)

end

function main()
# Example Test Data
arrGold = [
[1 0 0; 0 1 0; 0 0 1],
[1 1 0; 0 1 1; 0 0 1],
[1 0 1; 0 1 0; 1 0 1]
]

arrAlgo = [
    [0.9 0.1 0.2; 0.4 0.8 0.3; 0.3 0.4 0.9],
    [0.7 0.6 0.5; 0.2 0.9 0.6; 0.8 0.4 0.5],
    [0.3 0.7 0.2; 0.5 0.6 0.8; 0.9 0.8 0.7]
]

arrGold = reshape(arrGold, (3,3,3))
arrAlgo = reshape(arrAlgo, (3,3,3))


run_benchmarks(arrGold, arrAlgo)

end

main()

@jakubMitura14
Copy link
Collaborator

please add the code in 3 separate files (loss.jl ; benchmark.jl; test_loss.jl) also and add in readme section how to use it; get benchmark code also in separate file; doas the repository manifest has all required packages you use - like for benchmarking etc?

@akmathur1
Copy link
Author

Loss. Jl

function dice_loss(arrGold::Array{Int64,3}, arrAlgo::Array{Int64,3})
intersection = sum(arrGold .& arrAlgo)
union = sum(arrGold) + sum(arrAlgo)
if union == 0
return 1.0
end
return 1.0 - (2.0 * intersection/union)
end

function jaccard_index(arrGold::Array{Int64,3}, arrAlgo::Array{Int64,3})
intersection = sum(arrGold .& arrAlgo)
union = sum(arrGold .| arrAlgo)
if union == 0
return 1.0
end
return intersection/union
end

function cross_entropy_loss(arrGold::Array{Float64,3}, arrAlgo::Array{Float64,3})
epsilon = 1e-12
arrAlgo = clamp.(arrAlgo, epsilon, 1. - epsilon)
return -mean(arrGold .* log.(arrAlgo) .+ (1 .- arrGold) .* log.(1 .- arrAlgo))
End

@akmathur1
Copy link
Author

Benchmark. JI

include("loss.jl")

arrGold = rand(0:1, 256, 256, 256)
arrAlgo = rand(0:1, 256, 256, 256)

println("Benchmarking Dice Loss...")
@Btime dice_loss($arrGold, $arrAlgo)

println("Benchmarking Jaccard Index...")
@Btime jaccard_index($arrGold, $arrAlgo)

arrGold_float = convert(Array{Float64,3}, arrGold)
arrAlgo_float = convert(Array{Float64,3}, arrAlgo)
println("Benchmarking Cross-Entropy Loss...")
@Btime cross_entropy_loss($arrGold_float, $arrAlgo_float)

@akmathur1
Copy link
Author

Test_loss.jl

using Test

include("loss.jl")

@testset "Dice Loss" begins
arrGold = reshape([1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1], (3,3,3))
arrAlgo = reshape([1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0], (3,3,3))
@test dice_loss(arrGold, arrAlgo) ≈ 0.6667 atol=1e-4
end

arrGold = reshape([1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1], (3,3,3))
arrAlgo = reshape([1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0], (3,3,3))
@test jaccard_index(arrGold, arrAlgo) ≈ 0.5 atol=1e-4

end

arrGold = reshape([1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1], (3,3,3))
arrAlgo = reshape([1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0], (3,3,3))
arrGold_float = convert(Array{Float64,3}, arrGold)
arrAlgo_float = convert(Array{Float64,3}, arrAlgo)
@test cross_entropy_loss(arrGold_float, arrAlgo_float) ≈ 0.6931 atol=1e-4

end

@jakubMitura14
Copy link
Collaborator

Good but you added the files as a comments not as separate actual files :); do it please and propose to get a change in readme to tell people how to use it :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants