// Copyright 2020-2025 Consensys Software Inc.
// Licensed under the Apache License, Version 2.0. See the LICENSE file for details.

// Code generated by consensys/gnark-crypto DO NOT EDIT

package poseidon2

import (
	"errors"
	"fmt"

	"golang.org/x/crypto/sha3"

	"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
)

var (
	ErrInvalidSizebuffer = errors.New("the size of the input should match the size of the hash buffer")
)

// reference implementation: https://github.com/HorizenLabs/poseidon2/blob/main/plain_implementations/src/poseidon2/poseidon2.rs
// specifications: https://github.com/argumentcomputer/neptune/blob/main/spec/poseidon_spec.pdf
// original paper: https://eprint.iacr.org/2023/323.pdf
const (
	// d is the degree of the sBox
	d = 17
)

// DegreeSBox returns the degree of the sBox function used in the Poseidon2
// permutation.
func DegreeSBox() int {
	return d
}

// Parameters describing the Poseidon2 implementation. Use [NewParameters] or
// [NewParametersWithSeed] to initialize a new set of parameters to
// deterministically precompute the round keys.
type Parameters struct {
	// len(preimage)+len(digest)=len(preimage)+ceil(log(2*<security_level>/r))
	Width int

	// number of full rounds (even number)
	NbFullRounds int

	// number of partial rounds
	NbPartialRounds int

	// derived round keys from the parameter seed and curve ID
	RoundKeys [][]fr.Element
}

// NewParameters returns a new set of parameters for the Poseidon2 permutation.
// After creating the parameters, the round keys are initialized deterministically
// from the seed which is a digest of the parameters and curve ID.
func NewParameters(width, nbFullRounds, nbPartialRounds int) *Parameters {
	p := Parameters{Width: width, NbFullRounds: nbFullRounds, NbPartialRounds: nbPartialRounds}
	seed := p.String()
	p.initRC(seed)
	return &p
}

// NewParametersWithSeed returns a new set of parameters for the Poseidon2 permutation.
// After creating the parameters, the round keys are initialized deterministically
// from the given seed.
func NewParametersWithSeed(width, nbFullRounds, nbPartialRounds int, seed string) *Parameters {
	p := Parameters{Width: width, NbFullRounds: nbFullRounds, NbPartialRounds: nbPartialRounds}
	p.initRC(seed)
	return &p
}

// String returns a string representation of the parameters. It is unique for
// specific parameters and curve.
func (p *Parameters) String() string {
	return fmt.Sprintf("Poseidon2-BLS12_377[t=%d,rF=%d,rP=%d,d=%d]", p.Width, p.NbFullRounds, p.NbPartialRounds, d)
}

// initRC initiate round keys. Only one entry is non zero for the internal
// rounds, cf https://eprint.iacr.org/2023/323.pdf page 9
func (p *Parameters) initRC(seed string) {

	bseed := ([]byte)(seed)
	hash := sha3.NewLegacyKeccak256()
	_, _ = hash.Write(bseed)
	rnd := hash.Sum(nil) // pre hash before use
	hash.Reset()
	_, _ = hash.Write(rnd)

	roundKeys := make([][]fr.Element, p.NbFullRounds+p.NbPartialRounds)
	for i := 0; i < p.NbFullRounds/2; i++ {
		roundKeys[i] = make([]fr.Element, p.Width)
		for j := 0; j < p.Width; j++ {
			rnd = hash.Sum(nil)
			roundKeys[i][j].SetBytes(rnd)
			hash.Reset()
			_, _ = hash.Write(rnd)
		}
	}
	for i := p.NbFullRounds / 2; i < p.NbPartialRounds+p.NbFullRounds/2; i++ {
		roundKeys[i] = make([]fr.Element, 1)
		rnd = hash.Sum(nil)
		roundKeys[i][0].SetBytes(rnd)
		hash.Reset()
		_, _ = hash.Write(rnd)
	}
	for i := p.NbPartialRounds + p.NbFullRounds/2; i < p.NbPartialRounds+p.NbFullRounds; i++ {
		roundKeys[i] = make([]fr.Element, p.Width)
		for j := 0; j < p.Width; j++ {
			rnd = hash.Sum(nil)
			roundKeys[i][j].SetBytes(rnd)
			hash.Reset()
			_, _ = hash.Write(rnd)
		}
	}
	p.RoundKeys = roundKeys
}

// Permutation stores the buffer of the Poseidon2 permutation and provides
// Poseidon2 permutation methods on the buffer
type Permutation struct {
	// params parameters describing the instance
	params *Parameters
}

// NewPermutation returns a new Poseidon2 permutation instance.
func NewPermutation(t, rf, rp int) *Permutation {
	if t < 2 || t > 3 {
		panic("only t=2,3 is supported")
	}
	params := NewParameters(t, rf, rp)
	res := &Permutation{params: params}
	return res
}

