Skip to content
Snippets Groups Projects
update-version.py 10.4 KiB
Newer Older
import argparse
import atexit
import contextlib
import os
import pipes
import re
import subprocess
import sys
import tempfile


_VERBOSE = False

_VERSION_REGEXP = re.compile(
    r"""
        (?P<major>0|[1-9]\d*)
        \.
        (?P<minor>0|[1-9]\d*)
        (?:
            \.
            (?P<patch>0|[1-9]\d*)
            (?:
                \.
                (?P<build>0|[1-9]\d*)
            )?
        )?
    """,
    re.VERBOSE)

David Byers's avatar
David Byers committed
_VERSION_FIELDS = ('major', 'minor', 'patch', 'build')
_VERSION_FORMATS = {
    field: '.'.join(f'{{{name}}}' for name in _VERSION_FIELDS[:_VERSION_FIELDS.index(field) + 1])
    for field in _VERSION_FIELDS
    }
David Byers's avatar
David Byers committed

class Version:
    """Class representing a version number."""

David Byers's avatar
David Byers committed
    def __init__(self, major, minor, patch=0, build=0):
        self.major = major
        self.minor = minor
        self.patch = patch
        self.build = build

    def increment(self, what, amount):
        if amount <= 0:
            return self
        if what == 'major':
            self.major += amount
            self.minor = 0
            self.patch = 0
            self.build = 0
        elif what == 'minor':
            self.minor += amount
            self.patch = 0
            self.build = 0
        elif what == 'patch':
            self.patch += amount
            self.build = 0
        elif what == 'build':
            self.build += amount
        else:
            raise AttributeError(what)
        return self

    def _tuple(self):
        return (self.major, self.minor, self.patch, self.build)
David Byers's avatar
David Byers committed

    def __lt__(self, other):
        return self._tuple() < other._tuple()

    def format(self, fmt):
        return _VERSION_FORMATS[fmt].format(major=self.major,
                                            minor=self.minor,
                                            patch=self.patch,
                                            build=self.build)

    def __str__(self):
David Byers's avatar
David Byers committed
        return self.format('build')


def parse_version(version):
    """Parse a version number into its components."""
    match = _VERSION_REGEXP.match(version)
    if not match:
        raise ValueError(f'invalid version format: {version}')
David Byers's avatar
David Byers committed
    return Version(int(match.group('major') or 0),
                   int(match.group('minor') or 0),
                   int(match.group('patch') or 0),
                   int(match.group('build') or 0))


def ssh_add_host_keys(host):
    """Get host keys for a remote host."""
    known_hosts_path = os.path.expanduser('~/.ssh/known_hosts')
    if os.path.exists(known_hosts_path):
        verbose(f'{known_hosts_path} already exists: not adding {host}')
        return
        
    verbose(f'ssh-keyscan {host} >> {known_hosts_path}')
    scan = subprocess.run(['ssh-keyscan', host],
                          stdout=subprocess.PIPE,
                          stderr=subprocess.DEVNULL,
                          text=True)
    if scan.stdout:
        with open(os.path.expanduser('~/.ssh/known_hosts'), 'a') as known_hosts:
            known_hosts.write(scan.stdout)


def ssh_create_directory():
    """Create an ssh directory."""
    ssh_directory = os.path.expanduser('~/.ssh')
    if os.path.exists(ssh_directory):
        verbose(f'{ssh_directory} already exists: not creating')
        return
    os.mkdir(ssh_directory)
    os.chmod(ssh_directory, 0o700)


def ssh_kill_agent():
    """Kill the running ssh agent."""
    verbose(f'ssh-agent -k')
    subprocess.run(['ssh-agent', '-k'])
    if 'SSH_AUTH_SOCK' in os.environ:
        os.environ['SSH_AUTH_SOCK'] = ''
    if 'SSH_AGENT_PID' in os.environ:
        os.environ['SSH_AGENT_PID'] = ''
        

def ssh_start_agent():
    """Start the ssh agent."""
    verbose(f'ssh-agent -s')
    agent = subprocess.run(['ssh-agent', '-s'], stdout=subprocess.PIPE, text=True)
    match = re.match(r'SSH_AUTH_SOCK=(?P<socket>[^;]+).*SSH_AGENT_PID=(?P<pid>\d+)',
                     agent.stdout,
                     re.MULTILINE | re.DOTALL)
    if not match:
        fatal(f'unable to parse ssh-agent output:\n{agent.stdout}')
    os.environ['SSH_AUTH_SOCK'] = match.group('socket')
    os.environ['SSH_AGENT_PID'] = match.group('pid')
    atexit.register(ssh_kill_agent)


def ssh_add_key(ssh_key):
    """Add an ssh key to the running agent."""
    verbose(f'ssh-add -')
    process = subprocess.run(['ssh-add', '-'],
                             input=ssh_key,
                             stdout=subprocess.PIPE,
                             stderr=subprocess.STDOUT,
                             text=True)
    if process.returncode != 0:
        fatal(f'unable to add ssh key:\n{process.stdout}')


def git(*args):
    """Run git command."""
    verbose(f'git {" ".join(pipes.quote(arg) for arg in args)}')
    return subprocess.run(["git"] + list(args),
                          encoding='utf-8',
                          stdout=subprocess.PIPE).stdout


def parse_tag_list(tag_list):
    """Parse a tag list, returing a sorted list of versions."""
    res = []
    for tag in tag_list.split('\n'):
        with contextlib.suppress(ValueError):
            res.append(parse_version(tag))
    return sorted(res, reverse=True)


