--- lua-tikz3dtools-scene.lua
--- Scene management, TeX command registration, and rendering for lua-tikz3dtools.

local Vector
local Matrix
local Geometry

local Scene = {}
local lua_tikz3dtools = {}

--- Set class references (called after all modules are loaded).
function Scene._set_classes(V, M, G)
    Vector = V
    Matrix = M
    Geometry = G
end


-- ================================================================
-- Sandboxed evaluation environment
-- ================================================================

local function make_readonly_table(value, label)
    return setmetatable(value, {
        __newindex = function(_, key, _)
            error(("sandbox value '%s' is read-only; cannot assign '%s'")
                :format(label, tostring(key)), 2)
        end,
        __metatable = false,
    })
end

local function make_readonly_proxy(source, label)
    return setmetatable({}, {
        __index = source,
        __newindex = function(_, key, _)
            error(("sandbox value '%s' is read-only; cannot assign '%s'")
                :format(label, tostring(key)), 2)
        end,
        __metatable = false,
    })
end

--- Build the base environment used by all sandbox evaluations.
--- User-defined objects are resolved separately so expressions cannot mutate
--- shared global state or shadow built-in names.
local blocked_globals = {
    debug = true,
    dofile = true,
    load = true,
    loadfile = true,
    package = true,
    require = true,
}

local proxied_tables = {
    coroutine = true,
    io = true,
    math = true,
    os = true,
    string = true,
    table = true,
    utf8 = true,
}

local function build_base_env()
    local env = {}

    for key, value in pairs(_G) do
        if not blocked_globals[key] then
            if proxied_tables[key] and type(value) == "table" then
                env[key] = make_readonly_proxy(value, key)
            else
                env[key] = value
            end
        end
    end

    env.table = env.table or make_readonly_proxy(table, "table")
    env.math = env.math or make_readonly_proxy(math, "math")
    return env
end

local function source_preview(str)
    local preview = tostring(str or "")
        :gsub("%s+", " ")
        :gsub("^%s+", "")

    if #preview > 160 then
        preview = preview:sub(1, 157) .. "..."
    end

    return preview
end

local function format_eval_error(kind, label, source, err)
    return ("%s in %s: %s\nSource: %s")
        :format(kind, label or "expression", tostring(err), source_preview(source))
end

local function make_eval_env(bindings)
    bindings = bindings or {}

    return setmetatable({}, {
        __index = function(_, key)
            local value = bindings[key]
            if value ~= nil then
                return value
            end

            value = lua_tikz3dtools.objects[key]
            if value ~= nil then
                return value
            end

            return lua_tikz3dtools.base_env[key]
        end,
        __newindex = function(_, key, _)
            error(("sandbox is read-only; cannot assign global '%s'")
                :format(tostring(key)), 2)
        end,
        __metatable = false,
    })
end

local function evaluate_chunk(source, label, bindings)
    local chunk, syntax_err = load(source, label or "expression", "t", make_eval_env(bindings))

    if not chunk then
        error(format_eval_error("Lua syntax error", label, source, syntax_err), 0)
    end

    local ok, result = pcall(chunk)
    if not ok then
        error(format_eval_error("Lua evaluation error", label, source, result), 0)
    end

    return result
end

local function wrap_user_function(fn, label, source)
    return function(...)
        local ok, result = pcall(fn, ...)
        if not ok then
            error(format_eval_error("Lua function error", label, source, result), 0)
        end
        return result
    end
end

local statement_keywords = {
    "return",
    "local",
    "if",
    "for",
    "while",
    "repeat",
    "do",
}

local function starts_with_statement(trimmed)
    for _, keyword in ipairs(statement_keywords) do
        if trimmed:match("^" .. keyword .. "%f[%W]") then
            return true
        end
    end

    return false
end

-- ================================================================
-- TeX command registration helper
-- https://tex.stackexchange.com/a/747040
-- ================================================================

