diff --git a/src/encryption.py b/src/encryption.py index bc8e19a..a9b07c0 100755 --- a/src/encryption.py +++ b/src/encryption.py @@ -28,13 +28,11 @@ from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from collections import OrderedDict -class PasswordMismatch(Exception): +class EncryptionFail(Exception): pass - -# Salts -def generate_salt(): - return os.urandom(16) +class PasswordMismatch(Exception): + pass # Return string from bytes @@ -72,8 +70,7 @@ def derive_key(password, salt): iterations=100000, backend=default_backend() ) - r_key = base64.urlsafe_b64encode(kdf.derive(password)) - return r_key + return base64.urlsafe_b64encode(kdf.derive(password)) # Encryption functions @@ -99,94 +96,62 @@ def decrypt(token, key): # encryption_function: encrypt(message, key) : decrypt(token, key): # Returns settings_server_decrypted dictionary with Byte() values. Will need to use # ChangeEncodingDict to make them strings (recommended cfg file friendly) -def __settings_server(password, salt, settings_server, encryption_function): +def encrypt_settings(settings_server, password, salt, encryption_function): key = derive_key(password, salt) settings_server_decrypted = OrderedDict() for setting in settings_server: settings_server_decrypted[setting] = encryption_function(settings_server[setting], key) return settings_server_decrypted - -# Returns (salt, settings_server) -def _settings_server_encrypt(settings_server): - salt = generate_salt() - password = getpass.getpass("Enter password: ") - password2 = getpass.getpass("Retype Password: ") - - if password != password2: - raise PasswordMismatch - - settings_server_encrypted = __settings_server(password.encode(), salt, encode_dict(settings_server), encrypt) - - return salt, settings_server_encrypted +def get_keyfile(keyfile=None): + if keyfile is not None: + with open(keyfile, "rb") as f: + return f.read() + return None -# Returns (settings_server) -def _settings_server_decrypt(settings_server, settings_encrypt): - settings_server_encoded = encode_dict(settings_server) - if settings_encrypt["encrypt"]: - salt = salt_decode(settings_encrypt["salt"]) - password = getpass.getpass("Enter password: ") - return __settings_server(password.encode(), salt, settings_server_encoded, decrypt) - else: - return settings_server_encoded - - -# Wrapper function that will catch exceptions and exit -def settings_server_new(function, **kwargs): +def get_pass(q): try: - return function(**kwargs) - - # If the user cancels the login + return getpass.getpass(q).encode() except KeyboardInterrupt: - print("\nQuitting...") + raise EncryptionFail("\nQuitting...") - # If the user passwords do not match (encrypt) - except PasswordMismatch: - print("Passwords do not match...") - # Incorrect password entered (decrypt) - except InvalidToken: - print("Password or Token Incorrect...") - - # Probably the salt value got modified - except base64.binascii.Error: - print("Salt is invalid...") - - # Some other kind of fuck up +def settings_server_encrypt(settings_server, keyfile=None): + try: + settings_server = encode_dict(settings_server) + salt = os.urandom(16) + password = get_keyfile(keyfile) + if password is None: + password = get_pass("Enter password: ") + password2 = get_pass("Enter password: ") + if password != password2: + raise PasswordMismatch("Passwords do not match") + encrypted = encrypt_settings(settings_server, password, salt, encrypt) + return salt_encode(salt), decode_dict(encrypted) + except PasswordMismatch as e: + raise EncryptionFail(str(e)) except Exception as e: - print("Unknown exception occurred...") - print(e) - - # Exit if an exception was thrown - sys.exit(1) + err = str(e) + raise EncryptionFail("Encrypt Error: {}".format(err)) -# Glue functions that package **kwargs automatically -def settings_server_encrypt(settings_server): - kwargs = {"settings_server": settings_server} - return settings_server_new(_settings_server_encrypt, **kwargs) - - -def settings_server_decrypt(settings_server, settings_encrypt): - kwargs = { - "settings_server": settings_server, - "settings_encrypt": settings_encrypt - } - return settings_server_new(_settings_server_decrypt, **kwargs) - - -# The _cfg functions should return a regular string -# These are the functions that should interface with the bot a return a plain string -# settings_server ordered dictionary -def settings_server_encrypt_cfg(settings_server): - salt, settings_server = settings_server_encrypt(settings_server) - return salt_encode(salt), decode_dict(settings_server) - - -def settings_server_decrypt_cfg(settings_server, settings_encrypt): - settings_server = settings_server_decrypt(settings_server, settings_encrypt) - return decode_dict(settings_server) +def settings_server_decrypt(settings_server, settings_encrypt, keyfile=None): + try: + if not settings_encrypt["encrypt"]: + return settings_server + settings_server = encode_dict(settings_server) + password = get_keyfile(keyfile or settings_encrypt["keyfile"]) or get_pass("Enter password: ") + salt = salt_decode(settings_encrypt["salt"]) + decrypted = encrypt_settings(settings_server, password, salt, decrypt) + return decode_dict(decrypted) + except base64.binascii.Error: + raise EncryptionFail("Salt is invalid") + except InvalidToken: + raise EncryptionFail("Password or token is incorrect") + except Exception as e: + err = str(e) + raise EncryptionFail("Decrypt Error: {}".format(err)) def main(): @@ -202,6 +167,7 @@ def main(): parser.add_argument("--encrypt", help="Generate encrypted authentication.", action="store_true") parser.add_argument("--decrypt", help="Decrypt encrypted authentication", action="store_true") parser.add_argument("--recrypt", help="Recrypt encrypted authentication", action="store_true") + parser.add_argument("-k", "--keyfile", help="Keyfile used for decryption", default=None) parser.add_argument("-c", "--cfg", help="Specify config file.", default=default_cfg) arguments = parser.parse_args() @@ -214,25 +180,28 @@ def main(): import importlib cfg = importlib.import_module(arguments.cfg) settings_server = cfg.settings_server - settings_encrypt = None + settings_encrypt = cfg.settings_encrypt + keyfile = arguments.keyfile or settings_encrypt["keyfile"] if arguments.decrypt and arguments.encrypt: print("Re-encrypting") if arguments.decrypt: # arguments.decrypt print("Decrypt...") - settings_server = settings_server_decrypt_cfg(cfg.settings_server, cfg.settings_encrypt) + settings_server = settings_server_decrypt(settings_server, settings_encrypt, arguments.keyfile) settings_encrypt = OrderedDict([ - ("encrypt", False), - ("salt", cfg.settings_encrypt["encrypt"]) + ("encrypt", False), + ("salt", settings_encrypt["salt"]), + ("keyfile", arguments.keyfile) ]) if arguments.encrypt: print("Encrypt...") - salt, settings_server = settings_server_encrypt_cfg(settings_server) + salt, settings_server = settings_server_encrypt(settings_server, keyfile) settings_encrypt = OrderedDict([ - ("encrypt", True), - ("salt", salt) + ("encrypt", True), + ("salt", salt), + ("keyfile", arguments.keyfile) ]) print("settings_server = {}".format(pformat(settings_server))) @@ -241,4 +210,8 @@ def main(): if __name__ == "__main__": - sys.exit(main()) + try: + sys.exit(main()) + except EncryptionFail as e: + print(e) + sys.exit(1)