Program Instances

Finch relies heavily on Julia's metaprogramming capabilities ( macros and generated functions in particular) to produce code. To review briefly, a macro allows us to inspect the syntax of it's arguments and generate replacement syntax. A generated function allows us to inspect the type of the function arguments and produce code for a function body.

In normal Finch usage, we might call Finch as follows:

julia> C = Tensor(SparseList(Element(0)));

julia> A = Tensor(SparseList(Element(0)), [0, 2, 0, 0, 3]);

julia> B = Tensor(Dense(Element(0)), [11, 12, 13, 14, 15]);

julia> @finch (C .= 0; for i=_; C[i] = A[i] * B[i] end);

julia> C
5-Tensor
└─ SparseList (0) [1:5]
   ├─ [2]: 24
   └─ [5]: 45

The @macroexpand macro allows us to see the result of applying a macro. Let's examine what happens when we use the @finch macro (we've stripped line numbers from the result to clean it up):

julia> (@macroexpand @finch (C .= 0; for i=_; C[i] = A[i] * B[i] end)) |> Finch.striplines |> Finch.regensym
quote
    _res_1 = (Finch.execute)((Finch.FinchNotation.block_instance)((Finch.FinchNotation.block_instance)((Finch.FinchNotation.declare_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), literal_instance(0)), begin
                        let i = index_instance(i)
                            (Finch.FinchNotation.loop_instance)(i, Finch.FinchNotation.Dimensionless(), (Finch.FinchNotation.assign_instance)((Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:C), (Finch.FinchNotation.finch_leaf_instance)(C)), literal_instance(Finch.FinchNotation.Updater()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.literal_instance)(Finch.FinchNotation.initwrite), (Finch.FinchNotation.call_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:*), (Finch.FinchNotation.finch_leaf_instance)(*)), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:A), (Finch.FinchNotation.finch_leaf_instance)(A)), literal_instance(Finch.FinchNotation.Reader()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))), (Finch.FinchNotation.access_instance)((Finch.FinchNotation.tag_instance)(variable_instance(:B), (Finch.FinchNotation.finch_leaf_instance)(B)), literal_instance(Finch.FinchNotation.Reader()), (Finch.FinchNotation.tag_instance)(variable_instance(:i), (Finch.FinchNotation.finch_leaf_instance)(i))))))
                        end
                    end), (Finch.FinchNotation.yieldbind_instance)(variable_instance(:C))); )
    begin
        C = _res_1[:C]
    end
    begin
        _res_1
    end
end

In the above output, @finch creates an AST of program instances, then calls Finch.execute on it. A program instance is a struct that contains the program to be executed along with its arguments. Although we can use the above constructors (e.g. loop_instance) to make our own program instance, it is most convenient to use the unexported macro Finch.finch_program_instance:

julia> using Finch: @finch_program_instance

julia> prgm = Finch.@finch_program_instance (C .= 0; for i=_; C[i] = A[i] * B[i] end; return C)
Finch program instance: begin
  tag(C, Tensor(SparseList(Element(0)))) .= 0
  for i = Dimensionless()
    tag(C, Tensor(SparseList(Element(0))))[tag(i, i)] <<initwrite>>= tag(*, *)(tag(A, Tensor(SparseList(Element(0))))[tag(i, i)], tag(B, Tensor(Dense(Element(0))))[tag(i, i)])
  end
  return (tag(C, Tensor(SparseList(Element(0)))))
end

As we can see, our program instance contains not only the AST to be executed, but also the data to execute the program with. The type of the program instance contains only the program portion; there may be many program instances with different inputs, but the same program type. We can run our program using Finch.execute, which returns a NamedTuple of outputs.

julia> typeof(prgm)
Finch.FinchNotation.BlockInstance{Tuple{Finch.FinchNotation.DeclareInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{0}}, Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Dimensionless, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Updater()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.initwrite}, Finch.FinchNotation.CallInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:*}, Finch.FinchNotation.LiteralInstance{*}}, Tuple{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:B}, Tensor{DenseLevel{Int64, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}}}}, Finch.FinchNotation.YieldBindInstance{Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:C}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}}}}}

julia> C = Finch.execute(prgm).C
5-Tensor
└─ SparseList (0) [1:5]
   ├─ [2]: 24
   └─ [5]: 45

This functionality is sufficient for building finch kernels programatically. For example, if we wish to define a function pointwise_sum() that takes the pointwise sum of a variable number of vector inputs, we might implement it as follows:

