edq.util.code

Utilities for extracting and working with Python source code.

  1"""
  2Utilities for extracting and working with Python source code.
  3"""
  4
  5import ast
  6import os
  7import sys
  8import types
  9import typing
 10
 11import edq.util.dirent
 12import edq.util.json
 13
 14DEFAULT_MODULE_AST_ALLOWED_NODES: typing.List[typing.Type] = [
 15    ast.Import,
 16    ast.ImportFrom,
 17    ast.FunctionDef,
 18    ast.ClassDef,
 19]
 20
 21def extract_code(path: str) -> str:
 22    """
 23    Gets the source code out of a path (to either a notebook or vanilla python).
 24    All code will be cleaned in some way (for uncleaned code, just read the file normally).
 25    """
 26
 27    code = None
 28
 29    if (path.lower().endswith('.ipynb')):
 30        code = extract_notebook_code(path)
 31    elif (path.lower().endswith('.py')):
 32        code = extract_python_code(path)
 33    else:
 34        raise ValueError(f"Unknown extension for extracting code: '{path}'.")
 35
 36    return code.strip()
 37
 38def extract_python_code(path: str) -> str:
 39    """
 40    Gets the source code out of a Python code file.
 41    Each line will be stripped of trailing whitespace and joined with a single newline.
 42    This may change the contents of multiline strings.
 43    """
 44
 45    with open(path, 'r', encoding = edq.util.dirent.DEFAULT_ENCODING) as file:
 46        lines = file.readlines()
 47
 48    lines = [line.rstrip() for line in lines]
 49    code = "\n".join(lines)
 50
 51    return code.strip()
 52
 53def extract_notebook_code(path: str) -> str:
 54    """
 55    Extract all the code cells from an iPython notebook.
 56    A concatenation of all the cells (with a newline between each cell) will be output.
 57    """
 58
 59    notebook = edq.util.json.load_path(path, strict = True)
 60
 61    contents = []
 62    for cell in notebook['cells']:
 63        if (cell['cell_type'] != 'code'):
 64            continue
 65
 66        cell_code = ''.join(cell['source']).strip()
 67
 68        # Ignore empty cells.
 69        if (cell_code == ''):
 70            continue
 71
 72        contents.append(cell_code)
 73
 74    return "\n".join(contents) + "\n"
 75
 76def sanitize_and_import_path(path: str, **kwargs: typing.Any) -> typing.Any:
 77    """ Get the code from a source file and call sanitize_and_import_code() with it. """
 78
 79    source_code = extract_code(path)
 80    return sanitize_and_import_code(source_code, code_path = path, **kwargs)
 81
 82def sanitize_and_import_code(
 83        source_code: str,
 84        code_path: str = '<unspecified>',
 85        syspath: typing.Union[str, None] = None,
 86        **kwargs: typing.Any) -> typing.Any:
 87    """
 88    Sanitize the given code, exec it, and return it as a namespace or dict.
 89    The code is assumed to be a module.
 90    See parse_module_code() for sanitization details and kwargs.
 91    Prefer sanitize_and_import_path() over this function, because file and path information will be automatically set.
 92    """
 93
 94    if (syspath is None):
 95        syspath = os.path.dirname(os.path.abspath(os.getcwd()))
 96
 97    module_ast = parse_module_code(source_code, **kwargs)
 98
 99    globals_defs: typing.Dict[str, typing.Any] = {}
