Cassette API Documentation

Cassette.ContextType
Context{N<:Cassette.AbstractContextName,
        M<:Any,
        P<:Cassette.AbstractPass,
        T<:Union{Nothing,Cassette.Tag},
        B<:Union{Nothing,Cassette.BindingMetaDictCache},
        H<:Union{Nothing,Cassette.DisableHooks}}

A type representing a Cassette execution context. This type is normally interacted with through type aliases constructed via Cassette.@context:

julia> Cassette.@context MyCtx
Cassette.Context{nametype(MyCtx),M,P,T,B,H} where H<:Union{Nothing,DisableHooks}
                                            where B<:Union{Nothing,IdDict{Module,Dict{Symbol,BindingMeta}}}
                                            where P<:AbstractPass
                                            where T<:Union{Nothing,Tag}
                                            where M

Constructors

Given a context type alias named e.g. MyCtx, an instance of the type can be constructed via:

MyCtx(; metadata = nothing, pass = Cassette.NoPass())

To construct a new context instance using an existing context instance as a template, see the similarcontext function.

To enable contextual tagging for a given context instance, see the enabletagging function.

Fields

  • name::N<:Cassette.AbstractContextName: a parameter used to disambiguate different contexts for overloading purposes (e.g. distinguishes MyCtx from other Context type aliases).

  • metadata::M<:Any: trace-local metadata as provided to the context constructor

  • tag::T<:Union{Nothing,Tag}: the tag object that is attached to values when they are tagged w.r.t. the context instance

  • pass::P<:Cassette.AbstractPass: the Cassette pass that will be applied to all method bodies encountered during contextual execution (see the @pass macro for details).

  • bindingscache::B<:Union{Nothing,BindingMetaDictCache}}: storage for metadata associated with tagged module bindings

  • hooktoggle::H<:Union{Nothing,DisableHooks}: configuration toggle for disabling the overdub pass's prehook/posthook injection (see disablehooks for details)

source
Cassette.similarcontextFunction
similarcontext(context::Context;
               metadata = context.metadata,
               pass = context.pass)

Return a copy of the given context, where the copy's metadata and/or pass fields are replaced with those provided via the corresponding keyword arguments.

source
Cassette.disablehooksFunction
disablehooks(context::Cassette.Context)

Return of copy of the given context with prehook/posthook injection disabled for the context. Disabling hook injection can reduce IR bloat in scenarios where these hooks are not being utilized.

source
Cassette.enabletaggingFunction
enabletagging(context::Cassette.Context, f)

Return a copy of the given context with the tagging system enabled for the contextual execution of f.

Cassette uses the type of f to generate the tag field of the returned instance.

Note that it is generally unsafe to use the returned instance to contextually execute functions other than f. Specifically, in cases of nested contextual execution where both inner and outer contexts employ the tagging system, improper application of the tagged system could cause (for example) separate contexts to erroneously interfere with each other's metadata propagation.

See also: hastagging

source
Cassette.hastaggingFunction
hastagging(::Type{<:Cassette.Context})

Returns true if the given type indicates that the contextual tagging system is enabled for context instances of the type, returns false otherwise.

Example

julia> Cassette.@context MyCtx;

julia> ctx = MyCtx();

julia> Cassette.hastagging(typeof(ctx))
false

julia> ctx = Cassette.enabletagging(ctx, sum);

julia> Cassette.hastagging(typeof(ctx))
true

See also: enabletagging

source
Cassette.@contextMacro
Cassette.@context Ctx

Define a new Cassette context type with the name Ctx. In reality, Ctx is simply a type alias for Cassette.Context{Cassette.nametype(Ctx)}.

Note that Cassette.overdub is automatically overloaded w.r.t. Ctx to define several primitives by default. A full list of these default primitives can be obtained by running:

methods(Cassette.overdub, (Ctx, Vararg{Any}))

Note also that many of the default primitives' signatures only match when contextual tagging is enabled.

See also: Context

source
Cassette.overdubFunction
overdub(context::Context, f, args...)

