Skip to content

Commit

Permalink
Compute aggregate xpub for musig() in descriptors in the python clien…
Browse files Browse the repository at this point in the history
…t library
  • Loading branch information
bigspider committed Mar 11, 2024
1 parent d8ca7d4 commit d789442
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 1 deletion.
177 changes: 177 additions & 0 deletions bitcoin_client/ledger_bitcoin/bip0327.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# extracted from the BIP327 reference implementation: https://github.com/bitcoin/bips/blob/b3701faef2bdb98a0d7ace4eedbeefa2da4c89ed/bip-0327/reference.py

# Only contains the key aggregation part of the library

# The code in this source file is distributed under the BSD-3-Clause.

# autopep8: off

from typing import List, Optional, Tuple, NewType, NamedTuple
import hashlib

#
# The following helper functions were copied from the BIP-340 reference implementation:
# https://github.com/bitcoin/bips/blob/master/bip-0340/reference.py
#

p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141

# Points are tuples of X and Y coordinates and the point at infinity is
# represented by the None keyword.
G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8)

Point = Tuple[int, int]

# This implementation can be sped up by storing the midstate after hashing
# tag_hash instead of rehashing it all the time.
def tagged_hash(tag: str, msg: bytes) -> bytes:
tag_hash = hashlib.sha256(tag.encode()).digest()
return hashlib.sha256(tag_hash + tag_hash + msg).digest()

def is_infinite(P: Optional[Point]) -> bool:
return P is None

def x(P: Point) -> int:
assert not is_infinite(P)
return P[0]

def y(P: Point) -> int:
assert not is_infinite(P)
return P[1]

def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]:
if P1 is None:
return P2
if P2 is None:
return P1
if (x(P1) == x(P2)) and (y(P1) != y(P2)):
return None
if P1 == P2:
lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p
else:
lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p
x3 = (lam * lam - x(P1) - x(P2)) % p
return (x3, (lam * (x(P1) - x3) - y(P1)) % p)

def point_mul(P: Optional[Point], n: int) -> Optional[Point]:
R = None
for i in range(256):
if (n >> i) & 1:
R = point_add(R, P)
P = point_add(P, P)
return R

def bytes_from_int(x: int) -> bytes:
return x.to_bytes(32, byteorder="big")

def lift_x(b: bytes) -> Optional[Point]:
x = int_from_bytes(b)
if x >= p:
return None
y_sq = (pow(x, 3, p) + 7) % p
y = pow(y_sq, (p + 1) // 4, p)
if pow(y, 2, p) != y_sq:
return None
return (x, y if y & 1 == 0 else p-y)

def int_from_bytes(b: bytes) -> int:
return int.from_bytes(b, byteorder="big")

def has_even_y(P: Point) -> bool:
assert not is_infinite(P)
return y(P) % 2 == 0

#
# End of helper functions copied from BIP-340 reference implementation.
#

PlainPk = NewType('PlainPk', bytes)
XonlyPk = NewType('XonlyPk', bytes)

# There are two types of exceptions that can be raised by this implementation:
# - ValueError for indicating that an input doesn't conform to some function
# precondition (e.g. an input array is the wrong length, a serialized
# representation doesn't have the correct format).
# - InvalidContributionError for indicating that a signer (or the
# aggregator) is misbehaving in the protocol.
#
# Assertions are used to (1) satisfy the type-checking system, and (2) check for
# inconvenient events that can't happen except with negligible probability (e.g.
# output of a hash function is 0) and can't be manually triggered by any
# signer.

# This exception is raised if a party (signer or nonce aggregator) sends invalid
# values. Actual implementations should not crash when receiving invalid
# contributions. Instead, they should hold the offending party accountable.
class InvalidContributionError(Exception):
def __init__(self, signer, contrib):
self.signer = signer
# contrib is one of "pubkey", "pubnonce", "aggnonce", or "psig".
self.contrib = contrib

infinity = None

def xbytes(P: Point) -> bytes:
return bytes_from_int(x(P))

def cbytes(P: Point) -> bytes:
a = b'\x02' if has_even_y(P) else b'\x03'
return a + xbytes(P)

def point_negate(P: Optional[Point]) -> Optional[Point]:
if P is None:
return P
return (x(P), p - y(P))

def cpoint(x: bytes) -> Point:
if len(x) != 33:
raise ValueError('x is not a valid compressed point.')
P = lift_x(x[1:33])
if P is None:
raise ValueError('x is not a valid compressed point.')
if x[0] == 2:
return P
elif x[0] == 3:
P = point_negate(P)
assert P is not None
return P
else:
raise ValueError('x is not a valid compressed point.')

KeyAggContext = NamedTuple('KeyAggContext', [('Q', Point),
('gacc', int),
('tacc', int)])

def key_agg(pubkeys: List[PlainPk]) -> KeyAggContext:
pk2 = get_second_key(pubkeys)
u = len(pubkeys)
Q = infinity
for i in range(u):
try:
P_i = cpoint(pubkeys[i])
except ValueError:
raise InvalidContributionError(i, "pubkey")
a_i = key_agg_coeff_internal(pubkeys, pubkeys[i], pk2)
Q = point_add(Q, point_mul(P_i, a_i))
# Q is not the point at infinity except with negligible probability.
assert(Q is not None)
gacc = 1
tacc = 0
return KeyAggContext(Q, gacc, tacc)

def hash_keys(pubkeys: List[PlainPk]) -> bytes:
return tagged_hash('KeyAgg list', b''.join(pubkeys))

def get_second_key(pubkeys: List[PlainPk]) -> PlainPk:
u = len(pubkeys)
for j in range(1, u):
if pubkeys[j] != pubkeys[0]:
return pubkeys[j]
return PlainPk(b'\x00'*33)

def key_agg_coeff_internal(pubkeys: List[PlainPk], pk_: PlainPk, pk2: PlainPk) -> int:
L = hash_keys(pubkeys)
if pk_ == pk2:
return 1
return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk_)) % n
56 changes: 55 additions & 1 deletion bitcoin_client/ledger_bitcoin/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import base64
from io import BytesIO, BufferedReader