def get_commit_versions(commit):
    """Return all versions pointing to a commit."""
    return parse_tag_list(git('tag', '--points-at', commit))


def get_all_versions(commit):
    """Return all versions reachable from a given commit."""
    return parse_tag_list(git('tag', '--merged', commit))


def change_origin_to_ssh():
    """Change the remote origin to use ssh."""
    repository_url = os.environ.get('CI_REPOSITORY_URL')
    if repository_url and not repository_url.startswith('git'):
        push_url = re.sub(r'.+@([^/]+)/', r'git@\1:', repository_url)
        verbose(f'repository url: {push_url}')
        git('remote', 'set-url', '--push', 'origin', push_url)
    

def tag_repository(commit, tag):
    """Tag a commit."""
    change_origin_to_ssh()
    git('tag', tag, commit)
    git('push', 'origin', tag)


def verbose(message):
    if _VERBOSE:
        print(f'{_PROGNAME}: {message}')


def fatal(message):
    print(f'{_PROGNAME}: {message}', file=sys.stderr)
    sys.exit(1)


def main():
    global _PROGNAME
    global _VERBOSE

    parser = argparse.ArgumentParser('automatically update version in git repo')
    parser.add_argument('--verbose', '-v', action='store_true',
                        help='verbose output')
    parser.add_argument('--output', '-o', metavar='PATH',
                        help='output version number to this file')
    parser.add_argument('--ssh-key', '-k', metavar='PATH',
                        help='file containing ssh key')
David Byers's avatar
David Byers committed
    parser.add_argument('--format', '-f', choices=list(_VERSION_FORMATS.keys()), default='build',
                        help="version number format")
    parser.add_argument('--increment', '-i', choices=['major', 'minor', 'patch', 'build'],
                        default='build',
                        help="what to increment")
    parser.add_argument('--dry-run', '-n', action='store_true',
                        help="don't change anything")
    opts = parser.parse_args()

    _PROGNAME = sys.argv[0]
    if opts.verbose:
        _VERBOSE = True

David Byers's avatar
David Byers committed
    if _VERSION_FIELDS.index(opts.increment) > _VERSION_FIELDS.index(opts.format):
        fatal('incremented field not included in formatted version, exiting')
        sys.exit(1)

    ssh_key = None
    if opts.ssh_key:
        with open(opts.ssh_key, 'r') as ssh_key_file:
            ssh_key = ssh_key_file.read()

    increment = 1

    # Get the current branch and bail if there is none
    commit_branch = os.environ.get('CI_COMMIT_BRANCH')
    if not commit_branch:
        commit_branch = git('branch', '--show-current').strip()
        verbose(f'CI_COMMIT_BRANCH not set, assuming {commit_branch}')

    # Get the current commit and branch
    commit_ref = os.environ.get('CI_COMMIT_SHA')
    if not commit_ref:
        commit_ref = 'HEAD'
        verbose(f'CI_COMMIT_SHA not set, assuming HEAD')

    # Get the pipeline source
    pipeline_source = os.environ.get('CI_PIPELINE_SOURCE')
    if not pipeline_source:
        pipeline_source = 'push'
        verbose(f'CI_PIPELINE_SOURCE not set, assuming push')

    # Get the current remote
    remote_url = os.environ.get('CI_REPOSITORY_URL')
    if not remote_url:
        remote_url = git('remote', 'get-url', 'origin').strip()
        verbose(f'CI_REPOSITORY_URL not set, assuming {remote_url}')

    if re.match(r'^http', remote_url, re.I):
        repository_host = re.sub(r'.+@([^/]+)/.*', r'\1', remote_url)
    else:
        repository_host = re.sub(r'.+@([^:]+):.*', r'\1', remote_url)
    
    verbose(f'commit ref:      {commit_ref}')
    verbose(f'commit branch:   {commit_branch}')
    verbose(f'pipeline source: {pipeline_source}')
    verbose(f'remote url:      {remote_url}')
    verbose(f'repository host: {repository_host}')

    # Check the pipeline source
    if (pipeline_source != 'push'):
        verbose(f'pipeline source is not push: not incrementing version')
        increment = 0

    # Don't do anything if already tagged and at max
    commit_versions = get_commit_versions(commit_ref)
    all_versions = get_all_versions(commit_ref)

    verbose(f'current version tag: {", ".join(str(version) for version in commit_versions)}')
    verbose(f'recent versions:     {", ".join(str(version) for version in all_versions[:4])}')

    try:
        if commit_versions[0] == all_versions[0]:
            verbose(f'branch is already tagged: not incrementing')
            increment = 0
    except IndexError:
        pass

David Byers's avatar
David Byers committed
    verbose(f'increment: {opts.increment} by {increment}')
    new_version = all_versions[0].increment(opts.increment, increment).format(opts.format)

    ssh_create_directory()
    ssh_add_host_keys(repository_host)
    if not opts.dry_run:
        if increment:
            verbose(f'tagging {commit_branch}: {new_version}')
            if ssh_key:
                ssh_start_agent()
                ssh_add_key(ssh_key)
            tag_repository(commit_ref, new_version)
        if opts.output:
            with open(opts.output, 'w') as output_file:
                print(new_version, file=output_file)
        else:
            print(new_version)
    else:
        if increment:
David Byers's avatar
David Byers committed
            verbose(f'git tag {new_version} {commit_ref}')
            verbose(f'git push origin')
        if opts.output:
David Byers's avatar
David Byers committed
            verbose(f'echo "{new_version}" > {opts.output}')
        if not opts.verbose:
            print(f"{new_version}")

    return 0


if __name__ == '__main__':
    sys.exit(main())