Execute f(args...) overdubbed with respect to context.

More specifically, execute f(args...), but with every internal method invocation g(x...) replaced by statements similar to the following:

begin
    prehook(context, g, x...)
    overdub(context, g, x...) # %n
    posthook(context, %n, g, x...)
    %n
end

Otherwise, if Cassette cannot retrieve lowered IR for the method body of f(args...), then fallback(context, f, args...) will be called instead. Cassette's canrecurse function is a useful utility for checking if this will occur.

If the injected prehook/posthook statements are not needed for your use case, you can disable their injection via the disablehooks function.

Additionally, for every method body encountered in the execution trace, apply the compiler pass associated with context if one exists. Note that this user-provided pass is performed on the method IR before method invocations are transformed into the form specified above. See the @pass macro for further details.

If Cassette.hastagging(typeof(context)), then a number of additional passes are run in order to accomodate tagged value propagation:

  • Expr(:new) is replaced with a call to Cassette.tagged_new
  • Expr(:splatnew) is replaced with a call to Cassette.tagged_splatnew
  • conditional values passed to Expr(:gotoifnot) are untagged
  • arguments to Expr(:foreigncall) are untagged
  • load/stores to external module bindings are intercepted by the tagging system

The default definition of overdub is to recursively enter the given function and continue overdubbing, but one can interrupt/redirect this recursion by overloading overdub w.r.t. a given context and/or method signature to define new contextual execution primitives. For example:

julia> using Cassette

julia> Cassette.@context Ctx;

julia> Cassette.overdub(::Ctx, ::typeof(sin), x) = cos(x)

julia> Cassette.overdub(Ctx(), x -> sin(x) + cos(x), 1) == 2 * cos(1)
true

See also: recurse, prehook, posthook

source
Cassette.@overdubMacro
Cassette.@overdub(ctx, expression)

A convenience macro for executing expression within the context ctx. This macro roughly expands to Cassette.recurse(ctx, () -> expression).

See also: overdub, recurse

source
Cassette.recurseFunction
recurse(context::Context, f, args...)

Execute f(args...) overdubbed with respect to context.

This method performs exactly the same transformation as the default overdub transformation, but is not meant to be overloaded. Thus, one can call recurse to "continue" recursively overdubbing a function when calling overdub directly on that function might've dispatched to a contextual primitive.

To illustrate why recurse might be useful, consider the following example which utilizes recurse as part of a Cassette-based memoization implementation for the classic Fibonacci function:

using Cassette: Cassette, @context, overdub, recurse

fib(x) = x < 3 ? 1 : fib(x - 2) + fib(x - 1)
fibtest(n) = fib(2 * n) + n

@context MemoizeCtx

function Cassette.overdub(ctx::MemoizeCtx, ::typeof(fib), x)
    result = get(ctx.metadata, x, 0)
    if result === 0
        result = recurse(ctx, fib, x)
        ctx.metadata[x] = result
    end
    return result
end

See Cassette's Contextual Dispatch documentation for more details and examples.

source
Cassette.prehookFunction
prehook(context::Context, f, args...)

Overload this Cassette method w.r.t. a given context in order to define a new contextual prehook for that context.

To understand when/how this method is called, see the documentation for overdub.

Invoking prehook is a no-op by default (it immediately returns nothing).

See also: overdub, posthook, recurse, fallback

Examples

Simple trace logging:

julia> Cassette.@context PrintCtx;

julia> Cassette.prehook(::PrintCtx, f, args...) = println(f, args)

julia> Cassette.overdub(PrintCtx(), /, 1, 2)
float(1,)
AbstractFloat(1,)
Float64(1,)
sitofp(Float64, 1)
float(2,)
AbstractFloat(2,)
Float64(2,)
sitofp(Float64, 2)
/(1.0, 2.0)
div_float(1.0, 2.0)
0.5

Counting the number of method invocations with one or more arguments of a given type:

julia> mutable struct Count{T}
           count::Int
       end

julia> Cassette.@context CountCtx;

