diff --git a/tools/mdbook_zh_preprocessor.py b/tools/mdbook_zh_preprocessor.py index 42d05f6..4a3cec7 100644 --- a/tools/mdbook_zh_preprocessor.py +++ b/tools/mdbook_zh_preprocessor.py @@ -4,7 +4,7 @@ import json import sys from pathlib import Path -from prepare_mdbook_zh import build_title_cache, rewrite_markdown +from prepare_mdbook_zh import build_title_cache, parse_bib, rewrite_markdown def iter_chapters(items: list[dict]) -> list[dict]: @@ -26,13 +26,15 @@ def main() -> int: root = Path(context["root"]).resolve() source_dir = root / context["config"]["book"]["src"] title_cache = build_title_cache(source_dir) + bib_path = root / "mlsys.bib" + bib_db = parse_bib(bib_path) if bib_path.exists() else {} for chapter in iter_chapters(book.get("items", [])): source_path = chapter.get("source_path") or chapter.get("path") if not source_path: continue current_file = (source_dir / source_path).resolve() - chapter["content"] = rewrite_markdown(chapter["content"], current_file, title_cache) + chapter["content"] = rewrite_markdown(chapter["content"], current_file, title_cache, bib_db) json.dump(book, sys.stdout, ensure_ascii=False) return 0 diff --git a/tools/prepare_mdbook_zh.py b/tools/prepare_mdbook_zh.py index 31eb054..b0b5a86 100644 --- a/tools/prepare_mdbook_zh.py +++ b/tools/prepare_mdbook_zh.py @@ -11,6 +11,7 @@ OPTION_LINE_RE = re.compile(r"^:(width|label):`[^`]+`\s*$", re.MULTILINE) NUMREF_RE = re.compile(r":numref:`([^`]+)`") EQREF_RE = re.compile(r":eqref:`([^`]+)`") CITE_RE = re.compile(r":cite:`([^`]+)`") +BIB_ENTRY_RE = re.compile(r"@(\w+)\{([^,]+),") RAW_HTML_FILE_RE = re.compile(r"^\s*:file:\s*([^\s]+)\s*$") HEAD_TAG_RE = re.compile(r"?head>", re.IGNORECASE) STYLE_BLOCK_RE = re.compile(r"", re.IGNORECASE | re.DOTALL) @@ -134,7 +135,6 @@ def normalize_directives(markdown: str) -> str: normalized = OPTION_LINE_RE.sub("", markdown) normalized = NUMREF_RE.sub(lambda match: f"`{match.group(1)}`", normalized) normalized = EQREF_RE.sub(lambda match: f"`{match.group(1)}`", normalized) - normalized = CITE_RE.sub(lambda match: f"[{match.group(1)}]", normalized) lines = [line.rstrip() for line in normalized.splitlines()] collapsed: list[str] = [] @@ -152,6 +152,193 @@ def normalize_directives(markdown: str) -> str: return "\n".join(collapsed) + "\n" +# ── BibTeX parsing ──────────────────────────────────────────────────────────── + + +def clean_bibtex(value: str) -> str: + """Remove BibTeX formatting (braces, LaTeX accents) from a string.""" + value = re.sub(r"\{\\[`'^\"~=.](\w)\}", r"\1", value) + value = re.sub(r"\\[`'^\"~=.](\w)", r"\1", value) + value = value.replace("{", "").replace("}", "") + return value.strip() + + +def _parse_bib_fields(body: str) -> dict[str, str]: + """Parse field=value pairs inside a BibTeX entry body.""" + fields: dict[str, str] = {} + i = 0 + while i < len(body): + while i < len(body) and body[i] in " \t\n\r,": + i += 1 + if i >= len(body): + break + start = i + while i < len(body) and body[i] not in "= \t\n\r": + i += 1 + name = body[start:i].strip().lower() + while i < len(body) and body[i] != "=": + i += 1 + if i >= len(body): + break + i += 1 + while i < len(body) and body[i] in " \t\n\r": + i += 1 + if i >= len(body): + break + if body[i] == "{": + depth = 1 + i += 1 + vstart = i + while i < len(body) and depth > 0: + if body[i] == "{": + depth += 1 + elif body[i] == "}": + depth -= 1 + i += 1 + value = body[vstart : i - 1] + elif body[i] == '"': + i += 1 + vstart = i + while i < len(body) and body[i] != '"': + i += 1 + value = body[vstart:i] + i += 1 + else: + vstart = i + while i < len(body) and body[i] not in ", \t\n\r}": + i += 1 + value = body[vstart:i] + if name: + fields[name] = value.strip() + return fields + + +def parse_bib(bib_path: Path) -> dict[str, dict[str, str]]: + """Parse a BibTeX file and return a dict keyed by citation key.""" + text = bib_path.read_text(encoding="utf-8") + entries: dict[str, dict[str, str]] = {} + for match in BIB_ENTRY_RE.finditer(text): + key = match.group(2).strip() + start = match.end() + depth = 1 + pos = start + while pos < len(text) and depth > 0: + if text[pos] == "{": + depth += 1 + elif text[pos] == "}": + depth -= 1 + pos += 1 + fields = _parse_bib_fields(text[start : pos - 1]) + fields["_type"] = match.group(1).lower() + entries[key] = fields + return entries + + +# ── Citation formatting ─────────────────────────────────────────────────────── + + +def _first_author_surname(author_str: str) -> str: + """Extract the first author's surname from a BibTeX author string.""" + author_str = clean_bibtex(author_str) + authors = [a.strip() for a in author_str.split(" and ")] + if not authors or not authors[0]: + return "" + first = authors[0] + if "," in first: + return first.split(",")[0].strip() + parts = first.split() + return parts[-1] if parts else first + + +def _format_cite_label(author: str, year: str) -> str: + """Format an inline citation label like 'Surname et al., Year'.""" + surname = _first_author_surname(author) + if not surname: + return year or "?" + authors = [a.strip() for a in clean_bibtex(author).split(" and ")] + if len(authors) > 2: + name_part = f"{surname} et al." + elif len(authors) == 2: + second = authors[1] + if second.lower() == "others": + name_part = f"{surname} et al." + else: + if "," in second: + surname2 = second.split(",")[0].strip() + else: + parts = second.split() + surname2 = parts[-1] if parts else second + name_part = f"{surname} and {surname2}" + else: + name_part = surname + if year: + return f"{name_part}, {year}" + return name_part + + +def _render_bibliography( + cited_keys: list[str], bib_db: dict[str, dict[str, str]] +) -> list[str]: + """Render a bibliography section for the cited keys.""" + lines: list[str] = ["---", "", "## 参考文献", ""] + for key in cited_keys: + entry = bib_db.get(key) + if not entry: + lines.append(f'
[{key}] {key}.
') + lines.append("") + continue + author = clean_bibtex(entry.get("author", "")) + title = clean_bibtex(entry.get("title", "")) + year = entry.get("year", "") + venue = clean_bibtex(entry.get("journal", "") or entry.get("booktitle", "")) + label = _format_cite_label(entry.get("author", ""), year) + parts: list[str] = [] + if author: + parts.append(author) + if title: + parts.append(f"{title}") + if venue: + parts.append(venue) + if year: + parts.append(year) + text = ". ".join(parts) + "." if parts else f"{key}." + lines.append(f'[{label}] {text}
') + lines.append("") + return lines + + +def process_citations( + markdown: str, bib_db: dict[str, dict[str, str]] +) -> str: + """Replace :cite: references with linked citations and append bibliography.""" + cited_keys: list[str] = [] + + def _replace_cite(match: re.Match[str]) -> str: + keys = [k.strip() for k in match.group(1).split(",")] + for key in keys: + if key not in cited_keys: + cited_keys.append(key) + if not bib_db: + return "[" + ", ".join(keys) + "]" + parts: list[str] = [] + for key in keys: + entry = bib_db.get(key) + if entry: + label = _format_cite_label( + entry.get("author", ""), entry.get("year", "") + ) + parts.append(f'{label}') + else: + parts.append(key) + return "[" + "; ".join(parts) + "]" + + processed = CITE_RE.sub(_replace_cite, markdown) + if cited_keys and bib_db: + bib_lines = _render_bibliography(cited_keys, bib_db) + processed = processed.rstrip("\n") + "\n\n" + "\n".join(bib_lines) + "\n" + return processed + + def resolve_raw_html_file(current_file: Path, filename: str) -> Path: direct = (current_file.parent / filename).resolve() if direct.exists(): @@ -216,7 +403,12 @@ def render_toc_list(entries: list[str], current_file: Path, title_cache: dict[Pa return rendered -def rewrite_markdown(markdown: str, current_file: Path, title_cache: dict[Path, str]) -> str: +def rewrite_markdown( + markdown: str, + current_file: Path, + title_cache: dict[Path, str], + bib_db: dict[str, dict[str, str]] | None = None, +) -> str: output: list[str] = [] lines = markdown.splitlines() index = 0 @@ -256,7 +448,9 @@ def rewrite_markdown(markdown: str, current_file: Path, title_cache: dict[Path, while output and output[-1] == "": output.pop() - return normalize_directives("\n".join(output) + "\n") + result = normalize_directives("\n".join(output) + "\n") + result = process_citations(result, bib_db or {}) + return result def build_title_cache(source_dir: Path) -> dict[Path, str]: