2018-03-14 16:37:46 +07:00
|
|
|
import collections
|
|
|
|
import functools
|
2018-06-21 14:55:14 +07:00
|
|
|
import itertools
|
2018-03-14 16:37:46 +07:00
|
|
|
import logging
|
|
|
|
|
|
|
|
import requests
|
|
|
|
|
|
|
|
from . import exceptions
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
class GH(object):
|
|
|
|
def __init__(self, token, repo):
|
|
|
|
self._url = 'https://api.github.com'
|
|
|
|
self._repo = repo
|
|
|
|
session = self._session = requests.Session()
|
2018-03-26 18:08:49 +07:00
|
|
|
session.headers['Authorization'] = 'token {}'.format(token)
|
2018-03-14 16:37:46 +07:00
|
|
|
|
2018-08-28 20:42:28 +07:00
|
|
|
def __call__(self, method, path, params=None, json=None, check=True):
|
2018-03-14 16:37:46 +07:00
|
|
|
"""
|
|
|
|
:type check: bool | dict[int:Exception]
|
|
|
|
"""
|
|
|
|
r = self._session.request(
|
|
|
|
method,
|
2018-03-26 18:08:49 +07:00
|
|
|
'{}/repos/{}/{}'.format(self._url, self._repo, path),
|
2018-08-28 20:42:28 +07:00
|
|
|
params=params,
|
2018-03-14 16:37:46 +07:00
|
|
|
json=json
|
|
|
|
)
|
|
|
|
if check:
|
|
|
|
if isinstance(check, collections.Mapping):
|
|
|
|
exc = check.get(r.status_code)
|
|
|
|
if exc:
|
|
|
|
raise exc(r.content)
|
|
|
|
r.raise_for_status()
|
|
|
|
return r
|
|
|
|
|
|
|
|
def head(self, branch):
|
2018-03-26 18:08:49 +07:00
|
|
|
d = self('get', 'git/refs/heads/{}'.format(branch)).json()
|
2018-03-14 16:37:46 +07:00
|
|
|
|
2018-03-26 18:08:49 +07:00
|
|
|
assert d['ref'] == 'refs/heads/{}'.format(branch)
|
2018-03-14 16:37:46 +07:00
|
|
|
assert d['object']['type'] == 'commit'
|
2018-09-20 14:25:13 +07:00
|
|
|
_logger.debug("head(%s, %s) -> %s", self._repo, branch, d['object']['sha'])
|
2018-03-14 16:37:46 +07:00
|
|
|
return d['object']['sha']
|
|
|
|
|
|
|
|
def commit(self, sha):
|
2018-09-20 14:25:13 +07:00
|
|
|
c = self('GET', 'git/commits/{}'.format(sha)).json()
|
|
|
|
_logger.debug('commit(%s, %s) -> %s', self._repo, sha, shorten(c['message']))
|
|
|
|
return c
|
2018-03-14 16:37:46 +07:00
|
|
|
|
|
|
|
def comment(self, pr, message):
|
2018-03-26 18:08:49 +07:00
|
|
|
self('POST', 'issues/{}/comments'.format(pr), json={'body': message})
|
2018-09-20 14:25:13 +07:00
|
|
|
_logger.debug('comment(%s, %s, %s)', self._repo, pr, shorten(message))
|
2018-03-14 16:37:46 +07:00
|
|
|
|
|
|
|
def close(self, pr, message):
|
|
|
|
self.comment(pr, message)
|
2018-03-26 18:08:49 +07:00
|
|
|
self('PATCH', 'pulls/{}'.format(pr), json={'state': 'closed'})
|
2018-03-14 16:37:46 +07:00
|
|
|
|
2018-03-28 21:43:48 +07:00
|
|
|
def change_tags(self, pr, from_, to_):
|
|
|
|
to_add, to_remove = to_ - from_, from_ - to_
|
|
|
|
for t in to_remove:
|
|
|
|
r = self('DELETE', 'issues/{}/labels/{}'.format(pr, t), check=False)
|
|
|
|
# successful deletion or attempt to delete a tag which isn't there
|
|
|
|
# is fine, otherwise trigger an error
|
|
|
|
if r.status_code not in (200, 404):
|
|
|
|
r.raise_for_status()
|
|
|
|
|
|
|
|
if to_add:
|
|
|
|
self('POST', 'issues/{}/labels'.format(pr), json=list(to_add))
|
|
|
|
|
2018-09-20 14:25:13 +07:00
|
|
|
_logger.debug('change_tags(%s, %s, remove=%s, add=%s)', self._repo, pr, to_remove, to_add)
|
|
|
|
|
2018-03-14 16:37:46 +07:00
|
|
|
def fast_forward(self, branch, sha):
|
|
|
|
try:
|
2018-03-26 18:08:49 +07:00
|
|
|
self('patch', 'git/refs/heads/{}'.format(branch), json={'sha': sha})
|
2018-09-20 14:25:13 +07:00
|
|
|
_logger.debug('fast_forward(%s, %s, %s) -> OK', self._repo, branch, sha)
|
2018-03-14 16:37:46 +07:00
|
|
|
except requests.HTTPError:
|
2018-09-20 14:25:13 +07:00
|
|
|
_logger.debug('fast_forward(%s, %s, %s) -> ERROR', self._repo, branch, sha, exc_info=True)
|
2018-03-14 16:37:46 +07:00
|
|
|
raise exceptions.FastForwardError()
|
|
|
|
|
|
|
|
def set_ref(self, branch, sha):
|
|
|
|
# force-update ref
|
2018-03-26 18:08:49 +07:00
|
|
|
r = self('patch', 'git/refs/heads/{}'.format(branch), json={
|
2018-03-14 16:37:46 +07:00
|
|
|
'sha': sha,
|
|
|
|
'force': True,
|
|
|
|
}, check=False)
|
|
|
|
if r.status_code == 200:
|
2018-09-20 14:25:13 +07:00
|
|
|
_logger.debug('set_ref(update, %s, %s, %s) -> OK', self._repo, branch, sha)
|
2018-03-14 16:37:46 +07:00
|
|
|
return
|
|
|
|
|
2018-06-07 19:44:44 +07:00
|
|
|
# 422 makes no sense but that's what github returns, leaving 404 just
|
|
|
|
# in case
|
|
|
|
if r.status_code in (404, 422):
|
2018-03-14 16:37:46 +07:00
|
|
|
# fallback: create ref
|
|
|
|
r = self('post', 'git/refs', json={
|
2018-03-26 18:08:49 +07:00
|
|
|
'ref': 'refs/heads/{}'.format(branch),
|
2018-03-14 16:37:46 +07:00
|
|
|
'sha': sha,
|
|
|
|
}, check=False)
|
|
|
|
if r.status_code == 201:
|
2018-09-20 14:25:13 +07:00
|
|
|
_logger.debug('set_ref(create, %s, %s, %s) -> OK', self._repo, branch, sha)
|
2018-03-14 16:37:46 +07:00
|
|
|
return
|
2018-06-07 19:44:44 +07:00
|
|
|
raise AssertionError("{}: {}".format(r.status_code, r.json()))
|
2018-03-14 16:37:46 +07:00
|
|
|
|
2018-08-28 20:42:28 +07:00
|
|
|
def merge(self, sha, dest, message):
|
|
|
|
r = self('post', 'merges', json={
|
|
|
|
'base': dest,
|
|
|
|
'head': sha,
|
|
|
|
'commit_message': message,
|
|
|
|
}, check={409: exceptions.MergeError})
|
|
|
|
r = r.json()
|
2018-09-20 14:25:13 +07:00
|
|
|
_logger.debug("merge(%s, %s, %s) -> %s", self._repo, dest, shorten(message), r['sha'])
|
2018-08-28 20:42:28 +07:00
|
|
|
return dict(r['commit'], sha=r['sha'])
|
|
|
|
|
|
|
|
def rebase(self, pr, dest, reset=False, commits=None):
|
|
|
|
""" Rebase pr's commits on top of dest, updates dest unless ``reset``
|
|
|
|
is set.
|
|
|
|
|
|
|
|
Returns the hash of the rebased head.
|
|
|
|
"""
|
|
|
|
original_head = self.head(dest)
|
|
|
|
if commits is None:
|
|
|
|
commits = self.commits(pr)
|
|
|
|
|
|
|
|
assert commits, "can't rebase a PR with no commits"
|
|
|
|
for c in commits:
|
|
|
|
assert len(c['parents']) == 1, "can't rebase commits with more than one parent"
|
|
|
|
tmp_msg = 'temp rebasing PR %s (%s)' % (pr, c['sha'])
|
|
|
|
c['new_tree'] = self.merge(c['sha'], dest, tmp_msg)['tree']['sha']
|
|
|
|
self.set_ref(dest, original_head)
|
|
|
|
|
|
|
|
prev = original_head
|
|
|
|
for c in commits:
|
|
|
|
copy = self('post', 'git/commits', json={
|
|
|
|
'message': c['commit']['message'],
|
|
|
|
'tree': c['new_tree'],
|
|
|
|
'parents': [prev],
|
|
|
|
'author': c['commit']['author'],
|
|
|
|
'committer': c['commit']['committer'],
|
|
|
|
}, check={409: exceptions.MergeError}).json()
|
|
|
|
prev = copy['sha']
|
|
|
|
|
|
|
|
if reset:
|
|
|
|
self.set_ref(dest, original_head)
|
2018-09-20 15:08:08 +07:00
|
|
|
else:
|
|
|
|
self.set_ref(dest, prev)
|
2018-08-28 20:42:28 +07:00
|
|
|
|
2018-09-20 14:25:13 +07:00
|
|
|
_logger.debug('%s, %s, %s, reset=%s, commits=%s) -> %s',
|
|
|
|
self._repo, pr, dest, reset, commits and len(commits),
|
|
|
|
prev)
|
2018-08-28 20:42:28 +07:00
|
|
|
# prev is updated after each copy so it's the rebased PR head
|
|
|
|
return prev
|
2018-03-14 16:37:46 +07:00
|
|
|
|
2018-06-21 14:55:14 +07:00
|
|
|
# fetch various bits of issues / prs to load them
|
|
|
|
def pr(self, number):
|
|
|
|
return (
|
|
|
|
self('get', 'issues/{}'.format(number)).json(),
|
|
|
|
self('get', 'pulls/{}'.format(number)).json()
|
|
|
|
)
|
|
|
|
|
|
|
|
def comments(self, number):
|
|
|
|
for page in itertools.count(1):
|
2018-08-28 20:42:28 +07:00
|
|
|
r = self('get', 'issues/{}/comments'.format(number), params={'page': page})
|
2018-06-21 14:55:14 +07:00
|
|
|
yield from r.json()
|
|
|
|
if not r.links.get('next'):
|
|
|
|
return
|
|
|
|
|
|
|
|
def reviews(self, number):
|
|
|
|
for page in itertools.count(1):
|
2018-08-28 20:42:28 +07:00
|
|
|
r = self('get', 'pulls/{}/reviews'.format(number), params={'page': page})
|
2018-06-21 14:55:14 +07:00
|
|
|
yield from r.json()
|
|
|
|
if not r.links.get('next'):
|
|
|
|
return
|
|
|
|
|
2018-08-28 20:42:28 +07:00
|
|
|
def commits(self, pr):
|
|
|
|
""" Returns a PR's commits oldest first (that's what GH does &
|
|
|
|
is what we want)
|
|
|
|
"""
|
|
|
|
r = self('get', 'pulls/{}/commits'.format(pr), params={'per_page': PR_COMMITS_MAX})
|
|
|
|
assert not r.links.get('next'), "more than {} commits".format(PR_COMMITS_MAX)
|
|
|
|
return r.json()
|
|
|
|
|
2018-06-21 14:55:14 +07:00
|
|
|
def statuses(self, h):
|
|
|
|
r = self('get', 'commits/{}/status'.format(h)).json()
|
|
|
|
return [{
|
|
|
|
'sha': r['sha'],
|
2018-09-17 16:04:31 +07:00
|
|
|
**s,
|
2018-06-21 14:55:14 +07:00
|
|
|
} for s in r['statuses']]
|
2018-08-28 20:42:28 +07:00
|
|
|
|
|
|
|
PR_COMMITS_MAX = 50
|
2018-09-20 14:25:13 +07:00
|
|
|
def shorten(s):
|
|
|
|
if not s:
|
|
|
|
return s
|
|
|
|
|
|
|
|
line1 = s.split('\n', 1)[0]
|
|
|
|
if len(line1) < 50:
|
|
|
|
return line1
|
|
|
|
|
|
|
|
return line1[:47] + '...'
|