julia> function pointwise_sum(As...)
           B = Tensor(Dense(Element(0)))
           isempty(As) && return B
           i = Finch.FinchNotation.index_instance(:i)
           A_vars = [Finch.FinchNotation.tag_instance(Finch.FinchNotation.variable_instance(Symbol(:A, n)), As[n]) for n in 1:length(As)]
           #create a list of variable instances with different names to hold the input tensors
           ex = @finch_program_instance 0
           for A_var in A_vars
               ex = @finch_program_instance $A_var[i] + $ex
           end
           prgm = @finch_program_instance (B .= 0; for i=_; B[i] = $ex end; return B)
           return Finch.execute(prgm).B
       end
pointwise_sum (generic function with 1 method)

julia> pointwise_sum([1, 2], [3, 4])
2-Tensor
└─ Dense [1:2]
   ├─ [1]: 4
   └─ [2]: 6

Virtualization

Finch generates different code depending on the types of the arguments to the program. For example, in the following program, Finch generates different code depending on the types of A and B. In order to execute a program, Finch builds a typed AST (Abstract Syntax Tree), then calls Finch.execute on it. The AST object is just an instance of a program to execute, and contains the program to execute along with the data to execute it. The type of the program instance contains only the program portion; there may be many program instances with different inputs, but the same program type. During compilation, Finch uses the type of the program to construct a more ergonomic representation, which is then used to generate code. This process is called "virtualization". All of the Finch AST nodes have both instance and virtual representations. For example, the literal 42 is represented as Finch.FinchNotation.LiteralInstance(42) and then virtualized to literal(42). The virtualization process is implemented by the virtualize function.

julia> A = Tensor(SparseList(Element(0)), [0, 2, 0, 0, 3]);

julia> B = Tensor(Dense(Element(0)), [11, 12, 13, 14, 15]);

julia> s = Scalar(0);

julia> typeof(A)
Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}

julia> typeof(B)
Tensor{DenseLevel{Int64, ElementLevel{0, Int64, Int64, Vector{Int64}}}}

julia> inst = Finch.@finch_program_instance begin
           for i = _
               s[] += A[i]
           end
       end
Finch program instance: for i = Dimensionless()
  tag(s, Scalar{0, Int64})[] <<tag(+, +)>>= tag(A, Tensor(SparseList(Element(0))))[tag(i, i)]
end

julia> typeof(inst)
Finch.FinchNotation.LoopInstance{Finch.FinchNotation.IndexInstance{:i}, Finch.FinchNotation.Dimensionless, Finch.FinchNotation.AssignInstance{Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:s}, Scalar{0, Int64}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Updater()}, Tuple{}}, Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:+}, Finch.FinchNotation.LiteralInstance{+}}, Finch.FinchNotation.AccessInstance{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:A}, Tensor{SparseListLevel{Int64, Vector{Int64}, Vector{Int64}, ElementLevel{0, Int64, Int64, Vector{Int64}}}}}, Finch.FinchNotation.LiteralInstance{Finch.FinchNotation.Reader()}, Tuple{Finch.FinchNotation.TagInstance{Finch.FinchNotation.VariableInstance{:i}, Finch.FinchNotation.IndexInstance{:i}}}}}}

julia> Finch.virtualize(Finch.JuliaContext(), :inst, typeof(inst))
Finch program: for i = virtual(Finch.FinchNotation.Dimensionless)
  tag(s, virtual(Finch.VirtualScalar))[] <<tag(+, +)>>= tag(A, virtual(Finch.VirtualFiber{Finch.VirtualSparseListLevel}))[tag(i, i)]
end

julia> @finch_code begin
           for i = _
               s[] += A[i]
           end
       end
quote
    s = (ex.bodies[1]).body.lhs.tns.bind
    s_val = s.val
    A_lvl = (ex.bodies[1]).body.rhs.tns.bind.lvl
    A_lvl_ptr = A_lvl.ptr
    A_lvl_idx = A_lvl.idx
    A_lvl_val = A_lvl.lvl.val
    A_lvl_q = A_lvl_ptr[1]
    A_lvl_q_stop = A_lvl_ptr[1 + 1]
    if A_lvl_q < A_lvl_q_stop
        A_lvl_i1 = A_lvl_idx[A_lvl_q_stop - 1]
    else
        A_lvl_i1 = 0
    end
    phase_stop = min(A_lvl_i1, A_lvl.shape)
    if phase_stop >= 1
        if A_lvl_idx[A_lvl_q] < 1
            A_lvl_q = Finch.scansearch(A_lvl_idx, 1, A_lvl_q, A_lvl_q_stop - 1)
        end
        while true
            A_lvl_i = A_lvl_idx[A_lvl_q]
            if A_lvl_i < phase_stop
                A_lvl_2_val = A_lvl_val[A_lvl_q]
                s_val = A_lvl_2_val + s_val
                A_lvl_q += 1
            else
                phase_stop_3 = min(phase_stop, A_lvl_i)
                if A_lvl_i == phase_stop_3
                    A_lvl_2_val = A_lvl_val[A_lvl_q]
                    s_val += A_lvl_2_val
                    A_lvl_q += 1
                end
                break
            end
        end
    end
    result = ()
    s.val = s_val
    result
