Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For transports #244

Open
cscherrer opened this issue Oct 31, 2022 · 1 comment
Open

For transports #244

cscherrer opened this issue Oct 31, 2022 · 1 comment

Comments

@cscherrer
Copy link
Collaborator

I got this working, sort of:

julia> d = For(j -> Normal(j, 2.0), 1:3)
For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,))

julia> test_transport(d, Normal() ^ 3)
Test Summary:                                                                                             | Pass  Total  Time
transport_to Normal() ^ 3 to For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,)) |    8      8  0.0s
DefaultTestSet("transport_to Normal() ^ 3 to For{Normal{(:μ, :σ), Tuple{Int64, Float64}}}(j->Main.Normal(j, 2.0), (1:3,))", Any[], 8, false, false, true, 1.66725e9, 1.66725e9)

To do this, I added for_constructor that's like For, but a little smarter - it might sometimes collapse to a power measure:

for_constructor(f, x) = for_constructor(f, (x,))

@generated function for_constructor(f::F, inds::I) where {F,I<:Tuple}
    eltypes = Tuple{eltype.(I.types)...}
    quote
        T = Core.Compiler.return_type(f, $eltypes)
        _for(T, f, inds, static(Base.issingletontype(T)))
    end
end

function _for(::Type{T}, f::F, inds::I, ::True) where {T,F,I}
    instance(T) ^ size(first(inds))
end

function _for(::Type{T}, f::F, inds::I, ::False) where {T,F,I}
    For{T,F,I}(f, inds)
end

Then we just need the standard stuff:

function MeasureBase.transport_origin(d::AbstractProductMeasure)
    for_constructor(MeasureBase.transport_origin, marginals(d))
end

function MeasureBase.to_origin(d::AbstractProductMeasure, x)
    map(MeasureBase.to_origin, marginals(d), x)
end

function MeasureBase.from_origin(d::AbstractProductMeasure, x)
    map(MeasureBase.from_origin, marginals(d), x)
end

Well, almost. There's also this bug:

julia> MeasureBase._origin_depth(Normal() ^ 3)
ERROR: MethodError: no method matching ^(::MeasureBase.NoTransportOrigin{StdNormal}, ::Tuple{Int64})
Closest candidates are:
  ^(::AbstractMeasure, ::Tuple) at ~/git/MeasureBase.jl/src/combinators/power.jl:55
  ^(::AbstractMeasure, ::Any) at ~/git/MeasureBase.jl/src/combinators/power.jl:56
Stacktrace:
 [1] _for(#unused#::Type{MeasureBase.NoTransportOrigin{StdNormal}}, f::typeof(MeasureBase.transport_origin), inds::Tuple{FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}}}, #unused#::Static.True)
   @ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:37
 [2] macro expansion
   @ ~/git/MeasureTheory.jl/src/combinators/for.jl:32 [inlined]
 [3] for_constructor(f::typeof(MeasureBase.transport_origin), inds::Tuple{FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}}})
   @ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:28
 [4] for_constructor(f::Function, x::FillArrays.Fill{StdNormal, 1, Tuple{Base.OneTo{Int64}}})
   @ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:26
 [5] transport_origin(d::PowerMeasure{StdNormal, Tuple{Base.OneTo{Int64}}})
   @ MeasureTheory ~/git/MeasureTheory.jl/src/combinators/for.jl:305
 [6] _origin_depth::PowerMeasure{Normal{(), Tuple{}}, Tuple{Base.OneTo{Int64}}})
   @ MeasureBase ~/git/MeasureBase.jl/src/transport.jl:130
 [7] top-level scope
   @ REPL[60]:1

We end up taking a power of a NoTransportOrigin, which makes no sense. As a quick fix, I temporarily changed MeasureBase._origin_depth to

@inline function _origin_depth::NU) where {NU}
    ν_0 = ν
    Base.Cartesian.@nexprs 10 i -> begin  # 10 is just some "big enough" number
        ν_{i} = transport_origin(ν_{i - 1})
        if ν_{i} isa PowerMeasure
            ν_{i} = ν_{i}.parent
        else
            if ν_{i} isa NoTransportOrigin
            return static(i - 1)
        end
    end
    return static(10)
end

This last part feels kind of hacky. Also, we have the problem that map forces allocation. It would be nice to use mappedarray instead, but that doesn't infer properly. Maybe a modification of it could?

Also, it seems like a problem if we have a product with different "origin depths". A fixpoint approach would handle this, but I think the current approach will break. Any ideas for this @oschulz ?

@oschulz
Copy link
Collaborator

oschulz commented Oct 31, 2022

a problem if we have a product with different "origin depths"

Well, if it's a tuple-based product, the transport for each marginal should generate separate code and everything should infer. And if it's array-based the marginals have different depth then they also have different type, so type-inference is probably hopeless anyway, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants