Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/scripts/docstrings.py
5282 views
1
"""Lightweight fork of Keras-Autodocs.
2
"""
3
4
import warnings
5
import black
6
import re
7
import inspect
8
import importlib
9
import itertools
10
import copy
11
12
import render_presets
13
14
15
class KerasDocumentationGenerator:
16
def __init__(self, project_url=None):
17
self.project_url = project_url
18
19
def process_docstring(self, docstring):
20
docstring = docstring.replace("Args:", "# Arguments")
21
docstring = docstring.replace("Arguments:", "# Arguments")
22
docstring = docstring.replace("Attributes:", "# Attributes")
23
docstring = docstring.replace("Returns:", "# Returns")
24
docstring = docstring.replace("Raises:", "# Raises")
25
docstring = docstring.replace("Input shape:", "# Input shape")
26
docstring = docstring.replace("Output shape:", "# Output shape")
27
docstring = docstring.replace("Call arguments:", "# Call arguments")
28
docstring = docstring.replace("Returns:", "# Returns")
29
docstring = docstring.replace("Example:", "# Example\n")
30
docstring = docstring.replace("Examples:", "# Examples\n")
31
32
docstring = re.sub(r"\nReference:\n\s*", "\n**Reference**\n\n", docstring)
33
docstring = re.sub(r"\nReferences:\n\s*", "\n**References**\n\n", docstring)
34
35
# Fix typo
36
docstring = docstring.replace("\n >>> ", "\n>>> ")
37
38
lines = docstring.split("\n")
39
doctest_lines = []
40
usable_lines = []
41
42
def flush_docstest(usable_lines, doctest_lines):
43
usable_lines.append("```python")
44
usable_lines += doctest_lines
45
usable_lines.append("```")
46
usable_lines.append("")
47
48
for line in lines:
49
if doctest_lines:
50
if not line or set(line) == {" "}:
51
flush_docstest(usable_lines, doctest_lines)
52
doctest_lines = []
53
else:
54
doctest_lines.append(line)
55
else:
56
if line.startswith(">>>"):
57
doctest_lines.append(line)
58
else:
59
usable_lines.append(line)
60
if doctest_lines:
61
flush_docstest(usable_lines, doctest_lines)
62
docstring = "\n".join(usable_lines)
63
64
return process_docstring(docstring)
65
66
def process_signature(self, signature):
67
signature = signature.replace("tensorflow.keras", "tf.keras")
68
signature = signature.replace("*args, **kwargs", "")
69
return signature
70
71
def render(self, element):
72
if isinstance(element, str):
73
object_ = import_object(element)
74
if ismethod(object_):
75
# we remove the modules when displaying the methods
76
signature_override = ".".join(element.split(".")[-2:])
77
else:
78
signature_override = element
79
else:
80
signature_override = None
81
object_ = element
82
return self.render_from_object(object_, signature_override, element)
83
84
def render_from_object(self, object_, signature_override: str, element):
85
subblocks = []
86
source_link = make_source_link(object_, self.project_url)
87
if source_link is not None:
88
subblocks.append(source_link)
89
signature = get_signature(object_, signature_override)
90
signature = self.process_signature(signature)
91
subblocks.append(f"### `{get_name(object_)}` {get_type(object_)}\n")
92
subblocks.append(code_snippet(signature))
93
94
docstring = inspect.getdoc(object_)
95
if docstring:
96
docstring = self.process_docstring(docstring)
97
subblocks.append(docstring)
98
# Render preset table for KerasCV and KerasHub
99
if element.endswith("from_preset"):
100
table = render_presets.render_table(import_object(element.rsplit(".", 1)[0]))
101
if table is not None:
102
subblocks.append(table)
103
return "\n\n".join(subblocks) + "\n\n----\n\n"
104
105
106
def ismethod(function):
107
return get_class_from_method(function) is not None
108
109
110
def import_object(string: str):
111
"""Import an object from a string.
112
113
The object can be a function, class or method.
114
For example: `'keras.layers.Dense.get_weights'` is valid.
115
"""
116
last_object_got = None
117
seen_names = []
118
for name in string.split("."):
119
seen_names.append(name)
120
try:
121
last_object_got = importlib.import_module(".".join(seen_names))
122
except ModuleNotFoundError:
123
assert last_object_got is not None, f"Failed to import path {string}"
124
last_object_got = getattr(last_object_got, name)
125
return last_object_got
126
127
128
def make_source_link(cls, project_url):
129
if not hasattr(cls, "__module__"):
130
return None
131
if not project_url:
132
return None
133
134
base_module = cls.__module__.split(".")[0]
135
project_url = project_url[base_module]
136
assert project_url.endswith("/"), f"{base_module} not found"
137
project_url_version = project_url.split("/")[-2].removeprefix("v")
138
module_version = copy.copy(importlib.import_module(base_module).__version__)
139
if ".dev" in module_version:
140
module_version = project_url_version[: module_version.find(".dev")]
141
# TODO: Remove keras-rs condition, this is just a temporary thing.
142
if "keras-rs" not in project_url and module_version != project_url_version:
143
raise RuntimeError(
144
f"For project {base_module}, URL {project_url} "
145
f"has version number {project_url_version} which does not match the "
146
f"current imported package version {module_version}"
147
)
148
path = cls.__module__.replace(".", "/")
149
if base_module in ("tf_keras",):
150
path = path.replace("/src/", "/")
151
line = inspect.getsourcelines(cls)[-1]
152
return (
153
f'<span style="float:right;">'
154
f"[[source]]({project_url}{path}.py#L{line})"
155
f"</span>"
156
)
157
158
159
def code_snippet(snippet):
160
return f"```python\n{snippet}\n```\n"
161
162
163
def get_type(object_) -> str:
164
if inspect.isclass(object_):
165
return "class"
166
elif ismethod(object_):
167
return "method"
168
elif inspect.isfunction(object_):
169
return "function"
170
elif hasattr(object_, "fget"):
171
return "property"
172
else:
173
raise TypeError(
174
f"{object_} is detected as not a class, a method, "
175
f"a property, nor a function."
176
)
177
178
179
def get_name(object_) -> str:
180
if hasattr(object_, "fget"):
181
return object_.fget.__name__
182
return object_.__name__
183
184
185
def get_function_name(function):
186
if hasattr(function, "__wrapped__"):
187
return get_function_name(function.__wrapped__)
188
return function.__name__
189
190
191
def get_default_value_for_repr(value):
192
"""Return a substitute for rendering the default value of a funciton arg.
193
194
Function and object instances are rendered as <Foo object at 0x00000000>
195
which can't be parsed by black. We substitute functions with the function
196
name and objects with a rendered version of the constructor like
197
`Foo(a=2, b="bar")`.
198
199
Args:
200
value: The value to find a better rendering of.
201
202
Returns:
203
Another value or `None` if no substitution is needed.
204
"""
205
206
class ReprWrapper:
207
def __init__(self, representation):
208
self.representation = representation
209
210
def __repr__(self):
211
return self.representation
212
213
if value is inspect._empty:
214
return None
215
216
if inspect.isfunction(value):
217
# Render the function name instead
218
return ReprWrapper(value.__name__)
219
220
if inspect.isclass(value):
221
# Render classes as module.ClassName to produce a valid python
222
# dotted-name expression in the fake signature (black can parse it).
223
return ReprWrapper(value.__module__ + "." + value.__name__)
224
225
if (
226
repr(value).startswith("<") # <Foo object at 0x00000000>
227
and hasattr(value, "__class__") # it is an object
228
and hasattr(value, "get_config") # it is a Keras object
229
):
230
config = value.get_config()
231
init_args = [] # The __init__ arguments to render
232
for p in inspect.signature(value.__class__.__init__).parameters.values():
233
if p.name == "self":
234
continue
235
if p.kind == inspect.Parameter.POSITIONAL_ONLY:
236
# Required positional, render without a name
237
init_args.append(repr(config[p.name]))
238
elif p.default is inspect._empty or p.default != config[p.name]:
239
# Keyword arg with non-default value, render
240
init_args.append(p.name + "=" + repr(config[p.name]))
241
# else don't render that argument
242
return ReprWrapper(
243
value.__class__.__module__
244
+ "."
245
+ value.__class__.__name__
246
+ "("
247
+ ", ".join(init_args)
248
+ ")"
249
)
250
251
return None
252
253
254
def get_signature_start(function):
255
"""For the Dense layer, it should return the string 'keras.layers.Dense'"""
256
if ismethod(function):
257
prefix = f"{get_class_from_method(function).__name__}."
258
else:
259
try:
260
prefix = f"{function.__module__}."
261
except AttributeError:
262
warnings.warn(
263
f"function {function} has no module. "
264
f"It will not be included in the signature."
265
)
266
prefix = ""
267
return f"{prefix}{get_function_name(function)}"
268
269
270
def get_signature_end(function):
271
params = inspect.signature(function).parameters.values()
272
273
formatted_params = []
274
for x in params:
275
default = get_default_value_for_repr(x.default)
276
if default:
277
x = inspect.Parameter(
278
x.name, x.kind, default=default, annotation=x.annotation
279
)
280
str_x = str(x)
281
formatted_params.append(str_x)
282
signature_end = "(" + ", ".join(formatted_params) + ")"
283
284
if ismethod(function):
285
signature_end = signature_end.replace("(self, ", "(")
286
signature_end = signature_end.replace("(self)", "()")
287
# work around case-specific bug
288
signature_end = signature_end.replace(
289
"synchronization=<VariableSynchronization.AUTO: 0>, aggregation=<VariableAggregationV2.NONE: 0>",
290
"synchronization=tf.VariableSynchronization.AUTO, aggregation=tf.VariableSynchronization.NONE",
291
)
292
return signature_end
293
294
295
def get_function_signature(function, override=None):
296
if override is None:
297
signature_start = get_signature_start(function)
298
else:
299
signature_start = override
300
signature_end = get_signature_end(function)
301
return format_signature(signature_start, signature_end)
302
303
304
def get_class_signature(cls, override=None):
305
if override is None:
306
signature_start = f"{cls.__module__}.{cls.__name__}"
307
else:
308
signature_start = override
309
signature_end = get_signature_end(cls.__init__)
310
return format_signature(signature_start, signature_end)
311
312
313
def get_signature(object_, override):
314
if inspect.isclass(object_):
315
return get_class_signature(object_, override)
316
elif inspect.isfunction(object_) or inspect.ismethod(object_):
317
return get_function_signature(object_, override)
318
elif hasattr(object_, "fget"):
319
# properties
320
if override:
321
return override
322
return get_function_signature(object_.fget)
323
raise ValueError(f"Not able to retrieve signature for object {object_}")
324
325
326
def format_signature(signature_start: str, signature_end: str):
327
"""pretty formatting to avoid long signatures on one single line"""
328
# first, we make it look like a real function declaration.
329
fake_signature_start = "x" * len(signature_start)
330
fake_signature = fake_signature_start + signature_end
331
fake_python_code = f"def {fake_signature}:\n pass\n"
332
# we format with black
333
mode = black.FileMode(line_length=90)
334
formatted_fake_python_code = black.format_str(fake_python_code, mode=mode)
335
# we make the final, multiline signature
336
new_signature_end = extract_signature_end(formatted_fake_python_code)
337
return signature_start + new_signature_end
338
339
340
def extract_signature_end(function_definition):
341
start = function_definition.find("(")
342
stop = function_definition.rfind(")")
343
return function_definition[start : stop + 1]
344
345
346
def get_code_blocks(docstring):
347
code_blocks = {}
348
tmp = docstring[:]
349
while "```" in tmp:
350
tmp = tmp[tmp.find("```") :]
351
index = tmp[3:].find("```") + 6
352
snippet = tmp[:index]
353
# Place marker in docstring for later reinjection.
354
# Print the index with 4 digits so we know the symbol is unique.
355
token = f"$KERAS_AUTODOC_CODE_BLOCK_{len(code_blocks):04d}"
356
docstring = docstring.replace(snippet, token)
357
code_blocks[token] = snippet
358
tmp = tmp[index:]
359
return code_blocks, docstring
360
361
362
def get_section_end(docstring, section_start):
363
regex_indented_sections_end = re.compile(r"\S\n+(\S|$)")
364
end = re.search(regex_indented_sections_end, docstring[section_start:])
365
section_end = section_start + end.end()
366
if section_end == len(docstring):
367
return section_end
368
else:
369
return section_end - 2
370
371
372
def get_google_style_sections_without_code(docstring):
373
regex_indented_sections_start = re.compile(r"\n# .+?\n")
374
google_style_sections = {}
375
for i in itertools.count():
376
match = re.search(regex_indented_sections_start, docstring)
377
if match is None:
378
break
379
section_start = match.start() + 1
380
section_end = get_section_end(docstring, section_start)
381
google_style_section = docstring[section_start:section_end]
382
token = f"KERAS_AUTODOC_GOOGLE_STYLE_SECTION_{i}"
383
google_style_sections[token] = google_style_section
384
docstring = insert_in_string(docstring, token, section_start, section_end)
385
return google_style_sections, docstring
386
387
388
def get_google_style_sections(docstring):
389
# First, extract code blocks and process them.
390
# The parsing is easier if the #, : and other symbols aren't there.
391
code_blocks, docstring = get_code_blocks(docstring)
392
google_style_sections, docstring = get_google_style_sections_without_code(docstring)
393
docstring = reinject_strings(docstring, code_blocks)
394
for section_token, section in google_style_sections.items():
395
section = reinject_strings(section, code_blocks)
396
google_style_sections[section_token] = reinject_strings(section, code_blocks)
397
return google_style_sections, docstring
398
399
400
def to_markdown(google_style_section: str) -> str:
401
end_first_line = google_style_section.find("\n")
402
section_title = google_style_section[2:end_first_line]
403
section_body = google_style_section[end_first_line:]
404
section_body = remove_indentation(section_body)
405
if section_title in (
406
"Arguments",
407
"Attributes",
408
"Raises",
409
"Call arguments",
410
"Returns",
411
):
412
section_body = format_as_markdown_list(section_body)
413
if section_body:
414
return f"__{section_title}__\n\n{section_body}\n"
415
else:
416
return f"__{section_title}__\n"
417
418
419
def format_as_markdown_list(section_body):
420
section_body = re.sub(r"\n([^ ].*?):", r"\n- __\1__:", section_body)
421
section_body = re.sub(r"^([^ ].*?):", r"- __\1__:", section_body)
422
# Switch to 2-space indent so we can render nested lists.
423
section_body = section_body.replace("\n ", "\n ")
424
return section_body
425
426
427
def reinject_strings(target, strings_to_inject):
428
for token, string_to_inject in strings_to_inject.items():
429
target = target.replace(token, string_to_inject)
430
return target
431
432
433
def process_docstring(docstring):
434
if docstring[-1] != "\n":
435
docstring += "\n"
436
437
google_style_sections, docstring = get_google_style_sections(docstring)
438
for token, google_style_section in google_style_sections.items():
439
markdown_section = to_markdown(google_style_section)
440
docstring = docstring.replace(token, markdown_section)
441
return docstring
442
443
444
def get_class_from_method(meth):
445
if inspect.ismethod(meth):
446
for cls in inspect.getmro(meth.__self__.__class__):
447
if cls.__dict__.get(meth.__name__) is meth:
448
return cls
449
meth = meth.__func__ # fallback to __qualname__ parsing
450
if inspect.isfunction(meth):
451
cls_name = meth.__qualname__.split(".<locals>", 1)[0].rsplit(".", 1)[0]
452
cls = getattr(inspect.getmodule(meth), cls_name, None)
453
if isinstance(cls, type):
454
return cls
455
return getattr(meth, "__objclass__", None) # handle special descriptor objects
456
457
458
def insert_in_string(target, string_to_insert, start, end):
459
target_start_cut = target[:start]
460
target_end_cut = target[end:]
461
return target_start_cut + string_to_insert + target_end_cut
462
463
464
def remove_indentation(string):
465
lines = string.split("\n")
466
leading_spaces = [count_leading_spaces(l) for l in lines if l]
467
if leading_spaces:
468
min_leading_spaces = min(leading_spaces)
469
string = "\n".join(l[min_leading_spaces:] for l in lines)
470
return string.strip() # Drop leading/closing empty lines
471
472
473
def count_leading_spaces(s):
474
ws = re.search(r"\S", s)
475
if ws:
476
return ws.start()
477
return 0
478
479