Contextual Dispatch

Contextual Dispatch

In the previous section, we saw how, within a given execution trace, Cassette's overdub mechanism transforms every method invocation of the form f(args...) into statements similar to the following:

begin
    Cassette.prehook(context, f, args...)
    tmp = Cassette.execute(context, f, args...)
    tmp = isa(tmp, Cassette.OverdubInstead) ? overdub(context, f, args...) : tmp
    Cassette.posthook(context, tmp, f, args...)
    tmp
end

This transformation yields several extra points of overloadability in the form of various Cassette methods, such as prehook, posthook, and execute. Together, these methods form Cassette's "contextual dispatch" interface, so called because it allows the extra context parameter to participate in what would normally be a simple dispatch to the underlying method call.

In this section of the documentation, we'll go over these functions in a bit more detail.

To begin, let's define a simple contextual prehook by overloading the prehook method w.r.t. to a dummy context:

julia> using Cassette

julia> Cassette.@context Ctx
Cassette.Context{nametype(Ctx),M,P,T,B} where B<:Union{Nothing, IdDict{Module,Dict{Symbol,BindingMeta}}} where P<:Cassette.AbstractPass where T<:Union{Nothing, Tag} where M

# this prehook implements simple trace logging for overdubbed functions
julia> Cassette.prehook(::Ctx, f, args...) = println(f, args)

julia> Cassette.overdub(Ctx(), /, 1, 2)
float(1,)
AbstractFloat(1,)
Float64(1,)
sitofp(Float64, 1)
float(2,)
AbstractFloat(2,)
Float64(2,)
sitofp(Float64, 2)
/(1.0, 2.0)
div_float(1.0, 2.0)
0.5

Cool beans!

Actually, there's a subtlety about overdub here we should address before moving on. Why wasn't the first line in the trace log /(1, 2)? I'll leave the answer as an exercise to the reader - just recall the definition of overdub from the previous section. If this the barrier between the overdub and the contextual dispatch interface seems confusing, try comparing the output from the above example with the output generated via overdub(Ctx(), () -> 1/2).

For pedagogy's sake, let's make our prehook slightly more complicated; let's only print calls whose first argument matches a specific type. A nice configurable way to do this is as follows:

# reset our prehook to a no-op
julia> Cassette.prehook(::Ctx, f, args...) = nothing

# parameterize our prehook on the type of metadata stored in our context instance
julia> Cassette.prehook(::Ctx{Val{T}}, f, arg::T, rest...) where {T} = println(f, (arg, rest...))

# construct our context instance with metadata to configure the prehook
julia> Cassette.overdub(Ctx(metadata=Val(Int)), /, 1, 2)
float(1,)
AbstractFloat(1,)
Float64(1,)
float(2,)
AbstractFloat(2,)
Float64(2,)
0.5

julia> Cassette.overdub(Ctx(metadata=Val(DataType)), /, 1, 2)
sitofp(Float64, 1)
sitofp(Float64, 2)
0.5

Also of note is prehook's long-lost cousin posthook, with which prehook shares many similarities. In fact, these functions are so similar that we won't be spending too much time on posthook individually. The key difference between prehook and posthook is that posthook runs after the overdubbed invocation is executed, such that it has access to the output of the overdubbed invocation.

For example, here we use posthook and prehook together to accumulate a trace that preserves nesting information:

using Cassette

Cassette.@context TraceCtx

mutable struct Trace
    current::Vector{Any}
    stack::Vector{Any}
    Trace() = new(Any[], Any[])
end

function enter!(t::Trace, args...)
    pair = args => Any[]
    push!(t.current, pair)
    push!(t.stack, t.current)
    t.current = pair.second
    return nothing
end

function exit!(t::Trace)
    t.current = pop!(t.stack)
    return nothing
end

Cassette.prehook(ctx::TraceCtx, args...) = enter!(ctx.metadata, args...)
Cassette.posthook(ctx::TraceCtx, args...) = exit!(ctx.metadata)

trace = Trace()
x, y, z = rand(3)
f(x, y, z) = x*y + y*z
Cassette.overdub(TraceCtx(metadata = trace), () -> f(x, y, z))