julia> Cassette.prehook(ctx::CountCtx{Count{T}}, f, arg::T, args::T...) where {T} = (ctx.metadata.count += 1)

# count the number of calls of the form `f(::Float64, ::Float64...)`
julia> ctx = CountCtx(metadata = Count{Float64}(0));

julia> Cassette.overdub(ctx, /, 1, 2)
0.5

julia> ctx.metadata.count
2
source
Cassette.posthookFunction
posthook(context::Context, output, f, args...)

Overload this Cassette method w.r.t. a given context in order to define a new contextual posthook for that context.

To understand when/how this method is called, see the documentation for overdub.

Invoking posthook is a no-op by default (it immediately returns nothing).

See also: overdub, prehook, recurse, fallback

Examples

Simple trace logging:

julia> Cassette.@context PrintCtx;

julia> Cassette.posthook(::PrintCtx, output, f, args...) = println(output, " = ", f, args)

julia> Cassette.overdub(PrintCtx(), /, 1, 2)
1.0 = sitofp(Float64, 1)
1.0 = Float64(1,)
1.0 = AbstractFloat(1,)
1.0 = float(1,)
2.0 = sitofp(Float64, 2)
2.0 = Float64(2,)
2.0 = AbstractFloat(2,)
2.0 = float(2,)
0.5 = div_float(1.0, 2.0)
0.5 = /(1.0, 2.0)
0.5

Accumulate the sum of all numeric scalar outputs encountered in the trace:

julia> mutable struct Accum
           x::Number
       end

julia> Cassette.@context AccumCtx;

julia> Cassette.posthook(ctx::AccumCtx{Accum}, out::Number, f, args...) = (ctx.metadata.x += out)

julia> ctx = AccumCtx(metadata = Accum(0));

julia> Cassette.overdub(ctx, /, 1, 2)
0.5

julia> ctx.metadata.x
13.0
source
Cassette.fallbackFunction
fallback(context::Context, f, args...)

Overload this Cassette method w.r.t. a given context in order to define a new contextual execution fallback for that context.

To understand when/how this method is called, see the documentation for overdub and canrecurse.

By default, invoking fallback(context, f, args...) will simply call f(args...) (with all arguments automatically untagged, if hastagging(typeof(context))).

See also: canrecurse, overdub, recurse, prehook, posthook

source
Cassette.canrecurseFunction
canrecurse(context::Context, f, args...)

Return true if f(args...) has a lowered IR representation that Cassette can overdub, return false otherwise.

Alternatively, but equivalently:

Return false if recurse(context, f, args...) directly translates to fallback(context, f, args...), return true otherwise.

Note that unlike overdub, fallback, etc., this function is not intended to be overloaded.

See also: overdub, fallback, recurse

source
Cassette.@passMacro
Cassette.@pass transform

Return a Cassette pass that can be provided to the Context constructor's pass keyword argument in order to apply transform to the lowered IR representations of all methods invoked during contextual execution.

transform must be a Julia object that is callable with the following signature:

transform(::Type{<:Context}, ::Cassette.Reflection)::Union{Expr,CodeInfo}

If isa(transform(...), Expr), then the returned Expr will be emitted immediately without any additional processing. Otherwise, if isa(transform(...), CodeInfo), then the returned CodeInfo will undergo the rest of Cassette's overdubbing transformation before being emitted from the overdub generator.

Two special Expr heads are available to Cassette pass authors that are not normally valid in Julia IR. Exprs with these heads can be used to interact with the downstream built-in Cassette passes that consume them.

  • :nooverdub: Wrap an Expr with this head value around the first argument in an Expr(:call) to tell downstream built-in Cassette passes not to overdub that call. For example, Expr(:call, Expr(:nooverdub, GlobalRef(MyModule, :myfunc)), args...).

  • :contextslot: Cassette will replace any Expr(:contextslot) with the actual SlotNumber corresponding to the context object associated with the execution trace. For example, one could construct an IR element that accesses the context's metadata field by emitting: Expr(:call, Expr(:nooverdub, GlobalRef(Core, :getfield)), Expr(:contextslot), QuoteNode(:metadata))

