/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sshd.common.cipher;

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import javax.crypto.AEADBadTagException;
import org.apache.sshd.common.cipher.Cipher;
import org.apache.sshd.common.mac.Mac;
import org.apache.sshd.common.mac.Poly1305Mac;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.BufferUtils;

public class ChaCha20Cipher
implements Cipher {
    protected final ChaChaEngine headerEngine = new ChaChaEngine();
    protected final ChaChaEngine bodyEngine = new ChaChaEngine();
    protected final Mac mac = new Poly1305Mac();
    protected Cipher.Mode mode;

    @Override
    public String getAlgorithm() {
        return "ChaCha20";
    }

    @Override
    public void init(Cipher.Mode mode, byte[] key, byte[] iv) throws Exception {
        this.mode = mode;
        this.bodyEngine.initKey(Arrays.copyOfRange(key, 0, 32));
        this.bodyEngine.initNonce(iv);
        this.mac.init(this.bodyEngine.polyKey());
        this.headerEngine.initKey(Arrays.copyOfRange(key, 32, 64));
        this.headerEngine.initNonce(iv);
        this.headerEngine.initCounter(0L);
    }

    @Override
    public void updateAAD(byte[] data, int offset, int length) throws Exception {
        ValidateUtils.checkState(this.mode != null, "Cipher not initialized");
        ValidateUtils.checkTrue(length == 4, "AAD only supported for encrypted packet length");
        if (this.mode == Cipher.Mode.Decrypt) {
            this.mac.update(data, offset, length);
        }
        this.headerEngine.crypt(data, offset, length, data, offset);
        if (this.mode == Cipher.Mode.Encrypt) {
            this.mac.update(data, offset, length);
        }
    }

    @Override
    public void update(byte[] input, int inputOffset, int inputLen) throws Exception {
        ValidateUtils.checkState(this.mode != null, "Cipher not initialized");
        if (this.mode == Cipher.Mode.Decrypt) {
            this.mac.update(input, inputOffset, inputLen);
            byte[] actual = this.mac.doFinal();
            if (!Mac.equals(input, inputOffset + inputLen, actual, 0, actual.length)) {
                throw new AEADBadTagException("Tag mismatch");
            }
        }
        this.bodyEngine.crypt(input, inputOffset, inputLen, input, inputOffset);
        if (this.mode == Cipher.Mode.Encrypt) {
            this.mac.update(input, inputOffset, inputLen);
            this.mac.doFinal(input, inputOffset + inputLen);
        }
        this.headerEngine.advanceNonce();
        this.headerEngine.initCounter(0L);
        this.bodyEngine.advanceNonce();
        this.mac.init(this.bodyEngine.polyKey());
    }

    @Override
    public String getTransformation() {
        return "ChaCha20";
    }

    @Override
    public int getIVSize() {
        return 8;
    }

    @Override
    public int getAuthenticationTagSize() {
        return 16;
    }

    @Override
    public int getCipherBlockSize() {
        return 8;
    }

    @Override
    public int getKdfSize() {
        return 64;
    }

    @Override
    public int getKeySize() {
        return 512;
    }

    public String toString() {
        return "chacha20-poly1305";
    }

    protected static class ChaChaEngine {
        private static final int BLOCK_BYTES = 64;
        private static final int BLOCK_INTS = 16;
        private static final int KEY_OFFSET = 4;
        private static final int KEY_BYTES = 32;
        private static final int KEY_INTS = 8;
        private static final int COUNTER_OFFSET = 12;
        private static final int NONCE_OFFSET = 14;
        private static final int[] ENGINE_STATE_HEADER = ChaChaEngine.unpackSigmaString("expand 32-byte k".getBytes(StandardCharsets.US_ASCII));
        protected final int[] engineState = new int[16];
        protected final byte[] keyStream = new byte[64];
        protected final byte[] nonce = new byte[4];
        protected long initialNonce;
        protected long nonceVal;

        protected ChaChaEngine() {
            System.arraycopy(ENGINE_STATE_HEADER, 0, this.engineState, 0, 4);
        }

        protected void initKey(byte[] key) {
            ChaChaEngine.unpackIntsLE(key, 0, 8, this.engineState, 4);
        }

        protected void initNonce(byte[] nonce) {
            long hiBits = BufferUtils.getUInt(nonce, 0, 4);
            ValidateUtils.checkState(hiBits == 0L, "ChaCha20 nonce is not a valid SSH packet sequence number");
            this.nonceVal = this.initialNonce = BufferUtils.getUInt(nonce, 4, 4);
            this.engineState[14] = 0;
            this.engineState[15] = Poly1305Mac.unpackIntLE(nonce, 4);
        }

        protected void advanceNonce() {
            this.nonceVal = this.nonceVal + 1L & 0xFFFFFFFFL;
            ValidateUtils.checkState(this.nonceVal != this.initialNonce, "Packet sequence number cannot be reused with the same key");
            BufferUtils.putUInt(this.nonceVal, this.nonce, 0, 4);
            this.engineState[15] = Poly1305Mac.unpackIntLE(this.nonce, 0);
        }

        protected void initCounter(long counter) {
            this.engineState[12] = (int)counter;
            this.engineState[13] = 0;
        }

        protected void crypt(byte[] in, int offset, int length, byte[] out, int outOffset) {
            while (length > 0) {
                this.setKeyStream(this.engineState);
                int want = Math.min(64, length);
                for (int i = 0; i < want; ++i) {
                    out[outOffset++] = (byte)(in[offset++] ^ this.keyStream[i]);
                }
                length -= want;
                this.engineState[12] = this.engineState[12] + 1;
            }
        }

        protected byte[] polyKey() {
            byte[] block = new byte[32];
            this.initCounter(0L);
            this.crypt(block, 0, block.length, block, 0);
            this.initCounter(1L);
            return block;
        }

        protected void setKeyStream(int[] engine) {
            int x0 = engine[0];
            int x1 = engine[1];
            int x2 = engine[2];
            int x3 = engine[3];
            int x4 = engine[4];
            int x5 = engine[5];
            int x6 = engine[6];
            int x7 = engine[7];
            int x8 = engine[8];
            int x9 = engine[9];
            int x10 = engine[10];
            int x11 = engine[11];
            int x12 = engine[12];
            int x13 = engine[13];
            int x14 = engine[14];
            int x15 = engine[15];
            for (int i = 0; i < 10; ++i) {
                x12 = Integer.rotateLeft(x12 ^ (x0 += x4), 16);
                x4 = Integer.rotateLeft(x4 ^ (x8 += x12), 12);
                x12 = Integer.rotateLeft(x12 ^ (x0 += x4), 8);
                x4 = Integer.rotateLeft(x4 ^ (x8 += x12), 7);
                x13 = Integer.rotateLeft(x13 ^ (x1 += x5), 16);
                x5 = Integer.rotateLeft(x5 ^ (x9 += x13), 12);
                x13 = Integer.rotateLeft(x13 ^ (x1 += x5), 8);
                x5 = Integer.rotateLeft(x5 ^ (x9 += x13), 7);
                x14 = Integer.rotateLeft(x14 ^ (x2 += x6), 16);
                x6 = Integer.rotateLeft(x6 ^ (x10 += x14), 12);
                x14 = Integer.rotateLeft(x14 ^ (x2 += x6), 8);
                x6 = Integer.rotateLeft(x6 ^ (x10 += x14), 7);
                x15 = Integer.rotateLeft(x15 ^ (x3 += x7), 16);
                x7 = Integer.rotateLeft(x7 ^ (x11 += x15), 12);
                x15 = Integer.rotateLeft(x15 ^ (x3 += x7), 8);
                x7 = Integer.rotateLeft(x7 ^ (x11 += x15), 7);
                x15 = Integer.rotateLeft(x15 ^ (x0 += x5), 16);
                x5 = Integer.rotateLeft(x5 ^ (x10 += x15), 12);
                x15 = Integer.rotateLeft(x15 ^ (x0 += x5), 8);
                x5 = Integer.rotateLeft(x5 ^ (x10 += x15), 7);
                x12 = Integer.rotateLeft(x12 ^ (x1 += x6), 16);
                x6 = Integer.rotateLeft(x6 ^ (x11 += x12), 12);
                x12 = Integer.rotateLeft(x12 ^ (x1 += x6), 8);
                x6 = Integer.rotateLeft(x6 ^ (x11 += x12), 7);
                x13 = Integer.rotateLeft(x13 ^ (x2 += x7), 16);
                x7 = Integer.rotateLeft(x7 ^ (x8 += x13), 12);
                x13 = Integer.rotateLeft(x13 ^ (x2 += x7), 8);
                x7 = Integer.rotateLeft(x7 ^ (x8 += x13), 7);
                x14 = Integer.rotateLeft(x14 ^ (x3 += x4), 16);
                x4 = Integer.rotateLeft(x4 ^ (x9 += x14), 12);
                x14 = Integer.rotateLeft(x14 ^ (x3 += x4), 8);
                x4 = Integer.rotateLeft(x4 ^ (x9 += x14), 7);
            }
            Poly1305Mac.packIntLE(engine[0] + x0, this.keyStream, 0);
            Poly1305Mac.packIntLE(engine[1] + x1, this.keyStream, 4);
            Poly1305Mac.packIntLE(engine[2] + x2, this.keyStream, 8);
            Poly1305Mac.packIntLE(engine[3] + x3, this.keyStream, 12);
            Poly1305Mac.packIntLE(engine[4] + x4, this.keyStream, 16);
            Poly1305Mac.packIntLE(engine[5] + x5, this.keyStream, 20);
            Poly1305Mac.packIntLE(engine[6] + x6, this.keyStream, 24);
            Poly1305Mac.packIntLE(engine[7] + x7, this.keyStream, 28);
            Poly1305Mac.packIntLE(engine[8] + x8, this.keyStream, 32);
            Poly1305Mac.packIntLE(engine[9] + x9, this.keyStream, 36);
            Poly1305Mac.packIntLE(engine[10] + x10, this.keyStream, 40);
            Poly1305Mac.packIntLE(engine[11] + x11, this.keyStream, 44);
            Poly1305Mac.packIntLE(engine[12] + x12, this.keyStream, 48);
            Poly1305Mac.packIntLE(engine[13] + x13, this.keyStream, 52);
            Poly1305Mac.packIntLE(engine[14] + x14, this.keyStream, 56);
            Poly1305Mac.packIntLE(engine[15] + x15, this.keyStream, 60);
        }

        private static void unpackIntsLE(byte[] buf, int off, int nrInts, int[] dst, int dstOff) {
            for (int i = 0; i < nrInts; ++i) {
                dst[dstOff++] = Poly1305Mac.unpackIntLE(buf, off);
                off += 4;
            }
        }

        private static int[] unpackSigmaString(byte[] buf) {
            int[] values = new int[4];
            ChaChaEngine.unpackIntsLE(buf, 0, 4, values, 0);
            return values;
        }
    }
}

