/*
 * Decompiled with CFR 0.152.
 */
package org.jgroups.protocols;

import java.security.Key;
import java.security.KeyStore;
import java.security.MessageDigest;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import org.jgroups.Address;
import org.jgroups.Event;
import org.jgroups.Message;
import org.jgroups.View;
import org.jgroups.annotations.ManagedAttribute;
import org.jgroups.annotations.ManagedOperation;
import org.jgroups.annotations.Property;
import org.jgroups.protocols.EncryptHeader;
import org.jgroups.stack.Protocol;
import org.jgroups.util.AsciiString;
import org.jgroups.util.BoundedHashMap;
import org.jgroups.util.MessageBatch;
import org.jgroups.util.Util;

public abstract class Encrypt<E extends KeyStore.Entry>
extends Protocol {
    protected static final String DEFAULT_SYM_ALGO = "AES";
    @Property(description="Cryptographic Service Provider")
    protected String provider;
    @Property(description="Cipher engine transformation for asymmetric algorithm. Default is RSA")
    protected String asym_algorithm = "RSA";
    @Property(description="Cipher engine transformation for symmetric algorithm. Default is AES")
    protected String sym_algorithm = "AES";
    @Property(description="Initial public/private key length. Default is 2048")
    protected int asym_keylength = 2048;
    @Property(description="Initial key length for matching symmetric algorithm. Default is 128")
    protected int sym_keylength = 128;
    @Property(description="Number of ciphers in the pool to parallelize encrypt and decrypt requests", writable=false)
    protected int cipher_pool_size = 8;
    @Property(description="If true, the entire message (including payload and headers) is encrypted, else only the payload", deprecatedMessage="ignored (always false)")
    @Deprecated
    protected boolean encrypt_entire_message;
    @Property(description="If true, all messages are digitally signed by adding an encrypted checksum of the encrypted message to the header. Ignored if encrypt_entire_message is false", deprecatedMessage="ignored (always false)")
    @Deprecated
    protected boolean sign_msgs;
    @Property(description="When sign_msgs is true, by default CRC32 is used to create the checksum. If use_adler is true, Adler32 will be used", deprecatedMessage="ignored as sign_msgs has been deprecated")
    @Deprecated
    protected boolean use_adler;
    @Property(description="Max number of keys in key_map")
    protected int key_map_max_size = 20;
    protected volatile Address local_addr;
    protected volatile View view;
    protected volatile BlockingQueue<Cipher> encoding_ciphers;
    protected volatile BlockingQueue<Cipher> decoding_ciphers;
    protected volatile byte[] sym_version;
    protected volatile Key secret_key;
    protected Map<AsciiString, Cipher> key_map;

    public abstract void setKeyStoreEntry(E var1);

    public int asymKeylength() {
        return this.asym_keylength;
    }

    public <T extends Encrypt<E>> T asymKeylength(int len) {
        this.asym_keylength = len;
        return (T)this;
    }

    public int symKeylength() {
        return this.sym_keylength;
    }

    public <T extends Encrypt<E>> T symKeylength(int len) {
        this.sym_keylength = len;
        return (T)this;
    }

    public Key secretKey() {
        return this.secret_key;
    }

    public String symAlgorithm() {
        return this.sym_algorithm;
    }

    public <T extends Encrypt<E>> T symAlgorithm(String alg) {
        this.sym_algorithm = alg;
        return (T)this;
    }

    public String asymAlgorithm() {
        return this.asym_algorithm;
    }

    public <T extends Encrypt<E>> T asymAlgorithm(String alg) {
        this.asym_algorithm = alg;
        return (T)this;
    }

    public byte[] symVersion() {
        return this.sym_version;
    }

    public <T extends Encrypt<E>> T localAddress(Address addr) {
        this.local_addr = addr;
        return (T)this;
    }

    @ManagedAttribute
    public String version() {
        return Util.byteArrayToHexString(this.sym_version);
    }

    @ManagedOperation(description="Prints the versions of the shared group keys cached in the key map")
    public String printCachedGroupKeys() {
        return this.key_map.keySet().stream().map(v -> Util.byteArrayToHexString(v.chars())).collect(Collectors.joining(", "));
    }

    @Override
    public void init() throws Exception {
        int tmp = Util.getNextHigherPowerOfTwo(this.cipher_pool_size);
        if (tmp != this.cipher_pool_size) {
            this.log.warn("%s: setting cipher_pool_size (%d) to %d (power of 2) for faster modulo operation", this.local_addr, this.cipher_pool_size, tmp);
            this.cipher_pool_size = tmp;
        }
        this.key_map = new BoundedHashMap<AsciiString, Cipher>(this.key_map_max_size);
        this.initSymCiphers(this.sym_algorithm, this.secret_key);
    }

    @Override
    public Object down(Event evt) {
        switch (evt.getType()) {
            case 6: {
                Object retval = this.down_prot.down(evt);
                this.handleView((View)evt.getArg());
                return retval;
            }
            case 8: {
                this.local_addr = (Address)evt.getArg();
            }
        }
        return this.down_prot.down(evt);
    }

    @Override
    public Object down(Message msg) {
        try {
            if (this.secret_key == null) {
                this.log.trace("%s: discarded %s message to %s as secret key is null, hdrs: %s", this.local_addr, msg.dest() == null ? "mcast" : "unicast", msg.dest(), msg.printHeaders());
                return null;
            }
            this.encryptAndSend(msg);
        }
        catch (Exception e) {
            this.log.warn("%s: unable to send message down", this.local_addr, e);
        }
        return null;
    }

    @Override
    public Object up(Event evt) {
        switch (evt.getType()) {
            case 6: {
                this.handleView((View)evt.getArg());
            }
        }
        return this.up_prot.up(evt);
    }

    @Override
    public Object up(Message msg) {
        EncryptHeader hdr = (EncryptHeader)msg.getHeader(this.id);
        if (hdr == null) {
            this.log.error("%s: received message without encrypt header from %s; dropping it", this.local_addr, msg.src());
            return null;
        }
        try {
            return this.handleEncryptedMessage(msg);
        }
        catch (Exception e) {
            this.log.warn("%s: exception occurred decrypting message", this.local_addr, e);
            return null;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void up(MessageBatch batch) {
        if (this.secret_key == null) {
            this.log.trace("%s: discarded %s batch from %s as secret key is null", this.local_addr, batch.dest() == null ? "mcast" : "unicast", batch.sender());
            return;
        }
        BlockingQueue<Cipher> cipherQueue = this.decoding_ciphers;
        try {
            Cipher cipher = cipherQueue.take();
            try {
                Decrypter decrypter = new Decrypter(cipher);
                batch.forEach(decrypter);
            }
            finally {
                cipherQueue.offer(cipher);
            }
        }
        catch (InterruptedException e) {
            this.log.error("%s: failed processing batch; discarding batch", this.local_addr, e);
            return;
        }
        if (!batch.isEmpty()) {
            this.up_prot.up(batch);
        }
    }

    protected void initSymCiphers(String algorithm, Key secret) throws Exception {
        if (secret == null) {
            return;
        }
        ArrayBlockingQueue<Cipher> tmp_encoding_ciphers = new ArrayBlockingQueue<Cipher>(this.cipher_pool_size);
        ArrayBlockingQueue<Cipher> tmp_decoding_ciphers = new ArrayBlockingQueue<Cipher>(this.cipher_pool_size);
        for (int i = 0; i < this.cipher_pool_size; ++i) {
            tmp_encoding_ciphers.offer(this.createCipher(1, secret, algorithm));
            tmp_decoding_ciphers.offer(this.createCipher(2, secret, algorithm));
        }
        MessageDigest digest = MessageDigest.getInstance("MD5");
        byte[] tmp_sym_version = digest.digest(secret.getEncoded());
        this.encoding_ciphers = tmp_encoding_ciphers;
        this.decoding_ciphers = tmp_decoding_ciphers;
        this.sym_version = tmp_sym_version;
    }

    protected Cipher createCipher(int mode, Key secret_key, String algorithm) throws Exception {
        Cipher cipher = this.provider != null && !this.provider.trim().isEmpty() ? Cipher.getInstance(algorithm, this.provider) : Cipher.getInstance(algorithm);
        cipher.init(mode, secret_key);
        return cipher;
    }

    protected Object handleEncryptedMessage(Message msg) throws Exception {
        Message tmpMsg = this.decryptMessage(null, msg.copy());
        if (tmpMsg != null) {
            return this.up_prot.up(tmpMsg);
        }
        this.log.warn("%s: unrecognized cipher; discarding message from %s", this.local_addr, msg.src());
        return null;
    }

    protected void handleView(View view) {
        this.view = view;
    }

    protected boolean inView(Address sender, String error_msg) {
        View curr_view = this.view;
        if (curr_view == null || curr_view.containsMember(sender)) {
            return true;
        }
        this.log.error(error_msg, sender, curr_view);
        return false;
    }

    protected Message decryptMessage(Cipher cipher, Message msg) throws Exception {
        EncryptHeader hdr = (EncryptHeader)msg.getHeader(this.id);
        if (!Arrays.equals(hdr.version(), this.sym_version)) {
            if (!this.inView(msg.src(), String.format("%s: rejected decryption of %s message from non-member %s", this.local_addr, msg.dest() == null ? "multicast" : "unicast", msg.getSrc()))) {
                return null;
            }
            cipher = this.key_map.get(new AsciiString(hdr.version()));
            if (cipher == null) {
                this.log.trace("%s: message from %s (version: %s) dropped, as a cipher matching that version wasn't found (current version: %s)", this.local_addr, msg.src(), Util.byteArrayToHexString(hdr.version()), Util.byteArrayToHexString(this.sym_version));
                return null;
            }
            this.log.trace("%s: decrypting msg from %s using previous cipher version %s", this.local_addr, msg.src(), Util.byteArrayToHexString(hdr.version()));
            return this._decrypt(cipher, msg);
        }
        return this._decrypt(cipher, msg);
    }

    protected Message _decrypt(Cipher cipher, Message msg) throws Exception {
        byte[] decrypted_msg;
        if (msg.getLength() == 0) {
            return msg;
        }
        if (cipher == null) {
            decrypted_msg = this.code(msg.getRawBuffer(), msg.getOffset(), msg.getLength(), true);
        } else {
            try {
                decrypted_msg = cipher.doFinal(msg.getRawBuffer(), msg.getOffset(), msg.getLength());
            }
            catch (BadPaddingException | IllegalBlockSizeException e) {
                cipher.init(2, this.secret_key);
                throw e;
            }
        }
        return msg.setBuffer(decrypted_msg);
    }

    protected Message encrypt(Message msg) throws Exception {
        EncryptHeader hdr = new EncryptHeader(this.symVersion());
        Message msgEncrypted = msg.copy(false).putHeader(this.id, hdr);
        byte[] payload = msg.getRawBuffer();
        if (payload != null) {
            if (msg.getLength() > 0) {
                msgEncrypted.setBuffer(this.code(payload, msg.getOffset(), msg.getLength(), false));
            } else {
                msgEncrypted.setBuffer(payload, msg.getOffset(), msg.getLength());
            }
        }
        return msgEncrypted;
    }

    protected void encryptAndSend(Message msg) throws Exception {
        this.down_prot.down(this.encrypt(msg));
    }

    protected byte[] code(byte[] buf, int offset, int length, boolean decode) throws Exception {
        BlockingQueue<Cipher> queue = decode ? this.decoding_ciphers : this.encoding_ciphers;
        Cipher cipher = queue.take();
        try {
            byte[] byArray = cipher.doFinal(buf, offset, length);
            return byArray;
        }
        catch (BadPaddingException | IllegalBlockSizeException e) {
            cipher.init(decode ? 2 : 1, this.secret_key);
            throw e;
        }
        finally {
            queue.offer(cipher);
        }
    }

    protected static String getAlgorithm(String s) {
        int index = s.indexOf(47);
        return index == -1 ? s : s.substring(0, index);
    }

    protected class Decrypter
    implements BiConsumer<Message, MessageBatch> {
        protected final Cipher cipher;

        public Decrypter(Cipher cipher) {
            this.cipher = cipher;
        }

        @Override
        public void accept(Message msg, MessageBatch batch) {
            if (msg.getHeader(Encrypt.this.id) == null) {
                Encrypt.this.log.error("%s: received message without encrypt header from %s; dropping it", Encrypt.this.local_addr, batch.sender());
                batch.remove(msg);
                return;
            }
            try {
                Message tmpMsg = Encrypt.this.decryptMessage(this.cipher, msg.copy());
                if (tmpMsg != null) {
                    batch.replace(msg, tmpMsg);
                } else {
                    batch.remove(msg);
                }
            }
            catch (Exception e) {
                Encrypt.this.log.error("%s: failed decrypting message from %s (offset=%d, length=%d, buf.length=%d): %s, headers are %s", Encrypt.this.local_addr, msg.getSrc(), msg.getOffset(), msg.getLength(), msg.getRawBuffer().length, e, msg.printHeaders());
                batch.remove(msg);
            }
        }
    }
}

