704 lines
22 KiB
Python
Executable File
704 lines
22 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Much of the logging code here was forked from https://github.com/ezyang/ghstack
|
|
# Copyright (c) Edward Z. Yang <ezyang@mit.edu>
|
|
"""Checks out the nightly development version of PyTorch and installs pre-built
|
|
binaries into the repo.
|
|
|
|
You can use this script to check out a new nightly branch with the following::
|
|
|
|
$ ./tools/nightly.py checkout -b my-nightly-branch
|
|
$ conda activate pytorch-deps
|
|
|
|
Or if you would like to re-use an existing conda environment, you can pass in
|
|
the regular environment parameters (--name or --prefix)::
|
|
|
|
$ ./tools/nightly.py checkout -b my-nightly-branch -n my-env
|
|
$ conda activate my-env
|
|
|
|
You can also use this tool to pull the nightly commits into the current branch as
|
|
well. This can be done with
|
|
|
|
$ ./tools/nightly.py pull -n my-env
|
|
$ conda activate my-env
|
|
|
|
Pulling will reinstalle the conda dependencies as well as the nightly binaries into
|
|
the repo directory.
|
|
"""
|
|
import contextlib
|
|
import datetime
|
|
import functools
|
|
import glob
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
import uuid
|
|
from argparse import ArgumentParser
|
|
from ast import literal_eval
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
cast,
|
|
Dict,
|
|
Generator,
|
|
Iterable,
|
|
Iterator,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Set,
|
|
Tuple,
|
|
TypeVar,
|
|
)
|
|
|
|
LOGGER: Optional[logging.Logger] = None
|
|
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
|
|
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
|
|
SHA1_RE = re.compile("([0-9a-fA-F]{40})")
|
|
USERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@")
|
|
LOG_DIRNAME_RE = re.compile(
|
|
r"(\d{4}-\d\d-\d\d_\d\dh\d\dm\d\ds)_" r"[0-9a-f]{8}-(?:[0-9a-f]{4}-){3}[0-9a-f]{12}"
|
|
)
|
|
SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphinx")
|
|
|
|
|
|
class Formatter(logging.Formatter):
|
|
redactions: Dict[str, str]
|
|
|
|
def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None):
|
|
super().__init__(fmt, datefmt)
|
|
self.redactions = {}
|
|
|
|
# Remove sensitive information from URLs
|
|
def _filter(self, s: str) -> str:
|
|
s = USERNAME_PASSWORD_RE.sub(r"://<USERNAME>:<PASSWORD>@", s)
|
|
for needle, replace in self.redactions.items():
|
|
s = s.replace(needle, replace)
|
|
return s
|
|
|
|
def formatMessage(self, record: logging.LogRecord) -> str:
|
|
if record.levelno == logging.INFO or record.levelno == logging.DEBUG:
|
|
# Log INFO/DEBUG without any adornment
|
|
return record.getMessage()
|
|
else:
|
|
# I'm not sure why, but formatMessage doesn't show up
|
|
# even though it's in the typeshed for Python >3
|
|
return super().formatMessage(record)
|
|
|
|
def format(self, record: logging.LogRecord) -> str:
|
|
return self._filter(super().format(record))
|
|
|
|
def redact(self, needle: str, replace: str = "<REDACTED>") -> None:
|
|
"""Redact specific strings; e.g., authorization tokens. This won't
|
|
retroactively redact stuff you've already leaked, so make sure
|
|
you redact things as soon as possible.
|
|
"""
|
|
# Don't redact empty strings; this will lead to something
|
|
# that looks like s<REDACTED>t<REDACTED>r<REDACTED>...
|
|
if needle == "":
|
|
return
|
|
self.redactions[needle] = replace
|
|
|
|
|
|
@functools.lru_cache
|
|
def logging_base_dir() -> str:
|
|
meta_dir = os.getcwd()
|
|
base_dir = os.path.join(meta_dir, "nightly", "log")
|
|
os.makedirs(base_dir, exist_ok=True)
|
|
return base_dir
|
|
|
|
|
|
@functools.lru_cache
|
|
def logging_run_dir() -> str:
|
|
cur_dir = os.path.join(
|
|
logging_base_dir(),
|
|
f"{datetime.datetime.now().strftime(DATETIME_FORMAT)}_{uuid.uuid1()}",
|
|
)
|
|
os.makedirs(cur_dir, exist_ok=True)
|
|
return cur_dir
|
|
|
|
|
|
@functools.lru_cache
|
|
def logging_record_argv() -> None:
|
|
s = subprocess.list2cmdline(sys.argv)
|
|
with open(os.path.join(logging_run_dir(), "argv"), "w") as f:
|
|
f.write(s)
|
|
|
|
|
|
def logging_record_exception(e: BaseException) -> None:
|
|
with open(os.path.join(logging_run_dir(), "exception"), "w") as f:
|
|
f.write(type(e).__name__)
|
|
|
|
|
|
def logging_rotate() -> None:
|
|
log_base = logging_base_dir()
|
|
old_logs = os.listdir(log_base)
|
|
old_logs.sort(reverse=True)
|
|
for stale_log in old_logs[1000:]:
|
|
# Sanity check that it looks like a log
|
|
if LOG_DIRNAME_RE.fullmatch(stale_log) is not None:
|
|
shutil.rmtree(os.path.join(log_base, stale_log))
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, None]:
|
|
"""Setup logging. If a failure starts here we won't
|
|
be able to save the user in a reasonable way.
|
|
|
|
Logging structure: there is one logger (the root logger)
|
|
and in processes all events. There are two handlers:
|
|
stderr (INFO) and file handler (DEBUG).
|
|
"""
|
|
formatter = Formatter(fmt="%(levelname)s: %(message)s", datefmt="")
|
|
root_logger = logging.getLogger("conda-pytorch")
|
|
root_logger.setLevel(logging.DEBUG)
|
|
|
|
console_handler = logging.StreamHandler()
|
|
if debug:
|
|
console_handler.setLevel(logging.DEBUG)
|
|
else:
|
|
console_handler.setLevel(logging.INFO)
|
|
console_handler.setFormatter(formatter)
|
|
root_logger.addHandler(console_handler)
|
|
|
|
log_file = os.path.join(logging_run_dir(), "nightly.log")
|
|
|
|
file_handler = logging.FileHandler(log_file)
|
|
file_handler.setFormatter(formatter)
|
|
root_logger.addHandler(file_handler)
|
|
logging_record_argv()
|
|
|
|
try:
|
|
logging_rotate()
|
|
print(f"log file: {log_file}")
|
|
yield root_logger
|
|
except Exception as e:
|
|
logging.exception("Fatal exception")
|
|
logging_record_exception(e)
|
|
print(f"log file: {log_file}")
|
|
sys.exit(1)
|
|
except BaseException as e:
|
|
# You could logging.debug here to suppress the backtrace
|
|
# entirely, but there is no reason to hide it from technically
|
|
# savvy users.
|
|
logging.info("", exc_info=True)
|
|
logging_record_exception(e)
|
|
print(f"log file: {log_file}")
|
|
sys.exit(1)
|
|
|
|
|
|
def check_in_repo() -> Optional[str]:
|
|
"""Ensures that we are in the PyTorch repo."""
|
|
if not os.path.isfile("setup.py"):
|
|
return "Not in root-level PyTorch repo, no setup.py found"
|
|
with open("setup.py") as f:
|
|
s = f.read()
|
|
if "PyTorch" not in s:
|
|
return "Not in PyTorch repo, 'PyTorch' not found in setup.py"
|
|
return None
|
|
|
|
|
|
def check_branch(subcommand: str, branch: Optional[str]) -> Optional[str]:
|
|
"""Checks that the branch name can be checked out."""
|
|
if subcommand != "checkout":
|
|
return None
|
|
# first make sure actual branch name was given
|
|
if branch is None:
|
|
return "Branch name to checkout must be supplied with '-b' option"
|
|
# next check that the local repo is clean
|
|
cmd = ["git", "status", "--untracked-files=no", "--porcelain"]
|
|
p = subprocess.run(
|
|
cmd,
|
|
capture_output=True,
|
|
check=True,
|
|
text=True,
|
|
)
|
|
if p.stdout.strip():
|
|
return "Need to have clean working tree to checkout!\n\n" + p.stdout
|
|
# next check that the branch name doesn't already exist
|
|
cmd = ["git", "show-ref", "--verify", "--quiet", "refs/heads/" + branch]
|
|
p = subprocess.run(cmd, capture_output=True, check=False) # type: ignore[assignment]
|
|
if not p.returncode:
|
|
return f"Branch {branch!r} already exists"
|
|
return None
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
|
|
"""Timed context manager"""
|
|
start_time = time.time()
|
|
yield
|
|
logger.info("%s took %.3f [s]", prefix, time.time() - start_time)
|
|
|
|
|
|
F = TypeVar("F", bound=Callable[..., Any])
|
|
|
|
|
|
def timed(prefix: str) -> Callable[[F], F]:
|
|
"""Decorator for timing functions"""
|
|
|
|
def dec(f: F) -> F:
|
|
@functools.wraps(f)
|
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
global LOGGER
|
|
logger = cast(logging.Logger, LOGGER)
|
|
logger.info(prefix)
|
|
with timer(logger, prefix):
|
|
return f(*args, **kwargs)
|
|
|
|
return cast(F, wrapper)
|
|
|
|
return dec
|
|
|
|
|
|
def _make_channel_args(
|
|
channels: Iterable[str] = ("pytorch-nightly",),
|
|
override_channels: bool = False,
|
|
) -> List[str]:
|
|
args = []
|
|
for channel in channels:
|
|
args.append("--channel")
|
|
args.append(channel)
|
|
if override_channels:
|
|
args.append("--override-channels")
|
|
return args
|
|
|
|
|
|
@timed("Solving conda environment")
|
|
def conda_solve(
|
|
name: Optional[str] = None,
|
|
prefix: Optional[str] = None,
|
|
channels: Iterable[str] = ("pytorch-nightly",),
|
|
override_channels: bool = False,
|
|
) -> Tuple[List[str], str, str, bool, List[str]]:
|
|
"""Performs the conda solve and splits the deps from the package."""
|
|
# compute what environment to use
|
|
if prefix is not None:
|
|
existing_env = True
|
|
env_opts = ["--prefix", prefix]
|
|
elif name is not None:
|
|
existing_env = True
|
|
env_opts = ["--name", name]
|
|
else:
|
|
# create new environment
|
|
existing_env = False
|
|
env_opts = ["--name", "pytorch-deps"]
|
|
# run solve
|
|
if existing_env:
|
|
cmd = [
|
|
"conda",
|
|
"install",
|
|
"--yes",
|
|
"--dry-run",
|
|
"--json",
|
|
]
|
|
cmd.extend(env_opts)
|
|
else:
|
|
cmd = [
|
|
"conda",
|
|
"create",
|
|
"--yes",
|
|
"--dry-run",
|
|
"--json",
|
|
"--name",
|
|
"__pytorch__",
|
|
]
|
|
channel_args = _make_channel_args(
|
|
channels=channels, override_channels=override_channels
|
|
)
|
|
cmd.extend(channel_args)
|
|
cmd.extend(SPECS_TO_INSTALL)
|
|
p = subprocess.run(cmd, capture_output=True, check=True)
|
|
# parse solution
|
|
solve = json.loads(p.stdout)
|
|
link = solve["actions"]["LINK"]
|
|
deps = []
|
|
for pkg in link:
|
|
url = URL_FORMAT.format(**pkg)
|
|
if pkg["name"] == "pytorch":
|
|
pytorch = url
|
|
platform = pkg["platform"]
|
|
else:
|
|
deps.append(url)
|
|
return deps, pytorch, platform, existing_env, env_opts
|
|
|
|
|
|
@timed("Installing dependencies")
|
|
def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> None:
|
|
"""Install dependencies to deps environment"""
|
|
if not existing_env:
|
|
# first remove previous pytorch-deps env
|
|
cmd = ["conda", "env", "remove", "--yes"] + env_opts
|
|
p = subprocess.run(cmd, check=True)
|
|
# install new deps
|
|
inst_opt = "install" if existing_env else "create"
|
|
cmd = ["conda", inst_opt, "--yes", "--no-deps"] + env_opts + deps
|
|
p = subprocess.run(cmd, check=True)
|
|
|
|
|
|
@timed("Installing pytorch nightly binaries")
|
|
def pytorch_install(url: str) -> "tempfile.TemporaryDirectory[str]":
|
|
""" "Install pytorch into a temporary directory"""
|
|
pytdir = tempfile.TemporaryDirectory()
|
|
cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url]
|
|
p = subprocess.run(cmd, check=True)
|
|
return pytdir
|
|
|
|
|
|
def _site_packages(dirname: str, platform: str) -> str:
|
|
if platform.startswith("win"):
|
|
template = os.path.join(dirname, "Lib", "site-packages")
|
|
else:
|
|
template = os.path.join(dirname, "lib", "python*.*", "site-packages")
|
|
spdir = glob.glob(template)[0]
|
|
return spdir
|
|
|
|
|
|
def _ensure_commit(git_sha1: str) -> None:
|
|
"""Make sure that we actually have the commit locally"""
|
|
cmd = ["git", "cat-file", "-e", git_sha1 + "^{commit}"]
|
|
p = subprocess.run(cmd, capture_output=True, check=False)
|
|
if p.returncode == 0:
|
|
# we have the commit locally
|
|
return
|
|
# we don't have the commit, must fetch
|
|
cmd = ["git", "fetch", "https://github.com/pytorch/pytorch.git", git_sha1]
|
|
p = subprocess.run(cmd, check=True)
|
|
|
|
|
|
def _nightly_version(spdir: str) -> str:
|
|
# first get the git version from the installed module
|
|
version_fname = os.path.join(spdir, "torch", "version.py")
|
|
with open(version_fname) as f:
|
|
lines = f.read().splitlines()
|
|
for line in lines:
|
|
if not line.startswith("git_version"):
|
|
continue
|
|
git_version = literal_eval(line.partition("=")[2].strip())
|
|
break
|
|
else:
|
|
raise RuntimeError(f"Could not find git_version in {version_fname}")
|
|
print(f"Found released git version {git_version}")
|
|
# now cross reference with nightly version
|
|
_ensure_commit(git_version)
|
|
cmd = ["git", "show", "--no-patch", "--format=%s", git_version]
|
|
p = subprocess.run(
|
|
cmd,
|
|
capture_output=True,
|
|
check=True,
|
|
text=True,
|
|
)
|
|
m = SHA1_RE.search(p.stdout)
|
|
if m is None:
|
|
raise RuntimeError(
|
|
f"Could not find nightly release in git history:\n {p.stdout}"
|
|
)
|
|
nightly_version = m.group(1)
|
|
print(f"Found nightly release version {nightly_version}")
|
|
# now checkout nightly version
|
|
_ensure_commit(nightly_version)
|
|
return nightly_version
|
|
|
|
|
|
@timed("Checking out nightly PyTorch")
|
|
def checkout_nightly_version(branch: str, spdir: str) -> None:
|
|
"""Get's the nightly version and then checks it out."""
|
|
nightly_version = _nightly_version(spdir)
|
|
cmd = ["git", "checkout", "-b", branch, nightly_version]
|
|
p = subprocess.run(cmd, check=True)
|
|
|
|
|
|
@timed("Pulling nightly PyTorch")
|
|
def pull_nightly_version(spdir: str) -> None:
|
|
"""Fetches the nightly version and then merges it ."""
|
|
nightly_version = _nightly_version(spdir)
|
|
cmd = ["git", "merge", nightly_version]
|
|
p = subprocess.run(cmd, check=True)
|
|
|
|
|
|
def _get_listing_linux(source_dir: str) -> List[str]:
|
|
listing = glob.glob(os.path.join(source_dir, "*.so"))
|
|
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so")))
|
|
return listing
|
|
|
|
|
|
def _get_listing_osx(source_dir: str) -> List[str]:
|
|
# oddly, these are .so files even on Mac
|
|
listing = glob.glob(os.path.join(source_dir, "*.so"))
|
|
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib")))
|
|
return listing
|
|
|
|
|
|
def _get_listing_win(source_dir: str) -> List[str]:
|
|
listing = glob.glob(os.path.join(source_dir, "*.pyd"))
|
|
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib")))
|
|
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll")))
|
|
return listing
|
|
|
|
|
|
def _glob_pyis(d: str) -> Set[str]:
|
|
search = os.path.join(d, "**", "*.pyi")
|
|
pyis = {os.path.relpath(p, d) for p in glob.iglob(search)}
|
|
return pyis
|
|
|
|
|
|
def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]:
|
|
source_pyis = _glob_pyis(source_dir)
|
|
target_pyis = _glob_pyis(target_dir)
|
|
missing_pyis = [os.path.join(source_dir, p) for p in (source_pyis - target_pyis)]
|
|
missing_pyis.sort()
|
|
return missing_pyis
|
|
|
|
|
|
def _get_listing(source_dir: str, target_dir: str, platform: str) -> List[str]:
|
|
if platform.startswith("linux"):
|
|
listing = _get_listing_linux(source_dir)
|
|
elif platform.startswith("osx"):
|
|
listing = _get_listing_osx(source_dir)
|
|
elif platform.startswith("win"):
|
|
listing = _get_listing_win(source_dir)
|
|
else:
|
|
raise RuntimeError(f"Platform {platform!r} not recognized")
|
|
listing.extend(_find_missing_pyi(source_dir, target_dir))
|
|
listing.append(os.path.join(source_dir, "version.py"))
|
|
listing.append(os.path.join(source_dir, "testing", "_internal", "generated"))
|
|
listing.append(os.path.join(source_dir, "bin"))
|
|
listing.append(os.path.join(source_dir, "include"))
|
|
return listing
|
|
|
|
|
|
def _remove_existing(trg: str, is_dir: bool) -> None:
|
|
if os.path.exists(trg):
|
|
if is_dir:
|
|
shutil.rmtree(trg)
|
|
else:
|
|
os.remove(trg)
|
|
|
|
|
|
def _move_single(
|
|
src: str,
|
|
source_dir: str,
|
|
target_dir: str,
|
|
mover: Callable[[str, str], None],
|
|
verb: str,
|
|
) -> None:
|
|
is_dir = os.path.isdir(src)
|
|
relpath = os.path.relpath(src, source_dir)
|
|
trg = os.path.join(target_dir, relpath)
|
|
_remove_existing(trg, is_dir)
|
|
# move over new files
|
|
if is_dir:
|
|
os.makedirs(trg, exist_ok=True)
|
|
for root, dirs, files in os.walk(src):
|
|
relroot = os.path.relpath(root, src)
|
|
for name in files:
|
|
relname = os.path.join(relroot, name)
|
|
s = os.path.join(src, relname)
|
|
t = os.path.join(trg, relname)
|
|
print(f"{verb} {s} -> {t}")
|
|
mover(s, t)
|
|
for name in dirs:
|
|
relname = os.path.join(relroot, name)
|
|
os.makedirs(os.path.join(trg, relname), exist_ok=True)
|
|
else:
|
|
print(f"{verb} {src} -> {trg}")
|
|
mover(src, trg)
|
|
|
|
|
|
def _copy_files(listing: List[str], source_dir: str, target_dir: str) -> None:
|
|
for src in listing:
|
|
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
|
|
|
|
|
|
def _link_files(listing: List[str], source_dir: str, target_dir: str) -> None:
|
|
for src in listing:
|
|
_move_single(src, source_dir, target_dir, os.link, "Linking")
|
|
|
|
|
|
@timed("Moving nightly files into repo")
|
|
def move_nightly_files(spdir: str, platform: str) -> None:
|
|
"""Moves PyTorch files from temporary installed location to repo."""
|
|
# get file listing
|
|
source_dir = os.path.join(spdir, "torch")
|
|
target_dir = os.path.abspath("torch")
|
|
listing = _get_listing(source_dir, target_dir, platform)
|
|
# copy / link files
|
|
if platform.startswith("win"):
|
|
_copy_files(listing, source_dir, target_dir)
|
|
else:
|
|
try:
|
|
_link_files(listing, source_dir, target_dir)
|
|
except Exception:
|
|
_copy_files(listing, source_dir, target_dir)
|
|
|
|
|
|
def _available_envs() -> Dict[str, str]:
|
|
cmd = ["conda", "env", "list"]
|
|
p = subprocess.run(
|
|
cmd,
|
|
check=True,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
lines = p.stdout.splitlines()
|
|
envs = {}
|
|
for line in map(str.strip, lines):
|
|
if not line or line.startswith("#"):
|
|
continue
|
|
parts = line.split()
|
|
if len(parts) == 1:
|
|
# unnamed env
|
|
continue
|
|
envs[parts[0]] = parts[-1]
|
|
return envs
|
|
|
|
|
|
@timed("Writing pytorch-nightly.pth")
|
|
def write_pth(env_opts: List[str], platform: str) -> None:
|
|
"""Writes Python path file for this dir."""
|
|
env_type, env_dir = env_opts
|
|
if env_type == "--name":
|
|
# have to find directory
|
|
envs = _available_envs()
|
|
env_dir = envs[env_dir]
|
|
spdir = _site_packages(env_dir, platform)
|
|
pth = os.path.join(spdir, "pytorch-nightly.pth")
|
|
s = (
|
|
"# This file was autogenerated by PyTorch's tools/nightly.py\n"
|
|
"# Please delete this file if you no longer need the following development\n"
|
|
"# version of PyTorch to be importable\n"
|
|
f"{os.getcwd()}\n"
|
|
)
|
|
with open(pth, "w") as f:
|
|
f.write(s)
|
|
|
|
|
|
def install(
|
|
*,
|
|
logger: logging.Logger,
|
|
subcommand: str = "checkout",
|
|
branch: Optional[str] = None,
|
|
name: Optional[str] = None,
|
|
prefix: Optional[str] = None,
|
|
channels: Iterable[str] = ("pytorch-nightly",),
|
|
override_channels: bool = False,
|
|
) -> None:
|
|
"""Development install of PyTorch"""
|
|
deps, pytorch, platform, existing_env, env_opts = conda_solve(
|
|
name=name, prefix=prefix, channels=channels, override_channels=override_channels
|
|
)
|
|
if deps:
|
|
deps_install(deps, existing_env, env_opts)
|
|
pytdir = pytorch_install(pytorch)
|
|
spdir = _site_packages(pytdir.name, platform)
|
|
if subcommand == "checkout":
|
|
checkout_nightly_version(cast(str, branch), spdir)
|
|
elif subcommand == "pull":
|
|
pull_nightly_version(spdir)
|
|
else:
|
|
raise ValueError(f"Subcommand {subcommand} must be one of: checkout, pull.")
|
|
move_nightly_files(spdir, platform)
|
|
write_pth(env_opts, platform)
|
|
pytdir.cleanup()
|
|
logger.info(
|
|
"-------\nPyTorch Development Environment set up!\nPlease activate to "
|
|
"enable this environment:\n $ conda activate %s",
|
|
env_opts[1],
|
|
)
|
|
|
|
|
|
def make_parser() -> ArgumentParser:
|
|
p = ArgumentParser("nightly")
|
|
# subcommands
|
|
subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute")
|
|
co = subcmd.add_parser("checkout", help="checkout a new branch")
|
|
co.add_argument(
|
|
"-b",
|
|
"--branch",
|
|
help="Branch name to checkout",
|
|
dest="branch",
|
|
default=None,
|
|
metavar="NAME",
|
|
)
|
|
pull = subcmd.add_parser(
|
|
"pull", help="pulls the nightly commits into the current branch"
|
|
)
|
|
# general arguments
|
|
subps = [co, pull]
|
|
for subp in subps:
|
|
subp.add_argument(
|
|
"-n",
|
|
"--name",
|
|
help="Name of environment",
|
|
dest="name",
|
|
default=None,
|
|
metavar="ENVIRONMENT",
|
|
)
|
|
subp.add_argument(
|
|
"-p",
|
|
"--prefix",
|
|
help="Full path to environment location (i.e. prefix)",
|
|
dest="prefix",
|
|
default=None,
|
|
metavar="PATH",
|
|
)
|
|
subp.add_argument(
|
|
"-v",
|
|
"--verbose",
|
|
help="Provide debugging info",
|
|
dest="verbose",
|
|
default=False,
|
|
action="store_true",
|
|
)
|
|
subp.add_argument(
|
|
"--override-channels",
|
|
help="Do not search default or .condarc channels.",
|
|
dest="override_channels",
|
|
default=False,
|
|
action="store_true",
|
|
)
|
|
subp.add_argument(
|
|
"-c",
|
|
"--channel",
|
|
help="Additional channel to search for packages. 'pytorch-nightly' will always be prepended to this list.",
|
|
dest="channels",
|
|
action="append",
|
|
metavar="CHANNEL",
|
|
)
|
|
return p
|
|
|
|
|
|
def main(args: Optional[Sequence[str]] = None) -> None:
|
|
"""Main entry point"""
|
|
global LOGGER
|
|
p = make_parser()
|
|
ns = p.parse_args(args)
|
|
ns.branch = getattr(ns, "branch", None)
|
|
status = check_in_repo()
|
|
status = status or check_branch(ns.subcmd, ns.branch)
|
|
if status:
|
|
sys.exit(status)
|
|
channels = ["pytorch-nightly"]
|
|
if ns.channels:
|
|
channels.extend(ns.channels)
|
|
with logging_manager(debug=ns.verbose) as logger:
|
|
LOGGER = logger
|
|
install(
|
|
subcommand=ns.subcmd,
|
|
branch=ns.branch,
|
|
name=ns.name,
|
|
prefix=ns.prefix,
|
|
logger=logger,
|
|
channels=channels,
|
|
override_channels=ns.override_channels,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|