/*
 * SPDX-FileCopyrightText: 2020 Stalwart Labs LLC <hello@stalw.art>
 *
 * SPDX-License-Identifier: Apache-2.0 OR MIT
 */

use std::{borrow::Cow, slice::Iter};

use crate::{Error, Request};

pub const MAX_LINE_LENGTH: usize = 4096;

#[derive(Default)]
pub struct RequestReceiver {
    buf: Vec<u8>,
    buf_used: bool,
}

pub struct DataReceiver {
    crlf_dot: bool,
    last_ch: u8,
    prev_last_ch: u8,
}

pub struct BdatReceiver {
    pub is_last: bool,
    bytes_left: usize,
}

pub struct DummyDataReceiver {
    is_bdat: bool,
    bdat_bytes_left: usize,
    crlf_dot: bool,
    last_ch: u8,
    prev_last_ch: u8,
}

#[derive(Default)]
pub struct DummyLineReceiver {}

#[derive(Default)]
pub struct LineReceiver<T> {
    pub buf: Vec<u8>,
    pub state: T,
}

impl RequestReceiver {
    pub fn buf(&mut self) -> &mut Vec<u8> {
        if self.buf_used {
            self.buf.clear();
            self.buf_used = false;
        }

        &mut self.buf
    }

    pub fn ingest<'this, 'bytes, 'out>(
        &'this mut self,
        bytes: &mut Iter<'bytes, u8>,
    ) -> Result<Request<Cow<'out, str>>, Error>
    where
        'this: 'out,
        'bytes: 'out,
    {
        self.buf();

        if self.buf.is_empty() {
            let buf = bytes.as_slice();
            match Request::parse(bytes) {
                Err(Error::NeedsMoreData { bytes_left }) => {
                    if bytes_left > 0 {
                        if bytes_left < MAX_LINE_LENGTH {
                            self.buf = buf[buf.len().saturating_sub(bytes_left)..].to_vec();
                        } else {
                            return Err(Error::ResponseTooLong);
                        }
                    }
                }
                result => return result,
            }
        } else {
            for &ch in bytes {
                self.buf.push(ch);
                if ch == b'\n' {
                    self.buf_used = true;
                    return Request::parse(&mut self.buf.iter());
                } else if self.buf.len() == MAX_LINE_LENGTH {
                    self.buf.clear();
                    return Err(Error::ResponseTooLong);
                }
            }
        }

        Err(Error::NeedsMoreData { bytes_left: 0 })
    }
}

impl DataReceiver {
    #[allow(clippy::new_without_default)]
    pub fn new() -> Self {
        Self {
            crlf_dot: false,
            last_ch: 0,
            prev_last_ch: 0,
        }
    }

    pub fn ingest(&mut self, bytes: &mut Iter<'_, u8>, buf: &mut Vec<u8>) -> bool {
        for &ch in bytes {
            match ch {
                b'.' if self.last_ch == b'\n' && self.prev_last_ch == b'\r' => {
                    self.crlf_dot = true;
                }
                b'\n' if self.crlf_dot && self.last_ch == b'\r' => {
                    buf.truncate(buf.len() - 3);
                    return true;
                }
                b'\r' => {
                    buf.push(ch);
                }
                _ => {
                    buf.push(ch);
                    self.crlf_dot = false;
                }
            }
            self.prev_last_ch = self.last_ch;
            self.last_ch = ch;
        }

        false
    }
}

impl BdatReceiver {
    pub fn new(chunk_size: usize, is_last: bool) -> Self {
        Self {
            bytes_left: chunk_size,
            is_last,
        }
    }

    pub fn ingest(&mut self, bytes: &mut Iter<'_, u8>, buf: &mut Vec<u8>) -> bool {
        while self.bytes_left > 0 {
            if let Some(&ch) = bytes.next() {
                buf.push(ch);
                self.bytes_left -= 1;
            } else {
                return false;
            }
        }
        true
    }
}

impl DummyDataReceiver {
    pub fn new_bdat(chunk_size: usize) -> Self {
        Self {
            bdat_bytes_left: chunk_size,
            is_bdat: true,
            crlf_dot: false,
            last_ch: 0,
            prev_last_ch: 0,
        }
    }

    pub fn new_data(data: &DataReceiver) -> Self {
        Self {
            is_bdat: false,
            bdat_bytes_left: 0,
            crlf_dot: data.crlf_dot,
            last_ch: data.last_ch,
            prev_last_ch: data.prev_last_ch,
        }
    }

    pub fn ingest(&mut self, bytes: &mut Iter<'_, u8>) -> bool {
        if !self.is_bdat {
            for &ch in bytes {
                match ch {
                    b'.' if self.last_ch == b'\n' && self.prev_last_ch == b'\r' => {
                        self.crlf_dot = true;
                    }
                    b'\n' if self.crlf_dot && self.last_ch == b'\r' => {
                        return true;
                    }
                    b'\r' => {}
                    _ => {
                        self.crlf_dot = false;
                    }
                }
                self.prev_last_ch = self.last_ch;
                self.last_ch = ch;
            }

            false
        } else {
            while self.bdat_bytes_left > 0 {
                if bytes.next().is_some() {
                    self.bdat_bytes_left -= 1;
                } else {
                    return false;
                }
            }

            true
        }
    }
}

impl<T> LineReceiver<T> {
    pub fn new(state: T) -> Self {
        Self {
            buf: Vec::with_capacity(32),
            state,
        }
    }

