diff --git a/test/functional/test_framework/key.py b/test/functional/test_framework/key.py index fc3ef17f7d..d2f0c934b1 100644 --- a/test/functional/test_framework/key.py +++ b/test/functional/test_framework/key.py @@ -287,24 +287,33 @@ def sign_schnorr(key, msg, aux=None, flip_p=False, flip_r=False): class TestFrameworkKey(unittest.TestCase): - def test_schnorr(self): - """Test the Python Schnorr implementation.""" + def test_ecdsa_and_schnorr(self): + """Test the Python ECDSA and Schnorr implementations.""" + def random_bitflip(sig): + sig = list(sig) + sig[random.randrange(len(sig))] ^= (1 << (random.randrange(8))) + return bytes(sig) + byte_arrays = [generate_privkey() for _ in range(3)] + [v.to_bytes(32, 'big') for v in [0, ORDER - 1, ORDER, 2**256 - 1]] keys = {} - for privkey in byte_arrays: # build array of key/pubkey pairs - pubkey, _ = compute_xonly_pubkey(privkey) - if pubkey is not None: - keys[privkey] = pubkey + for privkey_bytes in byte_arrays: # build array of key/pubkey pairs + privkey = ECKey() + privkey.set(privkey_bytes, compressed=True) + if privkey.is_valid: + keys[privkey] = privkey.get_pubkey() for msg in byte_arrays: # test every combination of message, signing key, verification key for sign_privkey, _ in keys.items(): - sig = sign_schnorr(sign_privkey, msg) + sig_ecdsa = sign_privkey.sign_ecdsa(msg) + sig_schnorr = sign_schnorr(sign_privkey.get_bytes(), msg) for verify_privkey, verify_pubkey in keys.items(): + verify_xonly_pubkey = verify_pubkey.get_bytes()[1:] if verify_privkey == sign_privkey: - self.assertTrue(verify_schnorr(verify_pubkey, sig, msg)) - sig = list(sig) - sig[random.randrange(64)] ^= (1 << (random.randrange(8))) # damaging signature should break things - sig = bytes(sig) - self.assertFalse(verify_schnorr(verify_pubkey, sig, msg)) + self.assertTrue(verify_pubkey.verify_ecdsa(sig_ecdsa, msg)) + self.assertTrue(verify_schnorr(verify_xonly_pubkey, sig_schnorr, msg)) + sig_ecdsa = random_bitflip(sig_ecdsa) # damaging signature should break things + sig_schnorr = random_bitflip(sig_schnorr) + self.assertFalse(verify_pubkey.verify_ecdsa(sig_ecdsa, msg)) + self.assertFalse(verify_schnorr(verify_xonly_pubkey, sig_schnorr, msg)) def test_schnorr_testvectors(self): """Implement the BIP340 test vectors (read from bip340_test_vectors.csv)."""