Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import tempfile
import logging
from collections import namedtuple
from datetime import datetime, timedelta

from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import rsa

LOGGER = logging.getLogger(__name__)
LOGGER.addHandler(logging.NullHandler())


class CertAttributes(namedtuple('CertAttributes', [
'country',
'state',
'city',
'org',
'org_name',
'common'
])):
def to_x509(self):
"""Convert this namedtuple to an x509 named object."""
return x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, self.country),
x509.NameAttribute(
NameOID.STATE_OR_PROVINCE_NAME, self.state),
x509.NameAttribute(NameOID.LOCALITY_NAME, self.city),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, self.org),
x509.NameAttribute(
NameOID.ORGANIZATIONAL_UNIT_NAME, self.org_name),
x509.NameAttribute(NameOID.COMMON_NAME, self.common),
])


class Cert():
"""Cert class helps in the generation of SSL certificates."""

def __init__(self, attributes=None, private_key=None):

# Accept the attributes that define the certs.
if attributes:
self.attributes = attributes

if private_key:
self.private_key = private_key

def _private_key():
"""Generates a private key."""
return rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend()
)

def _certificate(self, certattrs, ca_cert=False, self_sign=True):
"""Generate a certificate."""
if not certattrs:
certattrs = self.attributes

if self_sign:
key = self._private_key()
else:
key = self.private_key

# Create certificate and sign it.
subject = issuer = certattrs.to_x509()
cert = x509.CertificateBuilder()\
.subject_name(subject)\
.issuer_name(issuer)\
.public_key(key.public_key())\
.serial_number(x509.random_serial_number())\
.not_valid_before(datetime.utcnow())\
.not_valid_after(datetime.utcnow() + timedelta(days=3650))\
.add_extension(
x509.SubjectAlternativeName([x509.DNSName(certattrs.common)]),
critical=False
)

if ca_cert:
cert = cert.add_extension(
x509.BasicConstraints(
ca=True,
path_length=1
),
critical=True
)
cert = cert.sign(key, hashes.SHA256(), default_backend())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signing the certificate with it's own key should probably be an option (self_sign=True) since it is not always desirable.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I made the proper modifications to meet this need. LMK what you think of my implementation.


return key, cert

def gen_self_signed_cert(self):
"""Generates a self signed cert."""

return self._certificate(self.attributes, self_sign=True)

def gen_ca_cert(self):
"""
Generate a key + certificate authority certificate.

Creates a CA cert for the proxy to use when bumping SSL. This certificate
is also used to sign an SSL server certificate for the proxy web interface.
This leverages the trust that must be in place for the ca certificate.
"""
return self._certificate(self.attributes, ca_cert=True)

def gen_csr(self):
"""
Generates a CSR that is based on the passed in key.

Returns a private key and a CSR.
"""
certattrs = self.attributes

key = self.private_key()

# Generate a CSR.
csr = x509.CertificateSigningRequestBuilder().subject_name(
certattrs.to_x509()
).add_extension(
x509.SubjectAlternativeName([
x509.DNSName(certattrs.common),
]),
critical=False,
)

# Sign CSR with our private key and return it.
return key, csr.sign(key, hashes.SHA256(), default_backend())

# DISK SAVING FUNCTIONS

def save_key(fobj, key, format='key'):
"""Save a key file."""
if format != 'key':
raise NotImplementedError(
'%s format unsupported, use .key' % format)
LOGGER.debug('Writing PEM encoded key to %s', fobj.name)
fobj.write(
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
)
)
fobj.write(b'\n')

def save_cert(fobj, cert, format='pem'):
"""Save a cert."""
if format != 'pem':
raise NotImplementedError(
'%s format unsupported, use .pem' % format)
LOGGER.debug('Writing PEM encoded certificate to %s', fobj.name)
fobj.write(cert.public_bytes(serialization.Encoding.PEM))
fobj.write(b'\n')

def save_csr(fobj, csr, format='pem'):
"""Save a csr."""
if format != 'pem':
raise NotImplementedError(
'%s format unsupported, use .pem' % format)
LOGGER.debug('Writing PEM encoded CSR to %s', fobj.name)
fobj.write(csr.public_bytes(serialization.Encoding.PEM))
fobj.write(b'\n')