# returns `true`
trace.current == Any[
    (f,x,y,z) => Any[
        (*,x,y) => Any[(Base.mul_float,x,y)=>Any[]]
        (*,y,z) => Any[(Base.mul_float,y,z)=>Any[]]
        (+,x*y,y*z) => Any[(Base.add_float,x*y,y*z)=>Any[]]
    ]
]

Next, let's tackle the meatiest part of the contextual dispatch interface: contextual primitives, as defined by the execute. Here's Cassette's default definition of execute:

execute(::Context, ::Vararg{Any}) = OverdubInstead()

With this definition in mind, the default case for the above contextual dispatch transformation can be reduced to:

begin
    Cassette.prehook(context, f, args...)
    tmp = overdub(context, f, args...)
    Cassette.posthook(context, tmp, f, args...)
    tmp
end

In other words, the execute's default behavior is to not interfere with the recursive application of overdub at all. If execute is ever overloaded to return something other than OverdubInstead, however, then it means the recursive overdubbing stops. Thus, in Cassette terminology, overloading execute defines a "contextual primitive" w.r.t. the overdubbing mechanism.

Note

A bunch of reasonable default contextual primitives are generated automatically upon context definition (via @context). It is possible, of course, to simply override these defaults if necessary. For more details, see @context.)

One might wonder why the default definition of execute isn't simply execute(context, args...) = overdub(context, args...). The reason is that this definition is a bit harder on the compiler, since it adds an extra cycle (e.g. execute -> overdub -> execute) to the recursion inherent in Cassette's overdubbing mechanism. It is much cheaper for the compiler to evaluate isa(tmp, OverdubInstead) than it is to infer through deep multi-cycle recursion.

Furthermore, it is often convenient to use OverdubInstead in your own contextual primitive definitions. For example, OverdubInstead is used in the below implementation, which memoizes the computation of Fibonacci numbers (many thanks to the illustrious Simon Byrne, the original author of this example):

using Cassette: Cassette, @context, OverdubInstead

fib(x) = x < 3 ? 1 : fib(x - 2) + fib(x - 1)
fibtest(n) = fib(2 * n) + n

@context MemoizeCtx
Cassette.execute(ctx::MemoizeCtx, ::typeof(fib), x) = get(ctx.metadata, x, OverdubInstead())
Cassette.posthook(ctx::MemoizeCtx, fibx, ::typeof(fib), x) = (ctx.metadata[x] = fibx)

Then (skipping the warm-up calls used to compile both functions):

julia> ctx = MemoizeCtx(metadata = Dict{Int,Int}());

julia> @time Cassette.overdub(ctx, fibtest, 20)
  0.000011 seconds (8 allocations: 1.547 KiB)
102334175

julia> @time Cassette.overdub(ctx, fibtest, 20)
  0.000006 seconds (5 allocations: 176 bytes)
102334175

julia> @time fibtest(20)
  0.276069 seconds (5 allocations: 176 bytes)
102334175

Finally, to get a sense of the interaction between execute and overdub, let's reimplement our previous nested tracing example using recursion instead of maintaining a stack:

using Cassette

Cassette.@context TraceCtx

function Cassette.execute(ctx::TraceCtx, args...)
    subtrace = Any[]
    push!(ctx.metadata, args => subtrace)
    if Cassette.canoverdub(ctx, args...)
        newctx = Cassette.similarcontext(ctx, metadata = subtrace)
        return Cassette.overdub(newctx, args...)
    else
        return Cassette.fallback(ctx, args...)
    end
end

trace = Any[]
x, y, z = rand(3)
f(x, y, z) = x*y + y*z
Cassette.overdub(TraceCtx(metadata = trace), () -> f(x, y, z))

# returns `true`
trace == Any[
   (f,x,y,z) => Any[
       (*,x,y) => Any[(Base.mul_float,x,y)=>Any[]]
       (*,y,z) => Any[(Base.mul_float,y,z)=>Any[]]
       (+,x*y,y*z) => Any[(Base.add_float,x*y,y*z)=>Any[]]
   ]
]