100
101    try:
102        sys.path.append(syspath)
103        exec(compile(module_ast, filename = code_path, mode = "exec"), globals_defs)  # pylint: disable=exec-used
104    finally:
105        sys.path.pop()
106
107    return types.SimpleNamespace(**globals_defs)
108
109def parse_module_code(
110        source_code: str,
111        sanitize: bool = True,
112        allowed_module_nodes: typing.Union[typing.List[typing.Type], None] = None,
113        **kwargs: typing.Any) -> ast.Module:
114    """
115    Parse a Python module's (file's) code, optionally sanitize it, and return an AST.
116    Sanitization in this context means removing things that are not
117    imports, functions, constants, and classes.
118    A "constant" will be considered an assignment where the LHS is a single variable all in caps.
119    """
120
121    if (allowed_module_nodes is None):
122        allowed_module_nodes = DEFAULT_MODULE_AST_ALLOWED_NODES
123
124    module_ast = ast.parse(source_code)
125
126    if (not isinstance(module_ast, ast.Module)):
127        raise ValueError(f"Provided code for parsing is not a Python module, found: '{type(module_ast)}'.")
128
129    if (not sanitize):
130        return module_ast
131
132    keep_nodes = []
133    for node in module_ast.body:
134        if (type(node) in DEFAULT_MODULE_AST_ALLOWED_NODES):
135            keep_nodes.append(node)
136            continue
137
138        if (not isinstance(node, ast.Assign)):
139            continue
140
141        if ((len(node.targets) != 1) or (not isinstance(node.targets[0], ast.Name))):
142            continue
143
144        if (node.targets[0].id != node.targets[0].id.upper()):
145            continue
146
147        keep_nodes.append(node)
148
149    module_ast.body = keep_nodes
150    return module_ast
151
152def ast_to_source(code_ast: ast.AST) -> str:
153    """ Get code from a Python AST. """
154
155    return ast.unparse(code_ast)
DEFAULT_MODULE_AST_ALLOWED_NODES: List[Type] = [<class 'ast.Import'>, <class 'ast.ImportFrom'>, <class 'ast.FunctionDef'>, <class 'ast.ClassDef'>]
def extract_code(path: str) -> str:
22def extract_code(path: str) -> str:
23    """
24    Gets the source code out of a path (to either a notebook or vanilla python).
25    All code will be cleaned in some way (for uncleaned code, just read the file normally).
26    """
27
28    code = None
29
30    if (path.lower().endswith('.ipynb')):
31        code = extract_notebook_code(path)
32    elif (path.lower().endswith('.py')):
33        code = extract_python_code(path)
34    else:
35        raise ValueError(f"Unknown extension for extracting code: '{path}'.")
36
37    return code.strip()

Gets the source code out of a path (to either a notebook or vanilla python). All code will be cleaned in some way (for uncleaned code, just read the file normally).

def extract_python_code(path: str) -> str:
39def extract_python_code(path: str) -> str:
40    """
41    Gets the source code out of a Python code file.
42    Each line will be stripped of trailing whitespace and joined with a single newline.
43    This may change the contents of multiline strings.
44    """
45
46    with open(path, 'r', encoding = edq.util.dirent.DEFAULT_ENCODING) as file:
47        lines = file.readlines()
48
49    lines = [line.rstrip() for line in lines]
50    code = "\n".join(lines)
51
52    return code.strip()

Gets the source code out of a Python code file. Each line will be stripped of trailing whitespace and joined with a single newline. This may change the contents of multiline strings.

def extract_notebook_code(path: str) -> str:
54def extract_notebook_code(path: str) -> str:
55    """
56    Extract all the code cells from an iPython notebook.
57    A concatenation of all the cells (with a newline between each cell) will be output.
58    """
59
60    notebook = edq.util.json.load_path(path, strict = True)
61
62    contents = []
63    for cell in notebook['cells']:
64        if (cell['cell_type'] != 'code'):
65            continue
66
67        cell_code = ''.join(cell['source']).strip()
68
69        # Ignore empty cells.
70        if (cell_code == ''):
71            continue
72
73        contents.append(cell_code)
74
75    return "\n".join(contents) + "\n"

Extract all the code cells from an iPython notebook. A concatenation of all the cells (with a newline between each cell) will be output.

