Contextual Dispatch

Contextual Dispatch

In the previous section, we saw how, within a given execution trace, Cassette's overdub mechanism transforms every method invocation of the form f(args...) into statements similar to the following:

begin
    Cassette.prehook(context, f, args...)
    tmp = Cassette.execute(context, f, args...)
    tmp = isa(tmp, Cassette.OverdubInstead) ? overdub(context, f, args...) : tmp
    Cassette.posthook(context, tmp, f, args...)
    tmp
end

This transformation yields several extra points of overloadability in the form of various Cassette methods, such as prehook, posthook, and execute. Together, these methods form Cassette's "contextual dispatch" interface, so called because it allows the extra context parameter to participate in what would normally be a simple dispatch to the underlying method call.

In this section of the documentation, we'll go over these functions in a bit more detail.

To begin, let's define a simple contextual prehook by overloading the prehook method w.r.t. to a dummy context:

julia> using Cassette

julia> Cassette.@context Ctx
Cassette.Context{nametype(Ctx),M,P,T,B} where B<:Union{Nothing, IdDict{Module,Dict{Symbol,BindingMeta}}} where P<:Cassette.AbstractPass where T<:Union{Nothing, Tag} where M

# this prehook implements simple trace logging for overdubbed functions
julia> Cassette.prehook(::Ctx, f, args...) = println(f, args)

julia> Cassette.overdub(Ctx(), /, 1, 2)
float(1,)
AbstractFloat(1,)
Float64(1,)
sitofp(Float64, 1)
float(2,)
AbstractFloat(2,)
Float64(2,)
sitofp(Float64, 2)
/(1.0, 2.0)
div_float(1.0, 2.0)
0.5

Cool beans!

Actually, there's a subtlety about overdub here we should address before moving on. Why wasn't the first line in the trace log /(1, 2)? I'll leave the answer as an exercise to the reader - just recall the definition of overdub from the previous section. If this the barrier between the overdub and the contextual dispatch interface seems confusing, try comparing the output from the above example with the output generated via overdub(Ctx(), () -> 1/2).

For pedagogy's sake, let's make our prehook slightly more complicated; let's only print calls whose first argument matches a specific type. A nice configurable way to do this is as follows:

# reset our prehook to a no-op
julia> Cassette.prehook(::Ctx, f, args...) = nothing

# parameterize our prehook on the type of metadata stored in our context instance
julia> Cassette.prehook(::Ctx{Val{T}}, f, arg::T, rest...) where {T} = println(f, (arg, rest...))

# construct our context instance with metadata to configure the prehook
julia> Cassette.overdub(Ctx(metadata=Val(Int)), /, 1, 2)
float(1,)
AbstractFloat(1,)
Float64(1,)
float(2,)
AbstractFloat(2,)
Float64(2,)
0.5

julia> Cassette.overdub(Ctx(metadata=Val(DataType)), /, 1, 2)
sitofp(Float64, 1)
sitofp(Float64, 2)
0.5

Also of note is prehook's long-lost cousin posthook, with which prehook shares many similarities. In fact, these functions are so similar that we won't be spending too much time on posthook individually. The key difference between prehook and posthook is that posthook runs after the overdubbed invocation is executed, such that it has access to the output of the overdubbed invocation.

For example, here we use posthook and prehook together to accumulate a trace that preserves nesting information:

using Cassette

Cassette.@context TraceCtx

mutable struct Trace
    current::Vector{Any}
    stack::Vector{Any}
    Trace() = new(Any[], Any[])
end

function enter!(t::Trace, args...)
    pair = args => Any[]
    push!(t.current, pair)
    push!(t.stack, t.current)
    t.current = pair.second
    return nothing
end

function exit!(t::Trace)
    t.current = pop!(t.stack)
    return nothing
end

Cassette.prehook(ctx::TraceCtx, args...) = enter!(ctx.metadata, args...)
Cassette.posthook(ctx::TraceCtx, args...) = exit!(ctx.metadata)

trace = Trace()
x, y, z = rand(3)
f(x, y, z) = x*y + y*z
Cassette.overdub(TraceCtx(metadata = trace), () -> f(x, y, z))

# returns `true`
trace.current == Any[
    (f,x,y,z) => Any[
        (*,x,y) => Any[(Base.mul_float,x,y)=>Any[]]
        (*,y,z) => Any[(Base.mul_float,y,z)=>Any[]]
        (+,x*y,y*z) => Any[(Base.add_float,x*y,y*z)=>Any[]]
    ]
]

Next, let's tackle the meatiest part of the contextual dispatch interface: contextual primitives, as defined by the execute. Here's Cassette's default definition of execute:

execute(::Context, ::Vararg{Any}) = OverdubInstead()

With this definition in mind, the default case for the above contextual dispatch transformation can be reduced to:

begin
    Cassette.prehook(context, f, args...)
    tmp = overdub(context, f, args...)
    Cassette.posthook(context, tmp, f, args...)
    tmp
end

In other words, the execute's default behavior is to not interfere with the recursive application of overdub at all. If execute is ever overloaded to return something other than OverdubInstead, however, then it means the recursive overdubbing stops. Thus, in Cassette terminology, overloading execute defines a "contextual primitive" w.r.t. the overdubbing mechanism.

Note that upon context definition (via @context) a bunch of reasonable default contextual primitives are generated automatically. It is possible, of course, to simply override these defaults if necessary. For more details, see @context.

As an aside, one might wonder why the default definition of execute isn't simply execute(context, args...) = overdub(context, args...). The reason is that this definition is a bit harder on the compiler, since it adds an extra cycle (e.g. execute -> overdub -> execute) in the recursion inherent to Cassette's overdubbing mechanism. The branching based definition is much nicer, since it is much cheaper to evaluate a trivial isa check at compile time than it is to determine the worth of inferring through deep multi-cycle recursion.

To get a sense of the interaction between execute and overdub, let's reimplement our previous nested tracing example using recursion instead of maintaining a stack:

using Cassette

Cassette.@context TraceCtx

function Cassette.execute(ctx::TraceCtx, args...)
    subtrace = Any[]
    push!(ctx.metadata, args => subtrace)
    if Cassette.canoverdub(ctx, args...)
        newctx = Cassette.similarcontext(ctx, metadata = subtrace)
        return Cassette.overdub(newctx, args...)
    else
        return Cassette.fallback(ctx, args...)
    end
end

trace = Any[]
x, y, z = rand(3)
f(x, y, z) = x*y + y*z
Cassette.overdub(TraceCtx(metadata = trace), () -> f(x, y, z))

# returns `true`
trace == Any[
   (f,x,y,z) => Any[
       (*,x,y) => Any[(Base.mul_float,x,y)=>Any[]]
       (*,y,z) => Any[(Base.mul_float,y,z)=>Any[]]
       (+,x*y,y*z) => Any[(Base.add_float,x*y,y*z)=>Any[]]
   ]
]