Cassette provides a few IR-munging utility functions of interest to pass authors; for details, see insert_statements!, replace_match!, and is_ir_element.

Note that the @pass macro expands to an eval call and thus should only be called at top-level. Furthermore, to avoid world-age issues, transform should not be overloaded after it has been registered with @pass.

Note also that transform should be "relatively pure." More specifically, Julia's compiler has license to apply transform multiple times, even if only compiling a single method invocation once. Thus, it is required that transform always return a generally "equivalent" CodeInfo for a given context, method body, and signature. If your transform implementation is not naturally "pure" in this sense, then it is still possible to guarantee this property by memoizing your implementation (i.e. maintaining a cache of previously computed IR results, instead of recomputing results every time).

See also: Context, overdub

source
Cassette.replace_match!Function
replace_match!(replace, ismatch, x)

Return x with all subelements y replaced with replace(y) if ismatch(y). If !ismatch(y), but y is of type Expr, Array, or SubArray, then replace y in x with replace_match!(replace, ismatch, y).

Generally, x should be of type Expr, Array, or SubArray.

Note that this function modifies x (and potentially its subelements) in-place.

See also: insert_statements!, is_ir_element

source
Cassette.insert_statements!Function
insert_statements!(code::Vector, codelocs::Vector, stmtcount, newstmts)

For every statement stmt at position i in code for which stmtcount(stmt, i) returns an Int, remove stmt, and in its place, insert the statements returned by newstmts(stmt, i). If stmtcount(stmt, i) returns nothing, leave stmt alone.

For every insertion, all downstream SSAValues, label indices, etc. are incremented appropriately according to number of inserted statements.

Proper usage of this function dictates that following properties hold true:

  • code is expected to be a valid value for the code field of a CodeInfo object.
  • codelocs is expected to be a valid value for the codelocs field of a CodeInfo object.
  • newstmts(stmt, i) should return a Vector of valid IR statements.
  • stmtcount and newstmts must obey stmtcount(stmt, i) == length(newstmts(stmt, i)) if isa(stmtcount(stmt, i), Int).

To gain a mental model for this function's behavior, consider the following scenario. Let's say our code object contains several statements:

code = Any[oldstmt1, oldstmt2, oldstmt3, oldstmt4, oldstmt5, oldstmt6]
codelocs = Int[1, 2, 3, 4, 5, 6]

Let's also say that for our stmtcount returns 2 for stmtcount(oldstmt2, 2), returns 3 for stmtcount(oldstmt5, 5), and returns nothing for all other inputs. From this setup, we can think of code/codelocs being modified in the following manner:

newstmts2 = newstmts(oldstmt2, 2)
newstmts5 = newstmts(oldstmt5, 5)
code = Any[oldstmt1,
           newstmts2[1], newstmts2[2],
           oldstmt3, oldstmt4,
           newstmts5[1], newstmts5[2], newstmts5[3],
           oldstmt6]
codelocs = Int[1, 2, 2, 3, 4, 5, 5, 5, 6]

See also: replace_match!, is_ir_element

source
Cassette.OVERDUB_ARGUMENTS_NAMEConstant
Cassette.OVERDUB_ARGUMENTS_NAME

The variable name bound to overdub's tuple of non-Context arguments in its @generated method definition.

This binding can be used to manually reference/destructure overdub arguments within Expr thunks emitted by user-provided passes.

See also: OVERDUB_CONTEXT_NAME, @pass, overdub

source
Cassette.ReflectionType
Cassette.Reflection

A struct representing the information retrieved via Cassette.reflect.

A Reflection is essentially just a convenient bundle of information about a specific method invocation.

Fields

  • signature: the invocation signature (in Tuple{...} type form) for the invoked method.

  • method: the Method object associated with the invoked method.

  • static_params: a Vector representing the invoked method's static parameter list.

  • code_info: the CodeInfo object associated with the invoked method.

