diff --git a/cli-v1-to-v2 b/cli-v1-to-v2 index 74eac50..1ff4892 100755 --- a/cli-v1-to-v2 +++ b/cli-v1-to-v2 @@ -36,7 +36,7 @@ def main(sysargs=sys.argv[:]): parser = argparse.ArgumentParser( description=_DESCRIPTION, formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('basedir', nargs='?', metavar='BASEDIR', + parser.add_argument('path', nargs='+', type=os.path.abspath, default=os.getcwd()) parser.add_argument('-w', '--write', help='write changes back to file', action='store_true', default=False) @@ -60,7 +60,7 @@ def main(sysargs=sys.argv[:]): format='%(message)s' ) - for filepath in _find_candidate_files(args.basedir): + for filepath in _find_candidate_files(args.path): updated_source = _update_filepath(filepath) if args.write: logging.info('Updating %s', filepath) @@ -74,21 +74,26 @@ def main(sysargs=sys.argv[:]): return 0 -def _find_candidate_files(basedir): - for curdir, dirs, files in os.walk(basedir): - for i, dirname in enumerate(dirs[:]): - if dirname.startswith('.'): - dirs.pop(i) +def _find_candidate_files(paths): + for path in paths: + if not os.path.isdir(path): + yield path + continue - for filename in files: - if not filename.endswith('.go'): - continue + for curdir, dirs, files in os.walk(path): + for i, dirname in enumerate(dirs[:]): + if dirname.startswith('.'): + dirs.pop(i) - filepath = os.path.join(curdir, filename) - if not os.access(filepath, os.R_OK | os.W_OK): - continue + for filename in files: + if not filename.endswith('.go'): + continue - yield filepath + filepath = os.path.join(curdir, filename) + if not os.access(filepath, os.R_OK | os.W_OK): + continue + + yield filepath def _update_filepath(filepath):