// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0

// ----------------------------------------------------------------------------
// Montgomery square, z := (x^2 / 2^256) mod p_sm2
// Input x[4]; output z[4]
//
//    extern void bignum_montsqr_sm2_alt
//     (uint64_t z[static 4], uint64_t x[static 4]);
//
// Does z := (x^2 / 2^256) mod p_sm2, assuming x^2 <= 2^256 * p_sm2, which is
// guaranteed in particular if x < p_sm2 initially (the "intended" case).
//
// Standard x86-64 ABI: RDI = z, RSI = x
// Microsoft x64 ABI:   RCX = z, RDX = x
// ----------------------------------------------------------------------------

#include "_internal_s2n_bignum.h"


        S2N_BN_SYM_VISIBILITY_DIRECTIVE(bignum_montsqr_sm2_alt)
        S2N_BN_SYM_PRIVACY_DIRECTIVE(bignum_montsqr_sm2_alt)
        .text

#define z %rdi
#define x %rsi

// Add %rbx * m into a register-pair (high,low) maintaining consistent
// carry-catching with carry (negated, as bitmask) and using %rax and %rdx
// as temporaries

#define mulpadd(carry,high,low,m)       \
        movq    m, %rax ;                 \
        mulq    %rbx;                    \
        subq    carry, %rdx ;             \
        addq    %rax, low ;               \
        adcq    %rdx, high ;              \
        sbbq    carry, carry

// Initial version assuming no carry-in

#define mulpadi(carry,high,low,m)       \
        movq    m, %rax ;                 \
        mulq    %rbx;                    \
        addq    %rax, low ;               \
        adcq    %rdx, high ;              \
        sbbq    carry, carry

// End version not catching the top carry-out

#define mulpade(carry,high,low,m)       \
        movq    m, %rax ;                 \
        mulq    %rbx;                    \
        subq    carry, %rdx ;             \
        addq    %rax, low ;               \
        adcq    %rdx, high

// ---------------------------------------------------------------------------
// Core one-step "short" Montgomery reduction macro. Takes input in
// [d3;d2;d1;d0] and returns result in [d0;d3;d2;d1], adding to the
// existing contents of [d3;d2;d1], and using %rax, %rcx, %rdx and %rbx
// as temporaries.
// ---------------------------------------------------------------------------

#define montreds(d3,d2,d1,d0)                                               \
        movq    d0, %rax ;                                                    \
        shlq    $32, %rax ;                                                    \
        movq    d0, %rcx ;                                                    \
        shrq    $32, %rcx ;                                                    \
        movq    %rax, %rdx ;                                                   \
        movq    %rcx, %rbx ;                                                   \
        subq    d0, %rax ;                                                    \
        sbbq    $0, %rcx ;                                                     \
        subq    %rax, d1 ;                                                    \
        sbbq    %rcx, d2 ;                                                    \
        sbbq    %rdx, d3 ;                                                    \
        sbbq    %rbx, d0

S2N_BN_SYMBOL(bignum_montsqr_sm2_alt):
        _CET_ENDBR

#if WINDOWS_ABI
        pushq   %rdi
        pushq   %rsi
        movq    %rcx, %rdi
        movq    %rdx, %rsi
#endif

// Save more registers to play with

        pushq   %rbx
        pushq   %r12
        pushq   %r13
        pushq   %r14
        pushq   %r15

// Compute [%r15;%r8] = [00] which we use later, but mainly
// set up an initial window [%r14;...;%r9] = [23;03;01]

        movq    (x), %rax
        movq    %rax, %rbx
        mulq    %rax
        movq    %rax, %r8
        movq    %rdx, %r15
        movq    8(x), %rax
        mulq    %rbx
        movq    %rax, %r9
        movq    %rdx, %r10
        movq    24(x), %rax
        movq    %rax, %r13
        mulq    %rbx
        movq    %rax, %r11
        movq    %rdx, %r12
        movq    16(x), %rax
        movq    %rax, %rbx
        mulq    %r13
        movq    %rax, %r13
        movq    %rdx, %r14

// Chain in the addition of 02 + 12 + 13 to that window (no carry-out possible)
// This gives all the "heterogeneous" terms of the squaring ready to double

        mulpadi(%rcx,%r11,%r10,(x))
        mulpadd(%rcx,%r12,%r11,8(x))
        movq    24(x), %rbx
        mulpade(%rcx,%r13,%r12,8(x))
        adcq    $0, %r14

// Double the window [%r14;...;%r9], catching top carry in %rcx

        xorl    %ecx, %ecx
        addq    %r9, %r9
        adcq    %r10, %r10
        adcq    %r11, %r11
        adcq    %r12, %r12
        adcq    %r13, %r13
        adcq    %r14, %r14
        adcq    %rcx, %rcx

// Add to the 00 + 11 + 22 + 33 terms

        movq    8(x), %rax
        mulq    %rax
        addq    %r15, %r9
        adcq    %rax, %r10
        adcq    %rdx, %r11
        sbbq    %r15, %r15
        movq    16(x), %rax
        mulq    %rax
        negq    %r15
        adcq    %rax, %r12
        adcq    %rdx, %r13
        sbbq    %r15, %r15
        movq    24(x), %rax
        mulq    %rax
        negq    %r15
        adcq    %rax, %r14
        adcq    %rcx, %rdx
        movq    %rdx, %r15

// Squaring complete. Perform 4 Montgomery steps to rotate the lower half

        montreds(%r11,%r10,%r9,%r8)
        montreds(%r8,%r11,%r10,%r9)
        montreds(%r9,%r8,%r11,%r10)
        montreds(%r10,%r9,%r8,%r11)

// Add high and low parts, catching carry in %rax

        xorl    %eax, %eax
        addq    %r8, %r12
        adcq    %r9, %r13
        adcq    %r10, %r14
        adcq    %r11, %r15
        adcq    %rax, %rax

// Load [%r8;%r11;%rbx;%rdx;%rcx] = 2^320 - p_sm2 then do
// [%r8;%r11;%rbx;%rdx;%rcx] = [%rax;%r15;%r14;%r13;%r12] + (2^320 - p_sm2)

        movl    $1, %ecx
        movl    $0x00000000FFFFFFFF, %edx
        xorl    %ebx, %ebx
        addq    %r12, %rcx
        leaq    1(%rdx), %r11
        adcq    %r13, %rdx
        leaq    -1(%rbx), %r8
        adcq    %r14, %rbx
        adcq    %r15, %r11
        adcq    %rax, %r8

// Now carry is set if r + (2^320 - p_sm2) >= 2^320, i.e. r >= p_sm2
// where r is the pre-reduced form. So conditionally select the
// output accordingly.

        cmovcq  %rcx, %r12
        cmovcq  %rdx, %r13
        cmovcq  %rbx, %r14
        cmovcq  %r11, %r15

// Write back reduced value

        movq    %r12, (z)
        movq    %r13, 8(z)
        movq    %r14, 16(z)
        movq    %r15, 24(z)

// Restore saved registers and return

        popq    %r15
        popq    %r14
        popq    %r13
        popq    %r12
        popq    %rbx

#if WINDOWS_ABI
        popq   %rsi
        popq   %rdi
#endif
        ret

#if defined(__linux__) && defined(__ELF__)
.section .note.GNU-stack,"",%progbits
#endif