// NewPermutationWithSeed returns a new Poseidon2 permutation instance with a
// given seed.
func NewPermutationWithSeed(t, rf, rp int, seed string) *Permutation {
	if t < 2 || t > 3 {
		panic("only t=2,3 is supported")
	}
	params := NewParametersWithSeed(t, rf, rp, seed)
	res := &Permutation{params: params}
	return res
}

// sBox applies the sBox on buffer[index]
func (h *Permutation) sBox(index int, input []fr.Element) {
	var tmp fr.Element
	tmp.Set(&input[index])

	// sbox degree is 17
	input[index].Square(&input[index]).
		Square(&input[index]).
		Square(&input[index]).
		Square(&input[index]).
		Mul(&input[index], &tmp)

}

// when T=2,3 the buffer is multiplied by circ(2,1) and circ(2,1,1)
// see https://eprint.iacr.org/2023/323.pdf page 15, case T=2,3
func (h *Permutation) matMulExternalInPlace(input []fr.Element) {

	if h.params.Width == 2 {
		var tmp fr.Element
		tmp.Add(&input[0], &input[1])
		input[0].Add(&tmp, &input[0])
		input[1].Add(&tmp, &input[1])
	} else if h.params.Width == 3 {
		var tmp fr.Element
		tmp.Add(&input[0], &input[1]).
			Add(&tmp, &input[2])
		input[0].Add(&tmp, &input[0])
		input[1].Add(&tmp, &input[1])
		input[2].Add(&tmp, &input[2])
	} else {
		panic("only Width=2,3 are supported")
	}
}

// when T=2,3 the matrix are respectibely [[2,1][1,3]] and [[2,1,1][1,2,1][1,1,3]]
// otherwise the matrix is filled with ones except on the diagonal.
func (h *Permutation) matMulInternalInPlace(input []fr.Element) {
	switch h.params.Width {
	case 2:
		var sum fr.Element
		sum.Add(&input[0], &input[1])
		input[0].Add(&input[0], &sum)
		input[1].Double(&input[1]).Add(&input[1], &sum)
	case 3:
		var sum fr.Element
		sum.Add(&input[0], &input[1]).Add(&sum, &input[2])
		input[0].Add(&input[0], &sum)
		input[1].Add(&input[1], &sum)
		input[2].Double(&input[2]).Add(&input[2], &sum)
	default:
		panic("only T=2,3 is supported")
	}
}

// addRoundKeyInPlace adds the round-th key to the buffer
func (h *Permutation) addRoundKeyInPlace(round int, input []fr.Element) {
	for i := 0; i < len(h.params.RoundKeys[round]); i++ {
		input[i].Add(&input[i], &h.params.RoundKeys[round][i])
	}
}

func (h *Permutation) BlockSize() int {
	return fr.Bytes
}

// Permutation applies the permutation on input, and stores the result in input.
func (h *Permutation) Permutation(input []fr.Element) error {
	if len(input) != h.params.Width {
		return ErrInvalidSizebuffer
	}

	// external matrix multiplication, cf https://eprint.iacr.org/2023/323.pdf page 14 (part 6)
	h.matMulExternalInPlace(input)

	rf := h.params.NbFullRounds / 2
	for i := 0; i < rf; i++ {
		// one round = matMulExternal(sBox_Full(addRoundKey))
		h.addRoundKeyInPlace(i, input)
		for j := 0; j < h.params.Width; j++ {
			h.sBox(j, input)
		}
		h.matMulExternalInPlace(input)
	}

	for i := rf; i < rf+h.params.NbPartialRounds; i++ {
		// one round = matMulInternal(sBox_sparse(addRoundKey))
		h.addRoundKeyInPlace(i, input)
		h.sBox(0, input)
		h.matMulInternalInPlace(input)
	}
	for i := rf + h.params.NbPartialRounds; i < h.params.NbFullRounds+h.params.NbPartialRounds; i++ {
		// one round = matMulExternal(sBox_Full(addRoundKey))
		h.addRoundKeyInPlace(i, input)
		for j := 0; j < h.params.Width; j++ {
			h.sBox(j, input)
		}
		h.matMulExternalInPlace(input)
	}

	return nil
}

// Compress uses the permutation to compress the left and right input in a collision resistant manner.
// Returns an error if the permutation instance is not initialized with a width of 2.
func (h *Permutation) Compress(left []byte, right []byte) ([]byte, error) {
	if h.params.Width != 2 {
		return nil, errors.New("need a 2-1 function")
	}
	var x [2]fr.Element

	if err := x[0].SetBytesCanonical(left); err != nil {
		return nil, err
	}
	if err := x[1].SetBytesCanonical(right); err != nil {
		return nil, err
	}
	res := x[1] // save right to feed forward later
	if err := h.Permutation(x[:]); err != nil {
		return nil, err
	}
	res.Add(&res, &x[1])
	return res.Marshal(), nil
}