local function register_tex_cmd(name, func, args, protected)
    name = "__lua_tikztdtools_" .. name .. ":" .. ("n"):rep(#args)
    local scanners = {}
    for _, arg in ipairs(args) do
        scanners[#scanners+1] = token['scan_' .. arg]
    end
    local scanning_func = function()
        local values = {}
        for _, scanner in ipairs(scanners) do
            values[#values+1] = scanner()
        end
        func(table.unpack(values))
    end
    local index = luatexbase.new_luafunction(name)
    lua.get_functions_table()[index] = scanning_func
    if protected then
        token.set_lua(name, index, "protected")
    else
        token.set_lua(name, index)
    end
end

-- ================================================================
-- Global scene state
-- ================================================================

lua_tikz3dtools.simplices = {}
lua_tikz3dtools.lights = {}
lua_tikz3dtools.objects = {}
lua_tikz3dtools.base_env = build_base_env()

local function refresh_base_env()
    lua_tikz3dtools.base_env.Vector = make_readonly_proxy(Vector, "Vector")
    lua_tikz3dtools.base_env.Matrix = make_readonly_proxy(Matrix, "Matrix")
    lua_tikz3dtools.base_env.tau = 2 * math.pi
end

--- Late-init: called after Vector/Matrix are set to update the env table
function Scene._init_math_env()
    refresh_base_env()
end

-- ================================================================
-- Expression evaluators (sandboxed)
-- ================================================================

local function single_string_expression(str, label, bindings)
    return evaluate_chunk(("return %s"):format(str), label, bindings)
end

local function body_expression(str, label, bindings)
    return evaluate_chunk(str, label, bindings)
end

local function object_expression(str, label, bindings)
    local trimmed = str:match("^%s*(.-)%s*$") or ""
    if starts_with_statement(trimmed) then
        return body_expression(str, label, bindings)
    end
    return single_string_expression(str, label, bindings)
end

local function single_string_function(str, label, bindings)
    local fn = evaluate_chunk(("return function(u) %s end"):format(str), label, bindings)
    return wrap_user_function(fn, label, str)
end

local function double_string_function(str, label, bindings)
    local fn = evaluate_chunk(("return function(u,v) %s end"):format(str), label, bindings)
    return wrap_user_function(fn, label, str)
end

local function triple_string_function(str, label, bindings)
    local fn = evaluate_chunk(("return function(u,v,w) %s end"):format(str), label, bindings)
    return wrap_user_function(fn, label, str)
end

local function is_finite_number(value)
    return type(value) == "number"
        and value == value
        and value ~= math.huge
        and value ~= -math.huge
end

local function is_finite_simplex(simplex)
    local mt = getmetatable(simplex)

    if mt == Vector then
        for i = 1, #simplex do
            if not is_finite_number(simplex[i]) then
                return false
            end
        end
        return true
    end

    if mt == Matrix then
        for i = 1, #simplex do
            for j = 1, #simplex[i] do
                if not is_finite_number(simplex[i][j]) then
                    return false
                end
            end
        end
        return true
    end

    return false
end

local function project_point(v, transformation, label)
    if not v then
        return nil
    end

    local ok, projected = pcall(function()
        return v:multiply(transformation)
    end)

    if not ok then
        error(("Projection failed for %s: %s")
            :format(label or "point", tostring(projected)), 0)
    end

    if is_finite_simplex(projected) then
        return projected
    end

    return nil
end

local function project_simplex(simplex, transformation, label)
    if not simplex then
        return nil
    end

    local ok, projected = pcall(function()
        return simplex:multiply(transformation):reciprocate_by_homogeneous()
    end)

    if not ok then
        error(("Projection failed for %s: %s")
            :format(label or "simplex", tostring(projected)), 0)
    end

    if is_finite_simplex(projected) then
        return projected
    end

    return nil
end

local function push_simplex(entry)
    if entry.simplex and is_finite_simplex(entry.simplex) then
        table.insert(lua_tikz3dtools.simplices, entry)
        return true
    end

    return false
end

local function is_nonempty_string(value)
    return type(value) == "string" and value ~= ""
end

local function uv_curve_point(value)
    if getmetatable(value) ~= Vector or #value < 2 then
        return nil
    end
    if not is_finite_number(value[1]) or not is_finite_number(value[2]) then
        return nil
    end
    return Vector:_new{value[1], value[2], 1}
end

local function uv_curve_point_value(value)
    if getmetatable(value) == Vector then
        return uv_curve_point(value)
    end
    if type(value) == "table" then
        return uv_curve_point(Vector:_new(value))
    end
    return nil
end

local function explicit_uv_curve_segments(str, default_drawoptions, label)
    local segments = body_expression(str, label)
    local uv_segments = {}

    if type(segments) ~= "table" then
        return nil
    end

    for _, segment in ipairs(segments) do
        if type(segment) == "table" or getmetatable(segment) == Matrix then
            local P = uv_curve_point_value(segment.start or segment[1])
            local Q = uv_curve_point_value(segment.stop or segment[2])

            if P and Q and P:hdistance(Q) > 1e-12 then
                table.insert(uv_segments, {
                    simplex = Matrix:_new{P:to_table(), Q:to_table()},
                    drawoptions = segment.drawoptions or default_drawoptions
                })
            end
        end
    end

    if #uv_segments == 0 then
        return nil
    end
    return uv_segments
end

local function embedded_segments_in_triangle(uv_segments, uv_triangle)
    if uv_segments == nil then
        return nil
    end

    local embedded_segments = {}
    for _, segment in ipairs(uv_segments) do
        local clipped = Geometry.hclip_line_segment_to_triangle(segment.simplex, uv_triangle)
        if clipped ~= nil then
            local start_bary = Geometry.hpoint_triangle_barycentric(Vector:_new(clipped[1]), uv_triangle)
            local stop_bary = Geometry.hpoint_triangle_barycentric(Vector:_new(clipped[2]), uv_triangle)
            if start_bary ~= nil and stop_bary ~= nil then
                table.insert(embedded_segments, {
                    start = start_bary:to_table(),
                    stop = stop_bary:to_table(),
                    drawoptions = segment.drawoptions
                })
            end
        end
    end

    if #embedded_segments == 0 then
        return nil
    end
    return embedded_segments
end

local function render_embedded_segments(simplex)
    if simplex.embedded_segments == nil then
        return
    end

    for _, segment in ipairs(simplex.embedded_segments) do
        local start_point = Geometry.hpoint_from_triangle_barycentric(
            simplex.simplex,
            Vector:_new(segment.start)
        )
        local stop_point = Geometry.hpoint_from_triangle_barycentric(
            simplex.simplex,
            Vector:_new(segment.stop)
        )

        if start_point:hdistance(stop_point) > 1e-12 then
            tex.sprint(
                ("\\path[%s] (%f,%f) -- (%f,%f);")
                :format(
                    segment.drawoptions or "",
                    start_point[1], start_point[2],
                    stop_point[1], stop_point[2]
                )
            )
        end
    end
end

-- ================================================================
-- Append functions
-- ================================================================

local function append_point(hash)
    local v              = body_expression(hash.v, "appendpoint.v")
    local transformation = object_expression(hash.transformation, "appendpoint.transformation")
    local filloptions    = hash.filloptions
    local filter         = hash.filter
    if v then
        local the_simplex = project_point(v, transformation, "appendpoint")
        if the_simplex then
            push_simplex({
                simplex     = the_simplex,
                filloptions = filloptions,
                type        = "point",
                filter      = filter
            })
        end
    end
end

local function param_triplet(value, key_name)
    assert(value and getmetatable(value) == Vector,
        key_name .. " must return a Vector")
    assert(value[1] ~= nil and value[2] ~= nil and value[3] ~= nil,
        key_name .. " must contain start, stop, and samples")

    return value[1], value[2], value[3]
end

local function resolve_axis_params(params_src, key_name)
    return param_triplet(body_expression(params_src, key_name), key_name)
end

local function append_surface(hash)
    local ustart, ustop, usamples = resolve_axis_params(hash.uparams, "appendsurface.uparams")
    local vstart, vstop, vsamples = resolve_axis_params(hash.vparams, "appendsurface.vparams")
    local transformation = object_expression(hash.transformation, "appendsurface.transformation")
    local f              = double_string_function(hash.v, "appendsurface.v")
    local filloptions    = hash.filloptions
    local filter         = hash.filter
    local uv_curve_segments

    assert(usamples and usamples >= 2, "usamples must be >= 2, got: " .. tostring(usamples))
    assert(vsamples and vsamples >= 2, "vsamples must be >= 2, got: " .. tostring(vsamples))

    local ustep = (ustop - ustart) / (usamples - 1)
    local vstep = (vstop - vstart) / (vsamples - 1)

    local function parametric_surface(u, v)
        return f(u, v)
    end

    if is_nonempty_string(hash.curve) then
        uv_curve_segments = explicit_uv_curve_segments(hash.curve, nil, "appendsurface.curve")
    end

    for i = 0, usamples - 2 do
        local u = ustart + i * ustep
        for j = 0, vsamples - 2 do
            local v = vstart + j * vstep
            local A = parametric_surface(u, v)
            local B = parametric_surface(u + ustep, v)
            local C = parametric_surface(u + ustep, v + vstep)
            local D = parametric_surface(u, v + vstep)
            if A and B and C and D then
                local uvA = Vector:_new{u, v, 1}
                local uvB = Vector:_new{u + ustep, v, 1}
                local uvC = Vector:_new{u + ustep, v + vstep, 1}
                local uvD = Vector:_new{u, v + vstep, 1}
                if not (
                    Geometry.hpoint_point_intersecting(A, B)
                    or Geometry.hpoint_point_intersecting(B, C)
                    or Geometry.hpoint_point_intersecting(A, C)
                ) then
                    local simplex1 = project_simplex(
                        Matrix:_new{A:to_table(), B:to_table(), C:to_table()},
                        transformation,
                        "appendsurface.triangle1"
                    )
                    if simplex1 then
                        push_simplex({
                            simplex           = simplex1,
                            filloptions       = filloptions,
                            type              = "triangle",
                            filter            = filter,
                            embedded_segments = embedded_segments_in_triangle(
                                uv_curve_segments,
                                Matrix:_new{uvA:to_table(), uvB:to_table(), uvC:to_table()}
                            )
                        })
                    end
                end
                if not (
                    Geometry.hpoint_point_intersecting(A, D)
                    or Geometry.hpoint_point_intersecting(D, C)
                    or Geometry.hpoint_point_intersecting(A, C)
                ) then
                    local simplex2 = project_simplex(
                        Matrix:_new{A:to_table(), D:to_table(), C:to_table()},
                        transformation,
                        "appendsurface.triangle2"
                    )
                    if simplex2 then
                        push_simplex({
                            simplex           = simplex2,
                            filloptions       = filloptions,
                            type              = "triangle",
                            filter            = filter,
                            embedded_segments = embedded_segments_in_triangle(
                                uv_curve_segments,
                                Matrix:_new{uvA:to_table(), uvD:to_table(), uvC:to_table()}
                            )
                        })
                    end
                end
            end
        end
    end
end

local function append_triangle(hash)
    local transformation = object_expression(hash.transformation, "appendtriangle.transformation")
    local filter         = hash.filter
    local filloptions    = hash.filloptions
    assert(hash.m and hash.m ~= "", "appendtriangle.m must return a 3-row Matrix")

    local the_simplex = object_expression(hash.m, "appendtriangle.m")
    assert(getmetatable(the_simplex) == Matrix, "appendtriangle.m must return a Matrix")
    assert(#the_simplex == 3, "appendtriangle.m must return a 3-row Matrix")

    local A = Vector:_new(the_simplex[1])
    local B = Vector:_new(the_simplex[2])
    local C = Vector:_new(the_simplex[3])

    if not (
        Geometry.hpoint_point_intersecting(A, B)
        or Geometry.hpoint_point_intersecting(B, C)
        or Geometry.hpoint_point_intersecting(C, A)
    ) then
        local projected = project_simplex(the_simplex, transformation, "appendtriangle")
        if projected then
            push_simplex({
                simplex     = projected,
                filloptions = filloptions,
                type        = "triangle",
                filter      = filter
            })
        end
    end
end

local function append_label(hash)
    local v              = body_expression(hash.v, "appendlabel.v")
    local filter         = hash.filter
    local text           = hash.text
    local transformation = object_expression(hash.transformation, "appendlabel.transformation")
    if v then
        local the_simplex = project_point(v, transformation, "appendlabel")
        if the_simplex then
            push_simplex({
                simplex     = the_simplex,
                text        = text,
                type        = "label",
                filter      = filter
            })
        end
    end
end

local function append_light(hash)
    local v = body_expression(hash.v, "appendlight.v")
    if v and getmetatable(v) == Vector then
        table.insert(lua_tikz3dtools.lights, v)
    else
        error("Invalid light vector: " .. tostring(v))
    end
end

local function append_curve(hash)
    local ustart, ustop, usamples = resolve_axis_params(hash.uparams, "appendcurve.uparams")
    local transformation = object_expression(hash.transformation, "appendcurve.transformation")
    local f              = single_string_function(hash.v, "appendcurve.v")
    local filter         = hash.filter
    local drawoptions    = hash.drawoptions
    local arrowoptions   = hash.arrowtip
    local tailoptions    = hash.arrowtail

    assert(usamples and usamples >= 2, "usamples must be >= 2, got: " .. tostring(usamples))

    local ustep = (ustop - ustart) / (usamples - 1)

    local function parametric_curve(u)
        return f(u)
    end

    for i = 0, usamples - 2 do
        local u = ustart + i * ustep
        local A = parametric_curve(u)
        local B = parametric_curve(u + ustep)
        if A and B then
            local simplex = project_simplex(
                Matrix:_new{A:to_table(), B:to_table()},
                transformation,
                "appendcurve.segment"
            )
            if simplex then
                push_simplex({
                    simplex      = simplex,
                    drawoptions  = drawoptions,
                    type         = "line segment",
                    filter       = filter
                })
            end
            if i == 0 and tailoptions then
                local P = project_point(parametric_curve(ustart), transformation, "appendcurve.arrowtail.start")
                local Q = project_point(parametric_curve(ustart + ustep), transformation, "appendcurve.arrowtail.stop")
                if P and Q and P:hdistance(Q) > 1e-12 then
                    local U = P:hsub(Q):hnormalize()
                    local V = U:orthogonal_vector():hnormalize()
                    local W = U:hhypercross(V):hnormalize()
                    append_surface{
                        uparams = "return Vector:new{0, tau, 6}",
                        vparams = "return Vector:new{0+pi/2, pi, 2}",
                        v = "return Vector:_new{0.1*math.sin(v)*math.cos(u), 0.1*math.sin(v)*math.sin(u), 0*0.1*math.cos(v), 1}",
                        filloptions = tailoptions,
                        transformation = ([[
                            return Matrix:new{
                                {%f,%f,%f,0},{%f,%f,%f,0},{%f,%f,%f,0},{%f,%f,%f,1}
                            }
                        ]]):format(
                            W[1],W[2],W[3], V[1],V[2],V[3], U[1],U[2],U[3], P[1],P[2],P[3]
                        ),
                        filter = filter
                    }
                end
            end
            if i == usamples - 2 and arrowoptions then
                local P = project_point(parametric_curve(ustop), transformation, "appendcurve.arrowtip.start")
                local Q = project_point(parametric_curve(ustop - ustep), transformation, "appendcurve.arrowtip.stop")
                if P and Q and P:hdistance(Q) > 1e-12 then
                    local U = P:hsub(Q):hnormalize()
                    local V = U:orthogonal_vector():hnormalize()
                    local W = U:hhypercross(V):hnormalize()
                    append_surface{
                        uparams = "return Vector:new{0, 0.1, 2}",
                        vparams = "return Vector:new{0, 1, 4}",
                        v = "return Vector:_new{u*math.cos(v*tau), u*math.sin(v*tau), -u, 1}",
                        filloptions = arrowoptions,
                        transformation = ([[
                            return Matrix:new{
                                {%f,%f,%f,0},{%f,%f,%f,0},{%f,%f,%f,0},{%f,%f,%f,1}
                            }
                        ]]):format(
                            W[1],W[2],W[3], V[1],V[2],V[3], U[1],U[2],U[3], P[1],P[2],P[3]
                        ),
                        filter = filter
                    }
                end
            end
        end
    end
end

local function append_solid(hash)
    local ustart, ustop, usamples = resolve_axis_params(hash.uparams, "appendsolid.uparams")
    local vstart, vstop, vsamples = resolve_axis_params(hash.vparams, "appendsolid.vparams")
    local wstart, wstop, wsamples = resolve_axis_params(hash.wparams, "appendsolid.wparams")
    local filloptions    = hash.filloptions
    local filter = hash.filter
    local transformation = object_expression(hash.transformation, "appendsolid.transformation")
    local f = triple_string_function(hash.v, "appendsolid.v")

    assert(usamples and usamples >= 2, "usamples must be >= 2, got: " .. tostring(usamples))
    assert(vsamples and vsamples >= 2, "vsamples must be >= 2, got: " .. tostring(vsamples))
    assert(wsamples and wsamples >= 2, "wsamples must be >= 2, got: " .. tostring(wsamples))

    local function parametric_solid(u, v, w)
        return f(u, v, w)
    end

    local ustep = (ustop - ustart) / (usamples - 1)
    local vstep = (vstop - vstart) / (vsamples - 1)
    local wstep = (wstop - wstart) / (wsamples - 1)

    local function tessellate_face(fixed_var, fixed_val, s1_start, s1_step, s1_count, s2_start, s2_step, s2_count)
        for i = 0, s1_count - 2 do
            local s1 = s1_start + i * s1_step
            for j = 0, s2_count - 2 do
                local s2 = s2_start + j * s2_step
                local A, B, C, D
                if fixed_var == "u" then
                    A = parametric_solid(fixed_val, s1, s2)
                    B = parametric_solid(fixed_val, s1 + s1_step, s2)
                    C = parametric_solid(fixed_val, s1 + s1_step, s2 + s2_step)
                    D = parametric_solid(fixed_val, s1, s2 + s2_step)
                elseif fixed_var == "v" then
                    A = parametric_solid(s1, fixed_val, s2)
                    B = parametric_solid(s1 + s1_step, fixed_val, s2)
                    C = parametric_solid(s1 + s1_step, fixed_val, s2 + s2_step)
                    D = parametric_solid(s1, fixed_val, s2 + s2_step)
                elseif fixed_var == "w" then
                    A = parametric_solid(s1, s2, fixed_val)
                    B = parametric_solid(s1 + s1_step, s2, fixed_val)
                    C = parametric_solid(s1 + s1_step, s2 + s2_step, fixed_val)
                    D = parametric_solid(s1, s2 + s2_step, fixed_val)
                end
                if A and B and D then
                    local simplex = project_simplex(
                        Matrix:_new{A:to_table(), B:to_table(), D:to_table()},
                        transformation,
                        "appendsolid.face1"
                    )
                    if simplex then
                        push_simplex({
                            simplex     = simplex,
                            filloptions = filloptions,
                            type        = "triangle",
                            filter      = filter
                        })
                    end
                end
                if B and C and D then
                    local simplex = project_simplex(
                        Matrix:_new{B:to_table(), C:to_table(), D:to_table()},
                        transformation,
                        "appendsolid.face2"
                    )
                    if simplex then
                        push_simplex({
                            simplex     = simplex,
                            filloptions = filloptions,
                            type        = "triangle",
                            filter      = filter
                        })
                    end
                end
            end
        end
    end

    tessellate_face("u", ustart, vstart, vstep, vsamples, wstart, wstep, wsamples)
    tessellate_face("u", ustop,  vstart, vstep, vsamples, wstart, wstep, wsamples)
    tessellate_face("v", vstart, ustart, ustep, usamples, wstart, wstep, wsamples)
    tessellate_face("v", vstop,  ustart, ustep, usamples, wstart, wstep, wsamples)
    tessellate_face("w", wstart, ustart, ustep, usamples, vstart, vstep, vsamples)
    tessellate_face("w", wstop,  ustart, ustep, usamples, vstart, vstep, vsamples)
end

-- ================================================================
-- Filters (sandboxed)
-- ================================================================

local function apply_filters(simplices)
    local new_simplices = {}

    for _, simplex in ipairs(simplices) do
        local bindings = {}

        if simplex.type == "point" then
            bindings.A = Vector:_new(simplex.simplex:to_table())
        elseif simplex.type == "line segment" then
            bindings.A = Vector:_new(simplex.simplex[1])
            bindings.B = Vector:_new(simplex.simplex[2])
        elseif simplex.type == "triangle" then
            bindings.A = Vector:_new(simplex.simplex[1])
            bindings.B = Vector:_new(simplex.simplex[2])
            bindings.C = Vector:_new(simplex.simplex[3])
        elseif simplex.type == "label" then
            bindings.A = Vector:_new(simplex.simplex:to_table())
        end

        local filter_body = simplex.filter or "return true"
        local filter_label = ("filter[%s]"):format(simplex.type)
        local filter_fn = evaluate_chunk(
            ("return function()\n%s\nend"):format(filter_body),
            filter_label,
            bindings
        )

        if wrap_user_function(filter_fn, filter_label, filter_body)() then
            table.insert(new_simplices, simplex)
        end
    end

    return new_simplices
end

-- ================================================================
-- Display / render
-- ================================================================

local function display_simplices()
    print("Time:" .. os.date("%X") .. " Displaying " .. #lua_tikz3dtools.simplices .. " simplices.")

    -- Pre-compute bbox2 for all simplices
    for _, s in ipairs(lua_tikz3dtools.simplices) do
        if s.type ~= "point" and s.type ~= "label" then
            s.bbox2 = s.simplex:get_bbox2()
        end
    end

    lua_tikz3dtools.simplices = Geometry.partition_simplices_by_parents(
        lua_tikz3dtools.simplices,
        lua_tikz3dtools.simplices
    )
    print("Time:" .. os.date("%X") .. " After partitioning, " .. #lua_tikz3dtools.simplices .. " simplices remain.")

    lua_tikz3dtools.simplices = apply_filters(lua_tikz3dtools.simplices)
    print("Time:" .. os.date("%X") .. " After filtering, " .. #lua_tikz3dtools.simplices .. " simplices remain.")

    -- Re-compute bbox2 after filtering (some simplices removed)
    for _, s in ipairs(lua_tikz3dtools.simplices) do
        if s.type ~= "point" and s.type ~= "label" and not s.bbox2 then
            s.bbox2 = s.simplex:get_bbox2()
        end
    end

    lua_tikz3dtools.simplices = Geometry.scc(lua_tikz3dtools.simplices)
    print("Time:" .. os.date("%X") .. " After occlusion sorting, " .. #lua_tikz3dtools.simplices .. " simplices remain.")

    local labels = {}
    for _, simplex in ipairs(lua_tikz3dtools.simplices) do
        if simplex.type == "point" then
            tex.sprint(
                ("\\path[%s] (%f,%f) circle[radius = 0.06];")
                :format(simplex.filloptions, simplex.simplex[1], simplex.simplex[2])
            )
        elseif simplex.type == "line segment" then
            tex.sprint(
                ("\\path[%s] (%f,%f) -- (%f,%f);")
                :format(
                    simplex.drawoptions,
                    simplex.simplex[1][1], simplex.simplex[1][2],
                    simplex.simplex[2][1], simplex.simplex[2][2]
                )
            )
        elseif simplex.type == "triangle" then
            local num_lights = #lua_tikz3dtools.lights
            if num_lights > 0 then
                local A = Vector:_new(simplex.simplex[1])
                local B = Vector:_new(simplex.simplex[2])
                local C = Vector:_new(simplex.simplex[3])
                local normal = (B:hsub(A)):hhypercross(C:hsub(A)):hnormalize()

                local total_intensity = 0
                for _, light in ipairs(lua_tikz3dtools.lights) do
                    local light_dir = light:hnormalize()
                    local cos_theta = math.abs(normal:hinner(light_dir))
                    if cos_theta > 1 then cos_theta = 1 end
                    -- Linear falloff: 0° → 1.0, 90° → 0.0
                    local theta = math.deg(math.acos(cos_theta))
                    total_intensity = total_intensity + (1 - theta / 90)
                end

                local avg_intensity = math.floor((total_intensity / num_lights) * 100 + 0.01)
                tex.sprint(("\\colorlet{ltdtbrightness}{white!%f!black}"):format(avg_intensity))
            else
                tex.sprint(("\\colorlet{ltdtbrightness}{white!%f!black}"):format(0))
            end

            tex.sprint(
                ("\\path[%s] (%f,%f) -- (%f,%f) -- (%f,%f) -- cycle;")
                :format(
                    simplex.filloptions,
                    simplex.simplex[1][1], simplex.simplex[1][2],
                    simplex.simplex[2][1], simplex.simplex[2][2],
                    simplex.simplex[3][1], simplex.simplex[3][2]
                )
            )
            render_embedded_segments(simplex)
        elseif simplex.type == "label" then
            table.insert(labels, simplex)
        end
    end

    for _, simplex in ipairs(labels) do
        tex.sprint(
            ("\\node at (%f,%f) {%s};")
            :format(simplex.simplex[1], simplex.simplex[2], simplex.text)
        )
    end

    lua_tikz3dtools.simplices = {}
    lua_tikz3dtools.lights = {}
end

-- ================================================================
-- set_object
-- ================================================================

local function set_object(hash)
    local object = object_expression(hash.object, "setobject.object")
    local name = hash.name

    assert(type(name) == "string" and name ~= "", "setobject.name must be a non-empty string")
    assert(lua_tikz3dtools.base_env[name] == nil,
        ("setobject.name '%s' is reserved and cannot be rebound"):format(name))

    lua_tikz3dtools.objects[name] = object
    return object
end

-- ================================================================
-- Register all TeX commands
-- ================================================================

--- Read a TeX macro, returning a fallback if undefined.
local function get_macro_or(name, fallback)
    local val = token.get_macro(name)
    if val == nil or val == "" then return fallback end
    return val
end

local function get_axis_params_or_legacy(prefix, axis, fallback)
    local params_value = token.get_macro(prefix .. "@" .. axis .. "params")
    if params_value ~= nil and params_value ~= "" then
        return params_value
    end

    local start = token.get_macro(prefix .. "@" .. axis .. "start")
    local stop = token.get_macro(prefix .. "@" .. axis .. "stop")
    local samples = token.get_macro(prefix .. "@" .. axis .. "samples")
    local has_legacy = (start ~= nil and start ~= "")
        or (stop ~= nil and stop ~= "")
        or (samples ~= nil and samples ~= "")

    if not has_legacy then
        return fallback
    end

    assert(start ~= nil and start ~= "", axis .. "start must be set when using legacy " .. axis .. "* keys")
    assert(stop ~= nil and stop ~= "", axis .. "stop must be set when using legacy " .. axis .. "* keys")
    assert(samples ~= nil and samples ~= "", axis .. "samples must be set when using legacy " .. axis .. "* keys")

    return ("return Vector:new{%s,%s,%s}"):format(start, stop, samples)
end

function Scene.register_commands()
    register_tex_cmd("appendpoint", function()
        append_point{
            v              = token.get_macro("luatikztdtools@p@p@v"),
            filloptions    = get_macro_or("luatikztdtools@p@p@filloptions", ""),
            transformation = get_macro_or("luatikztdtools@p@p@transformation", "return Matrix.identity()"),
            filter         = get_macro_or("luatikztdtools@p@p@filter", "return true")
        }
    end, { })

    register_tex_cmd("appendsurface", function()
        append_surface{
            uparams        = get_axis_params_or_legacy("luatikztdtools@p@s", "u", "return Vector:new{0,1,10}"),
            vparams        = get_axis_params_or_legacy("luatikztdtools@p@s", "v", "return Vector:new{0,1,10}"),
            v              = token.get_macro("luatikztdtools@p@s@v"),
            curve          = token.get_macro("luatikztdtools@p@s@curve"),
            transformation = get_macro_or("luatikztdtools@p@s@transformation", "return Matrix.identity()"),
            filloptions    = get_macro_or("luatikztdtools@p@s@filloptions", ""),
            filter         = get_macro_or("luatikztdtools@p@s@filter", "return true"),
        }
    end, { })

    register_tex_cmd("appendtriangle", function()
        append_triangle{
            m              = token.get_macro("luatikztdtools@p@t@m"),
            transformation = get_macro_or("luatikztdtools@p@t@transformation", "return Matrix.identity()"),
            filloptions    = get_macro_or("luatikztdtools@p@t@filloptions", ""),
            filter         = get_macro_or("luatikztdtools@p@t@filter", "return true"),
        }
    end, { })

    register_tex_cmd("appendlabel", function()
        append_label{
            v              = token.get_macro("luatikztdtools@p@l@v"),
            text           = token.get_macro("luatikztdtools@p@l@text"),
            transformation = get_macro_or("luatikztdtools@p@l@transformation", "return Matrix.identity()"),
            filter         = get_macro_or("luatikztdtools@p@l@filter", "return true")
        }
    end, { })

    register_tex_cmd("appendlight", function()
        append_light{
            v = token.get_macro("luatikztdtools@p@la@v")
        }
    end, { })

    register_tex_cmd("appendcurve", function()
        append_curve{
            uparams        = get_axis_params_or_legacy("luatikztdtools@p@c", "u", "return Vector:new{0,1,10}"),
            v              = token.get_macro("luatikztdtools@p@c@v"),
            transformation = get_macro_or("luatikztdtools@p@c@transformation", "return Matrix.identity()"),
            drawoptions    = get_macro_or("luatikztdtools@p@c@drawoptions", ""),
            arrowtip       = token.get_macro("luatikztdtools@p@c@arrowtip"),
            arrowtail      = token.get_macro("luatikztdtools@p@c@arrowtail"),
            filter         = get_macro_or("luatikztdtools@p@c@filter", "return true")
        }
    end, { })

    register_tex_cmd("appendsolid", function()
        append_solid{
            uparams        = get_axis_params_or_legacy("luatikztdtools@p@solid", "u", "return Vector:new{0,1,10}"),
            vparams        = get_axis_params_or_legacy("luatikztdtools@p@solid", "v", "return Vector:new{0,1,10}"),
            wparams        = get_axis_params_or_legacy("luatikztdtools@p@solid", "w", "return Vector:new{0,1,10}"),
            v              = token.get_macro("luatikztdtools@p@solid@v"),
            transformation = get_macro_or("luatikztdtools@p@solid@transformation", "return Matrix.identity()"),
            filloptions    = get_macro_or("luatikztdtools@p@solid@filloptions", ""),
            filter         = get_macro_or("luatikztdtools@p@solid@filter", "return true")
        }
    end, { })

    register_tex_cmd("displaysimplices", function()
        display_simplices()
    end, { })

    register_tex_cmd("setobject", function()
        set_object{
            name   = token.get_macro("luatikztdtools@p@m@name"),
            object = token.get_macro("luatikztdtools@p@m@object"),
        }
    end, { })
end

return Scene