    pub fn ingest(&mut self, bytes: &mut Iter<'_, u8>) -> bool {
        for &ch in bytes {
            match ch {
                b'\n' => return true,
                b'\r' => (),
                _ => {
                    if self.buf.len() < MAX_LINE_LENGTH {
                        self.buf.push(ch);
                    }
                }
            }
        }
        false
    }
}

impl DummyLineReceiver {
    pub fn ingest(&mut self, bytes: &mut Iter<'_, u8>) -> bool {
        for &ch in bytes {
            if ch == b'\n' {
                return true;
            }
        }
        false
    }
}

#[cfg(test)]
mod tests {
    use super::DataReceiver;
    use crate::{Error, MailFrom, RcptTo, Request, request::receiver::RequestReceiver};

    #[test]
    fn data_receiver() {
        'outer: for (data, message) in [
            (
                vec!["hi\r\n", "..\r\n", ".a\r\n", "\r\n.\r\n"],
                "hi\r\n.\r\na\r\n",
            ),
            (
                vec!["\r\na\rb\nc\r\n.d\r\n..\r\n", "\r\n.\r\n"],
                "\r\na\rb\nc\r\nd\r\n.\r\n",
            ),
            // Test SMTP smuggling attempts
            (
                vec![
                    "\n.\r\n",
                    "MAIL FROM:<hello@world.com>\r\n",
                    "RCPT TO:<test@domain.com\r\n",
                    "DATA\r\n",
                    "\r\n.\r\n",
                ],
                concat!(
                    "\n.\r\n",
                    "MAIL FROM:<hello@world.com>\r\n",
                    "RCPT TO:<test@domain.com\r\n",
                    "DATA\r\n",
                ),
            ),
            (
                vec![
                    "\n.\n",
                    "MAIL FROM:<hello@world.com>\r\n",
                    "RCPT TO:<test@domain.com\r\n",
                    "DATA\r\n",
                    "\r\n.\r\n",
                ],
                concat!(
                    "\n.\n",
                    "MAIL FROM:<hello@world.com>\r\n",
                    "RCPT TO:<test@domain.com\r\n",
                    "DATA\r\n",
                ),
            ),
            (
                vec![
                    "\r.\r\n",
                    "MAIL FROM:<hello@world.com>\r\n",
                    "RCPT TO:<test@domain.com\r\n",
                    "DATA\r\n",
                    "\r\n.\r\n",
                ],
                concat!(
                    "\r.\r\n",
                    "MAIL FROM:<hello@world.com>\r\n",
                    "RCPT TO:<test@domain.com\r\n",
                    "DATA\r\n",
                ),
            ),
            (
                vec![
                    "\r.\r",
                    "MAIL FROM:<hello@world.com>\r\n",
                    "RCPT TO:<test@domain.com\r\n",
                    "DATA\r\n",
                    "\r\n.\r\n",
                ],
                concat!(
                    "\r.\r",
                    "MAIL FROM:<hello@world.com>\r\n",
                    "RCPT TO:<test@domain.com\r\n",
                    "DATA\r\n",
                ),
            ),
        ] {
            let mut r = DataReceiver::new();
            let mut buf = Vec::new();
            for data in &data {
                if r.ingest(&mut data.as_bytes().iter(), &mut buf) {
                    assert_eq!(message, String::from_utf8(buf).unwrap());
                    continue 'outer;
                }
            }
            panic!("Failed for {data:?}");
        }
    }

    #[test]
    fn request_receiver() {
        for (data, expected_requests) in [
            (
                vec![
                    "data\n",
                    "start",
                    "tls\n",
                    "quit\nnoop",
                    " hello\nehlo test\nvrfy name\n",
                    "mail from:<hello",
                    "@world.com>\nrcpt to:<",
                    "test@domain.com>\n",
                ],
                vec![
                    Request::Data,
                    Request::StartTls,
                    Request::Quit,
                    Request::Noop {
                        value: "hello".to_string(),
                    },
                    Request::Ehlo {
                        host: "test".to_string(),
                    },
                    Request::Vrfy {
                        value: "name".to_string(),
                    },
                    Request::Mail {
                        from: MailFrom {
                            address: "hello@world.com".to_string(),
                            flags: 0,
                            size: 0,
                            trans_id: None,
                            by: 0,
                            env_id: None,
                            solicit: None,
                            mtrk: None,
                            auth: None,
                            hold_for: 0,
                            hold_until: 0,
                            mt_priority: 0,
                        },
                    },
                    Request::Rcpt {
                        to: RcptTo {
                            address: "test@domain.com".to_string(),
                            orcpt: None,
                            rrvs: 0,
                            flags: 0,
                        },
                    },
                ],
            ),
            (
                vec!["d", "a", "t", "a", "\n", "quit", "\n"],
                vec![Request::Data, Request::Quit],
            ),
        ] {
            let mut requests = Vec::new();
            let mut r = RequestReceiver::default();
            for data in &data {
                let mut bytes = data.as_bytes().iter();
                loop {
                    match r.ingest(&mut bytes) {
                        Ok(request) => {
                            requests.push(request.into_owned());
                            continue;
                        }
                        Err(Error::NeedsMoreData { .. }) => {
                            break;
                        }
                        err => panic!("Unexpected error for {data:?}: {err:?}"),
                    }
                }
            }
            assert_eq!(expected_requests, requests);
        }
    }
}