def sanitize_and_import_path(path: str, **kwargs: Any) -> Any:
77def sanitize_and_import_path(path: str, **kwargs: typing.Any) -> typing.Any:
78    """ Get the code from a source file and call sanitize_and_import_code() with it. """
79
80    source_code = extract_code(path)
81    return sanitize_and_import_code(source_code, code_path = path, **kwargs)

Get the code from a source file and call sanitize_and_import_code() with it.

def sanitize_and_import_code( source_code: str, code_path: str = '<unspecified>', syspath: Optional[str] = None, **kwargs: Any) -> Any:
 83def sanitize_and_import_code(
 84        source_code: str,
 85        code_path: str = '<unspecified>',
 86        syspath: typing.Union[str, None] = None,
 87        **kwargs: typing.Any) -> typing.Any:
 88    """
 89    Sanitize the given code, exec it, and return it as a namespace or dict.
 90    The code is assumed to be a module.
 91    See parse_module_code() for sanitization details and kwargs.
 92    Prefer sanitize_and_import_path() over this function, because file and path information will be automatically set.
 93    """
 94
 95    if (syspath is None):
 96        syspath = os.path.dirname(os.path.abspath(os.getcwd()))
 97
 98    module_ast = parse_module_code(source_code, **kwargs)
 99
100    globals_defs: typing.Dict[str, typing.Any] = {}
101
102    try:
103        sys.path.append(syspath)
104        exec(compile(module_ast, filename = code_path, mode = "exec"), globals_defs)  # pylint: disable=exec-used
105    finally:
106        sys.path.pop()
107
108    return types.SimpleNamespace(**globals_defs)

Sanitize the given code, exec it, and return it as a namespace or dict. The code is assumed to be a module. See parse_module_code() for sanitization details and kwargs. Prefer sanitize_and_import_path() over this function, because file and path information will be automatically set.

def parse_module_code( source_code: str, sanitize: bool = True, allowed_module_nodes: Optional[List[Type]] = None, **kwargs: Any) -> ast.Module:
110def parse_module_code(
111        source_code: str,
112        sanitize: bool = True,
113        allowed_module_nodes: typing.Union[typing.List[typing.Type], None] = None,
114        **kwargs: typing.Any) -> ast.Module:
115    """
116    Parse a Python module's (file's) code, optionally sanitize it, and return an AST.
117    Sanitization in this context means removing things that are not
118    imports, functions, constants, and classes.
119    A "constant" will be considered an assignment where the LHS is a single variable all in caps.
120    """
121
122    if (allowed_module_nodes is None):
123        allowed_module_nodes = DEFAULT_MODULE_AST_ALLOWED_NODES
124
125    module_ast = ast.parse(source_code)
126
127    if (not isinstance(module_ast, ast.Module)):
128        raise ValueError(f"Provided code for parsing is not a Python module, found: '{type(module_ast)}'.")
129
130    if (not sanitize):
131        return module_ast
132
133    keep_nodes = []
134    for node in module_ast.body:
135        if (type(node) in DEFAULT_MODULE_AST_ALLOWED_NODES):
136            keep_nodes.append(node)
137            continue
138
139        if (not isinstance(node, ast.Assign)):
140            continue
141
142        if ((len(node.targets) != 1) or (not isinstance(node.targets[0], ast.Name))):
143            continue
144
145        if (node.targets[0].id != node.targets[0].id.upper()):
146            continue
147
148        keep_nodes.append(node)
149
150    module_ast.body = keep_nodes
151    return module_ast

Parse a Python module's (file's) code, optionally sanitize it, and return an AST. Sanitization in this context means removing things that are not imports, functions, constants, and classes. A "constant" will be considered an assignment where the LHS is a single variable all in caps.

def ast_to_source(code_ast: ast.AST) -> str:
153def ast_to_source(code_ast: ast.AST) -> str:
154    """ Get code from a Python AST. """
155
156    return ast.unparse(code_ast)

Get code from a Python AST.