-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
differentiable kdes #77
Changes from all commits
e476e2f
1390975
d96d47d
0ba2565
3e425d4
86d8f32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,15 +5,39 @@ using StatsBase | |
using Distributions | ||
using Optim | ||
using Interpolations | ||
using Flux.Tracker: TrackedReal | ||
using Flux | ||
|
||
import StatsBase: RealVector, RealMatrix | ||
import Distributions: twoπ, pdf | ||
import FFTW: rfft, irfft | ||
import Flux.Tracker: conv | ||
import Base.round | ||
import Core.Integer | ||
|
||
export kde, kde_lscv, UnivariateKDE, BivariateKDE, InterpKDE, pdf | ||
|
||
abstract type AbstractKDE end | ||
|
||
"""n-dimensional convolution""" | ||
function conv(x::AbstractArray{T,N}, w::AbstractArray{T,N}) where {T,N} | ||
wdim = Int.(ceil.((size(w).-1)./2)) | ||
padding = Iterators.flatten([ (wdim[i],wdim[i]) for i=1:length(wdim) ]) |> collect | ||
|
||
dims = DenseConvDims((size(x)...,1,1),(size(w)...,1,1); padding=padding ) | ||
result = Tracker.conv( reshape(x,(size(x)...,1,1)), reshape(w,(size(w)...,1,1)), dims) | ||
return dropdims(result, dims = (1+N,2+N)) | ||
end | ||
|
||
# patches for TrackedReal and Vector{TrackedReal} | ||
conv(x::AbstractArray{TrackedReal{T},N}, w::AbstractArray) where {T,N} = conv(Tracker.collect(x),w) | ||
conv(x::AbstractArray, w::AbstractArray{TrackedReal{T},N}) where {T,N} = conv(x,Tracker.collect(w)) | ||
conv(x::AbstractArray{TrackedReal{T},N}, w::AbstractArray{TrackedReal{T},N}) where {T,N} = conv(Tracker.collect(x),Tracker.collect(w)) | ||
|
||
round(::Type{R}, t::TrackedReal) where {R<:Real} = round(R, t.data) | ||
round(t::TrackedReal, mode::RoundingMode) = round(t.data, mode) | ||
Integer(x::TrackedReal) = Integer(x.data) | ||
|
||
Comment on lines
+22
to
+40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bypasses the need for |
||
include("univariate.jl") | ||
include("bivariate.jl") | ||
include("interp.jl") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ mutable struct BivariateKDE{Rx<:AbstractRange,Ry<:AbstractRange} <: AbstractKDE | |
"Second coordinate of gridpoints for evaluating the density." | ||
y::Ry | ||
"Kernel density at corresponding gridpoints `Tuple.(x, permutedims(y))`." | ||
density::Matrix{Float64} | ||
density::AbstractMatrix{} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. relax type casting for future compatibility.. will make a PR to |
||
end | ||
|
||
function kernel_dist(::Type{D},w::Tuple{Real,Real}) where D<:UnivariateDistribution | ||
|
@@ -54,7 +54,7 @@ function tabulate(data::Tuple{RealVector, RealVector}, midpoints::Tuple{Rx, Ry}, | |
sx, sy = step(xmid), step(ymid) | ||
|
||
# Set up a grid for discretized data | ||
grid = zeros(Float64, nx, ny) | ||
grid = zeros(eltype(xdata),nx,ny) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. relaxing types |
||
ainc = 1.0 / (sum(weights)*(sx*sy)^2) | ||
|
||
# weighted discretization (cf. Jones and Lotwick) | ||
|
@@ -77,28 +77,16 @@ end | |
|
||
# convolution with product distribution of two univariates distributions | ||
function conv(k::BivariateKDE, dist::Tuple{UnivariateDistribution,UnivariateDistribution}) | ||
# Transform to Fourier basis | ||
Kx, Ky = size(k.density) | ||
ft = rfft(k.density) | ||
|
||
distx, disty = dist | ||
|
||
# Convolve fft with characteristic function of kernel | ||
cx = -twoπ/(step(k.x)*Kx) | ||
cy = -twoπ/(step(k.y)*Ky) | ||
for j = 0:size(ft,2)-1 | ||
for i = 0:size(ft,1)-1 | ||
ft[i+1,j+1] *= cf(distx,i*cx)*cf(disty,min(j,Ky-j)*cy) | ||
end | ||
end | ||
dens = irfft(ft, Kx) | ||
half_gridx = range(step(k.x),5*std(distx),step=step(k.x)) | ||
gridx = [-reverse(half_gridx);0;half_gridx] | ||
|
||
for i = 1:length(dens) | ||
dens[i] = max(0.0,dens[i]) | ||
end | ||
half_gridy = range(step(k.y),5*std(disty),step=step(k.y)) | ||
gridy = [-reverse(half_gridy);0;half_gridy] | ||
|
||
# Invert the Fourier transform to get the KDE | ||
BivariateKDE(k.x, k.y, dens) | ||
density = conv(k.density, pdf.(distx,gridx)*pdf.(disty,gridy)')' * step(k.x) * step(k.y) | ||
BivariateKDE(k.x, k.y, density) | ||
end | ||
|
||
const BivariateDistribution = Union{MultivariateDistribution,Tuple{UnivariateDistribution,UnivariateDistribution}} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ mutable struct UnivariateKDE{R<:AbstractRange} <: AbstractKDE | |
"Gridpoints for evaluating the density." | ||
x::R | ||
"Kernel density at corresponding gridpoints `x`." | ||
density::Vector{Float64} | ||
density::AbstractVector{} | ||
end | ||
|
||
# construct kernel from bandwidth | ||
|
@@ -101,7 +101,7 @@ function tabulate(data::RealVector, midpoints::R, weights::Weights=default_weigh | |
s = step(midpoints) | ||
|
||
# Set up a grid for discretized data | ||
grid = zeros(Float64, npoints) | ||
grid = zeros(eltype(data),npoints) | ||
ainc = 1.0 / (sum(weights)*s*s) | ||
|
||
# weighted discretization (cf. Jones and Lotwick) | ||
|
@@ -119,31 +119,11 @@ function tabulate(data::RealVector, midpoints::R, weights::Weights=default_weigh | |
end | ||
|
||
# convolve raw KDE with kernel | ||
# TODO: use in-place fft | ||
function conv(k::UnivariateKDE, dist::UnivariateDistribution) | ||
# Transform to Fourier basis | ||
K = length(k.density) | ||
ft = rfft(k.density) | ||
|
||
# Convolve fft with characteristic function of kernel | ||
# empirical cf | ||
# = \sum_{n=1}^N e^{i*t*X_n} / N | ||
# = \sum_{k=0}^K e^{i*t*(a+k*s)} N_k / N | ||
# = e^{i*t*a} \sum_{k=0}^K e^{-2pi*i*k*(-t*s*K/2pi)/K} N_k / N | ||
# = A * fft(N_k/N)[-t*s*K/2pi + 1] | ||
c = -twoπ/(step(k.x)*K) | ||
for j = 0:length(ft)-1 | ||
ft[j+1] *= cf(dist,j*c) | ||
end | ||
|
||
dens = irfft(ft, K) | ||
# fix rounding error. | ||
for i = 1:K | ||
dens[i] = max(0.0,dens[i]) | ||
end | ||
|
||
# Invert the Fourier transform to get the KDE | ||
UnivariateKDE(k.x, dens) | ||
half_grid = range(step(k.x),5*std(dist),step=step(k.x)) | ||
grid = [-reverse(half_grid);0;half_grid] | ||
density = conv(k.density, pdf.(dist,grid)) * step(k.x) | ||
UnivariateKDE(k.x, density) | ||
end | ||
Comment on lines
122
to
127
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what would your recommendation be for this? given that the methods in |
||
|
||
# main kde interface methods | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all of these methods below are type piracy, so probably better to make a PR to Tracker.jl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.. suppose these methods are added to
Tracker.jl
how would you rewrite theconv
methods inunivariate.jl
andbivariate.jl
so that they are compatible?