--- lua-tikz3dtools-matrix.lua
--- Matrix class for the lua-tikz3dtools package.

local Vector -- set via Matrix._set_Vector()

local Matrix = {}
Matrix.__index = Matrix

--- Create a new Matrix object (validated).
--- @param matrix table The matrix data (row-major)
--- @return Matrix
function Matrix:new(matrix)
    assert(type(matrix) == "table", "A Matrix object must be a table of rows.")
    local l = #matrix
    local p = #matrix[1]
    for i = 1, l do
        assert(type(matrix[i]) == "table", "A Matrix row must be a table.")
        local m = #matrix[i]
        assert(m == p, "All Matrix rows must have the same number of columns.")
        for j = 1, m do
            local g = getmetatable(matrix[i][j]) == Vector
            assert(
                type(matrix[i][j]) == "number" or g,
                "A Matrix can only contain numbers or Vectors"
            )
            if g and i == 1 then
                for k = 2, l do
                    assert(#matrix[k][j] == #matrix[1][j], "All Vectors in a Matrix column must have the same size.")
                end
            end
        end
    end
    return setmetatable(matrix, Matrix)
end

--- Unchecked constructor — skips validation. Use for internal calls only.
--- @param matrix table
--- @return Matrix
function Matrix:_new(matrix)
    return setmetatable(matrix, Matrix)
end

--- Print a Matrix object
--- @return string
function Matrix:__tostring()
    local s = "Matrix{\n"
    local l = #self
    for i = 1, l do
        s = s .. "\t{"
        local m = #self[i]
        for j = 1, m do
            if type(self[i][j]) == "number" then
                s = s .. string.format("%.12f", self[i][j])
            else
                s = s .. tostring(self[i][j])
            end
            if j < m then s = s .. "," end
        end
        s = s .. "}"
        if i < l then s = s .. ",\n" end
    end
    s = s .. "\n}"
    return s
end

--- swap two columns of the Matrix
--- @param c1 number The first column
--- @param c2 number The second column
function Matrix:swap_columns(c1, c2)
    local l = #self
    for i = 1, l do
        local t = self[i][c1]
        self[i][c1] = self[i][c2]
        self[i][c2] = t
    end
end

--- add a scalar multiple of one column to another column
--- @param src_col number The source column
--- @param scalar number The scalar multiple
--- @param dest_col number The destination column
function Matrix:add_scalar_multiple_of_column(src_col, scalar, dest_col)
    local l = #self
    for i = 1, l do
        self[i][dest_col] = self[i][dest_col] + scalar * self[i][src_col]
    end
end

--- scale a column by a scalar
--- @param col number The column
--- @param scalar number The scalar
function Matrix:scale_column(col, scalar)
    local l = #self
    for i = 1, l do
        self[i][col] = self[i][col] * scalar
    end
end

--- Deep copy of Matrix object
--- @return Matrix A deep copy of self
function Matrix:deep_copy()
    local numrows = #self
    local numcols = #self[1]
    local deep_copy = {}
    for i = 1, numrows do
        deep_copy[i] = {}
        for j = 1, numcols do
            if getmetatable(self[i][j]) == Vector then
                deep_copy[i][j] = Vector:_new{table.unpack(self[i][j])}
            else
                deep_copy[i][j] = self[i][j]
            end
        end
    end
    return Matrix:_new(deep_copy)
end

--- Convert to plain Lua table (no metatables)
--- @return table
function Matrix:to_table()
    local numrows = #self
    local numcols = #self[1]
    local t = {}
    for i = 1, numrows do
        t[i] = {}
        for j = 1, numcols do
            if getmetatable(self[i][j]) == Vector then
                t[i][j] = self[i][j]:to_table()
            else
                t[i][j] = self[i][j]
            end
        end
    end
    return t
end

--- Column reduction of the Matrix to RREF form
--- @return Vector|nil The solution vector (last row of RREF)
function Matrix:column_reduction()
    local cols = self:deep_copy()
    local numcols = #cols
    if numcols == 0 then return nil end
    local numrows = #cols[1]
    local vars = numcols - 1
    if vars < 1 then return nil end
    local eps = 1e-12

    local rank = 0
    local row = 1
    local pivot_cols = {}

    for col = 1, vars do
        if row > numrows then break end

        local pivot_row = nil
        local maxval = eps
        for r = row, numrows do
            local value = math.abs(cols[col][r])
            if value > maxval then
                maxval = value
                pivot_row = r
            end
        end

        if pivot_row then
            if pivot_row ~= row then
                for c = 1, numcols do
                    cols[c][row], cols[c][pivot_row] = cols[c][pivot_row], cols[c][row]
                end
            end

            local pivot = cols[col][row]
            for c = 1, numcols do
                cols[c][row] = cols[c][row] / pivot
            end

            for r = 1, numrows do
                if r ~= row then
                    local factor = cols[col][r]
                    if math.abs(factor) > eps then
                        for c = 1, numcols do
                            cols[c][r] = cols[c][r] - factor * cols[c][row]
                        end
                        cols[col][r] = 0
                    end
                end
            end

            pivot_cols[#pivot_cols + 1] = col
            rank = rank + 1
            row = row + 1
        end
    end

    for r = rank + 1, numrows do
        local all_zero = true
        for c = 1, vars do
            if math.abs(cols[c][r]) > eps then
                all_zero = false
                break
            end
        end
        if all_zero and math.abs(cols[numcols][r]) > eps then
            return nil
        end
    end

    local sol = {}
    for i = 1, vars do
        sol[i] = 0
    end
    for k, pcol in ipairs(pivot_cols) do
        sol[pcol] = cols[numcols][k]
    end

    for _, v in ipairs(sol) do
        if v ~= v then return nil end
    end

    return Vector:_new(sol)
end

--- Transpose of Matrix object
--- @return Matrix The transposed matrix
function Matrix:transpose()
    local numrows = #self
    local numcols = #self[1]
    local transposed = {}
    for j = 1, numcols do
        transposed[j] = {}
        for i = 1, numrows do
            transposed[j][i] = self[i][j]
        end
    end
    return Matrix:_new(transposed)
end

--- Inverse of the Matrix (row-vector convention, column GJ)
--- @return Matrix|nil The inverse matrix, or nil if not invertible
function Matrix:inverse()
    local n = #self
    assert(n == #self[1], "Matrix must be square.")

    local eps = 1e-12

    local A = Matrix:_new(self:to_table())

    local L_table = {}
    for i = 1, n do
        L_table[i] = {}
        for j = 1, n do
            L_table[i][j] = (i == j) and 1 or 0
        end
    end
    local L = Matrix:_new(L_table)

    for pivot = 1, n do
        local pivot_col = nil
        local pivot_row = nil

        for c = pivot, n do
            for r = pivot, n do
                if math.abs(A[r][c]) > eps then
                    pivot_col = c
                    pivot_row = r
                    break
                end
            end
            if pivot_col then break end
        end

        if not pivot_col then
            return nil
        end

        if pivot_col ~= pivot then
            A:swap_columns(pivot_col, pivot)
            L:swap_columns(pivot_col, pivot)
        end

        local p = A[pivot][pivot]
        A:scale_column(pivot, 1 / p)
        L:scale_column(pivot, 1 / p)

        for c = 1, n do
            if c ~= pivot then
                local k = -A[pivot][c]
                if math.abs(k) > eps then
                    A:add_scalar_multiple_of_column(pivot, k, c)
                    L:add_scalar_multiple_of_column(pivot, k, c)
                end
            end
        end
    end

    return L
end

--- Convert to a basis from a simplex
--- @return Matrix The basis matrix
function Matrix:hto_basis()
    local numrows = #self
    local numcols = #self[1]
    local deep_copy = self:deep_copy()
    for i = 1, numrows do
        if i ~= 1 then
            for j = 1, numcols - 1 do
                deep_copy[i][j] = deep_copy[i][j] - deep_copy[1][j]
            end
        end
    end
    return deep_copy
end

--- Convert to a simplex from a basis
--- @return Matrix The simplex matrix
function Matrix:hto_simplex()
    local numrows = #self
    local numcols = #self[1]
    local deep_copy = self:deep_copy()
    for i = 1, numrows do
        if i ~= 1 then
            for j = 1, numcols - 1 do
                deep_copy[i][j] = deep_copy[i][j] + deep_copy[1][j]
            end
        end
    end
    return deep_copy
end

--- Compute the minor of the matrix by removing row ii and column jj
--- @param ii number The row to remove
--- @param jj number The column to remove
--- @return number|Vector The minor determinant
function Matrix:minor(ii, jj)
    local numrows = #self
    local M = {}
    local row_idx = 0
    for i = 1, numrows do
        if i ~= ii then
            row_idx = row_idx + 1
            M[row_idx] = {}
            for j = 1, #self[i] do
                if j ~= jj then
                    table.insert(M[row_idx], self[i][j])
                end
            end
        end
    end
    return Matrix:_new(M):det("row", 1)
end

--- Compute the determinant of the matrix by cofactor expansion along a row or column
--- @param switch string The row/column switch
--- @param pos number The position of the row or column
--- @return number|Vector The determinant
function Matrix:det(switch, pos)
    local numrows = #self
    local numcols = #self[1]
    assert(numrows == numcols, "You tried to take the determinant of a non-square matrix.")
    if numrows == 2 then
        local A_mt = getmetatable(self[1][1])
        local B_mt = getmetatable(self[2][1])
        if A_mt == Vector and B_mt == Vector then
            -- det = A * d - B * c  (column 1 has vectors)
            return self[1][1]:scale(self[2][2]):sub(self[2][1]:scale(self[1][2]))
        end
        A_mt = getmetatable(self[1][2])
        B_mt = getmetatable(self[2][2])
        if A_mt == Vector and B_mt == Vector then
            -- det = a * D - b * C  (column 2 has vectors)
            return self[1][2]:scale(self[2][1]):sub(self[2][2]:scale(self[1][1]))
        end
        return self[1][1] * self[2][2] - self[1][2] * self[2][1]
    end
    if switch == "row" then
        local i = pos
        local det = 0
        for j = 1, numcols do
            local sign = (-1)^(i + j)
            local minor = self:minor(i, j)
            local s = self[i][j]
            det = det + s * minor * sign
        end
        return det
    elseif switch == "column" then
        local j = pos
        local det
        if getmetatable(self[1][j]) == Vector then
            local t = {}
            for i = 1, numcols + 1 do t[i] = 0 end
            det = Vector:_new(t)
        else
            det = 0
        end
        for i = 1, numrows do
            local sign = (-1)^(i + j)
            local minor = self:minor(i, j)
            local s = self[i][j]
            if getmetatable(s) == Vector then
                det = det:add(s:scale(minor * sign))
            else
                det = det + s * minor * sign
            end
        end
        return det
    end
end

--- Homogeneous reciprocation of Matrix objects
--- @return Matrix The homogeneous reciprocation of self
function Matrix:reciprocate_by_homogeneous()
    local result = {}
    for i = 1, #self do
        result[i] = {}
        local w = self[i][#self[i]]
        for j = 1, #self[i] - 1 do
            result[i][j] = self[i][j] / w
        end
        result[i][#self[i]] = 1
    end
    return Matrix:_new(result)
end

--- Multiply two Matrix objects
--- @param other Matrix|Vector The RHS
--- @return Matrix|Vector The product of self and other
--- @param reciprocate boolean|nil If true, apply homogeneous reciprocation for non-4x4 results (default true for non-flag)
function Matrix:multiply(other)
    local flag
    if getmetatable(other) == Vector then
        other = Matrix:_new{other:to_table()}
        flag = true
    end
    local Arows = #self
    local Acols = #self[1]
    local Bcols = #other[1]
    local product = {}
    for row = 1, Arows do
        product[row] = {}
        for col = 1, Bcols do
            product[row][col] = 0
            for k = 1, Acols do
                product[row][col] = product[row][col] + self[row][k] * other[k][col]
            end
        end
    end
    if flag then
        return Vector:_new(product[1])
    else
        return Matrix:_new(product)
    end
end

--- Get 3D bounding box of Matrix points
--- @return table A table with 'min' and 'max' Vector objects
function Matrix:get_bbox3()
    local min_x, min_y, min_z = math.huge, math.huge, math.huge
    local max_x, max_y, max_z = -math.huge, -math.huge, -math.huge
    for _, row in ipairs(self) do
        if row[1] < min_x then min_x = row[1] end
        if row[2] < min_y then min_y = row[2] end
        if row[3] < min_z then min_z = row[3] end
        if row[1] > max_x then max_x = row[1] end
        if row[2] > max_y then max_y = row[2] end
        if row[3] > max_z then max_z = row[3] end
    end
    return {
        min = Vector:_new{min_x, min_y, min_z, 1},
        max = Vector:_new{max_x, max_y, max_z, 1}
    }
end

--- Get 2D bounding box of Matrix points
--- @return table A table with 'min' and 'max' Vector objects
function Matrix:get_bbox2()
    local min_x, min_y = math.huge, math.huge
    local max_x, max_y = -math.huge, -math.huge
    for _, row in ipairs(self) do
        if row[1] < min_x then min_x = row[1] end
        if row[2] < min_y then min_y = row[2] end
        if row[1] > max_x then max_x = row[1] end
        if row[2] > max_y then max_y = row[2] end
    end
    return {
        min = Vector:_new{min_x, min_y, 0, 1},
        max = Vector:_new{max_x, max_y, 0, 1}
    }
end

--- Check if two 3D bounding boxes overlap (using pre-computed bboxes when available)
--- @param other Matrix The other Matrix
--- @param a_bbox table|nil Pre-computed bbox for self
--- @param b_bbox table|nil Pre-computed bbox for other
--- @return boolean True if they overlap, false otherwise
function Matrix:bboxes_overlap3(other, a_bbox, b_bbox)
    a_bbox = a_bbox or self:get_bbox3()
    b_bbox = b_bbox or other:get_bbox3()
    return not (
        a_bbox.max[1] < b_bbox.min[1]
        or a_bbox.min[1] > b_bbox.max[1]
        or a_bbox.max[2] < b_bbox.min[2]
        or a_bbox.min[2] > b_bbox.max[2]
        or a_bbox.max[3] < b_bbox.min[3]
        or a_bbox.min[3] > b_bbox.max[3]
    )
end

--- Check if two 2D bounding boxes overlap (using pre-computed bboxes when available)
--- @param other Matrix The other Matrix
--- @param a_bbox table|nil Pre-computed bbox for self
--- @param b_bbox table|nil Pre-computed bbox for other
--- @return boolean True if they overlap, false otherwise
function Matrix:bboxes_overlap2(other, a_bbox, b_bbox)
    a_bbox = a_bbox or self:get_bbox2()
    b_bbox = b_bbox or other:get_bbox2()
    return not (
        a_bbox.max[1] < b_bbox.min[1]
        or a_bbox.min[1] > b_bbox.max[1]
        or a_bbox.max[2] < b_bbox.min[2]
        or a_bbox.min[2] > b_bbox.max[2]
    )
end

--- Compute the homogeneous centroid of the Matrix points
--- @return Vector The homogeneous centroid
function Matrix:hcentroid()
    local ns = #self
    local dim = #self[1]
    local centroid = Vector:_new(self[1])
    for i = 2, ns do
        centroid = centroid:hadd(Vector:_new(self[i]))
    end
    centroid = centroid:hscale(1 / ns)
    return centroid
end

--- Sort points in Matrix by angle around their homogeneous centroid
--- @return table The sorted points (as plain table, not Matrix)
function Matrix:hcentroid_sort()
    local num = #self
    local centroid = self:hcentroid()
    local I = self:deep_copy()
    local sum = Vector:_new{0,0,0,1}
    for i = 1, num do 
        sum = sum:add(Vector:_new(I[i]))
    end
    centroid = sum:scale(1/num)
    local P = Vector:_new(I[1])
    local u = Vector:_new(I[2]):hsub(P):hnormalize()
    local v = Vector:_new(I[3]):hsub(P)
    local normal = u:hhypercross(v):hnormalize()
    v = normal:hhypercross(u):hnormalize()
    local angles = {}
    for i, p in ipairs(I) do
        local rel = Vector:_new(p):hsub(centroid)
        local x = rel:hinner(u)
        local y = rel:hinner(v)
        local angle = math.atan2(y, x)
        angles[i] = {angle = angle, index = i}
    end
    table.sort(angles, function(a, b) return a.angle < b.angle end)
    local sorted = {}
    for i, a in ipairs(angles) do
        sorted[i] = I[a.index]
    end
    return sorted
end

--- Check if all elements in the Matrix are numeric
--- @return boolean True if all elements are numeric, false otherwise
function Matrix:is_numeric()
    for _, row in ipairs(self) do
        for _, val in ipairs(row) do
            if type(val) ~= "number" then
                return false
            end
        end
    end
    return true
end

--- 3D rotation about an arbitrary axis through the origin
--- @param axis Vector The axis direction
--- @param theta number The rotation angle in radians
--- @return Matrix The rotation matrix
function Matrix.axis_angle(axis, theta)
    assert(getmetatable(axis) == Vector, "axis must be a Vector.")
    axis = axis:hnormalize()
    local x = axis[1]
    local y = axis[2]
    local z = axis[3]
    local c = math.cos(theta)
    local s = math.sin(theta)
    local t = 1 - c
    return Matrix:_new{
        {t * x * x + c,     t * x * y + s * z, t * x * z - s * y, 0},
        {t * x * y - s * z, t * y * y + c,     t * y * z + s * x, 0},
        {t * x * z + s * y, t * y * z - s * x, t * z * z + c,     0},
        {0, 0, 0, 1}
    }
end

--- 3D shear matrix
--- @param kxy number|nil x shear from y
--- @param kxz number|nil x shear from z
--- @param kyx number|nil y shear from x
--- @param kyz number|nil y shear from z
--- @param kzx number|nil z shear from x
--- @param kzy number|nil z shear from y
--- @return Matrix The shear matrix
function Matrix.shear(kxy, kxz, kyx, kyz, kzx, kzy)
    kxy = kxy or 0
    kxz = kxz or 0
    kyx = kyx or 0
    kyz = kyz or 0
    kzx = kzx or 0
    kzy = kzy or 0
    return Matrix:_new{
        {1,   kyx, kzx, 0},
        {kxy, 1,   kzy, 0},
        {kxz, kyz, 1,   0},
        {0,   0,   0,   1}
    }
end

--- 3D reflection about an arbitrary axis through the origin
--- @param axis Vector The axis direction
--- @return Matrix The reflection matrix
function Matrix.reflect_axis(axis)
    assert(getmetatable(axis) == Vector, "axis must be a Vector.")
    axis = axis:hnormalize()
    local x = axis[1]
    local y = axis[2]
    local z = axis[3]
    return Matrix:_new{
        {2 * x * x - 1, 2 * x * y,     2 * x * z,     0},
        {2 * x * y,     2 * y * y - 1, 2 * y * z,     0},
        {2 * x * z,     2 * y * z,     2 * z * z - 1, 0},
        {0, 0, 0, 1}
    }
end



--- 3D translation matrix
--- @param delta Vector The translation vector
--- @return Matrix The translation matrix
function Matrix.translate(delta)
    assert(getmetatable(delta) == Vector, "delta must be a Vector.")
    return Matrix:_new{
        {1, 0, 0, 0},
        {0, 1, 0, 0},
        {0, 0, 1, 0},
        {delta[1], delta[2], delta[3], 1}
    }
end

--- 3D scaling matrix
--- @param scale Vector The scaling vector
--- @return Matrix The scaling matrix
function Matrix.scale_axis(scale)
    assert(getmetatable(scale) == Vector, "scale must be a Vector.")
    return Matrix:_new{
        {scale[1], 0, 0, 0},
        {0, scale[2], 0, 0},
        {0, 0, scale[3], 0},
        {0, 0, 0, 1}
    }
end


--- 3D ZYZ Euler rotation matrix
--- @param angles Vector The Euler angles (alpha, beta, gamma)
--- @return Matrix The rotation matrix
function Matrix.zyzrotation(angles)
    assert(getmetatable(angles) == Vector, "angles must be a Vector.")
    return Matrix.axis_angle(Vector:_new{0, 0, 1, 1}, angles[1])
        :multiply(Matrix.axis_angle(Vector:_new{0, 1, 0, 1}, angles[2]))
        :multiply(Matrix.axis_angle(Vector:_new{0, 0, 1, 1}, angles[3]))
end

-- Compatibility helpers retained for legacy documents.
function Matrix.xrotation3(theta)
    return Matrix.axis_angle(Vector:_new{1, 0, 0, 1}, theta)
end

function Matrix.yrotation3(theta)
    return Matrix.axis_angle(Vector:_new{0, 1, 0, 1}, theta)
end

function Matrix.zrotation3(theta)
    return Matrix.axis_angle(Vector:_new{0, 0, 1, 1}, theta)
end

function Matrix.translate3(dx, dy, dz)
    return Matrix.translate(Vector:_new{dx, dy, dz, 1})
end

function Matrix.scale3(sx, sy, sz)
    return Matrix.scale_axis(Vector:_new{sx, sy, sz, 1})
end

function Matrix.zyzrotation3(alpha, beta, gamma)
    return Matrix.zyzrotation(Vector:_new{alpha, beta, gamma})
end

--- identity matrix
--- @return Matrix The identity matrix
function Matrix.identity()
    return Matrix:_new{
        {1, 0, 0, 0},
        {0, 1, 0, 0},
        {0, 0, 1, 0},
        {0, 0, 0, 1}
    }
end

function Matrix.identity3()
    return Matrix.identity()
end

--- perspective along an arbitrary axis
--- @param axis Vector The perspective axis and strength
--- @return Matrix The perspective matrix
function Matrix.perspective(axis)
    assert(getmetatable(axis) == Vector, "axis must be a Vector.")
    return Matrix:_new{
        {1, 0, 0, axis[1]},
        {0, 1, 0, axis[2]},
        {0, 0, 1, axis[3]},
        {0, 0, 0, 1}
    }
end

--- apply a transformation about a point
--- @param point Vector The fixed point
--- @param transformation Matrix The transformation to apply
--- @return Matrix The composed transformation
function Matrix.transform_about(point, transformation)
    assert(getmetatable(point) == Vector, "point must be a Vector.")
    return Matrix.translate(Vector:_new{-point[1], -point[2], -point[3], 1})
        :multiply(transformation)
        :multiply(Matrix.translate(point))
end

--- Set the Vector class reference (breaks circular dependency).
--- @param V table The Vector class
function Matrix._set_Vector(V)
    Vector = V
end

return Matrix