source
Cassette.tagFunction
tag(value, context::Context, metadata = Cassette.NoMetaData())

Return value tagged w.r.t. context, optionally associating metadata with the returned Tagged instance.

Any provided metadata must obey the type constraints determined by Cassette's metadatatype method.

Note that hastagging(typeof(context)) must be true for a value to be tagged w.r.t. to context.

See also: untag, enabletagging, hastagging

source
Cassette.untagFunction
untag(x, context::Context)

Return x untagged w.r.t. context if istagged(x, context), otherwise return x directly.

In other words, untag(tag(x, context), context) === x is always true.

If !istagged(x, context), then untag(x, context) === x is true.

See also: tag, istagged

source
Cassette.untagtypeFunction
untagtype(::Type{T}, ::Type{C<:Context})

Return typeof(untag(::T, ::C)).

In other words, untagtype(typeof(tag(x, context)), typeof(context)) === typeof(x) is always true.

If !istaggedtype(T, C), then untagtype(T, C) === T is true.

source
Cassette.metadataFunction
metadata(x, context::Context)

Return the metadata attached to x if hasmetadata(x, context), otherwise return Cassette.NoMetaData().

In other words, metadata(tag(x, context, m), context) === m is always true.

If !hasmetadata(x, context), then metadata(x, context) === Cassette.NoMetaData() is true.

source
Cassette.metadatatypeFunction
metadatatype(::Type{<:Context}, ::Type{T})

Overload this Cassette method w.r.t. a given context to define the type of metadata that can be tagged to values of type T within that context.

By default, this method is set such that associating metadata with any tagged value is disallowed.

Cassette uses metadatatype to statically compute a context-specific metadata type hiearchy for all tagged values within overdubbed programs. To gain a mental model for this mechanism, consider a simple struct definition as follows:

struct Foo
    x::Int
    y::Complex{Int}
end

Now, Cassette can use metadatatype to determine type constraints for metadata structures associated with tagged values of type Foo. In psuedo-Julia-code, these metadata structures might look something like the following for Foo:

struct IntMeta
    data::metadatatype(Ctx, Int)
    meta::Cassette.NoMetaMeta
end

struct ComplexIntMeta
    data::metadatatype(Ctx, Complex{Int})
    meta::NamedTuple{(:re,:im),Tuple{IntMeta,IntMeta}}
end

struct FooMeta
    data::metadatatype(Ctx, Foo)
    meta::NamedTuple{(:x,:y),Tuple{IntMeta,ComplexIntMeta}
end

Examples

julia> Cassette.@context Ctx;

# any value of type `Number` can now be tagged with metadata of type `Number`
julia> Cassette.metadatatype(::Type{<:Ctx}, ::Type{<:Number}) = Number

# any value of type `T<:Number` can now be tagged with metadata of type `T`
julia> Cassette.metadatatype(::Type{<:Ctx}, ::Type{T}) where {T<:Number} = T

# any value of type `T<:Number` can now be tagged with metadata of type `promote_type(T, M)`
# where `M` is the type of the trace-local metadata associated with the context
julia> Cassette.metadatatype(::Type{<:Ctx{M}}, ::Type{T}) where {M<:Number,T<:Number} = promote_type(T, M)
source
Cassette.hasmetadataFunction
hasmetadata(x, context::Context)

Return true if !isa(metadata(x, context), Cassette.NoMetaData), return false otherwise.

In other words, hasmetadata(tag(x, context, m), context) is always true and hasmetadata(tag(x, context), context) is always false.

See also: metadata

source
Cassette.istaggedFunction
istagged(x, context::Context)

Return true if x is tagged w.r.t. context, return false otherwise.

In other words, istagged(tag(x, context), context) is always true.

See also: tag, istaggedtype

source
Cassette.istaggedtypeFunction
istaggedtype(::Type{T}, ::Type{C<:Context})

Return typeof(istagged(::T, ::C)).

In other words, istaggedtype(typeof(tag(x, context)), typeof(context)) is always true.

See also: tag, istagged

source