from .embit import base58
from .embit.base import EmbitError
from .embit.descriptor import Descriptor
from .embit.networks import NETWORKS
Expand All @@ -17,9 +18,10 @@
from .merkle import get_merkleized_map_commitment
from .wallet import WalletPolicy, WalletType
from .psbt import PSBT, normalize_psbt
from . import segwit_addr
from ._serialize import deser_string

from .bip0327 import key_agg, cbytes


def parse_stream_to_map(f: BufferedReader) -> Mapping[bytes, bytes]:
result = {}
Expand All @@ -39,6 +41,53 @@ def parse_stream_to_map(f: BufferedReader) -> Mapping[bytes, bytes]:
return result


def aggr_xpub(pubkeys: List[bytes], chain: Chain) -> str:
BIP_MUSIG_CHAINCODE = bytes.fromhex(
"868087ca02a6f974c4598924c36b57762d32cb45717167e300622c7167e38965")
ctx = key_agg(pubkeys)
compressed_pubkey = cbytes(ctx.Q)

# Serialize according to BIP-32
if chain == Chain.MAIN:
version = 0x0488B21E
else:
version = 0x043587CF

return base58.encode_check(b''.join([
version.to_bytes(4, byteorder='big'),
b'\x00', # depth
b'\x00\x00\x00\x00', # parent fingerprint
b'\x00\x00\x00\x00', # child number
BIP_MUSIG_CHAINCODE,
compressed_pubkey
]))


# Given a valid descriptor, replaces each musig() (if any) with the
# corresponding synthetic xpub/tpub.
def replace_musigs(desc: str, chain: Chain) -> str:
while True:
musig_start = desc.find("musig(")
if musig_start == -1:
break
musig_end = desc.find(")", musig_start)
if musig_end == -1:
raise ValueError("Invalid descriptor template")

key_and_origs = desc[musig_start+6:musig_end].split(",")
pubkeys = []
for key_orig in key_and_origs:
orig_end = key_orig.find("]")
xpub = key_orig if orig_end == -1 else key_orig[orig_end+1:]
pubkeys.append(base58.decode_check(xpub)[-33:])

# replace with the aggregate xpub
desc = desc[:musig_start] + \
aggr_xpub(pubkeys, chain) + desc[musig_end+1:]

return desc


def _make_partial_signature(pubkey_augm: bytes, signature: bytes) -> PartialSignature:
if len(pubkey_augm) == 64:
# tapscript spend: pubkey_augm is the concatenation of:
Expand Down Expand Up @@ -273,6 +322,11 @@ def sign_message(self, message: Union[str, bytes], bip32_path: str) -> str:

def _derive_address_for_policy(self, wallet: WalletPolicy, change: bool, address_index: int) -> Optional[str]:
desc_str = wallet.get_descriptor(change)

# Since embit does not support musig() in descriptors, we replace each
# occurrence with the corresponding aggregated xpub
desc_str = replace_musigs(desc_str, self.chain)

try:
desc = Descriptor.from_string(desc_str)

Expand Down

0 comments on commit d789442

Please sign in to comment.