end

julia> @finch_code begin
           for i = _
               s[] += B[i]
           end
       end
quote
    s = (ex.bodies[1]).body.lhs.tns.bind
    s_val = s.val
    B_lvl = (ex.bodies[1]).body.rhs.tns.bind.lvl
    B_lvl_val = B_lvl.lvl.val
    for i_3 = 1:B_lvl.shape
        B_lvl_q = (1 - 1) * B_lvl.shape + i_3
        B_lvl_2_val = B_lvl_val[B_lvl_q]
        s_val = B_lvl_2_val + s_val
    end
    result = ()
    s.val = s_val
    result
end

The "virtual" IR Node

Users can also create their own virtual nodes to represent their custom types. While most calls to virtualize result in a Finch IR Node, some objects, such as tensors and dimensions, are virtualized to a virtual object, which holds the custom virtual type. These types may contain constants and other virtuals, as well as reference variables in the scope of the executing context. Any aspect of virtuals visible to Finch should be considered immutable, but virtuals may reference mutable variables in the scope of the executing context.

Finch.virtualizeFunction
virtualize(ctx, ex, T, [tag])

Return the virtual program corresponding to the Julia expression ex of type T in the JuliaContext ctx. Implementaters may support the optional tag argument is used to name the resulting virtual variable.

source
Finch.FinchNotation.virtualConstant
virtual(val)

Finch AST expression for an object val which has special meaning to the compiler. This type is typically used for tensors, as it allows users to specify the tensor's shape and data type.

source

Virtual Methods

Many methods have analogues we can call on the virtual version of the object. For example, we can call size an an array, and virtual_size on a virtual array. The virtual methods are used to generate code, so if they are pure they may return an expression which computes the results, and if they have side effects they may accept a context argument into which they can emit their side-effecting code.

In addition to the special compiler methods which are prefixed virtual_, there is also a function virtual_call, which is used to evaluate function calls on Finch IR when it would result in a virtual object. The behavior should mirror the concrete behavior of the corresponding function.

Finch.virtual_callFunction
virtual_call(ctx, f, a...)

Given the virtual arguments a..., and a literal function f, return a virtual object representing the result of the function call. If the function is not foldable, return nothing. This function is used so that we can call e.g. tensor constructors in finch code.

source

Working with Finch IR

Calling print on a finch program or program instance will print the structure of the program as one would call constructors to build it. For example,

julia> prgm_inst = Finch.@finch_program_instance for i = _
            s[] += A[i]
        end;

julia> println(prgm_inst)
loop_instance(index_instance(i), Finch.FinchNotation.Dimensionless(), assign_instance(access_instance(tag_instance(variable_instance(:s), Scalar{0, Int64}(0)), literal_instance(Finch.FinchNotation.Updater())), tag_instance(variable_instance(:+), literal_instance(+)), access_instance(tag_instance(variable_instance(:A), Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal_instance(Finch.FinchNotation.Reader()), tag_instance(variable_instance(:i), index_instance(i)))))

julia> prgm_inst
Finch program instance: for i = Dimensionless()
  tag(s, Scalar{0, Int64})[] <<tag(+, +)>>= tag(A, Tensor(SparseList(Element(0))))[tag(i, i)]
end

julia> prgm = Finch.@finch_program for i = _
               s[] += A[i]
           end;


julia> println(prgm)
loop(index(i), virtual(Finch.FinchNotation.Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), literal(Finch.FinchNotation.Updater())), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal(Finch.FinchNotation.Reader()), index(i))))

julia> prgm
Finch program: for i = virtual(Finch.FinchNotation.Dimensionless)
  Scalar{0, Int64}(0)[] <<+>>= Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))[i]
end

Both the virtual and instance representations of Finch IR define SyntaxInterface.jl and AbstractTrees.jl representations, so you can use the standard operation, arguments, istree, and children functions to inspect the structure of the program, as well as the rewriters defined by RewriteTools.jl

julia> using Finch.FinchNotation;


julia> PostOrderDFS(prgm)
PostOrderDFS{FinchNode}(loop(index(i), virtual(Dimensionless()), assign(access(literal(Scalar{0, Int64}(0)), literal(Updater())), literal(+), access(literal(Tensor(SparseList{Int64}(Element{0, Int64, Int64}([2, 3]), 5, [1, 3], [2, 5]))), literal(Reader()), index(i)))))

julia> (@capture prgm loop(~idx, ~ext, ~val))
true

julia> idx
Finch program: i